[
  {
    "path": ".github/scripts/build.sh",
    "content": "#!/bin/bash\n\nset -eoxu pipefail\n\n# We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6\n# https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810\n# However this still fails so I am using a newer version of setuptools\npip install setuptools==68.0.0\npip install ninja packaging wheel\nexport PATH=/usr/local/cuda/bin:/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH\nexport LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH\n\n# Limit MAX_JOBS otherwise the github runner goes OOM\nexport MAX_JOBS=2 \nexport MAMBA_FORCE_BUILD=\"TRUE\" \nexport MAMBA_FORCE_CXX11_ABI=$CXX11_ABI \n\n# 5h timeout since GH allows max 6h and we want some buffer\nEXIT_CODE=0\ntimeout 5h python setup.py bdist_wheel --dist-dir=dist || EXIT_CODE=$?\n\nif [ $EXIT_CODE -eq 0 ]; then\ntmpname=cu${WHEEL_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi$CXX11_ABI\nwheel_name=$(ls dist/*whl | xargs -n 1 basename | sed \"s/-/+$tmpname-/2\")\nls dist/*whl |xargs -I {} mv {} dist/${wheel_name}\necho \"wheel_name=${wheel_name}\" >> $GITHUB_ENV\nfi\n\necho $EXIT_CODE"
  },
  {
    "path": ".github/scripts/check_for_ngc_images.sh",
    "content": "#!/bin/bash\n\n# Configuration\nBASE_IMAGE=\"nvcr.io/nvidia/pytorch\"\nTAG_SUFFIX=\"-py3\"\nMONTHS_TO_CHECK=7 # Check current month and previous 6 months (total 7)\n\n# Initialize an array to store existing tags\nEXISTING_TAGS=()\n\necho \"Checking for existence of the last ${MONTHS_TO_CHECK} NGC PyTorch images: ${BASE_IMAGE}:YY.MM${TAG_SUFFIX}\"\necho \"---------------------------------------------------------------------\"\n\n# Loop through the last N months\nfor i in $(seq 0 $((MONTHS_TO_CHECK - 1))); do\n    # Calculate Year and Month for the tag\n    CURRENT_YEAR=$(date +%Y)\n    CURRENT_MONTH=$(date +%m)\n    \n    # Calculate target month and year\n    TARGET_DATE=$(date -d \"$CURRENT_YEAR-$CURRENT_MONTH-01 -$i months\" +%y.%m)\n    \n    # Construct the full image tag and the tag-only string\n    IMAGE_TAG=\"${TARGET_DATE}${TAG_SUFFIX}\"\n    FULL_IMAGE=\"${BASE_IMAGE}:${IMAGE_TAG}\"\n\n    echo \"Checking: ${FULL_IMAGE}\"\n\n    # Use 'docker manifest inspect' to check for image existence without pulling.\n    if docker manifest inspect \"${FULL_IMAGE}\" > /dev/null 2>&1; then\n        echo \"✅ EXISTS: Found.\"\n        # Add the tag-only string to the array\n        EXISTING_TAGS+=(\"nvcr.io/nvidia/pytorch:${IMAGE_TAG}\")\n    else\n        echo \"❌ MISSING: Not found.\"\n    fi\ndone\n\necho \"---------------------------------------------------------------------\"\n\n## JSON Output Generation\n# This uses the collected array to build a JSON string.\n\n# 1. Convert the shell array to a newline-separated string.\nTAGS_NL_SEP=$(printf \"%s\\n\" \"${EXISTING_TAGS[@]}\")\n\n# 2. Use jq to read the newline-separated list and format it into a JSON array.\n# . | split(\"\\n\") | .[:-1] reads the input, splits it by newline, and removes the trailing empty element.\nif command -v jq &> /dev/null; then\n    JSON_STRING=$(echo -e \"${TAGS_NL_SEP}\" | jq -R -s 'split(\"\\n\") | .[:-1]')\n    \n    echo \"Generated JSON String of Existing Tags:\"\n    echo \"${JSON_STRING}\"\n    \n    # Optional: Save the JSON string to a variable for further use\n    # echo \"JSON_STRING is now available in the shell if you source this script.\"\nelse\n    echo \"WARNING: 'jq' is not installed. Cannot format output as JSON.\"\n    echo \"Found Tags: ${EXISTING_TAGS[*]}\"\nfi\n\necho \"---\"\necho \"Check complete.\"\n\necho \"${JSON_STRING}\" > ngc_images.json"
  },
  {
    "path": ".github/scripts/test.sh",
    "content": "#!/bin/bash\n\nset -exou pipefail\n\npip install dist/*.whl\npython -c \"import mamba_ssm; print(mamba_ssm.__version__)\""
  },
  {
    "path": ".github/workflows/_build.yml",
    "content": "name: ~Build wheel template\n\non:\n  workflow_call:\n    inputs:\n      runs-on:\n        description: \"The runner to use for the build\"\n        required: true\n        type: string\n      python-version:\n        description: \"The Python version to use for the build\"\n        required: true\n        type: string\n      cuda-version:\n        description: \"The CUDA version to use for the build\"\n        required: true\n        type: string\n      torch-version:\n        description: \"The PyTorch version to use for the build\"\n        required: true\n        type: string\n      cxx11_abi:\n        description: \"The C++11 ABI to use for the build\"\n        required: true\n        type: string\n      upload-to-release:\n        description: \"Upload wheel to this release\"\n        required: false\n        type: boolean\n        default: false\n      release-version:\n        description: \"Upload wheel to this release\"\n        required: false\n        type: string\n\ndefaults:\n  run:\n    shell: bash -x -e -u -o pipefail {0}\n\njobs:\n  build-wheel:\n    runs-on: ${{ inputs.runs-on }}\n    name: Build wheel (${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }})\n    steps:\n      - name: Checkout\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.release-version }}\n          submodules: recursive\n      \n      - name: Checkout build scripts\n        uses: actions/checkout@v4\n        with:\n          path: build-scripts/\n\n      - name: Set up Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: ${{ inputs.python-version }}\n\n      - name: Set CUDA and PyTorch versions\n        run: |\n          echo \"MATRIX_CUDA_VERSION=$(echo ${{ inputs.cuda-version }} | awk -F \\. {'print $1 $2'})\" >> $GITHUB_ENV\n          echo \"MATRIX_TORCH_VERSION=$(echo ${{ inputs.torch-version }} | awk -F \\. {'print $1 \".\" $2'})\" >> $GITHUB_ENV\n          echo \"WHEEL_CUDA_VERSION=$(echo ${{ inputs.cuda-version }} | awk -F \\. {'print $1'})\" >> $GITHUB_ENV\n          echo \"MATRIX_PYTHON_VERSION=$(echo ${{ inputs.python-version }} | awk -F \\. {'print $1 $2'})\" >> $GITHUB_ENV\n\n      - name: Free up disk space\n        if: ${{ runner.os == 'Linux' }}\n        # https://github.com/easimon/maximize-build-space/blob/master/action.yml\n        # https://github.com/easimon/maximize-build-space/tree/test-report\n        run: |\n          sudo rm -rf /usr/share/dotnet\n          sudo rm -rf /opt/ghc\n          sudo rm -rf /opt/hostedtoolcache/CodeQL\n\n      - name: Set up swap space\n        if: runner.os == 'Linux'\n        uses: pierotofy/set-swap-space@v1.0\n        with:\n          swap-size-gb: 10\n\n      - name: Install CUDA ${{ inputs.cuda-version }}\n        if: ${{ inputs.cuda-version != 'cpu' }}\n        uses: Jimver/cuda-toolkit@v0.2.30\n        id: cuda-toolkit\n        with:\n          cuda: ${{ inputs.cuda-version }}\n          linux-local-args: '[\"--toolkit\"]'\n          # default method is \"local\", and we're hitting some error with caching for CUDA 11.8 and 12.1\n          # method: ${{ (inputs.cuda-version == '11.8.0' || inputs.cuda-version == '12.1.0') && 'network' || 'local' }}\n          method: \"network\"\n          sub-packages: '[\"nvcc\"]'\n\n      - name: Install PyTorch ${{ inputs.torch-version }}+cu${{ inputs.cuda-version }}\n        run: |\n          pip install --upgrade pip\n          # For some reason torch 2.2.0 on python 3.12 errors saying no setuptools\n          pip install setuptools==68.0.0\n          # With python 3.13 and torch 2.5.1, unless we update typing-extensions, we get error\n          # AttributeError: attribute '__default__' of 'typing.ParamSpec' objects is not writable\n          pip install typing-extensions==4.12.2\n          # Pick the highest available PyTorch wheel CUDA version that doesn't exceed system CUDA\n          export TORCH_CUDA_VERSION=$(python -c \"from os import environ as env; \\\n            available = { \\\n              '2.6': [118, 124, 126], \\\n              '2.7': [118, 126, 128], \\\n              '2.8': [126, 128, 129], \\\n              '2.9': [126, 128, 130], \\\n              '2.10': [126, 128, 130], \\\n            }[env['MATRIX_TORCH_VERSION']]; \\\n            sys_cuda = int(env['MATRIX_CUDA_VERSION']); \\\n            print(max(v for v in available if v <= sys_cuda))\" \\\n          )\n          if [[ ${{ inputs.torch-version }} == *\"dev\"* ]]; then\n            # pip install --no-cache-dir --pre torch==${{ inputs.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}\n            # Can't use --no-deps because we need cudnn etc.\n            # Hard-coding this version of pytorch-triton for torch 2.6.0.dev20241001\n            pip install jinja2\n            pip install https://download.pytorch.org/whl/nightly/pytorch_triton-3.1.0%2Bcf34004b8a-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl\n            pip install --no-cache-dir --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ inputs.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl\n          else\n            pip install --no-cache-dir torch==${{ inputs.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION}\n          fi\n          nvcc --version\n          python --version\n          python -c \"import torch; print('PyTorch:', torch.__version__)\"\n          python -c \"import torch; print('CUDA:', torch.version.cuda)\"\n          python -c \"from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)\"\n        shell: bash\n\n      - name: Build wheel\n        id: build_wheel\n        env: \n          CXX11_ABI: ${{ inputs.cxx11_abi }}\n          MATRIX_TORCH_VERSION: ${{ env.MATRIX_TORCH_VERSION}}\n          WHEEL_CUDA_VERSION: ${{ env.WHEEL_CUDA_VERSION }}\n          MATRIX_PYTHON_VERSION: ${{ env.MATRIX_PYTHON_VERSION }}\n        run: |\n          EXIT_CODE=$(bash build-scripts/.github/scripts/build.sh | tail -n 1)\n\n          # Store exit code in GitHub env for later steps\n          echo \"build_exit_code=$EXIT_CODE\" | tee -a \"$GITHUB_OUTPUT\"\n\n          exit $EXIT_CODE\n\n      - name: Log Built Wheels\n        run: |\n          ls dist\n\n      - name: Get Release with tag\n        id: get_current_release\n        uses: joutvhu/get-release@v1\n        with:\n          tag_name: ${{ inputs.release-version }}\n        env:\n          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}\n\n      - name: Upload Release Asset\n        id: upload_release_asset\n        if: inputs.upload-to-release\n        uses: actions/upload-release-asset@v1\n        env:\n          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}\n        with:\n          upload_url: ${{ steps.get_current_release.outputs.upload_url }}\n          asset_path: ./dist/${{env.wheel_name}}\n          asset_name: ${{env.wheel_name}}\n          asset_content_type: application/*\n"
  },
  {
    "path": ".github/workflows/_build_in_container.yml",
    "content": "name: ~Build wheel template\n\non:\n  workflow_call:\n    inputs:\n      runs-on:\n        description: \"The runner to use for the build\"\n        required: true\n        type: string\n      container-image:\n        description: \"Container image\"\n        required: true\n        type: string\n      upload-to-release:\n        description: \"Upload wheel to this release\"\n        required: false\n        type: boolean\n        default: false\n      release-version:\n        description: \"Upload wheel to this release\"\n        required: false\n        type: string\n\ndefaults:\n  run:\n    shell: bash -x -e -u -o pipefail {0}\n\njobs:\n  build-wheel:\n    runs-on: ${{ inputs.runs-on }}\n    name: Build wheel (${{ inputs.container-image }})\n    steps:\n      - name: Move /var/lib/containerd/\n        run: |\n          mkdir -p \"${GITHUB_WORKSPACE}/docker/containerd\"\n          sudo mv /var/lib/containerd/ \"${GITHUB_WORKSPACE}/docker/containerd\"\n\n      - name: Move /var/lib/containerd/\n        run: |\n          mkdir -p \"${GITHUB_WORKSPACE}/docker/docker\"\n          sudo mv /var/lib/docker/ \"${GITHUB_WORKSPACE}/docker/docker\"\n\n      - name: Maximize build space\n        uses: easimon/maximize-build-space@master\n        with:\n          root-reserve-mb: 5120\n          temp-reserve-mb: 32\n          swap-size-mb: 10240\n          remove-dotnet: \"true\"\n          remove-android: \"true\"\n          remove-haskell: \"true\"\n          remove-codeql: \"true\"\n          build-mount-path: \"/var/lib/\"\n\n      - name: Restore /var/lib/containerd/\n        run: sudo sh -c \"mv ${GITHUB_WORKSPACE}/docker/containerd/* /var/lib/containerd\"\n\n      - name: Restore /var/lib/docker/\n        run: sudo sh -c \"mv ${GITHUB_WORKSPACE}/docker/docker/* /var/lib/docker\"\n\n      - name: Checkout source\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ inputs.release-version }}\n          submodules: recursive\n\n      - name: Checkout build scripts\n        uses: actions/checkout@v4\n        with:\n          path: build-scripts/\n\n      - name: Build\n        run: |\n          echo \"Free space:\"\n          df -h\n\n      - name: Pull the container\n        run: docker pull ${{ inputs.container-image }}\n\n      - name: Set CUDA and PyTorch versions\n        run: |\n          cat <<'EOF' >> script.sh\n          #!/bin/bash\n\n          set -eoxu pipefail\n\n          echo \"MATRIX_CUDA_VERSION=$(echo $CUDA_VERSION | awk -F \\. {'print $1 $2'})\" >> $GITHUB_ENV\n          echo \"MATRIX_TORCH_VERSION=$NVIDIA_PYTORCH_VERSION\" >> $GITHUB_ENV\n          echo \"WHEEL_CUDA_VERSION=$(echo $CUDA_VERSION | awk -F \\. {'print $1'})\" >> $GITHUB_ENV\n          echo \"MATRIX_PYTHON_VERSION=$(python -c \"import sys; print('{}.{}'.format(sys.version_info[0], sys.version_info[1]))\" | awk -F \\. {'print $1 $2'})\" >> $GITHUB_ENV\n          echo \"CXX11_ABI=$(python -c 'import torch; print(str(torch._C._GLIBCXX_USE_CXX11_ABI).upper())')\" >> $GITHUB_ENV\n\n          cat $GITHUB_ENV\n          EOF\n\n          docker run \\\n            --rm \\\n            --shm-size=64g \\\n            --workdir /workspace \\\n            --volume $(pwd):/workspace \\\n            --volume $GITHUB_ENV:$GITHUB_ENV \\\n            -e GITHUB_ENV=$GITHUB_ENV \\\n            ${{ inputs.container-image }} bash /workspace/script.sh\n\n      - name: Build wheel\n        id: build_wheel\n        env:\n          CXX11_ABI: ${{ env.CXX11_ABI }}\n          MATRIX_TORCH_VERSION: ${{ env.MATRIX_TORCH_VERSION}}\n          WHEEL_CUDA_VERSION: ${{ env.WHEEL_CUDA_VERSION }}\n          MATRIX_PYTHON_VERSION: ${{ env.MATRIX_PYTHON_VERSION }}\n        run: |\n          EXIT_CODE=$(docker run \\\n            --rm \\\n            --shm-size=64g \\\n            --workdir /workspace \\\n            --volume $(pwd):/workspace \\\n            --volume $GITHUB_ENV:$GITHUB_ENV \\\n            -e PIP_CONSTRAINT= \\\n            -e GITHUB_ENV=$GITHUB_ENV \\\n            -e CXX11_ABI=$CXX11_ABI \\\n            -e MATRIX_TORCH_VERSION=$MATRIX_TORCH_VERSION \\\n            -e WHEEL_CUDA_VERSION=$WHEEL_CUDA_VERSION \\\n            -e MATRIX_PYTHON_VERSION=$MATRIX_PYTHON_VERSION \\\n            ${{ inputs.container-image }} bash /workspace/build-scripts/.github/scripts/build.sh | tail -n 1)\n\n      - name: Test wheels\n        run: |\n          docker run \\\n            --rm \\\n            --shm-size=64g \\\n            --workdir /workspace \\\n            --volume $(pwd):/workspace \\\n            --volume $GITHUB_ENV:$GITHUB_ENV \\\n            -e GITHUB_ENV=$GITHUB_ENV \\\n            ${{ inputs.container-image }} bash /workspace/build-scripts/.github/scripts/test.sh\n\n      - name: Log Built Wheels\n        run: |\n          ls dist\n\n      - name: Get Release with tag\n        id: get_current_release\n        uses: joutvhu/get-release@v1\n        with:\n          tag_name: ${{ inputs.release-version }}\n        env:\n          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}\n\n      - name: Upload Release Asset\n        id: upload_release_asset\n        if: inputs.upload-to-release\n        uses: actions/upload-release-asset@v1\n        env:\n          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}\n        with:\n          upload_url: ${{ steps.get_current_release.outputs.upload_url }}\n          asset_path: ./dist/${{env.wheel_name}}\n          asset_name: ${{env.wheel_name}}\n          asset_content_type: application/*\n"
  },
  {
    "path": ".github/workflows/build.yml",
    "content": "name: Build wheels\n\non:\n  workflow_dispatch:\n    inputs:\n      runs-on:\n        description: \"The runner to use for the build\"\n        required: true\n        type: string\n        default: ubuntu-22.04\n      python-version:\n        description: \"The Python version to use for the build\"\n        required: true\n        type: string\n      cuda-version:\n        description: \"The CUDA version to use for the build\"\n        required: true\n        type: string\n      torch-version:\n        description: \"The PyTorch version to use for the build\"\n        required: true\n        type: string\n      cxx11_abi:\n        description: \"Enable torch flag C++11 ABI (TRUE/FALSE)\"\n        required: true\n        type: string\n      upload-to-release:\n        description: \"Upload wheel to this release\"\n        required: false\n        type: boolean\n        default: false\n      release-version:\n        description: \"Upload wheel to this release\"\n        required: false\n        type: string\n\njobs:\n  build-wheels:\n    uses: ./.github/workflows/_build.yml\n    with:\n      runs-on: ${{ inputs.runs-on || 'ubuntu-22.04' }}\n      python-version: ${{ inputs.python-version || '3.12' }}\n      cuda-version: ${{ inputs.cuda-version || '12.9.1' }}\n      torch-version: ${{ inputs.torch-version || '2.10.0' }}\n      cxx11_abi: ${{ inputs.cxx11_abi || 'TRUE' }}\n      upload-to-release: ${{ inputs.upload-to-release || false }}\n      release-version: ${{ inputs.release-version || 'v2.2.6.post3' }}\n"
  },
  {
    "path": ".github/workflows/build_in_container.yml",
    "content": "name: Build wheels in a container\n\non:\n  workflow_dispatch:\n    inputs:\n      runs-on:\n        description: \"The runner to use for the build\"\n        required: true\n        type: string\n        default: ubuntu-22.04\n      container-image:\n        description: \"Container image\"\n        required: true\n        type: string\n      upload-to-release:\n        description: \"Upload wheel to this release\"\n        required: false\n        type: boolean\n        default: false\n      release-version:\n        description: \"Release version tag to checkout and upload to\"\n        required: false\n        type: string\n\n  push:\n    tags-ignore:\n      - v*\n\njobs:\n  get_version:\n    runs-on: ubuntu-latest\n    outputs:\n      version: ${{ steps.get_version.outputs.version }}\n    steps:\n      - name: Get version from input or git\n        id: get_version\n        run: |\n          if [ -n \"${{ inputs.release-version }}\" ]; then\n            echo \"version=${{ inputs.release-version }}\" >> $GITHUB_OUTPUT\n          else\n            # Get the latest tag from the repo\n            git clone --filter=blob:none --no-checkout $GITHUB_SERVER_URL/$GITHUB_REPOSITORY.git repo\n            cd repo\n            echo \"version=$(git describe --tags --abbrev=0)\" >> $GITHUB_OUTPUT\n          fi\n        shell: bash\n\n  check_for_ngc_images:\n    runs-on: ubuntu-latest\n    outputs:\n      images: ${{ steps.check_for_ngc_images.outputs.IMAGES }}\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v4\n\n      - name: Check for NGC PyTorch images\n        id: check_for_ngc_images\n        run: |\n          bash ./.github/scripts/check_for_ngc_images.sh\n          echo \"IMAGES=$(cat ngc_images.json| jq -cr)\" >> $GITHUB_OUTPUT\n\n  build-wheels:\n    needs: [get_version, check_for_ngc_images]\n    uses: ./.github/workflows/_build_in_container.yml\n    strategy:\n      fail-fast: false\n      matrix:\n        container-image: ${{ fromJson(needs.check_for_ngc_images.outputs.images) }}\n    with:\n      runs-on: ${{ inputs.runs-on || 'ubuntu-22.04' }}\n      container-image: ${{ matrix.container-image }}\n      upload-to-release: ${{ inputs.upload-to-release || false }}\n      release-version: ${{ needs.get_version.outputs.version }}\n"
  },
  {
    "path": ".github/workflows/publish.yaml",
    "content": "# This workflow will:\n# - Create a new Github release\n# - Build wheels for supported architectures\n# - Deploy the wheels to the Github release\n# - Release the static code to PyPi\n# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries\n\nname: Build wheels and deploy\n\non:\n  push:\n    tags:\n      - v*\n\njobs:\n  setup_release:\n    name: Create Release\n    runs-on: ubuntu-latest\n    outputs:\n      release-version: ${{ steps.extract_branch.outputs.branch }}\n    steps:\n      - name: Get the tag version\n        id: extract_branch\n        run: echo \"branch=${GITHUB_REF#refs/tags/}\" >> $GITHUB_OUTPUT\n        shell: bash\n\n      - name: Create Release\n        env:\n          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}\n        run: gh release create ${{ steps.extract_branch.outputs.branch }} --repo $GITHUB_REPOSITORY --title ${{ steps.extract_branch.outputs.branch }} --generate-notes\n        shell: bash\n\n  build_wheels:\n    name: Build Wheel\n    needs: setup_release\n\n    strategy:\n      fail-fast: false\n      matrix:\n        # Using ubuntu-22.04 instead of 24.04 for more compatibility (glibc). Ideally we'd use the\n        # manylinux docker image, but I haven't figured out how to install CUDA on manylinux.\n        os: [ubuntu-22.04, ubuntu-22.04-arm]\n        python-version: [\"3.10\", \"3.11\", \"3.12\", \"3.13\"]\n        torch-version: [\"2.6.0\", \"2.7.1\", \"2.8.0\", \"2.9.1\", \"2.10.0\"]\n        cuda-version: [\"11.8.0\", \"12.9.1\", \"13.0.1\"]\n        # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.\n        # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.\n        # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs)\n        # when building without C++11 ABI and using it on nvcr images.\n        cxx11_abi: [\"FALSE\", \"TRUE\"]\n        exclude:\n          # CUDA 11.8 is not supported by PyTorch 2.8+\n          - torch-version: \"2.8.0\"\n            cuda-version: \"11.8.0\"\n          - torch-version: \"2.9.1\"\n            cuda-version: \"11.8.0\"\n          - torch-version: \"2.10.0\"\n            cuda-version: \"11.8.0\"\n          # CUDA 13.0 is only supported by PyTorch 2.9+\n          - torch-version: \"2.6.0\"\n            cuda-version: \"13.0.1\"\n          - torch-version: \"2.7.1\"\n            cuda-version: \"13.0.1\"\n          - torch-version: \"2.8.0\"\n            cuda-version: \"13.0.1\"\n          # No aarch64 PyTorch wheels for 2.6.0, or 2.7.1+cu118\n          - torch-version: \"2.6.0\"\n            os: ubuntu-22.04-arm\n          - torch-version: \"2.7.1\"\n            cuda-version: \"11.8.0\"\n            os: ubuntu-22.04-arm\n          # PyTorch 2.7+ pip wheels use CXX11_ABI=1 by default, no need for FALSE\n          - torch-version: \"2.7.1\"\n            cxx11_abi: \"FALSE\"\n          - torch-version: \"2.8.0\"\n            cxx11_abi: \"FALSE\"\n          - torch-version: \"2.9.1\"\n            cxx11_abi: \"FALSE\"\n          - torch-version: \"2.10.0\"\n            cxx11_abi: \"FALSE\"\n    uses: ./.github/workflows/_build.yml\n    with:\n      runs-on: ${{ matrix.os }}\n      python-version: ${{ matrix.python-version }}\n      cuda-version: ${{ matrix.cuda-version }}\n      torch-version: ${{ matrix.torch-version }}\n      cxx11_abi: ${{ matrix.cxx11_abi }}\n      release-version: ${{ needs.setup_release.outputs.release-version }}\n      upload-to-release: true\n\n  check_for_ngc_images:\n    runs-on: ubuntu-latest\n    outputs:\n      images: ${{ steps.check_for_ngc_images.outputs.IMAGES }}\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v4\n\n      - name: Check for NGC PyTorch images\n        id: check_for_ngc_images\n        run: |\n          bash ./.github/scripts/check_for_ngc_images.sh\n          echo \"IMAGES=$(cat ngc_images.json| jq -cr)\" | tee -a $GITHUB_OUTPUT\n\n  build_ngc_wheels:\n    name: Build Wheel for NGC PyTorch\n    needs: [setup_release, check_for_ngc_images]\n    strategy:\n      fail-fast: false\n      matrix:\n        os: [ubuntu-22.04, ubuntu-22.04-arm]\n        container-image: ${{ fromJson(needs.check_for_ngc_images.outputs.images) }}\n    uses: ./.github/workflows/_build_in_container.yml\n    with:\n      runs-on: ${{ matrix.os }}\n      container-image: ${{ matrix.container-image }}\n      release-version: ${{ needs.setup_release.outputs.release-version }}\n      upload-to-release: true\n\n  publish_package:\n    name: Publish package\n    needs: [build_wheels]\n    if: always() && !cancelled()\n\n    runs-on: ubuntu-latest\n\n    steps:\n      - uses: actions/checkout@v4\n\n      - uses: actions/setup-python@v5\n        with:\n          python-version: \"3.10\"\n\n      - name: Install dependencies\n        run: |\n          pip install ninja packaging setuptools wheel twine\n          # We don't want to download anything CUDA-related here\n          pip install torch --index-url https://download.pytorch.org/whl/cpu\n\n      - name: Build core package\n        env:\n          MAMBA_SKIP_CUDA_BUILD: \"TRUE\"\n        run: |\n          python setup.py sdist --dist-dir=dist\n\n      - name: Deploy\n        env:\n          TWINE_USERNAME: \"__token__\"\n          TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}\n        run: |\n          python -m twine upload dist/*\n"
  },
  {
    "path": ".gitignore",
    "content": "*__pycache__/\n*.egg-info/\nbuild/\n**.so\n*.hip\n*_hip.*"
  },
  {
    "path": ".gitmodules",
    "content": "[submodule \"3rdparty/lm-evaluation-harness\"]\n\tpath = 3rdparty/lm-evaluation-harness\n\turl = https://github.com/EleutherAI/lm-evaluation-harness/\n"
  },
  {
    "path": "AUTHORS",
    "content": "Tri Dao, tri@tridao.me\nAlbert Gu, agu@andrew.cmu.edu\n"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright 2023 Tri Dao, Albert Gu\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "MANIFEST.in",
    "content": "recursive-include csrc *\nrecursive-include csrc *\nREADME.md\n"
  },
  {
    "path": "README.md",
    "content": "# Mamba\n\n![Mamba](assets/selection.png \"Selective State Space\")\n> **Mamba: Linear-Time Sequence Modeling with Selective State Spaces**\\\n> Albert Gu*, Tri Dao*\\\n> Paper: https://arxiv.org/abs/2312.00752\n\n![Mamba-2](assets/ssd_algorithm.png \"State Space Dual Model\")\n> **Transformers are SSMs: Generalized Models and Efficient Algorithms**\\\n>     **Through Structured State Space Duality**\\\n> Tri Dao*, Albert Gu*\\\n> Paper: https://arxiv.org/abs/2405.21060\n\n![Mamba-3](assets/mamba3.png \"Inference-first State Space Model\")\n> **Mamba-3: Improved Sequence Modeling using State Space Principles**\\\n>     **Through Structured State Space Duality**\\\n> Aakash Lahoti*, Kevin Y. Li*, Berlin Chen*, Caitlin Wang*, Aviv Bick, J. Zico Kolter, Tri Dao†, Albert Gu†\\\n> Paper: https://arxiv.org/abs/2603.15569\n\n## About\n\nMamba is a new state space model architecture showing promising performance on information-dense data such as language modeling, where previous subquadratic models fall short of Transformers.\nIt is based on the line of progress on [structured state space models](https://github.com/state-spaces/s4),\nwith an efficient hardware-aware design and implementation in the spirit of [FlashAttention](https://github.com/Dao-AILab/flash-attention).\n\n## Installation\n\nInstall PyTorch first, then:\n- [Option] `pip install causal-conv1d>=1.4.0 --no-build-isolation`: an efficient implementation of a simple causal Conv1d layer used inside the Mamba block.\n- `pip install mamba-ssm --no-build-isolation`: the core Mamba package.\n- `pip install mamba-ssm[causal-conv1d] --no-build-isolation`: To install core Mamba package and causal-conv1d.\n\n`--no-build-isolation` is required so that pip uses your existing CUDA-enabled PyTorch instead of installing torch-cpu in an isolated build environment.\n\nOther requirements:\n- Linux\n- NVIDIA GPU\n- PyTorch 1.12+\n- CUDA 11.6+\n\nFor AMD cards, see additional prerequisites below.\n\n## Usage\n\nWe expose several levels of interface with the Mamba model.\n\n### Selective SSM\n\nMamba is based on a selective SSM layer, which is the focus of the paper (Section 3; Algorithm 2).\n\nSource: [ops/selective_scan_interface.py](mamba_ssm/ops/selective_scan_interface.py).\n\n### Mamba Block\n\nThe main module of this repository is the Mamba architecture block wrapping the selective SSM.\n\nSource: [modules/mamba_simple.py](mamba_ssm/modules/mamba_simple.py).\n\nUsage:\n``` python\nimport torch\nfrom mamba_ssm import Mamba\n\nbatch, length, dim = 2, 64, 16\nx = torch.randn(batch, length, dim).to(\"cuda\")\nmodel = Mamba(\n    # This module uses roughly 3 * expand * d_model^2 parameters\n    d_model=dim, # Model dimension d_model\n    d_state=16,  # SSM state expansion factor\n    d_conv=4,    # Local convolution width\n    expand=2,    # Block expansion factor\n).to(\"cuda\")\ny = model(x)\nassert y.shape == x.shape\n```\n\n### Mamba-2\n\nThe Mamba-2 block is implemented at [modules/mamba2.py](mamba_ssm/modules/mamba2.py).\n\nA simpler version is at [modules/mamba2_simple.py](mamba_ssm/modules/mamba2_simple.py)\n\nThe usage is similar to Mamba(-1):\n``` python\nfrom mamba_ssm import Mamba2\nmodel = Mamba2(\n    # This module uses roughly 3 * expand * d_model^2 parameters\n    d_model=dim, # Model dimension d_model\n    d_state=64,  # SSM state expansion factor, typically 64 or 128\n    d_conv=4,    # Local convolution width\n    expand=2,    # Block expansion factor\n).to(\"cuda\")\ny = model(x)\nassert y.shape == x.shape\n```\n\n#### SSD\n\nA minimal version of the inner SSD module (Listing 1 from the Mamba-2 paper) with conversion between \"discrete\" and \"continuous\" SSM versions\nis at [modules/ssd_minimal.py](mamba_ssm/modules/ssd_minimal.py).\n\n### Mamba Language Model\n\nFinally, we provide an example of a complete language model: a deep sequence model backbone (with repeating Mamba blocks) + language model head.\n\nSource: [models/mixer_seq_simple.py](mamba_ssm/models/mixer_seq_simple.py).\n\nThis is an example of how to integrate Mamba into an end-to-end neural network.\nThis example is used in the generation scripts below.\n\n\n## Pretrained Models\n\nPretrained models are uploaded to\n[Hugging Face](https://huggingface.co/state-spaces): `mamba-130m`, `mamba-370m`,\n`mamba-790m`, `mamba-1.4b`, `mamba-2.8b`, `mamba2-130m`, `mamba2-370m`,\n`mamba2-780m`, `mamba2-1.3b`, `mamba2-2.7b`, `transformerpp-2.7b`, `mamba2attn-2.7b`, trained on 300B tokens on the Pile, as well as `mamba-2.8b-slimpj`\n(trained on 600B tokens on the SlimPajama dataset).\n\n\nThe models will be autodownloaded by the generation script below.\n\nThese models were trained on the [Pile](https://huggingface.co/datasets/EleutherAI/pile), and follow the standard model dimensions described by GPT-3 and followed by many open source models:\n\n| Parameters | Layers | Model dim. | \n|------------|--------|------------|\n| 130M       | 24     | 768        |\n| 370M       | 48     | 1024       |\n| 790M       | 48     | 1536       |\n| 1.4B       | 48     | 2048       |\n| 2.8B       | 64     | 2560       |\n\n(The layer count of Mamba doubles that of a Transformer with similar size, as two Mamba blocks are needed for each \"layer\" (MHA block + MLP block) of a Transformer.)\n\nNote: these are base models trained only for 300B tokens, without any form of downstream modification (instruction tuning, etc.).\nPerformance is expected to be comparable or better than other architectures trained on similar data, but not to match larger or fine-tuned models.\n\n\n## Evaluations\n\nTo run zero-shot evaluations of models (corresponding to Table 3 of the paper),\nwe use the\n[lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness)\nlibrary.\n\n1. Install `lm-evaluation-harness` by `pip install lm-eval==0.4.2`.\n2. Run evaluation with (more documentation at the [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor) repo):\n``` sh\nlm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-130m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256\npython evals/lm_harness_eval.py --model hf --model_args pretrained=EleutherAI/pythia-160m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64\n```\n\nTo reproduce the results on the `mamba-2.8b-slimpj` model reported in the blogposts:\n``` sh\nlm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks boolq,piqa,hellaswag,winogrande,arc_easy,arc_challenge,openbookqa,race,truthfulqa_mc2 --device cuda --batch_size 256\nlm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks mmlu --num_fewshot 5 --device cuda --batch_size 256\n```\n\nTo run evaluations on Mamba-2 models, simply replace the model names:\n``` sh\nlm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba2-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256\nlm_eval --model mamba_ssm --model_args pretrained=state-spaces/transformerpp-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256\nlm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba2attn-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256\n```\n\nNote that the result of each task might differ from reported values by 0.1-0.3 due to noise in the evaluation process.\n\n## Inference\n\nThe script [benchmarks/benchmark_generation_mamba_simple.py](benchmarks/benchmark_generation_mamba_simple.py)\n1. autoloads a model from the Hugging Face Hub,\n2. generates completions of a user-specified prompt,\n3. benchmarks the inference speed of this generation.\n\nOther configurable options include the top-p (nucleus sampling) probability, and the softmax temperature.\n\n### Examples\n\nTo test generation latency (e.g. batch size = 1) with different sampling strategies:\n\n``` sh\npython benchmarks/benchmark_generation_mamba_simple.py --model-name \"state-spaces/mamba-2.8b\" --prompt \"My cat wrote all this CUDA code for a new language model and\" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2\npython benchmarks/benchmark_generation_mamba_simple.py --model-name \"EleutherAI/pythia-2.8b\" --prompt \"My cat wrote all this CUDA code for a new language model and\" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2\npython benchmarks/benchmark_generation_mamba_simple.py --model-name \"state-spaces/mamba-2.8b\" --prompt \"My cat wrote all this CUDA code for a new language model and\" --minp 0.05 --topk 0 --temperature 0.7 --repetition-penalty 1.2\n```\n\nTo test generation throughput with random prompts (e.g. large batch size):\n``` sh\npython benchmarks/benchmark_generation_mamba_simple.py --model-name \"state-spaces/mamba-2.8b\" --batch 64\npython benchmarks/benchmark_generation_mamba_simple.py --model-name \"EleutherAI/pythia-2.8b\" --batch 64\n```\n\nWith Mamba-2, you just need to change the model name:\n``` sh\npython benchmarks/benchmark_generation_mamba_simple.py --model-name \"state-spaces/mamba2-2.7b\" --prompt \"My cat wrote all this CUDA code for a new language model and\" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2\n```\n\n\n## Troubleshooting\n\n### Precision\nOur models were trained using PyTorch [AMP](https://pytorch.org/docs/stable/amp.html) for mixed precision. AMP keeps model parameters in float32 and casts to half precision when necessary.\nOn the other hand, other frameworks like DeepSpeed store parameters in float16 and upcasts when necessary (e.g. for optimizer accumulation).\n\nWe've observed that higher precision for the main model parameters may be necessary, because SSMs are sensitive to their recurrent dynamics. If you are experiencing instabilities,\nas a first step please try a framework storing parameters in fp32 (such as AMP).\n\n### Initialization\nSome parts of the model have initializations inherited from prior work on S4 models.\nFor [example](https://github.com/state-spaces/mamba/blob/f0affcf69f06d1d06cef018ff640bf080a11c421/mamba_ssm/modules/mamba_simple.py#L102), the $\\Delta$ parameter has a targeted range by initializing the bias of its linear projection.\nHowever, some frameworks may have post-initialization hooks (e.g. setting all bias terms in `nn.Linear` modules to zero).\nIf this is the case, you may have to add custom logic (e.g. this [line](https://github.com/state-spaces/mamba/blob/f0affcf69f06d1d06cef018ff640bf080a11c421/mamba_ssm/modules/mamba_simple.py#L104) turns off re-initializing in our trainer, but would be a no-op in any other framework)\nthat is specific to the training framework.\n\n## Additional Prerequisites for AMD cards\n\n### Patching ROCm\n\nIf you are on ROCm 6.0, run the following steps to avoid errors during compilation. This is not required for ROCm 6.1 onwards.\n\n1. Locate your ROCm installation directory. This is typically found at `/opt/rocm/`, but may vary depending on your installation.\n\n2. Apply the Patch. Run with `sudo` in case you encounter permission issues.\n   ```bash\n    patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h < rocm_patch/rocm6_0.patch \n   ```\n\n\n## Citation\n\nIf you use this codebase, or otherwise find our work valuable, please cite Mamba:\n```\n@article{mamba,\n  title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces},\n  author={Gu, Albert and Dao, Tri},\n  journal={arXiv preprint arXiv:2312.00752},\n  year={2023}\n}\n\n@inproceedings{mamba2,\n  title={Transformers are {SSM}s: Generalized Models and Efficient Algorithms Through Structured State Space Duality},\n  author={Dao, Tri and Gu, Albert},\n  booktitle={International Conference on Machine Learning (ICML)},\n  year={2024}\n}\n\n@misc{lahoti2026mamba3improvedsequencemodeling,\n      title={Mamba-3: Improved Sequence Modeling using State Space Principles}, \n      author={Aakash Lahoti and Kevin Y. Li and Berlin Chen and Caitlin Wang and Aviv Bick and J. Zico Kolter and Tri Dao and Albert Gu},\n      year={2026},\n      eprint={2603.15569},\n      archivePrefix={arXiv},\n      primaryClass={cs.LG},\n      url={https://arxiv.org/abs/2603.15569}, \n}\n```\n"
  },
  {
    "path": "benchmarks/benchmark_generation_mamba_simple.py",
    "content": "# Copyright (c) 2023, Tri Dao, Albert Gu.\n\nimport argparse\nimport time\nimport json\n\nimport torch\nimport torch.nn.functional as F\n\nfrom einops import rearrange\n\nfrom transformers import AutoTokenizer, AutoModelForCausalLM\n\nfrom mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel\n\n\nparser = argparse.ArgumentParser(description=\"Generation benchmarking\")\nparser.add_argument(\"--model-name\", type=str, default=\"state-spaces/mamba-130m\")\nparser.add_argument(\"--prompt\", type=str, default=None)\nparser.add_argument(\"--promptlen\", type=int, default=100)\nparser.add_argument(\"--genlen\", type=int, default=100)\nparser.add_argument(\"--temperature\", type=float, default=1.0)\nparser.add_argument(\"--topk\", type=int, default=1)\nparser.add_argument(\"--topp\", type=float, default=1.0)\nparser.add_argument(\"--minp\", type=float, default=0.0)\nparser.add_argument(\"--repetition-penalty\", type=float, default=1.0)\nparser.add_argument(\"--batch\", type=int, default=1)\nargs = parser.parse_args()\n\nrepeats = 3\ndevice = \"cuda\"\ndtype = torch.float16\n\nprint(f\"Loading model {args.model_name}\")\nis_mamba = args.model_name.startswith(\"state-spaces/mamba\") or args.model_name.startswith(\"state-spaces/transformerpp\")\nif is_mamba:\n    tokenizer = AutoTokenizer.from_pretrained(\"EleutherAI/gpt-neox-20b\")\n    model = MambaLMHeadModel.from_pretrained(args.model_name, device=device, dtype=dtype)\nelse:\n    tokenizer = AutoTokenizer.from_pretrained(args.model_name)\n    model = AutoModelForCausalLM.from_pretrained(args.model_name, device_map={\"\": device}, torch_dtype=dtype)\nmodel.eval()\nprint(f\"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}\")\n\ntorch.random.manual_seed(0)\nif args.prompt is None:\n    input_ids = torch.randint(1, 1000, (args.batch, args.promptlen), dtype=torch.long, device=\"cuda\")\n    attn_mask = torch.ones_like(input_ids, dtype=torch.long, device=\"cuda\")\nelse:\n    tokens = tokenizer(args.prompt, return_tensors=\"pt\")\n    input_ids = tokens.input_ids.to(device=device)\n    attn_mask = tokens.attention_mask.to(device=device)\nmax_length = input_ids.shape[1] + args.genlen\n\nif is_mamba:\n    fn = lambda: model.generate(\n        input_ids=input_ids,\n        max_length=max_length,\n        cg=True,\n        return_dict_in_generate=True,\n        output_scores=True,\n        enable_timing=False,\n        temperature=args.temperature,\n        top_k=args.topk,\n        top_p=args.topp,\n        min_p=args.minp,\n        repetition_penalty=args.repetition_penalty,\n    )\nelse:\n    fn = lambda: model.generate(\n        input_ids=input_ids,\n        attention_mask=attn_mask,\n        max_length=max_length,\n        return_dict_in_generate=True,\n        pad_token_id=tokenizer.eos_token_id,\n        do_sample=True,\n        temperature=args.temperature,\n        top_k=args.topk,\n        top_p=args.topp,\n        repetition_penalty=args.repetition_penalty,\n    )\nout = fn()\nif args.prompt is not None:\n    print(tokenizer.batch_decode(out.sequences.tolist()))\n\ntorch.cuda.synchronize()\nstart = time.time()\nfor _ in range(repeats):\n    fn()\ntorch.cuda.synchronize()\nprint(f\"Prompt length: {len(input_ids[0])}, generation length: {len(out.sequences[0]) - len(input_ids[0])}\")\nprint(f\"{args.model_name} prompt processing + decoding time: {(time.time() - start) / repeats * 1000:.0f}ms\")\n"
  },
  {
    "path": "csrc/selective_scan/reverse_scan.cuh",
    "content": "/******************************************************************************\n * Copyright (c) 2023, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n#ifndef USE_ROCM\n    #include <cub/config.cuh>\n    \n    #include <cub/util_ptx.cuh>\n    #include <cub/util_type.cuh>\n    #include <cub/block/block_raking_layout.cuh>\n    // #include <cub/detail/uninitialized_copy.cuh>\n#else\n    #include <hipcub/hipcub.hpp>\n    namespace cub = hipcub;\n#endif\n#include \"uninitialized_copy.cuh\"\n\n/**\n * Perform a reverse sequential reduction over \\p LENGTH elements of the \\p input array.  The aggregate is returned.\n */\ntemplate <\n    int         LENGTH,\n    typename    T,\n    typename    ReductionOp>\n__device__ __forceinline__ T ThreadReverseReduce(const T (&input)[LENGTH], ReductionOp reduction_op) {\n    static_assert(LENGTH > 0);\n    T retval = input[LENGTH - 1];\n    #pragma unroll\n    for (int i = LENGTH - 2; i >= 0; --i) { retval = reduction_op(retval, input[i]); }\n    return retval;\n}\n\n/**\n * Perform a sequential inclusive postfix reverse scan over the statically-sized \\p input array, seeded with the specified \\p postfix.  The aggregate is returned.\n */\ntemplate <\n    int         LENGTH,\n    typename    T,\n    typename    ScanOp>\n__device__ __forceinline__ T ThreadReverseScanInclusive(\n    const T (&input)[LENGTH],\n    T (&output)[LENGTH],\n    ScanOp scan_op,\n    const T postfix)\n{\n    T inclusive = postfix;\n    #pragma unroll\n    for (int i = LENGTH - 1; i >= 0; --i) {\n        inclusive = scan_op(inclusive, input[i]);\n        output[i] = inclusive;\n    }\n    return inclusive; \n}\n\n/**\n * Perform a sequential exclusive postfix reverse scan over the statically-sized \\p input array, seeded with the specified \\p postfix.  The aggregate is returned.\n */\ntemplate <\n    int         LENGTH,\n    typename    T,\n    typename    ScanOp>\n__device__ __forceinline__ T ThreadReverseScanExclusive(\n    const T (&input)[LENGTH],\n    T (&output)[LENGTH],\n    ScanOp scan_op,\n    const T postfix)\n{\n    // Careful, output maybe be aliased to input\n    T exclusive = postfix;\n    T inclusive;\n    #pragma unroll\n    for (int i = LENGTH - 1; i >= 0; --i) {\n        inclusive = scan_op(exclusive, input[i]);\n        output[i] = exclusive;\n        exclusive = inclusive;\n    }\n    return inclusive;\n}\n\n\n/**\n * \\brief WarpReverseScan provides SHFL-based variants of parallel postfix scan of items partitioned across a CUDA thread warp.\n *\n * LOGICAL_WARP_THREADS must be a power-of-two\n */\ntemplate <\n    typename    T,                      ///< Data type being scanned\n    int         LOGICAL_WARP_THREADS    ///< Number of threads per logical warp\n    >\nstruct WarpReverseScan {\n    //---------------------------------------------------------------------\n    // Constants and type definitions\n    //---------------------------------------------------------------------\n\n    /// Whether the logical warp size and the PTX warp size coincide\n\n    // In hipcub, warp_threads is defined as HIPCUB_WARP_THREADS ::rocprim::warp_size()\n    // While in cub, it's defined as a macro that takes a redundant unused argument.\n    #ifndef USE_ROCM\n        #define WARP_THREADS CUB_WARP_THREADS(0)\n    #else\n        // ROCm 7.0+: HIPCUB_WARP_THREADS (rocprim::warp_size()) is no longer constexpr.\n        // We need a compile-time constant for IS_ARCH_WARP below.\n        // See: https://rocm.docs.amd.com/en/latest/about/release-notes.html\n        #if defined(__AMDGCN_WAVEFRONT_SIZE)\n            // Deprecated but still available and constexpr in ROCm 7.x\n            #define WARP_THREADS __AMDGCN_WAVEFRONT_SIZE\n        #elif defined(__gfx942__) || defined(__gfx941__) || defined(__gfx940__)\n            // AMD Instinct MI300 series (CDNA3) - 64-wide wavefronts\n            #define WARP_THREADS 64\n        #elif defined(__gfx90a__)\n            // AMD Instinct MI200 series (CDNA2) - 64-wide wavefronts\n            #define WARP_THREADS 64\n        #elif defined(__gfx908__)\n            // AMD Instinct MI100 (CDNA1) - 64-wide wavefronts\n            #define WARP_THREADS 64\n        #elif defined(__gfx906__) || defined(__gfx900__)\n            // AMD Instinct MI50/MI60 (Vega) - 64-wide wavefronts\n            #define WARP_THREADS 64\n        #elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__)\n            // AMD Radeon RX 7000 series (RDNA3) - 32-wide wavefronts\n            #define WARP_THREADS 32\n        #elif defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1034__)\n            // AMD Radeon RX 6000 series (RDNA2) - 32-wide wavefronts\n            #define WARP_THREADS 32\n        #elif defined(__gfx1010__) || defined(__gfx1011__) || defined(__gfx1012__)\n            // AMD Radeon RX 5000 series (RDNA1) - 32-wide wavefronts\n            #define WARP_THREADS 32\n        #else\n            // Unknown architecture - default to 64 (CDNA/GCN)\n            // This may not be optimal for RDNA GPUs\n            #pragma message(\"Warning: Unknown AMD GPU architecture. Defaulting WARP_THREADS to 64. \" \\\n                            \"For RDNA GPUs (gfx10xx/gfx11xx), this should be 32.\")\n            #define WARP_THREADS 64\n        #endif\n    #endif\n    static constexpr bool IS_ARCH_WARP = (LOGICAL_WARP_THREADS == WARP_THREADS);\n    /// The number of warp scan steps\n    static constexpr int STEPS = cub::Log2<LOGICAL_WARP_THREADS>::VALUE;\n    static_assert(LOGICAL_WARP_THREADS == 1 << STEPS);\n\n\n    //---------------------------------------------------------------------\n    // Thread fields\n    //---------------------------------------------------------------------\n\n    /// Lane index in logical warp\n    unsigned int lane_id;\n\n    /// Logical warp index in 32-thread physical warp\n    unsigned int warp_id;\n\n    /// 32-thread physical warp member mask of logical warp\n    unsigned int member_mask;\n\n    //---------------------------------------------------------------------\n    // Construction\n    //---------------------------------------------------------------------\n\n    /// Constructor\n    explicit __device__ __forceinline__\n    WarpReverseScan()\n#ifndef USE_ROCM\n        : lane_id(threadIdx.x & 0x1f)  // CUDA: 32-thread warps, mask = 31\n#else\n        : lane_id(threadIdx.x & (WARP_THREADS - 1))  // ROCm: use actual wavefront size (64 or 32)\n#endif\n        , warp_id(IS_ARCH_WARP ? 0 : (lane_id / LOGICAL_WARP_THREADS))\n        , member_mask(cub::WarpMask<LOGICAL_WARP_THREADS>(warp_id))\n    {\n        if (!IS_ARCH_WARP) {\n            lane_id = lane_id % LOGICAL_WARP_THREADS;\n        }\n    }\n\n\n    /// Broadcast\n    __device__ __forceinline__ T Broadcast(\n        T               input,              ///< [in] The value to broadcast\n        int             src_lane)           ///< [in] Which warp lane is to do the broadcasting\n    {\n        return cub::ShuffleIndex<LOGICAL_WARP_THREADS>(input, src_lane, member_mask);\n    }\n\n\n    /// Inclusive scan\n    template <typename ScanOpT>\n    __device__ __forceinline__ void InclusiveReverseScan(\n        T               input,              ///< [in] Calling thread's input item.\n        T               &inclusive_output,  ///< [out] Calling thread's output item.  May be aliased with \\p input.\n        ScanOpT         scan_op)            ///< [in] Binary scan operator\n    {\n        inclusive_output = input;\n        #pragma unroll\n        for (int STEP = 0; STEP < STEPS; STEP++) {\n            int offset = 1 << STEP;\n            T temp = cub::ShuffleDown<LOGICAL_WARP_THREADS>(\n                inclusive_output, offset, LOGICAL_WARP_THREADS - 1, member_mask\n            );\n            // Perform scan op if from a valid peer\n            inclusive_output = static_cast<int>(lane_id) >= LOGICAL_WARP_THREADS - offset\n                ? inclusive_output : scan_op(temp, inclusive_output);\n        }\n    }\n\n    /// Exclusive scan\n    // Get exclusive from inclusive\n    template <typename ScanOpT>\n    __device__ __forceinline__ void ExclusiveReverseScan(\n        T              input,              ///< [in] Calling thread's input item.\n        T              &exclusive_output,  ///< [out] Calling thread's output item.  May be aliased with \\p input.\n        ScanOpT        scan_op,            ///< [in] Binary scan operator\n        T              &warp_aggregate)    ///< [out] Warp-wide aggregate reduction of input items.\n    {\n        T inclusive_output;\n        InclusiveReverseScan(input, inclusive_output, scan_op);\n        warp_aggregate = cub::ShuffleIndex<LOGICAL_WARP_THREADS>(inclusive_output, 0, member_mask);\n        // initial value unknown\n        exclusive_output = cub::ShuffleDown<LOGICAL_WARP_THREADS>(\n            inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask\n        );\n    }\n\n    /**\n     * \\brief Computes both inclusive and exclusive reverse scans using the specified binary scan functor across the calling warp.  Because no initial value is supplied, the \\p exclusive_output computed for the last <em>warp-lane</em> is undefined.\n     */\n    template <typename ScanOpT>\n    __device__ __forceinline__ void ReverseScan(\n        T               input,              ///< [in] Calling thread's input item.\n        T               &inclusive_output,  ///< [out] Calling thread's inclusive-scan output item.\n        T               &exclusive_output,  ///< [out] Calling thread's exclusive-scan output item.\n        ScanOpT         scan_op)            ///< [in] Binary scan operator\n    {\n        InclusiveReverseScan(input, inclusive_output, scan_op);\n        // initial value unknown\n        exclusive_output = cub::ShuffleDown<LOGICAL_WARP_THREADS>(\n            inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask\n        );\n    }\n\n};\n\n/**\n * \\brief BlockReverseScan provides variants of raking-based parallel postfix scan across a CUDA thread block.\n */\ntemplate <\n    typename    T,              ///< Data type being scanned\n    int         BLOCK_DIM_X,    ///< The thread block length in threads along the X dimension\n    bool        MEMOIZE=false   ///< Whether or not to buffer outer raking scan partials to incur fewer shared memory reads at the expense of higher register pressure\n    >\nstruct BlockReverseScan {\n    //---------------------------------------------------------------------\n    // Types and constants\n    //---------------------------------------------------------------------\n\n    /// Constants\n    /// The thread block size in threads\n    static constexpr int BLOCK_THREADS = BLOCK_DIM_X;\n\n    /// Layout type for padded thread block raking grid\n    using BlockRakingLayout = cub::BlockRakingLayout<T, BLOCK_THREADS>;\n    // The number of reduction elements is not a multiple of the number of raking threads for now\n    static_assert(BlockRakingLayout::UNGUARDED);\n\n    /// Number of raking threads\n    static constexpr int RAKING_THREADS = BlockRakingLayout::RAKING_THREADS;\n    /// Number of raking elements per warp synchronous raking thread\n    static constexpr int SEGMENT_LENGTH = BlockRakingLayout::SEGMENT_LENGTH;\n    /// Cooperative work can be entirely warp synchronous\n    static constexpr bool WARP_SYNCHRONOUS = (int(BLOCK_THREADS) == int(RAKING_THREADS));\n\n    ///  WarpReverseScan utility type\n    using WarpReverseScan = WarpReverseScan<T, RAKING_THREADS>;\n\n    /// Shared memory storage layout type\n    struct _TempStorage {\n        typename BlockRakingLayout::TempStorage raking_grid;     ///< Padded thread block raking grid\n    };\n\n\n    /// Alias wrapper allowing storage to be unioned\n    struct TempStorage : cub::Uninitialized<_TempStorage> {};\n\n\n    //---------------------------------------------------------------------\n    // Per-thread fields\n    //---------------------------------------------------------------------\n\n    // Thread fields\n    _TempStorage    &temp_storage;\n    unsigned int    linear_tid;\n    T               cached_segment[SEGMENT_LENGTH];\n\n\n    //---------------------------------------------------------------------\n    // Utility methods\n    //---------------------------------------------------------------------\n\n    /// Performs upsweep raking reduction, returning the aggregate\n    template <typename ScanOp>\n    __device__ __forceinline__ T Upsweep(ScanOp scan_op) {\n        T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid);\n        // Read data into registers\n        #pragma unroll\n        for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; }\n        T raking_partial = cached_segment[SEGMENT_LENGTH - 1];\n        #pragma unroll\n        for (int i = SEGMENT_LENGTH - 2; i >= 0; --i) {\n            raking_partial = scan_op(raking_partial, cached_segment[i]);\n        }\n        return raking_partial;\n    }\n\n\n    /// Performs exclusive downsweep raking scan\n    template <typename ScanOp>\n    __device__ __forceinline__ void ExclusiveDownsweep(\n        ScanOp          scan_op,\n        T               raking_partial)\n    {\n        T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid);\n        // Read data back into registers\n        if (!MEMOIZE) {\n            #pragma unroll\n            for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; }\n        }\n        ThreadReverseScanExclusive(cached_segment, cached_segment, scan_op, raking_partial);\n        // Write data back to smem\n        #pragma unroll\n        for (int i = 0; i < SEGMENT_LENGTH; ++i) { smem_raking_ptr[i] = cached_segment[i]; }\n    }\n\n\n    //---------------------------------------------------------------------\n    // Constructors\n    //---------------------------------------------------------------------\n\n    /// Constructor\n    __device__ __forceinline__ BlockReverseScan(\n        TempStorage &temp_storage)\n    :\n        temp_storage(temp_storage.Alias()),\n        linear_tid(cub::RowMajorTid(BLOCK_DIM_X, 1, 1))\n    {}\n\n\n    /// Computes an exclusive thread block-wide postfix scan using the specified binary \\p scan_op functor.  Each thread contributes one input element.  the call-back functor \\p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by <em>lane</em><sub>0</sub> in that warp is used as the \"seed\" value that logically postfixes the thread block's scan inputs.  Also provides every thread with the block-wide \\p block_aggregate of all inputs.\n    template <\n        typename ScanOp,\n        typename BlockPostfixCallbackOp>\n    __device__ __forceinline__ void ExclusiveReverseScan(\n        T                       input,                          ///< [in] Calling thread's input item\n        T                       &exclusive_output,              ///< [out] Calling thread's output item (may be aliased to \\p input)\n        ScanOp                  scan_op,                        ///< [in] Binary scan operator\n        BlockPostfixCallbackOp  &block_postfix_callback_op)     ///< [in-out] <b>[<em>warp</em><sub>0</sub> only]</b> Call-back functor for specifying a thread block-wide postfix to be applied to all inputs.\n    {\n        if (WARP_SYNCHRONOUS) {\n            // Short-circuit directly to warp-synchronous scan\n            T block_aggregate;\n            WarpReverseScan warp_scan;\n            warp_scan.ExclusiveReverseScan(input, exclusive_output, scan_op, block_aggregate);\n            // Obtain warp-wide postfix in lane0, then broadcast to other lanes\n            T block_postfix = block_postfix_callback_op(block_aggregate);\n            block_postfix = warp_scan.Broadcast(block_postfix, 0);\n            exclusive_output = linear_tid == BLOCK_THREADS - 1 ? block_postfix : scan_op(block_postfix, exclusive_output);\n        } else {\n            // Place thread partial into shared memory raking grid\n            T *placement_ptr = BlockRakingLayout::PlacementPtr(temp_storage.raking_grid, linear_tid);\n            detail::uninitialized_copy(placement_ptr, input);\n            __syncthreads();\n            // Reduce parallelism down to just raking threads\n            if (linear_tid < RAKING_THREADS) {\n                WarpReverseScan warp_scan;\n                // Raking upsweep reduction across shared partials\n                T upsweep_partial = Upsweep(scan_op);\n                // Warp-synchronous scan\n                T exclusive_partial, block_aggregate;\n                warp_scan.ExclusiveReverseScan(upsweep_partial, exclusive_partial, scan_op, block_aggregate);\n                // Obtain block-wide postfix in lane0, then broadcast to other lanes\n                T block_postfix = block_postfix_callback_op(block_aggregate);\n                block_postfix = warp_scan.Broadcast(block_postfix, 0);\n                // Update postfix with warpscan exclusive partial\n                T downsweep_postfix = linear_tid == RAKING_THREADS - 1\n                    ? block_postfix : scan_op(block_postfix, exclusive_partial);\n                // Exclusive raking downsweep scan\n                ExclusiveDownsweep(scan_op, downsweep_postfix);\n            }\n            __syncthreads();\n            // Grab thread postfix from shared memory\n            exclusive_output = *placement_ptr;\n\n            // // Compute warp scan in each warp.\n            // // The exclusive output from the last lane in each warp is invalid.\n            // T inclusive_output;\n            // WarpReverseScan warp_scan;\n            // warp_scan.ReverseScan(input, inclusive_output, exclusive_output, scan_op);\n\n            // // Compute the warp-wide postfix and block-wide aggregate for each warp.  Warp postfix for the last warp is invalid.\n            // T block_aggregate;\n            // T warp_postfix = ComputeWarpPostfix(scan_op, inclusive_output, block_aggregate);\n\n            // // Apply warp postfix to our lane's partial\n            // if (warp_id != 0) {\n            //     exclusive_output = scan_op(warp_postfix, exclusive_output);\n            //     if (lane_id == 0) { exclusive_output = warp_postfix; }\n            // }\n\n            // // Use the first warp to determine the thread block postfix, returning the result in lane0\n            // if (warp_id == 0) {\n            //     T block_postfix = block_postfix_callback_op(block_aggregate);\n            //     if (lane_id == 0) {\n            //         // Share the postfix with all threads\n            //         detail::uninitialized_copy(&temp_storage.block_postfix,\n            //                                   block_postfix);\n\n            //         exclusive_output = block_postfix; // The block postfix is the exclusive output for tid0\n            //     }\n            // }\n\n            // __syncthreads();\n\n            // // Incorporate thread block postfix into outputs\n            // T block_postfix = temp_storage.block_postfix;\n            // if (linear_tid > 0) { exclusive_output = scan_op(block_postfix, exclusive_output); }\n        }\n    }\n\n\n    /**\n     * \\brief Computes an inclusive block-wide postfix scan using the specified binary \\p scan_op functor.  Each thread contributes an array of consecutive input elements.  the call-back functor \\p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by <em>lane</em><sub>0</sub> in that warp is used as the \"seed\" value that logically postfixes the thread block's scan inputs.  Also provides every thread with the block-wide \\p block_aggregate of all inputs.\n     */\n    template <\n        int             ITEMS_PER_THREAD,\n        typename        ScanOp,\n        typename        BlockPostfixCallbackOp>\n    __device__ __forceinline__ void InclusiveReverseScan(\n        T                       (&input)[ITEMS_PER_THREAD],     ///< [in] Calling thread's input items\n        T                       (&output)[ITEMS_PER_THREAD],    ///< [out] Calling thread's output items (may be aliased to \\p input)\n        ScanOp                  scan_op,                        ///< [in] Binary scan functor\n        BlockPostfixCallbackOp   &block_postfix_callback_op)    ///< [in-out] <b>[<em>warp</em><sub>0</sub> only]</b> Call-back functor for specifying a block-wide postfix to be applied to the logical input sequence.\n    {\n        // Reduce consecutive thread items in registers\n        T thread_postfix = ThreadReverseReduce(input, scan_op);\n        // Exclusive thread block-scan\n        ExclusiveReverseScan(thread_postfix, thread_postfix, scan_op, block_postfix_callback_op);\n        // Inclusive scan in registers with postfix as seed\n        ThreadReverseScanInclusive(input, output, scan_op, thread_postfix);\n    }\n\n};"
  },
  {
    "path": "csrc/selective_scan/selective_scan.cpp",
    "content": "/******************************************************************************\n * Copyright (c) 2023, Tri Dao.\n ******************************************************************************/\n\n#include <c10/cuda/CUDAGuard.h>\n#include <c10/cuda/CUDAStream.h>\n#include <torch/python.h>\n#include <vector>\n\n#include \"selective_scan.h\"\n\n#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x \" must have shape (\" #__VA_ARGS__ \")\")\n\n#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...)                    \\\n    if (ITYPE == at::ScalarType::Half) {                                            \\\n        using input_t = at::Half;                                                   \\\n        __VA_ARGS__();                                                              \\\n    } else if (ITYPE == at::ScalarType::BFloat16) {                                 \\\n        using input_t = at::BFloat16;                                               \\\n        __VA_ARGS__();                                                              \\\n    } else if (ITYPE == at::ScalarType::Float)  {                                   \\\n        using input_t = float;                                                      \\\n        __VA_ARGS__();                                                              \\\n    } else {                                                                        \\\n        AT_ERROR(#NAME, \" not implemented for input type '\", toString(ITYPE), \"'\"); \\\n    }\n\n#define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...)                     \\\n    if (WTYPE == at::ScalarType::Half) {                                             \\\n        using weight_t = at::Half;                                                   \\\n        __VA_ARGS__();                                                               \\\n    } else if (WTYPE == at::ScalarType::BFloat16) {                                  \\\n        using weight_t = at::BFloat16;                                               \\\n        __VA_ARGS__();                                                               \\\n    } else if (WTYPE == at::ScalarType::Float)  {                                    \\\n        using weight_t = float;                                                      \\\n        __VA_ARGS__();                                                               \\\n    } else {                                                                         \\\n        AT_ERROR(#NAME, \" not implemented for weight type '\", toString(WTYPE), \"'\"); \\\n    }\n\n#define DISPATCH_WTYPE_FLOAT_AND_COMPLEX(WTYPE, NAME, ...)                           \\\n    if (WTYPE == at::ScalarType::Float) {                                            \\\n       using weight_t = float;                                                       \\\n        __VA_ARGS__();                                                               \\\n    } else if (WTYPE == at::ScalarType::ComplexFloat) {                              \\\n        using weight_t = c10::complex<float>;                                        \\\n        __VA_ARGS__();                                                               \\\n    } else {                                                                         \\\n        AT_ERROR(#NAME, \" not implemented for weight type '\", toString(WTYPE), \"'\"); \\\n    }\n\ntemplate<typename input_t, typename weight_t>\nvoid selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream);\n\ntemplate <typename input_t, typename weight_t>\nvoid selective_scan_bwd_cuda(SSMParamsBwd &params, cudaStream_t stream);\n\nvoid set_ssm_params_fwd(SSMParamsBase &params,\n                        // sizes\n                        const size_t batch,\n                        const size_t dim,\n                        const size_t seqlen,\n                        const size_t dstate,\n                        const size_t n_groups,\n                        const size_t n_chunks,\n                        const bool is_variable_B,\n                        const bool is_variable_C,\n                        // device pointers\n                        const at::Tensor u,\n                        const at::Tensor delta,\n                        const at::Tensor A,\n                        const at::Tensor B,\n                        const at::Tensor C,\n                        const at::Tensor out,\n                        const at::Tensor z,\n                        const at::Tensor out_z,\n                        void* D_ptr,\n                        void* delta_bias_ptr,\n                        void* x_ptr,\n                        bool has_z,\n                        bool delta_softplus) {\n\n    // Reset the parameters\n    memset(&params, 0, sizeof(params));\n\n    params.batch = batch;\n    params.dim = dim;\n    params.seqlen = seqlen;\n    params.dstate = dstate;\n    params.n_groups = n_groups;\n    params.n_chunks = n_chunks;\n    params.dim_ngroups_ratio = dim / n_groups;\n\n    params.delta_softplus = delta_softplus;\n\n    params.is_variable_B = is_variable_B;\n    params.is_variable_C = is_variable_C;\n\n    // Set the pointers and strides.\n    params.u_ptr = u.data_ptr();\n    params.delta_ptr = delta.data_ptr();\n    params.A_ptr = A.data_ptr();\n    params.B_ptr = B.data_ptr();\n    params.C_ptr = C.data_ptr();\n    params.D_ptr = D_ptr;\n    params.delta_bias_ptr = delta_bias_ptr;\n    params.out_ptr = out.data_ptr();\n    params.x_ptr = x_ptr;\n    params.z_ptr = has_z ? z.data_ptr() : nullptr;\n    params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr;\n    // All stride are in elements, not bytes.\n    params.A_d_stride = A.stride(0);\n    params.A_dstate_stride = A.stride(1);\n    if (!is_variable_B) {\n        params.B_d_stride = B.stride(0);\n    } else {\n        params.B_batch_stride = B.stride(0);\n        params.B_group_stride = B.stride(1);\n    }\n    params.B_dstate_stride = !is_variable_B ? B.stride(1) : B.stride(2);\n    if (!is_variable_C) {\n        params.C_d_stride = C.stride(0);\n    } else {\n        params.C_batch_stride = C.stride(0);\n        params.C_group_stride = C.stride(1);\n    }\n    params.C_dstate_stride = !is_variable_C ? C.stride(1) : C.stride(2);\n    params.u_batch_stride = u.stride(0);\n    params.u_d_stride = u.stride(1);\n    params.delta_batch_stride = delta.stride(0);\n    params.delta_d_stride = delta.stride(1);\n    if (has_z) {\n        params.z_batch_stride = z.stride(0);\n        params.z_d_stride = z.stride(1);\n        params.out_z_batch_stride = out_z.stride(0);\n        params.out_z_d_stride = out_z.stride(1);\n    }\n    params.out_batch_stride = out.stride(0);\n    params.out_d_stride = out.stride(1);\n}\n\nvoid set_ssm_params_bwd(SSMParamsBwd &params,\n                        // sizes\n                        const size_t batch,\n                        const size_t dim,\n                        const size_t seqlen,\n                        const size_t dstate,\n                        const size_t n_groups,\n                        const size_t n_chunks,\n                        const bool is_variable_B,\n                        const bool is_variable_C,\n                        // device pointers\n                        const at::Tensor u,\n                        const at::Tensor delta,\n                        const at::Tensor A,\n                        const at::Tensor B,\n                        const at::Tensor C,\n                        const at::Tensor z,\n                        const at::Tensor out,\n                        const at::Tensor out_z,\n                        void* D_ptr,\n                        void* delta_bias_ptr,\n                        void* x_ptr,\n                        const at::Tensor dout,\n                        const at::Tensor du,\n                        const at::Tensor ddelta,\n                        const at::Tensor dA,\n                        const at::Tensor dB,\n                        const at::Tensor dC,\n                        const at::Tensor dz,\n                        void* dD_ptr,\n                        void* ddelta_bias_ptr,\n                        bool has_z,\n                        bool delta_softplus,\n                        bool recompute_out_z) {\n    // Pass in \"dout\" instead of \"out\", we're not gonna use \"out\" unless we have z\n    set_ssm_params_fwd(params, batch, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,\n                       u, delta, A, B, C, has_z ? out : dout,\n                       has_z ? z : dout,\n                       // If not recompute_out_z, pass dout instead of out_z.\n                       // This won't be used by the bwd kernel\n                       recompute_out_z ? out_z : dout,\n                       D_ptr, delta_bias_ptr, x_ptr, has_z, delta_softplus);\n    if (!recompute_out_z) { params.out_z_ptr = nullptr; }\n\n    // Set the pointers and strides.\n    params.dout_ptr = dout.data_ptr();\n    params.du_ptr = du.data_ptr();\n    params.dA_ptr = dA.data_ptr();\n    params.dB_ptr = dB.data_ptr();\n    params.dC_ptr = dC.data_ptr();\n    params.dD_ptr = dD_ptr;\n    params.ddelta_ptr = ddelta.data_ptr();\n    params.ddelta_bias_ptr = ddelta_bias_ptr;\n    params.dz_ptr = has_z ? dz.data_ptr() : nullptr;\n    // All stride are in elements, not bytes.\n    params.dout_batch_stride = dout.stride(0);\n    params.dout_d_stride = dout.stride(1);\n    params.dA_d_stride = dA.stride(0);\n    params.dA_dstate_stride = dA.stride(1);\n    if (!is_variable_B) {\n        params.dB_d_stride = dB.stride(0);\n    } else {\n        params.dB_batch_stride = dB.stride(0);\n        params.dB_group_stride = dB.stride(1);\n    }\n    params.dB_dstate_stride = !is_variable_B ? dB.stride(1) : dB.stride(2);\n    if (!is_variable_C) {\n        params.dC_d_stride = dC.stride(0);\n    } else {\n        params.dC_batch_stride = dC.stride(0);\n        params.dC_group_stride = dC.stride(1);\n    }\n    params.dC_dstate_stride = !is_variable_C ? dC.stride(1) : dC.stride(2);\n    params.du_batch_stride = du.stride(0);\n    params.du_d_stride = du.stride(1);\n    params.ddelta_batch_stride = ddelta.stride(0);\n    params.ddelta_d_stride = ddelta.stride(1);\n    if (has_z) {\n        params.dz_batch_stride = dz.stride(0);\n        params.dz_d_stride = dz.stride(1);\n    }\n}\n\nstd::vector<at::Tensor>\nselective_scan_fwd(const at::Tensor &u, const at::Tensor &delta,\n                  const at::Tensor &A, const at::Tensor &B, const at::Tensor &C,\n                  const c10::optional<at::Tensor> &D_,\n                  const c10::optional<at::Tensor> &z_,\n                  const c10::optional<at::Tensor> &delta_bias_,\n                  bool delta_softplus) {\n    auto input_type = u.scalar_type();\n    auto weight_type = A.scalar_type();\n    TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);\n    TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::ComplexFloat);\n\n    const bool is_variable_B = B.dim() >= 3;\n    const bool is_variable_C = C.dim() >= 3;\n    const bool is_complex = weight_type == at::ScalarType::ComplexFloat;\n\n    TORCH_CHECK(delta.scalar_type() == input_type);\n    TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type));\n    TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type));\n\n    TORCH_CHECK(u.is_cuda());\n    TORCH_CHECK(delta.is_cuda());\n    TORCH_CHECK(A.is_cuda());\n    TORCH_CHECK(B.is_cuda());\n    TORCH_CHECK(C.is_cuda());\n\n    TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);\n    TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);\n\n    const auto sizes = u.sizes();\n    const int batch_size = sizes[0];\n    const int dim = sizes[1];\n    const int seqlen = sizes[2];\n    const int dstate = A.size(1);\n    const int n_groups = is_variable_B ? B.size(1) : 1;\n\n    TORCH_CHECK(dstate <= 256, \"selective_scan only supports state dimension <= 256\");\n\n    CHECK_SHAPE(u, batch_size, dim, seqlen);\n    CHECK_SHAPE(delta, batch_size, dim, seqlen);\n    CHECK_SHAPE(A, dim, dstate);\n    if (!is_variable_B) {\n        CHECK_SHAPE(B, dim, dstate);\n    } else {\n        CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2);\n        TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);\n    }\n    if (!is_variable_C) {\n        CHECK_SHAPE(C, dim, dstate);\n    } else {\n        CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2);\n        TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);\n    }\n\n    if (D_.has_value()) {\n        auto D = D_.value();\n        TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);\n        TORCH_CHECK(D.is_cuda());\n        TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);\n        CHECK_SHAPE(D, dim);\n    }\n\n    if (delta_bias_.has_value()) {\n        auto delta_bias = delta_bias_.value();\n        TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);\n        TORCH_CHECK(delta_bias.is_cuda());\n        TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);\n        CHECK_SHAPE(delta_bias, dim);\n    }\n\n    at::Tensor z, out_z;\n    const bool has_z = z_.has_value();\n    if (has_z) {\n        z = z_.value();\n        TORCH_CHECK(z.scalar_type() == input_type);\n        TORCH_CHECK(z.is_cuda());\n        TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);\n        CHECK_SHAPE(z, batch_size, dim, seqlen);\n        out_z = torch::empty_like(z);\n    }\n\n    const int n_chunks = (seqlen + 2048 - 1) / 2048;\n    // const int n_chunks = (seqlen + 1024 - 1) / 1024;\n    // at::Tensor out = torch::empty_like(u);\n    // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout\n    at::Tensor out = torch::empty_like(delta);\n    at::Tensor x;\n    x = torch::empty({batch_size, dim, n_chunks, dstate * 2}, u.options().dtype(weight_type));\n\n    SSMParamsBase params;\n    set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,\n                       u, delta, A, B, C, out, z, out_z,\n                       D_.has_value() ? D_.value().data_ptr() : nullptr,\n                       delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,\n                       x.data_ptr(),\n                       has_z,\n                       delta_softplus);\n\n    // Otherwise the kernel will be launched from cuda:0 device\n    // Cast to char to avoid compiler warning about narrowing\n    at::cuda::CUDAGuard device_guard{u.device()};\n    auto stream = at::cuda::getCurrentCUDAStream().stream();\n    DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), \"selective_scan_fwd\", [&] {\n        DISPATCH_WTYPE_FLOAT_AND_COMPLEX(A.scalar_type(), \"selective_scan_fwd\", [&] {\n            selective_scan_fwd_cuda<input_t, weight_t>(params, stream);\n        });\n    });\n    std::vector<at::Tensor> result = {out, x};\n    if (has_z) { result.push_back(out_z); }\n    return result;\n}\n\nstd::vector<at::Tensor>\nselective_scan_bwd(const at::Tensor &u, const at::Tensor &delta,\n                  const at::Tensor &A, const at::Tensor &B, const at::Tensor &C,\n                  const c10::optional<at::Tensor> &D_,\n                  const c10::optional<at::Tensor> &z_,\n                  const c10::optional<at::Tensor> &delta_bias_,\n                  const at::Tensor &dout,\n                  const c10::optional<at::Tensor> &x_,\n                  const c10::optional<at::Tensor> &out_,\n                  c10::optional<at::Tensor> &dz_,\n                  bool delta_softplus,\n                  bool recompute_out_z) {\n    auto input_type = u.scalar_type();\n    auto weight_type = A.scalar_type();\n    TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);\n    TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::ComplexFloat);\n\n    const bool is_variable_B = B.dim() >= 3;\n    const bool is_variable_C = C.dim() >= 3;\n    const bool is_complex = weight_type == at::ScalarType::ComplexFloat;\n\n    TORCH_CHECK(delta.scalar_type() == input_type);\n    TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type));\n    TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type));\n    TORCH_CHECK(dout.scalar_type() == input_type);\n\n    TORCH_CHECK(u.is_cuda());\n    TORCH_CHECK(delta.is_cuda());\n    TORCH_CHECK(A.is_cuda());\n    TORCH_CHECK(B.is_cuda());\n    TORCH_CHECK(C.is_cuda());\n    TORCH_CHECK(dout.is_cuda());\n\n    TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);\n    TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);\n    TORCH_CHECK(dout.stride(-1) == 1 || dout.size(-1) == 1);\n\n    const auto sizes = u.sizes();\n    const int batch_size = sizes[0];\n    const int dim = sizes[1];\n    const int seqlen = sizes[2];\n    const int dstate = A.size(1);\n    const int n_groups = is_variable_B ? B.size(1) : 1;\n\n    TORCH_CHECK(dstate <= 256, \"selective_scan only supports state dimension <= 256\");\n\n    CHECK_SHAPE(u, batch_size, dim, seqlen);\n    CHECK_SHAPE(delta, batch_size, dim, seqlen);\n    CHECK_SHAPE(A, dim, dstate);\n    if (!is_variable_B) {\n        CHECK_SHAPE(B, dim, dstate);\n    } else {\n        CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2);\n        TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);\n    }\n    if (!is_variable_C) {\n        CHECK_SHAPE(C, dim, dstate);\n    } else {\n        CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2);\n        TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);\n    }\n    CHECK_SHAPE(dout, batch_size, dim, seqlen);\n\n    if (D_.has_value()) {\n        auto D = D_.value();\n        TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);\n        TORCH_CHECK(D.is_cuda());\n        TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);\n        CHECK_SHAPE(D, dim);\n    }\n\n    if (delta_bias_.has_value()) {\n        auto delta_bias = delta_bias_.value();\n        TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);\n        TORCH_CHECK(delta_bias.is_cuda());\n        TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);\n        CHECK_SHAPE(delta_bias, dim);\n    }\n\n    at::Tensor z, out, dz, out_z;\n    const bool has_z = z_.has_value();\n    if (has_z) {\n        z = z_.value();\n        TORCH_CHECK(z.scalar_type() == input_type);\n        TORCH_CHECK(z.is_cuda());\n        TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);\n        CHECK_SHAPE(z, batch_size, dim, seqlen);\n\n        TORCH_CHECK(out_.has_value());\n        out = out_.value();\n        TORCH_CHECK(out.scalar_type() == input_type);\n        TORCH_CHECK(out.is_cuda());\n        TORCH_CHECK(out.stride(-1) == 1 || out.size(-1) == 1);\n        CHECK_SHAPE(out, batch_size, dim, seqlen);\n\n        if (dz_.has_value()) {\n            dz = dz_.value();\n            TORCH_CHECK(dz.scalar_type() == input_type);\n            TORCH_CHECK(dz.is_cuda());\n            TORCH_CHECK(dz.stride(-1) == 1 || dz.size(-1) == 1);\n            CHECK_SHAPE(dz, batch_size, dim, seqlen);\n        } else {\n            dz = torch::empty_like(z);\n        }\n        if (recompute_out_z) {\n            out_z = torch::empty_like(out);\n        }\n    }\n\n    const int n_chunks = (seqlen + 2048 - 1) / 2048;\n    // const int n_chunks = (seqlen + 1024 - 1) / 1024;\n    if (n_chunks > 1) { TORCH_CHECK(x_.has_value()); }\n    if (x_.has_value()) {\n        auto x = x_.value();\n        TORCH_CHECK(x.scalar_type() == weight_type);\n        TORCH_CHECK(x.is_cuda());\n        TORCH_CHECK(x.is_contiguous());\n        CHECK_SHAPE(x, batch_size, dim, n_chunks, 2 * dstate);\n    }\n\n    at::Tensor du = torch::empty_like(u);\n    at::Tensor ddelta = torch::empty_like(delta);\n    at::Tensor dA = torch::zeros_like(A);\n    at::Tensor dB = !is_variable_B ? torch::zeros_like(B) : torch::zeros_like(B, B.options().dtype(torch::kFloat32));\n    at::Tensor dC = !is_variable_C ? torch::zeros_like(C) : torch::zeros_like(C, C.options().dtype(torch::kFloat32));\n    at::Tensor dD;\n    if (D_.has_value()) { dD = torch::zeros_like(D_.value()); }\n    at::Tensor ddelta_bias;\n    if (delta_bias_.has_value()) { ddelta_bias = torch::zeros_like(delta_bias_.value()); }\n\n    SSMParamsBwd params;\n    set_ssm_params_bwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,\n                       u, delta, A, B, C, z, out, out_z,\n                       D_.has_value() ? D_.value().data_ptr() : nullptr,\n                       delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,\n                       x_.has_value() ? x_.value().data_ptr() : nullptr,\n                       dout, du, ddelta, dA, dB, dC, dz,\n                       D_.has_value() ? dD.data_ptr() : nullptr,\n                       delta_bias_.has_value() ? ddelta_bias.data_ptr() : nullptr,\n                       has_z, delta_softplus, recompute_out_z);\n\n    // Otherwise the kernel will be launched from cuda:0 device\n    // Cast to char to avoid compiler warning about narrowing\n    at::cuda::CUDAGuard device_guard{u.device()};\n    auto stream = at::cuda::getCurrentCUDAStream().stream();\n    DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), \"selective_scan_bwd\", [&] {\n        DISPATCH_WTYPE_FLOAT_AND_COMPLEX(A.scalar_type(), \"selective_scan_bwd\", [&] {\n            selective_scan_bwd_cuda<input_t, weight_t>(params, stream);\n        });\n    });\n    std::vector<at::Tensor> result = {du, ddelta, dA, dB.to(B.dtype()), dC.to(C.dtype()), dD, ddelta_bias};\n    if (has_z) { result.push_back(dz); }\n    if (recompute_out_z) { result.push_back(out_z); }\n    return result;\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n    m.def(\"fwd\", &selective_scan_fwd, \"Selective scan forward\");\n    m.def(\"bwd\", &selective_scan_bwd, \"Selective scan backward\");\n}\n"
  },
  {
    "path": "csrc/selective_scan/selective_scan.h",
    "content": "/******************************************************************************\n * Copyright (c) 2023, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct SSMScanParamsBase {\n    using index_t = uint32_t;\n\n    int batch, seqlen, n_chunks;\n    index_t a_batch_stride;\n    index_t b_batch_stride;\n    index_t out_batch_stride;\n\n    // Common data pointers.\n    void *__restrict__ a_ptr;\n    void *__restrict__ b_ptr;\n    void *__restrict__ out_ptr;\n    void *__restrict__ x_ptr;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct SSMParamsBase {\n    using index_t = uint32_t;\n\n    int batch, dim, seqlen, dstate, n_groups, n_chunks;\n    int dim_ngroups_ratio;\n    bool is_variable_B;\n    bool is_variable_C;\n\n    bool delta_softplus;\n\n    index_t A_d_stride;\n    index_t A_dstate_stride;\n    index_t B_batch_stride;\n    index_t B_d_stride;\n    index_t B_dstate_stride;\n    index_t B_group_stride;\n    index_t C_batch_stride;\n    index_t C_d_stride;\n    index_t C_dstate_stride;\n    index_t C_group_stride;\n    index_t u_batch_stride;\n    index_t u_d_stride;\n    index_t delta_batch_stride;\n    index_t delta_d_stride;\n    index_t z_batch_stride;\n    index_t z_d_stride;\n    index_t out_batch_stride;\n    index_t out_d_stride;\n    index_t out_z_batch_stride;\n    index_t out_z_d_stride;\n\n    // Common data pointers.\n    void *__restrict__ A_ptr;\n    void *__restrict__ B_ptr;\n    void *__restrict__ C_ptr;\n    void *__restrict__ D_ptr;\n    void *__restrict__ u_ptr;\n    void *__restrict__ delta_ptr;\n    void *__restrict__ delta_bias_ptr;\n    void *__restrict__ out_ptr;\n    void *__restrict__ x_ptr;\n    void *__restrict__ z_ptr;\n    void *__restrict__ out_z_ptr;\n};\n\nstruct SSMParamsBwd: public SSMParamsBase {\n    index_t dout_batch_stride;\n    index_t dout_d_stride;\n    index_t dA_d_stride;\n    index_t dA_dstate_stride;\n    index_t dB_batch_stride;\n    index_t dB_group_stride;\n    index_t dB_d_stride;\n    index_t dB_dstate_stride;\n    index_t dC_batch_stride;\n    index_t dC_group_stride;\n    index_t dC_d_stride;\n    index_t dC_dstate_stride;\n    index_t du_batch_stride;\n    index_t du_d_stride;\n    index_t dz_batch_stride;\n    index_t dz_d_stride;\n    index_t ddelta_batch_stride;\n    index_t ddelta_d_stride;\n\n    // Common data pointers.\n    void *__restrict__ dout_ptr;\n    void *__restrict__ dA_ptr;\n    void *__restrict__ dB_ptr;\n    void *__restrict__ dC_ptr;\n    void *__restrict__ dD_ptr;\n    void *__restrict__ du_ptr;\n    void *__restrict__ dz_ptr;\n    void *__restrict__ ddelta_ptr;\n    void *__restrict__ ddelta_bias_ptr;\n};\n"
  },
  {
    "path": "csrc/selective_scan/selective_scan_bwd_bf16_complex.cu",
    "content": "/******************************************************************************\n * Copyright (c) 2023, Tri Dao.\n ******************************************************************************/\n\n// Split into multiple files to compile in paralell\n\n#include \"selective_scan_bwd_kernel.cuh\"\n\ntemplate void selective_scan_bwd_cuda<at::BFloat16, complex_t>(SSMParamsBwd &params, cudaStream_t stream);"
  },
  {
    "path": "csrc/selective_scan/selective_scan_bwd_bf16_real.cu",
    "content": "/******************************************************************************\n * Copyright (c) 2023, Tri Dao.\n ******************************************************************************/\n\n// Split into multiple files to compile in paralell\n\n#include \"selective_scan_bwd_kernel.cuh\"\n\ntemplate void selective_scan_bwd_cuda<at::BFloat16, float>(SSMParamsBwd &params, cudaStream_t stream);"
  },
  {
    "path": "csrc/selective_scan/selective_scan_bwd_fp16_complex.cu",
    "content": "/******************************************************************************\n * Copyright (c) 2023, Tri Dao.\n ******************************************************************************/\n\n// Split into multiple files to compile in paralell\n\n#include \"selective_scan_bwd_kernel.cuh\"\n\ntemplate void selective_scan_bwd_cuda<at::Half, complex_t>(SSMParamsBwd &params, cudaStream_t stream);"
  },
  {
    "path": "csrc/selective_scan/selective_scan_bwd_fp16_real.cu",
    "content": "/******************************************************************************\n * Copyright (c) 2023, Tri Dao.\n ******************************************************************************/\n\n// Split into multiple files to compile in paralell\n\n#include \"selective_scan_bwd_kernel.cuh\"\n\ntemplate void selective_scan_bwd_cuda<at::Half, float>(SSMParamsBwd &params, cudaStream_t stream);"
  },
  {
    "path": "csrc/selective_scan/selective_scan_bwd_fp32_complex.cu",
    "content": "/******************************************************************************\n * Copyright (c) 2023, Tri Dao.\n ******************************************************************************/\n\n// Split into multiple files to compile in paralell\n\n#include \"selective_scan_bwd_kernel.cuh\"\n\ntemplate void selective_scan_bwd_cuda<float, complex_t>(SSMParamsBwd &params, cudaStream_t stream);"
  },
  {
    "path": "csrc/selective_scan/selective_scan_bwd_fp32_real.cu",
    "content": "/******************************************************************************\n * Copyright (c) 2023, Tri Dao.\n ******************************************************************************/\n\n// Split into multiple files to compile in paralell\n\n#include \"selective_scan_bwd_kernel.cuh\"\n\ntemplate void selective_scan_bwd_cuda<float, float>(SSMParamsBwd &params, cudaStream_t stream);"
  },
  {
    "path": "csrc/selective_scan/selective_scan_bwd_kernel.cuh",
    "content": "/******************************************************************************\n * Copyright (c) 2023, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n#include <c10/util/BFloat16.h>\n#include <c10/util/Half.h>\n#include <c10/cuda/CUDAException.h>  // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK\n#include <ATen/cuda/Atomic.cuh>  // For atomicAdd on complex\n\n#ifndef USE_ROCM\n    #include <cub/block/block_load.cuh>\n    #include <cub/block/block_store.cuh>\n    #include <cub/block/block_scan.cuh>\n    #include <cub/block/block_reduce.cuh>\n#else\n    #include <hipcub/hipcub.hpp>\n    namespace cub = hipcub;\n#endif\n\n#include \"selective_scan.h\"\n#include \"selective_scan_common.h\"\n#include \"reverse_scan.cuh\"\n#include \"static_switch.h\"\n\ntemplate<typename scalar_t> __device__ __forceinline__ scalar_t conj(scalar_t x);\ntemplate<> __device__ __forceinline__ float conj<float>(float x) { return x; }\ntemplate<> __device__ __forceinline__ complex_t conj<complex_t>(complex_t x) { return std::conj(x); }\n\ntemplate<int kNThreads_, int kNItems_, bool kIsEvenLen_, bool kIsVariableB_, bool kIsVariableC_,\n         bool kDeltaSoftplus_, bool kHasZ_, typename input_t_, typename weight_t_>\nstruct Selective_Scan_bwd_kernel_traits {\n    static_assert(kNItems_ % 4 == 0);\n    using input_t = input_t_;\n    using weight_t = weight_t_;\n    static constexpr int kNThreads = kNThreads_;\n    static constexpr int kNItems = kNItems_;\n    static constexpr int kNBytes = sizeof(input_t);\n    static_assert(kNBytes == 2 || kNBytes == 4);\n    static constexpr int kNElts = kNBytes == 4 ? 4 : constexpr_min(8, kNItems);\n    static_assert(kNItems % kNElts == 0);\n    static constexpr int kNLoads = kNItems / kNElts;\n    static constexpr bool kIsComplex = std::is_same_v<weight_t, complex_t>;\n    static constexpr bool kIsEvenLen = kIsEvenLen_;\n    static constexpr bool kIsVariableB = kIsVariableB_;\n    static constexpr bool kIsVariableC = kIsVariableC_;\n    static constexpr bool kDeltaSoftplus = kDeltaSoftplus_;\n    static constexpr bool kHasZ = kHasZ_;\n    // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads with float improves occupancy.\n    // For complex this would lead to massive register spilling, so we keep it at 2.\n    static constexpr int kMinBlocks = kNThreads == 128 && !kIsComplex ? 3 : 2;\n    using vec_t = typename BytesToType<kNBytes * kNElts>::Type;\n    using scan_t = std::conditional_t<!kIsComplex, float2, float4>;\n    using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;\n    using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads, cub::BLOCK_LOAD_WARP_TRANSPOSE>;\n    using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, !kIsComplex ? kNItems : kNItems * 2, cub::BLOCK_LOAD_WARP_TRANSPOSE>;\n    using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, !kIsComplex ? kNLoads : kNLoads * 2, cub::BLOCK_LOAD_WARP_TRANSPOSE>;\n    using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;\n    using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, kNLoads, cub::BLOCK_STORE_WARP_TRANSPOSE>;\n    // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING_MEMOIZE>;\n    using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING>;\n    // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>;\n    using BlockReverseScanT = BlockReverseScan<scan_t, kNThreads>;\n    using BlockReduceT = cub::BlockReduce<scan_t, kNThreads>;\n    using BlockReduceFloatT = cub::BlockReduce<float, kNThreads>;\n    using BlockReduceComplexT = cub::BlockReduce<complex_t, kNThreads>;\n    using BlockExchangeT = cub::BlockExchange<float, kNThreads, !kIsComplex ? kNItems : kNItems * 2>;\n\n    static constexpr int kSmemIOSize = custom_max({sizeof(typename BlockLoadT::TempStorage),\n                                                    sizeof(typename BlockLoadVecT::TempStorage),\n                                                    (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage),\n                                                    (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage),\n                                                    sizeof(typename BlockStoreT::TempStorage),\n                                                    sizeof(typename BlockStoreVecT::TempStorage)});\n    static constexpr int kSmemExchangeSize = (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockExchangeT::TempStorage);\n    static constexpr int kSmemReduceSize = sizeof(typename BlockReduceT::TempStorage);\n    static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize + kSmemReduceSize + sizeof(typename BlockScanT::TempStorage) + sizeof(typename BlockReverseScanT::TempStorage);\n};\n\ntemplate<typename Ktraits>\n__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks)\nvoid selective_scan_bwd_kernel(SSMParamsBwd params) {\n    constexpr bool kIsComplex = Ktraits::kIsComplex;\n    constexpr bool kIsVariableB = Ktraits::kIsVariableB;\n    constexpr bool kIsVariableC = Ktraits::kIsVariableC;\n    constexpr bool kDeltaSoftplus = Ktraits::kDeltaSoftplus;\n    constexpr bool kHasZ = Ktraits::kHasZ;\n    constexpr int kNThreads = Ktraits::kNThreads;\n    constexpr int kNItems = Ktraits::kNItems;\n    using input_t = typename Ktraits::input_t;\n    using weight_t = typename Ktraits::weight_t;\n    using scan_t = typename Ktraits::scan_t;\n\n    // Shared memory.\n    extern __shared__ char smem_[];\n    // cast to lvalue reference of expected type\n    // char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t);\n    // auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_ + 2 * MAX_DSTATE * sizeof(weight_t));\n    // auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_loadstorescan);\n    auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);\n    auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);\n    auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));\n    auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);\n    auto& smem_exchange = *reinterpret_cast<typename Ktraits::BlockExchangeT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);\n    auto& smem_exchange1 = *reinterpret_cast<typename Ktraits::BlockExchangeT::TempStorage*>(smem_ + Ktraits::kSmemIOSize + sizeof(typename Ktraits::BlockExchangeT::TempStorage));\n    auto& smem_reduce = *reinterpret_cast<typename Ktraits::BlockReduceT::TempStorage*>(reinterpret_cast<char *>(&smem_exchange) + Ktraits::kSmemExchangeSize);\n    auto& smem_reduce_float = *reinterpret_cast<typename Ktraits::BlockReduceFloatT::TempStorage*>(&smem_reduce);\n    auto& smem_reduce_complex = *reinterpret_cast<typename Ktraits::BlockReduceComplexT::TempStorage*>(&smem_reduce);\n    auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(reinterpret_cast<char *>(&smem_reduce) + Ktraits::kSmemReduceSize);\n    auto& smem_reverse_scan = *reinterpret_cast<typename Ktraits::BlockReverseScanT::TempStorage*>(reinterpret_cast<char *>(&smem_scan) + sizeof(typename Ktraits::BlockScanT::TempStorage));\n    weight_t *smem_delta_a = reinterpret_cast<weight_t *>(smem_ + Ktraits::kSmemSize);\n    scan_t *smem_running_postfix = reinterpret_cast<scan_t *>(smem_delta_a + 2 * MAX_DSTATE + kNThreads);\n    weight_t *smem_da = reinterpret_cast<weight_t *>(smem_running_postfix + MAX_DSTATE);\n    weight_t *smem_dbc = reinterpret_cast<weight_t *>(smem_da + MAX_DSTATE);\n\n    const int batch_id = blockIdx.x;\n    const int dim_id = blockIdx.y;\n    const int group_id = dim_id / (params.dim_ngroups_ratio);\n    input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + batch_id * params.u_batch_stride\n        + dim_id * params.u_d_stride;\n    input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + batch_id * params.delta_batch_stride\n        + dim_id * params.delta_d_stride;\n    input_t *dout = reinterpret_cast<input_t *>(params.dout_ptr) + batch_id * params.dout_batch_stride\n        + dim_id * params.dout_d_stride;\n    weight_t *A = reinterpret_cast<weight_t *>(params.A_ptr) + dim_id * params.A_d_stride;\n    weight_t *B = reinterpret_cast<weight_t *>(params.B_ptr) + dim_id * params.B_d_stride;\n    input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride;\n    weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * params.C_d_stride;\n    input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;\n    weight_t *dA = reinterpret_cast<weight_t *>(params.dA_ptr) + dim_id * params.dA_d_stride;\n    weight_t *dB = reinterpret_cast<weight_t *>(params.dB_ptr)\n        + (!kIsVariableB ? dim_id * params.dB_d_stride : batch_id * (!kIsComplex ? params.dB_batch_stride : params.dB_batch_stride / 2) + group_id * params.dB_group_stride);\n    weight_t *dC = reinterpret_cast<weight_t *>(params.dC_ptr)\n        + (!kIsVariableC ? dim_id * params.dC_d_stride : batch_id * (!kIsComplex ? params.dC_batch_stride : params.dC_batch_stride / 2) + group_id * params.dC_group_stride);\n    float *dD = params.dD_ptr == nullptr ? nullptr : reinterpret_cast<float *>(params.dD_ptr) + dim_id;\n    float D_val = params.D_ptr == nullptr ? 0 : reinterpret_cast<float *>(params.D_ptr)[dim_id];\n    float *ddelta_bias = params.ddelta_bias_ptr == nullptr ? nullptr : reinterpret_cast<float *>(params.ddelta_bias_ptr) + dim_id;\n    float delta_bias = params.delta_bias_ptr == nullptr ? 0 : reinterpret_cast<float *>(params.delta_bias_ptr)[dim_id];\n    scan_t *x = params.x_ptr == nullptr\n        ? nullptr\n        : reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id) * (params.n_chunks) * params.dstate;\n    float dD_val = 0;\n    float ddelta_bias_val = 0;\n\n    constexpr int kChunkSize = kNThreads * kNItems;\n    u += (params.n_chunks - 1) * kChunkSize;\n    delta += (params.n_chunks - 1) * kChunkSize;\n    dout += (params.n_chunks - 1) * kChunkSize;\n    Bvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2);\n    Cvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2);\n    for (int chunk = params.n_chunks - 1; chunk >= 0; --chunk) {\n        input_t u_vals[kNItems];\n        input_t delta_vals_load[kNItems];\n        input_t dout_vals_load[kNItems];\n        __syncthreads();\n        load_input<Ktraits>(u, u_vals, smem_load, params.seqlen - chunk * kChunkSize);\n        u -= kChunkSize;\n        __syncthreads();\n        load_input<Ktraits>(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize);\n        // Will reload delta at the same location if kDeltaSoftplus\n        if constexpr (!kDeltaSoftplus) { delta -= kChunkSize; }\n        __syncthreads();\n        load_input<Ktraits>(dout, dout_vals_load, smem_load, params.seqlen - chunk * kChunkSize);\n        dout -= kChunkSize;\n\n        float dout_vals[kNItems], delta_vals[kNItems];\n        #pragma unroll\n        for (int i = 0; i < kNItems; ++i) {\n            dout_vals[i] = float(dout_vals_load[i]);\n            delta_vals[i] = float(delta_vals_load[i]) + delta_bias;\n            if constexpr (kDeltaSoftplus) {\n                delta_vals[i] = delta_vals[i] <= 20.f ? log1pf(expf(delta_vals[i])) : delta_vals[i];\n            }\n        }\n\n        if constexpr (kHasZ) {\n            input_t *z = reinterpret_cast<input_t *>(params.z_ptr) + batch_id * params.z_batch_stride\n                + dim_id * params.z_d_stride + chunk * kChunkSize;\n            input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride\n                + dim_id * params.out_d_stride + chunk * kChunkSize;\n            input_t *dz = reinterpret_cast<input_t *>(params.dz_ptr) + batch_id * params.dz_batch_stride\n                + dim_id * params.dz_d_stride + chunk * kChunkSize;\n            input_t z_vals[kNItems], out_vals[kNItems];\n            __syncthreads();\n            load_input<Ktraits>(z, z_vals, smem_load, params.seqlen - chunk * kChunkSize);\n            __syncthreads();\n            load_input<Ktraits>(out, out_vals, smem_load, params.seqlen - chunk * kChunkSize);\n            float dz_vals[kNItems], z_silu_vals[kNItems];\n            #pragma unroll\n            for (int i = 0; i < kNItems; ++i) {\n                float z_val = z_vals[i];\n                float z_sigmoid_val = 1.0f / (1.0f + expf(-z_val));\n                z_silu_vals[i] = z_val * z_sigmoid_val;\n                dz_vals[i] = dout_vals[i] * float(out_vals[i]) * z_sigmoid_val\n                             * (1.0f + z_val * (1.0f - z_sigmoid_val));\n                dout_vals[i] *= z_silu_vals[i];\n            }\n            __syncthreads();\n            store_output<Ktraits>(dz, dz_vals, smem_store, params.seqlen - chunk * kChunkSize);\n            if (params.out_z_ptr != nullptr) {  // Recompute and store out_z\n                float out_z_vals[kNItems];\n                #pragma unroll\n                for (int i = 0; i < kNItems; ++i) { out_z_vals[i] = float(out_vals[i]) * z_silu_vals[i]; }\n                // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0) {\n                    // printf(\"out_val=%f, z_silu_val = %f, out_z_val = %f\\n\", float(out_vals[0]), z_silu_vals[0], out_z_vals[0]);\n                // }\n                input_t *out_z = reinterpret_cast<input_t *>(params.out_z_ptr) + batch_id * params.out_z_batch_stride\n                    + dim_id * params.out_z_d_stride + chunk * kChunkSize;\n                __syncthreads();\n                store_output<Ktraits>(out_z, out_z_vals, smem_store, params.seqlen - chunk * kChunkSize);\n            }\n        }\n\n        float du_vals[kNItems];\n        #pragma unroll\n        for (int i = 0; i < kNItems; ++i) { du_vals[i] = D_val * dout_vals[i]; }\n        #pragma unroll\n        for (int i = 0; i < kNItems; ++i) { dD_val += dout_vals[i] * float(u_vals[i]); }\n\n        float ddelta_vals[kNItems] = {0};\n        __syncthreads();\n        for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {\n            const weight_t A_val = A[state_idx * params.A_dstate_stride];\n            // Multiply the real part of A with LOG2E so we can use exp2f instead of expf.\n            weight_t A_scaled;\n            constexpr float kLog2e = M_LOG2E;\n            if constexpr (!kIsComplex) {\n                A_scaled = A_val * kLog2e;\n            } else {\n                A_scaled = complex_t(A_val.real_ * kLog2e, A_val.imag_);\n            }\n            weight_t B_val, C_val;\n            weight_t B_vals[kNItems], C_vals[kNItems];\n            if constexpr (!kIsVariableB) {\n                B_val = B[state_idx * params.B_dstate_stride];\n            } else {\n                load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,\n                    smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2));\n            }\n            if constexpr (!kIsVariableC) {\n                C_val = C[state_idx * params.C_dstate_stride];\n            } else {\n                auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1;\n                load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,\n                    smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2));\n            }\n            // const weight_t A_val = smem_a[state_idx];\n            scan_t thread_data[kNItems], thread_reverse_data[kNItems];\n            if constexpr (!kIsComplex) {\n                #pragma unroll\n                for (int i = 0; i < kNItems; ++i) {\n                    const float delta_a_exp = exp2f(delta_vals[i] * A_scaled);\n                    thread_data[i] = make_float2(delta_a_exp, !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]);\n                    if (i == 0) {\n                        smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * MAX_DSTATE : threadIdx.x + 2 * MAX_DSTATE] = delta_a_exp;\n                    } else {\n                        thread_reverse_data[i - 1].x = delta_a_exp;\n                    }\n                    thread_reverse_data[i].y = dout_vals[i] *\n                        (!kIsVariableC\n                         ? (!kIsVariableB ? B_val * C_val : C_val)\n                         : (!kIsVariableB ? B_val * C_vals[i] : C_vals[i]));\n                }\n                __syncthreads();\n                thread_reverse_data[kNItems - 1].x = threadIdx.x == kNThreads - 1\n                    ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * MAX_DSTATE])\n                    : smem_delta_a[threadIdx.x + 1 + 2 * MAX_DSTATE];\n                // Initialize running total\n                scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float2(1.f, 0.f);\n                SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);\n                typename Ktraits::BlockScanT(smem_scan).InclusiveScan(\n                    thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op\n                );\n                scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float2(1.f, 0.f);\n                SSMScanPrefixCallbackOp<weight_t> postfix_op(running_postfix);\n                typename Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan(\n                    thread_reverse_data, thread_reverse_data, SSMScanOp<weight_t>(), postfix_op\n                );\n                if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; }\n                weight_t dA_val = 0, dBC_val = 0;\n                weight_t dB_vals[kNItems], dC_vals[kNItems];\n                #pragma unroll\n                for (int i = 0; i < kNItems; ++i) {\n                    const float dx = thread_reverse_data[i].y;\n                    const float ddelta_u = !kIsVariableB ? dx : dx * B_vals[i];\n                    du_vals[i] += ddelta_u * delta_vals[i];\n                    const float a = thread_data[i].y - (!kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]);\n                    ddelta_vals[i] += ddelta_u * float(u_vals[i]) + dx * A_val * a;\n                    dA_val += dx * delta_vals[i] * a;\n                    if constexpr (!kIsVariableB || !kIsVariableC) {\n                        if constexpr (!kIsVariableB) {  // dBC_val is dB_val\n                            dBC_val += dout_vals[i] * (!kIsVariableC ? thread_data[i].y : thread_data[i].y * C_vals[i]);\n                        } else {  // dBC_val is dC_val\n                            dBC_val += dout_vals[i] * thread_data[i].y;\n                        }\n                    }\n                    if constexpr (kIsVariableB) { dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]); }\n                    if constexpr (kIsVariableC) {\n                        dC_vals[i] = dout_vals[i] * (!kIsVariableB ? thread_data[i].y * B_val : thread_data[i].y);\n                    }\n                }\n                // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower\n                if constexpr (kIsVariableB || kIsVariableC) {\n                    if constexpr (kIsVariableB) {\n                        typename Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals, dB_vals);\n                    }\n                    if constexpr (kIsVariableC) {\n                        auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1;\n                        typename Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals, dC_vals);\n                    }\n                    const int seqlen_remaining = params.seqlen - chunk * kChunkSize - threadIdx.x;\n                    weight_t *dB_cur = dB + state_idx * params.dB_dstate_stride + chunk * kChunkSize + threadIdx.x;\n                    weight_t *dC_cur = dC + state_idx * params.dC_dstate_stride + chunk * kChunkSize + threadIdx.x;\n                    #pragma unroll\n                    for (int i = 0; i < kNItems; ++i) {\n                        if (i * kNThreads < seqlen_remaining) {\n                            if constexpr (kIsVariableB) { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals[i]); }\n                            if constexpr (kIsVariableC) { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals[i]); }\n                        }\n                    }\n                }\n                if constexpr (!kIsVariableB || !kIsVariableC) {\n                    float2 dA_dBC_val = make_float2(dA_val, dBC_val);\n                    dA_dBC_val = typename Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val);\n                    dA_val = dA_dBC_val.x;\n                    if (threadIdx.x == 0) {\n                        smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dA_dBC_val.y : dA_dBC_val.y + smem_dbc[state_idx];\n                    }\n                } else {\n                    dA_val = typename Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dA_val);\n                }\n                if (threadIdx.x == 0) {\n                    smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx];\n                }\n            } else {\n                #pragma unroll\n                for (int i = 0; i < kNItems; ++i) {\n                    // Pytorch's implementation of complex exp (which calls thrust) is very slow\n                    complex_t delta_a_exp = cexp2f(delta_vals[i] * A_scaled);\n                    weight_t B_delta_u_val = !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : B_vals[i] * delta_vals[i] * float(u_vals[i]);\n                    thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_);\n                    if (i == 0) {\n                        smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * MAX_DSTATE : threadIdx.x + 2 * MAX_DSTATE] = delta_a_exp;\n                    } else {\n                        thread_reverse_data[i - 1].x = delta_a_exp.real_;\n                        thread_reverse_data[i - 1].y = -delta_a_exp.imag_;\n                    }\n                    complex_t dout_BC = 2 * dout_vals[i]\n                        * conj(!kIsVariableC\n                                ? (!kIsVariableB ? B_val * C_val : C_val)\n                                : (!kIsVariableB ? B_val * C_vals[i] : C_vals[i]));\n                    thread_reverse_data[i].z = dout_BC.real_;\n                    thread_reverse_data[i].w = dout_BC.imag_;\n                }\n                __syncthreads();\n                complex_t delta_a_exp = threadIdx.x == kNThreads - 1\n                    ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * MAX_DSTATE])\n                    : smem_delta_a[threadIdx.x + 1 + 2 * MAX_DSTATE];\n                thread_reverse_data[kNItems - 1].x = delta_a_exp.real_;\n                thread_reverse_data[kNItems - 1].y = -delta_a_exp.imag_;\n                // Initialize running total\n                scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float4(1.f, 0.f, 0.f, 0.f);\n                SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);\n                typename Ktraits::BlockScanT(smem_scan).InclusiveScan(\n                    thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op\n                );\n                scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f);\n                SSMScanPrefixCallbackOp<weight_t> postfix_op(running_postfix);\n                typename Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan(\n                    thread_reverse_data, thread_reverse_data, SSMScanOp<weight_t>(), postfix_op\n                );\n                if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; }\n                weight_t dA_val = 0, dBC_val = 0;\n                weight_t dB_vals[kNItems], dC_vals[kNItems];\n                #pragma unroll\n                for (int i = 0; i < kNItems; ++i) {\n                    complex_t x = complex_t(thread_data[i].z, thread_data[i].w);\n                    complex_t dx = complex_t(thread_reverse_data[i].z, thread_reverse_data[i].w);\n                    float ddelta_u = !kIsVariableB ? dx.real_ : (dx * conj(B_vals[i])).real_;\n                    if constexpr (!kIsVariableB || !kIsVariableC) {\n                        if constexpr (!kIsVariableB) {  // dBC_val is dB_val\n                            dBC_val += (2 * dout_vals[i]) * conj(!kIsVariableC ? x : x * C_vals[i]);\n                        } else {  // dBC_val is dC_val\n                            dBC_val += (2 * dout_vals[i]) * conj(x);\n                        }\n                    }\n                    const complex_t a_conj = conj(x - (!kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]));\n                    du_vals[i] += ddelta_u * delta_vals[i];\n                    ddelta_vals[i] += ddelta_u * float(u_vals[i]) + (dx * conj(A_val) * a_conj).real_;\n                    dA_val += delta_vals[i] * dx * a_conj;\n                    if constexpr (kIsVariableB) { dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]); }\n                    if constexpr (kIsVariableC) {\n                        dC_vals[i] = (2 * dout_vals[i]) * conj(!kIsVariableB ? x * B_val : x);\n                    }\n                }\n                // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower\n                if constexpr (kIsVariableB || kIsVariableC) {\n                    float dB_vals_f[kNItems * 2], dC_vals_f[kNItems * 2];\n                    if constexpr (kIsVariableB) {\n                        #pragma unroll\n                        for (int i = 0; i < kNItems; ++i) {\n                            dB_vals_f[i * 2] = dB_vals[i].real_;\n                            dB_vals_f[i * 2 + 1] = dB_vals[i].imag_;\n                        }\n                        typename Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals_f, dB_vals_f);\n                    }\n                    if constexpr (kIsVariableC) {\n                        #pragma unroll\n                        for (int i = 0; i < kNItems; ++i) {\n                            dC_vals_f[i * 2] = dC_vals[i].real_;\n                            dC_vals_f[i * 2 + 1] = dC_vals[i].imag_;\n                        }\n                        auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1;\n                        typename Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals_f, dC_vals_f);\n                    }\n                    const int seqlen_remaining = (params.seqlen - chunk * kChunkSize) * 2 - threadIdx.x;\n                    float *dB_cur = reinterpret_cast<float *>(dB) + state_idx * params.dB_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x;\n                    float *dC_cur = reinterpret_cast<float *>(dC) + state_idx * params.dC_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x;\n                    #pragma unroll\n                    for (int i = 0; i < kNItems * 2; ++i) {\n                        if (i * kNThreads < seqlen_remaining) {\n                            if constexpr (kIsVariableB) { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals_f[i]); }\n                            if constexpr (kIsVariableC) { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals_f[i]); }\n                        }\n                    }\n                }\n                if constexpr (!kIsVariableB || !kIsVariableC) {\n                    float4 dA_dBC_val = make_float4(dA_val.real_, dA_val.imag_, dBC_val.real_, dBC_val.imag_);\n                    dA_dBC_val = typename Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val);\n                    dA_val = complex_t(dA_dBC_val.x, dA_dBC_val.y);\n                    dBC_val = complex_t(dA_dBC_val.z, dA_dBC_val.w);\n                    if (threadIdx.x == 0) {\n                        smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dBC_val : dBC_val + smem_dbc[state_idx];\n                    }\n                } else {\n                    dA_val = typename Ktraits::BlockReduceComplexT(smem_reduce_complex).Sum(dA_val);\n                }\n                if (threadIdx.x == 0) {\n                    smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx];\n                }\n            }\n        }\n\n        if constexpr (kDeltaSoftplus) {\n            __syncthreads();\n            input_t delta_vals_load[kNItems];\n            load_input<Ktraits>(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize);\n            delta -= kChunkSize;\n            #pragma unroll\n            for (int i = 0; i < kNItems; ++i) {\n                float delta_val = float(delta_vals_load[i]) + delta_bias;\n                float delta_val_neg_exp = expf(-delta_val);\n                ddelta_vals[i] = delta_val <= 20.f\n                    ? ddelta_vals[i] / (1.f + delta_val_neg_exp)\n                    : ddelta_vals[i];\n            }\n        }\n        for (int i = 0; i < kNItems; ++i) { ddelta_bias_val += ddelta_vals[i]; }\n\n        input_t *du = reinterpret_cast<input_t *>(params.du_ptr) + batch_id * params.du_batch_stride\n            + dim_id * params.du_d_stride + chunk * kChunkSize;\n        input_t *ddelta = reinterpret_cast<input_t *>(params.ddelta_ptr) + batch_id * params.ddelta_batch_stride\n            + dim_id * params.ddelta_d_stride + chunk * kChunkSize;\n        __syncthreads();\n        store_output<Ktraits>(du, du_vals, smem_store, params.seqlen - chunk * kChunkSize);\n        __syncthreads();\n        store_output<Ktraits>(ddelta, ddelta_vals, smem_store, params.seqlen - chunk * kChunkSize);\n\n        Bvar -= kChunkSize * (!kIsComplex ? 1 : 2);\n        Cvar -= kChunkSize * (!kIsComplex ? 1 : 2);\n    }\n    if (params.dD_ptr != nullptr) {\n        dD_val = typename Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dD_val);\n        if (threadIdx.x == 0) { gpuAtomicAdd(dD, dD_val); }\n    }\n    if (params.ddelta_bias_ptr != nullptr) {\n        __syncthreads();\n        ddelta_bias_val = typename Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(ddelta_bias_val);\n        if (threadIdx.x == 0) { gpuAtomicAdd(ddelta_bias, ddelta_bias_val); }\n    }\n    for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) {\n        gpuAtomicAdd(&(dA[state_idx * params.dA_dstate_stride]), smem_da[state_idx]);\n        weight_t dBC_val;\n        if (!kIsVariableB || !kIsVariableC) { dBC_val = smem_dbc[state_idx]; }\n        if constexpr (!kIsVariableB) {\n            gpuAtomicAdd(&(dB[state_idx * params.dB_dstate_stride]),\n                         !kIsVariableC ? dBC_val * conj(C[state_idx * params.C_dstate_stride]) : dBC_val);\n        }\n        if constexpr (!kIsVariableC) {\n            gpuAtomicAdd(&(dC[state_idx * params.dC_dstate_stride]),\n                        !kIsVariableB ? dBC_val * conj(B[state_idx * params.B_dstate_stride]) : dBC_val);\n        }\n    }\n}\n\ntemplate<int kNThreads, int kNItems, typename input_t, typename weight_t>\nvoid selective_scan_bwd_launch(SSMParamsBwd &params, cudaStream_t stream) {\n    BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {\n        BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] {\n            BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] {\n                BOOL_SWITCH(params.delta_softplus, kDeltaSoftplus, [&] {\n                    BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {\n                        using Ktraits = Selective_Scan_bwd_kernel_traits<kNThreads, kNItems, kIsEvenLen, kIsVariableB, kIsVariableC, kDeltaSoftplus, kHasZ, input_t, weight_t>;\n                        // using Ktraits = Selective_Scan_bwd_kernel_traits<kNThreads, kNItems, true, kIsVariableB, kIsVariableC, kDeltaSoftplus, kHasZ, input_t, weight_t>;\n                        // TODO: check this\n                        constexpr int kSmemSize = Ktraits::kSmemSize + MAX_DSTATE * sizeof(typename Ktraits::scan_t) + (kNThreads + 4 * MAX_DSTATE) * sizeof(typename Ktraits::weight_t);\n\n                        dim3 grid(params.batch, params.dim);\n                        \n                        auto kernel = &selective_scan_bwd_kernel<Ktraits>;\n\n                        if (kSmemSize >= 48 * 1024) {\n\n                            #ifndef USE_ROCM\n                            C10_CUDA_CHECK(cudaFuncSetAttribute(\n                                kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));\n                            #else\n                            C10_CUDA_CHECK(cudaFuncSetAttribute(\n                                (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));\n                            std::cerr << \"Warning (selective_scan_bwd_kernel): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \\n\" << std::endl;\n                            #endif\n\n                        }\n\n                        kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);\n                        C10_CUDA_KERNEL_LAUNCH_CHECK();\n                    });\n                });\n            });\n        });\n    });\n}\n\ntemplate<typename input_t, typename weight_t>\nvoid selective_scan_bwd_cuda(SSMParamsBwd &params, cudaStream_t stream) {\n\n    #ifndef USE_ROCM\n        if (params.seqlen <= 128) {\n            selective_scan_bwd_launch<32, 4, input_t, weight_t>(params, stream);\n        } else if (params.seqlen <= 256) {\n            selective_scan_bwd_launch<32, 8, input_t, weight_t>(params, stream);\n        } else if (params.seqlen <= 512) {\n            selective_scan_bwd_launch<32, 16, input_t, weight_t>(params, stream);\n        } else if (params.seqlen <= 1024) {\n            selective_scan_bwd_launch<64, 16, input_t, weight_t>(params, stream);\n        } else {\n            selective_scan_bwd_launch<128, 16, input_t, weight_t>(params, stream);\n        }\n    #else \n        if (params.seqlen <= 256) {\n            selective_scan_bwd_launch<64, 4, input_t, weight_t>(params, stream);\n        } else if (params.seqlen <= 512) {\n            selective_scan_bwd_launch<64, 8, input_t, weight_t>(params, stream);\n        } else if (params.seqlen <= 1024) {\n            selective_scan_bwd_launch<64, 16, input_t, weight_t>(params, stream);\n        } else {\n            selective_scan_bwd_launch<128, 16, input_t, weight_t>(params, stream);\n        }\n    #endif\n}"
  },
  {
    "path": "csrc/selective_scan/selective_scan_common.h",
    "content": "/******************************************************************************\n * Copyright (c) 2023, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n#ifndef USE_ROCM\n    #include <cuda_bf16.h>\n#else\n    #include <hip/hip_bf16.h>\n#endif\n#include <cuda_fp16.h>\n#include <c10/util/complex.h>  // For scalar_value_type\n\n\n#ifndef USE_ROCM\n\n    constexpr size_t custom_max(std::initializer_list<size_t> ilist) \n    {\n        return std::max(ilist);\n    }\n\n    template<typename T>\n    constexpr T constexpr_min(T a, T b) {\n        return std::min(a, b);\n    }\n\n#else\n    constexpr size_t custom_max(std::initializer_list<size_t> ilist) \n    {\n        return *std::max_element(ilist.begin(), ilist.end());\n    }\n\n    template<typename T>\n    constexpr T constexpr_min(T a, T b) {\n        return a < b ? a : b;\n    }\n#endif\n\n\n#define MAX_DSTATE 256\n\nusing complex_t = c10::complex<float>;\n\ninline __device__ float2 operator+(const float2 & a, const float2 & b){\n    return {a.x + b.x, a.y + b.y};\n}\n\ninline __device__ float3 operator+(const float3 &a, const float3 &b) {\n  return {a.x + b.x, a.y + b.y, a.z + b.z};\n}\n\ninline __device__ float4 operator+(const float4 & a, const float4 & b){\n    return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w};\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<int BYTES> struct BytesToType {};\n\ntemplate<> struct BytesToType<16> {\n    using Type = uint4;\n    static_assert(sizeof(Type) == 16);\n};\n\ntemplate<> struct BytesToType<8> {\n    using Type = uint64_t;\n    static_assert(sizeof(Type) == 8);\n};\n\ntemplate<> struct BytesToType<4> {\n    using Type = uint32_t;\n    static_assert(sizeof(Type) == 4);\n};\n\ntemplate<> struct BytesToType<2> {\n    using Type = uint16_t;\n    static_assert(sizeof(Type) == 2);\n};\n\ntemplate<> struct BytesToType<1> {\n    using Type = uint8_t;\n    static_assert(sizeof(Type) == 1);\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename scalar_t, int N>\nstruct Converter{\n    static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) {\n        #pragma unroll\n        for (int i = 0; i < N; ++i) { dst[i] = src[i]; }\n    }\n};\n\ntemplate<int N>\nstruct Converter<at::Half, N>{\n    static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) {\n        static_assert(N % 2 == 0);\n        auto &src2 = reinterpret_cast<const half2 (&)[N / 2]>(src);\n        auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);\n        #pragma unroll\n        for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); }\n    }\n};\n\n#if __CUDA_ARCH__ >= 800\ntemplate<int N>\nstruct Converter<at::BFloat16, N>{\n    static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) {\n        static_assert(N % 2 == 0);\n        auto &src2 = reinterpret_cast<const nv_bfloat162 (&)[N / 2]>(src);\n        auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);\n        #pragma unroll\n        for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); }\n    }\n};\n#endif\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n// From https://stackoverflow.com/questions/9860711/cucomplex-h-and-exp\n// and https://forums.developer.nvidia.com/t/complex-number-exponential-function/24696\n__device__ __forceinline__ complex_t cexp2f(complex_t z) {\n    float t = exp2f(z.real_);\n    float c, s;\n    sincosf(z.imag_, &s, &c);\n    return complex_t(c * t, s * t);\n}\n\n__device__ __forceinline__ complex_t cexpf(complex_t z) {\n    float t = expf(z.real_);\n    float c, s;\n    sincosf(z.imag_, &s, &c);\n    return complex_t(c * t, s * t);\n}\n\ntemplate<typename scalar_t> struct SSMScanOp;\n\ntemplate<>\nstruct SSMScanOp<float> {\n    __device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const {\n        return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y);\n    }\n};\n\ntemplate<>\nstruct SSMScanOp<complex_t> {\n    __device__ __forceinline__ float4 operator()(const float4 &ab0, const float4 &ab1) const {\n        complex_t a0 = complex_t(ab0.x, ab0.y);\n        complex_t b0 = complex_t(ab0.z, ab0.w);\n        complex_t a1 = complex_t(ab1.x, ab1.y);\n        complex_t b1 = complex_t(ab1.z, ab1.w);\n        complex_t out_a = a1 * a0;\n        complex_t out_b = a1 * b0 + b1;\n        return make_float4(out_a.real_, out_a.imag_, out_b.real_, out_b.imag_);\n    }\n};\n\n// A stateful callback functor that maintains a running prefix to be applied\n// during consecutive scan operations.\ntemplate <typename scalar_t> struct SSMScanPrefixCallbackOp {\n    using scan_t = std::conditional_t<std::is_same_v<scalar_t, float>, float2, float4>;\n    scan_t running_prefix;\n    // Constructor\n    __device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {}\n    // Callback operator to be entered by the first warp of threads in the block.\n    // Thread-0 is responsible for returning a value for seeding the block-wide scan.\n    __device__ scan_t operator()(scan_t block_aggregate) {\n        scan_t old_prefix = running_prefix;\n        running_prefix = SSMScanOp<scalar_t>()(running_prefix, block_aggregate);\n        return old_prefix;\n    }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename Ktraits>\ninline __device__ void load_input(typename Ktraits::input_t *u,\n                                  typename Ktraits::input_t (&u_vals)[Ktraits::kNItems],\n                                  typename Ktraits::BlockLoadT::TempStorage &smem_load,\n                                  int seqlen) {\n    if constexpr (Ktraits::kIsEvenLen) {\n        auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_load);\n        using vec_t = typename Ktraits::vec_t;\n        typename Ktraits::BlockLoadVecT(smem_load_vec).Load(\n            reinterpret_cast<vec_t*>(u),\n            reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(u_vals)\n            #ifdef USE_ROCM\n                , Ktraits::kNThreads * Ktraits::kNLoads\n            #endif\n            \n       );\n    } else {\n        typename Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f);\n    }\n}\n\ntemplate<typename Ktraits>\ninline __device__ void load_weight(typename Ktraits::input_t *Bvar,\n                                   typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems],\n                                   typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight,\n                                   int seqlen) {\n    constexpr int kNItems = Ktraits::kNItems;\n    if constexpr (!Ktraits::kIsComplex) {\n        typename Ktraits::input_t B_vals_load[kNItems];\n        if constexpr (Ktraits::kIsEvenLen) {\n            auto& smem_load_weight_vec = reinterpret_cast<typename Ktraits::BlockLoadWeightVecT::TempStorage&>(smem_load_weight);\n            using vec_t = typename Ktraits::vec_t;\n            typename Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load(\n                reinterpret_cast<vec_t*>(Bvar),\n                reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(B_vals_load)\n          );\n        } else {\n            typename Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f);\n        }\n        // #pragma unroll\n        // for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; }\n        Converter<typename Ktraits::input_t, kNItems>::to_float(B_vals_load, B_vals);\n    } else {\n        typename Ktraits::input_t B_vals_load[kNItems * 2];\n        if constexpr (Ktraits::kIsEvenLen) {\n            auto& smem_load_weight_vec = reinterpret_cast<typename Ktraits::BlockLoadWeightVecT::TempStorage&>(smem_load_weight);\n            using vec_t = typename Ktraits::vec_t;\n            typename Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load(\n                reinterpret_cast<vec_t*>(Bvar),\n                reinterpret_cast<vec_t(&)[Ktraits::kNLoads * 2]>(B_vals_load)\n          );\n        } else {\n            typename Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f);\n        }\n        #pragma unroll\n        for (int i = 0; i < kNItems; ++i) { B_vals[i] = complex_t(B_vals_load[i * 2], B_vals_load[i * 2 + 1]); }\n    }\n}\n\ntemplate<typename Ktraits>\ninline __device__ void store_output(typename Ktraits::input_t *out,\n                                    const float (&out_vals)[Ktraits::kNItems],\n                                    typename Ktraits::BlockStoreT::TempStorage &smem_store,\n                                    int seqlen) {\n    typename Ktraits::input_t write_vals[Ktraits::kNItems];\n    #pragma unroll\n    for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; }\n    if constexpr (Ktraits::kIsEvenLen) {\n        auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_store);\n        using vec_t = typename Ktraits::vec_t;\n        typename Ktraits::BlockStoreVecT(smem_store_vec).Store(\n            reinterpret_cast<vec_t*>(out),\n            reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(write_vals)\n       );\n    } else {\n        typename Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen);\n    }\n}\n"
  },
  {
    "path": "csrc/selective_scan/selective_scan_fwd_bf16.cu",
    "content": "/******************************************************************************\n * Copyright (c) 2023, Tri Dao.\n ******************************************************************************/\n\n// Split into multiple files to compile in paralell\n\n#include \"selective_scan_fwd_kernel.cuh\"\n\ntemplate void selective_scan_fwd_cuda<at::BFloat16, float>(SSMParamsBase &params, cudaStream_t stream);\ntemplate void selective_scan_fwd_cuda<at::BFloat16, complex_t>(SSMParamsBase &params, cudaStream_t stream);"
  },
  {
    "path": "csrc/selective_scan/selective_scan_fwd_fp16.cu",
    "content": "/******************************************************************************\n * Copyright (c) 2023, Tri Dao.\n ******************************************************************************/\n\n// Split into multiple files to compile in paralell\n\n#include \"selective_scan_fwd_kernel.cuh\"\n\ntemplate void selective_scan_fwd_cuda<at::Half, float>(SSMParamsBase &params, cudaStream_t stream);\ntemplate void selective_scan_fwd_cuda<at::Half, complex_t>(SSMParamsBase &params, cudaStream_t stream);"
  },
  {
    "path": "csrc/selective_scan/selective_scan_fwd_fp32.cu",
    "content": "/******************************************************************************\n * Copyright (c) 2023, Tri Dao.\n ******************************************************************************/\n\n// Split into multiple files to compile in paralell\n\n#include \"selective_scan_fwd_kernel.cuh\"\n\ntemplate void selective_scan_fwd_cuda<float, float>(SSMParamsBase &params, cudaStream_t stream);\ntemplate void selective_scan_fwd_cuda<float, complex_t>(SSMParamsBase &params, cudaStream_t stream);"
  },
  {
    "path": "csrc/selective_scan/selective_scan_fwd_kernel.cuh",
    "content": "/******************************************************************************\n * Copyright (c) 2023, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n#include <c10/util/BFloat16.h>\n#include <c10/util/Half.h>\n#include <c10/cuda/CUDAException.h>  // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK\n\n#ifndef USE_ROCM\n    #include <cub/block/block_load.cuh>\n    #include <cub/block/block_store.cuh>\n    #include <cub/block/block_scan.cuh>\n#else\n    #include <hipcub/hipcub.hpp>\n    namespace cub = hipcub;\n#endif\n\n#include \"selective_scan.h\"\n#include \"selective_scan_common.h\"\n#include \"static_switch.h\"\n\ntemplate<int kNThreads_, int kNItems_, int kNRows_, bool kIsEvenLen_,\n         bool kIsVariableB_, bool kIsVariableC_,\n         bool kHasZ_, typename input_t_, typename weight_t_>\nstruct Selective_Scan_fwd_kernel_traits {\n    static_assert(kNItems_ % 4 == 0);\n    using input_t = input_t_;\n    using weight_t = weight_t_;\n    static constexpr int kNThreads = kNThreads_;\n    // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy.\n    static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3;\n    static constexpr int kNItems = kNItems_;\n    static constexpr int kNRows = kNRows_;\n    static constexpr int kNBytes = sizeof(input_t);\n    static_assert(kNBytes == 2 || kNBytes == 4);\n    static constexpr int kNElts = kNBytes == 4 ? 4 : constexpr_min(8, kNItems);\n    static_assert(kNItems % kNElts == 0);\n    static constexpr int kNLoads = kNItems / kNElts;\n    static constexpr bool kIsComplex = std::is_same_v<weight_t, complex_t>;\n    static constexpr bool kIsEvenLen = kIsEvenLen_;\n    static constexpr bool kIsVariableB = kIsVariableB_;\n    static constexpr bool kIsVariableC = kIsVariableC_;\n    static constexpr bool kHasZ = kHasZ_;\n\n    static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1;\n\n    using vec_t = typename BytesToType<kNBytes * kNElts>::Type;\n    using scan_t = std::conditional_t<!kIsComplex, float2, float4>;\n    using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;\n    using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads,\n        !kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;\n    using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, !kIsComplex ? kNItems : kNItems * 2, cub::BLOCK_LOAD_WARP_TRANSPOSE>;\n    using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, !kIsComplex ? kNLoads : kNLoads * 2,\n        !kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE  : cub::BLOCK_LOAD_DIRECT>;\n    using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;\n    using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, kNLoads,\n        !kDirectIO ? cub::BLOCK_STORE_WARP_TRANSPOSE : cub::BLOCK_STORE_DIRECT>;\n    // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING_MEMOIZE>;\n    // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING>;\n    using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>;\n    static constexpr int kSmemIOSize = custom_max({sizeof(typename BlockLoadT::TempStorage),\n                                                 sizeof(typename BlockLoadVecT::TempStorage),\n                                                 (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage),\n                                                 (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage),\n                                                 sizeof(typename BlockStoreT::TempStorage),\n                                                 sizeof(typename BlockStoreVecT::TempStorage)});\n    static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage);\n};\n\ntemplate<typename Ktraits>\n__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks)\nvoid selective_scan_fwd_kernel(SSMParamsBase params) {\n    constexpr bool kIsComplex = Ktraits::kIsComplex;\n    constexpr bool kIsVariableB = Ktraits::kIsVariableB;\n    constexpr bool kIsVariableC = Ktraits::kIsVariableC;\n    constexpr bool kHasZ = Ktraits::kHasZ;\n    constexpr int kNThreads = Ktraits::kNThreads;\n    constexpr int kNItems = Ktraits::kNItems;\n    constexpr int kNRows = Ktraits::kNRows;\n    constexpr bool kDirectIO = Ktraits::kDirectIO;\n    using input_t = typename Ktraits::input_t;\n    using weight_t = typename Ktraits::weight_t;\n    using scan_t = typename Ktraits::scan_t;\n\n    // Shared memory.\n    extern __shared__ char smem_[];\n    // cast to lvalue reference of expected type\n    // char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t);\n    // auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_ + 2 * MAX_DSTATE * sizeof(weight_t));\n    // auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_loadstorescan);\n    auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);\n    auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);\n    auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));\n    auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);\n    auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);\n    // weight_t *smem_a = reinterpret_cast<weight_t *>(smem_ + smem_loadstorescan_size);\n    // weight_t *smem_bc = reinterpret_cast<weight_t *>(smem_a + MAX_DSTATE);\n    scan_t *smem_running_prefix = reinterpret_cast<scan_t *>(smem_ + Ktraits::kSmemSize);\n\n    const int batch_id = blockIdx.x;\n    const int dim_id = blockIdx.y;\n    const int group_id = dim_id / (params.dim_ngroups_ratio);\n    input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + batch_id * params.u_batch_stride\n        + dim_id * kNRows * params.u_d_stride;\n    input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + batch_id * params.delta_batch_stride\n        + dim_id * kNRows * params.delta_d_stride;\n    weight_t *A = reinterpret_cast<weight_t *>(params.A_ptr) + dim_id * kNRows * params.A_d_stride;\n    weight_t *B = reinterpret_cast<weight_t *>(params.B_ptr) + dim_id * kNRows * params.B_d_stride;\n    input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride;\n    weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * kNRows * params.C_d_stride;\n    input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;\n    scan_t *x = reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate;\n\n    float D_val[kNRows] = {0};\n    if (params.D_ptr != nullptr) {\n        #pragma unroll\n        for (int r = 0; r < kNRows; ++r) {\n            D_val[r] = reinterpret_cast<float *>(params.D_ptr)[dim_id * kNRows + r];\n        }\n    }\n    float delta_bias[kNRows] = {0};\n    if (params.delta_bias_ptr != nullptr) {\n        #pragma unroll\n        for (int r = 0; r < kNRows; ++r) {\n            delta_bias[r] = reinterpret_cast<float *>(params.delta_bias_ptr)[dim_id * kNRows + r];\n        }\n    }\n\n    // for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) {\n    //     smem_a[state_idx] = A[state_idx * params.A_dstate_stride];\n    //     smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride];\n    // }\n\n    constexpr int kChunkSize = kNThreads * kNItems;\n    for (int chunk = 0; chunk < params.n_chunks; ++chunk) {\n        input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems];\n        __syncthreads();\n        #pragma unroll\n        for (int r = 0; r < kNRows; ++r) {\n            if constexpr (!kDirectIO) {\n                if (r > 0) { __syncthreads(); }\n            }\n            load_input<Ktraits>(u + r * params.u_d_stride, u_vals[r], smem_load, params.seqlen - chunk * kChunkSize);\n            if constexpr (!kDirectIO) { __syncthreads(); }\n            load_input<Ktraits>(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize);\n        }\n        u += kChunkSize;\n        delta += kChunkSize;\n    \n        float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems];\n        #pragma unroll\n        for (int r = 0; r < kNRows; ++r) {\n            #pragma unroll\n            for (int i = 0; i < kNItems; ++i) {\n                float u_val = float(u_vals[r][i]);\n                delta_vals[r][i] = float(delta_vals_load[r][i]) + delta_bias[r];\n                if (params.delta_softplus) {\n                    delta_vals[r][i] = delta_vals[r][i] <= 20.f ? log1pf(expf(delta_vals[r][i])) : delta_vals[r][i];\n                }\n                delta_u_vals[r][i] = delta_vals[r][i] * u_val;\n                out_vals[r][i] = D_val[r] * u_val;\n            }\n        }\n\n        __syncthreads();\n        for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {\n            weight_t A_val[kNRows];\n            #pragma unroll\n            for (int r = 0; r < kNRows; ++r) {\n                A_val[r] = A[state_idx * params.A_dstate_stride + r * params.A_d_stride];\n                // Multiply the real part of A with LOG2E so we can use exp2f instead of expf.\n                constexpr float kLog2e = M_LOG2E;\n                if constexpr (!kIsComplex) {\n                    A_val[r] *= kLog2e;\n                } else {\n                    A_val[r].real_ *= kLog2e;\n                }\n            }\n            // This variable holds B * C if both B and C are constant across seqlen. If only B varies\n            // across seqlen, this holds C. If only C varies across seqlen, this holds B.\n            // If both B and C vary, this is unused.\n            weight_t BC_val[kNRows];\n            weight_t B_vals[kNItems], C_vals[kNItems];\n            if constexpr (kIsVariableB) {\n                load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,\n                    smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2));\n                if constexpr (!kIsVariableC) {\n                    #pragma unroll\n                    for (int r = 0; r < kNRows; ++r) {\n                        BC_val[r] = C[state_idx * params.C_dstate_stride + r * params.C_d_stride];\n                    }\n                }\n            }\n            if constexpr (kIsVariableC) {\n                auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1;\n                load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,\n                    smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2));\n                if constexpr (!kIsVariableB) {\n                    #pragma unroll\n                    for (int r = 0; r < kNRows; ++r) {\n                        BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride];\n                    }\n                }\n            }\n            if constexpr (!kIsVariableB && !kIsVariableC) {\n                #pragma unroll\n                for (int r = 0; r < kNRows; ++r) {\n                    BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride] * C[state_idx * params.C_dstate_stride + r * params.C_d_stride];\n                }\n            }\n\n            #pragma unroll\n            for (int r = 0; r < kNRows; ++r) {\n                if (r > 0) { __syncthreads(); }  // Scan could be using the same smem\n                scan_t thread_data[kNItems];\n                #pragma unroll\n                for (int i = 0; i < kNItems; ++i) {\n                    if constexpr (!kIsComplex) {\n                        thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]),\n                                                     !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]);\n                        if constexpr (!Ktraits::kIsEvenLen) {  // So that the last state is correct\n                            if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) {\n                                thread_data[i] = make_float2(1.f, 0.f);\n                            }\n                        }\n                    } else {\n                        // Pytorch's implementation of complex exp (which calls thrust) is very slow\n                        complex_t delta_a_exp = cexp2f(delta_vals[r][i] * A_val[r]);\n                        weight_t B_delta_u_val = !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i];\n                        thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_);\n                        if constexpr (!Ktraits::kIsEvenLen) {  // So that the last state is correct\n                            if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) {\n                                thread_data[i] = make_float4(1.f, 0.f, 0.f, 0.f);\n                            }\n                        }\n                    }\n                }\n                // Initialize running total\n                scan_t running_prefix;\n                if constexpr (!kIsComplex) {\n                    // If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read\n                    running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.f, 0.f);\n                    // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f);\n                } else {\n                    running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float4(1.f, 0.f, 0.f, 0.f);\n                    // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f);\n                }\n                SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);\n                typename Ktraits::BlockScanT(smem_scan).InclusiveScan(\n                    thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op\n                );\n                // There's a syncthreads in the scan op, so we don't need to sync here.\n                // Unless there's only 1 warp, but then it's the same thread (0) reading and writing.\n                if (threadIdx.x == 0) {\n                    smem_running_prefix[state_idx] = prefix_op.running_prefix;\n                    x[(r * params.n_chunks + chunk) * params.dstate + state_idx] = prefix_op.running_prefix;\n                }\n                #pragma unroll\n                for (int i = 0; i < kNItems; ++i) {\n                    const weight_t C_val = !kIsVariableC\n                        ? BC_val[r]\n                        : (!kIsVariableB ? BC_val[r] * C_vals[i] : C_vals[i]);\n                    if constexpr (!kIsComplex) {\n                        out_vals[r][i] += thread_data[i].y * C_val;\n                    } else {\n                        out_vals[r][i] += (complex_t(thread_data[i].z, thread_data[i].w) * C_val).real_ * 2;\n                    }\n                }\n            }\n        }\n        \n        input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride\n            + dim_id * kNRows * params.out_d_stride + chunk * kChunkSize;\n        __syncthreads();\n        #pragma unroll\n        for (int r = 0; r < kNRows; ++r) {\n            if constexpr (!kDirectIO) {\n                if (r > 0) { __syncthreads(); }\n            }\n            store_output<Ktraits>(out + r * params.out_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize);\n        }\n\n        if constexpr (kHasZ) {\n            input_t *z = reinterpret_cast<input_t *>(params.z_ptr) + batch_id * params.z_batch_stride\n                + dim_id * kNRows * params.z_d_stride + chunk * kChunkSize;\n            input_t *out_z = reinterpret_cast<input_t *>(params.out_z_ptr) + batch_id * params.out_z_batch_stride\n                + dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize;\n            #pragma unroll\n            for (int r = 0; r < kNRows; ++r) {\n                input_t z_vals[kNItems];\n                __syncthreads();\n                load_input<Ktraits>(z + r * params.z_d_stride, z_vals, smem_load, params.seqlen - chunk * kChunkSize);\n                #pragma unroll\n                for (int i = 0; i < kNItems; ++i) {\n                    float z_val = z_vals[i];\n                    out_vals[r][i] *= z_val / (1 + expf(-z_val));\n                }\n                __syncthreads();\n                store_output<Ktraits>(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize);\n            }\n        }\n\n        Bvar += kChunkSize * (!kIsComplex ? 1 : 2);\n        Cvar += kChunkSize * (!kIsComplex ? 1 : 2);\n    }\n}\n\ntemplate<int kNThreads, int kNItems, typename input_t, typename weight_t>\nvoid selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {\n    // Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block\n    // processing 1 row.\n    constexpr int kNRows = 1;\n    BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {\n        BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] {\n            BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] {\n                BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {\n                    using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, input_t, weight_t>;\n                    \n                    constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);\n                    dim3 grid(params.batch, params.dim / kNRows);\n\n                    // Had to change this substantially since potentially the hip \n                    // interface for setting kernel launch attributes is slightly different from \n                    // cuda's. In particualar, it seems to expect a plain const void * pointer.\n\n                    auto kernel = &selective_scan_fwd_kernel<Ktraits>;\n\n                    \n                    if (kSmemSize >= 48 * 1024) {\n                        #ifndef USE_ROCM\n                        C10_CUDA_CHECK(cudaFuncSetAttribute(\n                            kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));\n                        #else\n                        C10_CUDA_CHECK(cudaFuncSetAttribute(\n                            (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));\n                            std::cerr << \"Warning (selective_scan_fwd_kernel): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \\n\" << std::endl;\n                        #endif\n                    }\n\n                    kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);\n                    C10_CUDA_KERNEL_LAUNCH_CHECK();\n                });\n            });\n        });\n    });\n}\n\ntemplate<typename input_t, typename weight_t>\nvoid selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream) {\n\n    #ifndef USE_ROCM\n        if (params.seqlen <= 128) {           \n            selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream);\n        } else if (params.seqlen <= 256) {\n            selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream);\n        } else if (params.seqlen <= 512) {\n            selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream);\n        } else if (params.seqlen <= 1024) {\n            selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream);\n        } else {\n            selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream);\n        }\n    #else\n        if (params.seqlen <= 256) {\n            selective_scan_fwd_launch<64, 4, input_t, weight_t>(params, stream);\n        } else if (params.seqlen <= 512) {\n            selective_scan_fwd_launch<64, 8, input_t, weight_t>(params, stream);\n        } else if (params.seqlen <= 1024) {\n            selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream);\n        } else {\n            selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream);\n        }\n    #endif\n}\n"
  },
  {
    "path": "csrc/selective_scan/static_switch.h",
    "content": "// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h\n// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h\n\n#pragma once\n\n/// @param COND       - a boolean expression to switch by\n/// @param CONST_NAME - a name given for the constexpr bool variable.\n/// @param ...       - code to execute for true and false\n///\n/// Usage:\n/// ```\n/// BOOL_SWITCH(flag, BoolConst, [&] {\n///     some_function<BoolConst>(...);\n/// });\n/// ```\n#define BOOL_SWITCH(COND, CONST_NAME, ...)                                           \\\n    [&] {                                                                            \\\n        if (COND) {                                                                  \\\n            constexpr bool CONST_NAME = true;                                        \\\n            return __VA_ARGS__();                                                    \\\n        } else {                                                                     \\\n            constexpr bool CONST_NAME = false;                                       \\\n            return __VA_ARGS__();                                                    \\\n        }                                                                            \\\n    }()\n"
  },
  {
    "path": "csrc/selective_scan/uninitialized_copy.cuh",
    "content": "/******************************************************************************\n * Copyright (c) 2011-2022, NVIDIA CORPORATION.  All rights reserved.\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright\n *       notice, this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n *       notice, this list of conditions and the following disclaimer in the\n *       documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the\n *       names of its contributors may be used to endorse or promote products\n *       derived from this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE\n * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n ******************************************************************************/\n\n#pragma once\n\n#ifndef USE_ROCM\n    #include <cub/config.cuh>\n\n    #include <cuda/std/type_traits>\n#else\n    #include <hipcub/hipcub.hpp>\n    // Map ::cuda::std to the standard std namespace\n    namespace cuda {\n        namespace std = ::std;\n    }\n#endif\n\n\nnamespace detail\n{\n\n#if defined(_NVHPC_CUDA)\ntemplate <typename T, typename U>\n__host__ __device__ void uninitialized_copy(T *ptr, U &&val)\n{\n  // NVBug 3384810\n  new (ptr) T(::cuda::std::forward<U>(val));\n}\n#else\ntemplate <typename T,\n          typename U,\n          typename ::cuda::std::enable_if<\n            ::cuda::std::is_trivially_copyable<T>::value,\n            int\n          >::type = 0>\n__host__ __device__ void uninitialized_copy(T *ptr, U &&val)\n{\n  *ptr = ::cuda::std::forward<U>(val);\n}\n\ntemplate <typename T,\n         typename U,\n         typename ::cuda::std::enable_if<\n           !::cuda::std::is_trivially_copyable<T>::value,\n           int\n         >::type = 0>\n__host__ __device__ void uninitialized_copy(T *ptr, U &&val)\n{\n  new (ptr) T(::cuda::std::forward<U>(val));\n}\n#endif\n\n} // namespace detail\n"
  },
  {
    "path": "evals/lm_harness_eval.py",
    "content": "import torch\n\nimport transformers\nfrom transformers import AutoTokenizer\n\nfrom mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel\n\nfrom lm_eval.api.model import LM\nfrom lm_eval.models.huggingface import HFLM\nfrom lm_eval.api.registry import register_model\nfrom lm_eval.__main__ import cli_evaluate\n\n\n@register_model(\"mamba\")\nclass MambaEvalWrapper(HFLM):\n\n    AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM\n\n    def __init__(self, pretrained=\"state-spaces/mamba-2.8b\", max_length=2048, batch_size=None, device=\"cuda\",\n                 dtype=torch.float16):\n        LM.__init__(self)\n        self._model = MambaLMHeadModel.from_pretrained(pretrained, device=device, dtype=dtype)\n        self.tokenizer = AutoTokenizer.from_pretrained(\"EleutherAI/gpt-neox-20b\")\n        self.tokenizer.pad_token_id = self.tokenizer.eos_token_id\n        self.vocab_size = self.tokenizer.vocab_size\n        self._batch_size = int(batch_size) if batch_size is not None else 64\n        self._max_length = max_length\n        self._device = torch.device(device)\n\n    @property\n    def batch_size(self):\n        return self._batch_size\n\n    def _model_generate(self, context, max_length, stop, **generation_kwargs):\n        raise NotImplementedError()\n\n\nif __name__ == \"__main__\":\n    cli_evaluate()\n"
  },
  {
    "path": "mamba_ssm/__init__.py",
    "content": "__version__ = \"2.3.1\"\n\nfrom mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn\nfrom mamba_ssm.modules.mamba_simple import Mamba\nfrom mamba_ssm.modules.mamba2 import Mamba2\nfrom mamba_ssm.modules.mamba3 import Mamba3\nfrom mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel\n"
  },
  {
    "path": "mamba_ssm/distributed/__init__.py",
    "content": ""
  },
  {
    "path": "mamba_ssm/distributed/distributed_utils.py",
    "content": "from typing import Optional\n\nimport torch\nfrom torch import Tensor\nfrom torch.distributed import ProcessGroup\n\n# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for\n# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent\n# version of PyTorch. The following 4 lines are for backward compatibility with\n# older PyTorch.\nif \"all_gather_into_tensor\" not in dir(torch.distributed):\n    torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base\nif \"reduce_scatter_tensor\" not in dir(torch.distributed):\n    torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base\n\n\n# Raw operation, does not support autograd, but does support async\ndef all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):\n    world_size = torch.distributed.get_world_size(process_group)\n    output = torch.empty(\n        world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device\n    )\n    handle = torch.distributed.all_gather_into_tensor(\n        output, input_.contiguous(), group=process_group, async_op=async_op\n    )\n    return output, handle\n\n\n# Raw operation, does not support autograd, but does support async\ndef reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):\n    world_size = torch.distributed.get_world_size(process_group)\n    assert input_.shape[0] % world_size == 0\n    output = torch.empty(\n        input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device\n    )\n    handle = torch.distributed.reduce_scatter_tensor(\n        output, input_.contiguous(), group=process_group, async_op=async_op\n    )\n    return output, handle\n\n\n# Raw operation, does not support autograd, but does support async\ndef all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):\n    input_ = input_.contiguous()\n    handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op)\n    return input_, handle\n\n\nclass AllGatherFunc(torch.autograd.Function):\n    \"\"\"Gather the input from sequence parallel region and concatenate.\"\"\"\n\n    @staticmethod\n    def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:\n        ctx.process_group = process_group\n        output, _ = all_gather_raw(input_, process_group)\n        return output\n\n    @staticmethod\n    def backward(ctx, grad_output: Tensor):\n        grad_input, _ = reduce_scatter_raw(grad_output, ctx.process_group)\n        return grad_input, None\n\n\n# Supports autograd, but does not support async\nall_gather = AllGatherFunc.apply\n\n\nclass ReduceScatterFunc(torch.autograd.Function):\n    \"\"\"Reduce scatter the input from the sequence parallel region and concatenate.\"\"\"\n\n    @staticmethod\n    def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:\n        ctx.process_group = process_group\n        output, _ = reduce_scatter_raw(input_, process_group)\n        return output\n\n    @staticmethod\n    def backward(ctx, grad_output: Tensor):\n        grad_input, _ = all_gather_raw(grad_output, ctx.process_group)\n        return grad_input, None\n\n\n# Supports autograd, but does not support async\nreduce_scatter = ReduceScatterFunc.apply\n\n\nclass AllReduceFunc(torch.autograd.Function):\n    \"\"\"Gather the input from sequence parallel region and concatenate.\"\"\"\n\n    @staticmethod\n    def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:\n        ctx.process_group = process_group\n        output, _ = all_reduce_raw(input_, process_group)\n        return output\n\n    @staticmethod\n    def backward(ctx, grad_output: Tensor):\n        return grad_output, None\n\n\n# Supports autograd, but does not support async\nall_reduce = AllReduceFunc.apply\n\n\ndef sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup):\n    # We want to iterate over parameters with _shared_params=True in the same order,\n    # as different ranks might have different number of parameters (e.g., only rank 0 has bias).\n    pamams_shared = {\n        name: p for name, p in model.named_parameters() if getattr(p, \"_shared_params\", False)\n    }\n    for _, p in sorted(pamams_shared.items()):\n        with torch.no_grad():\n            # Broadcast needs src to be global rank, not group rank\n            torch.distributed.broadcast(\n                p, src=torch.distributed.get_global_rank(process_group, 0), group=process_group\n            )\n\n\n# Ref: https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/optimizer/optimizer.py#L256\ndef allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup):\n    # We want to iterate over parameters with _sequence_parallel=True in the same order,\n    # as different ranks might have different number of parameters (e.g., only rank 0 has bias).\n    params_seqparallel = {\n        name: p for name, p in model.named_parameters() if getattr(p, \"_sequence_parallel\", False)\n    }\n    grads = [p.grad for _, p in sorted(params_seqparallel.items())]\n    if grads:\n        with torch.no_grad():\n            coalesced = torch._utils._flatten_dense_tensors(grads)\n            torch.distributed.all_reduce(coalesced, group=process_group)\n            for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)):\n                buf.copy_(synced)\n\n\ndef get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, multiple_of: int = 1) -> int:\n    \"\"\"Get the dim for the local rank derived from splitting dim on world_size processes.\n\n    The split may not be even across the world_size processes.\n    \"\"\"\n    multiple = dim // multiple_of\n    div = multiple // world_size\n    mod = multiple % world_size\n    local_multiple = div + int(local_rank < mod)\n    return local_multiple * multiple_of\n"
  },
  {
    "path": "mamba_ssm/distributed/tensor_parallel.py",
    "content": "# Copyright (c) 2024, Tri Dao.\n# The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py\nfrom typing import Optional\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom torch.distributed import ProcessGroup\nfrom mamba_ssm.utils.torch import custom_bwd, custom_fwd\n\nfrom einops import rearrange\n\nfrom mamba_ssm.distributed.distributed_utils import (\n    all_gather_raw,\n    all_reduce,\n    all_reduce_raw,\n    reduce_scatter,\n    reduce_scatter_raw,\n)\n\n\nclass ParallelLinearFunc(torch.autograd.Function):\n    @staticmethod\n    @custom_fwd\n    def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True):\n        \"\"\"\n        If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel\n        with sequence parallelism: we do an all_gather_raw of x before doing the matmul.\n        \"\"\"\n        ctx.compute_weight_gradient = weight.requires_grad\n        ctx.process_group = process_group\n        ctx.sequence_parallel = sequence_parallel\n\n        if torch.is_autocast_enabled():\n            x = x.to(dtype=torch.get_autocast_gpu_dtype())\n        x = x.contiguous()\n        if process_group is not None and sequence_parallel:\n            # We want to kick off the all_gather early, before weight dtype conversion\n            total_x, handle_x = all_gather_raw(x, process_group, async_op=True)\n        else:\n            total_x = x\n\n        if torch.is_autocast_enabled():\n            weight = weight.to(dtype=torch.get_autocast_gpu_dtype())\n            bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None\n        weight = weight.contiguous()\n        if process_group is not None and sequence_parallel:\n            handle_x.wait()\n        batch_shape, n = total_x.shape[:-1], total_x.shape[-1]\n        batch_dim = batch_shape.numel()\n        # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174\n        output = F.linear(total_x, weight, bias)\n        if ctx.compute_weight_gradient:\n            ctx.save_for_backward(x, weight)\n        else:\n            ctx.save_for_backward(weight)\n        return output\n\n    @staticmethod\n    @custom_bwd\n    def backward(ctx, grad_output):\n        grad_output = grad_output.contiguous()\n        process_group = ctx.process_group\n        sequence_parallel = ctx.sequence_parallel\n        if ctx.compute_weight_gradient:\n            x, weight = ctx.saved_tensors\n            if process_group is not None and sequence_parallel:\n                total_x, handle_x = all_gather_raw(x, process_group, async_op=True)\n            else:\n                total_x = x\n        else:\n            (weight,) = ctx.saved_tensors\n            total_x = None\n        batch_shape = grad_output.shape[:-1]\n        batch_dim = batch_shape.numel()\n        grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])\n        if ctx.needs_input_grad[0]:\n            grad_input = F.linear(grad_output, weight.t())\n            grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])\n            if process_group is not None:\n                reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw\n                grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)\n        else:\n            grad_input = None\n        if ctx.needs_input_grad[1]:\n            assert ctx.compute_weight_gradient\n            if process_group is not None and sequence_parallel:\n                handle_x.wait()\n            grad_weight = torch.einsum(\n                \"bo,bi->oi\", grad_output, total_x.reshape(batch_dim, total_x.shape[-1])\n            )\n        else:\n            grad_weight = None\n        grad_bias = grad_output.sum(dim=0) if ctx.needs_input_grad[2] else None\n        if process_group is not None and ctx.needs_input_grad[0]:\n            handle_grad_input.wait()\n        return grad_input, grad_weight, grad_bias, None, None\n\n\ndef parallel_linear_func(\n    x: Tensor,\n    weight: Tensor,\n    bias: Optional[Tensor] = None,\n    process_group: Optional[ProcessGroup] = None,\n    sequence_parallel: bool = True,\n):\n    return ParallelLinearFunc.apply(x, weight, bias, process_group, sequence_parallel)\n\n\nclass ColumnParallelLinear(nn.Linear):\n    def __init__(\n        self,\n        in_features: int,\n        out_features: int,\n        process_group: ProcessGroup,\n        bias: bool = True,\n        sequence_parallel=True,\n        multiple_of=1,\n        device=None,\n        dtype=None,\n    ) -> None:\n        world_size = torch.distributed.get_world_size(process_group)\n        if out_features % multiple_of:\n            raise ValueError(f\"out_features ({out_features}) must be a multiple of {multiple_of}\")\n        multiple = out_features // multiple_of\n        # We want to split @multiple across world_size, but it could be an uneven split\n        div = multiple // world_size\n        mod = multiple % world_size\n        # The first @mod ranks get @div + 1 copies, the rest get @div copies\n        local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)\n        super().__init__(\n            in_features, local_multiple * multiple_of, bias=bias, device=device, dtype=dtype\n        )\n        self.process_group = process_group\n        self.sequence_parallel = sequence_parallel\n\n    def forward(self, x):\n        # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:\n        # we do an all_gather of x before doing the matmul.\n        # If not, then the input is already gathered.\n        return parallel_linear_func(\n            x,\n            self.weight,\n            self.bias,\n            process_group=self.process_group,\n            sequence_parallel=self.sequence_parallel,\n        )\n\n\nclass RowParallelLinear(nn.Linear):\n    def __init__(\n        self,\n        in_features: int,\n        out_features: int,\n        process_group: ProcessGroup,\n        bias: bool = True,\n        sequence_parallel=True,\n        multiple_of=1,\n        device=None,\n        dtype=None,\n    ) -> None:\n        world_size = torch.distributed.get_world_size(process_group)\n        rank = torch.distributed.get_rank(process_group)\n        if in_features % multiple_of:\n            raise ValueError(f\"in_features ({in_features}) must be a multiple of {multiple_of}\")\n        multiple = in_features // multiple_of\n        # We want to split @multiple across world_size, but it could be an uneven split\n        div = multiple // world_size\n        mod = multiple % world_size\n        # The first @mod ranks get @div + 1 copies, the rest get @div copies\n        local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)\n        # Only rank 0 will have bias\n        super().__init__(\n            local_multiple * multiple_of,\n            out_features,\n            bias=bias and rank == 0,\n            device=device,\n            dtype=dtype,\n        )\n        self.process_group = process_group\n        self.sequence_parallel = sequence_parallel\n\n    def forward(self, x):\n        \"\"\"\n        We're doing Tensor Parallel with sequence parallelism: we do the matmul and then\n        a reduce_scatter of the result.\n        \"\"\"\n        out = parallel_linear_func(x, self.weight, self.bias)\n        reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce\n        return reduce_fn(out, self.process_group)\n\n\nclass VocabParallelEmbedding(nn.Embedding):\n    def __init__(self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs):\n        self.process_group = process_group\n        if process_group is not None:\n            world_size = torch.distributed.get_world_size(process_group)\n            if num_embeddings % world_size != 0:\n                raise ValueError(\n                    f\"num_embeddings ({num_embeddings}) must be divisible by \"\n                    f\"world_size ({world_size})\"\n                )\n            if world_size > 1 and padding_idx is not None:\n                raise RuntimeError(\"ParallelEmbedding does not support padding_idx\")\n        else:\n            world_size = 1\n        super().__init__(num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs)\n\n    def forward(self, input: Tensor) -> Tensor:\n        if self.process_group is None:\n            return super().forward(input)\n        else:\n            rank = torch.distributed.get_rank(self.process_group)\n            vocab_size = self.num_embeddings\n            vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size\n            # Create a mask of valid vocab ids (1 means it needs to be masked).\n            input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index)\n            input = input - vocab_start_index\n            input[input_ids_mask] = 0\n            embeddings = super().forward(input)\n            embeddings[input_ids_mask] = 0.0\n            return embeddings\n\n\nclass ColumnParallelEmbedding(nn.Embedding):\n    def __init__(self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs):\n        self.process_group = process_group\n        if process_group is not None:\n            world_size = torch.distributed.get_world_size(process_group)\n            if embedding_dim % world_size != 0:\n                raise ValueError(\n                    f\"embedding_dim ({embedding_dim}) must be divisible by \"\n                    f\"world_size ({world_size})\"\n                )\n        else:\n            world_size = 1\n        super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs)\n\n\nclass ParallelEmbeddings(nn.Module):\n    def __init__(\n        self,\n        embed_dim,\n        vocab_size,\n        max_position_embeddings,\n        process_group,\n        padding_idx=None,\n        sequence_parallel=True,\n        device=None,\n        dtype=None,\n    ):\n        \"\"\"\n        If max_position_embeddings <= 0, there's no position embeddings\n        \"\"\"\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        self.process_group = process_group\n        self.sequence_parallel = sequence_parallel\n        self.word_embeddings = VocabParallelEmbedding(\n            vocab_size,\n            embed_dim,\n            padding_idx=padding_idx,\n            process_group=process_group,\n            **factory_kwargs,\n        )\n        self.max_position_embeddings = max_position_embeddings\n        if self.max_position_embeddings > 0:\n            self.position_embeddings = ColumnParallelEmbedding(\n                max_position_embeddings, embed_dim, process_group=process_group, **factory_kwargs\n            )\n\n    def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False):\n        \"\"\"\n        input_ids: (batch, seqlen)\n        position_ids: (batch, seqlen)\n        \"\"\"\n        batch_size, seqlen = input_ids.shape\n        world_size = torch.distributed.get_world_size(self.process_group)\n        embeddings = self.word_embeddings(input_ids)\n        if self.max_position_embeddings > 0:\n            if position_ids is None:\n                position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)\n            position_embeddings = self.position_embeddings(position_ids)\n            if world_size <= 1:\n                embeddings = embeddings + position_embeddings\n            else:\n                partition_dim = self.position_embeddings.embedding_dim\n                rank = torch.distributed.get_rank(self.process_group)\n                embeddings[\n                    ..., rank * partition_dim : (rank + 1) * partition_dim\n                ] += position_embeddings\n        if combine_batch_seqlen_dim:\n            embeddings = rearrange(embeddings, \"b s d -> (b s) d\")\n        reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce\n        return embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group)\n"
  },
  {
    "path": "mamba_ssm/models/__init__.py",
    "content": ""
  },
  {
    "path": "mamba_ssm/models/config_mamba.py",
    "content": "from dataclasses import dataclass, field\n\n\n@dataclass\nclass MambaConfig:\n\n    d_model: int = 2560\n    d_intermediate: int = 0\n    n_layer: int = 64\n    vocab_size: int = 50277\n    ssm_cfg: dict = field(default_factory=dict)\n    attn_layer_idx: list = field(default_factory=list)\n    attn_cfg: dict = field(default_factory=dict)\n    rms_norm: bool = True\n    residual_in_fp32: bool = True\n    fused_add_norm: bool = True\n    pad_vocab_size_multiple: int = 8\n    tie_embeddings: bool = True\n"
  },
  {
    "path": "mamba_ssm/models/mixer_seq_simple.py",
    "content": "# Copyright (c) 2023, Albert Gu, Tri Dao.\n\nimport math\nfrom functools import partial\nimport json\nimport os\nimport copy\n\nfrom collections import namedtuple\n\nimport torch\nimport torch.nn as nn\n\nfrom mamba_ssm.models.config_mamba import MambaConfig\nfrom mamba_ssm.modules.mamba_simple import Mamba\nfrom mamba_ssm.modules.mamba2 import Mamba2\nfrom mamba_ssm.modules.mha import MHA\nfrom mamba_ssm.modules.mlp import GatedMLP\nfrom mamba_ssm.modules.block import Block\nfrom mamba_ssm.utils.generation import GenerationMixin\nfrom mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf\n\ntry:\n    from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn\nexcept ImportError:\n    RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None\n\n\ndef create_block(\n    d_model,\n    d_intermediate,\n    ssm_cfg=None,\n    attn_layer_idx=None,\n    attn_cfg=None,\n    norm_epsilon=1e-5,\n    rms_norm=False,\n    residual_in_fp32=False,\n    fused_add_norm=False,\n    layer_idx=None,\n    device=None,\n    dtype=None,\n):\n    if ssm_cfg is None:\n        ssm_cfg = {}\n    if attn_layer_idx is None:\n        attn_layer_idx = []\n    if attn_cfg is None:\n        attn_cfg = {}\n    factory_kwargs = {\"device\": device, \"dtype\": dtype}\n    if layer_idx not in attn_layer_idx:\n        # Create a copy of the config to modify\n        ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {}\n        ssm_layer = ssm_cfg.pop(\"layer\", \"Mamba1\")\n        if ssm_layer not in [\"Mamba1\", \"Mamba2\"]:\n            raise ValueError(f\"Invalid ssm_layer: {ssm_layer}, only support Mamba1 and Mamba2\")\n        mixer_cls = partial(\n            Mamba2 if ssm_layer == \"Mamba2\" else Mamba,\n            layer_idx=layer_idx,\n            **ssm_cfg,\n            **factory_kwargs\n        )\n    else:\n        mixer_cls = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs)\n    norm_cls = partial(\n        nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs\n    )\n    if d_intermediate == 0:\n        mlp_cls = nn.Identity\n    else:\n        mlp_cls = partial(\n            GatedMLP, hidden_features=d_intermediate, out_features=d_model, **factory_kwargs\n        )\n    block = Block(\n        d_model,\n        mixer_cls,\n        mlp_cls,\n        norm_cls=norm_cls,\n        fused_add_norm=fused_add_norm,\n        residual_in_fp32=residual_in_fp32,\n    )\n    block.layer_idx = layer_idx\n    return block\n\n\n# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454\ndef _init_weights(\n    module,\n    n_layer,\n    initializer_range=0.02,  # Now only used for embedding layer.\n    rescale_prenorm_residual=True,\n    n_residuals_per_layer=1,  # Change to 2 if we have MLP\n):\n    if isinstance(module, nn.Linear):\n        if module.bias is not None:\n            if not getattr(module.bias, \"_no_reinit\", False):\n                nn.init.zeros_(module.bias)\n    elif isinstance(module, nn.Embedding):\n        nn.init.normal_(module.weight, std=initializer_range)\n\n    if rescale_prenorm_residual:\n        # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:\n        #   > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale\n        #   > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.\n        #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/\n        #\n        # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py\n        for name, p in module.named_parameters():\n            if name in [\"out_proj.weight\", \"fc2.weight\"]:\n                # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block\n                # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)\n                # We need to reinit p since this code could be called multiple times\n                # Having just p *= scale would repeatedly scale it down\n                nn.init.kaiming_uniform_(p, a=math.sqrt(5))\n                with torch.no_grad():\n                    p /= math.sqrt(n_residuals_per_layer * n_layer)\n\n\nclass MixerModel(nn.Module):\n    def __init__(\n        self,\n        d_model: int,\n        n_layer: int,\n        d_intermediate: int,\n        vocab_size: int,\n        ssm_cfg=None,\n        attn_layer_idx=None,\n        attn_cfg=None,\n        norm_epsilon: float = 1e-5,\n        rms_norm: bool = False,\n        initializer_cfg=None,\n        fused_add_norm=False,\n        residual_in_fp32=False,\n        device=None,\n        dtype=None,\n    ) -> None:\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        self.residual_in_fp32 = residual_in_fp32\n\n        self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)\n\n        # We change the order of residual and layer norm:\n        # Instead of LN -> Attn / MLP -> Add, we do:\n        # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and\n        # the main branch (output of MLP / Mixer). The model definition is unchanged.\n        # This is for performance reason: we can fuse add + layer_norm.\n        self.fused_add_norm = fused_add_norm\n        if self.fused_add_norm:\n            if layer_norm_fn is None or rms_norm_fn is None:\n                raise ImportError(\"Failed to import Triton LayerNorm / RMSNorm kernels\")\n\n        self.layers = nn.ModuleList(\n            [\n                create_block(\n                    d_model,\n                    d_intermediate=d_intermediate,\n                    ssm_cfg=ssm_cfg,\n                    attn_layer_idx=attn_layer_idx,\n                    attn_cfg=attn_cfg,\n                    norm_epsilon=norm_epsilon,\n                    rms_norm=rms_norm,\n                    residual_in_fp32=residual_in_fp32,\n                    fused_add_norm=fused_add_norm,\n                    layer_idx=i,\n                    **factory_kwargs,\n                )\n                for i in range(n_layer)\n            ]\n        )\n\n        self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(\n            d_model, eps=norm_epsilon, **factory_kwargs\n        )\n\n        self.apply(\n            partial(\n                _init_weights,\n                n_layer=n_layer,\n                **(initializer_cfg if initializer_cfg is not None else {}),\n                n_residuals_per_layer=1 if d_intermediate == 0 else 2,  # 2 if we have MLP\n            )\n        )\n\n    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):\n        return {\n            i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)\n            for i, layer in enumerate(self.layers)\n        }\n\n    def forward(self, input_ids, inference_params=None, **mixer_kwargs):\n        hidden_states = self.embedding(input_ids)\n        residual = None\n        for layer in self.layers:\n            hidden_states, residual = layer(\n                hidden_states, residual, inference_params=inference_params, **mixer_kwargs\n            )\n        if not self.fused_add_norm:\n            residual = (hidden_states + residual) if residual is not None else hidden_states\n            hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))\n        else:\n            # Set prenorm=False here since we don't need the residual\n            hidden_states = layer_norm_fn(\n                hidden_states,\n                self.norm_f.weight,\n                self.norm_f.bias,\n                eps=self.norm_f.eps,\n                residual=residual,\n                prenorm=False,\n                residual_in_fp32=self.residual_in_fp32,\n                is_rms_norm=isinstance(self.norm_f, RMSNorm)\n            )\n        return hidden_states\n\n\nclass MambaLMHeadModel(nn.Module, GenerationMixin):\n\n    def __init__(\n        self,\n        config: MambaConfig,\n        initializer_cfg=None,\n        device=None,\n        dtype=None,\n    ) -> None:\n        self.config = config\n        d_model = config.d_model\n        n_layer = config.n_layer\n        d_intermediate = config.d_intermediate\n        vocab_size = config.vocab_size\n        ssm_cfg = config.ssm_cfg\n        attn_layer_idx = config.attn_layer_idx\n        attn_cfg = config.attn_cfg\n        rms_norm = config.rms_norm\n        residual_in_fp32 = config.residual_in_fp32\n        fused_add_norm = config.fused_add_norm\n        pad_vocab_size_multiple = config.pad_vocab_size_multiple\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n\n        super().__init__()\n        if vocab_size % pad_vocab_size_multiple != 0:\n            vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)\n        self.backbone = MixerModel(\n            d_model=d_model,\n            n_layer=n_layer,\n            d_intermediate=d_intermediate,\n            vocab_size=vocab_size,\n            ssm_cfg=ssm_cfg,\n            attn_layer_idx=attn_layer_idx,\n            attn_cfg=attn_cfg,\n            rms_norm=rms_norm,\n            initializer_cfg=initializer_cfg,\n            fused_add_norm=fused_add_norm,\n            residual_in_fp32=residual_in_fp32,\n            **factory_kwargs,\n        )\n        self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)\n\n        # Initialize weights and apply final processing\n        self.apply(\n            partial(\n                _init_weights,\n                n_layer=n_layer,\n                **(initializer_cfg if initializer_cfg is not None else {}),\n            )\n        )\n        self.tie_weights()\n\n    def tie_weights(self):\n        if self.config.tie_embeddings:\n            self.lm_head.weight = self.backbone.embedding.weight\n\n    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):\n        return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)\n\n    def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0, **mixer_kwargs):\n        \"\"\"\n        \"position_ids\" is just to be compatible with Transformer generation. We don't use it.\n        num_last_tokens: if > 0, only return the logits for the last n tokens\n        \"\"\"\n        hidden_states = self.backbone(input_ids, inference_params=inference_params, **mixer_kwargs)\n        if num_last_tokens > 0:\n            hidden_states = hidden_states[:, -num_last_tokens:]\n        lm_logits = self.lm_head(hidden_states)\n        CausalLMOutput = namedtuple(\"CausalLMOutput\", [\"logits\"])\n        return CausalLMOutput(logits=lm_logits)\n\n    @classmethod\n    def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):\n        config_data = load_config_hf(pretrained_model_name)\n        config = MambaConfig(**config_data)\n        model = cls(config, device=device, dtype=dtype, **kwargs)\n        model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype))\n        return model\n\n    def save_pretrained(self, save_directory):\n        \"\"\"\n        Minimal implementation of save_pretrained for MambaLMHeadModel.\n        Save the model and its configuration file to a directory.\n        \"\"\"\n        # Ensure save_directory exists\n        os.makedirs(save_directory, exist_ok=True)\n\n        # Save the model's state_dict\n        model_path = os.path.join(save_directory, 'pytorch_model.bin')\n        torch.save(self.state_dict(), model_path)\n\n        # Save the configuration of the model\n        config_path = os.path.join(save_directory, 'config.json')\n        with open(config_path, 'w') as f:\n            json.dump(self.config.__dict__, f, indent=4)\n"
  },
  {
    "path": "mamba_ssm/modules/__init__.py",
    "content": ""
  },
  {
    "path": "mamba_ssm/modules/block.py",
    "content": "# Copyright (c) 2024, Tri Dao, Albert Gu.\nfrom typing import Optional\n\nimport torch\nfrom torch import nn, Tensor\n\nfrom mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn\n\n\nclass Block(nn.Module):\n    def __init__(\n        self, dim, mixer_cls, mlp_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False\n    ):\n        \"\"\"\n        Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection\"\n\n        This Block has a slightly different structure compared to a regular\n        prenorm Transformer block.\n        The standard block is: LN -> MHA/MLP -> Add.\n        [Ref: https://arxiv.org/abs/2002.04745]\n        Here we have: Add -> LN -> Mixer, returning both\n        the hidden_states (output of the mixer) and the residual.\n        This is purely for performance reasons, as we can fuse add and LayerNorm.\n        The residual needs to be provided (except for the very first block).\n        \"\"\"\n        super().__init__()\n        self.residual_in_fp32 = residual_in_fp32\n        self.fused_add_norm = fused_add_norm\n        self.norm = norm_cls(dim)\n        self.mixer = mixer_cls(dim)\n        if mlp_cls is not nn.Identity:\n            self.norm2 = norm_cls(dim)\n            self.mlp = mlp_cls(dim)\n        else:\n            self.mlp = None\n        if self.fused_add_norm:\n            assert RMSNorm is not None, \"RMSNorm import fails\"\n            assert isinstance(\n                self.norm, (nn.LayerNorm, RMSNorm)\n            ), \"Only LayerNorm and RMSNorm are supported for fused_add_norm\"\n\n    def forward(\n            self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None, **mixer_kwargs\n    ):\n        r\"\"\"Pass the input through the encoder layer.\n\n        Args:\n            hidden_states: the sequence to the encoder layer (required).\n            residual: hidden_states = Mixer(LN(residual))\n        \"\"\"\n        if not self.fused_add_norm:\n            residual = (hidden_states + residual) if residual is not None else hidden_states\n            hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))\n            if self.residual_in_fp32:\n                residual = residual.to(torch.float32)\n        else:\n            hidden_states, residual = layer_norm_fn(\n                hidden_states,\n                self.norm.weight,\n                self.norm.bias,\n                residual=residual,\n                prenorm=True,\n                residual_in_fp32=self.residual_in_fp32,\n                eps=self.norm.eps,\n                is_rms_norm=isinstance(self.norm, RMSNorm)\n            )\n        hidden_states = self.mixer(hidden_states, inference_params=inference_params, **mixer_kwargs)\n\n        if self.mlp is not None:\n            if not self.fused_add_norm:\n                residual = hidden_states + residual\n                hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))\n                if self.residual_in_fp32:\n                    residual = residual.to(torch.float32)\n            else:\n                hidden_states, residual = layer_norm_fn(\n                    hidden_states,\n                    self.norm2.weight,\n                    self.norm2.bias,\n                    residual=residual,\n                    prenorm=True,\n                    residual_in_fp32=self.residual_in_fp32,\n                    eps=self.norm2.eps,\n                    is_rms_norm=isinstance(self.norm2, RMSNorm)\n                )\n            hidden_states = self.mlp(hidden_states)\n\n        return hidden_states, residual\n\n    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):\n        return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)\n"
  },
  {
    "path": "mamba_ssm/modules/mamba2.py",
    "content": "# Copyright (c) 2024, Tri Dao, Albert Gu.\n\nimport math\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom einops import rearrange, repeat\n\ntry:\n    from causal_conv1d import causal_conv1d_fn, causal_conv1d_update\nexcept ImportError:\n    causal_conv1d_fn, causal_conv1d_update = None, None\n\ntry:\n    from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states\nexcept ImportError:\n    causal_conv1d_varlen_states = None\n\ntry:\n    from mamba_ssm.ops.triton.selective_state_update import selective_state_update\nexcept ImportError:\n    selective_state_update = None\n\nfrom mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated\n\nfrom mamba_ssm.distributed.tensor_parallel import ColumnParallelLinear, RowParallelLinear\nfrom mamba_ssm.distributed.distributed_utils import all_reduce, reduce_scatter\n\nfrom mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined\nfrom mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined\n\nfrom huggingface_hub import PyTorchModelHubMixin\n\n\nclass Mamba2(nn.Module, PyTorchModelHubMixin):\n    def __init__(\n        self,\n        d_model,\n        d_state=128,\n        d_conv=4,\n        conv_init=None,\n        expand=2,\n        headdim=64,\n        d_ssm=None,  # If not None, we only apply SSM on this many dimensions, the rest uses gated MLP\n        ngroups=1,\n        A_init_range=(1, 16),\n        D_has_hdim=False,\n        rmsnorm=True,\n        norm_before_gate=False,\n        dt_min=0.001,\n        dt_max=0.1,\n        dt_init_floor=1e-4,\n        dt_limit=(0.0, float(\"inf\")),\n        bias=False,\n        conv_bias=True,\n        # Fused kernel and sharding options\n        chunk_size=256,\n        use_mem_eff_path=True,\n        layer_idx=None,  # Absorb kwarg for general module\n        process_group=None,\n        sequence_parallel=True,\n        device=None,\n        dtype=None,\n    ):\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        self.d_model = d_model\n        self.d_state = d_state\n        self.d_conv = d_conv\n        self.conv_init = conv_init\n        self.expand = expand\n        self.process_group = process_group\n        self.sequence_parallel = sequence_parallel\n        self.world_size = 1 if process_group is None else process_group.size()\n        self.local_rank = 0 if process_group is None else process_group.rank()\n        self.d_inner = (self.expand * self.d_model) // self.world_size\n        assert self.d_inner * self.world_size == self.expand * self.d_model\n        self.headdim = headdim\n        self.d_ssm = self.d_inner if d_ssm is None else d_ssm // self.world_size\n        assert ngroups % self.world_size == 0\n        self.ngroups = ngroups // self.world_size\n        assert self.d_ssm % self.headdim == 0\n        self.nheads = self.d_ssm // self.headdim\n        self.D_has_hdim = D_has_hdim\n        self.rmsnorm = rmsnorm\n        self.norm_before_gate = norm_before_gate\n        self.dt_limit = dt_limit\n        self.activation = \"silu\"\n        self.chunk_size = chunk_size\n        self.use_mem_eff_path = use_mem_eff_path\n        self.layer_idx = layer_idx\n\n        # Order: [z, x, B, C, dt]\n        d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads\n        if self.process_group is None:\n            self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs)\n        else:\n            self.in_proj = ColumnParallelLinear(self.d_model, d_in_proj * self.world_size, bias=bias,\n                                                process_group=self.process_group, sequence_parallel=self.sequence_parallel,\n                                                **factory_kwargs)\n\n        conv_dim = self.d_ssm + 2 * self.ngroups * self.d_state\n        self.conv1d = nn.Conv1d(\n            in_channels=conv_dim,\n            out_channels=conv_dim,\n            bias=conv_bias,\n            kernel_size=d_conv,\n            groups=conv_dim,\n            padding=d_conv - 1,\n            **factory_kwargs,\n        )\n        if self.conv_init is not None:\n            nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)\n\n        self.act = nn.SiLU()\n\n        # Initialize log dt bias\n        dt = torch.exp(\n            torch.rand(self.nheads, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))\n            + math.log(dt_min)\n        )\n        dt = torch.clamp(dt, min=dt_init_floor)\n        # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759\n        inv_dt = dt + torch.log(-torch.expm1(-dt))\n        self.dt_bias = nn.Parameter(inv_dt)\n        # Just to be explicit. Without this we already don't put wd on dt_bias because of the check\n        # name.endswith(\"bias\") in param_grouping.py\n        self.dt_bias._no_weight_decay = True\n\n        assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]\n        A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(*A_init_range)\n        A_log = torch.log(A).to(dtype=dtype)\n        self.A_log = nn.Parameter(A_log)\n        self.A_log._no_weight_decay = True\n\n        # D \"skip\" parameter\n        self.D = nn.Parameter(torch.ones(self.d_ssm if self.D_has_hdim else self.nheads, device=device))\n        self.D._no_weight_decay = True\n\n        if self.rmsnorm:\n            assert RMSNormGated is not None\n            self.norm = RMSNormGated(self.d_ssm, eps=1e-5, norm_before_gate=self.norm_before_gate,\n                                     group_size=self.d_ssm // ngroups, **factory_kwargs)\n\n        if self.process_group is None:\n            self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)\n        else:\n            self.out_proj = RowParallelLinear(self.d_inner * self.world_size, self.d_model, bias=bias,\n                                              process_group=self.process_group, sequence_parallel=self.sequence_parallel,\n                                              **factory_kwargs)\n\n    def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_params=None):\n        \"\"\"\n        u: (batch, seqlen, hidden_dim) if seqlen=None.\n            If seqlen is not None, u is (batch * seqlen, hidden_dim). This is so that when we\n            split u during sequence parallel, we split the batch * seqlen dimension\n            (in case batch is small).\n        Returns: same shape as u\n        \"\"\"\n        seqlen_og = seqlen\n        if seqlen is None:\n            batch, seqlen, dim = u.shape\n        else:\n            batch_seqlen, dim = u.shape\n            batch = batch_seqlen // seqlen\n\n        conv_state, ssm_state = None, None\n        if inference_params is not None:\n            inference_batch = cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch\n            conv_state, ssm_state = self._get_states_from_cache(inference_params, inference_batch)\n            if inference_params.seqlen_offset > 0:\n                # The states are updated inplace\n                out, _, _ = self.step(u, conv_state, ssm_state)\n                return out\n\n        zxbcdt = self.in_proj(u)  # (B, L, d_in_proj) or (B * L, d_in_proj)\n        if seqlen_og is not None:\n            zxbcdt = rearrange(zxbcdt, \"(b l) d -> b l d\", l=seqlen)\n        # If the model is loaded in fp16, without the .float() here, A might be -inf\n        A = -torch.exp(self.A_log.float())  # (nheads) or (d_inner, d_state)\n        dt_limit_kwargs = {} if self.dt_limit == (0.0, float(\"inf\")) else dict(dt_limit=self.dt_limit)\n        if self.use_mem_eff_path and inference_params is None:\n            out = mamba_split_conv1d_scan_combined(\n                zxbcdt,\n                rearrange(self.conv1d.weight, \"d 1 w -> d w\"),\n                self.conv1d.bias,\n                self.dt_bias,\n                A,\n                D=rearrange(self.D, \"(h p) -> h p\", p=self.headdim) if self.D_has_hdim else self.D,\n                chunk_size=self.chunk_size,\n                seq_idx=seq_idx,\n                activation=self.activation,\n                rmsnorm_weight=self.norm.weight if self.rmsnorm else None,\n                rmsnorm_eps=self.norm.eps if self.rmsnorm else 1e-6,\n                outproj_weight=self.out_proj.weight,\n                outproj_bias=self.out_proj.bias,\n                headdim=None if self.D_has_hdim else self.headdim,\n                ngroups=self.ngroups,\n                norm_before_gate=self.norm_before_gate,\n                **dt_limit_kwargs,\n            )\n            if seqlen_og is not None:\n                out = rearrange(out, \"b l d -> (b l) d\")\n            if self.process_group is not None:\n                reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce\n                out = reduce_fn(out, self.process_group)\n        else:\n            d_mlp = (zxbcdt.shape[-1] - 2 * self.d_ssm - 2 * self.ngroups * self.d_state - self.nheads) // 2\n            z0, x0, z, xBC, dt = torch.split(\n                zxbcdt,\n                [d_mlp, d_mlp, self.d_ssm, self.d_ssm + 2 * self.ngroups * self.d_state, self.nheads],\n                dim=-1\n            )\n            if conv_state is not None:\n                if cu_seqlens is None:\n                    # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv\n                    # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.\n                    xBC_t = rearrange(xBC, \"b l d -> b d l\")\n                    conv_state.copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0)))  # Update state (B D W)\n                else:\n                    assert causal_conv1d_varlen_states is not None, \"varlen inference requires causal_conv1d package\"\n                    assert batch == 1, \"varlen inference only supports batch dimension 1\"\n                    conv_varlen_states = causal_conv1d_varlen_states(\n                        xBC.squeeze(0), cu_seqlens, state_len=conv_state.shape[-1]\n                    )\n                    conv_state.copy_(conv_varlen_states)\n            assert self.activation in [\"silu\", \"swish\"]\n            if causal_conv1d_fn is None or self.activation not in [\"silu\", \"swish\"]:\n                assert seq_idx is None, \"varlen conv1d requires the causal_conv1d package\"\n                xBC = self.act(\n                    self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, :-(self.d_conv - 1)]\n                )  # (B, L, self.d_ssm + 2 * ngroups * d_state)\n            else:\n                xBC = causal_conv1d_fn(\n                    xBC.transpose(1, 2),\n                    rearrange(self.conv1d.weight, \"d 1 w -> d w\"),\n                    bias=self.conv1d.bias,\n                    activation=self.activation,\n                    seq_idx=seq_idx,\n                ).transpose(1, 2)\n            x, B, C = torch.split(xBC, [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)\n            y = mamba_chunk_scan_combined(\n                rearrange(x, \"b l (h p) -> b l h p\", p=self.headdim),\n                dt,\n                A,\n                rearrange(B, \"b l (g n) -> b l g n\", g=self.ngroups),\n                rearrange(C, \"b l (g n) -> b l g n\", g=self.ngroups),\n                chunk_size=self.chunk_size,\n                D=rearrange(self.D, \"(h p) -> h p\", p=self.headdim) if self.D_has_hdim else self.D,\n                z=rearrange(z, \"b l (h p) -> b l h p\", p=self.headdim) if not self.rmsnorm else None,\n                dt_bias=self.dt_bias,\n                dt_softplus=True,\n                seq_idx=seq_idx,\n                cu_seqlens=cu_seqlens,\n                **dt_limit_kwargs,\n                return_final_states=ssm_state is not None,\n                return_varlen_states=cu_seqlens is not None and inference_params is not None,\n            )\n            if ssm_state is not None:\n                y, last_state, *rest = y\n                if cu_seqlens is None:\n                    ssm_state.copy_(last_state)\n                else:\n                    varlen_states = rest[0]\n                    ssm_state.copy_(varlen_states)\n            y = rearrange(y, \"b l h p -> b l (h p)\")\n            if self.rmsnorm:\n                y = self.norm(y, z)\n            if d_mlp > 0:\n                y = torch.cat([F.silu(z0) * x0, y], dim=-1)\n            if seqlen_og is not None:\n                y = rearrange(y, \"b l d -> (b l) d\")\n            out = self.out_proj(y)\n        return out\n\n    def step(self, hidden_states, conv_state, ssm_state):\n        dtype = hidden_states.dtype\n        assert hidden_states.shape[1] == 1, \"Only support decoding with 1 token at a time for now\"\n        zxbcdt = self.in_proj(hidden_states.squeeze(1))  # (B 2D)\n        d_mlp = (zxbcdt.shape[-1] - 2 * self.d_ssm - 2 * self.ngroups * self.d_state - self.nheads) // 2\n        z0, x0, z, xBC, dt = torch.split(\n            zxbcdt,\n            [d_mlp, d_mlp, self.d_ssm, self.d_ssm + 2 * self.ngroups * self.d_state, self.nheads],\n            dim=-1\n        )\n\n        # Conv step\n        if causal_conv1d_update is None:\n            conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1))  # Update state (B D W)\n            conv_state[:, :, -1] = xBC\n            xBC = torch.sum(conv_state * rearrange(self.conv1d.weight, \"d 1 w -> d w\"), dim=-1)  # (B D)\n            if self.conv1d.bias is not None:\n                xBC = xBC + self.conv1d.bias\n            xBC = self.act(xBC).to(dtype=dtype)\n        else:\n            xBC = causal_conv1d_update(\n                xBC,\n                conv_state,\n                rearrange(self.conv1d.weight, \"d 1 w -> d w\"),\n                self.conv1d.bias,\n                self.activation,\n            )\n\n        x, B, C = torch.split(xBC, [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)\n        A = -torch.exp(self.A_log.float())  # (nheads,)\n\n        # SSM step\n        if selective_state_update is None:\n            assert self.ngroups == 1, \"Only support ngroups=1 for this inference code path\"\n            # Discretize A and B\n            dt = F.softplus(dt + self.dt_bias.to(dtype=dt.dtype))  # (batch, nheads)\n            dA = torch.exp(dt * A)  # (batch, nheads)\n            x = rearrange(x, \"b (h p) -> b h p\", p=self.headdim)\n            dBx = torch.einsum(\"bh,bn,bhp->bhpn\", dt, B, x)\n            ssm_state.copy_(ssm_state * rearrange(dA, \"b h -> b h 1 1\") + dBx)\n            y = torch.einsum(\"bhpn,bn->bhp\", ssm_state.to(dtype), C)\n            y = y + rearrange(self.D.to(dtype), \"h -> h 1\") * x\n            y = rearrange(y, \"b h p -> b (h p)\")\n            if not self.rmsnorm:\n                y = y * self.act(z)  # (B D)\n        else:\n            A = repeat(A, \"h -> h p n\", p=self.headdim, n=self.d_state).to(dtype=torch.float32)\n            dt = repeat(dt, \"b h -> b h p\", p=self.headdim)\n            dt_bias = repeat(self.dt_bias, \"h -> h p\", p=self.headdim)\n            D = repeat(self.D, \"h -> h p\", p=self.headdim)\n            B = rearrange(B, \"b (g n) -> b g n\", g=self.ngroups)\n            C = rearrange(C, \"b (g n) -> b g n\", g=self.ngroups)\n            x_reshaped = rearrange(x, \"b (h p) -> b h p\", p=self.headdim)\n            if not self.rmsnorm:\n                z = rearrange(z, \"b (h p) -> b h p\", p=self.headdim)\n            y = selective_state_update(\n                ssm_state, x_reshaped, dt, A, B, C, D, z=z if not self.rmsnorm else None,\n                dt_bias=dt_bias, dt_softplus=True\n            )\n            y = rearrange(y, \"b h p -> b (h p)\")\n        if self.rmsnorm:\n            y = self.norm(y, z)\n        if d_mlp > 0:\n            y = torch.cat([F.silu(z0) * x0, y], dim=-1)\n        out = self.out_proj(y)\n        return out.unsqueeze(1), conv_state, ssm_state\n\n    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):\n        device = self.out_proj.weight.device\n        conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype\n        conv_state = torch.zeros(\n            batch_size, self.d_conv, self.conv1d.weight.shape[0], device=device, dtype=conv_dtype\n        ).transpose(1, 2)\n        ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype\n        ssm_state = torch.zeros(\n            batch_size, self.nheads, self.headdim, self.d_state, device=device, dtype=ssm_dtype\n        )\n        return conv_state, ssm_state\n\n    def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):\n        assert self.layer_idx is not None\n        if self.layer_idx not in inference_params.key_value_memory_dict:\n            batch_shape = (batch_size,)\n            conv_state = torch.zeros(\n                batch_size,\n                self.d_conv,\n                self.conv1d.weight.shape[0],\n                device=self.conv1d.weight.device,\n                dtype=self.conv1d.weight.dtype,\n            ).transpose(1, 2)\n            ssm_state = torch.zeros(\n                batch_size,\n                self.nheads,\n                self.headdim,\n                self.d_state,\n                device=self.in_proj.weight.device,\n                dtype=self.in_proj.weight.dtype,\n            )\n            inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)\n        else:\n            conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]\n            # TODO: What if batch size changes between generation, and we reuse the same states?\n            if initialize_states:\n                conv_state.zero_()\n                ssm_state.zero_()\n        return conv_state, ssm_state\n"
  },
  {
    "path": "mamba_ssm/modules/mamba2_simple.py",
    "content": "# Copyright (c) 2024, Tri Dao, Albert Gu.\n\nimport math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom einops import rearrange, repeat\n\ntry:\n    from causal_conv1d import causal_conv1d_fn\nexcept ImportError:\n    causal_conv1d_fn = None\n\ntry:\n    from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated, LayerNorm\nexcept ImportError:\n    RMSNormGated, LayerNorm = None, None\n\nfrom mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined\nfrom mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined\n\n\nclass Mamba2Simple(nn.Module):\n    def __init__(\n        self,\n        d_model,\n        d_state=64,\n        d_conv=4,\n        conv_init=None,\n        expand=2,\n        headdim=128,\n        ngroups=1,\n        A_init_range=(1, 16),\n        dt_min=0.001,\n        dt_max=0.1,\n        dt_init_floor=1e-4,\n        dt_limit=(0.0, float(\"inf\")),\n        learnable_init_states=False,\n        activation=\"swish\",\n        bias=False,\n        conv_bias=True,\n        # Fused kernel and sharding options\n        chunk_size=256,\n        use_mem_eff_path=True,\n        layer_idx=None,  # Absorb kwarg for general module\n        device=None,\n        dtype=None,\n    ):\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        self.d_model = d_model\n        self.d_state = d_state\n        self.d_conv = d_conv\n        self.conv_init = conv_init\n        self.expand = expand\n        self.d_inner = self.expand * self.d_model\n        self.headdim = headdim\n        self.ngroups = ngroups\n        assert self.d_inner % self.headdim == 0\n        self.nheads = self.d_inner // self.headdim\n        self.dt_limit = dt_limit\n        self.learnable_init_states = learnable_init_states\n        self.activation = activation\n        self.chunk_size = chunk_size\n        self.use_mem_eff_path = use_mem_eff_path\n        self.layer_idx = layer_idx\n\n        # Order: [z, x, B, C, dt]\n        d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads\n        self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs)\n\n        conv_dim = self.d_inner + 2 * self.ngroups * self.d_state\n        self.conv1d = nn.Conv1d(\n            in_channels=conv_dim,\n            out_channels=conv_dim,\n            bias=conv_bias,\n            kernel_size=d_conv,\n            groups=conv_dim,\n            padding=d_conv - 1,\n            **factory_kwargs,\n        )\n        if self.conv_init is not None:\n            nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)\n        # self.conv1d.weight._no_weight_decay = True\n\n        if self.learnable_init_states:\n            self.init_states = nn.Parameter(torch.zeros(self.nheads, self.headdim, self.d_state, **factory_kwargs))\n            self.init_states._no_weight_decay = True\n\n        self.act = nn.SiLU()\n\n        # Initialize log dt bias\n        dt = torch.exp(\n            torch.rand(self.nheads, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))\n            + math.log(dt_min)\n        )\n        dt = torch.clamp(dt, min=dt_init_floor)\n        # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759\n        inv_dt = dt + torch.log(-torch.expm1(-dt))\n        self.dt_bias = nn.Parameter(inv_dt)\n        # Just to be explicit. Without this we already don't put wd on dt_bias because of the check\n        # name.endswith(\"bias\") in param_grouping.py\n        self.dt_bias._no_weight_decay = True\n\n        # A parameter\n        assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]\n        A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(*A_init_range)\n        A_log = torch.log(A).to(dtype=dtype)\n        self.A_log = nn.Parameter(A_log)\n        # self.register_buffer(\"A_log\", torch.zeros(self.nheads, dtype=torch.float32, device=device), persistent=True)\n        self.A_log._no_weight_decay = True\n\n        # D \"skip\" parameter\n        self.D = nn.Parameter(torch.ones(self.nheads, device=device))\n        self.D._no_weight_decay = True\n\n        # Extra normalization layer right before output projection\n        assert RMSNormGated is not None\n        self.norm = RMSNormGated(self.d_inner, eps=1e-5, norm_before_gate=False, **factory_kwargs)\n\n        self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)\n\n    def forward(self, u, seq_idx=None):\n        \"\"\"\n        u: (B, L, D)\n        Returns: same shape as u\n        \"\"\"\n        batch, seqlen, dim = u.shape\n\n        zxbcdt = self.in_proj(u)  # (B, L, d_in_proj)\n        A = -torch.exp(self.A_log)  # (nheads) or (d_inner, d_state)\n        initial_states=repeat(self.init_states, \"... -> b ...\", b=batch) if self.learnable_init_states else None\n        dt_limit_kwargs = {} if self.dt_limit == (0.0, float(\"inf\")) else dict(dt_limit=self.dt_limit)\n\n        if self.use_mem_eff_path:\n            # Fully fused path\n            out = mamba_split_conv1d_scan_combined(\n                zxbcdt,\n                rearrange(self.conv1d.weight, \"d 1 w -> d w\"),\n                self.conv1d.bias,\n                self.dt_bias,\n                A,\n                D=self.D,\n                chunk_size=self.chunk_size,\n                seq_idx=seq_idx,\n                activation=self.activation,\n                rmsnorm_weight=self.norm.weight,\n                rmsnorm_eps=self.norm.eps,\n                outproj_weight=self.out_proj.weight,\n                outproj_bias=self.out_proj.bias,\n                headdim=self.headdim,\n                ngroups=self.ngroups,\n                norm_before_gate=False,\n                initial_states=initial_states,\n                **dt_limit_kwargs,\n            )\n        else:\n            z, xBC, dt = torch.split(\n                zxbcdt, [self.d_inner, self.d_inner + 2 * self.ngroups * self.d_state, self.nheads], dim=-1\n            )\n            dt = F.softplus(dt + self.dt_bias)  # (B, L, nheads)\n            assert self.activation in [\"silu\", \"swish\"]\n\n            # 1D Convolution\n            if causal_conv1d_fn is None or self.activation not in [\"silu\", \"swish\"]:\n                xBC = self.act(\n                    self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)\n                )  # (B, L, self.d_inner + 2 * ngroups * d_state)\n                xBC = xBC[:, :seqlen, :]\n            else:\n                xBC = causal_conv1d_fn(\n                    x=xBC.transpose(1, 2),\n                    weight=rearrange(self.conv1d.weight, \"d 1 w -> d w\"),\n                    bias=self.conv1d.bias,\n                    activation=self.activation,\n                ).transpose(1, 2)\n\n            # Split into 3 main branches: X, B, C\n            # These correspond to V, K, Q respectively in the SSM/attention duality\n            x, B, C = torch.split(xBC, [self.d_inner, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)\n            y = mamba_chunk_scan_combined(\n                rearrange(x, \"b l (h p) -> b l h p\", p=self.headdim),\n                dt,\n                A,\n                rearrange(B, \"b l (g n) -> b l g n\", g=self.ngroups),\n                rearrange(C, \"b l (g n) -> b l g n\", g=self.ngroups),\n                chunk_size=self.chunk_size,\n                D=self.D,\n                z=None,\n                seq_idx=seq_idx,\n                initial_states=initial_states,\n                **dt_limit_kwargs,\n            )\n            y = rearrange(y, \"b l h p -> b l (h p)\")\n\n            # Multiply \"gate\" branch and apply extra normalization layer\n            y = self.norm(y, z)\n            out = self.out_proj(y)\n        return out\n"
  },
  {
    "path": "mamba_ssm/modules/mamba3.py",
    "content": "# Copyright (c) 2026, Dao AI Lab, Goombalab.\n\nimport math\nfrom einops import rearrange, repeat\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated\n\nfrom mamba_ssm.ops.tilelang.mamba3.mamba3_mimo import mamba3_mimo as mamba3_mimo_combined\nfrom mamba_ssm.ops.triton.angle_cumsum import angle_dt\nfrom mamba_ssm.ops.triton.mamba3.mamba3_siso_combined import mamba3_siso_combined\n\nfrom mamba_ssm.ops.triton.mamba3.mamba3_mimo_rotary_step import apply_rotary_qk_inference_fwd\n\nfrom mamba_ssm.ops.cute.mamba3.mamba3_step_fn import mamba3_step_fn\n\nclass Mamba3(nn.Module):\n    def __init__(\n        self,\n        d_model,\n        d_state=128,\n        expand=2,\n        headdim=64,\n        ngroups=1,\n        # ----------------------------------------\n        # Mamba-3 configs\n        rope_fraction=0.5,\n        dt_min=0.001,\n        dt_max=0.1,\n        dt_init_floor=1e-4,\n        A_floor=1e-4,\n        is_outproj_norm=False,\n        is_mimo=False,\n        mimo_rank=4,\n        #-------------------------------------------\n        # Fused kernel and sharding options\n        chunk_size=64, # Recommended: 64 for SISO, 64/mimo_rank for MIMO\n        dropout=0.0,  # Just to absorb the kwarg\n        layer_idx=None,  # Absorb kwarg for general module\n        n_layer=None,  # Absorb kwarg for general module\n        device=None,\n        dtype=None,\n        **kwargs,\n    ):\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        self.d_model = d_model\n        self.d_state = d_state\n        self.expand = expand\n        self.headdim = headdim\n        self.chunk_size = chunk_size\n        self.layer_idx = layer_idx\n        self.A_floor = A_floor\n        self.is_outproj_norm=is_outproj_norm\n        self.is_mimo = is_mimo\n        self.mimo_rank = mimo_rank\n        if not self.is_mimo:\n            self.mimo_rank = 1\n\n        self.d_inner = int(self.expand * self.d_model)\n        assert self.d_inner % self.headdim == 0\n        self.nheads = self.d_inner // self.headdim\n        self.num_bc_heads = ngroups\n        \n        # RoPE flags\n        assert rope_fraction in [0.5, 1.0]\n        self.rotary_dim_divisor = int(2/rope_fraction)\n        self.split_tensor_size = int(d_state * rope_fraction)\n        if self.split_tensor_size % 2 != 0:\n            self.split_tensor_size -= 1\n        self.num_rope_angles = self.split_tensor_size // 2\n        assert self.num_rope_angles > 0\n\n        # Order: [z, x, B, C, dd_dt, dd_A, trap, angle]\n        d_in_proj = 2 * self.d_inner + 2 * self.d_state * self.num_bc_heads * self.mimo_rank + 3 * self.nheads + self.num_rope_angles\n        self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=False, **factory_kwargs)\n\n        # dt_bias parameterization        \n        _dt = torch.exp(\n            torch.rand(self.nheads, device=device, dtype=torch.float32) * (math.log(dt_max) - math.log(dt_min))\n            + math.log(dt_min)\n        )\n        _dt = torch.clamp(_dt, min=dt_init_floor)\n        _dt_bias = _dt + torch.log(-torch.expm1(-_dt))\n        self.dt_bias = nn.Parameter(_dt_bias, requires_grad=True)\n        self.dt_bias._no_weight_decay = True\n        \n        # B and C biases\n        self.B_bias = nn.Parameter(1+torch.zeros((self.nheads, self.mimo_rank, self.d_state), dtype=torch.float32, device=device), requires_grad=True)\n        self.C_bias = nn.Parameter(1+torch.zeros((self.nheads, self.mimo_rank, self.d_state), dtype=torch.float32, device=device), requires_grad=True)\n                                                       \n        # RMS Norm for B and C\n        assert RMSNormGated is not None\n        self.B_norm = RMSNormGated(self.d_state, eps=1e-5, **factory_kwargs)\n        self.C_norm = RMSNormGated(self.d_state, eps=1e-5, **factory_kwargs)\n\n        if self.is_mimo:\n            # Initialize up/down MIMO projection (for x and z)\n            mimo_x_init_weights = torch.ones(self.nheads, self.mimo_rank, self.headdim, device=device) / self.mimo_rank\n            mimo_z_init_weights = torch.ones(self.nheads, self.mimo_rank, self.headdim, device=device)\n            mimo_o_init_weights = torch.ones(self.nheads, self.mimo_rank, self.headdim, device=device) / self.mimo_rank\n\n            self.mimo_x = nn.Parameter(mimo_x_init_weights, requires_grad=True)\n            self.mimo_z = nn.Parameter(mimo_z_init_weights, requires_grad=True)\n            self.mimo_o = nn.Parameter(mimo_o_init_weights, requires_grad=True)\n    \n        # D \"skip\" parameter\n        self.D = nn.Parameter(torch.ones(self.nheads, device=device))\n        self.D._no_weight_decay = True\n\n        if self.is_outproj_norm:\n            self.norm = RMSNormGated(\n                self.d_inner,\n                eps=1e-5,\n                norm_before_gate=True,\n                group_size=self.headdim,\n                **factory_kwargs\n            )\n\n        # Output projection\n        self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=False, **factory_kwargs)\n\n\n    def forward(self, u, seq_idx=None, cu_seqlens=None, inference_params=None):\n        \"\"\"\n        u: (batch, seqlen, hidden_dim)\n        Returns: same shape as u\n        \"\"\"\n        batch, seqlen, dim = u.shape\n        if cu_seqlens is not None:\n            raise NotImplementedError(\"Currently does not support varlen in Mamba-3 (MIMO).\")\n\n        angle_dt_state, ssm_state, k_state, v_state  = None, None, None, None\n        if inference_params is not None:\n            inference_batch = cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch\n            angle_dt_state, ssm_state, k_state, v_state = self._get_states_from_cache(inference_params, inference_batch)\n            if inference_params.seqlen_offset > 0:\n                # The states are updated inplace here; however, due to the current implementation,\n                # setting inplace=True incurs significant overhead. So potentially\n                # faster to call step() directly with inplace=False:\n                out, _, _, _, _ = self.step(u, angle_dt_state, ssm_state, k_state, v_state)\n                return out\n\n        # Apply in_proj\n        zxBCdtAtrap = self.in_proj(u)\n        z, x, B, C, dd_dt, dd_A, trap, angles = torch.split(\n            zxBCdtAtrap,\n            [\n                self.d_inner, self.d_inner, \n                self.d_state * self.num_bc_heads * self.mimo_rank,\n                self.d_state * self.num_bc_heads * self.mimo_rank,\n                self.nheads, self.nheads, self.nheads, \n                self.num_rope_angles\n            ],\n            dim=-1)\n        z = rearrange(z, \"b l (h p) -> b l h p\", p=self.headdim)\n        x = rearrange(x, \"b l (h p) -> b l h p\", p=self.headdim)\n        B = rearrange(B, \"b l (r g n) -> b l r g n\", r=self.mimo_rank, g=self.num_bc_heads)\n        C = rearrange(C, \"b l (r g n) -> b l r g n\", r=self.mimo_rank, g=self.num_bc_heads)\n        trap = rearrange(trap, \"b l h -> b h l\")\n\n        # Compute ADT, DT\n        _A = -F.softplus(dd_A.to(torch.float32)) # (B, L, N)\n        _A = torch.clamp(_A, max=-self.A_floor)            \n        DT = F.softplus(dd_dt + self.dt_bias) # (B, L, N)\n        ADT = _A * DT\n        DT = rearrange(DT, \"b l n -> b n l\")\n        ADT = rearrange(ADT, \"b l n -> b n l\")\n\n        # Compute angle\n        angles = angles.unsqueeze(-2).expand(-1, -1, self.nheads, -1) # (B, L, N, S)\n\n        # Apply RMS Norm on B and C\n        B = self.B_norm(B)\n        C = self.C_norm(C)\n        \n        # Apply Mamba-3 kernel\n        if self.is_mimo:\n            angles = angle_dt(angles, DT.transpose(-1, -2)) # (B, L, N, S)\n            y = mamba3_mimo_combined(\n                Q=C,\n                K=B,\n                V=x,\n                ADT=ADT,\n                DT=DT,\n                Trap=trap,\n                Q_bias=self.C_bias,\n                K_bias=self.B_bias,\n                MIMO_V=self.mimo_x,\n                MIMO_Z=self.mimo_z,\n                MIMO_Out=self.mimo_o if not self.is_outproj_norm else None,\n                Angles=angles,\n                D=self.D,\n                Z=z if not self.is_outproj_norm else None,\n                chunk_size=self.chunk_size,\n                rotary_dim_divisor=self.rotary_dim_divisor,\n                dtype=x.dtype,\n                return_state=ssm_state is not None,\n            )\n            if ssm_state is not None:\n                y, last_angle, last_state, last_k, last_v, *rest = y\n                angle_dt_state.copy_(last_angle)\n                ssm_state.copy_(last_state)\n                k_state.copy_(last_k)\n                v_state.copy_(last_v)\n            if self.is_outproj_norm:\n                z = torch.einsum(\"blhp,hrp->blrhp\", z.float(), self.mimo_z)\n                z = rearrange(z, \"b l r h p -> b l r (h p)\")\n                y = rearrange(y, \"b l r h p -> b l r (h p)\").float()\n                y = self.norm(y, z)\n                y = rearrange(y, \"b l r (h p) -> b l r h p\", p=self.headdim)\n                y = torch.einsum(\"blrhp,hrp->blhp\", y, self.mimo_o)\n            y = rearrange(y, \"b l h p -> b l (h p)\")\n        else:\n            y = mamba3_siso_combined(\n                Q=C.squeeze(2),\n                K=B.squeeze(2),\n                V=x,\n                ADT=ADT,\n                DT=DT,\n                Trap=trap,\n                Q_bias=self.C_bias.squeeze(1),\n                K_bias=self.B_bias.squeeze(1),\n                Angles=angles,\n                D=self.D,\n                Z=z if not self.is_outproj_norm else None,\n                chunk_size=self.chunk_size,\n                Input_States=None,\n                return_final_states=ssm_state is not None,\n            )\n            if ssm_state is not None:\n                y, last_angle, last_state, last_k, last_v, *rest = y\n                angle_dt_state.copy_(last_angle)\n                ssm_state.copy_(last_state)\n                k_state.copy_(last_k)\n                v_state.copy_(last_v)\n            y = rearrange(y, \"b l h p -> b l (h p)\")\n            if self.is_outproj_norm:\n                z = rearrange(z, \"b l h p -> b l (h p)\")\n                y = self.norm(y, z)\n        \n        out = self.out_proj(y.to(x.dtype))\n        return out\n    \n\n    def _preprocess(self, A_proj, dd_dt, B, C, x, z, trap_proj, angle_proj):\n        _A = -F.softplus(A_proj.to(torch.float32))\n        _A = torch.clamp(_A, max=-self.A_floor)\n        DT = F.softplus(dd_dt + self.dt_bias)\n        trap = torch.sigmoid(trap_proj)\n\n        rank = self.mimo_rank if self.is_mimo else 1\n        B = rearrange(B, \"b (r g s) -> b r g s\", g=self.num_bc_heads, r=rank)\n        C = rearrange(C, \"b (r g s) -> b r g s\", g=self.num_bc_heads, r=rank)\n\n        B = self.B_norm(B)\n        C = self.C_norm(C)\n\n        B = B.expand(-1, -1, self.nheads, -1) # (B, R, N, S)\n        C = C.expand(-1, -1, self.nheads, -1) # (B, R, N, S)\n    \n        x = rearrange(x, \"b (h p) -> b h p\", p=self.headdim)\n        z = rearrange(z, \"b (h p) -> b h p\", p=self.headdim)\n\n        angles = angle_proj.unsqueeze(-2).expand(-1, self.nheads, -1)\n\n        return DT, B, C, x, z, trap, _A, angles\n\n    def _postprocess(self, y, outpj, z, zpj, headdim):\n        # y: (batch, R, H, D) — apply mimo_z to z, then norm, then mimo_o\n        z_r = torch.einsum(\"bhp,rhp->brhp\", z.float(), zpj)  # (batch, R, H, D)\n        z_r = rearrange(z_r, \"b r h p -> b r (h p)\")\n        y = rearrange(y, \"b r h p -> b r (h p)\").float()\n        y = self.norm(y, z_r)\n        y = rearrange(y, \"b r (h p) -> b r h p\", p=headdim)\n        y = torch.einsum(\"brhp,rhp->bhp\", y, outpj)  # (batch, H, D)\n        return y\n\n    def step(self, u, angle_state, ssm_state, k_state, v_state, **kwargs):\n        \"\"\"\n        Decode function using CuteDSL kernel from mamba3_step_fn.py.\n        Also modify the state vars in-place for the next step.\n\n        NOTE: Only tested on H100. Compatibility with other hardware\n        will be made available in the future.\n\n        Args:\n            u: (batch, d_model)\n            angle_state: (batch, nheads, num_rope_angles)\n            ssm_state: (batch, nheads, headdim, d_state)\n            k_state: (batch, R, nheads, d_state), where R = mimo_rank (R=1 if not MIMO)\n            v_state: (batch, nheads, headdim)\n            **kwargs: ignored\n        Returns:\n            out: (batch, d_model)\n            nxt_angle_state: (batch, nheads, num_rope_angles)\n            state_out: (batch, nheads, headdim, d_state)\n            nxt_k_state: (batch, R, nheads, d_state), where R = mimo_rank (R=1 if not MIMO)\n            nxt_v_state: (batch, nheads, headdim)\n        \"\"\"\n\n        # in_proj\n        zxBCdt = self.in_proj(u)\n        z, x, B, C, dd_dt, dd_A, trap, angles = torch.split(\n            zxBCdt,\n            [\n                self.d_inner,\n                self.d_inner,\n                self.d_state * self.num_bc_heads * self.mimo_rank,\n                self.d_state * self.num_bc_heads * self.mimo_rank,\n                self.nheads,\n                self.nheads,\n                self.nheads,\n                self.num_rope_angles,\n            ],\n            dim=-1)\n\n        DT, B, C, x, z, trap, A, angles = self._preprocess(\n            dd_A, dd_dt, B, C, x, z, trap, angles)\n\n        bias_q = rearrange(self.C_bias, \"h r n -> r h n\")\n        bias_k = rearrange(self.B_bias, \"h r n -> r h n\")\n\n        # NOTE: MIMO calls the Tilelang kernel, \n        # which permute the blockwise rotation matrix so that\n        # the i-th entry is paired with the i+N//2-th entry:\n        rotate_pairwise = not self.is_mimo\n        C, B, nxt_angle_state = apply_rotary_qk_inference_fwd(\n            q=C, k=B, angle_state=angle_state, \n            angle_proj=angles, dt=DT, bias_q=bias_q, bias_k=bias_k, \n            conjugate=False, inplace=False, # NOTE: inplace is incompatible with self.nheads != self.num_bc_heads\n            rotate_pairwise=rotate_pairwise)\n\n        nxt_v_state = x\n        nxt_k_state = B\n\n        if self.is_mimo:\n            xpj = rearrange(self.mimo_x, \"h r p -> r h p\", p=self.headdim).contiguous()\n            zpj = rearrange(self.mimo_z, \"h r p -> r h p\", p=self.headdim).contiguous()\n            outpj = rearrange(self.mimo_o, \"h r p -> r h p\", p=self.headdim).contiguous()\n        else:\n            xpj = torch.ones(self.mimo_rank, self.nheads, self.headdim, device=x.device, dtype=x.dtype)\n            zpj = torch.ones(self.mimo_rank, self.nheads, self.headdim, device=z.device, dtype=z.dtype)\n            outpj = torch.ones(self.mimo_rank, self.nheads, self.headdim, device=x.device, dtype=x.dtype)\n\n        if self.is_outproj_norm:\n            batch = x.shape[0]\n            y = torch.empty(batch, self.mimo_rank, self.nheads, self.headdim, device=x.device, dtype=x.dtype)\n            mamba3_step_fn(\n                ssm_state,\n                k_state,\n                v_state,\n                A,\n                B,\n                C,\n                self.D,\n                x,\n                DT,\n                trap,\n                xpj,\n                outproj=None,\n                state_out=None, # can be not in place if pass in state_out\n                out=y,\n                z=None,\n                zproj=None,\n                tile_D=64,\n                num_warps=4,\n            )\n            y = self._postprocess(y, outpj, z, zpj, self.headdim)\n        else:\n            y = torch.empty_like(x)\n            mamba3_step_fn(\n                ssm_state,\n                k_state,\n                v_state,\n                A,\n                B,\n                C,\n                self.D,\n                x,\n                DT,\n                trap,\n                xpj,\n                outproj=outpj,\n                state_out=None, # can be not in place if pass in state_out\n                out=y,\n                z=z,\n                zproj=zpj,\n                tile_D=64,\n                num_warps=4,\n            )\n\n        # out_proj\n        out = rearrange(y, \"b h p -> b (h p)\")\n        out = self.out_proj(out.to(x.dtype))\n\n        angle_state.copy_(nxt_angle_state)\n        # Uncomment the following if mamba3_step_fn is not in place:\n        # state_out = torch.empty_like(ssm_state)\n        # ssm_state.copy_(state_out) \n        k_state.copy_(nxt_k_state)\n        v_state.copy_(nxt_v_state)\n\n        return out, nxt_angle_state, ssm_state, nxt_k_state, nxt_v_state\n    \n    def allocate_inference_cache(self, batch_size, max_seqlen, device=None, dtype=None, inplace_state=None, **kwargs):\n        device = self.in_proj.weight.device if device is None else device\n        dtype = self.in_proj.weight.dtype if dtype is None else dtype\n\n        # RoPE State\n        angle_dt_state = torch.zeros(\n            (batch_size, self.nheads, self.num_rope_angles),\n            device=device,\n            dtype=torch.float32,\n        )\n\n        # Mamba-3 Combined Kernel States\n        # SSM State\n        ssm_state = torch.zeros(\n            (batch_size, self.nheads, self.headdim, self.d_state),\n            device=device,\n            dtype=torch.float32,\n        )\n\n        # K (=B) State\n        if self.is_mimo:\n            k_state = torch.zeros(\n                (batch_size, self.mimo_rank, self.nheads, self.d_state),\n                device=device,\n                dtype=dtype,\n            )\n        else:\n            k_state = torch.zeros(\n                (batch_size, 1, self.nheads, self.d_state),\n                device=device,\n                dtype=dtype,\n            )\n\n        # V (=x) State\n        v_state = torch.zeros(\n            (batch_size, self.nheads, self.headdim),\n            device=device,\n            dtype=dtype,\n        )\n\n        return (angle_dt_state, ssm_state, k_state, v_state)\n    \n    def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):\n        assert self.layer_idx is not None\n        device = self.in_proj.weight.device\n        dtype = self.in_proj.weight.dtype\n\n        if self.layer_idx not in inference_params.key_value_memory_dict:\n            angle_dt_state = torch.zeros(\n                (batch_size, self.nheads, self.num_rope_angles),\n                device=device,\n                dtype=torch.float32,\n            )\n            ssm_state = torch.zeros(\n                (batch_size, self.nheads, self.headdim, self.d_state),\n                device=device,\n                dtype=torch.float32,\n            )\n            if self.is_mimo:\n                k_state = torch.zeros(\n                    (batch_size, self.mimo_rank, self.nheads, self.d_state),\n                    device=device,\n                    dtype=dtype,\n                )\n            else:\n                k_state = torch.zeros(\n                    (batch_size, self.nheads, self.d_state),\n                    device=device,\n                    dtype=dtype,\n                )\n            v_state = torch.zeros(\n                (batch_size, self.nheads, self.headdim),\n                device=device,\n                dtype=dtype,\n            )\n            inference_params.key_value_memory_dict[self.layer_idx] = (angle_dt_state, ssm_state, k_state, v_state)\n        else:\n            angle_dt_state, ssm_state, k_state, v_state = inference_params.key_value_memory_dict[self.layer_idx]\n            # TODO: What if batch size changes between generation, and we reuse the same states?\n            if initialize_states:\n                angle_dt_state.zero_()\n                ssm_state.zero_()\n                k_state.zero_()\n                v_state.zero_()\n        return angle_dt_state, ssm_state, k_state, v_state\n"
  },
  {
    "path": "mamba_ssm/modules/mamba_simple.py",
    "content": "# Copyright (c) 2023, Tri Dao, Albert Gu.\n\nimport math\nfrom typing import Optional\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch import Tensor\n\nfrom einops import rearrange, repeat\n\nfrom mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn\n\ntry:\n    from causal_conv1d import causal_conv1d_fn, causal_conv1d_update\nexcept ImportError:\n    causal_conv1d_fn, causal_conv1d_update = None, None\n\ntry:\n    from mamba_ssm.ops.triton.selective_state_update import selective_state_update\nexcept ImportError:\n    selective_state_update = None\n\ntry:\n    from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn\nexcept ImportError:\n    RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None\n\n\nclass Mamba(nn.Module):\n    def __init__(\n        self,\n        d_model,\n        d_state=16,\n        d_conv=4,\n        expand=2,\n        dt_rank=\"auto\",\n        dt_min=0.001,\n        dt_max=0.1,\n        dt_init=\"random\",\n        dt_scale=1.0,\n        dt_init_floor=1e-4,\n        conv_bias=True,\n        bias=False,\n        use_fast_path=True,  # Fused kernel options\n        layer_idx=None,\n        device=None,\n        dtype=None,\n    ):\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        self.d_model = d_model\n        self.d_state = d_state\n        self.d_conv = d_conv\n        self.expand = expand\n        self.d_inner = int(self.expand * self.d_model)\n        self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == \"auto\" else dt_rank\n        self.use_fast_path = use_fast_path\n        self.layer_idx = layer_idx\n\n        self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)\n\n        self.conv1d = nn.Conv1d(\n            in_channels=self.d_inner,\n            out_channels=self.d_inner,\n            bias=conv_bias,\n            kernel_size=d_conv,\n            groups=self.d_inner,\n            padding=d_conv - 1,\n            **factory_kwargs,\n        )\n\n        self.activation = \"silu\"\n        self.act = nn.SiLU()\n\n        self.x_proj = nn.Linear(\n            self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs\n        )\n        self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)\n\n        # Initialize special dt projection to preserve variance at initialization\n        dt_init_std = self.dt_rank**-0.5 * dt_scale\n        if dt_init == \"constant\":\n            nn.init.constant_(self.dt_proj.weight, dt_init_std)\n        elif dt_init == \"random\":\n            nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)\n        else:\n            raise NotImplementedError\n\n        # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max\n        dt = torch.exp(\n            torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))\n            + math.log(dt_min)\n        ).clamp(min=dt_init_floor)\n        # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759\n        inv_dt = dt + torch.log(-torch.expm1(-dt))\n        with torch.no_grad():\n            self.dt_proj.bias.copy_(inv_dt)\n        # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit\n        self.dt_proj.bias._no_reinit = True\n\n        # S4D real initialization\n        A = repeat(\n            torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),\n            \"n -> d n\",\n            d=self.d_inner,\n        ).contiguous()\n        A_log = torch.log(A)  # Keep A_log in fp32\n        self.A_log = nn.Parameter(A_log)\n        self.A_log._no_weight_decay = True\n\n        # D \"skip\" parameter\n        self.D = nn.Parameter(torch.ones(self.d_inner, device=device))  # Keep in fp32\n        self.D._no_weight_decay = True\n\n        self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)\n\n    def forward(self, hidden_states, inference_params=None):\n        \"\"\"\n        hidden_states: (B, L, D)\n        Returns: same shape as hidden_states\n        \"\"\"\n        batch, seqlen, dim = hidden_states.shape\n\n        conv_state, ssm_state = None, None\n        if inference_params is not None:\n            conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)\n            if inference_params.seqlen_offset > 0:\n                # The states are updated inplace\n                out, _, _ = self.step(hidden_states, conv_state, ssm_state)\n                return out\n\n        # We do matmul and transpose BLH -> HBL at the same time\n        xz = rearrange(\n            self.in_proj.weight @ rearrange(hidden_states, \"b l d -> d (b l)\"),\n            \"d (b l) -> b d l\",\n            l=seqlen,\n        )\n        if self.in_proj.bias is not None:\n            xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), \"d -> d 1\")\n\n        A = -torch.exp(self.A_log.float())  # (d_inner, d_state)\n        # In the backward pass we write dx and dz next to each other to avoid torch.cat\n        if self.use_fast_path and causal_conv1d_fn is not None and inference_params is None:  # Doesn't support outputting the states\n            out = mamba_inner_fn(\n                xz,\n                self.conv1d.weight,\n                self.conv1d.bias,\n                self.x_proj.weight,\n                self.dt_proj.weight,\n                self.out_proj.weight,\n                self.out_proj.bias,\n                A,\n                None,  # input-dependent B\n                None,  # input-dependent C\n                self.D.float(),\n                delta_bias=self.dt_proj.bias.float(),\n                delta_softplus=True,\n            )\n        else:\n            x, z = xz.chunk(2, dim=1)\n            # Compute short convolution\n            if conv_state is not None:\n                # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv\n                # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.\n                conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0)))  # Update state (B D W)\n            if causal_conv1d_fn is None:\n                x = self.act(self.conv1d(x)[..., :seqlen])\n            else:\n                assert self.activation in [\"silu\", \"swish\"]\n                x = causal_conv1d_fn(\n                    x=x,\n                    weight=rearrange(self.conv1d.weight, \"d 1 w -> d w\"),\n                    bias=self.conv1d.bias,\n                    activation=self.activation,\n                )\n\n            # We're careful here about the layout, to avoid extra transposes.\n            # We want dt to have d as the slowest moving dimension\n            # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.\n            x_dbl = self.x_proj(rearrange(x, \"b d l -> (b l) d\"))  # (bl d)\n            dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)\n            dt = self.dt_proj.weight @ dt.t()\n            dt = rearrange(dt, \"d (b l) -> b d l\", l=seqlen)\n            B = rearrange(B, \"(b l) dstate -> b dstate l\", l=seqlen).contiguous()\n            C = rearrange(C, \"(b l) dstate -> b dstate l\", l=seqlen).contiguous()\n            assert self.activation in [\"silu\", \"swish\"]\n            y = selective_scan_fn(\n                x,\n                dt,\n                A,\n                B,\n                C,\n                self.D.float(),\n                z=z,\n                delta_bias=self.dt_proj.bias.float(),\n                delta_softplus=True,\n                return_last_state=ssm_state is not None,\n            )\n            if ssm_state is not None:\n                y, last_state = y\n                ssm_state.copy_(last_state)\n            y = rearrange(y, \"b d l -> b l d\")\n            out = self.out_proj(y)\n        return out\n\n    def step(self, hidden_states, conv_state, ssm_state):\n        dtype = hidden_states.dtype\n        assert hidden_states.shape[1] == 1, \"Only support decoding with 1 token at a time for now\"\n        xz = self.in_proj(hidden_states.squeeze(1))  # (B 2D)\n        x, z = xz.chunk(2, dim=-1)  # (B D)\n\n        # Conv step\n        if causal_conv1d_update is None:\n            conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1))  # Update state (B D W)\n            conv_state[:, :, -1] = x\n            x = torch.sum(conv_state * rearrange(self.conv1d.weight, \"d 1 w -> d w\"), dim=-1)  # (B D)\n            if self.conv1d.bias is not None:\n                x = x + self.conv1d.bias\n            x = self.act(x).to(dtype=dtype)\n        else:\n            x = causal_conv1d_update(\n                x,\n                conv_state,\n                rearrange(self.conv1d.weight, \"d 1 w -> d w\"),\n                self.conv1d.bias,\n                self.activation,\n            )\n\n        x_db = self.x_proj(x)  # (B dt_rank+2*d_state)\n        dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)\n        # Don't add dt_bias here\n        dt = F.linear(dt, self.dt_proj.weight)  # (B d_inner)\n        A = -torch.exp(self.A_log.float())  # (d_inner, d_state)\n\n        # SSM step\n        if selective_state_update is None:\n            # Discretize A and B\n            dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))\n            dA = torch.exp(torch.einsum(\"bd,dn->bdn\", dt, A))\n            dB = torch.einsum(\"bd,bn->bdn\", dt, B)\n            ssm_state.copy_(ssm_state * dA + rearrange(x, \"b d -> b d 1\") * dB)\n            y = torch.einsum(\"bdn,bn->bd\", ssm_state.to(dtype), C)\n            y = y + self.D.to(dtype) * x\n            y = y * self.act(z)  # (B D)\n        else:\n            y = selective_state_update(\n                ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True\n            )\n\n        out = self.out_proj(y)\n        return out.unsqueeze(1), conv_state, ssm_state\n\n    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):\n        device = self.out_proj.weight.device\n        conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype\n        conv_state = torch.zeros(\n            batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype\n        )\n        ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype\n        # ssm_dtype = torch.float32\n        ssm_state = torch.zeros(\n            batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype\n        )\n        return conv_state, ssm_state\n\n    def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):\n        assert self.layer_idx is not None\n        if self.layer_idx not in inference_params.key_value_memory_dict:\n            batch_shape = (batch_size,)\n            conv_state = torch.zeros(\n                batch_size,\n                self.d_model * self.expand,\n                self.d_conv,\n                device=self.conv1d.weight.device,\n                dtype=self.conv1d.weight.dtype,\n            )\n            ssm_state = torch.zeros(\n                batch_size,\n                self.d_model * self.expand,\n                self.d_state,\n                device=self.dt_proj.weight.device,\n                dtype=self.dt_proj.weight.dtype,\n                # dtype=torch.float32,\n            )\n            inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)\n        else:\n            conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]\n            # TODO: What if batch size changes between generation, and we reuse the same states?\n            if initialize_states:\n                conv_state.zero_()\n                ssm_state.zero_()\n        return conv_state, ssm_state\n"
  },
  {
    "path": "mamba_ssm/modules/mha.py",
    "content": "# Copyright (c) 2024, Tri Dao, Albert Gu.\n\nimport math\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom einops import rearrange\n\ntry:\n    from flash_attn import flash_attn_with_kvcache\nexcept ImportError:\n    flash_attn_with_kvcache = None\n\ntry:\n    from flash_attn.layers.rotary import RotaryEmbedding\nexcept ImportError:\n    RotaryEmbedding = None\n\ntry:\n    from causal_conv1d import causal_conv1d_fn, causal_conv1d_update\nexcept ImportError:\n    causal_conv1d_fn, causal_conv1d_update = None, None\n\n\ndef _update_kv_cache(kv, inference_params, layer_idx):\n    \"\"\"kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)\"\"\"\n    # Pre-allocate memory for key-values for inference.\n    num_heads, head_dim = kv.shape[-2:]\n    assert layer_idx in inference_params.key_value_memory_dict\n    kv_cache, _ = inference_params.key_value_memory_dict[layer_idx]\n    # Adjust key and value for inference\n    batch_start = inference_params.batch_size_offset\n    batch_end = batch_start + kv.shape[0]\n    sequence_start = inference_params.seqlen_offset\n    sequence_end = sequence_start + kv.shape[1]\n    assert batch_end <= kv_cache.shape[0]\n    assert sequence_end <= kv_cache.shape[1]\n    assert kv_cache is not None\n    kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv\n    return kv_cache[batch_start:batch_end, :sequence_end, ...]\n\n\nclass MHA(nn.Module):\n    \"\"\"Multi-head self-attention and cross-attention\"\"\"\n\n    def __init__(\n        self,\n        embed_dim,\n        num_heads,\n        num_heads_kv=None,\n        head_dim=None,  # If None, use embed_dim // num_heads\n        mlp_dim=0,\n        qkv_proj_bias=True,\n        out_proj_bias=True,\n        softmax_scale=None,\n        causal=False,\n        layer_idx=None,\n        d_conv=0,\n        rotary_emb_dim=0,\n        rotary_emb_base=10000.0,\n        rotary_emb_interleaved=False,\n        device=None,\n        dtype=None,\n    ) -> None:\n        \"\"\"\n        num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.\n        return_residual: whether to return the input x along with the output. This is for\n            performance reason: for post-norm architecture, returning the input allows us\n            to fuse the backward of nn.Linear with the residual connection.\n        \"\"\"\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        self.embed_dim = embed_dim\n        self.layer_idx = layer_idx\n        self.d_conv = d_conv\n        self.rotary_emb_dim = rotary_emb_dim\n        self.softmax_scale = softmax_scale\n        self.causal = causal\n\n        self.num_heads = num_heads\n        self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads\n        assert (\n            self.num_heads % self.num_heads_kv == 0\n        ), \"num_heads must be divisible by num_heads_kv\"\n        if head_dim is None:\n            assert self.embed_dim % num_heads == 0, \"embed_dim must be divisible by num_heads\"\n        self.head_dim = head_dim if head_dim is not None else self.embed_dim // num_heads\n        self.mlp_dim = math.ceil(mlp_dim / 256) * 256\n        qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)\n        out_dim = self.head_dim * self.num_heads\n\n        if self.rotary_emb_dim > 0:\n            assert RotaryEmbedding is not None, \"rotary requires flash_attn to be installed\"\n            self.rotary_emb = RotaryEmbedding(\n                self.rotary_emb_dim,\n                base=rotary_emb_base,\n                interleaved=rotary_emb_interleaved,\n                device=device,\n            )\n\n        self.in_proj = nn.Linear(embed_dim, qkv_dim + self.mlp_dim, bias=qkv_proj_bias, **factory_kwargs)\n        if self.d_conv > 0:\n            self.conv1d = nn.Conv1d(\n                qkv_dim, qkv_dim, kernel_size=self.d_conv, padding=self.d_conv - 1, groups=qkv_dim,\n                **factory_kwargs\n            )\n        self.out_proj = nn.Linear(out_dim + self.mlp_dim // 2, embed_dim, bias=out_proj_bias, **factory_kwargs)\n\n    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):\n        dtype = self.out_proj.weight.dtype if dtype is None else dtype\n        device = self.out_proj.weight.device\n        if self.d_conv > 0:\n            conv_state = torch.zeros(\n                batch_size, self.conv1d.weight.shape[0], self.d_conv, device=device, dtype=dtype\n            )\n        else:\n            conv_state = None\n        kv_cache = torch.empty(\n            batch_size, max_seqlen, 2, self.num_heads_kv, self.head_dim, dtype=dtype, device=device,\n        )\n        return kv_cache, conv_state\n\n    def _update_kv_cache(self, kv, inference_params):\n        \"\"\"kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)\"\"\"\n        assert self.layer_idx is not None, \"Generation requires layer_idx in the constructor\"\n        return _update_kv_cache(kv, inference_params, self.layer_idx)\n\n    def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):\n        \"\"\"\n        Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.\n        q: (batch_size, seqlen_q, nheads, head_dim)\n        kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)\n        \"\"\"\n        assert inference_params is not None and inference_params.seqlen_offset > 0\n        if self.rotary_emb_dim > 0:\n            self.rotary_emb._update_cos_sin_cache(\n                inference_params.max_seqlen, device=q.device, dtype=q.dtype\n            )\n            rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached\n        else:\n            rotary_cos, rotary_sin = None, None\n        batch = q.shape[0]\n        kv_cache, _ = inference_params.key_value_memory_dict[self.layer_idx]\n        kv_cache = kv_cache[:batch]\n        cache_seqlens = (\n            inference_params.lengths_per_sample[:batch]\n            if inference_params.lengths_per_sample is not None\n            else inference_params.seqlen_offset\n        )\n        assert flash_attn_with_kvcache is not None, \"flash_attn must be installed\"\n        context = flash_attn_with_kvcache(\n            q,\n            kv_cache[:, :, 0],\n            kv_cache[:, :, 1],\n            kv[:, :, 0],\n            kv[:, :, 1],\n            rotary_cos=rotary_cos,\n            rotary_sin=rotary_sin,\n            cache_seqlens=cache_seqlens,\n            softmax_scale=self.softmax_scale,\n            causal=self.causal,\n            rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,\n        )\n        return context\n\n    def _update_kvcache_attention(self, q, kv, inference_params):\n        \"\"\"Write kv to inference_params, then do attention\"\"\"\n        if (\n            inference_params.seqlen_offset == 0\n            or flash_attn_with_kvcache is None\n        ):\n            # TODO: this only uses seqlen_offset and not lengths_per_sample.\n            kv = self._update_kv_cache(kv, inference_params)\n            k, v = kv.unbind(dim=-3)\n            k = torch.repeat_interleave(k, dim=2, repeats=self.num_heads // self.num_heads_kv)\n            v = torch.repeat_interleave(v, dim=2, repeats=self.num_heads // self.num_heads_kv)\n            return F.scaled_dot_product_attention(\n                q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=self.causal, scale=self.softmax_scale\n            ).transpose(1, 2)\n        else:\n            batch = q.shape[0]\n            kv_cache, _ = inference_params.key_value_memory_dict[self.layer_idx]\n            kv_cache = kv_cache[:batch]\n            cache_seqlens = (\n                inference_params.lengths_per_sample[:batch]\n                if inference_params.lengths_per_sample is not None\n                else inference_params.seqlen_offset\n            )\n            return flash_attn_with_kvcache(\n                q,\n                kv_cache[:, :, 0],\n                kv_cache[:, :, 1],\n                kv[:, :, 0],\n                kv[:, :, 1],\n                cache_seqlens=cache_seqlens,\n                softmax_scale=self.softmax_scale,\n                causal=self.causal,\n            )\n\n    def forward(self, x, inference_params=None):\n        \"\"\"\n        Arguments:\n            x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if\n                cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total\n                is the is the sum of the sequence lengths in the batch.\n            inference_params: for generation. Adapted from Megatron-LM (and Apex)\n            https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470\n        \"\"\"\n        if inference_params is not None and self.layer_idx not in inference_params.key_value_memory_dict:\n            inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache(\n                x.shape[0], inference_params.max_seqlen, dtype=x.dtype\n            )\n        seqlen_offset = (\n            0\n            if inference_params is None\n            else (\n                inference_params.lengths_per_sample\n                if inference_params.lengths_per_sample is not None\n                else inference_params.seqlen_offset\n            )\n        )\n        rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None\n        qkv = self.in_proj(x)\n        if self.mlp_dim > 0:\n            qkv, x_mlp = qkv.split([qkv.shape[-1] - self.mlp_dim, self.mlp_dim], dim=-1)\n            x_mlp_up, x_mlp_gate = x_mlp.chunk(2, dim=-1)\n            x_mlp = x_mlp_up * F.silu(x_mlp_gate)\n        if self.d_conv > 0:\n            # The inference code for conv1d is pretty messy, should clean it up\n            if (inference_params is None or inference_params.seqlen_offset == 0):\n                if causal_conv1d_fn is None:\n                    qkv = rearrange(\n                        self.conv1d(rearrange(qkv, \"b s d -> b d s\"))[..., :-(self.d_conv - 1)], \"b d s -> b s d\"\n                    ).contiguous()\n                else:\n                    qkv = causal_conv1d_fn(\n                        qkv.transpose(1, 2),\n                        rearrange(self.conv1d.weight, \"d 1 w -> d w\"),\n                        self.conv1d.bias\n                    ).transpose(1, 2)\n                if inference_params is not None:\n                    _, conv_state = inference_params.key_value_memory_dict[self.layer_idx]\n                    # If we just take qkv[:, :, -self.d_conv :], it will error if seqlen < self.d_conv\n                    # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.\n                    qkv_t = rearrange(qkv, \"b l d -> b d l\")\n                    conv_state.copy_(F.pad(qkv_t, (self.d_conv - qkv_t.shape[-1], 0)))  # Update state (B D W)\n            else:\n                _, conv_state = inference_params.key_value_memory_dict[self.layer_idx]\n                assert qkv.shape[1] == 1, \"Only support decoding with 1 token at a time for now\"\n                qkv = qkv.squeeze(1)\n                # Conv step\n                if causal_conv1d_update is None:\n                    conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1))  # Update state (B D W)\n                    conv_state[:, :, -1] = qkv\n                    qkv = torch.sum(conv_state * rearrange(self.conv1d.weight, \"d 1 w -> d w\"), dim=-1)  # (B D)\n                    if self.conv1d.bias is not None:\n                        qkv = qkv + self.conv1d.bias\n                else:\n                    qkv = causal_conv1d_update(\n                        qkv,\n                        conv_state,\n                        rearrange(self.conv1d.weight, \"d 1 w -> d w\"),\n                        self.conv1d.bias\n                    )\n                qkv = qkv.unsqueeze(1)\n        q, kv = qkv.split([self.num_heads * self.head_dim, self.num_heads_kv * 2 * self.head_dim], dim=-1)\n        q = rearrange(q, \"... (h d) -> ... h d\", d=self.head_dim)\n        kv = rearrange(kv, \"... (two hkv d) -> ... two hkv d\", two=2, d=self.head_dim)\n        if (\n            inference_params is None\n            or inference_params.seqlen_offset == 0\n            or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)\n        ):\n            if self.rotary_emb_dim > 0:\n                q, kv = self.rotary_emb(\n                    q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen\n                )\n            if inference_params is None:\n                k, v = kv.unbind(dim=-3)\n                k = torch.repeat_interleave(k, dim=2, repeats=self.num_heads // self.num_heads_kv)\n                v = torch.repeat_interleave(v, dim=2, repeats=self.num_heads // self.num_heads_kv)\n                context = F.scaled_dot_product_attention(\n                    q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=self.causal, scale=self.softmax_scale\n                ).transpose(1, 2)\n            else:\n                context = self._update_kvcache_attention(q, kv, inference_params)\n        else:\n            context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)\n        context = rearrange(context, \"... h d -> ... (h d)\")\n        if self.mlp_dim > 0:\n            context = torch.cat([context, x_mlp], dim=-1)\n        out = self.out_proj(context)\n        return out\n"
  },
  {
    "path": "mamba_ssm/modules/mlp.py",
    "content": "# Copyright (c) 2024, Tri Dao, Albert Gu.\nfrom torch import nn\nfrom torch.nn import functional as F\n\n\nclass GatedMLP(nn.Module):\n    def __init__(\n        self,\n        in_features,\n        hidden_features=None,\n        out_features=None,\n        activation=F.silu,\n        bias=False,\n        multiple_of=128,\n        device=None,\n        dtype=None,\n    ):\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        out_features = out_features if out_features is not None else in_features\n        hidden_features = (\n            hidden_features if hidden_features is not None else int(8 * in_features / 3)\n        )\n        hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of\n        self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias, **factory_kwargs)\n        self.activation = activation\n        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, **factory_kwargs)\n\n    def forward(self, x):\n        y = self.fc1(x)\n        y, gate = y.chunk(2, dim=-1)\n        y = y * self.activation(gate)\n        y = self.fc2(y)\n        return y\n"
  },
  {
    "path": "mamba_ssm/modules/ssd_minimal.py",
    "content": "# Copyright (c) 2024, Albert Gu and Tri Dao.\n\"\"\"Minimal implementation of SSD.\n\nThis is the same as Listing 1 from the paper.\n\"\"\"\n\nimport torch\nimport torch.nn.functional as F\nfrom einops import rearrange, repeat\n\nfrom mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined\n\n\ndef segsum_unstable(x):\n    \"\"\"Naive segment sum calculation.\"\"\"\n    T = x.size(-1)\n    x_cumsum = torch.cumsum(x, dim=-1)\n    x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :]\n    mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)\n    x_segsum = x_segsum.masked_fill(~mask, -torch.inf)\n    return x_segsum\n\ndef segsum(x):\n    \"\"\"More stable segment sum calculation.\"\"\"\n    T = x.size(-1)\n    x = repeat(x, \"... d -> ... d e\", e=T)\n    mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1)\n    x = x.masked_fill(~mask, 0)\n    x_segsum = torch.cumsum(x, dim=-2)\n    mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)\n    x_segsum = x_segsum.masked_fill(~mask, -torch.inf)\n    return x_segsum\n\ndef ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None):\n    \"\"\"\n    Arguments:\n        X: (batch, length, n_heads, d_head)\n        A: (batch, length, n_heads)\n        B: (batch, length, n_heads, d_state)\n        C: (batch, length, n_heads, d_state)\n    Return:\n        Y: (batch, length, n_heads, d_head)\n    \"\"\"\n    assert X.dtype == A.dtype == B.dtype == C.dtype\n    assert X.shape[1] % block_len == 0\n\n    # Rearrange into blocks/chunks\n    X, A, B, C = [rearrange(x, \"b (c l) ... -> b c l ...\", l=block_len) for x in (X, A, B, C)]\n\n    A = rearrange(A, \"b c l h -> b h c l\")\n    A_cumsum = torch.cumsum(A, dim=-1)\n\n    # 1. Compute the output for each intra-chunk (diagonal blocks)\n    L = torch.exp(segsum(A))\n    Y_diag  = torch.einsum(\"bclhn,bcshn,bhcls,bcshp->bclhp\", C, B, L, X)\n\n    # 2. Compute the state for each intra-chunk\n    # (right term of low-rank factorization of off-diagonal blocks; B terms)\n    decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))\n    states = torch.einsum(\"bclhn,bhcl,bclhp->bchpn\", B, decay_states, X)\n\n    # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries\n    # (middle term of factorization of off-diag blocks; A terms)\n    if initial_states is None:\n        initial_states = torch.zeros_like(states[:, :1])\n    states = torch.cat([initial_states, states], dim=1)\n    decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))\n    new_states = torch.einsum(\"bhzc,bchpn->bzhpn\", decay_chunk, states)\n    states, final_state = new_states[:, :-1], new_states[:, -1]\n\n    # 4. Compute state -> output conversion per chunk\n    # (left term of low-rank factorization of off-diagonal blocks; C terms)\n    state_decay_out = torch.exp(A_cumsum)\n    Y_off = torch.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out)\n\n    # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)\n    Y = rearrange(Y_diag+Y_off, \"b c l h p -> b (c l) h p\")\n    return Y, final_state\n\n\n# Simple test\ndef test_correctness():\n    torch.manual_seed(42)\n\n    ## Dimensions\n    # Denoted (B, T, Q, D, P) in the paper\n    batch, seqlen, chunk_size, dim, headdim = 1, 2048, 64, 2048, 64\n    nheads = dim // headdim  # (H) in the paper\n    ngroups = 1 # (G) in the paper\n    dstate = 64  # (N) in the paper\n    dtype = torch.float32\n    device = \"cuda\"\n\n    x = torch.randn(batch, seqlen, nheads, headdim, dtype=dtype, device=device)\n    dt = F.softplus(torch.randn(batch, seqlen, nheads, dtype=torch.float32, device=device) - 4).requires_grad_()\n    A = (-torch.exp(torch.rand(nheads, dtype=torch.float32, device=device))).requires_grad_()\n    B = torch.randn(batch, seqlen, ngroups, dstate, dtype=dtype, device=device)\n    C = torch.randn(batch, seqlen, ngroups, dstate, dtype=dtype, device=device)\n    D = torch.randn(nheads, dtype=dtype, device=device)\n\n    # Comparing fused version and minimal version\n    y = mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None)\n    y_min, _ = ssd_minimal_discrete(x*dt.unsqueeze(-1), A*dt, B, C, chunk_size)\n"
  },
  {
    "path": "mamba_ssm/ops/__init__.py",
    "content": ""
  },
  {
    "path": "mamba_ssm/ops/cute/mamba3/mamba3_step_fn.py",
    "content": "# Copyright (c) 2025, Tri Dao.\n# Modified to use tvm-ffi and fake tensors instead of dlpack.\n# Modified to optionally update state in place (state_out=None) or write to separate state_out.\n\nimport math\nfrom typing import Optional, Type, Literal, List\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\n\nimport cuda.bindings.driver as cuda\n\nimport cutlass\nimport cutlass.cute as cute\nfrom cutlass import Int32, Float32, Float16, BFloat16, Boolean, const_expr\n\nfrom quack.compile_utils import make_fake_tensor\nfrom quack.cute_dsl_utils import torch2cute_dtype_map\n\n\ndef transpose_view(a: cute.Tensor) -> cute.Tensor:\n    \"\"\"Transpose the first two dimensions of a tensor on smem.\"\"\"\n    shape = (a.shape[1], a.shape[0], *a.shape[2:])\n    order = (1, 0, *range(2, cute.rank(a)))\n    return cute.composition(a, cute.make_ordered_layout(shape, order=order))\n\ndef select(a: cute.Tensor, mode: List[int]) -> cute.Tensor:\n    return cute.make_tensor(a.iterator, cute.select(a.layout, mode))\n\n\n\ndef get_gmem_tiled_copy(dtype: Type[cutlass.Numeric], major_mode_size: int, num_threads: int, is_async: bool = True):\n    num_copy_bits = math.gcd(major_mode_size, 128 // dtype.width) * dtype.width\n    copy_elems = num_copy_bits // dtype.width\n    copy_op = cute.nvgpu.cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()\n    copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)\n    gmem_threads_per_row = major_mode_size // copy_elems\n    assert num_threads % gmem_threads_per_row == 0\n    thr_layout = cute.make_ordered_layout(\n        (num_threads // gmem_threads_per_row, gmem_threads_per_row),\n        order=(1, 0),\n    )\n    val_layout = cute.make_layout((1, copy_elems))\n    return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)\n\n\nclass Mamba3Step():\n    def __init__(self, tile_D: int, dstate: int, mimo: int = 1, num_warps: int = 4, remove_gate: bool = False, remove_outproj: bool = False):\n        assert num_warps >= 2\n        assert dstate % 8 == 0, \"dstate must be multiple of 8\" # for vectorized load /store\n        self.tile_D = tile_D\n        self.dstate = dstate\n        self.mimo = mimo\n        self.num_warps = num_warps\n        self.remove_gate = remove_gate\n        self.remove_outproj = remove_outproj\n\n    def _setup_smem_layouts(self):\n        self.sState_layout = cute.make_ordered_layout((self.tile_D, self.dstate), order=(1, 0))\n        # We don't need any swizzling for Bstate, B, C\n        self.sBC_layout = cute.make_ordered_layout((self.mimo, self.dstate), order=(1, 0))\n        # We don't need any swizzling for Xproj, Zproj, Outproj\n        self.sProj_layout = cute.make_ordered_layout((self.mimo, self.tile_D), order=(1, 0))\n\n    def _setup_gmem_tiled_copy(self, ):\n        num_threads = self.num_warps * cute.arch.WARP_SIZE\n        self.gmem_tiled_copy_state = get_gmem_tiled_copy(self.dtype, self.dstate, num_threads)\n        self.gmem_tiled_copy_BC = get_gmem_tiled_copy(self.b_dtype, self.dstate, num_threads)\n        self.gmem_tiled_copy_Proj = get_gmem_tiled_copy(self.proj_dtype, self.tile_D, num_threads)\n        # Gmem tiled copy for X, Z\n        # e.g. for tile_D = 64, we only want each thread loading 2 values\n        copy_elems_x = const_expr(min(4, cute.ceil_div(self.tile_D, cute.arch.WARP_SIZE)))\n        num_copy_bits_x = copy_elems_x * self.x_dtype.width\n        copy_atom_load_x = cute.make_copy_atom(\n            cute.nvgpu.CopyUniversalOp(), self.x_dtype, num_bits_per_copy=num_copy_bits_x\n        )\n        gmem_threads_per_row = self.tile_D // copy_elems_x\n        assert cute.arch.WARP_SIZE >= gmem_threads_per_row   # Only 1 warp loads X, Z\n        self.gmem_tiled_copy_X = cute.make_tiled_copy_tv(\n            copy_atom_load_x, cute.make_layout(self.tile_D // copy_elems_x), cute.make_layout(copy_elems_x)\n        )\n\n\n    @cute.jit\n    def __call__(\n        # B: batch size, H: num heads, D: head dim, N: dstate, R: mimo\n        self,\n        mState: cute.Tensor,  # (B, H, D, N)\n        mBstate: cute.Tensor,  # (B, R, H, N)\n        mXstate: cute.Tensor,  # (B, H, D)\n        mA: cute.Tensor,  # (B, H)\n        mB: cute.Tensor,  # (B, R, H, N)\n        mC: cute.Tensor,  # (B, R, H, N)\n        mD: cute.Tensor,  # (H)\n        mX: cute.Tensor,  # (B, H, D)\n        mDt: cute.Tensor,  # (B, H)\n        mTrap: cute.Tensor,  # (B, H)\n        mXproj: cute.Tensor,  # (R, H, D)\n        mOutproj: Optional[cute.Tensor],  # (R, H, D), None if remove_outproj\n        mStateOut: cute.Tensor,  # (B, H, D, N) — same as mState for in-place, or separate\n        mOut: cute.Tensor,  # (B, H, D) or (B, R, H, D) if remove_outproj\n        mZ: Optional[cute.Tensor],  # (B, H, D), None if remove_gate\n        mZproj: Optional[cute.Tensor],  # (R, H, D), None if remove_gate\n        stream: cuda.CUstream,\n    ):\n        self.dtype = mState.element_type\n        self.b_dtype = mB.element_type\n        self.proj_dtype = mXproj.element_type\n        self.x_dtype = mX.element_type\n        assert mStateOut.element_type == self.dtype\n        assert mBstate.element_type == mB.element_type == mC.element_type\n        if const_expr(mOutproj is not None):\n            assert mXproj.element_type == mOutproj.element_type\n        if const_expr(mZ is not None):\n            assert mXproj.element_type == mZproj.element_type\n            assert mZ.element_type == self.x_dtype\n\n        self._setup_smem_layouts()\n        self._setup_gmem_tiled_copy()\n\n        # TV layout, this is the most important step as it decides which elements in B, C, State\n        # each thread will load from smem\n        num_threads = self.num_warps * cute.arch.WARP_SIZE\n        # TODO: these need to be adjusted based on dstate and tile_D\n        assert self.dstate in [32, 64, 128]\n        # TODO: This is not optimal for dstate=32 and 64, just to get sth quick to run\n        vecsize_dstate = 4 if self.dstate == 128 else 2 if self.dstate == 64 else 1\n        threads_per_dstate = self.dstate // vecsize_dstate\n        assert cute.arch.WARP_SIZE % threads_per_dstate == 0\n        num_groups = num_threads // threads_per_dstate\n        assert self.tile_D % num_groups == 0\n        lanes_per_D = self.tile_D // num_groups\n        copy_atom_state_s2r = cute.make_copy_atom(\n            cute.nvgpu.CopyUniversalOp(), mState.element_type, num_bits_per_copy=vecsize_dstate * mState.element_type.width\n        )\n        tiled_copy_state_s2r = cute.make_tiled_copy_tv(\n            copy_atom_state_s2r,\n            cute.make_ordered_layout((num_groups, threads_per_dstate), order=(1, 0)),\n            cute.make_ordered_layout((lanes_per_D, vecsize_dstate), order=(1, 0)),\n        )\n        copy_atom_B_s2r = cute.make_copy_atom(\n            cute.nvgpu.CopyUniversalOp(), mB.element_type, num_bits_per_copy=vecsize_dstate * mB.element_type.width\n        )\n        tiled_copy_B_s2r = cute.make_tiled_copy_tv(\n            copy_atom_B_s2r,\n            cute.make_ordered_layout((1, threads_per_dstate), order=(1, 0)),\n            cute.make_ordered_layout((1, vecsize_dstate), order=(1, 0)),\n        )\n\n        self.buffer_align_bytes = 1024\n\n        sZproj_size = cute.cosize(self.sProj_layout) if not self.remove_gate else 0\n        sOutproj_size = cute.cosize(self.sProj_layout) if not self.remove_outproj else 0\n\n        @cute.struct\n        class SharedStorage:\n            sX: cute.struct.Align[cute.struct.MemRange[Float32, self.tile_D], 128]\n            sXgamma: cute.struct.Align[cute.struct.MemRange[Float32, self.tile_D], 128]\n            sXstate: cute.struct.Align[cute.struct.MemRange[Float32, self.tile_D], 128]\n            sState: cute.struct.Align[\n                cute.struct.MemRange[self.dtype, cute.cosize(self.sState_layout)],\n                self.buffer_align_bytes,\n            ]\n            sBstate: cute.struct.Align[\n                cute.struct.MemRange[self.b_dtype, cute.cosize(self.sBC_layout)],\n                self.buffer_align_bytes,\n            ]\n            sB: cute.struct.Align[\n                cute.struct.MemRange[self.b_dtype, cute.cosize(self.sBC_layout)],\n                self.buffer_align_bytes,\n            ]\n            sC: cute.struct.Align[\n                cute.struct.MemRange[self.b_dtype, cute.cosize(self.sBC_layout)],\n                self.buffer_align_bytes,\n            ]\n            sXproj: cute.struct.Align[\n                cute.struct.MemRange[self.proj_dtype, cute.cosize(self.sProj_layout)],\n                self.buffer_align_bytes,\n            ]\n            sZproj: cute.struct.Align[\n                cute.struct.MemRange[self.proj_dtype, sZproj_size],\n                self.buffer_align_bytes,\n            ]\n            sOutproj: cute.struct.Align[\n                cute.struct.MemRange[self.proj_dtype, sOutproj_size],\n                self.buffer_align_bytes,\n            ]\n\n        self.shared_storage = SharedStorage\n\n        self.kernel(\n            mState,\n            mBstate,\n            mXstate,\n            mA,\n            mB,\n            mC,\n            mD,\n            mX,\n            mDt,\n            mTrap,\n            mXproj,\n            mOutproj,\n            mStateOut,\n            mOut,\n            mZ,\n            mZproj,\n            self.sState_layout,\n            self.sBC_layout,\n            self.sProj_layout,\n            self.gmem_tiled_copy_state,\n            self.gmem_tiled_copy_BC,\n            self.gmem_tiled_copy_Proj,\n            self.gmem_tiled_copy_X,\n            tiled_copy_state_s2r,\n            tiled_copy_B_s2r,\n            vecsize_dstate,\n        ).launch(\n            # grid: (d, h, b)\n            grid=[cute.ceil_div(mState.shape[2], self.tile_D), mState.shape[1], mState.shape[0]],\n            block=[num_threads, 1, 1],\n            stream=stream,\n        )\n\n    @cute.kernel\n    def kernel(\n        self,\n        mState: cute.Tensor,  # (B, H, D, N)\n        mBstate: cute.Tensor,  # (B, R, H, N)\n        mXstate: cute.Tensor,  # (B, H, D)\n        mA: cute.Tensor,  # (B, H)\n        mB: cute.Tensor,  # (B, R, H, N)\n        mC: cute.Tensor,  # (B, R, H, N)\n        mD: cute.Tensor,  # (H)\n        mX: cute.Tensor,  # (B, H, D)\n        mDt: cute.Tensor,  # (B, H)\n        mTrap: cute.Tensor,  # (B, H)\n        mXproj: cute.Tensor,  # (R, H, D)\n        mOutproj: Optional[cute.Tensor],  # (R, H, D), None if remove_outproj\n        mStateOut: cute.Tensor,  # (B, H, D, N)\n        mOut: cute.Tensor,  # (B, H, D) or (B, R, H, D) if remove_outproj\n        mZ: Optional[cute.Tensor],  # (B, H, D), None if remove_gate\n        mZproj: Optional[cute.Tensor],  # (R, H, D), None if remove_gate\n        sState_layout: cute.Layout | cute.ComposedLayout,\n        sBC_layout: cute.Layout | cute.ComposedLayout,\n        sProj_layout: cute.Layout | cute.ComposedLayout,\n        gmem_tiled_copy_state: cute.TiledCopy,\n        gmem_tiled_copy_BC: cute.TiledCopy,\n        gmem_tiled_copy_Proj: cute.TiledCopy,\n        gmem_tiled_copy_X: cute.TiledCopy,\n        tiled_copy_state_s2r: cute.TiledCopy,\n        tiled_copy_B_s2r: cute.TiledCopy,\n        vecsize_dstate: cutlass.Constexpr[int],\n    ):\n        tidx, _, _ = cute.arch.thread_idx()\n        bidd, bidh, bidb = cute.arch.block_idx()\n        warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())\n        lane_idx = cute.arch.lane_idx()\n\n        limit_d = mState.shape[2]\n\n        # ///////////////////////////////////////////////////////////////////////////////\n        #  Slice for CTA\n        # ///////////////////////////////////////////////////////////////////////////////\n        # (tile_D, N)\n        gState, gStateOut = [\n            cute.local_tile(t[bidb, bidh, None, None], (self.tile_D, self.dstate), (bidd, 0))\n            for t in (mState, mStateOut)\n        ]\n        # (R, N)\n        gBstate, gB, gC = [\n            cute.local_tile(t[bidb, None, bidh, None], (self.mimo, self.dstate), (0, 0))\n            for t in (mBstate, mB, mC)\n        ]\n        # (tile_D,)\n        gXstate, gX = [\n            cute.local_tile(t[bidb, bidh, None], (self.tile_D,), (bidd,))\n            for t in (mXstate, mX)\n        ]\n        if const_expr(mOutproj is not None):\n            # Output is (B, H, D), outproj reduces MIMO rank\n            gOut = cute.local_tile(mOut[bidb, bidh, None], (self.tile_D,), (bidd,))\n            gXproj = cute.local_tile(mXproj[None, bidh, None], (self.mimo, self.tile_D), (0, bidd))\n            gOutproj = cute.local_tile(mOutproj[None, bidh, None], (self.mimo, self.tile_D), (0, bidd))\n        else:\n            # Output is (B, R, H, D), no outproj reduction\n            gXproj = cute.local_tile(mXproj[None, bidh, None], (self.mimo, self.tile_D), (0, bidd))\n            gOutproj = None\n        if const_expr(mZ is not None):\n            gZ = cute.local_tile(mZ[bidb, bidh, None], (self.tile_D,), (bidd,))\n            gZproj = cute.local_tile(mZproj[None, bidh, None], (self.mimo, self.tile_D), (0, bidd))\n\n        # ///////////////////////////////////////////////////////////////////////////////\n        #  Generate smem tensors\n        # ///////////////////////////////////////////////////////////////////////////////\n        smem = cutlass.utils.SmemAllocator()\n        storage = smem.allocate(self.shared_storage)\n\n        sState = storage.sState.get_tensor(sState_layout)\n        sBstate = storage.sBstate.get_tensor(sBC_layout)\n        sB = storage.sB.get_tensor(sBC_layout)\n        sC = storage.sC.get_tensor(sBC_layout)\n        sXproj = storage.sXproj.get_tensor(sProj_layout)\n        sZproj = storage.sZproj.get_tensor(sProj_layout) if const_expr(mZ is not None) else None\n        sOutproj = storage.sOutproj.get_tensor(sProj_layout) if const_expr(mOutproj is not None) else None\n        sXstate = storage.sXstate.get_tensor(cute.make_layout(self.tile_D))\n        sX = storage.sX.get_tensor(cute.make_layout(self.tile_D))\n        sXgamma = storage.sXgamma.get_tensor(cute.make_layout(self.tile_D))\n\n        # ///////////////////////////////////////////////////////////////////////////////\n        #  Partitioning using copy atoms\n        # ///////////////////////////////////////////////////////////////////////////////\n        gmem_thr_copy_state = gmem_tiled_copy_state.get_slice(tidx)\n        # copying states from r2g uses the same tiled copy as s2r\n        gmem_thr_copy_StateOut = tiled_copy_state_s2r.get_slice(tidx)\n        gmem_thr_copy_BC = gmem_tiled_copy_BC.get_slice(tidx)\n        gmem_thr_copy_Proj = gmem_tiled_copy_Proj.get_slice(tidx)\n        gmem_thr_copy_X = gmem_tiled_copy_X.get_slice(lane_idx)  # Only 1 warp loads X, Z\n\n        tSgS = gmem_thr_copy_state.partition_S(gState)\n        tSsS_g2s = gmem_thr_copy_state.partition_D(sState)\n        tSgSOut = gmem_thr_copy_StateOut.partition_D(gStateOut)\n        tBCgBstate = gmem_thr_copy_BC.partition_S(gBstate)\n        tBCsBstate = gmem_thr_copy_BC.partition_D(sBstate)\n        tBCgB = gmem_thr_copy_BC.partition_S(gB)\n        tBCsB = gmem_thr_copy_BC.partition_D(sB)\n        tBCgC = gmem_thr_copy_BC.partition_S(gC)\n        tBCsC = gmem_thr_copy_BC.partition_D(sC)\n        tPgXproj = gmem_thr_copy_Proj.partition_S(gXproj)\n        tPsXproj = gmem_thr_copy_Proj.partition_D(sXproj)\n        if const_expr(mZ is not None):\n            tPgZproj = gmem_thr_copy_Proj.partition_S(gZproj)\n            tPsZproj = gmem_thr_copy_Proj.partition_D(sZproj)\n        if const_expr(mOutproj is not None):\n            tPgOutproj = gmem_thr_copy_Proj.partition_S(gOutproj)\n            tPsOutproj = gmem_thr_copy_Proj.partition_D(sOutproj)\n        tXgX = gmem_thr_copy_X.partition_S(gX)\n        tXsX = gmem_thr_copy_X.partition_D(sX)\n        tXsXgamma = gmem_thr_copy_X.partition_D(sXgamma)\n        tXgXstate = gmem_thr_copy_X.partition_S(gXstate)\n        tXsXstate = gmem_thr_copy_X.partition_D(sXstate)\n\n        # Idk why this order of threads_per_dstate and num_groups are reversed\n        threads_per_dstate, num_groups = tiled_copy_state_s2r.layout_tv_tiled[0].shape\n        lanes_per_D = self.tile_D // num_groups\n\n        # For bound checking\n        cS = cute.make_identity_tensor((self.tile_D, self.dstate))\n        tScS = gmem_thr_copy_state.partition_S(cS)\n        cBC = cute.make_identity_tensor((self.mimo, self.dstate))\n        tBCcBC = gmem_thr_copy_BC.partition_S(cBC)\n        cProj = cute.make_identity_tensor((self.mimo, self.tile_D))\n        tPcProj = gmem_thr_copy_Proj.partition_S(cProj)\n\n        A_val = Float32(mA[bidb, bidh])\n        dt_val = Float32(mDt[bidb, bidh])\n        trap_val = Float32(mTrap[bidb, bidh])\n\n        # Load X and Xstate, these are small so we want to kick them off first\n        tXrX = cute.make_fragment_like(tXgX)\n        tXrXstate = cute.make_fragment_like(tXgXstate)\n        copy_elems_x = cute.size(tXgX.shape[0][0])\n        assert cute.size(tXgX.shape) == copy_elems_x  # Only 1 load instruction\n        num_loads_X = const_expr(self.tile_D // copy_elems_x)\n        need_bound_check_X = const_expr(cute.arch.WARP_SIZE > num_loads_X)\n        if warp_idx == 0:\n            if not need_bound_check_X or lane_idx < num_loads_X:\n                cute.copy(gmem_tiled_copy_X, tXgX, tXrX)\n        if warp_idx == 1:\n            if not need_bound_check_X or lane_idx < num_loads_X:\n                cute.copy(gmem_tiled_copy_X, tXgXstate, tXrXstate)\n\n        # Load Bstate, B, Xproj need bound checking\n        for m in cutlass.range(cute.size(tBCcBC.shape[1]), unroll_full=True):\n            if tBCcBC[0, m, 0][0] < self.mimo:\n                cute.copy(gmem_tiled_copy_BC, tBCgBstate[None, m, None], tBCsBstate[None, m, None])\n                cute.copy(gmem_tiled_copy_BC, tBCgB[None, m, None], tBCsB[None, m, None])\n        for m in cutlass.range(cute.size(tPcProj.shape[1]), unroll_full=True):\n            if tPcProj[0, m, 0][0] < self.mimo:\n                cute.copy(gmem_tiled_copy_Proj, tPgXproj[None, m, None], tPsXproj[None, m, None])\n        cute.arch.cp_async_commit_group()\n\n        # Load State, not doing any bound check for now\n        cute.copy(gmem_tiled_copy_state, tSgS, tSsS_g2s)\n        cute.arch.cp_async_commit_group()\n\n        alpha_val = cute.arch.exp(A_val * dt_val)\n        # Transform X and Xstate by multiplying with gamma and beta, then write to smem\n        if warp_idx == 0:\n            tXrX_f32 = cute.make_fragment_like(tXrX, Float32)\n            tXrX_f32.store(tXrX.load().to(Float32))\n            if not need_bound_check_X or lane_idx < num_loads_X:\n                cute.autovec_copy(tXrX_f32, tXsX)\n            gamma_val = trap_val * dt_val\n            tXrX_f32.store(tXrX_f32.load() * gamma_val)\n            if not need_bound_check_X or lane_idx < num_loads_X:\n                cute.autovec_copy(tXrX_f32, tXsXgamma)\n        if warp_idx == 1:\n            beta_val = (1.0 - trap_val) * dt_val * alpha_val\n            tXrXstate_f32 = cute.make_fragment_like(tXgXstate, Float32)\n            tXrXstate_f32.store(tXrXstate.load().to(Float32) * beta_val)\n            if not need_bound_check_X or lane_idx < num_loads_X:\n                cute.autovec_copy(tXrXstate_f32, tXsXstate)\n\n        # Load C, need bound checking\n        for m in cutlass.range(cute.size(tBCcBC.shape[1]), unroll_full=True):\n            if tBCcBC[0, m, 0][0] < self.mimo:\n                cute.copy(gmem_tiled_copy_BC, tBCgC[None, m, None], tBCsC[None, m, None])\n        cute.arch.cp_async_commit_group()\n\n        cute.arch.cp_async_wait_group(2)  # B, Bstate, Xproj are done loading\n        cute.arch.sync_threads()\n        # Load B, Bstate, Xproj from smem\n        smem_thr_copy_B = tiled_copy_B_s2r.get_slice(tidx % threads_per_dstate)\n        # ((vecsize_dstate, 1), mimo, 1) -> ((vecsize_dstate, 1), mimo)\n        tSsB = smem_thr_copy_B.partition_S(sB)[None, None, 0]\n        tSsBstate = smem_thr_copy_B.partition_S(sBstate)[None, None, 0]\n        tSrB = cute.make_fragment_like(tSsB)\n        tSrBstate = cute.make_fragment_like(tSsBstate)\n        cute.autovec_copy(tSsB, tSrB)\n        cute.autovec_copy(tSsBstate, tSrBstate)\n        tSrB_f32 = cute.make_fragment_like(tSrB, Float32)\n        tSrB_f32.store(tSrB.load().to(Float32))\n        tSrBstate_f32 = cute.make_fragment_like(tSrBstate, Float32)\n        tSrBstate_f32.store(tSrBstate.load().to(Float32))\n        # Loading x and xstate, at most 1 val per thread\n        x_val = Float32(0.0)\n        if lane_idx < lanes_per_D:\n            # TODO: should this be warp_idx or group_idx?\n            x_val = sXgamma[warp_idx * lanes_per_D + lane_idx]\n        x_state_val = Float32(0.0)\n        if lane_idx < lanes_per_D:\n            x_state_val = sXstate[warp_idx * lanes_per_D + lane_idx]\n\n        new_state = cute.make_fragment((vecsize_dstate, lanes_per_D), Float32)\n        for r in cutlass.range_constexpr(self.mimo):\n            x_proj_val = Float32(0.0)\n            if lane_idx < lanes_per_D:\n                x_proj_val = Float32(sXproj[r, warp_idx * lanes_per_D + lane_idx])\n            x_gamma_x_proj_val = x_val * x_proj_val\n            x_state_x_proj_val = x_state_val * x_proj_val\n            for d in cutlass.range(lanes_per_D, unroll_full=True):\n                xg = cute.arch.shuffle_sync(x_gamma_x_proj_val, d)\n                xb = cute.arch.shuffle_sync(x_state_x_proj_val, d)\n                for v in cutlass.range(vecsize_dstate, unroll_full=True):\n                    if const_expr(r == 0):\n                        new_state[v, d] = xg * tSrB_f32[v, r]\n                    else:\n                        new_state[v, d] += xg * tSrB_f32[v, r]\n                    new_state[v, d] += xb * tSrBstate_f32[v, r]\n\n        cute.arch.cp_async_wait_group(1)  # state is done loading\n        cute.arch.sync_threads()\n        thr_copy_state_s2r = tiled_copy_state_s2r.get_slice(tidx)\n        # ((vecsize_state, lanes_per_D), 1, 1)\n        tSsS = thr_copy_state_s2r.partition_S(sState)\n        tSrS = cute.make_fragment_like(tSsS)\n        cute.autovec_copy(tSsS, tSrS)\n\n        # ((vecsize_state, lanes_per_D), 1, 1)\n        # tSrS_f32 = cute.make_fragment_like(tSrS, Float32)\n        tSrS_f32 = cute.make_fragment(((vecsize_dstate, 1), lanes_per_D, 1), Float32)\n        assert cute.size(tSrS.shape) == cute.size(tSrS_f32.shape)\n        tSrS_f32.store(tSrS.load().to(Float32))\n        for v in cutlass.range(cute.size(tSrS_f32), unroll_full=True):\n            tSrS_f32[v] = tSrS_f32[v] * alpha_val + new_state[v]\n        tSrS.store(tSrS_f32.load().to(self.dtype))\n\n        # Load Z from gmem -> rmem, it's small, at most 1 val per thread\n        if const_expr(mZ is not None):\n            z_val = Float32(0.0)\n            if lane_idx < lanes_per_D:\n                z_val = Float32(gZ[warp_idx * lanes_per_D + lane_idx])\n        # Load Zproj and Outproj, need bound checking\n        for m in cutlass.range(cute.size(tPcProj.shape[1]), unroll_full=True):\n            if tPcProj[0, m, 0][0] < self.mimo:\n                if const_expr(mZ is not None):\n                    cute.copy(gmem_tiled_copy_Proj, tPgZproj[None, m, None], tPsZproj[None, m, None])\n                if const_expr(mOutproj is not None):\n                    cute.copy(gmem_tiled_copy_Proj, tPgOutproj[None, m, None], tPsOutproj[None, m, None])\n        cute.arch.cp_async_commit_group()\n\n        # Write state back to StateOut (may be same memory as State for in-place)\n        cute.copy(tiled_copy_state_s2r, tSrS, tSgSOut)\n\n        # Do state @ C\n        cute.arch.cp_async_wait_group(1)  # C is done loading\n        cute.arch.sync_threads()\n        # ((vecsize_dstate, 1), mimo, 1) -> ((vecsize_dstate, 1), 1, mimo)\n        tSsC = select(smem_thr_copy_B.partition_S(sC), mode=[0, 2, 1])\n        tSrC = cute.make_fragment_like(tSsC)\n        cute.autovec_copy(tSsC, tSrC)\n        tSrC_f32 = cute.make_fragment_like(tSrC, Float32)\n        tSrC_f32.store(tSrC.load().to(Float32))\n        out_expanded = cute.make_fragment((lanes_per_D, self.mimo), Float32)\n        # tSrS_f32 has shape ((vecsize_dstate, 1), lanes_per_D, 1)\n        # tSrC has shape ((vecsize_dstate, 1), mimo)\n        out_expanded.store(\n            (tSrS_f32.load() * tSrC_f32.load()).reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=(0, None, None))\n        )\n        assert lanes_per_D <= threads_per_dstate\n        for d in cutlass.range(lanes_per_D, unroll_full=True):\n            for r in cutlass.range(self.mimo, unroll_full=True):\n                out_expanded[d, r] += cute.arch.shuffle_sync_bfly(out_expanded[d, r], offset=16)\n        for i in cutlass.range_constexpr(int(math.log2(lanes_per_D))):\n            step = 1 << (int(math.log2(lanes_per_D)) - 1 - i)\n            should_swap = not Boolean(lane_idx & step)\n            for j in cutlass.range_constexpr(step):\n                for r in cutlass.range(self.mimo, unroll_full=True):\n                    lower, upper = out_expanded[j, r], out_expanded[j + step, r]\n                    out_expanded[j, r] = upper if should_swap else lower\n                    out_expanded[j + step, r] = lower if should_swap else upper\n                    shfl_val = cute.arch.shuffle_sync_bfly(out_expanded[j, r], offset=step)\n                    out_expanded[j, r] = shfl_val + out_expanded[j + step, r]\n        # After this, the out values are just out_expanded[0, None]\n        out = out_expanded[0, None]  # (mimo,)\n\n        # Add D * x * x_proj to out\n        D_val = Float32(mD[bidh])\n        x_val = Float32(0.0)\n        if lane_idx < lanes_per_D:\n            x_val = sX[warp_idx * lanes_per_D + lane_idx]\n        for r in cutlass.range_constexpr(self.mimo):\n            x_proj_val = Float32(0.0)\n            if lane_idx < lanes_per_D:\n                x_proj_val = Float32(sXproj[r, warp_idx * lanes_per_D + lane_idx])\n            out[r] += D_val * x_val * x_proj_val\n\n        cute.arch.cp_async_wait_group(0)  # Zproj and Outproj are done loading\n        cute.arch.sync_threads()\n\n        if const_expr(mOutproj is not None):\n            # Gate: z_r * sigmoid(z_r)\n            if const_expr(mZ is not None):\n                for r in cutlass.range_constexpr(self.mimo):\n                    z_proj_val = Float32(0.0)\n                    if lane_idx < lanes_per_D:\n                        z_proj_val = Float32(sZproj[r, warp_idx * lanes_per_D + lane_idx])\n                    z_r_half = 0.5 * (z_val * z_proj_val)\n                    z_r_silu = z_r_half * cute.math.tanh(z_r_half, fastmath=True) + z_r_half\n                    out[r] *= z_r_silu\n\n            # Final projection along mimo dim\n            out_val = Float32(0.0)\n            for r in cutlass.range_constexpr(self.mimo):\n                out_proj_val = Float32(0.0)\n                if lane_idx < lanes_per_D:\n                    out_proj_val = Float32(sOutproj[r, warp_idx * lanes_per_D + lane_idx])\n                if const_expr(r == 0):\n                    out_val = out[r] * out_proj_val\n                else:\n                    out_val += out[r] * out_proj_val\n\n            # Write final output (B, H, D)\n            if lane_idx < lanes_per_D:\n                gOut[warp_idx * lanes_per_D + lane_idx] = out_val.to(mOut.element_type)\n        else:\n            # No outproj: write per-rank output (B, R, H, D)\n            for r in cutlass.range_constexpr(self.mimo):\n                gOut_r = cute.local_tile(mOut[bidb, r, bidh, None], (self.tile_D,), (bidd,))\n                if lane_idx < lanes_per_D:\n                    gOut_r[warp_idx * lanes_per_D + lane_idx] = out[r].to(mOut.element_type)\n\n\ndef mamba3_step_fn(\n    # B: batch size, H: num heads, D: head dim, N: dstate, R: mimo\n    state: Tensor,  # (B, H, D, N) — updated in place if state_out is None\n    Bstate: Tensor,  # (B, R, H, N)\n    Xstate: Tensor,  # (B, H, D)\n    A: Tensor,  # (B, H)\n    B: Tensor,  # (B, R, H, N)\n    C: Tensor,  # (B, R, H, N)\n    D: Tensor,  # (H)\n    x: Tensor,  # (B, H, D)\n    dt: Tensor,  # (B, H)\n    trap: Tensor,  # (B, H)\n    xproj: Tensor,  # (R, H, D)\n    outproj: Optional[Tensor] = None,  # (R, H, D), None if remove_outproj\n    state_out: Optional[Tensor] = None,  # (B, H, D, N), None for in-place update\n    out: Tensor = None,  # (B, H, D) or (B, R, H, D) if remove_outproj\n    z: Optional[Tensor] = None,  # (B, H, D), None if remove_gate\n    zproj: Optional[Tensor] = None,  # (R, H, D), None if remove_gate\n    tile_D: int = 64,\n    num_warps: int = 2,\n) -> None:\n    has_z = z is not None\n    has_outproj = outproj is not None\n    inplace = state_out is None\n    batch, nheads, hdim, dstate = state.shape\n    mimo = Bstate.shape[1]\n    assert state.shape == (batch, nheads, hdim, dstate)\n    assert Bstate.shape == (batch, mimo, nheads, dstate)\n    assert Xstate.shape == (batch, nheads, hdim)\n    assert A.shape == (batch, nheads)\n    assert B.shape == (batch, mimo, nheads, dstate)\n    assert C.shape == (batch, mimo, nheads, dstate)\n    assert D.shape == (nheads,)\n    assert x.shape == (batch, nheads, hdim)\n    if has_z:\n        assert z.shape == (batch, nheads, hdim)\n        assert zproj is not None\n        assert zproj.shape == (mimo, nheads, hdim)\n    assert dt.shape == (batch, nheads)\n    assert trap.shape == (batch, nheads)\n    assert xproj.shape == (mimo, nheads, hdim)\n    xproj = xproj.contiguous()\n    if has_outproj:\n        assert outproj.shape == (mimo, nheads, hdim)\n        assert out.shape == (batch, nheads, hdim)\n    else:\n        assert out.shape == (batch, mimo, nheads, hdim)\n\n    # Use state itself as output target when in-place\n    if inplace:\n        state_out = state\n    else:\n        assert state_out.shape == (batch, nheads, hdim, dstate)\n\n    required_tensors = [state, Bstate, Xstate, A, B, C, D, x, dt, trap, xproj, state_out, out]\n    if has_outproj:\n        required_tensors.append(outproj)\n    if has_z:\n        required_tensors.extend([z, zproj])\n    assert all(t.is_cuda for t in required_tensors)\n    assert state.dtype in [torch.float16, torch.bfloat16, torch.float32], \"Unsupported input dtype\"\n\n    # Map torch dtypes to cutlass dtypes\n    state_cute_dtype = torch2cute_dtype_map[state.dtype]\n    b_cute_dtype = torch2cute_dtype_map[Bstate.dtype]\n    x_cute_dtype = torch2cute_dtype_map[x.dtype]\n    proj_cute_dtype = torch2cute_dtype_map[xproj.dtype]\n    a_cute_dtype = torch2cute_dtype_map[A.dtype]\n    d_cute_dtype = torch2cute_dtype_map[D.dtype]\n    dt_cute_dtype = torch2cute_dtype_map[dt.dtype]\n    trap_cute_dtype = torch2cute_dtype_map[trap.dtype]\n\n    compile_key = (\n        tile_D,\n        num_warps,\n        dstate,\n        hdim,\n        mimo,\n        state.dtype,\n        Bstate.dtype,\n        xproj.dtype,\n        A.dtype,\n        D.dtype,\n        dt.dtype,\n        trap.dtype,\n        has_z,\n        has_outproj,\n    )\n    if compile_key not in mamba3_step_fn.compile_cache:\n        mamba3_step_op = Mamba3Step(tile_D, dstate, mimo, num_warps, remove_gate=not has_z, remove_outproj=not has_outproj)\n\n        # Create symbolic dimensions for batch and nheads\n        batch_sym = cute.sym_int()\n        nheads_sym = cute.sym_int()\n\n        # Divisibility for strides (128-bit alignment)\n        div_state = 128 // state_cute_dtype.width\n        div_b = 128 // b_cute_dtype.width\n        div_x = 128 // x_cute_dtype.width\n        div_proj = 128 // proj_cute_dtype.width\n        div_a = 128 // a_cute_dtype.width\n        div_d = 128 // d_cute_dtype.width\n        div_dt = 128 // dt_cute_dtype.width\n        div_trap = 128 // trap_cute_dtype.width\n\n        # Create fake tensors with symbolic batch/nheads dimensions\n        state_fake = make_fake_tensor(state_cute_dtype, (batch_sym, nheads_sym, hdim, dstate), div_state)\n        Bstate_fake = make_fake_tensor(b_cute_dtype, (batch_sym, mimo, nheads_sym, dstate), div_b)\n        Xstate_fake = make_fake_tensor(x_cute_dtype, (batch_sym, nheads_sym, hdim), div_x)\n        A_fake = make_fake_tensor(a_cute_dtype, (batch_sym, nheads_sym), div_a)\n        B_fake = make_fake_tensor(b_cute_dtype, (batch_sym, mimo, nheads_sym, dstate), div_b)\n        C_fake = make_fake_tensor(b_cute_dtype, (batch_sym, mimo, nheads_sym, dstate), div_b)\n        D_fake = make_fake_tensor(d_cute_dtype, (nheads_sym,), div_d)\n        x_fake = make_fake_tensor(x_cute_dtype, (batch_sym, nheads_sym, hdim), div_x)\n        dt_fake = make_fake_tensor(dt_cute_dtype, (batch_sym, nheads_sym), div_dt)\n        trap_fake = make_fake_tensor(trap_cute_dtype, (batch_sym, nheads_sym), div_trap)\n        xproj_fake = make_fake_tensor(proj_cute_dtype, (mimo, nheads_sym, hdim), div_proj)\n        outproj_fake = make_fake_tensor(proj_cute_dtype, (mimo, nheads_sym, hdim), div_proj) if has_outproj else None\n        state_out_fake = make_fake_tensor(state_cute_dtype, (batch_sym, nheads_sym, hdim, dstate), div_state)\n        if has_outproj:\n            out_fake = make_fake_tensor(x_cute_dtype, (batch_sym, nheads_sym, hdim), div_x)\n        else:\n            out_fake = make_fake_tensor(x_cute_dtype, (batch_sym, mimo, nheads_sym, hdim), div_x)\n        z_fake = make_fake_tensor(x_cute_dtype, (batch_sym, nheads_sym, hdim), div_x) if has_z else None\n        zproj_fake = make_fake_tensor(proj_cute_dtype, (mimo, nheads_sym, hdim), div_proj) if has_z else None\n\n        fake_stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True)\n\n        mamba3_step_fn.compile_cache[compile_key] = cute.compile(\n            mamba3_step_op,\n            state_fake,\n            Bstate_fake,\n            Xstate_fake,\n            A_fake,\n            B_fake,\n            C_fake,\n            D_fake,\n            x_fake,\n            dt_fake,\n            trap_fake,\n            xproj_fake,\n            outproj_fake,\n            state_out_fake,\n            out_fake,\n            z_fake,\n            zproj_fake,\n            fake_stream,\n            options=\"--enable-tvm-ffi\",\n        )\n\n    # Call with real PyTorch tensors directly (no dlpack conversion needed)\n    # When inplace, state_out is state (set above)\n    mamba3_step_fn.compile_cache[compile_key](\n        state,\n        Bstate,\n        Xstate,\n        A,\n        B,\n        C,\n        D,\n        x,\n        dt,\n        trap,\n        xproj,\n        outproj,\n        state_out,\n        out,\n        z,\n        zproj,\n    )\n\n\nmamba3_step_fn.compile_cache = {}\n\n\ndef selective_state_update_fused_ref_v2(\n    state, A, B, C, xproj, x, zproj, z, dt, B_state, x_state, trap, D, outproj,\n    compute_dtype=torch.float32\n):\n    \"\"\"\n    Reference to match the new fused kernel API.\n\n    Shapes:\n        state:   (B, N, H, S)\n        A:       (B, N)\n        B:       (B, R, N, S)\n        C:       (B, R, N, S)\n        xproj:   (R, N, H)\n        x:       (B, N, H)\n        zproj:   (R, N, H)\n        z:       (B, N, H)\n        dt:      (B, N)\n        B_state: (B, R, N, S)\n        x_state: (B, N, H)\n        trap:    (B, N)\n        D:       (N,)\n        outproj: (R, N, H)\n\n    Returns:\n        out:       (B, N, H)\n        new_state: (B, N, H, S)\n    \"\"\"\n    Bsz, N, H, S = state.shape\n    R = B.shape[1]\n\n    # Dtypes for numerics (match kernel's fp32 accum)\n    og_dtype = state.dtype\n    A_f    = A.to(compute_dtype)       # (B, N)\n    dt_f   = dt.to(compute_dtype)      # (B, N)\n    trap_f = trap.to(compute_dtype)    # (B, N)\n    D_f    = D.to(compute_dtype)       # (N,)\n    x_f    = x.to(compute_dtype)       # (B, N, H)\n    xst_f  = x_state.to(compute_dtype) # (B, N, H)\n    B_f    = B.to(compute_dtype)       # (B, R, N, S)\n    C_f    = C.to(compute_dtype)       # (B, R, N, S)\n    Bst_f  = B_state.to(compute_dtype) # (B, R, N, S)\n    Xp_f   = xproj.to(compute_dtype)   # (R, N, H)\n    st_f   = state.to(compute_dtype)   # (B, N, H, S)\n\n    alpha = torch.exp(A_f * dt_f)                 # (B, N)\n    beta  = (1.0 - trap_f) * dt_f * alpha        # (B, N)\n    gamma = trap_f * dt_f                         # (B, N)\n\n    x_vals   = (x_f[:, None, :, :] * Xp_f[None, :, :, :])    # (B, R, N, H)\n    xs_vals  = (xst_f[:, None, :, :] * Xp_f[None, :, :, :])  # (B, R, N, H)\n\n    xBt_state = torch.einsum('brnh,brns->bnhs', x_vals * gamma.unsqueeze(-1).unsqueeze(1),  B_f)\n    xBt_prev  = torch.einsum('brnh,brns->bnhs', xs_vals * beta.unsqueeze(-1).unsqueeze(1), Bst_f)\n\n    new_state = st_f * alpha[:, :, None, None] + xBt_state + xBt_prev   # (B, N, H, S)\n\n    out_r = torch.einsum('bnhs,brns->brnh', new_state, C_f)  # (B, R, N, H)\n\n    out_r = out_r + (x_vals * D_f[None, :, None])            # (B, R, N, H)\n\n    if z is not None:\n        z_f    = z.to(compute_dtype)       # (B, N, H)\n        Zp_f   = zproj.to(compute_dtype)   # (R, N, H)\n        z_vals = (z_f[:, None, :, :] * Zp_f[None, :, :, :])      # (B, R, N, H)\n        out_r  = out_r * z_vals * torch.sigmoid(z_vals)          # (B, R, N, H)\n\n    if outproj is not None:\n        Op_f   = outproj.to(compute_dtype) # (R, N, H)\n        out = torch.einsum('brnh,rnh->bnh', out_r, Op_f)         # (B, N, H)\n    else:\n        out = out_r                                               # (B, R, N, H)\n\n    return out.to(og_dtype), new_state.to(og_dtype)\n\n\ndef _bytes_of(t):\n    return t.numel() * t.element_size()\n\n\nif __name__ == \"__main__\":\n    torch.manual_seed(1357)\n    batch, nheads, hdim, dstate, mimo = 128, 64, 64, 128, 4\n    device = torch.device(\"cuda:0\")\n    dtype_state = torch.float32\n    dtype = torch.float32\n    state = torch.randn(batch, nheads, hdim, dstate, device=device, dtype=dtype_state)\n    Bstate = torch.randn(batch, mimo, nheads, dstate, device=device, dtype=dtype)\n    Xstate = torch.randn(batch, nheads, hdim, device=device, dtype=dtype)\n    A = -F.softplus(torch.randn(batch, nheads, device=device, dtype=torch.float32))\n    B = torch.randn(batch, mimo, nheads, dstate, device=device, dtype=dtype)\n    C = torch.randn(batch, mimo, nheads, dstate, device=device, dtype=dtype)\n    D = torch.randn(nheads, device=device, dtype=dtype)\n    x = torch.randn(batch, nheads, hdim, device=device, dtype=dtype)\n    z = torch.randn(batch, nheads, hdim, device=device, dtype=dtype)\n    dt = torch.exp(torch.rand(nheads, device=device) * (math.log(0.1) - math.log(0.001)) + math.log(0.001))\n    dt = torch.clamp(dt, min=1e-4)\n    dt_bias = dt + torch.log(-torch.expm1(-dt))\n    dt = F.softplus(torch.randn(batch, nheads, device=device) + dt_bias)  # (B, H)\n    trap = torch.sigmoid(torch.randn(batch, nheads, device=device, dtype=torch.float32))\n    xproj = torch.randn(mimo, nheads, hdim, device=device, dtype=dtype)\n    zproj = torch.randn(mimo, nheads, hdim, device=device, dtype=dtype)\n    outproj = torch.randn(mimo, nheads, hdim, device=device, dtype=dtype)\n    out = torch.zeros_like(x)\n\n    # =========================================================================\n    # Test 1: Out-of-place (explicit state_out)\n    # =========================================================================\n    print(\"=== Out-of-place test ===\")\n    state_out = torch.zeros_like(state)\n    fn_oop = lambda: mamba3_step_fn(\n        state,\n        Bstate,\n        Xstate,\n        A,\n        B,\n        C,\n        D,\n        x,\n        dt,\n        trap,\n        xproj,\n        outproj,\n        state_out,\n        out,\n        z=z,\n        zproj=zproj,\n        tile_D=64,\n        num_warps=4,\n    )\n\n    fn_oop()\n    out_ref, state_out_ref = selective_state_update_fused_ref_v2(state, A, B, C, xproj, x, zproj, z, dt, Bstate, Xstate, trap, D, outproj, compute_dtype=torch.float64)\n    out_pt, state_out_pt = selective_state_update_fused_ref_v2(state, A, B, C, xproj, x, zproj, z, dt, Bstate, Xstate, trap, D, outproj, compute_dtype=torch.float32)\n    print(f\"state_out vs ref (f64): {(state_out - state_out_ref).abs().max()}\")\n    print(f\"state_out_pt vs ref (f64): {(state_out_pt - state_out_ref).abs().max()}\")\n    print(f\"out vs ref (f64): {(out - out_ref).abs().max()}\")\n    print(f\"out_pt vs ref (f64): {(out_pt - out_ref).abs().max()}\")\n\n    # =========================================================================\n    # Test 2: In-place (state_out=None)\n    # =========================================================================\n    print(\"\\n=== In-place test ===\")\n    # Fresh state for in-place test\n    state_ip = state.clone()\n    out_ip = torch.zeros_like(x)\n    fn_ip = lambda: mamba3_step_fn(\n        state_ip,\n        Bstate,\n        Xstate,\n        A,\n        B,\n        C,\n        D,\n        x,\n        dt,\n        trap,\n        xproj,\n        outproj,\n        None,  # state_out=None -> in-place\n        out_ip,\n        z=z,\n        zproj=zproj,\n        tile_D=64,\n        num_warps=4,\n    )\n\n    fn_ip()\n    # state_ip was updated in place, compare against same reference\n    print(f\"state (in-place) vs ref (f64): {(state_ip - state_out_ref).abs().max()}\")\n    print(f\"out (in-place) vs ref (f64): {(out_ip - out_ref).abs().max()}\")\n    # Verify in-place and out-of-place produce identical results\n    print(f\"state in-place vs out-of-place: {(state_ip - state_out).abs().max()}\")\n    print(f\"out in-place vs out-of-place: {(out_ip - out).abs().max()}\")\n\n    # =========================================================================\n    # Benchmark (out-of-place)\n    # =========================================================================\n    read_bytes = (\n        _bytes_of(state) + _bytes_of(A) + _bytes_of(B)\n        + _bytes_of(C)\n        + _bytes_of(xproj) + _bytes_of(x)\n        + _bytes_of(zproj) + _bytes_of(z)\n        + _bytes_of(dt) + _bytes_of(Bstate) + _bytes_of(Xstate)\n        + _bytes_of(trap) + _bytes_of(D) + _bytes_of(outproj)\n    )\n    out_bytes       = _bytes_of(out)\n    new_state_bytes = _bytes_of(state)\n    total_bytes = read_bytes + out_bytes + new_state_bytes\n\n    from triton.testing import do_bench_cudagraph\n    ms = do_bench_cudagraph(fn_oop, rep=30)\n    bandwidth = (total_bytes) / ms * 1e-6\n    print(f\"\\nMamba3 step (out-of-place): {ms:.3f} ms, {bandwidth:.1f} GB/s\")"
  },
  {
    "path": "mamba_ssm/ops/selective_scan_interface.py",
    "content": "# Copyright (c) 2023, Tri Dao, Albert Gu.\n\nimport torch\nimport torch.nn.functional as F\nfrom mamba_ssm.utils.torch import custom_bwd, custom_fwd\n\nfrom einops import rearrange, repeat\n\ntry:\n    from causal_conv1d import causal_conv1d_fn\n    from causal_conv1d.cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function\nexcept ImportError:\n    causal_conv1d_fn = None\n    causal_conv1d_fwd_function = None\n    causal_conv1d_bwd_function = None\n    causal_conv1d_update_function = None\n\nfrom mamba_ssm.ops.triton.layer_norm import _layer_norm_fwd\n\nimport selective_scan_cuda\n\n\nclass SelectiveScanFn(torch.autograd.Function):\n\n    @staticmethod\n    def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,\n                return_last_state=False):\n        if u.stride(-1) != 1:\n            u = u.contiguous()\n        if delta.stride(-1) != 1:\n            delta = delta.contiguous()\n        if D is not None:\n            D = D.contiguous()\n        if B.stride(-1) != 1:\n            B = B.contiguous()\n        if C.stride(-1) != 1:\n            C = C.contiguous()\n        if z is not None and z.stride(-1) != 1:\n            z = z.contiguous()\n        if B.dim() == 3:\n            B = rearrange(B, \"b dstate l -> b 1 dstate l\")\n            ctx.squeeze_B = True\n        if C.dim() == 3:\n            C = rearrange(C, \"b dstate l -> b 1 dstate l\")\n            ctx.squeeze_C = True\n        out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus)\n        ctx.delta_softplus = delta_softplus\n        ctx.has_z = z is not None\n        last_state = x[:, :, -1, 1::2]  # (batch, dim, dstate)\n        if not ctx.has_z:\n            ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)\n            return out if not return_last_state else (out, last_state)\n        else:\n            ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)\n            out_z = rest[0]\n            return out_z if not return_last_state else (out_z, last_state)\n\n    @staticmethod\n    def backward(ctx, dout, *args):\n        if not ctx.has_z:\n            u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors\n            z = None\n            out = None\n        else:\n            u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors\n        if dout.stride(-1) != 1:\n            dout = dout.contiguous()\n        # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the\n        # backward of selective_scan_cuda with the backward of chunk).\n        # Here we just pass in None and dz will be allocated in the C++ code.\n        du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(\n            u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus,\n            False  # option to recompute out_z, not used here\n        )\n        dz = rest[0] if ctx.has_z else None\n        dB = dB.squeeze(1) if getattr(ctx, \"squeeze_B\", False) else dB\n        dC = dC.squeeze(1) if getattr(ctx, \"squeeze_C\", False) else dC\n        return (du, ddelta, dA, dB, dC,\n                dD if D is not None else None,\n                dz,\n                ddelta_bias if delta_bias is not None else None,\n                None,\n                None)\n\n\ndef rms_norm_forward(\n    x,\n    weight,\n    bias,\n    eps=1e-6,\n    is_rms_norm=True,\n):\n    # x (b l) d\n    if x.stride(-1) != 1:\n        x = x.contiguous()\n    weight = weight.contiguous()\n    if bias is not None:\n        bias = bias.contiguous()\n    y = _layer_norm_fwd(\n        x, weight, bias, eps, None, residual_dtype=None, is_rms_norm=is_rms_norm\n    )[0]\n    # y (b l) d\n    return y\n\n\ndef selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,\n                     return_last_state=False):\n    \"\"\"if return_last_state is True, returns (out, last_state)\n    last_state has shape (batch, dim, dstate). Note that the gradient of the last state is\n    not considered in the backward pass.\n    \"\"\"\n    return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)\n\n\ndef selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,\n                      return_last_state=False):\n    \"\"\"\n    u: r(B D L)\n    delta: r(B D L)\n    A: c(D N) or r(D N)\n    B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)\n    C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)\n    D: r(D)\n    z: r(B D L)\n    delta_bias: r(D), fp32\n\n    out: r(B D L)\n    last_state (optional): r(B D dstate) or c(B D dstate)\n    \"\"\"\n    dtype_in = u.dtype\n    u = u.float()\n    delta = delta.float()\n    if delta_bias is not None:\n        delta = delta + delta_bias[..., None].float()\n    if delta_softplus:\n        delta = F.softplus(delta)\n    batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]\n    is_variable_B = B.dim() >= 3\n    is_variable_C = C.dim() >= 3\n    if A.is_complex():\n        if is_variable_B:\n            B = torch.view_as_complex(rearrange(B.float(), \"... (L two) -> ... L two\", two=2))\n        if is_variable_C:\n            C = torch.view_as_complex(rearrange(C.float(), \"... (L two) -> ... L two\", two=2))\n    else:\n        B = B.float()\n        C = C.float()\n    x = A.new_zeros((batch, dim, dstate))\n    ys = []\n    deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))\n    if not is_variable_B:\n        deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)\n    else:\n        if B.dim() == 3:\n            deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)\n        else:\n            B = repeat(B, \"B G N L -> B (G H) N L\", H=dim // B.shape[1])\n            deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)\n    if is_variable_C and C.dim() == 4:\n        C = repeat(C, \"B G N L -> B (G H) N L\", H=dim // C.shape[1])\n    last_state = None\n    for i in range(u.shape[2]):\n        x = deltaA[:, :, i] * x + deltaB_u[:, :, i]\n        if not is_variable_C:\n            y = torch.einsum('bdn,dn->bd', x, C)\n        else:\n            if C.dim() == 3:\n                y = torch.einsum('bdn,bn->bd', x, C[:, :, i])\n            else:\n                y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])\n        if i == u.shape[2] - 1:\n            last_state = x\n        if y.is_complex():\n            y = y.real * 2\n        ys.append(y)\n    y = torch.stack(ys, dim=2) # (batch dim L)\n    out = y if D is None else y + u * rearrange(D, \"d -> d 1\")\n    if z is not None:\n        out = out * F.silu(z)\n    out = out.to(dtype=dtype_in)\n    return out if not return_last_state else (out, last_state)\n\n\nclass MambaInnerFn(torch.autograd.Function):\n\n    @staticmethod\n    @custom_fwd\n    def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,\n                out_proj_weight, out_proj_bias,\n                A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,\n                C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1, b_rms_weight=None, c_rms_weight= None, dt_rms_weight= None, b_c_dt_rms_eps=1e-6):\n        \"\"\"\n             xz: (batch, dim, seqlen)\n        \"\"\"\n        assert causal_conv1d_fwd_function is not None, \"causal_conv1d_cuda is not available. Please install causal-conv1d.\"\n        assert checkpoint_lvl in [0, 1]\n        L = xz.shape[-1]\n        delta_rank = delta_proj_weight.shape[1]\n        d_state = A.shape[-1] * (1 if not A.is_complex() else 2)\n        if torch.is_autocast_enabled():\n            x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())\n            delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())\n            out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())\n            out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())\n                             if out_proj_bias is not None else None)\n        if xz.stride(-1) != 1:\n            xz = xz.contiguous()\n        conv1d_weight = rearrange(conv1d_weight, \"d 1 w -> d w\")\n        x, z = xz.chunk(2, dim=1)\n        conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None\n        conv1d_out = causal_conv1d_fwd_function(\n            x, conv1d_weight, conv1d_bias, None, None, None, True\n        )\n        # We're being very careful here about the layout, to avoid extra transposes.\n        # We want delta to have d as the slowest moving dimension\n        # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.\n        x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight)  # (bl d)\n        delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), \"d (b l) -> b d l\", l = L)\n        ctx.is_variable_B = B is None\n        ctx.is_variable_C = C is None\n        ctx.B_proj_bias_is_None = B_proj_bias is None\n        ctx.C_proj_bias_is_None = C_proj_bias is None\n        if B is None:  # variable B\n            B = x_dbl[:, delta_rank:delta_rank + d_state]  # (bl dstate)\n            if B_proj_bias is not None:\n                B = B + B_proj_bias.to(dtype=B.dtype)\n            if not A.is_complex():\n                # B = rearrange(B, \"(b l) dstate -> b dstate l\", l=L).contiguous()\n                B = rearrange(B, \"(b l) dstate -> b 1 dstate l\", l=L).contiguous()\n            else:\n                B = rearrange(B, \"(b l) (dstate two) -> b 1 dstate (l two)\", l=L, two=2).contiguous()\n        else:\n            if B.stride(-1) != 1:\n                B = B.contiguous()\n        if C is None:  # variable C\n            C = x_dbl[:, -d_state:]  # (bl dstate)\n            if C_proj_bias is not None:\n                C = C + C_proj_bias.to(dtype=C.dtype)\n            if not A.is_complex():\n                # C = rearrange(C, \"(b l) dstate -> b dstate l\", l=L).contiguous()\n                C = rearrange(C, \"(b l) dstate -> b 1 dstate l\", l=L).contiguous()\n            else:\n                C = rearrange(C, \"(b l) (dstate two) -> b 1 dstate (l two)\", l=L, two=2).contiguous()\n        else:\n            if C.stride(-1) != 1:\n                C = C.contiguous()\n        if D is not None:\n            D = D.contiguous()\n            \n        if b_rms_weight is not None:\n            B = rearrange(B, \"b 1 dstate l -> (b l) dstate\", l=L).contiguous()\n            B = rms_norm_forward(B, b_rms_weight, bias=None, eps=b_c_dt_rms_eps)\n            B = rearrange(B, \"(b l) dstate -> b 1 dstate l\", l=L).contiguous()\n        if c_rms_weight is not None:\n            C = rearrange(C, \"b 1 dstate l -> (b l) dstate\", l=L).contiguous()\n            C = rms_norm_forward(C, c_rms_weight, bias=None, eps=b_c_dt_rms_eps)\n            C = rearrange(C, \"(b l) dstate -> b 1 dstate l\", l=L).contiguous()\n        if dt_rms_weight is not None:\n            delta = rearrange(delta, \"b d l -> (b l) d\", l=L).contiguous()\n            delta = rms_norm_forward(delta, dt_rms_weight, bias=None, eps=b_c_dt_rms_eps)\n            delta = rearrange(delta, \"(b l) d -> b d l\", l=L).contiguous()\n        \n        out, scan_intermediates, out_z = selective_scan_cuda.fwd(\n            conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus\n        )\n        ctx.delta_softplus = delta_softplus\n        ctx.out_proj_bias_is_None = out_proj_bias is None\n        ctx.checkpoint_lvl = checkpoint_lvl\n        ctx.b_rms_weight = b_rms_weight\n        ctx.c_rms_weight = c_rms_weight\n        ctx.dt_rms_weight = dt_rms_weight\n        ctx.b_c_dt_rms_eps = b_c_dt_rms_eps\n        if checkpoint_lvl >= 1:  # Will recompute conv1d_out and delta in the backward pass\n            conv1d_out, delta = None, None\n        ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight,\n                              delta_proj_weight, out_proj_weight, conv1d_out, delta,\n                              A, B, C, D, delta_bias, scan_intermediates, b_rms_weight, c_rms_weight, dt_rms_weight, out)\n        return F.linear(rearrange(out_z, \"b d l -> b l d\"), out_proj_weight, out_proj_bias)\n\n    @staticmethod\n    @custom_bwd\n    def backward(ctx, dout):\n        # dout: (batch, seqlen, dim)\n        assert causal_conv1d_fwd_function is not None, \"causal_conv1d_cuda is not available. Please install causal-conv1d.\"\n        (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight,\n         conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, b_rms_weight, c_rms_weight, dt_rms_weight, out) = ctx.saved_tensors\n        L = xz.shape[-1]\n        delta_rank = delta_proj_weight.shape[1]\n        d_state = A.shape[-1] * (1 if not A.is_complex() else 2)\n        x, z = xz.chunk(2, dim=1)\n        if dout.stride(-1) != 1:\n            dout = dout.contiguous()\n        if ctx.checkpoint_lvl == 1:\n            conv1d_out = causal_conv1d_fwd_function(\n                x, conv1d_weight, conv1d_bias, None, None, None, True\n            )\n            delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),\n                              \"d (b l) -> b d l\", l = L)\n            if dt_rms_weight is not None:\n                delta = rearrange(delta, \"b d l -> (b l) d\", l=L).contiguous()\n                delta = rms_norm_forward(delta, ctx.dt_rms_weight, None, ctx.b_c_dt_rms_eps)\n                delta = rearrange(delta, \"(b l) d -> b d l\", l=L).contiguous()\n            if b_rms_weight is not None:\n                # Recompute & RMSNorm B\n                B = rearrange(B, \"b 1 dstate l -> (b l) dstate\", l=L).contiguous()\n                B = rms_norm_forward(\n                    B, ctx.b_rms_weight, None, ctx.b_c_dt_rms_eps\n                )\n                B = rearrange(B, \"(b l) dstate -> b 1 dstate l\", l=L).contiguous()\n            if c_rms_weight is not None:\n                # Recompute & RMSNorm C\n                C = rearrange(C, \"b 1 dstate l -> (b l) dstate\", l=L).contiguous()\n                C = rms_norm_forward(\n                    C, ctx.c_rms_weight, None, ctx.b_c_dt_rms_eps\n                )\n                C = rearrange(C, \"(b l) dstate -> b 1 dstate l\", l=L).contiguous()\n            \n        # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the\n        # backward of selective_scan_cuda with the backward of chunk).\n        dxz = torch.empty_like(xz)  # (batch, dim, seqlen)\n        dx, dz = dxz.chunk(2, dim=1)\n        dout = rearrange(dout, \"b l e -> e (b l)\")\n        dout_y = rearrange(out_proj_weight.t() @ dout, \"d (b l) -> b d l\", l=L)\n        dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd(\n            conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates, out, dz,\n            ctx.delta_softplus,\n            True  # option to recompute out_z\n        )\n        dout_proj_weight = torch.einsum(\"eB,dB->ed\", dout, rearrange(out_z, \"b d l -> d (b l)\"))\n        dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None\n        dD = dD if D is not None else None\n        dx_dbl = torch.empty_like(x_dbl)\n        dB_proj_bias = None\n        if ctx.is_variable_B:\n            if not A.is_complex():\n                dB = rearrange(dB, \"b 1 dstate l -> (b l) dstate\").contiguous()\n            else:\n                dB = rearrange(dB, \"b 1 dstate (l two) -> (b l) (dstate two)\", two=2).contiguous()\n            dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None\n            dx_dbl[:, delta_rank:delta_rank + d_state] = dB  # (bl d)\n            dB = None\n        dC_proj_bias = None\n        if ctx.is_variable_C:\n            if not A.is_complex():\n                dC = rearrange(dC, \"b 1 dstate l -> (b l) dstate\").contiguous()\n            else:\n                dC = rearrange(dC, \"b 1 dstate (l two) -> (b l) (dstate two)\", two=2).contiguous()\n            dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None\n            dx_dbl[:, -d_state:] = dC  # (bl d)\n            dC = None\n        ddelta = rearrange(ddelta, \"b d l -> d (b l)\")\n        ddelta_proj_weight = torch.einsum(\"dB,Br->dr\", ddelta, x_dbl[:, :delta_rank])\n        dx_dbl[:, :delta_rank] = torch.einsum(\"dB,dr->Br\", ddelta, delta_proj_weight)\n        dconv1d_out = rearrange(dconv1d_out, \"b d l -> d (b l)\")\n        dx_proj_weight = torch.einsum(\"Br,Bd->rd\", dx_dbl, rearrange(conv1d_out, \"b d l -> (b l) d\"))\n        dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)\n        dconv1d_out = rearrange(dconv1d_out, \"d (b l) -> b d l\", b=x.shape[0], l=x.shape[-1])\n        # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the\n        # backward of conv1d with the backward of chunk).\n        dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_bwd_function(\n            x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True\n        )\n        dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None\n        dconv1d_weight = rearrange(dconv1d_weight, \"d w -> d 1 w\")\n        return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight,\n                dout_proj_weight, dout_proj_bias,\n                dA, dB, dC, dD,\n                ddelta_bias if delta_bias is not None else None,\n                # 6-None are delta_softplus, checkpoint_lvl, b_rms_weight, c_rms_weight, dt_rms_weight, b_c_dt_rms_eps\n                dB_proj_bias, dC_proj_bias, None, None, None, None, None, None)\n\n\ndef mamba_inner_fn(\n    xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,\n    out_proj_weight, out_proj_bias,\n    A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,\n    C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1, b_rms_weight= None, c_rms_weight= None, dt_rms_weight= None, b_c_dt_rms_eps=1e-6\n):\n    return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,\n                              out_proj_weight, out_proj_bias,\n                              A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus, checkpoint_lvl, b_rms_weight, c_rms_weight, dt_rms_weight, b_c_dt_rms_eps)\n\n\ndef mamba_inner_ref(\n    xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,\n    out_proj_weight, out_proj_bias,\n    A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,\n    C_proj_bias=None, delta_softplus=True\n):\n    assert causal_conv1d_fn is not None, \"causal_conv1d_fn is not available. Please install causal-conv1d.\"\n    L = xz.shape[-1]\n    delta_rank = delta_proj_weight.shape[1]\n    d_state = A.shape[-1] * (1 if not A.is_complex() else 2)\n    x, z = xz.chunk(2, dim=1)\n    x = causal_conv1d_fn(x, rearrange(conv1d_weight, \"d 1 w -> d w\"), conv1d_bias, activation=\"silu\")\n    # We're being very careful here about the layout, to avoid extra transposes.\n    # We want delta to have d as the slowest moving dimension\n    # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.\n    x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight)  # (bl d)\n    delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()\n    delta = rearrange(delta, \"d (b l) -> b d l\", l=L)\n    if B is None:  # variable B\n        B = x_dbl[:, delta_rank:delta_rank + d_state]  # (bl d)\n        if B_proj_bias is not None:\n            B = B + B_proj_bias.to(dtype=B.dtype)\n        if not A.is_complex():\n            B = rearrange(B, \"(b l) dstate -> b dstate l\", l=L).contiguous()\n        else:\n            B = rearrange(B, \"(b l) (dstate two) -> b dstate (l two)\", l=L, two=2).contiguous()\n    if C is None:  # variable B\n        C = x_dbl[:, -d_state:]  # (bl d)\n        if C_proj_bias is not None:\n            C = C + C_proj_bias.to(dtype=C.dtype)\n        if not A.is_complex():\n            C = rearrange(C, \"(b l) dstate -> b dstate l\", l=L).contiguous()\n        else:\n            C = rearrange(C, \"(b l) (dstate two) -> b dstate (l two)\", l=L, two=2).contiguous()\n    y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True)\n    return F.linear(rearrange(y, \"b d l -> b l d\"), out_proj_weight, out_proj_bias)\n"
  },
  {
    "path": "mamba_ssm/ops/tilelang/mamba3/mamba3_mimo.py",
    "content": "\"\"\"Mamba-3 Tilelang Autograd Wrapper\n\nInterface for Mamba-3 Tilelang kernels with automatic differentiation\n\nCopyright (c) 2026, Dao AI Lab, Goombalab\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom typing import Optional, Tuple, Union\n\nimport torch\nfrom torch import Tensor\n\n# Import kernels\nfrom mamba_ssm.ops.tilelang.mamba3.mamba3_mimo_fwd import mamba_mimo_forward\nfrom mamba_ssm.ops.triton.mamba3.mamba3_mimo_utils import compute_dacs_segsum_triton\nfrom mamba_ssm.ops.tilelang.mamba3.mamba3_mimo_bwd import mamba_mimo_bwd_combined\n\n# =============================================================================\n# Autograd Function\n# =============================================================================\n\nclass _Mamba3Function(torch.autograd.Function):\n    \"\"\"Custom autograd function for Mamba-3 with Triton/Tilelang kernels.\"\"\"\n    \n    @staticmethod\n    def forward(\n        ctx,\n        Q: Tensor,\n        K: Tensor,\n        V: Tensor,\n        ADT: Tensor,\n        DT: Tensor,\n        Trap: Tensor,\n        Q_bias: Tensor,\n        K_bias: Tensor,\n        MIMO_V: Tensor,\n        MIMO_Z: Tensor,\n        MIMO_Out: Union[Tensor, None],\n        Angles: Tensor,\n        D: Tensor,\n        Z: Tensor,\n        chunk_size: int,\n        rotary_dim_divisor: int,\n        dtype: torch.dtype,\n        return_state: bool,\n    ) -> Tensor | Tuple[Tensor, Tuple]:\n        \"\"\"Forward pass: call Triton/Tilelang kernel and save tensors for backward.\"\"\"\n        ctx.chunk_size = chunk_size\n        ctx.rotary_dim_divisor = rotary_dim_divisor\n        ctx.dtype = dtype\n        (Q, K, V, ADT, DT, Trap, Q_bias, K_bias, MIMO_V, MIMO_Z, MIMO_Out, Angles, D, Z) = tuple(\n            t.contiguous() if t is not None else None\n            for t in (\n                Q, K, V, ADT, DT, Trap, Q_bias, K_bias, MIMO_V, MIMO_Z, MIMO_Out, Angles, D, Z,\n            )\n        )\n\n        DA_CS, DA_CS_REV, Segsum = compute_dacs_segsum_triton(ADT, chunk_size)\n        Out, Final_SSM_State, Final_K = mamba_mimo_forward(\n            Q, K, V, Q_bias, K_bias, MIMO_V, MIMO_Out,\n            Z, D, MIMO_Z, Angles,\n            DA_CS, DA_CS_REV, DT, Trap, Segsum, \n            return_state=return_state,\n            chunk_size=chunk_size, rotary_dim_divisor=rotary_dim_divisor,\n            dtype=dtype,\n        )\n\n        ctx.chunk_size = chunk_size\n        ctx.save_for_backward(\n            Q, K, V, ADT, DT, Trap, Q_bias, K_bias, Angles,\n            D, Z,\n            MIMO_V, MIMO_Out, MIMO_Z,\n        )\n\n        if not return_state:\n            return Out\n        else:\n            Final_Angle = torch.remainder(Angles[:, -1, :, :], 2 * torch.pi).contiguous().detach()\n            Final_SSM_State = Final_SSM_State.permute(0, 1, 3, 2).contiguous().detach()\n            Final_K = Final_K.contiguous().detach()\n            Final_V = V[:, -1, :, :].contiguous().detach()\n            ctx.mark_non_differentiable(Final_Angle, Final_SSM_State, Final_K, Final_V)\n            return Out, Final_Angle, Final_SSM_State, Final_K, Final_V\n    \n    @staticmethod\n    def backward(ctx, dout, *args) -> tuple:\n        \"\"\"Backward pass: compute gradients using Triton backward kernels.\"\"\"\n        \n        if len(ctx.saved_tensors) == 0:\n            raise RuntimeError(\n                \"Backward called but forward ran without gradient tracking. \"\n                \"Ensure inputs require grad or run under torch.enable_grad().\"\n            )\n        dout = dout.contiguous()\n\n        (Q, K, V, ADT, DT, Trap, Q_bias, K_bias, Angles,\n            D, Z,\n            MIMO_V, MIMO_Out, MIMO_Z,\n            ) = ctx.saved_tensors\n    \n        DA_CS, DA_CS_REV, Segsum = compute_dacs_segsum_triton(ADT, ctx.chunk_size)\n\n        (dQ, dK, dV, \n            dADT, dDT, dTrap, dQ_bias, dK_bias,\n            dMIMO_V, dMIMO_Z, dMIMO_Out, dAngles, \n            dD, dZ) = mamba_mimo_bwd_combined(\n                dout,\n                Q, \n                K, \n                V, \n                Q_bias,\n                K_bias,\n                MIMO_V, \n                MIMO_Out,\n                Z,\n                MIMO_Z,\n                Angles,\n                DA_CS,\n                DA_CS_REV,\n                DT,\n                Trap,\n                D,\n                Segsum,\n                ctx.chunk_size,\n                ctx.rotary_dim_divisor,\n                ctx.dtype,\n            )\n\n        return (\n            dQ,\n            dK,\n            dV,\n            dADT,\n            dDT,\n            dTrap,\n            dQ_bias,\n            dK_bias,\n            dMIMO_V,\n            dMIMO_Z,\n            dMIMO_Out,\n            dAngles,\n            dD,\n            dZ,\n            None, None, None, None,\n        )\n\n\n# =============================================================================\n# Public API\n# =============================================================================\n\ndef mamba3_mimo(\n    Q: Tensor,\n    K: Tensor,\n    V: Tensor,\n    ADT: Tensor,\n    DT: Tensor,\n    Trap: Tensor,\n    Q_bias: Tensor,\n    K_bias: Tensor,\n    MIMO_V: Tensor,\n    MIMO_Z: Tensor,\n    MIMO_Out: Tensor,\n    Angles: Tensor,\n    D: Tensor,\n    Z: Tensor,\n    chunk_size: int,\n    rotary_dim_divisor: int,\n    dtype: torch.dtype,\n    return_state: bool = False,\n) -> Tensor | Tuple[Tensor, Tuple]:\n    \"\"\"Mamba-3 attention with Tilelang kernels and automatic differentiation.\n    \n    Args:\n        Q: Query tensor (batch, seqlen, mimo_rank, nheads_qk, headdim_qk)\n        K: Key tensor (batch, seqlen, mimo_rank, nheads_qk, headdim_qk)\n        V: Value tensor (batch, seqlen, nheads, headdim_v)\n        ADT: Decay factor A * dt (batch, nheads, seqlen)\n        DT: Time delta tensor dt (batch, nheads, seqlen)\n        Trap: Trapezoidal mixing factor, pre-sigmoid (batch, nheads, seqlen)\n        Q_bias: Query bias (nheads, mimo_rank, headdim_qk)\n        K_bias: Key bias (nheads, mimo_rank, headdim_qk)\n        MIMO_V: Mimo up projection for V (nheads, mimo_rank, headdim_v),\n        MIMO_Z: Mimo up projection for Z (nheads, mimo_rank, headdim_v),\n        MIMO_Out: Mimo down projection for output (nheads, mimo_rank, headdim_v). If None, does not reduce output with MIMO_Out,\n        Angles: Rotary position embeddings (batch, seqlen, nheads, headangles)\n        D: Optional skip connection weight (nheads,)\n        Z: Optional gating tensor (batch, seqlen, nheads, headdim_v)\n        chunk_size: Chunk size for state computation (default: 64//R)  \n        rotary_dim_divisor: Divisor for rotary embedding dimensions (default: 4, meaning angles have 1/4 of headdim_qk)\n    \n    Returns:\n        output: (batch, seqlen, nheads, headdim_v) if MIMO_Out is not None\n                (batch, seqlen, mimo_rank, nheads, headdim_v) if MIMO_Out is None\n        final_state: Tuple of tensors representing the running Angle sum, final SSM state, final K, and final V for autoregressive decoding. Only returned if return_state=True.\n\n    NOTE: The kernel is most optimized for seqlen: 2048, nheads_qk: 1, nheads: 32\n    headdim_qk: 128, headdim_v: 64, mimo_rank: 4, and chunk_size: 16. On H100.\n    NOTE: The code is still prone to smem over-allocation and Tilelang compilation error\n    once headdim_qk, headdim_v, mimo_rank, chunk_size, or hardware type deviate from the combinations tested.\n    NOTE: Chunk size of 64/R is recommended, where R is the MIMO rank. However, it may be necessary to reduce chunk size\n    in case of smem over-allocation, which can occur with larger headdim_qk, headdim_v, or mimo_rank values.\n    NOTE: Currently final_state is currently intended to be a non-differentiable side output. In particular,\n    loss = f(output) is fine, but loss = f(output, final_state) will not work properly since the backward does not compute gradients for final_state components.\n\n\n    \"\"\"\n    \n    batch, seqlen, mimo_rank, nheads_qk, headdim_qk = Q.shape\n    _, _, nheads, headdim_v = V.shape\n    \n    assert chunk_size >= 8, f\"chunk_size must be at least 8\"\n    assert nheads % nheads_qk == 0, f\"nheads ({nheads}) must be divisible by nheads_qk ({nheads_qk})\"\n    assert headdim_qk % 2 == 0, f\"headdim_qk ({headdim_qk}) must be even for rotary embeddings\"\n    assert rotary_dim_divisor in [2, 4], f\"currently only supports rotary embedding on entire or half of headdim_qk\"\n    # NOTE: the following (headdim_qk, headdim_v) values currently can result in compilation errors: (16, 32), (256, 128) \n    if headdim_qk not in [16, 32, 64, 128, 256]:\n        print(f\"WARNING: The value headdim_qk={headdim_qk} has not been tested. \" +\\\n              \"Proceed with caution and consider one of the tested headdim_qk: 16, 32, 64, 128, 256.\")\n    if headdim_v not in [32, 64, 128]:\n        print(f\"WARNING: The value headdim_v={headdim_v} has not been tested. \" +\\\n              \"Proceed with caution and consider one of the tested headdim_v: 32, 64, 128.\")\n    if mimo_rank not in [1, 2, 4, 8]:\n        print(f\"WARNING: The value mimo_rank={mimo_rank} has not been tested. \" +\\\n              \"Proceed with caution and consider one of the tested mimo_rank: 1, 2, 4, 8.\")\n\n    if chunk_size*mimo_rank > 64:\n        print(f\"WARNING: chunk_size * mimo_rank = {chunk_size*mimo_rank} exceeds 64, which may result in smem over-allocation. Consider decreasing chunk_size.\")\n\n    return _Mamba3Function.apply(\n        Q,\n        K,\n        V,\n        ADT,\n        DT,\n        Trap,\n        Q_bias,\n        K_bias,\n        MIMO_V,\n        MIMO_Z,\n        MIMO_Out,\n        Angles,\n        D,\n        Z,\n        chunk_size,\n        rotary_dim_divisor,\n        dtype,\n        return_state,\n    )"
  },
  {
    "path": "mamba_ssm/ops/tilelang/mamba3/mamba3_mimo_bwd.py",
    "content": "\"\"\"\nTilelang implementation of Mamba3 backward kernels,\nwith MIMO support.\n\nCopyright (c) 2026, Dao AI Lab, Goombalab\n\n\"\"\"\n\nimport torch\nimport tilelang\nimport tilelang.language as T\nfrom triton.testing import do_bench\nfrom tilelang.autotuner import autotune\n\n\nimport itertools\nimport argparse\nfrom einops import rearrange\nfrom typing import Optional, Tuple\n\nfrom mamba_ssm.ops.triton.mamba3.mamba3_mimo_utils import bwd_dadt_fused_triton, bwd_dtrap_ddt_triton\n\n\n# def get_configs():\n#     iter_params = dict(num_stages=[0, 1, 2, 3], threads=[128, 256, 512])\n#     # iter_params = dict(num_stages=[2], threads=[128])\n#     return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]\n\n# @autotune(\n#     configs=get_configs(),\n#     warmup=3,\n#     rep=20,\n# )\n@tilelang.jit(\n    out_idx=[],\n    pass_configs={\n        tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,\n        tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,\n        tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,\n    })\ndef mamba_mimo_bwd_fwd(\n    B,\n    S,\n    H,\n    G,\n    N,\n    P,\n    R,\n    hasZ,\n    hasD,\n    reduceO,\n    chunk_size: int = 16,\n    rotary_dim_divisor: int = 4,\n    dtype: str = 'float16',\n    threads: int = 128,\n    num_stages: int = 0,\n) -> torch.Tensor:\n\n    accum_dtype = 'float32'\n\n    nchunks = tilelang.cdiv(S, chunk_size)\n    fused_chunk_size = chunk_size * R\n\n    if reduceO:\n        DOUT_shape = (B, S, H, P)\n    else:\n        DOUT_shape = (B, S, R, H, P)\n\n    @T.prim_func\n    def mamba_mimo_bwd_fwd_kernel(\n            DOUT: T.Tensor(DOUT_shape, dtype),  # type: ignore\n            Q: T.Tensor([B, S, R, G, N], dtype),  # type: ignore\n            K: T.Tensor([B, S, R, G, N], dtype),  # type: ignore\n            V: T.Tensor([B, S, H, P], dtype),  # type: ignore\n            Q_BIAS: T.Tensor([H, R, N], T.float32),  # type: ignore\n            K_BIAS: T.Tensor([H, R, N], T.float32),  # type: ignore\n            MIMO_V: T.Tensor([H, R, P], T.float32), # type: ignore\n            MIMO_O: T.Tensor([H, R, P], T.float32), # type: ignore\n            \n            DMIMO_O: T.Tensor([B, H, R, P], T.float32), # type: ignore\n            STATES: T.Tensor([B, H, nchunks, N, P], dtype), # type: ignore \n\n            Z: T.Tensor([B, S, H, P], dtype),  # type: ignore\n            MIMO_Z: T.Tensor([H, R, P], T.float32), # type: ignore\n            DZ: T.Tensor([B, S, H, P], dtype),  # type: ignore\n            DMIMO_Z: T.Tensor([B, H, R, P], T.float32), # type: ignore\n            ANGLES: T.Tensor([B, S, H, N//rotary_dim_divisor], T.float32), # type: ignore\n            DA_CS: T.Tensor([B, H, S], T.float32), # type: ignore\n            DA_CS_REV: T.Tensor([B, H, S], T.float32), # type: ignore\n            DT: T.Tensor([B, H, S], T.float32), # type: ignore\n            TRAP: T.Tensor([B, H, S], dtype), # type: ignore\n            D: T.Tensor([H], T.float32),  # type: ignore\n\n            QK_DOT: T.Tensor([B, H, S, R, R], dtype), # type: ignore\n            \n            SEGSUM: T.Tensor([B, H, nchunks, chunk_size, chunk_size], T.float32), # type: ignore\n            ):\n        \"\"\"\n        Overview:\n            Fused backward-forward pass over chunks. Recomputes local forward intermediates,\n            accumulates projection gradients (DMIMO_O and optional DMIMO_Z), emits optional DZ,\n            stores per-chunk recurrent STATES, and materializes QK_DOT for the second backward pass.\n\n        Inputs:\n            - Activations and upstream grad: DOUT, Q, K, V.\n            - Projection weights/biases: Q_BIAS, K_BIAS, MIMO_V (Psi), MIMO_O (Phi), optional MIMO_Z (Zeta).\n            - Optional forward modifiers: Z, D.\n            - Discretization tensors: DA_CS, DA_CS_REV, DT, TRAP, and SEGSUM.\n\n        Outputs:\n            - MIMO projection grads: DMIMO_O and optional DMIMO_Z.\n            - Optional activation grad: DZ.\n            - Cached intermediates for pass 2: STATES and QK_DOT.\n\n        Notation:\n            - Psi: MIMO X projection.\n            - Phi: MIMO O projection.\n            - Zeta: MIMO Z projection.\n            - Trap: convex-combination modulator used in exponential-trapezoidal discretization.\n        \"\"\"\n        \n        with T.Kernel(H, B, threads=threads) as (i_h, i_b):\n            # --- Kernel Setup ---\n            # GQA support: map V head to Q/K head\n            i_h_qk = i_h // (H // G)\n\n            # --- Buffer Allocation ---\n            q_shared = T.alloc_shared([fused_chunk_size, N], dtype)\n            k_shared = T.alloc_shared([fused_chunk_size, N], dtype)\n            PsiV_shared = T.alloc_shared([fused_chunk_size, P], dtype)\n            qs_shared = T.alloc_shared([fused_chunk_size, P], dtype)\n            o_shared = T.alloc_shared([chunk_size, P], dtype)\n            v_shared = T.alloc_shared([chunk_size, P], dtype)\n            states_accum_cast_shared = T.alloc_shared([N, P], dtype)\n\n            qk_dot_full_shared = T.alloc_shared([fused_chunk_size, fused_chunk_size], dtype)\n\n            # --- Output Accumulators ---\n            if reduceO:\n                dPhi_shared = T.alloc_shared([R, P], accum_dtype)\n                T.clear(dPhi_shared)\n\n            dout_shared = T.alloc_shared([chunk_size, P], dtype)\n\n            z_shared = T.alloc_shared([chunk_size, P], dtype)\n            dZeta_shared = T.alloc_shared([R, P], accum_dtype)\n            T.clear(dZeta_shared)\n\n            # --- Swizzling Annotation ---\n            T.annotate_layout({\n                q_shared: tilelang.layout.make_swizzled_layout(q_shared),\n                k_shared: tilelang.layout.make_swizzled_layout(k_shared),\n\n                PsiV_shared: tilelang.layout.make_swizzled_layout(PsiV_shared),\n                qs_shared: tilelang.layout.make_swizzled_layout(qs_shared),\n                o_shared: tilelang.layout.make_swizzled_layout(o_shared),\n                states_accum_cast_shared: tilelang.layout.make_swizzled_layout(states_accum_cast_shared),\n                qk_dot_full_shared: tilelang.layout.make_swizzled_layout(qk_dot_full_shared),\n                dout_shared: tilelang.layout.make_swizzled_layout(dout_shared),\n                z_shared: tilelang.layout.make_swizzled_layout(z_shared),\n\n            })\n            T.use_swizzle(10, \"row\")\n\n            T.no_set_max_nreg()\n\n            # --- Per-Head Constants / Running State ---\n            states_frag = T.alloc_fragment([N, P], accum_dtype)\n            T.clear(states_frag)\n\n            if reduceO:\n                phi_frag_intrachunk = T.alloc_fragment([R, P], dtype=dtype)\n                T.copy(MIMO_O[i_h, :, :], phi_frag_intrachunk)\n            Psi_frag = T.alloc_fragment([R, P], dtype)\n            T.copy(MIMO_V[i_h, :, :], Psi_frag)\n\n            q_bias_frag = T.alloc_fragment([R, N], dtype)\n            k_bias_frag = T.alloc_fragment([R, N], dtype)\n            T.copy(Q_BIAS[i_h, :, :], q_bias_frag)\n            T.copy(K_BIAS[i_h, :, :], k_bias_frag)\n\n            # --- Chunk Loop ---\n            for i in T.Pipelined(0, nchunks, num_stages=num_stages):\n                chunk_start = i * chunk_size\n                fused_chunk_start = chunk_start * R\n\n                # --- Discretization Factors (Shifted Gamma + Trap Scale) ---\n                trap_shifted_frag = T.alloc_fragment([chunk_size], T.float32)\n                dt_shifted_frag = T.alloc_fragment([chunk_size], dtype)\n                for cs in T.Parallel(chunk_size):\n                    trap_shifted_frag[cs] = T.if_then_else(\n                        chunk_start + cs + 1 < S,\n                        TRAP[i_b, i_h, chunk_start + cs + 1],\n                        0.0,\n                    )\n                    dt_shifted_frag[cs] = T.if_then_else(\n                        chunk_start + cs + 1 < S,\n                        DT[i_b, i_h, chunk_start + cs + 1],\n                        0.0,\n                    )\n                shifted_gamma_frag = T.alloc_fragment([chunk_size], dtype)\n                for cs in T.Parallel(chunk_size):\n                    shifted_gamma_frag[cs] = T.if_then_else(chunk_start + cs < (S - 1), \n                                                            dt_shifted_frag[cs] * (T.sigmoid(-trap_shifted_frag[cs])), \n                                                            0.0)\n\n                shifted_gamma_shared = T.alloc_shared([chunk_size], dtype)\n                T.copy(shifted_gamma_frag, shifted_gamma_shared)\n\n                trap_frag = T.alloc_fragment([chunk_size], T.float32)\n                T.copy(TRAP[i_b, i_h, chunk_start: chunk_start+chunk_size], trap_frag)\n                dt_frag = T.alloc_fragment([chunk_size], dtype)\n                T.copy(DT[i_b, i_h, chunk_start: chunk_start+chunk_size], dt_frag)\n                gamma_frag = T.alloc_fragment([chunk_size], T.float32)\n                for cs in T.Parallel(chunk_size):\n                    gamma_frag[cs] = dt_frag[cs] * T.sigmoid(trap_frag[cs])\n                trap_scale_frag = T.alloc_fragment([chunk_size], dtype)\n                for cs in T.Parallel(chunk_size):\n                    trap_scale_frag[cs] = gamma_frag[cs] + shifted_gamma_shared[cs]\n                trap_scale_shared = T.alloc_shared([chunk_size], dtype)\n                T.copy(trap_scale_frag, trap_scale_shared)\n\n                # --- Up-Project V and Prepare Biased Q/K ---\n                PsiV_frag = T.alloc_fragment([chunk_size, R, P], dtype)\n\n                T.copy(V[i_b, chunk_start:chunk_start+chunk_size, i_h, :], v_shared)\n                for cs, r, p in T.Parallel(chunk_size, R, P):\n                    PsiV_frag[cs, r, p] = v_shared[cs, p] * Psi_frag[r, p]\n                PsiV_reshaped_frag = T.view(PsiV_frag, shape=[fused_chunk_size, P])\n                T.copy(PsiV_reshaped_frag, PsiV_shared)\n\n                q_reshaped_shared = T.view(q_shared, shape=[chunk_size, R, N])\n                T.copy(Q[i_b, chunk_start:chunk_start+chunk_size, :, i_h_qk, :], q_reshaped_shared)\n                q_frag = T.alloc_fragment([chunk_size, R, N], dtype)\n                T.copy(q_reshaped_shared, q_frag)\n                for cs, r, n in T.Parallel(chunk_size, R, N):\n                    q_frag[cs, r, n] += q_bias_frag[r, n]\n                T.copy(q_frag, q_reshaped_shared)\n\n                k_reshaped_shared = T.view(k_shared, shape=[chunk_size, R, N])\n                T.copy(K[i_b, chunk_start:chunk_start+chunk_size, :, i_h_qk, :], k_reshaped_shared)\n                k_frag = T.alloc_fragment([chunk_size, R, N], dtype)\n                T.copy(k_reshaped_shared, k_frag)\n                for cs, r, n in T.Parallel(chunk_size, R, N):\n                    k_frag[cs, r, n] += k_bias_frag[r, n]\n                T.copy(k_frag, k_reshaped_shared)\n\n                # --- Cache Diagonal qk_dot Path ---\n                # Keep full qk_dot in shared memory to reuse per-step R x R blocks.\n                qk_dot_frag = T.alloc_fragment([fused_chunk_size, fused_chunk_size], dtype=accum_dtype)\n                T.gemm(q_shared, k_shared, qk_dot_frag, transpose_B=True, clear_accum=True)\n                T.copy(qk_dot_frag, qk_dot_full_shared)\n                # Output QK_DOT for the bwd_bwd kernel (per-time-step blocks only)\n                for cs, r_out, r_in in T.Parallel(chunk_size, R, R):\n                    QK_DOT[i_b, i_h, chunk_start + cs, r_out, r_in] = \\\n                        qk_dot_full_shared[cs * R + r_out, cs * R + r_in]\n\n                # --- Rotary Q/K + Trap Scaling ---\n                q_first_half_frag = T.alloc_fragment([chunk_size, R, N//rotary_dim_divisor], dtype)\n                q_second_half_frag = T.alloc_fragment([chunk_size, R, N//rotary_dim_divisor], dtype)\n\n                for cs, r, n in T.Parallel(chunk_size, R, N//rotary_dim_divisor):\n                    q_first_half_frag[cs, r, n] = q_shared[cs*R + r, n]\n                    q_second_half_frag[cs, r, n] = q_shared[cs*R + r, N//2 + n]\n\n                # NOTE: angles are casted to fp32 for numerical stability\n                angles_frag = T.alloc_fragment([chunk_size, N//rotary_dim_divisor], T.float32)\n                T.copy(ANGLES[i_b, chunk_start:chunk_start+chunk_size, i_h, :], angles_frag)\n\n                for cs, r, n in T.Parallel(chunk_size, R, N//rotary_dim_divisor):\n                    q_shared[cs*R + r, n] = T.cos(angles_frag[cs, n]) * q_first_half_frag[cs, r, n] - T.sin(angles_frag[cs, n]) * q_second_half_frag[cs, r, n]\n                    q_shared[cs*R + r, N//2 + n] = T.sin(angles_frag[cs, n]) * q_first_half_frag[cs, r, n] + T.cos(angles_frag[cs, n]) * q_second_half_frag[cs, r, n]\n\n                k_first_half_frag = T.alloc_fragment([chunk_size, R, N//rotary_dim_divisor], dtype)\n                k_second_half_frag = T.alloc_fragment([chunk_size, R, N//rotary_dim_divisor], dtype)\n\n                for cs, r, n in T.Parallel(chunk_size, R, N//rotary_dim_divisor):\n                    k_first_half_frag[cs, r, n] = k_shared[cs*R + r, n]\n                    k_second_half_frag[cs, r, n] = k_shared[cs*R + r, N//2 + n]\n                \n                for cs, r, n in T.Parallel(chunk_size, R, N//rotary_dim_divisor):\n                    k_shared[cs*R + r, n] = T.cos(angles_frag[cs, n]) * k_first_half_frag[cs, r, n] - T.sin(angles_frag[cs, n]) * k_second_half_frag[cs, r, n]\n                    k_shared[cs*R + r, N//2 + n] = T.sin(angles_frag[cs, n]) * k_first_half_frag[cs, r, n] + T.cos(angles_frag[cs, n]) * k_second_half_frag[cs, r, n]\n\n                k_trap_scaled_frag = T.alloc_fragment([fused_chunk_size, N], dtype)\n                T.copy(k_shared, k_trap_scaled_frag)\n                for csr, n in T.Parallel(fused_chunk_size, N):\n                    k_trap_scaled_frag[csr, n] *= trap_scale_shared[csr//R]\n                T.copy(k_trap_scaled_frag, k_shared)\n\n                # --- Interchunk + Intrachunk Output Accumulation ---\n                q_state_out_frag = T.alloc_fragment([fused_chunk_size, P], dtype=accum_dtype)\n                # NOTE: Tilelang unable to infer correct layout when trying to cast\n                # states_frag to 16-bit within rmem\n                T.copy(states_frag, states_accum_cast_shared)\n                T.gemm(q_shared, states_accum_cast_shared, q_state_out_frag, clear_accum=True)\n\n                qk_intrachunk_frag = T.alloc_fragment([fused_chunk_size, fused_chunk_size], dtype=accum_dtype)\n                T.gemm(q_shared, k_shared, qk_intrachunk_frag, transpose_B=True, clear_accum=True)\n\n                # Strictly causal masking over chunk steps (exclude same-step diagonal).\n                da_cs__or__exp_da_cs_shared = T.alloc_shared([chunk_size], T.float32)\n                T.copy(DA_CS[i_b, i_h, chunk_start:chunk_start+chunk_size], da_cs__or__exp_da_cs_shared)       \n                for csr_i, csr_j in T.Parallel(fused_chunk_size, fused_chunk_size):\n                    qk_intrachunk_frag[csr_i, csr_j] = T.if_then_else(\n                                                csr_i//R > csr_j//R,\n                                                qk_intrachunk_frag[csr_i, csr_j] * T.exp(SEGSUM[i_b, i_h, i, csr_i//R, csr_j//R]),\n                                                0.0\n                                            )\n                qk_intrachunk_masked_shared = T.alloc_shared([fused_chunk_size, fused_chunk_size], dtype=dtype)\n                for csr_i, csr_j in T.Parallel(fused_chunk_size, fused_chunk_size):\n                    qk_intrachunk_masked_shared[csr_i, csr_j] = qk_intrachunk_frag[csr_i, csr_j]\n                \n                # Exponentiate da_cs__or__exp_da_cs_shared so that later usage does not have to:\n                for cs in T.Parallel(chunk_size):\n                    da_cs__or__exp_da_cs_shared[cs] = T.exp(da_cs__or__exp_da_cs_shared[cs])\n\n                exp_da_cs_frag = T.alloc_fragment([chunk_size], dtype=T.float32)\n                T.copy(da_cs__or__exp_da_cs_shared, exp_da_cs_frag)\n                for csr, p in T.Parallel(fused_chunk_size, P):\n                    q_state_out_frag[csr, p] *= exp_da_cs_frag[csr//R]\n\n                o_mimo_accum_frag = T.alloc_fragment([fused_chunk_size, P], dtype=accum_dtype)\n                T.gemm(qk_intrachunk_masked_shared, PsiV_shared, o_mimo_accum_frag, clear_accum=True)\n\n                # Merge interchunk and intrachunk contributions.\n                for cs, p in T.Parallel(fused_chunk_size, P):\n                    o_mimo_accum_frag[cs, p] += q_state_out_frag[cs, p]\n\n                # --- Add Diagonal Terms (qk_dot and optional D) ---\n                qkdot_psiv_frag = T.alloc_fragment([chunk_size, R, P], dtype=dtype)\n                T.clear(qkdot_psiv_frag)\n                for cs, r_out, p in T.Parallel(chunk_size, R, P):\n                    for r_in in T.serial(R):\n                        qkdot_psiv_frag[cs, r_out, p] += qk_dot_full_shared[cs * R + r_out, cs * R + r_in] * PsiV_shared[cs * R + r_in, p]\n                    qkdot_psiv_frag[cs, r_out, p] *= gamma_frag[cs] # Apply gamma\n                qkdot_psiv_reshaped_frag = T.view(qkdot_psiv_frag, shape=[fused_chunk_size, P])\n                for csr, p in T.Parallel(fused_chunk_size, P):\n                    o_mimo_accum_frag[csr, p] += qkdot_psiv_reshaped_frag[csr, p]\n\n                if hasD:\n                    D_var = T.alloc_var(T.float32)\n                    T.copy(D[i_h], D_var)\n                    PsiV_D_frag = T.alloc_fragment([fused_chunk_size, P], T.float32)\n                    T.copy(PsiV_shared, PsiV_D_frag)\n                    for csr, p in T.Parallel(fused_chunk_size, P):\n                        o_mimo_accum_frag[csr, p] += D_var * PsiV_D_frag[csr, p]\n\n                # --- Project to dMIMO_O and Optional Z Backward Path ---\n                if reduceO:\n                    out_prereduced_shared = T.alloc_shared([fused_chunk_size, P], dtype)\n                    T.copy(o_mimo_accum_frag, out_prereduced_shared)\n                    \n                    o_gated_frag = T.alloc_fragment([chunk_size, R, P], T.float32)\n                    if hasZ:\n                        # Apply Z gating to out:\n                        T.copy(Z[i_b, chunk_start:chunk_start+chunk_size, i_h, :], z_shared)\n                        z_o_frag = T.alloc_fragment([chunk_size, P], T.float32)\n                        T.copy(z_shared, z_o_frag)\n                        Zeta_o_frag = T.alloc_fragment([R, P], T.float32)\n                        T.copy(MIMO_Z[i_h, :, :], Zeta_o_frag)\n                        for cs, r, p in T.Parallel(chunk_size, R, P):\n                            # Apply SiLU to o_gated_frag:\n                            tmp = z_o_frag[cs, p] * Zeta_o_frag[r, p] * 0.5\n                            o_gated_frag[cs, r, p] = tmp * T.tanh(tmp) + tmp\n                        for cs, r, p in T.Parallel(chunk_size, R, P):\n                            o_gated_frag[cs, r, p] *= out_prereduced_shared[cs*R + r, p]\n                    else:\n                        for cs, r, p in T.Parallel(chunk_size, R, P):\n                            o_gated_frag[cs, r, p] = out_prereduced_shared[cs*R + r, p]\n                    \n                    # NOTE: keeping dPhi_frag in fp32 for numerical reason\n                    dPhi_frag = T.alloc_fragment([R, P], T.float32)\n                    T.copy(dPhi_shared, dPhi_frag)\n                    dout_frag = T.alloc_fragment([chunk_size, P], dtype)\n                    T.copy(DOUT[i_b, chunk_start:chunk_start+chunk_size, i_h, :], dout_shared)\n                    T.copy(dout_shared, dout_frag)\n                    for r, p in T.Parallel(R, P):\n                        for cs in T.serial(chunk_size):\n                            dPhi_frag[r, p] += o_gated_frag[cs, r, p] * dout_frag[cs, p]\n                    T.copy(dPhi_frag, dPhi_shared)\n\n                    if hasZ:\n                        # Up-project DOUT from SISO to MIMO.\n                        Phi_frag = T.alloc_fragment([R, P], dtype)\n                        T.copy(MIMO_O[i_h, :, :], Phi_frag)\n                        dPhiO_frag = T.alloc_fragment([chunk_size, R, P], dtype)\n                        dout_preexpand_frag = T.alloc_fragment([chunk_size, P], dtype)\n                        T.copy(dout_shared, dout_preexpand_frag)\n                        for cs, r, p in T.Parallel(chunk_size, R, P):\n                            dPhiO_frag[cs, r, p] = dout_frag[cs, p] * Phi_frag[r, p]\n\n                        # NOTE: layout issue when trying to reuse o_mimo_accum_frag\n                        # NOTE: note that it uses out_prereduced_shared, which is the pre-Z-gate version\n                        # of out\n                        for cs, r, p in T.Parallel(chunk_size, R, P):\n                            dPhiO_frag[cs, r, p] *= out_prereduced_shared[cs*R + r, p]\n                        # Backward of SILU(z) is sigmoid(z) * (1 + z * (1 - sigmoid(z)))\n                        z_frag = T.alloc_fragment([chunk_size, P], T.float32)\n                        T.copy(z_shared, z_frag)\n                        Zeta_frag = T.alloc_fragment([R, P], T.float32)\n                        T.copy(MIMO_Z[i_h, :, :], Zeta_frag)\n                        dZetaZ_frag = T.alloc_fragment([chunk_size, R, P], T.float32)\n                        for cs, r, p in T.Parallel(chunk_size, R, P):\n                            dZetaZ_frag[cs, r, p] = z_frag[cs, p] * Zeta_frag[r, p]\n                            dZetaZ_frag[cs, r, p] = dPhiO_frag[cs, r, p]* T.sigmoid(dZetaZ_frag[cs, r, p]) * \\\n                                (1 + dZetaZ_frag[cs, r, p] * (T.sigmoid(-dZetaZ_frag[cs, r, p])))\n\n                        dZ_frag = T.alloc_fragment([chunk_size, P], dtype)\n                        T.clear(dZ_frag)\n                        for cs, p in T.Parallel(chunk_size, P):\n                            for r in T.serial(R):\n                                dZ_frag[cs, p] += dZetaZ_frag[cs, r, p] * Zeta_frag[r, p]\n                        T.copy(dZ_frag, DZ[i_b, chunk_start:chunk_start+chunk_size, i_h, :])\n\n                        for cs, r, p in T.Parallel(chunk_size, R, P):\n                            dZetaZ_frag[cs, r, p] *= z_frag[cs, p]\n                        dZeta_frag = T.alloc_fragment([R, P], T.float32)\n                        T.copy(dZeta_shared, dZeta_frag)\n                        T.reduce_sum(dZetaZ_frag, dZeta_frag, clear=False, dim=0)\n                        T.copy(dZeta_frag, dZeta_shared)\n                else:\n                    if hasZ:\n                        out_prereduced_shared = T.alloc_shared([fused_chunk_size, P], dtype)\n                        T.copy(o_mimo_accum_frag, out_prereduced_shared)\n                        T.copy(Z[i_b, chunk_start:chunk_start+chunk_size, i_h, :], z_shared)\n                        dPhiO_frag = T.alloc_fragment([chunk_size, R, P], dtype)\n                        for cs, r, p in T.Parallel(chunk_size, R, P):\n                            dPhiO_frag[cs, r, p] = DOUT[i_b, chunk_start + cs, r, i_h, p]\n                        for cs, r, p in T.Parallel(chunk_size, R, P):\n                            dPhiO_frag[cs, r, p] *= out_prereduced_shared[cs*R + r, p]\n                        # Backward of SILU(z) is sigmoid(z) * (1 + z * (1 - sigmoid(z)))\n                        z_frag = T.alloc_fragment([chunk_size, P], T.float32)\n                        T.copy(z_shared, z_frag)\n                        Zeta_frag = T.alloc_fragment([R, P], T.float32)\n                        T.copy(MIMO_Z[i_h, :, :], Zeta_frag)\n                        dZetaZ_frag = T.alloc_fragment([chunk_size, R, P], T.float32)\n                        for cs, r, p in T.Parallel(chunk_size, R, P):\n                            dZetaZ_frag[cs, r, p] = z_frag[cs, p] * Zeta_frag[r, p]\n                            dZetaZ_frag[cs, r, p] = dPhiO_frag[cs, r, p]* T.sigmoid(dZetaZ_frag[cs, r, p]) * \\\n                                (1 + dZetaZ_frag[cs, r, p] * (T.sigmoid(-dZetaZ_frag[cs, r, p])))\n                        ## Compute DZ\n                        dZ_frag = T.alloc_fragment([chunk_size, P], dtype)\n                        T.clear(dZ_frag)\n                        for cs, p in T.Parallel(chunk_size, P):\n                            for r in T.serial(R):\n                                dZ_frag[cs, p] += dZetaZ_frag[cs, r, p] * Zeta_frag[r, p]\n                        T.copy(dZ_frag, DZ[i_b, chunk_start:chunk_start+chunk_size, i_h, :])\n                        ## Compute DMIMO_Z\n                        for cs, r, p in T.Parallel(chunk_size, R, P):\n                            dZetaZ_frag[cs, r, p] *= z_frag[cs, p]\n                        dZeta_frag = T.alloc_fragment([R, P], T.float32)\n                        T.copy(dZeta_shared, dZeta_frag)\n                        T.reduce_sum(dZetaZ_frag, dZeta_frag, clear=False, dim=0)\n                        T.copy(dZeta_frag, dZeta_shared)\n\n                # --- Save and Update Recurrent State ---\n                T.copy(states_frag, STATES[i_b, i_h, i, :, :])\n\n                # DA_CS_REV scales stepwise K contribution into the new state.\n                dA_cs_rev_frag = T.alloc_fragment([chunk_size], T.float32)\n                T.copy(DA_CS_REV[i_b, i_h, chunk_start:chunk_start+chunk_size], dA_cs_rev_frag)\n                # NOTE: we can recycle k_trap_scaled_frag from earlier, however,\n                # that is slower, so choose to recopy from smem:\n                k_state_frag = T.alloc_fragment([fused_chunk_size, N], dtype)\n                T.copy(k_shared, k_state_frag)\n                for csr, n in T.Parallel(fused_chunk_size, N):\n                    k_state_frag[csr, n] *= T.exp(dA_cs_rev_frag[csr//R])\n\n                # DA_CS(last) applies chunk-level decay to the carried state.\n                da_cs_sum = T.alloc_var(T.float32)\n                T.copy(DA_CS[i_b, i_h, chunk_start+chunk_size-1], da_cs_sum)\n                for n, p in T.Parallel(N, P):\n                    states_frag[n, p] *= T.exp(da_cs_sum)\n                T.gemm(k_state_frag, PsiV_shared, states_frag, transpose_A=True, clear_accum=False)\n            \n            if reduceO:\n                T.copy(dPhi_shared, DMIMO_O[i_b, i_h, :, :])\n            if hasZ:\n                T.copy(dZeta_shared, DMIMO_Z[i_b, i_h, :, :])\n\n    return mamba_mimo_bwd_fwd_kernel\n\n# def get_configs():\n#     iter_params = dict(num_stages=[0], threads=[128, 256])\n#     return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]\n\n# @autotune(\n#     configs=get_configs(),\n#     warmup=3,\n#     rep=20,\n# )\n@tilelang.jit(\n    out_idx=[],\n    pass_configs={\n        tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,\n        tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,\n        tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,\n    })\ndef mamba_mimo_bwd_bwd(\n    B,\n    S,\n    H,\n    G,\n    N,\n    P,\n    R,\n    hasZ,\n    hasD,\n    reduceO,\n    chunk_size: int = 16,\n    rotary_dim_divisor: int = 4,\n    dtype: str = 'float16',\n    threads: int = 256,\n    num_stages: int = 0,\n) -> torch.Tensor:\n\n    accum_dtype = 'float32'\n\n    nchunks = tilelang.cdiv(S, chunk_size)\n    fused_chunk_size = chunk_size * R\n\n    if reduceO:\n        DOUT_shape = (B, S, H, P)\n    else:\n        DOUT_shape = (B, S, R, H, P)\n\n    @T.prim_func\n    def mamba_mimo_bwd_bwd_kernel(\n            DOUT: T.Tensor(DOUT_shape, dtype),  # type: ignore\n            Q: T.Tensor([B, S, R, G, N], dtype),  # type: ignore\n            K: T.Tensor([B, S, R, G, N], dtype),  # type: ignore\n            V: T.Tensor([B, S, H, P], dtype),  # type: ignore\n            Q_BIAS: T.Tensor([H, R, N], T.float32),  # type: ignore\n            K_BIAS: T.Tensor([H, R, N], T.float32),  # type: ignore\n            MIMO_V: T.Tensor([H, R, P], T.float32), # type: ignore\n            MIMO_O: T.Tensor([H, R, P], T.float32), # type: ignore\n            DK: T.Tensor([B, S*R, H, N], dtype),  # type: ignore\n            DV: T.Tensor([B, S, H, P], dtype),  # type: ignore\n            DMIMO_V: T.Tensor([B, H, R, P], T.float32), # type: ignore\n            STATES: T.Tensor([B, H, nchunks, N, P], dtype), # type: ignore \n            DQ: T.Tensor([B, S*R, H, N], dtype),  # type: ignore\n\n            Z: T.Tensor([B, S, H, P], dtype),  # type: ignore\n            MIMO_Z: T.Tensor([H, R, P], T.float32), # type: ignore\n            ANGLES: T.Tensor([B, S, H, N//rotary_dim_divisor], T.float32), # type: ignore\n            DA_CS: T.Tensor([B, H, S], T.float32), # type: ignore\n            DA_CS_REV: T.Tensor([B, H, S], T.float32), # type: ignore\n            DT: T.Tensor([B, H, S], T.float32), # type: ignore\n            TRAP: T.Tensor([B, H, S], dtype), # type: ignore\n            DFACTOR: T.Tensor([B, H, S], T.float32), # type: ignore\n            DGAMMA_DIAG: T.Tensor([B, H, S], T.float32), # type: ignore\n            DANGLES: T.Tensor([B, S, H, N//rotary_dim_divisor], T.float32), # type: ignore\n            D: T.Tensor([H], T.float32), # type: ignore\n            DD: T.Tensor([B, H], T.float32), # type: ignore\n\n            QK_DOT: T.Tensor([B, H, S, R, R], dtype), # type: ignore\n            # DQK_DOT: T.Tensor([B, H, S, R, R], dtype), # type: ignore\n            DDA: T.Tensor([B, H, S], T.float32), # type: ignore\n            DSSDA: T.Tensor([B, H, nchunks, chunk_size, chunk_size], T.float32), # type: ignore\n            DDA_CS_REV: T.Tensor([B, H, S], T.float32), # type: ignore\n            DDA_CS: T.Tensor([B, H, S], T.float32), # type: ignore\n\n            SEGSUM: T.Tensor([B, H, nchunks, chunk_size, chunk_size], T.float32), # type: ignore\n            ):\n        \"\"\"\n        Overview:\n            Reverse-chunk backward pass that consumes cached STATES and QK_DOT from the first pass\n            to produce gradients for the fused Mamba3 attention block.\n\n        Inputs:\n            - Forward activations/tensors: DOUT, Q, K, V, optional Z, optional D.\n            - Projection weights/biases: Q_BIAS, K_BIAS, MIMO_V (Psi), MIMO_O (Phi), optional MIMO_Z (Zeta).\n            - Cached intermediates: STATES and QK_DOT.\n            - Discretization grads and factors:\n              DA_CS, DA_CS_REV, DT, TRAP, DDA, DSSDA, DDA_CS_REV, DDA_CS, and SEGSUM.\n\n        Outputs:\n            - QKV grads: DQ, DK, DV.\n            - MIMO projection grads: DMIMO_V.\n            - Discretization/rotation grads: DANGLES, DFACTOR, DGAMMA_DIAG, DDA_CS_REV, DDA_CS, DDA.\n            - Additional grads: optional DD.\n\n        Notation:\n            - Psi: MIMO X projection.\n            - Phi: MIMO O projection.\n            - Zeta: MIMO Z projection.\n            - Trap: convex-combination modulator used in exponential-trapezoidal discretization.\n        \"\"\"\n        \n        with T.Kernel(H, B, threads=threads) as (i_h, i_b):\n            # --- Kernel Setup ---\n            # GQA support: map V head to Q/K head\n            i_h_qk = i_h // (H // G)\n\n            # --- Buffer Allocation ---\n            dstates_shared = T.alloc_shared([N, P], dtype)\n            dstates_frag = T.alloc_fragment([N, P], accum_dtype)\n\n            dout_shared = T.alloc_shared([chunk_size, P], dtype)\n            dPhiO_shared = T.alloc_shared([fused_chunk_size, P], dtype)\n\n            q_shared = T.alloc_shared([fused_chunk_size, N], dtype)\n\n            k_shared = T.alloc_shared([fused_chunk_size, N], dtype)\n            v_shared = T.alloc_shared([chunk_size, P], dtype)\n\n            states_shared = T.alloc_shared([N, P], dtype)\n            lkq_masked__or__dkq_masked_shared = T.alloc_shared([fused_chunk_size, fused_chunk_size], dtype)\n\n            dPsiV_combined_shared = T.alloc_shared([fused_chunk_size, P], dtype)\n\n            dqk_from_diag_shared = T.alloc_shared([fused_chunk_size, fused_chunk_size], accum_dtype)\n\n            q_pre_rot_shared = T.alloc_shared([fused_chunk_size, N], dtype)\n            k_pre_rot_shared = T.alloc_shared([fused_chunk_size, N], dtype)\n\n            dk_shared = T.alloc_shared([fused_chunk_size, N], dtype)\n            dq_shared = T.alloc_shared([fused_chunk_size, N], dtype)\n\n            qk_dot_shared = T.alloc_shared([chunk_size, R, R], dtype)\n\n            k_pre_trap_shared = T.alloc_shared([fused_chunk_size, N], dtype)\n\n            dangle_dk__or__dq_shared = T.alloc_shared([fused_chunk_size, N//rotary_dim_divisor], T.float32)\n\n            # --- Swizzling Annotation ---\n            noswizzle_annot = threads == 256 and (N <= 32 or P >= 128) # NOTE: heuristics for when swizzling annotation causes kernel hang, needs more investigation\n            if not noswizzle_annot:\n                T.annotate_layout({\n                    dstates_shared: tilelang.layout.make_swizzled_layout(dstates_shared),\n                    dout_shared: tilelang.layout.make_swizzled_layout(dout_shared),\n                    q_shared: tilelang.layout.make_swizzled_layout(q_shared),\n\n                    k_shared: tilelang.layout.make_swizzled_layout(k_shared),\n                    v_shared: tilelang.layout.make_swizzled_layout(v_shared),\n                    states_shared: tilelang.layout.make_swizzled_layout(states_shared),\n                    lkq_masked__or__dkq_masked_shared: tilelang.layout.make_swizzled_layout(lkq_masked__or__dkq_masked_shared),\n\n                    dPsiV_combined_shared: tilelang.layout.make_swizzled_layout(dPsiV_combined_shared),\n                    dqk_from_diag_shared: tilelang.layout.make_swizzled_layout(dqk_from_diag_shared),\n\n                    k_pre_rot_shared: tilelang.layout.make_swizzled_layout(k_pre_rot_shared),\n                    q_pre_rot_shared: tilelang.layout.make_swizzled_layout(q_pre_rot_shared),\n\n                    dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),\n                    dq_shared: tilelang.layout.make_swizzled_layout(dq_shared),\n\n                    k_pre_trap_shared: tilelang.layout.make_swizzled_layout(k_pre_trap_shared),\n                    dangle_dk__or__dq_shared: tilelang.layout.make_swizzled_layout(dangle_dk__or__dq_shared),\n                })\n            T.use_swizzle(10, \"row\")\n            T.no_set_max_nreg()\n\n            # --- Per-Head Constants / Running State ---\n            T.clear(dstates_frag)\n            T.clear(dstates_shared)\n\n            if reduceO:\n                Phi_frag = T.alloc_fragment([R, P], dtype)\n                T.copy(MIMO_O[i_h, :, :], Phi_frag)\n            Psi_frag = T.alloc_fragment([R, P], dtype)\n            T.copy(MIMO_V[i_h, :, :], Psi_frag)\n\n            dPsi_acc = T.alloc_fragment([R, P], accum_dtype) # TODO\n            T.clear(dPsi_acc)\n\n            if hasD:\n                dD_frag = T.alloc_fragment([1], accum_dtype)\n                T.clear(dD_frag)\n\n            q_bias_frag = T.alloc_fragment([R, N], dtype)\n            k_bias_frag = T.alloc_fragment([R, N], dtype)\n            T.copy(Q_BIAS[i_h, :, :], q_bias_frag)\n            T.copy(K_BIAS[i_h, :, :], k_bias_frag)\n\n            # --- Reverse Chunk Loop ---\n            for chunk_idx_rev in T.Pipelined(0, nchunks, num_stages=num_stages):\n                chunk_idx = nchunks - 1 - chunk_idx_rev\n                chunk_start = chunk_idx * chunk_size\n                fused_chunk_start = chunk_start * R\n\n                # --- Discretization Factors (Shifted Gamma + Trap Scale) ---\n                trap_shifted_frag = T.alloc_fragment([chunk_size], T.float32)\n                dt_shifted_frag = T.alloc_fragment([chunk_size], dtype)\n                for cs in T.Parallel(chunk_size):\n                    trap_shifted_frag[cs] = T.if_then_else(\n                        chunk_start + cs + 1 < S,\n                        TRAP[i_b, i_h, chunk_start + cs + 1],\n                        0.0,\n                    )\n                    dt_shifted_frag[cs] = T.if_then_else(\n                        chunk_start + cs + 1 < S,\n                        DT[i_b, i_h, chunk_start + cs + 1],\n                        0.0,\n                    )\n                shifted_gamma_frag = T.alloc_fragment([chunk_size], dtype)\n                for cs in T.Parallel(chunk_size):\n                    shifted_gamma_frag[cs] = T.if_then_else(chunk_start + cs < (S - 1), \n                                                            dt_shifted_frag[cs] * T.sigmoid(-trap_shifted_frag[cs]), \n                                                            0.0)\n\n                trap_frag = T.alloc_fragment([chunk_size], T.float32)\n                T.copy(TRAP[i_b, i_h, chunk_start: chunk_start+chunk_size], trap_frag)\n                dt_frag = T.alloc_fragment([chunk_size], dtype)\n                T.copy(DT[i_b, i_h, chunk_start: chunk_start+chunk_size], dt_frag)\n                gamma_frag = T.alloc_fragment([chunk_size], T.float32)\n                for cs in T.Parallel(chunk_size):\n                    gamma_frag[cs] = dt_frag[cs] * T.sigmoid(trap_frag[cs])\n                gamma_cached_frag = T.alloc_fragment([chunk_size], T.float32)\n                T.copy(gamma_frag, gamma_cached_frag)\n                trap_scale_frag = T.alloc_fragment([chunk_size], dtype)\n                for cs in T.Parallel(chunk_size):\n                    trap_scale_frag[cs] = gamma_frag[cs] + shifted_gamma_frag[cs]\n                trap_scale_shared = T.alloc_shared([chunk_size], dtype)\n                T.copy(trap_scale_frag, trap_scale_shared)\n\n                # --- DOUT Projection and Optional Z / D Paths ---\n                dPhiO_frag = T.alloc_fragment([chunk_size, R, P], dtype)\n                if reduceO:\n                    for cs, p in T.Parallel(chunk_size, P):\n                        dout_shared[cs, p] = DOUT[i_b, chunk_start+cs, i_h, p]\n                    for cs, r, p in T.Parallel(chunk_size, R, P):\n                        dPhiO_frag[cs, r, p] = dout_shared[cs, p] * Phi_frag[r, p]\n                else:\n                    for cs, r, p in T.Parallel(chunk_size, R, P):\n                        dPhiO_frag[cs, r, p] = DOUT[i_b, chunk_start + cs, r, i_h, p]\n\n                if hasZ:\n                    ## Backpropagate via *SILU(Z)\n                    Zeta_frag = T.alloc_fragment([R, P], dtype)\n                    T.copy(MIMO_Z[i_h, :, :], Zeta_frag)\n                    z_frag = T.alloc_fragment([chunk_size, P], dtype)\n                    T.copy(Z[i_b, chunk_start:chunk_start+chunk_size, i_h, :], z_frag)\n                    for cs, r, p in T.Parallel(chunk_size, R, P):\n                        tmp = z_frag[cs, p] * Zeta_frag[r, p] * 0.5\n                        dPhiO_frag[cs, r, p] *= tmp * T.tanh(tmp) + tmp\n                T.copy(T.view(dPhiO_frag, shape=[fused_chunk_size, P]), dPhiO_shared)\n\n                T.copy(V[i_b, chunk_start:chunk_start+chunk_size, i_h, :], v_shared)\n                if hasD:\n                    # Compute dD via projected DOUT and V/Psi factors.\n                    v_dD_frag =  T.alloc_fragment([chunk_size, P], accum_dtype)\n                    Psi_dD_frag = T.alloc_fragment([R, P], accum_dtype)\n                    T.copy(v_shared, v_dD_frag)\n                    T.copy(MIMO_V[i_h, :, :], Psi_dD_frag)\n                    for cs, r, p in T.Parallel(chunk_size, R, P):\n                        dPhiO_frag[cs, r, p] *= v_dD_frag[cs, p] * Psi_dD_frag[r, p]\n                    T.reduce_sum(T.view(dPhiO_frag, shape=[fused_chunk_size*P]), dD_frag, clear=False)\n\n                # --- Prepare Rotated/Scaled QK and Compute dPsiV ---\n                # Load q and apply q_bias to it:\n                for cs, r, n in T.Parallel(chunk_size, R, N):\n                    q_shared[cs*R + r, n] = Q[i_b, chunk_start+cs, r, i_h_qk, n]\n                \n                q_frag = T.alloc_fragment([chunk_size, R, N], dtype)\n                for cs, r, n in T.Parallel(chunk_size, R, N):\n                    q_frag[cs, r, n] = q_shared[cs*R + r, n]\n                for cs, r, n in T.Parallel(chunk_size, R, N):\n                    q_frag[cs, r, n] += q_bias_frag[r, n]\n                for cs, r, n in T.Parallel(chunk_size, R, N):\n                    q_shared[cs*R + r, n] = q_frag[cs, r, n]\n                T.copy(q_shared, q_pre_rot_shared) # Save pre-rotated q for later:\n                # Apply rotary to q:\n                q_first_half_frag = T.alloc_fragment([chunk_size, R, N//rotary_dim_divisor], dtype)\n                q_second_half_frag = T.alloc_fragment([chunk_size, R, N//rotary_dim_divisor], dtype)\n                for cs, r, n in T.Parallel(chunk_size, R, N//rotary_dim_divisor):\n                    q_first_half_frag[cs, r, n] = q_shared[cs*R + r, n]\n                    q_second_half_frag[cs, r, n] = q_shared[cs*R + r, N//2 + n]\n                angles_frag = T.alloc_fragment([chunk_size, N//rotary_dim_divisor], T.float32)\n                T.copy(ANGLES[i_b, chunk_start:chunk_start+chunk_size, i_h, :], angles_frag)\n                for cs, r, n in T.Parallel(chunk_size, R, N//rotary_dim_divisor):\n                    q_shared[cs*R + r, n] = T.cos(angles_frag[cs, n]) * q_first_half_frag[cs, r, n] - T.sin(angles_frag[cs, n]) * q_second_half_frag[cs, r, n]\n                    q_shared[cs*R + r, N//2 + n] = T.sin(angles_frag[cs, n]) * q_first_half_frag[cs, r, n] + T.cos(angles_frag[cs, n]) * q_second_half_frag[cs, r, n]\n\n                # Load k and apply k_bias to it:\n                k_reshaped_shared = T.view(k_pre_trap_shared, shape=[chunk_size, R, N])\n                T.copy(K[i_b, chunk_start:chunk_start+chunk_size, :, i_h_qk, :], k_reshaped_shared)\n                k_frag = T.alloc_fragment([chunk_size, R, N], dtype)\n                T.copy(k_reshaped_shared, k_frag)\n                for cs, r, n in T.Parallel(chunk_size, R, N):\n                    k_frag[cs, r, n] += k_bias_frag[r, n]\n                T.copy(k_frag, k_reshaped_shared)\n                # Save pre-rotated k for later:\n                for csr, n in T.Parallel(fused_chunk_size, N):\n                    k_pre_rot_shared[csr, n] = k_pre_trap_shared[csr, n]\n                # Apply rotary to k:\n                k_first_half_frag = T.alloc_fragment([chunk_size, R, N//rotary_dim_divisor], dtype)\n                k_second_half_frag = T.alloc_fragment([chunk_size, R, N//rotary_dim_divisor], dtype)\n                for cs, r, n in T.Parallel(chunk_size, R, N//rotary_dim_divisor):\n                    k_first_half_frag[cs, r, n] = k_reshaped_shared[cs, r, n]\n                    k_second_half_frag[cs, r, n] = k_reshaped_shared[cs, r, N//2 + n]\n                for cs, r, n in T.Parallel(chunk_size, R, N//rotary_dim_divisor):\n                    k_reshaped_shared[cs, r, n] = T.cos(angles_frag[cs, n]) * k_first_half_frag[cs, r, n] - T.sin(angles_frag[cs, n]) * k_second_half_frag[cs, r, n]\n                    k_reshaped_shared[cs, r, N//2 + n] = T.sin(angles_frag[cs, n]) * k_first_half_frag[cs, r, n] + T.cos(angles_frag[cs, n]) * k_second_half_frag[cs, r, n]\n                # Apply Trap-specific scaling:\n                k_trap_scaled_frag = T.alloc_fragment([fused_chunk_size, N], dtype)\n                T.copy(k_pre_trap_shared, k_trap_scaled_frag)\n                for csr, n in T.Parallel(fused_chunk_size, N):\n                    k_trap_scaled_frag[csr, n] *= trap_scale_shared[csr//R]\n                T.copy(k_trap_scaled_frag, k_shared)\n\n                # Apply the effect of interchunk (state update):\n                dPsiV_frag = T.alloc_fragment([fused_chunk_size, P], accum_dtype)\n                T.gemm(k_shared, dstates_shared, dPsiV_frag, clear_accum=True)\n                dA_cs_rev_frag = T.alloc_fragment([chunk_size], T.float32)\n                dA_cs_rev_shared = T.alloc_shared([chunk_size], T.float32)\n                T.copy(DA_CS_REV[i_b, i_h, chunk_start:chunk_start+chunk_size], dA_cs_rev_shared)\n                T.copy(dA_cs_rev_shared, dA_cs_rev_frag)\n                for csr, p in T.Parallel(fused_chunk_size, P):\n                    # DA_CS_REV scales per-step state contribution into dPsiV.\n                    dPsiV_frag[csr, p] *= T.exp(dA_cs_rev_frag[csr//R])\n\n                # Apply the effect of intrachunk:\n                lkq_frag = T.alloc_fragment([fused_chunk_size, fused_chunk_size], accum_dtype)\n                T.gemm(k_shared, q_shared, lkq_frag, transpose_B=True, clear_accum=True)\n                T.copy(lkq_frag, lkq_masked__or__dkq_masked_shared) # NOTE: Save later for the computation of DSSDA, using lkq_masked__or__dkq_masked_shared which has the same shape\n                if R == 1: # More smem efficient which is necessary for R=1, but slower due to the need for casting\n                    lkq_masked_dtype_buf = T.alloc_fragment([fused_chunk_size, fused_chunk_size], dtype)\n                    T.copy(lkq_masked__or__dkq_masked_shared, lkq_masked_dtype_buf)\n                    for csr_i, csr_j in T.Parallel(fused_chunk_size, fused_chunk_size):\n                        # Reverse-causal mask for backward flow across chunk steps.\n                        lkq_masked_dtype_buf[csr_i, csr_j] = T.if_then_else(\n                            csr_i//R < csr_j//R,\n                            lkq_masked_dtype_buf[csr_i, csr_j]\n                            * T.exp(SEGSUM[i_b, i_h, chunk_idx, csr_j//R, csr_i//R]),\n                            0.0\n                        )\n                else:\n                    for csr_i, csr_j in T.Parallel(fused_chunk_size, fused_chunk_size):\n                        # Reverse-causal mask for backward flow across chunk steps.\n                        lkq_frag[csr_i, csr_j] = T.if_then_else(\n                            csr_i//R < csr_j//R,\n                            lkq_frag[csr_i, csr_j]\n                            * T.exp(SEGSUM[i_b, i_h, chunk_idx, csr_j//R, csr_i//R]),\n                            0.0\n                        )\n                    lkq_masked_dtype_buf = T.alloc_shared([fused_chunk_size, fused_chunk_size], dtype)\n                    T.copy(lkq_frag, lkq_masked_dtype_buf) # Convert to dtype\n                T.gemm(lkq_masked_dtype_buf, dPhiO_shared, dPsiV_frag, clear_accum=False)\n\n                # --- Add Diagonal Contributions to dPsiV (D and qk_dot) ---\n                dPsiV_D_fused_frag = T.alloc_fragment([fused_chunk_size, P], accum_dtype)\n                if hasD:\n                    D_frag = T.alloc_var(T.float32)\n                    T.copy(D[i_h], D_frag)\n                    for csr, p in T.Parallel(fused_chunk_size, P):\n                        dPsiV_D_fused_frag[csr, p] = dPsiV_frag[csr, p] + dPhiO_shared[csr, p]*D_frag\n                else:\n                    T.copy(dPsiV_frag, dPsiV_D_fused_frag)\n                # Compute the contribution from the qk_dot term:\n                # NOTE: recomputing qk_dot here is much slower than just loading from\n                # the result of the bwd_fwd kernel\n                qk_dot_frag = T.alloc_fragment([chunk_size, R, R], dtype)\n                T.copy(QK_DOT[i_b, i_h, chunk_start:chunk_start+chunk_size, :, :], qk_dot_shared)\n                T.copy(qk_dot_shared, qk_dot_frag)\n                gamma_dPsiV_frag = T.alloc_fragment([chunk_size], dtype)\n                T.copy(gamma_frag, gamma_dPsiV_frag)\n                for csr, p in T.Parallel(fused_chunk_size, P):\n                    cs = csr // R\n                    r_in = csr % R\n                    for r_out in T.serial(R):\n                        csr_out = cs * R + r_out\n                        dPsiV_D_fused_frag[csr, p] += dPhiO_shared[csr_out, p] * qk_dot_frag[cs, r_out, r_in] * gamma_dPsiV_frag[cs]\n                T.copy(dPsiV_D_fused_frag, dPsiV_combined_shared)\n\n                # --- Compute dV and dPsi from dPsiV ---\n                # Compute dV\n                dv_frag = T.alloc_fragment([chunk_size, P], dtype)\n                T.clear(dv_frag)\n                for cs, p in T.Parallel(chunk_size, P):\n                    for r in T.serial(R):\n                        dv_frag[cs, p] += dPsiV_combined_shared[cs*R + r, p] * Psi_frag[r, p]\n                T.copy(dv_frag, DV[i_b, chunk_start:chunk_start+chunk_size, i_h, :])\n\n                dPsi_frag = T.alloc_fragment([R, P], accum_dtype)\n                T.copy(dPsi_acc, dPsi_frag)\n                v_frag = T.alloc_fragment([chunk_size, P], accum_dtype)\n                T.copy(v_shared, v_frag)\n                for r, p in T.Parallel(R, P):\n                    for cs in T.serial(chunk_size):\n                        dPsi_frag[r, p] += dPsiV_combined_shared[cs*R + r, p] * v_frag[cs, p]\n                T.copy(dPsi_frag, dPsi_acc)\n\n                # Compute Psi_V\n                PsiV_frag = T.alloc_fragment([chunk_size, R, P], dtype)\n                T.clear(PsiV_frag)\n                for cs, p in T.Parallel(chunk_size, P):\n                    for r in T.serial(R):\n                        PsiV_frag[cs, r, p] += v_frag[cs, p] * Psi_frag[r, p]\n                # NOTE: Tilelang unable to perform gemm with reshaped PsiV_frag\n                # so have to copy to smem\n                PsiV_shared  = T.alloc_shared([fused_chunk_size, P], dtype)\n                for cs, r, p in T.Parallel(chunk_size, R, P):\n                    PsiV_shared[cs*R + r, p] = PsiV_frag[cs, r, p]\n\n                # Compute dqk_from_diag, which is the contribution to dQ/dK from qk_dot:\n                dqk_from_diag_frag = T.alloc_fragment([fused_chunk_size, fused_chunk_size], accum_dtype)\n                T.gemm(dPhiO_shared, PsiV_shared, dqk_from_diag_frag, transpose_B=True, clear_accum=True) # (cs*r_out, cs*r_in)\n                # Compute dgamma_diag\n                dgamma_diag_prereduce_frag = T.alloc_fragment([chunk_size, R, R], accum_dtype)\n                T.copy(qk_dot_shared, dgamma_diag_prereduce_frag)\n                T.copy(dqk_from_diag_frag, dqk_from_diag_shared)\n                for cs, r_out, r_in in T.Parallel(chunk_size, R, R):\n                    dgamma_diag_prereduce_frag[cs, r_out, r_in] *= dqk_from_diag_shared[cs*R + r_out, cs*R + r_in]\n\n                dgamma_diag_reduced_frag = T.alloc_fragment([chunk_size], accum_dtype)\n                T.reduce_sum(\n                    T.view(dgamma_diag_prereduce_frag, shape=[chunk_size, R*R]),\n                    dgamma_diag_reduced_frag,\n                    dim=-1,\n                    clear=True\n                    )\n                T.copy(dgamma_diag_reduced_frag, DGAMMA_DIAG[i_b, i_h, chunk_start:chunk_start+chunk_size])\n                # Apply shifted gamma to dqk:\n                gamma_qk_frag = T.alloc_fragment([chunk_size], accum_dtype)\n                T.copy(gamma_cached_frag, gamma_qk_frag) # Apply shifted gamma\n                for csr_i, csr_j in T.Parallel(fused_chunk_size, fused_chunk_size):\n                    dqk_from_diag_frag[csr_i, csr_j] *= gamma_qk_frag[csr_i//R]\n                T.copy(dqk_from_diag_frag, dqk_from_diag_shared)\n\n                # --- dK Path + ddA Terms ---\n                dk_frag = T.alloc_fragment([fused_chunk_size, N], accum_dtype)\n                T.gemm(PsiV_shared, dstates_shared, dk_frag, transpose_B=True, clear_accum=True)\n\n                # Compute contribution to ddA from KV part of state update (part 1 of 4)\n                ddA_state_kv_prereduce_frag = T.alloc_fragment([fused_chunk_size, N], accum_dtype)\n                T.copy(k_shared, ddA_state_kv_prereduce_frag)\n                for csr, n in T.Parallel(fused_chunk_size, N):\n                    ddA_state_kv_prereduce_frag[csr, n] *= dk_frag[csr, n]\n                ddA_state_kv_prereduce_frag_reshaped = T.view(ddA_state_kv_prereduce_frag, shape=[chunk_size, R*N])\n                ddA_state_kv_frag = T.alloc_fragment([chunk_size], accum_dtype)\n                T.reduce_sum(ddA_state_kv_prereduce_frag_reshaped, ddA_state_kv_frag, dim=-1, clear=True)\n                T.copy(ddA_state_kv_frag, DDA_CS_REV[i_b, i_h, chunk_start:chunk_start+chunk_size])\n\n                # Interchunk path uses k_scaled * exp(dA_cs_rev) in forward,\n                # so apply exp(dA_cs_rev) to the interchunk dk term only.\n                dA_cs_rev_dk_frag = T.alloc_fragment([chunk_size], T.float32)\n                T.copy(dA_cs_rev_shared, dA_cs_rev_dk_frag)\n                for cs in T.Parallel(chunk_size):\n                    dA_cs_rev_dk_frag[cs] = T.exp(dA_cs_rev_dk_frag[cs])\n                for csr, n in T.Parallel(fused_chunk_size, N):\n                    dk_frag[csr, n] *= dA_cs_rev_dk_frag[csr//R]\n\n                dk_intrachunk_frag = T.alloc_fragment([fused_chunk_size, fused_chunk_size], accum_dtype)\n                T.gemm(PsiV_shared, dPhiO_shared, dk_intrachunk_frag, transpose_B=True, clear_accum=True)\n\n                # Compute contribution to ddA from intrachunk (part 2 of 4)\n                kq_frag = T.alloc_fragment([fused_chunk_size, fused_chunk_size], dtype)\n                T.copy(lkq_masked__or__dkq_masked_shared, kq_frag)\n                for csr_i, csr_j in T.Parallel(fused_chunk_size, fused_chunk_size):\n                    kq_frag[csr_i, csr_j] *= dk_intrachunk_frag[csr_i, csr_j]\n                kq_frag_reshaped = T.view(kq_frag, shape=[fused_chunk_size, chunk_size, R])\n                interchunk_dda_prereduce_frag = T.alloc_fragment([fused_chunk_size, chunk_size], accum_dtype)\n                T.reduce_sum(kq_frag_reshaped, interchunk_dda_prereduce_frag, dim=-1, clear=True)\n                interchunk_dda_prereduce_frag_reshaped = T.view(interchunk_dda_prereduce_frag, shape=[chunk_size, R, chunk_size])\n                interchunk_dda_frag = T.alloc_fragment([chunk_size, chunk_size], accum_dtype)\n                T.reduce_sum(interchunk_dda_prereduce_frag_reshaped, interchunk_dda_frag, dim=1, clear=True)\n                T.copy(interchunk_dda_frag, DSSDA[i_b, i_h, chunk_idx, :, :])\n\n                for csr_i, csr_j in T.Parallel(fused_chunk_size, fused_chunk_size):\n                    # Reverse-causal mask for intrachunk gradient flow.\n                    dk_intrachunk_frag[csr_i, csr_j] = T.if_then_else(\n                        csr_i//R < csr_j//R,\n                        dk_intrachunk_frag[csr_i, csr_j]\n                        * T.exp(SEGSUM[i_b, i_h, chunk_idx, csr_j//R, csr_i//R]),\n                        0.0\n                    )\n\n                T.copy(dk_intrachunk_frag, lkq_masked__or__dkq_masked_shared) # denote lkq_masked__or__dkq_masked_shared as dkq_intrachunk\n                T.copy(dk_frag, dk_shared)\n                dk_nodiag_frag = T.alloc_fragment([fused_chunk_size, N], accum_dtype)\n                T.copy(dk_shared, dk_nodiag_frag)\n                T.gemm(lkq_masked__or__dkq_masked_shared, q_shared, dk_nodiag_frag, clear_accum=False) # Adding dk_interchunk to dkq_intrachunk @ q\n                # Compute dfactor, using dk_nodiag_frag:\n                k_factor_frag = T.alloc_fragment([chunk_size, R, N], accum_dtype)\n                T.copy(k_pre_trap_shared, T.view(k_factor_frag, shape=[fused_chunk_size, N]))\n                dfactor_prereduce_frag = T.alloc_fragment([chunk_size, R, N], accum_dtype)\n                for cs, r, n in T.Parallel(chunk_size, R, N):\n                    dfactor_prereduce_frag[cs, r, n] = k_factor_frag[cs, r, n] * dk_nodiag_frag[cs*R + r, n]\n                dfactor_frag = T.alloc_fragment([chunk_size], accum_dtype)\n                T.reduce_sum(T.view(dfactor_prereduce_frag, shape=[chunk_size, R*N]), dfactor_frag, dim=-1, clear=True)\n                T.copy(dfactor_frag, DFACTOR[i_b, i_h, chunk_start:chunk_start+chunk_size])\n                # Account for the effect of trap_scale = gamma + shifted_gamma:\n                trap_scale_dk_frag = T.alloc_fragment([chunk_size], dtype)\n                T.copy(trap_scale_shared, trap_scale_dk_frag)\n                for csr, n in T.Parallel(fused_chunk_size, N):\n                    dk_nodiag_frag[csr, n] *= trap_scale_dk_frag[csr//R]\n                T.copy(dk_nodiag_frag, dk_shared)\n\n                # --- State-Passing ddA Terms + Interchunk dQ ---\n                T.copy(STATES[i_b, i_h, chunk_idx, :, :], states_shared) # Load cached states from bwd_fwd\n                # NOTE: Compute the contribution of state passing (part 3 of 4)\n                states_frag = T.alloc_fragment([N, P], T.float32)\n                T.copy(states_shared, states_frag)\n                ddA_state_passing = T.alloc_fragment([1], T.float32)\n                ddA_state_passing_prereduce_frag = T.alloc_fragment([N, P], T.float32)\n                da_cs_sum = T.alloc_var(T.float32)\n                T.copy(DA_CS[i_b, i_h, chunk_start+chunk_size-1], da_cs_sum)\n                for n, p in T.Parallel(N, P):\n                    ddA_state_passing_prereduce_frag[n, p] = (\n                        states_frag[n, p] \n                        * dstates_frag[n, p] \n                        * T.exp(da_cs_sum)\n                    )\n                T.reduce_sum(\n                    T.view(ddA_state_passing_prereduce_frag, shape=[N*P]),\n                    ddA_state_passing,\n                    dim=-1, clear=True,\n                )\n                dda_frag = T.alloc_fragment([chunk_size,], T.float32)\n                for cs in T.Parallel(chunk_size):\n                    dda_frag[cs] = ddA_state_passing[0]\n                T.copy(dda_frag, DDA[i_b, i_h, chunk_start:chunk_start+chunk_size])\n\n                dq_frag = T.alloc_fragment([fused_chunk_size, N], accum_dtype)\n                T.gemm(dPhiO_shared, states_shared, dq_frag, transpose_B=True, clear_accum=True)\n                # NOTE: Compute the contribution to ddA from applying it to q*state (part 4 of 4)\n                dda_cs_prereduce_frag = T.alloc_fragment([fused_chunk_size, N], accum_dtype)\n                T.copy(q_shared, dda_cs_prereduce_frag)\n                for csr, n in T.Parallel(fused_chunk_size, N):\n                    dda_cs_prereduce_frag[csr, n] *= dq_frag[csr, n]\n                dda_cs_frag = T.alloc_fragment([chunk_size], accum_dtype)\n                T.reduce_sum(T.view(dda_cs_prereduce_frag, shape=[chunk_size, R*N]), \n                             dda_cs_frag, dim=-1, clear=True)\n                T.copy(dda_cs_frag, DDA_CS[i_b, i_h, chunk_start:chunk_start+chunk_size])\n\n                dA_cs_dq_frag = T.alloc_fragment([chunk_size], T.float32)\n                dA_cs_shared = T.alloc_shared([chunk_size], T.float32)\n\n                T.copy(DA_CS[i_b, i_h, chunk_start:chunk_start+chunk_size], dA_cs_shared)\n                T.copy(dA_cs_shared, dA_cs_dq_frag)\n                for csr, n in T.Parallel(fused_chunk_size, N):\n                    # DA_CS scales interchunk q-state contribution in backward.\n                    dq_frag[csr, n] *= T.exp(dA_cs_dq_frag[csr//R])\n                # NOTE: Unable to reuse dk_intrachunk_frag_dtype due to layout issue\n                # (we do gemm with the transpose of dk_intrachunk_frag_dtype)\n                T.copy(dq_frag, dq_shared)\n                dq_combined_frag = T.alloc_fragment([fused_chunk_size, N], accum_dtype)\n                T.copy(dq_shared, dq_combined_frag)\n                T.gemm(lkq_masked__or__dkq_masked_shared, k_shared, dq_combined_frag, transpose_A=True, clear_accum=False)\n                T.copy(dq_combined_frag, dq_shared)\n\n                # --- Inverse Rotary for dK and dQ + dAngles ---\n                angles_dk_frag = T.alloc_fragment([chunk_size, N//rotary_dim_divisor], T.float32)\n                T.copy(ANGLES[i_b, chunk_start:chunk_start+chunk_size, i_h, :], angles_dk_frag)\n                dk_first_half_frag = T.alloc_fragment([chunk_size, R, N//rotary_dim_divisor], dtype)\n                dk_second_half_frag = T.alloc_fragment([chunk_size, R, N//rotary_dim_divisor], dtype)\n                k_prerot_first_half_frag = T.alloc_fragment([chunk_size, R, N//rotary_dim_divisor], dtype)\n                k_prerot_second_half_frag = T.alloc_fragment([chunk_size, R, N//rotary_dim_divisor], dtype)\n                for cs, r, n in T.Parallel(chunk_size, R, N//rotary_dim_divisor):\n                    dk_first_half_frag[cs, r, n] = dk_shared[cs*R + r, n]\n                    dk_second_half_frag[cs, r, n] = dk_shared[cs*R + r, N//2 + n]\n                    k_prerot_first_half_frag[cs, r, n] = k_pre_rot_shared[cs*R + r, n]\n                    k_prerot_second_half_frag[cs, r, n] = k_pre_rot_shared[cs*R + r, N//2 + n]\n                # Compute the contribution of dk to dangle:\n                dangle_dk_frag = T.alloc_fragment([chunk_size, R, N//rotary_dim_divisor], T.float32)\n                for cs, r, n in T.Parallel(chunk_size, R, N//rotary_dim_divisor):\n                    dangle_dk_frag[cs, r, n] = dk_first_half_frag[cs, r, n] * (-k_prerot_first_half_frag[cs, r, n] * T.sin(angles_dk_frag[cs, n]) - k_prerot_second_half_frag[cs, r, n] * T.cos(angles_dk_frag[cs, n])) +\\\n                                            dk_second_half_frag[cs, r, n] * (k_prerot_first_half_frag[cs, r, n] * T.cos(angles_dk_frag[cs, n]) - k_prerot_second_half_frag[cs, r, n] * T.sin(angles_dk_frag[cs, n]))\n                T.copy(T.view(dangle_dk_frag, shape=[fused_chunk_size, N//rotary_dim_divisor]), dangle_dk__or__dq_shared)\n                \n                # Rotate dk_shared:\n                for cs, r, n in T.Parallel(chunk_size, R, N//rotary_dim_divisor):\n                    dk_shared[cs*R + r, n] = T.cos(angles_dk_frag[cs, n]) * dk_first_half_frag[cs, r, n] + T.sin(angles_dk_frag[cs, n]) * dk_second_half_frag[cs, r, n]\n                    dk_shared[cs*R + r, N//2 + n] = -T.sin(angles_dk_frag[cs, n]) * dk_first_half_frag[cs, r, n] + T.cos(angles_dk_frag[cs, n]) * dk_second_half_frag[cs, r, n]\n\n                dk_combined_frag = T.alloc_fragment([fused_chunk_size, N], accum_dtype)\n                T.copy(dk_shared, dk_combined_frag)\n\n                # Compute the effect of dqk_from_diag\n                q_dk_frag = T.alloc_fragment([fused_chunk_size, N], accum_dtype) # Keeping q_dk_frag in accum_dtype to avoid casting instructions\n                T.copy(q_pre_rot_shared, q_dk_frag) # NOTE: we need to use the pre-rotated version of q\n                q_dk_frag_reshaped = T.view(q_dk_frag, [chunk_size, R, N])\n                for csr_in, n in T.Parallel(fused_chunk_size, N):\n                    cs = csr_in // R\n                    for r_out in T.serial(R):\n                        csr_out = cs*R + r_out\n                        dk_combined_frag[csr_in, n] += dqk_from_diag_shared[csr_out, csr_in] * q_dk_frag_reshaped[cs, r_out, n]  \n                # Copy to gmem:\n                T.copy(dk_combined_frag, DK[i_b, fused_chunk_start:fused_chunk_start+fused_chunk_size, i_h, :])\n\n                angles_dq_frag = T.alloc_fragment([chunk_size, N//rotary_dim_divisor], T.float32)\n                T.copy(ANGLES[i_b, chunk_start:chunk_start+chunk_size, i_h, :], angles_dq_frag)\n                dq_first_half_frag = T.alloc_fragment([chunk_size, R, N//rotary_dim_divisor], dtype)\n                dq_second_half_frag = T.alloc_fragment([chunk_size, R, N//rotary_dim_divisor], dtype)\n                for cs, r, n in T.Parallel(chunk_size, R, N//rotary_dim_divisor):\n                    dq_first_half_frag[cs, r, n] = dq_shared[cs*R + r, n]\n                    dq_second_half_frag[cs, r, n] = dq_shared[cs*R + r, N//2 + n]\n                \n                # Compute the contribution of dq to dangle:\n                q_prerot_first_half_frag = T.alloc_fragment([chunk_size, R, N//rotary_dim_divisor], dtype)\n                q_prerot_second_half_frag = T.alloc_fragment([chunk_size, R, N//rotary_dim_divisor], dtype)\n                for cs, r, n in T.Parallel(chunk_size, R, N//rotary_dim_divisor):\n                    q_prerot_first_half_frag[cs, r, n] = q_pre_rot_shared[cs*R + r, n]\n                    q_prerot_second_half_frag[cs, r, n] = q_pre_rot_shared[cs*R + r, N//2 + n]\n                dangle_dq_frag = T.alloc_fragment([chunk_size, R, N//rotary_dim_divisor], T.float32)\n                T.copy(dangle_dk__or__dq_shared, T.view(dangle_dq_frag, shape=[fused_chunk_size, N//rotary_dim_divisor]))\n                for cs, r, n in T.Parallel(chunk_size, R, N//rotary_dim_divisor):\n                    dangle_dq_frag[cs, r, n] += dq_first_half_frag[cs, r, n] * (-q_prerot_first_half_frag[cs, r, n] * T.sin(angles_dq_frag[cs, n]) - q_prerot_second_half_frag[cs, r, n] * T.cos(angles_dq_frag[cs, n])) +\\\n                                            dq_second_half_frag[cs, r, n] * (q_prerot_first_half_frag[cs, r, n] * T.cos(angles_dq_frag[cs, n]) - q_prerot_second_half_frag[cs, r, n] * T.sin(angles_dq_frag[cs, n]))\n                # Sum dangle across R, and copy to gmem\n                dangle_frag_reduced = T.alloc_fragment([chunk_size,  N//rotary_dim_divisor], T.float32)\n                T.clear(dangle_frag_reduced)\n                for cs, n in T.Parallel(chunk_size, N//rotary_dim_divisor):\n                    for r in T.serial(R):\n                        dangle_frag_reduced[cs, n] += dangle_dq_frag[cs, r, n]\n                T.copy(dangle_frag_reduced, DANGLES[i_b, chunk_start:chunk_start+chunk_size, i_h, :])\n                # Rotate dq_shared:\n                for cs, r, n in T.Parallel(chunk_size, R, N//rotary_dim_divisor):\n                    dq_shared[cs*R + r, n] = T.cos(angles_dk_frag[cs, n]) * dq_first_half_frag[cs, r, n] + T.sin(angles_dk_frag[cs, n]) * dq_second_half_frag[cs, r, n]\n                    dq_shared[cs*R + r, N//2 + n] = -T.sin(angles_dk_frag[cs, n]) * dq_first_half_frag[cs, r, n] + T.cos(angles_dk_frag[cs, n]) * dq_second_half_frag[cs, r, n]\n                T.copy(dq_shared, dq_frag)\n\n                # Compute the effect of dqk_from_diag\n                for csr_out, n in T.Parallel(fused_chunk_size, N):\n                    cs = csr_out // R\n                    for r_in in T.serial(R):\n                        csr_in = cs*R + r_in\n                        dq_frag[csr_out, n] += dqk_from_diag_shared[csr_out, csr_in] * k_pre_rot_shared[csr_in, n]\n                # Copy to gmem:\n                T.copy(dq_frag, DQ[i_b, fused_chunk_start:fused_chunk_start+fused_chunk_size, i_h, :])\n\n                # --- Update Reverse-Passed State Gradient ---\n                da_cs_sum_dstates = T.alloc_var(T.float32)\n                T.copy(DA_CS[i_b, i_h, chunk_start+chunk_size-1], da_cs_sum_dstates)\n                for n, p in T.Parallel(N, P):\n                    dstates_frag[n, p] *= T.exp(da_cs_sum_dstates)\n                dPhiO_scaled_frag = T.alloc_fragment([fused_chunk_size, P], dtype)\n                T.copy(dPhiO_shared, dPhiO_scaled_frag)\n                dA_cs_dPhiO_frag = T.alloc_fragment([chunk_size], T.float32)\n                T.copy(dA_cs_shared, dA_cs_dPhiO_frag)\n                for csr, p in T.Parallel(fused_chunk_size, P):\n                    # DA_CS applies chunk-level decay to the passed gradient state.\n                    dPhiO_scaled_frag[csr, p] *= T.exp(dA_cs_dPhiO_frag[csr//R])\n                T.gemm(q_shared, dPhiO_scaled_frag, dstates_frag, transpose_A=True, clear_accum=False)\n                T.copy(dstates_frag, dstates_shared)\n\n            T.copy(dPsi_acc, DMIMO_V[i_b, i_h, :, :])\n            if hasD:\n                T.copy(dD_frag, DD[i_b, i_h])\n\n    return mamba_mimo_bwd_bwd_kernel\n\n\ndef mamba_mimo_bwd_combined(\n        dout,\n        q, \n        k, \n        v, \n        q_bias,\n        k_bias,\n        mimo_v, \n        mimo_o,\n        z,\n        mimo_z,\n        angles,\n        dA_cs,\n        dA_cs_rev,\n        dt,\n        trap,\n        D,\n        segsum,\n        chunk_size,\n        rotary_dim_divisor,\n        dtype,\n        bf_threads=128,\n        bf_num_stages=0,\n        bb_threads=256,\n        bb_num_stages=0,\n        ):\n    # TileLang kernel expects contiguous last-dim strides for DOUT.\n    B, S, R, G, N = q.shape\n    H, P = v.shape[-2], v.shape[-1]\n    reduceO = mimo_o is not None\n\n    dmimo_o = torch.empty([B, H, R, P], dtype=mimo_v.dtype, device=mimo_v.device) if reduceO else None\n    states = torch.empty([B, H, S//chunk_size, N, P], dtype=v.dtype, device=v.device) # NOTE: states dtype is set to v.dtype\n    \n    if z is not None:\n        dz_tilelang = torch.empty_like(v)\n        dmimo_z = torch.empty([B, H, R, P], dtype=mimo_v.dtype, device=mimo_v.device)\n    else:\n        dz_tilelang = None\n        dmimo_z = None\n    qk_dot = torch.zeros([B, H, S, R, R], dtype=q.dtype, device=q.device)\n\n\n    if isinstance(dtype, torch.dtype):\n        dtype_str = str(dtype).replace(\"torch.\", \"\")\n    else:\n        dtype_str = dtype\n    bwd_fwd_kernel = mamba_mimo_bwd_fwd(B, S, H, G, N, P, R, \n                                             z is not None,\n                                             D is not None,\n                                             reduceO,\n                                             chunk_size, \n                                             rotary_dim_divisor,\n                                             dtype_str,\n                                             bf_threads,\n                                             bf_num_stages)\n    bwd_fwd_kernel(\n                    dout,\n                    q, \n                    k, \n                    v, \n                    q_bias,\n                    k_bias,\n                    mimo_v, \n                    mimo_o,\n                    dmimo_o,\n                    states,\n                    z,\n                    mimo_z,\n                    dz_tilelang,\n                    dmimo_z,\n                    angles,\n                    dA_cs,\n                    dA_cs_rev,\n                    dt,\n                    trap,\n                    D,\n                    qk_dot,\n                    segsum,\n                    )\n    if reduceO:\n        dmimo_o = dmimo_o.sum(dim=0)\n\n    \n    dq_tilelang = torch.empty([B, S, R, H, N], dtype=q.dtype, device=q.device)\n    dk_tilelang = torch.empty([B, S, R, H, N], dtype=k.dtype, device=k.device)\n    dv_tilelang = torch.empty_like(v)\n    dmimo_v = torch.empty([B, H, R, P], dtype=mimo_v.dtype, device=mimo_v.device)\n    dD = torch.empty([B, H], dtype=D.dtype, device=D.device) if D is not None else None\n    dangles = torch.zeros([B, S, H, N//rotary_dim_divisor], dtype=angles.dtype, device=angles.device)\n    dfactor = torch.zeros([B, H, S], dtype=torch.float32, device=trap.device)\n    dgamma_diag = torch.zeros([B, H, S], dtype=torch.float32, device=trap.device)\n    ddA = torch.zeros([B, H, S], dtype=torch.float32, device=dt.device)\n    dSSdA = torch.zeros([B, H, S//chunk_size, chunk_size, chunk_size], dtype=torch.float32, device=dt.device)\n    ddA_cs_rev = torch.zeros([B, H, S], dtype=torch.float32, device=dt.device)\n    ddA_cs = torch.zeros([B, H, S], dtype=torch.float32, device=dt.device)\n    \n    \n    bwd_bwd_kernel = mamba_mimo_bwd_bwd(B, S, H, G, N, P, R, \n                                             z is not None,\n                                             D is not None,\n                                             reduceO,\n                                             chunk_size, \n                                             rotary_dim_divisor,\n                                             dtype_str,\n                                             bb_threads,\n                                             bb_num_stages)\n    bwd_bwd_kernel(\n            dout,\n            q, \n            k,\n            v,\n            q_bias,\n            k_bias,\n            mimo_v, \n            mimo_o,\n            dk_tilelang.view(B, S*R, H, N), \n            dv_tilelang, \n            dmimo_v,\n            states,\n            dq_tilelang.view(B, S*R, H, N),\n            z,\n            mimo_z,\n            angles,\n            dA_cs,\n            dA_cs_rev,\n            dt,\n            trap,\n            dfactor,\n            dgamma_diag,\n            dangles,\n            D,\n            dD,\n            qk_dot,\n            ddA,\n            dSSdA,\n            ddA_cs_rev,\n            ddA_cs,\n            segsum,\n            )\n    \n    if G == 1:\n        dq_bias_tilelang = dq_tilelang.sum(dim=(0, 1)).permute((1, 0, 2))\n        dk_bias_tilelang = dk_tilelang.sum(dim=(0, 1)).permute((1, 0, 2))\n        dq_tilelang = dq_tilelang.sum(dim=3, keepdim=True)\n        dk_tilelang = dk_tilelang.sum(dim=3, keepdim=True)\n        dmimo_v = dmimo_v.sum(dim=0)\n        dmimo_z = dmimo_z.sum(dim=0) if dmimo_z is not None else None\n        dD = dD.sum(dim=0) if dD is not None else None\n    elif G == H:\n        dq_bias_tilelang = dq_tilelang.sum(dim=(0, 1)).permute((1, 0, 2))\n        dk_bias_tilelang = dk_tilelang.sum(dim=(0, 1)).permute((1, 0, 2))\n        dmimo_v = dmimo_v.sum(dim=0)\n        dmimo_z = dmimo_z.sum(dim=0) if dmimo_z is not None else None\n        dD = dD.sum(dim=0) if dD is not None else None\n    else:\n        raise ValueError(f\"G value of {G} is not currently supported!\")\n\n    ddt, dtrap = bwd_dtrap_ddt_triton(\n        trap, dt, dfactor, dgamma_diag, chunk_size\n    )\n\n    ddA += bwd_dadt_fused_triton(\n        dSSdA, segsum, ddA_cs, ddA_cs_rev, dA_cs, dA_cs_rev, chunk_size\n    )\n\n\n    return (dq_tilelang, dk_tilelang, dv_tilelang, \n            ddA, ddt, dtrap, dq_bias_tilelang, dk_bias_tilelang,\n            dmimo_v, dmimo_z, dmimo_o, dangles, \n            dD, dz_tilelang)\n"
  },
  {
    "path": "mamba_ssm/ops/tilelang/mamba3/mamba3_mimo_fwd.py",
    "content": "\"\"\"\nTilelang implementation of Mamba3 forward kernel,\nwith MIMO support.\n\nCopyright (c) 2026, Dao AI Lab, Goombalab\n\n\"\"\"\n\nimport torch\nimport tilelang\nimport tilelang.language as T\nfrom tilelang.profiler import do_bench\nfrom tilelang.autotuner import autotune\n\nimport itertools\nimport argparse\nfrom typing import Optional, Tuple\n\n\n# NOTE: Uncomment the following to autotune:\n# def get_configs():\n#     iter_params = dict(num_stages=[0, 1, 2, 3], threads=[128, 256, 512])\n#     # iter_params = dict(num_stages=[2], threads=[128])\n#     return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]\n\n# @autotune(\n#     configs=get_configs(),\n#     warmup=3,\n#     rep=20,\n# )\n@tilelang.jit(\n    out_idx=[],\n    pass_configs={\n        tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,\n        tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,\n        tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,\n    })\ndef mamba_mimo_fwd(\n    B,\n    S,\n    H,\n    G,\n    N,\n    P,\n    R,\n    hasZ,\n    hasD,\n    reduceO,\n    return_final_state=False,\n    chunk_size: int = 16,\n    rotary_dim_divisor = 4,\n    dtype: str = 'bfloat16',\n    threads: int = 128,\n    num_stages: int = 0,\n) -> torch.Tensor:\n\n    accum_dtype = 'float32'\n\n    # Block sizes for K and V dimensions - use full dimensions (no tiling)\n    assert S % chunk_size == 0, \"Sequence length must be divisible by chunk_size\"\n\n    nchunks = tilelang.cdiv(S, chunk_size)\n    fused_chunk_size = chunk_size * R\n\n    if reduceO:\n        O_shape = (B, S, H, P)\n    else:\n        O_shape = (B, S, R, H, P)\n\n\n    @T.prim_func\n    def mamba_mimo_fwd_kernel(\n            Q: T.Tensor([B, S, R, G, N], dtype),  # type: ignore\n            K: T.Tensor([B, S, R, G, N], dtype),  # type: ignore\n            V: T.Tensor([B, S, H, P], dtype),  # type: ignore\n            O: T.Tensor(O_shape, dtype),  # type: ignore\n            Q_BIAS: T.Tensor([H, R, N], T.float32),  # type: ignore\n            K_BIAS: T.Tensor([H, R, N], T.float32),  # type: ignore\n            MIMO_V: T.Tensor([H, R, P], T.float32), # type: ignore\n            MIMO_O: T.Tensor([H, R, P], T.float32), # type: ignore\n            Z: T.Tensor([B, S, H, P], dtype),  # type: ignore\n            D: T.Tensor([H], T.float32),  # type: ignore\n            MIMO_Z: T.Tensor([H, R, P], T.float32), # type: ignore\n            ANGLES: T.Tensor([B, S, H, N//rotary_dim_divisor], T.float32), # type: ignore\n            DA_CS: T.Tensor([B, H, S], T.float32), # type: ignore\n            DA_CS_REV: T.Tensor([B, H, S], T.float32), # type: ignore\n            DT: T.Tensor([B, H, S], T.float32), # type: ignore\n            TRAP: T.Tensor([B, H, S], dtype), # type: ignore\n            SEGSUM: T.Tensor([B, H, nchunks, chunk_size, chunk_size], T.float32), # type: ignore\n\n            FINAL_STATE: T.Tensor([B, H, N, P], T.float32),  # type: ignore\n            FINAL_K: T.Tensor([B, R, H, N], dtype)  # type: ignore\n            ):\n        \"\"\"\n        Overview:\n            Fused chunked forward pass that combines MIMO projections with recurrent state updates.\n            Computes interchunk and intrachunk contributions with optional D and Z paths,\n            then writes output activations.\n\n        Inputs:\n            - Activations: Q, K, V.\n            - Projection parameters/biases: MIMO_V (Psi), MIMO_O (Phi), optional MIMO_Z (Zeta), ANGLES,\n              and Q_BIAS/K_BIAS.\n            - Optional modifiers: Z, and D.\n            - Discretization tensors: DA_CS, DA_CS_REV, DT, TRAP, and SEGSUM.\n\n        Outputs:\n            - O: fused forward output activations.\n            - FINAL_STATE: final recurrent states (if return_state is True).\n            - FINAL_K: final K tensor (if return_state is True, for use in decode)\n\n        Notation:\n            - Psi: MIMO X projection.\n            - Phi: MIMO O projection.\n            - Zeta: MIMO Z projection.\n            - Trap: convex-combination modulator used in exponential-trapezoidal discretization.\n        \"\"\"\n        \n        with T.Kernel(H, B, threads=threads) as (i_h, i_b):\n            # --- Kernel Setup ---\n            # GQA support: map V head to Q/K head\n            i_h_qk = i_h // (H // G)\n\n            # --- Buffer Allocation ---\n            q_shared = T.alloc_shared([fused_chunk_size, N], dtype)\n            k_shared = T.alloc_shared([fused_chunk_size, N], dtype)\n            q_bias_frag = T.alloc_fragment([R, N], dtype)\n            k_bias_frag = T.alloc_fragment([R, N], dtype)\n\n            angles_shared = T.alloc_shared([chunk_size, N], dtype)\n\n            PsiV_shared = T.alloc_shared([fused_chunk_size, P], dtype)\n            qs_shared = T.alloc_shared([fused_chunk_size, P], dtype)\n            o_shared = T.alloc_shared([chunk_size, P], dtype)\n            v_shared = T.alloc_shared([chunk_size, P], dtype)\n            states_accum_cast_shared = T.alloc_shared([N, P], dtype)\n            qk_intrachunk_shared = T.alloc_shared([fused_chunk_size, fused_chunk_size], dtype)\n            qk_dot_full_shared = T.alloc_shared([fused_chunk_size, fused_chunk_size], dtype)\n\n            # --- Swizzling Annotation ---\n            T.annotate_layout({\n                q_shared: tilelang.layout.make_swizzled_layout(q_shared),\n                k_shared: tilelang.layout.make_swizzled_layout(k_shared),\n                v_shared: tilelang.layout.make_swizzled_layout(v_shared),\n\n                angles_shared: tilelang.layout.make_swizzled_layout(angles_shared),\n\n                PsiV_shared: tilelang.layout.make_swizzled_layout(PsiV_shared),\n                qs_shared: tilelang.layout.make_swizzled_layout(qs_shared),\n                o_shared: tilelang.layout.make_swizzled_layout(o_shared),\n                states_accum_cast_shared: tilelang.layout.make_swizzled_layout(states_accum_cast_shared),\n                qk_dot_full_shared: tilelang.layout.make_swizzled_layout(qk_dot_full_shared),\n                qk_intrachunk_shared: tilelang.layout.make_swizzled_layout(qk_intrachunk_shared),\n            })\n            T.use_swizzle(10, \"row\")\n\n            T.no_set_max_nreg()\n\n            # --- Per-Head Constants / Running State ---\n            states_frag = T.alloc_fragment([N, P], accum_dtype)\n            T.clear(states_frag)\n\n            phi_frag_intrachunk = T.alloc_fragment([R, P], dtype=dtype)\n            if reduceO:\n                T.copy(MIMO_O[i_h, :, :], phi_frag_intrachunk)\n            Psi_frag = T.alloc_fragment([R, P], dtype)\n            T.copy(MIMO_V[i_h, :, :], Psi_frag)\n\n            T.copy(Q_BIAS[i_h, :, :], q_bias_frag)\n            T.copy(K_BIAS[i_h, :, :], k_bias_frag)\n\n            # --- Chunk Loop ---\n            for i in T.Pipelined(0, nchunks, num_stages=num_stages):\n                chunk_start = i * chunk_size\n\n                # --- Discretization Factors (Shifted Gamma + Trap Scale) ---\n                trap_shifted_frag = T.alloc_fragment([chunk_size], T.float32)\n                T.copy(TRAP[i_b, i_h, chunk_start+1: chunk_start+chunk_size+1], trap_shifted_frag)\n                dt_shifted_frag = T.alloc_fragment([chunk_size], dtype)\n                T.copy(DT[i_b, i_h, chunk_start+1: chunk_start+chunk_size+1], dt_shifted_frag)\n                shifted_gamma_frag = T.alloc_fragment([chunk_size], dtype)\n                for cs in T.Parallel(chunk_size):\n                    shifted_gamma_frag[cs] = T.if_then_else(chunk_start + cs < (S - 1), \n                                                            dt_shifted_frag[cs] * T.sigmoid(-trap_shifted_frag[cs]), \n                                                            0.0)\n                shifted_gamma_shared = T.alloc_shared([chunk_size], dtype)\n                T.copy(shifted_gamma_frag, shifted_gamma_shared)\n\n                trap_frag = T.alloc_fragment([chunk_size], T.float32)\n                T.copy(TRAP[i_b, i_h, chunk_start: chunk_start+chunk_size], trap_frag)\n                dt_frag = T.alloc_fragment([chunk_size], dtype)\n                T.copy(DT[i_b, i_h, chunk_start: chunk_start+chunk_size], dt_frag)\n                gamma_frag = T.alloc_fragment([chunk_size], T.float32)\n                for cs in T.Parallel(chunk_size):\n                    gamma_frag[cs] = dt_frag[cs] * T.sigmoid(trap_frag[cs])\n                trap_scale_frag = T.alloc_fragment([chunk_size], dtype)\n                for cs in T.Parallel(chunk_size):\n                    trap_scale_frag[cs] = gamma_frag[cs] + shifted_gamma_shared[cs]\n                trap_scale_shared = T.alloc_shared([chunk_size], dtype)\n                T.copy(trap_scale_frag, trap_scale_shared)\n\n                # --- Up-Project V and Prepare Biased Q/K ---\n                PsiV_frag = T.alloc_fragment([chunk_size, R, P], dtype)\n                for cs, p in T.Parallel(chunk_size, P):\n                    v_shared[cs, p] = V[i_b, chunk_start+cs, i_h, p]\n                for cs, r, p in T.Parallel(chunk_size, R, P):\n                    PsiV_frag[cs, r, p] = v_shared[cs, p] * Psi_frag[r, p]\n                PsiV_reshaped_frag = T.view(PsiV_frag, shape=[fused_chunk_size, P])\n                T.copy(PsiV_reshaped_frag, PsiV_shared)\n\n                q_frag = T.alloc_fragment([chunk_size, R, N], dtype)\n                T.copy(Q[i_b, chunk_start:chunk_start+chunk_size, :, i_h_qk, :], q_frag)\n                for cs, r, n in T.Parallel(chunk_size, R, N):\n                    q_frag[cs, r, n] += q_bias_frag[r, n]\n                T.copy(T.view(q_frag, shape=[fused_chunk_size, N]), q_shared)\n\n                k_frag = T.alloc_fragment([chunk_size, R, N], dtype)\n                T.copy(K[i_b, chunk_start:chunk_start+chunk_size, :, i_h_qk, :], k_frag)\n                for cs, r, n in T.Parallel(chunk_size, R, N):\n                    k_frag[cs, r, n] += k_bias_frag[r, n]\n                T.copy(T.view(k_frag, shape=[fused_chunk_size, N]), k_shared)\n\n                # --- Cache Diagonal qk_dot Path ---\n                # Keep full qk_dot in shared memory because we reuse same-step R x R blocks later.\n                qk_dot_frag = T.alloc_fragment([fused_chunk_size, fused_chunk_size], dtype=accum_dtype)\n                T.gemm(q_shared, k_shared, qk_dot_frag, transpose_B=True, clear_accum=True)\n                T.copy(qk_dot_frag, qk_dot_full_shared)\n                # Option B: extremely slow\n                # qk_dot_frag = T.alloc_fragment([chunk_size, R, R], dtype=accum_dtype)\n                # T.clear(qk_dot_frag)\n                # for cs, r_out, r_in in T.Parallel(chunk_size, R, R):\n                #     for n in T.serial(N):\n                #         qk_dot_frag[cs, r_out, r_in] += (\n                #             q_frag[cs, r_out, n] * k_frag[cs, r_in, n]\n                #         )\n                # T.copy(T.view(qk_dot_frag, shape=[fused_chunk_size, R]), qk_dot_shared)\n                # NOTE (\"option C\"): The following fails Tilelang compilation:\n                # qk_predot_frag = T.alloc_fragment([chunk_size, R, R, N], dtype)\n                # for cs, r_out, r_in, n in T.Parallel(chunk_size, R, R, N):\n                #     qk_predot_frag[cs, r_out, r_in, n] = q_frag[cs, r_out, n] * k_frag[cs, r_in, n]\n                # qk_dot_frag = T.alloc_fragment([chunk_size, R, R], dtype)\n                # T.reduce_sum(qk_predot_frag, qk_dot_frag, dim=-1, clear=True)\n                # T.copy(T.view(qk_dot_frag, shape=[fused_chunk_size, R]), qk_dot_shared)\n\n                # --- Rotary Q + Interchunk Contribution ---\n                q_first_half_frag = T.alloc_fragment([chunk_size, R, N//rotary_dim_divisor], dtype)\n                q_second_half_frag = T.alloc_fragment([chunk_size, R, N//rotary_dim_divisor], dtype)\n\n                for cs, r, n in T.Parallel(chunk_size, R, N//rotary_dim_divisor):\n                    q_first_half_frag[cs, r, n] = q_shared[cs*R + r, n]\n                    q_second_half_frag[cs, r, n] = q_shared[cs*R + r, N//2 + n]\n\n                # NOTE: angles are casted to fp32 for numerical stability\n                angles_frag = T.alloc_fragment([chunk_size, N//rotary_dim_divisor], T.float32)\n                T.copy(ANGLES[i_b, chunk_start:chunk_start+chunk_size, i_h, :], angles_frag)\n\n                for cs, r, n in T.Parallel(chunk_size, R, N//rotary_dim_divisor):\n                    q_shared[cs*R + r, n] = T.cos(angles_frag[cs, n]) * q_first_half_frag[cs, r, n] - T.sin(angles_frag[cs, n]) * q_second_half_frag[cs, r, n]\n                    q_shared[cs*R + r, N//2 + n] = T.sin(angles_frag[cs, n]) * q_first_half_frag[cs, r, n] + T.cos(angles_frag[cs, n]) * q_second_half_frag[cs, r, n]\n\n                o_mimo_accum_frag = T.alloc_fragment([fused_chunk_size, P], dtype=accum_dtype)\n                T.copy(states_frag, states_accum_cast_shared)\n                T.gemm(q_shared, states_accum_cast_shared, o_mimo_accum_frag, clear_accum=True)\n\n                # --- Rotary K + Trap Scaling + Intrachunk Contribution ---\n                k_first_half_frag = T.alloc_fragment([chunk_size, R, N//rotary_dim_divisor], dtype)\n                k_second_half_frag = T.alloc_fragment([chunk_size, R, N//rotary_dim_divisor], dtype)\n\n                for cs, r, n in T.Parallel(chunk_size, R, N//rotary_dim_divisor):\n                    k_first_half_frag[cs, r, n] = k_shared[cs*R + r, n]\n                    k_second_half_frag[cs, r, n] = k_shared[cs*R + r, N//2 + n]\n                \n                for cs, r, n in T.Parallel(chunk_size, R, N//rotary_dim_divisor):\n                    k_shared[cs*R + r, n] = T.cos(angles_frag[cs, n]) * k_first_half_frag[cs, r, n] - T.sin(angles_frag[cs, n]) * k_second_half_frag[cs, r, n]\n                    k_shared[cs*R + r, N//2 + n] = T.sin(angles_frag[cs, n]) * k_first_half_frag[cs, r, n] + T.cos(angles_frag[cs, n]) * k_second_half_frag[cs, r, n]\n\n                if i == nchunks - 1 and return_final_state:\n                    seq_boundary = T.min(chunk_start + chunk_size, S) - chunk_start\n                    for csr, n in T.Parallel(fused_chunk_size, N):\n                        if csr >= (seq_boundary - 1) * R and csr < seq_boundary * R:  # Only copy the last chunk's R rows to FINAL_K\n                            FINAL_K[i_b, csr % R, i_h, n] = k_shared[csr, n]\n\n                k_trap_scaled_frag = T.alloc_fragment([fused_chunk_size, N], dtype)\n                T.copy(k_shared, k_trap_scaled_frag)\n                for csr, n in T.Parallel(fused_chunk_size, N):\n                    k_trap_scaled_frag[csr, n] *= trap_scale_shared[csr//R]\n                T.copy(k_trap_scaled_frag, k_shared)\n\n                qk_intrachunk_frag = T.alloc_fragment([fused_chunk_size, fused_chunk_size], dtype=accum_dtype)\n                T.gemm(q_shared, k_shared, qk_intrachunk_frag, transpose_B=True, clear_accum=True)\n\n                # Strictly causal mask over chunk steps (exclude same-step diagonal).\n                da_cs__or__exp_da_cs_shared = T.alloc_shared([chunk_size], T.float32)\n                T.copy(DA_CS[i_b, i_h, chunk_start:chunk_start+chunk_size], da_cs__or__exp_da_cs_shared)\n                qk_intrachunk_masked_frag = T.alloc_fragment([fused_chunk_size, fused_chunk_size], dtype=dtype)\n                for csr_i, csr_j in T.Parallel(fused_chunk_size, fused_chunk_size):\n                    qk_intrachunk_masked_frag[csr_i, csr_j] = T.if_then_else(\n                                                csr_i//R > csr_j//R, # NOTE: we do indeed want to exclude the diagonal\n                                                qk_intrachunk_frag[csr_i, csr_j] \n                                                * T.exp(SEGSUM[i_b, i_h, i, csr_i//R, csr_j//R]),\n                                                0.0\n                                            )\n\n                # Exponentiate da_cs__or__exp_da_cs_shared so that later usage does not have to:\n                for cs in T.Parallel(chunk_size):\n                    da_cs__or__exp_da_cs_shared[cs] = T.exp(da_cs__or__exp_da_cs_shared[cs])\n\n                exp_da_cs_frag = T.alloc_fragment([chunk_size], dtype=T.float32)\n                T.copy(da_cs__or__exp_da_cs_shared, exp_da_cs_frag)\n                for csr, p in T.Parallel(fused_chunk_size, P):\n                    o_mimo_accum_frag[csr, p] *= exp_da_cs_frag[csr//R]\n\n                # NOTE: if we gemm with qk_intrachunk_masked_frag the compiler will\n                # error with layout issue if threads != 128:\n                # Copy via shared memory to satisfy layout constraints before GEMM.\n                T.copy(qk_intrachunk_masked_frag, qk_intrachunk_shared)\n                # Adding the two intermediate outputs together (interchunk += intrachunk)\n                T.gemm(qk_intrachunk_shared, PsiV_shared, o_mimo_accum_frag, clear_accum=False)\n\n                # --- Add Diagonal Terms (qk_dot and optional D) ---\n                qkdot_psiv_frag = T.alloc_fragment([chunk_size, R, P], dtype=dtype)\n                T.clear(qkdot_psiv_frag)\n                for cs, r_out, p in T.Parallel(chunk_size, R, P):\n                    for r_in in T.serial(R):\n                        qkdot_psiv_frag[cs, r_out, p] += qk_dot_full_shared[cs * R + r_out, cs * R + r_in] * PsiV_shared[cs * R + r_in, p]                    \n                    qkdot_psiv_frag[cs, r_out, p] *= gamma_frag[cs] # Apply shifted gamma\n\n                if hasD:\n                    PsiV_D_frag = T.alloc_fragment([chunk_size, R, P], T.float32)\n                    for cs, r, p in T.Parallel(chunk_size, R, P):\n                        PsiV_D_frag[cs, r, p] = PsiV_shared[cs * R + r, p]\n                    D_var = T.alloc_var(T.float32)\n                    T.copy(D[i_h], D_var)\n                    for cs, r_out, p in T.Parallel(chunk_size, R, P):\n                        qkdot_psiv_frag[cs, r_out, p] += D_var * PsiV_D_frag[cs, r_out, p]\n                qkdot_psiv_reshaped_frag = T.view(qkdot_psiv_frag, shape=[fused_chunk_size, P])\n                for csr, p in T.Parallel(fused_chunk_size, P):\n                    o_mimo_accum_frag[csr, p] += qkdot_psiv_reshaped_frag[csr, p]\n\n                # --- Optional Z Gating + Down-Projection ---\n                if reduceO:\n                    if hasZ:\n                        z_frag = T.alloc_fragment([chunk_size, P], dtype)\n                        T.copy(Z[i_b, chunk_start:chunk_start+chunk_size, i_h, :], z_frag)\n                        z_expanded_frag = T.alloc_fragment([chunk_size, R, P], dtype)\n                        for cs, r, p in T.Parallel(chunk_size, R, P):\n                            # Apply SiLU to z_expanded_frag[cs, r, p]:\n                            o_gated = z_frag[cs, p] * MIMO_Z[i_h, r, p] * 0.5\n                            z_expanded_frag[cs, r, p] = o_gated * T.tanh(o_gated) + o_gated\n\n                    lqk_PsiV_reshaped_frag = T.view(o_mimo_accum_frag, shape=[chunk_size, R, P])\n                    if hasZ:\n                        for cs, r, p in T.Parallel(chunk_size, R, P):\n                            lqk_PsiV_reshaped_frag[cs, r, p] *= phi_frag_intrachunk[r, p] * z_expanded_frag[cs, r, p]\n                    else:\n                        for cs, r, p in T.Parallel(chunk_size, R, P):\n                            lqk_PsiV_reshaped_frag[cs, r, p] *= phi_frag_intrachunk[r, p]\n                    lqk_PsiV_reshaped_shared = T.alloc_shared([chunk_size, R, P], dtype)\n                    T.copy(lqk_PsiV_reshaped_frag, lqk_PsiV_reshaped_shared)\n                    o_frag = T.alloc_fragment([chunk_size, P], dtype)\n                    T.clear(o_frag)\n                    for r in T.serial(R):\n                        for cs, p in T.Parallel(chunk_size, P):\n                            o_frag[cs, p] += lqk_PsiV_reshaped_shared[cs, r, p]\n                    T.copy(o_frag, O[i_b, chunk_start:chunk_start+chunk_size, i_h, :])\n                else:\n                    if hasZ:\n                        z_frag = T.alloc_fragment([chunk_size, P], dtype)\n                        T.copy(Z[i_b, chunk_start:chunk_start+chunk_size, i_h, :], z_frag)\n                        z_expanded_frag = T.alloc_fragment([chunk_size, R, P], dtype)\n                        for cs, r, p in T.Parallel(chunk_size, R, P):\n                            # Apply SiLU to z_expanded_frag[cs, r, p]:\n                            o_gated = z_frag[cs, p] * MIMO_Z[i_h, r, p] * 0.5\n                            z_expanded_frag[cs, r, p] = o_gated * T.tanh(o_gated) + o_gated\n                        lqk_PsiV_reshaped_shared = T.alloc_shared([chunk_size, R, P], dtype)\n                        for cs, r, p in T.Parallel(chunk_size, R, P):\n                            lqk_PsiV_reshaped_shared[cs, r, p] = o_mimo_accum_frag[cs* R + r, p] * z_expanded_frag[cs, r, p]\n                        # T.copy(lqk_PsiV_frag, lqk_PsiV_reshaped_shared)\n                        # for cs, r, p in T.Parallel(chunk_size, R, P):\n                        #     lqk_PsiV_reshaped_shared[cs, r, p] *= z_expanded_frag[cs, r, p]\n                    else:\n                        lqk_PsiV_reshaped_shared = T.alloc_shared([chunk_size, R, P], dtype)\n                        # T.copy(lqk_PsiV_reshaped_frag, lqk_PsiV_reshaped_shared)\n                        for cs, r, p in T.Parallel(chunk_size, R, P):\n                            lqk_PsiV_reshaped_shared[cs, r, p] = o_mimo_accum_frag[cs* R + r, p]\n                    T.copy(lqk_PsiV_reshaped_shared, O[i_b, chunk_start:chunk_start+chunk_size, :, i_h, :])\n\n                # --- Recurrent State Update ---\n                # DA_CS_REV scales per-step K contributions for state accumulation.\n                dA_cs_rev_frag = T.alloc_fragment([chunk_size], T.float32)\n                T.copy(DA_CS_REV[i_b, i_h, chunk_start:chunk_start+chunk_size], dA_cs_rev_frag)\n\n                k_state_frag = T.alloc_fragment([fused_chunk_size, N], dtype)\n                T.copy(k_shared, k_state_frag)\n                for csr, n in T.Parallel(fused_chunk_size, N):\n                    k_state_frag[csr, n] *= T.exp(dA_cs_rev_frag[csr//R])\n\n                # DA_CS(last) applies the chunk-level decay to the carried state.\n                da_cs_sum = T.alloc_var(T.float32)\n                T.copy(DA_CS[i_b, i_h, chunk_start+chunk_size-1], da_cs_sum)\n                for n, p in T.Parallel(N, P):\n                    states_frag[n, p] *= T.exp(da_cs_sum)\n                T.gemm(k_state_frag, PsiV_shared, states_frag, transpose_A=True, clear_accum=False)\n            \n            # --- Save Last State (if applicable) ---\n            if return_final_state:\n                T.copy(states_frag, FINAL_STATE[i_b, i_h, :, :])\n                \n\n    return mamba_mimo_fwd_kernel\n\n\ndef mamba_mimo_forward(q, k, v, \n                       q_bias, k_bias, \n                       mimo_v, mimo_o, \n                       z, D, \n                       mimo_z, \n                       angles, \n                       dA_cs,\n                       dA_cs_rev,\n                       dt,\n                       trap,\n                       segsum,\n                       chunk_size, rotary_dim_divisor, dtype, \n                       return_state=False,\n                       threads=128, \n                       num_stages=0):\n    B, S, R, G, N = q.shape\n    H, P = v.shape[-2], v.shape[-1]\n    if isinstance(dtype, torch.dtype):\n        tl_dtype = str(dtype).replace(\"torch.\", \"\")\n    else:\n        tl_dtype = dtype\n    reduceO = mimo_o is not None\n    kernel = mamba_mimo_fwd(B, S, H, G, N, P, R, \n                                       z is not None, \n                                       D is not None, \n                                       reduceO,\n                                       return_final_state=return_state,\n                                       chunk_size=chunk_size, \n                                       rotary_dim_divisor=rotary_dim_divisor, \n                                       dtype=tl_dtype, \n                                       threads=threads, \n                                       num_stages=num_stages)\n    # print(kernel.get_kernel_source()) # NOTE: prints compiled CUDA code\n    if reduceO:\n        o = torch.empty((B, S, H, P), device='cuda', dtype=dtype)\n    else:\n        o = torch.empty((B, S, R, H, P), device='cuda', dtype=dtype)\n    # Kernel always declares all tensor parameters; pass dummies for None args\n    mimo_o_arg = mimo_o if reduceO else torch.empty((H, R, P), device=q.device, dtype=torch.float32)\n    z_arg = z if z is not None else torch.empty((B, S, H, P), device=q.device, dtype=dtype)\n    D_arg = D if D is not None else torch.empty((H,), device=q.device, dtype=torch.float32)\n    mimo_z_arg = mimo_z if mimo_z is not None else torch.empty((H, R, P), device=q.device, dtype=torch.float32)\n\n    h = torch.empty((B, H, N, P), device='cuda', dtype=torch.float32) if return_state else None\n    k_final = torch.empty((B, R, H, N), device='cuda', dtype=dtype) if return_state else None\n\n    kernel( q,\n            k,\n            v, o,\n            q_bias, k_bias,\n            mimo_v, mimo_o_arg,\n            z_arg, D_arg, mimo_z_arg,\n            angles,\n            dA_cs,\n            dA_cs_rev,\n            dt,\n            trap,\n            segsum,\n            h,\n            k_final\n            )\n    return o, h, k_final\n"
  },
  {
    "path": "mamba_ssm/ops/triton/__init__.py",
    "content": ""
  },
  {
    "path": "mamba_ssm/ops/triton/angle_cumsum.py",
    "content": "# Copyright (c) 2025, Tri Dao.\n\nfrom typing import Optional\nimport math\n\nimport torch\n\nimport triton\nimport triton.language as tl\nfrom triton.language.extra import libdevice\n\nclass AngleDtFn(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx,\n                angle: torch.Tensor,   # (B, S, H, D)\n                dt: torch.Tensor,      # (B, S, H)\n                chunk_size: int = 128  # power of 2\n                ) -> torch.Tensor:\n        # run Triton fwd\n        out = apply_angle_dt_fwd(angle, dt, chunk_size=chunk_size)\n        # save for bwd\n        ctx.save_for_backward(angle, dt)\n        ctx.chunk_size = int(chunk_size)\n        return out\n\n    @staticmethod\n    def backward(ctx, grad_out: torch.Tensor):\n        angle, dt = ctx.saved_tensors\n        # run Triton bwd\n        grad_dt, grad_angle = apply_angle_dt_bwd(\n            grad_out=grad_out, angle=angle, dt=dt, chunk_size=ctx.chunk_size\n        )\n        # grads align with (angle, dt, chunk_size)\n        return grad_angle, grad_dt, None\n\n\ndef angle_dt(angle: torch.Tensor,\n             dt: torch.Tensor,\n             *,\n             chunk_size: int = 128) -> torch.Tensor:\n    return AngleDtFn.apply(angle, dt, chunk_size)\n\n\n@triton.jit\ndef cumsum_kernel(\n    OUT,        # Output tensor (batch, seqlen, nheads, dim)\n    X,          # Input tensor (batch, seqlen, nheads, dim)\n    seqlen,\n    dim,\n    stride_out, # (batch, seqlen, nheads, dim)\n    stride_x,   # (batch, seqlen, nheads, dim)\n    # Meta-parameters\n    BLOCK_S: tl.constexpr,\n    BLOCK_D: tl.constexpr,\n):\n    # Program IDs\n    pid_h = tl.program_id(axis=0)  # Head index (one per head)\n    pid_d = tl.program_id(axis=1)  # Dim block\n    pid_b = tl.program_id(axis=2)  # Batch index (one per batch element)\n\n    # Offset pointers by batch and head\n    X = X + pid_b * stride_x[0] + pid_h * stride_x[2]\n    OUT = OUT + pid_b * stride_out[0] + pid_h * stride_out[2]\n\n    # Compute ranges\n    dim_range = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)\n    dim_mask = dim_range < dim\n\n    # Load entire sequence for this batch, head, and dim block\n    seq_range = tl.arange(0, BLOCK_S)[:, None]  # (BLOCK_S, 1)\n\n    # Load input: (seqlen, dim) for this batch and head\n    x_ptrs = X + seq_range * stride_x[1] + dim_range[None, :] * stride_x[3]\n    x_mask = (seq_range < seqlen) & dim_mask[None, :]\n    x_vals = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.float32)\n\n    # Compute cumulative sum along sequence dimension (axis 0)\n    cumsum_vals = tl.cumsum(x_vals, axis=0)\n\n    # Store output: (seqlen, dim) for this batch and head\n    out_ptrs = OUT + seq_range * stride_out[1] + dim_range[None, :] * stride_out[3]\n    out_mask = (seq_range < seqlen) & dim_mask[None, :]\n    tl.store(out_ptrs, cumsum_vals, mask=out_mask)\n\n\n@triton.jit\ndef angle_dt_fwd_kernel(\n    OUT,        # Output tensor (batch, seqlen, nheads, dim)\n    OUT_SUM,    # Output sum tensor (batch, seqlen // chunk_size, nheads, dim)\n    ANGLE,      # Angle tensor (batch, seqlen, nheads, dim)\n    DT,         # Delta time tensor (batch, seqlen, nheads)\n    PREFIX,     # Prefix tensor (batch, numchunks, nheads, dim) - optional\n    seqlen,\n    dim,\n    chunk_size,\n    stride_out,     # (batch, seqlen, nheads, dim)\n    stride_out_sum, # (batch, seqlen // chunk_size, nheads, dim)\n    stride_angle,   # (batch, seqlen, nheads, dim)\n    stride_dt,      # (batch, seqlen, nheads)\n    stride_prefix,  # (batch, numchunks, nheads, dim)\n    # Meta-parameters\n    BLOCK_S: tl.constexpr,\n    BLOCK_D: tl.constexpr,\n    WRITE_OUTPUT: tl.constexpr,  # Whether to write the full output\n    WRITE_CHUNK_SUM: tl.constexpr,  # Whether to write the chunk sum\n    HAS_PREFIX: tl.constexpr,  # Whether prefix is provided\n):\n    # Program IDs\n    pid_b = tl.program_id(axis=2)  # Batch index (one per batch element)\n    pid_s = tl.program_id(axis=1)  # Sequence block (chunk index)\n    pid_h = tl.program_id(axis=0)  # Head index (one per head)\n\n    # Offset pointers by batch and head\n    ANGLE = ANGLE + pid_b * stride_angle[0] + pid_h * stride_angle[2]\n    DT = DT + pid_b * stride_dt[0] + pid_h * stride_dt[2]\n    if WRITE_OUTPUT:\n        OUT = OUT + pid_b * stride_out[0] + pid_h * stride_out[2]\n    if WRITE_CHUNK_SUM:\n        OUT_SUM = OUT_SUM + pid_b * stride_out_sum[0] + pid_h * stride_out_sum[2]\n    if HAS_PREFIX:\n        PREFIX = PREFIX + pid_b * stride_prefix[0] + pid_h * stride_prefix[2]\n\n    # Compute ranges - each block processes exactly chunk_size elements\n    seq_start = pid_s * chunk_size\n    seq_range = seq_start + tl.arange(0, BLOCK_S)\n    dim_range = tl.arange(0, BLOCK_D)\n\n    # Masks\n    seq_mask = seq_range < seqlen\n    dim_mask = dim_range < dim\n\n    # Load angle: (seqlen, dim) for this batch and head\n    angle_ptrs = ANGLE + seq_range[:, None] * stride_angle[1] + dim_range[None, :] * stride_angle[3]\n    angle_mask = (seq_mask[:, None] & dim_mask[None, :])\n    angle_vals = tl.load(angle_ptrs, mask=angle_mask, other=0.0).to(tl.float32)\n\n    # Load dt: (seqlen,) for this batch and head\n    dt_ptrs = DT + seq_range * stride_dt[1]\n    dt_mask = seq_mask\n    dt_vals = tl.load(dt_ptrs, mask=dt_mask, other=0.0).to(tl.float32)\n\n    # Multiply: angle (S, D) * dt (S, 1) -> output (S, D)\n    # angle_vals: (BLOCK_S, BLOCK_D)\n    # dt_vals: (BLOCK_S,)\n    #output_vals = angle_vals * dt_vals[:, None]  # (BLOCK_S, BLOCK_D)\n    output_vals = tl.sigmoid(2.0 * angle_vals) * 2.0 - 1.0\n    # output_vals = libdevice.tanh(output_vals)  # This is pretty slow\n    # This is still not super fast, idk how to enable fastmath\n    #output_vals = tl.sigmoid(2.0 * output_vals) * 2.0 - 1.0\n    output_vals = output_vals * dt_vals[:, None]\n    # This is the fastest, but with reduced accuracy. We probably don't need it\n    # output_vals = tl.inline_asm_elementwise(\n    #     \"tanh.approx.f32 $0, $1;\",\n    #     \"=f,f\",\n    #     [output_vals],\n    #     dtype=tl.float32,\n    #     is_pure=True,\n    #     pack=1,\n    # )\n    output_vals *= 3.141592653589793  # pi\n\n    # Conditionally compute and store chunk sum\n    if WRITE_CHUNK_SUM:\n        # Compute sum along sequence dimension (within this chunk)\n        # Sum over the sequence dimension (axis 0)\n        chunk_sum = tl.sum(output_vals, axis=0)  # (BLOCK_D,)\n        # Store chunk sum: (seqlen // chunk_size, dim) for this batch and head\n        sum_ptrs = OUT_SUM + pid_s * stride_out_sum[1] + dim_range * stride_out_sum[3]\n        sum_mask = dim_mask\n        tl.store(sum_ptrs, chunk_sum, mask=sum_mask)\n\n    # Conditionally store output: (seqlen, dim) for this batch and head\n    if WRITE_OUTPUT:\n        output_vals = tl.cumsum(output_vals, axis=0)  # Cumulative sum along sequence dimension (axis 0)\n        # Add prefix if provided\n        if HAS_PREFIX:\n            # If chunk idx is 0, prefix is 0. If chunk idx is i, read from prefix at location i-1\n            if pid_s > 0:\n                # Load prefix for this chunk from location pid_s - 1\n                prefix_ptrs = PREFIX + (pid_s - 1) * stride_prefix[1] + dim_range * stride_prefix[3]\n                prefix_mask = dim_mask\n                prefix_vals = tl.load(prefix_ptrs, mask=prefix_mask, other=0.0).to(tl.float32)\n                # Add prefix to all elements in this chunk\n                output_vals = output_vals + prefix_vals[None, :]  # Broadcast prefix across sequence dimension\n            # For pid_s == 0, prefix is implicitly 0, so no addition needed\n        out_ptrs = OUT + seq_range[:, None] * stride_out[1] + dim_range[None, :] * stride_out[3]\n        out_mask = (seq_mask[:, None] & dim_mask[None, :])\n        tl.store(out_ptrs, output_vals, mask=out_mask)\n\n\n# The kernel expects inputs to be flipped in the sequence dimension.\n# This is because it processes chunks in reverse order.\n@triton.jit\ndef angle_dt_bwd_kernel(\n    GRAD_DT,      # Grad dt tensor (batch, seqlen, nheads)\n    GRAD_ANGLE,   # Grad angle tensor (batch, seqlen, nheads, dim)\n    GRAD_SUM,     # Grad sum tensor (batch, seqlen // chunk_size, nheads, dim)\n    GRAD_OUT,     # Grad input tensor (batch, seqlen, nheads, dim)\n    ANGLE,        # Angle tensor (batch, seqlen, nheads, dim)\n    DT,           # Delta time tensor (batch, seqlen, nheads)\n    PREFIX,       # Prefix tensor (batch, numchunks, nheads, dim) - optional\n    seqlen,\n    dim,\n    chunk_size,\n    stride_grad_dt,     # (batch, seqlen, nheads)\n    stride_grad_angle,  # (batch, seqlen, nheads, dim)\n    stride_grad_sum,    # (batch, seqlen // chunk_size, nheads, dim)\n    stride_grad_out,    # (batch, seqlen, nheads, dim)\n    stride_angle,       # (batch, seqlen, nheads, dim)\n    stride_dt,          # (batch, seqlen, nheads)\n    stride_prefix,      # (batch, numchunks, nheads, dim)\n    # Meta-parameters\n    BLOCK_S: tl.constexpr,\n    BLOCK_D: tl.constexpr,\n    WRITE_GRAD: tl.constexpr,       # Whether to write the full output\n    WRITE_CHUNK_SUM: tl.constexpr,  # Whether to write the chunk sum\n    HAS_PREFIX: tl.constexpr,       # Whether prefix is provided\n):\n    # Program IDs\n    pid_b = tl.program_id(axis=2)  # Batch index (one per batch element)\n    pid_s = tl.program_id(axis=1)  # Sequence block (chunk index)\n    pid_h = tl.program_id(axis=0)  # Head index (one per head)\n\n    # Offset pointers by batch and head\n    GRAD_OUT = GRAD_OUT + pid_b * stride_grad_out[0] + pid_h * stride_grad_out[2]\n    if WRITE_GRAD:\n        GRAD_DT = GRAD_DT + pid_b * stride_grad_dt[0] + pid_h * stride_grad_dt[2]\n        GRAD_ANGLE = GRAD_ANGLE + pid_b * stride_grad_angle[0] + pid_h * stride_grad_angle[2]\n        DT = DT + pid_b * stride_dt[0] + pid_h * stride_dt[2]\n        ANGLE = ANGLE + pid_b * stride_angle[0] + pid_h * stride_angle[2]\n    if WRITE_CHUNK_SUM:\n        GRAD_SUM = GRAD_SUM + pid_b * stride_grad_sum[0] + pid_h * stride_grad_sum[2]\n    if HAS_PREFIX:\n        PREFIX = PREFIX + pid_b * stride_prefix[0] + pid_h * stride_prefix[2]\n\n    # Compute ranges - each block processes exactly chunk_size elements\n    seq_start = pid_s * chunk_size\n    seq_range = seq_start + tl.arange(0, BLOCK_S)\n    dim_range = tl.arange(0, BLOCK_D)\n\n    # Masks\n    seq_mask = seq_range < seqlen\n    dim_mask = dim_range < dim\n\n    # Load angle: (seqlen, dim) for this batch and head\n    grad_out_ptrs = GRAD_OUT + seq_range[:, None] * stride_grad_out[1] + dim_range[None, :] * stride_grad_out[3]\n    grad_out_mask = (seq_mask[:, None] & dim_mask[None, :])\n    grad_out_vals = tl.load(grad_out_ptrs, mask=grad_out_mask, other=0.0).to(tl.float32)\n\n    # Conditionally compute and store chunk sum\n    if WRITE_CHUNK_SUM:\n        # Compute sum along sequence dimension (within this chunk)\n        # Sum over the sequence dimension (axis 0)\n        chunk_sum = tl.sum(grad_out_vals, axis=0)  # (BLOCK_D,)\n        # Store chunk sum: (seqlen // chunk_size, dim) for this batch and head\n        sum_ptrs = GRAD_SUM + pid_s * stride_grad_sum[1] + dim_range * stride_grad_sum[3]\n        sum_mask = dim_mask\n        tl.store(sum_ptrs, chunk_sum, mask=sum_mask)\n\n    # Conditionally store output: (seqlen, dim) for this batch and head\n    if WRITE_GRAD:\n        grad_out_vals = tl.cumsum(grad_out_vals, axis=0)  # Cumulative sum along sequence dimension (axis 0)\n\n        # Add prefix if provided\n        if HAS_PREFIX:\n            # If chunk idx is 0, prefix is 0. If chunk idx is i, read from prefix at location i-1\n            if pid_s > 0:\n                # Load prefix for this chunk from location pid_s - 1\n                prefix_ptrs = PREFIX + (pid_s - 1) * stride_prefix[1] + dim_range * stride_prefix[3]\n                prefix_mask = dim_mask\n                prefix_vals = tl.load(prefix_ptrs, mask=prefix_mask, other=0.0).to(tl.float32)\n                # Add prefix to all elements in this chunk\n                grad_out_vals = grad_out_vals + prefix_vals[None, :]  # Broadcast prefix across sequence dimension\n            # For pid_s == 0, prefix is implicitly 0, so no addition needed\n\n        # Load angle: (seqlen, dim) for this batch and head\n        angle_ptrs = ANGLE + seq_range[:, None] * stride_angle[1] + dim_range[None, :] * stride_angle[3]\n        angle_mask = (seq_mask[:, None] & dim_mask[None, :])\n        angle_vals = tl.load(angle_ptrs, mask=angle_mask, other=0.0).to(tl.float32)\n\n        # Load dt: (seqlen,) for this batch and head\n        dt_ptrs = DT + seq_range * stride_dt[1]\n        dt_mask = seq_mask\n        dt_vals = tl.load(dt_ptrs, mask=dt_mask, other=0.0).to(tl.float32)  # (BLOCK_S,)\n\n        # Compute dt gradients\n        tanh_angle_vals = tl.sigmoid(2.0 * angle_vals) * 2.0 - 1.0  # (BLOCK_S, BLOCK_D)\n        pi_tanh_angle_vals = tanh_angle_vals*3.141592653589793\n        dt_grad_vals = grad_out_vals * pi_tanh_angle_vals # (BLOCK_S, BLOCK_D)\n        dt_grad_vals = tl.sum(dt_grad_vals, axis=1)  # Sum over dim to get (BLOCK_S,)\n\n        # Store dt gradients\n        grad_dt_ptrs = GRAD_DT + seq_range * stride_grad_dt[1]\n        grad_dt_mask = seq_mask\n        tl.store(grad_dt_ptrs, dt_grad_vals, mask=grad_dt_mask)\n\n        # Compute angle gradients\n        d_tanh = 1.0 - tanh_angle_vals * tanh_angle_vals\n        grad_angle_vals = (3.141592653589793 * dt_vals[:, None]) * d_tanh * grad_out_vals\n\n        # Store angle gradients\n        grad_angle_ptrs = GRAD_ANGLE + seq_range[:, None] * stride_grad_angle[1] + dim_range[None, :] * stride_grad_angle[3]\n        grad_angle_mask = (seq_mask[:, None] & dim_mask[None, :])\n        tl.store(grad_angle_ptrs, grad_angle_vals, mask=grad_angle_mask)\n\n\ndef apply_angle_dt_fwd(\n    angle: torch.Tensor,  # (batch, seqlen, nheads, dim)\n    dt: torch.Tensor,     # (batch, seqlen, nheads)\n    chunk_size: int = 128,\n) -> tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Multiply angle and dt tensors element-wise and compute chunk sums.\n\n    Arguments:\n        angle: (batch, seqlen, nheads, dim)\n        dt: (batch, seqlen, nheads)\n        chunk_size: Size of chunks for summing (must be power of 2)\n        write_output: Whether to write the full output tensor\n        write_chunk_sum: Whether to write the chunk sum tensor\n        prefix: Optional prefix to add before cumsum (batch, numchunks, nheads, dim)\n\n    Returns:\n        output: (batch, seqlen, nheads, dim) - may contain uninitialized data if write_output=False\n        output_sum: (batch, seqlen // chunk_size, nheads, dim) - may contain uninitialized data if write_chunk_sum=False\n    \"\"\"\n    batch, seqlen, nheads, dim = angle.shape\n    assert angle.shape == (batch, seqlen, nheads, dim)\n    assert dt.shape == (batch, seqlen, nheads)\n    assert chunk_size > 0 and (chunk_size & (chunk_size - 1)) == 0, \"chunk_size must be power of 2\"\n\n    # Calculate output dimensions\n    num_chunks = math.ceil(seqlen / chunk_size)\n\n    # Create output tensors (always fp32)\n    output = torch.empty(batch, seqlen, nheads, dim, device=angle.device, dtype=torch.float32)\n    output_sum = torch.empty(batch, num_chunks, nheads, dim, device=angle.device, dtype=torch.float32)\n\n    # Launch kernel\n    BLOCK_S = chunk_size  # Use chunk_size as BLOCK_S\n    BLOCK_D = triton.next_power_of_2(dim)\n\n    # Step 1: compute the sum of each chunk. Don't write the output\n    grid = lambda META: (nheads, num_chunks, batch)\n    with torch.cuda.device(angle.device.index):\n        torch.library.wrap_triton(angle_dt_fwd_kernel)[grid](\n            None,  # output\n            output_sum,\n            angle,\n            dt,\n            None,  # prefix\n            seqlen,\n            dim,\n            chunk_size,\n            (0, 0, 0, 0),  # output_stride\n            output_sum.stride(),\n            angle.stride(),\n            dt.stride(),\n            (0, 0, 0, 0),   # prefix_stride\n            BLOCK_S=BLOCK_S,\n            BLOCK_D=BLOCK_D,\n            WRITE_OUTPUT=False,\n            WRITE_CHUNK_SUM=True,\n            HAS_PREFIX=False,\n        )\n\n    # Step 2: compute cumsum on output_sum to get prefix\n    prefix = apply_cumsum(output_sum)  # Shape: (batch, num_chunks, nheads, dim)\n\n    # Step 3: call angle_dt_kernel again with output and prefix, don't need to write output_sum\n    with torch.cuda.device(angle.device.index):\n        torch.library.wrap_triton(angle_dt_fwd_kernel)[grid](\n            output,  # output\n            None,    # output_sum (don't need to write)\n            angle,\n            dt,\n            prefix,  # prefix\n            seqlen,\n            dim,\n            chunk_size,\n            output.stride(),  # output_stride\n            (0, 0, 0, 0),     # output_sum_stride\n            angle.stride(),\n            dt.stride(),\n            prefix.stride(),  # prefix_stride\n            BLOCK_S=BLOCK_S,\n            BLOCK_D=BLOCK_D,\n            WRITE_OUTPUT=True,\n            WRITE_CHUNK_SUM=False,\n            HAS_PREFIX=True,\n        )\n\n    return output\n\ndef apply_angle_dt_bwd(\n    grad_out: torch.Tensor,  # (batch, seqlen, nheads, dim)\n    angle: torch.Tensor,     # (batch, seqlen, nheads, dim)\n    dt: torch.Tensor,        # (batch, seqlen, nheads)\n    chunk_size: int = 128,\n) -> tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Multiply angle and dt tensors element-wise and compute chunk sums.\n\n    Arguments:\n        grad_out: (batch, seqlen, nheads, dim) - gradient of the output\n        angle: (batch, seqlen, nheads, dim) - stored angle tensor\n        dt: (batch, seqlen, nheads) - stored delta time tensor\n        chunk_size: Size of chunks for summing (must be power of 2)\n        write_output: Whether to write the full output tensor\n        write_chunk_sum: Whether to write the chunk sum tensor\n        prefix: Optional prefix to add before cumsum (batch, numchunks, nheads, dim)\n\n    Returns:\n        output: (batch, seqlen, nheads, dim) - may contain uninitialized data if write_output=False\n        output_sum: (batch, seqlen // chunk_size, nheads, dim) - may contain uninitialized data if write_chunk_sum=False\n    \"\"\"\n    batch, seqlen, nheads, dim = grad_out.shape\n    assert grad_out.shape == (batch, seqlen, nheads, dim)\n    assert angle.shape == (batch, seqlen, nheads, dim)\n    assert dt.shape == (batch, seqlen, nheads)\n    assert chunk_size > 0 and (chunk_size & (chunk_size - 1)) == 0, \"chunk_size must be power of 2\"\n\n    # Calculate output dimensions\n    num_chunks = math.ceil(seqlen / chunk_size)\n\n    # Reverse the sequence dimension of grad_out, angle, dt\n    grad_out = grad_out.flip(dims=(1,))  # Reverse along sequence dimension\n    angle = angle.flip(dims=(1,))\n    dt = dt.flip(dims=(1,))\n\n    # Create output tensors (always fp32)\n    grad_dt = torch.empty_like(dt) # (batch, seqlen, nheads)\n    grad_angle = torch.empty_like(angle) # (batch, seqlen, nheads, dim)\n    grad_sum = torch.empty(batch, num_chunks, nheads, dim, device=angle.device, dtype=torch.float32)\n\n    # Launch kernel\n    BLOCK_S = chunk_size  # Use chunk_size as BLOCK_S\n    BLOCK_D = triton.next_power_of_2(dim)\n\n    # Step 1: compute the sum of each chunk. Don't write the output\n    grid = lambda META: (nheads, num_chunks, batch)\n    with torch.cuda.device(angle.device.index):\n        torch.library.wrap_triton(angle_dt_bwd_kernel)[grid](\n            None,  # GRAD_DT\n            None,  # GRAD_ANGLE\n            grad_sum,  # GRAD_SUM\n            grad_out,  # GRAD_OUT\n            angle,\n            dt,\n            None,  # PREFIX\n            seqlen,\n            dim,\n            chunk_size,\n            (0, 0, 0),             # stride_grad_dt\n            (0, 0, 0, 0),          # stride_grad_angle\n            grad_sum.stride(),     # stride_grad_sum\n            grad_out.stride(),     # stride_grad_out\n            angle.stride(),\n            dt.stride(),\n            (0, 0, 0, 0),          # stride_prefix\n            BLOCK_S=BLOCK_S,\n            BLOCK_D=BLOCK_D,\n            WRITE_GRAD=False,      # Don't write grad_dt and grad_angle yet\n            WRITE_CHUNK_SUM=True,  # Write chunk sums to grad_sum\n            HAS_PREFIX=False,      # No prefix provided\n        )\n\n    # Step 2: compute cumsum on output_sum to get prefix\n    prefix = apply_cumsum(grad_sum)  # Shape: (batch, num_chunks, nheads, dim)\n\n    # Step 3: call angle_dt_fwd_chunksum_kernel again with output and prefix, don't need to write output_sum\n    with torch.cuda.device(angle.device.index):\n        torch.library.wrap_triton(angle_dt_bwd_kernel)[grid](\n            grad_dt,\n            grad_angle,\n            None,               # GRAD_SUM (don't need to write)\n            grad_out,\n            angle,\n            dt,\n            prefix,             # prefix\n            seqlen,\n            dim,\n            chunk_size,\n            grad_dt.stride(),       # stride_grad_dt\n            grad_angle.stride(),    # stride_grad_angle\n            (0, 0, 0),              # stride_grad_sum\n            grad_out.stride(),      # stride_grad_out\n            angle.stride(),\n            dt.stride(),\n            prefix.stride(),        # stride_prefix\n            BLOCK_S=BLOCK_S,\n            BLOCK_D=BLOCK_D,\n            WRITE_GRAD=True,        # Write grad_dt and grad_angle\n            WRITE_CHUNK_SUM=False,  # Don't write chunk sums again\n            HAS_PREFIX=True,        # Use the computed prefix\n        )\n\n    grad_dt = grad_dt.flip(dims=(1,))\n    grad_angle = grad_angle.flip(dims=(1,)) \n\n    return grad_dt, grad_angle\n\n\ndef apply_cumsum(\n    x: torch.Tensor,  # (batch, seqlen, nheads, dim)\n) -> torch.Tensor:\n    \"\"\"\n    Compute cumulative sum along sequence dimension using Triton.\n\n    Arguments:\n        x: (batch, seqlen, nheads, dim)\n\n    Returns:\n        output: (batch, seqlen, nheads, dim) - cumulative sum along seqlen dimension\n    \"\"\"\n    batch, seqlen, nheads, dim = x.shape\n    assert seqlen <= 512, f\"seqlen must be <= 512, got {seqlen}\"\n    # Create output tensor (always fp32)\n    output = torch.empty_like(x, dtype=torch.float32)\n\n    # Launch kernel\n    BLOCK_S = triton.next_power_of_2(seqlen)\n    BLOCK_D = triton.next_power_of_2(min(dim, 16))\n\n    grid = lambda META: (nheads, triton.cdiv(dim, META[\"BLOCK_D\"]), batch)\n    with torch.cuda.device(x.device.index):\n        torch.library.wrap_triton(cumsum_kernel)[grid](\n            output,\n            x,\n            seqlen,\n            dim,\n            output.stride(),\n            x.stride(),\n            BLOCK_S=BLOCK_S,\n            BLOCK_D=BLOCK_D,\n        )\n\n    return output\n\n\ndef apply_angle_dt_reference(\n    angle: torch.Tensor,  # (batch, seqlen, nheads, dim)\n    dt: torch.Tensor,     # (batch, seqlen, nheads)\n    chunk_size: int = 64,\n) -> tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"Reference PyTorch implementation.\"\"\"\n    batch, seqlen, nheads, dim = angle.shape\n\n    # Element-wise multiply: angle (B, S, H, D) * dt (B, S, H, 1) -> (B, S, H, D)\n    #base_vals = (angle * dt[..., None]).to(torch.float32)  # Always return fp32\n    base_vals = (angle).to(torch.float32)\n\n    # Apply tanh then multiply by pi\n    base_vals = torch.tanh(base_vals) * dt[..., None].to(torch.float32) * torch.pi\n\n    # Simple cumulative sum along seqlen dimension\n    output = torch.cumsum(base_vals, dim=1)\n    return output\n\n\ndef test_correctness():\n    \"\"\"Test correctness against reference implementation.\"\"\"\n    print(\"Testing angle_dt kernel correctness...\")\n\n    # Test parameters\n    batch, seqlen, nheads, dim = 2, 512, 4, 32\n    chunk_size = 64\n    device = \"cuda\"\n    dtype = torch.float32\n\n    # Create test tensors\n    #torch.manual_seed(42)\n    angle = torch.randn(batch, seqlen, nheads, dim, device=device, dtype=dtype)\n    dt = torch.randn(batch, seqlen, nheads, device=device, dtype=dtype)\n\n    # Test kernel vs reference\n    out_triton = apply_angle_dt_fwd(angle, dt, chunk_size)\n    out_ref = apply_angle_dt_reference(angle, dt, chunk_size)\n\n    max_diff = (out_triton - out_ref).abs().max().item()\n    print(f\"Output max difference: {max_diff:.6f}\")\n    assert max_diff < 1e-3, f\"Too large difference in output: {max_diff}\"\n    print(\"Test passed! ✓\")\n    print(\"All basic tests passed! ✓\")\n\n\ndef test_cumsum_correctness():\n    \"\"\"Test cumsum kernel correctness against PyTorch.\"\"\"\n    print(\"Testing cumsum kernel correctness...\")\n\n    # Test parameters\n    batch, seqlen, nheads, dim = 4, 128, 8, 64\n    device = \"cuda\"\n    dtype = torch.float32\n\n    # Create test tensors\n    #torch.manual_seed(42)\n    x = torch.randn(batch, seqlen, nheads, dim, device=device, dtype=dtype)\n\n    # Test kernel vs PyTorch\n    out_triton = apply_cumsum(x)\n    out_ref = torch.cumsum(x, dim=1).to(torch.float32)\n\n    max_diff = (out_triton - out_ref).abs().max().item()\n    print(f\"Cumsum max difference: {max_diff:.6f}\")\n\n    assert max_diff < 1e-4, f\"Too large difference in cumsum: {max_diff}\"\n\n    print(\"Cumsum test passed! ✓\")\n\ndef test_backward_correctness():\n    \"\"\"Backward correctness vs PyTorch autograd on small cases.\"\"\"\n    print(\"Testing backward correctness...\")\n\n    device = \"cuda\"\n    tol = 5e-3  # fp32\n\n    cases = [\n        (2, 257, 3, 17, 64),  # odd S/D, non-power-of-two\n        (1, 129, 4, 33, 32),\n    ]\n\n    for (batch, seqlen, nheads, dim, chunk_size) in cases:\n        angle = torch.randn(batch, seqlen, nheads, dim, device=device, dtype=torch.float32)\n        dt    = torch.randn(batch, seqlen, nheads,      device=device, dtype=torch.float32)\n        grad_out = torch.randn(batch, seqlen, nheads, dim, device=device, dtype=torch.float32)\n\n        # Triton bwd\n        grad_dt_tri, grad_angle_tri = apply_angle_dt_bwd(grad_out, angle, dt, chunk_size)\n        # Reference bwd via autograd\n        angle_ref = angle.detach().clone().requires_grad_(True)\n        dt_ref    = dt.detach().clone().requires_grad_(True)\n        out_ref = apply_angle_dt_reference(angle_ref, dt_ref, chunk_size)\n        out_ref.backward(grad_out)\n        grad_angle_ref = angle_ref.grad.detach()\n        grad_dt_ref    = dt_ref.grad.detach()\n\n        max_da = (grad_angle_tri - grad_angle_ref).abs().max().item()\n        max_dd = (grad_dt_tri    - grad_dt_ref   ).abs().max().item()\n        print(f\"  Case B={batch} S={seqlen} H={nheads} D={dim} chunk={chunk_size} | \"\n              f\"max|Δ angle|={max_da:.3e}  max|Δ dt|={max_dd:.3e}\")\n        assert max_da < tol, f\"angle grad mismatch {max_da}\"\n        assert max_dd < tol, f\"dt grad mismatch {max_dd}\"\n\n    print(\"Backward correctness test passed! ✓\")\n\ndef benchmark_angle_dt():\n    \"\"\"Benchmark angle_dt kernel and measure memory bandwidth.\"\"\"\n    print(\"\\nBenchmarking angle_dt kernel...\")\n\n    # Benchmark parameters\n    batch, seqlen, nheads, dim = 8, 4096, 32, 32\n    # batch, seqlen, nheads, dim = 1, 128, 1, 1\n    chunk_size = 128\n    device = \"cuda\"\n    dtype = torch.bfloat16\n\n    # Create input tensors\n    #torch.manual_seed(42)\n    # Generate angle by expanding from (batch, seqlen, 1, dim) to (batch, seqlen, nheads, dim)\n    angle_base = torch.randn(batch, seqlen, 1, dim, device=device, dtype=dtype)\n    angle = angle_base.expand(batch, seqlen, nheads, dim)\n    dt = torch.randn(batch, seqlen, nheads, device=device, dtype=dtype)\n\n    fn = lambda: apply_angle_dt_fwd(angle, dt, chunk_size)\n    out = fn()\n    # Warmup\n    for _ in range(10):\n        fn()\n\n    # Benchmark\n    torch.cuda.synchronize()\n    import time\n    time.sleep(0.5)\n\n    # Run benchmark\n    time_ms = triton.testing.do_bench(fn, warmup=10, rep=100)\n\n    # Calculate memory bandwidth\n    # Read: angle_base (actual underlying data) + dt\n    # Write: output + output_sum (always fp32, so 4 bytes per element)\n    # Note: angle is expanded so actual memory read is only angle_base.numel()\n    bytes_read = angle_base.untyped_storage().nbytes() + dt.untyped_storage().nbytes()\n    bytes_write = out.untyped_storage().nbytes()  # Both output and output_sum (fp32 = 4 bytes)\n    total_bytes = bytes_read + bytes_write\n\n    # Convert to GB/s\n    time_s = time_ms / 1000.0\n    bandwidth_gb_s = (total_bytes / 1e9) / time_s\n\n    print(f\"Angle base shape: {angle_base.shape}\")\n    print(f\"Angle expanded shape: {angle.shape}\")\n    print(f\"Angle stride: {angle.stride()}\")\n    print(f\"DT shape: {dt.shape}\")\n    print(f\"Output shape: {out.shape}\")\n    print(f\"Chunk size: {chunk_size}\")\n    print(f\"Time: {time_ms:.3f} ms\")\n    print(f\"Memory transferred: {total_bytes / 1e9:.3f} GB\")\n    print(f\"Memory bandwidth: {bandwidth_gb_s:.1f} GB/s\")\n\n    # from flash_attn.utils.benchmark import pytorch_profiler\n    # pytorch_profiler(fn)\n\n    return time_ms, bandwidth_gb_s\n\ndef benchmark_angle_dt_backward():\n    \"\"\"Benchmark backward pass and report rough memory bandwidth.\"\"\"\n    print(\"\\nBenchmarking angle_dt backward...\")\n\n    batch, seqlen, nheads, dim = 8, 4096, 32, 32\n    chunk_size = 128\n    device = \"cuda\"\n\n    # Use fp32 for bwd accumulations\n    angle = torch.randn(batch, seqlen, nheads, dim, device=device, dtype=torch.float32)\n    dt    = torch.randn(batch, seqlen, nheads,      device=device, dtype=torch.float32)\n    grad_out = torch.randn(batch, seqlen, nheads, dim, device=device, dtype=torch.float32)\n\n    fn = lambda: apply_angle_dt_bwd(grad_out, angle, dt, chunk_size)\n    _ = fn()\n    # Warmup\n    for _ in range(10):\n        fn()\n\n    torch.cuda.synchronize()\n    import time\n    time.sleep(0.5)\n    time_ms = triton.testing.do_bench(fn, warmup=10, rep=100)\n\n    # Rough traffic estimate (two-stage + prefixes), conservative:\n    num_chunks = (seqlen + chunk_size - 1) // chunk_size\n    bytes_read = (\n        grad_out.numel() * 4 +  # read grad_out\n        angle.numel()   * 4 +   # read angle\n        dt.numel()      * 4 +   # read dt\n        (batch * num_chunks * nheads * dim) * 4 +  # read grad_sum for prefix\n        (batch * num_chunks * nheads * dim) * 4    # read prefix in stage 2\n    )\n    bytes_write = (\n        (batch * num_chunks * nheads * dim) * 4 +  # write grad_sum (stage 1)\n        (batch * seqlen * nheads) * 4 +            # write grad_dt\n        (batch * seqlen * nheads * dim) * 4        # write grad_angle\n    )\n    total_bytes = bytes_read + bytes_write\n    bandwidth_gb_s = (total_bytes / 1e9) / (time_ms / 1000.0)\n\n    print(f\"B={batch} S={seqlen} H={nheads} D={dim} chunk={chunk_size}\")\n    print(f\"Time: {time_ms:.3f} ms\")\n    print(f\"Memory transferred (est): {total_bytes / 1e9:.3f} GB\")\n    print(f\"Memory bandwidth (est): {bandwidth_gb_s:.1f} GB/s\")\n\n    return time_ms, bandwidth_gb_s\n\nif __name__ == \"__main__\":\n    test_correctness()\n    test_cumsum_correctness()\n    benchmark_angle_dt()"
  },
  {
    "path": "mamba_ssm/ops/triton/k_activations.py",
    "content": "# Copyright (c) 2024, Tri Dao, Albert Gu.\n\nimport torch\n\nimport triton\nimport triton.language as tl\n\nfrom mamba_ssm.utils.determinism import autotune_configs\n\n\n@triton.autotune(\n    configs=autotune_configs([\n        triton.Config({'BLOCK_N': 32}),\n        triton.Config({'BLOCK_N': 64}),\n        triton.Config({'BLOCK_N': 128}),\n        triton.Config({'BLOCK_N': 256}),\n        triton.Config({'BLOCK_N': 512}),\n        triton.Config({'BLOCK_N': 1024}),\n    ]),\n    key=['ncols'],\n)\n@triton.jit\ndef _swiglu_fwd_kernel(\n    X,\n    Y,\n    OUT,\n    stride_x_row,  # how much to increase the pointer when moving by 1 row\n    stride_y_row,\n    stride_out_row,\n    ncols,\n    BLOCK_N: tl.constexpr,\n):\n    # Map the program id to the row of X and Y it should compute.\n    row = tl.program_id(0)\n    start_col = tl.program_id(1) * BLOCK_N\n    X += row * stride_x_row\n    Y += row * stride_y_row\n    OUT += row * stride_out_row\n    cols = start_col + tl.arange(0, BLOCK_N)\n    x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)\n    y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)\n    out = x * tl.sigmoid(x) * y\n    tl.store(OUT + cols, out, mask=cols < ncols)\n\n\ndef _swiglu_fwd(xy, out=None):\n    if xy.stride(-1) != 1:\n        xy = xy.contiguous()\n    batch_shape = xy.shape[:-1]\n    xy = xy.reshape(-1, xy.shape[-1])\n    x, y = xy.chunk(2, dim=-1)\n    if out is None:\n        out = torch.empty_like(x)\n    else:\n        out = out.reshape(-1, out.shape[-1])\n        assert out.shape == x.shape\n    assert out.stride(-1) == 1\n    M, N = x.shape\n    grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))\n    with torch.cuda.device(x.device.index):\n        _swiglu_fwd_kernel[grid](x, y, out, x.stride(0), y.stride(0), out.stride(0), N)\n    return out.reshape(*batch_shape, out.shape[-1])\n\n\n@triton.autotune(\n    configs=autotune_configs([\n        triton.Config({'BLOCK_N': 32}),\n        triton.Config({'BLOCK_N': 64}),\n        triton.Config({'BLOCK_N': 128}),\n        triton.Config({'BLOCK_N': 256}),\n        triton.Config({'BLOCK_N': 512}),\n        triton.Config({'BLOCK_N': 1024}),\n    ]),\n    key=['ncols'],\n)\n@triton.heuristics({\"RECOMPUTE_OUTPUT\": lambda args: args[\"OUT\"] is not None})\n@triton.jit\ndef _swiglu_bwd_kernel(\n    X,\n    Y,\n    DOUT,\n    OUT,\n    DX,\n    DY,\n    stride_x_row,  # how much to increase the pointer when moving by 1 row\n    stride_y_row,\n    stride_dout_row,\n    stride_out_row,\n    stride_dx_row,\n    stride_dy_row,\n    ncols,\n    BLOCK_N: tl.constexpr,\n    RECOMPUTE_OUTPUT: tl.constexpr,\n):\n    # Map the program id to the row of X and Y it should compute.\n    row = tl.program_id(0)\n    start_col = tl.program_id(1) * BLOCK_N\n    X += row * stride_x_row\n    Y += row * stride_y_row\n    DOUT += row * stride_dout_row\n    if RECOMPUTE_OUTPUT:\n        OUT += row * stride_out_row\n    DX += row * stride_dx_row\n    DY += row * stride_dy_row\n    cols = start_col + tl.arange(0, BLOCK_N)\n    x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)\n    y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)\n    dout = tl.load(DOUT + cols, mask=cols < ncols, other=0.).to(tl.float32)\n    x_sigmoid = tl.sigmoid(x)\n    dx = x_sigmoid * (1 + x * (1 - x_sigmoid)) * y * dout\n    dy = x * x_sigmoid * dout\n    tl.store(DX + cols, dx, mask=cols < ncols)\n    tl.store(DY + cols, dy, mask=cols < ncols)\n    if RECOMPUTE_OUTPUT:\n        out = x * x_sigmoid * y\n        tl.store(OUT + cols, out, mask=cols < ncols)\n\n\ndef _swiglu_bwd(xy, dout, dxy=None, recompute_output=False, out=None):\n    if xy.stride(-1) != 1:\n        xy = xy.contiguous()\n    if dout.stride(-1) != 1:\n        dout = dout.contiguous()\n    batch_shape = xy.shape[:-1]\n    xy = xy.reshape(-1, xy.shape[-1])\n    x, y = xy.chunk(2, dim=-1)\n    dout = dout.reshape(-1, dout.shape[-1])\n    assert dout.shape == x.shape\n    if dxy is None:\n        dxy = torch.empty_like(xy)\n    else:\n        dxy = dxy.reshape(-1, dxy.shape[-1])\n        assert dxy.shape == xy.shape\n    dx, dy = dxy.chunk(2, dim=-1)\n    assert dx.stride(-1) == 1\n    assert dy.stride(-1) == 1\n    if recompute_output:\n        if out is None:\n            out = torch.empty_like(x)\n        else:\n            out = out.reshape(-1, out.shape[-1])\n            assert out.shape == x.shape\n        assert out.stride(-1) == 1\n    M, N = x.shape\n    grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))\n    with torch.cuda.device(x.device.index):\n        _swiglu_bwd_kernel[grid](x, y, dout, out if recompute_output else None, dx, dy,\n                                 x.stride(0), y.stride(0), dout.stride(0),\n                                 out.stride(0) if recompute_output else 0,\n                                 dx.stride(0), dy.stride(0),\n                                 N)\n    if not recompute_output:\n        return dxy.reshape(*batch_shape, dxy.shape[-1])\n    else:\n        return dxy.reshape(*batch_shape, dxy.shape[-1]), out.reshape(*batch_shape, out.shape[-1])\n\n\nclass SwiGLU(torch.autograd.Function):\n\n    @staticmethod\n    def forward(ctx, xy):\n        ctx.save_for_backward(xy)\n        return _swiglu_fwd(xy)\n\n    @staticmethod\n    def backward(ctx, dout):\n        xy, = ctx.saved_tensors\n        return _swiglu_bwd(xy, dout)\n\n\nswiglu = SwiGLU.apply\n"
  },
  {
    "path": "mamba_ssm/ops/triton/layer_norm.py",
    "content": "# Copyright (c) 2024, Tri Dao.\n# Implement dropout + residual + layer_norm / rms_norm.\n\n# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html\n# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.\n# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.\n# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.\n\nimport math\nimport warnings\n\nimport torch\nimport torch.nn.functional as F\nfrom mamba_ssm.utils.torch import custom_bwd, custom_fwd\n\nimport triton\nimport triton.language as tl\n\nfrom mamba_ssm.utils.determinism import autotune_configs\n\n\ndef layer_norm_ref(\n    x,\n    weight,\n    bias,\n    residual=None,\n    x1=None,\n    weight1=None,\n    bias1=None,\n    eps=1e-6,\n    dropout_p=0.0,\n    rowscale=None,\n    prenorm=False,\n    dropout_mask=None,\n    dropout_mask1=None,\n    upcast=False,\n):\n    dtype = x.dtype\n    if upcast:\n        x = x.float()\n        weight = weight.float()\n        bias = bias.float() if bias is not None else None\n        residual = residual.float() if residual is not None else residual\n        x1 = x1.float() if x1 is not None else None\n        weight1 = weight1.float() if weight1 is not None else None\n        bias1 = bias1.float() if bias1 is not None else None\n    if x1 is not None:\n        assert rowscale is None, \"rowscale is not supported with parallel LayerNorm\"\n    if rowscale is not None:\n        x = x * rowscale[..., None]\n    if dropout_p > 0.0:\n        if dropout_mask is not None:\n            x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)\n        else:\n            x = F.dropout(x, p=dropout_p)\n        if x1 is not None:\n            if dropout_mask1 is not None:\n                x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)\n            else:\n                x1 = F.dropout(x1, p=dropout_p)\n    if x1 is not None:\n        x = x + x1\n    if residual is not None:\n        x = (x + residual).to(x.dtype)\n    out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(\n        dtype\n    )\n    if weight1 is None:\n        return out if not prenorm else (out, x)\n    else:\n        out1 = F.layer_norm(\n            x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps\n        ).to(dtype)\n        return (out, out1) if not prenorm else (out, out1, x)\n\n\ndef rms_norm_ref(\n    x,\n    weight,\n    bias,\n    residual=None,\n    x1=None,\n    weight1=None,\n    bias1=None,\n    eps=1e-6,\n    dropout_p=0.0,\n    rowscale=None,\n    prenorm=False,\n    dropout_mask=None,\n    dropout_mask1=None,\n    upcast=False,\n):\n    dtype = x.dtype\n    if upcast:\n        x = x.float()\n        weight = weight.float()\n        bias = bias.float() if bias is not None else None\n        residual = residual.float() if residual is not None else residual\n        x1 = x1.float() if x1 is not None else None\n        weight1 = weight1.float() if weight1 is not None else None\n        bias1 = bias1.float() if bias1 is not None else None\n    if x1 is not None:\n        assert rowscale is None, \"rowscale is not supported with parallel LayerNorm\"\n    if rowscale is not None:\n        x = x * rowscale[..., None]\n    if dropout_p > 0.0:\n        if dropout_mask is not None:\n            x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)\n        else:\n            x = F.dropout(x, p=dropout_p)\n        if x1 is not None:\n            if dropout_mask1 is not None:\n                x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)\n            else:\n                x1 = F.dropout(x1, p=dropout_p)\n    if x1 is not None:\n        x = x + x1\n    if residual is not None:\n        x = (x + residual).to(x.dtype)\n    rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)\n    out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(dtype)\n    if weight1 is None:\n        return out if not prenorm else (out, x)\n    else:\n        out1 = ((x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)).to(\n            dtype\n        )\n        return (out, out1) if not prenorm else (out, out1, x)\n\ndef config_prune(configs):\n\n    if torch.version.hip:\n        try:\n            # set warp size based on gcn architecure \n            gcn_arch_name = torch.cuda.get_device_properties(0).gcnArchName\n            if \"gfx10\" in gcn_arch_name or \"gfx11\" in gcn_arch_name:\n                # radeon\n                warp_size = 32\n            else:\n                # instinct\n                warp_size = 64\n        except AttributeError as e:\n            # fall back to crude method to set warp size\n            device_name = torch.cuda.get_device_properties(0).name\n            if 'instinct' in device_name.lower():\n                warp_size = 64\n            else:\n                warp_size = 32\n            warnings.warn(f\"{e}, warp size set to {warp_size} based on device name: {device_name}\", UserWarning)\n\n    else:\n        # cuda \n        warp_size = 32    \n\n    max_block_sz = 1024\n    max_num_warps = max_block_sz // warp_size\n    pruned_configs = [config for config in configs if config.num_warps <= max_num_warps]\n    return pruned_configs\n\nconfigs_autotune = [\n        triton.Config({}, num_warps=1),\n        triton.Config({}, num_warps=2),\n        triton.Config({}, num_warps=4),\n        triton.Config({}, num_warps=8),\n        triton.Config({}, num_warps=16),\n        triton.Config({}, num_warps=32),\n        ]\n\npruned_configs_autotune = config_prune(configs_autotune)\n\n@triton.autotune(\n    configs=autotune_configs(pruned_configs_autotune),\n    key=[\"N\", \"HAS_RESIDUAL\", \"STORE_RESIDUAL_OUT\", \"IS_RMS_NORM\", \"HAS_BIAS\"],\n)\n# @triton.heuristics({\"HAS_BIAS\": lambda args: args[\"B\"] is not None})\n# @triton.heuristics({\"HAS_RESIDUAL\": lambda args: args[\"RESIDUAL\"] is not None})\n@triton.heuristics({\"HAS_X1\": lambda args: args[\"X1\"] is not None})\n@triton.heuristics({\"HAS_W1\": lambda args: args[\"W1\"] is not None})\n@triton.heuristics({\"HAS_B1\": lambda args: args[\"B1\"] is not None})\n@triton.jit\ndef _layer_norm_fwd_1pass_kernel(\n    X,  # pointer to the input\n    Y,  # pointer to the output\n    W,  # pointer to the weights\n    B,  # pointer to the biases\n    RESIDUAL,  # pointer to the residual\n    X1,\n    W1,\n    B1,\n    Y1,\n    RESIDUAL_OUT,  # pointer to the residual\n    ROWSCALE,\n    SEEDS,  # Dropout seeds for each row\n    DROPOUT_MASK,\n    Mean,  # pointer to the mean\n    Rstd,  # pointer to the 1/std\n    stride_x_row,  # how much to increase the pointer when moving by 1 row\n    stride_y_row,\n    stride_res_row,\n    stride_res_out_row,\n    stride_x1_row,\n    stride_y1_row,\n    M,  # number of rows in X\n    N,  # number of columns in X\n    eps,  # epsilon to avoid division by zero\n    dropout_p,  # Dropout probability\n    IS_RMS_NORM: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n    HAS_RESIDUAL: tl.constexpr,\n    STORE_RESIDUAL_OUT: tl.constexpr,\n    HAS_BIAS: tl.constexpr,\n    HAS_DROPOUT: tl.constexpr,\n    STORE_DROPOUT_MASK: tl.constexpr,\n    HAS_ROWSCALE: tl.constexpr,\n    HAS_X1: tl.constexpr,\n    HAS_W1: tl.constexpr,\n    HAS_B1: tl.constexpr,\n):\n    # Map the program id to the row of X and Y it should compute.\n    row = tl.program_id(0)\n    X += row * stride_x_row\n    Y += row * stride_y_row\n    if HAS_RESIDUAL:\n        RESIDUAL += row * stride_res_row\n    if STORE_RESIDUAL_OUT:\n        RESIDUAL_OUT += row * stride_res_out_row\n    if HAS_X1:\n        X1 += row * stride_x1_row\n    if HAS_W1:\n        Y1 += row * stride_y1_row\n    # Compute mean and variance\n    cols = tl.arange(0, BLOCK_N)\n    x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n    if HAS_ROWSCALE:\n        rowscale = tl.load(ROWSCALE + row).to(tl.float32)\n        x *= rowscale\n    if HAS_DROPOUT:\n        # Compute dropout mask\n        # 7 rounds is good enough, and reduces register pressure\n        keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p\n        x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)\n        if STORE_DROPOUT_MASK:\n            tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)\n    if HAS_X1:\n        x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)\n        if HAS_ROWSCALE:\n            rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)\n            x1 *= rowscale\n        if HAS_DROPOUT:\n            # Compute dropout mask\n            # 7 rounds is good enough, and reduces register pressure\n            keep_mask = (\n                tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p\n            )\n            x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)\n            if STORE_DROPOUT_MASK:\n                tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N)\n        x += x1\n    if HAS_RESIDUAL:\n        residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)\n        x += residual\n    if STORE_RESIDUAL_OUT:\n        tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)\n    if not IS_RMS_NORM:\n        mean = tl.sum(x, axis=0) / N\n        tl.store(Mean + row, mean)\n        xbar = tl.where(cols < N, x - mean, 0.0)\n        var = tl.sum(xbar * xbar, axis=0) / N\n    else:\n        xbar = tl.where(cols < N, x, 0.0)\n        var = tl.sum(xbar * xbar, axis=0) / N\n    rstd = 1 / tl.sqrt(var + eps)\n    tl.store(Rstd + row, rstd)\n    # Normalize and apply linear transformation\n    mask = cols < N\n    w = tl.load(W + cols, mask=mask).to(tl.float32)\n    if HAS_BIAS:\n        b = tl.load(B + cols, mask=mask).to(tl.float32)\n    x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n    y = x_hat * w + b if HAS_BIAS else x_hat * w\n    # Write output\n    tl.store(Y + cols, y, mask=mask)\n    if HAS_W1:\n        w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)\n        if HAS_B1:\n            b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)\n        y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1\n        tl.store(Y1 + cols, y1, mask=mask)\n\n\ndef _layer_norm_fwd(\n    x,\n    weight,\n    bias,\n    eps,\n    residual=None,\n    x1=None,\n    weight1=None,\n    bias1=None,\n    dropout_p=0.0,\n    rowscale=None,\n    out_dtype=None,\n    residual_dtype=None,\n    is_rms_norm=False,\n    return_dropout_mask=False,\n):\n    if residual is not None:\n        residual_dtype = residual.dtype\n    M, N = x.shape\n    assert x.stride(-1) == 1\n    if residual is not None:\n        assert residual.stride(-1) == 1\n        assert residual.shape == (M, N)\n    assert weight.shape == (N,)\n    assert weight.stride(-1) == 1\n    if bias is not None:\n        assert bias.stride(-1) == 1\n        assert bias.shape == (N,)\n    if x1 is not None:\n        assert x1.shape == x.shape\n        assert rowscale is None\n        assert x1.stride(-1) == 1\n    if weight1 is not None:\n        assert weight1.shape == (N,)\n        assert weight1.stride(-1) == 1\n    if bias1 is not None:\n        assert bias1.shape == (N,)\n        assert bias1.stride(-1) == 1\n    if rowscale is not None:\n        assert rowscale.is_contiguous()\n        assert rowscale.shape == (M,)\n    # allocate output\n    y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)\n    assert y.stride(-1) == 1\n    if weight1 is not None:\n        y1 = torch.empty_like(y)\n        assert y1.stride(-1) == 1\n    else:\n        y1 = None\n    if (\n        residual is not None\n        or (residual_dtype is not None and residual_dtype != x.dtype)\n        or dropout_p > 0.0\n        or rowscale is not None\n        or x1 is not None\n    ):\n        residual_out = torch.empty(\n            M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype\n        )\n        assert residual_out.stride(-1) == 1\n    else:\n        residual_out = None\n    mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None\n    rstd = torch.empty((M,), dtype=torch.float32, device=x.device)\n    if dropout_p > 0.0:\n        seeds = torch.randint(\n            2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64\n        )\n    else:\n        seeds = None\n    if return_dropout_mask and dropout_p > 0.0:\n        dropout_mask = torch.empty(M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool)\n    else:\n        dropout_mask = None\n    # Less than 64KB per feature: enqueue fused kernel\n    MAX_FUSED_SIZE = 65536 // x.element_size()\n    BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n    if N > BLOCK_N:\n        raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n    with torch.cuda.device(x.device.index):\n        _layer_norm_fwd_1pass_kernel[(M,)](\n            x,\n            y,\n            weight,\n            bias,\n            residual,\n            x1,\n            weight1,\n            bias1,\n            y1,\n            residual_out,\n            rowscale,\n            seeds,\n            dropout_mask,\n            mean,\n            rstd,\n            x.stride(0),\n            y.stride(0),\n            residual.stride(0) if residual is not None else 0,\n            residual_out.stride(0) if residual_out is not None else 0,\n            x1.stride(0) if x1 is not None else 0,\n            y1.stride(0) if y1 is not None else 0,\n            M,\n            N,\n            eps,\n            dropout_p,\n            is_rms_norm,\n            BLOCK_N,\n            residual is not None,\n            residual_out is not None,\n            bias is not None,\n            dropout_p > 0.0,\n            dropout_mask is not None,\n            rowscale is not None,\n        )\n    # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0\n    if dropout_mask is not None and x1 is not None:\n        dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0)\n    else:\n        dropout_mask1 = None\n    return (\n        y,\n        y1,\n        mean,\n        rstd,\n        residual_out if residual_out is not None else x,\n        seeds,\n        dropout_mask,\n        dropout_mask1,\n    )\n\n\n@triton.autotune(\n    configs=autotune_configs(pruned_configs_autotune),\n    key=[\"N\", \"HAS_DRESIDUAL\", \"STORE_DRESIDUAL\", \"IS_RMS_NORM\", \"HAS_BIAS\", \"HAS_DROPOUT\"],\n)\n# @triton.heuristics({\"HAS_BIAS\": lambda args: args[\"B\"] is not None})\n# @triton.heuristics({\"HAS_DRESIDUAL\": lambda args: args[\"DRESIDUAL\"] is not None})\n# @triton.heuristics({\"STORE_DRESIDUAL\": lambda args: args[\"DRESIDUAL_IN\"] is not None})\n@triton.heuristics({\"HAS_ROWSCALE\": lambda args: args[\"ROWSCALE\"] is not None})\n@triton.heuristics({\"HAS_DY1\": lambda args: args[\"DY1\"] is not None})\n@triton.heuristics({\"HAS_DX1\": lambda args: args[\"DX1\"] is not None})\n@triton.heuristics({\"HAS_B1\": lambda args: args[\"DB1\"] is not None})\n@triton.heuristics({\"RECOMPUTE_OUTPUT\": lambda args: args[\"Y\"] is not None})\n@triton.jit\ndef _layer_norm_bwd_kernel(\n    X,  # pointer to the input\n    W,  # pointer to the weights\n    B,  # pointer to the biases\n    Y,  # pointer to the output to be recomputed\n    DY,  # pointer to the output gradient\n    DX,  # pointer to the input gradient\n    DW,  # pointer to the partial sum of weights gradient\n    DB,  # pointer to the partial sum of biases gradient\n    DRESIDUAL,\n    W1,\n    DY1,\n    DX1,\n    DW1,\n    DB1,\n    DRESIDUAL_IN,\n    ROWSCALE,\n    SEEDS,\n    Mean,  # pointer to the mean\n    Rstd,  # pointer to the 1/std\n    stride_x_row,  # how much to increase the pointer when moving by 1 row\n    stride_y_row,\n    stride_dy_row,\n    stride_dx_row,\n    stride_dres_row,\n    stride_dy1_row,\n    stride_dx1_row,\n    stride_dres_in_row,\n    M,  # number of rows in X\n    N,  # number of columns in X\n    eps,  # epsilon to avoid division by zero\n    dropout_p,\n    rows_per_program,\n    IS_RMS_NORM: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n    HAS_DRESIDUAL: tl.constexpr,\n    STORE_DRESIDUAL: tl.constexpr,\n    HAS_BIAS: tl.constexpr,\n    HAS_DROPOUT: tl.constexpr,\n    HAS_ROWSCALE: tl.constexpr,\n    HAS_DY1: tl.constexpr,\n    HAS_DX1: tl.constexpr,\n    HAS_B1: tl.constexpr,\n    RECOMPUTE_OUTPUT: tl.constexpr,\n):\n    # Map the program id to the elements of X, DX, and DY it should compute.\n    row_block_id = tl.program_id(0)\n    row_start = row_block_id * rows_per_program\n    # Do not early exit if row_start >= M, because we need to write DW and DB\n    cols = tl.arange(0, BLOCK_N)\n    mask = cols < N\n    X += row_start * stride_x_row\n    if HAS_DRESIDUAL:\n        DRESIDUAL += row_start * stride_dres_row\n    if STORE_DRESIDUAL:\n        DRESIDUAL_IN += row_start * stride_dres_in_row\n    DY += row_start * stride_dy_row\n    DX += row_start * stride_dx_row\n    if HAS_DY1:\n        DY1 += row_start * stride_dy1_row\n    if HAS_DX1:\n        DX1 += row_start * stride_dx1_row\n    if RECOMPUTE_OUTPUT:\n        Y += row_start * stride_y_row\n    w = tl.load(W + cols, mask=mask).to(tl.float32)\n    if RECOMPUTE_OUTPUT and HAS_BIAS:\n        b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)\n    if HAS_DY1:\n        w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)\n    dw = tl.zeros((BLOCK_N,), dtype=tl.float32)\n    if HAS_BIAS:\n        db = tl.zeros((BLOCK_N,), dtype=tl.float32)\n    if HAS_DY1:\n        dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32)\n        if HAS_B1:\n            db1 = tl.zeros((BLOCK_N,), dtype=tl.float32)\n    row_end = min((row_block_id + 1) * rows_per_program, M)\n    for row in range(row_start, row_end):\n        # Load data to SRAM\n        x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)\n        dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)\n        if HAS_DY1:\n            dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32)\n        if not IS_RMS_NORM:\n            mean = tl.load(Mean + row)\n        rstd = tl.load(Rstd + row)\n        # Compute dx\n        xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n        xhat = tl.where(mask, xhat, 0.0)\n        if RECOMPUTE_OUTPUT:\n            y = xhat * w + b if HAS_BIAS else xhat * w\n            tl.store(Y + cols, y, mask=mask)\n        wdy = w * dy\n        dw += dy * xhat\n        if HAS_BIAS:\n            db += dy\n        if HAS_DY1:\n            wdy += w1 * dy1\n            dw1 += dy1 * xhat\n            if HAS_B1:\n                db1 += dy1\n        if not IS_RMS_NORM:\n            c1 = tl.sum(xhat * wdy, axis=0) / N\n            c2 = tl.sum(wdy, axis=0) / N\n            dx = (wdy - (xhat * c1 + c2)) * rstd\n        else:\n            c1 = tl.sum(xhat * wdy, axis=0) / N\n            dx = (wdy - xhat * c1) * rstd\n        if HAS_DRESIDUAL:\n            dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)\n            dx += dres\n        # Write dx\n        if STORE_DRESIDUAL:\n            tl.store(DRESIDUAL_IN + cols, dx, mask=mask)\n        if HAS_DX1:\n            if HAS_DROPOUT:\n                keep_mask = (\n                    tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p\n                )\n                dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)\n            else:\n                dx1 = dx\n            tl.store(DX1 + cols, dx1, mask=mask)\n        if HAS_DROPOUT:\n            keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p\n            dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)\n        if HAS_ROWSCALE:\n            rowscale = tl.load(ROWSCALE + row).to(tl.float32)\n            dx *= rowscale\n        tl.store(DX + cols, dx, mask=mask)\n\n        X += stride_x_row\n        if HAS_DRESIDUAL:\n            DRESIDUAL += stride_dres_row\n        if STORE_DRESIDUAL:\n            DRESIDUAL_IN += stride_dres_in_row\n        if RECOMPUTE_OUTPUT:\n            Y += stride_y_row\n        DY += stride_dy_row\n        DX += stride_dx_row\n        if HAS_DY1:\n            DY1 += stride_dy1_row\n        if HAS_DX1:\n            DX1 += stride_dx1_row\n    tl.store(DW + row_block_id * N + cols, dw, mask=mask)\n    if HAS_BIAS:\n        tl.store(DB + row_block_id * N + cols, db, mask=mask)\n    if HAS_DY1:\n        tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask)\n        if HAS_B1:\n            tl.store(DB1 + row_block_id * N + cols, db1, mask=mask)\n\n\ndef _layer_norm_bwd(\n    dy,\n    x,\n    weight,\n    bias,\n    eps,\n    mean,\n    rstd,\n    dresidual=None,\n    dy1=None,\n    weight1=None,\n    bias1=None,\n    seeds=None,\n    dropout_p=0.0,\n    rowscale=None,\n    has_residual=False,\n    has_x1=False,\n    is_rms_norm=False,\n    x_dtype=None,\n    recompute_output=False,\n):\n    M, N = x.shape\n    assert x.stride(-1) == 1\n    assert dy.stride(-1) == 1\n    assert dy.shape == (M, N)\n    if dresidual is not None:\n        assert dresidual.stride(-1) == 1\n        assert dresidual.shape == (M, N)\n    assert weight.shape == (N,)\n    assert weight.stride(-1) == 1\n    if bias is not None:\n        assert bias.stride(-1) == 1\n        assert bias.shape == (N,)\n    if dy1 is not None:\n        assert weight1 is not None\n        assert dy1.shape == dy.shape\n        assert dy1.stride(-1) == 1\n    if weight1 is not None:\n        assert weight1.shape == (N,)\n        assert weight1.stride(-1) == 1\n    if bias1 is not None:\n        assert bias1.shape == (N,)\n        assert bias1.stride(-1) == 1\n    if seeds is not None:\n        assert seeds.is_contiguous()\n        assert seeds.shape == (M if not has_x1 else M * 2,)\n    if rowscale is not None:\n        assert rowscale.is_contiguous()\n        assert rowscale.shape == (M,)\n    # allocate output\n    dx = (\n        torch.empty_like(x)\n        if x_dtype is None\n        else torch.empty(M, N, dtype=x_dtype, device=x.device)\n    )\n    dresidual_in = (\n        torch.empty_like(x)\n        if has_residual\n        and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1)\n        else None\n    )\n    dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None\n    y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None\n    if recompute_output:\n        assert weight1 is None, \"recompute_output is not supported with parallel LayerNorm\"\n\n    # Less than 64KB per feature: enqueue fused kernel\n    MAX_FUSED_SIZE = 65536 // x.element_size()\n    BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n    if N > BLOCK_N:\n        raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n    sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count\n    _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)\n    _db = (\n        torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)\n        if bias is not None\n        else None\n    )\n    _dw1 = torch.empty_like(_dw) if weight1 is not None else None\n    _db1 = torch.empty_like(_db) if bias1 is not None else None\n    rows_per_program = math.ceil(M / sm_count)\n    grid = (sm_count,)\n    with torch.cuda.device(x.device.index):\n        _layer_norm_bwd_kernel[grid](\n            x,\n            weight,\n            bias,\n            y,\n            dy,\n            dx,\n            _dw,\n            _db,\n            dresidual,\n            weight1,\n            dy1,\n            dx1,\n            _dw1,\n            _db1,\n            dresidual_in,\n            rowscale,\n            seeds,\n            mean,\n            rstd,\n            x.stride(0),\n            0 if not recompute_output else y.stride(0),\n            dy.stride(0),\n            dx.stride(0),\n            dresidual.stride(0) if dresidual is not None else 0,\n            dy1.stride(0) if dy1 is not None else 0,\n            dx1.stride(0) if dx1 is not None else 0,\n            dresidual_in.stride(0) if dresidual_in is not None else 0,\n            M,\n            N,\n            eps,\n            dropout_p,\n            rows_per_program,\n            is_rms_norm,\n            BLOCK_N,\n            dresidual is not None,\n            dresidual_in is not None,\n            bias is not None,\n            dropout_p > 0.0,\n        )\n    dw = _dw.sum(0).to(weight.dtype)\n    db = _db.sum(0).to(bias.dtype) if bias is not None else None\n    dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None\n    db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None\n    # Don't need to compute dresidual_in separately in this case\n    if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None:\n        dresidual_in = dx\n    if has_x1 and dropout_p == 0.0:\n        dx1 = dx\n    return (\n        (dx, dw, db, dresidual_in, dx1, dw1, db1)\n        if not recompute_output\n        else (dx, dw, db, dresidual_in, dx1, dw1, db1, y)\n    )\n\n\nclass LayerNormFn(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        x,\n        weight,\n        bias,\n        residual=None,\n        x1=None,\n        weight1=None,\n        bias1=None,\n        eps=1e-6,\n        dropout_p=0.0,\n        rowscale=None,\n        prenorm=False,\n        residual_in_fp32=False,\n        is_rms_norm=False,\n        return_dropout_mask=False,\n    ):\n        x_shape_og = x.shape\n        # reshape input data into 2D tensor\n        x = x.reshape(-1, x.shape[-1])\n        if x.stride(-1) != 1:\n            x = x.contiguous()\n        if residual is not None:\n            assert residual.shape == x_shape_og\n            residual = residual.reshape(-1, residual.shape[-1])\n            if residual.stride(-1) != 1:\n                residual = residual.contiguous()\n        if x1 is not None:\n            assert x1.shape == x_shape_og\n            assert rowscale is None, \"rowscale is not supported with parallel LayerNorm\"\n            x1 = x1.reshape(-1, x1.shape[-1])\n            if x1.stride(-1) != 1:\n                x1 = x1.contiguous()\n        weight = weight.contiguous()\n        if bias is not None:\n            bias = bias.contiguous()\n        if weight1 is not None:\n            weight1 = weight1.contiguous()\n        if bias1 is not None:\n            bias1 = bias1.contiguous()\n        if rowscale is not None:\n            rowscale = rowscale.reshape(-1).contiguous()\n        residual_dtype = (\n            residual.dtype\n            if residual is not None\n            else (torch.float32 if residual_in_fp32 else None)\n        )\n        y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd(\n            x,\n            weight,\n            bias,\n            eps,\n            residual,\n            x1,\n            weight1,\n            bias1,\n            dropout_p=dropout_p,\n            rowscale=rowscale,\n            residual_dtype=residual_dtype,\n            is_rms_norm=is_rms_norm,\n            return_dropout_mask=return_dropout_mask,\n        )\n        ctx.save_for_backward(\n            residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd\n        )\n        ctx.x_shape_og = x_shape_og\n        ctx.eps = eps\n        ctx.dropout_p = dropout_p\n        ctx.is_rms_norm = is_rms_norm\n        ctx.has_residual = residual is not None\n        ctx.has_x1 = x1 is not None\n        ctx.prenorm = prenorm\n        ctx.x_dtype = x.dtype\n        y = y.reshape(x_shape_og)\n        y1 = y1.reshape(x_shape_og) if y1 is not None else None\n        residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None\n        dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None\n        dropout_mask1 = dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None\n        if not return_dropout_mask:\n            if weight1 is None:\n                return y if not prenorm else (y, residual_out)\n            else:\n                return (y, y1) if not prenorm else (y, y1, residual_out)\n        else:\n            if weight1 is None:\n                return (\n                    (y, dropout_mask, dropout_mask1)\n                    if not prenorm\n                    else (y, residual_out, dropout_mask, dropout_mask1)\n                )\n            else:\n                return (\n                    (y, y1, dropout_mask, dropout_mask1)\n                    if not prenorm\n                    else (y, y1, residual_out, dropout_mask, dropout_mask1)\n                )\n\n    @staticmethod\n    def backward(ctx, dy, *args):\n        x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors\n        dy = dy.reshape(-1, dy.shape[-1])\n        if dy.stride(-1) != 1:\n            dy = dy.contiguous()\n        assert dy.shape == x.shape\n        if weight1 is not None:\n            dy1, args = args[0], args[1:]\n            dy1 = dy1.reshape(-1, dy1.shape[-1])\n            if dy1.stride(-1) != 1:\n                dy1 = dy1.contiguous()\n            assert dy1.shape == x.shape\n        else:\n            dy1 = None\n        if ctx.prenorm:\n            dresidual = args[0]\n            dresidual = dresidual.reshape(-1, dresidual.shape[-1])\n            if dresidual.stride(-1) != 1:\n                dresidual = dresidual.contiguous()\n            assert dresidual.shape == x.shape\n        else:\n            dresidual = None\n        dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd(\n            dy,\n            x,\n            weight,\n            bias,\n            ctx.eps,\n            mean,\n            rstd,\n            dresidual,\n            dy1,\n            weight1,\n            bias1,\n            seeds,\n            ctx.dropout_p,\n            rowscale,\n            ctx.has_residual,\n            ctx.has_x1,\n            ctx.is_rms_norm,\n            x_dtype=ctx.x_dtype,\n        )\n        return (\n            dx.reshape(ctx.x_shape_og),\n            dw,\n            db,\n            dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,\n            dx1.reshape(ctx.x_shape_og) if dx1 is not None else None,\n            dw1,\n            db1,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n        )\n\n\ndef layer_norm_fn(\n    x,\n    weight,\n    bias,\n    residual=None,\n    x1=None,\n    weight1=None,\n    bias1=None,\n    eps=1e-6,\n    dropout_p=0.0,\n    rowscale=None,\n    prenorm=False,\n    residual_in_fp32=False,\n    is_rms_norm=False,\n    return_dropout_mask=False,\n):\n    return LayerNormFn.apply(\n        x,\n        weight,\n        bias,\n        residual,\n        x1,\n        weight1,\n        bias1,\n        eps,\n        dropout_p,\n        rowscale,\n        prenorm,\n        residual_in_fp32,\n        is_rms_norm,\n        return_dropout_mask,\n    )\n\n\ndef rms_norm_fn(\n    x,\n    weight,\n    bias,\n    residual=None,\n    x1=None,\n    weight1=None,\n    bias1=None,\n    eps=1e-6,\n    dropout_p=0.0,\n    rowscale=None,\n    prenorm=False,\n    residual_in_fp32=False,\n    return_dropout_mask=False,\n):\n    return LayerNormFn.apply(\n        x,\n        weight,\n        bias,\n        residual,\n        x1,\n        weight1,\n        bias1,\n        eps,\n        dropout_p,\n        rowscale,\n        prenorm,\n        residual_in_fp32,\n        True,\n        return_dropout_mask,\n    )\n\n\nclass RMSNorm(torch.nn.Module):\n\n    def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, dtype=None):\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        self.eps = eps\n        if dropout_p > 0.0:\n            self.drop = torch.nn.Dropout(dropout_p)\n        else:\n            self.drop = None\n        self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))\n        self.register_parameter(\"bias\", None)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        torch.nn.init.ones_(self.weight)\n\n    def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):\n        return rms_norm_fn(\n            x,\n            self.weight,\n            self.bias,\n            residual=residual,\n            eps=self.eps,\n            dropout_p=self.drop.p if self.drop is not None and self.training else 0.0,\n            prenorm=prenorm,\n            residual_in_fp32=residual_in_fp32,\n        )\n\n\nclass LayerNormLinearFn(torch.autograd.Function):\n    @staticmethod\n    @custom_fwd\n    def forward(\n        ctx,\n        x,\n        norm_weight,\n        norm_bias,\n        linear_weight,\n        linear_bias,\n        residual=None,\n        eps=1e-6,\n        prenorm=False,\n        residual_in_fp32=False,\n        is_rms_norm=False,\n    ):\n        x_shape_og = x.shape\n        # reshape input data into 2D tensor\n        x = x.reshape(-1, x.shape[-1])\n        if x.stride(-1) != 1:\n            x = x.contiguous()\n        if residual is not None:\n            assert residual.shape == x_shape_og\n            residual = residual.reshape(-1, residual.shape[-1])\n            if residual.stride(-1) != 1:\n                residual = residual.contiguous()\n        norm_weight = norm_weight.contiguous()\n        if norm_bias is not None:\n            norm_bias = norm_bias.contiguous()\n        residual_dtype = (\n            residual.dtype\n            if residual is not None\n            else (torch.float32 if residual_in_fp32 else None)\n        )\n        y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd(\n            x,\n            norm_weight,\n            norm_bias,\n            eps,\n            residual,\n            out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(),\n            residual_dtype=residual_dtype,\n            is_rms_norm=is_rms_norm,\n        )\n        y = y.reshape(x_shape_og)\n        dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype\n        linear_weight = linear_weight.to(dtype)\n        linear_bias = linear_bias.to(dtype) if linear_bias is not None else None\n        out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)\n        # We don't store y, will be recomputed in the backward pass to save memory\n        ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd)\n        ctx.x_shape_og = x_shape_og\n        ctx.eps = eps\n        ctx.is_rms_norm = is_rms_norm\n        ctx.has_residual = residual is not None\n        ctx.prenorm = prenorm\n        ctx.x_dtype = x.dtype\n        ctx.linear_bias_is_none = linear_bias is None\n        return out if not prenorm else (out, residual_out.reshape(x_shape_og))\n\n    @staticmethod\n    @custom_bwd\n    def backward(ctx, dout, *args):\n        x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors\n        dout = dout.reshape(-1, dout.shape[-1])\n        dy = F.linear(dout, linear_weight.t())\n        dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)\n        if dy.stride(-1) != 1:\n            dy = dy.contiguous()\n        assert dy.shape == x.shape\n        if ctx.prenorm:\n            dresidual = args[0]\n            dresidual = dresidual.reshape(-1, dresidual.shape[-1])\n            if dresidual.stride(-1) != 1:\n                dresidual = dresidual.contiguous()\n            assert dresidual.shape == x.shape\n        else:\n            dresidual = None\n        dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd(\n            dy,\n            x,\n            norm_weight,\n            norm_bias,\n            ctx.eps,\n            mean,\n            rstd,\n            dresidual=dresidual,\n            has_residual=ctx.has_residual,\n            is_rms_norm=ctx.is_rms_norm,\n            x_dtype=ctx.x_dtype,\n            recompute_output=True,\n        )\n        dlinear_weight = torch.einsum(\"bo,bi->oi\", dout, y)\n        return (\n            dx.reshape(ctx.x_shape_og),\n            dnorm_weight,\n            dnorm_bias,\n            dlinear_weight,\n            dlinear_bias,\n            dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,\n            None,\n            None,\n            None,\n            None,\n        )\n\n\ndef layer_norm_linear_fn(\n    x,\n    norm_weight,\n    norm_bias,\n    linear_weight,\n    linear_bias,\n    residual=None,\n    eps=1e-6,\n    prenorm=False,\n    residual_in_fp32=False,\n    is_rms_norm=False,\n):\n    return LayerNormLinearFn.apply(\n        x,\n        norm_weight,\n        norm_bias,\n        linear_weight,\n        linear_bias,\n        residual,\n        eps,\n        prenorm,\n        residual_in_fp32,\n        is_rms_norm,\n    )\n"
  },
  {
    "path": "mamba_ssm/ops/triton/layernorm_gated.py",
    "content": "# Copyright (c) 2024, Tri Dao.\n# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html\n# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.\n# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.\n# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.\n\nimport math\n\nimport torch\nimport torch.nn.functional as F\n\nimport triton\nimport triton.language as tl\n\nfrom einops import rearrange\n\n\ndef rms_norm_ref(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, upcast=True):\n    dtype = x.dtype\n    N = x.shape[-1]\n    weight = weight.float()\n    bias = bias.float() if bias is not None else None\n    if upcast:\n        x = x.float()\n        z = z.float() if z is not None else z\n    if z is not None and not norm_before_gate:\n        x = x * F.silu(z)\n    if group_size is None:\n        rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)\n        out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)\n    else:\n        x_group = rearrange(x, \"... (g d) -> ... g d\", d=group_size)\n        rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps)\n        out = rearrange(x_group * rstd, \"... g d -> ... (g d)\") * weight\n        if bias is not None:\n            out = out + bias\n    if z is not None and norm_before_gate:\n        out *= F.silu(z)\n    return out.to(dtype)\n\n\n@triton.heuristics({\"HAS_BIAS\": lambda args: args[\"B\"] is not None})\n@triton.heuristics({\"HAS_Z\": lambda args: args[\"Z\"] is not None})\n@triton.jit\ndef _layer_norm_fwd_1pass_kernel(\n    X,  # pointer to the input\n    Y,  # pointer to the output\n    W,  # pointer to the weights\n    B,  # pointer to the biases\n    Z,  # pointer to the other branch\n    Mean,  # pointer to the mean\n    Rstd,  # pointer to the 1/std\n    stride_x_row,  # how much to increase the pointer when moving by 1 row\n    stride_y_row,\n    stride_z_row,\n    M,  # number of rows in X\n    N,  # number of columns in X\n    eps,  # epsilon to avoid division by zero\n    BLOCK_N: tl.constexpr,\n    HAS_BIAS: tl.constexpr,\n    HAS_Z: tl.constexpr,\n    NORM_BEFORE_GATE: tl.constexpr,\n    IS_RMS_NORM: tl.constexpr,\n):\n    # Map the program id to the row of X and Y it should compute.\n    row = tl.program_id(0)\n    group = tl.program_id(1)\n    X += row * stride_x_row + group * N\n    Y += row * stride_y_row + group * N\n    if HAS_Z:\n        Z += row * stride_z_row + group * N\n    if not IS_RMS_NORM:\n        Mean += group * M\n    Rstd += group * M\n    W += group * N\n    if HAS_BIAS:\n        B += group * N\n    # Compute mean and variance\n    cols = tl.arange(0, BLOCK_N)\n    x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)\n    if HAS_Z and not NORM_BEFORE_GATE:\n        z = tl.load(Z + cols, mask=cols < N).to(tl.float32)\n        x *= z * tl.sigmoid(z)\n    if not IS_RMS_NORM:\n        mean = tl.sum(x, axis=0) / N\n        tl.store(Mean + row, mean)\n        xbar = tl.where(cols < N, x - mean, 0.)\n        var = tl.sum(xbar * xbar, axis=0) / N\n    else:\n        xbar = tl.where(cols < N, x, 0.)\n        var = tl.sum(xbar * xbar, axis=0) / N\n    rstd = 1 / tl.sqrt(var + eps)\n    tl.store(Rstd + row, rstd)\n    # Normalize and apply linear transformation\n    mask = cols < N\n    w = tl.load(W + cols, mask=mask).to(tl.float32)\n    if HAS_BIAS:\n        b = tl.load(B + cols, mask=mask).to(tl.float32)\n    x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n    y = x_hat * w + b if HAS_BIAS else x_hat * w\n    if HAS_Z and NORM_BEFORE_GATE:\n        z = tl.load(Z + cols, mask=mask).to(tl.float32)\n        y *= z * tl.sigmoid(z)\n    # Write output\n    tl.store(Y + cols, y, mask=mask)\n\n\ndef _layer_norm_fwd(x, weight, bias, eps, z=None, out=None, group_size=None, norm_before_gate=True, is_rms_norm=False):\n    M, N = x.shape\n    if group_size is None:\n        group_size = N\n    assert N % group_size == 0\n    ngroups = N // group_size\n    assert x.stride(-1) == 1\n    if z is not None:\n        assert z.stride(-1) == 1\n        assert z.shape == (M, N)\n    assert weight.shape == (N,)\n    assert weight.stride(-1) == 1\n    if bias is not None:\n        assert bias.stride(-1) == 1\n        assert bias.shape == (N,)\n    # allocate output\n    if out is not None:\n        assert out.shape == x.shape\n    else:\n        out = torch.empty_like(x)\n    assert out.stride(-1) == 1\n    mean = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) if not is_rms_norm else None\n    rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)\n    # Less than 64KB per feature: enqueue fused kernel\n    MAX_FUSED_SIZE = 65536 // x.element_size()\n    BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))\n    if group_size > BLOCK_N:\n        raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n    # heuristics for number of warps\n    num_warps = min(max(BLOCK_N // 256, 1), 8)\n    grid = (M, ngroups)\n    with torch.cuda.device(x.device.index):\n        _layer_norm_fwd_1pass_kernel[grid](x, out, weight, bias, z, mean, rstd,\n                                           x.stride(0), out.stride(0), z.stride(0) if z is not None else 0,\n                                           M, group_size, eps,\n                                           BLOCK_N=BLOCK_N,\n                                           NORM_BEFORE_GATE=norm_before_gate,\n                                           IS_RMS_NORM=is_rms_norm,\n                                           num_warps=num_warps)\n    return out, mean, rstd\n\n\n\n@triton.heuristics({\"HAS_BIAS\": lambda args: args[\"B\"] is not None})\n@triton.heuristics({\"HAS_Z\": lambda args: args[\"Z\"] is not None})\n@triton.heuristics({\"RECOMPUTE_OUTPUT\": lambda args: args[\"Y\"] is not None})\n@triton.jit\ndef _layer_norm_bwd_kernel(\n    X,   # pointer to the input\n    W,   # pointer to the weights\n    B,   # pointer to the biases\n    Z,   # pointer to the other branch\n    Y,   # pointer to the output to be recomputed\n    DY,  # pointer to the output gradient\n    DX,  # pointer to the input gradient\n    DW,  # pointer to the partial sum of weights gradient\n    DB,  # pointer to the partial sum of biases gradient\n    DZ,  # pointer to the other branch\n    Mean,   # pointer to the mean\n    Rstd,   # pointer to the 1/std\n    stride_x_row,  # how much to increase the pointer when moving by 1 row\n    stride_z_row,\n    stride_y_row,\n    stride_dy_row,\n    stride_dx_row,\n    stride_dz_row,\n    stride_dw_row,\n    stride_db_row,\n    M,  # number of rows in X\n    N,  # number of columns in X\n    eps,  # epsilon to avoid division by zero\n    rows_per_program,\n    NORM_BEFORE_GATE: tl.constexpr,\n    IS_RMS_NORM: tl.constexpr,\n    HAS_BIAS: tl.constexpr,\n    HAS_Z: tl.constexpr,\n    RECOMPUTE_OUTPUT: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n):\n    # Map the program id to the elements of X, DX, and DY it should compute.\n    row_block_id = tl.program_id(0)\n    group = tl.program_id(1)\n    row_start = row_block_id * rows_per_program\n    cols = tl.arange(0, BLOCK_N)\n    mask = cols < N\n    X += row_start * stride_x_row + group * N\n    if HAS_Z:\n        Z += row_start * stride_z_row + group * N\n        DZ += row_start * stride_dz_row + group * N\n    DY += row_start * stride_dy_row + group * N\n    DX += row_start * stride_dx_row + group * N\n    if RECOMPUTE_OUTPUT:\n        Y += row_start * stride_y_row + group * N\n    if not IS_RMS_NORM:\n        Mean += group * M\n    Rstd += group * M\n    W += group * N\n    w = tl.load(W + cols, mask=mask).to(tl.float32)\n    if (RECOMPUTE_OUTPUT or HAS_Z) and HAS_BIAS:\n        B += group * N\n        b = tl.load(B + cols, mask=mask, other=0.).to(tl.float32)\n    dw = tl.zeros((BLOCK_N,), dtype=tl.float32)\n    if HAS_BIAS:\n        db = tl.zeros((BLOCK_N,), dtype=tl.float32)\n    row_end = min((row_block_id + 1) * rows_per_program, M)\n    for row in range(row_start, row_end):\n        # Load data to SRAM\n        x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)\n        dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)\n        if not IS_RMS_NORM:\n            mean = tl.load(Mean + row)\n        if HAS_Z and not NORM_BEFORE_GATE:\n            z = tl.load(Z + cols, mask=mask, other=0.).to(tl.float32)\n            x_og = x\n            x = x_og * z * tl.sigmoid(z)\n        rstd = tl.load(Rstd + row)\n        # Compute dx\n        xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n        xhat = tl.where(mask, xhat, 0.)\n        if HAS_Z and NORM_BEFORE_GATE:\n            z = tl.load(Z + cols, mask=mask, other=0.).to(tl.float32)\n            z_sigmoid = tl.sigmoid(z)\n            y = xhat * w + b if HAS_BIAS else xhat * w\n            if RECOMPUTE_OUTPUT:\n                tl.store(Y + cols, y * z * z_sigmoid, mask=mask)\n            dz = dy * y * z_sigmoid * (1 + z * (1 - z_sigmoid))\n            tl.store(DZ + cols, dz, mask=mask)\n            dy *= z * z_sigmoid\n        else:\n            if RECOMPUTE_OUTPUT:\n                y = xhat * w + b if HAS_BIAS else xhat * w\n                tl.store(Y + cols, y, mask=mask)\n        wdy = w * dy\n        c1 = tl.sum(xhat * wdy, axis=0) / N\n        if not IS_RMS_NORM:\n            c2 = tl.sum(wdy, axis=0) / N\n            dx = (wdy - (xhat * c1 + c2)) * rstd\n        else:\n            dx = (wdy - xhat * c1) * rstd\n        dw += dy * xhat\n        if HAS_BIAS:\n            db += dy\n        if HAS_Z and not NORM_BEFORE_GATE:\n            z_sigmoid = tl.sigmoid(z)\n            dz = dx * x_og * z_sigmoid * (1 + z * (1 - z_sigmoid))\n            tl.store(DZ + cols, dz, mask=mask)\n            dx *= z * z_sigmoid\n        # Write dx\n        tl.store(DX + cols, dx, mask=mask)\n\n        X += stride_x_row\n        if HAS_Z:\n            Z += stride_z_row\n            DZ += stride_dz_row\n        if RECOMPUTE_OUTPUT:\n            Y += stride_y_row\n        DY += stride_dy_row\n        DX += stride_dx_row\n    tl.store(DW + row_block_id * stride_dw_row + group * N + cols, dw, mask=mask)\n    if HAS_BIAS:\n        tl.store(DB + row_block_id * stride_db_row + group * N + cols, db, mask=mask)\n\n\ndef _layer_norm_bwd(dy, x, weight, bias, eps, mean, rstd, z=None, group_size=None,\n                    norm_before_gate=True, is_rms_norm=False, recompute_output=False, dz=None, out=None):\n    M, N = x.shape\n    if group_size is None:\n        group_size = N\n    assert N % group_size == 0\n    ngroups = N // group_size\n    assert x.stride(-1) == 1\n    assert dy.stride(-1) == 1\n    assert dy.shape == (M, N)\n    if z is not None:\n        assert z.stride(-1) == 1\n        assert z.shape == (M, N)\n    assert weight.shape == (N,)\n    assert weight.stride(-1) == 1\n    if bias is not None:\n        assert bias.stride(-1) == 1\n        assert bias.shape == (N,)\n    # allocate output\n    dx = torch.empty_like(x)\n    if dz is not None:\n        assert z is not None\n        assert dz.shape == z.shape\n        assert dz.stride(-1) == 1\n    else:\n        dz = torch.empty_like(z) if z is not None else None\n    if recompute_output:\n        if out is None:\n            out = torch.empty_like(x)\n        assert out.shape == x.shape\n\n    # Less than 64KB per feature: enqueue fused kernel\n    MAX_FUSED_SIZE = 65536 // x.element_size()\n    BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))\n    if group_size > BLOCK_N:\n        raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n    # heuristics for number of warps\n    num_warps = min(max(BLOCK_N // 256, 1), 8)\n    sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count\n    # If group size is small (e.g., 64), we're only using 1 warp. So having just 108 programs\n    # would limit the occupancy.\n    nrow_groups = math.ceil(sm_count * math.ceil(4 / num_warps) / ngroups)\n    _dw = torch.empty((nrow_groups, N), dtype=torch.float32, device=weight.device)\n    _db = torch.empty((nrow_groups, N), dtype=torch.float32, device=bias.device) if bias is not None else None\n    rows_per_program = math.ceil(M / nrow_groups)\n    grid = (nrow_groups, ngroups)\n    with torch.cuda.device(x.device.index):\n        _layer_norm_bwd_kernel[grid](x, weight, bias, z, out if recompute_output else None,\n                                     dy, dx, _dw, _db, dz, mean, rstd,\n                                     x.stride(0),\n                                     z.stride(0) if z is not None else 0,\n                                     0 if not recompute_output else out.stride(0),\n                                     dy.stride(0), dx.stride(0),\n                                     dz.stride(0) if dz is not None else 0,\n                                     _dw.stride(0),\n                                     _db.stride(0) if _db is not None else 0,\n                                     M, group_size, eps,\n                                     rows_per_program,\n                                     BLOCK_N=BLOCK_N,\n                                     NORM_BEFORE_GATE=norm_before_gate,\n                                     IS_RMS_NORM=is_rms_norm,\n                                     num_warps=num_warps)\n    dw = _dw.sum(0).to(weight.dtype)\n    db = _db.sum(0).to(bias.dtype) if bias is not None else None\n    return (dx, dw, db, dz) if not recompute_output else (dx, dw, db, dz, out)\n\n\nclass LayerNormFn(torch.autograd.Function):\n\n    @staticmethod\n    def forward(ctx, x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True,\n                is_rms_norm=False):\n        \"\"\"If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))\n        \"\"\"\n\n        x_shape_og = x.shape\n        # reshape input data into 2D tensor\n        x = x.reshape(-1, x.shape[-1])\n        if x.stride(-1) != 1:\n            x = x.contiguous()\n        if z is not None:\n            assert z.shape == x_shape_og\n            z = z.reshape(-1, z.shape[-1])\n            if z.stride(-1) != 1:\n                z = z.contiguous()\n        weight = weight.contiguous()\n        if bias is not None:\n            bias = bias.contiguous()\n        y, mean, rstd = _layer_norm_fwd(x, weight, bias, eps, z=z, group_size=group_size, norm_before_gate=norm_before_gate, is_rms_norm=is_rms_norm)\n        ctx.save_for_backward(x, weight, bias, mean, rstd, z)\n        ctx.x_shape_og = x_shape_og\n        ctx.eps = eps\n        ctx.group_size = group_size\n        ctx.norm_before_gate = norm_before_gate\n        ctx.is_rms_norm = is_rms_norm\n        return y.reshape(x_shape_og)\n\n    @staticmethod\n    def backward(ctx, dy):\n        x, weight, bias, mean, rstd, z = ctx.saved_tensors\n        dy = dy.reshape(-1, dy.shape[-1])\n        if dy.stride(-1) != 1:\n            dy = dy.contiguous()\n        assert dy.shape == x.shape\n        dx, dw, db, dz = _layer_norm_bwd(dy, x, weight, bias, ctx.eps, mean, rstd, z, ctx.group_size,\n                                         ctx.norm_before_gate, ctx.is_rms_norm)\n        return dx.reshape(ctx.x_shape_og), dw, db, dz.reshape(ctx.x_shape_og) if dz is not None else None, None, None, None, None\n\n\ndef layernorm_fn(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, is_rms_norm=False):\n    return LayerNormFn.apply(x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm)\n\n\ndef rmsnorm_fn(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True):\n    return LayerNormFn.apply(x, weight, bias, z, eps, group_size, norm_before_gate, True)\n\n\nclass LayerNorm(torch.nn.Module):\n\n    def __init__(self, hidden_size, eps=1e-5, group_size=None, norm_before_gate=True, device=None, dtype=None):\n        \"\"\"If group_size is not None, we do GroupNorm with each group having group_size elements.\n        group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).\n        \"\"\"\n\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        self.eps = eps\n        self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))\n        self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))\n        self.group_size = group_size\n        self.norm_before_gate = norm_before_gate\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        torch.nn.init.ones_(self.weight)\n        torch.nn.init.zeros_(self.bias)\n\n    def forward(self, x, z=None):\n        \"\"\"If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))\n        \"\"\"\n        return layernorm_fn(x, self.weight, self.bias, z=z, group_size=self.group_size, eps=self.eps,\n                            norm_before_gate=self.norm_before_gate)\n\n\nclass RMSNorm(torch.nn.Module):\n\n    def __init__(self, hidden_size, eps=1e-5, group_size=None, norm_before_gate=True, device=None, dtype=None):\n        \"\"\"If group_size is not None, we do GroupNorm with each group having group_size elements.\n        group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).\n        \"\"\"\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        self.eps = eps\n        self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))\n        self.register_parameter(\"bias\", None)\n        self.group_size = group_size\n        self.norm_before_gate = norm_before_gate\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        torch.nn.init.ones_(self.weight)\n\n    def forward(self, x, z=None):\n        \"\"\"If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))\n        \"\"\"\n        return rmsnorm_fn(x, self.weight, self.bias, z=z, eps=self.eps, group_size=self.group_size,\n                          norm_before_gate=self.norm_before_gate)\n"
  },
  {
    "path": "mamba_ssm/ops/triton/mamba3/angle_dt.py",
    "content": "from typing import Tuple, Optional\n\nimport torch\nfrom torch import Tensor\n\nimport triton\nimport triton.language as tl\nfrom mamba_ssm.ops.triton.mamba3.utils import tanh_approx, sech2_approx\n\n\n# -----------------------------------------------------------------------------\n# Forward kernel\n# -----------------------------------------------------------------------------\n\n@triton.autotune(\n    configs=[\n        triton.Config({}, num_stages=s, num_warps=w)\n        for s in [1, 2, 3]\n        for w in [2, 4, 8]\n    ],\n    key=[\"CHUNK_SIZE\", \"BLOCK_D\", \"HAS_INIT_STATE\", \"RETURN_OUTPUT_STATE\", \"IS_VARLEN\"],\n)\n@triton.jit\ndef angle_dt_fwd_kernel(\n    # Outputs\n    OUT, OUTPUT_STATE,\n    # Inputs\n    ANGLE, DT, INIT_STATE, CU_SEQLENS,\n    # Strides for OUT (batch, seqlen, nheads, dim)\n    stride_out_batch, stride_out_seq, stride_out_head, stride_out_dim,\n    # Strides for OUTPUT_STATE (num_sequences, nheads, dim)\n    stride_output_state_seq, stride_output_state_head, stride_output_state_dim,\n    # Strides for ANGLE (batch, seqlen, nheads, dim)\n    stride_angle_batch, stride_angle_seq, stride_angle_head, stride_angle_dim,\n    # Strides for DT (batch, nheads, seqlen)\n    stride_dt_batch, stride_dt_head, stride_dt_seq,\n    # Strides for INIT_STATE (num_sequences, nheads, dim)\n    stride_init_seq, stride_init_head, stride_init_dim,\n    # Stride for CU_SEQLENS\n    stride_cu_seqlen,\n    # Dimensions\n    seqlen, dim,\n    # Meta-parameters\n    CHUNK_SIZE: tl.constexpr,\n    BLOCK_D: tl.constexpr,\n    HAS_INIT_STATE: tl.constexpr,\n    RETURN_OUTPUT_STATE: tl.constexpr,\n    IS_VARLEN: tl.constexpr,\n):\n    pid_h = tl.program_id(0)\n    pid_b = tl.program_id(1)\n\n    # Handle varlen mode\n    if IS_VARLEN:\n        pid_seq = tl.program_id(2)\n        seq_idx = pid_seq\n        cu_seqlen_start = tl.load(CU_SEQLENS + pid_seq * stride_cu_seqlen).to(tl.int32)\n        cu_seqlen_end = tl.load(CU_SEQLENS + (pid_seq + 1) * stride_cu_seqlen).to(tl.int32)\n        seq_len = cu_seqlen_end - cu_seqlen_start\n        seq_offset = cu_seqlen_start\n    else:\n        seq_idx = pid_b\n        seq_len = seqlen\n        seq_offset = 0\n\n    nchunks = tl.cdiv(seq_len, CHUNK_SIZE)\n\n    # Offset base pointers by batch and head\n    ANGLE += pid_b * stride_angle_batch + pid_h * stride_angle_head + seq_offset * stride_angle_seq\n    DT += pid_b * stride_dt_batch + pid_h * stride_dt_head + seq_offset * stride_dt_seq\n    OUT += pid_b * stride_out_batch + pid_h * stride_out_head + seq_offset * stride_out_seq\n\n    dim_range = tl.arange(0, BLOCK_D)\n    dim_mask = dim_range < dim\n\n    # Initialize state from init_state or zeros\n    if HAS_INIT_STATE:\n        init_ptrs = INIT_STATE + seq_idx * stride_init_seq + pid_h * stride_init_head + dim_range * stride_init_dim\n        state = tl.load(init_ptrs, mask=dim_mask, other=0.0).to(tl.float32)\n    else:\n        state = tl.zeros((BLOCK_D,), dtype=tl.float32)\n\n    PI = 3.141592653589793\n    TWO_PI = 2 * PI\n\n    for chunk_idx in range(nchunks):\n        chunk_start = chunk_idx * CHUNK_SIZE\n        seq_range = tl.arange(0, CHUNK_SIZE)\n        seq_mask = (chunk_start + seq_range) < seq_len\n\n        # Load angle (CHUNK_SIZE, BLOCK_D)\n        angle_ptrs = ANGLE + (chunk_start + seq_range[:, None]) * stride_angle_seq + dim_range[None, :] * stride_angle_dim\n        angle_vals = tl.load(angle_ptrs, mask=seq_mask[:, None] & dim_mask[None, :], other=0.0).to(tl.float32)\n        angle_vals = tanh_approx(angle_vals) * PI\n\n        # Load dt (CHUNK_SIZE,)\n        dt_ptrs = DT + (chunk_start + seq_range) * stride_dt_seq\n        dt_vals = tl.load(dt_ptrs, mask=seq_mask, other=0.0).to(tl.float32)\n\n        # Compute vals = angle * dt\n        vals = angle_vals * dt_vals[:, None]\n\n        # Cumsum within chunk + add state from previous chunks\n        chunk_cumsum = tl.cumsum(vals, axis=0)\n        out_vals = chunk_cumsum + state[None, :]\n\n        # Apply mod 2*pi for rotary angle normalization\n        out_vals = out_vals - TWO_PI * tl.floor(out_vals / TWO_PI)\n\n        # Store output\n        out_ptrs = OUT + (chunk_start + seq_range[:, None]) * stride_out_seq + dim_range[None, :] * stride_out_dim\n        tl.store(out_ptrs, out_vals, mask=seq_mask[:, None] & dim_mask[None, :])\n\n        # Update state: add chunk sum and apply mod 2*pi\n        chunk_sum = tl.sum(vals, axis=0)\n        state = state + chunk_sum\n        state = state - TWO_PI * tl.floor(state / TWO_PI)\n\n    # Store final state if requested\n    if RETURN_OUTPUT_STATE:\n        output_state_ptrs = OUTPUT_STATE + seq_idx * stride_output_state_seq + pid_h * stride_output_state_head + dim_range * stride_output_state_dim\n        tl.store(output_state_ptrs, state, mask=dim_mask)\n\n\ndef angle_dt_fwd(\n    angle: Tensor,\n    dt: Tensor,\n    init_state: Optional[Tensor] = None,\n    chunk_size: int = 64,\n    return_output_state: bool = False,\n    cu_seqlens: Optional[Tensor] = None,\n) -> Tensor | Tuple[Tensor, Tensor]:\n    \"\"\"Forward pass for angle * dt cumsum.\n\n    Args:\n        angle: Angle tensor             (batch, seqlen, nheads, dim)\n        dt: Time delta tensor           (batch, nheads, seqlen)\n        init_state: Initial state       (num_sequences, nheads, dim) or None\n        chunk_size: Chunk size for chunked computation\n        return_output_state: Whether to return final state\n        cu_seqlens: Cumulative sequence lengths (num_sequences + 1,) for varlen mode\n\n    Returns:\n        If return_output_state=False:\n            out: Cumulative output      (batch, seqlen, nheads, dim)\n        If return_output_state=True:\n            Tuple of:\n                out: Cumulative output      (batch, seqlen, nheads, dim)\n                output_state: Final state   (num_sequences, nheads, dim)\n    \"\"\"\n    batch, seqlen, nheads, dim = angle.shape\n    is_varlen = cu_seqlens is not None\n    \n    # Determine number of sequences\n    if is_varlen:\n        assert batch == 1, \"Varlen mode requires batch=1\"\n        num_sequences = cu_seqlens.shape[0] - 1\n    else:\n        num_sequences = batch\n    \n    assert dt.shape == (batch, nheads, seqlen), f\"dt shape mismatch: {dt.shape}\"\n    if init_state is not None:\n        assert init_state.shape == (num_sequences, nheads, dim), f\"init_state shape mismatch: {init_state.shape}\"\n\n    out = torch.empty_like(angle)\n    BLOCK_D = triton.next_power_of_2(dim)\n\n    # Handle None init_state for kernel\n    HAS_INIT_STATE = init_state is not None\n    if not HAS_INIT_STATE:\n        init_state = angle  # dummy, won't be accessed\n        stride_init = (0, 0, 0)\n    else:\n        stride_init = init_state.stride()\n\n    # Handle output_state\n    if return_output_state:\n        output_state = torch.empty(num_sequences, nheads, dim, device=angle.device, dtype=angle.dtype)\n        stride_output_state = output_state.stride()\n    else:\n        output_state = out  # dummy, won't be accessed\n        stride_output_state = (0, 0, 0)\n\n    # Handle cu_seqlens\n    if cu_seqlens is not None:\n        stride_cu_seqlen = cu_seqlens.stride(0)\n    else:\n        cu_seqlens = angle  # dummy, won't be accessed\n        stride_cu_seqlen = 0\n\n    # Grid setup\n    if is_varlen:\n        grid = (nheads, batch, num_sequences)\n    else:\n        grid = (nheads, batch)\n\n    angle_dt_fwd_kernel[grid](\n        out, output_state,\n        angle, dt, init_state, cu_seqlens,\n        out.stride(0), out.stride(1), out.stride(2), out.stride(3),\n        stride_output_state[0], stride_output_state[1], stride_output_state[2],\n        angle.stride(0), angle.stride(1), angle.stride(2), angle.stride(3),\n        dt.stride(0), dt.stride(1), dt.stride(2),\n        stride_init[0], stride_init[1], stride_init[2],\n        stride_cu_seqlen,\n        seqlen, dim,\n        CHUNK_SIZE=chunk_size,\n        BLOCK_D=BLOCK_D,\n        HAS_INIT_STATE=HAS_INIT_STATE,\n        RETURN_OUTPUT_STATE=return_output_state,\n        IS_VARLEN=is_varlen,\n    )\n\n    if return_output_state:\n        return out, output_state\n    return out\n\n\n# -----------------------------------------------------------------------------\n# Backward kernel\n# -----------------------------------------------------------------------------\n\n@triton.autotune(\n    configs=[\n        triton.Config({}, num_stages=s, num_warps=w)\n        for s in [1, 2, 3]\n        for w in [2, 4, 8]\n    ],\n    key=[\"CHUNK_SIZE\", \"BLOCK_D\", \"HAS_INIT_STATE\", \"HAS_GRAD_OUTPUT_STATE\", \"IS_VARLEN\"],\n)\n@triton.jit\ndef angle_dt_bwd_kernel(\n    # Outputs\n    GRAD_ANGLE, GRAD_DT, GRAD_INIT_STATE,\n    # Inputs\n    GRAD_OUT, GRAD_OUTPUT_STATE, ANGLE, DT, CU_SEQLENS,\n    # Strides for GRAD_ANGLE (batch, seqlen, nheads, dim)\n    stride_grad_angle_batch, stride_grad_angle_seq, stride_grad_angle_head, stride_grad_angle_dim,\n    # Strides for GRAD_DT (batch, nheads, seqlen)\n    stride_grad_dt_batch, stride_grad_dt_head, stride_grad_dt_seq,\n    # Strides for GRAD_INIT_STATE (num_sequences, nheads, dim)\n    stride_grad_init_seq, stride_grad_init_head, stride_grad_init_dim,\n    # Strides for GRAD_OUT (batch, seqlen, nheads, dim)\n    stride_grad_out_batch, stride_grad_out_seq, stride_grad_out_head, stride_grad_out_dim,\n    # Strides for GRAD_OUTPUT_STATE (num_sequences, nheads, dim)\n    stride_grad_output_state_seq, stride_grad_output_state_head, stride_grad_output_state_dim,\n    # Strides for ANGLE (batch, seqlen, nheads, dim)\n    stride_angle_batch, stride_angle_seq, stride_angle_head, stride_angle_dim,\n    # Strides for DT (batch, nheads, seqlen)\n    stride_dt_batch, stride_dt_head, stride_dt_seq,\n    # Stride for CU_SEQLENS\n    stride_cu_seqlen,\n    # Dimensions\n    seqlen, dim,\n    # Meta-parameters\n    CHUNK_SIZE: tl.constexpr,\n    BLOCK_D: tl.constexpr,\n    HAS_INIT_STATE: tl.constexpr,\n    HAS_GRAD_OUTPUT_STATE: tl.constexpr,\n    IS_VARLEN: tl.constexpr,\n):\n    pid_h = tl.program_id(0)\n    pid_b = tl.program_id(1)\n\n    # Handle varlen mode\n    if IS_VARLEN:\n        pid_seq = tl.program_id(2)\n        seq_idx = pid_seq\n        cu_seqlen_start = tl.load(CU_SEQLENS + pid_seq * stride_cu_seqlen).to(tl.int32)\n        cu_seqlen_end = tl.load(CU_SEQLENS + (pid_seq + 1) * stride_cu_seqlen).to(tl.int32)\n        seq_len = cu_seqlen_end - cu_seqlen_start\n        seq_offset = cu_seqlen_start\n    else:\n        seq_idx = pid_b\n        seq_len = seqlen\n        seq_offset = 0\n\n    nchunks = tl.cdiv(seq_len, CHUNK_SIZE)\n\n    # Offset base pointers by batch and head\n    GRAD_ANGLE += pid_b * stride_grad_angle_batch + pid_h * stride_grad_angle_head + seq_offset * stride_grad_angle_seq\n    GRAD_DT += pid_b * stride_grad_dt_batch + pid_h * stride_grad_dt_head + seq_offset * stride_grad_dt_seq\n    GRAD_OUT += pid_b * stride_grad_out_batch + pid_h * stride_grad_out_head + seq_offset * stride_grad_out_seq\n    ANGLE += pid_b * stride_angle_batch + pid_h * stride_angle_head + seq_offset * stride_angle_seq\n    DT += pid_b * stride_dt_batch + pid_h * stride_dt_head + seq_offset * stride_dt_seq\n\n    dim_range = tl.arange(0, BLOCK_D)\n    dim_mask = dim_range < dim\n    PI = 3.141592653589793\n\n    # Initialize gradient state from grad_output_state or zeros\n    if HAS_GRAD_OUTPUT_STATE:\n        grad_output_state_ptrs = GRAD_OUTPUT_STATE + seq_idx * stride_grad_output_state_seq + pid_h * stride_grad_output_state_head + dim_range * stride_grad_output_state_dim\n        grad_state = tl.load(grad_output_state_ptrs, mask=dim_mask, other=0.0).to(tl.float32)\n    else:\n        grad_state = tl.zeros((BLOCK_D,), dtype=tl.float32)\n\n    # Loop in reverse: derivative of cumsum is reverse cumsum\n    for chunk_idx in range(nchunks - 1, -1, -1):\n        chunk_start = chunk_idx * CHUNK_SIZE\n        seq_range = tl.arange(0, CHUNK_SIZE)\n        seq_mask = (chunk_start + seq_range) < seq_len\n\n        # Load grad_out (CHUNK_SIZE, BLOCK_D)\n        grad_out_ptrs = GRAD_OUT + (chunk_start + seq_range[:, None]) * stride_grad_out_seq + dim_range[None, :] * stride_grad_out_dim\n        grad_out_vals = tl.load(grad_out_ptrs, mask=seq_mask[:, None] & dim_mask[None, :], other=0.0).to(tl.float32)\n\n        # Reverse cumsum within chunk: rev_cumsum = total - cumsum + x\n        # But we need to handle the mask properly for partial chunks\n        chunk_sum = tl.sum(grad_out_vals, axis=0)\n        fwd_cumsum = tl.cumsum(grad_out_vals, axis=0)\n        rev_cumsum = chunk_sum[None, :] - fwd_cumsum + grad_out_vals\n\n        # Add gradient from future chunks\n        grad_vals = rev_cumsum + grad_state[None, :]\n\n        # Load angle and dt\n        angle_ptrs = ANGLE + (chunk_start + seq_range[:, None]) * stride_angle_seq + dim_range[None, :] * stride_angle_dim\n        pretanh_angle_vals = tl.load(angle_ptrs, mask=seq_mask[:, None] & dim_mask[None, :], other=0.0).to(tl.float32)\n        angle_vals = tanh_approx(pretanh_angle_vals) * PI\n\n        dt_ptrs = DT + (chunk_start + seq_range) * stride_dt_seq\n        dt_vals = tl.load(dt_ptrs, mask=seq_mask, other=0.0).to(tl.float32)\n\n        # Compute gradients: out = angle * dt\n        grad_angle_vals = grad_vals * dt_vals[:, None] * PI * sech2_approx(pretanh_angle_vals)\n        grad_dt_vals = tl.sum(grad_vals * angle_vals, axis=1)\n\n        # Store gradients\n        grad_angle_ptrs = GRAD_ANGLE + (chunk_start + seq_range[:, None]) * stride_grad_angle_seq + dim_range[None, :] * stride_grad_angle_dim\n        tl.store(grad_angle_ptrs, grad_angle_vals, mask=seq_mask[:, None] & dim_mask[None, :])\n\n        grad_dt_ptrs = GRAD_DT + (chunk_start + seq_range) * stride_grad_dt_seq\n        tl.store(grad_dt_ptrs, grad_dt_vals, mask=seq_mask)\n\n        # Update state for previous chunk\n        grad_state = grad_state + chunk_sum\n\n    # Store gradient for init_state if provided\n    if HAS_INIT_STATE:\n        grad_init_ptrs = GRAD_INIT_STATE + seq_idx * stride_grad_init_seq + pid_h * stride_grad_init_head + dim_range * stride_grad_init_dim\n        tl.store(grad_init_ptrs, grad_state, mask=dim_mask)\n\n\ndef angle_dt_bwd(\n    grad_out: Tensor,\n    angle: Tensor,\n    dt: Tensor,\n    has_init_state: bool = False,\n    chunk_size: int = 64,\n    grad_output_state: Optional[Tensor] = None,\n    cu_seqlens: Optional[Tensor] = None,\n) -> Tuple[Tensor, Tensor, Optional[Tensor]]:\n    \"\"\"Backward pass for angle * dt cumsum.\n\n    Args:\n        grad_out: Gradient of output         (batch, seqlen, nheads, dim)\n        angle: Angle tensor                  (batch, seqlen, nheads, dim)\n        dt: Time delta tensor                (batch, nheads, seqlen)\n        has_init_state: Whether init_state was provided in forward\n        chunk_size: Chunk size for chunked computation\n        grad_output_state: Gradient of output state (num_sequences, nheads, dim) or None\n        cu_seqlens: Cumulative sequence lengths (num_sequences + 1,) for varlen mode\n\n    Returns:\n        grad_angle: Gradient for angle       (batch, seqlen, nheads, dim)\n        grad_dt: Gradient for dt             (batch, nheads, seqlen)\n        grad_init_state: Gradient for init_state (num_sequences, nheads, dim) or None\n    \"\"\"\n    batch, seqlen, nheads, dim = angle.shape\n    is_varlen = cu_seqlens is not None\n    \n    # Determine number of sequences\n    if is_varlen:\n        assert batch == 1, \"Varlen mode requires batch=1\"\n        num_sequences = cu_seqlens.shape[0] - 1\n    else:\n        num_sequences = batch\n    \n    grad_angle = torch.empty_like(angle)\n    grad_dt = torch.empty_like(dt)\n    BLOCK_D = triton.next_power_of_2(dim)\n\n    # Handle init_state gradient\n    if has_init_state:\n        grad_init_state = torch.empty(num_sequences, nheads, dim, device=angle.device, dtype=torch.float32)\n        stride_grad_init = grad_init_state.stride()\n    else:\n        grad_init_state = None\n        stride_grad_init = (0, 0, 0)\n        grad_init_dummy = grad_angle  # dummy pointer\n\n    # Handle grad_output_state\n    HAS_GRAD_OUTPUT_STATE = grad_output_state is not None\n    if not HAS_GRAD_OUTPUT_STATE:\n        grad_output_state = grad_angle  # dummy, won't be accessed\n        stride_grad_output_state = (0, 0, 0)\n    else:\n        stride_grad_output_state = grad_output_state.stride()\n\n    # Handle cu_seqlens\n    if cu_seqlens is not None:\n        stride_cu_seqlen = cu_seqlens.stride(0)\n    else:\n        cu_seqlens = angle  # dummy, won't be accessed\n        stride_cu_seqlen = 0\n\n    # Grid setup\n    if is_varlen:\n        grid = (nheads, batch, num_sequences)\n    else:\n        grid = (nheads, batch)\n\n    angle_dt_bwd_kernel[grid](\n        grad_angle, grad_dt, grad_init_state if has_init_state else grad_init_dummy,\n        grad_out, grad_output_state, angle, dt, cu_seqlens,\n        grad_angle.stride(0), grad_angle.stride(1), grad_angle.stride(2), grad_angle.stride(3),\n        grad_dt.stride(0), grad_dt.stride(1), grad_dt.stride(2),\n        stride_grad_init[0], stride_grad_init[1], stride_grad_init[2],\n        grad_out.stride(0), grad_out.stride(1), grad_out.stride(2), grad_out.stride(3),\n        stride_grad_output_state[0], stride_grad_output_state[1], stride_grad_output_state[2],\n        angle.stride(0), angle.stride(1), angle.stride(2), angle.stride(3),\n        dt.stride(0), dt.stride(1), dt.stride(2),\n        stride_cu_seqlen,\n        seqlen, dim,\n        CHUNK_SIZE=chunk_size,\n        BLOCK_D=BLOCK_D,\n        HAS_INIT_STATE=has_init_state,\n        HAS_GRAD_OUTPUT_STATE=HAS_GRAD_OUTPUT_STATE,\n        IS_VARLEN=is_varlen,\n    )\n    return grad_angle, grad_dt, grad_init_state"
  },
  {
    "path": "mamba_ssm/ops/triton/mamba3/mamba3_mimo_rotary_step.py",
    "content": "# Copyright (c) 2025, Tri Dao.\n# We need a pretty recent version of triton to support tuples. 3.3 definitely will work,\n# idk which is the minimum version.\n\nimport math\nfrom typing import Optional, Tuple\n\nimport torch\n\nimport triton\nimport triton.language as tl\nimport triton.testing\n#from flash_attn.cute.benchmark import pytorch_profiler\n\n@triton.jit\ndef rotary_qk_inference_kernel(\n    OUT_Q,  # Pointers to matrices\n    OUT_K,\n    OUT_ANGLE_STATE,\n    Q,\n    K,\n    ANGLE_STATE,\n    ANGLE_PROJ,\n    DT,\n    BIAS_Q,\n    BIAS_K,\n    nheads,\n    headdim,\n    stride_out_q,           # (batch, mimo_dim, nheads, headdim)\n    stride_out_k,           # (batch, mimo_dim, nheads, headdim)\n    stride_out_angle_state, # (batch, nheads, rotary_dim // 2)\n    stride_q,               # (batch, mimo_dim, nheads, headdim)\n    stride_k,               # (batch, mimo_dim, nheads, headdim)\n    stride_angle_state,     # (batch, nheads, rotary_dim // 2)\n    stride_angle_proj,      # (batch, nheads, rotary_dim // 2)\n    stride_dt,              # (batch, nheads)\n    stride_bias_q,          # (mimo_dim, nheads, headdim)\n    stride_bias_k,          # (mimo_dim, nheads, headdim)\n    # Meta-parameters\n    ROTARY_DIM: tl.constexpr,\n    CONJUGATE: tl.constexpr,\n    HAS_BIAS_Q: tl.constexpr,\n    HAS_BIAS_K: tl.constexpr,\n    MIMO_DIM: tl.constexpr,\n    BLOCK_D: tl.constexpr, # headdim, no chunking\n    ROTATE_PAIRWISE: tl.constexpr, # If true, rotate every pair of dimensions together. Otherwise, rotate the first half and second half separately (like in the original RoPE paper)\n):\n    pid_nheads = tl.program_id(axis=0) # heads\n    pid_batch = tl.program_id(axis=1)\n\n    Q = Q + pid_batch * stride_q[0] + pid_nheads * stride_q[2]\n    K = K + pid_batch * stride_k[0] + pid_nheads * stride_k[2]\n    ANGLE_STATE = ANGLE_STATE + pid_batch * stride_angle_state[0] + pid_nheads * stride_angle_state[1]  # FIX: [1]\n    ANGLE_PROJ = ANGLE_PROJ + pid_batch * stride_angle_proj[0] + pid_nheads * stride_angle_proj[1]      # FIX: [1]\n    DT = DT + pid_batch * stride_dt[0] + pid_nheads * stride_dt[1]\n\n    OUT_Q = OUT_Q + pid_batch * stride_out_q[0] + pid_nheads * stride_out_q[2]\n    OUT_K = OUT_K + pid_batch * stride_out_k[0] + pid_nheads * stride_out_k[2]\n    OUT_ANGLE_STATE = OUT_ANGLE_STATE + pid_batch * stride_out_angle_state[0] + pid_nheads * stride_out_angle_state[1]  # FIX: [1]\n\n    rm = tl.arange(0, MIMO_DIM)\n    rd = tl.arange(0, BLOCK_D)\n    rd_half = tl.arange(0, BLOCK_D // 2)\n\n    # Load angle and compute cos/sin (same for both q and k)\n    ANGLE_STATE = ANGLE_STATE + rd_half * stride_angle_state[2]  # (rotary_dim // 2)\n    mask_angle = rd_half < ROTARY_DIM // 2\n    angle_state = tl.load(ANGLE_STATE, mask=mask_angle, other=0.0).to(tl.float32)\n\n    ANGLE_PROJ = ANGLE_PROJ + rd_half * stride_angle_proj[2]     # (rotary_dim // 2)\n    angle_proj = tl.load(ANGLE_PROJ, mask=mask_angle, other=0.0).to(tl.float32)\n\n    dt = tl.load(DT, mask=True, other=0.0).to(tl.float32)\n\n    # Match angle_dt: tanh(angle_proj) * dt * pi\n    angle_proj = tl.sigmoid(2.0 * angle_proj) * 2.0 - 1.0  # tanh\n    angle = angle_state + angle_proj * dt * 3.141592653589793  # (rotary_dim // 2)\n\n    OUT_ANGLE_STATE = OUT_ANGLE_STATE + rd_half * stride_out_angle_state[2]\n    tl.store(OUT_ANGLE_STATE, angle, mask=mask_angle)\n\n    angle = angle[None, :]  # (1, rotary_dim // 2) for mimo_dim broadcasting\n    cos = tl.cos(angle)\n    sin = tl.sin(angle)\n    if CONJUGATE:\n        sin = -sin\n\n    # Process Q tensor\n    Q = Q + (rm[:, None] * stride_q[1] + rd[None, :] * stride_q[3])\n    OUT_Q = OUT_Q + (rm[:, None] * stride_out_q[1] + rd[None, :] * stride_out_q[3])\n    mask = rd[None, :] < headdim\n    q = tl.load(Q, mask=mask, other=0.0).to(tl.float32)  # (mimo_dim, headdim)\n\n    # Add bias to Q if present\n    if HAS_BIAS_Q:\n        BIAS_Q = BIAS_Q + pid_nheads * stride_bias_q[1]                                                   \n        BIAS_Q = BIAS_Q + (rm[:, None] * stride_bias_q[0] + rd[None, :] * stride_bias_q[2])\n        bias_q = tl.load(BIAS_Q, mask=mask, other=0.0).to(tl.float32)\n        q = q + bias_q\n\n    if ROTATE_PAIRWISE:\n        # Apply rotary to Q\n        q0, q1 = tl.split(tl.reshape(q, [MIMO_DIM, BLOCK_D // 2, 2]))\n        qo0 = q0 * cos - q1 * sin\n        qo1 = q0 * sin + q1 * cos\n        qo = tl.reshape(tl.join(qo0, qo1), [MIMO_DIM, BLOCK_D])\n        tl.store(OUT_Q, qo, mask=mask)\n    else:\n        # Apply rotary to Q\n        q_reshaped = tl.reshape(q, [MIMO_DIM, 2, BLOCK_D // 2])\n        q_permuted = tl.permute(q_reshaped, (0, 2, 1))  # (mimo_dim, block_d // 2, 2)\n        q0, q1 = tl.split(q_permuted)\n        qo0 = q0 * cos - q1 * sin\n        qo1 = q0 * sin + q1 * cos\n        q_joined = tl.join(qo0, qo1)\n        q_final = tl.permute(q_joined, (0, 2, 1))  # (mimo_dim, 2, block_d // 2)\n        qo = tl.reshape(q_final, [MIMO_DIM, BLOCK_D])\n        tl.store(OUT_Q, qo, mask=mask)\n\n    # Process K tensor\n    K = K + (rm[:, None] * stride_k[1] + rd[None, :] * stride_k[3])\n    OUT_K = OUT_K + (rm[:, None] * stride_out_k[1] + rd[None, :] * stride_out_k[3])\n    k = tl.load(K, mask=mask, other=0.0).to(tl.float32)\n\n    # Add bias to K if present\n    if HAS_BIAS_K:\n        BIAS_K = BIAS_K + pid_nheads * stride_bias_k[1]                                                 \n        BIAS_K = BIAS_K + (rm[:, None] * stride_bias_k[0] + rd[None, :] * stride_bias_k[2])\n        bias_k = tl.load(BIAS_K, mask=mask, other=0.0).to(tl.float32)\n        k = k + bias_k\n\n    if ROTATE_PAIRWISE:\n        # Apply rotary to K\n        k0, k1 = tl.split(tl.reshape(k, [MIMO_DIM, BLOCK_D // 2, 2]))\n        ko0 = k0 * cos - k1 * sin\n        ko1 = k0 * sin + k1 * cos\n        ko = tl.reshape(tl.join(ko0, ko1), [MIMO_DIM, BLOCK_D])\n        tl.store(OUT_K, ko, mask=mask)\n    else:\n        # Apply rotary to K\n        k_reshaped = tl.reshape(k, [MIMO_DIM, 2, BLOCK_D // 2])\n        k_permuted = tl.permute(k_reshaped, (0, 2, 1))  # (mimo_dim, block_d // 2, 2)\n        k0, k1 = tl.split(k_permuted)\n        ko0 = k0 * cos - k1 * sin\n        ko1 = k0 * sin + k1 * cos\n        k_joined = tl.join(ko0, ko1)\n        k_final = tl.permute(k_joined, (0, 2, 1))  # (mimo_dim, 2, block_d // 2)\n        ko = tl.reshape(k_final, [MIMO_DIM, BLOCK_D])\n        tl.store(OUT_K, ko, mask=mask)\n\ndef apply_rotary_qk_inference_fwd(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    angle_state: torch.Tensor,\n    angle_proj: torch.Tensor,\n    dt: torch.Tensor,\n    bias_q: Optional[torch.Tensor] = None,\n    bias_k: Optional[torch.Tensor] = None,\n    inplace=False,\n    conjugate=False,\n    rotate_pairwise=True,\n) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Apply rotary embedding to both q and k tensors using the same angle.\n    Also computes output angle state for next step.\n\n    Arguments:\n        q: (batch, mimo_dim, nheads, headdim)\n        k: (batch, mimo_dim, nheads, headdim)\n        angle_state: (batch, nheads, rotary_dim / 2)\n        angle_proj: (batch, nheads, rotary_dim / 2)\n        dt: (batch, nheads)\n        bias_q: Optional (mimo_dim, nheads, headdim) - bias to add to q before rotary\n        bias_k: Optional (mimo_dim, nheads, headdim) - bias to add to k before rotary\n    Returns:\n        (q_out, k_out, angle_state_out): q_out and k_out are (batch, mimo_dim, nheads, headdim),\n                               angle_state_out is (batch, nheads, rotary_dim / 2)\n    \"\"\"\n    batch, mimo_dim, nheads, headdim = q.shape\n    assert headdim % 2 == 0\n    assert k.shape == q.shape, f\"k shape {k.shape} != q shape {q.shape}\"\n\n    rotary_dim = angle_state.shape[-1] * 2\n    assert angle_state.shape == (batch, nheads, rotary_dim // 2)\n    assert angle_state.shape == angle_proj.shape\n    assert dt.shape == (batch, nheads)\n    assert rotary_dim <= headdim, \"rotary_dim must be <= headdim\"\n    assert headdim <= 256, \"Only support headdim <= 256\"\n\n    if bias_q is not None:\n        assert bias_q.shape == (mimo_dim, nheads, headdim), f\"bias_q shape {bias_q.shape} != (mimo_dim, nheads, headdim) {(mimo_dim, nheads, headdim)}\"\n        bias_q = bias_q.contiguous()\n\n    if bias_k is not None:\n        assert bias_k.shape == (mimo_dim, nheads, headdim), f\"bias_k shape {bias_k.shape} != (mimo_dim, nheads, headdim) {(mimo_dim, nheads, headdim)}\"\n        bias_k = bias_k.contiguous()\n\n    output_q = torch.empty_like(q) if not inplace else q\n    output_k = torch.empty_like(k) if not inplace else k\n    output_angle_state = torch.empty_like(angle_state) if not inplace else angle_state\n\n    grid = lambda META: (nheads, batch)  # noqa\n    with torch.cuda.device(q.device.index):\n        torch.library.wrap_triton(rotary_qk_inference_kernel)[grid](\n            output_q,  # data ptrs\n            output_k,\n            output_angle_state,\n            q,\n            k,\n            angle_state,\n            angle_proj,\n            dt,\n            bias_q,\n            bias_k,\n            nheads,\n            headdim,\n            output_q.stride(),  # output strides tuples\n            output_k.stride(),\n            output_angle_state.stride(),\n            q.stride(),  # input strides tuples\n            k.stride(),\n            angle_state.stride(),\n            angle_proj.stride(),\n            dt.stride(),\n            bias_q.stride() if bias_q is not None else (0, 0, 0),\n            bias_k.stride() if bias_k is not None else (0, 0, 0),\n            rotary_dim,\n            conjugate,\n            bias_q is not None,\n            bias_k is not None,\n            MIMO_DIM=mimo_dim,\n            BLOCK_D=triton.next_power_of_2(headdim),\n            num_warps=8,  # important, 4 warps is slower if we compute qk_sum\n            ROTATE_PAIRWISE=rotate_pairwise,\n        )\n    return output_q, output_k, output_angle_state\n\n\ndef apply_rotary_qk_inference_reference(\n    q: torch.Tensor, # (B, R, N, D)\n    k: torch.Tensor, # (B, R, N, D)\n    angle_state: torch.Tensor, # (B, N, S) S: num_rope_angles\n    angle_proj: torch.Tensor, # (B, N, S)\n    dt: torch.Tensor, # (B, N)\n    bias_q: Optional[torch.Tensor] = None, # (R, N, D)\n    bias_k: Optional[torch.Tensor] = None, # (R, N, D)\n    conjugate=False,\n    rotate_pairwise=True,\n) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n    \"\"\"Reference PyTorch implementation for QK rotary embedding with qk_sum.\"\"\"\n    batch, mimo_dim, nheads, headdim = q.shape\n    rotary_dim = angle_state.shape[-1] * 2\n\n    # Match angle_dt: tanh(angle_proj) * dt * pi\n    angle_proj = torch.tanh(angle_proj)\n    angle = angle_state + angle_proj * dt[:, :, None] * math.pi  # (B, N, S)\n    angle_state_new = angle\n    angle = angle.unsqueeze(1).expand(-1, mimo_dim, -1, -1)  # (B, R, N, S)\n\n    # Add biases if present\n    if bias_q is not None:\n        q = q + bias_q[None, :, :, :]  # Broadcast bias_q\n    if bias_k is not None:\n        k = k + bias_k[None, :, :, :]  # Broadcast bias_k\n\n    # Only apply rotary to the rotary dimensions\n    q_rot = q[..., :rotary_dim] # (B, R, N, rotary_dim)\n    q_pass = q[..., rotary_dim:]\n    k_rot = k[..., :rotary_dim]\n    k_pass = k[..., rotary_dim:]\n\n    # Compute cos and sin from angle (same for both q and k)\n    cos = torch.cos(angle) # (B, N, S)\n    sin = torch.sin(angle)\n    if conjugate:\n        sin = -sin\n\n    if rotate_pairwise:\n        # Interleaved rotary: pairs are (x0,x1), (x2,x3), ...\n        q_rot = q_rot.reshape(batch, mimo_dim, nheads, rotary_dim // 2, 2)\n        q0, q1 = q_rot[..., 0], q_rot[..., 1]\n        k_rot = k_rot.reshape(batch, mimo_dim, nheads, rotary_dim // 2, 2)\n        k0, k1 = k_rot[..., 0], k_rot[..., 1]\n\n        qo0 = q0 * cos - q1 * sin\n        qo1 = q0 * sin + q1 * cos\n        ko0 = k0 * cos - k1 * sin\n        ko1 = k0 * sin + k1 * cos\n\n        qout_rot = torch.stack([qo0, qo1], dim=-1).reshape(batch, mimo_dim, nheads, rotary_dim)\n        kout_rot = torch.stack([ko0, ko1], dim=-1).reshape(batch, mimo_dim, nheads, rotary_dim)\n\n        # Concatenate rotated and pass-through dimensions\n        if rotary_dim < headdim:\n            q_out = torch.cat([qout_rot, q_pass], dim=-1)\n            k_out = torch.cat([kout_rot, k_pass], dim=-1)\n        else:\n            q_out = qout_rot\n            k_out = kout_rot\n    else:\n        # Halved rotary: split full headdim in half, pairs are (dim_i, dim_{i+D/2})\n        # Matches kernel which splits BLOCK_D in half; cos(0)=1/sin(0)=0 gives identity\n        # for pairs beyond rotary_dim//2\n        half = headdim // 2\n        q0, q1 = q[..., :half], q[..., half:]\n        k0, k1 = k[..., :half], k[..., half:]\n\n        # Pad cos/sin from rotary_dim//2 to headdim//2 with cos=1, sin=0\n        rdim_half = rotary_dim // 2\n        if half > rdim_half:\n            pad_shape = list(cos.shape)\n            pad_shape[-1] = half - rdim_half\n            cos = torch.cat([cos, torch.ones(pad_shape, device=cos.device, dtype=cos.dtype)], dim=-1)\n            sin = torch.cat([sin, torch.zeros(pad_shape, device=sin.device, dtype=sin.dtype)], dim=-1)\n\n        qo0 = q0 * cos - q1 * sin\n        qo1 = q0 * sin + q1 * cos\n        ko0 = k0 * cos - k1 * sin\n        ko1 = k0 * sin + k1 * cos\n\n        q_out = torch.cat([qo0, qo1], dim=-1)\n        k_out = torch.cat([ko0, ko1], dim=-1)\n\n    return q_out, k_out, angle_state_new\n\n\ndef test_correctness_qk_inference():\n    print(\"Testing QK Inference correctness...\")\n\n    device = \"cuda\"\n    torch.manual_seed(2025)\n    dtype_qk = torch.bfloat16  # common inference dtype\n    dtype_ang = torch.float32\n\n    def run_case(B, R, N, D, RD, with_bias, conjugate, expanded_heads, rotate_pairwise):\n        assert D % 2 == 0\n        # Build q,k with optional head broadcasting\n        q0 = torch.randn(B, R, 1 if expanded_heads else N, D, device=device, dtype=dtype_qk)\n        k0 = torch.randn(B, R, 1 if expanded_heads else N, D, device=device, dtype=dtype_qk)\n        q  = q0.expand(B, R, N, D) if expanded_heads else q0\n        k  = k0.expand(B, R, N, D) if expanded_heads else k0\n\n        angle_state = torch.randn(B, N, RD // 2, device=device, dtype=dtype_ang)\n        angle_proj  = torch.randn(B, N, RD // 2, device=device, dtype=dtype_ang)\n        dt          = torch.randn(B, N, device=device, dtype=dtype_ang)\n\n        bias_q = torch.randn(R, N, D, device=device, dtype=dtype_qk) if with_bias else None\n        bias_k = torch.randn(R, N, D, device=device, dtype=dtype_qk) if with_bias else None\n\n        # Reference\n        q_ref, k_ref, updated_angle_ref = apply_rotary_qk_inference_reference(\n            q, k, angle_state, angle_proj, dt,\n            bias_q=bias_q, bias_k=bias_k, conjugate=conjugate,\n            rotate_pairwise=rotate_pairwise,\n        )\n\n        # Kernel\n        q_out, k_out, updated_angle = apply_rotary_qk_inference_fwd(\n            q, k, angle_state, angle_proj, dt,\n            bias_q=bias_q, bias_k=bias_k, conjugate=conjugate, inplace=False,\n            rotate_pairwise=rotate_pairwise,\n        )\n\n        def _chk(name, a, b, atol=1e-1, rtol=1e-1):\n            diff = (a - b).abs().max().item()\n            if not torch.allclose(a, b, atol=atol, rtol=rtol):\n                raise AssertionError(f\"{name} mismatch: max|Δ|={diff:.3e}  got={tuple(a.shape)}  ref={tuple(b.shape)}\")\n            print(f\"  {name:18s} ok   max|Δ|={diff:.2e}\")\n\n        print(f\"\\nInference [{B=}, {R=}, {N=}, {D=}, {RD=} | bias={with_bias}, conj={conjugate}, expanded={expanded_heads}, pairwise={rotate_pairwise}]\")\n        _chk(\"q_out\", q_out.float(), q_ref.float(), atol=1e-1, rtol=1e-1)\n        _chk(\"k_out\", k_out.float(), k_ref.float(), atol=1e-1, rtol=1e-1)\n        _chk(\"updated_angle\", updated_angle, updated_angle_ref, atol=1e-1, rtol=1e-1)\n\n    # standard config\n    B, R, N, D, RD = 2, 4, 64, 128, 64\n    for with_bias in [False, True]:\n        for conjugate in [False, True]:\n            for expanded in [True, False]:\n                for pairwise in [True, False]:\n                    run_case(B, R, N, D, RD, with_bias, conjugate, expanded, pairwise)\n\n    # light shape sweep\n    for (BB, RR, NN, DD, RRd) in [\n        (1, 2, 64, 64,  32),\n        (3, 1, 32, 128, 64),\n        (2, 8, 32, 128, 64),\n    ]:\n        for pairwise in [True, False]:\n            run_case(BB, RR, NN, DD, RRd, with_bias=True,  conjugate=False, expanded_heads=True, rotate_pairwise=pairwise)\n            run_case(BB, RR, NN, DD, RRd, with_bias=False, conjugate=True,  expanded_heads=False, rotate_pairwise=pairwise)\n\n    print(\"\\nAll QK Inference tests passed! ✓\")\n\n\nif __name__ == \"__main__\":\n    test_correctness_qk_inference()\n"
  },
  {
    "path": "mamba_ssm/ops/triton/mamba3/mamba3_mimo_utils.py",
    "content": "\"\"\"\nFused Triton kernels for Mamba3 backward pass ddt computation.\n\nThis module implements fused kernels that combine three separate backward operations:\n1. bwd_segsum_ddt_from_dSSdA - Complex 2D segsum operation\n2. bwd_ddt_from_ddA_cs_rev - Forward exclusive cumsum operation\n3. bwd_ddt_from_ddA_cs - Reverse cumsum operation\n\nThe fusion reduces memory traffic and kernel launch overhead.\n\"\"\"\n\nimport torch\nimport triton\nimport triton.language as tl\nimport math\nfrom typing import Optional, Tuple\n\n# Constants\nLOG2 = math.log(2.0)\nNEG_LOG2E = -math.log2(math.e)\n\n\n# ============================================================================\n# Kernel 1: Fused Cumsum Operations (forward exclusive + reverse)\n# ============================================================================\n\n@triton.autotune(\n    configs=[\n        triton.Config({}, num_stages=s, num_warps=w)\n        for s in [1, 2, 3]\n        for w in [4, 8]\n    ],\n    key=[\"CHUNK_SIZE\"],\n    restore_value=[\"ddt_out_ptr\"],\n)\n@triton.jit\ndef bwd_dadt_cumsum_fused_kernel(\n    ddA_cs_ptr,         # [B, H, S]\n    ddA_cs_rev_ptr,     # [B, H, S]\n    dA_cs_ptr,          # [B, H, S]\n    dA_cs_rev_ptr,      # [B, H, S]\n    ddt_out_ptr,        # [B, H, S] - output\n    stride_batch,\n    stride_head,\n    stride_seq,\n    B: tl.constexpr,\n    H: tl.constexpr,\n    S: tl.constexpr,\n    CHUNK_SIZE: tl.constexpr,\n):\n    \"\"\"\n    Fused kernel that computes contributions from:\n    - bwd_ddt_from_ddA_cs: reverse cumsum operation\n    - bwd_ddt_from_ddA_cs_rev: forward exclusive cumsum operation\n\n    Each program handles one chunk for one (batch, head) pair.\n    Grid: (B, H, nchunks)\n    \"\"\"\n    # Get program indices\n    pid_batch = tl.program_id(0)\n    pid_head = tl.program_id(1)\n    pid_chunk = tl.program_id(2)\n\n    # Calculate chunk boundaries\n    chunk_start = pid_chunk * CHUNK_SIZE\n    offs_seq = chunk_start + tl.arange(0, CHUNK_SIZE)\n    mask = offs_seq < S\n\n    # Compute base offset for this (batch, head) pair\n    base_offset = pid_batch * stride_batch + pid_head * stride_head\n\n    # Load chunk data for all four input tensors\n    ddA_cs = tl.load(ddA_cs_ptr + base_offset + offs_seq * stride_seq, mask=mask, other=0.0)\n    ddA_cs_rev = tl.load(ddA_cs_rev_ptr + base_offset + offs_seq * stride_seq, mask=mask, other=0.0)\n    dA_cs = tl.load(dA_cs_ptr + base_offset + offs_seq * stride_seq, mask=mask, other=0.0)\n    dA_cs_rev = tl.load(dA_cs_rev_ptr + base_offset + offs_seq * stride_seq, mask=mask, other=0.0)\n\n    # ========================================================================\n    # Operation 1: bwd_ddt_from_ddA_cs (reverse cumsum)\n    # ========================================================================\n    # Scale by log(2) * exp2(dA_cs)\n    # Use literal constants instead of globals\n    scaled_ddA_cs =  tl.exp(dA_cs) * ddA_cs  # LOG2\n    # Apply reverse cumsum within chunk\n    ddt_cs = tl.cumsum(scaled_ddA_cs, axis=0, reverse=True)\n\n    # ========================================================================\n    # Operation 2: bwd_ddt_from_ddA_cs_rev (forward exclusive cumsum)\n    # ========================================================================\n    # Scale by log(2) * exp2(dA_cs_rev)\n    # Use literal constants instead of globals\n    scaled_ddA_cs_rev = tl.exp(dA_cs_rev) * ddA_cs_rev  # LOG2\n    # Apply forward cumsum within chunk (inclusive)\n    ddt_cs_rev_inclusive = tl.cumsum(scaled_ddA_cs_rev, axis=0)\n\n    # Roll one to the right:\n    i = tl.arange(0, CHUNK_SIZE)[:, None]          # [N,1]\n    j = tl.arange(0, CHUNK_SIZE)[None, :]          # [1,N]\n    S = (i == j + 1)                      # strictly lower diagonal (one below main)\n    ddt_cs_rev_exclusive = tl.sum(tl.where(S, ddt_cs_rev_inclusive, 0), axis=1)\n\n    # # Convert to exclusive cumsum\n    # # Exclusive cumsum: output[i] = sum(input[0:i])\n    # # Inclusive cumsum: cumsum[i] = sum(input[0:i+1])\n    # # Therefore: exclusive[i] = inclusive[i] - input[i]\n    # # Which is: exclusive[i] = cumsum[i] - scaled_ddA_cs_rev[i]\n    # ddt_cs_rev_shifted = ddt_cs_rev_inclusive - scaled_ddA_cs_rev\n\n    # ========================================================================\n    # Combine contributions and apply final scaling\n    # ========================================================================\n    # Use literal constant instead of global\n    ddt_total = ddt_cs + ddt_cs_rev_exclusive \n\n    # Store result\n    tl.store(ddt_out_ptr + base_offset + offs_seq * stride_seq, ddt_total, mask=mask)\n\n\n# ============================================================================\n# Kernel 2: Segsum Operation with 2D Matrix Processing\n# ============================================================================\n\n@triton.autotune(\n    configs=[\n        triton.Config({}, num_stages=s, num_warps=w)\n        for s in [2, 3]\n        for w in [4, 8]\n    ],\n    key=[\"CHUNK_SIZE\"],\n    restore_value=[\"ddt_out_ptr\"],\n\n)\n@triton.jit\ndef bwd_segsum_dadt_kernel(\n    dSSdA_ptr,          # [B, H, nchunks, C, C]\n    SSdA_cs_ptr,          # [B, H, S]\n    ddt_out_ptr,        # [B, H, S] - accumulated output\n    stride_dSSdA_batch,\n    stride_dSSdA_head,\n    stride_dSSdA_chunk,\n    stride_dSSdA_row,\n    stride_dSSdA_col,\n    stride_SSdA_batch,\n    stride_SSdA_head,\n    stride_SSdA_chunk,\n    stride_SSdA_row,\n    stride_SSdA_col,\n    stride_ddt_batch,\n    stride_ddt_head,\n    stride_ddt_seq,\n    B: tl.constexpr,\n    H: tl.constexpr,\n    nchunks: tl.constexpr,\n    C: tl.constexpr,\n    CHUNK_SIZE: tl.constexpr,\n):\n    \"\"\"\n    Kernel for bwd_segsum_ddt_from_dSSdA operation.\n    Matches the reference implementation:\n    1. Permute dSSdA last two dims\n    2. Compute seg = dA_cs[i] - dA_cs[j]\n    3. Scale by log(2) * exp2(seg)\n    4. Reverse cumsum along dim -2 (column-wise for each row)\n    5. Apply lower triangular mask (i > j)\n    6. Sum along dim -1 (sum over j for each i)\n\n    Each program handles one chunk for one (batch, head) pair.\n    Grid: (B, H, nchunks)\n    \"\"\"\n    # Get program indices\n    pid_batch = tl.program_id(0)\n    pid_head = tl.program_id(1)\n    pid_chunk = tl.program_id(2)\n\n    # Calculate chunk boundaries\n    chunk_start = pid_chunk * CHUNK_SIZE\n    offs_c = tl.arange(0, CHUNK_SIZE)\n    offs_seq = chunk_start + offs_c\n\n    # Load dA_cs for this chunk [C]\n    # dA_cs_offset = pid_batch * stride_dA_batch + pid_head * stride_dA_head\n    # dA_cs_chunk = tl.load(dA_cs_ptr + dA_cs_offset + offs_seq * stride_dA_seq)\n\n    # Base offset for dSSdA matrix [nchunks, C, C]\n    dSSdA_offset = dSSdA_ptr + (pid_batch * stride_dSSdA_batch +\n                    pid_head * stride_dSSdA_head +\n                    pid_chunk * stride_dSSdA_chunk)\n    SSdA_offset = SSdA_cs_ptr + (pid_batch * stride_SSdA_batch +\n                    pid_head * stride_SSdA_head +\n                    pid_chunk * stride_SSdA_chunk)\n    ddt_ptrs = ddt_out_ptr + (pid_batch * stride_ddt_batch +\n                    pid_head * stride_ddt_head +\n                    offs_seq * stride_ddt_seq)\n\n    # NOTE: dSSdA is actually the transpose corresponding to seq_k \\time seq_q\n    dSSdA_block = tl.load(dSSdA_offset + offs_c[:, None]*stride_dSSdA_col + offs_c[None, :]*stride_dSSdA_row)\n    SSdA_block = tl.load(SSdA_offset + offs_c[:, None]*stride_SSdA_row + offs_c[None, :]*stride_SSdA_col)\n\n    dSSdA_block = dSSdA_block * tl.exp(SSdA_block)\n    dSSdA_block = tl.cumsum(dSSdA_block, axis=0, reverse=True)\n\n    offs_i = tl.arange(0, CHUNK_SIZE)[:, None]\n    offs_j = tl.arange(0, CHUNK_SIZE)[None, :]\n    SS_mask = offs_i > offs_j\n    dSSdA = tl.where(SS_mask, dSSdA_block, 0.0)\n\n    ddt_chunk = tl.load(ddt_ptrs)\n    ddt_chunk += tl.sum(dSSdA, axis=1)\n    tl.store(ddt_ptrs, ddt_chunk)\n\n\n# ============================================================================\n# Kernel 3:  backwards from gamma terms to trap \n# ============================================================================\n\n@triton.autotune(\n    configs=[\n        triton.Config({}, num_stages=s, num_warps=w)\n        for s in [2, 3]\n        for w in [4, 8]\n    ],\n    key=[\"CHUNK_SIZE\"],\n)\n@triton.jit\ndef bwd_dtrap_ddt_kernel(\n    trap_ptr, dt_ptr, dfactor_ptr, dgamma_diag_ptr,\n    ddt_ptr, dtrap_ptr, \n    stride_trap_batch, stride_trap_head, stride_trap_seq,\n    stride_dt_batch, stride_dt_head, stride_dt_seq,\n    stride_dfactor_batch, stride_dfactor_head, stride_dfactor_seq,\n    stride_dgamma_diag_batch, stride_dgamma_diag_head, stride_dgamma_diag_seq,\n    stride_ddt_batch, stride_ddt_head, stride_ddt_seq,\n    stride_dtrap_batch, stride_dtrap_head, stride_dtrap_seq,\n\n    SEQLEN: tl.constexpr,\n    CHUNK_SIZE: tl.constexpr,\n):\n    # Get program indices\n    pid_batch = tl.program_id(0)\n    pid_head = tl.program_id(1)\n    pid_chunk = tl.program_id(2)\n\n    # Calculate chunk boundaries\n    chunk_start = pid_chunk * CHUNK_SIZE\n    offs_c = tl.arange(0, CHUNK_SIZE)\n    offs_seq = chunk_start + offs_c\n\n    trap_offset = pid_batch*stride_trap_batch + pid_head*stride_trap_head\n    dt_offset = pid_batch*stride_dt_batch + pid_head*stride_dt_head\n    dfactor_offset = pid_batch*stride_dfactor_batch + pid_head*stride_dfactor_head\n    dgamma_diag_offset = pid_batch*stride_dgamma_diag_batch + pid_head*stride_dgamma_diag_head\n\n    strap_block = tl.load(\n        trap_ptr + trap_offset + (offs_seq + 1)*stride_trap_seq,\n        mask=(offs_seq + 1) < SEQLEN, other=0.0\n        )\n    sdt_block = tl.load(\n        dt_ptr + dt_offset + (offs_seq + 1)*stride_dt_seq,\n        mask=(offs_seq + 1) < SEQLEN, other=0.0\n    )\n    trap_block = tl.load(\n        trap_ptr + trap_offset + offs_seq * stride_trap_seq,\n        mask=offs_seq < SEQLEN, other=0.0\n    )\n    dt_block = tl.load(\n        dt_ptr + dt_offset + offs_seq * stride_dt_seq,\n        mask=offs_seq < SEQLEN, other=0.0\n    )\n    dfactor_block = tl.load(\n        dfactor_ptr + dfactor_offset + offs_seq * stride_dfactor_seq,\n        mask=offs_seq < SEQLEN, other=0.0\n    )\n    dgamma_diag_input_block = tl.load(\n        dgamma_diag_ptr + dgamma_diag_offset + offs_seq * stride_dgamma_diag_seq,\n        mask=offs_seq < SEQLEN, other=0.0\n    )\n\n    # dgamma and dsgamma for current positions\n    dgamma_block = dfactor_block + dgamma_diag_input_block\n    dsgamma_block = dfactor_block #+ dsgamma_input_block\n\n    # dsdt and dstrap for current positions (using shifted strap/sdt)\n    dsdt_block = tl.sigmoid(-strap_block.to(tl.float32)) * dsgamma_block\n    dstrap_block = -sdt_block * dsgamma_block\n\n    # Compute dsdt/dstrap at previous position for cross-chunk shift\n    prev_seq = chunk_start - 1\n    prev_mask = prev_seq >= 0\n    prev_dgamma = tl.load(\n        dfactor_ptr + dfactor_offset + prev_seq * stride_dfactor_seq,\n        mask=prev_mask, other=0.0\n    )\n    # prev_dsgamma_input = tl.load(\n    #     dsgamma_ptr + dsgamma_offset + prev_seq * stride_dsgamma_seq,\n    #     mask=prev_mask, other=0.0\n    # )\n    prev_dsgamma = prev_dgamma  #+ prev_dsgamma_input\n    prev_strap = tl.load(\n        trap_ptr + trap_offset + chunk_start * stride_trap_seq,\n        mask=chunk_start < SEQLEN, other=0.0\n    )\n    prev_sdt = tl.load(\n        dt_ptr + dt_offset + chunk_start * stride_dt_seq,\n        mask=chunk_start < SEQLEN, other=0.0\n    )\n    prev_dsdt = tl.sigmoid(-prev_strap.to(tl.float32)) * prev_dsgamma\n    prev_dstrap = -prev_sdt * prev_dsgamma\n\n    # Shift right by one within chunk: out[i] = in[i-1], with cross-chunk value at i=0\n    offs_i = tl.arange(0, CHUNK_SIZE)[:, None]\n    offs_j = tl.arange(0, CHUNK_SIZE)[None, :]\n    shift_mask = offs_i == (offs_j + 1)\n    dsdt_shift = tl.sum(tl.where(shift_mask, dsdt_block[None, :], 0.0), axis=1)\n    dstrap_shift = tl.sum(tl.where(shift_mask, dstrap_block[None, :], 0.0), axis=1)\n\n    offs = tl.arange(0, CHUNK_SIZE)\n    dsdt_shift = tl.where(offs == 0, prev_dsdt, dsdt_shift)\n    dstrap_shift = tl.where(offs == 0, prev_dstrap, dstrap_shift)\n\n    # Add dgamma path\n    ddt_out = dsdt_shift + dgamma_block * tl.sigmoid(trap_block.to(tl.float32))\n    dtrap_out = dstrap_shift + dgamma_block * dt_block \n    dtrap_out *= tl.sigmoid(trap_block.to(tl.float32)) * tl.sigmoid(-trap_block.to(tl.float32)) \n\n    ddt_ptrs = ddt_ptr + (pid_batch * stride_ddt_batch +\n                          pid_head * stride_ddt_head +\n                          offs_seq * stride_ddt_seq)\n    dtrap_ptrs = dtrap_ptr + (pid_batch * stride_dtrap_batch +\n                              pid_head * stride_dtrap_head +\n                              offs_seq * stride_dtrap_seq)\n\n    tl.store(ddt_ptrs, ddt_out, mask=offs_seq < SEQLEN)\n    tl.store(dtrap_ptrs, dtrap_out, mask=offs_seq < SEQLEN)\n\n\n# ============================================================================\n# Kernel 4:  compute da_cs, da_cs_rev, segsum from da\n# ============================================================================\n\n@triton.autotune(\n    configs=[\n        triton.Config({}, num_stages=s, num_warps=w)\n        for s in [2, 3]\n        for w in [4, 8]\n    ],\n    key=[\"CHUNK_SIZE\"],\n)\n@triton.jit\ndef dacs_segsum_kernel(\n    da_ptr,\n    da_cs_ptr,\n    da_cs_rev_ptr,\n    segsum_ptr,\n    stride_da_batch, stride_da_head, stride_da_seq,\n    stride_da_cs_batch, stride_da_cs_head, stride_da_cs_seq,\n    stride_da_cs_rev_batch, stride_da_cs_rev_head, stride_da_cs_rev_seq,\n    stride_segsum_batch, stride_segsum_head, stride_segsum_chunk,\n    stride_segsum_row, stride_segsum_col,\n    SEQLEN: tl.constexpr,\n    CHUNK_SIZE: tl.constexpr,\n):\n    pid_batch = tl.program_id(0)\n    pid_head = tl.program_id(1)\n    pid_chunk = tl.program_id(2)\n\n    chunk_start = pid_chunk * CHUNK_SIZE\n    offs = tl.arange(0, CHUNK_SIZE)\n    offs_seq = chunk_start + offs\n    mask = offs_seq < SEQLEN\n\n    base_da = pid_batch * stride_da_batch + pid_head * stride_da_head\n    da_chunk = tl.load(da_ptr + base_da + offs_seq * stride_da_seq, mask=mask, other=0.0)\n\n    da_cs = tl.cumsum(da_chunk, axis=0)\n    da_cs = tl.minimum(da_cs, 0.0)\n    \n    da_cs_rev = tl.cumsum(da_chunk, axis=0, reverse=True)\n    # Roll one to the left:\n    i = tl.arange(0, CHUNK_SIZE)[:, None]          # [N,1]\n    j = tl.arange(0, CHUNK_SIZE)[None, :]          # [1,N]\n    S = (i == j - 1)                      # strictly upper diagonal (one above main)\n    da_cs_rev = tl.sum(tl.where(S, da_cs_rev, 0), axis=1)\n    da_cs_rev = tl.minimum(da_cs_rev, 0.0)\n\n    base_da_cs = pid_batch * stride_da_cs_batch + pid_head * stride_da_cs_head\n    base_da_cs_rev = pid_batch * stride_da_cs_rev_batch + pid_head * stride_da_cs_rev_head\n    tl.store(da_cs_ptr + base_da_cs + offs_seq * stride_da_cs_seq, da_cs, mask=mask)\n    tl.store(da_cs_rev_ptr + base_da_cs_rev + offs_seq * stride_da_cs_rev_seq, da_cs_rev, mask=mask)\n\n    broadcasted_indices = tl.zeros_like(offs)\n    segsum = tl.load(da_ptr + base_da + offs_seq[:, None] * stride_da_seq + broadcasted_indices[None, :])\n    offs_i = offs[:, None]\n    offs_j = offs[None, :]\n    segsum = tl.where(offs_i > offs_j, segsum, 0.0)\n    segsum = tl.cumsum(segsum, axis=0)\n    segsum = tl.minimum(segsum, 0.0)\n\n    base_segsum = (pid_batch * stride_segsum_batch +\n                   pid_head * stride_segsum_head +\n                   pid_chunk * stride_segsum_chunk)\n    tl.store(segsum_ptr + base_segsum + offs_i * stride_segsum_row + offs_j * stride_segsum_col, segsum)\n\n# ============================================================================\n# Wrapper Function\n# ============================================================================\n\ndef bwd_dadt_fused_triton(\n    dSSdA: torch.Tensor,              # [B, H, nchunks, C, C]\n    SSdA: torch.Tensor,               # [B, H, nchunks, C, C]\n    ddA_cs: torch.Tensor,             # [B, H, S]\n    ddA_cs_rev: torch.Tensor,         # [B, H, S]\n    dA_cs: torch.Tensor,              # [B, H, S]\n    dA_cs_rev: torch.Tensor,          # [B, H, S]\n    chunk_size: int,\n) -> torch.Tensor:\n    # Validate inputs\n    B, H, S = ddA_cs.shape\n    nchunks = S // chunk_size\n    assert S % chunk_size == 0, f\"Sequence length {S} must be divisible by chunk_size {chunk_size}\"\n    assert dSSdA.shape == (B, H, nchunks, chunk_size, chunk_size), \\\n        f\"dSSdA shape mismatch: expected {(B, H, nchunks, chunk_size, chunk_size)}, got {dSSdA.shape}\"\n\n    # Initialize output tensor\n    dadt_out = torch.zeros(B, H, S, device=ddA_cs.device, dtype=torch.float32)\n\n    # Kernel 1: Fused ddA_cs and ddA_cs_rev contributions\n    grid1 = (B, H, nchunks)\n    bwd_dadt_cumsum_fused_kernel[grid1](\n        ddA_cs, ddA_cs_rev, dA_cs, dA_cs_rev, dadt_out,\n        ddA_cs.stride(0), ddA_cs.stride(1), ddA_cs.stride(2),\n        B, H, S,\n        CHUNK_SIZE=chunk_size,\n    )\n\n    # Kernel 2: dSSdA segsum contribution\n    grid2 = (B, H, nchunks)\n    bwd_segsum_dadt_kernel[grid2](\n        dSSdA, SSdA, dadt_out,\n        dSSdA.stride(0), dSSdA.stride(1), dSSdA.stride(2),\n        dSSdA.stride(3), dSSdA.stride(4),\n        SSdA.stride(0), SSdA.stride(1), SSdA.stride(2),\n        SSdA.stride(3), SSdA.stride(4),\n        dadt_out.stride(0), dadt_out.stride(1), dadt_out.stride(2),\n        B, H, nchunks, chunk_size,\n        CHUNK_SIZE=chunk_size,\n    )\n\n    return dadt_out\n\ndef bwd_dtrap_ddt_triton(\n    trap: torch.Tensor,      # [B, H, S]\n    dt: torch.Tensor,        # [B, H, S]\n    dfactor: torch.Tensor,   # [B, H, S]\n    dgamma_diag: torch.Tensor,   # [B, H, S]\n    chunk_size: int, # NOTE: the chunk_size does not have to be the same as the other kernels\n):\n    B, H, S = dt.shape\n    nchunks = S // chunk_size\n\n    ddt = torch.zeros_like(dt)\n    dtrap = torch.zeros_like(trap)\n\n    grid = (B, H, nchunks)\n    bwd_dtrap_ddt_kernel[grid](\n        trap, dt, dfactor, dgamma_diag,\n        ddt, dtrap,\n        trap.stride(0), trap.stride(1), trap.stride(2),\n        dt.stride(0), dt.stride(1), dt.stride(2),\n        dfactor.stride(0), dfactor.stride(1), dfactor.stride(2),\n        dgamma_diag.stride(0), dgamma_diag.stride(1), dgamma_diag.stride(2),\n        ddt.stride(0), ddt.stride(1), ddt.stride(2),\n        dtrap.stride(0), dtrap.stride(1), dtrap.stride(2),\n        S,\n        chunk_size,\n    )\n    return ddt, dtrap\n\ndef compute_dacs_segsum_triton(\n    da: torch.Tensor,  # (B, H, S)\n    chunk_size: int,\n):\n    B, H, S = da.shape\n    nchunks = (S + chunk_size - 1) // chunk_size\n\n    da_cs = torch.empty_like(da)\n    da_cs_rev = torch.empty_like(da)\n    segsum = torch.empty(B, H, nchunks, chunk_size, chunk_size, device=da.device, dtype=da.dtype)\n\n    grid = (B, H, nchunks)\n    dacs_segsum_kernel[grid](\n        da, da_cs, da_cs_rev, segsum,\n        da.stride(0), da.stride(1), da.stride(2),\n        da_cs.stride(0), da_cs.stride(1), da_cs.stride(2),\n        da_cs_rev.stride(0), da_cs_rev.stride(1), da_cs_rev.stride(2),\n        segsum.stride(0), segsum.stride(1), segsum.stride(2),\n        segsum.stride(3), segsum.stride(4),\n        S,\n        chunk_size,\n    )\n\n    return da_cs, da_cs_rev, segsum\n\n\n# ============================================================================\n# Reference Implementations (for testing)\n# ============================================================================\n\ndef bwd_segsum_ddt_from_dSSdA_ref(\n    dSSdA: torch.Tensor,\n    dA_cs: torch.Tensor,\n    chunk_size: int,\n):\n    \"\"\"Reference implementation of bwd_segsum_ddt_from_dSSdA.\"\"\"\n    B, H, nchunks, C, C_ = dSSdA.shape\n    assert C == chunk_size == C_\n    dA_cs_chunk = dA_cs.view(B, H, nchunks, C)\n    dSSdA = dSSdA.permute([0, 1, 2, 4, 3])\n    seg = dA_cs_chunk[..., :, None] - dA_cs_chunk[..., None, :]\n    dSSdA = dSSdA * torch.exp(seg)\n    ddA = torch.flip(torch.cumsum(torch.flip(dSSdA, dims=[-2]), dim=-2), dims=[-2])\n    mask = torch.tril(torch.ones(C, C, device=dSSdA.device, dtype=dSSdA.dtype), -1)\n    ddA = ddA * mask\n    ddA = ddA.sum(-1)\n    ddt = ddA * (-math.log2(math.e))\n    return ddt.reshape(B, H, nchunks*C)\n\n\ndef bwd_ddt_from_ddA_cs_rev_ref(\n    ddA_cs_rev: torch.Tensor,\n    dA_cs_rev: torch.Tensor,\n    chunk_size: int,\n):\n    \"\"\"Reference implementation of bwd_ddt_from_ddA_cs_rev.\"\"\"\n    B, H, S = ddA_cs_rev.shape\n    nchunks = S // chunk_size\n    ddA_cs_rev = torch.exp(dA_cs_rev) * ddA_cs_rev\n    dA_cs_rev = dA_cs_rev.view(B, H, nchunks, chunk_size)\n    ddA_cs_rev = ddA_cs_rev.view(B, H, nchunks, chunk_size)\n    ddA = torch.cumsum(ddA_cs_rev, dim=-1)\n    ddA = torch.cat([torch.zeros_like(ddA[..., :1]), ddA[..., :-1]], dim=-1)\n    ddt = ddA * (-math.log2(math.e))\n    return ddt.reshape(B, H, nchunks*chunk_size)\n\n \ndef bwd_ddt_from_ddA_cs_ref(\n    ddA_cs: torch.Tensor,\n    dA_cs: torch.Tensor,\n    chunk_size: int,\n):\n    \"\"\"Reference implementation of bwd_ddt_from_ddA_cs.\"\"\"\n    B, H, S = ddA_cs.shape\n    nchunks = S // chunk_size\n    ddA_cs =  torch.exp(dA_cs) * ddA_cs\n    dA_cs = dA_cs.view(B, H, nchunks, chunk_size)\n    ddA_cs = ddA_cs.view(B, H, nchunks, chunk_size)\n    ddA = torch.flip(torch.cumsum(torch.flip(ddA_cs, dims=[-1]), dim=-1), dims=[-1])\n    ddt = ddA * (-math.log2(math.e))\n    return ddt.reshape(B, H, nchunks*chunk_size)\n\ndef compute_dtrap_ddt_ref(dfactor: torch.Tensor,\n                          dgamma_diag_input: torch.Tensor,\n                          trap_presigmoid,\n                          dt,\n                          ) -> Tuple[torch.Tensor, torch.Tensor]:\n    trap = torch.nn.functional.sigmoid(trap_presigmoid)\n    strap = torch.nn.functional.pad(trap[:, :, 1:], (0, 1), value=0.0)\n    sdt = torch.nn.functional.pad(dt[:, :, 1:], (0, 1), value=0.0)\n    dgamma = dfactor.detach().clone() + dgamma_diag_input.detach().clone()\n    dsgamma = dfactor.detach().clone() # + dsgamma_input.detach().clone()\n    dsdt = (1 - strap) * dsgamma\n    dstrap = -sdt * dsgamma\n    # shift rightward:\n    ddt = torch.nn.functional.pad(dsdt[:, :, :-1], (1, 0), value=0.0)\n    dtrap = torch.nn.functional.pad(dstrap[:, :, :-1], (1, 0), value=0.0)\n    # Add the dgamma path:\n    dtrap += dgamma*dt\n    # grad of sigmoid(x) = sigmoid(x) * (1 - sigmoid(x))\n    dtrap *= trap * torch.nn.functional.sigmoid(-trap_presigmoid)\n    ddt += dgamma*trap\n    return ddt, dtrap\n\ndef compute_dacs_segsum_ref(da: torch.Tensor, # (B, H, S)\n                        chunk_size: int,\n                        ):\n    B, H, S = da.shape\n    nchunks = S // chunk_size\n\n    da_reshaped = da.view(B, H, nchunks, chunk_size)\n    da_cs = torch.cumsum(da_reshaped, dim=-1)\n    da_cs_sum = torch.sum(da_reshaped, dim=-1)\n    da_cs_rev = da_cs_sum[..., None] - da_cs #torch.flip(torch.cumsum(torch.flip(da_reshaped, dims=[-1]), dim=-1), dims=[-1])\n\n    from einops import repeat\n    segsum = repeat(da_reshaped, \"... d -> ... d e\", e=chunk_size)\n    mask = torch.tril(torch.ones(chunk_size, chunk_size, device=da_cs.device, dtype=bool), diagonal=-1)\n    segsum = segsum.masked_fill(~mask, 0)\n    segsum = torch.cumsum(segsum, dim=-2)\n\n    return da_cs.view(B, H, S), da_cs_rev.view(B, H, S), segsum\n\n\n# ============================================================================\n# Testing Functions\n# ============================================================================\n\ndef test_bwd_ddt_fused_correctness():\n    \"\"\"Test the fused kernel against reference implementation.\"\"\"\n    print(\"=\" * 70)\n    print(\"Test: basic_correctness\")\n    print(\"=\" * 70)\n\n    B, H, S = 16, 32, 2048\n    chunk_size = 16\n    nchunks = S // chunk_size\n    C = chunk_size\n\n    # Generate random inputs\n    torch.manual_seed(42)\n    dSSdA = torch.randn(B, H, nchunks, C, C, device='cuda', dtype=torch.float32)\n    ddA_cs = torch.randn(B, H, S, device='cuda', dtype=torch.float32)\n    ddA_cs_rev = torch.randn(B, H, S, device='cuda', dtype=torch.float32)\n    dA_cs = torch.randn(B, H, S, device='cuda', dtype=torch.float32) * 0.1  # Scale to avoid overflow\n    dA_cs_rev = torch.randn(B, H, S, device='cuda', dtype=torch.float32) * 0.1\n    \n    dA_cs_reshape = dA_cs.view(B, H, nchunks, chunk_size)\n    SSdA = dA_cs_reshape[:, :, :, :, None] - dA_cs_reshape[:, :, :, None, :]\n\n    # Reference implementation (separate functions)\n    ddt_ref1 = bwd_segsum_ddt_from_dSSdA_ref(dSSdA.clone(), dA_cs.clone(), chunk_size)\n    ddt_ref2 = bwd_ddt_from_ddA_cs_rev_ref(ddA_cs_rev.clone(), dA_cs_rev.clone(), chunk_size)\n    ddt_ref3 = bwd_ddt_from_ddA_cs_ref(ddA_cs.clone(), dA_cs.clone(), chunk_size)\n    ddt_ref = ddt_ref1 + ddt_ref2 + ddt_ref3 # TODO: \n\n    # Fused Triton implementation\n    ddt_triton = bwd_dadt_fused_triton(\n        dSSdA, SSdA, ddA_cs, ddA_cs_rev, dA_cs, dA_cs_rev, chunk_size\n    ) * -1.4426950408889634 # i.e., -log2(e)\n\n    # Compare\n    max_diff = (ddt_ref - ddt_triton).abs().max().item()\n    mean_diff = (ddt_ref - ddt_triton).abs().mean().item()\n\n    print(f\"  Max difference: {max_diff:.2e}\")\n    print(f\"  Mean difference: {mean_diff:.2e}\")\n    passed = max_diff < 1e-4\n    print(f\"  Status: {'PASS' if passed else 'FAIL'}\")\n    print()\n\n    return passed\n\ndef test_dtrap_ddt_correctness():\n    \"\"\"Test the fused kernel against reference implementation.\"\"\"\n    import torch.nn.functional as F\n\n    print(\"=\" * 70)\n    print(\"Test: basic_correctness\")\n    print(\"=\" * 70)\n\n    B, H, S = 16, 32, 2048\n    chunk_size = 16\n    nchunks = S // chunk_size\n    C = chunk_size\n\n    # Generate random inputs\n    torch.manual_seed(42)\n\n    trap = torch.rand(B, H, S, device='cuda', dtype=torch.float16)\n    dt = F.softplus(-3.0 + torch.randn(B, H, S, device='cuda', dtype=torch.float))\n    dfactor = torch.randn(B, H, S, device='cuda', dtype=torch.float32) * 0.1\n    dgamma_diag = torch.randn(B, H, S, device='cuda', dtype=torch.float32) * 0.1\n\n    # Reference implementation\n    ddt_ref, dtrap_ref = compute_dtrap_ddt_ref(dfactor, dgamma_diag, trap, dt)\n\n    # Triton implementation\n    ddt_triton, dtrap_triton = bwd_dtrap_ddt_triton(\n        trap, dt, dfactor, dgamma_diag, chunk_size\n    )\n\n    # Compare\n    max_diff_ddt = (ddt_ref - ddt_triton).abs().max().item()\n    mean_diff_ddt = (ddt_ref - ddt_triton).abs().mean().item()\n    max_diff_dtrap = (dtrap_ref - dtrap_triton).abs().max().item()\n    mean_diff_dtrap = (dtrap_ref - dtrap_triton).abs().mean().item()\n\n    print(f\"  ddt max difference:   {max_diff_ddt:.2e}\")\n    print(f\"  ddt mean difference:  {mean_diff_ddt:.2e}\")\n    print(f\"  dtrap max difference: {max_diff_dtrap:.2e}\")\n    print(f\"  dtrap mean difference:{mean_diff_dtrap:.2e}\")\n    passed = max(max_diff_ddt, max_diff_dtrap) < 1e-3\n    print(f\"  Status: {'PASS' if passed else 'FAIL'}\")\n    print()\n\n    return passed\n\ndef test_dacs_segsum_correctness():\n    import torch.nn.functional as F\n    B, H, S = 16, 32, 2048\n    chunk_size = 16\n    da = -F.softplus(-3.0 + torch.randn(B, H, S, device='cuda', dtype=torch.float))\n\n    da_cs_ref, da_cs_rev_ref, segsum_ref = compute_dacs_segsum_ref(da, chunk_size)\n    da_cs_triton, da_cs_rev_triton, segsum_triton = compute_dacs_segsum_triton(da, chunk_size)\n\n    max_diff_cs = (da_cs_ref - da_cs_triton).abs().max().item()\n    mean_diff_cs = (da_cs_ref - da_cs_triton).abs().mean().item()\n    max_diff_cs_rev = (da_cs_rev_ref - da_cs_rev_triton).abs().max().item()\n    mean_diff_cs_rev = (da_cs_rev_ref - da_cs_rev_triton).abs().mean().item()\n    max_diff_segsum = (segsum_ref - segsum_triton).abs().max().item()\n    mean_diff_segsum = (segsum_ref - segsum_triton).abs().mean().item()\n\n    print(f\"  da_cs max difference:     {max_diff_cs:.2e}\")\n    print(f\"  da_cs mean difference:    {mean_diff_cs:.2e}\")\n    print(f\"  da_cs_rev max difference: {max_diff_cs_rev:.2e}\")\n    print(f\"  da_cs_rev mean difference:{mean_diff_cs_rev:.2e}\")\n    print(f\"  segsum max difference:    {max_diff_segsum:.2e}\")\n    print(f\"  segsum mean difference:   {mean_diff_segsum:.2e}\")\n    passed = max(max_diff_cs, max_diff_cs_rev, max_diff_segsum) < 1e-4\n    print(f\"  Status: {'PASS' if passed else 'FAIL'}\")\n    print()\n\n    return passed\n\n\n# ============================================================================\n# Benchmarking Functions\n# ============================================================================\n\ndef benchmark_bwd_ddt():\n    \"\"\"Benchmark fused kernel against unfused baseline.\"\"\"\n    from triton.testing import do_bench\n\n\n    print(\"=\" * 70)\n    print(\"Benchmark: bwd_ddt_fused\")\n    print(\"=\" * 70)\n\n    B, H, S = 16, 32, 2048\n    chunk_size = 16\n    nchunks = S // chunk_size\n    C = chunk_size\n\n    print(f\"Configuration: B={B}, H={H}, S={S}, chunk_size={chunk_size}\")\n    print()\n\n    # Setup inputs\n    torch.manual_seed(42)\n    dSSdA = torch.randn(B, H, nchunks, C, C, device='cuda', dtype=torch.float32)\n    ddA_cs = torch.randn(B, H, S, device='cuda', dtype=torch.float32)\n    ddA_cs_rev = torch.randn(B, H, S, device='cuda', dtype=torch.float32)\n    dA_cs = torch.randn(B, H, S, device='cuda', dtype=torch.float32) * 0.1\n    dA_cs_rev = torch.randn(B, H, S, device='cuda', dtype=torch.float32) * 0.1\n\n    dA_cs_reshape = dA_cs.view(B, H, nchunks, chunk_size)\n    SSdA = dA_cs_reshape[:, :, :, :, None] - dA_cs_reshape[:, :, :, None, :]\n\n    # Benchmark reference (unfused)\n    def ref_impl():\n        ddt1 = bwd_segsum_ddt_from_dSSdA_ref(dSSdA, dA_cs, chunk_size)\n        ddt2 = bwd_ddt_from_ddA_cs_rev_ref(ddA_cs_rev, dA_cs_rev, chunk_size)\n        ddt3 = bwd_ddt_from_ddA_cs_ref(ddA_cs, dA_cs, chunk_size)\n        return ddt1 + ddt2 + ddt3\n\n    # Benchmark individual components\n    ref1_time = do_bench(lambda: bwd_segsum_ddt_from_dSSdA_ref(dSSdA, dA_cs, chunk_size), warmup=25, rep=100)\n    ref2_time = do_bench(lambda: bwd_ddt_from_ddA_cs_rev_ref(ddA_cs_rev, dA_cs_rev, chunk_size), warmup=25, rep=100)\n    ref3_time = do_bench(lambda: bwd_ddt_from_ddA_cs_ref(ddA_cs, dA_cs, chunk_size), warmup=25, rep=100)\n    ref_time = do_bench(ref_impl, warmup=25, rep=100)\n\n    # Benchmark fused\n    def fused_impl():\n        return bwd_dadt_fused_triton(\n            dSSdA, SSdA, ddA_cs, ddA_cs_rev, dA_cs, dA_cs_rev, chunk_size\n        )\n\n    fused_time = do_bench(fused_impl, warmup=25, rep=100)\n\n    print(\"Reference (unfused):\")\n    print(f\"  Function 1 (segsum): {ref1_time:.3f} ms\")\n    print(f\"  Function 2 (cs_rev): {ref2_time:.3f} ms\")\n    print(f\"  Function 3 (cs):     {ref3_time:.3f} ms\")\n    print(f\"  Total:               {ref_time:.3f} ms\")\n    print()\n    print(\"Fused Triton:\")\n    print(f\"  Total:               {fused_time:.3f} ms\")\n    print(f\"  Speedup:             {ref_time / fused_time:.2f}x\")\n    print()\n\n    return ref_time, fused_time\n\n\ndef benchmark_dacs_segsum():\n    \"\"\"Benchmark dacs+segsum Triton against reference implementation.\"\"\"\n    from triton.testing import do_bench\n    import torch.nn.functional as F\n\n    print(\"=\" * 70)\n    print(\"Benchmark: dacs_segsum\")\n    print(\"=\" * 70)\n\n    B, H, S = 16, 32, 2048\n    chunk_size = 16\n\n    print(f\"Configuration: B={B}, H={H}, S={S}, chunk_size={chunk_size}\")\n    print()\n\n    torch.manual_seed(42)\n    da = F.softplus(-3.0 + torch.randn(B, H, S, device='cuda', dtype=torch.float))\n\n    def ref_impl():\n        return compute_dacs_segsum_ref(da, chunk_size)\n\n    def triton_impl():\n        return compute_dacs_segsum_triton(da, chunk_size)\n\n    ref_time = do_bench(ref_impl, warmup=25, rep=100)\n    triton_time = do_bench(triton_impl, warmup=25, rep=100)\n\n    print(\"Reference:\")\n    print(f\"  Total: {ref_time:.3f} ms\")\n    print(\"Triton:\")\n    print(f\"  Total: {triton_time:.3f} ms\")\n    print(f\"  Speedup: {ref_time / triton_time:.2f}x\")\n    print()\n\n    return ref_time, triton_time\n\n\ndef benchmark_dtrap_ddt():\n    \"\"\"Benchmark dtrap/ddt kernel against reference implementation.\"\"\"\n    from triton.testing import do_bench\n    import torch.nn.functional as F\n\n    print(\"=\" * 70)\n    print(\"Benchmark: bwd_dtrap_ddt\")\n    print(\"=\" * 70)\n\n    B, H, S = 16, 32, 2048\n    chunk_size = 16\n\n    print(f\"Configuration: B={B}, H={H}, S={S}, chunk_size={chunk_size}\")\n    print()\n\n    torch.manual_seed(42)\n    trap = torch.ones(B, H, S, device='cuda', dtype=torch.float16) * 0.5\n    dt = F.softplus(-3.0 + torch.randn(B, H, S, device='cuda', dtype=torch.float))\n    dfactor = torch.randn(B, H, S, device='cuda', dtype=torch.float32) * 0.1\n    dgamma_diag = torch.randn(B, H, S, device='cuda', dtype=torch.float32) * 0.1\n\n    def ref_impl():\n        return compute_dtrap_ddt_ref(dfactor, dgamma_diag, trap, dt)\n\n    def triton_impl():\n        return bwd_dtrap_ddt_triton(trap, dt, dfactor, dgamma_diag, chunk_size)\n\n    ref_time = do_bench(ref_impl, warmup=25, rep=100)\n    triton_time = do_bench(triton_impl, warmup=25, rep=100)\n\n    print(\"Reference:\")\n    print(f\"  Total: {ref_time:.3f} ms\")\n    print(\"Triton:\")\n    print(f\"  Total: {triton_time:.3f} ms\")\n    print(f\"  Speedup: {ref_time / triton_time:.2f}x\")\n    print()\n\n    return ref_time, triton_time\n\n\n# ============================================================================\n# Main execution\n# ============================================================================\nif __name__ == \"__main__\":\n    test_bwd_ddt_fused_correctness()\n    # benchmark_bwd_ddt()\n    test_dtrap_ddt_correctness()\n    # benchmark_dtrap_ddt()\n    # benchmark_dacs_segsum()\n    test_dacs_segsum_correctness()\n    # benchmark_dacs_segsum()"
  },
  {
    "path": "mamba_ssm/ops/triton/mamba3/mamba3_siso_bwd.py",
    "content": "\"\"\"\nMamba-3 Backward Pass Triton Kernels.\n\nCopyright (c) 2026, Dao AI Lab, Goombalab\n\"\"\"\n\nfrom typing import Optional, Tuple\nimport math\n\nimport torch\nimport torch.nn.functional as F\nfrom einops import rearrange, repeat\n\nimport triton\nimport triton.language as tl\nfrom mamba_ssm.ops.triton.mamba3.utils import cos_approx, sin_approx, sigmoid_approx\n\n# =============================================================================\n# dZ Kernel\n# =============================================================================\n\n@triton.autotune(\n    configs=[\n        triton.Config({\"CHUNK_SIZE\": cs}, num_stages=s, num_warps=w)\n        for cs in [32, 64]\n        for s in [1, 2, 3]\n        for w in [2, 4, 8]\n    ],\n    key=[\"HEADDIM_V\"]\n)\n@triton.jit\ndef mamba3_siso_bwd_kernel_dzdo(\n    # Input tensors\n    DO, Z, O,\n    # Output tensors\n    Dz, DO_scaled,\n    # Strides for DO: (batch, seqlen, nheads, headdim_v)\n    stride_do_batch, stride_do_seqlen, stride_do_head, stride_do_vdim,\n    # Strides for Z: (batch, seqlen, nheads, headdim_v)\n    stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_vdim,\n    # Strides for O: (batch, seqlen, nheads, headdim_v)\n    stride_o_batch, stride_o_seqlen, stride_o_head, stride_o_vdim,\n    # Strides for Dz: (batch, seqlen, nheads, headdim_v)\n    stride_dz_batch, stride_dz_seqlen, stride_dz_head, stride_dz_vdim,\n    # Strides for DO_scaled: (batch, seqlen, nheads, headdim_v)\n    stride_do_scaled_batch, stride_do_scaled_seqlen, stride_do_scaled_head, stride_do_scaled_vdim,\n    # Dimensions\n    seqlen, headdim_v,\n    # Compile-time constants\n    CHUNK_SIZE: tl.constexpr,\n    HEADDIM_V: tl.constexpr,\n):\n    \"\"\"\n    Backward kernel for Z-gating: computes dZ and scales dO.\n    \n    In the forward pass, output is gated as: out = O * Z * sigmoid(Z) = O * silu(Z)\n    \n    This kernel computes:\n        - dZ = dO * O * sigmoid(Z) * (1 + Z * (1 - sigmoid(Z)))\n        - dO_scaled = dO * sigmoid(Z) * Z  (for downstream gradient computation)\n    \n    Each program instance processes one (chunk, head, batch) triplet.\n    \"\"\"\n    pid_chunk = tl.program_id(0)\n    pid_head = tl.program_id(1)\n    pid_batch = tl.program_id(2)\n\n    # Compute offsets for this (batch, head) pair\n    do_offset = pid_batch * stride_do_batch + pid_head * stride_do_head\n    z_offset = pid_batch * stride_z_batch + pid_head * stride_z_head\n    o_offset = pid_batch * stride_o_batch + pid_head * stride_o_head\n    dz_offset = pid_batch * stride_dz_batch + pid_head * stride_dz_head\n    do_scaled_offset = pid_batch * stride_do_scaled_batch + pid_head * stride_do_scaled_head\n\n    chunk_start = pid_chunk * CHUNK_SIZE\n    offs_seq = chunk_start + tl.arange(0, CHUNK_SIZE)\n    offs_dim = tl.arange(0, HEADDIM_V)\n    mask = (offs_seq[:, None] < seqlen) & (offs_dim[None, :] < HEADDIM_V)\n\n    # Load dO block: (CHUNK_SIZE, headdim_v)\n    do_ptrs = DO + do_offset + offs_seq[:, None] * stride_do_seqlen + offs_dim[None, :] * stride_do_vdim\n    do_block = tl.load(do_ptrs, mask=mask, other=0.0)\n    # Load Z block: (CHUNK_SIZE, headdim_v)\n    z_ptrs = Z + z_offset + offs_seq[:, None] * stride_z_seqlen + offs_dim[None, :] * stride_z_vdim\n    z_block = tl.load(z_ptrs, mask=mask, other=0.0)\n    # Load O block (pre-gating output): (CHUNK_SIZE, headdim_v)\n    o_ptrs = O + o_offset + offs_seq[:, None] * stride_o_seqlen + offs_dim[None, :] * stride_o_vdim\n    o_block = tl.load(o_ptrs, mask=mask, other=0.0)\n\n    # Compute sigmoid(Z) for gating\n    sigmoid_z = tl.sigmoid(z_block.to(tl.float32))\n    \n    # Scale dO by sigmoid(Z)\n    do_block = do_block * sigmoid_z\n\n    # Compute dZ gradient\n    # d/dZ [O * Z * sigmoid(Z)] = O * sigmoid(Z) * (1 + Z * (1 - sigmoid(Z)))\n    #                           = O * sigmoid(Z) + O * Z * sigmoid(Z) * (1 - sigmoid(Z))\n    dz_block = do_block * o_block * (1 + z_block * (1 - sigmoid_z))\n    \n    # Store dZ\n    dz_ptrs = Dz + dz_offset + offs_seq[:, None] * stride_dz_seqlen + offs_dim[None, :] * stride_dz_vdim\n    tl.store(dz_ptrs, dz_block, mask=mask)\n\n    # Complete scaling of dO: dO * sigmoid(Z) * Z\n    do_block = do_block * z_block\n    \n    # Store scaled dO for downstream gradient computation\n    do_scaled_ptrs = DO_scaled + do_scaled_offset + offs_seq[:, None] * stride_do_scaled_seqlen + offs_dim[None, :] * stride_do_scaled_vdim\n    tl.store(do_scaled_ptrs, do_block, mask=mask)\n\n\n\ndef compute_dzdo(\n    do: torch.Tensor,\n    z: torch.Tensor,\n    o: torch.Tensor,\n    chunk_size: int = 64,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Compute Z-gating gradients for Mamba-3 backward pass.\n    \n    When Z-gating is used in the forward pass (out = O * silu(Z)), this function\n    computes the gradient with respect to Z and scales dO for downstream\n    gradient computation.\n    \n    Args:\n        do: Output gradient tensor of shape (batch, seqlen, nheads, headdim_v)\n        z: Gating tensor from forward pass of shape (batch, seqlen, nheads, headdim_v)\n        o: Pre-gating output from forward pass of shape (batch, seqlen, nheads, headdim_v)\n        chunk_size: Chunk size used in forward pass (default: 64)\n    \n    Returns:\n        Tuple containing:\n            - dz: Gradient for Z tensor of shape (batch, seqlen, nheads, headdim_v)\n            - do_scaled: Scaled output gradient of shape (batch, seqlen, nheads, headdim_v)\n                        This should be used as input to subsequent gradient kernels.\n\n    \"\"\"\n    batch, seqlen, nheads, headdim_v = do.shape\n    \n    # Validate inputs\n    assert z is not None and o is not None and do is not None, \"Z, O, and DO tensors must be provided\"\n    assert z.is_cuda and o.is_cuda and do.is_cuda, \"All tensors must be on CUDA\"\n    assert z.shape == do.shape and o.shape == do.shape, f\"Shape mismatch: Z={z.shape}, O={o.shape}, DO={do.shape}\"\n\n    # Ensure contiguity for optimal memory access\n    if do.stride(-1) != 1:\n        do = do.contiguous()\n    if z.stride(-1) != 1:\n        z = z.contiguous()\n    if o.stride(-1) != 1:\n        o = o.contiguous()\n\n    # Allocate output tensors\n    dz = torch.empty_like(z, dtype=do.dtype)\n    do_scaled = torch.empty_like(do, dtype=do.dtype)\n\n    # Round up head dimension to power of 2 for efficient loading\n    HEADDIM_V = triton.next_power_of_2(headdim_v)\n\n    # Launch kernel: grid = (nchunks, nheads, batch)\n    # CHUNK_SIZE is autotuned, so we compute nchunks dynamically via a lambda\n    def grid(META):\n        return (triton.cdiv(seqlen, META[\"CHUNK_SIZE\"]), nheads, batch)\n    \n    mamba3_siso_bwd_kernel_dzdo[grid](\n        do, z, o,\n        dz, do_scaled,\n        # DO strides\n        do.stride(0), do.stride(1), do.stride(2), do.stride(3),\n        # Z strides\n        z.stride(0), z.stride(1), z.stride(2), z.stride(3),\n        # O strides\n        o.stride(0), o.stride(1), o.stride(2), o.stride(3),\n        # Dz strides\n        dz.stride(0), dz.stride(1), dz.stride(2), dz.stride(3),\n        # DO_scaled strides\n        do_scaled.stride(0), do_scaled.stride(1), do_scaled.stride(2), do_scaled.stride(3),\n        # Dimensions\n        seqlen, headdim_v,\n        # Compile-time constants\n        HEADDIM_V=HEADDIM_V,\n    )\n\n    return dz, do_scaled\n\n\n# =============================================================================\n# dQKV Kernel\n# =============================================================================\n\n@triton.autotune(\n    configs=[\n        triton.Config({}, num_stages=s, num_warps=w)\n        for s in [1, 2, 3]\n        for w in [2, 4, 8]\n    ],\n    key=[\"CHUNK_SIZE\", \"HEADDIM_QK\", \"HEADDIM_V\", \"IS_VARLEN\"]\n)\n@triton.jit\ndef mamba3_siso_bwd_kernel_dqkv(\n    # Input tensors\n    Q, K, V, DA_CS, DA_CS_SUM, QK_Dot, D, SSM_States, dO, d_OSSM_State, Cu_Seqlens, # dO is scaled with Z\n    # Output tensors\n    dQ, dK, dV, dADT, dQK_Dot, dD, d_ISSM_State, # dQK_Dot is scaled with scale\n    # Strides for Inputs\n    # Strides for Q: (batch, seqlen, nheads_qk, HEADDIM_QK)\n    stride_q_batch, stride_q_seqlen, stride_q_head, stride_q_qkdim,\n    # Strides for K: (batch, seqlen, nheads_qk, HEADDIM_QK)\n    stride_k_batch, stride_k_seqlen, stride_k_head, stride_k_qkdim,\n    # Strides for V: (batch, seqlen, nheads, HEADDIM_V)\n    stride_v_batch, stride_v_seqlen, stride_v_head, stride_v_vdim,\n    # Strides for DA_CS: (batch, nheads, seqlen)\n    stride_da_cs_batch, stride_da_cs_head, stride_da_cs_seqlen,\n    # Strides for DA_CS_SUM: (batch, nheads, nchunks)\n    stride_da_cs_sum_batch, stride_da_cs_sum_head, stride_da_cs_sum_seqlen,\n    # Strides for QK (QK dot products): (batch, nheads, nchunks*CHUNK_SIZE)\n    stride_qk_dot_batch, stride_qk_dot_head, stride_qk_dot_seqlen,\n    # Strides for D: (nheads,)\n    stride_d_head,\n    # Strides for SSM_States: (batch, nheads, HEADDIM_V, nchunks*HEADDIM_QK)\n    stride_ssm_states_batch, stride_ssm_states_head, stride_ssm_states_vdim, stride_ssm_states_qkdim,\n    # Strides for dO: (batch, seqlen, nheads, HEADDIM_V)\n    stride_do_batch, stride_do_seqlen, stride_do_head, stride_do_vdim,\n    # Strides for d_OSSM_State: (num_sequences, nheads, HEADDIM_V, HEADDIM_QK)\n    stride_d_ossm_state_batch, stride_d_ossm_state_head, stride_d_ossm_state_vdim, stride_d_ossm_state_qkdim,\n    # Strides for Cu_Seqlens: (num_sequences + 1,)\n    stride_cu_seqlen,\n    # Strides for Outputs\n    # Strides for dQ: (batch, seqlen, nheads, HEADDIM_QK)\n    stride_dq_batch, stride_dq_seqlen, stride_dq_head, stride_dq_qkdim,\n    # Strides for dK: (batch, seqlen, nheads, HEADDIM_QK)\n    stride_dk_batch, stride_dk_seqlen, stride_dk_head, stride_dk_qkdim,\n    # Strides for dV: (batch, seqlen, nheads, HEADDIM_V)\n    stride_dv_batch, stride_dv_seqlen, stride_dv_head, stride_dv_vdim,\n    # Strides for dAdt: (batch, nheads, seqlen)\n    stride_dadt_batch, stride_dadt_head, stride_dadt_seqlen,\n    # Strides for dQK_dot: (batch, nheads, seqlen)\n    stride_dQK_dot_batch, stride_dQK_dot_head, stride_dQK_dot_seqlen,\n    # Strides for dD: (nheads,)\n    stride_dd_batch, stride_dd_head,\n    # Strides for d_ISSM_State: (num_sequences, nheads, HEADDIM_V, HEADDIM_QK)\n    stride_d_issm_state_batch, stride_d_issm_state_head, stride_d_issm_state_vdim, stride_d_issm_state_qkdim,\n    # Dimensions\n    seqlen, nheads_qk, headdim_qk, headdim_v,\n    CHUNK_SIZE: tl.constexpr,\n    HEADDIM_QK: tl.constexpr,\n    HEADDIM_V: tl.constexpr,\n    RECOMPUTE_MASK: tl.constexpr,\n    HAS_D_OSSM_STATE: tl.constexpr,\n    RETURN_D_ISSM_STATE: tl.constexpr,\n    IS_VARLEN: tl.constexpr,\n):\n    \"\"\"\n    Backward kernel for Mamba-3 attention mechanism.\n    \n    Each program instance handles one (head, batch/seq) pair and iterates through\n    all chunks in reverse order. This reverse iteration is necessary because\n    state gradients flow backward through the sequence.\n    \n    The kernel computes:\n        - dQ, dK: Gradients for query/key from both intra-chunk attention and inter-chunk states\n        - dV: Gradient for values\n        - dADT: Gradient for the decay parameter (A * dt)\n        - dQK_Dot: Gradient for the QK dot product term\n        - dD: Gradient for the skip connection (if present)\n        - dISSM_State: Gradient for the input SSM state (if present)\n\n    Grid:\n        - Normal mode: (nheads, batch)\n        - Varlen mode: (nheads, num_sequences)\n    \"\"\"\n    # ==================== Program Indexing ====================\n    pid_head = tl.program_id(0)\n    pid_batch = tl.program_id(1)\n\n    if IS_VARLEN:\n        pid_seq = pid_batch\n        pid_batch = 0\n        cu_seqlen = tl.load(Cu_Seqlens + pid_seq * stride_cu_seqlen).to(tl.int32)\n        cu_seqlen_next = tl.load(Cu_Seqlens + (pid_seq + 1) * stride_cu_seqlen).to(tl.int32)\n        seqlen = cu_seqlen_next - cu_seqlen\n        cu_chunks = pid_seq + cu_seqlen // CHUNK_SIZE\n    else:\n        cu_seqlen = 0\n        cu_chunks = 0\n        pid_seq = 0\n\n    # Compute Q/K head index for GQA (grouped query attention)\n    # Multiple output heads may share the same Q/K head\n    nheads = tl.num_programs(0)\n    head_idx_qk = pid_head // (nheads // nheads_qk)\n\n    # Input Pointer Offsets\n    q_offset = pid_batch * stride_q_batch + head_idx_qk * stride_q_head + IS_VARLEN * cu_seqlen * stride_q_seqlen\n    k_offset = pid_batch * stride_k_batch + head_idx_qk * stride_k_head + IS_VARLEN * cu_seqlen * stride_k_seqlen\n    v_offset = pid_batch * stride_v_batch + pid_head * stride_v_head + IS_VARLEN * cu_seqlen * stride_v_seqlen\n    da_cs_offset = pid_batch * stride_da_cs_batch + pid_head * stride_da_cs_head + IS_VARLEN * cu_seqlen * stride_da_cs_seqlen\n    da_cs_sum_offset = pid_batch * stride_da_cs_sum_batch + pid_head * stride_da_cs_sum_head + IS_VARLEN * cu_chunks * stride_da_cs_sum_seqlen\n    qk_dot_offset = pid_batch * stride_qk_dot_batch + pid_head * stride_qk_dot_head + IS_VARLEN * cu_seqlen * stride_qk_dot_seqlen\n    ssm_states_offset = pid_batch * stride_ssm_states_batch + pid_head * stride_ssm_states_head + IS_VARLEN * cu_chunks * HEADDIM_QK * stride_ssm_states_qkdim\n    do_offset = pid_batch * stride_do_batch + pid_head * stride_do_head + IS_VARLEN * cu_seqlen * stride_do_seqlen\n    if HAS_D_OSSM_STATE:\n        d_ossm_state_offset = (pid_batch + IS_VARLEN * pid_seq) * stride_d_ossm_state_batch + pid_head * stride_d_ossm_state_head\n\n    # Load skip connection value D if present\n    if D is not None:\n        D_offset = pid_head * stride_d_head\n        D_val = tl.load(D + D_offset)\n\n    # Output Pointer Offsets\n    dq_offset = pid_batch * stride_dq_batch + pid_head * stride_dq_head + IS_VARLEN * cu_seqlen * stride_dq_seqlen\n    dk_offset = pid_batch * stride_dk_batch + pid_head * stride_dk_head + IS_VARLEN * cu_seqlen * stride_dk_seqlen\n    dv_offset = pid_batch * stride_dv_batch + pid_head * stride_dv_head + IS_VARLEN * cu_seqlen * stride_dv_seqlen\n    dadt_offset = pid_batch * stride_dadt_batch + pid_head * stride_dadt_head + IS_VARLEN * cu_seqlen * stride_dadt_seqlen\n    dQK_dot_offset = pid_batch * stride_dQK_dot_batch + pid_head * stride_dQK_dot_head + IS_VARLEN * cu_seqlen * stride_dQK_dot_seqlen\n    \n    if D is not None:\n        dD_offset = pid_head * stride_dd_head + pid_batch * stride_dd_batch + IS_VARLEN * pid_seq * stride_dd_batch\n        dD_acc = tl.zeros([1], dtype=tl.float32)\n    \n    if RETURN_D_ISSM_STATE:\n        d_issm_state_offset = (pid_batch + IS_VARLEN * pid_seq) * stride_d_issm_state_batch + pid_head * stride_d_issm_state_head\n\n    # Accumulates gradients flowing backward through states across chunks\n    if HAS_D_OSSM_STATE:\n        d_ssm_ptrs =  d_OSSM_State + d_ossm_state_offset + tl.arange(0, HEADDIM_V)[:, None] * stride_d_ossm_state_vdim + tl.arange(0, HEADDIM_QK)[None, :] * stride_d_ossm_state_qkdim\n        d_ssm_states_mask = (tl.arange(0, HEADDIM_V)[:, None] < headdim_v) & (tl.arange(0, HEADDIM_QK)[None, :] < headdim_qk)\n        d_ssm_states_acc = tl.load(d_ssm_ptrs, mask=d_ssm_states_mask, other=0.0).to(tl.float32)\n    else:\n        d_ssm_states_acc = tl.zeros([HEADDIM_V, HEADDIM_QK], dtype=tl.float32)\n\n    num_chunks = tl.cdiv(seqlen, CHUNK_SIZE)\n\n    #  TMA Descriptors for Efficient Memory Access \n    q_desc = tl.make_tensor_descriptor(\n        Q + q_offset,\n        shape=[seqlen, headdim_qk],\n        strides=[stride_q_seqlen, stride_q_qkdim],\n        block_shape=[CHUNK_SIZE, HEADDIM_QK],\n    )\n    k_desc = tl.make_tensor_descriptor(\n        K + k_offset,\n        shape=[seqlen, headdim_qk],\n        strides=[stride_k_seqlen, stride_k_qkdim],\n        block_shape=[CHUNK_SIZE, HEADDIM_QK],\n    )\n    v_desc = tl.make_tensor_descriptor(\n        V + v_offset,\n        shape=[seqlen, headdim_v],\n        strides=[stride_v_seqlen, stride_v_vdim],\n        block_shape=[CHUNK_SIZE, HEADDIM_V],\n    )\n    ssm_states_desc = tl.make_tensor_descriptor(\n        SSM_States + ssm_states_offset,\n        shape=[headdim_v, num_chunks * headdim_qk],\n        strides=[stride_ssm_states_vdim, stride_ssm_states_qkdim],\n        block_shape=[HEADDIM_V, HEADDIM_QK],\n    )\n    do_desc = tl.make_tensor_descriptor(\n        dO + do_offset,\n        shape=[seqlen, headdim_v],\n        strides=[stride_do_seqlen, stride_do_vdim],\n        block_shape=[CHUNK_SIZE, HEADDIM_V],\n    )\n    dq_desc = tl.make_tensor_descriptor(\n        dQ + dq_offset,\n        shape=[seqlen, headdim_qk],\n        strides=[stride_dq_seqlen, stride_dq_qkdim],\n        block_shape=[CHUNK_SIZE, HEADDIM_QK],\n    )\n    dk_desc = tl.make_tensor_descriptor(\n        dK + dk_offset,\n        shape=[seqlen, headdim_qk],\n        strides=[stride_dk_seqlen, stride_dk_qkdim],\n        block_shape=[CHUNK_SIZE, HEADDIM_QK],\n    )\n    dv_desc = tl.make_tensor_descriptor(\n        dV + dv_offset,\n        shape=[seqlen, headdim_v],\n        strides=[stride_dv_seqlen, stride_dv_vdim],\n        block_shape=[CHUNK_SIZE, HEADDIM_V],\n    )\n\n    for chunk_idx_loop in range(num_chunks):\n        chunk_idx = num_chunks - 1 - chunk_idx_loop  # Reverse order for backward pass\n        chunk_start = chunk_idx * CHUNK_SIZE\n\n        # Sequence-length mask for non-TMA loads/stores\n        offs_cs = chunk_start + tl.arange(0, CHUNK_SIZE)\n        seq_mask = offs_cs < seqlen\n\n        # ============================================================\n        # Load Decay Values\n        # We load these first to overlap computation with TMA loads\n        # ============================================================\n        da_cs_ptrs = DA_CS + da_cs_offset + offs_cs * stride_da_cs_seqlen\n        da_cs = tl.load(da_cs_ptrs, mask=seq_mask, other=0.0)  # Cumulative decay within chunk: (CHUNK_SIZE,)\n\n        da_cs_sum_ptrs = DA_CS_SUM + da_cs_sum_offset + chunk_idx * stride_da_cs_sum_seqlen\n        da_cs_chunk_sum = tl.load(da_cs_sum_ptrs)  # Total decay for this chunk: scalar\n\n        # ============================================================\n        # Load Q, K, V, dO, SSM_States via TMA\n        # ============================================================\n        do_block = do_desc.load([chunk_start, 0])  # (CHUNK_SIZE, HEADDIM_V)\n        v_block = v_desc.load([chunk_start, 0])    # (CHUNK_SIZE, HEADDIM_V)\n        q_block = q_desc.load([chunk_start, 0])    # (CHUNK_SIZE, HEADDIM_QK)\n        k_block = k_desc.load([chunk_start, 0])    # (CHUNK_SIZE, HEADDIM_QK)\n        ssm_states_block = ssm_states_desc.load([0, chunk_idx * headdim_qk])  # (HEADDIM_V, HEADDIM_QK)\n\n        # ============================================================\n        # Compute Decay Scaling Factors\n        # ============================================================\n        # Reverse cumsum: how much decay from position i to end of chunk\n        da_cs_rev = da_cs_chunk_sum - da_cs\n        exp_da_cs_rev = tl.math.exp2(da_cs_rev)  # For scaling inter-chunk contributions\n        exp_da_cs = tl.math.exp2(da_cs)          # For scaling intra-chunk contributions\n\n        # Compute strictly causal mask with exponential decay (this is L^T)\n        if not RECOMPUTE_MASK:\n            causal_decay_mask = tl.where(\n                tl.arange(0, CHUNK_SIZE)[None, :] > tl.arange(0, CHUNK_SIZE)[:, None],\n                tl.math.exp2(tl.minimum(da_cs[None, :] - da_cs[:, None], 0.0)),\n                0.0\n            )\n\n        # ============================================================\n        # Compute dADT Gradient (Part 1): From Intra-chunk Attention\n        # This is register-heavy so we compute it early before spilling\n        # ============================================================\n        # Gradient contribution from (QK^T ⊙ L) V term\n        dAinv = tl.dot(v_block, tl.trans(do_block))  # V @ dO^T\n        if RECOMPUTE_MASK:\n            dAinv *= tl.math.exp2(tl.minimum(da_cs[None, :] - da_cs[:, None], 0.0))\n            dAinv = tl.where(\n                tl.arange(0, CHUNK_SIZE)[None, :] > tl.arange(0, CHUNK_SIZE)[:, None],\n                dAinv,\n                0.0\n            )\n        else:\n            dAinv *= causal_decay_mask\n        dAinv *= tl.dot(k_block, tl.trans(q_block))  # Element-wise with K @ Q^T\n        dM_rev_vector = tl.sum(dAinv, axis=0) - tl.sum(dAinv, axis=1)  # (CHUNK_SIZE,)\n\n        # ============================================================\n        # Compute dK: Key Gradient\n        # dK = (V @ dO^T ⊙ mask)^T @ Q + V @ dStates * scale\n        # ============================================================\n        # Intra-chunk: dP^T @ Q where dP = dO @ V^T ⊙ mask\n        dp_t_block = tl.dot(v_block, tl.trans(do_block))  # V @ dO^T: (CHUNK_SIZE, CHUNK_SIZE)\n        if RECOMPUTE_MASK:\n            dp_t_block *= tl.math.exp2(tl.minimum(da_cs[None, :] - da_cs[:, None], 0.0))\n            dp_t_block = tl.where(\n                tl.arange(0, CHUNK_SIZE)[None, :] > tl.arange(0, CHUNK_SIZE)[:, None],\n                dp_t_block,\n                0.0\n            )\n        else:\n            dp_t_block *= causal_decay_mask\n\n        acc_dk = tl.dot(dp_t_block.to(q_block.dtype), q_block)  # (CHUNK_SIZE, HEADDIM_QK)\n\n        # Inter-chunk: gradient flowing through accumulated states\n        acc_dk += tl.dot(v_block, d_ssm_states_acc.to(v_block.dtype)) * exp_da_cs_rev[:, None]\n\n        dk_desc.store([chunk_start, 0], acc_dk)\n\n        # ============================================================\n        # Compute dQ: Query Gradient\n        # dQ = (V @ dO^T ⊙ mask) @ K + dO @ States * scale\n        # ============================================================\n        # Intra-chunk: S^T @ K where S = V @ dO^T ⊙ mask\n        s_block = tl.dot(v_block, tl.trans(do_block))  # (CHUNK_SIZE, CHUNK_SIZE)\n        if RECOMPUTE_MASK:\n            s_block *= tl.math.exp2(tl.minimum(da_cs[None, :] - da_cs[:, None], 0.0))\n            s_block = tl.where(\n                tl.arange(0, CHUNK_SIZE)[None, :] > tl.arange(0, CHUNK_SIZE)[:, None],\n                s_block,\n                0.0\n            )\n        else:\n            s_block *= causal_decay_mask\n\n        acc_dq = tl.dot(tl.trans(s_block).to(k_block.dtype), k_block)  # (CHUNK_SIZE, HEADDIM_QK)\n\n        # Inter-chunk: gradient through states from previous chunks\n        acc_dq += tl.dot(do_block, ssm_states_block) * exp_da_cs[:, None]\n\n        dq_desc.store([chunk_start, 0], acc_dq)\n\n        # ============================================================\n        # Compute dV: Value Gradient\n        # dV = (K @ Q^T ⊙ mask) @ dO + K @ dStates^T * scale + dO * (D + qk_dot)\n        # ============================================================\n        # Intra-chunk: P^T @ dO where P = Q @ K^T ⊙ mask\n        p_t_block = tl.dot(k_block, tl.trans(q_block))  # K @ Q^T: (CHUNK_SIZE, CHUNK_SIZE)\n        if RECOMPUTE_MASK:\n            p_t_block *= tl.math.exp2(tl.minimum(da_cs[None, :] - da_cs[:, None], 0.0))\n            p_t_block = tl.where(\n                tl.arange(0, CHUNK_SIZE)[None, :] > tl.arange(0, CHUNK_SIZE)[:, None],\n                p_t_block,\n                0.0\n            )\n        else:\n            p_t_block *= causal_decay_mask\n\n        acc_dv = tl.dot(p_t_block.to(do_block.dtype), do_block)  # (CHUNK_SIZE, HEADDIM_V)\n\n        # Inter-chunk: gradient through states\n        acc_dv += tl.dot(k_block, tl.trans(d_ssm_states_acc).to(k_block.dtype)) * exp_da_cs_rev[:, None]\n\n        # Skip connection gradient contribution\n        # Load dO again with volatile to avoid cache conflicts\n        dO_reloaded = tl.load(\n            dO + do_offset + offs_cs[:, None] * stride_do_seqlen +\n            tl.arange(0, HEADDIM_V)[None, :] * stride_do_vdim,\n            mask=seq_mask[:, None] & (tl.arange(0, HEADDIM_V)[None, :] < headdim_v),\n            other=0.0,\n            volatile=True\n        )\n\n        qk_dot = tl.load(QK_Dot + qk_dot_offset + offs_cs * stride_qk_dot_seqlen, mask=seq_mask, other=0.0)\n        if D is not None:\n            acc_dv += dO_reloaded * (D_val + qk_dot[:, None])\n        else:\n            acc_dv += dO_reloaded * qk_dot[:, None]\n\n        dv_desc.store([chunk_start, 0], acc_dv)\n\n        # ============================================================\n        # Compute dQK_Dot and dD: Skip Connection Gradients\n        # ============================================================\n        v_block_reloaded = tl.load(\n            V + v_offset + offs_cs[:, None] * stride_v_seqlen +\n            tl.arange(0, HEADDIM_V)[None, :] * stride_v_vdim,\n            mask=seq_mask[:, None] & (tl.arange(0, HEADDIM_V)[None, :] < headdim_v),\n            other=0.0,\n            volatile=True\n        )\n\n        # dQK_dot = sum_v(dO * V) for each position\n        dQK_dot_block = tl.dot(\n            dO_reloaded * v_block_reloaded,\n            tl.full([HEADDIM_V, 1], 1, dtype=dO_reloaded.dtype)\n        )\n\n        tl.store(\n            dQK_Dot + dQK_dot_offset + offs_cs * stride_dQK_dot_seqlen,\n            dQK_dot_block.reshape(CHUNK_SIZE),\n            mask=seq_mask\n        )\n\n        # Accumulate dD gradient\n        if D is not None:\n            dD_acc += tl.dot(\n                tl.full([1, CHUNK_SIZE], 1, dtype=tl.float32),\n                dQK_dot_block\n            ).reshape(1)\n\n        # ============================================================\n        # Compute dADT Gradient (Part 2): From Inter-chunk States\n        # ============================================================\n        # Gradient from Q @ States^T term\n        QS = tl.dot(q_block, tl.trans(ssm_states_block))  # (CHUNK_SIZE, HEADDIM_V)\n        dM_rev_vector += tl.sum(QS * dO_reloaded, axis=1) * exp_da_cs  # (CHUNK_SIZE,)\n\n        # ============================================================\n        # Compute dADT Gradient (Part 3): From State Accumulation\n        # ============================================================\n        # Gradient flowing through d_ssm_states_acc @ SSM_States\n        SSM_States_ptrs = (SSM_States + ssm_states_offset +\n                tl.arange(0, HEADDIM_V)[:, None] * stride_ssm_states_vdim +\n                (chunk_idx * headdim_qk + tl.arange(0, HEADDIM_QK)[None, :]) * stride_ssm_states_qkdim)\n        SSM_States_mask = (tl.arange(0, HEADDIM_V)[:, None] < headdim_v) & ((chunk_idx * headdim_qk + tl.arange(0, HEADDIM_QK)[None, :]) < num_chunks * headdim_qk)\n        \n        SSM_States_reloaded = tl.load(SSM_States_ptrs, volatile=True, mask=SSM_States_mask)  # (HEADDIM_V, HEADDIM_QK)\n        dM_scalar = tl.sum(SSM_States_reloaded * d_ssm_states_acc) * tl.math.exp2(da_cs_chunk_sum)\n\n        # ============================================================\n        # Compute dADT Gradient (Part 4): From K @ dStates\n        # ============================================================\n        dSK = tl.dot(k_block, tl.trans(d_ssm_states_acc).to(k_block.dtype))  # (CHUNK_SIZE, HEADDIM_V)\n        dM_vector = tl.sum(dSK * v_block_reloaded, axis=1) * exp_da_cs_rev  # (CHUNK_SIZE,)\n\n        # ============================================================\n        # Combine dADT Gradient Components via Reverse Cumsum\n        # ============================================================\n        dM_rev_vector += (tl.sum(dM_rev_vector) + dM_scalar) + tl.cumsum(dM_vector - dM_rev_vector) - dM_vector\n\n        # Store dADT\n        dadt_ptrs = dADT + dadt_offset + offs_cs * stride_dadt_seqlen\n        tl.store(dadt_ptrs, dM_rev_vector, mask=seq_mask)\n\n        # ============================================================\n        # Accumulate State Gradients for Previous Chunks\n        # ============================================================\n        dO_reloaded *= exp_da_cs[:, None]\n        d_ssm_states_acc = (tl.math.exp2(da_cs_chunk_sum) * d_ssm_states_acc +\n                       tl.dot(tl.trans(dO_reloaded).to(q_block.dtype), q_block))\n\n    # Store Final dD Gradient \n    if D is not None:\n        tl.store(dD + dD_offset + tl.arange(0, 1), dD_acc)\n\n    # Store d_ISSM_State \n    if RETURN_D_ISSM_STATE:\n        d_ISSM_State_ptrs = d_ISSM_State + d_issm_state_offset + tl.arange(0, HEADDIM_V)[:, None] * stride_d_issm_state_vdim + tl.arange(0, HEADDIM_QK)[None, :] * stride_d_issm_state_qkdim\n        d_ISSM_State_mask = (tl.arange(0, HEADDIM_V)[:, None] < headdim_v) & (tl.arange(0, HEADDIM_QK)[None, :] < headdim_qk)\n        tl.store(d_ISSM_State_ptrs, d_ssm_states_acc, mask=d_ISSM_State_mask)\n\n\ndef compute_dqkv(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    da_cs: torch.Tensor,\n    da_cs_sum: torch.Tensor,\n    qk_dot: torch.Tensor,\n    SSM_States: torch.Tensor,\n    do: torch.Tensor,\n    d_ossm_state: Optional[torch.Tensor] = None,\n    d_ov_state: Optional[torch.Tensor] = None,\n    D: Optional[torch.Tensor] = None,\n    chunk_size: int = 64,\n    has_input_state: bool = False,\n    Cu_Seqlens: Optional[torch.Tensor] = None,\n) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:\n    \"\"\"\n    Compute gradients dQ_mid, dK_mid, dV, dADT, dQK_dot, dD, d_issm_state for Mamba-3 backward pass.\n    \n    This kernel operates on the rotated/scaled Q and K tensors (Q_mid, K_mid from forward).\n    \n    Args:\n        q: Rotated query tensor Q_mid (batch, seqlen, headdim_qk, headdim_qk)\n        k: Rotated+scaled key tensor K_mid (batch, seqlen, headdim_qk, headdim_qk)\n        v: Value tensor (batch, seqlen, nheads, headdim_v)\n        da_cs: Cumulative decay per chunk (batch, nheads, seqlen)\n        da_cs_sum: Sum of decay per chunk (batch, nheads, nchunks)\n        qk_dot: QK dot products from forward (batch, nheads, seqlen)\n        SSM_States: SSM states from forward pass (batch, nheads, headdim_v, nchunks * headdim_qk)\n        do: Output gradient, possibly scaled by Z (batch, seqlen, nheads, headdim_v)\n        d_ossm_state: Gradient of output SSM states (num_sequences, nheads, headdim_v, headdim_qk)\n        d_ov_state: Gradient of output V state (num_sequences, nheads, headdim_v) - added to last token of dV\n        D: Optional skip connection weight (nheads,)\n        chunk_size: Chunk size (default: 64)\n        has_input_state: Whether to compute gradient for input states\n    \n    Returns:\n        Tuple of (dQ_mid, dK_mid, dV, dADT, dQK_dot, dD, d_issm_state)\n        where d_issm_state is None if has_input_state=False\n    \"\"\"\n    batch, seqlen, nheads_qk, headdim_qk = q.shape\n    _, _, nheads, headdim_v = v.shape\n    is_varlen = Cu_Seqlens is not None\n    \n    if is_varlen:\n        num_sequences = Cu_Seqlens.shape[0] - 1\n        assert batch == 1\n        nchunks = num_sequences + seqlen // chunk_size\n    else:\n        num_sequences = batch\n        nchunks = (seqlen + chunk_size - 1) // chunk_size\n\n    assert nheads % nheads_qk == 0, \"nheads must be divisible by nheads_qk (for GQA support)\"\n    assert q.is_cuda and k.is_cuda and v.is_cuda and da_cs.is_cuda and da_cs_sum.is_cuda and do.is_cuda, \"All tensors must be on CUDA\"\n\n    assert k.shape == q.shape\n    assert v.shape == (batch, seqlen, nheads, headdim_v)\n    assert da_cs.shape == (batch, nheads, seqlen)\n    assert da_cs_sum.shape == (batch, nheads, nchunks)\n    assert qk_dot.shape == (batch, nheads, seqlen)\n    assert SSM_States.shape == (batch, nheads, headdim_v, nchunks * headdim_qk)\n    assert do.shape == (batch, seqlen, nheads, headdim_v)\n    assert d_ossm_state is None or d_ossm_state.shape == (num_sequences, nheads, headdim_v, headdim_qk)\n    assert d_ov_state is None or d_ov_state.shape == (num_sequences, nheads, headdim_v)\n    if D is not None:\n        assert D.shape == (nheads,)\n    \n    # Ensure all tensors are contiguous for optimal memory access\n    # Check if tensors have expected strides (innermost dimension stride = 1)\n    if q.stride(-1) != 1:\n        q = q.contiguous()\n    if k.stride(-1) != 1:\n        k = k.contiguous()\n    if v.stride(-1) != 1:\n        v = v.contiguous()\n    if da_cs.stride(-1) != 1:\n        da_cs = da_cs.contiguous()\n    if da_cs_sum.stride(-1) != 1:\n        da_cs_sum = da_cs_sum.contiguous()\n    if qk_dot.stride(-1) != 1:\n        qk_dot = qk_dot.contiguous()\n    if SSM_States.stride(-1) != 1:\n        SSM_States = SSM_States.contiguous()\n    if do.stride(-1) != 1:\n        do = do.contiguous()\n    if D is not None and D.stride(-1) != 1:\n        D = D.contiguous()\n    if d_ossm_state is not None and d_ossm_state.stride(-1) != 1:\n        d_ossm_state = d_ossm_state.contiguous()\n    if d_ov_state is not None and d_ov_state.stride(-1) != 1:\n        d_ov_state = d_ov_state.contiguous()\n    \n    # Allocate output tensors\n    dq = torch.empty((batch, seqlen, nheads, headdim_qk), dtype=q.dtype, device=q.device)\n    dk = torch.empty((batch, seqlen, nheads, headdim_qk), dtype=k.dtype, device=k.device)\n    dv = torch.empty_like(v)\n    dAdt = torch.empty_like(da_cs)\n    dQK = torch.empty_like(da_cs)\n    dD = torch.empty((num_sequences, nheads), dtype=torch.float32, device=q.device) if D is not None else None\n    d_issm_state = torch.empty((num_sequences, nheads, headdim_v, headdim_qk), dtype=torch.float32, device=q.device) if has_input_state else None\n    \n    # Round up head dimensions to power of 2 for efficient loading\n    HEADDIM_QK = triton.next_power_of_2(headdim_qk)\n    HEADDIM_V = triton.next_power_of_2(headdim_v)\n    \n    # Grid: each program handles one (head, batch/num_sequences) pair\n    if is_varlen:\n        grid = (nheads, num_sequences)\n    else:\n        grid = (nheads, batch)\n    \n    # Launch kernel\n    mamba3_siso_bwd_kernel_dqkv[grid](\n        q, k, v, da_cs, da_cs_sum, qk_dot, D, SSM_States, do, d_ossm_state, Cu_Seqlens,\n        dq, dk, dv, dAdt, dQK, dD, d_issm_state,\n        # Q strides\n        q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n        # K strides\n        k.stride(0), k.stride(1), k.stride(2), k.stride(3),\n        # V strides\n        v.stride(0), v.stride(1), v.stride(2), v.stride(3),\n        # DA_CS strides\n        da_cs.stride(0), da_cs.stride(1), da_cs.stride(2),\n        # DA_CS_SUM strides\n        da_cs_sum.stride(0), da_cs_sum.stride(1), da_cs_sum.stride(2),\n        # QK_Dot strides\n        qk_dot.stride(0), qk_dot.stride(1), qk_dot.stride(2),\n        # D stride\n        D.stride(0) if D is not None else 0,\n        # SSM_States strides: (batch, nheads, headdim_v, nchunks*headdim_qk)\n        SSM_States.stride(0), SSM_States.stride(1), SSM_States.stride(2),\n        SSM_States.stride(3),\n        # dO strides\n        do.stride(0), do.stride(1), do.stride(2), do.stride(3),\n        # d_ossm_state strides\n        d_ossm_state.stride(0) if d_ossm_state is not None else 0,\n        d_ossm_state.stride(1) if d_ossm_state is not None else 0,\n        d_ossm_state.stride(2) if d_ossm_state is not None else 0,\n        d_ossm_state.stride(3) if d_ossm_state is not None else 0,\n        # Cu_Seqlens strides\n        Cu_Seqlens.stride(0) if Cu_Seqlens is not None else 0,\n        # dQ strides\n        dq.stride(0), dq.stride(1), dq.stride(2), dq.stride(3),\n        # dK strides\n        dk.stride(0), dk.stride(1), dk.stride(2), dk.stride(3),\n        # dV strides\n        dv.stride(0), dv.stride(1), dv.stride(2), dv.stride(3),\n        # dAdt strides\n        dAdt.stride(0), dAdt.stride(1), dAdt.stride(2),\n        # dQK strides\n        dQK.stride(0), dQK.stride(1), dQK.stride(2),\n        # dD strides\n        dD.stride(0) if D is not None else 0,\n        dD.stride(1) if D is not None else 0,\n        # d_issm_state strides\n        d_issm_state.stride(0) if d_issm_state is not None else 0,\n        d_issm_state.stride(1) if d_issm_state is not None else 0,\n        d_issm_state.stride(2) if d_issm_state is not None else 0,\n        d_issm_state.stride(3) if d_issm_state is not None else 0,\n        # Dimensions\n        seqlen, nheads_qk, headdim_qk, headdim_v,\n        # Compile-time constants\n        CHUNK_SIZE=chunk_size,\n        HEADDIM_QK=HEADDIM_QK,\n        HEADDIM_V=HEADDIM_V,\n        RECOMPUTE_MASK=False,\n        HAS_D_OSSM_STATE=d_ossm_state is not None,\n        RETURN_D_ISSM_STATE=has_input_state,\n        IS_VARLEN=is_varlen,\n    )\n\n    # Add output V state gradients to the last token\n    if d_ov_state is not None:\n        if is_varlen:\n            last_token_idx = Cu_Seqlens[1:] - 1\n            dv[0, last_token_idx] += d_ov_state\n        else:\n            dv[:, -1, :, :] += d_ov_state\n\n    dD = dD.sum(dim=0) if dD is not None else None\n    return dq, dk, dv, dAdt, dQK, dD, d_issm_state\n\n\n# =============================================================================\n#  d Rotary+Bias Kernel\n# =============================================================================\n\n\n@triton.autotune(\n    configs=[\n        triton.Config({}, num_stages=s, num_warps=w)\n        for s in [1, 2, 3]\n        for w in [2, 4, 8]\n    ],\n    key=[\"CHUNK_SIZE\", \"BLOCK_HEADDIM_QK\", \"HEADDIM_QK\", \"GQA_RATIO\"]\n)\n@triton.jit\ndef mamba3_siso_bwd_kernel_rotary_bias_angles(\n    # Input tensors\n    Q, K, Scale, Gamma, Q_bias, K_bias, Angles, dQ_in, dK_in, dQK,\n    # Output tensors\n    dQ, dK, dAngles, dScale, dGamma, dQ_bias, dK_bias,\n    # Strides for inputs -------------------------------------------------------    \n    # Q: (batch, seqlen, nheads_qk, BLOCK_HEADDIM_QK)\n    stride_q_batch, stride_q_seqlen, stride_q_head, stride_q_qkdim,\n    # K: (batch, seqlen, nheads_qk, BLOCK_HEADDIM_QK)\n    stride_k_batch, stride_k_seqlen, stride_k_head, stride_k_qkdim,\n    # Scale: (batch, nheads, seqlen)\n    stride_scale_batch, stride_scale_head, stride_scale_seqlen,\n    # Gamma: (batch, nheads, seqlen)\n    stride_gamma_batch, stride_gamma_head, stride_gamma_seqlen,\n    # Q_bias: (nheads, BLOCK_HEADDIM_QK)\n    stride_q_bias_head, stride_q_bias_qkdim,\n    # K_bias: (nheads, BLOCK_HEADDIM_QK)\n    stride_k_bias_head, stride_k_bias_qkdim,\n    # Angles: (batch, seqlen, nheads, BLOCK_HEADDIM_QK/2)\n    stride_angles_batch, stride_angles_seqlen, stride_angles_head, stride_angles_qkdim,\n    # dQ_in: (batch, seqlen, nheads, BLOCK_HEADDIM_QK)\n    stride_dq_in_batch, stride_dq_in_seqlen, stride_dq_in_head, stride_dq_in_qkdim,\n    # dK_in: (batch, seqlen, nheads, BLOCK_HEADDIM_QK)\n    stride_dk_in_batch, stride_dk_in_seqlen, stride_dk_in_head, stride_dk_in_qkdim,\n    # dQK: (batch, nheads, seqlen)\n    stride_dqk_batch, stride_dqk_head, stride_dqk_seqlen,\n    # Strides for outputs ------------------------------------------------------\n    # dQ: (batch, seqlen, nheads_qk, BLOCK_HEADDIM_QK)\n    stride_dq_batch, stride_dq_seqlen, stride_dq_head, stride_dq_qkdim,\n    # dK: (batch, seqlen, nheads_qk, BLOCK_HEADDIM_QK)\n    stride_dk_batch, stride_dk_seqlen, stride_dk_head, stride_dk_qkdim,\n    # dAngles: (batch, seqlen, nheads, BLOCK_HEADDIM_QK/2)\n    stride_dangles_batch, stride_dangles_seqlen, stride_dangles_head, stride_dangles_qkdim,\n    # dScale: (batch, nheads, HEADDIM_QK // BLOCK_HEADDIM_QK, seqlen)\n    stride_dscale_batch, stride_dscale_head, stride_dscale_nqkchunks ,stride_dscale_seqlen,\n    # dGamma: (batch, nheads, HEADDIM_QK // BLOCK_HEADDIM_QK, seqlen)\n    stride_dgamma_batch, stride_dgamma_head, stride_dgamma_nqkchunks, stride_dgamma_seqlen,\n    # dQ_bias: (batch, nchunks, nheads, BLOCK_HEADDIM_QK)\n    stride_dq_bias_batch, stride_dq_bias_nchunks, stride_dq_bias_head, stride_dq_bias_qkdim,\n    # dK_bias: (batch, nchunks, nheads, BLOCK_HEADDIM_QK)\n    stride_dk_bias_batch, stride_dk_bias_nchunks, stride_dk_bias_head, stride_dk_bias_qkdim,\n    # ---- sizes ----\n    seqlen, nheads_qk, nheads, headdim_qk, headdim_angles,\n    CHUNK_SIZE: tl.constexpr,\n    HEADDIM_QK: tl.constexpr,\n    BLOCK_HEADDIM_QK: tl.constexpr,\n    GQA_RATIO: tl.constexpr,\n):\n    \"\"\"\n    Grid: (nchunks, batch)\n    Each program processes one (batch, chunk) pair.\n    \n    Loop structure:\n    - Outer loop: iterate over qk_heads (nheads_qk)\n    - Inner loop: iterate over GQA group (GQA_RATIO heads per qk_head)\n    \"\"\"\n    pid_nchunk = tl.program_id(0)\n    pid_batch = tl.program_id(1)\n    nchunks = tl.cdiv(seqlen, CHUNK_SIZE)\n\n    # Base offsets for inputs\n    q_offset_base = pid_batch * stride_q_batch\n    k_offset_base = pid_batch * stride_k_batch\n    scale_offset_base = pid_batch * stride_scale_batch\n    gamma_offset_base = pid_batch * stride_gamma_batch\n    angle_offset_base = pid_batch * stride_angles_batch\n    dq_in_offset_base = pid_batch * stride_dq_in_batch\n    dk_in_offset_base = pid_batch * stride_dk_in_batch\n    dqk_offset_base = pid_batch * stride_dqk_batch\n\n    # Base offsets for outputs\n    dq_offset_base = pid_batch * stride_dq_batch\n    dk_offset_base = pid_batch * stride_dk_batch\n    dangle_offset_base = pid_batch * stride_dangles_batch\n    dscale_offset_base = pid_batch * stride_dscale_batch\n    dgamma_offset_base = pid_batch * stride_dgamma_batch\n    dq_bias_offset_base = pid_batch * stride_dq_bias_batch + pid_nchunk * stride_dq_bias_nchunks\n    dk_bias_offset_base = pid_batch * stride_dk_bias_batch + pid_nchunk * stride_dk_bias_nchunks\n\n    num_nheads_qk = HEADDIM_QK // BLOCK_HEADDIM_QK\n    for nhead_qk_id in range(num_nheads_qk):\n        offs_s = tl.arange(0, CHUNK_SIZE) + pid_nchunk * CHUNK_SIZE\n        offs_d = tl.arange(0, BLOCK_HEADDIM_QK) + nhead_qk_id * BLOCK_HEADDIM_QK\n        offs_dr = tl.arange(0, BLOCK_HEADDIM_QK // 2) + nhead_qk_id * (BLOCK_HEADDIM_QK // 2)\n\n        # Outer loop: iterate over qk_heads\n        for qk_head_idx in range(nheads_qk):\n            # ============================================================\n            # Load Q, K for this qk_head (once per GQA group)\n            # ============================================================\n            q_offset = q_offset_base + qk_head_idx * stride_q_head\n            k_offset = k_offset_base + qk_head_idx * stride_k_head\n            q_ptrs = Q + q_offset + offs_s[:, None] * stride_q_seqlen + offs_d[None, :] * stride_q_qkdim\n            k_ptrs = K + k_offset + offs_s[:, None] * stride_k_seqlen + offs_d[None, :] * stride_k_qkdim\n            \n            # Zero accumulators for this qk_head\n            dq_acc = tl.zeros((CHUNK_SIZE, BLOCK_HEADDIM_QK), dtype=tl.float32)\n            dk_acc = tl.zeros((CHUNK_SIZE, BLOCK_HEADDIM_QK), dtype=tl.float32)\n            \n            # Inner loop: iterate over GQA group\n            for gqa_idx in range(GQA_RATIO):\n                nhead_idx = qk_head_idx * GQA_RATIO + gqa_idx\n                \n                # ============================================================\n                # Load per-head data\n                # ============================================================\n                # Bias for this head\n                q_bias = tl.load(\n                    Q_bias + nhead_idx * stride_q_bias_head + offs_d * stride_q_bias_qkdim,\n                    mask=offs_d < headdim_qk).to(tl.float32)\n                k_bias = tl.load(\n                    K_bias + nhead_idx * stride_k_bias_head + offs_d * stride_k_bias_qkdim, \n                    mask=offs_d < headdim_qk).to(tl.float32)\n                \n                # Q + bias, K + bias\n                q0 = tl.load(q_ptrs, mask=(offs_s[:, None] < seqlen) & (offs_d[None, :] < headdim_qk), other=0.0)  # [CHUNK_SIZE, BLOCK_HEADDIM_QK]\n                k0 = tl.load(k_ptrs, mask=(offs_s[:, None] < seqlen) & (offs_d[None, :] < headdim_qk), other=0.0)  # [CHUNK_SIZE, BLOCK_HEADDIM_QK]\n                Q_wbias = q0 + q_bias[None, :]\n                K_wbias = k0 + k_bias[None, :]\n                \n                # dQK for this head\n                dqk_offset = dqk_offset_base + nhead_idx * stride_dqk_head\n                dqk = tl.load(dQK + dqk_offset + offs_s * stride_dqk_seqlen, mask=offs_s < seqlen, other=0.0)\n                \n                # Scale, Gamma for this head\n                scale_offset = scale_offset_base + nhead_idx * stride_scale_head\n                gamma_offset = gamma_offset_base + nhead_idx * stride_gamma_head\n                scale = tl.load(Scale + scale_offset + offs_s * stride_scale_seqlen, mask=offs_s < seqlen, other=0.0).to(tl.float32)\n                gamma = tl.load(Gamma + gamma_offset + offs_s * stride_gamma_seqlen, mask=offs_s < seqlen, other=0.0).to(tl.float32)\n                \n                # Angles for this head\n                angle_offset = angle_offset_base + nhead_idx * stride_angles_head\n                theta = tl.load(\n                    Angles + angle_offset + offs_s[:, None] * stride_angles_seqlen + offs_dr[None, :] * stride_angles_qkdim,\n                    mask=(offs_dr[None, :] < headdim_angles) & (offs_s[:, None] < seqlen), \n                    other=0.0).to(tl.float32)\n                \n                # dQ_in, dK_in for this head\n                dq_in_offset = dq_in_offset_base + nhead_idx * stride_dq_in_head\n                dk_in_offset = dk_in_offset_base + nhead_idx * stride_dk_in_head\n                dQ_in_load = tl.load(dQ_in + dq_in_offset + offs_s[:, None] * stride_dq_in_seqlen + offs_d[None, :] * stride_dq_in_qkdim, \n                    mask=(offs_s[:, None] < seqlen) & (offs_d[None, :] < headdim_qk), other=0.0)\n                dK_in_load = tl.load(dK_in + dk_in_offset + offs_s[:, None] * stride_dk_in_seqlen + offs_d[None, :] * stride_dk_in_qkdim,\n                    mask=(offs_s[:, None] < seqlen) & (offs_d[None, :] < headdim_qk), other=0.0)\n                \n                # ============================================================\n                # Compute dGamma = dQK * (Q_wbias · K_wbias)\n                # ============================================================\n                QK_dot = tl.sum(Q_wbias * K_wbias, axis=1)\n                d_gamma = dqk * QK_dot\n                dgamma_store_offset = dgamma_offset_base + nhead_idx * stride_dgamma_head\n                tl.store(\n                    dGamma + dgamma_store_offset + offs_s * stride_dgamma_seqlen + nhead_qk_id * stride_dgamma_nqkchunks, \n                    d_gamma, mask=offs_s < seqlen)\n                \n                # ============================================================\n                # Compute cos/sin for rotary\n                # ============================================================\n                cos_angle = cos_approx(theta.to(tl.float32))\n                sin_angle = sin_approx(theta.to(tl.float32))\n                \n                # ============================================================\n                # Compute dScale = sum(dK_in * K_rot)\n                # ============================================================\n                K_r = tl.reshape(K_wbias, [CHUNK_SIZE, BLOCK_HEADDIM_QK // 2, 2])\n                K_r0, K_r1 = tl.split(K_r)\n                K_rot0 = K_r0 * cos_angle - K_r1 * sin_angle\n                K_rot1 = K_r0 * sin_angle + K_r1 * cos_angle\n                K_rot = tl.reshape(tl.join(K_rot0, K_rot1), [CHUNK_SIZE, BLOCK_HEADDIM_QK])\n                \n                dscale_val = tl.sum(dK_in_load * K_rot, axis=1)\n                dscale_store_offset = dscale_offset_base + nhead_idx * stride_dscale_head\n                tl.store(\n                    dScale + dscale_store_offset + offs_s * stride_dscale_seqlen + nhead_qk_id * stride_dscale_nqkchunks, \n                    dscale_val, mask=offs_s < seqlen)\n                \n                # ============================================================\n                # Compute dQ_pre, dK_pre through inverse rotary\n                # ============================================================\n                dK_in_scaled = dK_in_load * scale[:, None] # shape: (CHUNK_SIZE, BLOCK_HEADDIM_QK)\n\n                Q_r = tl.reshape(Q_wbias, [CHUNK_SIZE, BLOCK_HEADDIM_QK // 2, 2])\n                Q_r0, Q_r1 = tl.split(Q_r)\n                \n                dQ_in_r = tl.reshape(dQ_in_load, [CHUNK_SIZE, BLOCK_HEADDIM_QK // 2, 2])\n                dK_in_r = tl.reshape(dK_in_scaled, [CHUNK_SIZE, BLOCK_HEADDIM_QK // 2, 2])\n                dQ_in_r0, dQ_in_r1 = tl.split(dQ_in_r)\n                dK_in_r0, dK_in_r1 = tl.split(dK_in_r)\n                \n                # Inverse rotary\n                dq0 = dQ_in_r0 * cos_angle + dQ_in_r1 * sin_angle\n                dq1 = -dQ_in_r0 * sin_angle + dQ_in_r1 * cos_angle\n                dk0 = dK_in_r0 * cos_angle + dK_in_r1 * sin_angle\n                dk1 = -dK_in_r0 * sin_angle + dK_in_r1 * cos_angle\n                \n                dQ_pre = tl.reshape(tl.join(dq0, dq1), [CHUNK_SIZE, BLOCK_HEADDIM_QK])\n                dK_pre = tl.reshape(tl.join(dk0, dk1), [CHUNK_SIZE, BLOCK_HEADDIM_QK])\n                \n                # Add dQK path\n                dqk_scaled = (dqk * gamma)[:, None]\n                dQ_pre = dQ_pre + dqk_scaled * K_wbias\n                dK_pre = dK_pre + dqk_scaled * Q_wbias\n                \n                # ============================================================\n                # Accumulate dQ, dK for GQA reduction\n                # ============================================================\n                dq_acc += dQ_pre\n                dk_acc += dK_pre\n                \n                # ============================================================\n                # Store dQ_bias, dK_bias for this head (sum over chunk)\n                # ============================================================\n                dq_bias_out = tl.sum(dQ_pre, axis=0)\n                dk_bias_out = tl.sum(dK_pre, axis=0)\n                dq_bias_store_offset = dq_bias_offset_base + nhead_idx * stride_dq_bias_head\n                dk_bias_store_offset = dk_bias_offset_base + nhead_idx * stride_dk_bias_head\n                tl.store(dQ_bias + dq_bias_store_offset + offs_d * stride_dq_bias_qkdim, dq_bias_out, mask=offs_d < headdim_qk)\n                tl.store(dK_bias + dk_bias_store_offset + offs_d * stride_dk_bias_qkdim, dk_bias_out, mask=offs_d < headdim_qk)\n                \n                # ============================================================\n                # Compute and store dAngles for this head\n                # ============================================================\n                dtheta_q = dQ_in_r0 * (-Q_r0 * sin_angle - Q_r1 * cos_angle) + dQ_in_r1 * (Q_r0 * cos_angle - Q_r1 * sin_angle)\n                dtheta_k = dK_in_r0 * (-K_r0 * sin_angle - K_r1 * cos_angle) + dK_in_r1 * (K_r0 * cos_angle - K_r1 * sin_angle)\n                dtheta = dtheta_q + dtheta_k\n                \n                dangle_store_offset = dangle_offset_base + nhead_idx * stride_dangles_head\n                tl.store(\n                    dAngles + dangle_store_offset + offs_s[:, None] * stride_dangles_seqlen + offs_dr[None, :] * stride_dangles_qkdim, \n                    dtheta, mask=(offs_dr[None, :] < headdim_angles) & (offs_s[:, None] < seqlen))\n            \n            # ============================================================\n            # End of GQA group: store accumulated dQ, dK\n            # ============================================================\n            dq_offset = dq_offset_base + qk_head_idx * stride_dq_head\n            dk_offset = dk_offset_base + qk_head_idx * stride_dk_head\n            dq_ptrs = dQ + dq_offset + offs_s[:, None] * stride_dq_seqlen + offs_d[None, :] * stride_dq_qkdim\n            dk_ptrs = dK + dk_offset + offs_s[:, None] * stride_dk_seqlen + offs_d[None, :] * stride_dk_qkdim\n            tl.store(dq_ptrs, dq_acc, mask=(offs_s[:, None] < seqlen) & (offs_d[None, :] < headdim_qk))\n            tl.store(dk_ptrs, dk_acc, mask=(offs_s[:, None] < seqlen) & (offs_d[None, :] < headdim_qk))\n\n\n# NOTE: Do not autotune this kernel. It overwrites dK, dK_bias, dAngles via atomic adds and autotuning will lead to multiple overwrites.\n@triton.jit\ndef mamba3_siso_bwd_kernel_dk_state_post(\n    # Inputs tensors\n    dK_State, Angles, K, K_bias, Cu_Seqlens,\n    # Outputs tensors\n    dK, dK_bias, dAngles,\n    # Strides for dK_State: (num_sequences, nheads, headdim_qk)\n    stride_dk_state_batch, stride_dk_state_head, stride_dk_state_qkdim,\n    # Strides for Angles: (batch, seqlen, nheads, headdim_angles)\n    stride_angles_batch, stride_angles_seqlen, stride_angles_head, stride_angles_qkdim,\n    # Strides for K: (batch, seqlen, nheads_qk, headdim_qk)\n    stride_k_batch, stride_k_seqlen, stride_k_head, stride_k_qkdim,\n    # Strides for K_bias: (nheads, headdim_qk)\n    stride_k_bias_head, stride_k_bias_qkdim,\n    # Strides for Cu_Seqlens: (num_sequences + 1,)\n    stride_cu_seqlen,\n    # Strides for dK: (batch, seqlen, nheads_qk, headdim_qk)\n    stride_dk_batch, stride_dk_seqlen, stride_dk_head, stride_dk_qkdim,\n    # Strides for dK_bias: (nheads, headdim_qk)\n    stride_dk_bias_head, stride_dk_bias_qkdim,\n    # Strides for dAngles: (batch, seqlen, nheads, headdim_angles)\n    stride_dangles_batch, stride_dangles_seqlen, stride_dangles_head, stride_dangles_qkdim,\n    # Dimensions\n    seqlen, headdim_qk, headdim_angles,\n    HEADDIM_QK: tl.constexpr,\n    GQA_RATIO: tl.constexpr,\n    IS_VARLEN: tl.constexpr,\n):\n    \"\"\"\n    Post-kernel for d_ok_state contributions.\n    Grid: (nheads, batch)\n    \n    Each program handles one (batch, nhead) pair and computes:\n    1. dK via inverse rotary + GQA reduction (atomic add)\n    2. dK_bias via inverse rotary + batch reduction (atomic add)\n    3. dAngles via rotary gradient (atomic add)\n    \"\"\"\n    pid_head = tl.program_id(0)\n    pid_batch = tl.program_id(1)\n\n    if IS_VARLEN:\n        pid_seq = pid_batch\n        pid_batch = 0\n        cu_seqlen = tl.load(Cu_Seqlens + (pid_seq + 1) * stride_cu_seqlen).to(tl.int32)\n        last_pos = cu_seqlen - 1\n    else:\n        pid_seq = 0\n        last_pos = seqlen - 1\n    \n    qk_head_idx = pid_head // GQA_RATIO\n    offs_d = tl.arange(0, HEADDIM_QK)\n    offs_dr = tl.arange(0, HEADDIM_QK // 2)\n\n    # Load dK_State as interleaved pairs\n    dk_state_base = dK_State + (pid_batch + pid_seq) * stride_dk_state_batch + pid_head * stride_dk_state_head\n    dk_state = tl.load(dk_state_base + offs_d * stride_dk_state_qkdim, mask=offs_d < headdim_qk, other=0.0).to(tl.float32)\n    dk_state_r = tl.reshape(dk_state, [HEADDIM_QK // 2, 2])\n    dk_state_r0, dk_state_r1 = tl.split(dk_state_r)  # shape: (HEADDIM_QK // 2,)\n    \n    # Load angles at last position\n    angles_base = Angles + pid_batch * stride_angles_batch + last_pos * stride_angles_seqlen + pid_head * stride_angles_head\n    angles_val = tl.load(angles_base + offs_dr * stride_angles_qkdim, mask=offs_dr < headdim_angles, other=0.0).to(tl.float32)  # shape: (HEADDIM_QK // 2,)\n    \n    cos_ang = cos_approx(angles_val)\n    sin_ang = sin_approx(angles_val)\n    \n    # Inverse rotary: dk_rotated\n    dk0 = dk_state_r0 * cos_ang + dk_state_r1 * sin_ang\n    dk1 = -dk_state_r0 * sin_ang + dk_state_r1 * cos_ang\n    dk_rotated = tl.reshape(tl.join(dk0, dk1), [HEADDIM_QK])\n    \n    # 1. Accumulate to dK (GQA reduction via atomic)\n    dk_base = dK + pid_batch * stride_dk_batch + last_pos * stride_dk_seqlen + qk_head_idx * stride_dk_head\n    tl.atomic_add(dk_base + offs_d * stride_dk_qkdim, dk_rotated, mask=offs_d < headdim_qk)\n    \n    # 2. Accumulate to dK_bias (batch reduction via atomic)\n    dk_bias_base = dK_bias + pid_head * stride_dk_bias_head\n    tl.atomic_add(dk_bias_base + offs_d * stride_dk_bias_qkdim, dk_rotated, mask=offs_d < headdim_qk)\n    \n    # 3. Compute dAngles\n    # Load K at last position (using qk_head_idx for GQA)\n    k_base = K + pid_batch * stride_k_batch + last_pos * stride_k_seqlen + qk_head_idx * stride_k_head\n    k_val = tl.load(k_base + offs_d * stride_k_qkdim, mask=offs_d < headdim_qk, other=0.0).to(tl.float32)\n    kr = tl.reshape(k_val, [HEADDIM_QK // 2, 2])\n    k_r0, k_r1 = tl.split(kr)  # shape: (HEADDIM_QK // 2,)\n    \n    # Load K_bias\n    k_bias_base = K_bias + pid_head * stride_k_bias_head\n    k_bias_val = tl.load(k_bias_base + offs_d * stride_k_bias_qkdim, mask=offs_d < headdim_qk, other=0.0).to(tl.float32)\n    kbr = tl.reshape(k_bias_val, [HEADDIM_QK // 2, 2])\n    kb_r0, kb_r1 = tl.split(kbr)  # shape: (HEADDIM_QK // 2,)\n    \n    # K_wbias = K + K_bias\n    K_wbias_r0 = k_r0 + kb_r0\n    K_wbias_r1 = k_r1 + kb_r1\n    \n    # dtheta = dk_r0 * (-K0*sin - K1*cos) + dk_r1 * (K0*cos - K1*sin)\n    dtheta_k = (dk_state_r0 * (-K_wbias_r0 * sin_ang - K_wbias_r1 * cos_ang) + \n                dk_state_r1 * (K_wbias_r0 * cos_ang - K_wbias_r1 * sin_ang))\n    \n    # Accumulate to dAngles at last position\n    da_base = dAngles + pid_batch * stride_dangles_batch + last_pos * stride_dangles_seqlen + pid_head * stride_dangles_head\n    tl.atomic_add(da_base + offs_dr * stride_dangles_qkdim, dtheta_k, mask=offs_dr < headdim_angles)\n\n\ndef compute_dqktheta(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    scale: torch.Tensor,\n    gamma: torch.Tensor,\n    q_bias: torch.Tensor,\n    k_bias: torch.Tensor,\n    angles: torch.Tensor,\n    dq_in: torch.Tensor,\n    dk_in: torch.Tensor,\n    dqk: torch.Tensor,\n    d_ok_state: Optional[torch.Tensor] = None,\n    chunk_size: int = 64,\n    Cu_Seqlens: Optional[torch.Tensor] = None,\n) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Compute gradients through rotary embeddings and biases for Mamba-3 backward pass.\n    \n    This kernel undoes the rotary embedding and computes gradients for the original Q, K,\n    angles, scaling factors, and biases.\n    \n    Args:\n        q: Original query tensor before bias/rotary (batch, seqlen, nheads_qk, headdim_qk)\n        k: Original key tensor before bias/rotary (batch, seqlen, nheads_qk, headdim_qk)\n        scale: Combined scale factor gamma + gamma (batch, nheads, seqlen)\n        gamma: gamma factor (batch, nheads, seqlen)\n        q_bias: Query bias (nheads, headdim_qk)\n        k_bias: Key bias (nheads, headdim_qk)\n        angles: Rotary angles (batch, seqlen, nheads, headdim_angles)\n        dq_in: Gradient from downstream for Q_mid (batch, seqlen, nheads, headdim_qk)\n        dk_in: Gradient from downstream for K_mid (batch, seqlen, nheads, headdim_qk)\n        dqk: Gradient for QK dot products (batch, nheads, seqlen)\n        d_ok_state: Gradient of output K state (batch, nheads, headdim_qk) - added to last token of dK (without scaling)\n        chunk_size: Chunk size (default: 64)\n    \n    Returns:\n        Tuple of (dQ, dK, dQ_bias, dK_bias, dAngles, dScale, dSGamma)\n        - dQ: (batch, seqlen, nheads_qk, headdim_qk)\n        - dK: (batch, seqlen, nheads_qk, headdim_qk)\n        - dQ_bias: (nheads, headdim_qk)\n        - dK_bias: (nheads, headdim_qk)\n        - dAngles: (batch, seqlen, nheads, headdim_angles)\n        - dScale: (batch, nheads, seqlen)\n        - dGamma: (batch, nheads, seqlen)\n    \"\"\"\n    batch, seqlen, nheads_qk, headdim_qk = q.shape\n    assert q.shape == k.shape\n\n    nheads = scale.shape[1]\n    nchunks = triton.cdiv(seqlen, chunk_size)\n    GQA_RATIO = nheads // nheads_qk\n    \n    assert scale.shape == (batch, nheads, seqlen)\n    assert gamma.shape == (batch, nheads, seqlen)\n    assert q_bias.shape == (nheads, headdim_qk)\n    assert k_bias.shape == (nheads, headdim_qk)\n    headdim_angles = angles.shape[-1]\n    assert angles.shape == (batch, seqlen, nheads, headdim_angles)\n    assert dq_in.shape == (batch, seqlen, nheads, headdim_qk)\n    assert dk_in.shape == (batch, seqlen, nheads, headdim_qk)\n    assert dqk.shape == (batch, nheads, seqlen)\n    if d_ok_state is not None:\n        num_sequences = Cu_Seqlens.shape[0] - 1 if Cu_Seqlens is not None else batch\n        assert d_ok_state.shape == (num_sequences, nheads, headdim_qk)\n    assert nheads % nheads_qk == 0, \"nheads must be multiple of nheads_qk for GQA support\"\n\n    # Ensure contiguity after reshaping\n    if not q.is_contiguous():\n        q = q.contiguous()\n    if not k.is_contiguous():\n        k = k.contiguous()\n    if not scale.is_contiguous():\n        scale = scale.contiguous()\n    if not gamma.is_contiguous():\n        gamma = gamma.contiguous()\n    if not dqk.is_contiguous():\n        dqk = dqk.contiguous()\n    if not angles.is_contiguous():\n        angles = angles.contiguous()\n    if not dq_in.is_contiguous():\n        dq_in = dq_in.contiguous()\n    if not dk_in.is_contiguous():\n        dk_in = dk_in.contiguous()\n    if q_bias.stride(-1) != 1:\n        q_bias = q_bias.contiguous()\n    if k_bias.stride(-1) != 1:\n        k_bias = k_bias.contiguous()\n    if d_ok_state is not None and (not d_ok_state.is_contiguous()):\n        d_ok_state = d_ok_state.contiguous()\n    \n    HEADDIM_QK = triton.next_power_of_2(headdim_qk)\n    BLOCK_HEADDIM_QK = min(HEADDIM_QK, 64)\n\n    # Allocate output tensors layout\n    dq = torch.empty((batch, seqlen, nheads_qk, headdim_qk), \n                              dtype=dq_in.dtype, device=q.device)\n    dk = torch.empty((batch, seqlen, nheads_qk, headdim_qk), \n                              dtype=dk_in.dtype, device=k.device)\n    dangles = torch.empty((batch, seqlen, nheads, headdim_angles),\n                                   dtype=angles.dtype, device=angles.device)\n    dscale = torch.empty((batch, nheads, HEADDIM_QK // BLOCK_HEADDIM_QK, seqlen),\n                                  dtype=scale.dtype, device=scale.device)\n    dgamma = torch.empty((batch, nheads, HEADDIM_QK // BLOCK_HEADDIM_QK, seqlen),\n                                   dtype=gamma.dtype, device=gamma.device)\n    dq_bias_partial = torch.empty((batch, nchunks, nheads, headdim_qk),\n                                   dtype=torch.float32, device=q.device)\n    dk_bias_partial = torch.empty((batch, nchunks, nheads, headdim_qk),\n                                   dtype=torch.float32, device=k.device)\n\n    # Grid: (nchunks, batch)\n    grid = (nchunks, batch)\n\n    mamba3_siso_bwd_kernel_rotary_bias_angles[grid](\n        # Input tensors\n        q, k, scale, gamma, q_bias, k_bias, angles, dq_in, dk_in, dqk,\n        # Output tensors\n        dq, dk, dangles, dscale, dgamma, dq_bias_partial, dk_bias_partial,\n        # Q strides: (batch, seqlen, nheads_qk, headdim_qk)\n        q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n        # K strides\n        k.stride(0), k.stride(1), k.stride(2), k.stride(3),\n        # Scale strides: (batch, nheads, seqlen)\n        scale.stride(0), scale.stride(1), scale.stride(2),\n        # SGamma strides\n        gamma.stride(0), gamma.stride(1), gamma.stride(2),\n        # Q_bias strides: (nheads, headdim_qk)\n        q_bias.stride(0), q_bias.stride(1),\n        # K_bias strides\n        k_bias.stride(0), k_bias.stride(1),\n        # Angles strides: (batch, seqlen, nheads, headdim_qk//2)\n        angles.stride(0), angles.stride(1), angles.stride(2), angles.stride(3),\n        # dQ_in strides: (batch, seqlen, nheads, headdim_qk)\n        dq_in.stride(0), dq_in.stride(1), dq_in.stride(2), dq_in.stride(3),\n        # dK_in strides\n        dk_in.stride(0), dk_in.stride(1), dk_in.stride(2), dk_in.stride(3),\n        # dQK strides: (batch, nheads, seqlen)\n        dqk.stride(0), dqk.stride(1), dqk.stride(2),\n        # Output tensors\n        # dQ strides: (batch, seqlen, nheads_qk, headdim_qk)\n        dq.stride(0), dq.stride(1), dq.stride(2), dq.stride(3),\n        # dK strides\n        dk.stride(0), dk.stride(1), dk.stride(2), dk.stride(3),\n        # dAngles strides: (batch, seqlen, nheads, headdim_qk//2)\n        dangles.stride(0), dangles.stride(1), dangles.stride(2), dangles.stride(3),\n        # dScale strides: (batch, nheads, seqlen)\n        dscale.stride(0), dscale.stride(1), dscale.stride(2), dscale.stride(3),\n        # dSGamma strides\n        dgamma.stride(0), dgamma.stride(1), dgamma.stride(2), dgamma.stride(3),\n        # dQ_bias_partial strides: (batch, nchunks, nheads, headdim_qk)\n        dq_bias_partial.stride(0), dq_bias_partial.stride(1),\n        dq_bias_partial.stride(2), dq_bias_partial.stride(3),\n        # dK_bias_partial strides\n        dk_bias_partial.stride(0), dk_bias_partial.stride(1),\n        dk_bias_partial.stride(2), dk_bias_partial.stride(3),\n        # Sizes\n        seqlen, nheads_qk, nheads, headdim_qk, headdim_angles,\n        CHUNK_SIZE=chunk_size,\n        HEADDIM_QK=HEADDIM_QK,\n        BLOCK_HEADDIM_QK=BLOCK_HEADDIM_QK,\n        GQA_RATIO=GQA_RATIO,\n    )\n    \n    # Reshape outputs back to original layout\n    dscale = torch.sum(dscale, dim=2)  # Sum over headdim blocks\n    dgamma = torch.sum(dgamma, dim=2)  # Sum over headdim blocks\n    \n    # Reduce bias gradients: (batch, nchunks, nheads, headdim_qk) -> (nheads, headdim_qk)\n    dq_bias = dq_bias_partial.sum(dim=(0, 1))\n    dk_bias = dk_bias_partial.sum(dim=(0, 1))\n\n    # NOTE: We handle d_ok_state contributions in a different kernel because merging it in \n    # causes a +800% increase in register spillage and a +200us increase in runtime. For now \n    # this new kernel only introduces +5us.\n    if d_ok_state is not None:\n        apply_dk_state_post(\n            d_ok_state, angles, k, k_bias, dk, dk_bias, dangles, Cu_Seqlens\n        )\n    return dq, dk, dq_bias, dk_bias, dangles, dscale, dgamma\n\ndef apply_dk_state_post(\n    d_ok_state: torch.Tensor,\n    angles: torch.Tensor,\n    k: torch.Tensor,\n    k_bias: torch.Tensor,\n    dk: torch.Tensor,\n    dk_bias: torch.Tensor,\n    dangles: torch.Tensor,\n    Cu_Seqlens: Optional[torch.Tensor] = None,\n):\n    batch, seqlen, nheads, headdim_angles = angles.shape\n    _, _, headdim_qk = d_ok_state.shape\n    nheads_qk = k.shape[2]\n    GQA_RATIO = nheads // nheads_qk\n\n    is_varlen = Cu_Seqlens is not None\n    if is_varlen:\n        num_sequences = Cu_Seqlens.shape[0] - 1\n        assert batch == 1\n    else:\n        num_sequences = batch\n    \n    # Ensure contiguity\n    if not d_ok_state.is_contiguous():\n        d_ok_state = d_ok_state.contiguous()\n    if not angles.is_contiguous():\n        angles = angles.contiguous()\n    if not k.is_contiguous():\n        k = k.contiguous()\n    if not k_bias.is_contiguous():\n        k_bias = k_bias.contiguous()\n    \n    HEADDIM_QK = triton.next_power_of_2(headdim_qk)\n    \n    grid = (nheads, num_sequences)\n    \n    mamba3_siso_bwd_kernel_dk_state_post[grid](\n        # Input tensors\n        d_ok_state, angles, k, k_bias, Cu_Seqlens,\n        # Output tensors\n        dk, dk_bias, dangles,\n        # dK_State strides: (batch, nheads, headdim_qk)\n        d_ok_state.stride(0), d_ok_state.stride(1), d_ok_state.stride(2),\n        # Angles strides: (batch, seqlen, nheads, headdim_angles)\n        angles.stride(0), angles.stride(1), angles.stride(2), angles.stride(3),\n        # K strides: (batch, seqlen, nheads_qk, headdim_qk)\n        k.stride(0), k.stride(1), k.stride(2), k.stride(3),\n        # K_bias strides: (nheads, headdim_qk)\n        k_bias.stride(0), k_bias.stride(1),\n        # Cu_Seqlens strides: (num_sequences + 1,)\n        Cu_Seqlens.stride(0) if is_varlen else 0,\n        # dK strides: (batch, seqlen, nheads_qk, headdim_qk)\n        dk.stride(0), dk.stride(1), dk.stride(2), dk.stride(3),\n        # dK_bias strides: (nheads, headdim_qk)\n        dk_bias.stride(0), dk_bias.stride(1),\n        # dAngles strides: (batch, seqlen, nheads, headdim_angles)\n        dangles.stride(0), dangles.stride(1), dangles.stride(2), dangles.stride(3),\n        # Dimensions\n        seqlen, headdim_qk, headdim_angles,\n        HEADDIM_QK=HEADDIM_QK,\n        GQA_RATIO=GQA_RATIO,\n        IS_VARLEN=is_varlen,\n        num_warps=2,\n        num_stages=3,\n    )\n\n\n# =============================================================================\n# dDT, dTrap, and dInput States Kernel\n# =============================================================================\n@triton.autotune(\n    configs=[\n        triton.Config({\"CHUNK_SIZE\": cs}, num_stages=s, num_warps=w)\n        for cs in [64, 128, 256]\n        for s in [1, 2, 3]\n        for w in [2, 4, 8]\n    ],\n    key=[\"HEADDIM_V\", \"HEADDIM_QK\", \"HAS_INPUT_STATE\", \"IS_VARLEN\"]\n)\n@triton.jit\ndef mamba3_siso_bwd_kernel_ddt_dtrap_dinput_states(\n    # Input tensors\n    dScale, dGamma, DT, Trap,\n    d_ISSM_State, Input_K_State, Input_V_State, Cu_Seqlens,\n    # Output tensors\n    dDT, dTrap,\n    dInput_SSM_State, dInput_K_State, dInput_V_State,\n    # Strides for dScale: (batch, nheads, seqlen)\n    stride_dscale_batch, stride_dscale_head, stride_dscale_seqlen,\n    # Strides for dGamma: (batch, nheads, seqlen)\n    stride_dgamma_batch, stride_dgamma_head, stride_dgamma_seqlen,\n    # Strides for DT: (batch, nheads, seqlen)\n    stride_dt_batch, stride_dt_head, stride_dt_seqlen,\n    # Strides for Trap: (batch, nheads, seqlen)\n    stride_trap_batch, stride_trap_head, stride_trap_seqlen,\n    # Strides for d_ISSM_State: (num_sequences, nheads, headdim_v, headdim_qk)\n    stride_d_issm_state_batch, stride_d_issm_state_head, stride_d_issm_state_vdim, stride_d_issm_state_qkdim,\n    # Strides for Input_K_State: (num_sequences, nheads, headdim_qk)\n    stride_input_k_state_batch, stride_input_k_state_head, stride_input_k_state_qkdim,\n    # Strides for Input_V_State: (num_sequences, nheads, headdim_v)\n    stride_input_v_state_batch, stride_input_v_state_head, stride_input_v_state_vdim,\n    # Stride for Cu_Seqlens\n    stride_cu_seqlen,\n    # Strides for dDT: (batch, nheads, seqlen)\n    stride_ddt_batch, stride_ddt_head, stride_ddt_seqlen,\n    # Strides for dTrap: (batch, nheads, seqlen)\n    stride_dtrap_batch, stride_dtrap_head, stride_dtrap_seqlen,\n    # Strides for dInput_SSM_State: (num_sequences, nheads, headdim_v, headdim_qk)\n    stride_dinput_ssm_state_batch, stride_dinput_ssm_state_head, stride_dinput_ssm_state_vdim, stride_dinput_ssm_state_qkdim,\n    # Strides for dInput_K_State: (num_sequences, nheads, headdim_qk)\n    stride_dinput_k_state_batch, stride_dinput_k_state_head, stride_dinput_k_state_qkdim,\n    # Strides for dInput_V_State: (num_sequences, nheads, headdim_v)\n    stride_dinput_v_state_batch, stride_dinput_v_state_head, stride_dinput_v_state_vdim,\n    # Dimensions\n    seqlen, headdim_v, headdim_qk,\n    # Compile-time constants\n    CHUNK_SIZE: tl.constexpr,\n    HEADDIM_V: tl.constexpr,\n    HEADDIM_QK: tl.constexpr,\n    HAS_INPUT_STATE: tl.constexpr,\n    IS_VARLEN: tl.constexpr,\n):\n    \"\"\"\n    Backward kernel for computing dDT, dTrap, and input state gradients.\n    \n    Part 1 - dDT and dTrap from dScale and dGamma:\n        Forward: gamma_t = DT_t * Trap_t                    (used independently)\n                 shifted_gamma_t = DT_{t+1} * (1 - Trap_{t+1})  (used as scale for position t)\n        \n        Backward: DT[t] appears in gamma[t] and shifted_gamma[t-1]:\n                  dDT_t = dGamma_t * Trap_t + dScale_{t-1} * (1 - Trap_t)\n                  \n                  Trap[t] appears in gamma[t] and shifted_gamma[t-1]:\n                  dTrap_t = dGamma_t * DT_t - dScale_{t-1} * DT_t\n    \n    Part 2 - Input state gradients (first token only, if HAS_INPUT_STATE):\n        Forward: scalar = DT_0 * (1 - Trap_0)\n                 SSM_State = Input_SSM_State + outer(Input_V, Input_K) * scalar\n        Backward: dInput_SSM_State = d_ISSM_State\n                  dInput_V = einsum(d_ISSM_State, Input_K) * scalar\n                  dInput_K = einsum(d_ISSM_State, Input_V) * scalar\n                  dDT_0 += d_scalar * (1 - Trap_0)\n                  dTrap_0 += d_scalar * (-DT_0)\n    \n    Grid: \n        - Normal mode: (nheads, batch)\n        - Varlen mode: (nheads, num_sequences)\n    \"\"\"\n    pid_head = tl.program_id(0)\n    pid_batch = tl.program_id(1)\n\n    if IS_VARLEN:\n        pid_seq = pid_batch\n        pid_batch = 0\n        cu_seqlen = tl.load(Cu_Seqlens + pid_seq * stride_cu_seqlen).to(tl.int32)\n        cu_seqlen_next = tl.load(Cu_Seqlens + (pid_seq + 1) * stride_cu_seqlen).to(tl.int32)\n        seqlen = cu_seqlen_next - cu_seqlen\n    else:\n        pid_seq = 0\n        cu_seqlen = 0\n\n    # ==================== Pointer Offsets ====================\n    dscale_offset = pid_batch * stride_dscale_batch + pid_head * stride_dscale_head + IS_VARLEN * cu_seqlen * stride_dscale_seqlen\n    dgamma_offset = pid_batch * stride_dgamma_batch + pid_head * stride_dgamma_head + IS_VARLEN * cu_seqlen * stride_dgamma_seqlen\n    dt_offset = pid_batch * stride_dt_batch + pid_head * stride_dt_head + IS_VARLEN * cu_seqlen * stride_dt_seqlen\n    trap_offset = pid_batch * stride_trap_batch + pid_head * stride_trap_head + IS_VARLEN * cu_seqlen * stride_trap_seqlen\n    ddt_offset = pid_batch * stride_ddt_batch + pid_head * stride_ddt_head + IS_VARLEN * cu_seqlen * stride_ddt_seqlen\n    dtrap_offset = pid_batch * stride_dtrap_batch + pid_head * stride_dtrap_head + IS_VARLEN * cu_seqlen * stride_dtrap_seqlen\n\n    # ==================== Part 1: dDT and dTrap ====================\n    num_chunks = tl.cdiv(seqlen, CHUNK_SIZE)\n    \n    for chunk_idx in range(num_chunks):\n        offs_s = chunk_idx * CHUNK_SIZE + tl.arange(0, CHUNK_SIZE)\n        mask = offs_s < seqlen\n\n        # Load dscale_t, dGamma_t, Trap_t, DT_t for current positions\n        dscale_t = tl.load(dScale + dscale_offset + offs_s * stride_dscale_seqlen, mask=mask, other=0.0)\n        dgamma_t = tl.load(dGamma + dgamma_offset + offs_s * stride_dgamma_seqlen, mask=mask, other=0.0)\n        trap_presig_t = tl.load(Trap + trap_offset + offs_s * stride_trap_seqlen, mask=mask, other=0.0).to(tl.float32)\n        trap_t = sigmoid_approx(trap_presig_t)\n        dt_t = tl.load(DT + dt_offset + offs_s * stride_dt_seqlen, mask=mask, other=0.0)\n\n        # Load dScale_{t-1} (shifted by 1, with 0 at t=0)\n        # shifted_gamma[t-1] = DT[t] * (1 - Trap[t]) feeds into scale[t-1]\n        offs_s_prev = offs_s - 1\n        mask_prev = (offs_s_prev >= 0) & (offs_s_prev < seqlen)\n        dscale_prev = tl.load(\n            dScale + dscale_offset + offs_s_prev * stride_dscale_seqlen,\n            mask=mask_prev,\n            other=0.0\n        )\n\n        # Compute gradients:\n        ddt_t = (dgamma_t + dscale_t) * trap_t + dscale_prev * (1.0 - trap_t)\n        dtrap_t = (dgamma_t + dscale_t) * dt_t - dscale_prev * dt_t\n        dtrap_presig_t = dtrap_t * trap_t * (1.0 - trap_t)\n\n        # Store results\n        tl.store(dDT + ddt_offset + offs_s * stride_ddt_seqlen, ddt_t, mask=mask)\n        tl.store(dTrap + dtrap_offset + offs_s * stride_dtrap_seqlen, dtrap_presig_t, mask=mask)\n\n    # ==================== Part 2: Input State Gradients ====================\n    if HAS_INPUT_STATE:\n        # Pointer offsets for input states\n        d_issm_offset = (pid_batch + pid_seq) * stride_d_issm_state_batch + pid_head * stride_d_issm_state_head\n        input_k_offset = (pid_batch + pid_seq) * stride_input_k_state_batch + pid_head * stride_input_k_state_head\n        input_v_offset = (pid_batch + pid_seq) * stride_input_v_state_batch + pid_head * stride_input_v_state_head\n        dinput_ssm_offset = (pid_batch + pid_seq) * stride_dinput_ssm_state_batch + pid_head * stride_dinput_ssm_state_head\n        dinput_k_offset = (pid_batch + pid_seq) * stride_dinput_k_state_batch + pid_head * stride_dinput_k_state_head\n        dinput_v_offset = (pid_batch + pid_seq) * stride_dinput_v_state_batch + pid_head * stride_dinput_v_state_head\n        # Load DT_0 and Trap_0 (first token)\n        dt_0 = tl.load(DT + dt_offset).to(tl.float32)\n        trap_presig_0 = tl.load(Trap + trap_offset).to(tl.float32)\n        trap_0 = sigmoid_approx(trap_presig_0)\n        scalar = dt_0 * (1.0 - trap_0)\n\n        # Dimension offsets\n        offs_v = tl.arange(0, HEADDIM_V)\n        offs_qk = tl.arange(0, HEADDIM_QK)\n\n        # Load Input_K_State and Input_V_State\n        input_k = tl.load(\n            Input_K_State + input_k_offset + offs_qk * stride_input_k_state_qkdim, \n            mask=offs_qk < headdim_qk, \n            other=0.0).to(tl.float32)\n        input_v = tl.load(\n            Input_V_State + input_v_offset + offs_v * stride_input_v_state_vdim,\n            mask=offs_v < headdim_v,\n            other=0.0\n        ).to(tl.float32)\n\n        # Load d_ISSM_State: (headdim_v, headdim_qk)\n        d_issm = tl.load(\n            d_ISSM_State + d_issm_offset + \n            offs_v[:, None] * stride_d_issm_state_vdim + \n            offs_qk[None, :] * stride_d_issm_state_qkdim,\n            mask=(offs_v[:, None] < headdim_v) & (offs_qk[None, :] < headdim_qk),\n            other=0.0\n        ).to(tl.float32)\n\n        # dInput_SSM_State = d_ISSM_State (direct copy)\n        tl.store(\n            dInput_SSM_State + dinput_ssm_offset + \n            offs_v[:, None] * stride_dinput_ssm_state_vdim + \n            offs_qk[None, :] * stride_dinput_ssm_state_qkdim,\n            d_issm,\n            mask=(offs_v[:, None] < headdim_v) & (offs_qk[None, :] < headdim_qk),\n        )\n\n        # d_scalar = sum(d_ISSM_State * outer(Input_V, Input_K))\n        outer_product = input_v[:, None] * input_k[None, :]\n        d_scalar = tl.sum(d_issm * outer_product)\n\n        # dInput_V = sum_d(d_ISSM_State * Input_K) * scalar\n        # dInput_K = sum_D(d_ISSM_State * Input_V) * scalar\n        dinput_v = tl.sum(d_issm * input_k[None, :], axis=1) * scalar\n        dinput_k = tl.sum(d_issm * input_v[:, None], axis=0) * scalar\n\n        # Store dInput_V_State and dInput_K_State\n        tl.store(dInput_V_State + dinput_v_offset + offs_v * stride_dinput_v_state_vdim, dinput_v, mask=offs_v < headdim_v)\n        tl.store(dInput_K_State + dinput_k_offset + offs_qk * stride_dinput_k_state_qkdim, dinput_k, mask=offs_qk < headdim_qk)\n\n        # Add contributions to dDT_0 and dTrap_0 from input state gradient\n        ddt_0_contrib = d_scalar * (1.0 - trap_0)\n        dtrap_0_contrib = d_scalar * (-dt_0)\n        dtrap_0_presig_contrib = dtrap_0_contrib * trap_0 * (1.0 - trap_0)\n        \n        # Atomically add to the first position (already written in Part 1)\n        tl.atomic_add(dDT + ddt_offset, ddt_0_contrib)\n        tl.atomic_add(dTrap + dtrap_offset, dtrap_0_presig_contrib)\n\n\ndef compute_ddt_dtrap_dinput_states(\n    dscale: torch.Tensor,\n    dgamma: torch.Tensor,\n    dt: torch.Tensor,\n    trap: torch.Tensor,\n    d_issm_state: Optional[torch.Tensor] = None,\n    input_k_state: Optional[torch.Tensor] = None,\n    input_v_state: Optional[torch.Tensor] = None,\n    Cu_Seqlens: Optional[torch.Tensor] = None,\n) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:\n    \"\"\"\n    Compute dDT, dTrap from dScale/dGamma, and optionally input state gradients.\n    \n    Args:\n        dscale: Gradient of scale, shape (batch, nheads, seqlen)\n        dgamma: Gradient of gamma, shape (batch, nheads, seqlen)\n        dt: DT tensor from forward pass, shape (batch, nheads, seqlen)\n        trap: Trap tensor from forward pass, shape (batch, nheads, seqlen)\n        d_issm_state: Gradient of SSM_State_mid (optional), shape (batch, nheads, headdim_v, headdim_qk)\n        input_k_state: Input K state from forward pass (optional), shape (batch, nheads, headdim_qk)\n        input_v_state: Input V state from forward pass (optional), shape (batch, nheads, headdim_v)\n    \n    Returns:\n        Tuple containing:\n            - dDT: Gradient for DT, shape (batch, nheads, seqlen)\n            - dTrap: Gradient for Trap, shape (batch, nheads, seqlen)\n            - dInput_SSM_State: Gradient for Input_SSM_State (None if no input state)\n            - dInput_K_State: Gradient for Input_K_State (None if no input state)\n            - dInput_V_State: Gradient for Input_V_State (None if no input state)\n    \"\"\"\n    batch, nheads, seqlen = dscale.shape\n    has_input_state = d_issm_state is not None\n    is_varlen = Cu_Seqlens is not None\n    \n    if is_varlen:\n        num_sequences = Cu_Seqlens.shape[0] - 1\n        assert batch == 1, \"Batch size must be 1 when using variable-length sequences.\"\n    else:\n        num_sequences = batch\n    \n    # Validate inputs\n    assert dgamma.shape == (batch, nheads, seqlen), f\"dgamma shape mismatch: {dgamma.shape}\"\n    assert dt.shape == (batch, nheads, seqlen), f\"dt shape mismatch: {dt.shape}\"\n    assert trap.shape == (batch, nheads, seqlen), f\"trap shape mismatch: {trap.shape}\"\n    \n    if has_input_state:\n        assert input_k_state is not None and input_v_state is not None, \\\n            \"input_k_state and input_v_state must be provided with d_issm_state\"\n        headdim_v, headdim_qk = d_issm_state.shape[2], d_issm_state.shape[3]\n        assert d_issm_state.shape == (num_sequences, nheads, headdim_v, headdim_qk), \\\n            f\"d_issm_state shape mismatch: {d_issm_state.shape}\"\n        assert input_k_state.shape == (num_sequences, nheads, headdim_qk), \\\n            f\"input_k_state shape mismatch: {input_k_state.shape}\"\n        assert input_v_state.shape == (num_sequences, nheads, headdim_v), \\\n            f\"input_v_state shape mismatch: {input_v_state.shape}\"\n    else:\n        headdim_v, headdim_qk = 64, 128  # Dummy values for block size calculation\n\n    # Ensure contiguity\n    dscale = dscale.contiguous() if not dscale.is_contiguous() else dscale\n    dgamma = dgamma.contiguous() if not dgamma.is_contiguous() else dgamma\n    dt = dt.contiguous() if not dt.is_contiguous() else dt\n    trap = trap.contiguous() if not trap.is_contiguous() else trap\n    \n    if has_input_state:\n        d_issm_state = d_issm_state.contiguous() if not d_issm_state.is_contiguous() else d_issm_state\n        input_k_state = input_k_state.contiguous() if not input_k_state.is_contiguous() else input_k_state\n        input_v_state = input_v_state.contiguous() if not input_v_state.is_contiguous() else input_v_state\n\n    # Allocate outputs\n    dDT = torch.empty_like(dt, dtype=torch.float32)\n    dTrap = torch.empty_like(trap, dtype=torch.float32)\n    \n    if has_input_state:\n        d_Input_SSM_State = torch.empty_like(d_issm_state)\n        d_Input_K_State = torch.empty((num_sequences, nheads, headdim_qk), dtype=torch.float32, device=dt.device)\n        d_Input_V_State = torch.empty((num_sequences, nheads, headdim_v), dtype=torch.float32, device=dt.device)\n    else:\n        d_Input_SSM_State = None\n        d_Input_K_State = None\n        d_Input_V_State = None\n\n    # Launch kernel\n    HEADDIM_V = triton.next_power_of_2(headdim_v) if has_input_state else 64\n    HEADDIM_QK = triton.next_power_of_2(headdim_qk) if has_input_state else 128\n    \n    # Grid\n    if is_varlen:\n        grid = (nheads, num_sequences)\n    else:\n        grid = (nheads, batch)\n    \n    mamba3_siso_bwd_kernel_ddt_dtrap_dinput_states[grid](\n        # Inputs\n        dscale, dgamma, dt, trap,\n        d_issm_state if has_input_state else dscale,  # Dummy pointer if not used\n        input_k_state if has_input_state else dscale,\n        input_v_state if has_input_state else dscale,\n        Cu_Seqlens,\n        # Outputs\n        dDT, dTrap,\n        d_Input_SSM_State if has_input_state else dDT,  # Dummy pointer if not used\n        d_Input_K_State if has_input_state else dDT,\n        d_Input_V_State if has_input_state else dDT,\n        # Strides for dScale\n        dscale.stride(0), dscale.stride(1), dscale.stride(2),\n        # Strides for dSGamma\n        dgamma.stride(0), dgamma.stride(1), dgamma.stride(2),\n        # Strides for DT\n        dt.stride(0), dt.stride(1), dt.stride(2),\n        # Strides for Trap\n        trap.stride(0), trap.stride(1), trap.stride(2),\n        # Strides for d_ISSM_State\n        d_issm_state.stride(0) if has_input_state else 0,\n        d_issm_state.stride(1) if has_input_state else 0,\n        d_issm_state.stride(2) if has_input_state else 0,\n        d_issm_state.stride(3) if has_input_state else 0,\n        # Strides for Input_K_State\n        input_k_state.stride(0) if has_input_state else 0,\n        input_k_state.stride(1) if has_input_state else 0,\n        input_k_state.stride(2) if has_input_state else 0,\n        # Strides for Input_V_State\n        input_v_state.stride(0) if has_input_state else 0,\n        input_v_state.stride(1) if has_input_state else 0,\n        input_v_state.stride(2) if has_input_state else 0,\n        # Stride for Cu_Seqlens\n        Cu_Seqlens.stride(0) if Cu_Seqlens is not None else 0,\n        # Strides for dDT\n        dDT.stride(0), dDT.stride(1), dDT.stride(2),\n        # Strides for dTrap\n        dTrap.stride(0), dTrap.stride(1), dTrap.stride(2),\n        # Strides for d_Input_SSM_State\n        d_Input_SSM_State.stride(0) if has_input_state else 0,\n        d_Input_SSM_State.stride(1) if has_input_state else 0,\n        d_Input_SSM_State.stride(2) if has_input_state else 0,\n        d_Input_SSM_State.stride(3) if has_input_state else 0,\n        # Strides for d_Input_K_State\n        d_Input_K_State.stride(0) if has_input_state else 0,\n        d_Input_K_State.stride(1) if has_input_state else 0,\n        d_Input_K_State.stride(2) if has_input_state else 0,\n        # Strides for d_Input_V_State\n        d_Input_V_State.stride(0) if has_input_state else 0,\n        d_Input_V_State.stride(1) if has_input_state else 0,\n        d_Input_V_State.stride(2) if has_input_state else 0,\n        # Dimensions\n        seqlen, headdim_v, headdim_qk,\n        # Constants\n        HEADDIM_V=HEADDIM_V,\n        HEADDIM_QK=HEADDIM_QK,\n        HAS_INPUT_STATE=has_input_state,\n        IS_VARLEN=is_varlen,\n    )\n\n    return dDT, dTrap, d_Input_SSM_State, d_Input_K_State, d_Input_V_State\n\n\n# =============================================================================\n# Memory Allocator for TMA Descriptors\n# =============================================================================\n\ndef _alloc_fn(size: int, alignment: int, stream: Optional[int]):\n    \"\"\"Custom allocator for TMA descriptor global memory allocation.\"\"\"\n    return torch.empty(size, device=\"cuda\", dtype=torch.int8)\n\n\ntriton.set_allocator(_alloc_fn)\n\n"
  },
  {
    "path": "mamba_ssm/ops/triton/mamba3/mamba3_siso_combined.py",
    "content": "\"\"\"Mamba-3 Triton Autograd Wrapper\n\nCopyright (c) 2025, Dao AI Lab, Goombalab\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom dataclasses import dataclass\nfrom typing import Optional, Tuple\n\nimport torch\nfrom torch import Tensor\nimport triton\n\n# Import kernels\nfrom mamba_ssm.ops.triton.mamba3.mamba3_siso_fwd import mamba3_siso_fwd\nfrom mamba_ssm.ops.triton.mamba3.mamba3_siso_bwd import compute_dzdo, compute_dqkv, compute_dqktheta, compute_ddt_dtrap_dinput_states\nfrom mamba_ssm.ops.triton.mamba3.angle_dt import angle_dt_fwd, angle_dt_bwd\n\n\ndef _triton_alloc_fn(size: int, alignment: int, stream: Optional[int]):\n    \"\"\"Allocator for Triton runtime memory (TMA descriptors, scratch).\"\"\"\n    return torch.empty(size, device=\"cuda\", dtype=torch.int8)\n\n\n# Set allocator immediately at import time.\ntry:\n    triton.set_allocator(_triton_alloc_fn)\nexcept Exception:\n    pass  # Allocator may already be set\n\n\n@dataclass(frozen=True)\nclass Mamba3Output:\n    \"\"\"Container for Mamba-3 outputs and optional intermediates.\n    \n    Attributes:\n        out: Main output tensor (batch, seqlen, nheads, headdim_v)\n        final_angle_state: Final angle state (num_sequences, nheads, headdim_angles)\n        final_ssm_state: Final SSM state (num_sequences, nheads, headdim_v, headdim_qk)\n        final_k_state: Final K state (num_sequences, nheads, headdim_qk)\n        final_v_state: Final V state (num_sequences, nheads, headdim_v)\n    \"\"\"\n    out: Tensor\n    final_angle_state: Optional[Tensor] = None\n    final_ssm_state: Optional[Tensor] = None\n    final_k_state: Optional[Tensor] = None\n    final_v_state: Optional[Tensor] = None\n\nclass _Mamba3Function(torch.autograd.Function):\n    \"\"\"Custom autograd function for Mamba-3 with Triton kernels.\"\"\"\n    \n    @staticmethod\n    def forward(\n        ctx,\n        Q: Tensor,\n        K: Tensor,\n        V: Tensor,\n        ADT: Tensor,\n        DT: Tensor,\n        Trap: Tensor,\n        Q_bias: Tensor,\n        K_bias: Tensor,\n        Angles: Tensor,\n        D: Optional[Tensor],\n        Z: Optional[Tensor],\n        Input_Angle_State: Optional[Tensor],\n        Input_SSM_State: Optional[Tensor],\n        Input_K_State: Optional[Tensor],\n        Input_V_State: Optional[Tensor],\n        cu_seqlens: Optional[Tensor],\n        chunk_size: int,\n        return_final_states: bool,\n    ) -> Tensor | Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:\n        \"\"\"Forward pass: call Triton kernel and save tensors for backward.\"\"\"\n        \n        try:\n            triton.set_allocator(_triton_alloc_fn)\n        except Exception:\n            pass\n        \n        needs_backward = any(ctx.needs_input_grad)\n        has_varlen = cu_seqlens is not None\n\n        all_states_present = (Input_SSM_State is not None) and (Input_K_State is not None) and (Input_V_State is not None) and (Input_Angle_State is not None)\n        all_states_absent = (Input_SSM_State is None) and (Input_K_State is None) and (Input_V_State is None) and (Input_Angle_State is None)\n\n        assert all_states_present or all_states_absent, \"Input states must be provided together or all be None.\"\n        \n        Angles_Cumsum, Final_Angle_State = angle_dt_fwd(\n            Angles, DT, \n            init_state=Input_Angle_State, \n            chunk_size=chunk_size, \n            return_output_state=True,\n            cu_seqlens=cu_seqlens,\n        )\n\n        Input_States = (\n            (Input_SSM_State, Input_K_State, Input_V_State)\n            if Input_SSM_State is not None\n            else None\n        )\n\n        Out, Out_v, SSM_States, DA_CS, DA_CS_SUM, Q_rot, K_scaled, QK_dot, Scale, Gamma, Final_States = mamba3_siso_fwd(\n            Q, K, V, ADT, DT, Trap, Q_bias, K_bias, Angles_Cumsum, D, Z, Input_States,\n            chunk_size=chunk_size,\n            store_states_adt_outv=needs_backward,\n            return_final_states=return_final_states,\n            cu_seqlens=cu_seqlens,\n        )\n\n        Final_SSM_State = Final_States[0] if Final_States is not None else None\n        Final_K_State = Final_States[1] if Final_States is not None else None\n        Final_V_State = Final_States[2] if Final_States is not None else None\n        \n        if needs_backward:\n            ctx.chunk_size = chunk_size\n            ctx.has_D = D is not None\n            ctx.has_Z = Z is not None\n            ctx.has_input_state = Input_SSM_State is not None\n            ctx.return_final_states = return_final_states\n            ctx.has_varlen = has_varlen\n            \n            # Save tensors - use empty tensor placeholders for None values\n            D_save = D if D is not None else torch.empty((), device=Q.device)\n            Z_save = Z if Z is not None else torch.empty((), device=Q.device)\n            Input_SSM_State_save = Input_SSM_State if Input_SSM_State is not None else torch.empty((), device=Q.device)\n            Input_K_State_save = Input_K_State if Input_K_State is not None else torch.empty((), device=Q.device)\n            Input_V_State_save = Input_V_State if Input_V_State is not None else torch.empty((), device=Q.device)\n            Final_SSM_State_save = Final_SSM_State if Final_SSM_State is not None else torch.empty((), device=Q.device)\n            cu_seqlens_save = cu_seqlens if cu_seqlens is not None else torch.empty((), device=Q.device, dtype=torch.int32)\n            \n            ctx.save_for_backward(\n                Q, K, V, ADT, DT, Trap, Q_bias, K_bias, Angles, Angles_Cumsum,\n                D_save, Z_save, Input_SSM_State_save, Input_K_State_save, Input_V_State_save,\n                Out, Out_v, SSM_States, DA_CS, DA_CS_SUM, Q_rot, K_scaled, QK_dot, Scale, Gamma,\n                Final_SSM_State_save, cu_seqlens_save\n            )\n        else:\n            ctx.chunk_size = chunk_size\n            ctx.has_D = D is not None\n            ctx.has_Z = Z is not None\n            ctx.has_input_state = Input_SSM_State is not None\n            ctx.return_final_states = return_final_states\n            ctx.has_varlen = has_varlen\n            ctx.save_for_backward()\n        \n        if return_final_states:\n            return Out, Final_Angle_State, Final_SSM_State, Final_K_State, Final_V_State\n        return Out\n    \n    @staticmethod\n    def backward(\n        ctx, \n        grad_out: Optional[Tensor] = None, \n        grad_final_angle_state: Optional[Tensor] = None,\n        grad_final_ssm_state: Optional[Tensor] = None, \n        grad_final_k_state: Optional[Tensor] = None, \n        grad_final_v_state: Optional[Tensor] = None\n    ) -> tuple:\n        \"\"\"Backward pass: compute gradients using Triton backward kernels.\"\"\"\n        \n        try:\n            triton.set_allocator(_triton_alloc_fn)\n        except Exception:\n            pass\n        \n        if len(ctx.saved_tensors) == 0:\n            raise RuntimeError(\n                \"Backward called but forward ran without gradient tracking. \"\n                \"Ensure inputs require grad or run under torch.enable_grad().\"\n            )\n        if grad_out is None and grad_final_ssm_state is None and grad_final_k_state is None and grad_final_v_state is None and grad_final_angle_state is None:\n            raise RuntimeError(\"No gradients provided for backward pass.\")\n\n        (Q, K, V, ADT, DT, Trap, Q_bias, K_bias, Angles, Angles_Cumsum,\n        D_save, Z_save, Input_SSM_State_save, Input_K_State_save, Input_V_State_save,\n        Out, Out_v, SSM_States, DA_CS, DA_CS_SUM, Q_rot, K_scaled, QK_dot, Scale, Gamma,\n        Final_SSM_State_save, cu_seqlens_save) = ctx.saved_tensors\n        \n        D = D_save if ctx.has_D else None\n        Z = Z_save if ctx.has_Z else None\n        Input_SSM_State = Input_SSM_State_save if ctx.has_input_state else None\n        Input_K_State = Input_K_State_save if ctx.has_input_state else None\n        Input_V_State = Input_V_State_save if ctx.has_input_state else None\n        cu_seqlens = cu_seqlens_save if ctx.has_varlen else None\n        \n        if grad_out is None:\n            grad_out = torch.zeros_like(Out)\n        \n        # Step 1: Compute dZ and scale grad_out if Z gating is present\n        if Z is not None:\n            dZ, grad_out_scaled = compute_dzdo(\n                grad_out, Z, Out_v, chunk_size=ctx.chunk_size\n            )\n        else:\n            dZ = None\n            grad_out_scaled = grad_out\n\n        # Step 2: Compute main gradients (dQ_mid, dK_mid, dV, dADT, dQK_dot, dD, dInput_SSM_State)\n        dQ_mid, dK_mid, dV, dADT, dQK_dot, dD, dInput_SSM_State = compute_dqkv(\n            q=Q_rot,\n            k=K_scaled,\n            v=V,\n            da_cs=DA_CS,\n            da_cs_sum=DA_CS_SUM,\n            qk_dot=QK_dot,\n            SSM_States=SSM_States,\n            do=grad_out_scaled,\n            d_ossm_state=grad_final_ssm_state,\n            d_ov_state=grad_final_v_state,\n            D=D,\n            chunk_size=ctx.chunk_size,\n            has_input_state=ctx.has_input_state,\n            Cu_Seqlens=cu_seqlens,\n        )\n        \n        # Step 3: Compute gradients through rotary embeddings and biases\n        dQ, dK, dQ_bias, dK_bias, dAngles_Cumsum, dScale, dGamma = compute_dqktheta(\n            q=Q,\n            k=K,\n            scale=Scale,\n            gamma=Gamma,\n            q_bias=Q_bias,\n            k_bias=K_bias,\n            angles=Angles_Cumsum,\n            dq_in=dQ_mid,\n            dk_in=dK_mid,\n            dqk=dQK_dot,\n            d_ok_state=grad_final_k_state,\n            chunk_size=ctx.chunk_size,\n            Cu_Seqlens=cu_seqlens,\n        )\n        \n        # Step 4: Compute dDT, dTrap, and input state gradients\n        dDT, dTrap, dInput_SSM_State_final, dInput_K_State, dInput_V_State = compute_ddt_dtrap_dinput_states(\n            dscale=dScale,\n            dgamma=dGamma,\n            dt=DT,\n            trap=Trap.float(),\n            d_issm_state=dInput_SSM_State if ctx.has_input_state else None,\n            input_k_state=Input_K_State,\n            input_v_state=Input_V_State,\n            Cu_Seqlens=cu_seqlens,\n        )\n        \n        # Step 5: Compute gradients through angle_dt cumsum\n        dAngles, dDT_angle, dInput_Angle_State = angle_dt_bwd(\n            grad_out=dAngles_Cumsum,\n            angle=Angles,\n            dt=DT,\n            has_init_state=ctx.has_input_state,\n            chunk_size=ctx.chunk_size,\n            grad_output_state=grad_final_angle_state if ctx.return_final_states else None,\n            cu_seqlens=cu_seqlens,\n        )\n        \n        # Accumulate DT gradients from angle_dt backward\n        dDT = dDT + dDT_angle\n        \n        if ctx.has_input_state:\n            dInput_SSM_State = dInput_SSM_State_final\n        else:\n            dInput_SSM_State = None\n            dInput_K_State = None\n            dInput_V_State = None\n            dInput_Angle_State = None\n        \n        return (\n            dQ,                     # Q\n            dK,                     # K\n            dV,                     # V\n            dADT,                   # ADT\n            dDT,                    # DT\n            dTrap,                  # Trap\n            dQ_bias,                # Q_bias\n            dK_bias,                # K_bias\n            dAngles,                # Angles\n            dD,                     # D\n            dZ,                     # Z\n            dInput_Angle_State,     # Input_Angle_State\n            dInput_SSM_State,       # Input_SSM_State\n            dInput_K_State,         # Input_K_State\n            dInput_V_State,         # Input_V_State\n            None,                   # cu_seqlens (not differentiable)\n            None,                   # chunk_size (not differentiable)\n            None,                   # return_final_states (not differentiable)\n        )\n\n\ndef mamba3_siso_combined(\n    Q: Tensor,\n    K: Tensor,\n    V: Tensor,\n    ADT: Tensor,\n    DT: Tensor,\n    Trap: Tensor,\n    Q_bias: Tensor,\n    K_bias: Tensor,\n    Angles: Tensor,\n    D: Optional[Tensor] = None,\n    Z: Optional[Tensor] = None,\n    Input_States: Optional[Tuple[Tensor, Tensor, Tensor, Tensor]] = None,\n    chunk_size: int = 64,\n    return_final_states: bool = False,\n    cu_seqlens: Optional[Tensor] = None,\n) -> Tensor | Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:\n    \"\"\"Mamba-3 attention with Triton kernels and automatic differentiation.\n\n    This is the main entry point for Mamba-3 forward and backward passes using\n    optimized Triton kernels. Supports GQA (grouped-query attention), rotary\n    position embeddings, optional gating, skip connections, state passing\n    for recurrent inference, and variable-length sequences.\n\n    Internally computes cumulative angles: Angles_Cumsum = cumsum(Angles * DT) mod 2π\n\n    Args:\n        Q: Query tensor             (batch, seqlen, nheads_qk, headdim_qk)\n        K: Key tensor               (batch, seqlen, nheads_qk, headdim_qk)\n        V: Value tensor             (batch, seqlen, nheads, headdim_v)\n        ADT: Decay factor A * dt    (batch, nheads, seqlen)\n        DT: Time delta tensor dt    (batch, nheads, seqlen)\n        Trap: Trapezoidal factor    (batch, nheads, seqlen)\n            Mixing factor in [0, 1] for trapezoidal discretization.\n        Q_bias: Query bias          (nheads, headdim_qk)\n        K_bias: Key bias            (nheads, headdim_qk)\n        Angles: Rotary angle rates  (batch, seqlen, nheads, headdim_angles)\n            Raw angle values that get accumulated via cumsum(Angles * DT).\n            If headdim_angles < headdim_qk // 2, remaining dims are unrotated.\n        D: Skip connection          (nheads,)\n            Optional per-head skip connection weight applied to V.\n        Z: Gating tensor            (batch, seqlen, nheads, headdim_v)\n            Optional gating applied as: out = out * silu(Z).\n        Input_States: Optional initial state tuple for recurrent inference.\n            Angle State:            (num_sequences, nheads, headdim_angles)\n            SSM State:              (num_sequences, nheads, headdim_v, headdim_qk)\n            K State:                (num_sequences, nheads, headdim_qk)\n            V State:                (num_sequences, nheads, headdim_v)\n        chunk_size: Chunk size for chunked state computation (default: 64).\n        return_final_states: If True, return final states for recurrent inference.\n        cu_seqlens: Cumulative sequence lengths for variable-length support.\n            Shape: (num_sequences + 1,), dtype: torch.int32.\n            Example: [0, 128, 256, 512] for 3 sequences of lengths 128, 128, 256.\n            When using cu_seqlens, batch must be 1 and the seqlen dimension\n            contains all sequences concatenated.\n\n    Returns:\n        If return_final_states=False:\n            out: Output tensor      (batch, seqlen, nheads, headdim_v)\n        If return_final_states=True:\n            Tuple of:\n                out: Output tensor              (batch, seqlen, nheads, headdim_v)\n                final_angle_state: Angle state  (num_sequences, nheads, headdim_angles)\n                final_ssm_state: SSM state      (num_sequences, nheads, headdim_v, headdim_qk)\n                final_k_state: K state          (num_sequences, nheads, headdim_qk)\n                final_v_state: V state          (num_sequences, nheads, headdim_v)\n\n    Notes:\n        - For GQA: nheads must be divisible by nheads_qk.\n        - headdim_qk and headdim_v must be powers of two for TMA compatiblity,\n        - Variable-length mode (cu_seqlens is not None) requires batch == 1.\n        - num_sequences = batch for batched mode, len(cu_seqlens)-1 for varlen mode.\n\n\n    Performance Notes:\n        The kernel is optimized for:\n            nheads_qk=1, nheads=32, headdim_qk=128, headdim_v=64, chunk_size=64.\n    \"\"\"\n    \n    batch, seqlen, nheads_qk, headdim_qk = Q.shape\n    _, _, nheads, headdim_v = V.shape\n    \n    assert nheads % nheads_qk == 0, f\"nheads ({nheads}) must be divisible by nheads_qk ({nheads_qk})\"\n    assert headdim_qk % 2 == 0, f\"headdim_qk ({headdim_qk}) must be even for rotary embeddings\"\n    \n    # Varlen mode checks\n    has_varlen = cu_seqlens is not None\n    if has_varlen:\n        if batch != 1:\n            raise ValueError(f\"Batch size must be 1 with variable-length sequences (cu_seqlens), got {batch}.\")\n    \n    Input_Angle_State, Input_SSM_State, Input_K_State, Input_V_State = (\n        Input_States if Input_States is not None else (None, None, None, None)\n    )\n\n    all_states_present = (Input_SSM_State is not None) and (Input_K_State is not None) and (Input_V_State is not None) and (Input_Angle_State is not None)\n    all_states_absent = (Input_SSM_State is None) and (Input_K_State is None) and (Input_V_State is None) and (Input_Angle_State is None)\n    assert all_states_present or all_states_absent, \"Input states must be provided together or all be None.\"\n\n    return _Mamba3Function.apply(\n        Q, K, V, ADT, DT, Trap, Q_bias, K_bias, Angles, D, Z,\n        Input_Angle_State, Input_SSM_State, Input_K_State, Input_V_State, cu_seqlens, chunk_size, return_final_states\n    )"
  },
  {
    "path": "mamba_ssm/ops/triton/mamba3/mamba3_siso_fwd.py",
    "content": "\"\"\"\nMamba-3 SISO Forward Pass Triton Kernel.\n\nCopyright (c) 2025, Dao AI Lab, Goombalab\n\"\"\"\n\nfrom typing import Optional, Tuple\nimport math\n\nimport torch\nimport torch.nn.functional as F\nfrom einops import rearrange, repeat\n\nimport triton\nimport triton.language as tl\nfrom mamba_ssm.ops.triton.mamba3.utils import cos_approx, sin_approx, tanh_approx, silu, sigmoid_approx\n\n@triton.autotune(\n    configs=[\n        triton.Config({}, num_stages=s, num_warps=w)\n        for s in [1, 2, 3]\n        for w in [2, 4, 8]\n    ],\n    key=[\n        \"CHUNK_SIZE\", \"HEADDIM_QK\", \"HEADDIM_V\", \"STORE_SSM_STATES_ADT_OUTV\", \"HAS_D\", \n        \"HAS_Z\", \"HAS_INITIAL_STATES\", \"RETURN_FINAL_STATES\", \"IS_VARLEN\"],\n)\n@triton.jit\ndef mamba3_siso_fwd_kernel(\n    # Inputs\n    Q, K, V, ADT, DT, Trap, Q_bias, K_bias, Angles, D, Z, \n    Initial_SSM_State, Initial_K_State, Initial_V_State, Cu_Seqlens,\n    # Outputs\n    Out, Out_v, SSM_States, DA_CS_Store, DA_CS_SUM_Store, Q_store, K_store, QK_store,\n    Scale_store, Gamma_store, Final_SSM_State, Final_K_State,\n    # Input Strides\n    stride_q_batch, stride_q_seqlen, stride_q_head, stride_q_qkdim,\n    stride_k_batch, stride_k_seqlen, stride_k_head, stride_k_qkdim,\n    stride_v_batch, stride_v_seqlen, stride_v_head, stride_v_vdim,\n    stride_adt_batch, stride_adt_head, stride_adt_seqlen,\n    stride_dt_batch, stride_dt_head, stride_dt_seqlen,\n    stride_trap_batch, stride_trap_head, stride_trap_seqlen,\n    stride_q_bias_head, stride_q_bias_qkdim,\n    stride_k_bias_head, stride_k_bias_qkdim,\n    stride_angles_batch, stride_angles_seqlen, stride_angles_head, stride_angles_qkdim,\n    stride_d_head,\n    stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_vdim,\n    stride_init_ssm_state_seq, stride_init_ssm_state_head, stride_init_ssm_state_vdim, \n    stride_init_ssm_state_qkdim,\n    stride_init_k_state_seq, stride_init_k_state_head, stride_init_k_state_qkdim,\n    stride_init_v_state_seq, stride_init_v_state_head, stride_init_v_state_vdim,\n    stride_cu_seqlen,\n    # Output Strides\n    stride_o_batch, stride_o_seqlen, stride_o_head, stride_o_vdim,\n    stride_o_v_batch, stride_o_v_seqlen, stride_o_v_head, stride_o_v_vdim,\n    stride_ssm_states_batch, stride_ssm_states_head, stride_ssm_states_vdim, stride_ssm_states_qkdim,\n    stride_da_cs_store_batch, stride_da_cs_store_head, stride_da_cs_store_seqlen,\n    stride_da_cs_sum_store_batch, stride_da_cs_sum_store_head, stride_da_cs_sum_store_seqlen,\n    stride_q_store_batch, stride_q_store_seqlen, stride_q_store_head, stride_q_store_qkdim,\n    stride_k_store_batch, stride_k_store_seqlen, stride_k_store_head, stride_k_store_qkdim,\n    stride_qk_store_batch, stride_qk_store_head, stride_qk_store_seqlen,\n    stride_scale_store_batch, stride_scale_store_head, stride_scale_store_seqlen,\n    stride_gamma_store_batch, stride_gamma_store_head, stride_gamma_store_seqlen,\n    stride_final_ssm_state_seq, stride_final_ssm_state_head, stride_final_ssm_state_vdim, \n    stride_final_ssm_state_qkdim,\n    stride_final_k_state_seq, stride_final_k_state_head, stride_final_k_state_chunk, \n    stride_final_k_state_qkdim,\n    # Dimensions\n    seqlen, nheads_qk, headdim_qk, headdim_v, headdim_angles,\n    CHUNK_SIZE: tl.constexpr,\n    HEADDIM_QK: tl.constexpr,\n    HEADDIM_V: tl.constexpr,\n    STORE_SSM_STATES_ADT_OUTV: tl.constexpr,\n    HAS_INITIAL_STATES: tl.constexpr,\n    RETURN_FINAL_STATES: tl.constexpr,\n    HAS_D: tl.constexpr,\n    HAS_Z: tl.constexpr,\n    IS_VARLEN: tl.constexpr,\n):\n    \"\"\"\n    Mamba-3 forward kernel.\n\n    Grid: (nheads, batch) for batched, (nheads, 1, num_sequences) for varlen\n\n    Inputs:\n        Q, K:                       (batch, seqlen, nheads_qk, headdim_qk)\n        V:                          (batch, seqlen, nheads, headdim_v)  \n        ADT, DT, Trap:              (batch, nheads, seqlen)\n        Q_bias, K_bias:             (nheads, headdim_qk)\n        Angles:                     (batch, seqlen, nheads, headdim_angles)\n        D:                          (nheads,)\n        Z:                          (batch, seqlen, nheads, headdim_v)\n        Initial SSM State:          (num_sequences, nheads, headdim_v, headdim_qk)\n        Initial K State:            (num_sequences, nheads, headdim_qk)\n        Initial V State:            (num_sequences, nheads, headdim_v)\n        Cu_Seqlens:                 (num_sequences + 1,)\n\n    NOTE: num_sequences = batch for batched mode, or len(cu_seqlens)-1 for varlen mode.\n\n    Compile-time constants:\n        CHUNK_SIZE:                 Chunk size for processing sequences\n        HEADDIM_QK:                 Head dimension for Q/K\n        HEADDIM_V:                  Head dimension for V\n        \n        STORE_SSM_STATES_ADT_OUTV:  Whether to store SSM states, ADT, and Out_v for backward pass\n                                    Set to FALSE for inference-only runs for efficiency\n        HAS_INITIAL_STATES:         Whether input SSM states are provided for state passing\n        RETURN_FINAL_STATES:        Whether to return final SSM states for state passing\n        HAS_D:                      Whether D-skip connection is used\n        HAS_Z:                      Whether Z-gating is used\n        IS_VARLEN:                  Whether the input is a variable-length sequence\n\n    NOTE:\n        1. nheads % nheads_qk == 0\n        2. Kernel is optimized for headdim_qk = 128 and headdim_v = 64\n\n    Outputs:\n        Out:                    (batch, seqlen, nheads, headdim_v)\n        Out_v:                  (batch, seqlen, nheads, headdim_v) (if STORE_SSM_STATES_ADT_OUTV)\n        SSM_States:             (batch, nheads, headdim_v, nchunks * headdim_qk) (if STORE_SSM_STATES_ADT_OUTV)\n        DA_CS_Store:            (batch, nheads, seqlen) (if STORE_SSM_STATES_ADT_OUTV)\n        DA_CS_SUM_Store:        (batch, nheads, nchunks) (if STORE_SSM_STATES_ADT_OUTV)\n        Q_store:                (batch, seqlen, nheads, headdim_qk)\n        K_store:                (batch, seqlen, nheads, headdim_qk)\n        QK_store:               (batch, seqlen, nheads)\n        Scale_store:            (batch, seqlen, nheads)\n        Gamma_store:            (batch, seqlen, nheads)\n        Final SSM State:        (num_sequences, nheads, headdim_v, headdim_qk) (if RETURN_FINAL_STATES)\n        Final K State:          (num_sequences, nheads, chunk_size, headdim_qk) (if RETURN_FINAL_STATES)\n    \n    NOTE: \n    1. For batched inputs, nchunks = ceil(seqlen / CHUNK_SIZE) and for varlen inputs, nchunks = num_sequences + \n    total_seqlen//CHUNK_SIZE.\n    2. Final K state has an additional chunk_size dimension since triton does not allow indexing within a chunk. We\n    pick the correct index in the wrapper.\n    \"\"\"\n    pid_head = tl.program_id(0)\n    pid_batch = tl.program_id(1)\n\n    if IS_VARLEN:\n        pid_seq = tl.program_id(2)\n        seq_idx = pid_seq\n\n        cu_seqlen_start = tl.load(Cu_Seqlens + pid_seq * stride_cu_seqlen).to(tl.int32)\n        cu_seqlen_end = tl.load(Cu_Seqlens + (pid_seq + 1) * stride_cu_seqlen).to(tl.int32)\n        total_seqlen = seqlen\n        seqlen = cu_seqlen_end - cu_seqlen_start\n        seq_offset = cu_seqlen_start\n        chunk_offset = pid_seq + cu_seqlen_start // CHUNK_SIZE\n    else:\n        seq_idx = pid_batch\n        seq_offset = 0\n        chunk_offset = 0\n    \n    num_chunks = tl.cdiv(seqlen, CHUNK_SIZE)\n\n    # Compute head index for Q/K (supports Grouped Query Attention)\n    nheads = tl.num_programs(0)\n    head_idx_qk = pid_head // (nheads // nheads_qk)\n\n    # Setup input pointers\n    q_ptr = Q + pid_batch * stride_q_batch + head_idx_qk * stride_q_head + seq_offset * stride_q_seqlen\n    k_ptr = K + pid_batch * stride_k_batch + head_idx_qk * stride_k_head + seq_offset * stride_k_seqlen\n    v_ptr = V + pid_batch * stride_v_batch + pid_head * stride_v_head + seq_offset * stride_v_seqlen\n    adt_ptr = ADT + pid_batch * stride_adt_batch + pid_head * stride_adt_head + seq_offset * stride_adt_seqlen\n    dt_ptr = DT + pid_batch * stride_dt_batch + pid_head * stride_dt_head + seq_offset * stride_dt_seqlen\n    trap_ptr = Trap + pid_batch * stride_trap_batch + pid_head * stride_trap_head + seq_offset * stride_trap_seqlen\n    q_bias_ptr = Q_bias + pid_head * stride_q_bias_head\n    k_bias_ptr = K_bias + pid_head * stride_k_bias_head\n    angle_ptr = Angles + pid_batch * stride_angles_batch + pid_head * stride_angles_head + seq_offset * stride_angles_seqlen\n    \n    if HAS_D:\n        D_ptr = D + pid_head * stride_d_head\n        D_val = tl.load(D_ptr).to(tl.float32)\n    if HAS_Z:\n        z_ptr = Z + pid_batch * stride_z_batch + pid_head * stride_z_head + seq_offset * stride_z_seqlen\n    \n    # State pointers use seq_idx (unified for batched and varlen)\n    if HAS_INITIAL_STATES:\n        init_ssm_state_ptr = Initial_SSM_State + seq_idx * stride_init_ssm_state_seq + pid_head * stride_init_ssm_state_head\n        init_k_state_ptr = Initial_K_State + seq_idx * stride_init_k_state_seq + pid_head * stride_init_k_state_head\n        init_v_state_ptr = Initial_V_State + seq_idx * stride_init_v_state_seq + pid_head * stride_init_v_state_head\n\n    # Setup output pointers\n    o_ptr = Out + pid_batch * stride_o_batch + pid_head * stride_o_head + seq_offset * stride_o_seqlen\n    if STORE_SSM_STATES_ADT_OUTV:\n        out_v_ptr = Out_v + pid_batch * stride_o_v_batch + pid_head * stride_o_v_head + seq_offset * stride_o_v_seqlen\n        ssm_states_ptr = SSM_States + pid_batch * stride_ssm_states_batch + pid_head * stride_ssm_states_head + chunk_offset * HEADDIM_QK * stride_ssm_states_qkdim\n        da_cs_store_ptr = DA_CS_Store + pid_batch * stride_da_cs_store_batch + pid_head * stride_da_cs_store_head + seq_offset * stride_da_cs_store_seqlen\n        da_cs_sum_store_ptr = DA_CS_SUM_Store + pid_batch * stride_da_cs_sum_store_batch + pid_head * stride_da_cs_sum_store_head + chunk_offset * stride_da_cs_sum_store_seqlen\n\n    q_store_ptr = Q_store + pid_batch * stride_q_store_batch + pid_head * stride_q_store_head + seq_offset * stride_q_store_seqlen\n    k_store_ptr = K_store + pid_batch * stride_k_store_batch + pid_head * stride_k_store_head + seq_offset * stride_k_store_seqlen\n    qk_store_ptr = QK_store + pid_batch * stride_qk_store_batch + pid_head * stride_qk_store_head + seq_offset * stride_qk_store_seqlen\n    scale_store_ptr = Scale_store + pid_batch * stride_scale_store_batch + pid_head * stride_scale_store_head + seq_offset * stride_scale_store_seqlen\n    gamma_store_ptr = Gamma_store + pid_batch * stride_gamma_store_batch + pid_head * stride_gamma_store_head + seq_offset * stride_gamma_store_seqlen\n\n    if RETURN_FINAL_STATES:\n        final_ssm_state_ptr = Final_SSM_State + seq_idx * stride_final_ssm_state_seq + pid_head * stride_final_ssm_state_head\n        final_k_state_ptr = Final_K_State + seq_idx * stride_final_k_state_seq + pid_head * stride_final_k_state_head\n\n    # Create TMA tensor descriptors\n    q_desc = tl.make_tensor_descriptor(\n        q_ptr,\n        shape=[seqlen, headdim_qk],\n        strides=[stride_q_seqlen, stride_q_qkdim],\n        block_shape=[CHUNK_SIZE, HEADDIM_QK],\n    )\n    k_desc = tl.make_tensor_descriptor(\n        k_ptr,\n        shape=[seqlen, headdim_qk],\n        strides=[stride_k_seqlen, stride_k_qkdim],\n        block_shape=[CHUNK_SIZE, HEADDIM_QK],\n    )\n    v_desc = tl.make_tensor_descriptor(\n        v_ptr,\n        shape=[seqlen, headdim_v],\n        strides=[stride_v_seqlen, stride_v_vdim],\n        block_shape=[CHUNK_SIZE, HEADDIM_V],\n    )\n    if HAS_Z:\n        z_desc = tl.make_tensor_descriptor(\n            z_ptr,\n            shape=[seqlen, headdim_v],\n            strides=[stride_z_seqlen, stride_z_vdim],\n            block_shape=[CHUNK_SIZE, HEADDIM_V],\n        )\n    \n    q_store_desc = tl.make_tensor_descriptor(\n        q_store_ptr,\n        shape=[seqlen, headdim_qk],\n        strides=[stride_q_store_seqlen, stride_q_store_qkdim],\n        block_shape=[CHUNK_SIZE, HEADDIM_QK],\n    )\n    k_store_desc = tl.make_tensor_descriptor(\n        k_store_ptr,\n        shape=[seqlen, headdim_qk],\n        strides=[stride_k_store_seqlen, stride_k_store_qkdim],\n        block_shape=[CHUNK_SIZE, HEADDIM_QK],\n    )\n    o_desc = tl.make_tensor_descriptor(\n        o_ptr,\n        shape=[seqlen, headdim_v],\n        strides=[stride_o_seqlen, stride_o_vdim],\n        block_shape=[CHUNK_SIZE, HEADDIM_V],\n    )\n    if STORE_SSM_STATES_ADT_OUTV:\n        ssm_states_desc = tl.make_tensor_descriptor(\n            ssm_states_ptr,\n            shape=[headdim_v, num_chunks * headdim_qk],\n            strides=[stride_ssm_states_vdim, stride_ssm_states_qkdim],\n            block_shape=[HEADDIM_V, HEADDIM_QK],\n        )\n\n    # Phase 1: Preprocessing - Apply bias, rotary embeddings, compute QK dots.\n    for chunk_idx in range(num_chunks):\n        chunk_start = chunk_idx * CHUNK_SIZE\n        offs_seqlen = chunk_start + tl.arange(0, CHUNK_SIZE)\n        offs_hd = tl.arange(0, HEADDIM_QK)\n        offs_hdr = tl.arange(0, HEADDIM_QK // 2)\n\n        # Load Q and K blocks via TMA\n        q_pre_block = q_desc.load([chunk_start, 0])\n        k_pre_block = k_desc.load([chunk_start, 0])\n        \n        # Load rotary angles\n        angle_block = tl.load(\n            angle_ptr + offs_seqlen[:, None] * stride_angles_seqlen + offs_hdr[None, :] * stride_angles_qkdim,\n            mask=(offs_seqlen[:, None] < seqlen) & (offs_hdr[None, :] < headdim_angles), other=0.0\n        )\n        \n        # Compute shifted gamma and scale\n        dt = tl.load(dt_ptr + offs_seqlen * stride_dt_seqlen, mask=offs_seqlen < seqlen, other=0.0).to(tl.float32)\n        dt_shifted = tl.load(\n            dt_ptr + (offs_seqlen + 1) * stride_dt_seqlen, \n            mask=offs_seqlen + 1 < seqlen, other=0.0).to(tl.float32)\n        trap = tl.load(trap_ptr + offs_seqlen * stride_trap_seqlen, mask=offs_seqlen < seqlen, other=0.0).to(tl.float32)\n        trap = sigmoid_approx(trap)\n        trap_shifted = tl.load(\n            trap_ptr + (offs_seqlen + 1) * stride_trap_seqlen, \n            mask=offs_seqlen + 1 < seqlen, other=0.0).to(tl.float32)\n        trap_shifted = sigmoid_approx(trap_shifted)\n\n        shifted_gamma = dt_shifted * (1 - trap_shifted)\n        gamma = dt * trap\n        scale = shifted_gamma + gamma\n\n        # Store scale and shifted gamma for backward pass\n        tl.store(gamma_store_ptr + offs_seqlen * stride_gamma_store_seqlen, gamma, mask=offs_seqlen < seqlen)\n        tl.store(scale_store_ptr + offs_seqlen * stride_scale_store_seqlen, scale, mask=offs_seqlen < seqlen)\n\n        # Add biases to Q and K\n        q_bias_block = tl.load(q_bias_ptr + offs_hd * stride_q_bias_qkdim, offs_hd < headdim_qk)\n        q_pre_block += q_bias_block[None, :]\n        k_bias_block = tl.load(k_bias_ptr + offs_hd * stride_k_bias_qkdim, offs_hd < headdim_qk)\n        k_pre_block += k_bias_block[None, :]\n\n        # Compute QK dot products for skip connection\n        store_qk_dot = tl.dot(\n            q_pre_block * k_pre_block,\n            tl.full([HEADDIM_QK, 1], 1, dtype=q_pre_block.dtype)\n        ).to(q_pre_block.dtype)\n        store_qk_dot = store_qk_dot.reshape(CHUNK_SIZE)\n        store_qk_dot *= gamma\n        tl.store(qk_store_ptr + offs_seqlen * stride_qk_store_seqlen, store_qk_dot, mask=offs_seqlen < seqlen)\n        \n        # Compute rotary embedding cos/sin\n        cos_block = cos_approx(angle_block.to(tl.float32))\n        sin_block = sin_approx(angle_block.to(tl.float32))\n\n        # Apply rotary embeddings to K and scale\n        k0, k1 = tl.split(tl.reshape(k_pre_block, [CHUNK_SIZE, HEADDIM_QK // 2, 2]))\n        ko0 = k0 * cos_block - k1 * sin_block\n        ko1 = k0 * sin_block + k1 * cos_block\n        k_pre_block = tl.reshape(tl.join(ko0, ko1), [CHUNK_SIZE, HEADDIM_QK]).to(k_pre_block.dtype)\n\n        if chunk_idx == num_chunks - 1 and RETURN_FINAL_STATES:\n            tl.store(final_k_state_ptr + tl.arange(0, CHUNK_SIZE)[:, None] * stride_final_k_state_chunk \n                + offs_hd[None, :] * stride_final_k_state_qkdim, \n                k_pre_block,\n                mask=(offs_hd[None, :] < headdim_qk))\n            \n        k_pre_block *= scale[:, None]\n        k_store_desc.store([chunk_start, 0], k_pre_block)\n\n        # Apply rotary embeddings to Q\n        q0, q1 = tl.split(tl.reshape(q_pre_block, [CHUNK_SIZE, HEADDIM_QK // 2, 2]))\n        qo0 = q0 * cos_block - q1 * sin_block\n        qo1 = q0 * sin_block + q1 * cos_block\n        q_pre_block = tl.reshape(tl.join(qo0, qo1), [CHUNK_SIZE, HEADDIM_QK]).to(q_pre_block.dtype)\n        q_store_desc.store([chunk_start, 0], q_pre_block)\n\n    # Phase 2: Main computation and output generation.\n    if HAS_INITIAL_STATES:\n        acc_ssm_states = tl.load(\n            init_ssm_state_ptr + tl.arange(0, HEADDIM_V)[:, None] * stride_init_ssm_state_vdim \n            + tl.arange(0, HEADDIM_QK)[None, :] * stride_init_ssm_state_qkdim,\n            mask= (tl.arange(0, HEADDIM_V)[:, None] < headdim_v) & (tl.arange(0, HEADDIM_QK)[None, :] < headdim_qk),\n            other=0.0).to(tl.float32)\n        input_k_state = tl.load(\n            init_k_state_ptr + tl.arange(0, HEADDIM_QK) * stride_init_k_state_qkdim,\n            mask=tl.arange(0, HEADDIM_QK) < headdim_qk, other=0.0).to(tl.float32)\n        input_v_state = tl.load(\n            init_v_state_ptr + tl.arange(0, HEADDIM_V) * stride_init_v_state_vdim,\n            mask=tl.arange(0, HEADDIM_V) < headdim_v, other=0.0).to(tl.float32)\n\n        dt_scalar = tl.load(dt_ptr).to(tl.float32)\n        trap_scalar = tl.load(trap_ptr).to(tl.float32)\n        trap_scalar = sigmoid_approx(trap_scalar)\n        # Step on the SSM states with input K/V states to account for trapezoidal discretization\n        acc_ssm_states += input_v_state[:, None] * input_k_state[None, :] * dt_scalar * (1 - trap_scalar)\n    else:\n        acc_ssm_states = tl.zeros([HEADDIM_V, HEADDIM_QK], dtype=tl.float32)\n\n    if HAS_D:\n        D_val = tl.load(D_ptr).to(tl.float32)\n    else:\n        D_val = 0.0\n\n    for chunk_idx in range(num_chunks):\n        chunk_start = chunk_idx * CHUNK_SIZE\n        offs_seqlen = chunk_start + tl.arange(0, CHUNK_SIZE)\n\n        # Load decay factors (log2 scale for exp2 computation)\n        adt_ptrs = adt_ptr + offs_seqlen * stride_adt_seqlen\n        da = tl.load(adt_ptrs, mask=offs_seqlen < seqlen, other=0.0) * 1.44269504089  # log2(e)\n\n        # Load preprocessed Q, K, V blocks\n        q_block = q_store_desc.load([chunk_start, 0])\n        k_block = k_store_desc.load([chunk_start, 0])\n        v_block = v_desc.load([chunk_start, 0])\n        if HAS_Z:\n            z_block = z_desc.load([chunk_start, 0])\n\n        # Compute cumulative decay for this chunk\n        da_cs = tl.cumsum(da)\n        da_cs_last = tl.sum(da)\n        da_cs_rev = da_cs_last - da_cs\n\n        # Store decay info for backward pass\n        if STORE_SSM_STATES_ADT_OUTV:\n            tl.store(da_cs_store_ptr + offs_seqlen * stride_da_cs_store_seqlen, da_cs, mask=offs_seqlen < seqlen)\n            tl.store(da_cs_sum_store_ptr + chunk_idx * stride_da_cs_sum_store_seqlen, da_cs_last)\n\n        # Output contribution from previous state: Q @ SSM_States^T * exp(da_cs)\n        acc_o = tl.dot(q_block, tl.trans(acc_ssm_states).to(q_block.dtype))\n        acc_o *= tl.math.exp2(da_cs)[:, None]\n\n        # Output contribution from current chunk: causal(Q @ K^T * exp(decay)) @ V\n        # NOTE: We compute the (i,i) component using QK dot to prevent non-causal numerical leakage\n        s_block = tl.dot(q_block, tl.trans(k_block))\n        s_block *= tl.math.exp2(tl.minimum((da_cs[:, None] - da_cs[None, :]), 0.0))\n        s_block = tl.where(\n            tl.arange(0, CHUNK_SIZE)[:, None] > tl.arange(0, CHUNK_SIZE)[None, :], \n            s_block, \n            0.0\n        )\n        acc_o += tl.dot(s_block.to(v_block.dtype), v_block)\n\n        # Add D-skip connection and subtract QK dot contribution\n        qk_dot = tl.load(qk_store_ptr + offs_seqlen * stride_qk_store_seqlen, mask=offs_seqlen < seqlen, other=0.0)\n        acc_o += (D_val + qk_dot)[:, None] * v_block\n\n        if STORE_SSM_STATES_ADT_OUTV:\n            tl.store(out_v_ptr + offs_seqlen[:, None] * stride_o_v_seqlen \n                + tl.arange(0, HEADDIM_V)[None, :] * stride_o_v_vdim, acc_o, \n                mask=(offs_seqlen[:, None] < seqlen) & (tl.arange(0, HEADDIM_V)[None, :] < headdim_v))\n\n        # Apply Z-gating if present\n        if HAS_Z:\n            acc_o = acc_o * silu(z_block.to(tl.float32))\n\n        # Store output\n        o_desc.store([chunk_start, 0], acc_o)\n\n        if STORE_SSM_STATES_ADT_OUTV:\n            ssm_states_desc.store([0, chunk_idx * headdim_qk], acc_ssm_states.to(ssm_states_desc.dtype))\n\n        # Update recurrent states\n        scale = tl.math.exp2(da_cs_rev)\n        v_block *= scale[:, None]\n        acc_ssm_states = acc_ssm_states * tl.math.exp2(da_cs_last) + tl.dot(\n            tl.trans(v_block).to(k_block.dtype), k_block\n        )\n\n    # Store final states if requested\n    if RETURN_FINAL_STATES:\n        tl.store(final_ssm_state_ptr + tl.arange(0, HEADDIM_V)[:, None] * stride_final_ssm_state_vdim \n            + tl.arange(0, HEADDIM_QK)[None, :] * stride_final_ssm_state_qkdim, \n            acc_ssm_states,\n            mask=(tl.arange(0, HEADDIM_V)[:, None] < headdim_v) & (tl.arange(0, HEADDIM_QK)[None, :] < headdim_qk))\n\n# Memory Allocator for TMA Descriptors\ndef _alloc_fn(size: int, alignment: int, stream: Optional[int]):\n    \"\"\"Custom allocator for TMA descriptor global memory allocation.\"\"\"\n    return torch.empty(size, device=\"cuda\", dtype=torch.int8)\ntriton.set_allocator(_alloc_fn)\n\ndef mamba3_siso_fwd(\n    Q: torch.Tensor,\n    K: torch.Tensor,\n    V: torch.Tensor,\n    ADT: torch.Tensor,\n    DT: torch.Tensor,\n    Trap: torch.Tensor,\n    Q_bias: torch.Tensor,\n    K_bias: torch.Tensor,\n    Angles: torch.Tensor,\n    D: Optional[torch.Tensor] = None,\n    Z: Optional[torch.Tensor] = None,\n    Initial_States: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None,\n    chunk_size: int = 64,\n    store_states_adt_outv: bool = False,\n    return_final_states: bool = False,\n    cu_seqlens: Optional[torch.Tensor] = None,\n):\n    \"\"\"\n    Mamba-3 forward pass wrapper.\n    \n    Args:\n        Q: Query tensor                 (batch, seqlen, nheads_qk, headdim_qk)\n        K: Key tensor                   (batch, seqlen, nheads_qk, headdim_qk)\n        V: Value tensor                 (batch, seqlen, nheads, headdim_v)\n        ADT: Decay tensor               (batch, nheads, seqlen)\n        DT: DT tensor                   (batch, nheads, seqlen)\n        Trap: Trap tensor               (batch, nheads, seqlen)\n        Q_bias: Query bias              (nheads, headdim_qk)\n        K_bias: Key bias                (nheads, headdim_qk)\n        Angles: Rotary angles           (batch, seqlen, nheads, headdim_angles)\n            - headdim_angles <= headdim_qk // 2 and headdim_angles % 2 == 0.\n        D: Skip connection weight       (nheads,)\n        Z: Gating tensor                (batch, seqlen, nheads, headdim_v)\n            - Applies SiLU gating: out = out * silu(Z).\n        Initial_States: Tuple of (SSM_State, K_State, V_State)\n            SSM State shape:        (num_sequences, nheads, headdim_v, headdim_qk).\n            K state shape:          (num_sequences, nheads, headdim_qk).\n            V state shape:          (num_sequences, nheads, headdim_v).\n                - K state is post bias and rotation and pre scaling\n        cu_seqlens: Cumulative sequence lengths (num_sequences + 1,) for varlen\n        chunk_size: Chunk size for processing\n        store_states_adt_outv: Store intermediate states for backward pass\n        return_final_states: Return final states\n        \n    Returns:\n        Out: Output tensor                      (batch, seqlen, nheads, headdim_v)\n        Out_v: Pre-gate output tensor           (batch, seqlen, nheads, headdim_v) (if store_states_adt_outv)\n        SSM_States: Per-chunk SSM States        (batch, nheads, headdim_v, nchunks * headdim_qk) (if store_states_adt_outv)\n        DA_CS_Store: Cumulative decay           (batch, nheads, seqlen) (if store_states_adt_outv)\n        DA_CS_SUM_Store: Chunk decay sum        (batch, nheads, nchunks) (if store_states_adt_outv)\n        Q_store: Rotated Q+bias                 (batch, seqlen, nheads, headdim_qk) (None if store_states_adt_outv=False)\n        K_store: Rotated K+bias                 (batch, seqlen, nheads, headdim_qk) (None if store_states_adt_outv=False)\n        QK_store: QK dot products               (batch, nheads, seqlen) (None if store_states_adt_outv=False)\n        Scale_store: Scale factors              (batch, nheads, seqlen) (None if store_states_adt_outv=False)\n        Gamma_store: Gamma factors              (batch, nheads, seqlen) (None if store_states_adt_outv=False)\n        Final States: Final output state (None if return_output_state=False)\n            Final SSM State                (num_sequences, nheads, headdim_v, headdim_qk)\n            Final K state                  (num_sequences, nheads, headdim_qk)\n            Final V state                  (num_sequences, nheads, headdim_v)\n    \n    Notes:\n        1. For varlen mode: batch must be 1, cu_seqlens required\n        2. num_sequences = batch for batched mode, len(cu_seqlens)-1 for varlen\n        3. nheads % nheads_qk == 0\n        4. nchunks = ceil(seqlen / chunk_size) for batched mode, num_sequences + total_seqlen//chunk_size for varlen mode.\n    \n    COMMENT:\n        Design choice to store: Q_store, K_store, QK_store, is primarily an artifact of Triton's\n        lack of programmatic access to shared memory---In the forward pass, we compute, store and then re-load\n        these tensors in shared memory (using TMA) to prevent register spilling.\n        \n    \"\"\"\n    batch, seqlen, nheads_qk, headdim_qk = Q.shape\n    _, _, nheads, headdim_v = V.shape\n    device = Q.device\n    is_varlen = cu_seqlens is not None\n    assert seqlen > 0, \"Sequence length must be greater than 0\"\n\n    # Determine number of sequences\n    if is_varlen:\n        assert batch == 1, \"Varlen mode requires batch=1\"\n        num_sequences = cu_seqlens.shape[0] - 1\n    else:\n        num_sequences = batch\n        cu_seqlens = None\n\n    # Validate shapes\n    assert Q.shape == K.shape, f\"Q and K shape mismatch: {Q.shape} vs {K.shape}\"\n    assert nheads % nheads_qk == 0, f\"nheads ({nheads}) must be divisible by nheads_qk ({nheads_qk})\"\n    assert ADT.shape == (batch, nheads, seqlen)\n    assert DT.shape == (batch, nheads, seqlen)\n    assert Trap.shape == (batch, nheads, seqlen)\n    assert Q_bias.shape == (nheads, headdim_qk)\n    assert K_bias.shape == (nheads, headdim_qk)\n    \n    headdim_angles = Angles.shape[-1]\n    assert headdim_angles <= headdim_qk // 2 and headdim_angles % 2 == 0\n    assert Angles.shape == (batch, seqlen, nheads, headdim_angles)\n    \n    if D is not None:\n        assert D.shape == (nheads,)\n    if Z is not None:\n        assert Z.shape == (batch, seqlen, nheads, headdim_v)\n    \n    if Initial_States is not None:\n        Init_SSM_State, Init_K_State, Init_V_State = Initial_States\n        assert Init_SSM_State.shape == (num_sequences, nheads, headdim_v, headdim_qk), \\\n            f\"Initial_States[0] shape mismatch: expected {(num_sequences, nheads, headdim_v, headdim_qk)}, got {Init_SSM_State.shape}\"\n        assert Init_K_State.shape == (num_sequences, nheads, headdim_qk), \\\n            f\"Initial_States[1] shape mismatch: expected {(num_sequences, nheads, headdim_qk)}, got {Init_K_State.shape}\"\n        assert Init_V_State.shape == (num_sequences, nheads, headdim_v), \\\n            f\"Initial_States[2] shape mismatch: expected {(num_sequences, nheads, headdim_v)}, got {Init_V_State.shape}\"\n    else:\n        Init_SSM_State, Init_K_State, Init_V_State = None, None, None\n\n    # Ensure contiguous\n    Q = Q.contiguous() if not Q.is_contiguous() else Q\n    K = K.contiguous() if not K.is_contiguous() else K\n    V = V.contiguous() if not V.is_contiguous() else V\n    ADT = ADT.contiguous() if not ADT.is_contiguous() else ADT\n    DT = DT.contiguous() if not DT.is_contiguous() else DT\n    Trap = Trap.contiguous() if not Trap.is_contiguous() else Trap\n    Q_bias = Q_bias.contiguous() if not Q_bias.is_contiguous() else Q_bias\n    K_bias = K_bias.contiguous() if not K_bias.is_contiguous() else K_bias\n    Angles = Angles.contiguous() if not Angles.is_contiguous() else Angles\n    \n    if D is not None:\n        D = D.contiguous() if not D.is_contiguous() else D\n    if Z is not None:\n        Z = Z.contiguous() if not Z.is_contiguous() else Z\n    if Initial_States is not None:\n        Init_SSM_State = Init_SSM_State.contiguous() if not Init_SSM_State.is_contiguous() else Init_SSM_State\n        Init_K_State = Init_K_State.contiguous() if not Init_K_State.is_contiguous() else Init_K_State\n        Init_V_State = Init_V_State.contiguous() if not Init_V_State.is_contiguous() else Init_V_State\n    \n    # Calculate nchunks\n    if is_varlen:\n        nchunks = num_sequences + seqlen // chunk_size\n    else:\n        nchunks = (seqlen + chunk_size - 1) // chunk_size\n\n    # Allocate output tensors\n    Out = torch.empty((batch, seqlen, nheads, headdim_v), device=device, dtype=V.dtype)\n    if store_states_adt_outv:\n        SSM_States = torch.zeros((batch, nheads, headdim_v, nchunks * headdim_qk), device=device, dtype=torch.bfloat16)\n        DA_CS_Store = torch.empty((batch, nheads, seqlen), device=device, dtype=torch.float32)\n        DA_CS_SUM_Store = torch.zeros((batch, nheads, nchunks), device=device, dtype=torch.float32)\n        Out_v = torch.empty((batch, seqlen, nheads, headdim_v), device=device, dtype=V.dtype)\n    else:\n        SSM_States, DA_CS_Store, DA_CS_SUM_Store, Out_v = None, None, None, None\n    \n    Q_store = torch.empty((batch, seqlen, nheads, headdim_qk), device=device, dtype=Q.dtype)\n    K_store = torch.empty((batch, seqlen, nheads, headdim_qk), device=device, dtype=K.dtype)\n    QK_store = torch.empty((batch, nheads, seqlen), device=device, dtype=torch.float32)\n    Scale_store = torch.empty((batch, nheads, seqlen), device=device, dtype=torch.float32)\n    Gamma_store = torch.empty((batch, nheads, seqlen), device=device, dtype=torch.float32)\n    \n    if return_final_states:\n        Final_SSM_State = torch.empty((num_sequences, nheads, headdim_v, headdim_qk), device=device, dtype=torch.float32)\n        Final_K_State = torch.empty((num_sequences, nheads, chunk_size, headdim_qk), device=device, dtype=torch.float32)\n    else:\n        Final_SSM_State, Final_K_State = None, None\n\n    HEADDIM_V = triton.next_power_of_2(headdim_v)\n    HEADDIM_QK = triton.next_power_of_2(headdim_qk)\n\n    # Grid setup\n    if is_varlen:\n        grid = (nheads, batch, num_sequences) # batch = 1\n    else:\n        grid = (nheads, batch)\n\n    mamba3_siso_fwd_kernel[grid](\n        # Inputs\n        Q, K, V, ADT, DT, Trap, Q_bias, K_bias, Angles, D, Z, \n        Init_SSM_State, Init_K_State, Init_V_State, cu_seqlens,\n        # Outputs\n        Out, Out_v, SSM_States, DA_CS_Store, DA_CS_SUM_Store, \n        Q_store, K_store, QK_store, Scale_store, Gamma_store,\n        Final_SSM_State, Final_K_State,\n        # Input strides\n        Q.stride(0), Q.stride(1), Q.stride(2), Q.stride(3),\n        K.stride(0), K.stride(1), K.stride(2), K.stride(3),\n        V.stride(0), V.stride(1), V.stride(2), V.stride(3),\n        ADT.stride(0), ADT.stride(1), ADT.stride(2),\n        DT.stride(0), DT.stride(1), DT.stride(2),\n        Trap.stride(0), Trap.stride(1), Trap.stride(2),\n        Q_bias.stride(0), Q_bias.stride(1),\n        K_bias.stride(0), K_bias.stride(1),\n        Angles.stride(0), Angles.stride(1), Angles.stride(2), Angles.stride(3),\n        D.stride(0) if D is not None else 0,\n        Z.stride(0) if Z is not None else 0,\n        Z.stride(1) if Z is not None else 0,\n        Z.stride(2) if Z is not None else 0,\n        Z.stride(3) if Z is not None else 0,\n        Init_SSM_State.stride(0) if Init_SSM_State is not None else 0,\n        Init_SSM_State.stride(1) if Init_SSM_State is not None else 0,\n        Init_SSM_State.stride(2) if Init_SSM_State is not None else 0,\n        Init_SSM_State.stride(3) if Init_SSM_State is not None else 0,\n        Init_K_State.stride(0) if Init_K_State is not None else 0,\n        Init_K_State.stride(1) if Init_K_State is not None else 0,\n        Init_K_State.stride(2) if Init_K_State is not None else 0,\n        Init_V_State.stride(0) if Init_V_State is not None else 0,\n        Init_V_State.stride(1) if Init_V_State is not None else 0,\n        Init_V_State.stride(2) if Init_V_State is not None else 0,\n        cu_seqlens.stride(0) if cu_seqlens is not None else 0,\n        # Output strides\n        Out.stride(0), Out.stride(1), Out.stride(2), Out.stride(3),\n        Out_v.stride(0) if Out_v is not None else 0,\n        Out_v.stride(1) if Out_v is not None else 0,\n        Out_v.stride(2) if Out_v is not None else 0,\n        Out_v.stride(3) if Out_v is not None else 0,\n        SSM_States.stride(0) if SSM_States is not None else 0,\n        SSM_States.stride(1) if SSM_States is not None else 0,\n        SSM_States.stride(2) if SSM_States is not None else 0,\n        SSM_States.stride(3) if SSM_States is not None else 0,\n        DA_CS_Store.stride(0) if DA_CS_Store is not None else 0,\n        DA_CS_Store.stride(1) if DA_CS_Store is not None else 0,\n        DA_CS_Store.stride(2) if DA_CS_Store is not None else 0,\n        DA_CS_SUM_Store.stride(0) if DA_CS_SUM_Store is not None else 0,\n        DA_CS_SUM_Store.stride(1) if DA_CS_SUM_Store is not None else 0,\n        DA_CS_SUM_Store.stride(2) if DA_CS_SUM_Store is not None else 0,\n        Q_store.stride(0), Q_store.stride(1), Q_store.stride(2), Q_store.stride(3),\n        K_store.stride(0), K_store.stride(1), K_store.stride(2), K_store.stride(3),\n        QK_store.stride(0), QK_store.stride(1), QK_store.stride(2),\n        Scale_store.stride(0), Scale_store.stride(1), Scale_store.stride(2),\n        Gamma_store.stride(0), Gamma_store.stride(1), Gamma_store.stride(2),\n        Final_SSM_State.stride(0) if Final_SSM_State is not None else 0,\n        Final_SSM_State.stride(1) if Final_SSM_State is not None else 0,\n        Final_SSM_State.stride(2) if Final_SSM_State is not None else 0,\n        Final_SSM_State.stride(3) if Final_SSM_State is not None else 0,\n        Final_K_State.stride(0) if Final_K_State is not None else 0,\n        Final_K_State.stride(1) if Final_K_State is not None else 0,\n        Final_K_State.stride(2) if Final_K_State is not None else 0,\n        Final_K_State.stride(3) if Final_K_State is not None else 0,\n        # Dimensions\n        seqlen, nheads_qk, headdim_qk, headdim_v, headdim_angles,\n        # Compile-time constants\n        chunk_size,\n        HEADDIM_QK,\n        HEADDIM_V,\n        STORE_SSM_STATES_ADT_OUTV=store_states_adt_outv,\n        HAS_INITIAL_STATES=Initial_States is not None,\n        RETURN_FINAL_STATES=return_final_states,\n        HAS_D=D is not None,\n        HAS_Z=Z is not None,\n        IS_VARLEN=is_varlen,\n    )\n\n    Final_States = None\n    if return_final_states:\n        if is_varlen:\n            seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]\n            last_chunk_pos = (seq_lens - 1) % chunk_size \n\n            final_k = Final_K_State[\n                torch.arange(num_sequences, device=device),\n                :, \n                last_chunk_pos,\n                :\n            ]\n            \n            last_token_idx = cu_seqlens[1:] - 1\n            final_v = V[0, last_token_idx]\n        else:\n            k_state_idx = (seqlen - 1) % chunk_size\n            final_k = Final_K_State[:, :, k_state_idx, :]\n            final_v = V[:, -1]\n        \n        Final_States = (Final_SSM_State, final_k, final_v)\n\n    return (Out, Out_v, SSM_States, DA_CS_Store, DA_CS_SUM_Store, \n            Q_store, K_store, QK_store, Scale_store, Gamma_store, Final_States)"
  },
  {
    "path": "mamba_ssm/ops/triton/mamba3/mamba3_siso_step.py",
    "content": "\"\"\"\nMamba-3 Step Kernel.\n\nCopyright (c) 2025, Dao AI Lab, Goombalab\n\"\"\"\n\nfrom typing import Optional, Tuple\nimport math\n\nimport torch\n\nimport triton\nimport triton.language as tl\nfrom mamba_ssm.ops.triton.mamba3.utils import cos_approx, sin_approx, silu, tanh_approx, sigmoid_approx\n\n\n@triton.autotune(\n    configs=[\n        triton.Config({}, num_stages=s, num_warps=w)\n        for s in [1, 2, 3]\n        for w in [2, 4, 8]\n    ],\n    key=[\n        \"HEADDIM_QK\", \"HEADDIM_V\", \"HAS_D\", \"HAS_Z\",],\n)\n@triton.jit\ndef mamba3_siso_step_kernel(\n    # Inputs\n    Q, K, V, ADT, DT, Trap, Q_bias, K_bias, Angles, D, Z, Input_Angle_State, Input_SSM_State, Input_K_State, Input_V_State,\n    # Outputs\n    Out, Output_Angle_State, Output_SSM_State, Output_K_State,\n    # Input Strides\n    stride_q_batch, stride_q_head, stride_q_qkdim,\n    stride_k_batch, stride_k_head, stride_k_qkdim,\n    stride_v_batch, stride_v_head, stride_v_vdim,\n    stride_adt_batch, stride_adt_head,\n    stride_dt_batch, stride_dt_head,\n    stride_trap_batch, stride_trap_head,\n    stride_q_bias_head, stride_q_bias_qkdim,\n    stride_k_bias_head, stride_k_bias_qkdim,\n    stride_angles_batch, stride_angles_head, stride_angles_qkdim,\n    stride_d_head,\n    stride_z_batch, stride_z_head, stride_z_vdim,\n    stride_angle_state_batch, stride_angle_state_head, stride_angle_state_anglesdim,\n    stride_input_ssm_state_batch, stride_input_ssm_state_head, stride_input_ssm_state_vdim, \n    stride_input_ssm_state_qkdim,\n    stride_input_k_state_batch, stride_input_k_state_head, stride_input_k_state_qkdim,\n    stride_input_v_state_batch, stride_input_v_state_head, stride_input_v_state_vdim,\n    # Output Strides\n    stride_o_batch, stride_o_head, stride_o_vdim,\n    stride_output_angle_state_batch, stride_output_angle_state_head, stride_output_angle_state_anglesdim,\n    stride_output_ssm_state_batch, stride_output_ssm_state_head, stride_output_ssm_state_vdim, \n    stride_output_ssm_state_qkdim,\n    stride_output_k_state_batch, stride_output_k_state_head, stride_output_k_state_qkdim,\n    # Dimensions\n    nheads_qk,\n    HEADDIM_QK: tl.constexpr,\n    HEADDIM_V: tl.constexpr,\n    HEADDIM_ANGLES: tl.constexpr,\n    HAS_D: tl.constexpr,\n    HAS_Z: tl.constexpr,\n):\n    \"\"\"\n    Mamba-3 Step kernel.\n\n    Inputs:\n        Q, K:                       (batch, nheads_qk, headdim_qk)\n        V:                          (batch, nheads, headdim_v)  \n        ADT, DT, Trap:              (batch, nheads)\n        Q_bias, K_bias:             (nheads, headdim_qk)\n        Angles:                     (batch, nheads, headdim_angles)\n        D:                          (nheads,)\n        Z:                          (batch, nheads, headdim_v)\n        Out:                        (batch, nheads, headdim_v)\n        SSM_States:                 (batch, nheads, headdim_v, headdim_qk)\n        Input/Output Angle State:   (batch, nheads, headdim_angles)\n        Input/Output SSM State:     (batch, nheads, headdim_v, headdim_qk)\n        Input/Output K State:       (batch, nheads, headdim_qk)\n        Input/Output V State:       (batch, nheads, headdim_v)\n\n    Compile-time constants:\n        HEADDIM_QK:                 Head dimension for Q/K\n        HEADDIM_V:                  Head dimension for V\n        HEADDIM_ANGLES:             Head dimension for Angles\n        HAS_D:                      Whether D-skip connection is used\n        HAS_Z:                      Whether Z-gating is used\n\n    Outputs:\n        Out:                    (batch, nheads, headdim_v)\n        Output_Angle_State:     (batch, nheads, headdim_angles)\n        Output_SSM_State:       (batch, nheads, headdim_v, headdim_qk)\n        Output_K_State:         (batch, nheads, headdim_qk)\n    \"\"\"\n    # Program ID determines which (head, batch) pair this instance processes\n    pid_head = tl.program_id(0)\n    pid_batch = tl.program_id(1)\n\n    # Compute head index for Q/K (supports Grouped Query Attention)\n    nheads = tl.num_programs(0)\n    head_idx_qk = pid_head // (nheads // nheads_qk)\n\n    # Setup input pointers\n    q_ptr = Q + pid_batch * stride_q_batch + head_idx_qk * stride_q_head\n    k_ptr = K + pid_batch * stride_k_batch + head_idx_qk * stride_k_head\n    v_ptr = V + pid_batch * stride_v_batch + pid_head * stride_v_head\n    adt_ptr = ADT + pid_batch * stride_adt_batch + pid_head * stride_adt_head\n    dt_ptr = DT + pid_batch * stride_dt_batch + pid_head * stride_dt_head\n    trap_ptr = Trap + pid_batch * stride_trap_batch + pid_head * stride_trap_head\n    q_bias_ptr = Q_bias + pid_head * stride_q_bias_head\n    k_bias_ptr = K_bias + pid_head * stride_k_bias_head\n    angle_ptr = Angles + pid_batch * stride_angles_batch + pid_head * stride_angles_head\n    if HAS_D:\n        D_ptr = D + pid_head * stride_d_head\n        D_val = tl.load(D_ptr).to(tl.float32)\n    if HAS_Z:\n        z_ptr = Z + pid_batch * stride_z_batch + pid_head * stride_z_head\n    input_angle_state_ptr = Input_Angle_State + pid_batch * stride_angle_state_batch + pid_head * stride_angle_state_head\n    input_ssm_state_ptr = Input_SSM_State + pid_batch * stride_input_ssm_state_batch + pid_head * stride_input_ssm_state_head\n    input_k_state_ptr = Input_K_State + pid_batch * stride_input_k_state_batch + pid_head * stride_input_k_state_head\n    input_v_state_ptr = Input_V_State + pid_batch * stride_input_v_state_batch + pid_head * stride_input_v_state_head\n\n    # Setup output pointers\n    o_ptr = Out + pid_batch * stride_o_batch + pid_head * stride_o_head\n    output_angle_state_ptr = Output_Angle_State + pid_batch * stride_output_angle_state_batch + pid_head * stride_output_angle_state_head\n    output_ssm_state_ptr = Output_SSM_State + pid_batch * stride_output_ssm_state_batch + pid_head * stride_output_ssm_state_head\n    output_k_state_ptr = Output_K_State + pid_batch * stride_output_k_state_batch + pid_head * stride_output_k_state_head\n\n    PI = 3.141592653589793\n    TWO_PI = 2 * PI\n    offs_qk = tl.arange(0, HEADDIM_QK)\n    offs_v = tl.arange(0, HEADDIM_V)\n    offs_qkr = tl.arange(0, HEADDIM_QK // 2)\n\n    # Load Q and K blocks\n    q_pre_block = tl.load(q_ptr + offs_qk * stride_q_qkdim) # (HEADDIM_QK)\n    k_pre_block = tl.load(k_ptr + offs_qk * stride_k_qkdim) # (HEADDIM_QK)\n\n    # Load Q and K biases\n    q_bias_block = tl.load(q_bias_ptr + offs_qk * stride_q_bias_qkdim) # (HEADDIM_QK)\n    k_bias_block = tl.load(k_bias_ptr + offs_qk * stride_k_bias_qkdim) # (HEADDIM_QK)\n\n    q_pre_block += q_bias_block\n    k_pre_block += k_bias_block\n\n    # Load rotary angles (smaller block, direct load is faster than TMA)\n    dt = tl.load(dt_ptr)\n    angle_block = tl.load(\n        angle_ptr + offs_qkr * stride_angles_qkdim, mask=offs_qkr < HEADDIM_ANGLES, other=0.0\n    ) # (HEADDIM_QK)\n    angle_block = tanh_approx(angle_block.to(tl.float32)) * PI * dt\n    angle_state = tl.load(\n        input_angle_state_ptr + offs_qkr * stride_angle_state_anglesdim, mask=offs_qkr < HEADDIM_ANGLES, other=0.0\n    ) # (HEADDIM_QK)\n\n    angle_block += angle_state\n    angle_block -= TWO_PI * tl.floor(angle_block / TWO_PI)\n    # angles mod 2pi\n\n    tl.store(output_angle_state_ptr + offs_qkr * stride_output_angle_state_anglesdim, angle_block, mask=offs_qkr < HEADDIM_ANGLES)\n\n    # Rotate Q and K with angles\n    cos_block = cos_approx(angle_block.to(tl.float32))\n    sin_block = sin_approx(angle_block.to(tl.float32))\n\n    # Apply rotary embeddings to K and scale\n    q0, q1 = tl.split(tl.reshape(q_pre_block, [HEADDIM_QK // 2, 2]))\n    qo0 = q0 * cos_block - q1 * sin_block\n    qo1 = q0 * sin_block + q1 * cos_block\n    q_block = tl.reshape(tl.join(qo0, qo1), [HEADDIM_QK]).to(q_pre_block.dtype)\n\n    k0, k1 = tl.split(tl.reshape(k_pre_block, [HEADDIM_QK // 2, 2]))\n    ko0 = k0 * cos_block - k1 * sin_block\n    ko1 = k0 * sin_block + k1 * cos_block\n    k_block = tl.reshape(tl.join(ko0, ko1), [HEADDIM_QK]).to(k_pre_block.dtype)\n\n    # Store K state\n    tl.store(output_k_state_ptr + offs_qk * stride_output_k_state_qkdim, k_block)\n\n    # Load previous K, V and current V\n    k_prev_state = tl.load(input_k_state_ptr + offs_qk * stride_input_k_state_qkdim) # (HEADDIM_QK)\n    v_prev_state = tl.load(input_v_state_ptr + offs_v * stride_input_v_state_vdim) # (HEADDIM_V)\n    v_block = tl.load(v_ptr + offs_v * stride_v_vdim) # (HEADDIM_V)\n        \n    # Load ADT, DT and Trap\n    adt = tl.load(adt_ptr) * 1.44269504089\n    trap = tl.load(trap_ptr)\n    trap = sigmoid_approx(trap.to(tl.float32))\n\n    alpha = tl.math.exp2(adt)\n    beta = alpha * dt * (1 - trap)\n    gamma = trap * dt\n\n    ssm_state_diff = (beta * v_prev_state)[:, None] * k_prev_state[None, :] + (gamma * v_block)[:, None] * k_block[None, :]\n\n    # Load previous SSM state\n    ssm_state = tl.load(\n        input_ssm_state_ptr + offs_v[:, None] * stride_input_ssm_state_vdim \n        + offs_qk[None, :] * stride_input_ssm_state_qkdim).to(tl.float32) # (HEADDIM_V, HEADDIM_QK)\n    \n    ssm_state = ssm_state * alpha + ssm_state_diff\n\n    # Store updated SSM state\n    tl.store(output_ssm_state_ptr + offs_v[:, None] * stride_output_ssm_state_vdim \n        + offs_qk[None, :] * stride_output_ssm_state_qkdim, ssm_state)\n\n    # Compute output\n    out = tl.dot(ssm_state.to(tl.bfloat16), q_block.reshape([HEADDIM_QK, 1]).to(tl.bfloat16)) # (HEADDIM_V, 1)\n    out = out.reshape([HEADDIM_V]).to(tl.float32)\n\n    # out = tl.sum(ssm_state * q_block[None, :], axis=1)  # (HEADDIM_V,)\n\n    # Add D-skip connection\n    if HAS_D:\n        out += D_val * v_block\n\n    # Apply Z-gating\n    if HAS_Z:\n        z_block = tl.load(z_ptr + offs_v * stride_z_vdim) # (HEADDIM_V)\n        out = out * silu(z_block.to(tl.float32))\n    \n    # Store output\n    tl.store(o_ptr + offs_v * stride_o_vdim, out)\n\n\n\n\n# Memory Allocator for TMA Descriptors\ndef _alloc_fn(size: int, alignment: int, stream: Optional[int]):\n    \"\"\"Custom allocator for TMA descriptor global memory allocation.\"\"\"\n    return torch.empty(size, device=\"cuda\", dtype=torch.int8)\ntriton.set_allocator(_alloc_fn)\n\ndef mamba3_siso_step(\n    Q: torch.Tensor,\n    K: torch.Tensor,\n    V: torch.Tensor,\n    ADT: torch.Tensor,\n    DT: torch.Tensor,\n    Trap: torch.Tensor,\n    Q_bias: torch.Tensor,\n    K_bias: torch.Tensor,\n    Angles: torch.Tensor,\n    D: Optional[torch.Tensor] = None,\n    Z: Optional[torch.Tensor] = None,\n    Input_States: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None,\n):\n    \"\"\"\n    Mamba-3 step wrapper.\n    \n    Inputs:\n        Q: Query tensor             (batch, nheads_qk, headdim_qk).\n        K: Key tensor               (batch, nheads_qk, headdim_qk).\n        V: Value tensor             (batch, nheads, headdim_v).\n        ADT: Decay tensor           (batch, nheads).\n        DT: DT tensor               (batch, nheads).\n        Trap: Trap tensor           (batch, nheads).\n        Q_bias: Query bias          (nheads, headdim_qk).\n        K_bias: Key bias            (nheads, headdim_qk).\n        Angles: Rotary angles       (batch, nheads, headdim_angles)\n            - headdim_angles <= headdim_qk // 2 and headdim_angles % 2 == 0.\n        D: Skip connection weight   (nheads,).\n        Z: Gating tensor of shape   (batch, nheads, headdim_v).\n            - Applies SiLU gating: out = out * silu(Z).\n        Input_States: Tuple of (Angle State SSM State, K state, V state)\n            Angle state shape:      (batch, nheads, headdim_angles).\n            SSM state shape:        (batch, nheads, headdim_v, headdim_qk).\n            K state shape:          (batch, nheads, headdim_qk).\n            V state shape:          (batch, nheads, headdim_v).\n\n    NOTE: nheads % nheads_qk == 0\n    \n    Outputs:\n        Out: Output tensor                      (batch, nheads, headdim_v)\n        Output_States: Final output state (None if return_output_state=False)\n            - Output_Angle_State: Angle State   (batch, nheads, headdim_angles)\n            - Output_SSM_State: SSM State       (batch, nheads, headdim_v, headdim_qk)\n            - K_State: K state                  (batch, nheads, headdim_qk)\n            - V_State: V state                  (batch, nheads, headdim_v)\n    \"\"\"\n    # Get dimensions\n    batch, nheads_qk, headdim_qk = Q.shape\n    _, nheads, headdim_v = V.shape\n    device = Q.device\n\n    # Validate input shapes\n    assert Q.shape == K.shape, f\"Q and K shape mismatch: {Q.shape} vs {K.shape}\"\n    assert nheads % nheads_qk == 0, f\"nheads ({nheads}) must be divisible by nheads_qk ({nheads_qk})\"\n    assert ADT.shape == (batch, nheads), f\"ADT shape mismatch: expected {(batch, nheads)}, got {ADT.shape}\"\n    assert DT.shape == (batch, nheads), f\"DT shape mismatch: expected {(batch, nheads)}, got {DT.shape}\"\n    assert Trap.shape == (batch, nheads), f\"Trap shape mismatch: expected {(batch, nheads)}, got {Trap.shape}\"\n    assert Q_bias.shape == (nheads, headdim_qk), f\"Q_bias shape mismatch: expected {(nheads, headdim_qk)}, got {Q_bias.shape}\"\n    assert K_bias.shape == (nheads, headdim_qk), f\"K_bias shape mismatch: expected {(nheads, headdim_qk)}, got {K_bias.shape}\"\n    headdim_angles = Angles.shape[-1]\n    assert headdim_angles <= headdim_qk // 2 and headdim_angles % 2 == 0, f\"headdim_angles ({headdim_angles}) must be <= headdim_qk // 2 ({headdim_qk // 2}) and even.\"\n    assert Angles.shape == (batch, nheads, headdim_angles), f\"Angles shape mismatch: expected {(batch, nheads, headdim_angles)}, got {Angles.shape}\"\n    \n    if D is not None:\n        assert D.shape == (nheads,), f\"D shape mismatch: expected {(nheads,)}, got {D.shape}\"\n    if Z is not None:\n        assert Z.shape == (batch, nheads, headdim_v), f\"Z shape mismatch: expected {(batch, nheads, headdim_v)}, got {Z.shape}\"\n\n    Input_Angle_State, Input_SSM_State, Input_K_State, Input_V_State = Input_States\n    assert Input_Angle_State.shape == (batch, nheads, headdim_angles), f\"Input_Angle_State shape mismatch: expected {(batch, nheads, headdim_angles)}, got {Input_Angle_State.shape}\"\n    assert Input_SSM_State.shape == (batch, nheads, headdim_v, headdim_qk), f\"Input_SSM_State shape mismatch: expected {(batch, nheads, headdim_v, headdim_qk)}, got {Input_SSM_State.shape}\"\n    assert Input_K_State.shape == (batch, nheads, headdim_qk), f\"Input_K_State shape mismatch: expected {(batch, nheads, headdim_qk)}, got {Input_K_State.shape}\"\n    assert Input_V_State.shape == (batch, nheads, headdim_v), f\"Input_V_State shape mismatch: expected {(batch, nheads, headdim_v)}, got {Input_V_State.shape}\"\n        \n    # Ensure all tensors are contiguous\n    Q = Q.contiguous() if not Q.is_contiguous() else Q\n    K = K.contiguous() if not K.is_contiguous() else K\n    V = V.contiguous() if not V.is_contiguous() else V\n    ADT = ADT.contiguous() if not ADT.is_contiguous() else ADT\n    DT = DT.contiguous() if not DT.is_contiguous() else DT\n    Trap = Trap.contiguous() if not Trap.is_contiguous() else Trap\n    Q_bias = Q_bias.contiguous() if not Q_bias.is_contiguous() else Q_bias\n    K_bias = K_bias.contiguous() if not K_bias.is_contiguous() else K_bias\n    Angles = Angles.contiguous() if not Angles.is_contiguous() else Angles\n    \n    if D is not None:\n        D = D.contiguous() if not D.is_contiguous() else D\n    if Z is not None:\n        Z = Z.contiguous() if not Z.is_contiguous() else Z\n    if Input_States is not None:\n        Input_Angle_State = Input_Angle_State.contiguous() if not Input_Angle_State.is_contiguous() else Input_Angle_State\n        Input_SSM_State = Input_SSM_State.contiguous() if not Input_SSM_State.is_contiguous() else Input_SSM_State\n        Input_K_State = Input_K_State.contiguous() if not Input_K_State.is_contiguous() else Input_K_State\n        Input_V_State = Input_V_State.contiguous() if not Input_V_State.is_contiguous() else Input_V_State\n\n    # Allocate output tensors\n    Out = torch.empty((batch, nheads, headdim_v), device=device, dtype=V.dtype)\n    Output_Angle_State = torch.empty((batch, nheads, headdim_angles), device=device, dtype=torch.float32)\n    Output_SSM_State = torch.empty((batch, nheads, headdim_v, headdim_qk), device=device, dtype=torch.float32)\n    Output_K_State = torch.empty((batch, nheads, headdim_qk), device=device, dtype=torch.float32)\n    \n    grid = (nheads, batch)\n    mamba3_siso_step_kernel[grid](\n        # Inputs\n        Q, K, V, ADT, DT, Trap, Q_bias, K_bias, Angles, D, Z, Input_Angle_State, Input_SSM_State, \n        Input_K_State, Input_V_State,\n        # Outputs\n        Out, Output_Angle_State, Output_SSM_State, Output_K_State, \n        # Input strides\n        Q.stride(0), Q.stride(1), Q.stride(2),\n        K.stride(0), K.stride(1), K.stride(2),\n        V.stride(0), V.stride(1), V.stride(2),\n        ADT.stride(0), ADT.stride(1),\n        DT.stride(0), DT.stride(1),\n        Trap.stride(0), Trap.stride(1),\n        Q_bias.stride(0), Q_bias.stride(1),\n        K_bias.stride(0), K_bias.stride(1),\n        Angles.stride(0), Angles.stride(1), Angles.stride(2),\n        D.stride(0) if D is not None else 0,\n        Z.stride(0) if Z is not None else 0,\n        Z.stride(1) if Z is not None else 0,\n        Z.stride(2) if Z is not None else 0,\n        Input_Angle_State.stride(0), Input_Angle_State.stride(1), Input_Angle_State.stride(2),\n        Input_SSM_State.stride(0), Input_SSM_State.stride(1), Input_SSM_State.stride(2), Input_SSM_State.stride(3),\n        Input_K_State.stride(0), Input_K_State.stride(1), Input_K_State.stride(2),\n        Input_V_State.stride(0), Input_V_State.stride(1), Input_V_State.stride(2),\n        # Output strides\n        Out.stride(0), Out.stride(1), Out.stride(2),\n        Output_Angle_State.stride(0), Output_Angle_State.stride(1), Output_Angle_State.stride(2),\n        Output_SSM_State.stride(0), Output_SSM_State.stride(1), Output_SSM_State.stride(2), Output_SSM_State.stride(3),\n        Output_K_State.stride(0), Output_K_State.stride(1), Output_K_State.stride(2),\n        # Dimensions\n        nheads_qk,\n        # Compile-time constants\n        headdim_qk,\n        headdim_v,\n        headdim_angles,\n        HAS_D=D is not None,\n        HAS_Z=Z is not None,\n    )\n\n    Output_States = [Output_Angle_State, Output_SSM_State, Output_K_State, V]\n\n    return Out, Output_States"
  },
  {
    "path": "mamba_ssm/ops/triton/mamba3/utils.py",
    "content": "\"\"\"\nMamba-3 Util Functions.\n\nCopyright (c) 2025, Dao AI Lab, Goombalab\n\"\"\"\n\nimport triton\nimport triton.language as tl\n\n# We use PTX approximations instead of triton built-in functions\n# to trade off a bit of accuracy for much faster speed.\n\n@triton.jit\ndef cos_approx(x):\n    \"\"\"\n    (Fast) Cosine approximation using PTX inline assembly.\n\n    Args:\n        x: Input triton tensor (any shape) in float32\n    Returns:\n        Approximate cosine values in float32\n    \"\"\"\n    return tl.inline_asm_elementwise(\n        \"cos.approx.f32 $0, $1;\",\n        constraints=\"=f,f\",\n        args=[x],\n        dtype=tl.float32,\n        is_pure=True,\n        pack=1,\n    )\n\n\n@triton.jit\ndef sin_approx(x):\n    \"\"\"\n    (Fast) Sine approximation using PTX inline assembly.\n\n    Args:\n        x: Input triton tensor (any shape) in float32\n    Returns:\n        Approximate sine values in float32\n    \"\"\"\n    return tl.inline_asm_elementwise(\n        \"sin.approx.f32 $0, $1;\",\n        constraints=\"=f,f\",\n        args=[x],\n        dtype=tl.float32,\n        is_pure=True,\n        pack=1,\n    )\n\n@triton.jit\ndef tanh_approx(x):\n    \"\"\"\n    (Fast) hyperbolic tangent approximation using PTX inline assembly.\n\n    Args:\n        x: Input triton tensor (any shape) in float32\n    Returns:\n        Approximate tanh values in float32\n    \"\"\"\n    return tl.inline_asm_elementwise(\n        \"tanh.approx.f32 $0, $1;\",\n        constraints=\"=f,f\",\n        args=[x],\n        dtype=tl.float32,\n        is_pure=True,\n        pack=1,\n    )\n\n@triton.jit\ndef sech2_approx(x):\n    \"\"\"\n    (Fast) square of the hyperbolic secant approximation using PTX inline assembly.\n\n    Args:\n        x: Input triton tensor (any shape) in float32\n    Returns:\n        Approximate sech^2 values in float32\n    \"\"\"\n    tanh_x = tl.inline_asm_elementwise(\n        \"tanh.approx.f32 $0, $1;\",\n        constraints=\"=f,f\",\n        args=[x],\n        dtype=tl.float32,\n        is_pure=True,\n        pack=1,\n    )\n    return 1.0 - tanh_x * tanh_x\n\n@triton.jit\ndef sigmoid_approx(x):\n    \"\"\"\n    (Fast) Sigmoid approximation using PTX inline assembly.\n\n    Formula: sigmoid(x) = 0.5 * (1 + tanh(0.5 * x))\n    Leverages fast tanh approximation for speed.\n\n    Args:\n        x: Input triton tensor (any shape) in float32\n    Returns:\n        Approximate sigmoid values in float32\n    \"\"\"\n    # tanh_half_x = tl.inline_asm_elementwise(\n    #     \"tanh.approx.f32 $0, $1;\",\n    #     constraints=\"=f,f\",\n    #     args=[0.5 * x],\n    #     dtype=tl.float32,\n    #     is_pure=True,\n    #     pack=1,\n    # )\n    # return 0.5 * (1.0 + tanh_half_x)\n    # NOTE: We ended up using the built-in sigmoid for better performance, as the PTX approximation was not faster in this case.\n    return tl.sigmoid(x)\n\n@triton.jit\ndef silu(x):\n    \"\"\"\n    SiLU (Swish) activation function: x * sigmoid(x).\n\n    Formula: silu(x) = 0.5*x * (1 + tanh(0.5*x)) + 0.5*x.\n    Leverages fast tanh_approx for speed.\n    \n    Args:\n        x: Input triton tensor (any shape) in float32\n    \n    Returns:\n        SiLU activation output in float32\n    \"\"\"\n    # x_half = 0.5 * x\n    # return x_half * tanh_approx(x_half) + x_half\n    # NOTE: We ended up using the built-in sigmoid for better performance, as the PTX approximation was not faster in this case.\n    return x*tl.sigmoid(x)"
  },
  {
    "path": "mamba_ssm/ops/triton/selective_state_update.py",
    "content": "# Copyright (c) 2024, Tri Dao, Albert Gu.\n\n\"\"\"We want triton==2.1.0 or triton==2.2.0 or triton==2.3.0 for this\n\"\"\"\n\nimport math\nimport torch\nimport torch.nn.functional as F\n\nimport triton\nimport triton.language as tl\n\nfrom einops import rearrange, repeat\n\nfrom mamba_ssm.ops.triton.softplus import softplus\n\n\n@triton.heuristics({\"HAS_DT_BIAS\": lambda args: args[\"dt_bias_ptr\"] is not None})\n@triton.heuristics({\"HAS_D\": lambda args: args[\"D_ptr\"] is not None})\n@triton.heuristics({\"HAS_Z\": lambda args: args[\"z_ptr\"] is not None})\n@triton.heuristics({\"HAS_STATE_BATCH_INDICES\": lambda args: args[\"state_batch_indices_ptr\"] is not None})\n@triton.heuristics({\"BLOCK_SIZE_DSTATE\": lambda args: triton.next_power_of_2(args[\"dstate\"])})\n@triton.jit\ndef _selective_scan_update_kernel(\n    # Pointers to matrices\n    state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr, state_batch_indices_ptr,\n    # Matrix dimensions\n    batch, nheads, dim, dstate, nheads_ngroups_ratio,\n    # Strides\n    stride_state_batch, stride_state_head, stride_state_dim, stride_state_dstate,\n    stride_x_batch, stride_x_head, stride_x_dim,\n    stride_dt_batch, stride_dt_head, stride_dt_dim,\n    stride_dt_bias_head, stride_dt_bias_dim,\n    stride_A_head, stride_A_dim, stride_A_dstate,\n    stride_B_batch, stride_B_group, stride_B_dstate,\n    stride_C_batch, stride_C_group, stride_C_dstate,\n    stride_D_head, stride_D_dim,\n    stride_z_batch, stride_z_head, stride_z_dim,\n    stride_out_batch, stride_out_head, stride_out_dim,\n    # Meta-parameters\n    DT_SOFTPLUS: tl.constexpr,\n    TIE_HDIM: tl.constexpr,\n    BLOCK_SIZE_M: tl.constexpr,\n    HAS_DT_BIAS: tl.constexpr,\n    HAS_D: tl.constexpr,\n    HAS_Z: tl.constexpr,\n    HAS_STATE_BATCH_INDICES: tl.constexpr,\n    BLOCK_SIZE_DSTATE: tl.constexpr,\n):\n    pid_m = tl.program_id(axis=0)\n    pid_b = tl.program_id(axis=1)\n    pid_h = tl.program_id(axis=2)\n\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head\n    out_ptrs = out_ptr + offs_m * stride_out_dim\n\n    if HAS_STATE_BATCH_INDICES:\n        state_batch_indices_ptr += pid_b\n        state_batch_idx = tl.load(state_batch_indices_ptr)\n        # Skip padding tokens\n        if state_batch_idx < 0:\n            tl.store(out_ptrs, 0.0, mask=offs_m < dim)\n            return\n        state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head\n    else:\n        state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head\n\n    x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head\n    dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head\n    if HAS_DT_BIAS:\n        dt_bias_ptr += pid_h * stride_dt_bias_head\n    A_ptr += pid_h * stride_A_head\n    B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group\n    C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group\n    if HAS_Z:\n        z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head\n\n    offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)\n    state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate)\n    x_ptrs = x_ptr + offs_m * stride_x_dim\n    dt_ptrs = dt_ptr + offs_m * stride_dt_dim\n    if HAS_DT_BIAS:\n        dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim\n    if HAS_D:\n        D_ptr += pid_h * stride_D_head\n    A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate)\n    B_ptrs = B_ptr + offs_n * stride_B_dstate\n    C_ptrs = C_ptr + offs_n * stride_C_dstate\n    if HAS_D:\n        D_ptrs = D_ptr + offs_m * stride_D_dim\n    if HAS_Z:\n        z_ptrs = z_ptr + offs_m * stride_z_dim\n\n    state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0)\n    x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n    if not TIE_HDIM:\n        dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n        if HAS_DT_BIAS:\n            dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n        if DT_SOFTPLUS:\n            dt = tl.where(dt <= 20.0, softplus(dt), dt)\n        A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)\n        dA = tl.exp(A * dt[:, None])\n    else:\n        dt = tl.load(dt_ptr).to(tl.float32)\n        if HAS_DT_BIAS:\n            dt += tl.load(dt_bias_ptr).to(tl.float32)\n        if DT_SOFTPLUS:\n            dt = tl.where(dt <= 20.0, softplus(dt), dt)\n        A = tl.load(A_ptr).to(tl.float32)\n        dA = tl.exp(A * dt)  # scalar, not a matrix\n\n    B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)\n    C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)\n    if HAS_D:\n        D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n    if HAS_Z:\n        z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n\n    if not TIE_HDIM:\n        dB = B[None, :] * dt[:, None]\n    else:\n        dB = B * dt  # vector of size (dstate,)\n    state = state * dA + dB * x[:, None]\n    tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))\n    out = tl.sum(state * C[None, :], axis=1)\n    if HAS_D:\n        out += x * D\n    if HAS_Z:\n        out *= z * tl.sigmoid(z)\n    tl.store(out_ptrs, out, mask=offs_m < dim)\n\n\ndef selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False,\n                           state_batch_indices=None):\n    \"\"\"\n    Argument:\n        state: (batch, dim, dstate) or (batch, nheads, dim, dstate)\n        x: (batch, dim) or (batch, nheads, dim)\n        dt: (batch, dim) or (batch, nheads, dim)\n        A: (dim, dstate) or (nheads, dim, dstate)\n        B: (batch, dstate) or (batch, ngroups, dstate)\n        C: (batch, dstate) or (batch, ngroups, dstate)\n        D: (dim,) or (nheads, dim)\n        z: (batch, dim) or (batch, nheads, dim)\n        dt_bias: (dim,) or (nheads, dim)\n    Return:\n        out: (batch, dim) or (batch, nheads, dim)\n    \"\"\"\n    has_heads = state.dim() > 3\n    if state.dim() == 3:\n        state = state.unsqueeze(1)\n    if x.dim() == 2:\n        x = x.unsqueeze(1)\n    if dt.dim() == 2:\n        dt = dt.unsqueeze(1)\n    if A.dim() == 2:\n        A = A.unsqueeze(0)\n    if B.dim() == 2:\n        B = B.unsqueeze(1)\n    if C.dim() == 2:\n        C = C.unsqueeze(1)\n    if D is not None and D.dim() == 1:\n        D = D.unsqueeze(0)\n    if z is not None and z.dim() == 2:\n        z = z.unsqueeze(1)\n    if dt_bias is not None and dt_bias.dim() == 1:\n        dt_bias = dt_bias.unsqueeze(0)\n    _, nheads, dim, dstate = state.shape\n    batch = x.shape[0]\n    if x.shape != (batch, nheads, dim):\n        print(f\"{state.shape} {x.shape} {batch} {nheads} {dim}\")\n    assert x.shape == (batch, nheads, dim)\n    assert dt.shape == x.shape\n    assert A.shape == (nheads, dim, dstate)\n    ngroups = B.shape[1]\n    assert nheads % ngroups == 0, \"nheads must be divisible by ngroups\"\n    assert B.shape == (batch, ngroups, dstate)\n    assert C.shape == B.shape\n    if D is not None:\n        assert D.shape == (nheads, dim)\n    if z is not None:\n        assert z.shape == x.shape\n    if dt_bias is not None:\n        assert dt_bias.shape == (nheads, dim)\n    if state_batch_indices is not None:\n        assert state_batch_indices.shape == (batch,)\n    out = torch.empty_like(x)\n    grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)\n    z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0))\n    # We don't want autotune since it will overwrite the state\n    # We instead tune by hand.\n    BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16\n                               else ((16, 4) if dstate <= 32 else\n                                     ((8, 4) if dstate <= 64 else\n                                      ((4, 4) if dstate <= 128 else\n                                       ((4, 8))))))\n    tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride(-1) == 0 and dt_bias.stride(-1) == 0\n    with torch.cuda.device(x.device.index):\n        _selective_scan_update_kernel[grid](\n            state, x, dt, dt_bias, A, B, C, D, z, out, state_batch_indices,\n            batch, nheads, dim, dstate, nheads // ngroups,\n            state.stride(0), state.stride(1), state.stride(2), state.stride(3),\n            x.stride(0), x.stride(1), x.stride(2),\n            dt.stride(0), dt.stride(1), dt.stride(2),\n            *(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0,\n            A.stride(0), A.stride(1), A.stride(2),\n            B.stride(0), B.stride(1), B.stride(2),\n            C.stride(0), C.stride(1), C.stride(2),\n            *(D.stride(0), D.stride(1)) if D is not None else 0,\n            z_strides[0], z_strides[1], z_strides[2],\n            out.stride(0), out.stride(1), out.stride(2),\n            dt_softplus,\n            tie_hdim,\n            BLOCK_SIZE_M,\n            num_warps=num_warps,\n        )\n    if not has_heads:\n        out = out.squeeze(1)\n    return out\n\n\ndef selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):\n    \"\"\"\n    Argument:\n        state: (batch, dim, dstate) or (batch, nheads, dim, dstate)\n        x: (batch, dim) or (batch, nheads, dim)\n        dt: (batch, dim) or (batch, nheads, dim)\n        A: (dim, dstate) or (nheads, dim, dstate)\n        B: (batch, dstate) or (batch, ngroups, dstate)\n        C: (batch, dstate) or (batch, ngroups, dstate)\n        D: (dim,) or (nheads, dim)\n        z: (batch, dim) or (batch, nheads, dim)\n        dt_bias: (dim,) or (nheads, dim)\n    Return:\n        out: (batch, dim) or (batch, nheads, dim)\n    \"\"\"\n    has_heads = state.dim() > 3\n    if state.dim() == 3:\n        state = state.unsqueeze(1)\n    if x.dim() == 2:\n        x = x.unsqueeze(1)\n    if dt.dim() == 2:\n        dt = dt.unsqueeze(1)\n    if A.dim() == 2:\n        A = A.unsqueeze(0)\n    if B.dim() == 2:\n        B = B.unsqueeze(1)\n    if C.dim() == 2:\n        C = C.unsqueeze(1)\n    if D is not None and D.dim() == 1:\n        D = D.unsqueeze(0)\n    if z is not None and z.dim() == 2:\n        z = z.unsqueeze(1)\n    if dt_bias is not None and dt_bias.dim() == 1:\n        dt_bias = dt_bias.unsqueeze(0)\n    batch, nheads, dim, dstate = state.shape\n    assert x.shape == (batch, nheads, dim)\n    assert dt.shape == x.shape\n    assert A.shape == (nheads, dim, dstate)\n    ngroups = B.shape[1]\n    assert nheads % ngroups == 0, \"nheads must be divisible by ngroups\"\n    assert B.shape == (batch, ngroups, dstate)\n    assert C.shape == B.shape\n    if D is not None:\n        assert D.shape == (nheads, dim)\n    if z is not None:\n        assert z.shape == x.shape\n    if dt_bias is not None:\n        assert dt_bias.shape == (nheads, dim)\n        dt = dt + dt_bias\n    dt = F.softplus(dt) if dt_softplus else dt\n    dA = torch.exp(rearrange(dt, \"b h d -> b h d 1\") * A)  # (batch, nheads, dim, dstate)\n    B = repeat(B, \"b g n -> b (g h) n\", h=nheads // ngroups)  # (batch, nheads, dstate)\n    C = repeat(C, \"b g n -> b (g h) n\", h=nheads // ngroups)  # (batch, nheads, dstate)\n    dB = rearrange(dt, \"b h d -> b h d 1\") * rearrange(B, \"b h n -> b h 1 n\")  # (batch, nheads, dim, dstate)\n    state.copy_(state * dA + dB * rearrange(x, \"b h d -> b h d 1\"))  # (batch, dim, dstate\n    out = torch.einsum(\"bhdn,bhn->bhd\", state.to(C.dtype), C)\n    if D is not None:\n        out += (x * D).to(out.dtype)\n    out = (out if z is None else out * F.silu(z)).to(x.dtype)\n    if not has_heads:\n        out = out.squeeze(1)\n    return out\n"
  },
  {
    "path": "mamba_ssm/ops/triton/softplus.py",
    "content": "import triton\nimport triton.language as tl\nfrom packaging import version\n\nTRITON3 = version.parse(triton.__version__) >= version.parse(\"3.0.0\")\n\n\nif TRITON3:\n    @triton.jit\n    def softplus(dt):\n        return tl.math.log(tl.math.exp(dt) + 1)\nelse:\n    @triton.jit\n    def softplus(dt):\n        return tl.math.log1p(tl.exp(dt))"
  },
  {
    "path": "mamba_ssm/ops/triton/ssd_bmm.py",
    "content": "# Copyright (c) 2024, Tri Dao, Albert Gu.\n\n\"\"\"We want triton==2.1.0 or 2.2.0 for this\n\"\"\"\n\nimport math\nimport torch\nimport torch.nn.functional as F\n\nimport triton\nimport triton.language as tl\n\nfrom einops import rearrange, repeat\n\nfrom mamba_ssm.utils.determinism import autotune_configs\n\n\ndef init_to_zero(names):\n    return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None]\n\n\n@triton.autotune(\n    configs=autotune_configs([\n        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),\n    ]),\n    key=['chunk_size', 'K', 'IS_CAUSAL'],\n)\n@triton.jit\ndef _bmm_chunk_fwd_kernel(\n    # Pointers to matrices\n    a_ptr, b_ptr, out_ptr, seq_idx_ptr,\n    # Matrix dimensions\n    seqlen, chunk_size, K, ngroups,\n    stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,\n    stride_b_batch, stride_b_seqlen, stride_b_head, stride_bk,\n    stride_out_batch, stride_out_chunk, stride_out_head, stride_outm, stride_outn,\n    stride_seq_idx_batch, stride_seq_idx_seqlen,\n    # Meta-parameters\n    IS_CAUSAL: tl.constexpr,\n    dot_dtype: tl.constexpr,\n    HAS_SEQ_IDX: tl.constexpr,\n    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n):\n    pid_b = tl.program_id(axis=1)\n    pid_ch = tl.program_id(axis=2)\n    pid_c = pid_ch // ngroups\n    pid_h = pid_ch - pid_c * ngroups\n    num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N)\n    pid_m = tl.program_id(axis=0) // num_pid_n\n    pid_n = tl.program_id(axis=0) % num_pid_n\n    if IS_CAUSAL:\n        if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M:\n            return\n    a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head\n    b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head\n    if HAS_SEQ_IDX:\n        seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen\n\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    offs_k = tl.arange(0, BLOCK_SIZE_K)\n    a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak)\n    b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen)\n    chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n\n    acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n        a = tl.load(a_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0).to(dot_dtype)\n        b = tl.load(b_ptrs, mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_n[None, :] < chunk_size_limit), other=0.0).to(dot_dtype)\n        acc += tl.dot(a, b)\n        a_ptrs += BLOCK_SIZE_K * stride_ak\n        b_ptrs += BLOCK_SIZE_K * stride_bk\n\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    if HAS_SEQ_IDX:\n        chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n        seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)\n        seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, mask=offs_n < chunk_size_limit, other=-2)\n        acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0)\n    out = acc.to(out_ptr.dtype.element_ty)\n\n    out_ptr += pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head\n    out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn)\n    tl.store(out_ptrs, out, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size))\n\n\n@triton.autotune(\n    configs=autotune_configs([\n        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 64}, num_stages=3, num_warps=8),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),\n        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),\n        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),\n        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2),\n        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=2),\n    ]),\n    key=['chunk_size', 'K'],\n)\n@triton.jit\ndef _bmm_chunk_bwd_kernel(\n    # Pointers to matrices\n    a_ptr, dout_ptr, db_ptr, res_ptr,\n    # Matrix dimensions\n    seqlen, chunk_size, K, ngroups,\n    stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,\n    stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_csize_m, stride_dout_csize_n,\n    stride_db_batch, stride_db_seqlen, stride_db_head, stride_db_k,\n    stride_res_batch, stride_res_seqlen, stride_res_head, stride_res_k,\n    # Meta-parameters\n    dot_dtype: tl.constexpr,\n    HAS_RESIDUAL: tl.constexpr,\n    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_CS: tl.constexpr,\n):\n    pid_b = tl.program_id(axis=1)\n    pid_ch = tl.program_id(axis=2)\n    pid_c = pid_ch // ngroups\n    pid_h = pid_ch - pid_c * ngroups\n    num_pid_n = tl.cdiv(K, BLOCK_SIZE_N)\n    pid_m = tl.program_id(axis=0) // num_pid_n\n    pid_n = tl.program_id(axis=0) % num_pid_n\n\n    a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head\n    dout_ptr += pid_b * stride_dout_batch + pid_c * stride_dout_chunk + pid_h * stride_dout_head\n\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    offs_cs = tl.arange(0, BLOCK_SIZE_CS)\n    dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_csize_n + offs_cs[None, :] * stride_dout_csize_m)\n    a_ptrs = a_ptr + (offs_cs[:, None] * stride_a_seqlen + offs_n[None, :] * stride_ak)\n    chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n\n    acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n    for cs in range(0, tl.cdiv(chunk_size_limit, BLOCK_SIZE_CS)):\n        dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_cs[None, :] < chunk_size_limit - cs * BLOCK_SIZE_CS), other=0.0).to(dot_dtype)\n        a = tl.load(a_ptrs, mask=(offs_cs[:, None] < chunk_size_limit - cs * BLOCK_SIZE_CS) & (offs_n[None, :] < K), other=0.0).to(dot_dtype)\n        acc += tl.dot(dout, a)\n        dout_ptrs += BLOCK_SIZE_CS * stride_dout_csize_m\n        a_ptrs += BLOCK_SIZE_CS * stride_a_seqlen\n\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    if HAS_RESIDUAL:\n        res_ptr += pid_b * stride_res_batch + pid_c * chunk_size * stride_res_seqlen + pid_h * stride_res_head\n        res_ptrs = res_ptr + (offs_m[:, None] * stride_res_seqlen + offs_n[None, :] * stride_res_k)\n        res = tl.load(res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K)).to(tl.float32)\n        acc += res\n    db = acc.to(db_ptr.dtype.element_ty)\n\n    db_ptr += pid_b * stride_db_batch + pid_c * chunk_size * stride_db_seqlen + pid_h * stride_db_head\n    db_ptrs = db_ptr + (offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_k)\n    tl.store(db_ptrs, db, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K))\n\n\ndef _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=None):\n    \"\"\"\n    Argument:\n        a: (batch, seqlen, k) or (batch, seqlen, ngroups, k)\n        b: (batch, seqlen, k) or (batch, seqlen, ngroups, k)\n        seq_idx: (batch, seqlen) or None. out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out.\n        causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are\n            guaranteed to be correct.\n    Return:\n        out: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size)\n    \"\"\"\n    # Check constraints.\n    has_groups = a.dim() == 4\n    if not has_groups:\n        batch, seqlen, k = a.shape\n    else:\n        batch, seqlen, ngroups, k = a.shape\n    assert b.shape == a.shape\n    if seq_idx is not None:\n        assert seq_idx.shape == (batch, seqlen)\n    if a.stride(-1) != 1 and a.stride(1) != 1:\n        a = a.contiguous()\n    if b.stride(-1) != 1 and b.stride(1) != 1:\n        b = b.contiguous()\n    nchunks = math.ceil(seqlen / chunk_size)\n    # Allocates output.\n    out_dtype = a.dtype if output_dtype is None else output_dtype\n    out = torch.empty((batch, nchunks, chunk_size, chunk_size) if not has_groups else (batch, nchunks, ngroups, chunk_size, chunk_size),\n                      device=a.device, dtype=out_dtype)\n    dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else\n                 (tl.float16 if a.dtype == torch.float16 or b.dtype == torch.float16 else tl.float32))\n    grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(chunk_size, META['BLOCK_SIZE_N']),\n                    batch, nchunks if not has_groups else nchunks * ngroups)\n    with torch.cuda.device(a.device.index):\n        _bmm_chunk_fwd_kernel[grid](\n            a, b, out, seq_idx,\n            seqlen, chunk_size, k, ngroups if has_groups else 1,\n            a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1),\n            b.stride(0), b.stride(1), 0 if not has_groups else b.stride(2), b.stride(-1),\n            out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-2), out.stride(-1),\n            *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n            causal,\n            dot_dtype,\n            HAS_SEQ_IDX=seq_idx is not None,\n        )\n    return out\n\n\ndef _bmm_chunk_bwd(a, dout, residual=None, out=None):\n    \"\"\"\n    Argument:\n        a: (batch, seqlen, k) or (batch, seqlen, ngroups, k)\n        dout: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size)\n        residual: (batch, seqlen, k) or (batch, seqlen, ngroups, k)\n    Return:\n        out: (batch, seqlen, k) or (batch, seqlen, ngroups, k)\n\n    If there was seq_idx in the fwd pass, then dout[i, j] for seq_idx[i] != seq_idx[j] should already be\n    zeroed out before calling this function.\n    \"\"\"\n    # Check constraints.\n    has_groups = a.dim() == 4\n    if not has_groups:\n        batch, seqlen, k = a.shape\n    else:\n        batch, seqlen, ngroups, k = a.shape\n    nchunks, chunk_size = dout.shape[1], dout.shape[-1]\n    if a.stride(-1) != 1 and a.stride(-2) != 1:\n        a = a.contiguous()\n    if dout.stride(-1) != 1 and dout.stride(-2) != 1:\n        dout = dout.contiguous()\n    if residual is not None:\n        assert residual.shape == (batch, seqlen, k) if not has_groups else (batch, seqlen, ngroups, k)\n        if residual.stride(-1) != 1 and residual.stride(1) != 1:\n            residual = residual.contiguous()\n    # Allocates output.\n    if out is not None:\n        assert out.shape == a.shape\n        assert out.stride(-1) == 1 or out.stride(1) == 1\n    else:\n        out = torch.empty_like(a)\n    dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or dout.dtype == torch.bfloat16 else\n                 (tl.float16 if a.dtype == torch.float16 or dout.dtype == torch.float16 else tl.float32))\n    grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(k, META['BLOCK_SIZE_N']), batch,\n                    nchunks if not has_groups else nchunks * ngroups)\n    residual_strides = ((residual.stride(0), residual.stride(1), 0 if not has_groups else residual.stride(2),\n                         residual.stride(-1))\n                        if residual is not None else (0, 0, 0, 0))\n    with torch.cuda.device(a.device.index):\n        _bmm_chunk_bwd_kernel[grid](\n            a, dout, out, residual,\n            seqlen, chunk_size, k, ngroups if has_groups else 1,\n            a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1),\n            dout.stride(0), dout.stride(1), 0 if not has_groups else dout.stride(2), dout.stride(-2), dout.stride(-1),\n            out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-1),\n            residual_strides[0], residual_strides[1], residual_strides[2], residual_strides[3],\n            dot_dtype,\n            HAS_RESIDUAL=residual is not None,\n        )\n    return out\n"
  },
  {
    "path": "mamba_ssm/ops/triton/ssd_chunk_scan.py",
    "content": "# Copyright (c) 2024, Tri Dao, Albert Gu.\n\n\"\"\"We want triton==2.1.0 or 2.2.0 for this\n\"\"\"\n\nimport math\nfrom packaging import version\n\nimport torch\nimport torch.nn.functional as F\n\nimport triton\nimport triton.language as tl\n\nfrom einops import rearrange, repeat\n\nfrom mamba_ssm.ops.triton.ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd\nfrom mamba_ssm.utils.determinism import (\n    alloc_tile_workspace,\n    finalize_tile_workspace,\n    use_deterministic_mode,\n    autotune_configs,\n)\n\nTRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')\n\n\ndef init_to_zero(names):\n    return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None]\n\n\n@triton.autotune(\n    configs=autotune_configs([\n        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4),\n        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),\n    ]),\n    key=['chunk_size', 'hdim', 'dstate', 'IS_CAUSAL'],\n)\n@triton.jit\ndef _chunk_scan_fwd_kernel(\n    # Pointers to matrices\n    cb_ptr, x_ptr, z_ptr, out_ptr, out_x_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, C_ptr, prev_states_ptr, D_ptr,\n    # Matrix dimensions\n    chunk_size, hdim, dstate,\n    batch, seqlen, nheads_ngroups_ratio,\n    # Strides\n    stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k,\n    stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,\n    stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim,\n    stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim,\n    stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,\n    stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,\n    stride_seq_idx_batch, stride_seq_idx_seqlen,\n    stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate,\n    stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate,\n    stride_D_head,\n    # Meta-parameters\n    IS_CAUSAL: tl.constexpr,\n    HAS_D: tl.constexpr,\n    D_HAS_HDIM: tl.constexpr,\n    HAS_Z: tl.constexpr,\n    HAS_SEQ_IDX: tl.constexpr,\n    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n    BLOCK_SIZE_DSTATE: tl.constexpr,\n    IS_TRITON_22: tl.constexpr,\n):\n    pid_bc = tl.program_id(axis=1)\n    pid_c = pid_bc // batch\n    pid_b = pid_bc - pid_c * batch\n    pid_h = tl.program_id(axis=2)\n    num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)\n    pid_m = tl.program_id(axis=0) // num_pid_n\n    pid_n = tl.program_id(axis=0) % num_pid_n\n    cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head\n    x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head\n    dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head\n    dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head\n    C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + (pid_h // nheads_ngroups_ratio) * stride_C_head\n    prev_states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head\n    if HAS_SEQ_IDX:\n        seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen\n\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)\n\n    chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n    if HAS_SEQ_IDX:\n        seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0)\n        seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)\n    acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n    # Without the if (pid_c > -1), with Triton 2.1.0, I get\n    # Assertion `!(srcMmaLayout && dstMmaLayout) && \"Unexpected mma -> mm a layout conversion\"' failed.\n    # With Triton 2.2.0, this works\n    if IS_TRITON_22 or pid_c > -1:\n        # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128\n        offs_k_dstate = tl.arange(0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K)\n        C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate)\n        prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_states_hdim + offs_k_dstate[:, None] * stride_states_dstate)\n        if not HAS_SEQ_IDX:\n            scale_m = tl.exp(dA_cs_m)\n        else:\n            scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0)\n        if BLOCK_SIZE_DSTATE <= 128:\n            C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate), other=0.0)\n            prev_states = tl.load(prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0)\n            prev_states = prev_states.to(C_ptr.dtype.element_ty)\n            acc = tl.dot(C, prev_states) * scale_m[:, None]\n        else:\n            for k in range(0, dstate, BLOCK_SIZE_K):\n                C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate - k), other=0.0)\n                # C = (C * scale_m[:, None]).to(C_ptr.dtype.element_ty)\n                prev_states = tl.load(prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0)\n                prev_states = prev_states.to(C_ptr.dtype.element_ty)\n                acc += tl.dot(C, prev_states)\n                C_ptrs += BLOCK_SIZE_K\n                prev_states_ptrs += BLOCK_SIZE_K\n            acc *= scale_m[:, None]\n\n    offs_k = tl.arange(0, BLOCK_SIZE_K)\n    cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k)\n    x_ptrs = x_ptr + (offs_k[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)\n    dt_ptrs = dt_ptr + offs_k * stride_dt_csize\n    dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize\n    K_MAX = chunk_size_limit if not IS_CAUSAL else min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit)\n    for k in range(0, K_MAX, BLOCK_SIZE_K):\n        cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < chunk_size - k), other=0.0).to(tl.float32)\n        dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32)\n        # If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j].\n        # So we don't need masking wrt seq_idx here.\n        # cb *= tl.exp((dA_cs_m[:, None] - dA_cs_k[None, :]))\n        cb *= tl.exp(tl.minimum((dA_cs_m[:, None] - dA_cs_k[None, :]), 0.0))\n        dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32)\n        cb *= dt_k\n        if IS_CAUSAL:\n            mask = offs_m[:, None] >= k + offs_k[None, :]\n            cb = tl.where(mask, cb, 0.0)\n        cb = cb.to(x_ptr.dtype.element_ty)\n        x = tl.load(x_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < hdim), other=0.0)\n        acc += tl.dot(cb, x)\n        cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k\n        x_ptrs += BLOCK_SIZE_K * stride_x_seqlen\n        dt_ptrs += BLOCK_SIZE_K * stride_dt_csize\n        dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize\n\n    offs_out_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n\n    if HAS_D:\n        if D_HAS_HDIM:\n            D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32)\n        else:\n            D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)\n        x_residual = tl.load(x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim),\n                             mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)\n        acc += x_residual * D\n\n    if HAS_Z:\n        out_x_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head\n        out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :])\n        tl.store(out_x_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim))\n\n        z_ptr += pid_b * stride_z_batch + pid_c * chunk_size * stride_z_seqlen + pid_h * stride_z_head\n        z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :])\n        z = tl.load(z_ptrs, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim), other=0.0).to(tl.float32)\n        acc *= z * tl.sigmoid(z)\n\n    out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head\n    out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim)\n    tl.store(out_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim))\n\n\n@triton.autotune(\n    configs=autotune_configs([\n        # triton.Config({'BLOCK_SIZE_N': 256}, num_stages=4, num_warps=4),\n        # triton.Config({'BLOCK_SIZE_N': 128}, num_stages=4, num_warps=4),\n        triton.Config({'BLOCK_SIZE_N': 64}, num_stages=4, num_warps=4),\n        triton.Config({'BLOCK_SIZE_N': 32}, num_stages=4, num_warps=4),\n        triton.Config({'BLOCK_SIZE_N': 64}, num_stages=4, num_warps=8),\n        triton.Config({'BLOCK_SIZE_N': 32}, num_stages=4, num_warps=8),\n    ]),\n    key=['chunk_size', 'hdim', 'dstate'],\n)\n@triton.jit\ndef _chunk_scan_fwd_kernel_wip(\n    # Pointers to matrices\n    cb_ptr, x_ptr, z_ptr, out_ptr, out_x_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, C_ptr, B_ptr, prev_states_ptr, D_ptr,\n    # Matrix dimensions\n    chunk_size, hdim, dstate,\n    batch, seqlen, nheads_ngroups_ratio,\n    # Strides\n    stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k,\n    stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,\n    stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim,\n    stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim,\n    stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,\n    stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,\n    stride_seq_idx_batch, stride_seq_idx_seqlen,\n    stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate,\n    stride_B_batch, stride_B_seqlen, stride_B_head, stride_B_dstate,\n    stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate,\n    stride_D_head,\n    # Meta-parameters\n    HAS_D: tl.constexpr,\n    D_HAS_HDIM: tl.constexpr,\n    HAS_Z: tl.constexpr,\n    HAS_SEQ_IDX: tl.constexpr,\n    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,\n    BLOCK_SIZE_DSTATE: tl.constexpr,\n):\n    pid_bc = tl.program_id(axis=1)\n    pid_c = pid_bc // batch\n    pid_b = pid_bc - pid_c * batch\n    pid_h = tl.program_id(axis=2)\n    pid_n = tl.program_id(axis=0)\n    cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head\n    x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head\n    dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head\n    dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head\n    C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + (pid_h // nheads_ngroups_ratio) * stride_C_head\n    B_ptr += pid_b * stride_B_batch + pid_c * chunk_size * stride_B_seqlen + (pid_h // nheads_ngroups_ratio) * stride_B_head\n    prev_states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head\n    if HAS_SEQ_IDX:\n        seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen\n    out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head\n\n    offs_m = tl.arange(0, BLOCK_SIZE_M)\n    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    offs_k_dstate = tl.arange(0, BLOCK_SIZE_DSTATE)\n\n    C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate)\n    B_ptrs = B_ptr + (offs_m[None, :] * stride_B_seqlen + offs_k_dstate[:, None] * stride_B_dstate)\n    prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_states_hdim + offs_k_dstate[:, None] * stride_states_dstate)\n    num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)\n    cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_m[None, :] * stride_cb_csize_k)\n    x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)\n    dt_ptrs = dt_ptr + offs_m * stride_dt_csize\n    out_ptrs = out_ptr + (offs_m[:, None] * stride_out_seqlen + offs_n[None, :] * stride_out_hdim)\n\n    prev_states = tl.load(prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0)\n    # if pid_c == 0:\n    #     if pid_b == 0:\n    #         if pid_h == 0:\n    #             tl.device_print(\"\", prev_states)\n    chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n\n    # dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)\n    # scale_m = tl.exp(dA_cs_m)\n    # C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate), other=0.0)\n    # acc = tl.dot(C, prev_states.to(C_ptr.dtype.element_ty)) * scale_m[:, None]\n    # cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_m[None, :] < chunk_size), other=0.0).to(tl.float32)\n    # cb *= tl.exp((dA_cs_m[:, None] - dA_cs_m[None, :]))\n    # dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)\n    # cb *= dt_m\n    # mask = offs_m[:, None] >= offs_m[None, :]\n    # cb = tl.where(mask, cb, 0.0)\n    # cb = cb.to(x_ptr.dtype.element_ty)\n    # x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0)\n    # acc += tl.dot(cb, x)\n    # if HAS_D:\n    #     if D_HAS_HDIM:\n    #         D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32)\n    #     else:\n    #         D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)\n    #     acc += x.to(tl.float32) * D\n    # tl.store(out_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim))\n\n    for start_m in range(0, chunk_size_limit, BLOCK_SIZE_M):\n        start_m = tl.multiple_of(start_m, BLOCK_SIZE_M)\n        dA_cs_m = tl.load(dA_cumsum_ptr + (start_m + offs_m) * stride_dA_cs_csize, mask=offs_m < chunk_size - start_m, other=0.0).to(tl.float32)\n        if HAS_SEQ_IDX:\n            seq_idx_prev = tl.load(seq_idx_ptr + start_m - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0)\n            seq_idx_m = tl.load(seq_idx_ptr + (start_m + offs_m) * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit - start_m, other=-1)\n        if not HAS_SEQ_IDX:\n            scale_m = tl.exp(dA_cs_m)\n        else:\n            scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0)\n        C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit - start_m) & (offs_k_dstate[None, :] < dstate), other=0.0)\n        acc = tl.dot(C, prev_states.to(C_ptr.dtype.element_ty)) * scale_m[:, None]\n        # cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size - start_m) & (offs_m[None, :] < chunk_size - start_m), other=0.0).to(tl.float32)\n        # cb *= tl.exp((dA_cs_m[:, None] - dA_cs_m[None, :]))\n        dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size - start_m, other=0.0).to(tl.float32)\n        # cb *= dt_m\n        # mask = offs_m[:, None] >= offs_m[None, :]\n        # cb = tl.where(mask, cb, 0.0)\n        # cb = cb.to(x_ptr.dtype.element_ty)\n        x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit - start_m) & (offs_n[None, :] < hdim), other=0.0)\n        # acc += tl.dot(cb, x)\n\n        if HAS_D:\n            if D_HAS_HDIM:\n                D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32)\n            else:\n                D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)\n            acc += x.to(tl.float32) * D\n\n        # if HAS_Z:\n        #     out_x_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head\n        #     out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :])\n        #     tl.store(out_x_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim))\n\n        #     z_ptr += pid_b * stride_z_batch + pid_c * chunk_size * stride_z_seqlen + pid_h * stride_z_head\n        #     z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :])\n        #     z = tl.load(z_ptrs, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim), other=0.0).to(tl.float32)\n        #     acc *= z * tl.sigmoid(z)\n\n        tl.store(out_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit - start_m) & (offs_n[None, :] < hdim))\n\n        # TODO: this is not correct, and quite a bit slower\n        if start_m + BLOCK_SIZE_M < chunk_size_limit:\n            # B = tl.load(B_ptrs, mask=(offs_m[None, :] < chunk_size_limit - start_m) & (offs_k_dstate[:, None] < dstate), other=0.0).to(tl.float32)\n            B = tl.load(B_ptrs, mask=(offs_m[None, :] < chunk_size_limit - start_m) & (offs_k_dstate[:, None] < dstate), other=0.0)\n            dA_cs_last = tl.load(dA_cumsum_ptr + (start_m + BLOCK_SIZE_M) * stride_dA_cs_csize).to(tl.float32)\n            # TODO: seq_idx\n            scale = tl.exp((dA_cs_last - dA_cs_m)) * dt_m\n            # B *= scale\n            B = B.to(x_ptr.dtype.element_ty)\n            tmp = tl.dot(B, x)\n            prev_states += tmp.to(prev_states.dtype)\n\n        C_ptrs += BLOCK_SIZE_M * stride_C_seqlen\n        B_ptrs += BLOCK_SIZE_M * stride_B_seqlen\n        cb_ptrs += BLOCK_SIZE_M * stride_cb_csize_m + BLOCK_SIZE_M * stride_cb_csize_k\n        x_ptrs += BLOCK_SIZE_M * stride_x_seqlen\n        dt_ptrs += BLOCK_SIZE_M * stride_dt_csize\n        out_ptrs += BLOCK_SIZE_M * stride_out_seqlen\n\n\n@triton.autotune(\n    configs=autotune_configs([\n        triton.Config({'BLOCK_SIZE_M': 32}),\n        triton.Config({'BLOCK_SIZE_M': 64}),\n        triton.Config({'BLOCK_SIZE_M': 128}),\n        triton.Config({'BLOCK_SIZE_M': 256}),\n    ]),\n    key=[\"chunk_size\", \"hdim\"],\n)\n@triton.jit\ndef _chunk_scan_bwd_dz_kernel(\n    # Pointers to matrices\n    dout_ptr, out_ptr, z_ptr, x_ptr, D_ptr, outz_ptr, dz_ptr, dout_x_ptr, dD_ptr, ddA_cumsum_ptr,\n    # Matrix dimensions\n    chunk_size, hdim,\n    batch, seqlen,\n    # Strides\n    stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,\n    stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim,\n    stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim,\n    stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,\n    stride_D_head,\n    stride_outz_batch, stride_outz_seqlen, stride_outz_head, stride_outz_hdim,\n    stride_dz_batch, stride_dz_seqlen, stride_dz_head, stride_dz_hdim,\n    stride_doutx_batch, stride_doutx_seqlen, stride_doutx_head, stride_doutx_hdim,\n    stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim,\n    stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize,\n    # Meta-parameters\n    HAS_D: tl.constexpr,\n    D_HAS_HDIM: tl.constexpr,\n    HAS_DDACS: tl.constexpr,\n    RECOMPUTE_OUTPUT: tl.constexpr,\n    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,\n):\n    pid_bc = tl.program_id(axis=1)\n    pid_c = pid_bc // batch\n    pid_b = pid_bc - pid_c * batch\n    pid_h = tl.program_id(axis=2)\n    pid_m = tl.program_id(axis=0)\n\n    dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head\n    dout_x_ptr += pid_b * stride_doutx_batch + pid_c * chunk_size * stride_doutx_seqlen + pid_h * stride_doutx_head\n    out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head\n    z_ptr += pid_b * stride_z_batch + pid_c * chunk_size * stride_z_seqlen + pid_h * stride_z_head\n    dz_ptr += pid_b * stride_dz_batch + pid_c * chunk_size * stride_dz_seqlen + pid_h * stride_dz_head\n    if RECOMPUTE_OUTPUT:\n        outz_ptr += pid_b * stride_outz_batch + pid_c * chunk_size * stride_outz_seqlen + pid_h * stride_outz_head\n    if HAS_DDACS:\n        ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head\n    if HAS_D:\n        x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head\n        dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize\n\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_n = tl.arange(0, BLOCK_SIZE_N)\n    dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim)\n    dout_x_ptrs = dout_x_ptr + (offs_m[:, None] * stride_doutx_seqlen + offs_n[None, :] * stride_doutx_hdim)\n    out_ptrs = out_ptr + (offs_m[:, None] * stride_out_seqlen + offs_n[None, :] * stride_out_hdim)\n    z_ptrs = z_ptr + (offs_m[:, None] * stride_z_seqlen + offs_n[None, :] * stride_z_hdim)\n    dz_ptrs = dz_ptr + (offs_m[:, None] * stride_dz_seqlen + offs_n[None, :] * stride_dz_hdim)\n    if RECOMPUTE_OUTPUT:\n        outz_ptrs = outz_ptr + (offs_m[:, None] * stride_outz_seqlen + offs_n[None, :] * stride_outz_hdim)\n    if HAS_D:\n        x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)\n        if D_HAS_HDIM:\n            dD_ptrs = dD_ptr + offs_n * stride_dD_hdim\n\n    chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n    dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)\n    out = tl.load(out_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)\n    z = tl.load(z_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)\n    z_sigmoid = tl.sigmoid(z)\n    if RECOMPUTE_OUTPUT:\n        outz = out * z * z_sigmoid\n        tl.store(outz_ptrs, outz, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim))\n    dz = dout * out * z_sigmoid * (1 + z * (1 - z_sigmoid))\n    tl.store(dz_ptrs, dz, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim))\n    dout *= z * z_sigmoid\n    tl.store(dout_x_ptrs, dout, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim))\n    if HAS_D:\n        x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)\n        if D_HAS_HDIM:\n            dD = tl.sum(dout * x, axis=0)\n            tl.store(dD_ptrs, dD, mask=offs_n < hdim)\n            D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32)\n        else:\n            dD = tl.sum(dout * x)\n            tl.store(dD_ptr, dD)\n            D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)\n        out -= x * D\n    if HAS_DDACS:\n        ddA_cs = tl.sum(dout * out, axis=1)\n        tl.store(ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size)\n\n\n@triton.autotune(\n    configs=autotune_configs([\n        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),\n    ]),\n    key=['hdim', 'dstate', 'chunk_size'],\n)\n@triton.jit\ndef _chunk_scan_bwd_dstates_kernel(\n    # Pointers to matrices\n    dout_ptr, c_ptr, dprev_states_ptr, dA_cumsum_ptr, seq_idx_ptr,\n    # Matrix dimensions\n    hdim, dstate, chunk_size,\n    batch, seqlen, nchunks, nheads_ngroups_ratio,\n    # Strides\n    stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,\n    stride_c_batch, stride_c_seqlen, stride_c_head, stride_c_dstate,\n    stride_dprev_states_batch, stride_dprev_states_chunk, stride_dprev_states_head, stride_dprev_states_hdim, stride_dprev_states_dstate,\n    stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,\n    stride_seq_idx_batch, stride_seq_idx_seqlen,\n    # Meta-parameters\n    HAS_SEQ_IDX: tl.constexpr,\n    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n):\n    pid_bc = tl.program_id(axis=1)\n    pid_c = pid_bc // batch\n    pid_b = pid_bc - pid_c * batch\n    pid_h = tl.program_id(axis=2)\n    num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)\n    pid_m = tl.program_id(axis=0) // num_pid_n\n    pid_n = tl.program_id(axis=0) % num_pid_n\n    c_ptr += pid_b * stride_c_batch + pid_c * chunk_size * stride_c_seqlen + (pid_h // nheads_ngroups_ratio) * stride_c_head\n    dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head\n    dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head\n    if HAS_SEQ_IDX:\n        seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen\n\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    offs_k = tl.arange(0, BLOCK_SIZE_K)\n    dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_hdim + offs_k[None, :] * stride_dout_seqlen)\n    c_ptrs = c_ptr + (offs_n[None, :] * stride_c_dstate + offs_k[:, None] * stride_c_seqlen)\n    dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize\n    if HAS_SEQ_IDX:\n        seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen\n\n    chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n    acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n    if HAS_SEQ_IDX:\n        seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0)\n    for k in range(0, chunk_size_limit, BLOCK_SIZE_K):\n        dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k), other=0.0).to(tl.float32)\n        dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32)\n        if not HAS_SEQ_IDX:\n            scale_k = tl.exp(dA_cs_k)\n        else:\n            seq_idx_k = tl.load(seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1)\n            scale_k = tl.where(seq_idx_k == seq_idx_prev, tl.exp(dA_cs_k), 0.0)\n        dout = (dout * scale_k).to(dout_ptr.dtype.element_ty)\n        c = tl.load(c_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate), other=0.0).to(dout_ptr.dtype.element_ty)\n        acc += tl.dot(dout, c)\n        dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen\n        c_ptrs += BLOCK_SIZE_K * stride_c_seqlen\n        dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize\n        if HAS_SEQ_IDX:\n            seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen\n    out = acc.to(dprev_states_ptr.dtype.element_ty)\n\n    dprev_states_ptr += pid_b * stride_dprev_states_batch + pid_c * stride_dprev_states_chunk + pid_h * stride_dprev_states_head\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    dprev_states_ptrs = dprev_states_ptr + (offs_m[:, None] * stride_dprev_states_hdim + offs_n[None, :] * stride_dprev_states_dstate)\n    tl.store(dprev_states_ptrs, out, mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate))\n\n\n@triton.autotune(\n    configs=autotune_configs([\n        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero([\"ddA_cumsum_ptr\"])),\n        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero([\"ddA_cumsum_ptr\"])),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero([\"ddA_cumsum_ptr\"])),\n        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero([\"ddA_cumsum_ptr\"])),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero([\"ddA_cumsum_ptr\"])),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero([\"ddA_cumsum_ptr\"])),\n        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero([\"ddA_cumsum_ptr\"])),\n        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero([\"ddA_cumsum_ptr\"])),\n    ]),\n    key=['chunk_size', 'dstate', 'hdim'],\n)\n@triton.jit\ndef _chunk_scan_bwd_dc_kernel(\n    # Pointers to matrices\n    dout_ptr, prev_states_ptr, C_ptr, dA_cumsum_ptr, seq_idx_ptr,\n    dc_ptr, ddA_cumsum_ptr,\n    # Matrix dimensions\n    chunk_size, dstate, hdim,\n    batch, seqlen, nheads, nheads_per_program, ngroups,\n    # Strides\n    stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,\n    stride_prev_states_batch, stride_prev_states_chunk, stride_prev_states_head, stride_prev_states_hdim, stride_prev_states_dstate,\n    stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate,\n    stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,\n    stride_seq_idx_batch, stride_seq_idx_seqlen,\n    stride_dc_batch, stride_dc_seqlen, stride_dc_split, stride_dc_group, stride_dc_dstate,\n    stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, stride_ddA_tile,\n    # Meta-parameters\n    HAS_DDA_CS: tl.constexpr,\n    HAS_SEQ_IDX: tl.constexpr,\n    DETERMINISTIC_REDUCTION: tl.constexpr,\n    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n):\n    pid_bc = tl.program_id(axis=1)\n    pid_c = pid_bc // batch\n    pid_b = pid_bc - pid_c * batch\n    pid_sg = tl.program_id(axis=2)\n    pid_s = pid_sg // ngroups\n    pid_g = pid_sg - pid_s * ngroups\n    num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)\n    pid_m = tl.program_id(axis=0) // num_pid_n\n    pid_n = tl.program_id(axis=0) % num_pid_n\n    dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dout_head\n    dc_ptr += pid_b * stride_dc_batch + pid_c * chunk_size * stride_dc_seqlen + pid_g * stride_dc_group + pid_s * stride_dc_split\n    prev_states_ptr += pid_b * stride_prev_states_batch + pid_c * stride_prev_states_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_prev_states_head\n    dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head\n    if HAS_DDA_CS:\n        C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + pid_g * stride_C_head\n        ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head + pid_n * stride_ddA_tile\n    if HAS_SEQ_IDX:\n        seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen\n\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    offs_k = tl.arange(0, BLOCK_SIZE_K)\n    dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim)\n    prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_prev_states_dstate + offs_k[:, None] * stride_prev_states_hdim)\n    dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize\n    if HAS_DDA_CS:\n        C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_n[None, :] * stride_C_dstate)\n        ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize\n\n    chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n    acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n    if HAS_DDA_CS:\n        c = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)\n    if HAS_SEQ_IDX:\n        seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0)\n        seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)\n    nheads_iter = min(nheads_per_program, nheads // ngroups - pid_s * nheads_per_program)\n    for h in range(nheads_iter):\n        dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0)\n        prev_states = tl.load(prev_states_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0)\n        prev_states = prev_states.to(dout_ptrs.dtype.element_ty)\n        dc = tl.dot(dout, prev_states)\n        dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)\n        if not HAS_SEQ_IDX:\n            scale = tl.exp(dA_cs_m)\n        else:\n            scale = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0)\n        dc *= scale[:, None]\n        if HAS_DDA_CS:\n            ddA_cs = tl.sum(dc * c, axis=1)\n            if DETERMINISTIC_REDUCTION:\n                tl.store(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)\n            else:\n                tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)\n        acc += dc\n        dout_ptrs += stride_dout_head\n        prev_states_ptrs += stride_prev_states_head\n        dA_cumsum_ptrs += stride_dA_cs_head\n        if HAS_DDA_CS:\n            ddA_cumsum_ptrs += stride_ddA_cs_head\n    # if HAS_SEQ_IDX:\n    #     seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0)\n    #     seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)\n    #     acc = tl.where(seq_idx_m[:, None] == seq_idx_prev, acc, 0.0)\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    dc_ptrs = dc_ptr + (offs_m[:, None] * stride_dc_seqlen + offs_n[None, :] * stride_dc_dstate)\n    tl.store(dc_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate))\n\n\n_CHUNK_SCAN_BWD_DC_MIN_BLOCK_N = min(\n    cfg.kwargs['BLOCK_SIZE_N'] for cfg in _chunk_scan_bwd_dc_kernel.configs\n)\n\n\n@triton.autotune(\n    configs=autotune_configs([\n        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero([\"ddt_ptr\"])),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\"])),\n        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\"])),\n        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\"])),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\"])),\n        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\"])),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\"])),\n        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\"])),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\"])),\n    ]),\n    key=['chunk_size', 'hdim'],\n)\n@triton.jit\ndef _chunk_scan_bwd_dx_kernel(\n    # Pointers to matrices\n    x_ptr, cb_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, D_ptr,\n    dx_ptr, ddt_ptr, # dD_ptr,\n    # Matrix dimensions\n    chunk_size, hdim,\n    batch, seqlen, nheads_ngroups_ratio,\n    # Strides\n    stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,\n    stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k,\n    stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,\n    stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,\n    stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,\n    stride_D_head,\n    stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim,\n    stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize, stride_ddt_tile,\n    # stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_hdim, stride_dD_csize,\n    # Meta-parameters\n    HAS_D: tl.constexpr,\n    D_HAS_HDIM: tl.constexpr,\n    DETERMINISTIC_REDUCTION: tl.constexpr,\n    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n):\n    pid_bc = tl.program_id(axis=1)\n    pid_c = pid_bc // batch\n    pid_b = pid_bc - pid_c * batch\n    pid_h = tl.program_id(axis=2)\n    num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)\n    pid_m = tl.program_id(axis=0) // num_pid_n\n    pid_n = tl.program_id(axis=0) % num_pid_n\n    x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head\n    cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head\n    dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head\n    dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head\n    ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head + pid_n * stride_ddt_tile\n    dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head\n    # if HAS_D:\n    #     dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize\n\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    offs_k = tl.arange(0, BLOCK_SIZE_K)\n    cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k)\n    dout_ptrs = dout_ptr + (offs_k[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim)\n    dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize\n\n    chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n    dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)\n\n    acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n    # Idk why limiting K_MAX gives wrong results, is it a Triton bug?\n    # K_MAX = min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit)\n    K_MAX = chunk_size_limit\n    for k in range(0, K_MAX, BLOCK_SIZE_K):\n        # For some reason setting mask to (offs_m[:, None] < chunk_size_limit) is much slower\n        cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0)\n        dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0)\n        dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32)\n        # cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])\n        cb *= tl.exp(tl.minimum((dA_cs_k[None, :] - dA_cs_m[:, None]), 0.0))\n        # If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range,\n        # we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf.\n        # Multiplying with cb, which is 0.0 outside the range, will make the result NaN.\n        # This will cause NaN in acc, and hence NaN in dx and ddt.\n        mask = (k + offs_k[None, :] >= offs_m[:, None]) & (k + offs_k[None, :] < K_MAX)\n        cb = tl.where(mask, cb, 0.0)\n        cb = cb.to(dout_ptr.dtype.element_ty)\n        acc += tl.dot(cb, dout)\n        cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k\n        dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen\n        dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize\n\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    dt_ptrs = dt_ptr + offs_m * stride_dt_csize\n    dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)\n    dx = acc * dt_m[:, None]\n    dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head\n    dx_ptrs = dx_ptr + (offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim)\n    if HAS_D:\n        dout_res_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim)\n        dout_res = tl.load(dout_res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)\n        if D_HAS_HDIM:\n            D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32)\n        else:\n            D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)\n        dx += dout_res * D\n    tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim))\n\n    x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)\n    x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)\n    ddt = tl.sum(acc * x, axis=1)\n    ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize\n    if DETERMINISTIC_REDUCTION:\n        tl.store(ddt_ptrs, ddt, mask=offs_m < chunk_size)\n    else:\n        tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)\n\n    # if HAS_D:\n    #     dout_new_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_csize + offs_n[None, :] * stride_dout_hdim)\n    #     dout = tl.load(dout_new_ptrs, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N), other=0.0).to(tl.float32)\n    #     dD = tl.sum(x * dout, axis=0)\n    #     tl.store(dD_ptr + offs_n * stride_dD_hdim, dD, mask=offs_n < N)\n\n\n_CHUNK_SCAN_BWD_DX_MIN_BLOCK_N = min(\n    cfg.kwargs['BLOCK_SIZE_N'] for cfg in _chunk_scan_bwd_dx_kernel.configs\n)\n\n\n# Disabling HAS_DDA_CS for now since it's much slower\n@triton.autotune(\n    configs=autotune_configs([\n        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4),\n        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4),\n        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4),\n        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4),\n        # triton.Config({'BLOCK_SIZE_M': 16}, num_stages=3, num_warps=4),\n        # triton.Config({'BLOCK_SIZE_M': 32}, num_stages=3, num_warps=4),\n        # triton.Config({'BLOCK_SIZE_M': 64}, num_stages=3, num_warps=4),\n        # triton.Config({'BLOCK_SIZE_M': 128}, num_stages=3, num_warps=4),\n        # triton.Config({'BLOCK_SIZE_M': 16}, num_stages=4, num_warps=8),\n        # triton.Config({'BLOCK_SIZE_M': 32}, num_stages=4, num_warps=8),\n        # triton.Config({'BLOCK_SIZE_M': 64}, num_stages=4, num_warps=8),\n        # triton.Config({'BLOCK_SIZE_M': 128}, num_stages=4, num_warps=8),\n    ]),\n    key=['chunk_size', 'hdim'],\n)\n# @triton.heuristics({\"BLOCK_SIZE_N\": lambda args: max(triton.next_power_of_2(args[\"chunk_size\"]), 16)})\n# @triton.heuristics({\"BLOCK_SIZE_N\": lambda args: 32})\n@triton.jit\ndef _chunk_scan_bwd_dcb_kernel(\n    # Pointers to matrices\n    x_ptr, dout_ptr, cb_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr,\n    dcb_ptr, ddA_cumsum_ptr,\n    # Matrix dimensions\n    chunk_size, hdim,\n    batch, seqlen, nheads, nheads_per_program, ngroups,\n    # Strides\n    stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,\n    stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,\n    stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_n,\n    stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,\n    stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,\n    stride_seq_idx_batch, stride_seq_idx_seqlen,\n    stride_dcb_batch, stride_dcb_chunk, stride_dcb_split, stride_dcb_group, stride_dcb_csize_m, stride_dcb_csize_n,\n    stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize_m, stride_ddA_cs_csize_n,\n    # Meta-parameters\n    HAS_DDA_CS: tl.constexpr,\n    HAS_SEQ_IDX: tl.constexpr,\n    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n):\n    pid_bc = tl.program_id(axis=1)\n    pid_c = pid_bc // batch\n    pid_b = pid_bc - pid_c * batch\n    pid_sg = tl.program_id(axis=2)\n    pid_s = pid_sg // ngroups\n    pid_g = pid_sg - pid_s * ngroups\n    num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N)\n    pid_m = tl.program_id(axis=0) // num_pid_n\n    pid_n = tl.program_id(axis=0) % num_pid_n\n\n    x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_x_head\n    dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dout_head\n    dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dt_head\n    dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head\n    if HAS_DDA_CS:\n        cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + pid_g * stride_cb_head\n        ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head + pid_m * stride_ddA_cs_csize_m\n    if HAS_SEQ_IDX:\n        seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen\n\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    offs_k = tl.arange(0, BLOCK_SIZE_K)\n    dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim)\n    x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim)\n    dt_ptrs = dt_ptr + offs_n * stride_dt_csize\n    if HAS_DDA_CS:\n        cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n)\n        ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_n * stride_ddA_cs_csize_n\n\n    if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M:\n        dcb_ptr += pid_b * stride_dcb_batch + pid_c * stride_dcb_chunk + pid_g * stride_dcb_group + pid_s * stride_dcb_split\n        dcb_ptrs = dcb_ptr + (offs_m[:, None] * stride_dcb_csize_m + offs_n[None, :] * stride_dcb_csize_n)\n        tl.store(dcb_ptrs, tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=dcb_ptr.dtype.element_ty), mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size))\n        return\n\n    chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n    chunk_size_limit_n = min(chunk_size_limit, (pid_m + 1) * BLOCK_SIZE_M)\n    acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n    if HAS_DDA_CS:\n        cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size), other=0.0).to(tl.float32)\n    nheads_iter = min(nheads_per_program, nheads // ngroups - pid_s * nheads_per_program)\n    for h in range(nheads_iter):\n        dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0)\n        x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n), other=0.0)\n        dcb = tl.dot(dout, x)\n        dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size, other=0.0).to(tl.float32)\n        dcb *= dt_n\n        dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)\n        dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size_limit, other=0.0).to(tl.float32)\n        # dcb *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :])\n        dcb *= tl.exp(tl.minimum((dA_cs_m[:, None] - dA_cs_n[None, :]), 0.0))\n        if HAS_DDA_CS:\n            tl.static_assert(not HAS_SEQ_IDX, \"HAS_SEQ_IDX not supported with HAS_DDA_CS yet\")\n            ddA_cs = dcb * cb\n            mask = offs_m[:, None] >= offs_n[None, :] + 1\n            ddA_cs = tl.where(mask, ddA_cs, 0.0)\n            ddA_cs = tl.cumsum(ddA_cs, axis=1)\n            ddA_cs = tl.where(mask, ddA_cs, 0.0)\n            ddA_cs = tl.sum(ddA_cs, axis=0)\n            tl.store(ddA_cumsum_ptrs + stride_ddA_cs_csize_n, ddA_cs, mask=offs_n < chunk_size - 1)\n            tl.store(ddA_cumsum_ptr, 0.0)\n        acc += dcb\n        dout_ptrs += stride_dout_head\n        x_ptrs += stride_x_head\n        dt_ptrs += stride_dt_head\n        dA_cumsum_ptr += stride_dA_cs_head\n        if HAS_DDA_CS:\n            ddA_cumsum_ptr += stride_ddA_cs_head\n            ddA_cumsum_ptrs += stride_ddA_cs_head\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    if HAS_SEQ_IDX:\n        seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)\n        seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, mask=offs_n < chunk_size_limit, other=-2)\n        acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0)\n    mask = offs_m[:, None] >= offs_n[None, :]\n    acc = tl.where(mask, acc, 0.0)\n    dcb_ptr += pid_b * stride_dcb_batch + pid_c * stride_dcb_chunk + pid_g * stride_dcb_group + pid_s * stride_dcb_split\n    dcb_ptrs = dcb_ptr + (offs_m[:, None] * stride_dcb_csize_m + offs_n[None, :] * stride_dcb_csize_n)\n    tl.store(dcb_ptrs, acc, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size))\n\n\n# Not numerically stable and should not be used. Leaving here for reference.\n@triton.autotune(\n    configs=autotune_configs([\n        triton.Config({'BLOCK_SIZE_M': 32}),\n        triton.Config({'BLOCK_SIZE_M': 64}),\n        triton.Config({'BLOCK_SIZE_M': 128}),\n        triton.Config({'BLOCK_SIZE_M': 256}),\n    ]),\n    key=[\"chunk_size\", \"hdim\"],\n)\n@triton.jit\ndef _chunk_scan_bwd_ddAcs_unstable_kernel(\n    # Pointers to matrices\n    dout_ptr, out_ptr, dt_ptr, ddt_ptr, x_ptr, D_ptr,\n    ddA_cumsum_ptr, dD_ptr,\n    # Matrix dimensions\n    chunk_size, hdim,\n    batch, seqlen,\n    # Strides\n    stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,\n    stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim,\n    stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,\n    stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize,\n    stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,\n    stride_D_head,\n    stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize,\n    stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim,\n    # Meta-parameters\n    HAS_D: tl.constexpr,\n    D_HAS_HDIM: tl.constexpr,\n    SUBTRACT_DDTDT: tl.constexpr,\n    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,\n):\n    pid_bc = tl.program_id(axis=1)\n    pid_c = pid_bc // batch\n    pid_b = pid_bc - pid_c * batch\n    pid_h = tl.program_id(axis=2)\n    pid_m = tl.program_id(axis=0)\n\n    dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head\n    out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head\n    dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head\n    ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head\n    ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head\n    if HAS_D:\n        x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head\n        dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize\n\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_n = tl.arange(0, BLOCK_SIZE_N)\n    dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim)\n    out_ptrs = out_ptr + (offs_m[:, None] * stride_out_seqlen + offs_n[None, :] * stride_out_hdim)\n    if HAS_D:\n        x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)\n        if D_HAS_HDIM:\n            dD_ptrs = dD_ptr + offs_n * stride_dD_hdim\n\n    chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n    dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)\n    out = tl.load(out_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)\n    if HAS_D:\n        x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)\n        if D_HAS_HDIM:\n            dD = tl.sum(dout * x, axis=0)\n            tl.store(dD_ptrs, dD, mask=offs_n < hdim)\n            D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32)\n        else:\n            dD = tl.sum(dout * x)\n            tl.store(dD_ptr, dD)\n            D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)\n        out -= x * D\n    ddA_cs = tl.sum(dout * out, axis=1)\n    if SUBTRACT_DDTDT:\n        dt = tl.load(dt_ptr + offs_m * stride_dt_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)\n        ddt = tl.load(ddt_ptr + offs_m * stride_ddt_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)\n        ddA_cs -= dt * ddt\n    tl.store(ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size)\n\n\n@triton.autotune(\n    configs=autotune_configs([\n        # triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4),\n        # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4),\n        # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4),\n        # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4),\n        # triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8),\n        # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8),\n        # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8),\n        # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8),\n        triton.Config({'BLOCK_SIZE_M': 16}, num_stages=3, num_warps=4),\n        triton.Config({'BLOCK_SIZE_M': 32}, num_stages=3, num_warps=4),\n        triton.Config({'BLOCK_SIZE_M': 64}, num_stages=3, num_warps=4),\n        triton.Config({'BLOCK_SIZE_M': 128}, num_stages=3, num_warps=4),\n        triton.Config({'BLOCK_SIZE_M': 16}, num_stages=4, num_warps=8),\n        triton.Config({'BLOCK_SIZE_M': 32}, num_stages=4, num_warps=8),\n        triton.Config({'BLOCK_SIZE_M': 64}, num_stages=4, num_warps=8),\n        triton.Config({'BLOCK_SIZE_M': 128}, num_stages=4, num_warps=8),\n    ]),\n    key=['chunk_size', 'hdim'],\n)\n@triton.jit\ndef _chunk_scan_bwd_ddAcs_stable_kernel_old(\n    # Pointers to matrices\n    x_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, cb_ptr,\n    ddAcs_ptr,\n    # Matrix dimensions\n    chunk_size, hdim,\n    batch, seqlen, nheads_ngroups_ratio,\n    # Strides\n    stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,\n    stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,\n    stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,\n    stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,\n    stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_n,\n    stride_ddAcs_batch, stride_ddAcs_chunk, stride_ddAcs_head, stride_ddAcs_csize_m, stride_ddAcs_csize_n,\n    # Meta-parameters\n    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n):\n    pid_bc = tl.program_id(axis=1)\n    pid_c = pid_bc // batch\n    pid_b = pid_bc - pid_c * batch\n    pid_h = tl.program_id(axis=2)\n    num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N)\n    pid_m = tl.program_id(axis=0) // num_pid_n\n    pid_n = tl.program_id(axis=0) % num_pid_n\n\n    x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head\n    dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head\n    dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head\n    dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head\n    cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head\n\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    offs_k = tl.arange(0, BLOCK_SIZE_K)\n    dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim)\n    x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim)\n    dt_ptrs = dt_ptr + offs_n * stride_dt_csize\n    cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n)\n\n    chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n    chunk_size_limit_n = min(chunk_size_limit, (pid_m + 1) * BLOCK_SIZE_M)\n    # Doing a matmul loop with cumsum later on will cause Triton to crash\n    # Instead we do just one big matmul\n    # acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n    # for k in range(0, hdim, BLOCK_SIZE_K):\n    #     dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim - k), other=0.0)\n    #     x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim - k) & (offs_n[None, :] < chunk_size_limit), other=0.0)\n    #     acc += tl.dot(dout, x)\n    #     dout_ptrs += BLOCK_SIZE_K * stride_dout_hdim\n    #     x_ptrs += BLOCK_SIZE_K * stride_x_hdim\n    dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0)\n    x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n), other=0.0)\n    acc = tl.dot(dout, x)\n    cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size), other=0.0).to(tl.float32)\n    acc *= cb\n    dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size, other=0.0).to(tl.float32)\n    acc *= dt_n\n    dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)\n    dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size, other=0.0).to(tl.float32)\n    # acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :])\n    acc *= tl.exp(tl.minimum((dA_cs_m[:, None] - dA_cs_n[None, :]), 0.0))\n    mask = offs_m[:, None] >= offs_n[None, :] + 1\n    acc = tl.where(mask, acc, 0.0)\n    acc = tl.cumsum(acc, axis=1)\n    acc = tl.where(mask, acc, 0.0)\n    ddA_cs = tl.sum(acc, axis=0)\n    ddAcs_ptr += pid_b * stride_ddAcs_batch + pid_c * stride_ddAcs_chunk + pid_h * stride_ddAcs_head + pid_m * stride_ddAcs_csize_m\n    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    ddAcs_ptrs = ddAcs_ptr + offs_n * stride_ddAcs_csize_n\n    tl.store(ddAcs_ptrs + stride_ddAcs_csize_n, ddA_cs, mask=offs_n < chunk_size - 1)\n    tl.store(ddAcs_ptr, 0.0)\n\n    # offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, 64)\n    # offs_k = tl.arange(0, BLOCK_SIZE_K)\n    # dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim)\n    # x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim)\n    # dt_ptrs = dt_ptr + offs_n * stride_dt_csize\n    # cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n)\n\n    # chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n    # chunk_size_limit_n = min(chunk_size_limit, (pid_m + 1) * BLOCK_SIZE_M)\n    # rowsum = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)\n    # dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0)\n    # dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)\n    # ddAcs_ptr += pid_b * stride_ddAcs_batch + pid_c * stride_ddAcs_chunk + pid_h * stride_ddAcs_head + pid_m * stride_ddAcs_csize_m\n    # ddAcs_ptrs = ddAcs_ptr + offs_n * stride_ddAcs_csize_n\n    # for n in range(0, chunk_size_limit_n, 64):\n    #     x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n - n), other=0.0)\n    #     acc = tl.dot(dout, x)\n    #     cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size - n), other=0.0).to(tl.float32)\n    #     acc *= cb\n    #     dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size - n, other=0.0).to(tl.float32)\n    #     acc *= dt_n\n    #     dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size - n, other=0.0).to(tl.float32)\n    #     acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :])\n    #     mask = offs_m[:, None] >= offs_n[None, :] + 1 + n\n    #     acc = tl.where(mask, acc, 0.0)\n    #     acc = tl.cumsum(acc, axis=1)\n    #     acc = tl.where(mask, acc, 0.0)\n    #     ddA_cs = tl.sum(acc, axis=0)\n    #     tl.store(ddAcs_ptrs, ddA_cs, mask=offs_n < chunk_size - 1 - n)\n    # # tl.store(ddAcs_ptr, 0.0)\n\n\n@triton.autotune(\n    configs=autotune_configs([\n        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4),\n        # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4),\n        # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4),\n        # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4),\n        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4),\n        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4),\n        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4),\n    ]),\n    key=['chunk_size', 'hdim'],\n)\n@triton.jit\ndef _chunk_scan_bwd_ddAcs_stable_kernel(\n    # Pointers to matrices\n    x_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, cb_ptr,\n    ddA_cumsum_ptr,\n    # Matrix dimensions\n    chunk_size, hdim,\n    batch, seqlen, nheads_ngroups_ratio,\n    # Strides\n    stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,\n    stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,\n    stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,\n    stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,\n    stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_n,\n    stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize_m, stride_ddA_cs_csize_n,\n    # Meta-parameters\n    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n):\n    pid_bc = tl.program_id(axis=1)\n    pid_c = pid_bc // batch\n    pid_b = pid_bc - pid_c * batch\n    pid_h = tl.program_id(axis=2)\n    pid_m = tl.program_id(axis=0)\n\n    x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head\n    dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head\n    dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head\n    dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head\n    cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head\n    ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head + pid_m * stride_ddA_cs_csize_m\n\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_n = tl.arange(0, BLOCK_SIZE_N)\n    offs_k = tl.arange(0, BLOCK_SIZE_K)\n    dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim)\n    x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim)\n    dt_ptrs = dt_ptr + offs_n * stride_dt_csize\n    cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n)\n    ddAcs_ptrs = ddA_cumsum_ptr + offs_n * stride_ddA_cs_csize_n\n    tl.store(ddA_cumsum_ptr, 0.0)\n\n    chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n    rowsum = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)\n    dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0)\n    dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)\n    # Actually hi is (pid_m + 1) * BLOCK_SIZE_M - 1 but subtracting 1 makes it slower\n    lo, hi = 0, (pid_m + 1) * BLOCK_SIZE_M\n    # lo, hi = 0, chunk_size\n    for start_n in range(lo, hi, BLOCK_SIZE_N):\n        start_n = tl.multiple_of(start_n, BLOCK_SIZE_N)\n        # Doing a matmul loop with cumsum later on will cause Triton to crash\n        # Instead we do just one big matmul\n        # acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n        # for k in range(0, hdim, BLOCK_SIZE_K):\n        #     dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim - k), other=0.0)\n        #     x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim - k) & (offs_n[None, :] < chunk_size_limit), other=0.0)\n        #     acc += tl.dot(dout, x)\n        #     dout_ptrs += BLOCK_SIZE_K * stride_dout_hdim\n        #     x_ptrs += BLOCK_SIZE_K * stride_x_hdim\n        # x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n), other=0.0)\n        x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit - start_n), other=0.0)\n        acc = tl.dot(dout, x)\n        dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size - start_n, other=0.0).to(tl.float32)\n        acc *= dt_n\n        # If there's seq_idx, we already zero'ed out cb[i, j] for seq_idx[i] != seq_idx[j]\n        cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size - start_n), other=0.0).to(tl.float32)\n        acc *= cb\n        dA_cs_n = tl.load(dA_cumsum_ptr + (start_n + offs_n) * stride_dA_cs_csize, mask=offs_n < chunk_size - start_n, other=0.0).to(tl.float32)\n        # acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :])\n        acc *= tl.exp(tl.minimum((dA_cs_m[:, None] - dA_cs_n[None, :]), 0.0))\n        mask = offs_m[:, None] >= start_n + offs_n[None, :] + 1\n        acc = tl.where(mask, acc, 0.0)\n        rowsum_new = rowsum + tl.sum(acc, axis=1)\n        acc = rowsum[:, None] + tl.cumsum(acc, axis=1)\n        rowsum = rowsum_new\n        acc = tl.where(mask, acc, 0.0)\n        ddA_cs = tl.sum(acc, axis=0)\n        tl.store(ddAcs_ptrs + stride_ddA_cs_csize_n, ddA_cs, mask=offs_n < chunk_size - start_n - 1)\n        x_ptrs += BLOCK_SIZE_N * stride_x_seqlen\n        dt_ptrs += BLOCK_SIZE_N * stride_dt_csize\n        cb_ptrs += BLOCK_SIZE_N * stride_cb_csize_n\n        ddAcs_ptrs += BLOCK_SIZE_N * stride_ddA_cs_csize_n\n\n    # Need to zero out the rest, since we'll be summing the rows together\n    for start_n in range(hi, chunk_size, BLOCK_SIZE_N):\n        tl.store(ddAcs_ptrs + stride_ddA_cs_csize_n, tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32), mask=offs_n < chunk_size - start_n - 1)\n        ddAcs_ptrs += BLOCK_SIZE_N * stride_ddA_cs_csize_n\n\n\n@triton.autotune(\n    configs=autotune_configs([\n        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero([\"ddA_cumsum_ptr\"])),\n        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero([\"ddA_cumsum_ptr\"])),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero([\"ddA_cumsum_ptr\"])),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero([\"ddA_cumsum_ptr\"])),\n        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero([\"ddA_cumsum_ptr\"])),\n        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero([\"ddA_cumsum_ptr\"])),\n    ]),\n    key=['chunk_size', 'dstate', 'hdim'],\n)\n@triton.jit\ndef _chunk_scan_bwd_ddAcs_prev_kernel(\n    # Pointers to matrices\n    dout_ptr, prev_states_ptr, C_ptr, dA_cumsum_ptr, seq_idx_ptr,\n    ddA_cumsum_ptr,\n    # Matrix dimensions\n    chunk_size, dstate, hdim,\n    batch, seqlen, nchunks, nheads_ngroups_ratio,\n    # Strides\n    stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,\n    stride_prev_states_batch, stride_prev_states_chunk, stride_prev_states_head, stride_prev_states_hdim, stride_prev_states_dstate,\n    stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate,\n    stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,\n    stride_seq_idx_batch, stride_seq_idx_seqlen,\n    stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize,\n    # Meta-parameters\n    HAS_SEQ_IDX: tl.constexpr,\n    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n):\n    pid_bc = tl.program_id(axis=1)\n    pid_c = pid_bc // batch\n    pid_b = pid_bc - pid_c * batch\n    pid_h = tl.program_id(axis=2)\n    num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)\n    pid_m = tl.program_id(axis=0) // num_pid_n\n    pid_n = tl.program_id(axis=0) % num_pid_n\n    dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head\n    prev_states_ptr += pid_b * stride_prev_states_batch + pid_c * stride_prev_states_chunk + pid_h * stride_prev_states_head\n    C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + (pid_h // nheads_ngroups_ratio) * stride_C_head\n    ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head\n    dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head\n    if HAS_SEQ_IDX:\n        seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen\n\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    offs_k = tl.arange(0, BLOCK_SIZE_K)\n    dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim)\n    prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_prev_states_dstate + offs_k[:, None] * stride_prev_states_hdim)\n    C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_n[None, :] * stride_C_dstate)\n    dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize\n\n    chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n    dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0)\n    prev_states = tl.load(prev_states_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0)\n    prev_states = prev_states.to(dout_ptrs.dtype.element_ty)\n    acc = tl.dot(dout, prev_states)\n    c = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)\n    ddA_cs = tl.sum(acc * c, axis=1)\n    dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)\n    if not HAS_SEQ_IDX:\n        scale = tl.exp(dA_cs_m)\n    if HAS_SEQ_IDX:\n        seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0)\n        seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)\n        scale =  tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0)\n    ddA_cs *= scale\n    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize\n    tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)\n\n\ndef _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D=None, z=None, seq_idx=None):\n    batch, seqlen, nheads, headdim = x.shape\n    _, _, nchunks, chunk_size = dt.shape\n    _, _, ngroups, dstate = C.shape\n    assert nheads % ngroups == 0\n    assert C.shape == (batch, seqlen, ngroups, dstate)\n    assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)\n    if z is not None:\n        assert z.shape == x.shape\n    if D is not None:\n        assert D.shape == (nheads, headdim) or D.shape == (nheads,)\n    assert dt.shape == (batch, nheads, nchunks, chunk_size)\n    assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)\n    assert states.shape == (batch, nchunks, nheads, headdim, dstate)\n    if seq_idx is not None:\n        assert seq_idx.shape == (batch, seqlen)\n    # Allocates output.\n    out = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype)\n    if z is not None:\n        out_x = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype)\n        assert out_x.stride() == out.stride()\n    else:\n        out_x = None\n    grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),\n                    batch * nchunks, nheads)\n    z_strides = ((z.stride(0), z.stride(1), z.stride(2), z.stride(3))\n                  if z is not None else (0, 0, 0, 0))\n    _chunk_scan_fwd_kernel[grid](\n        cb, x, z, out, out_x, dt, dA_cumsum, seq_idx, C, states, D,\n        chunk_size, headdim, dstate,\n        batch, seqlen, nheads // ngroups,\n        cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4),\n        x.stride(0), x.stride(1), x.stride(2), x.stride(3),\n        z_strides[0], z_strides[1], z_strides[2], z_strides[3],\n        out.stride(0), out.stride(1), out.stride(2), out.stride(3),\n        dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),\n        dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),\n        *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n        C.stride(0), C.stride(1), C.stride(2), C.stride(3),\n        states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4),\n        D.stride(0) if D is not None else 0,\n        True,\n        D is not None,\n        D.dim() == 2 if D is not None else True,\n        BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),\n        HAS_Z=z is not None,\n        HAS_SEQ_IDX=seq_idx is not None,\n        IS_TRITON_22=TRITON_22,\n    )\n    return out, out_x\n\n\ndef _chunk_scan_fwd_wip(cb, x, dt, dA_cumsum, C, B, states, D=None, z=None, seq_idx=None):\n    batch, seqlen, nheads, headdim = x.shape\n    _, _, nchunks, chunk_size = dt.shape\n    _, _, ngroups, dstate = C.shape\n    assert nheads % ngroups == 0\n    assert C.shape == (batch, seqlen, ngroups, dstate)\n    assert B.shape == C.shape\n    assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)\n    if z is not None:\n        assert z.shape == x.shape\n    if D is not None:\n        assert D.shape == (nheads, headdim) or D.shape == (nheads,)\n    assert dt.shape == (batch, nheads, nchunks, chunk_size)\n    assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)\n    assert states.shape == (batch, nchunks, nheads, headdim, dstate)\n    if seq_idx is not None:\n        assert seq_idx.shape == (batch, seqlen)\n    # Allocates output.\n    out = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype)\n    if z is not None:\n        out_x = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype)\n        assert out_x.stride() == out.stride()\n    else:\n        out_x = None\n    grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_N']), batch * nchunks, nheads)\n    z_strides = ((z.stride(0), z.stride(1), z.stride(2), z.stride(3))\n                  if z is not None else (0, 0, 0, 0))\n    _chunk_scan_fwd_kernel_wip[grid](\n        cb, x, z, out, out_x, dt, dA_cumsum, seq_idx, C, B, states, D,\n        chunk_size, headdim, dstate,\n        batch, seqlen, nheads // ngroups,\n        cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4),\n        x.stride(0), x.stride(1), x.stride(2), x.stride(3),\n        z_strides[0], z_strides[1], z_strides[2], z_strides[3],\n        out.stride(0), out.stride(1), out.stride(2), out.stride(3),\n        dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),\n        dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),\n        *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n        C.stride(0), C.stride(1), C.stride(2), C.stride(3),\n        B.stride(0), B.stride(1), B.stride(2), B.stride(3),\n        states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4),\n        D.stride(0) if D is not None else 0,\n        D is not None,\n        D.dim() == 2 if D is not None else True,\n        BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),\n        BLOCK_SIZE_M=128,\n        HAS_Z=z is not None,\n        HAS_SEQ_IDX=seq_idx is not None,\n    )\n    return out, out_x\n\n\ndef _chunk_scan_bwd_dz(x, z, out, dout, chunk_size, has_ddAcs=True, D=None, dz=None, recompute_output=False):\n    batch, seqlen, nheads, headdim = x.shape\n    assert z.shape == x.shape\n    assert out.shape == x.shape\n    assert dout.shape == out.shape\n    nchunks = math.ceil(seqlen / chunk_size)\n    if D is not None:\n        assert D.shape == (nheads, headdim) or D.shape == (nheads,)\n        assert D.stride(-1) == 1\n    if has_ddAcs:\n        ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32)\n    if D is not None:\n        BLOCK_SIZE_min = 32\n        dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads,\n                         headdim if D.dim() == 2 else 1, device=D.device, dtype=torch.float32)\n    else:\n        dD = None\n    if dz is not None:\n        assert dz.shape == z.shape\n    else:\n        dz = torch.empty_like(z)\n    if recompute_output:\n        outz = torch.empty_like(x)\n    dout_x = torch.empty_like(dout)\n    dD_strides = ((dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4))\n                    if D is not None else (0, 0, 0, 0, 0))\n    grid_dz = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads)\n    with torch.cuda.device(x.device.index):\n        _chunk_scan_bwd_dz_kernel[grid_dz](\n            dout, out, z, x, D, outz if recompute_output else None,\n            dz, dout_x, dD, ddA_cumsum if has_ddAcs else None,\n            chunk_size, headdim,\n            batch, seqlen,\n            dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),\n            out.stride(0), out.stride(1), out.stride(2), out.stride(3),\n            z.stride(0), z.stride(1), z.stride(2), z.stride(3),\n            x.stride(0), x.stride(1), x.stride(2), x.stride(3),\n            D.stride(0) if D is not None else 0,\n            *((outz.stride(0), outz.stride(1), outz.stride(2), outz.stride(3)) if recompute_output else (0, 0, 0, 0)),\n            dz.stride(0), dz.stride(1), dz.stride(2), dz.stride(3),\n            dout_x.stride(0), dout_x.stride(1), dout_x.stride(2), dout_x.stride(3),\n            dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4],\n            *((ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3))\n              if has_ddAcs else (0, 0, 0, 0)),\n            D is not None,\n            D.dim() == 2 if D is not None else True,\n            has_ddAcs,\n            BLOCK_SIZE_N=max(triton.next_power_of_2(headdim), 16),\n            RECOMPUTE_OUTPUT=recompute_output,\n        )\n    if D is not None:\n        BLOCK_SIZE_actual = _chunk_scan_bwd_dz_kernel.best_config.kwargs[\"BLOCK_SIZE_M\"]\n        n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual\n        dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype)\n        if D.dim() == 1:\n            dD = rearrange(dD, \"h 1 -> h\")\n    return_vals = (dz, dout_x, dD, ddA_cumsum) if has_ddAcs else (dz, dout_x, dD)\n    return return_vals if not recompute_output else (*return_vals, outz)\n\n\ndef _chunk_scan_bwd_dstates(C, dA_cumsum, dout, seq_idx=None, dtype=None):\n    batch, seqlen, nheads, headdim = dout.shape\n    _, _, nchunks, chunk_size = dA_cumsum.shape\n    _, _, ngroups, dstate = C.shape\n    assert nheads % ngroups == 0\n    assert C.shape == (batch, seqlen, ngroups, dstate)\n    assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)\n    if seq_idx is not None:\n        assert seq_idx.shape == (batch, seqlen)\n    dtype = C.dtype if dtype is None else dtype\n    dprev_states = torch.empty(batch, nchunks, nheads, headdim, dstate, device=C.device, dtype=dtype)\n    grid_dstates = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']),\n                            batch * nchunks, nheads)\n    with torch.cuda.device(C.device.index):\n        _chunk_scan_bwd_dstates_kernel[grid_dstates](\n            dout, C, dprev_states, dA_cumsum, seq_idx,\n            headdim, dstate, chunk_size,\n            batch, seqlen, nchunks, nheads // ngroups,\n            dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),\n            C.stride(0), C.stride(1), C.stride(2), C.stride(3),\n            dprev_states.stride(0), dprev_states.stride(1), dprev_states.stride(2), dprev_states.stride(3), dprev_states.stride(4),\n            dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),\n            *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n            HAS_SEQ_IDX=seq_idx is not None,\n        )\n    return dprev_states\n\n\ndef _chunk_scan_bwd_dC(prev_states, dA_cumsum, dout, seq_idx=None, C=None, ngroups=1):\n    batch, nchunks, nheads, headdim, dstate = prev_states.shape\n    _, seqlen, _, _ = dout.shape\n    _, _, _, chunk_size = dA_cumsum.shape\n    assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate)\n    assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)\n    assert dout.shape == (batch, seqlen, nheads, headdim)\n    if seq_idx is not None:\n        assert seq_idx.shape == (batch, seqlen)\n    deterministic = use_deterministic_mode()\n    if C is not None:\n        assert C.shape == (batch, seqlen, ngroups, dstate)\n        C_strides = (C.stride(0), C.stride(1), C.stride(2), C.stride(3))\n        tile_count = math.ceil(dstate / _CHUNK_SCAN_BWD_DC_MIN_BLOCK_N)\n        ddA_cumsum_prev, stride_ddA_tile = alloc_tile_workspace(\n            (batch, nheads, nchunks, chunk_size),\n            tile_count,\n            torch.float32,\n            dout.device,\n            deterministic,\n            zero_init=True,\n        )\n        ddA_cumsum_prev_strides = (\n            ddA_cumsum_prev.stride(0),\n            ddA_cumsum_prev.stride(2),\n            ddA_cumsum_prev.stride(1),\n            ddA_cumsum_prev.stride(3),\n        )\n    else:\n        C_strides = (0, 0, 0, 0)\n        ddA_cumsum_prev = None\n        ddA_cumsum_prev_strides = (0, 0, 0, 0)\n        stride_ddA_tile = 0\n    nheads_ngroups_ratio = nheads // ngroups\n    sm_count = torch.cuda.get_device_properties(dout.device).multi_processor_count\n    nheads_per_program = max(min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1)\n    nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program)\n    dC = torch.empty(batch, seqlen, nsplits, ngroups, dstate, device=dout.device, dtype=torch.float32)\n    grid_dc = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']),\n                        batch * nchunks, nsplits * ngroups)\n    with torch.cuda.device(dout.device.index):\n        _chunk_scan_bwd_dc_kernel[grid_dc](\n            dout, prev_states, C, dA_cumsum, seq_idx, dC, ddA_cumsum_prev,\n            chunk_size, dstate, headdim,\n            batch, seqlen, nheads, nheads_per_program, ngroups,\n            dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),\n            prev_states.stride(0), prev_states.stride(1), prev_states.stride(2), prev_states.stride(3), prev_states.stride(4),\n            *C_strides,\n            dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),\n            *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n            dC.stride(0), dC.stride(1), dC.stride(2), dC.stride(3), dC.stride(4),\n            *ddA_cumsum_prev_strides, stride_ddA_tile,\n            HAS_DDA_CS=ddA_cumsum_prev is not None,\n            HAS_SEQ_IDX=seq_idx is not None,\n            DETERMINISTIC_REDUCTION=deterministic,\n            BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16),\n        )\n    dC = dC.sum(2)\n    if ddA_cumsum_prev is not None:\n        ddA_cumsum_prev = finalize_tile_workspace(ddA_cumsum_prev, deterministic)\n    return dC if C is None else (dC, ddA_cumsum_prev)\n\n\ndef _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=None, CB=None, ngroups=1):\n    batch, seqlen, nheads, headdim = x.shape\n    _, _, nchunks, chunk_size = dt.shape\n    assert dt.shape == (batch, nheads, nchunks, chunk_size)\n    assert dA_cumsum.shape == dt.shape\n    assert dout.shape == x.shape\n    if seq_idx is not None:\n        assert seq_idx.shape == (batch, seqlen)\n    if CB is not None:\n        assert CB.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)\n        CB_strides = (CB.stride(0), CB.stride(1), CB.stride(2), CB.stride(3), CB.stride(4))\n        BLOCK_SIZE_M_min = 16\n        ddA_cumsum = torch.empty(batch, nheads, nchunks, triton.cdiv(chunk_size, BLOCK_SIZE_M_min),\n                                chunk_size, device=x.device, dtype=torch.float32)\n        ddA_cumsum_strides = (ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), ddA_cumsum.stride(4))\n    else:\n        CB_strides = (0, 0, 0, 0, 0)\n        ddA_cumsum = None\n        ddA_cumsum_strides = (0, 0, 0, 0, 0)\n    nheads_ngroups_ratio = nheads // ngroups\n    sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count\n    nheads_per_program = max(min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1)\n    nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program)\n    dcb = torch.empty(batch, nchunks, nsplits, ngroups, chunk_size, chunk_size, device=x.device, dtype=torch.float32)\n    grid_dcb = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(chunk_size, META['BLOCK_SIZE_N']),\n                        batch * nchunks, nsplits * ngroups)\n    with torch.cuda.device(x.device.index):\n        _chunk_scan_bwd_dcb_kernel[grid_dcb](\n            x, dout, CB, dt, dA_cumsum, seq_idx, dcb, ddA_cumsum,\n            chunk_size, headdim,\n            batch, seqlen, nheads, nheads_per_program, ngroups,\n            x.stride(0), x.stride(1), x.stride(2), x.stride(3),\n            dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),\n            *CB_strides,\n            dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),\n            dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),\n            *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n            dcb.stride(0), dcb.stride(1), dcb.stride(2), dcb.stride(3), dcb.stride(4), dcb.stride(5),\n            *ddA_cumsum_strides,\n            HAS_DDA_CS=ddA_cumsum is not None,\n            HAS_SEQ_IDX=seq_idx is not None,\n            BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16),\n        )\n    dcb = dcb.sum(2)\n    if ddA_cumsum is not None:\n        BLOCK_SIZE_M_actual = _chunk_scan_bwd_dcb_kernel.best_config.kwargs[\"BLOCK_SIZE_M\"]\n        n_valid_blocks = (chunk_size + BLOCK_SIZE_M_actual - 1) // BLOCK_SIZE_M_actual\n        ddA_cumsum = ddA_cumsum[:, :, :, :n_valid_blocks].sum(dim=3)\n    return dcb if CB is None else (dcb, ddA_cumsum)\n\n\ndef _chunk_scan_bwd_dx(cb, x, dt, dA_cumsum, dout, D=None):\n    batch, seqlen, nheads, headdim = x.shape\n    _, _, nchunks, chunk_size = dt.shape\n    ngroups = cb.shape[2]\n    assert nheads % ngroups == 0\n    assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)\n    assert dt.shape == (batch, nheads, nchunks, chunk_size)\n    assert dA_cumsum.shape == dt.shape\n    assert dout.shape == x.shape\n    # if D is not None:\n    #     BLOCK_SIZE_M_min = 32\n    #     dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_M_min), batch, nchunks, nheads, headdim, device=D.device, dtype=torch.float32)\n    # else:\n    #     dD = None\n    dx = torch.empty_like(x)\n    deterministic = use_deterministic_mode()\n    tile_count = math.ceil(headdim / _CHUNK_SCAN_BWD_DX_MIN_BLOCK_N)\n    ddt, stride_ddt_tile = alloc_tile_workspace(\n        (batch, nheads, nchunks, chunk_size),\n        tile_count,\n        torch.float32,\n        dout.device,\n        deterministic,\n        zero_init=True,\n    )\n    grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),\n                        batch * nchunks, nheads)\n    with torch.cuda.device(x.device.index):\n        _chunk_scan_bwd_dx_kernel[grid_dx](\n            x, cb, dout, dt, dA_cumsum, D, dx, ddt, # dD,\n            chunk_size, headdim,\n            batch, seqlen, nheads // ngroups,\n            x.stride(0), x.stride(1), x.stride(2), x.stride(3),\n            cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(-1), cb.stride(-2),\n            dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),\n            dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),\n            dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),\n            D.stride(0) if D is not None else 0,\n            dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3),\n            ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3), stride_ddt_tile,\n            # dD.stride(1) if dD is not None else 0, dD.stride(2) if dD is not None else 0, dD.stride(3) if dD is not None else 0, dD.stride(4) if dD is not None else 0, dD.stride(0) if dD is not None else 0,\n            D is not None,\n            D.dim() == 2 if D is not None else True,\n            DETERMINISTIC_REDUCTION=deterministic,\n        )\n    # if D is not None:\n    #     BLOCK_SIZE_actual = _chunk_scan_bwd_dx_kernel.best_config.kwargs[\"BLOCK_SIZE_M\"]\n    #     n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual\n    #     dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype)\n    ddt = finalize_tile_workspace(ddt, deterministic)\n    return dx, ddt\n\n\ndef _chunk_scan_bwd_ddAcs_unstable(x, dt, out, dout, ddt, D=None, subtract_ddtdt=True):\n    \"\"\"Not numerically stable and should not be used. Leaving here for reference.\n    \"\"\"\n\n    batch, seqlen, nheads, headdim = x.shape\n    _, _, nchunks, chunk_size = dt.shape\n    assert dt.shape == (batch, nheads, nchunks, chunk_size)\n    assert ddt.shape == dt.shape\n    assert out.shape == x.shape\n    assert dout.shape == x.shape\n    if D is not None:\n        assert D.shape == (nheads, headdim) or D.shape == (nheads,)\n    ddA_cumsum = torch.empty_like(dt)\n    grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads)\n    if D is not None:  # Triton gives wrong results if we write to the same location\n        BLOCK_SIZE_min = 32\n        dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads,\n                         headdim if D.dim() == 2 else 1, device=D.device, dtype=torch.float32)\n    else:\n        dD = None\n    dD_strides = ((dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4))\n                    if D is not None else (0, 0, 0, 0, 0))\n    with torch.cuda.device(x.device.index):\n        _chunk_scan_bwd_ddAcs_unstable_kernel[grid_ddtcs](\n            dout, out, dt, ddt, x, D, ddA_cumsum, dD,\n            chunk_size, headdim,\n            batch, seqlen,\n            dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),\n            out.stride(0), out.stride(1), out.stride(2), out.stride(3),\n            dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),\n            ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3),\n            x.stride(0), x.stride(1), x.stride(2), x.stride(3),\n            D.stride(0) if D is not None else 0,\n            ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3),\n            dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4],\n            D is not None,\n            D.dim() == 2 if D is not None else True,\n            subtract_ddtdt,\n            BLOCK_SIZE_N=max(triton.next_power_of_2(headdim), 16),\n        )\n    if D is not None:\n        BLOCK_SIZE_actual = _chunk_scan_bwd_ddAcs_unstable_kernel.best_config.kwargs[\"BLOCK_SIZE_M\"]\n        n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual\n        dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype)\n        if D.dim() == 1:\n            dD = rearrange(dD, \"h 1 -> h\")\n    return ddA_cumsum, dD\n\n\ndef _chunk_scan_bwd_ddAcs_stable_old(x, dt, dA_cumsum, dout, cb):\n    batch, seqlen, nheads, headdim = x.shape\n    _, _, nchunks, chunk_size = dt.shape\n    assert dt.shape == (batch, nheads, nchunks, chunk_size)\n    assert dout.shape == x.shape\n    assert dA_cumsum.shape == dt.shape\n    ngroups = cb.shape[2]\n    assert nheads % ngroups == 0\n    assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)\n    BLOCK_SIZE_M_min = 16\n    ddA_cumsum = torch.empty(batch, nheads, nchunks, triton.cdiv(chunk_size, BLOCK_SIZE_M_min),\n                             chunk_size, device=x.device, dtype=torch.float32)\n    grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads)\n    with torch.cuda.device(x.device.index):\n        _chunk_scan_bwd_ddAcs_stable_kernel_old[grid_ddtcs](\n            x, dout, dt, dA_cumsum, cb, ddA_cumsum,\n            chunk_size, headdim,\n            batch, seqlen, nheads // ngroups,\n            x.stride(0), x.stride(1), x.stride(2), x.stride(3),\n            dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),\n            dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),\n            dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),\n            cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4),\n            ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), ddA_cumsum.stride(4),\n            BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16),\n            BLOCK_SIZE_N=max(triton.next_power_of_2(chunk_size), 16),\n        )\n    BLOCK_SIZE_M_actual = _chunk_scan_bwd_ddAcs_stable_kernel_old.best_config.kwargs[\"BLOCK_SIZE_M\"]\n    n_valid_blocks = (chunk_size + BLOCK_SIZE_M_actual - 1) // BLOCK_SIZE_M_actual\n    ddA_cumsum = ddA_cumsum[:, :, :, :n_valid_blocks].sum(dim=3)\n    return ddA_cumsum\n\n\ndef _chunk_scan_bwd_ddAcs_stable(x, dt, dA_cumsum, dout, cb):\n    batch, seqlen, nheads, headdim = x.shape\n    _, _, nchunks, chunk_size = dt.shape\n    assert dt.shape == (batch, nheads, nchunks, chunk_size)\n    assert dout.shape == x.shape\n    assert dA_cumsum.shape == dt.shape\n    ngroups = cb.shape[2]\n    assert nheads % ngroups == 0\n    assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)\n    BLOCK_SIZE_M_min = 32\n    ddA_cumsum = torch.empty(batch, nheads, nchunks, triton.cdiv(chunk_size, BLOCK_SIZE_M_min),\n                             chunk_size, device=x.device, dtype=torch.float32)\n    grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads)\n    with torch.cuda.device(x.device.index):\n        _chunk_scan_bwd_ddAcs_stable_kernel[grid_ddtcs](\n            x, dout, dt, dA_cumsum, cb, ddA_cumsum,\n            chunk_size, headdim,\n            batch, seqlen, nheads // ngroups,\n            x.stride(0), x.stride(1), x.stride(2), x.stride(3),\n            dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),\n            dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),\n            dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),\n            cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4),\n            ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), ddA_cumsum.stride(4),\n            BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16),\n        )\n    BLOCK_SIZE_M_actual = _chunk_scan_bwd_ddAcs_stable_kernel.best_config.kwargs[\"BLOCK_SIZE_M\"]\n    n_valid_blocks = (chunk_size + BLOCK_SIZE_M_actual - 1) // BLOCK_SIZE_M_actual\n    ddA_cumsum = ddA_cumsum[:, :, :, :n_valid_blocks].sum(dim=3)\n    return ddA_cumsum\n\n\ndef _chunk_scan_bwd_ddAcs_prev(prev_states, C, dout, dA_cumsum, seq_idx=None):\n    batch, nchunks, nheads, headdim, dstate = prev_states.shape\n    _, seqlen, _, _ = dout.shape\n    _, _, _, chunk_size = dA_cumsum.shape\n    assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate)\n    assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)\n    assert dout.shape == (batch, seqlen, nheads, headdim)\n    ngroups = C.shape[2]\n    assert nheads % ngroups == 0\n    assert C.shape == (batch, seqlen, ngroups, dstate)\n    if seq_idx is not None:\n        assert seq_idx.shape == (batch, seqlen)\n    ddA_cumsum_prev = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32)\n    grid_ddAcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']),\n                          batch * nchunks, nheads)\n    with torch.cuda.device(dout.device.index):\n        _chunk_scan_bwd_ddAcs_prev_kernel[grid_ddAcs](\n            dout, prev_states, C, dA_cumsum, seq_idx, ddA_cumsum_prev,\n            chunk_size, dstate, headdim,\n            batch, seqlen, nchunks, nheads // ngroups,\n            dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),\n            prev_states.stride(0), prev_states.stride(1), prev_states.stride(2), prev_states.stride(3), prev_states.stride(4),\n            C.stride(0), C.stride(1), C.stride(2), C.stride(3),\n            dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),\n            *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n            ddA_cumsum_prev.stride(0), ddA_cumsum_prev.stride(2), ddA_cumsum_prev.stride(1), ddA_cumsum_prev.stride(3),\n            HAS_SEQ_IDX=seq_idx is not None,\n            BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16),\n        )\n    return ddA_cumsum_prev\n\n\nclass ChunkScanFn(torch.autograd.Function):\n\n    @staticmethod\n    def forward(ctx, B, C, x, dt, dA_cumsum, prev_states, D=None, z=None):\n        # Check constraints.\n        batch, seqlen, nheads, headdim = x.shape\n        _, _, ngroups, dstate = B.shape\n        assert B.shape == (batch, seqlen, ngroups, dstate)\n        _, _, nchunks, chunk_size = dt.shape\n        assert seqlen == nchunks * chunk_size\n        assert C.shape == B.shape\n        if z is not None:\n            assert z.shape == x.shape\n        if D is not None:\n            assert D.shape == (nheads, headdim) or D.shape == (nheads,)\n        assert dt.shape == (batch, nheads, nchunks, chunk_size)\n        assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)\n        assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate)\n        if B.stride(-1) != 1:\n            B = B.contiguous()\n        if C.stride(-1) != 1:\n            C = C.contiguous()\n        if x.stride(-1) != 1 and x.stride(1) != 1:  # Either M or K dimension should be contiguous\n            x = x.contiguous()\n        if z is not None and z.stride(-1) != 1 and z.stride(1) != 1:  # Either M or K dimension should be contiguous\n            z = z.contiguous()\n        if D is not None and D.stride(-1) != 1:\n            D = D.contiguous()\n        CB = _bmm_chunk_fwd(C, B, chunk_size)\n        out, out_x = _chunk_scan_fwd(CB, x, dt, dA_cumsum, C, prev_states, D=D, z=z)\n        ctx.save_for_backward(out if z is None else out_x, B, C, CB, x, dt, dA_cumsum, prev_states, D, z)\n        return out\n\n    @staticmethod\n    def backward(ctx, dout):\n        if dout.stride(-1) != 1:\n            dout = dout.contiguous()\n        out, B, C, CB, x, dt, dA_cumsum, prev_states, D, z = ctx.saved_tensors\n        batch, seqlen, nheads, headdim = x.shape\n        _, _, nchunks, chunk_size = dt.shape\n        _, _, ngroups, dstate = B.shape\n        assert dout.shape == (batch, seqlen, nheads, headdim)\n        if z is not None:\n            dz, dout, dD, ddA_cumsum = _chunk_scan_bwd_dz(x, z, out, dout, chunk_size=chunk_size, D=D)\n        else:\n            dz = None\n        dprev_states = _chunk_scan_bwd_dstates(C, dA_cumsum, dout, dtype=prev_states.dtype)\n        dC = _chunk_scan_bwd_dC(prev_states, dA_cumsum, dout, ngroups=ngroups)\n        dC = dC.to(C.dtype)\n        dCB = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, ngroups=ngroups)\n        dCB = dCB.to(CB.dtype)\n        dB = _bmm_chunk_bwd(C, dCB)\n        dC = _bmm_chunk_bwd(B, rearrange(dCB, \"... l s -> ... s l\"), residual=dC)\n        dx, ddt = _chunk_scan_bwd_dx(CB, x, dt, dA_cumsum, dout, D=D)\n        # Formula for ddA_cumsum, assuming out is the output of the forward pass before adding x * D.\n        # ddA_cumsum = torch.einsum(\"bclhp,bclhp->bhcl\", out.float(), dout.float()) - ddt * dt\n        if z is not None:\n            ddA_cumsum -= ddt * dt\n        else: # If z is not None, we already calculated ddA_cumsum and dD when computing dz\n            ddA_cumsum, dD = _chunk_scan_bwd_ddAcs_unstable(x, dt, out, dout, ddt, D=D)\n        ddA_cumsum = ddA_cumsum.to(dA_cumsum.dtype)\n        return dB, dC, dx, ddt, ddA_cumsum, dprev_states, dD, dz\n\n\ndef chunk_scan(B, C, x, dt, dA_cumsum, prev_states, D=None, z=None):\n    \"\"\"\n    prev_states contains the initial_states at index 0, and the state for the next-to-last chunk at index -1.\n    Argument:\n        B: (batch, seqlen, ngroups, dstate)\n        C: (batch, seqlen, ngroups, dstate)\n        x: (batch, seqlen, nheads, headdim)\n        dt: (batch, nheads, nchunks, chunk_size)\n        dA_cumsum: (batch, nheads, nchunks, chunk_size)\n        prev_states: (batch, nchunks, nheads, headdim, dstate)\n        D: (nheads, headdim) or (nheads,)\n        z: (batch, seqlen, nheads, headdim)\n    Return:\n        out: (batch, seqlen, nheads, headdim)\n    \"\"\"\n    return ChunkScanFn.apply(B, C, x, dt, dA_cumsum, prev_states, D, z)\n\n\ndef chunk_scan_ref(B, C, x, dt, dA_cumsum, prev_states, D=None, z=None):\n    \"\"\"\n    Argument:\n        B: (batch, seqlen, ngroups, dstate)\n        C: (batch, seqlen, ngroups, dstate)\n        x: (batch, seqlen, nheads, headdim)\n        dt: (batch, nheads, nchunks, chunk_size)\n        dA_cumsum: (batch, nheads, nchunks, chunk_size)\n        prev_states: (batch, nchunks, nheads, headdim, dstate)\n        D: (nheads, headdim) or (nheads,)\n        z: (batch, seqlen, nheads, headdim)\n    Return:\n        out: (batch, seqlen, nheads, headdim)\n    \"\"\"\n    batch, seqlen, nheads, headdim = x.shape\n    _, _, ngroups, dstate = B.shape\n    assert B.shape == (batch, seqlen, ngroups, dstate)\n    _, _, nchunks, chunk_size = dt.shape\n    assert seqlen == nchunks * chunk_size\n    assert C.shape == B.shape\n    B = repeat(B, \"b l g d -> b l (g h) d\", h=nheads // ngroups)\n    C = repeat(C, \"b l g d -> b l (g h) d\", h=nheads // ngroups)\n    CB = torch.einsum(\"bclhn,bcshn->bchls\", rearrange(C, \"b (c l) h n -> b c l h n\", c=nchunks),\n                      rearrange(B, \"b (c s) h n -> b c s h n\", c=nchunks))\n    # (batch, nheads, nchunks, chunksize, chunksize)\n    dt_segment_sum = dA_cumsum[:, :, :, :, None] - dA_cumsum[:, :, :, None, :]\n    decay = torch.exp(dt_segment_sum)\n    scores_decay = CB * rearrange(decay, \"b h c l s -> b c h l s\")\n    causal_mask = torch.tril(torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0)\n    scores_decay = scores_decay.masked_fill(~causal_mask, 0)\n    out = torch.einsum('bchls,bhcs,bcshp->bclhp', scores_decay.to(x.dtype), dt.to(x.dtype),\n                       rearrange(x, \"b (c s) h p -> b c s h p\", c=nchunks))\n    state_decay_out = torch.exp(rearrange(dA_cumsum, \"b h c l -> b c l h 1\"))\n    out_prev = torch.einsum('bclhn,bchpn->bclhp', rearrange(C, \"b (c l) h n -> b c l h n\", c=nchunks),\n                            prev_states.to(C.dtype)) * state_decay_out\n    out = out + out_prev\n    out = rearrange(out, \"b c l h p -> b (c l) h p\")\n    if D is not None:\n        if D.dim() == 1:\n            D = rearrange(D, \"h -> h 1\")\n        out = out + x * D\n    return out if z is None else out * F.silu(z)\n"
  },
  {
    "path": "mamba_ssm/ops/triton/ssd_chunk_state.py",
    "content": "# Copyright (c) 2024, Tri Dao, Albert Gu.\n\n\"\"\"We want triton==2.1.0 or 2.2.0 for this\n\"\"\"\n\nimport math\nimport torch\nimport torch.nn.functional as F\n\nimport triton\nimport triton.language as tl\n\nfrom einops import rearrange, repeat\n\nfrom mamba_ssm.ops.triton.softplus import softplus\nfrom mamba_ssm.utils.determinism import (\n    alloc_tile_workspace,\n    finalize_tile_workspace,\n    use_deterministic_mode,\n    autotune_configs,\n)\n\n\ndef init_to_zero(names):\n    return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None]\n\n@triton.autotune(\n    configs=autotune_configs([\n        triton.Config({'BLOCK_SIZE_H': 1}),\n        triton.Config({'BLOCK_SIZE_H': 2}),\n        triton.Config({'BLOCK_SIZE_H': 4}),\n        triton.Config({'BLOCK_SIZE_H': 8}),\n        triton.Config({'BLOCK_SIZE_H': 16}),\n        triton.Config({'BLOCK_SIZE_H': 32}),\n        triton.Config({'BLOCK_SIZE_H': 64}),\n    ]),\n    key=['chunk_size', 'nheads'],\n)\n@triton.jit\ndef _chunk_cumsum_fwd_kernel(\n    # Pointers to matrices\n    dt_ptr, A_ptr, dt_bias_ptr, dt_out_ptr, dA_cumsum_ptr,\n    # Matrix dimension\n    batch, seqlen, nheads, chunk_size,\n    dt_min, dt_max,\n    # Strides\n    stride_dt_batch, stride_dt_seqlen, stride_dt_head,\n    stride_A_head,\n    stride_dt_bias_head,\n    stride_dt_out_batch, stride_dt_out_chunk, stride_dt_out_head, stride_dt_out_csize,\n    stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,\n    # Meta-parameters\n    DT_SOFTPLUS: tl.constexpr,\n    HAS_DT_BIAS: tl.constexpr,\n    BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_CHUNK: tl.constexpr,\n):\n    pid_b = tl.program_id(axis=0)\n    pid_c = tl.program_id(axis=1)\n    pid_h = tl.program_id(axis=2)\n    dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen\n    dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk\n    dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk\n\n    offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)\n    offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)\n    dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen)\n    A_ptrs = A_ptr + offs_h * stride_A_head\n    dt_out_ptrs = dt_out_ptr + (offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize)\n    dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize)\n    chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n\n    dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)\n    if HAS_DT_BIAS:\n        dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0).to(tl.float32)\n        dt += dt_bias[:, None]\n    if DT_SOFTPLUS:\n        dt = tl.where(dt <= 20.0, softplus(dt), dt)\n    # As of Triton 2.2.0, tl.clamp is not available yet\n    # dt = tl.clamp(dt, dt_min, dt_max)\n    dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)\n    dt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0)\n    tl.store(dt_out_ptrs, dt, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size))\n    A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)\n    dA = dt * A[:, None]\n    dA_cs = tl.cumsum(dA, axis=1)\n    tl.store(dA_cs_ptrs, dA_cs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size))\n\n\n@triton.autotune(\n    configs=autotune_configs([\n        triton.Config({'BLOCK_SIZE_H': 1}, pre_hook=init_to_zero([\"dA_ptr\", \"ddt_bias_ptr\"])),\n        triton.Config({'BLOCK_SIZE_H': 2}, pre_hook=init_to_zero([\"dA_ptr\", \"ddt_bias_ptr\"])),\n        triton.Config({'BLOCK_SIZE_H': 4}, pre_hook=init_to_zero([\"dA_ptr\", \"ddt_bias_ptr\"])),\n        triton.Config({'BLOCK_SIZE_H': 8}, pre_hook=init_to_zero([\"dA_ptr\", \"ddt_bias_ptr\"])),\n        triton.Config({'BLOCK_SIZE_H': 16}, pre_hook=init_to_zero([\"dA_ptr\", \"ddt_bias_ptr\"])),\n        triton.Config({'BLOCK_SIZE_H': 32}, pre_hook=init_to_zero([\"dA_ptr\", \"ddt_bias_ptr\"])),\n        triton.Config({'BLOCK_SIZE_H': 64}, pre_hook=init_to_zero([\"dA_ptr\", \"ddt_bias_ptr\"])),\n    ]),\n    key=['chunk_size', 'nheads'],\n)\n@triton.jit\ndef _chunk_cumsum_bwd_kernel(\n    # Pointers to matrices\n    ddA_ptr, ddt_out_ptr, dt_ptr, A_ptr, dt_bias_ptr,\n    ddt_ptr, dA_ptr, ddt_bias_ptr,\n    # Matrix dimensions\n    batch, seqlen, nheads, chunk_size,\n    dt_min, dt_max,\n    # Strides\n    stride_ddA_batch, stride_ddA_chunk, stride_ddA_head, stride_ddA_csize,\n    stride_ddt_out_batch, stride_ddt_out_chunk, stride_ddt_out_head, stride_ddt_out_csize,\n    stride_dt_batch, stride_dt_seqlen, stride_dt_head,\n    stride_A_head,\n    stride_dt_bias_head,\n    stride_ddt_batch, stride_ddt_seqlen, stride_ddt_head,\n    stride_dA_batch, stride_dA_chunk, stride_dA_head,\n    stride_ddt_bias_batch, stride_ddt_bias_chunk, stride_ddt_bias_head,\n    # Meta-parameters\n    DT_SOFTPLUS: tl.constexpr,\n    HAS_DT_BIAS: tl.constexpr,\n    BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_CHUNK: tl.constexpr,\n    DETERMINISTIC_REDUCTION: tl.constexpr,\n):\n    pid_b = tl.program_id(axis=0)\n    pid_c = tl.program_id(axis=1)\n    pid_h = tl.program_id(axis=2)\n    ddt_out_ptr += pid_b * stride_ddt_out_batch + pid_c * stride_ddt_out_chunk\n    ddA_ptr += pid_b * stride_ddA_batch + pid_c * stride_ddA_chunk\n    dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen\n    ddt_ptr += pid_b * stride_ddt_batch + pid_c * chunk_size * stride_ddt_seqlen\n\n    offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)\n    offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)\n    ddt_out_ptrs = ddt_out_ptr + (offs_h[:, None] * stride_ddt_out_head + offs_c[None, :] * stride_ddt_out_csize)\n    ddA_ptrs = ddA_ptr + (offs_h[:, None] * stride_ddA_head + offs_c[None, :] * stride_ddA_csize)\n    dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen)\n    ddt_ptrs = ddt_ptr + (offs_h[:, None] * stride_ddt_head + offs_c[None, :] * stride_ddt_seqlen)\n    A_ptrs = A_ptr + offs_h * stride_A_head\n    chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n\n    ddA = tl.load(ddA_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)\n    ddt_out = tl.load(ddt_out_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)\n    A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)\n    ddt = ddA * A[:, None] + ddt_out\n    dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)\n    if HAS_DT_BIAS:\n        dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0).to(tl.float32)\n        dt += dt_bias[:, None]\n    if DT_SOFTPLUS:\n        dt_presoftplus = dt\n        dt = tl.where(dt <= 20.0, softplus(dt), dt)\n    clamp_mask = (dt < dt_min) | (dt > dt_max)\n    # As of Triton 2.2.0, tl.clamp is not available yet\n    # dt = tl.clamp(dt, dt_min, dt_max)\n    dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)\n    dt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0)\n    ddt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), ddt, 0.0)\n    ddt = tl.where(clamp_mask, 0.0, ddt)\n    if DT_SOFTPLUS:\n        ddt = tl.where(dt_presoftplus <= 20.0, ddt * tl.sigmoid(dt_presoftplus), ddt)\n    tl.store(ddt_ptrs, ddt, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit))\n    dA = tl.sum(ddA * dt, axis=1)\n    dA_ptr += pid_b * stride_dA_batch + pid_c * stride_dA_chunk\n    if DETERMINISTIC_REDUCTION:\n        tl.store(dA_ptr + offs_h * stride_dA_head, dA, mask=offs_h < nheads)\n    else:\n        tl.atomic_add(dA_ptr + offs_h * stride_dA_head, dA, mask=offs_h < nheads)\n    if HAS_DT_BIAS:\n        ddt_bias = tl.sum(ddt, axis=1)\n        ddt_bias_ptr += pid_b * stride_ddt_bias_batch + pid_c * stride_ddt_bias_chunk\n        if DETERMINISTIC_REDUCTION:\n            tl.store(ddt_bias_ptr + offs_h * stride_ddt_bias_head, ddt_bias, mask=offs_h < nheads)\n        else:\n            tl.atomic_add(ddt_bias_ptr + offs_h * stride_ddt_bias_head, ddt_bias, mask=offs_h < nheads)\n\n\n@triton.autotune(\n    configs=autotune_configs([\n        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),\n    ]),\n    key=['hdim', 'dstate', 'chunk_size'],\n)\n@triton.jit\ndef _chunk_state_fwd_kernel(\n    # Pointers to matrices\n    x_ptr, b_ptr, states_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr,\n    # Matrix dimensions\n    hdim, dstate, chunk_size,\n    batch, seqlen, nheads_ngroups_ratio,\n    # Strides\n    stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,\n    stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate,\n    stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate,\n    stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,\n    stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,\n    stride_seq_idx_batch, stride_seq_idx_seqlen,\n    # Meta-parameters\n    HAS_SEQ_IDX: tl.constexpr,\n    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n):\n    pid_bc = tl.program_id(axis=1)\n    pid_c = pid_bc // batch\n    pid_b = pid_bc - pid_c * batch\n    pid_h = tl.program_id(axis=2)\n    num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)\n    pid_m = tl.program_id(axis=0) // num_pid_n\n    pid_n = tl.program_id(axis=0) % num_pid_n\n    b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head\n    x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head\n    dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head\n    dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head\n    if HAS_SEQ_IDX:\n        seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen\n\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    offs_k = tl.arange(0, BLOCK_SIZE_K)\n    x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen)\n    b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen)\n    dt_ptrs = dt_ptr + offs_k * stride_dt_csize\n    dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)\n    dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize\n    if HAS_SEQ_IDX:\n        seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen\n\n    chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n    if HAS_SEQ_IDX:\n        seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)\n\n    acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n    for k in range(0, chunk_size_limit, BLOCK_SIZE_K):\n        x = tl.load(x_ptrs, mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k), other=0.0)\n        b = tl.load(b_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)\n        dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32)\n        if HAS_SEQ_IDX:\n            seq_idx_k = tl.load(seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1)\n        dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32)\n        if not HAS_SEQ_IDX:\n            # scale = tl.exp((dA_cs_last - dA_cs_k)) * dt_k\n            scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k\n        else:\n            # scale = tl.where(seq_idx_k == seq_idx_last, tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0)\n            scale = tl.where((seq_idx_last >= 0) & (seq_idx_k == seq_idx_last), tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k, 0.0)\n        b *= scale[:, None]\n        b = b.to(x_ptr.dtype.element_ty)\n        acc += tl.dot(x, b)\n        x_ptrs += BLOCK_SIZE_K * stride_x_seqlen\n        b_ptrs += BLOCK_SIZE_K * stride_b_seqlen\n        dt_ptrs += BLOCK_SIZE_K * stride_dt_csize\n        dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize\n        if HAS_SEQ_IDX:\n            seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen\n    states = acc.to(states_ptr.dtype.element_ty)\n\n    states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate)\n    c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)\n    tl.store(states_ptrs, states, mask=c_mask)\n\n\n@triton.autotune(\n    configs=autotune_configs([\n        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero([\"ddt_ptr\", \"ddA_cumsum_ptr\"])),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\", \"ddA_cumsum_ptr\"])),\n        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\", \"ddA_cumsum_ptr\"])),\n        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\", \"ddA_cumsum_ptr\"])),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\", \"ddA_cumsum_ptr\"])),\n        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\", \"ddA_cumsum_ptr\"])),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\", \"ddA_cumsum_ptr\"])),\n        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\", \"ddA_cumsum_ptr\"])),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\", \"ddA_cumsum_ptr\"])),\n    ]),\n    key=['chunk_size', 'hdim', 'dstate'],\n)\n@triton.jit\ndef _chunk_state_bwd_dx_kernel(\n    # Pointers to matrices\n    x_ptr, b_ptr, dstates_ptr, dt_ptr, dA_cumsum_ptr,\n    dx_ptr, ddt_ptr, ddA_cumsum_ptr,\n    # Matrix dimensions\n    chunk_size, hdim, dstate,\n    batch, seqlen, nheads_ngroups_ratio,\n    # Strides\n    stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,\n    stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate,\n    stride_dstates_batch, stride_dstates_chunk, stride_states_head, stride_states_hdim, stride_states_dstate,\n    stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,\n    stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,\n    stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim,\n    stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize, stride_ddt_tile,\n    stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, stride_ddA_tile,\n    # Meta-parameters\n    DETERMINISTIC_REDUCTION: tl.constexpr,\n    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n    BLOCK_SIZE_DSTATE: tl.constexpr,\n):\n    pid_bc = tl.program_id(axis=1)\n    pid_c = pid_bc // batch\n    pid_b = pid_bc - pid_c * batch\n    pid_h = tl.program_id(axis=2)\n    num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)\n    pid_m = tl.program_id(axis=0) // num_pid_n\n    pid_n = tl.program_id(axis=0) % num_pid_n\n    x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head\n    b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head\n    dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_states_head\n    dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head\n    ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head + pid_n * stride_ddt_tile\n    ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head + pid_n * stride_ddA_tile\n    dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head\n\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n\n    chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n    # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128\n    offs_k = tl.arange(0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K)\n    b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate)\n    dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate)\n    if BLOCK_SIZE_DSTATE <= 128:\n        b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate), other=0.0)\n        dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0)\n        dstates = dstates.to(b_ptr.dtype.element_ty)\n        acc = tl.dot(b, dstates)\n    else:\n        acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n        for k in range(0, dstate, BLOCK_SIZE_K):\n            b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate - k), other=0.0)\n            dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0)\n            dstates = dstates.to(b_ptr.dtype.element_ty)\n            acc += tl.dot(b, dstates)\n            b_ptrs += BLOCK_SIZE_K * stride_b_dstate\n            dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate\n\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n\n    dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)\n    dt_ptrs = dt_ptr + offs_m * stride_dt_csize\n    dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize\n    dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)\n    dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)\n    # acc *= tl.exp(dA_cs_last - dA_cs_m)[:, None]\n    acc *= tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0))[:, None]\n\n    x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)\n    x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)\n    ddt = tl.sum(acc * x, axis=1)\n    ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize\n    if DETERMINISTIC_REDUCTION:\n        tl.store(ddt_ptrs, ddt, mask=offs_m < chunk_size)\n    else:\n        tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)\n    ddA_cs = -(ddt * dt_m)\n    ddA_cs_last = -tl.sum(ddA_cs)\n    ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize\n    if DETERMINISTIC_REDUCTION:\n        tl.store(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)\n    else:\n        tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)\n        tl.atomic_add(ddA_cumsum_ptr + (chunk_size - 1) * stride_ddA_cs_csize, ddA_cs_last)\n\n    dx = (acc * dt_m[:, None]).to(dx_ptr.dtype.element_ty)\n    dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head\n    dx_ptrs = dx_ptr + (offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim)\n    tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim))\n\n\n_CHUNK_STATE_BWD_DX_MIN_BLOCK_N = min(\n    cfg.kwargs['BLOCK_SIZE_N'] for cfg in _chunk_state_bwd_dx_kernel.configs\n)\n\n\n@triton.autotune(\n    configs=autotune_configs([\n        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero([\"ddA_cumsum_ptr\"])),\n        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero([\"ddA_cumsum_ptr\"])),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero([\"ddA_cumsum_ptr\"])),\n        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero([\"ddA_cumsum_ptr\"])),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero([\"ddA_cumsum_ptr\"])),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero([\"ddA_cumsum_ptr\"])),\n        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero([\"ddA_cumsum_ptr\"])),\n        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero([\"ddA_cumsum_ptr\"])),\n    ]),\n    key=['chunk_size', 'dstate', 'hdim'],\n)\n@triton.jit\ndef _chunk_state_bwd_db_kernel(\n    # Pointers to matrices\n    x_ptr, dstates_ptr, b_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr,\n    db_ptr, ddA_cumsum_ptr,\n    # Matrix dimensions\n    chunk_size, dstate, hdim,\n    batch, seqlen, nheads, nheads_per_program, ngroups,\n    # Strides\n    stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,\n    stride_dstates_batch, stride_dstates_chunk, stride_states_head, stride_states_hdim, stride_states_dstate,\n    stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate,\n    stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,\n    stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,\n    stride_seq_idx_batch, stride_seq_idx_seqlen,\n    stride_db_batch, stride_db_seqlen, stride_db_split, stride_db_group, stride_db_dstate,\n    stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, stride_ddA_tile,\n    # Meta-parameters\n    HAS_DDA_CS: tl.constexpr,\n    HAS_SEQ_IDX: tl.constexpr,\n    DETERMINISTIC_REDUCTION: tl.constexpr,\n    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n):\n    pid_bc = tl.program_id(axis=1)\n    pid_c = pid_bc // batch\n    pid_b = pid_bc - pid_c * batch\n    pid_sg = tl.program_id(axis=2)\n    pid_s = pid_sg // ngroups\n    pid_g = pid_sg - pid_s * ngroups\n    num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)\n    pid_m = tl.program_id(axis=0) // num_pid_n\n    pid_n = tl.program_id(axis=0) % num_pid_n\n    x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_x_head\n    db_ptr += pid_b * stride_db_batch + pid_c * chunk_size * stride_db_seqlen + pid_g * stride_db_group + pid_s * stride_db_split\n    dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_states_head\n    dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dt_head\n    dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head\n    if HAS_DDA_CS:\n        b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_g * stride_b_head\n        ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head + pid_n * stride_ddA_tile\n    if HAS_SEQ_IDX:\n        seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen\n\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    offs_k = tl.arange(0, BLOCK_SIZE_K)\n    x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_k[None, :] * stride_x_hdim)\n    dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_states_dstate + offs_k[:, None] * stride_states_hdim)\n    dt_ptrs = dt_ptr + offs_m * stride_dt_csize\n    dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize\n    if HAS_DDA_CS:\n        b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_n[None, :] * stride_b_dstate)\n        ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize\n\n    chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n    acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n    if HAS_DDA_CS:\n        b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)\n    if HAS_SEQ_IDX:\n        seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)\n        seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)\n    nheads_iter = min(nheads_per_program, nheads // ngroups - pid_s * nheads_per_program)\n    for h in range(nheads_iter):\n        x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0)\n        dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0)\n        dstates = dstates.to(x_ptrs.dtype.element_ty)\n        db = tl.dot(x, dstates)\n        dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)\n        dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)\n        dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)\n        if not HAS_SEQ_IDX:\n            # scale = tl.exp(dA_cs_last - dA_cs_m)\n            scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0))\n        else:\n            # scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)\n            scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)), 0.0)\n        db *= (scale * dt_m)[:, None]\n        if HAS_DDA_CS:\n            # This is the gradient wrt (dA_cs_last - dA_cs_m), i.e. the exclusive reverse cumsum\n            ddA_cs = tl.sum(db * b, axis=1)\n            if DETERMINISTIC_REDUCTION:\n                tl.store(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1)\n            else:\n                tl.atomic_add(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1)\n        acc += db\n        x_ptrs += stride_x_head\n        dstates_ptrs += stride_states_head\n        dt_ptrs += stride_dt_head\n        dA_cumsum_ptr += stride_dA_cs_head\n        dA_cumsum_ptrs += stride_dA_cs_head\n        if HAS_DDA_CS:\n            ddA_cumsum_ptrs += stride_ddA_cs_head\n\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    # if HAS_SEQ_IDX:\n    #     seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)\n    #     seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)\n    #     acc = tl.where(seq_idx_m[:, None] == seq_idx_last, acc, 0.0)\n    db_ptrs = db_ptr + (offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_dstate)\n    tl.store(db_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate))\n\n\n_CHUNK_STATE_BWD_DB_MIN_BLOCK_N = min(\n    cfg.kwargs['BLOCK_SIZE_N'] for cfg in _chunk_state_bwd_db_kernel.configs\n)\n\n\n@triton.autotune(\n    configs=autotune_configs([\n        # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero([\"ddA_cumsum_ptr\"])),\n        # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddA_cumsum_ptr\"])),\n        # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddA_cumsum_ptr\"])),\n        # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddA_cumsum_ptr\"])),\n        # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddA_cumsum_ptr\"])),\n        # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddA_cumsum_ptr\"])),\n        # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero([\"ddA_cumsum_ptr\"])),\n        # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero([\"ddA_cumsum_ptr\"])),\n        # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddA_cumsum_ptr\"])),\n        triton.Config({'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero([\"ddA_cumsum_ptr\"])),\n        triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero([\"ddA_cumsum_ptr\"])),\n        triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero([\"ddA_cumsum_ptr\"])),\n        triton.Config({'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero([\"ddA_cumsum_ptr\"])),\n        triton.Config({'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero([\"ddA_cumsum_ptr\"])),\n        triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero([\"ddA_cumsum_ptr\"])),\n        triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero([\"ddA_cumsum_ptr\"])),\n        triton.Config({'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero([\"ddA_cumsum_ptr\"])),\n    ]),\n    key=['chunk_size', 'hdim', 'dstate'],\n)\n@triton.jit\ndef _chunk_state_bwd_ddAcs_stable_kernel(\n    # Pointers to matrices\n    x_ptr, b_ptr, dstates_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr,\n    ddA_cumsum_ptr,\n    # Matrix dimensions\n    chunk_size, hdim, dstate,\n    batch, seqlen, nheads_ngroups_ratio,\n    # Strides\n    stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,\n    stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate,\n    stride_dstates_batch, stride_dstates_chunk, stride_states_head, stride_states_hdim, stride_states_dstate,\n    stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,\n    stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,\n    stride_seq_idx_batch, stride_seq_idx_seqlen,\n    stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, stride_ddA_tile,\n    # Meta-parameters\n    HAS_SEQ_IDX: tl.constexpr,\n    DETERMINISTIC_REDUCTION: tl.constexpr,\n    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n    BLOCK_SIZE_DSTATE: tl.constexpr,\n):\n    pid_bc = tl.program_id(axis=1)\n    pid_c = pid_bc // batch\n    pid_b = pid_bc - pid_c * batch\n    pid_h = tl.program_id(axis=2)\n    num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)\n    pid_m = tl.program_id(axis=0) // num_pid_n\n    pid_n = tl.program_id(axis=0) % num_pid_n\n    x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head\n    b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head\n    dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_states_head\n    dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head\n    ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head + pid_n * stride_ddA_tile\n    dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head\n    if HAS_SEQ_IDX:\n        seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen\n\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n\n    chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n    # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128\n    offs_k = tl.arange(0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K)\n    b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate)\n    dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate)\n    if BLOCK_SIZE_DSTATE <= 128:\n        b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate), other=0.0)\n        dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0)\n        dstates = dstates.to(b_ptr.dtype.element_ty)\n        acc = tl.dot(b, dstates)\n    else:\n        acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n        for k in range(0, dstate, BLOCK_SIZE_K):\n            b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate - k), other=0.0)\n            dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0)\n            dstates = dstates.to(b_ptr.dtype.element_ty)\n            acc += tl.dot(b, dstates)\n            b_ptrs += BLOCK_SIZE_K * stride_b_dstate\n            dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate\n\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n\n    dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)\n    dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)\n    if not HAS_SEQ_IDX:\n        # scale = tl.exp(dA_cs_last - dA_cs_m)\n        scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0))\n    else:\n        seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)\n        seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)\n        # scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)\n        scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)), 0.0)\n    acc *= scale[:, None]\n\n    x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)\n    x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)\n    dt_ptrs = dt_ptr + offs_m * stride_dt_csize\n    dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)\n    ddt = tl.sum(acc * x, axis=1)\n    # ddA_cs = -(ddt * dt_m)\n    # Triton 2.2.0 errors if we have the cumsum here, so we just write it out\n    # then call torch.cumsum outside this kernel.\n    # ddA_cs = tl.cumsum(ddt * dt_m)\n    ddA_cs = ddt * dt_m\n    ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize\n    if DETERMINISTIC_REDUCTION:\n        tl.store(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1)\n    else:\n        tl.atomic_add(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1)\n\n\n_CHUNK_STATE_BWD_DDACS_MIN_BLOCK_N = min(\n    cfg.kwargs['BLOCK_SIZE_N'] for cfg in _chunk_state_bwd_ddAcs_stable_kernel.configs\n)\n\n\n@triton.autotune(\n    configs=autotune_configs([\n        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),\n    ]),\n    key=['hdim', 'dstate', 'chunk_size'],\n)\n@triton.jit\ndef _chunk_state_varlen_kernel(\n    # Pointers to matrices\n    x_ptr, b_ptr, dt_ptr, dA_cumsum_ptr, chunk_states_ptr, cu_seqlens_ptr, states_ptr,\n    # Matrix dimensions\n    hdim, dstate, chunk_size,\n    seqlen, nheads_ngroups_ratio,\n    # Strides\n    stride_x_seqlen, stride_x_head, stride_x_hdim,\n    stride_b_seqlen, stride_b_head, stride_b_dstate,\n    stride_dt_chunk, stride_dt_head, stride_dt_csize,\n    stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,\n    stride_chunk_states_chunk, stride_chunk_states_head, stride_chunk_states_hdim, stride_chunk_states_dstate,\n    stride_states_batch, stride_states_head, stride_states_hdim, stride_states_dstate,\n    # Meta-parameters\n    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n):\n    pid_b = tl.program_id(axis=1)\n    pid_h = tl.program_id(axis=2)\n    num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)\n    pid_m = tl.program_id(axis=0) // num_pid_n\n    pid_n = tl.program_id(axis=0) % num_pid_n\n    end_idx = tl.load(cu_seqlens_ptr + pid_b + 1)\n    pid_c = (end_idx - 1) // chunk_size\n    b_ptr += pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head\n    x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head\n    dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head\n    dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head\n    chunk_states_ptr += pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head\n\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    offs_k = tl.arange(0, BLOCK_SIZE_K)\n    x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen)\n    b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen)\n    dt_ptrs = dt_ptr + offs_k * stride_dt_csize\n    dA_cs_last = tl.load(dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)\n    dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize\n\n    chunk_size_limit = end_idx - pid_c * chunk_size\n    start_idx = tl.load(cu_seqlens_ptr + pid_b)\n    start_idx_cur = tl.maximum(start_idx - pid_c * chunk_size, 0)\n\n    acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n    for k in range(0, chunk_size_limit, BLOCK_SIZE_K):\n        x = tl.load(x_ptrs, mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k) & (offs_k[None, :] >= start_idx_cur - k), other=0.0)\n        b = tl.load(b_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate) & (offs_k[:, None] >= start_idx_cur - k), other=0.0).to(tl.float32)\n        dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32)\n        dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32)\n        # scale = tl.where((offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k),\n        #                  tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0)\n        scale = tl.where((offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k),\n                         tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k, 0.0)\n        b *= scale[:, None]\n        b = b.to(x_ptr.dtype.element_ty)\n        acc += tl.dot(x, b)\n        x_ptrs += BLOCK_SIZE_K * stride_x_seqlen\n        b_ptrs += BLOCK_SIZE_K * stride_b_seqlen\n        dt_ptrs += BLOCK_SIZE_K * stride_dt_csize\n        dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize\n\n    # If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk\n    if start_idx < pid_c * chunk_size:\n        chunk_states_ptrs = chunk_states_ptr + (offs_m[:, None] * stride_chunk_states_hdim + offs_n[None, :] * stride_chunk_states_dstate)\n        chunk_states = tl.load(chunk_states_ptrs, mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)\n        # scale = tl.where(start_idx < pid_c * chunk_size, tl.exp(dA_cs_last), 0.0)\n        scale = tl.exp(dA_cs_last)\n        acc += chunk_states * scale\n\n    states = acc.to(states_ptr.dtype.element_ty)\n\n    states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate)\n    c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)\n    tl.store(states_ptrs, states, mask=c_mask)\n\n\ndef _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float(\"inf\"))):\n    batch, seqlen, nheads = dt.shape\n    assert A.shape == (nheads,)\n    if dt_bias is not None:\n        assert dt_bias.shape == (nheads,)\n    nchunks = math.ceil(seqlen / chunk_size)\n    dt_out = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32)\n    dA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32)\n    grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H']))\n    with torch.cuda.device(dt.device.index):\n        _chunk_cumsum_fwd_kernel[grid_chunk_cs](\n            dt, A, dt_bias, dt_out, dA_cumsum,\n            batch, seqlen, nheads, chunk_size,\n            dt_limit[0], dt_limit[1],\n            dt.stride(0), dt.stride(1), dt.stride(2),\n            A.stride(0),\n            dt_bias.stride(0) if dt_bias is not None else 0,\n            dt_out.stride(0), dt_out.stride(2), dt_out.stride(1), dt_out.stride(3),\n            dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),\n            dt_softplus,\n            HAS_DT_BIAS=dt_bias is not None,\n            BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),\n        )\n    return dA_cumsum, dt_out\n\n\ndef _chunk_cumsum_bwd(ddA, ddt_out, dt, A, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float(\"inf\")), ddt=None):\n    batch, seqlen, nheads = dt.shape\n    _, _, nchunks, chunk_size = ddA.shape\n    assert ddA.shape == (batch, nheads, nchunks, chunk_size)\n    assert ddt_out.shape == (batch, nheads, nchunks, chunk_size)\n    assert A.shape == (nheads,)\n    deterministic = use_deterministic_mode()\n    if dt_bias is not None:\n        assert dt_bias.shape == (nheads,)\n        if deterministic:\n            ddt_bias_workspace = torch.zeros(\n                batch, nchunks, nheads, device=dt.device, dtype=torch.float32\n            )\n            ddt_bias = torch.empty_like(dt_bias, dtype=torch.float32)\n            stride_ddt_bias_batch = ddt_bias_workspace.stride(0)\n            stride_ddt_bias_chunk = ddt_bias_workspace.stride(1)\n        else:\n            ddt_bias_workspace = ddt_bias = torch.empty_like(\n                dt_bias, dtype=torch.float32\n            )\n            stride_ddt_bias_batch = 0\n            stride_ddt_bias_chunk = 0\n    else:\n        ddt_bias = None\n        ddt_bias_workspace = None\n        stride_ddt_bias_batch = 0\n        stride_ddt_bias_chunk = 0\n    if ddt is not None:\n        assert ddt.shape == dt.shape\n    else:\n        ddt = torch.empty_like(dt)\n    dA = torch.empty_like(A, dtype=torch.float32)\n    if deterministic:\n        dA_workspace = torch.zeros(\n            batch, nchunks, nheads, device=dt.device, dtype=torch.float32\n        )\n        stride_dA_batch = dA_workspace.stride(0)\n        stride_dA_chunk = dA_workspace.stride(1)\n    else:\n        dA_workspace = dA\n        stride_dA_batch = 0\n        stride_dA_chunk = 0\n    grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H']))\n    with torch.cuda.device(dt.device.index):\n        _chunk_cumsum_bwd_kernel[grid_chunk_cs](\n            ddA, ddt_out, dt, A, dt_bias, ddt, dA_workspace, ddt_bias_workspace if ddt_bias is not None else None,\n            batch, seqlen, nheads, chunk_size,\n            dt_limit[0], dt_limit[1],\n            ddA.stride(0), ddA.stride(2), ddA.stride(1), ddA.stride(3),\n            ddt_out.stride(0), ddt_out.stride(2), ddt_out.stride(1), ddt_out.stride(3),\n            dt.stride(0), dt.stride(1), dt.stride(2),\n            A.stride(0),\n            dt_bias.stride(0) if dt_bias is not None else 0,\n            ddt.stride(0), ddt.stride(1), ddt.stride(2),\n            stride_dA_batch, stride_dA_chunk, dA_workspace.stride(-1),\n            stride_ddt_bias_batch, stride_ddt_bias_chunk, (ddt_bias_workspace.stride(-1) if ddt_bias is not None else 0),\n            dt_softplus,\n            HAS_DT_BIAS=dt_bias is not None,\n            BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),\n            DETERMINISTIC_REDUCTION=deterministic,\n        )\n    if deterministic:\n        dA.copy_(dA_workspace.sum(dim=(0, 1)))\n        if ddt_bias is not None:\n            ddt_bias.copy_(ddt_bias_workspace.sum(dim=(0, 1)))\n    return ddt, dA, ddt_bias\n\n\ndef _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=None, states=None, states_in_fp32=True):\n    batch, seqlen, nheads, headdim = x.shape\n    _, _, nchunks, chunk_size = dt.shape\n    _, _, ngroups, dstate = B.shape\n    assert nheads % ngroups == 0\n    assert B.shape == (batch, seqlen, ngroups, dstate)\n    assert dt.shape == (batch, nheads, nchunks, chunk_size)\n    assert dA_cumsum.shape == dt.shape\n    if seq_idx is not None:\n        assert seq_idx.shape == (batch, seqlen)\n    if states is not None:\n        assert states.shape == (batch, nchunks, nheads, headdim, dstate)\n    else:\n        states_dtype = torch.float32 if states_in_fp32 else B.dtype\n        states = torch.empty((batch, nchunks, nheads, headdim, dstate), device=x.device, dtype=states_dtype)\n    grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']),\n                    batch * nchunks, nheads)\n    with torch.cuda.device(x.device.index):\n        _chunk_state_fwd_kernel[grid](\n            x, B, states, dt, dA_cumsum, seq_idx,\n            headdim, dstate, chunk_size,\n            batch, seqlen, nheads // ngroups,\n            x.stride(0), x.stride(1), x.stride(2), x.stride(3),\n            B.stride(0), B.stride(1), B.stride(2), B.stride(-1),\n            states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4),\n            dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),\n            dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),\n            *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n            HAS_SEQ_IDX=seq_idx is not None,\n        )\n    return states\n\n\ndef _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates, dx=None):\n    batch, seqlen, nheads, headdim = x.shape\n    _, _, nchunks, chunk_size = dt.shape\n    _, _, ngroups, dstate = B.shape\n    assert nheads % ngroups == 0\n    assert B.shape == (batch, seqlen, ngroups, dstate)\n    assert dt.shape == (batch, nheads, nchunks, chunk_size)\n    assert dA_cumsum.shape == dt.shape\n    assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)\n    if dx is not None:\n        assert dx.shape == x.shape\n    else:\n        dx = torch.empty_like(x)\n    deterministic = use_deterministic_mode()\n    tile_count = math.ceil(headdim / _CHUNK_STATE_BWD_DX_MIN_BLOCK_N)\n    ddt, stride_ddt_tile = alloc_tile_workspace(\n        (batch, nheads, nchunks, chunk_size),\n        tile_count,\n        torch.float32,\n        dt.device,\n        deterministic,\n        zero_init=True,\n    )\n    ddA_cumsum, stride_ddA_tile = alloc_tile_workspace(\n        (batch, nheads, nchunks, chunk_size),\n        tile_count,\n        torch.float32,\n        dA_cumsum.device,\n        deterministic,\n        zero_init=True,\n    )\n    grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),\n                       batch * nchunks, nheads)\n    with torch.cuda.device(x.device.index):\n        _chunk_state_bwd_dx_kernel[grid_dx](\n            x, B, dstates, dt, dA_cumsum, dx, ddt, ddA_cumsum,\n            chunk_size, headdim, dstate,\n            batch, seqlen, nheads // ngroups,\n            x.stride(0), x.stride(1), x.stride(2), x.stride(3),\n            B.stride(0), B.stride(1), B.stride(2), B.stride(-1),\n            dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4),\n            dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),\n            dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),\n            dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3),\n            ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3), stride_ddt_tile,\n            ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), stride_ddA_tile,\n            DETERMINISTIC_REDUCTION=deterministic,\n            BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),\n        )\n    ddt = finalize_tile_workspace(ddt, deterministic)\n    ddA_cumsum = finalize_tile_workspace(ddA_cumsum, deterministic)\n    if deterministic:\n        # Match `_chunk_state_bwd_dx_kernel` atomic path (`tl.atomic_add(..., ddA_cs_last)` into last element).\n        ddA_cumsum[..., -1] -= ddA_cumsum.sum(dim=-1)\n    return dx, ddt, ddA_cumsum\n\n\ndef _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=None, B=None, ngroups=1):\n    batch, seqlen, nheads, headdim = x.shape\n    _, _, nchunks, chunk_size = dt.shape\n    dstate = dstates.shape[-1]\n    assert dt.shape == (batch, nheads, nchunks, chunk_size)\n    assert dA_cumsum.shape == dt.shape\n    assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)\n    if seq_idx is not None:\n        assert seq_idx.shape == (batch, seqlen)\n    deterministic = use_deterministic_mode()\n    if B is not None:\n        assert B.shape == (batch, seqlen, ngroups, dstate)\n        B_strides = (B.stride(0), B.stride(1), B.stride(2), B.stride(3))\n        # Use torch.empty since the Triton kernel will call init_to_zero\n        tile_count = math.ceil(dstate / _CHUNK_STATE_BWD_DB_MIN_BLOCK_N)\n        ddA_cumsum, stride_ddA_tile = alloc_tile_workspace(\n            (batch, nheads, nchunks, chunk_size),\n            tile_count,\n            torch.float32,\n            x.device,\n            deterministic,\n            zero_init=True,\n        )\n        ddA_cumsum_strides = (\n            ddA_cumsum.stride(0),\n            ddA_cumsum.stride(2),\n            ddA_cumsum.stride(1),\n            ddA_cumsum.stride(3),\n        )\n    else:\n        B_strides = (0, 0, 0, 0)\n        ddA_cumsum = None\n        ddA_cumsum_strides = (0, 0, 0, 0)\n        stride_ddA_tile = 0\n    nheads_ngroups_ratio = nheads // ngroups\n    sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count\n    nheads_per_program = max(min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1)\n    nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program)\n    dB = torch.empty(batch, seqlen, nsplits, ngroups, dstate, device=x.device, dtype=torch.float32)\n    grid_db = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']),\n                        batch * nchunks, nsplits * ngroups)\n    with torch.cuda.device(x.device.index):\n        _chunk_state_bwd_db_kernel[grid_db](\n            x, dstates, B, dt, dA_cumsum, seq_idx, dB, ddA_cumsum,\n            chunk_size, dstate, headdim,\n            batch, seqlen, nheads, nheads_per_program, ngroups,\n            x.stride(0), x.stride(1), x.stride(2), x.stride(3),\n            dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4),\n            *B_strides,\n            dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),\n            dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),\n            *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n            dB.stride(0), dB.stride(1), dB.stride(2), dB.stride(3), dB.stride(4),\n            *ddA_cumsum_strides, stride_ddA_tile,\n            HAS_DDA_CS=ddA_cumsum is not None,\n            HAS_SEQ_IDX=seq_idx is not None,\n            DETERMINISTIC_REDUCTION=deterministic,\n            BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16),\n        )\n    dB = dB.sum(2)\n    if ddA_cumsum is not None:\n        ddA_cumsum = finalize_tile_workspace(ddA_cumsum, deterministic)\n        # The first element of ddA_cumsum is always zero, since that dA_cumsum does not contribute\n        # to the state of the chunk.\n        # torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:])\n        # But it's easier to just do the cumsum for all elements, the result will be the same.\n        torch.cumsum(ddA_cumsum, dim=-1, out=ddA_cumsum)\n    return dB if B is None else (dB, ddA_cumsum)\n\n\ndef _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=None):\n    batch, seqlen, nheads, headdim = x.shape\n    _, _, nchunks, chunk_size = dt.shape\n    _, _, ngroups, dstate = B.shape\n    assert nheads % ngroups == 0\n    assert B.shape == (batch, seqlen, ngroups, dstate)\n    assert dt.shape == (batch, nheads, nchunks, chunk_size)\n    assert dA_cumsum.shape == dt.shape\n    assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)\n    if seq_idx is not None:\n        assert seq_idx.shape == (batch, seqlen)\n    # Use torch.empty since the Triton kernel will call init_to_zero\n    deterministic = use_deterministic_mode()\n    tile_count = math.ceil(headdim / _CHUNK_STATE_BWD_DDACS_MIN_BLOCK_N)\n    ddA_cumsum, stride_ddA_tile = alloc_tile_workspace(\n        (batch, nheads, nchunks, chunk_size),\n        tile_count,\n        torch.float32,\n        x.device,\n        deterministic,\n        zero_init=True,\n    )\n    grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),\n                          batch * nchunks, nheads)\n    with torch.cuda.device(x.device.index):\n        _chunk_state_bwd_ddAcs_stable_kernel[grid_ddtcs](\n            x, B, dstates, dt, dA_cumsum, seq_idx, ddA_cumsum,\n            chunk_size, headdim, dstate,\n            batch, seqlen, nheads // ngroups,\n            x.stride(0), x.stride(1), x.stride(2), x.stride(3),\n            B.stride(0), B.stride(1), B.stride(2), B.stride(-1),\n            dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4),\n            dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),\n            dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),\n            *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n            ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), stride_ddA_tile,\n            HAS_SEQ_IDX=seq_idx is not None,\n            DETERMINISTIC_REDUCTION=deterministic,\n            BLOCK_SIZE_M=max(triton.next_power_of_2(chunk_size), 16),\n            BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),\n        )\n    ddA_cumsum = finalize_tile_workspace(ddA_cumsum, deterministic)\n    torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:])\n    return ddA_cumsum\n\n\ndef chunk_state_varlen(B, x, dt, dA_cumsum, cu_seqlens, chunk_states):\n    total_seqlen, nheads, headdim = x.shape\n    _, nchunks, chunk_size = dt.shape\n    _, ngroups, dstate = B.shape\n    batch = cu_seqlens.shape[0] - 1\n    cu_seqlens = cu_seqlens.contiguous()\n    assert nheads % ngroups == 0\n    assert B.shape == (total_seqlen, ngroups, dstate)\n    assert dt.shape == (nheads, nchunks, chunk_size)\n    assert dA_cumsum.shape == dt.shape\n    assert chunk_states.shape == (nchunks, nheads, headdim, dstate)\n    states = torch.empty(batch, nheads, headdim, dstate, dtype=chunk_states.dtype, device=chunk_states.device)\n    grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']),\n                    batch, nheads)\n    with torch.cuda.device(x.device.index):\n        _chunk_state_varlen_kernel[grid](\n            x, B, dt, dA_cumsum, chunk_states, cu_seqlens, states,\n            headdim, dstate, chunk_size,\n            total_seqlen, nheads // ngroups,\n            x.stride(0), x.stride(1), x.stride(2),\n            B.stride(0), B.stride(1), B.stride(2),\n            dt.stride(1), dt.stride(0), dt.stride(2),\n            dA_cumsum.stride(1), dA_cumsum.stride(0), dA_cumsum.stride(2),\n            chunk_states.stride(0), chunk_states.stride(1), chunk_states.stride(2), chunk_states.stride(3),\n            states.stride(0), states.stride(1), states.stride(2), states.stride(3),\n        )\n    return states\n\n\nclass ChunkStateFn(torch.autograd.Function):\n\n    @staticmethod\n    def forward(ctx, B, x, dt, dA_cumsum, states_in_fp32=True):\n        batch, seqlen, nheads, headdim = x.shape\n        _, _, nchunks, chunk_size = dt.shape\n        assert seqlen <= nchunks * chunk_size\n        _, _, ngroups, dstate = B.shape\n        assert B.shape == (batch, seqlen, ngroups, dstate)\n        assert dt.shape == (batch, nheads, nchunks, chunk_size)\n        assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)\n        if B.stride(-1) != 1:\n            B = B.contiguous()\n        if x.stride(-1) != 1 and x.stride(1) != 1:  # Either M or K dimension should be contiguous\n            x = x.contiguous()\n        states = _chunk_state_fwd(B, x, dt, dA_cumsum, states_in_fp32=states_in_fp32)\n        ctx.save_for_backward(B, x, dt, dA_cumsum)\n        return states\n\n    @staticmethod\n    def backward(ctx, dstates):\n        B, x, dt, dA_cumsum = ctx.saved_tensors\n        batch, seqlen, nheads, headdim = x.shape\n        _, _, nchunks, chunk_size = dt.shape\n        _, _, ngroups, dstate = B.shape\n        assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)\n        if dstates.stride(-1) != 1:\n            dstates = dstates.contiguous()\n        dx, ddt, ddA_cumsum = _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates)\n        dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, ngroups=ngroups)\n        dB = dB.to(B.dtype)\n        return dB, dx, ddt, ddA_cumsum, None\n\n\ndef chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True):\n    \"\"\"\n    Argument:\n        B: (batch, seqlen, ngroups, dstate)\n        x: (batch, seqlen, nheads, headdim)\n        dt: (batch, nheads, nchunks, chunk_size)\n        dA_cumsum: (batch, nheads, nchunks, chunk_size)\n    Return:\n        states: (batch, nchunks, nheads, headdim, dstate)\n    \"\"\"\n    return ChunkStateFn.apply(B, x, dt, dA_cumsum, states_in_fp32)\n\n\ndef chunk_state_ref(B, x, dt, dA_cumsum):\n    \"\"\"\n    Argument:\n        B: (batch, seqlen, ngroups, dstate)\n        x: (batch, seqlen, nheads, headdim)\n        dt: (batch, nheads, nchunks, chunk_size)\n        dA_cumsum: (batch, nheads, nchunks, chunk_size)\n    Return:\n        states: (batch, nchunks, nheads, headdim, dstate)\n    \"\"\"\n    # Check constraints.\n    batch, seqlen, nheads, headdim = x.shape\n    dstate = B.shape[-1]\n    _, _, nchunks, chunk_size = dt.shape\n    assert seqlen <= nchunks * chunk_size\n    assert x.shape == (batch, seqlen, nheads, headdim)\n    assert dt.shape == (batch, nheads, nchunks, chunk_size)\n    ngroups = B.shape[2]\n    assert nheads % ngroups == 0\n    assert B.shape == (batch, seqlen, ngroups, dstate)\n    B = repeat(B, \"b l g d -> b l (g h) d\", h=nheads // ngroups)\n    assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)\n    if seqlen < nchunks * chunk_size:\n        x = F.pad(x, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen))\n        B = F.pad(B, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen))\n    x = rearrange(x, \"b (c l) h p -> b c l h p\", l=chunk_size)\n    B = rearrange(B, \"b (c l) ... -> b c l ...\", l=chunk_size)\n    decay_states = torch.exp((dA_cumsum[:, :, :, -1:] - dA_cumsum))\n    return torch.einsum(\"bclhn,bhcl,bhcl,bclhp->bchpn\", B.to(x.dtype), decay_states.to(x.dtype), dt.to(x.dtype), x)\n"
  },
  {
    "path": "mamba_ssm/ops/triton/ssd_combined.py",
    "content": "# Copyright (c) 2024, Tri Dao, Albert Gu.\n\n\"\"\"We want triton==2.1.0 or 2.2.0 for this\n\"\"\"\n\nfrom typing import Optional\n\nimport math\nfrom packaging import version\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom mamba_ssm.utils.torch import custom_bwd, custom_fwd\n\nimport triton\nimport triton.language as tl\n\nfrom einops import rearrange, repeat\n\ntry:\n    from causal_conv1d import causal_conv1d_fn\n    from causal_conv1d.cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function\nexcept ImportError:\n    causal_conv1d_fn = None\n    causal_conv1d_fwd_function = None\n    causal_conv1d_bwd_function = None\n    causal_conv1d_update_function = None\n\nfrom mamba_ssm.ops.triton.ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd\nfrom mamba_ssm.ops.triton.ssd_chunk_state import _chunk_cumsum_fwd, _chunk_cumsum_bwd\nfrom mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_fwd, _chunk_state_bwd_db\nfrom mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_bwd_ddAcs_stable\nfrom mamba_ssm.ops.triton.ssd_chunk_state import chunk_state, chunk_state_ref\nfrom mamba_ssm.ops.triton.ssd_chunk_state import chunk_state_varlen\nfrom mamba_ssm.ops.triton.ssd_state_passing import _state_passing_fwd, _state_passing_bwd\nfrom mamba_ssm.ops.triton.ssd_state_passing import state_passing, state_passing_ref\nfrom mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_fwd, _chunk_scan_bwd_dz, _chunk_scan_bwd_dstates\nfrom mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_dC, _chunk_scan_bwd_dcb\nfrom mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_ddAcs_stable\nfrom mamba_ssm.ops.triton.ssd_chunk_scan import chunk_scan, chunk_scan_ref\nfrom mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_ddAcs_prev\nfrom mamba_ssm.ops.triton.layernorm_gated import rmsnorm_fn, _layer_norm_fwd, _layer_norm_bwd\nfrom mamba_ssm.ops.triton.k_activations import _swiglu_fwd, _swiglu_bwd\nfrom mamba_ssm.utils.determinism import (\n    alloc_tile_workspace,\n    autotune_configs,\n    finalize_tile_workspace,\n    use_deterministic_mode,\n)\n\nTRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')\n\n\ndef init_to_zero(names):\n    return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None]\n\n\ndef ensure_stride(inp):\n    \"\"\"\n    Return inp, while ensuring that stride(1) of the returned tensor is a multiple of 8.\n\n    The inp tensor is of shape [batch, length, channels], where channels is assumed, and tested, to be\n    a multiple of 8. If it is contiguous, inp will have strides [length*channels, channels, 1]. The\n    output of this function will be rearranged to shape [batch, channels, length] before being passed to\n    causal_conv1d. That rearranged tensor will have strides [length*channels, 1, channels].\n    causal_conv1d handles this stride configuration (which it calls channels_last) directly and\n    efficiently, after first recognizing it (when stride[1]==1 and stride[2]>1). causal_conv1d cannot\n    operate on a channels_last tensor for which stride[2] is not a multiple of 8, and in that case will\n    raise an exception. This function prevents the aforementioned exception by returning a tensor with\n    stride(1) equal to channels, by making the returned tensor contiguous, if inp.stride(1) is not\n    already a multiple of 8.\n    \"\"\"\n    assert inp.shape[2] % 8 == 0, \"Number of convolution channels is required to be a multiple of 8.\"\n    return inp if inp.stride(1) % 8 == 0 else inp.contiguous()\n\n\n@triton.autotune(\n    configs=autotune_configs([\n        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero([\"ddt_ptr\", \"dD_ptr\"])),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\", \"dD_ptr\"])),\n        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\", \"dD_ptr\"])),\n        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\", \"dD_ptr\"])),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\", \"dD_ptr\"])),\n        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\", \"dD_ptr\"])),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\", \"dD_ptr\"])),\n        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\", \"dD_ptr\"])),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero([\"ddt_ptr\", \"dD_ptr\"])),\n    ]),\n    key=['chunk_size', 'hdim', 'dstate'],\n)\n@triton.jit\ndef _chunk_scan_chunk_state_bwd_dx_kernel(\n    # Pointers to matrices\n    x_ptr, cb_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, D_ptr,\n    b_ptr, dstates_ptr,\n    dx_ptr, ddt_ptr, dD_ptr,\n    # Matrix dimensions\n    chunk_size, hdim, dstate,\n    batch, seqlen, nheads_ngroups_ratio,\n    # Strides\n    stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,\n    stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k,\n    stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,\n    stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,\n    stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,\n    stride_seq_idx_batch, stride_seq_idx_seqlen,\n    stride_D_head,\n    stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate,\n    stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_hdim, stride_dstates_dstate,\n    stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim,\n    stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize, stride_ddt_tile,\n    stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim,\n    # Meta-parameters\n    HAS_D: tl.constexpr,\n    D_HAS_HDIM: tl.constexpr,\n    HAS_SEQ_IDX: tl.constexpr,\n    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n    BLOCK_SIZE_DSTATE: tl.constexpr,\n    IS_TRITON_22: tl.constexpr,\n    DETERMINISTIC_REDUCTION: tl.constexpr,\n):\n    pid_bc = tl.program_id(axis=1)\n    pid_c = pid_bc // batch\n    pid_b = pid_bc - pid_c * batch\n    pid_h = tl.program_id(axis=2)\n    num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)\n    pid_m = tl.program_id(axis=0) // num_pid_n\n    pid_n = tl.program_id(axis=0) % num_pid_n\n    x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head\n    cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head\n    dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head\n    dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head\n    ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head + pid_n * stride_ddt_tile\n    dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head\n    b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head\n    dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_dstates_head\n    if HAS_SEQ_IDX:\n        seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen\n\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n\n    chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n\n    acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n    dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)\n\n    dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)\n    if not HAS_SEQ_IDX:\n        # scale = tl.exp(dA_cs_last - dA_cs_m)\n        scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0))\n    else:\n        seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)\n        seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)\n        # scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)\n        scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)), 0.0)\n    # Might be faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128\n    # However, we're getting error with the Triton compiler 2.1.0 for that code path:\n    # Unexpected mma -> mma layout conversion\n    # Triton 2.2.0 fixes this\n    offs_dstate = tl.arange(0, BLOCK_SIZE_DSTATE if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K)\n    b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_dstate[None, :] * stride_b_dstate)\n    dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_dstates_hdim + offs_dstate[:, None] * stride_dstates_dstate)\n    if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128:\n        b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate), other=0.0)\n        dstates = tl.load(dstates_ptrs, mask=(offs_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0)\n        dstates = dstates.to(b_ptr.dtype.element_ty)\n        acc = tl.dot(b, dstates) * scale[:, None]\n    else:\n        for k in range(0, dstate, BLOCK_SIZE_K):\n            b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate - k), other=0.0)\n            dstates = tl.load(dstates_ptrs, mask=(offs_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0)\n            dstates = dstates.to(b_ptr.dtype.element_ty)\n            acc += tl.dot(b, dstates)\n            b_ptrs += BLOCK_SIZE_K * stride_b_dstate\n            dstates_ptrs += BLOCK_SIZE_K * stride_dstates_dstate\n        acc *= scale[:, None]\n\n    # x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)\n    # x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)\n    # dt_ptrs = dt_ptr + offs_m * stride_dt_csize\n    # dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)\n    # ddt = tl.sum(acc * x, axis=1) * dt_m\n    # ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize\n    # tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)\n\n    offs_k = tl.arange(0, BLOCK_SIZE_K)\n    cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k)\n    dout_ptrs = dout_ptr + (offs_k[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim)\n    dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize\n    K_MAX = chunk_size_limit\n    K_MIN = pid_m * BLOCK_SIZE_M\n    cb_ptrs += K_MIN * stride_cb_csize_k\n    dout_ptrs += K_MIN * stride_dout_seqlen\n    dA_cumsum_ptrs += K_MIN * stride_dA_cs_csize\n    for k in range(K_MIN, K_MAX, BLOCK_SIZE_K):\n        k = tl.multiple_of(k, BLOCK_SIZE_K)\n        # For some reason setting mask to (offs_m[:, None] < chunk_size_limit) is much slower\n        cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0)\n        dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0)\n        dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32)\n        # cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])\n        cb *= tl.exp(tl.minimum((dA_cs_k[None, :] - dA_cs_m[:, None]), 0.0))\n        # If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range,\n        # we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf.\n        # Multiplying with cb, which is 0.0 outside the range, will make the result NaN.\n        # This will cause NaN in acc, and hence NaN in dx and ddt.\n        mask = (k + offs_k[None, :] >= offs_m[:, None]) & (k + offs_k[None, :] < K_MAX)\n        cb = tl.where(mask, cb, 0.0)\n        cb = cb.to(dout_ptr.dtype.element_ty)\n        acc += tl.dot(cb, dout)\n        cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k\n        dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen\n        dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize\n\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    dt_ptrs = dt_ptr + offs_m * stride_dt_csize\n    dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)\n    dx = acc * dt_m[:, None]\n    dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head\n    dx_ptrs = dx_ptr + (offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim)\n    if HAS_D:\n        dout_res_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim)\n        dout_res = tl.load(dout_res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)\n        if D_HAS_HDIM:\n            D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32)\n        else:\n            D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)\n        dx += dout_res * D\n    tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim))\n\n    x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)\n    x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)\n    if HAS_D:\n        dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize\n        if D_HAS_HDIM:\n            dD_ptrs = dD_ptr + offs_n * stride_dD_hdim\n            dD = tl.sum(dout_res * x, axis=0)\n            tl.store(dD_ptrs, dD, mask=offs_n < hdim)\n        else:\n            dD = tl.sum(dout_res * x)\n            if DETERMINISTIC_REDUCTION:\n                tl.store(dD_ptr + pid_n * stride_dD_hdim, dD)\n            else:\n                tl.atomic_add(dD_ptr, dD)\n    ddt = tl.sum(acc * x, axis=1)\n    ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize\n    if DETERMINISTIC_REDUCTION:\n        tl.store(ddt_ptrs, ddt, mask=offs_m < chunk_size)\n    else:\n        tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)\n\n\n_CHUNK_SCAN_CHUNK_STATE_BWD_DX_MIN_BLOCK_N = min(\n    cfg.kwargs['BLOCK_SIZE_N'] for cfg in _chunk_scan_chunk_state_bwd_dx_kernel.configs\n)\n\n\ndef _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=None, seq_idx=None, dx=None):\n    batch, seqlen, nheads, headdim = x.shape\n    _, _, nchunks, chunk_size = dt.shape\n    _, _, ngroups, dstate = B.shape\n    assert nheads % ngroups == 0\n    assert B.shape == (batch, seqlen, ngroups, dstate)\n    assert CB.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)\n    assert dt.shape == (batch, nheads, nchunks, chunk_size)\n    assert dA_cumsum.shape == dt.shape\n    assert dout.shape == x.shape\n    assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)\n    if seq_idx is not None:\n        assert seq_idx.shape == (batch, seqlen)\n    deterministic = use_deterministic_mode()\n    if D is not None:\n        assert D.shape == (nheads, headdim) or D.shape == (nheads,)\n        assert D.stride(-1) == 1\n        BLOCK_SIZE_min = 32\n        pid_m_tiles = triton.cdiv(chunk_size, BLOCK_SIZE_min)\n        pid_n_tiles = math.ceil(headdim / _CHUNK_SCAN_CHUNK_STATE_BWD_DX_MIN_BLOCK_N)\n        if D.dim() == 2:\n            dD_hdim = headdim\n        elif deterministic:\n            dD_hdim = pid_n_tiles\n        else:\n            dD_hdim = 1\n        dD = torch.zeros(pid_m_tiles, batch, nchunks, nheads, dD_hdim, device=D.device, dtype=torch.float32)\n        dD_strides = (dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4))\n    else:\n        dD = None\n        dD_strides = (0, 0, 0, 0, 0)\n    if dx is None:\n        dx = torch.empty_like(x)\n    else:\n        assert dx.shape == x.shape\n    tile_count = math.ceil(headdim / _CHUNK_SCAN_CHUNK_STATE_BWD_DX_MIN_BLOCK_N)\n    ddt, stride_ddt_tile = alloc_tile_workspace(\n        (batch, nheads, nchunks, chunk_size),\n        tile_count,\n        torch.float32,\n        dout.device,\n        deterministic,\n        zero_init=True,\n    )\n    grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),\n                        batch * nchunks, nheads)\n    with torch.cuda.device(x.device.index):\n        _chunk_scan_chunk_state_bwd_dx_kernel[grid_dx](\n            x, CB, dout, dt, dA_cumsum, seq_idx, D, B, dstates, dx, ddt, dD,\n            chunk_size, headdim, dstate,\n            batch, seqlen, nheads // ngroups,\n            x.stride(0), x.stride(1), x.stride(2), x.stride(3),\n            CB.stride(0), CB.stride(1), CB.stride(2), CB.stride(-1), CB.stride(-2),\n            dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),\n            dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),\n            dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),\n            *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n            D.stride(0) if D is not None else 0,\n            B.stride(0), B.stride(1), B.stride(2), B.stride(3),\n            dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4),\n            dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3),\n            ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3), stride_ddt_tile,\n            dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4],\n            D is not None,\n            D.dim() == 2 if D is not None else True,\n            HAS_SEQ_IDX=seq_idx is not None,\n            BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),\n            IS_TRITON_22=TRITON_22,\n            DETERMINISTIC_REDUCTION=deterministic,\n        )\n    if D is not None:\n        BLOCK_SIZE_actual = _chunk_scan_chunk_state_bwd_dx_kernel.best_config.kwargs[\"BLOCK_SIZE_M\"]\n        n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual\n        dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2))\n        if D.dim() == 1:\n            dD = dD.sum(dim=-1)\n        dD = dD.to(dtype=D.dtype)\n    ddt = finalize_tile_workspace(ddt, deterministic)\n    return dx, ddt, dD\n\n\ndef _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float(\"inf\"))):\n    batch, seqlen, nheads, headdim = x.shape\n    _, _, ngroups, dstate = B.shape\n    assert nheads % ngroups == 0\n    assert B.shape == (batch, seqlen, ngroups, dstate)\n    assert x.shape == (batch, seqlen, nheads, headdim)\n    assert dt.shape == (batch, seqlen, nheads)\n    assert A.shape == (nheads,)\n    assert C.shape == B.shape\n    if z is not None:\n        assert z.shape == x.shape\n    if D is not None:\n        assert D.shape == (nheads, headdim) or D.shape == (nheads,)\n    if seq_idx is not None:\n        assert seq_idx.shape == (batch, seqlen)\n    if B.stride(-1) != 1:\n        B = B.contiguous()\n    if C.stride(-1) != 1:\n        C = C.contiguous()\n    if x.stride(-1) != 1 and x.stride(1) != 1:  # Either M or K dimension should be contiguous\n        x = x.contiguous()\n    if z is not None and z.stride(-1) != 1 and z.stride(1) != 1:  # Either M or K dimension should be contiguous\n        z = z.contiguous()\n    if D is not None and D.stride(-1) != 1:\n        D = D.contiguous()\n    if initial_states is not None:\n        assert initial_states.shape == (batch, nheads, headdim, dstate)\n    # # (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, nheads, chunk_size, chunk_size)\n    # dA_cumsum_tmp0, dt_tmp0 = _chunk_cumsum_fwd(dt[:, :147], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)\n    # dA_cumsum_tmp1, dt_tmp1 = _chunk_cumsum_fwd(dt[:, 147:], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)\n    # dA_cumsum_tmp2, dt_tmp2 = _chunk_cumsum_fwd(dt[:, 147:256], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)\n    dA_cumsum, dt = _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit)\n    states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True)\n    # states_tmp0 = _chunk_state_fwd(B[:, :147], x[:, :147], dt_tmp0, dA_cumsum_tmp0, states_in_fp32=True)\n    # states_tmp1 = _chunk_state_fwd(B[:, 147:], x[:, 147:], dt_tmp1, dA_cumsum_tmp1, states_in_fp32=True)\n    # states_tmp2 = _chunk_state_fwd(B[:, 147:256], x[:, 147:256], dt_tmp2, dA_cumsum_tmp2, states_in_fp32=True)\n    states, final_states = _state_passing_fwd(rearrange(states, \"... p n -> ... (p n)\"), dA_cumsum[:, :, :, -1],\n                                              initial_states=rearrange(initial_states, \"... p n -> ... (p n)\") if initial_states is not None else None,\n                                              seq_idx=seq_idx, chunk_size=chunk_size, out_dtype=C.dtype)\n    states, final_states = [rearrange(t, \"... (p n) -> ... p n\", n=dstate) for t in [states, final_states]]\n    # states_tmp0 = rearrange(_state_passing_fwd(rearrange(states_tmp0, \"... p n -> ... (p n)\"), dA_cumsum_tmp0[:, :, :, -1], chunk_size=chunk_size), \"... (p n) -> ... p n\", n=dstate)\n    # states_tmp1 = rearrange(_state_passing_fwd(rearrange(states_tmp1, \"... p n -> ... (p n)\"), dA_cumsum_tmp1[:, :, :, -1], chunk_size=chunk_size), \"... (p n) -> ... p n\", n=dstate)\n    CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32)\n    out, out_x = _chunk_scan_fwd(CB, x, dt, dA_cumsum, C, states, D=D, z=z, seq_idx=seq_idx)\n    if cu_seqlens is None:\n        return out, out_x, dt, dA_cumsum, states, final_states\n    else:\n        assert batch == 1, \"passing cu_seqlens to get the varlen states is only supported if batch dimension is 1\"\n        varlen_states = chunk_state_varlen(B.squeeze(0), x.squeeze(0), dt.squeeze(0), dA_cumsum.squeeze(0),\n                                           cu_seqlens, states.squeeze(0))\n        return out, out_x, dt, dA_cumsum, states, final_states, varlen_states\n\n\ndef _mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, chunk_size, D=None, z=None,\n                                   dt_bias=None, initial_states=None, dfinal_states=None, seq_idx=None, dt_softplus=False,\n                                   dt_limit=(0.0, float(\"inf\")),\n                                   dx=None, ddt=None, dB=None, dC=None, dz=None, recompute_output=False):\n    if dout.stride(-1) != 1:\n        dout = dout.contiguous()\n    batch, seqlen, nheads, headdim = x.shape\n    nchunks = math.ceil(seqlen / chunk_size)\n    _, _, ngroups, dstate = B.shape\n    assert dout.shape == (batch, seqlen, nheads, headdim)\n    assert dt.shape == (batch, seqlen, nheads)\n    assert A.shape == (nheads,)\n    assert nheads % ngroups == 0\n    assert B.shape == (batch, seqlen, ngroups, dstate)\n    assert C.shape == B.shape\n    assert out.shape == x.shape\n    if initial_states is not None:\n        assert initial_states.shape == (batch, nheads, headdim, dstate)\n    if seq_idx is not None:\n        assert seq_idx.shape == (batch, seqlen)\n    if dx is not None:\n        assert dx.shape == x.shape\n    if dB is not None:\n        assert dB.shape == B.shape\n        dB_given = dB\n    else:\n        dB_given = torch.empty_like(B)\n    if dC is not None:\n        assert dC.shape == C.shape\n        dC_given = dC\n    else:\n        dC_given = torch.empty_like(C)\n    if dz is not None:\n        assert z is not None\n        assert dz.shape == z.shape\n    if ddt is not None:\n        assert ddt.shape == dt.shape\n        ddt_given = ddt\n    else:\n        ddt_given = torch.empty_like(dt)\n    # TD: For some reason Triton (2.1.0 and 2.2.0) errors with\n    # \"[CUDA]: invalid device context\" (e.g. during varlne test), and cloning makes it work. Idk why.\n    dt_in = dt.clone()\n    dA_cumsum, dt = _chunk_cumsum_fwd(dt_in, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus,\n                                      dt_limit=dt_limit)\n    CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32)\n    states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True)\n    states, _ = _state_passing_fwd(rearrange(states, \"... p n -> ... (p n)\"), dA_cumsum[:, :, :, -1],\n                                   initial_states=rearrange(initial_states, \"... p n -> ... (p n)\") if initial_states is not None else None,\n                                   seq_idx=seq_idx, chunk_size=chunk_size)\n    states = rearrange(states, \"... (p n) -> ... p n\", n=dstate)\n    if z is not None:\n        dz, dout, dD, *rest = _chunk_scan_bwd_dz(x, z, out, dout, chunk_size=chunk_size, has_ddAcs=False, D=D, dz=dz, recompute_output=recompute_output)\n        outz = rest[0] if recompute_output else out\n    else:\n        dz = None\n        outz = out\n    dstates = _chunk_scan_bwd_dstates(C, dA_cumsum, dout, seq_idx=seq_idx, dtype=states.dtype)\n    # dstates has length nchunks, containing the gradient to initial states at index 0 and\n    # gradient to the states of chunk (nchunks - 2) at index (nchunks - 1)\n    # Do computation in fp32 but convert dstates and states to fp16/bf16 since dstates and states\n    # will be used in matmul in the next kernels.\n    dstates, ddA_chunk_cumsum, dinitial_states, states = _state_passing_bwd(\n        rearrange(states, \"... p n -> ... (p n)\"),\n        dA_cumsum[:, :, :, -1],\n        rearrange(dstates, \"... p n -> ... (p n)\"),\n        dfinal_states=rearrange(dfinal_states, \"... p n -> ... (p n)\") if dfinal_states is not None else None,\n        seq_idx=seq_idx,\n        has_initial_states=initial_states is not None,\n        dstates_dtype=x.dtype,\n        states_dtype=x.dtype,\n        chunk_size=chunk_size,\n    )\n    # dstates has length nchunks, containing the gradient to states of chunk 0 at index 0 and\n    # gradient to the final states at index (nchunks - 1)\n    # states has length nchunks, containing the initial states at index 0 and the state for chunk (nchunks - 2) at index (nchunks - 1)\n    # The final states is not stored.\n    states = rearrange(states, \"... (p n) -> ... p n\", n=dstate)\n    dstates = rearrange(dstates, \"... (p n) -> ... p n\", n=dstate)\n    dinitial_states = rearrange(dinitial_states, \"... (p n) -> ... p n\", n=dstate) if dinitial_states is not None else None\n    dx, ddt, dD_from_x = _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=D, seq_idx=seq_idx, dx=dx)\n    # dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=seq_idx, ngroups=ngroups)\n    dB, ddA_next = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=seq_idx, B=B, ngroups=ngroups)\n    # dC = _chunk_scan_bwd_dC(states[:, :-1].to(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups)\n    dC, ddA_cumsum_prev = _chunk_scan_bwd_dC(states.to(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, C=C, ngroups=ngroups)\n    # Computing ddA with the dcb kernel is much slower, so we're not using it for now\n    dCB = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups)\n    # dCB, ddA_tmp = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, CB=CB, ngroups=ngroups)\n    dCB = dCB.to(CB.dtype)\n    _bmm_chunk_bwd(C, dCB, residual=dB, out=dB_given)\n    _bmm_chunk_bwd(B, rearrange(dCB, \"... l s -> ... s l\"), residual=dC, out=dC_given)\n    # If we have z, then dout_x is recomputed in fp32 so dD = (dout_x * x).sum() is more accurate\n    # than dD_from_x = (dout_x * x).sum() where dout_x is in fp16/bf16\n    if z is None:\n        dD = dD_from_x\n    # Formula for ddA_cumsum, assuming out is the output of the forward pass before adding x * D.\n    # ddA_cumsum = torch.einsum(\"bclhp,bclhp->bhcl\", out.float(), dout.float()) - ddt * dt\n    # However, this is numerically unstable: when we do the reverse cumsum on ddA_cumsum, there might\n    # be a lot of underflow.\n\n    # This is already done as part of bwd_dC kernel\n    # ddA_cumsum_prev = _chunk_scan_bwd_ddAcs_prev(states[:, :-1], C, dout, dA_cumsum, seq_idx=seq_idx)\n    ddA_cumsum_prev[..., -1] += ddA_chunk_cumsum\n    ddA_prev = ddA_cumsum_prev.flip([-1]).cumsum(dim=-1).flip([-1])\n    # This is already done as part of bwd_dB kernel\n    # ddA_next = _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=seq_idx)\n    # We don't need to pass in seq_idx because CB also zeros out entries where seq_idx[i] != seq_idx[j]\n    ddA = _chunk_scan_bwd_ddAcs_stable(x, dt, dA_cumsum, dout, CB)\n    ddA += ddA_next + ddA_prev\n\n    ddt_given, dA, ddt_bias = _chunk_cumsum_bwd(ddA, ddt, dt_in, A, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit, ddt=ddt_given)\n\n    # These 2 lines are just to test ddt and dA being computed by old code\n    # _, dA = selective_scan_bwd(dout, x, dt, A, B, C, D=D.float(), z=z)\n    # ddt_given.copy_(ddt)\n\n    return_vals = (dx, ddt_given, dA, dB_given, dC_given, dD, dz, ddt_bias, dinitial_states)\n    return return_vals if not recompute_output else (*return_vals, outz)\n\n\ndef selective_scan_bwd(dout, x, dt, A, B, C, D=None, z=None):\n    \"\"\"\n    Argument:\n        dout: (batch, seqlen, nheads, headdim)\n        x: (batch, seqlen, nheads, headdim)\n        dt: (batch, nheads, nchunks, chunk_size) or (batch, nheads, headdim, nchunks, chunk_size)\n        A: (nheads) or (dim, dstate)\n        B: (batch, seqlen, ngroups, dstate)\n        C: (batch, seqlen, ngroups, dstate)\n        D: (nheads, headdim) or (nheads,)\n        z: (batch, seqlen, nheads, headdim)\n    Return:\n        out: (batch, seqlen, nheads, headdim)\n    \"\"\"\n    import selective_scan\n\n    batch, seqlen, nheads, headdim = x.shape\n    chunk_size = dt.shape[-1]\n    _, _, ngroups, dstate = B.shape\n    assert nheads % ngroups == 0\n    x = rearrange(x, \"b l h p -> b (h p) l\")\n    squeeze_dt = dt.dim() == 4\n    if dt.dim() == 4:\n        dt = repeat(dt, \"b h c l -> b h p c l\", p=headdim)\n    dt = rearrange(dt, \"b h p c l -> b (h p) (c l)\", p=headdim)\n    squeeze_A = A.dim() == 1\n    if A.dim() == 1:\n        A = repeat(A, \"h -> (h p) n\", p=headdim, n=dstate).to(dtype=torch.float32)\n    else:\n        A = A.to(dtype=torch.float32)\n    B = rearrange(B, \"b l g n -> b g n l\")\n    C = rearrange(C, \"b l g n -> b g n l\")\n    if D is not None:\n        if D.dim() == 2:\n            D = rearrange(D, \"h p -> (h p)\")\n        else:\n            D = repeat(D, \"h -> (h p)\", p=headdim)\n    if z is not None:\n        z = rearrange(z, \"b l h p -> b (h p) l\")\n\n    if x.stride(-1) != 1:\n        x = x.contiguous()\n    if dt.stride(-1) != 1:\n        dt = dt.contiguous()\n    if D is not None:\n        D = D.contiguous()\n    if B.stride(-1) != 1:\n        B = B.contiguous()\n    if C.stride(-1) != 1:\n        C = C.contiguous()\n    if z is not None and z.stride(-1) != 1:\n        z = z.contiguous()\n    _, intermediate, *rest = selective_scan.fwd(x, dt.to(dtype=x.dtype), A, B, C, D, z, None, False)\n    if z is not None:\n        out = rest[0]\n    else:\n        out = None\n\n    dout = rearrange(dout, \"b l h p -> b (h p) l\")\n\n    if dout.stride(-1) != 1:\n        dout = dout.contiguous()\n    # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the\n    # backward of selective_scan with the backward of chunk).\n    # Here we just pass in None and dz will be allocated in the C++ code.\n    _, ddt, dA, *rest = selective_scan.bwd(\n        x, dt.to(dtype=x.dtype), A, B, C, D, z, None, dout, intermediate, out, None, False,\n        False  # option to recompute out_z, not used here\n    )\n    ddt = rearrange(ddt, \"b (h p) (c l) -> b h p c l\", p=headdim, l=chunk_size)\n    if squeeze_dt:\n        ddt = ddt.float().sum(dim=2)\n    if squeeze_A:\n        dA = rearrange(dA, \"(h p) n -> h p n\", p=headdim).sum(dim=(1, 2))\n    return ddt, dA\n\n\nclass MambaChunkScanCombinedFn(torch.autograd.Function):\n\n    @staticmethod\n    def forward(ctx, x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float(\"inf\")), return_final_states=False, return_varlen_states=False):\n        ctx.dt_dtype = dt.dtype\n        if not return_varlen_states:\n            cu_seqlens = None\n        else:\n            assert cu_seqlens is not None, \"cu_seqlens must be provided if return_varlen_states is True\"\n        out, out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, cu_seqlens=cu_seqlens, dt_softplus=dt_softplus, dt_limit=dt_limit)\n        ctx.save_for_backward(out if z is None else out_x, x, dt, dA_cumsum, A, B, C, D, z, dt_bias, initial_states, seq_idx)\n        ctx.dt_softplus = dt_softplus\n        ctx.chunk_size = chunk_size\n        ctx.dt_limit = dt_limit\n        ctx.return_final_states = return_final_states\n        ctx.return_varlen_states = return_varlen_states\n        if not return_varlen_states:\n            return out if not return_final_states else (out, final_states)\n        else:\n            varlen_states = rest[0]\n            return (out, varlen_states) if not return_final_states else (out, final_states, varlen_states)\n\n    @staticmethod\n    def backward(ctx, dout, *args):\n        out, x, dt, dA_cumsum, A, B, C, D, z, dt_bias, initial_states, seq_idx = ctx.saved_tensors\n        assert not ctx.return_varlen_states, \"return_varlen_states is not supported in backward\"\n        dfinal_states = args[0] if ctx.return_final_states else None\n        dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states = _mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, ctx.chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=dfinal_states, seq_idx=seq_idx, dt_softplus=ctx.dt_softplus, dt_limit=ctx.dt_limit)\n        return dx, ddt, dA, dB, dC, None, dD, dz, ddt_bias, dinitial_states, None, None, None, None, None, None\n\n\ndef mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float(\"inf\")), return_final_states=False, return_varlen_states=False):\n    \"\"\"\n    Argument:\n        x: (batch, seqlen, nheads, headdim)\n        dt: (batch, seqlen, nheads)\n        A: (nheads)\n        B: (batch, seqlen, ngroups, dstate)\n        C: (batch, seqlen, ngroups, dstate)\n        chunk_size: int\n        D: (nheads, headdim) or (nheads,)\n        z: (batch, seqlen, nheads, headdim)\n        dt_bias: (nheads,)\n        initial_states: (batch, nheads, headdim, dstate)\n        seq_idx: (batch, seqlen)\n        cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True\n        dt_softplus: Whether to apply softplus to dt\n    Return:\n        out: (batch, seqlen, nheads, headdim)\n    \"\"\"\n    return MambaChunkScanCombinedFn.apply(x, dt, A, B, C, chunk_size, D, z, dt_bias, initial_states, seq_idx, cu_seqlens, dt_softplus, dt_limit, return_final_states, return_varlen_states)\n\n\ndef mamba_chunk_scan(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False):\n    \"\"\"\n    Argument:\n        x: (batch, seqlen, nheads, headdim)\n        dt: (batch, seqlen, nheads)\n        A: (nheads)\n        B: (batch, seqlen, ngroups, dstate)\n        C: (batch, seqlen, ngroups, dstate)\n        D: (nheads, headdim) or (nheads,)\n        z: (batch, seqlen, nheads, headdim)\n        dt_bias: (nheads,)\n    Return:\n        out: (batch, seqlen, nheads, headdim)\n    \"\"\"\n    batch, seqlen, nheads, headdim = x.shape\n    dstate = B.shape[-1]\n    if seqlen % chunk_size != 0:\n        dt = F.pad(dt, (0, 0, 0, chunk_size - seqlen % chunk_size))\n    dt = rearrange(dt, \"b (c l) h -> b h c l\", l=chunk_size)\n    dt = dt.float()  # We want high precision for this before cumsum\n    if dt_bias is not None:\n        dt = dt + rearrange(dt_bias, \"h -> h 1 1\")\n    if dt_softplus:\n        dt = F.softplus(dt)\n    dA = dt * rearrange(A, \"h -> h 1 1\")\n    dA = dt * rearrange(A, \"h -> h 1 1\")\n    dA_cumsum = torch.cumsum(dA, dim=-1)\n    # 1. Compute the state for each chunk\n    states = chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True)\n    # 2. Pass the state to all the chunks by weighted cumsum.\n    states = rearrange(state_passing(rearrange(states, \"... p n -> ... (p n)\"), dA_cumsum[:, :, :, -1])[0],\n                       \"... (p n) -> ... p n\", n=dstate)\n    # 3. Compute the output for each chunk\n    out = chunk_scan(B, C, x, dt, dA_cumsum, states, D=D, z=z)\n    return out\n\n\ndef ssd_chunk_scan_combined_ref(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False):\n    \"\"\"\n    Argument:\n        x: (batch, seqlen, nheads, headdim)\n        dt: (batch, seqlen, nheads)\n        A: (nheads)\n        B: (batch, seqlen, ngroups, dstate)\n        C: (batch, seqlen, ngroups, dstate)\n        D: (nheads, headdim) or (nheads,)\n        z: (batch, seqlen, nheads, headdim)\n        dt_bias: (nheads,)\n    Return:\n        out: (batch, seqlen, nheads, headdim)\n    \"\"\"\n    batch, seqlen, nheads, headdim = x.shape\n    dstate = B.shape[-1]\n    if seqlen % chunk_size != 0:\n        dt = F.pad(dt, (0, 0, 0, chunk_size - seqlen % chunk_size))\n    dt = rearrange(dt, \"b (c l) h -> b h c l\", l=chunk_size)\n    dt = dt.float()  # We want high precision for this before cumsum\n    if dt_bias is not None:\n        dt = dt + rearrange(dt_bias, \"h -> h 1 1\")\n    if dt_softplus:\n        dt = F.softplus(dt)\n    dA = dt * rearrange(A, \"h -> h 1 1\")\n    dA_cumsum = torch.cumsum(dA, dim=-1)\n    # 1. Compute the state for each chunk\n    states = chunk_state_ref(B, x, dt, dA_cumsum)\n    states_dtype = states.dtype\n    if states.dtype not in [torch.float32, torch.float64]:\n        states = states.to(torch.float32)\n    # 2. Pass the state to all the chunks by weighted cumsum.\n    # state_passing_ref is much less numerically stable\n    states = rearrange(state_passing_ref(rearrange(states, \"... p n -> ... (p n)\"), dA_cumsum[:, :, :, -1])[0],\n                       \"... (p n) -> ... p n\", n=dstate)\n    states = states.to(states_dtype)\n    # 3. Compute the output for each chunk\n    out = chunk_scan_ref(B, C, x, dt, dA_cumsum, states, D=D, z=z)\n    return out\n\n\ndef ssd_selective_scan(x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float(\"inf\"))):\n    \"\"\"\n    Argument:\n        x: (batch, seqlen, nheads, headdim)\n        dt: (batch, seqlen, nheads) or (batch, seqlen, nheads, headdim)\n        A: (nheads) or (dim, dstate)\n        B: (batch, seqlen, ngroups, dstate)\n        C: (batch, seqlen, ngroups, dstate)\n        D: (nheads, headdim) or (nheads,)\n        z: (batch, seqlen, nheads, headdim)\n        dt_bias: (nheads,) or (nheads, headdim)\n    Return:\n        out: (batch, seqlen, nheads, headdim)\n    \"\"\"\n    from mamba_ssm.ops.selective_scan_interface import selective_scan_fn\n\n    batch, seqlen, nheads, headdim = x.shape\n    _, _, ngroups, dstate = B.shape\n    x = rearrange(x, \"b l h p -> b (h p) l\")\n    if dt.dim() == 3:\n        dt = repeat(dt, \"b l h -> b l h p\", p=headdim)\n    dt = rearrange(dt, \"b l h p -> b (h p) l\")\n    if A.dim() == 1:\n        A = repeat(A, \"h -> (h p) n\", p=headdim, n=dstate).to(dtype=torch.float32)\n    else:\n        A = A.to(dtype=torch.float32)\n    B = rearrange(B, \"b l g n -> b g n l\")\n    C = rearrange(C, \"b l g n -> b g n l\")\n    if D is not None:\n        if D.dim() == 2:\n            D = rearrange(D, \"h p -> (h p)\")\n        else:\n            D = repeat(D, \"h -> (h p)\", p=headdim)\n    if z is not None:\n        z = rearrange(z, \"b l h p -> b (h p) l\")\n    if dt_bias is not None:\n        if dt_bias.dim() == 1:\n            dt_bias = repeat(dt_bias, \"h -> h p\", p=headdim)\n        dt_bias = rearrange(dt_bias, \"h p -> (h p)\")\n    if dt_limit != (0.0, float(\"inf\")):\n        if dt_bias is not None:\n            dt = dt + rearrange(dt_bias, \"d -> d 1\")\n        if dt_softplus:\n            dt = F.softplus(dt)\n        dt = dt.clamp(min=dt_limit[0], max=dt_limit[1]).to(x.dtype)\n        dt_bias = None\n        dt_softplus = None\n    out = selective_scan_fn(x, dt, A, B, C, D=D, z=z, delta_bias=dt_bias, delta_softplus=dt_softplus)\n    return rearrange(out, \"b (h p) l -> b l h p\", p=headdim)\n\n\ndef mamba_conv1d_scan_ref(xBC, conv1d_weight, conv1d_bias, dt, A, chunk_size, D=None, z=None,\n                          dt_bias=None, dt_softplus=False, dt_limit=(0.0, float(\"inf\")),\n                          activation=\"silu\", headdim=None, ngroups=1):\n    \"\"\"\n    Argument:\n        xBC: (batch, seqlen, dim + 2 * ngroups * dstate) where dim == nheads * headdim\n        conv1d_weight: (dim + 2 * ngroups * dstate, width)\n        conv1d_bias: (dim + 2 * ngroups * dstate,)\n        dt: (batch, seqlen, nheads) or (batch, seqlen, nheads, headdim)\n        A: (nheads)\n        D: (nheads, headdim) or (nheads,)\n        z: (batch, seqlen, dim)\n        dt_bias: (nheads) or (nheads, headdim)\n        headdim: if D is 1D and z is None, headdim must be passed in\n    Return:\n        out: (batch, seqlen, dim)\n    \"\"\"\n    batch, seqlen, nheads = dt.shape[:3]\n    assert nheads % ngroups == 0\n    if z is not None:\n        dim = z.shape[-1]\n        assert dim % nheads == 0\n        headdim = dim // nheads\n    else:\n        if D.dim() == 1:\n            assert headdim is not None\n        else:\n            headdim = D.shape[1]\n        dim = nheads * headdim\n    xBC = rearrange(causal_conv1d_fn(rearrange(xBC, \"b s d -> b d s\"), conv1d_weight, conv1d_bias, activation=activation),\n                    \"b d s -> b s d\")\n    dstate = (xBC.shape[-1] - dim) // ngroups // 2\n    x, B, C = torch.split(xBC, [dim, ngroups * dstate, ngroups * dstate], dim=-1)\n    x = rearrange(x, \"b l (h p) -> b l h p\", h=nheads)\n    B = rearrange(B, \"b l (g n) -> b l g n\", g=ngroups)\n    C = rearrange(C, \"b l (g n) -> b l g n\", g=ngroups)\n    z = rearrange(z, \"b l (h p) -> b l h p\", h=nheads) if z is not None else None\n    out = ssd_selective_scan(x, dt.to(x.dtype), A, B, C, D=D.float(), z=z, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit)\n    return rearrange(out, \"b s h p -> b s (h p)\")\n\n\nclass MambaSplitConv1dScanCombinedFn(torch.autograd.Function):\n\n    @staticmethod\n    @custom_fwd\n    def forward(ctx, zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float(\"inf\")), return_final_states=False, activation=\"silu\",\n                rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None,\n                ngroups=1, norm_before_gate=True):\n        assert activation in [None, \"silu\", \"swish\"]\n        if D.dim() == 1:\n            assert headdim is not None\n            nheads, = D.shape\n        else:\n            nheads, headdim = D.shape\n        batch, seqlen, _ = zxbcdt.shape\n        dim = nheads * headdim\n        assert nheads % ngroups == 0\n        dstate = (conv1d_weight.shape[0] - dim) // ngroups // 2\n        d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ngroups * dstate - nheads) // 2\n        assert d_nonssm >= 0\n        assert zxbcdt.shape == (batch, seqlen, 2 * d_nonssm + 2 * dim + 2 * ngroups * dstate + nheads)\n        assert dt_bias.shape == (nheads,)\n        assert A.shape == (nheads,)\n        zx0, z, xBC, dt = torch.split(zxbcdt, [2 * d_nonssm, dim, dim + ngroups * dstate * 2, nheads], dim=-1)\n        seq_idx = seq_idx.contiguous() if seq_idx is not None else None\n        xBC_conv = rearrange(\n            causal_conv1d_fwd_function(rearrange(ensure_stride(xBC), \"b s d -> b d s\"),\n                conv1d_weight, conv1d_bias, seq_idx, None, None, activation in [\"silu\", \"swish\"]),\n            \"b d s -> b s d\"\n        )\n        x, B, C = torch.split(xBC_conv, [dim, ngroups * dstate, ngroups * dstate], dim=-1)\n        x = rearrange(x, \"b l (h p) -> b l h p\", h=nheads)\n        B = rearrange(B, \"b l (g n) -> b l g n\", g=ngroups)\n        C = rearrange(C, \"b l (g n) -> b l g n\", g=ngroups)\n        z = rearrange(z, \"b l (h p) -> b l h p\", h=nheads) if z is not None else None\n        if rmsnorm_weight is None:\n            out, out_x, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size=chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=dt_limit)\n            out = rearrange(out, \"b s h p -> b s (h p)\")\n            rstd = None\n            if d_nonssm > 0:\n                out = torch.cat([_swiglu_fwd(zx0), out], dim=-1)\n        else:\n            out_x, _, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size=chunk_size, D=D, z=None, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=dt_limit)\n            # reshape input data into 2D tensor\n            x_rms = rearrange(out_x, \"b s h p -> (b s) (h p)\")\n            z_rms = rearrange(z, \"b s h p -> (b s) (h p)\")\n            rmsnorm_weight = rmsnorm_weight.contiguous()\n            if d_nonssm == 0:\n                out = None\n            else:\n                out01 = torch.empty((batch, seqlen, d_nonssm + dim), dtype=x_rms.dtype, device=x_rms.device)\n                out = rearrange(out01[..., d_nonssm:], \"b s d -> (b s) d\")\n                _swiglu_fwd(zx0, out=out01[..., :d_nonssm])\n            out, _, rstd = _layer_norm_fwd(x_rms, rmsnorm_weight, None, rmsnorm_eps, z_rms, out=out,\n                                           group_size=dim // ngroups,\n                                           norm_before_gate=norm_before_gate, is_rms_norm=True)\n            if d_nonssm == 0:\n                out = rearrange(out, \"(b s) d -> b s d\", b=batch)\n            else:\n                out = out01\n        ctx.outproj_weight_dtype = outproj_weight.dtype if outproj_weight is not None else None\n        if outproj_weight is not None:\n            if torch.is_autocast_enabled():\n                dtype = torch.get_autocast_gpu_dtype()\n                out, outproj_weight = out.to(dtype), outproj_weight.to(dtype)\n                outproj_bias = outproj_bias.to(dtype) if outproj_bias is not None else None\n            out = F.linear(out, outproj_weight, outproj_bias)\n        else:\n            assert outproj_bias is None\n        ctx.save_for_backward(zxbcdt, conv1d_weight, conv1d_bias,\n                              out_x, A, D, dt_bias, initial_states, seq_idx, rmsnorm_weight, rstd, outproj_weight, outproj_bias)\n        ctx.dt_limit = dt_limit\n        ctx.return_final_states = return_final_states\n        ctx.activation = activation\n        ctx.rmsnorm_eps = rmsnorm_eps\n        ctx.norm_before_gate = norm_before_gate\n        ctx.chunk_size = chunk_size\n        ctx.headdim = headdim\n        ctx.ngroups = ngroups\n        return out if not return_final_states else (out, final_states)\n\n    @staticmethod\n    @custom_bwd\n    def backward(ctx, dout, *args):\n        zxbcdt, conv1d_weight, conv1d_bias, out, A, D, dt_bias, initial_states, seq_idx, rmsnorm_weight, rstd, outproj_weight, outproj_bias = ctx.saved_tensors\n        dfinal_states = args[0] if ctx.return_final_states else None\n        headdim = ctx.headdim\n        nheads = D.shape[0]\n        dim = nheads * headdim\n        assert nheads % ctx.ngroups == 0\n        dstate = (conv1d_weight.shape[0] - dim) // ctx.ngroups // 2\n        d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ctx.ngroups * dstate - nheads) // 2\n        assert d_nonssm >= 0\n        recompute_output = outproj_weight is not None\n        if recompute_output:\n            out_recompute = torch.empty(*out.shape[:2], d_nonssm + dim, device=out.device, dtype=out.dtype)\n            out0_recompute, out1_recompute = out_recompute.split([d_nonssm, dim], dim=-1)\n        zx0, z, xBC, dt = torch.split(zxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1)\n        # Recompute x, B, C\n        xBC_conv = rearrange(\n            causal_conv1d_fwd_function(rearrange(ensure_stride(xBC), \"b s d -> b d s\"),\n                conv1d_weight, conv1d_bias, seq_idx, None, None, ctx.activation in [\"silu\", \"swish\"]),\n            \"b d s -> b s d\"\n        )\n        x, B, C = torch.split(xBC_conv, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1)\n        x = rearrange(x, \"b l (h p) -> b l h p\", h=nheads)\n        B = rearrange(B, \"b l (g n) -> b l g n\", g=ctx.ngroups)\n        C = rearrange(C, \"b l (g n) -> b l g n\", g=ctx.ngroups)\n        dzxbcdt = torch.empty_like(zxbcdt)\n        dzx0, dz, dxBC_given, ddt_given = torch.split(dzxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1)\n        dxBC = torch.empty_like(xBC)\n        dx, dB, dC = torch.split(dxBC, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1)\n        z = rearrange(z, \"b l (h p) -> b l h p\", h=nheads)\n        dx = rearrange(dx, \"b l (h p) -> b l h p\", h=nheads)\n        dB = rearrange(dB, \"b l (g n) -> b l g n\", g=ctx.ngroups)\n        dC = rearrange(dC, \"b l (g n) -> b l g n\", g=ctx.ngroups)\n        if outproj_weight is not None:\n            dout_og = dout\n            dout = F.linear(dout, outproj_weight.t())\n        if d_nonssm > 0:\n            dout0, dout = dout.split([d_nonssm, dim], dim=-1)\n            _swiglu_bwd(zx0, dout0, dxy=dzx0, recompute_output=True, out=out0_recompute)\n        dout = rearrange(dout, \"b s (h p) -> b s h p\", p=headdim)\n        if rmsnorm_weight is None:\n            dz = rearrange(dz, \"b l (h p) -> b l h p\", h=nheads)\n            dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states, *rest = _mamba_chunk_scan_combined_bwd(\n                dout, x, dt, A, B, C, out, ctx.chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=dfinal_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=ctx.dt_limit, dx=dx, ddt=ddt_given, dB=dB, dC=dC, dz=dz, recompute_output=recompute_output\n            )\n            out_for_linear = rearrange(rest[0], \"b s h p -> b s (h p)\") if recompute_output else None\n            drmsnorm_weight = None\n        else:\n            batch = dout.shape[0]\n            dy_rms = rearrange(dout, \"b s h p -> (b s) (h p)\")\n            dz = rearrange(dz, \"b l d -> (b l) d\")\n            x_rms = rearrange(out, \"b s h p -> (b s) (h p)\")\n            z_rms = rearrange(z, \"b s h p -> (b s) (h p)\")\n            out1_recompute = rearrange(out1_recompute, \"b s d -> (b s) d\") if recompute_output else None\n            dout, drmsnorm_weight, _, dz, *rest = _layer_norm_bwd(dy_rms, x_rms, rmsnorm_weight, None, ctx.rmsnorm_eps, None, rstd, z_rms, group_size=dim//ctx.ngroups, norm_before_gate=ctx.norm_before_gate, is_rms_norm=True, recompute_output=recompute_output, dz=dz, out=out1_recompute if recompute_output else None)\n            out_for_linear = out_recompute if recompute_output else None\n            dout = rearrange(dout, \"(b s) (h p) -> b s h p\", b=batch, p=headdim)\n            dx, ddt, dA, dB, dC, dD, _, ddt_bias, dinitial_states = _mamba_chunk_scan_combined_bwd(\n                dout, x, dt, A, B, C, out, ctx.chunk_size, D=D, z=None, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=dfinal_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=ctx.dt_limit, dx=dx, ddt=ddt_given, dB=dB, dC=dC\n            )\n\n        if outproj_weight is not None:\n            doutproj_weight = torch.einsum(\"bso,bsd->od\", dout_og, out_for_linear)\n            doutproj_bias = dout_og.sum(dim=(0, 1)) if outproj_bias is not None else None\n        else:\n            doutproj_weight, doutproj_bias = None, None\n        dxBC_given_update, dweight, dbias, *_ = causal_conv1d_bwd_function(\n            rearrange(ensure_stride(xBC), \"b s d -> b d s\"), conv1d_weight, conv1d_bias,\n            # It might be okay to not run ensure_stride on dxBC, but we're not sure. So playing safe here.\n            rearrange(ensure_stride(dxBC), \"b s d -> b d s\"), seq_idx, None, None,\n            rearrange(ensure_stride(dxBC_given), \"b s d -> b d s\"), False, ctx.activation in [\"silu\", \"swish\"]\n        )\n        dxBC_given_update = rearrange(dxBC_given_update, \"b d s -> b s d\")\n        if dxBC_given.stride() != dxBC_given_update.stride():\n            dxBC_given.copy_(dxBC_given_update)\n        else:\n            dxBC_given = dxBC_given_update\n        return dzxbcdt, dweight, dbias, ddt_bias, dA, dD, None, dinitial_states, None, None, None, None, drmsnorm_weight, None, doutproj_weight, doutproj_bias, None, None, None\n\n\ndef mamba_split_conv1d_scan_combined(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float(\"inf\")), return_final_states=False, activation=\"silu\", rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None, ngroups=1, norm_before_gate=True):\n    \"\"\"\n    Argument:\n        zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim\n        conv1d_weight: (dim + 2 * ngroups * dstate, width)\n        conv1d_bias: (dim + 2 * ngroups * dstate,)\n        dt_bias: (nheads,)\n        A: (nheads)\n        D: (nheads, headdim) or (nheads,)\n        initial_states: (batch, nheads, headdim, dstate)\n        seq_idx: (batch, seqlen), int32\n        rmsnorm_weight: (dim,)\n        outproj_weight: (out_dim, dim)\n        outproj_bias: (out_dim,)\n        headdim: if D is 1D, headdim must be passed in\n        norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z))\n    Return:\n        out: (batch, seqlen, dim)\n    \"\"\"\n    return MambaSplitConv1dScanCombinedFn.apply(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate)\n\n\ndef mamba_split_conv1d_scan_ref(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, dt_limit=(0.0, float(\"inf\")), activation=\"silu\", rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None, ngroups=1, norm_before_gate=True):\n    \"\"\"\n    Argument:\n        zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim\n        conv1d_weight: (dim + 2 * ngroups * dstate, width)\n        conv1d_bias: (dim + 2 * ngroups * dstate,)\n        dt_bias: (nheads,)\n        A: (nheads)\n        D: (nheads, headdim) or (nheads,)\n        rmsnorm_weight: (dim,)\n        outproj_weight: (out_dim, dim)\n        outproj_bias: (out_dim,)\n        headdim: if D is 1D, headdim must be passed in\n        norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z))\n    Return:\n        out: (batch, seqlen, dim)\n    \"\"\"\n    if D.dim() == 1:\n        assert headdim is not None\n        nheads, = D.shape\n    else:\n        nheads, headdim = D.shape\n    assert nheads % ngroups == 0\n    batch, seqlen, _ = zxbcdt.shape\n    dim = nheads * headdim\n    dstate = (zxbcdt.shape[-1] - 2 * dim - nheads) // ngroups // 2\n    assert zxbcdt.shape == (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads)\n    assert dt_bias.shape == (nheads,)\n    assert A.shape == (nheads,)\n    if rmsnorm_weight is not None:\n        assert rmsnorm_weight.shape == (dim,)\n    z, xBC, dt = torch.split(zxbcdt, [dim, dim + 2 * ngroups * dstate, nheads], dim=-1)\n    xBC = rearrange(causal_conv1d_fn(rearrange(xBC, \"b s d -> b d s\"), conv1d_weight, conv1d_bias, activation=activation),\n                    \"b d s -> b s d\")\n    x, B, C = torch.split(xBC, [dim, ngroups * dstate, ngroups * dstate], dim=-1)\n    x = rearrange(x, \"b l (h p) -> b l h p\", h=nheads)\n    B = rearrange(B, \"b l (g n) -> b l g n\", g=ngroups)\n    C = rearrange(C, \"b l (g n) -> b l g n\", g=ngroups)\n    z = rearrange(z, \"b l (h p) -> b l h p\", h=nheads)\n    out = ssd_selective_scan(x, dt.to(x.dtype), A, B, C, D=D.float(),\n                             z=z if rmsnorm_weight is None else None, dt_bias=dt_bias, dt_softplus=True, dt_limit=dt_limit)\n    out = rearrange(out, \"b s h p -> b s (h p)\")\n    if rmsnorm_weight is not None:\n        out = rmsnorm_fn(out, rmsnorm_weight, None, z=rearrange(z, \"b l h p -> b l (h p)\"), eps=rmsnorm_eps,\n                         norm_before_gate=norm_before_gate)\n    if outproj_weight is not None:\n        out = F.linear(out, outproj_weight, outproj_bias)\n    return out\n\n"
  },
  {
    "path": "mamba_ssm/ops/triton/ssd_state_passing.py",
    "content": "# Copyright (c) 2024, Tri Dao, Albert Gu.\n\n\"\"\"We want triton==2.1.0 or 2.2.0 for this\n\"\"\"\n\nimport math\nimport torch\nimport torch.nn.functional as F\n\nimport triton\nimport triton.language as tl\n\nfrom einops import rearrange, repeat\n\nfrom mamba_ssm.utils.determinism import autotune_configs\n\n\n@triton.autotune(\n    configs=autotune_configs([\n        triton.Config({'BLOCK_SIZE': 64}),\n        triton.Config({'BLOCK_SIZE': 128}),\n        triton.Config({'BLOCK_SIZE': 256}),\n        triton.Config({'BLOCK_SIZE': 512}),\n        triton.Config({'BLOCK_SIZE': 1024}),\n        triton.Config({'BLOCK_SIZE': 2048}),\n    ]),\n    key=['dim'],\n)\n@triton.jit\ndef _state_passing_fwd_kernel(\n    # Pointers to matrices\n    states_ptr, out_ptr, final_states_ptr, dA_cs_ptr, initstates_ptr, seq_idx_ptr,\n    # Matrix dimensions\n    dim, nchunks, seqlen, chunk_size,\n    # Strides\n    stride_states_batch, stride_states_chunk, stride_states_head, stride_states_dim,\n    stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim,\n    stride_final_states_batch, stride_final_states_head, stride_final_states_dim,\n    stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head,\n    stride_initstates_batch, stride_initstates_head, stride_initstates_dim,\n    stride_seq_idx_batch, stride_seq_idx_seqlen,\n    # Meta-parameters\n    HAS_INITSTATES: tl.constexpr,\n    HAS_SEQ_IDX: tl.constexpr,\n    BLOCK_SIZE: tl.constexpr,\n):\n    pid_b = tl.program_id(axis=1)\n    pid_h = tl.program_id(axis=2)\n    pid_m = tl.program_id(axis=0)\n    states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head\n    dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head\n    out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head\n    final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head\n    if HAS_INITSTATES:\n        initstates_ptr += pid_b * stride_initstates_batch + pid_h * stride_initstates_head\n    if HAS_SEQ_IDX:\n        seq_idx_ptr += pid_b * stride_seq_idx_batch\n\n    offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n    states_ptrs = states_ptr + offs_m * stride_states_dim\n    out_ptrs = out_ptr + offs_m * stride_out_dim\n    final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim\n\n    if not HAS_INITSTATES:\n        states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)\n    else:\n        initstates_ptrs = initstates_ptr + offs_m * stride_initstates_dim\n        states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n    tl.store(out_ptrs, states, mask=offs_m < dim)\n    out_ptrs += stride_out_chunk\n    seq_idx = 0\n    for c in range(nchunks):\n        new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n        dA_cs = tl.load(dA_cs_ptr).to(tl.float32)\n        scale = tl.exp(dA_cs)\n        if HAS_SEQ_IDX:\n            seq_idx_new = tl.load(seq_idx_ptr + (min((c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen)\n            scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)\n            seq_idx = seq_idx_new\n        states = scale * states + new_states\n        if c < nchunks - 1:\n            tl.store(out_ptrs, states, mask=offs_m < dim)\n        else:\n            tl.store(final_states_ptrs, states, mask=offs_m < dim)\n        states_ptrs += stride_states_chunk\n        dA_cs_ptr += stride_dA_cs_chunk\n        out_ptrs += stride_out_chunk\n\n\n@triton.autotune(\n    configs=autotune_configs([\n        triton.Config({'BLOCK_SIZE': 64}),\n        triton.Config({'BLOCK_SIZE': 128}),\n        triton.Config({'BLOCK_SIZE': 256}),\n        triton.Config({'BLOCK_SIZE': 512}),\n        triton.Config({'BLOCK_SIZE': 1024}),\n        triton.Config({'BLOCK_SIZE': 2048}),\n    ]),\n    key=['dim'],\n)\n@triton.jit\ndef _state_passing_bwd_kernel(\n    # Pointers to matrices\n    dout_ptr, out_ptr, dA_cs_ptr, dfinal_states_ptr, seq_idx_ptr,\n    dstates_ptr, ddA_cs_ptr, dinitstates_ptr, states_converted_ptr,\n    # Matrix dimensions\n    dim, nchunks, seqlen, chunk_size,\n    # Strides\n    stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_dim,\n    stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim,\n    stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head,\n    stride_dfinal_states_batch, stride_dfinal_states_head, stride_dfinal_states_dim,\n    stride_seq_idx_batch, stride_seq_idx_seqlen,\n    stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_dim,\n    stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head,\n    stride_dinitstates_batch, stride_dinitstates_head, stride_dinitstates_dim,\n    # Meta-parameters\n    CONVERT_STATES: tl.constexpr,\n    HAS_DFINAL_STATES: tl.constexpr,\n    HAS_DINITSTATES: tl.constexpr,\n    HAS_SEQ_IDX: tl.constexpr,\n    BLOCK_SIZE: tl.constexpr,\n):\n    pid_b = tl.program_id(axis=1)\n    pid_h = tl.program_id(axis=2)\n    pid_m = tl.program_id(axis=0)\n    dstates_ptr += pid_b * stride_dstates_batch + pid_h * stride_dstates_head + (nchunks - 1) * stride_dstates_chunk\n    dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + (nchunks - 1) * stride_dA_cs_chunk\n    ddA_cs_ptr += pid_b * stride_ddA_cs_batch + pid_h * stride_ddA_cs_head + (nchunks - 1) * stride_ddA_cs_chunk + pid_m\n    out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk\n    dout_ptr += pid_b * stride_dout_batch + pid_h * stride_dout_head + (nchunks - 1) * stride_dout_chunk\n    if CONVERT_STATES:\n        states_converted_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk\n    if HAS_DFINAL_STATES:\n        dfinal_states_ptr += pid_b * stride_dfinal_states_batch + pid_h * stride_dfinal_states_head\n    if HAS_DINITSTATES:\n        dinitstates_ptr += pid_b * stride_dinitstates_batch + pid_h * stride_dinitstates_head\n    if HAS_SEQ_IDX:\n        seq_idx_ptr += pid_b * stride_seq_idx_batch\n\n    offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n    dstates_ptrs = dstates_ptr + offs_m * stride_dstates_dim\n    out_ptrs = out_ptr + offs_m * stride_out_dim\n    dout_ptrs = dout_ptr + offs_m * stride_dout_dim\n    if CONVERT_STATES:\n        states_converted_ptrs = states_converted_ptr + offs_m * stride_out_dim\n\n    if HAS_DFINAL_STATES:\n        dstates = tl.load(dfinal_states_ptr + offs_m * stride_dfinal_states_dim, mask=offs_m < dim, other=0.0).to(tl.float32)\n    else:\n        dstates = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)\n    tl.store(dstates_ptrs, dstates, mask=offs_m < dim)\n    if HAS_SEQ_IDX:\n        seq_idx = tl.load(seq_idx_ptr + (seqlen - 1) * stride_seq_idx_seqlen)\n    dstates_ptrs -= stride_dstates_chunk\n    for c in range(nchunks - 1):\n        dA_cs = tl.load(dA_cs_ptr).to(tl.float32)\n        scale = tl.exp(dA_cs)\n        if HAS_SEQ_IDX:\n            seq_idx_new = tl.load(seq_idx_ptr + (((nchunks - c - 1) * chunk_size - 1) * stride_seq_idx_seqlen))\n            scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)\n            seq_idx = seq_idx_new\n        out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n        if CONVERT_STATES:\n            tl.store(states_converted_ptrs, out, mask=offs_m < dim)\n        ddA = tl.sum(out * dstates) * scale\n        tl.store(ddA_cs_ptr, ddA)\n        dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n        dstates = scale * dstates + dout\n        tl.store(dstates_ptrs, dstates, mask=offs_m < dim)\n        dout_ptrs -= stride_dout_chunk\n        dstates_ptrs -= stride_dstates_chunk\n        dA_cs_ptr -= stride_dA_cs_chunk\n        ddA_cs_ptr -= stride_ddA_cs_chunk\n        out_ptrs -= stride_out_chunk\n        if CONVERT_STATES:\n            states_converted_ptrs -= stride_out_chunk\n    if CONVERT_STATES:\n        out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n        tl.store(states_converted_ptrs, out, mask=offs_m < dim)\n    if not HAS_DINITSTATES:\n        tl.store(ddA_cs_ptr, 0.0)\n    else:\n        dA_cs = tl.load(dA_cs_ptr).to(tl.float32)\n        scale = tl.exp(dA_cs)\n        if HAS_SEQ_IDX:\n            scale = tl.where(seq_idx == 0, scale, 0.0)\n        out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n        ddA = tl.sum(out * dstates) * scale\n        tl.store(ddA_cs_ptr, ddA)\n        dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)\n        dstates = scale * dstates + dout\n        tl.store(dinitstates_ptr + offs_m * stride_dinitstates_dim, dstates, mask=offs_m < dim)\n\n\ndef _state_passing_fwd(states, dA_chunk_cumsum, initial_states=None, seq_idx=None, chunk_size=None,\n                       out_dtype=None):\n    batch, nchunks, nheads, dim = states.shape\n    assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)\n    if initial_states is not None:\n        assert initial_states.shape == (batch, nheads, dim)\n    if seq_idx is not None:\n        assert chunk_size is not None\n        seqlen = seq_idx.shape[-1]\n        assert seq_idx.shape == (batch, seqlen)\n    out_dtype = states.dtype if out_dtype is None else out_dtype\n    out = torch.empty((batch, nchunks, nheads, dim), device=states.device, dtype=out_dtype)\n    final_states = torch.empty((batch, nheads, dim), device=states.device, dtype=torch.float32)\n    grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads)\n    with torch.cuda.device(states.device.index):\n        _state_passing_fwd_kernel[grid](\n            states, out, final_states, dA_chunk_cumsum, initial_states, seq_idx,\n            dim, nchunks, seqlen if seq_idx is not None else 0, chunk_size if seq_idx is not None else 0,\n            states.stride(0), states.stride(1), states.stride(2), states.stride(3),\n            out.stride(0), out.stride(1), out.stride(2), out.stride(3),\n            final_states.stride(0), final_states.stride(1), final_states.stride(2),\n            dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1),\n            *((initial_states.stride(0), initial_states.stride(1), initial_states.stride(2))\n              if initial_states is not None else (0, 0, 0)),\n            *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n            HAS_INITSTATES=initial_states is not None,\n            HAS_SEQ_IDX=seq_idx is not None,\n        )\n    return out, final_states\n\n\ndef _state_passing_bwd(\n        states, dA_chunk_cumsum, dout, dfinal_states=None, seq_idx=None, has_initial_states=None,\n        dstates_dtype=None, states_dtype=None, chunk_size=None\n):\n    \"\"\"\n    states contains the initial_states at index 0. The final states are not included in states.\n    \"\"\"\n    batch, nchunks, nheads, dim = states.shape\n    assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)\n    assert dout.shape == (batch, nchunks, nheads, dim)\n    if seq_idx is not None:\n        assert chunk_size is not None\n        seqlen = seq_idx.shape[-1]\n        assert seq_idx.shape == (batch, seqlen)\n    dstates = torch.empty_like(dout, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype)\n    if states_dtype is not None and states_dtype != states.dtype:\n        states_converted = torch.empty_like(states, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype)\n        assert states_converted.stride() == states.stride()\n    else:\n        states_converted = None\n    if has_initial_states:\n        dinitstates = torch.empty_like(dstates[:, 0])\n    else:\n        dinitstates = None\n    if dfinal_states is not None:\n        assert dfinal_states.shape == (batch, nheads, dim)\n    BLOCK_SIZE_min = 64\n    n_blocks = (dim + BLOCK_SIZE_min - 1) // BLOCK_SIZE_min\n    ddA_chunk_cumsum = torch.empty(batch, nheads, nchunks, n_blocks,\n                                    dtype=torch.float32, device=dA_chunk_cumsum.device)\n    grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads)\n    with torch.cuda.device(dout.device.index):\n        _state_passing_bwd_kernel[grid](\n            dout, states, dA_chunk_cumsum, dfinal_states, seq_idx,\n            dstates, ddA_chunk_cumsum, dinitstates, states_converted,\n            dim, nchunks, seqlen if seq_idx is not None else 0, chunk_size if seq_idx is not None else 0,\n            dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),\n            states.stride(0), states.stride(1), states.stride(2), states.stride(3),\n            dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1),\n            *((dfinal_states.stride(0), dfinal_states.stride(1), dfinal_states.stride(2))\n                if dfinal_states is not None else (0, 0, 0)),\n            *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n            dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3),\n            ddA_chunk_cumsum.stride(0), ddA_chunk_cumsum.stride(2), ddA_chunk_cumsum.stride(1),\n            *((dinitstates.stride(0), dinitstates.stride(1), dinitstates.stride(2))\n              if dinitstates is not None else (0, 0, 0)),\n            CONVERT_STATES=states_converted is not None,\n            HAS_DFINAL_STATES=dfinal_states is not None,\n            HAS_DINITSTATES=dinitstates is not None,\n            HAS_SEQ_IDX=seq_idx is not None,\n        )\n    BLOCK_SIZE_actual = _state_passing_bwd_kernel.best_config.kwargs[\"BLOCK_SIZE\"]\n    n_valid_blocks = (dim + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual\n    ddA_chunk_cumsum = ddA_chunk_cumsum[..., :n_valid_blocks].sum(dim=-1).to(dtype=dA_chunk_cumsum.dtype)\n    if states_dtype is not None and states_dtype == states.dtype:\n        states_converted = states\n    return (dstates, ddA_chunk_cumsum, dinitstates) if states_dtype is None else (dstates, ddA_chunk_cumsum, dinitstates, states_converted)\n\n\nclass StatePassingFn(torch.autograd.Function):\n\n    @staticmethod\n    def forward(ctx, states, dA_chunk_cumsum, initial_states=None):\n        batch, nchunks, nheads, dim = states.shape\n        assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)\n        if states.stride(-1) != 1:\n            states = states.contiguous()\n        out, final_states = _state_passing_fwd(states, dA_chunk_cumsum, initial_states)\n        ctx.save_for_backward(out, dA_chunk_cumsum)\n        ctx.has_initial_states = initial_states is not None\n        return out, final_states\n\n    @staticmethod\n    def backward(ctx, dout, dfinal_states):\n        out, dA_chunk_cumsum = ctx.saved_tensors\n        batch, nchunks, nheads, dim = out.shape\n        assert dout.shape == (batch, nchunks, nheads, dim)\n        assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)\n        assert dfinal_states.shape == (batch, nheads, dim)\n        if dout.stride(-1) != 1:\n            dout = dout.contiguous()\n        dstates, ddA_chunk_cumsum, dinitstates = _state_passing_bwd(\n            out, dA_chunk_cumsum, dout, dfinal_states=dfinal_states , has_initial_states=ctx.has_initial_states\n        )\n        return dstates, ddA_chunk_cumsum, dinitstates\n\n\ndef state_passing(states, dA_chunk_cumsum, initial_states=None):\n    \"\"\"\n    Argument:\n        states: (batch, nchunks, nheads, dim)\n        dA_chunk_cumsum: (batch, nheads, nchunks)\n        initial_states: (batch, nheads, dim)\n    Return:\n        out: (batch, nchunks, nheads, dim)\n        final_states: (batch, nheads, dim)\n    \"\"\"\n    return StatePassingFn.apply(states, dA_chunk_cumsum, initial_states)\n\n\ndef state_passing_ref(states, dA_chunk_cumsum, initial_states=None):\n    \"\"\"\n    Argument:\n        states: (batch, nchunks, nheads, dim)\n        dA_chunk_cumsum: (batch, nheads, nchunks)\n        initial_states: (batch, nheads, dim)\n    Return:\n        out: (batch, nchunks, nheads, dim)\n        final_states: (batch, nheads, dim)\n    \"\"\"\n    if initial_states is None:\n        initial_states = torch.zeros_like(states[:, 0])\n    states = torch.cat([rearrange(initial_states, \"b h d -> b 1 h d\"), states], dim=1)\n    dA_chunk_cumsum = F.pad(dA_chunk_cumsum, (1, 0))\n    dA_chunk_cumsum = torch.cumsum(dA_chunk_cumsum, dim=-1)\n    nchunks = dA_chunk_cumsum.shape[-1]\n    # (batch, nheads, nchunks, nchunks)\n    dt_chunk_segment_sum = dA_chunk_cumsum[:, :, :, None] - dA_chunk_cumsum[:, :, None, :]\n    # (batch, nheads, nchunks, nchunks)\n    decay_chunk = torch.exp(dt_chunk_segment_sum)\n    causal_mask = torch.tril(torch.ones(nchunks, nchunks, device=states.device, dtype=bool), diagonal=0)\n    decay_chunk = decay_chunk.masked_fill(~causal_mask, 0)\n    out = torch.einsum(\"bhzc,bchd->bzhd\", decay_chunk.to(dtype=states.dtype), states)\n    return out[:, :-1], out[:, -1]\n"
  },
  {
    "path": "mamba_ssm/utils/__init__.py",
    "content": ""
  },
  {
    "path": "mamba_ssm/utils/determinism.py",
    "content": "# Copyright (c) 2024, Tri Dao, Albert Gu.\n\nimport os\nimport warnings\nfrom packaging import version\n\nimport torch\n\ntry:\n    import triton\n    TRITON_VERSION = version.parse(triton.__version__)\nexcept ImportError:\n    TRITON_VERSION = version.parse(\"0.0.0\")\n\nTRITON_HAS_CACHE_RESULTS = TRITON_VERSION >= version.parse(\"3.4.0\")\n_autotune_warning_issued = False\n\n_deterministic_override = None\n\n\ndef use_deterministic_mode():\n    if _deterministic_override is not None:\n        return _deterministic_override\n    env = os.environ.get('MAMBA_DETERMINISTIC')\n    if env:\n        return env[0] == '1'\n    return torch.are_deterministic_algorithms_enabled()\n\n\ndef set_deterministic_mode(value):\n    global _deterministic_override\n    _deterministic_override = value\n\n\ndef _estimate_config_cost(cfg):\n    \"\"\"Estimate shared memory cost of a config. Lower is cheaper.\"\"\"\n    block_product = 1\n    for key, val in cfg.kwargs.items():\n        if key.startswith('BLOCK_SIZE_'):\n            block_product *= val\n    return block_product * (getattr(cfg, 'num_stages', 1) or 1)\n\n\ndef _filter_configs_by_block_sizes(configs):\n    \"\"\"Filter configs by TRITON_AUTOTUNE_BLOCK_SIZE_* env vars.\"\"\"\n    env_filters = {}\n    for suffix in ('M', 'N', 'K', 'DSTATE'):\n        env_val = os.environ.get(f\"TRITON_AUTOTUNE_BLOCK_SIZE_{suffix}\")\n        if env_val is not None:\n            env_filters[f'BLOCK_SIZE_{suffix}'] = int(env_val)\n    if not env_filters:\n        return None\n    matching = configs\n    for key, target in env_filters.items():\n        matching = [c for c in matching if c.kwargs.get(key) == target]\n    return matching[:1] if matching else None\n\n\ndef autotune_configs(configs):\n    \"\"\"Select autotune configs for deterministic mode.\n    \n    Uses cached autotuning (TRITON_CACHE_AUTOTUNING=1) if Triton >= 3.4.0,\n    otherwise auto-selects the cheapest config by block size * stages.\n    \"\"\"\n    if not configs or not use_deterministic_mode():\n        return configs\n    if TRITON_HAS_CACHE_RESULTS and os.environ.get(\"TRITON_CACHE_AUTOTUNING\") == \"1\":\n        return configs\n    global _autotune_warning_issued\n    if not _autotune_warning_issued:\n        _autotune_warning_issued = True\n        msg = \"Deterministic mode: set TRITON_CACHE_AUTOTUNING=1 for cached autotuning.\" if TRITON_HAS_CACHE_RESULTS else \"Deterministic mode: upgrade to Triton >= 3.4.0 for cached autotuning.\"\n        warnings.warn(msg)\n    filtered = _filter_configs_by_block_sizes(configs)\n    if filtered:\n        return filtered\n    return [min(configs, key=_estimate_config_cost)]\n\n\ndef alloc_tile_workspace(base_shape, tile_dim, dtype, device, deterministic, *, zero_init=True):\n    \"\"\"Allocate buffer for deterministic per-program reductions.\"\"\"\n    if base_shape is None:\n        return None, 0\n    if deterministic:\n        factory = torch.zeros if zero_init else torch.empty\n        tensor = factory(*base_shape, tile_dim, device=device, dtype=dtype)\n        return tensor, tensor.stride(-1)\n    return torch.empty(*base_shape, device=device, dtype=dtype), 0\n\n\ndef finalize_tile_workspace(tensor, deterministic):\n    if tensor is None:\n        return None\n    if deterministic:\n        tensor = tensor.sum(dim=-1)\n    return tensor\n"
  },
  {
    "path": "mamba_ssm/utils/generation.py",
    "content": "# Copyright (c) 2023, Albert Gu, Tri Dao.\nimport gc\nimport time\nfrom collections import namedtuple\nfrom dataclasses import dataclass, field\nfrom functools import partial\nfrom typing import Callable, Optional, Sequence, Union\n\nimport torch\nimport torch.nn.functional as F\nfrom einops import rearrange, repeat\nfrom torch import Tensor\nfrom torch.profiler import ProfilerActivity, profile, record_function\nfrom transformers.generation import GenerateDecoderOnlyOutput, TextStreamer\n\n\n@dataclass\nclass InferenceParams:\n    \"\"\"Inference parameters that are passed to the main model in order\n    to efficienly calculate and store the context during inference.\"\"\"\n\n    max_seqlen: int\n    max_batch_size: int\n    seqlen_offset: int = 0\n    batch_size_offset: int = 0\n    key_value_memory_dict: dict = field(default_factory=dict)\n    lengths_per_sample: Optional[Tensor] = None\n\n    def reset(self, max_seqlen, max_batch_size):\n        self.max_seqlen = max_seqlen\n        self.max_batch_size = max_batch_size\n        self.seqlen_offset = 0\n        if self.lengths_per_sample is not None:\n            self.lengths_per_sample.zero_()\n\n\ndef modify_logits_for_min_p_filtering(logits, min_p):\n    \"\"\"Set the logits for none min_p values to -inf. Done in-place.\"\"\"\n    if min_p <= 0.0 or min_p >= 1.0:\n        return\n    indices_to_remove = logits < min_p\n    logits.masked_fill_(indices_to_remove, float(\"-Inf\"))\n# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py\n# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231\ndef modify_logits_for_top_k_filtering(logits, top_k):\n    \"\"\"Set the logits for none top-k values to -inf. Done in-place.\"\"\"\n    indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]\n    logits.masked_fill_(indices_to_remove, float(\"-Inf\"))\n\n\n# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py\n# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170\ndef modify_logits_for_top_p_filtering(logits, top_p):\n    \"\"\"Set the logits for none top-p values to -inf. Done in-place.\"\"\"\n    if top_p <= 0.0 or top_p >= 1.0:\n        return\n    # First sort and calculate cumulative sum of probabilities.\n    sorted_logits, sorted_indices = torch.sort(logits, descending=False)\n    cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)\n    # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)\n    sorted_indices_to_remove = cumulative_probs <= (1 - top_p)\n    # scatter sorted tensors to original indexing\n    indices_to_remove = sorted_indices_to_remove.scatter(\n        1, sorted_indices, sorted_indices_to_remove\n    )\n    logits.masked_fill_(indices_to_remove, float(\"-inf\"))\n\n\ndef modify_logit_for_repetition_penalty(logits, prev_output_tokens, repetition_penalty=1.0):\n    \"\"\"Apply repetition penalty. See https://arxiv.org/abs/1909.05858\n    logits: (batch_size, vocab_size)\n    prev_output_tokens: (batch_size, seq_len)\n    \"\"\"\n    if repetition_penalty == 1.0:\n        return logits\n    score = torch.gather(logits, 1, prev_output_tokens)\n    # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability\n    score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)\n    logits.scatter_(1, prev_output_tokens, score)\n    return logits\n\n\ndef sample(logits, top_k=1, top_p=0.0, min_p=0.0, temperature=1.0):\n    \"\"\"Sample from top-k logits.\n    Arguments:\n        logits: Tensor of shape (batch_size, vocab_size)\n    \"\"\"\n    if top_k == 1:  # Short-circuit for greedy decoding\n        return logits.argmax(dim=-1)\n    else:\n        if top_p > 0.0:\n            assert top_p <= 1.0, \"top-p should be in (0, 1].\"\n        if top_k > 0:\n            top_k = min(top_k, logits.size(-1))  # Safety check\n            logits_top, indices = torch.topk(logits, top_k, dim=-1)\n            if temperature != 1.0:\n                logits_top /= temperature\n            modify_logits_for_top_p_filtering(logits_top, top_p)\n            return indices[\n                torch.arange(indices.shape[0], device=indices.device),\n                torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1),\n            ]\n        else:\n            if min_p > 0.0:\n                logits_top = logits.clone()\n                max_prob = logits_top[..., 0].item()\n                min_prob = max_prob * min_p\n                modify_logits_for_min_p_filtering(logits_top, min_prob)\n                if temperature != 1.0:\n                    logits_top /= temperature\n                return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)\n            # Clone so that when we modify for top_p we don't change the original logits\n            logits_top = logits / temperature if temperature != 1.0 else logits.clone()\n            modify_logits_for_top_p_filtering(logits_top, top_p)\n            return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(\n                dim=-1\n            )\n\n\n@torch.inference_mode()\ndef decode(\n    input_ids,\n    model,\n    max_length,\n    top_k=1,\n    top_p=0.0,\n    min_p=0.0,\n    temperature=1.0,\n    repetition_penalty=1.0,\n    eos_token_id=None,\n    teacher_outputs=None,\n    vocab_size=None,\n    cg=False,\n    enable_timing=False,\n    output_scores=False,\n    streamer: Optional[TextStreamer] = None\n):\n    \"\"\"Decoding, either greedy or with top-k or top-p sampling.\n    If top-k = 0, don't limit the number of candidates (pure sampling).\n    Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,\n    then top-p.\n    We assume that all sequences in the same batch have the same length.\n\n    Arguments:\n        input_ids: (batch, seq_len)\n        max_length: int\n        teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the\n            logits, the next token is taken from the teacher_outputs. Useful for testing.\n    Returns: GenerateDecoderOnlyOutput, with the following fields:\n        sequences: (batch, max_length)\n        scores: tuples of (batch, vocab_size)\n    \"\"\"\n    if streamer is not None:\n        streamer.put(input_ids.cpu())\n\n    batch_size, seqlen_og = input_ids.shape\n    teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0\n    if cg:\n        if not hasattr(model, \"_decoding_cache\"):\n            model._decoding_cache = None\n        model._decoding_cache = update_graph_cache(\n            model,\n            model._decoding_cache,\n            batch_size,\n            seqlen_og,\n            max_length,\n        )\n        inference_params = model._decoding_cache.inference_params\n        inference_params.reset(max_length, batch_size)\n    else:\n        inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)\n\n    def get_logits(input_ids, inference_params):\n        decoding = inference_params.seqlen_offset > 0\n        if decoding:\n            position_ids = torch.full(\n                (batch_size, 1),\n                inference_params.seqlen_offset,\n                dtype=torch.long,\n                device=input_ids.device,\n            )\n        else:\n            position_ids = None\n        if not cg or not decoding:\n            logits = model(\n                input_ids,\n                position_ids=position_ids,\n                inference_params=inference_params,\n                num_last_tokens=1,\n            ).logits.squeeze(dim=1)\n        else:\n            logits = model._decoding_cache.run(\n                input_ids, position_ids, inference_params.seqlen_offset\n            ).squeeze(dim=1)\n        return logits[..., :vocab_size] if vocab_size is not None else logits\n\n    def sample_tokens(logits, inference_params):\n        if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset:\n            token = sample(logits, top_k=top_k, top_p=top_p, min_p=min_p, temperature=temperature)\n        else:\n            token = teacher_outputs[:, inference_params.seqlen_offset]\n        # return rearrange(token, \"b -> b 1\")\n        return token.unsqueeze(1)\n\n    def should_stop(current_token, inference_params):\n        if inference_params.seqlen_offset == 0:\n            return False\n        if eos_token_id is not None and (current_token == eos_token_id).all():\n            return True\n        if inference_params.seqlen_offset >= max_length - 1:\n            return True\n        return False\n\n    start = torch.cuda.Event(enable_timing=enable_timing)\n    end = torch.cuda.Event(enable_timing=enable_timing)\n\n    if enable_timing:\n        start.record()\n    scores, sequences = [], [input_ids]\n    sequences_cat = input_ids\n    while not should_stop(sequences[-1], inference_params):\n        logits = get_logits(sequences[-1], inference_params)\n        if output_scores:\n            scores.append(logits.clone())\n        inference_params.seqlen_offset += sequences[-1].shape[1]\n        if repetition_penalty == 1.0:\n            sampled_tokens = sample_tokens(logits, inference_params)\n        else:\n            logits = modify_logit_for_repetition_penalty(\n                logits, sequences_cat, repetition_penalty\n            )\n            sampled_tokens = sample_tokens(logits, inference_params)\n            sequences_cat = torch.cat([sequences_cat, sampled_tokens], dim=1)\n        sequences.append(sampled_tokens)\n        if streamer is not None:\n            streamer.put(sampled_tokens.cpu())\n    if streamer is not None:\n        streamer.end()\n    if enable_timing:\n        end.record()\n        torch.cuda.synchronize()\n        print(f\"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms\")\n    return GenerateDecoderOnlyOutput(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))\n\n\nclass GenerationMixin:\n    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):\n        raise NotImplementedError\n\n    def generate(\n        self,\n        input_ids,\n        max_length,\n        top_k=1,\n        top_p=0.0,\n        min_p=0.0,\n        temperature=1.0,\n        return_dict_in_generate=False,\n        output_scores=False,\n        **kwargs,\n    ):\n        output = decode(\n            input_ids, self, max_length, top_k=top_k, top_p=top_p, min_p = min_p, temperature=temperature, output_scores=output_scores, **kwargs\n        )\n        if not output_scores:\n            output.scores = None\n        return output if return_dict_in_generate else output.sequences\n\n\n@dataclass\nclass DecodingCGCache:\n    max_batch_size: int = 0\n    max_seqlen: int = 0\n    device = None\n    dtype = None\n    callables: dict = field(default_factory=dict)\n    mempool = None\n    inference_params: Optional[InferenceParams] = None\n    run: Optional[Callable] = None\n\n\n@torch.inference_mode()\ndef update_graph_cache(\n    model,\n    cache,\n    batch_size,\n    seqlen_og,\n    max_seqlen,\n    decoding_seqlens=(1,),\n    dtype=None,\n    n_warmups=2,\n):\n    if cache is None:\n        cache = DecodingCGCache()\n    param_example = next(iter(model.parameters()))\n    device = param_example.device\n    if dtype is None:\n        dtype = param_example.dtype\n    if (\n        (device, dtype) != (cache.device, cache.dtype)\n        or batch_size > cache.max_batch_size\n        or max_seqlen > cache.max_seqlen\n    ):  # Invalidate the cache\n        cache.callables = {}\n        cache.mempool = None\n        cache.inference_params = None\n        gc.collect()\n        cache.device, cache.dtype = device, dtype\n        cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen\n        assert hasattr(model, \"allocate_inference_cache\"), \"CUDA graph decoding requires that the model has a method allocate_inference_cache\"\n        inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)\n        lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)\n        cache.inference_params = InferenceParams(\n            max_seqlen=max_seqlen,\n            max_batch_size=batch_size,\n            seqlen_offset=seqlen_og,\n            key_value_memory_dict=inf_cache,\n            lengths_per_sample=lengths_per_sample,\n        )\n        cache.mempool = torch.cuda.graphs.graph_pool_handle()\n    for decoding_seqlen in decoding_seqlens:\n        if (batch_size, decoding_seqlen) not in cache.callables:\n            cache.callables[batch_size, decoding_seqlen] = capture_graph(\n                model,\n                cache.inference_params,\n                batch_size,\n                max_seqlen,\n                decoding_seqlen=decoding_seqlen,\n                mempool=cache.mempool,\n                n_warmups=n_warmups,\n            )\n\n    def dispatch(input_ids, position_ids, seqlen):\n        batch_size, decoding_seqlen = input_ids.shape[:2]\n        return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen)\n\n    cache.run = dispatch\n    cache.inference_params.seqlen_offset = 0  # Reset so it's not confusing\n    return cache\n\n\ndef capture_graph(\n    model, inference_params, batch_size, max_seqlen, decoding_seqlen=1, mempool=None, n_warmups=2\n):\n    device = next(iter(model.parameters())).device\n    input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)\n    position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)\n    seqlen_offset_og = inference_params.seqlen_offset\n    inference_params.seqlen_offset = max_seqlen - decoding_seqlen\n    inference_params.lengths_per_sample[:] = inference_params.seqlen_offset\n\n    # Warmup before capture\n    s = torch.cuda.Stream()\n    s.wait_stream(torch.cuda.current_stream())\n    with torch.cuda.stream(s):\n        for _ in range(n_warmups):\n            logits = model(\n                input_ids,\n                position_ids=position_ids,\n                inference_params=inference_params,\n                num_last_tokens=decoding_seqlen,\n            ).logits\n        s.synchronize()\n        # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,\n        # which requires that graph launch and non-captured launch to not overlap (I think,\n        # that's how I interpret the documentation). I'm not sure if this is required.\n        if torch.distributed.is_initialized():\n            torch.distributed.barrier()\n    torch.cuda.current_stream().wait_stream(s)\n    # Captures the graph\n    # To allow capture, automatically sets a side stream as the current stream in the context\n    graph = torch.cuda.CUDAGraph()\n    with torch.cuda.graph(graph, pool=mempool):\n        logits = model(\n            input_ids,\n            position_ids=position_ids,\n            inference_params=inference_params,\n            num_last_tokens=decoding_seqlen,\n        ).logits\n\n    def run(new_input_ids, new_position_ids, seqlen):\n        inference_params.lengths_per_sample[:] = seqlen\n        input_ids.copy_(new_input_ids)\n        position_ids.copy_(new_position_ids)\n        graph.replay()\n        return logits.clone()\n\n    inference_params.seqlen_offset = seqlen_offset_og\n    return run\n"
  },
  {
    "path": "mamba_ssm/utils/hf.py",
    "content": "import json\n\nimport torch\n\nfrom transformers.utils import WEIGHTS_NAME, CONFIG_NAME\nfrom transformers.utils.hub import cached_file\n\n\ndef load_config_hf(model_name):\n    resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False)\n    return json.load(open(resolved_archive_file))\n\n\ndef load_state_dict_hf(model_name, device=None, dtype=None):\n    # If not fp32, then we don't want to load directly to the GPU\n    mapped_device = \"cpu\" if dtype not in [torch.float32, None] else device\n    resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)\n    return torch.load(resolved_archive_file, map_location=mapped_device)\n    # Convert dtype before moving to GPU to save memory\n    if dtype is not None:\n        state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}\n    state_dict = {k: v.to(device=device) for k, v in state_dict.items()}\n    return state_dict\n"
  },
  {
    "path": "mamba_ssm/utils/torch.py",
    "content": "import torch\nfrom functools import partial\nfrom typing import Callable\n\ndef custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool):\n    def decorator(*args, **kwargs):\n        if cuda_amp_deprecated:\n            kwargs[\"device_type\"] = \"cuda\"\n        return dec(*args, **kwargs)\n    return decorator\n\n\nif hasattr(torch.amp, \"custom_fwd\"): # type: ignore[attr-defined]\n    deprecated = True\n    from torch.amp import custom_fwd, custom_bwd # type: ignore[attr-defined]\nelse:\n    deprecated = False\n    from torch.cuda.amp import custom_fwd, custom_bwd\n\ncustom_fwd = custom_amp_decorator(custom_fwd, deprecated)\ncustom_bwd = custom_amp_decorator(custom_bwd, deprecated)\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[project]\nname = \"mamba_ssm\"\ndescription = \"Mamba state-space model\"\nreadme = \"README.md\"\nauthors = [\n    { name = \"Tri Dao\", email = \"tri@tridao.me\" },\n    { name = \"Albert Gu\", email = \"agu@cs.cmu.edu\" }\n]\nrequires-python = \">= 3.9\"\ndynamic = [\"version\"]\nlicense = { file = \"LICENSE\" }  # Include a LICENSE file in your repo\nkeywords = [\"cuda\", \"pytorch\", \"state-space model\"]\nclassifiers = [\n    \"Programming Language :: Python :: 3\",\n    \"License :: OSI Approved :: Apache Software License\",\n    \"Operating System :: Unix\"\n]\ndependencies = [\n    \"torch\",\n    \"triton\",\n    \"ninja\",\n    \"einops\",\n    \"transformers\",\n    \"packaging\",\n    \"setuptools>=61.0.0\",\n]\n[project.urls]\nRepository = \"https://github.com/state-spaces/mamba\"\n\n[project.optional-dependencies]\ncausal-conv1d = [\n    \"causal-conv1d>=1.2.0\"\n]\ndev = [\n    \"pytest\"\n]\n\n\n[build-system]\n# torch is intentionally excluded: pip's build isolation would install\n# torch-cpu from PyPI, ignoring the user's CUDA-enabled torch.\n# Users building from source should install torch first, then:\n#   pip install mamba-ssm --no-build-isolation\nrequires = [\n    \"setuptools>=61.0.0\",\n    \"wheel\",\n    \"packaging\",\n    \"ninja\",\n]\nbuild-backend = \"setuptools.build_meta\"\n"
  },
  {
    "path": "rocm_patch/rocm6_0.patch",
    "content": "--- /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h\t2023-12-12 20:11:48.000000000 +0000\n+++ rocm_update_files/amd_hip_bf16.h\t2024-05-20 17:40:26.983349079 +0000\n@@ -137,7 +137,7 @@\n  * \\ingroup HIP_INTRINSIC_BFLOAT16_CONV\n  * \\brief Converts float to bfloat16\n  */\n-__HOST_DEVICE__ __hip_bfloat16 __float2bfloat16(float f) {\n+__HOST_DEVICE__ static inline __hip_bfloat16 __float2bfloat16(float f) {\n   __hip_bfloat16 ret;\n   union {\n     float fp32;\n@@ -181,7 +181,7 @@\n  * \\ingroup HIP_INTRINSIC_BFLOAT162_CONV\n  * \\brief Converts and moves bfloat162 to float2\n  */\n-__HOST_DEVICE__ float2 __bfloat1622float2(const __hip_bfloat162 a) {\n+__HOST_DEVICE__ static inline float2 __bfloat1622float2(const __hip_bfloat162 a) {\n   return float2{__bfloat162float(a.x), __bfloat162float(a.y)};\n }\n \n@@ -209,7 +209,7 @@\n  * \\ingroup HIP_INTRINSIC_BFLOAT162_CONV\n  * \\brief Convert double to __hip_bfloat16\n  */\n-__HOST_DEVICE__ __hip_bfloat16 __double2bfloat16(const double a) {\n+__HOST_DEVICE__ static inline __hip_bfloat16 __double2bfloat16(const double a) {\n   return __float2bfloat16((float)a);\n }\n \n@@ -217,7 +217,7 @@\n  * \\ingroup HIP_INTRINSIC_BFLOAT162_CONV\n  * \\brief Convert float2 to __hip_bfloat162\n  */\n-__HOST_DEVICE__ __hip_bfloat162 __float22bfloat162_rn(const float2 a) {\n+__HOST_DEVICE__ static inline __hip_bfloat162 __float22bfloat162_rn(const float2 a) {\n   return __hip_bfloat162{__float2bfloat16(a.x), __float2bfloat16(a.y)};\n }\n \n@@ -247,7 +247,7 @@\n  * \\ingroup HIP_INTRINSIC_BFLOAT162_CONV\n  * \\brief Converts high 16 bits of __hip_bfloat162 to float and returns the result\n  */\n-__HOST_DEVICE__ float __high2float(const __hip_bfloat162 a) { return __bfloat162float(a.y); }\n+__HOST_DEVICE__ static inline float __high2float(const __hip_bfloat162 a) { return __bfloat162float(a.y); }\n \n /**\n  * \\ingroup HIP_INTRINSIC_BFLOAT162_CONV\n@@ -275,7 +275,7 @@\n  * \\ingroup HIP_INTRINSIC_BFLOAT162_CONV\n  * \\brief Converts low 16 bits of __hip_bfloat162 to float and returns the result\n  */\n-__HOST_DEVICE__ float __low2float(const __hip_bfloat162 a) { return __bfloat162float(a.x); }\n+__HOST_DEVICE__ static inline float __low2float(const __hip_bfloat162 a) { return __bfloat162float(a.x); }\n \n /**\n  * \\ingroup HIP_INTRINSIC_BFLOAT162_CONV\n"
  },
  {
    "path": "setup.py",
    "content": "# Copyright (c) 2023, Albert Gu, Tri Dao.\nimport sys\nimport warnings\nimport os\nimport re\nimport ast\nfrom pathlib import Path\nfrom packaging.version import parse, Version\nimport platform\nimport shutil\n\nfrom setuptools import setup, find_packages\nimport subprocess\n\nimport urllib.request\nimport urllib.error\nfrom wheel.bdist_wheel import bdist_wheel as _bdist_wheel\n\nimport torch\nfrom torch.utils.cpp_extension import (\n    BuildExtension,\n    CUDAExtension,\n    CUDA_HOME,\n    HIP_HOME\n)\n\n\nwith open(\"README.md\", \"r\", encoding=\"utf-8\") as fh:\n    long_description = fh.read()\n\n\n# ninja build does not work unless include_dirs are abs path\nthis_dir = os.path.dirname(os.path.abspath(__file__))\n\nPACKAGE_NAME = \"mamba_ssm\"\n\nBASE_WHEEL_URL = \"https://github.com/state-spaces/mamba/releases/download/{tag_name}/{wheel_name}\"\n\n# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels\n# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation\nFORCE_BUILD = os.getenv(\"MAMBA_FORCE_BUILD\", \"FALSE\") == \"TRUE\"\nSKIP_CUDA_BUILD = os.getenv(\"MAMBA_SKIP_CUDA_BUILD\", \"FALSE\") == \"TRUE\"\n# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI\nFORCE_CXX11_ABI = os.getenv(\"MAMBA_FORCE_CXX11_ABI\", \"FALSE\") == \"TRUE\"\n\n\ndef get_platform():\n    \"\"\"\n    Returns the platform name as used in wheel filenames.\n    \"\"\"\n    if sys.platform.startswith(\"linux\"):\n        return f\"linux_{platform.machine()}\"\n    elif sys.platform == \"darwin\":\n        mac_version = \".\".join(platform.mac_ver()[0].split(\".\")[:2])\n        return f\"macosx_{mac_version}_x86_64\"\n    elif sys.platform == \"win32\":\n        return \"win_amd64\"\n    else:\n        raise ValueError(\"Unsupported platform: {}\".format(sys.platform))\n\n\ndef get_cuda_bare_metal_version(cuda_dir):\n    raw_output = subprocess.check_output(\n        [cuda_dir + \"/bin/nvcc\", \"-V\"], universal_newlines=True\n    )\n    output = raw_output.split()\n    release_idx = output.index(\"release\") + 1\n    bare_metal_ver = parse(output[release_idx].split(\",\")[0])\n\n    return raw_output, bare_metal_ver\n\n\ndef get_hip_version(rocm_dir):\n\n    hipcc_bin = \"hipcc\" if rocm_dir is None else os.path.join(rocm_dir, \"bin\", \"hipcc\")\n    try:\n        raw_output = subprocess.check_output(\n            [hipcc_bin, \"--version\"], universal_newlines=True\n        )\n    except Exception as e:\n        print(\n            f\"hip installation not found: {e} ROCM_PATH={os.environ.get('ROCM_PATH')}\"\n        )\n        return None, None\n\n    for line in raw_output.split(\"\\n\"):\n        if \"HIP version\" in line:\n            rocm_version = parse(line.split()[-1].rstrip('-').replace('-', '+')) # local version is not parsed correctly\n            return line, rocm_version\n\n    return None, None\n\n\ndef get_torch_hip_version():\n\n    if torch.version.hip:\n        return parse(torch.version.hip.split()[-1].rstrip('-').replace('-', '+'))\n    else:\n        return None\n\n\ndef check_if_hip_home_none(global_option: str) -> None:\n\n    if HIP_HOME is not None:\n        return\n    # warn instead of error because user could be downloading prebuilt wheels, so hipcc won't be necessary\n    # in that case.\n    warnings.warn(\n        f\"{global_option} was requested, but hipcc was not found.  Are you sure your environment has hipcc available?\"\n    )\n\n\ndef check_if_cuda_home_none(global_option: str) -> None:\n    if CUDA_HOME is not None:\n        return\n    # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary\n    # in that case.\n    warnings.warn(\n        f\"{global_option} was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  \"\n        \"If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, \"\n        \"only images whose names contain 'devel' will provide nvcc.\"\n    )\n\n\ndef append_nvcc_threads(nvcc_extra_args):\n    return nvcc_extra_args + [\"--threads\", \"4\"]\n\n\ncmdclass = {}\next_modules = []\n\n\nHIP_BUILD = bool(torch.version.hip)\n\nif not SKIP_CUDA_BUILD:\n    print(\"\\n\\ntorch.__version__  = {}\\n\\n\".format(torch.__version__))\n    TORCH_MAJOR = int(torch.__version__.split(\".\")[0])\n    TORCH_MINOR = int(torch.__version__.split(\".\")[1])\n\n    cc_flag = []\n\n    if HIP_BUILD:\n        check_if_hip_home_none(PACKAGE_NAME)\n\n        rocm_home = os.getenv(\"ROCM_PATH\")\n        _, hip_version = get_hip_version(rocm_home)\n\n        if HIP_HOME is not None:\n            if hip_version < Version(\"6.0\"):\n                raise RuntimeError(\n                    f\"{PACKAGE_NAME} is only supported on ROCm 6.0 and above.  \"\n                    \"Note: make sure HIP has a supported version by running hipcc --version.\"\n                )\n            if hip_version == Version(\"6.0\"):\n                warnings.warn(\n                    f\"{PACKAGE_NAME} requires a patch to be applied when running on ROCm 6.0. \"\n                    \"Refer to the README.md for detailed instructions.\",\n                    UserWarning\n                )\n\n        cc_flag.append(\"-DBUILD_PYTHON_PACKAGE\")\n\n    else:\n        check_if_cuda_home_none(PACKAGE_NAME)\n        # Check, if CUDA11 is installed for compute capability 8.0\n\n        if CUDA_HOME is not None:\n            _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)\n            if bare_metal_version < Version(\"11.6\"):\n                raise RuntimeError(\n                    f\"{PACKAGE_NAME} is only supported on CUDA 11.6 and above.  \"\n                    \"Note: make sure nvcc has a supported version by running nvcc -V.\"\n                )\n\n        # If system CUDA and PyTorch CUDA have different major versions,\n        # clear TORCH_CUDA_ARCH_LIST to prevent cpp_extension from erroring\n        torch_cuda_version = parse(torch.version.cuda)\n        if bare_metal_version.major != torch_cuda_version.major:\n            os.environ[\"TORCH_CUDA_ARCH_LIST\"] = \"\"\n\n        cc_flag.append(\"-gencode\")\n        cc_flag.append(\"arch=compute_75,code=sm_75\")\n        cc_flag.append(\"-gencode\")\n        cc_flag.append(\"arch=compute_80,code=sm_80\")\n        cc_flag.append(\"-gencode\")\n        cc_flag.append(\"arch=compute_87,code=sm_87\")\n        if bare_metal_version >= Version(\"11.8\"):\n            cc_flag.append(\"-gencode\")\n            cc_flag.append(\"arch=compute_90,code=sm_90\")\n        if bare_metal_version >= Version(\"12.8\"):\n            cc_flag.append(\"-gencode\")\n            cc_flag.append(\"arch=compute_100,code=sm_100\")\n            cc_flag.append(\"-gencode\")\n            cc_flag.append(\"arch=compute_120,code=sm_120\")\n        if bare_metal_version >= Version(\"13.0\"):\n            cc_flag.append(\"-gencode\")\n            cc_flag.append(\"arch=compute_103,code=sm_103\")\n            cc_flag.append(\"-gencode\")\n            cc_flag.append(\"arch=compute_110,code=sm_110\")\n            cc_flag.append(\"-gencode\")\n            cc_flag.append(\"arch=compute_121,code=sm_121\")\n\n\n    # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as\n    # torch._C._GLIBCXX_USE_CXX11_ABI\n    # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920\n    if FORCE_CXX11_ABI:\n        torch._C._GLIBCXX_USE_CXX11_ABI = True\n\n    if HIP_BUILD:\n\n        extra_compile_args = {\n            \"cxx\": [\"-O3\", \"-std=c++17\"],\n            \"nvcc\": [\n                \"-O3\",\n                \"-std=c++17\",\n                f\"--offload-arch={os.getenv('HIP_ARCHITECTURES', 'native')}\",\n                \"-U__CUDA_NO_HALF_OPERATORS__\",\n                \"-U__CUDA_NO_HALF_CONVERSIONS__\",\n                \"-fgpu-flush-denormals-to-zero\",\n            ]\n            + cc_flag,\n        }\n    else:\n        extra_compile_args = {\n            \"cxx\": [\"-O3\", \"-std=c++17\"],\n            \"nvcc\": append_nvcc_threads(\n                [\n                    \"-O3\",\n                    \"-std=c++17\",\n                    \"-U__CUDA_NO_HALF_OPERATORS__\",\n                    \"-U__CUDA_NO_HALF_CONVERSIONS__\",\n                    \"-U__CUDA_NO_BFLOAT16_OPERATORS__\",\n                    \"-U__CUDA_NO_BFLOAT16_CONVERSIONS__\",\n                    \"-U__CUDA_NO_BFLOAT162_OPERATORS__\",\n                    \"-U__CUDA_NO_BFLOAT162_CONVERSIONS__\",\n                    \"--expt-relaxed-constexpr\",\n                    \"--expt-extended-lambda\",\n                    \"--use_fast_math\",\n                    \"--ptxas-options=-v\",\n                    \"-lineinfo\",\n                ]\n                + cc_flag\n            ),\n        }\n\n    ext_modules.append(\n        CUDAExtension(\n            name=\"selective_scan_cuda\",\n            sources=[\n                \"csrc/selective_scan/selective_scan.cpp\",\n                \"csrc/selective_scan/selective_scan_fwd_fp32.cu\",\n                \"csrc/selective_scan/selective_scan_fwd_fp16.cu\",\n                \"csrc/selective_scan/selective_scan_fwd_bf16.cu\",\n                \"csrc/selective_scan/selective_scan_bwd_fp32_real.cu\",\n                \"csrc/selective_scan/selective_scan_bwd_fp32_complex.cu\",\n                \"csrc/selective_scan/selective_scan_bwd_fp16_real.cu\",\n                \"csrc/selective_scan/selective_scan_bwd_fp16_complex.cu\",\n                \"csrc/selective_scan/selective_scan_bwd_bf16_real.cu\",\n                \"csrc/selective_scan/selective_scan_bwd_bf16_complex.cu\",\n            ],\n            extra_compile_args=extra_compile_args,\n            include_dirs=[Path(this_dir) / \"csrc\" / \"selective_scan\"],\n        )\n    )\n\n\ndef get_package_version():\n    with open(Path(this_dir) / PACKAGE_NAME / \"__init__.py\", \"r\") as f:\n        version_match = re.search(r\"^__version__\\s*=\\s*(.*)$\", f.read(), re.MULTILINE)\n    public_version = ast.literal_eval(version_match.group(1))\n    local_version = os.environ.get(\"MAMBA_LOCAL_VERSION\")\n    if local_version:\n        return f\"{public_version}+{local_version}\"\n    else:\n        return str(public_version)\n\n\ndef get_wheel_url():\n    # Determine the version numbers that will be used to determine the correct wheel\n    torch_version_raw = parse(torch.__version__)\n\n    if HIP_BUILD:\n        # We're using the HIP version used to build torch, not the one currently installed\n        torch_hip_version = get_torch_hip_version()\n        hip_ver = f\"{torch_hip_version.major}{torch_hip_version.minor}\"\n    else:\n        # We're using the CUDA version used to build torch, not the one currently installed\n        # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)\n        torch_cuda_version = parse(torch.version.cuda)\n        # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3\n        # to save CI time. Minor versions should be compatible.\n        if torch_cuda_version.major == 11:\n            torch_cuda_version = parse(\"11.8\")\n        elif torch_cuda_version.major == 12:\n            torch_cuda_version = parse(\"12.3\")\n        elif torch_cuda_version.major == 13:\n            torch_cuda_version = parse(\"13.0\")\n        else:\n            raise ValueError(f\"CUDA version {torch_cuda_version} not supported\")\n        \n        cuda_version = f\"{torch_cuda_version.major}\"\n\n    gpu_compute_version = hip_ver if HIP_BUILD else cuda_version\n    cuda_or_hip = \"hip\" if HIP_BUILD else \"cu\"\n\n    python_version = f\"cp{sys.version_info.major}{sys.version_info.minor}\"\n    platform_name = get_platform()\n    mamba_ssm_version = get_package_version()\n    if os.environ.get(\"NVIDIA_PRODUCT_NAME\", \"\") == \"PyTorch\":\n        torch_version = str(os.environ.get(\"NVIDIA_PYTORCH_VERSION\"))\n        # On NGC images, use the container's CUDA version (matching how wheels are built)\n        ngc_cuda_version = os.environ.get(\"CUDA_VERSION\", \"\")\n        if ngc_cuda_version:\n            cuda_version = str(parse(ngc_cuda_version).major)\n    else:\n        torch_version = f\"{torch_version_raw.major}.{torch_version_raw.minor}\"\n    cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()\n\n    # Determine wheel URL based on CUDA version, torch version, python version and OS\n    wheel_filename = f\"{PACKAGE_NAME}-{mamba_ssm_version}+{cuda_or_hip}{gpu_compute_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl\"\n    wheel_url = BASE_WHEEL_URL.format(\n        tag_name=f\"v{mamba_ssm_version}\", wheel_name=wheel_filename\n    )\n    return wheel_url, wheel_filename\n\n\nclass CachedWheelsCommand(_bdist_wheel):\n    \"\"\"\n    The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot\n    find an existing wheel (which is currently the case for all installs). We use\n    the environment parameters to detect whether there is already a pre-built version of a compatible\n    wheel available and short-circuits the standard full build pipeline.\n    \"\"\"\n\n    def run(self):\n        if FORCE_BUILD:\n            return super().run()\n\n        wheel_url, wheel_filename = get_wheel_url()\n        print(\"Guessing wheel URL: \", wheel_url)\n        try:\n            urllib.request.urlretrieve(wheel_url, wheel_filename)\n\n            # Make the archive\n            # Lifted from the root wheel processing command\n            # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85\n            if not os.path.exists(self.dist_dir):\n                os.makedirs(self.dist_dir)\n\n            impl_tag, abi_tag, plat_tag = self.get_tag()\n            archive_basename = f\"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}\"\n\n            wheel_path = os.path.join(self.dist_dir, archive_basename + \".whl\")\n            print(\"Raw wheel path\", wheel_path)\n            shutil.move(wheel_filename, wheel_path)\n        except urllib.error.HTTPError:\n            print(\"Precompiled wheel not found. Building from source...\")\n            # If the wheel could not be downloaded, build from source\n            super().run()\n\nsetup(\n    name=PACKAGE_NAME,\n    version=get_package_version(),\n    packages=find_packages(\n        exclude=(\n            \"build\",\n            \"csrc\",\n            \"include\",\n            \"tests\",\n            \"dist\",\n            \"docs\",\n            \"benchmarks\",\n            \"mamba_ssm.egg-info\",\n        )\n    ),\n    author=\"Tri Dao, Albert Gu\",\n    author_email=\"tri@tridao.me, agu@cs.cmu.edu\",\n    description=\"Mamba state-space model\",\n    long_description=long_description,\n    long_description_content_type=\"text/markdown\",\n    url=\"https://github.com/state-spaces/mamba\",\n    classifiers=[\n        \"Programming Language :: Python :: 3\",\n        \"License :: OSI Approved :: Apache Software License\",\n        \"Operating System :: Unix\",\n    ],\n    ext_modules=ext_modules,\n    cmdclass={\"bdist_wheel\": CachedWheelsCommand, \"build_ext\": BuildExtension}\n    if ext_modules\n    else {\n        \"bdist_wheel\": CachedWheelsCommand,\n    },\n    python_requires=\">=3.9\",\n    install_requires=[\n        \"torch\",\n        \"packaging\",\n        \"ninja\",\n        \"einops\",\n        \"triton>=3.5.0\",\n        \"transformers\",\n        \"tilelang>=0.1.7.post3\",\n        \"nvidia-cutlass-dsl==4.4.1\",\n        \"quack-kernels==0.3.1\",\n        # \"causal_conv1d>=1.4.0\",\n    ],\n)\n"
  },
  {
    "path": "tests/benchmark_determinism_kernels.py",
    "content": "#!/usr/bin/env python\n# Copyright (c) 2024, Tri Dao, Albert Gu.\n\nimport gc\nimport math\n\nimport torch\nfrom triton.testing import do_bench\n\nfrom mamba_ssm.utils.determinism import set_deterministic_mode\n\nMODEL_PRESETS = {\n    \"small\": {\"nheads\": 32, \"headdim\": 64, \"dstate\": 64, \"ngroups\": 1},\n    \"nemotronh-56b\": {\"nheads\": 256, \"headdim\": 64, \"dstate\": 256, \"ngroups\": 8},\n}\n\n\ndef _reset_peak_memory() -> None:\n    gc.collect()\n    torch.cuda.empty_cache()\n    torch.cuda.reset_peak_memory_stats()\n    torch.cuda.synchronize()\n\n\ndef _peak_memory_mb(fn, *, warmup: int = 3) -> float:\n    for _ in range(warmup):\n        fn()\n    torch.cuda.synchronize()\n    _reset_peak_memory()\n    fn()\n    torch.cuda.synchronize()\n    return torch.cuda.max_memory_allocated() / (1024 * 1024)\n\n\ndef make_tensors(*, batch: int, seqlen: int, nheads: int, headdim: int, dstate: int, ngroups: int, chunk_size: int,\n                 dtype: torch.dtype = torch.bfloat16) -> dict[str, torch.Tensor]:\n    device = \"cuda\"\n    nchunks = math.ceil(seqlen / chunk_size)\n    return {\n        \"x\": torch.randn(batch, seqlen, nheads, headdim, device=device, dtype=dtype),\n        \"B\": torch.randn(batch, seqlen, ngroups, dstate, device=device, dtype=dtype),\n        \"C\": torch.randn(batch, seqlen, ngroups, dstate, device=device, dtype=dtype),\n        \"dt\": torch.randn(batch, nheads, nchunks, chunk_size, device=device, dtype=torch.float32),\n        \"dA_cumsum\": torch.randn(batch, nheads, nchunks, chunk_size, device=device, dtype=torch.float32),\n        \"dstates\": torch.randn(batch, nchunks, nheads, headdim, dstate, device=device, dtype=torch.float32),\n        \"dout\": torch.randn(batch, seqlen, nheads, headdim, device=device, dtype=dtype),\n        \"ddA\": torch.randn(batch, nheads, nchunks, chunk_size, device=device, dtype=torch.float32),\n        \"ddt_out\": torch.randn(batch, nheads, nchunks, chunk_size, device=device, dtype=torch.float32),\n        \"dt_raw\": torch.randn(batch, seqlen, nheads, device=device, dtype=dtype),\n        \"A\": torch.randn(nheads, device=device, dtype=torch.float32) * -1,\n        \"dt_bias\": torch.randn(nheads, device=device, dtype=torch.float32),\n        \"prev_states\": torch.randn(batch, nchunks, nheads, headdim, dstate, device=device, dtype=torch.float32),\n        \"cb\": torch.randn(batch, nchunks, ngroups, chunk_size, chunk_size, device=device, dtype=dtype),\n    }\n\n\ndef get_benchmarks(t: dict[str, torch.Tensor], *, ngroups: int):\n    from mamba_ssm.ops.triton.ssd_chunk_state import (\n        _chunk_cumsum_bwd,\n        _chunk_state_bwd_db,\n        _chunk_state_bwd_ddAcs_stable,\n        _chunk_state_bwd_dx,\n    )\n    from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_dC, _chunk_scan_bwd_dx\n    from mamba_ssm.ops.triton.ssd_combined import _chunk_scan_chunk_state_bwd_dx\n\n    x = t[\"x\"].contiguous()\n    B = t[\"B\"].contiguous()\n    C = t[\"C\"].contiguous()\n    dout = t[\"dout\"].contiguous()\n    dstates = t[\"dstates\"].contiguous()\n\n    return [\n        (\"chunk_cumsum_bwd\", lambda: _chunk_cumsum_bwd(t[\"ddA\"], t[\"ddt_out\"], t[\"dt_raw\"], t[\"A\"], dt_bias=t[\"dt_bias\"], dt_softplus=True)),\n        (\"chunk_state_bwd_dx\", lambda: _chunk_state_bwd_dx(B, x, t[\"dt\"], t[\"dA_cumsum\"], dstates)),\n        (\"chunk_state_bwd_db\", lambda: _chunk_state_bwd_db(x, t[\"dt\"], t[\"dA_cumsum\"], dstates, B=B, ngroups=ngroups)),\n        (\"chunk_state_bwd_ddAcs\", lambda: _chunk_state_bwd_ddAcs_stable(B, x, t[\"dt\"], t[\"dA_cumsum\"], dstates)),\n        (\"chunk_scan_bwd_dC\", lambda: _chunk_scan_bwd_dC(t[\"prev_states\"], t[\"dA_cumsum\"], dout, C=C, ngroups=ngroups)),\n        (\"chunk_scan_bwd_dx\", lambda: _chunk_scan_bwd_dx(t[\"cb\"], x, t[\"dt\"], t[\"dA_cumsum\"], dout)),\n        (\"combined_bwd_dx\", lambda: _chunk_scan_chunk_state_bwd_dx(x, t[\"dt\"], t[\"dA_cumsum\"], B, t[\"cb\"], dout, dstates)),\n    ]\n\n\ndef _run_one(fn, *, deterministic: bool, warmup: int, rep: int):\n    set_deterministic_mode(deterministic)\n    ms = do_bench(fn, warmup=warmup, rep=rep, return_mode=\"median\")\n    peak_mb = _peak_memory_mb(fn, warmup=1)\n    return ms, peak_mb\n\n\ndef main() -> None:\n    import argparse\n\n    parser = argparse.ArgumentParser(description=\"Benchmark determinism overhead for key Triton backward kernels\")\n    parser.add_argument(\"--preset\", choices=sorted(MODEL_PRESETS.keys()), default=\"small\")\n    parser.add_argument(\"--warmup\", type=int, default=25)\n    parser.add_argument(\"--rep\", type=int, default=100)\n    parser.add_argument(\"--batch\", type=int, default=4)\n    parser.add_argument(\"--seqlen\", type=int, default=2048)\n    parser.add_argument(\"--chunk-size\", type=int, default=256)\n    args = parser.parse_args()\n\n    if not torch.cuda.is_available():\n        raise SystemExit(\"CUDA not available\")\n\n    p = MODEL_PRESETS[args.preset]\n    tensors = make_tensors(\n        batch=args.batch,\n        seqlen=args.seqlen,\n        nheads=p[\"nheads\"],\n        headdim=p[\"headdim\"],\n        dstate=p[\"dstate\"],\n        ngroups=p[\"ngroups\"],\n        chunk_size=args.chunk_size,\n    )\n    benches = get_benchmarks(tensors, ngroups=p[\"ngroups\"])\n\n    print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n    print(f\"preset={args.preset} batch={args.batch} seqlen={args.seqlen} chunk_size={args.chunk_size}\")\n    print(f\"{'kernel':<20} {'ms':>9} {'det_ms':>9} {'ms_%':>6} {'MB':>9} {'det_MB':>9} {'MB_%':>6}\")\n\n    rows = []\n    try:\n        for name, fn in benches:\n            ms, mb = _run_one(fn, deterministic=False, warmup=args.warmup, rep=args.rep)\n            det_ms, det_mb = _run_one(fn, deterministic=True, warmup=args.warmup, rep=args.rep)\n            ms_pct = (det_ms / ms - 1.0) * 100.0\n            mb_pct = (det_mb / mb - 1.0) * 100.0 if mb else 0.0\n            rows.append((name, ms, det_ms, ms_pct, mb, det_mb, mb_pct))\n            print(f\"{name:<20} {ms:>9.3f} {det_ms:>9.3f} {ms_pct:>+6.0f}% {mb:>9.1f} {det_mb:>9.1f} {mb_pct:>+6.0f}%\")\n    finally:\n        set_deterministic_mode(None)\n\n    total_ms = sum(r[1] for r in rows)\n    total_det_ms = sum(r[2] for r in rows)\n    max_mb = max(r[4] for r in rows) if rows else 0.0\n    max_det_mb = max(r[5] for r in rows) if rows else 0.0\n    total_pct = (total_det_ms / total_ms - 1.0) * 100.0 if total_ms else 0.0\n    max_mb_pct = (max_det_mb / max_mb - 1.0) * 100.0 if max_mb else 0.0\n    print(f\"{'TOTAL/MAX':<20} {total_ms:>9.3f} {total_det_ms:>9.3f} {total_pct:>+6.0f}% {max_mb:>9.1f} {max_det_mb:>9.1f} {max_mb_pct:>+6.0f}%\")\n\n\nif __name__ == \"__main__\":\n    main()\n\n\n"
  },
  {
    "path": "tests/ops/cute/test_mamba3_mimo_step.py",
    "content": "\"\"\"\nMamba-3 MIMO Step Function Tests\n\nCopyright (c) 2026, Dao AI Lab, Goombalab\n\nPytest coverage for Mamba3.step() and mixed forward/step decoding.\n\nUsage:\npytest -q -s -p no:warnings tests/ops/cute/test_mamba3_mimo_step.py  # For correctness tests\npython tests/ops/cute/test_mamba3_mimo_step.py  # For benchmark\n\nRemove the -s flag for less verbose output.\n\"\"\"\nimport logging\nimport sys\nimport warnings\nfrom dataclasses import dataclass, field\nfrom pathlib import Path\nfrom typing import Optional\n\nimport pytest\nimport torch\nfrom torch import Tensor\n\n\nwarnings.filterwarnings(\"ignore\")\nlogging.disable(logging.WARNING)\n\n\n\n\n\nBATCH = 128\nSEQLEN = 32\nNHEADS = 64\nHDIM = 64\nDSTATE = 128\nMIMO_DIM = 4\nUSE_TILELANG = True\nDTYPE = torch.bfloat16\nDEVICE = \"cuda\"\nRTOL = 0.1\nATOL = 0.1\n\n\ndef _require_cuda_and_kernel_deps() -> None:\n    if not torch.cuda.is_available():\n        pytest.skip(\"CUDA is required for mamba3 step tests\")\n    pytest.importorskip(\"tilelang\")\n    pytest.importorskip(\"triton\")\n\n\ndef _mamba3_cls():\n    from mamba_ssm.modules.mamba3 import Mamba3\n\n    return Mamba3\n\n\n@pytest.fixture(scope=\"module\", autouse=True)\ndef _kernel_deps() -> None:\n    _require_cuda_and_kernel_deps()\n\n\n# Adapted from\n# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31\n@dataclass\nclass InferenceParams:\n    \"\"\"Inference parameters used to store context during inference.\"\"\"\n\n    max_seqlen: int\n    max_batch_size: int\n    seqlen_offset: int = 0\n    batch_size_offset: int = 0\n    key_value_memory_dict: dict = field(default_factory=dict)\n    new_key_value_memory_dict: dict = field(default_factory=dict)\n    lengths_per_sample: Optional[Tensor] = None\n\n    def reset(self, max_seqlen, max_batch_size):\n        self.max_seqlen = max_seqlen\n        self.max_batch_size = max_batch_size\n        self.seqlen_offset = 0\n        if self.lengths_per_sample is not None:\n            self.lengths_per_sample.zero_()\n\n\n@dataclass\nclass RunOutputs:\n    config_label: str\n    split: int\n    out_fwd_fp32: Tensor\n    outputs_step: Tensor\n    prefix_out: Tensor\n    outputs_mixed: Tensor\n\n\ndef _case_config(*, is_outproj_norm: bool) -> dict:\n    d_model = NHEADS * HDIM // 2\n    return {\n        \"d_model\": d_model,\n        \"d_state\": DSTATE,\n        \"headdim\": HDIM,\n        \"is_mimo\": True,\n        \"mimo_rank\": MIMO_DIM,\n        \"chunk_size\": 64 // MIMO_DIM,\n        \"dtype\": DTYPE,\n        \"device\": DEVICE,\n        \"layer_idx\": 0,\n        \"use_tilelang\": USE_TILELANG,\n        \"is_outproj_norm\": is_outproj_norm,\n    }\n\n\ndef _diff_stats(actual: Tensor, expected: Tensor) -> str:\n    diff = (actual.float() - expected.float()).abs()\n    return f\"max_abs={diff.max().item():.6e}, mean_abs={diff.mean().item():.6e}\"\n\n\ndef _assert_close(\n    actual: Tensor,\n    expected: Tensor,\n    *,\n    label: str,\n    cfg: str,\n    step: Optional[int] = None,\n) -> None:\n    try:\n        torch.testing.assert_close(\n            actual.float(),\n            expected.float(),\n            rtol=RTOL,\n            atol=ATOL,\n        )\n    except AssertionError as err:\n        location = f\", step={step}\" if step is not None else \"\"\n        stats = _diff_stats(actual, expected)\n        raise AssertionError(\n            f\"{label} assertion failed for {cfg}{location} ({stats})\"\n        ) from err\n\n\ndef _run_case(*, is_outproj_norm: bool) -> RunOutputs:\n    Mamba3 = _mamba3_cls()\n    cfg = _case_config(is_outproj_norm=is_outproj_norm)\n    config_label = (\n        f\"use_tilelang={cfg['use_tilelang']}, \"\n        f\"is_outproj_norm={cfg['is_outproj_norm']}, \"\n        f\"batch={BATCH}, seqlen={SEQLEN}, \"\n        f\"nheads={NHEADS}, hdim={HDIM}, dstate={DSTATE}, mimo_dim={MIMO_DIM}\"\n    )\n\n    torch.manual_seed(42)\n    torch.cuda.manual_seed_all(42)\n    model_fwd = Mamba3(**cfg)\n    model_fwd.eval()\n\n    cfg_fp32 = {**cfg, \"dtype\": torch.float32}\n    torch.manual_seed(42)\n    torch.cuda.manual_seed_all(42)\n    model_fwd_fp32 = Mamba3(**cfg_fp32)\n    model_fwd_fp32.eval()\n    model_fwd_fp32.load_state_dict(\n        {k: v.float() for k, v in model_fwd.state_dict().items()},\n        strict=False,\n    )\n\n    torch.manual_seed(42)\n    torch.cuda.manual_seed_all(42)\n    model_step = Mamba3(**cfg)\n    model_step.eval()\n    model_step.load_state_dict(model_fwd.state_dict(), strict=False)\n\n    torch.manual_seed(42)\n    torch.cuda.manual_seed_all(42)\n    model_mix = Mamba3(**cfg)\n    model_mix.eval()\n    model_mix.load_state_dict(model_fwd.state_dict(), strict=False)\n\n    u = torch.randn(BATCH, SEQLEN, cfg[\"d_model\"], device=DEVICE, dtype=DTYPE)\n\n    with torch.no_grad():\n        out_fwd_fp32 = model_fwd_fp32(u.float())\n\n        state = model_step.allocate_inference_cache(BATCH, 1, device=DEVICE, dtype=DTYPE)\n        outputs_step = []\n        for t in range(SEQLEN):\n            out_step, nxt_angle_state, state_out, nxt_k_state, nxt_v_state = model_step.step(\n                u[:, t], *state\n            )\n            state = (nxt_angle_state, state_out, nxt_k_state, nxt_v_state)\n            outputs_step.append(out_step)\n        outputs_step = torch.stack(outputs_step, dim=1)\n\n        split = SEQLEN // 2\n        assert 0 < split < SEQLEN\n        inference_params = InferenceParams(max_seqlen=SEQLEN, max_batch_size=BATCH)\n        prefix_out = model_mix(u[:, :split], inference_params=inference_params)\n        state = inference_params.key_value_memory_dict[model_mix.layer_idx]\n        mixed_suffix = []\n        for t in range(split, SEQLEN):\n            out_step, nxt_angle_state, state_out, nxt_k_state, nxt_v_state = model_mix.step(\n                u[:, t], *state\n            )\n            state = (nxt_angle_state, state_out, nxt_k_state, nxt_v_state)\n            mixed_suffix.append(out_step)\n        outputs_mixed = torch.cat([prefix_out, torch.stack(mixed_suffix, dim=1)], dim=1)\n\n    return RunOutputs(\n        config_label=config_label,\n        split=split,\n        out_fwd_fp32=out_fwd_fp32,\n        outputs_step=outputs_step,\n        prefix_out=prefix_out,\n        outputs_mixed=outputs_mixed,\n    )\n\n\n@pytest.mark.parametrize(\n    \"is_outproj_norm\",\n    [\n        pytest.param(False, id=\"outproj_norm_false\"),\n        pytest.param(True, id=\"outproj_norm_true\"),\n    ],\n)\ndef test_step_matches_forward_fp32(is_outproj_norm: bool) -> None:\n    outputs = _run_case(is_outproj_norm=is_outproj_norm)\n\n    for t in range(SEQLEN):\n        _assert_close(\n            outputs.outputs_step[:, t],\n            outputs.out_fwd_fp32[:, t],\n            label=\"pure-step\",\n            cfg=outputs.config_label,\n            step=t,\n        )\n\n    _assert_close(\n        outputs.prefix_out,\n        outputs.out_fwd_fp32[:, :outputs.split],\n        label=\"mixed-prefix\",\n        cfg=outputs.config_label,\n    )\n\n    for t in range(outputs.split, SEQLEN):\n        _assert_close(\n            outputs.outputs_mixed[:, t],\n            outputs.out_fwd_fp32[:, t],\n            label=\"mixed-suffix\",\n            cfg=outputs.config_label,\n            step=t,\n        )\n\n\ndef run_step_benchmark(*, is_outproj_norm: bool) -> None:\n    _require_cuda_and_kernel_deps()\n    from triton.testing import do_bench_cudagraph\n    Mamba3 = _mamba3_cls()\n\n    cfg = _case_config(is_outproj_norm=is_outproj_norm)\n    rotate_str = \"halved\" if USE_TILELANG else \"pairwise\"\n\n    torch.manual_seed(42)\n    torch.cuda.manual_seed_all(42)\n    model_step = Mamba3(**cfg)\n    model_step.eval()\n\n    state_bm = model_step.allocate_inference_cache(BATCH, 1, device=DEVICE, dtype=DTYPE)\n    u_step_bm = torch.randn(BATCH, cfg[\"d_model\"], device=DEVICE, dtype=DTYPE)\n\n    with torch.no_grad():\n        model_step.step(u_step_bm, *state_bm)\n\n    def full_step_fn():\n        out, _, _, _, _ = model_step.step(u_step_bm, *state_bm)\n        return out\n\n    ms_full = do_bench_cudagraph(full_step_fn, rep=30)\n\n    dtype_size = torch.tensor([], dtype=DTYPE).element_size()\n    state_dtype_size = 4\n    num_rope_angles = model_step.num_rope_angles\n\n    bytes_read = (\n        BATCH * cfg[\"d_model\"] * dtype_size\n        + BATCH * NHEADS * HDIM * DSTATE * state_dtype_size\n        + BATCH * NHEADS * num_rope_angles * state_dtype_size\n        + BATCH * MIMO_DIM * NHEADS * DSTATE * dtype_size\n        + BATCH * NHEADS * HDIM * dtype_size\n    )\n    bytes_write = (\n        BATCH * NHEADS * HDIM * dtype_size\n        + BATCH * NHEADS * HDIM * DSTATE * state_dtype_size\n        + BATCH * NHEADS * num_rope_angles * state_dtype_size\n        + BATCH * MIMO_DIM * NHEADS * DSTATE * dtype_size\n        + BATCH * NHEADS * HDIM * dtype_size\n    )\n\n    total_bytes = bytes_read + bytes_write\n    bw = total_bytes / (ms_full * 1e-3) / 1e9\n\n    print(\"\\n\" + \"=\" * 70)\n    print(\n        \"Benchmark: Mamba3.step() \"\n        f\"(rotation={rotate_str}, is_outproj_norm={is_outproj_norm})\"\n    )\n    print(\"=\" * 70)\n    print(\n        f\"  batch={BATCH}, d_model={cfg['d_model']}, nheads={NHEADS}, \"\n        f\"hdim={HDIM}, dstate={DSTATE}, mimo_dim={MIMO_DIM}\"\n    )\n    print(f\"  Time per step: {ms_full:.4f} ms\")\n    print(\n        \"  Memory I/O:    \"\n        f\"{total_bytes / 1e6:.2f} MB \"\n        f\"(Read: {bytes_read / 1e6:.2f} MB, Write: {bytes_write / 1e6:.2f} MB)\"\n    )\n    print(f\"  Bandwidth:     {bw:.1f} GB/s\")\n\n\nif __name__ == \"__main__\":\n    run_step_benchmark(is_outproj_norm=False)\n    run_step_benchmark(is_outproj_norm=True)\n"
  },
  {
    "path": "tests/ops/test_selective_scan.py",
    "content": "# Copyright (C) 2023, Tri Dao.\n\nimport math\n\nimport torch\nimport torch.nn.functional as F\nimport pytest\n\nfrom einops import rearrange\n\nfrom mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref\nfrom mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, mamba_inner_ref\n\n\n# @pytest.mark.parametrize('wtype', [torch.float32, torch.complex64])\n@pytest.mark.parametrize('wtype', [torch.float32])\n# @pytest.mark.parametrize('itype', [torch.float32, torch.float16, torch.bfloat16])\n@pytest.mark.parametrize('itype', [torch.float32])\n# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 372, 512, 784, 1024, 1134, 2048, 4096])\n@pytest.mark.parametrize('seqlen', [128, 256, 512, 1024, 2048, 4096])\n# @pytest.mark.parametrize('seqlen', [128])\n# @pytest.mark.parametrize(\"return_last_state\", [False, True])\n@pytest.mark.parametrize(\"return_last_state\", [True])\n# @pytest.mark.parametrize('has_delta_bias', [False, True])\n@pytest.mark.parametrize('has_delta_bias', [True])\n# @pytest.mark.parametrize('delta_softplus', [False, True])\n@pytest.mark.parametrize('delta_softplus', [True])\n# @pytest.mark.parametrize('has_z', [False, True])\n@pytest.mark.parametrize('has_z', [True])\n# @pytest.mark.parametrize('has_D', [False, True])\n@pytest.mark.parametrize('has_D', [True])\n@pytest.mark.parametrize(\"varBC_groups\", [1, 2])\n# @pytest.mark.parametrize(\"varBC_groups\", [1])\n# @pytest.mark.parametrize(\"is_variable_C\", [False, True])\n@pytest.mark.parametrize(\"is_variable_C\", [True])\n# @pytest.mark.parametrize(\"is_variable_B\", [False, True])\n@pytest.mark.parametrize(\"is_variable_B\", [True])\ndef test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z, has_delta_bias,\n                        delta_softplus, return_last_state, seqlen, itype, wtype):\n    if varBC_groups > 1 and (not is_variable_B or not is_variable_C):\n        pytest.skip()  # This config is not applicable\n    device = 'cuda'\n    rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3)\n    if itype == torch.bfloat16:\n        rtol, atol = 3e-2, 5e-2\n    rtolw, atolw = (1e-3, 1e-3)\n    if has_z:  # If we have z, the errors on the weights seem higher\n        rtolw = max(rtolw, rtol)\n        atolw = max(atolw, atol)\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 2\n    dim = 4\n    dstate = 8\n    is_complex = wtype == torch.complex64\n    A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_()\n    if not is_variable_B:\n        B_shape = (dim, dstate)\n    elif varBC_groups == 1:\n        B_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2)\n    else:\n        B_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2)\n    B = torch.randn(*B_shape, device=device, dtype=wtype if not is_variable_B else itype,\n                    requires_grad=True)\n    if not is_variable_C:\n        C_shape = (dim, dstate)\n    elif varBC_groups == 1:\n        C_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2)\n    else:\n        C_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2)\n    C = torch.randn(*C_shape, device=device, dtype=wtype if not is_variable_C else itype,\n                    requires_grad=True)\n    if has_D:\n        D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)\n    else:\n        D = None\n    if has_z:\n        z = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True)\n    else:\n        z = None\n    if has_delta_bias:\n        delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_()\n    else:\n        delta_bias = None\n    u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True)\n    delta = (0.5 * torch.rand(batch_size, dim, seqlen, device=device, dtype=itype)).requires_grad_()\n    A_ref = A.detach().clone().requires_grad_()\n    B_ref = B.detach().clone().requires_grad_()\n    C_ref = C.detach().clone().requires_grad_()\n    D_ref = D.detach().clone().requires_grad_() if D is not None else None\n    z_ref = z.detach().clone().requires_grad_() if z is not None else None\n    u_ref = u.detach().clone().requires_grad_()\n    delta_ref = delta.detach().clone().requires_grad_()\n    delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None\n    out, *rest = selective_scan_fn(\n        u, delta, A, B, C, D, z=z,\n        delta_bias=delta_bias, delta_softplus=delta_softplus,\n        return_last_state=return_last_state\n    )\n    if return_last_state:\n        state = rest[0]\n    out_ref, *rest = selective_scan_ref(\n        u_ref, delta_ref, A_ref, B_ref, C_ref, D_ref, z=z_ref,\n        delta_bias=delta_bias_ref, delta_softplus=delta_softplus,\n        return_last_state=return_last_state\n    )\n    if return_last_state:\n        state_ref = rest[0]\n    # dA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))\n    # dt_u = delta * u\n\n    print(f'Output max diff: {(out - out_ref).abs().max().item()}')\n    print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')\n    assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)\n    if return_last_state:\n        print(f'State max diff: {(state - state_ref).abs().max().item()}')\n        assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)\n\n    g = torch.randn_like(out)\n    out_ref.backward(g)\n    out.backward(g)\n\n    print(f'du max diff: {(u.grad - u_ref.grad).abs().max().item()}')\n    print(f'ddelta max diff: {(delta.grad - delta_ref.grad).abs().max().item()}')\n    print(f'dA max diff: {(A.grad - A_ref.grad).abs().max().item()}')\n    print(f'dB max diff: {(B.grad - B_ref.grad).abs().max().item()}')\n    print(f'dC max diff: {(C.grad - C_ref.grad).abs().max().item()}')\n    if has_D:\n        print(f'dD max diff: {(D.grad - D_ref.grad).abs().max().item()}')\n    if has_z:\n        print(f'dz max diff: {(z.grad - z_ref.grad).abs().max().item()}')\n    if has_delta_bias:\n        print(f'ddelta_bias max diff: {(delta_bias.grad - delta_bias_ref.grad).abs().max().item()}')\n\n    assert torch.allclose(u.grad, u_ref.grad.to(dtype=itype), rtol=rtol * 2, atol=atol * 2)\n    assert torch.allclose(delta.grad, delta_ref.grad.to(dtype=itype), rtol=rtol * 5, atol=atol * 10)\n    assert torch.allclose(A.grad, A_ref.grad, rtol=rtolw, atol=atolw * 5)\n    assert torch.allclose(B.grad, B_ref.grad, rtol=rtolw if not is_variable_B else rtol,\n                          atol=atolw if not is_variable_B else atol)\n    assert torch.allclose(C.grad, C_ref.grad, rtol=rtolw if not is_variable_C else rtol,\n                          atol=atolw if not is_variable_C else atol)\n    if has_D:\n        assert torch.allclose(D.grad, D_ref.grad, rtol=rtolw, atol=atolw)\n    if has_z:\n        assert torch.allclose(z.grad, z_ref.grad, rtol=rtolw, atol=atolw)\n    if has_delta_bias:\n        assert torch.allclose(delta_bias.grad, delta_bias_ref.grad, rtol=rtolw, atol=atolw)\n\n\n@pytest.mark.parametrize('wtype', [torch.float32, torch.complex64])\n# @pytest.mark.parametrize('wtype', [torch.complex64])\n# @pytest.mark.parametrize('itype', [torch.float32, torch.float16, torch.bfloat16])\n@pytest.mark.parametrize('itype', [torch.float32])\n# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 372, 512, 784, 1024, 1134, 2048, 4096])\n@pytest.mark.parametrize('seqlen', [128])\n@pytest.mark.parametrize(\"is_variable_C\", [False, True])\n# @pytest.mark.parametrize(\"is_variable_C\", [False])\n@pytest.mark.parametrize(\"is_variable_B\", [False, True])\n# @pytest.mark.parametrize(\"is_variable_B\", [True])\ndef test_mamba_inner_fn(is_variable_B, is_variable_C, seqlen, itype, wtype):\n    device = 'cuda'\n    rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3)\n    if itype == torch.bfloat16:\n        rtol, atol = 3e-2, 5e-2\n    rtolw, atolw = (1e-3, 1e-3)\n    # If we have z, the errors on the weights seem higher\n    rtolw = max(rtolw, rtol)\n    atolw = max(atolw, atol)\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 2\n    dim = 768\n    dstate = 8\n    dt_rank = 48\n    is_complex = wtype == torch.complex64\n    xz = torch.randn(batch_size, 2 * dim, seqlen, device=device, dtype=itype, requires_grad=True)\n    conv1d_weight = torch.randn(dim, 1, 3, device=device, dtype=torch.float32, requires_grad=True)\n    conv1d_bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)\n    x_proj_weight = torch.randn(dt_rank + (bool(is_variable_B) + bool(is_variable_C)) * dstate\n                                * (1 if not is_complex else 2),\n                                dim, device=device, dtype=itype, requires_grad=True)\n    delta_proj_weight = torch.randn(dim, dt_rank, device=device, dtype=itype, requires_grad=True)\n    out_proj_weight = torch.randn(dim // 2, dim, device=device, dtype=itype, requires_grad=True)\n    out_proj_bias = None\n    A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_()\n    B = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True)\n         if not is_variable_B else None)\n    C = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True)\n         if not is_variable_C else None)\n    D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)\n    delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_()\n    B_proj_bias = None\n    C_proj_bias = None\n    xz_ref = xz.detach().clone().requires_grad_()\n    conv1d_weight_ref = conv1d_weight.detach().clone().requires_grad_()\n    conv1d_bias_ref = conv1d_bias.detach().clone().requires_grad_()\n    x_proj_weight_ref = x_proj_weight.detach().clone().requires_grad_()\n    delta_proj_weight_ref = delta_proj_weight.detach().clone().requires_grad_()\n    out_proj_weight_ref = out_proj_weight.detach().clone().requires_grad_()\n    out_proj_bias_ref = (out_proj_bias.detach().clone().requires_grad_()\n                         if out_proj_bias is not None else None)\n    A_ref = A.detach().clone().requires_grad_()\n    B_ref = B.detach().clone().requires_grad_() if B is not None else None\n    C_ref = C.detach().clone().requires_grad_() if C is not None else None\n    D_ref = D.detach().clone().requires_grad_()\n    delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None\n    out = mamba_inner_fn(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,\n                         out_proj_weight, out_proj_bias,\n                         A, B, C, D, delta_bias=delta_bias, delta_softplus=True)\n    out_ref = mamba_inner_ref(xz_ref, conv1d_weight_ref, conv1d_bias_ref, x_proj_weight_ref,\n                              delta_proj_weight_ref, out_proj_weight_ref, out_proj_bias_ref,\n                              A_ref, B_ref, C_ref, D_ref,\n                              delta_bias=delta_bias_ref, delta_softplus=True)\n    # dA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))\n    # dt_u = delta * u\n\n    print(f'Output max diff: {(out - out_ref).abs().max().item()}')\n    print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')\n    assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)\n\n    g = torch.randn_like(out)\n    out_ref.backward(g)\n    out.backward(g)\n\n    print(f'dxz max diff: {(xz.grad - xz_ref.grad).abs().max().item()}')\n    print(f'dA max diff: {(A.grad - A_ref.grad).abs().max().item()}')\n    if not is_variable_B:\n        print(f'dB max diff: {(B.grad - B_ref.grad).abs().max().item()}')\n    if not is_variable_C:\n        print(f'dC max diff: {(C.grad - C_ref.grad).abs().max().item()}')\n    print(f'dD max diff: {(D.grad - D_ref.grad).abs().max().item()}')\n    print(f'ddelta_bias max diff: {(delta_bias.grad - delta_bias_ref.grad).abs().max().item()}')\n    print(f'dout_proj_weight max diff: {(out_proj_weight.grad - out_proj_weight_ref.grad).abs().max().item()}')\n    print(f'ddelta_proj_weight max diff: {(delta_proj_weight.grad - delta_proj_weight_ref.grad).abs().max().item()}')\n    print(f'dx_proj_weight max diff: {(x_proj_weight.grad - x_proj_weight_ref.grad).abs().max().item()}')\n    print(f'dconv1d_weight max diff: {(conv1d_weight.grad - conv1d_weight_ref.grad).abs().max().item()}')\n    print(f'dconv1d_bias max diff: {(conv1d_bias.grad - conv1d_bias_ref.grad).abs().max().item()}')\n\n    # assert torch.allclose(xz.grad, xz_ref.grad.to(dtype=itype), rtol=rtol * 2, atol=atol * 2)\n    # assert torch.allclose(delta.grad, delta_ref.grad.to(dtype=itype), rtol=rtol * 5, atol=atol * 10)\n    # assert torch.allclose(A.grad, A_ref.grad, rtol=rtolw, atol=atolw * 5)\n    # assert torch.allclose(B.grad, B_ref.grad, rtol=rtolw if not is_variable_B else rtol,\n    #                       atol=atolw if not is_variable_B else atol)\n    # assert torch.allclose(C.grad, C_ref.grad, rtol=rtolw if not is_variable_C else rtol,\n    #                       atol=atolw if not is_variable_C else atol)\n    # assert torch.allclose(D.grad, D_ref.grad, rtol=rtolw, atol=atolw)\n    # assert torch.allclose(delta_bias.grad, delta_bias_ref.grad, rtol=rtolw, atol=atolw)\n"
  },
  {
    "path": "tests/ops/tilelang/test_mamba3_mimo.py",
    "content": "\"\"\"\nMamba-3 MIMO Kernel Tests\n\nCopyright (c) 2026, Dao AI Lab, Goombalab\n\n\nUsage:\npytest -q -s -p no:warnings tests/ops/tilelang/test_mamba3_mimo.py -k bwd\npytest -q -s -p no:warnings tests/ops/tilelang/test_mamba3_mimo.py -k fwd\npytest -q -s -p no:warnings tests/ops/tilelang/test_mamba3_mimo.py -k smoke\npytest -q -s -p no:warnings tests/ops/tilelang/test_mamba3_mimo.py -k chunk_ref_matches_step_ref\n\nRemove the -s flag for less verbose output.\n\"\"\"\n\nimport sys\nfrom pathlib import Path\nfrom types import SimpleNamespace\nimport math\nfrom typing import Optional, Tuple\nfrom einops import rearrange, repeat\n\n\nimport pytest\nimport torch\nfrom torch import Tensor\nF = torch.nn.functional\n\n\nFIXED_B = 4\nFIXED_S = 2048\nFIXED_H = 16\nFIXED_G = 1\nFIXED_ROTARY_DIM_DIVISOR = 4\nFIXED_DTYPE = torch.bfloat16\nREL_TOL = 0.10\n\n\nCASE_GRID = [\n    pytest.param(16, 64, 4, 8, 128, id=\"N16_P64_R4_C8_BB128\"),\n    pytest.param(32, 64, 4, 16, 256, id=\"N32_P64_R4_C16_BB256\"),\n    pytest.param(64, 64, 4, 16, 256, id=\"N64_P64_R4_C16_BB256\"),\n    pytest.param(128, 64, 4, 16, 256, id=\"N128_P64_R4_C16_BB256\"),\n    pytest.param(256, 64, 4, 8, 256, id=\"N256_P64_R4_C8_BB256\"),\n    pytest.param(64, 128, 4, 16, 256, id=\"N64_P128_R4_C16_BB256\"),\n    pytest.param(128, 32, 4, 16, 256, id=\"N128_P32_R4_C16_BB256\"),\n    pytest.param(128, 128, 4, 8, 256, id=\"N128_P128_R4_C8_BB256\"),\n    pytest.param(128, 64, 8, 8, 256, id=\"N128_P64_R8_C8_BB256\"),\n    pytest.param(128, 64, 2, 32, 256, id=\"N128_P64_R2_C32_BB256\"),\n    pytest.param(128, 64, 1, 64, 256, id=\"N128_P64_R1_C64_BB256\"),\n]\n\n\ndef _require_cuda_and_kernel_deps() -> None:\n    if not torch.cuda.is_available():\n        pytest.skip(\"CUDA is required for mamba3 tilelang tests\")\n    pytest.importorskip(\"tilelang\")\n    pytest.importorskip(\"triton\")\n\n\n@pytest.fixture(scope=\"module\")\ndef mods() -> SimpleNamespace:\n    _require_cuda_and_kernel_deps()\n    import mamba_ssm.ops.tilelang.mamba3.mamba3_mimo as mamba3_top\n    import mamba_ssm.ops.tilelang.mamba3.mamba3_mimo_bwd as mamba3_bwd\n    import mamba_ssm.ops.tilelang.mamba3.mamba3_mimo_fwd as mamba3_fwd\n    import mamba_ssm.ops.triton.mamba3.mamba3_mimo_utils as mamba3_mimo_utils\n\n    return SimpleNamespace(\n        top=mamba3_top,\n        bwd=mamba3_bwd,\n        fwd=mamba3_fwd,\n        utils=mamba3_mimo_utils,\n    )\n\n\ndef max_rel_err(ours: Tensor, ref: Tensor, eps: float = 1e-5) -> float:\n    ours_f = ours.float()\n    ref_f = ref.float()\n    num = (ours_f - ref_f).abs().max()\n    den = ref_f.abs().max().clamp_min(eps)\n    return float((num / den).item())\n\n\ndef assert_stable_rel(\n    ours: Tensor,\n    ref: Tensor,\n    *,\n    label: str,\n    cfg: str,\n    rel_tol: float = REL_TOL,\n) -> None:\n    ours_f = ours.float()\n    ref_f = ref.float()\n    rel = max_rel_err(ours_f, ref_f)\n    close_mask = torch.isclose(ours_f, ref_f, rtol=REL_TOL, atol=0.1)\n    bad_frac = float((~close_mask).float().mean().item())\n    max_abs = float((ours_f - ref_f).abs().max().item())\n    print(\n        f\"[debug] {label} ({cfg}) \"\n        f\"stable_max_rel={rel:.6f} max_abs={max_abs:.6e} \"\n        f\"bad_frac(rtol=0.1,atol=0.1)={bad_frac:.6f}\"\n    )\n    if rel < rel_tol:\n        return\n\n    raise AssertionError(\n        f\"{label} stable_max_rel >= {rel_tol} for {cfg}: \"\n        f\"stable_max_rel={rel:.6f}, max_abs={max_abs:.6e}, \"\n        f\"diag_bad_frac_at_rtol0.1_atol0.1={bad_frac:.6f}\"\n    )\n\n\ndef build_inputs(\n    *,\n    mods: SimpleNamespace,\n    n: int,\n    p: int,\n    r: int,\n    chunk_size: int,\n    seed: int,\n    b: int = FIXED_B,\n    s: int = FIXED_S,\n    h: int = FIXED_H,\n    g: int = FIXED_G,\n    dtype: torch.dtype = FIXED_DTYPE,\n    has_z: bool = True,\n    has_d: bool = True,\n    rotary_dim_divisor: int = FIXED_ROTARY_DIM_DIVISOR,\n) -> dict:\n    assert s % chunk_size == 0\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n\n    q = torch.randn((b, s, r, g, n), device=\"cuda\", dtype=dtype)\n    k = torch.randn((b, s, r, g, n), device=\"cuda\", dtype=dtype)\n    v = torch.randn((b, s, h, p), device=\"cuda\", dtype=dtype)\n\n    q_bias = torch.randn((h, r, n), device=\"cuda\", dtype=torch.float32)\n    k_bias = torch.randn((h, r, n), device=\"cuda\", dtype=torch.float32)\n    mimo_v = torch.randn((h, r, p), device=\"cuda\", dtype=torch.float32) / r\n    mimo_o = torch.randn((h, r, p), device=\"cuda\", dtype=torch.float32) / r\n\n    z = torch.randn_like(v) if has_z else None\n    mimo_z = torch.randn_like(mimo_v) if has_z else None\n    d = torch.randn((h,), device=\"cuda\", dtype=torch.float32) if has_d else None\n\n    angles = torch.rand(\n        (b, s, h, n // rotary_dim_divisor), device=\"cuda\", dtype=torch.float32\n    )\n    dt = F.softplus(-3.0 + torch.randn((b, h, s), device=\"cuda\", dtype=torch.float32))\n    a = torch.rand((b, h, s), device=\"cuda\", dtype=torch.float32)\n    dA = (-dt * a).detach()\n    dA_cs, dA_cs_rev, segsum = mods.utils.compute_dacs_segsum_triton(dA, chunk_size)\n    trap = torch.rand((b, h, s), device=\"cuda\", dtype=dtype)\n    dout = torch.randn_like(v)\n\n    return {\n        \"q\": q,\n        \"k\": k,\n        \"v\": v,\n        \"q_bias\": q_bias,\n        \"k_bias\": k_bias,\n        \"mimo_v\": mimo_v,\n        \"mimo_o\": mimo_o,\n        \"z\": z,\n        \"mimo_z\": mimo_z,\n        \"D\": d,\n        \"angles\": angles,\n        \"dt\": dt,\n        \"dA\": dA,\n        \"dA_cs\": dA_cs,\n        \"dA_cs_rev\": dA_cs_rev,\n        \"segsum\": segsum,\n        \"trap\": trap,\n        \"dout\": dout,\n        \"chunk_size\": chunk_size,\n        \"rotary_dim_divisor\": rotary_dim_divisor,\n    }\n\ndef make_smoke_inputs(\n    *,\n    batch: int = 1,\n    seqlen: int = 64,\n    mimo_rank: int = 4,\n    nheads_qk: int = 1,\n    nheads: int = 8,\n    headdim_qk: int = 64,\n    headdim_v: int = 32,\n    chunk_size: int = 16,\n    rotary_dim_divisor: int = 4,\n    device: str = \"cuda\",\n    dtype: torch.dtype = torch.bfloat16,\n    seed: int = 0,\n):\n    torch.manual_seed(seed)\n    if device == \"cuda\":\n        torch.cuda.manual_seed_all(seed)\n\n    Q = torch.randn(\n        (batch, seqlen, mimo_rank, nheads_qk, headdim_qk),\n        device=device,\n        dtype=dtype,\n        requires_grad=True,\n    )\n    K = torch.randn_like(Q, requires_grad=True)\n    V = torch.randn(\n        (batch, seqlen, nheads, headdim_v),\n        device=device,\n        dtype=dtype,\n        requires_grad=True,\n    )\n\n    import torch.nn.functional as F\n    import math\n    DT = F.softplus(\n        -3.0\n        + torch.randn(\n            batch,\n            nheads,\n            seqlen,\n            device=device,\n            dtype=torch.float,\n        )\n    ).detach().requires_grad_(True)\n    # Make ADT a leaf so .grad is populated without retain_grad().\n    ADT = (-DT.detach() * math.log2(math.e)).clone().detach().requires_grad_(True)\n    \n    Trap = (\n        torch.rand(\n            (batch, nheads, seqlen),\n            device=device,\n            dtype=dtype,\n        )\n        * 0.5\n    ).detach().requires_grad_(True)\n\n    Q_bias = torch.randn(\n        (nheads, mimo_rank, headdim_qk),\n        device=device,\n        dtype=torch.float32,\n        requires_grad=True,\n    )\n    K_bias = torch.randn_like(Q_bias, requires_grad=True)\n    MIMO_V = torch.randn(\n        (nheads, mimo_rank, headdim_v),\n        device=device,\n        dtype=torch.float32,\n        requires_grad=True,\n    )\n    MIMO_Z = (torch.randn_like(MIMO_V) / mimo_rank).detach().requires_grad_(True)\n    MIMO_Out = (torch.randn_like(MIMO_V) / mimo_rank).detach().requires_grad_(True)\n    Angles = torch.rand(\n        (batch, seqlen, nheads, headdim_qk // rotary_dim_divisor),\n        device=device,\n        dtype=torch.float32,\n        requires_grad=True,\n    )\n    D = torch.randn(\n        (nheads,),\n        device=device,\n        dtype=torch.float32,\n        requires_grad=True,\n    )\n    Z = torch.randn(\n        (batch, seqlen, nheads, headdim_v),\n        device=device,\n        dtype=dtype,\n        requires_grad=True,\n    )\n\n    return dict(\n        Q=Q,\n        K=K,\n        V=V,\n        ADT=ADT,\n        DT=DT,\n        Trap=Trap,\n        Q_bias=Q_bias,\n        K_bias=K_bias,\n        MIMO_V=MIMO_V,\n        MIMO_Z=MIMO_Z,\n        MIMO_Out=MIMO_Out,\n        Angles=Angles,\n        D=D,\n        Z=Z,\n        chunk_size=chunk_size,\n        rotary_dim_divisor=rotary_dim_divisor,\n        dtype=dtype,\n    )\n\ndef grads_to_dA(grad_dA_cs: Tensor, grad_dA_cs_rev: Tensor, chunk_size: int) -> Tensor:\n    b, h, s = grad_dA_cs.shape\n    assert s % chunk_size == 0\n    nchunks = s // chunk_size\n\n    g_f = grad_dA_cs.view(b, h, nchunks, chunk_size)\n    grad_from_f = torch.flip(torch.cumsum(torch.flip(g_f, dims=[-1]), dim=-1), dims=[-1])\n\n    g_r = grad_dA_cs_rev.view(b, h, nchunks, chunk_size)\n    prefix = torch.cumsum(g_r, dim=-1)\n    grad_from_r = torch.cat([torch.zeros_like(prefix[..., :1]), prefix[..., :-1]], dim=-1)\n    return (grad_from_f + grad_from_r).view(b, h, s)\n\n\n\ndef mamba3_MIMO_step_ref(\n    Q: torch.Tensor,\n    K: torch.Tensor,\n    V: torch.Tensor,\n    ADT: torch.Tensor,\n    DT: torch.Tensor,\n    Trap: torch.Tensor,\n    Q_bias: torch.Tensor,\n    K_bias: torch.Tensor,\n    Angles: torch.Tensor,\n    MIMO_V: torch.Tensor,\n    MIMO_O: torch.Tensor,\n    D: Optional[torch.Tensor] = None,\n    Z: Optional[torch.Tensor] = None,\n    MIMO_Z: Optional[torch.Tensor] = None,\n    Input_States: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = None,\n) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:\n    \"\"\"Reference implementation of Mamba-3 MIMO in recurrent (step) mode.\n\n    Args:\n        Input_States: Optional tuple of (Angle_State, SSM_State, K_State, V_State)\n\n    Returns:\n        out: Output tensor (batch, seqlen, nheads, headdim_v)\n        Final_States: Tuple of (Angle_State, SSM_State, K_State, V_State)\n    \"\"\"\n    batch, seqlen, mimo_rank, nheads_qk, headdim_qk = Q.shape\n    _, _, nheads, headdim_v = V.shape\n    headdim_angles = Angles.shape[-1]\n    device = Q.device\n    assert seqlen > 0\n\n    # Expand Q/K for GQA\n    if Q.shape[3] != V.shape[2]:\n        Q = repeat(Q, \"b s r h_bc d -> b s r (h_bc g) d\", g=V.shape[2] // Q.shape[3])\n    if K.shape[3] != V.shape[2]:\n        K = repeat(K, \"b s r h_bc d -> b s r (h_bc g) d\", g=V.shape[2] // K.shape[3])\n\n    def apply_rotary_emb(tensor, cos, sin):\n        tensor_reshaped = tensor.view(*tensor.shape[:-1], -1, 2)\n        tensor_0 = tensor_reshaped[..., 0]\n        tensor_1 = tensor_reshaped[..., 1]\n        if cos.shape[-1] < tensor_0.shape[-1]:\n            pad_size = tensor_0.shape[-1] - cos.shape[-1]\n            cos = F.pad(cos, (0, pad_size), value=1.0)\n            sin = F.pad(sin, (0, pad_size), value=0.0)\n        rotated_0 = tensor_0 * cos - tensor_1 * sin\n        rotated_1 = tensor_0 * sin + tensor_1 * cos\n        rotated = torch.stack([rotated_0, rotated_1], dim=-1).view_as(tensor)\n        return rotated\n\n    q_bias = rearrange(Q_bias, \"h r d -> r h d\")\n    k_bias = rearrange(K_bias, \"h r d -> r h d\")\n\n    # Initialize states\n    if Input_States is not None:\n        Angle_State, SSM_State, K_State, V_State = Input_States\n        Angle_State = Angle_State.clone()\n        SSM_State = SSM_State.clone().to(torch.float32)\n        K_State = K_State.clone()\n        V_State = V_State.clone()\n    else:\n        Angle_State = torch.zeros((batch, nheads, headdim_angles), dtype=torch.float32, device=device)\n        SSM_State = torch.zeros((batch, nheads, headdim_v, headdim_qk), dtype=torch.float32, device=device)\n        K_State = torch.zeros((batch, nheads, mimo_rank, headdim_qk), dtype=Q.dtype, device=device)\n        V_State = torch.zeros((batch, nheads, mimo_rank, headdim_v), dtype=V.dtype, device=device)\n\n    # MIMO up project x and z:\n    v_proj = torch.einsum(\"bthd,hrd->btrhd\", V, MIMO_V)\n    if Z is not None:\n        z_proj = torch.einsum(\"bthd,hrd->btrhd\", Z, MIMO_Z)\n    else:\n        z_proj = None\n\n    TWO_PI = 2 * math.pi\n    out_arr = []\n\n    # Main SSM recurrence\n    for idx in range(seqlen):\n        q = Q[:, idx, :, :, :] + q_bias.unsqueeze(0)\n        k = K[:, idx, :, :, :] + k_bias.unsqueeze(0)\n        v = v_proj[:, idx, :, :, :] # (B R H P)\n        adt = ADT[:, :, idx]\n        dt = DT[:, :, idx]\n        trap = torch.nn.functional.sigmoid(Trap[:, :, idx])\n        z = z_proj[:, idx, :, :, :] if z_proj is not None else None\n        angles = Angles[:, idx, :, :] # (B H N)\n\n        q = q.permute(0, 2, 1, 3) # (B H R N)\n        k = k.permute(0, 2, 1, 3)\n        v = v.permute(0, 2, 1, 3)\n        if z is not None:\n            z = z.permute(0, 2, 1, 3)\n\n        # Update angle state with cumsum: Angle_State = (Angle_State + Angles * DT) mod 2π\n        # Angle_State = Angle_State + angles * dt.unsqueeze(-1)\n        # Angle_State = Angle_State - TWO_PI * torch.floor(Angle_State / TWO_PI)\n        Angle_State = Angle_State + torch.tanh(angles) * dt.unsqueeze(-1) * math.pi\n\n\n        # Apply rotary embeddings to Q and K using cumulative angles\n        cos_angles = torch.cos(Angle_State).unsqueeze(2) # (B H 1 N)\n        sin_angles = torch.sin(Angle_State).unsqueeze(2)\n        q_rot = apply_rotary_emb(q, cos_angles, sin_angles)\n        k_rot = apply_rotary_emb(k, cos_angles, sin_angles)\n\n        alpha = torch.exp(adt)\n        beta = (1 - trap) * dt * alpha\n        gamma = trap * dt\n\n        # Update SSM state using previous K_State and V_State\n        prev_kv = torch.einsum(\"bhrd,bhrp->bhpd\", K_State, V_State)\n        curr_kv = torch.einsum(\"bhrd,bhrp->bhpd\", k_rot, v)\n        SSM_State = alpha.unsqueeze(-1).unsqueeze(-1) * SSM_State\n        SSM_State = SSM_State + beta.unsqueeze(-1).unsqueeze(-1) * prev_kv\n        SSM_State = SSM_State + gamma.unsqueeze(-1).unsqueeze(-1) * curr_kv\n\n        # Compute output\n        out = torch.einsum(\"bhpd,bhrd->bhrp\", SSM_State, q_rot.to(SSM_State.dtype))\n\n        if D is not None:\n            out = out + D[None, :, None, None] * v\n\n        if z is not None:\n            out = out * z * torch.sigmoid(z)\n\n        out = torch.einsum(\"bhrp,hrp->bhp\", out, MIMO_O)\n        out_arr.append(out)\n\n        # Update K and V states for next step\n        K_State = k_rot\n        V_State = v\n\n    out = torch.stack(out_arr, dim=1)\n    Final_States = (Angle_State, SSM_State, K_State, V_State)\n    return out, Final_States\n\n\ndef apply_angle_dt_reference(\n    angle: Tensor,  # (batch, seqlen, nheads, dim)\n    dt: Tensor,     # (batch, seqlen, nheads)\n) -> Tensor:\n    # Match debug_mimo_step.py preprocessing for chunk reference path.\n    base_vals = angle.to(torch.float32)\n    base_vals = torch.tanh(base_vals) * dt[..., None].to(torch.float32) * torch.pi\n    return torch.cumsum(base_vals, dim=1)\n\n\ndef mamba3_MIMO_chunk_ref(\n    q: Tensor,\n    k: Tensor,\n    v: Tensor,\n    q_bias: Tensor,\n    k_bias: Tensor,\n    mimo_v: Tensor,\n    mimo_o: Optional[Tensor],\n    z: Optional[Tensor],\n    mimo_z: Optional[Tensor],\n    angles: Tensor,\n    dA_cs: Tensor,\n    dA_cs_rev: Tensor,\n    dt: Tensor,\n    trap: Tensor,\n    D: Optional[Tensor],\n    chunk_size: int = 64,\n    rotary_dim_divisor: int = 4,\n    return_final_state: bool = False,\n    dtype: torch.dtype = torch.float32,\n    rotate_pairwise: bool = False,\n    contract_mimo_out: bool = True,\n) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]:\n    # Local copy of the reference program so tests remain valid even if module-level\n    # debug/reference helpers are removed from shipped kernels.\n    from einops import rearrange, repeat\n\n    nchunks = q.shape[1] // chunk_size\n    q, k, v = q.to(dtype), k.to(dtype), v.to(dtype)\n    if z is not None:\n        z = z.to(dtype)\n        mimo_z = mimo_z.to(dtype)\n    if D is not None:\n        D = D.to(dtype)\n    q_bias, k_bias = q_bias.to(dtype), k_bias.to(dtype)\n    mimo_v = mimo_v.to(dtype)\n    if contract_mimo_out:\n        assert mimo_o is not None\n        mimo_o = mimo_o.to(dtype)\n    if dA_cs is not None:\n        dA_cs, dA_cs_rev = dA_cs.to(dtype), dA_cs_rev.to(dtype)\n        dA_cs = rearrange(dA_cs, \"b h (n c) -> b h n c\", c=chunk_size)\n        dA_cs_rev = rearrange(dA_cs_rev, \"b h (n c) -> b h n c\", c=chunk_size)\n\n    batch, seqlen, mimo_rank, nheads_qk, dstate = q.shape\n    nheads = v.shape[-2]\n    if nheads_qk != nheads:\n        q = repeat(q, \"b s r h_qk d -> b s r (h_qk g) d\", g=nheads // nheads_qk)\n        k = repeat(k, \"b s r h_qk d -> b s r (h_qk g) d\", g=nheads // nheads_qk)\n\n    angles = angles.to(dtype) if angles is not None else None\n    trap = trap.to(dtype) if trap is not None else None\n    dt = dt.to(dtype) if dt is not None else None\n\n    q_bias = rearrange(q_bias, \"h r d -> r h d\")\n    k_bias = rearrange(k_bias, \"h r d -> r h d\")\n    q = q + q_bias[None, None, :, :, :]\n    k = k + k_bias[None, None, :, :, :]\n\n    qk_dot = torch.einsum(\"bsRhd,bsrhd->bsRrh\", q, k)\n\n    if angles is not None:\n        angles = angles.unsqueeze(2)\n        cos_angles = torch.cos(angles)\n        sin_angles = torch.sin(angles)\n\n        def apply_rotary_emb(tensor: Tensor, cos: Tensor, sin: Tensor) -> Tensor:\n            if rotate_pairwise:\n                # Pairwise convention used by mamba3_MIMO_step_ref / debug_mimo_step.py.\n                tensor_reshaped = tensor.view(*tensor.shape[:-1], -1, 2)\n                tensor_0 = tensor_reshaped[..., 0]\n                tensor_1 = tensor_reshaped[..., 1]\n                rotated_0 = tensor_0 * cos - tensor_1 * sin\n                rotated_1 = tensor_0 * sin + tensor_1 * cos\n                return torch.stack([rotated_0, rotated_1], dim=-1).view_as(tensor)\n            # Kernel-aligned convention (kept as default for existing tests).\n            tensor_reshaped = tensor.view(*tensor.shape[:-1], 2, -1)\n            tensor_0 = tensor_reshaped[..., 0, :]\n            tensor_1 = tensor_reshaped[..., 1, :]\n            rotated_0 = tensor_0 * cos - tensor_1 * sin\n            rotated_1 = tensor_0 * sin + tensor_1 * cos\n            return torch.stack([rotated_0, rotated_1], dim=-2).view_as(tensor)\n\n        def apply_rotary_emb_rotate_half(tensor: Tensor, cos: Tensor, sin: Tensor) -> Tensor:\n            tensor_reshaped = tensor.view(*tensor.shape[:-1], 4, -1)\n            tensor_0 = tensor_reshaped[..., 0, :]\n            tensor_1 = tensor_reshaped[..., 2, :]\n            rotated_0 = tensor_0 * cos - tensor_1 * sin\n            rotated_1 = tensor_0 * sin + tensor_1 * cos\n            return torch.stack(\n                [\n                    rotated_0,\n                    tensor_reshaped[..., 1, :],\n                    rotated_1,\n                    tensor_reshaped[..., 3, :],\n                ],\n                dim=-2,\n            ).view_as(tensor)\n\n        if rotary_dim_divisor == 4:\n            q = apply_rotary_emb_rotate_half(q, cos_angles, sin_angles)\n            k = apply_rotary_emb_rotate_half(k, cos_angles, sin_angles)\n        elif rotary_dim_divisor == 2:\n            q = apply_rotary_emb(q, cos_angles, sin_angles)\n            k = apply_rotary_emb(k, cos_angles, sin_angles)\n        else:\n            raise ValueError(f\"Invalid rotary_dim_divisor: {rotary_dim_divisor}\")\n\n    if return_final_state:\n        final_k = k[:, -1].contiguous().clone()\n    else:\n        final_k = None\n\n    trap = torch.nn.functional.sigmoid(trap)\n    gamma = dt * trap\n    dt_shifted = torch.nn.functional.pad(dt[:, :, 1:], (0, 1), value=0.0)\n    trap_shifted = torch.nn.functional.pad(trap[:, :, 1:], (0, 1), value=0.0)\n    shifted_gamma = dt_shifted * (1 - trap_shifted)\n    factor = gamma + shifted_gamma\n    k = torch.einsum(\"bsrhn,bhs->bsrhn\", k, factor)\n    qk_dot = torch.einsum(\"bsrRh,bhs->bsrRh\", qk_dot, shifted_gamma)\n\n    v = torch.einsum(\"bthd,hrd->btrhd\", v, mimo_v)\n\n    def segsum_unstable(x: Tensor) -> Tensor:\n        x_segsum = x[..., :, None] - x[..., None, :]\n        mask = torch.tril(torch.ones(x.size(-1), x.size(-1), device=x.device, dtype=torch.bool), diagonal=0)\n        return x_segsum.masked_fill(~mask, -torch.inf)\n\n    mimo_mask_outer = segsum_unstable(dA_cs)\n    mimo_mask_inner = torch.ones(mimo_rank, mimo_rank, dtype=torch.bool, device=q.device)\n    mimo_mask = torch.kron(mimo_mask_outer, mimo_mask_inner[None, None, None, :, :])\n\n    q = rearrange(q, \"b (n c) r h d -> b h n (c r) d\", c=chunk_size)\n    k_scaled = rearrange(k, \"b (n c) r h d -> b h n c r d\", c=chunk_size)\n    k_scaled = torch.einsum(\"bhncrd,bhnc->bhncrd\", k_scaled, torch.exp(dA_cs_rev))\n    k_scaled = rearrange(k_scaled, \"b h n c r d -> b h n (c r) d\", c=chunk_size)\n    k = rearrange(k, \"b (n c) r h d -> b h n (c r) d\", c=chunk_size)\n    v = rearrange(v, \"b (n c) r h d -> b h n (c r) d\", c=chunk_size)\n    kv = k_scaled.transpose(-1, -2) @ v\n\n    curr_state = torch.zeros_like(kv[:, :, 0, :, :])\n    for n in range(nchunks):\n        curr_dA_sum = dA_cs[:, :, n, -1]\n        next_state = (torch.exp(curr_dA_sum[:, :, None, None]) * curr_state) + kv[:, :, n, :, :]\n        kv[:, :, n, :, :] = curr_state\n        curr_state = next_state\n\n    if return_final_state:\n        final_state = next_state.float()\n    else:\n        final_state = None\n\n    q_inter = q * torch.exp(repeat(dA_cs, \"b h n c -> b h n (c r)\", r=mimo_rank).unsqueeze(-1))\n    inter = q_inter @ kv\n    intra = ((q @ k.transpose(-1, -2)) * torch.exp(mimo_mask)) @ v\n    o = inter + intra\n    o = rearrange(o, \"b h n (c r) d -> b h n c r d\", r=mimo_rank)\n\n    v = rearrange(v, \"b h n (c r) d -> b h (n c) r d\", r=mimo_rank)\n    qk_dot = rearrange(qk_dot, \"b t R r h -> b h t R r\")\n    qkv = torch.einsum(\"bhtRr,bhtrp->bhtRp\", qk_dot, v)\n    qkv = rearrange(qkv, \"b h (n c) r d -> b h n c r d\", c=chunk_size)\n    o -= qkv\n\n    if D is not None:\n        vd = torch.einsum(\"bhtrp,h->bhtrp\", v, D)\n        vd = rearrange(vd, \"b h (n c) r d -> b h n c r d\", c=chunk_size)\n        o += vd\n\n    if z is not None:\n        z = torch.einsum(\"bthd,hrd->btrhd\", z, mimo_z)\n        z = rearrange(z, \"b (n c) r h d -> b h n c r d\", c=chunk_size)\n        o = o * torch.nn.functional.silu(z)\n\n    if contract_mimo_out:\n        assert mimo_o is not None\n        o = torch.einsum(\"bhncrd,hrd->bhncd\", o, mimo_o)\n        return rearrange(o, \"b h n c d -> b (n c) h d\"), final_state, final_k\n\n    return rearrange(o, \"b h n c r d -> b (n c) r h d\"), final_state, final_k\n\n\ndef run_ref_backward_fp32(\n    mods: SimpleNamespace,\n    inputs: dict,\n    *,\n    contract_mimo_out: bool = True,\n    grad_output: Optional[Tensor] = None,\n) -> dict:\n    ref_dtype = torch.float32\n    q = inputs[\"q\"].detach().to(ref_dtype).requires_grad_(True)\n    k = inputs[\"k\"].detach().to(ref_dtype).requires_grad_(True)\n    v = inputs[\"v\"].detach().to(ref_dtype).requires_grad_(True)\n    q_bias = inputs[\"q_bias\"].detach().to(ref_dtype).requires_grad_(True)\n    k_bias = inputs[\"k_bias\"].detach().to(ref_dtype).requires_grad_(True)\n    mimo_v = inputs[\"mimo_v\"].detach().to(ref_dtype).requires_grad_(True)\n    mimo_o = (\n        inputs[\"mimo_o\"].detach().to(ref_dtype).requires_grad_(True)\n        if contract_mimo_out\n        else None\n    )\n    z = (\n        inputs[\"z\"].detach().to(ref_dtype).requires_grad_(True)\n        if inputs[\"z\"] is not None\n        else None\n    )\n    mimo_z = (\n        inputs[\"mimo_z\"].detach().to(ref_dtype).requires_grad_(True)\n        if inputs[\"mimo_z\"] is not None\n        else None\n    )\n    angles = inputs[\"angles\"].detach().to(ref_dtype).requires_grad_(True)\n    dt = inputs[\"dt\"].detach().to(ref_dtype).requires_grad_(True)\n    trap = inputs[\"trap\"].detach().to(ref_dtype).requires_grad_(True)\n    d = inputs[\"D\"].detach().to(ref_dtype).requires_grad_(True)\n\n    dA_cs_base, dA_cs_rev_base, _ = mods.utils.compute_dacs_segsum_triton(\n        inputs[\"dA\"].detach().to(torch.float32), inputs[\"chunk_size\"]\n    )\n    dA_cs = dA_cs_base.detach().to(ref_dtype).requires_grad_(True)\n    dA_cs_rev = dA_cs_rev_base.detach().to(ref_dtype).requires_grad_(True)\n\n    out, _, _ = mamba3_MIMO_chunk_ref(\n        q,\n        k,\n        v,\n        q_bias,\n        k_bias,\n        mimo_v,\n        mimo_o,\n        z,\n        mimo_z,\n        angles,\n        dA_cs,\n        dA_cs_rev,\n        dt,\n        trap,\n        d,\n        chunk_size=inputs[\"chunk_size\"],\n        rotary_dim_divisor=inputs[\"rotary_dim_divisor\"],\n        dtype=ref_dtype,\n        contract_mimo_out=contract_mimo_out,\n    )\n\n    grad_input_items = [\n        (\"q\", q),\n        (\"k\", k),\n        (\"v\", v),\n        (\"q_bias\", q_bias),\n        (\"k_bias\", k_bias),\n        (\"mimo_v\", mimo_v),\n        (\"angles\", angles),\n        (\"dA_cs\", dA_cs),\n        (\"dA_cs_rev\", dA_cs_rev),\n        (\"dt\", dt),\n        (\"trap\", trap),\n        (\"dD\", d),\n    ]\n    if z is not None:\n        grad_input_items.append((\"z\", z))\n    if mimo_z is not None:\n        grad_input_items.append((\"mimo_z\", mimo_z))\n    if contract_mimo_out:\n        grad_input_items.append((\"mimo_o\", mimo_o))\n\n    if grad_output is None:\n        grad_output = inputs[\"dout\"]\n    grads = torch.autograd.grad(\n        outputs=out,\n        inputs=tuple(t for _, t in grad_input_items),\n        grad_outputs=grad_output.detach().to(ref_dtype),\n        retain_graph=False,\n        allow_unused=True,\n    )\n    grad_map = {name: grad for (name, _), grad in zip(grad_input_items, grads)}\n    grad_map[\"dA\"] = grads_to_dA(grad_map[\"dA_cs\"], grad_map[\"dA_cs_rev\"], inputs[\"chunk_size\"])\n\n    return {\n        \"dq\": grad_map[\"q\"],\n        \"dk\": grad_map[\"k\"],\n        \"dv\": grad_map[\"v\"],\n        \"dA\": grad_map[\"dA\"],\n        \"ddt\": grad_map[\"dt\"],\n        \"dtrap\": grad_map[\"trap\"],\n        \"dq_bias\": grad_map[\"q_bias\"],\n        \"dk_bias\": grad_map[\"k_bias\"],\n        \"dmimo_v\": grad_map[\"mimo_v\"],\n        \"dmimo_z\": grad_map.get(\"mimo_z\"),\n        \"dmimo_o\": grad_map.get(\"mimo_o\"),\n        \"dangles\": grad_map[\"angles\"],\n        \"dD\": grad_map[\"dD\"],\n        \"dz\": grad_map.get(\"z\"),\n    }\n\n\ndef test_mamba3_MIMO_chunk_ref_matches_step_ref() -> None:\n    # Lightweight deterministic ref-vs-ref consistency test.\n    B, S, H, G, P, N, R, chunk_size = 1, 128, 8, 1, 32, 64, 4, 16\n    dtype = torch.float32\n    device = \"cpu\"\n    torch.manual_seed(0)\n\n    q = torch.randn((B, S, R, G, N), device=device, dtype=dtype)\n    k = torch.randn((B, S, R, G, N), device=device, dtype=dtype)\n    v = torch.randn((B, S, H, P), device=device, dtype=dtype)\n\n    q_bias = torch.randn((H, R, N), device=device, dtype=dtype)\n    k_bias = torch.randn((H, R, N), device=device, dtype=dtype)\n    mimo_v = torch.rand((H, R, P), device=device, dtype=dtype)\n    mimo_o = torch.rand((H, R, P), device=device, dtype=dtype)\n\n    z = torch.randn_like(v)\n    mimo_z = torch.rand_like(mimo_v)\n    D = torch.randn((H,), device=device, dtype=dtype)\n\n    angles = torch.rand((B, S, H, N // 2), device=device, dtype=dtype)\n    dt = F.softplus(-3.0 + torch.randn(B, H, S, device=device, dtype=torch.float32))\n    A_neg = -F.softplus(torch.randn((B, H, S), device=device, dtype=torch.float32))\n    A_neg = torch.clamp(A_neg, max=-1e-4)\n    ADT = A_neg * dt\n    trap = torch.rand(B, H, S, device=device, dtype=dtype) * 0.5\n\n    dA_cs = torch.cumsum(rearrange(ADT, \"b h (n c) -> b h n c\", c=chunk_size), dim=-1)\n    dA_cs_rev = dA_cs[..., -1:] - dA_cs\n    angles_prerotated = apply_angle_dt_reference(angles, dt.permute(0, 2, 1))\n\n    chunk_out, _, _ = mamba3_MIMO_chunk_ref(\n        q,\n        k,\n        v,\n        q_bias,\n        k_bias,\n        mimo_v,\n        mimo_o,\n        z,\n        mimo_z,\n        angles_prerotated,\n        dA_cs.view(B, H, S),\n        dA_cs_rev.view(B, H, S),\n        dt,\n        trap,\n        D,\n        chunk_size=chunk_size,\n        rotary_dim_divisor=2,\n        return_final_state=True,\n        dtype=dtype,\n        rotate_pairwise=True,\n    )\n\n    step_out, _ = mamba3_MIMO_step_ref(\n        q,\n        k,\n        v,\n        ADT,\n        dt,\n        trap,\n        q_bias,\n        k_bias,\n        angles,\n        mimo_v,\n        mimo_o,\n        D=D,\n        Z=z,\n        MIMO_Z=mimo_z,\n    )\n\n    assert chunk_out.shape == step_out.shape\n    assert_stable_rel(\n        chunk_out,\n        step_out,\n        label=\"chunk_ref_vs_step_ref\",\n        cfg=f\"B={B}, S={S}, H={H}, P={P}, N={N}, R={R}, C={chunk_size}\",\n        rel_tol=0.02,\n    )\n\n\n@pytest.mark.parametrize(\"n,p,r,chunk_size,bb_threads\", CASE_GRID)\ndef test_fused_chunk_linear_attn_fwd_relative_error_lt_10pct(\n    mods: SimpleNamespace, n: int, p: int, r: int, chunk_size: int, bb_threads: int\n) -> None:\n    del bb_threads\n    inputs = build_inputs(\n        mods=mods,\n        n=n,\n        p=p,\n        r=r,\n        chunk_size=chunk_size,\n        seed=1234 + n + p + r + chunk_size,\n    )\n\n    out_tilelang, _, _ = mods.fwd.mamba_mimo_forward(\n        inputs[\"q\"],\n        inputs[\"k\"],\n        inputs[\"v\"],\n        inputs[\"q_bias\"],\n        inputs[\"k_bias\"],\n        inputs[\"mimo_v\"],\n        inputs[\"mimo_o\"],\n        inputs[\"z\"],\n        inputs[\"D\"],\n        inputs[\"mimo_z\"],\n        inputs[\"angles\"],\n        inputs[\"dA_cs\"],\n        inputs[\"dA_cs_rev\"],\n        inputs[\"dt\"],\n        inputs[\"trap\"],\n        inputs[\"segsum\"],\n        chunk_size=chunk_size,\n        rotary_dim_divisor=inputs[\"rotary_dim_divisor\"],\n        dtype=FIXED_DTYPE,\n    )\n\n    out_ref_fp32, _, _ = mamba3_MIMO_chunk_ref(\n        inputs[\"q\"].clone(),\n        inputs[\"k\"].clone(),\n        inputs[\"v\"].clone(),\n        inputs[\"q_bias\"].clone(),\n        inputs[\"k_bias\"].clone(),\n        inputs[\"mimo_v\"].clone(),\n        inputs[\"mimo_o\"].clone(),\n        inputs[\"z\"].clone(),\n        inputs[\"mimo_z\"].clone(),\n        inputs[\"angles\"].clone(),\n        inputs[\"dA_cs\"].clone(),\n        inputs[\"dA_cs_rev\"].clone(),\n        inputs[\"dt\"].clone(),\n        inputs[\"trap\"].clone(),\n        inputs[\"D\"].clone(),\n        chunk_size=chunk_size,\n        rotary_dim_divisor=inputs[\"rotary_dim_divisor\"],\n        dtype=torch.float32,\n    )\n\n    assert_stable_rel(\n        out_tilelang,\n        out_ref_fp32,\n        label=\"forward\",\n        cfg=f\"N={n}, P={p}, R={r}, chunk={chunk_size}\",\n    )\n\n\n@pytest.mark.parametrize(\"n,p,r,chunk_size,bb_threads\", CASE_GRID)\ndef test_fused_chunk_linear_attn_fwd_return_state_relative_error_lt_10pct(\n    mods: SimpleNamespace, n: int, p: int, r: int, chunk_size: int, bb_threads: int\n) -> None:\n    del bb_threads\n    inputs = build_inputs(\n        mods=mods,\n        n=n,\n        p=p,\n        r=r,\n        chunk_size=chunk_size,\n        seed=3456 + n + p + r + chunk_size,\n    )\n\n    out_tilelang, final_state_tilelang, final_k_tilelang = mods.fwd.mamba_mimo_forward(\n        inputs[\"q\"],\n        inputs[\"k\"],\n        inputs[\"v\"],\n        inputs[\"q_bias\"],\n        inputs[\"k_bias\"],\n        inputs[\"mimo_v\"],\n        inputs[\"mimo_o\"],\n        inputs[\"z\"],\n        inputs[\"D\"],\n        inputs[\"mimo_z\"],\n        inputs[\"angles\"],\n        inputs[\"dA_cs\"],\n        inputs[\"dA_cs_rev\"],\n        inputs[\"dt\"],\n        inputs[\"trap\"],\n        inputs[\"segsum\"],\n        return_state=True,\n        chunk_size=chunk_size,\n        rotary_dim_divisor=inputs[\"rotary_dim_divisor\"],\n        dtype=FIXED_DTYPE,\n    )\n\n    out_ref_fp32, final_state_ref, final_k_ref = mamba3_MIMO_chunk_ref(\n        inputs[\"q\"].clone(),\n        inputs[\"k\"].clone(),\n        inputs[\"v\"].clone(),\n        inputs[\"q_bias\"].clone(),\n        inputs[\"k_bias\"].clone(),\n        inputs[\"mimo_v\"].clone(),\n        inputs[\"mimo_o\"].clone(),\n        inputs[\"z\"].clone(),\n        inputs[\"mimo_z\"].clone(),\n        inputs[\"angles\"].clone(),\n        inputs[\"dA_cs\"].clone(),\n        inputs[\"dA_cs_rev\"].clone(),\n        inputs[\"dt\"].clone(),\n        inputs[\"trap\"].clone(),\n        inputs[\"D\"].clone(),\n        chunk_size=chunk_size,\n        rotary_dim_divisor=inputs[\"rotary_dim_divisor\"],\n        return_final_state=True,\n        dtype=torch.float32,\n    )\n\n    assert_stable_rel(\n        out_tilelang,\n        out_ref_fp32,\n        label=\"forward_return_state_out\",\n        cfg=f\"N={n}, P={p}, R={r}, chunk={chunk_size}\",\n    )\n    assert_stable_rel(\n        final_state_tilelang,\n        final_state_ref,\n        label=\"forward_return_state_final_state\",\n        cfg=f\"N={n}, P={p}, R={r}, chunk={chunk_size}\",\n    )\n    assert_stable_rel(\n        final_k_tilelang,\n        final_k_ref,\n        label=\"forward_return_state_final_k\",\n        cfg=f\"N={n}, P={p}, R={r}, chunk={chunk_size}\",\n    )\n\n\n@pytest.mark.parametrize(\"n,p,r,chunk_size,bb_threads\", CASE_GRID)\ndef test_fused_chunk_linear_attn_fwd_prereduce_relative_error_lt_10pct(\n    mods: SimpleNamespace, n: int, p: int, r: int, chunk_size: int, bb_threads: int\n) -> None:\n    del bb_threads\n    inputs = build_inputs(\n        mods=mods,\n        n=n,\n        p=p,\n        r=r,\n        chunk_size=chunk_size,\n        seed=2345 + n + p + r + chunk_size,\n        has_z=False,\n    )\n\n    out_tilelang, _, _ = mods.fwd.mamba_mimo_forward(\n        inputs[\"q\"],\n        inputs[\"k\"],\n        inputs[\"v\"],\n        inputs[\"q_bias\"],\n        inputs[\"k_bias\"],\n        inputs[\"mimo_v\"],\n        None,\n        inputs[\"z\"],\n        inputs[\"D\"],\n        inputs[\"mimo_z\"],\n        inputs[\"angles\"],\n        inputs[\"dA_cs\"],\n        inputs[\"dA_cs_rev\"],\n        inputs[\"dt\"],\n        inputs[\"trap\"],\n        inputs[\"segsum\"],\n        chunk_size=chunk_size,\n        rotary_dim_divisor=inputs[\"rotary_dim_divisor\"],\n        dtype=FIXED_DTYPE,\n    )\n\n    out_ref_fp32, _, _ = mamba3_MIMO_chunk_ref(\n        inputs[\"q\"].clone(),\n        inputs[\"k\"].clone(),\n        inputs[\"v\"].clone(),\n        inputs[\"q_bias\"].clone(),\n        inputs[\"k_bias\"].clone(),\n        inputs[\"mimo_v\"].clone(),\n        None,\n        None,\n        None,\n        inputs[\"angles\"].clone(),\n        inputs[\"dA_cs\"].clone(),\n        inputs[\"dA_cs_rev\"].clone(),\n        inputs[\"dt\"].clone(),\n        inputs[\"trap\"].clone(),\n        inputs[\"D\"].clone(),\n        chunk_size=chunk_size,\n        rotary_dim_divisor=inputs[\"rotary_dim_divisor\"],\n        dtype=torch.float32,\n        contract_mimo_out=False,\n    )\n\n    assert_stable_rel(\n        out_tilelang,\n        out_ref_fp32,\n        label=\"forward_prereduce\",\n        cfg=f\"N={n}, P={p}, R={r}, chunk={chunk_size}\",\n    )\n\n\n@pytest.mark.parametrize(\"n,p,r,chunk_size,bb_threads\", CASE_GRID)\ndef test_mamba_mimo_bwd_combined_relative_errors_lt_10pct(\n    mods: SimpleNamespace, n: int, p: int, r: int, chunk_size: int, bb_threads: int\n) -> None:\n    inputs = build_inputs(\n        mods=mods,\n        n=n,\n        p=p,\n        r=r,\n        chunk_size=chunk_size,\n        seed=5678 + n + p + r + chunk_size,\n    )\n\n    ref_grads = run_ref_backward_fp32(mods, inputs)\n\n    (\n        dq,\n        dk,\n        dv,\n        dA,\n        ddt,\n        dtrap,\n        dq_bias,\n        dk_bias,\n        dmimo_v,\n        dmimo_z,\n        dmimo_o,\n        dangles,\n        dD,\n        dz,\n    ) = mods.bwd.mamba_mimo_bwd_combined(\n        inputs[\"dout\"],\n        inputs[\"q\"],\n        inputs[\"k\"],\n        inputs[\"v\"],\n        inputs[\"q_bias\"],\n        inputs[\"k_bias\"],\n        inputs[\"mimo_v\"],\n        inputs[\"mimo_o\"],\n        inputs[\"z\"],\n        inputs[\"mimo_z\"],\n        inputs[\"angles\"],\n        inputs[\"dA_cs\"],\n        inputs[\"dA_cs_rev\"],\n        inputs[\"dt\"],\n        inputs[\"trap\"],\n        inputs[\"D\"],\n        inputs[\"segsum\"],\n        chunk_size,\n        inputs[\"rotary_dim_divisor\"],\n        FIXED_DTYPE,\n        bb_threads=bb_threads,\n    )\n\n    comparisons = {\n        \"dq\": (dq, ref_grads[\"dq\"]),\n        \"dk\": (dk, ref_grads[\"dk\"]),\n        \"dv\": (dv, ref_grads[\"dv\"]),\n        \"dA\": (dA, ref_grads[\"dA\"]),\n        \"ddt\": (ddt, ref_grads[\"ddt\"]),\n        \"dtrap\": (dtrap, ref_grads[\"dtrap\"]),\n        \"dq_bias\": (dq_bias, ref_grads[\"dq_bias\"]),\n        \"dk_bias\": (dk_bias, ref_grads[\"dk_bias\"]),\n        \"dmimo_v\": (dmimo_v, ref_grads[\"dmimo_v\"]),\n        \"dmimo_z\": (dmimo_z, ref_grads[\"dmimo_z\"]),\n        \"dmimo_o\": (dmimo_o, ref_grads[\"dmimo_o\"]),\n        \"dangles\": (dangles, ref_grads[\"dangles\"]),\n        \"dD\": (dD, ref_grads[\"dD\"]),\n        \"dz\": (dz, ref_grads[\"dz\"]),\n    }\n\n    for name, (ours, ref) in comparisons.items():\n        assert_stable_rel(\n            ours,\n            ref,\n            label=name,\n            cfg=f\"N={n}, P={p}, R={r}, chunk={chunk_size}, bb_threads={bb_threads}\",\n        )\n\n\n@pytest.mark.parametrize(\"n,p,r,chunk_size,bb_threads\", CASE_GRID)\ndef test_mamba_mimo_bwd_combined_prereduce_relative_errors_lt_10pct(\n    mods: SimpleNamespace, n: int, p: int, r: int, chunk_size: int, bb_threads: int\n) -> None:\n    inputs = build_inputs(\n        mods=mods,\n        n=n,\n        p=p,\n        r=r,\n        chunk_size=chunk_size,\n        seed=6789 + n + p + r + chunk_size,\n        has_z=False,\n    )\n    b, s, h, p_dim = inputs[\"v\"].shape\n    dout_prereduce = torch.randn((b, s, r, h, p_dim), device=\"cuda\", dtype=FIXED_DTYPE)\n\n    ref_grads = run_ref_backward_fp32(\n        mods,\n        inputs,\n        contract_mimo_out=False,\n        grad_output=dout_prereduce,\n    )\n\n    (\n        dq,\n        dk,\n        dv,\n        dA,\n        ddt,\n        dtrap,\n        dq_bias,\n        dk_bias,\n        dmimo_v,\n        dmimo_z,\n        dmimo_o,\n        dangles,\n        dD,\n        dz,\n    ) = mods.bwd.mamba_mimo_bwd_combined(\n        dout_prereduce,\n        inputs[\"q\"],\n        inputs[\"k\"],\n        inputs[\"v\"],\n        inputs[\"q_bias\"],\n        inputs[\"k_bias\"],\n        inputs[\"mimo_v\"],\n        None,\n        None,\n        None,\n        inputs[\"angles\"],\n        inputs[\"dA_cs\"],\n        inputs[\"dA_cs_rev\"],\n        inputs[\"dt\"],\n        inputs[\"trap\"],\n        inputs[\"D\"],\n        inputs[\"segsum\"],\n        chunk_size,\n        inputs[\"rotary_dim_divisor\"],\n        FIXED_DTYPE,\n        bb_threads=bb_threads,\n    )\n    assert dmimo_o is None\n    assert dmimo_z is None\n    assert dz is None\n\n    comparisons = {\n        \"dq_prereduce\": (dq, ref_grads[\"dq\"]),\n        \"dk_prereduce\": (dk, ref_grads[\"dk\"]),\n        \"dv_prereduce\": (dv, ref_grads[\"dv\"]),\n        \"dA_prereduce\": (dA, ref_grads[\"dA\"]),\n        \"ddt_prereduce\": (ddt, ref_grads[\"ddt\"]),\n        \"dtrap_prereduce\": (dtrap, ref_grads[\"dtrap\"]),\n        \"dq_bias_prereduce\": (dq_bias, ref_grads[\"dq_bias\"]),\n        \"dk_bias_prereduce\": (dk_bias, ref_grads[\"dk_bias\"]),\n        \"dmimo_v_prereduce\": (dmimo_v, ref_grads[\"dmimo_v\"]),\n        \"dangles_prereduce\": (dangles, ref_grads[\"dangles\"]),\n        \"dD_prereduce\": (dD, ref_grads[\"dD\"]),\n    }\n\n    for name, (ours, ref) in comparisons.items():\n        assert_stable_rel(\n            ours,\n            ref,\n            label=name,\n            cfg=f\"N={n}, P={p}, R={r}, chunk={chunk_size}, bb_threads={bb_threads}\",\n        )\n\n\ndef test_mamba_mimo_smoke_forward_backward(mods: SimpleNamespace) -> None:\n    inputs = make_smoke_inputs(\n        batch=FIXED_B,\n        seqlen=FIXED_S,\n        mimo_rank=4,\n        nheads_qk=FIXED_G,\n        nheads=FIXED_H,\n        headdim_qk=128,\n        headdim_v=64,\n        chunk_size=16,\n        rotary_dim_divisor=FIXED_ROTARY_DIM_DIVISOR,\n        device=\"cuda\",\n        dtype=FIXED_DTYPE,\n        seed=999,\n    )\n\n    out = mods.top.mamba3_mimo(**inputs)\n    assert out.shape == (FIXED_B, FIXED_S, FIXED_H, 64)\n\n    loss = out.float().sum()\n    loss.backward()\n\n    grad_names = [\n        \"Q\",\n        \"K\",\n        \"V\",\n        \"ADT\",\n        \"DT\",\n        \"Trap\",\n        \"Q_bias\",\n        \"K_bias\",\n        \"MIMO_V\",\n        \"MIMO_Z\",\n        \"MIMO_Out\",\n        \"Angles\",\n        \"D\",\n        \"Z\",\n    ]\n    for name in grad_names:\n        grad = inputs[name].grad\n        assert grad is not None, f\"Missing gradient for {name}\"\n        assert torch.isfinite(grad).all(), f\"Non-finite gradient detected for {name}\"\n"
  },
  {
    "path": "tests/ops/triton/test_layernorm_gated.py",
    "content": "import math\n\nimport torch\nimport torch.nn.functional as F\n\nimport pytest\n\nfrom einops import rearrange, repeat\n\nfrom mamba_ssm.ops.triton.layernorm_gated import layernorm_fn, rms_norm_ref\n\n\n@pytest.mark.parametrize(\"norm_before_gate\", [True, False])\n# @pytest.mark.parametrize(\"norm_before_gate\", [False])\n@pytest.mark.parametrize(\"has_group\", [False, True])\n# @pytest.mark.parametrize(\"has_group\", [False])\n@pytest.mark.parametrize(\"is_rms_norm\", [False, True])\n# @pytest.mark.parametrize(\"is_rms_norm\", [True])\n@pytest.mark.parametrize(\"has_z\", [False, True])\n# @pytest.mark.parametrize(\"has_z\", [True])\n@pytest.mark.parametrize(\"has_bias\", [False, True])\n# @pytest.mark.parametrize(\"has_bias\", [False])\n# @pytest.mark.parametrize('dtype', [torch.float32, torch.float16, torch.bfloat16])\n@pytest.mark.parametrize('dtype', [torch.float16])\n# @pytest.mark.parametrize(\"wtype\", [torch.float32, torch.float16, torch.bfloat16])\n@pytest.mark.parametrize(\"wtype\", [torch.float32])\n@pytest.mark.parametrize('d', [2048, 4096])\n# @pytest.mark.parametrize('d', [4096])\ndef test_layer_norm_gated(d, dtype, wtype, has_bias, has_z, is_rms_norm, has_group, norm_before_gate):\n    if not has_z and not norm_before_gate:\n        pytest.skip()\n    if not norm_before_gate and not is_rms_norm:  # Reference LN isn't implemented for this case yet\n        pytest.skip()\n    device = 'cuda'\n    rtol, atol = (1e-5, 1e-5) if dtype == torch.float32 else (1e-2, 8e-3)\n    group_size = None if not has_group else 64\n    # set seed\n    torch.random.manual_seed(0)\n    batch = 16\n    seqlen = 1024\n    x = torch.randn(batch, seqlen, d, dtype=dtype, device=device, requires_grad=True)\n    if has_z:\n        z = torch.randn(batch, seqlen, d, dtype=dtype, device=device, requires_grad=True)\n    else:\n        z = None\n    weight = torch.randn(d, dtype=wtype, device=device, requires_grad=True)\n    if has_bias:\n        bias = torch.randn(d, dtype=wtype, device=device, requires_grad=True)\n    else:\n        bias = None\n    x_ref = x.detach().clone().requires_grad_()\n    x_pt = x.detach().clone().requires_grad_()\n    z_ref = z.detach().clone().requires_grad_() if z is not None else None\n    z_pt = z.detach().clone().requires_grad_() if z is not None else None\n    weight_ref = weight.detach().clone().requires_grad_()\n    weight_pt = weight.detach().clone().requires_grad_()\n    bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None\n    bias_pt = bias.detach().clone().requires_grad_() if bias is not None else None\n    out = layernorm_fn(x, weight, bias, z=z, eps=1e-5, group_size=group_size, norm_before_gate=norm_before_gate,\n                       is_rms_norm=is_rms_norm)\n    if not is_rms_norm:\n        if not has_group:\n            out_ref = F.layer_norm(x_ref.float(), (d,), weight=weight_ref.float(), bias=bias_ref.float() if bias_ref is not None else None, eps=1e-5)\n            out_pt = F.layer_norm(x_pt.to(wtype), (d,), weight=weight_pt, bias=bias_pt, eps=1e-5)\n        else:\n            out_ref = rearrange(F.layer_norm(rearrange(x_ref, \"... (g d) -> ... g d\", d=group_size).float(), (group_size,), eps=1e-5), \"... g d -> ... (g d)\") * weight_ref.float()\n            if has_bias:\n                out_ref = out_ref + bias_ref.float()\n            out_pt = rearrange(F.layer_norm(rearrange(x_pt, \"... (g d) -> ... g d\", d=group_size), (group_size,), eps=1e-5), \"... g d -> ... (g d)\") * weight_pt\n            if has_bias:\n                out_pt = out_pt + bias_pt\n        if has_z and norm_before_gate:\n            out_ref = out_ref * F.silu(z_ref.float())\n            out_pt = out_pt * F.silu(z_pt)\n    else:\n        out_ref = rms_norm_ref(x_ref, weight_ref, bias_ref, z=z_ref, eps=1e-5, group_size=group_size,\n                               norm_before_gate=norm_before_gate)\n        out_pt = rms_norm_ref(x_pt, weight_pt, bias_pt, z=z_pt, eps=1e-5, group_size=group_size,\n                              norm_before_gate=norm_before_gate, upcast=False)\n    print(f\"Max diff = {(out - out_ref).abs().max().item()}\")\n    print(f\"Max diff Pytorch = {(out_pt - out_ref).abs().max().item()}\")\n    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + atol\n\n    g = torch.randn_like(out)\n    out.backward(g)\n    out_ref.backward(g)\n    out_pt.backward(g)\n    print(f\"Max dx diff = {(x.grad - x_ref.grad).abs().max().item()}\")\n    print(f\"Max dx diff Pytorch = {(x_pt.grad - x_ref.grad).abs().max().item()}\")\n    if has_z:\n        print(f\"Max dz diff = {(z.grad - z_ref.grad).abs().max().item()}\")\n        print(f\"Max dz diff Pytorch = {(z_pt.grad - z_ref.grad).abs().max().item()}\")\n    print(f\"Max dw diff = {(weight.grad - weight_ref.grad).abs().max().item()}\")\n    print(f\"Max dw diff Pytorch = {(weight_pt.grad - weight_ref.grad).abs().max().item()}\")\n    if has_bias:\n        print(f\"Max db diff = {(bias.grad - bias_ref.grad).abs().max().item()}\")\n        print(f\"Max db diff Pytorch = {(bias_pt.grad - bias_ref.grad).abs().max().item()}\")\n    assert (x.grad - x_ref.grad).abs().max().item() <= 2 * (x_pt.grad - x_ref.grad).abs().max().item() + atol\n    if has_z:\n        assert (z.grad - z_ref.grad).abs().max().item() <= 2 * (z_pt.grad - z_ref.grad).abs().max().item() + atol\n    assert (weight.grad - weight_ref.grad).abs().max().item() <= 2 * (weight_pt.grad - weight_ref.grad).abs().max().item() + atol\n    if has_bias:\n        assert (bias.grad - bias_ref.grad).abs().max().item() <= 2 * (bias_pt.grad - bias_ref.grad).abs().max().item() + atol\n"
  },
  {
    "path": "tests/ops/triton/test_mamba3_siso.py",
    "content": "\"\"\"\nMamba-3 SISO Kernel Tests\n\nCopyright (c) 2025, Dao AI Lab, Goombalab\n\"\"\"\n\nimport copy\nimport math\nfrom typing import Optional, Tuple\n\nimport pytest\nimport torch\nimport torch.nn.functional as F\nfrom einops import rearrange, repeat\n\nfrom mamba_ssm.ops.triton.mamba3.mamba3_siso_combined import mamba3_siso_combined\nfrom mamba_ssm.ops.triton.mamba3.mamba3_siso_step import mamba3_siso_step\n\n\n# Reference Implementations\ndef _segsum(x: torch.Tensor) -> torch.Tensor:\n    \"\"\"Segment sum helper for attention computation.\"\"\"\n    T = x.size(-1)\n    x = repeat(x, \"... d -> ... d e\", e=T)\n    mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1)\n    x = x.masked_fill(~mask, 0)\n    x_segsum = torch.cumsum(x, dim=-2)\n    mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)\n    x_segsum = x_segsum.masked_fill(~mask, -torch.inf)\n    return x_segsum\n\n\ndef mamba3_siso_step_ref(\n    Q: torch.Tensor,\n    K: torch.Tensor,\n    V: torch.Tensor,\n    ADT: torch.Tensor,\n    DT: torch.Tensor,\n    Trap: torch.Tensor,\n    Q_bias: torch.Tensor,\n    K_bias: torch.Tensor,\n    Angles: torch.Tensor,\n    D: Optional[torch.Tensor] = None,\n    Z: Optional[torch.Tensor] = None,\n    Input_States: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = None,\n) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:\n    \"\"\"Reference implementation of Mamba-3 in recurrent (step) mode.\n    \n    Args:\n        Input_States: Optional tuple of (Angle_State, SSM_State, K_State, V_State)\n    \n    Returns:\n        out: Output tensor (batch, seqlen, nheads, headdim_v)\n        Final_States: Tuple of (Angle_State, SSM_State, K_State, V_State)\n    \"\"\"\n    batch, seqlen, nheads_qk, headdim_qk = Q.shape\n    _, _, nheads, headdim_v = V.shape\n    headdim_angles = Angles.shape[-1]\n    device = Q.device\n    assert seqlen > 0\n    Angles = torch.tanh(Angles) * math.pi\n\n    # Expand Q/K for GQA\n    if Q.shape[2] != V.shape[2]:\n        Q = repeat(Q, \"b s h_bc d -> b s (h_bc g) d\", g=V.shape[2] // Q.shape[2])\n    if K.shape[2] != V.shape[2]:\n        K = repeat(K, \"b s h_bc d -> b s (h_bc g) d\", g=V.shape[2] // K.shape[2])\n\n    def apply_rotary_emb(tensor, cos, sin):\n        tensor_reshaped = tensor.view(*tensor.shape[:-1], -1, 2)\n        tensor_0 = tensor_reshaped[..., 0]\n        tensor_1 = tensor_reshaped[..., 1]\n        if cos.shape[-1] < tensor_0.shape[-1]:\n            pad_size = tensor_0.shape[-1] - cos.shape[-1]\n            cos = F.pad(cos, (0, pad_size), value=1.0)\n            sin = F.pad(sin, (0, pad_size), value=0.0)\n        rotated_0 = tensor_0 * cos - tensor_1 * sin\n        rotated_1 = tensor_0 * sin + tensor_1 * cos\n        rotated = torch.stack([rotated_0, rotated_1], dim=-1).view_as(tensor)\n        return rotated\n    \n    # Initialize states\n    if Input_States is not None:\n        Angle_State, SSM_State, K_State, V_State = Input_States\n        Angle_State = Angle_State.clone()\n        SSM_State = SSM_State.clone().to(torch.float32)\n        K_State = K_State.clone()\n        V_State = V_State.clone()\n    else:\n        Angle_State = torch.zeros((batch, nheads, headdim_angles), dtype=torch.float32, device=device)\n        SSM_State = torch.zeros((batch, nheads, headdim_v, headdim_qk), dtype=torch.float32, device=device)\n        K_State = torch.zeros((batch, nheads, headdim_qk), dtype=Q.dtype, device=device)\n        V_State = torch.zeros((batch, nheads, headdim_v), dtype=V.dtype, device=device)\n    \n    TWO_PI = 2 * math.pi\n    out_arr = []\n\n    for idx in range(seqlen):\n        q = Q[:, idx, :, :] + Q_bias.unsqueeze(0)\n        k = K[:, idx, :, :] + K_bias.unsqueeze(0)\n        v = V[:, idx, :, :]\n        adt = ADT[:, :, idx]\n        dt = DT[:, :, idx]\n        trap = Trap[:, :, idx]\n        z = Z[:, idx, :, :] if Z is not None else None\n        angles = Angles[:, idx, :, :]\n\n        # Update angle state with cumsum: Angle_State = (Angle_State + Angles * DT) mod 2π\n        Angle_State = Angle_State + angles * dt.unsqueeze(-1)\n        Angle_State = Angle_State - TWO_PI * torch.floor(Angle_State / TWO_PI)\n\n        # Apply rotary embeddings to Q and K using cumulative angles\n        cos_angles = torch.cos(Angle_State)\n        sin_angles = torch.sin(Angle_State)\n        q_rot = apply_rotary_emb(q, cos_angles, sin_angles)\n        k_rot = apply_rotary_emb(k, cos_angles, sin_angles)\n\n        trap = torch.sigmoid(trap)\n        alpha = torch.exp(adt)\n        beta = (1 - trap) * dt * alpha\n        gamma = trap * dt\n\n        # Update SSM state using previous K_State and V_State\n        SSM_State = alpha.unsqueeze(-1).unsqueeze(-1) * SSM_State \n        SSM_State = SSM_State + beta.unsqueeze(-1).unsqueeze(-1) * (K_State.unsqueeze(-2) * V_State.unsqueeze(-1))\n        SSM_State = SSM_State + gamma.unsqueeze(-1).unsqueeze(-1) * (k_rot.unsqueeze(-2) * v.unsqueeze(-1))\n\n        # Compute output\n        out = torch.einsum(\"bhdD, bhD -> bhd\", SSM_State, q_rot.to(SSM_State.dtype))\n        \n        if D is not None:\n            out = out + D[None, :, None] * v\n        \n        if Z is not None:\n            out = out * z * torch.sigmoid(z)\n        \n        out_arr.append(out)\n        \n        # Update K and V states for next step\n        K_State = k_rot\n        V_State = v\n    \n    out = torch.stack(out_arr, dim=1)\n    Final_States = (Angle_State, SSM_State, K_State, V_State)\n    return out, Final_States\n\n\ndef mamba3_siso_fwd_ref(\n    Q: torch.Tensor,\n    K: torch.Tensor,\n    V: torch.Tensor,\n    ADT: torch.Tensor,\n    DT: torch.Tensor,\n    Trap: torch.Tensor,\n    Q_bias: torch.Tensor,\n    K_bias: torch.Tensor,\n    Angles: torch.Tensor,\n    D: Optional[torch.Tensor] = None,\n    Z: Optional[torch.Tensor] = None,\n    Initial_States: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = None,\n    chunk_size: int = 64,\n    dtype: torch.dtype = torch.float32,\n    cu_seqlens: Optional[torch.Tensor] = None,\n):\n    \"\"\"Reference implementation of Mamba-3 forward pass.\n    \n    Args:\n        Initial_States: Optional tuple of (Angle_State, SSM_State, K_State, V_State)\n    \n    Returns:\n        out_z: Output with Z gating applied\n        final_states: (Final_Angle_State, Final_SSM_State, Final_K_State, Final_V_State)\n    \"\"\"\n    batch, total_seqlen, nheads_qk, headdim_qk = Q.shape\n    _, _, nheads, headdim_v = V.shape\n    headdim_angles = Angles.shape[-1]\n    device = Q.device\n    \n    is_varlen = cu_seqlens is not None\n    if is_varlen:\n        assert batch == 1\n    \n    # Cast inputs\n    Q = Q.to(dtype)\n    K = K.to(dtype)\n    V = V.to(dtype)\n    ADT = ADT.to(torch.float32)\n    DT = DT.to(torch.float32)\n    Trap = Trap.to(dtype)\n    Q_bias = Q_bias.to(dtype)\n    K_bias = K_bias.to(dtype)\n    Angles = Angles.to(dtype)\n    if D is not None:\n        D = D.to(dtype)\n    if Z is not None:\n        Z = Z.to(dtype)\n    if Initial_States is not None:\n        Initial_Angle_State, Initial_SSM_State, Initial_K_State, Initial_V_State = Initial_States\n\n    Angles = torch.tanh(Angles) * math.pi\n    # Expand Q/K for GQA\n    if Q.shape[2] != V.shape[2]:\n        Q = repeat(Q, \"b s h_bc d -> b s (h_bc g) d\", g=V.shape[2] // Q.shape[2])\n    if K.shape[2] != V.shape[2]:\n        K = repeat(K, \"b s h_bc d -> b s (h_bc g) d\", g=V.shape[2] // K.shape[2])\n\n    out_zs = []\n    Final_Angle_States = []\n    Final_SSM_States = []\n    Final_K_States = []\n    Final_V_States = []\n\n    TWO_PI = 2 * math.pi\n\n    def _rotary(tensor, cos, sin):\n        tensor_reshaped = tensor.view(*tensor.shape[:-1], -1, 2)\n        tensor_0 = tensor_reshaped[..., 0]\n        tensor_1 = tensor_reshaped[..., 1]\n        if cos.shape[-1] < tensor_0.shape[-1]:\n            pad_size = tensor_0.shape[-1] - cos.shape[-1]\n            cos = F.pad(cos, (0, pad_size), value=1.0)\n            sin = F.pad(sin, (0, pad_size), value=0.0)\n        rotated_0 = tensor_0 * cos - tensor_1 * sin\n        rotated_1 = tensor_0 * sin + tensor_1 * cos\n        return torch.stack([rotated_0, rotated_1], dim=-1).view_as(tensor)\n\n    def compute_one_sequence(seq_idx):\n        if is_varlen:\n            start_idx, end_idx = cu_seqlens[seq_idx].item(), cu_seqlens[seq_idx + 1].item()\n            Q_curr = Q[0, start_idx:end_idx, :, :]\n            K_curr = K[0, start_idx:end_idx, :, :]\n            V_curr = V[0, start_idx:end_idx, :, :]\n            ADT_curr = ADT[0, :, start_idx:end_idx]\n            DT_curr = DT[0, :, start_idx:end_idx]\n            Trap_curr = Trap[0, :, start_idx:end_idx]\n            Angles_curr = Angles[0, start_idx:end_idx, :, :]\n            Z_curr = Z[0, start_idx:end_idx, :, :] if Z is not None else None\n        else:\n            Q_curr = Q[seq_idx]\n            K_curr = K[seq_idx]\n            V_curr = V[seq_idx]\n            ADT_curr = ADT[seq_idx]\n            DT_curr = DT[seq_idx]\n            Trap_curr = Trap[seq_idx]\n            Angles_curr = Angles[seq_idx]\n            Z_curr = Z[seq_idx] if Z is not None else None\n\n        Trap_curr = torch.sigmoid(Trap_curr)\n        seqlen_curr = Q_curr.shape[0]\n\n        Angles_scaled = Angles_curr.float() * DT_curr.transpose(0, 1).unsqueeze(-1)\n        Angles_Cumsum = torch.cumsum(Angles_scaled, dim=0)\n        if Initial_States is not None:\n            Initial_Angle_State_curr = Initial_Angle_State[seq_idx]\n            Angles_Cumsum = Angles_Cumsum + Initial_Angle_State_curr.unsqueeze(0)\n        Angles_Cumsum = Angles_Cumsum - TWO_PI * torch.floor(Angles_Cumsum / TWO_PI)\n        Final_Angle_States.append(Angles_Cumsum[-1])\n\n        # Initialize acc_states\n        if Initial_States is not None:\n            Initial_SSM_State_curr = Initial_SSM_State[seq_idx]\n            Initial_K_State_curr = Initial_K_State[seq_idx]\n            Initial_V_State_curr = Initial_V_State[seq_idx]\n\n            scalar = DT_curr[:, 0] * (1 - Trap_curr[:, 0])\n            acc_states = Initial_SSM_State_curr + Initial_V_State_curr[:, :, None] * Initial_K_State_curr[:, None, :] * scalar[:, None, None]\n        else:\n            acc_states = torch.zeros((nheads, headdim_v, headdim_qk), device=device, dtype=torch.float32)\n\n        # Compute shifted gamma and scale\n        DT_shifted = F.pad(DT_curr[:, 1:], (0, 1))\n        Trap_shifted = F.pad(Trap_curr[:, 1:], (0, 1))\n        shifted_gamma = DT_shifted * (1 - Trap_shifted)\n        scale = DT_curr * Trap_curr + DT_shifted * (1 - Trap_shifted)\n\n        # Add biases\n        Q_curr = Q_curr + Q_bias.unsqueeze(0)\n        K_curr = K_curr + K_bias.unsqueeze(0)\n\n        # Compute QK dot for skip connection\n        QK_dot = torch.sum(K_curr * Q_curr, dim=-1) * shifted_gamma.transpose(0, 1)\n\n        # Rotary embeddings using Angles_Cumsum\n        cos_angles_curr = torch.cos(Angles_Cumsum).to(Q_curr.dtype)\n        sin_angles_curr = torch.sin(Angles_Cumsum).to(Q_curr.dtype)\n        Q_curr = _rotary(Q_curr, cos_angles_curr, sin_angles_curr)\n        K_curr = _rotary(K_curr, cos_angles_curr, sin_angles_curr)\n\n        Final_K_States.append(K_curr[-1])\n        Final_V_States.append(V_curr[-1])\n\n        K_curr_scaled = K_curr * scale.transpose(0, 1).unsqueeze(-1).to(K_curr.dtype)\n\n        # Compute output via quadratic attention\n        QK = torch.einsum(\"thd,shd->hts\", Q_curr, K_curr_scaled)\n        QK_causal = torch.tril(QK)\n        QK_causal = (QK_causal * torch.exp(_segsum(ADT_curr))).to(QK_causal.dtype)\n        out = torch.einsum(\"hts,shd->thd\", QK_causal, V_curr)\n\n        if Initial_States is not None:\n            da_cs = torch.cumsum(ADT_curr, dim=-1)\n            exp_da_cs = torch.exp(da_cs)\n            out = out + torch.einsum(\"hDd,thd,ht->thD\", acc_states.to(Q_curr.dtype), Q_curr, exp_da_cs.to(Q_curr.dtype))\n\n        if D is not None:\n            out = out + D[None, :, None] * V_curr\n\n        out = out - V_curr * QK_dot.unsqueeze(-1)\n\n        if Z_curr is not None:\n            out = out * Z_curr * torch.sigmoid(Z_curr)\n        out_zs.append(out)\n\n        # Compute final state\n        da_cs_last = torch.exp(torch.sum(ADT_curr, dim=-1))\n        da_cs_rev = torch.exp(torch.sum(ADT_curr, dim=-1, keepdim=True) - torch.cumsum(ADT_curr, dim=-1))\n        V_curr_scaled = V_curr * da_cs_rev.permute(1, 0).unsqueeze(-1).to(V_curr.dtype)\n        final_acc_states = acc_states * da_cs_last.unsqueeze(-1).unsqueeze(-1) + torch.einsum(\n            \"thd,thD->hDd\", K_curr_scaled, V_curr_scaled.to(K_curr_scaled.dtype))\n        Final_SSM_States.append(final_acc_states)\n\n    num_sequences = cu_seqlens.size(0) - 1 if is_varlen else batch\n    for seq_idx in range(num_sequences):\n        compute_one_sequence(seq_idx)\n\n    if not is_varlen:\n        out_zs = torch.stack(out_zs, dim=0)\n        Final_Angle_States = torch.stack(Final_Angle_States, dim=0)\n        Final_SSM_States = torch.stack(Final_SSM_States, dim=0)\n        Final_K_States = torch.stack(Final_K_States, dim=0)\n        Final_V_States = torch.stack(Final_V_States, dim=0)\n    else:\n        out_zs = torch.cat(out_zs, dim=0).unsqueeze(0)\n        Final_Angle_States = torch.stack(Final_Angle_States, dim=0)\n        Final_SSM_States = torch.stack(Final_SSM_States, dim=0)\n        Final_K_States = torch.stack(Final_K_States, dim=0)\n        Final_V_States = torch.stack(Final_V_States, dim=0)\n\n    return out_zs, (Final_Angle_States, Final_SSM_States, Final_K_States, Final_V_States)\n\n\n# ================================================================== \n# Test Utilities\n# ================================================================== \n\ndef detach_clone(*args):\n    \"\"\"Detach and clone tensors, preserving None values.\"\"\"\n    return tuple([arg.detach().clone().requires_grad_() if arg is not None else None for arg in args])\n\n@torch.no_grad()\ndef relative_error(\n    ker: torch.Tensor,\n    ref: torch.Tensor,\n    eps: float = 1e-6,\n    ref_mag_mask: float = 1e-2,\n    p: float = 0.95,\n    name: str = \"\",\n    print_top_errors: bool = True,\n    angle: bool = False,   # if True: use circular absolute error; else: relative error\n) -> float:\n    assert ker.shape == ref.shape\n\n    ker_xx = ker.detach().to(torch.float32)\n    ref_xx = ref.detach().to(torch.float32)\n\n    abs_ref = ref_xx.abs()\n\n    if angle:\n        delta = ker_xx - ref_xx\n        delta = torch.remainder(delta + math.pi, 2 * math.pi) - math.pi\n        abs_diff = delta.abs()\n    else:\n        abs_diff = (ker_xx - ref_xx).abs()\n\n    mask = abs_ref >= ref_mag_mask\n    if not mask.any():\n        return 0.0\n\n    vals = abs_diff[mask].flatten() if angle else (abs_diff[mask] / (abs_ref[mask] + eps)).flatten()\n\n    n = vals.numel()\n    k = max(1, min(n, int(math.ceil(p * n))))\n    err = vals.kthvalue(k).values.item()\n\n    if print_top_errors and err > 0.01:\n        print(f\"\\n  Top 10 errors for {name}:\")\n        diff_flat = abs_diff.flatten()\n        ref_flat = ref_xx.flatten()\n        ker_flat = ker_xx.flatten()\n        topk = diff_flat.topk(min(10, diff_flat.numel()))\n        for i, idx in enumerate(topk.indices):\n            idx = idx.item()\n            r = ref_flat[idx].item()\n            k_val = ker_flat[idx].item()\n            d = diff_flat[idx].item()\n            if angle:\n                # For angles, show absolute angular error (radians)\n                print(f\"    {i}: ref={r:.6e}, ker={k_val:.6e}, ang_err={d:.6e} rad\")\n            else:\n                rel_e = d / (abs(r) + eps) if abs(r) >= ref_mag_mask else float('nan')\n                print(f\"    {i}: ref={r:.6e}, ker={k_val:.6e}, diff={d:.6e}, rel={rel_e:.2%}\")\n\n    return err\n\n\ndef create_mamba3_siso_inputs(\n    batch: int,\n    seqlen: int,\n    nheads: int,\n    nheads_qk: int,\n    headdim_qk: int,\n    headdim_v: int,\n    dtype: torch.dtype,\n    device: str,\n    has_D: bool,\n    has_Z: bool,\n    has_input_states: bool,\n    cu_seqlens: Optional[torch.Tensor] = None,\n    requires_grad: bool = False,\n):\n    num_sequences = cu_seqlens.size(0) - 1 if cu_seqlens is not None else batch\n    \n    Q = torch.randn((batch, seqlen, nheads_qk, headdim_qk), device=device, dtype=dtype)\n    Q = F.rms_norm(Q, normalized_shape=(headdim_qk,)).clone()\n    K = torch.randn((batch, seqlen, nheads_qk, headdim_qk), device=device, dtype=dtype)\n    K = F.rms_norm(K, normalized_shape=(headdim_qk,)).clone()\n    V = torch.randn((batch, seqlen, nheads, headdim_v), device=device, dtype=dtype)\n\n    dt_max, dt_min = 0.1, 0.001\n    a_init = -torch.empty(batch, nheads, seqlen, device=device, dtype=torch.float32).uniform_(1.0, 16.0)\n    dt = torch.exp(\n        torch.rand(batch, nheads, seqlen, device=device, dtype=torch.float32) \n        * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)\n    )\n    ADT = (a_init * dt).contiguous()\n    DT = dt.contiguous()\n    Trap = torch.empty(batch, nheads, seqlen, dtype=dtype, device=device).uniform_(0.0, 1.0).clone()\n    Q_bias = torch.randn(nheads, headdim_qk, dtype=dtype, device=device)\n    K_bias = torch.randn(nheads, headdim_qk, dtype=dtype, device=device)\n    \n    # headdim_angles constraint: 2*headdim_angles <= headdim_qk\n    headdim_angles = headdim_qk // 4\n    Angles = torch.randn(batch, seqlen, nheads, headdim_angles, dtype=torch.float32, device=device)\n\n    D = torch.ones((nheads,), device=device, dtype=torch.float32) if has_D else None\n    Z = torch.randn((batch, seqlen, nheads, headdim_v), device=device, dtype=dtype) if has_Z else None\n    \n    if has_input_states:\n        Input_Angle_State = torch.randn((num_sequences, nheads, headdim_angles), device=device, dtype=torch.float32)\n        Input_SSM_State = torch.randn((num_sequences, nheads, headdim_v, headdim_qk), device=device, dtype=torch.float32)\n        Input_K_State = torch.randn((num_sequences, nheads, headdim_qk), device=device, dtype=torch.float32)\n        Input_V_State = torch.randn((num_sequences, nheads, headdim_v), device=device, dtype=torch.float32)\n        Input_States = (Input_Angle_State, Input_SSM_State, Input_K_State, Input_V_State)\n    else:\n        Input_States = None\n    \n    if requires_grad:\n        Q.requires_grad_(True)\n        K.requires_grad_(True)\n        V.requires_grad_(True)\n        ADT.requires_grad_(True)\n        DT.requires_grad_(True)\n        Trap.requires_grad_(True)\n        Q_bias.requires_grad_(True)\n        K_bias.requires_grad_(True)\n        Angles.requires_grad_(True)\n        if D is not None:\n            D.requires_grad_(True)\n        if Z is not None:\n            Z.requires_grad_(True)\n        if Input_States is not None:\n            for state in Input_States:\n                state.requires_grad_(True)\n    \n    return {\n        'Q': Q, 'K': K, 'V': V,\n        'ADT': ADT, 'DT': DT, 'Trap': Trap,\n        'Q_bias': Q_bias, 'K_bias': K_bias, 'Angles': Angles,\n        'D': D, 'Z': Z, 'Input_States': Input_States,\n    }\n\n\n# ================================================================== \n# Triton Step Kernel Test\n# ================================================================== \n\ndef test_mamba3_siso_step(nheads_qk=4, has_Z=True, has_D=True):\n    \"\"\"Test Mamba-3 step kernel against reference recurrent implementation.\"\"\"\n    device = 'cuda'\n    rtol = 5e-2\n    dtype = torch.bfloat16\n    torch.random.manual_seed(42)\n    \n    batch = 128\n    seqlen = 2345\n    nheads = 32\n    headdim_qk = 128\n    headdim_v = 64\n    headdim_angles = headdim_qk // 4\n    \n    inputs = create_mamba3_siso_inputs(\n        batch, seqlen, nheads, nheads_qk, headdim_qk, headdim_v,\n        dtype, device, has_D=has_D, has_Z=has_Z, has_input_states=True,\n        requires_grad=False\n    )\n    Q_full, K_full, V_full, ADT_full, DT_full, Trap_full, Q_bias, K_bias, Angles_full, D, Z_full, Input_States = inputs['Q'], inputs['K'], inputs['V'], inputs['ADT'], inputs['DT'], inputs['Trap'], inputs['Q_bias'], inputs['K_bias'], inputs['Angles'], inputs['D'], inputs['Z'], inputs['Input_States']\n\n    angle_state_triton, ssm_state_triton, k_state_triton, v_state_triton = Input_States\n    outputs_triton = []    \n    for step in range(seqlen):\n        Q_step = Q_full[:, step, :, :].contiguous()\n        K_step = K_full[:, step, :, :].contiguous()\n        V_step = V_full[:, step, :, :].contiguous()\n        ADT_step = ADT_full[:, :, step].contiguous()\n        DT_step = DT_full[:, :, step].contiguous()\n        Trap_step = Trap_full[:, :, step].contiguous()\n        Angles_step = Angles_full[:, step, :, :].contiguous()\n        Z_step = Z_full[:, step, :, :].contiguous() if Z_full is not None else None\n        \n        input_states_triton = (angle_state_triton, ssm_state_triton, k_state_triton, v_state_triton)\n        out_triton, output_states_triton = mamba3_siso_step(\n            Q_step, K_step, V_step, ADT_step, DT_step, Trap_step,\n            Q_bias, K_bias, Angles_step, D, Z_step, input_states_triton\n        )\n        angle_state_triton, ssm_state_triton, k_state_triton, v_state_triton = output_states_triton\n        outputs_triton.append(out_triton)\n    \n    outputs_triton = torch.stack(outputs_triton, dim=1)\n\n    # Reference implementation\n    outputs_ref, final_states_ref = mamba3_siso_step_ref(\n        Q_full, K_full, V_full, ADT_full, DT_full, Trap_full,\n        Q_bias, K_bias, Angles_full, D, Z_full, Input_States=Input_States\n    )\n    angle_state_ref, ssm_state_ref, k_state_ref, v_state_ref = final_states_ref\n    \n    out_rel_err = relative_error(outputs_triton, outputs_ref)\n    print(f\"Step output relative error: {out_rel_err:.2e}\")\n    assert out_rel_err < rtol, f\"Step output relative error {out_rel_err} exceeds tolerance {rtol}\"\n    \n    # Compare final states\n    angle_state_err = relative_error(angle_state_triton, angle_state_ref)\n    ssm_state_err = relative_error(ssm_state_triton, ssm_state_ref)\n    k_state_err = relative_error(k_state_triton, k_state_ref)\n    v_state_err = relative_error(v_state_triton, v_state_ref)\n    \n    print(f\"Final state errors - Angle: {angle_state_err:.2e}, SSM: {ssm_state_err:.2e}, K: {k_state_err:.2e}, V: {v_state_err:.2e}\")\n    assert angle_state_err < rtol, f\"Angle state error {angle_state_err} exceeds tolerance {rtol}\"\n    assert ssm_state_err < rtol, f\"SSM state error {ssm_state_err} exceeds tolerance {rtol}\"\n    assert k_state_err < rtol, f\"K state error {k_state_err} exceeds tolerance {rtol}\"\n    assert v_state_err < rtol, f\"V state error {v_state_err} exceeds tolerance {rtol}\"\n\n# ================================================================== \n# Triton Forward+Backward Batched Kernel Test\n# ================================================================== \n\n# Combined Forward+Backward batched mode test\n# NOTE: Relative erros for tensors are within 6-8% (especially when they are reduced). \n# The error for angle is ~20% because cumsum accumulates error over sequence length. This\n# error becomes ~3% when cumsum (angle-dt) kernel is removed\ndef test_mamba3_siso_combined_batched(nheads_qk=4, has_Z=True, has_D=True, headdim_qk=128):\n    \"\"\"Test Mamba-3 combined forward+backward against fwd reference.\n    \"\"\"\n    device = 'cuda'\n    rtol = 1e-1\n    dtype = torch.bfloat16\n    torch.random.manual_seed(42)\n    \n    batch = 16\n    seqlen = 2345\n    nheads = 32\n    headdim_v = 64\n    chunk_size = 64\n    half = seqlen // 2\n    \n    inputs = create_mamba3_siso_inputs(\n        batch, seqlen, nheads, nheads_qk, headdim_qk, headdim_v,\n        dtype, device, has_D=has_D, has_Z=has_Z, has_input_states=True,\n        requires_grad=True\n    )\n    inputs_ref = copy.deepcopy(inputs)\n    \n    # Reference: use mamba3_siso_fwd_ref to compute full sequence output.\n    Out_ref, Final_States_ref = mamba3_siso_fwd_ref(\n        inputs_ref['Q'], inputs_ref['K'], inputs_ref['V'],\n        inputs_ref['ADT'], inputs_ref['DT'], inputs_ref['Trap'],\n        inputs_ref['Q_bias'], inputs_ref['K_bias'], inputs_ref['Angles'],\n        inputs_ref['D'], inputs_ref['Z'], inputs_ref['Input_States'],\n    )\n    \n    # Kernel: two-pass forward via state passing.\n    Out_first, Angle_State_1, SSM_State_1, K_State_1, V_State_1 = mamba3_siso_combined(\n        inputs['Q'][:, :half], inputs['K'][:, :half], inputs['V'][:, :half],\n        inputs['ADT'][:, :, :half], inputs['DT'][:, :, :half], inputs['Trap'][:, :, :half],\n        inputs['Q_bias'], inputs['K_bias'], inputs['Angles'][:, :half],\n        inputs['D'], inputs['Z'][:, :half] if has_Z else None, \n        inputs['Input_States'],\n        chunk_size=chunk_size,\n        return_final_states=True,\n    )\n    Out_second, Final_Angle_State, Final_SSM_State, Final_K_State, Final_V_State = mamba3_siso_combined(\n        inputs['Q'][:, half:], inputs['K'][:, half:], inputs['V'][:, half:],\n        inputs['ADT'][:, :, half:], inputs['DT'][:, :, half:], inputs['Trap'][:, :, half:],\n        inputs['Q_bias'], inputs['K_bias'], inputs['Angles'][:, half:],\n        inputs['D'], inputs['Z'][:, half:] if has_Z else None,\n        (Angle_State_1, SSM_State_1, K_State_1, V_State_1),\n        chunk_size=chunk_size,\n        return_final_states=True,\n    )\n    Out_kernel = torch.cat([Out_first, Out_second], dim=1)\n    \n    # Forward comparison\n    out_err = relative_error(Out_kernel, Out_ref, name=\"Output\")\n    print(f\"Forward output error: {out_err:.2e}\")\n    # assert out_err < rtol, f\"Forward output error {out_err:.2e} exceeds tolerance {rtol}\"\n    \n    # Compare final states\n    Final_Angle_State_ref, Final_SSM_State_ref, Final_K_State_ref, Final_V_State_ref = Final_States_ref\n    for state_name, ker_state, ref_state in [\n        ('Angle', Final_Angle_State, Final_Angle_State_ref),\n        ('SSM', Final_SSM_State, Final_SSM_State_ref),\n        ('K', Final_K_State, Final_K_State_ref),\n        ('V', Final_V_State, Final_V_State_ref),\n    ]:\n        err = relative_error(ker_state, ref_state, name=f\"Final_{state_name}_State\", angle=(state_name=='Angle'))\n        print(f\"Final_{state_name}_State error: {err:.2e}\")\n        # assert err < rtol, f\"Final_{state_name}_State error {err:.2e} exceeds tolerance\"\n    \n    # Backward \n    # Give gradients to both output and final states\n    dO = torch.randn_like(Out_ref)\n    dFinal_Angle_State = torch.randn_like(Final_Angle_State)\n    dFinal_SSM_State = torch.randn_like(Final_SSM_State)\n    dFinal_K_State = torch.randn_like(Final_K_State)\n    dFinal_V_State = torch.randn_like(Final_V_State)\n    \n    # Reference backward\n    torch.autograd.backward(\n        [Out_ref, Final_Angle_State_ref, Final_SSM_State_ref, Final_K_State_ref, Final_V_State_ref],\n        [dO, dFinal_Angle_State, dFinal_SSM_State, dFinal_K_State, dFinal_V_State],\n    )\n    # Kernel backward\n    torch.autograd.backward(\n        [Out_kernel, Final_Angle_State, Final_SSM_State, Final_K_State, Final_V_State],\n        [dO, dFinal_Angle_State, dFinal_SSM_State, dFinal_K_State, dFinal_V_State],\n    )\n    \n    # Compare gradients\n    for grad_name in ['Q', 'K', 'V', 'ADT', 'DT', 'Trap', 'Q_bias', 'K_bias', 'Angles']:\n        err = relative_error(inputs[grad_name].grad, inputs_ref[grad_name].grad, name=f\"d{grad_name}\")\n        print(f\"d{grad_name} error: {err:.2e}\")\n        # assert err < rtol, f\"d{grad_name} error {err:.2e} exceeds tolerance\"\n    \n    if has_D:\n        err = relative_error(inputs['D'].grad, inputs_ref['D'].grad, name=\"dD\")\n        print(f\"dD error: {err:.2e}\")\n\n    if has_Z:\n        err = relative_error(inputs['Z'].grad, inputs_ref['Z'].grad, name=\"dZ\")\n        print(f\"dZ error: {err:.2e}\")\n    \n    # Input state gradients\n    for i, state_name in enumerate(['Angle', 'SSM', 'K', 'V']):\n        err = relative_error(inputs['Input_States'][i].grad, inputs_ref['Input_States'][i].grad, name=f\"dInput_{state_name}_State\")\n        print(f\"dInput_{state_name}_State error: {err:.2e}\")\n\n# ================================================================== \n# Triton Forward+Backward Varlen Kernel Test\n# ================================================================== \n\n# Combined Forward+Backward varlen mode test\n# NOTE: Relative erros for tensors are within 6-8% (especially when they are reduced). \n# The error for angle is ~20% because cumsum accumulates error over sequence length. This\n# error becomes ~3% when cumsum (angle-dt) kernel is removed\ndef test_mamba3_siso_combined_varlen(nheads_qk=4, has_Z=True, has_D=True, headdim_qk=128):\n    \"\"\"Test Mamba-3 combined forward+backward with variable-length sequences against fwd reference.\n    \"\"\"\n    device = 'cuda'\n    rtol = 1e-1\n    dtype = torch.bfloat16\n    torch.random.manual_seed(42)\n    \n    num_sequences = 8\n    seq_lengths = [2345, 2346, 2347, 2348, 2349, 2350, 2351, 2352]\n    total_seqlen = sum(seq_lengths)\n    \n    # Create cu_seqlens\n    cu_seqlens = torch.tensor([0] + list(torch.cumsum(torch.tensor(seq_lengths), dim=0).tolist()), \n                               dtype=torch.int32, device=device)\n    \n    batch = 1  # Varlen requires batch=1\n    nheads = 32\n    headdim_v = 64\n    chunk_size = 64\n    headdim_angles = headdim_qk // 4\n    \n    # Create packed inputs (batch=1, total_seqlen, ...)\n    Q = torch.randn((batch, total_seqlen, nheads_qk, headdim_qk), device=device, dtype=dtype)\n    Q = F.rms_norm(Q, normalized_shape=(headdim_qk,)).clone()\n    K = torch.randn((batch, total_seqlen, nheads_qk, headdim_qk), device=device, dtype=dtype)\n    K = F.rms_norm(K, normalized_shape=(headdim_qk,)).clone()\n    V = torch.randn((batch, total_seqlen, nheads, headdim_v), device=device, dtype=dtype)\n    \n    dt_max, dt_min = 0.1, 0.001\n    a_init = -torch.empty(batch, nheads, total_seqlen, device=device, dtype=torch.float32).uniform_(1.0, 16.0)\n    dt = torch.exp(\n        torch.rand(batch, nheads, total_seqlen, device=device, dtype=torch.float32) \n        * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)\n    )\n    ADT = (a_init * dt).contiguous()\n    DT = dt.contiguous()\n    Trap = torch.empty(batch, nheads, total_seqlen, dtype=dtype, device=device).uniform_(0.0, 1.0).clone()\n    \n    Q_bias = torch.randn(nheads, headdim_qk, dtype=dtype, device=device)\n    K_bias = torch.randn(nheads, headdim_qk, dtype=dtype, device=device)\n    Angles = torch.randn(batch, total_seqlen, nheads, headdim_angles, dtype=dtype, device=device) * 0.1\n    \n    D = torch.ones((nheads,), device=device, dtype=torch.float32) if has_D else None\n    Z = torch.randn((batch, total_seqlen, nheads, headdim_v), device=device, dtype=dtype) if has_Z else None\n    \n    # Input states: one per sequence\n    Input_Angle_State = torch.randn((num_sequences, nheads, headdim_angles), device=device, dtype=torch.float32)\n    Input_SSM_State = torch.randn((num_sequences, nheads, headdim_v, headdim_qk), device=device, dtype=torch.float32)\n    Input_K_State = torch.randn((num_sequences, nheads, headdim_qk), device=device, dtype=torch.float32)\n    Input_V_State = torch.randn((num_sequences, nheads, headdim_v), device=device, dtype=torch.float32)\n    Input_States = (Input_Angle_State, Input_SSM_State, Input_K_State, Input_V_State)\n    \n    # Enable gradients\n    Q.requires_grad_(True)\n    K.requires_grad_(True)\n    V.requires_grad_(True)\n    ADT.requires_grad_(True)\n    DT.requires_grad_(True)\n    Trap.requires_grad_(True)\n    Q_bias.requires_grad_(True)\n    K_bias.requires_grad_(True)\n    Angles.requires_grad_(True)\n    if D is not None:\n        D.requires_grad_(True)\n    if Z is not None:\n        Z.requires_grad_(True)\n    for state in Input_States:\n        state.requires_grad_(True)\n    \n    # Create deep copies for reference\n    inputs_ref = {\n        'Q': Q.detach().clone().requires_grad_(True),\n        'K': K.detach().clone().requires_grad_(True),\n        'V': V.detach().clone().requires_grad_(True),\n        'ADT': ADT.detach().clone().requires_grad_(True),\n        'DT': DT.detach().clone().requires_grad_(True),\n        'Trap': Trap.detach().clone().requires_grad_(True),\n        'Q_bias': Q_bias.detach().clone().requires_grad_(True),\n        'K_bias': K_bias.detach().clone().requires_grad_(True),\n        'Angles': Angles.detach().clone().requires_grad_(True),\n        'D': D.detach().clone().requires_grad_(True) if D is not None else None,\n        'Z': Z.detach().clone().requires_grad_(True) if Z is not None else None,\n        'Input_States': tuple(s.detach().clone().requires_grad_(True) for s in Input_States),\n    }\n    \n    inputs_ker = {\n        'Q': Q, 'K': K, 'V': V,\n        'ADT': ADT, 'DT': DT, 'Trap': Trap,\n        'Q_bias': Q_bias, 'K_bias': K_bias, 'Angles': Angles,\n        'D': D, 'Z': Z, 'Input_States': Input_States,\n    }\n    \n    # Reference: use mamba3_siso_fwd_ref with cu_seqlens\n    Out_ref, Final_States_ref = mamba3_siso_fwd_ref(\n        inputs_ref['Q'], inputs_ref['K'], inputs_ref['V'],\n        inputs_ref['ADT'], inputs_ref['DT'], inputs_ref['Trap'],\n        inputs_ref['Q_bias'], inputs_ref['K_bias'], inputs_ref['Angles'],\n        inputs_ref['D'], inputs_ref['Z'], inputs_ref['Input_States'],\n        cu_seqlens=cu_seqlens,\n    )\n    \n    # Kernel: single call with cu_seqlens\n    Out_kernel, Final_Angle_State, Final_SSM_State, Final_K_State, Final_V_State = mamba3_siso_combined(\n        inputs_ker['Q'], inputs_ker['K'], inputs_ker['V'],\n        inputs_ker['ADT'], inputs_ker['DT'], inputs_ker['Trap'],\n        inputs_ker['Q_bias'], inputs_ker['K_bias'], inputs_ker['Angles'],\n        inputs_ker['D'], inputs_ker['Z'], inputs_ker['Input_States'],\n        chunk_size=chunk_size,\n        return_final_states=True,\n        cu_seqlens=cu_seqlens,\n    )\n    \n    # Forward comparison\n    out_err = relative_error(Out_kernel, Out_ref, name=\"Output\")\n    print(f\"Forward output error: {out_err:.2e}\")\n    \n    # Compare final states\n    Final_Angle_State_ref, Final_SSM_State_ref, Final_K_State_ref, Final_V_State_ref = Final_States_ref\n    for state_name, ker_state, ref_state in [\n        ('Angle', Final_Angle_State, Final_Angle_State_ref),\n        ('SSM', Final_SSM_State, Final_SSM_State_ref),\n        ('K', Final_K_State, Final_K_State_ref),\n        ('V', Final_V_State, Final_V_State_ref),\n    ]:\n        err = relative_error(ker_state, ref_state, name=f\"Final_{state_name}_State\", angle=(state_name=='Angle'))\n        print(f\"Final_{state_name}_State error: {err:.2e}\")\n    \n    # Backward\n    dO = torch.randn_like(Out_ref)\n    dFinal_Angle_State = torch.randn_like(Final_Angle_State)\n    dFinal_SSM_State = torch.randn_like(Final_SSM_State)\n    dFinal_K_State = torch.randn_like(Final_K_State)\n    dFinal_V_State = torch.randn_like(Final_V_State)\n    \n    # Reference backward\n    torch.autograd.backward(\n        [Out_ref, Final_Angle_State_ref, Final_SSM_State_ref, Final_K_State_ref, Final_V_State_ref],\n        [dO, dFinal_Angle_State, dFinal_SSM_State, dFinal_K_State, dFinal_V_State],\n    )\n    # Kernel backward\n    torch.autograd.backward(\n        [Out_kernel, Final_Angle_State, Final_SSM_State, Final_K_State, Final_V_State],\n        [dO, dFinal_Angle_State, dFinal_SSM_State, dFinal_K_State, dFinal_V_State],\n    )\n    \n    # Compare gradients\n    for grad_name in ['Q', 'K', 'V', 'ADT', 'DT', 'Trap', 'Q_bias', 'K_bias', 'Angles']:\n        err = relative_error(inputs_ker[grad_name].grad, inputs_ref[grad_name].grad, name=f\"d{grad_name}\")\n        print(f\"d{grad_name} error: {err:.2e}\")\n    \n    if has_D:\n        err = relative_error(inputs_ker['D'].grad, inputs_ref['D'].grad, name=\"dD\")\n        print(f\"dD error: {err:.2e}\")\n    if has_Z:\n        err = relative_error(inputs_ker['Z'].grad, inputs_ref['Z'].grad, name=\"dZ\")\n        print(f\"dZ error: {err:.2e}\")\n    \n    # Input state gradients\n    for i, state_name in enumerate(['Angle', 'SSM', 'K', 'V']):\n        err = relative_error(inputs_ker['Input_States'][i].grad, inputs_ref['Input_States'][i].grad, name=f\"dInput_{state_name}_State\")\n        print(f\"dInput_{state_name}_State error: {err:.2e}\")\n\n\n# ================================================================== \n# Sanity check test: Step reference and Forward reference match\n# ================================================================== \n\ndef test_mamba3_siso_step_ref_vs_fwd_ref(nheads_qk=4, has_Z=True, has_D=True):\n    \"\"\"Test that mamba3_siso_step_ref and mamba3_siso_fwd_ref produce identical outputs.\"\"\"\n    device = 'cuda'\n    rtol = 1e-4  # Both are pure Python/PyTorch, so should match very closely\n    dtype = torch.float32  # Use float32 for reference-vs-reference comparison\n    torch.random.manual_seed(42)\n\n    batch = 16\n    seqlen = 2048\n    nheads = 32\n    headdim_qk = 128\n    headdim_v = 64\n    headdim_angles = headdim_qk // 4\n\n    inputs = create_mamba3_siso_inputs(\n        batch, seqlen, nheads, nheads_qk, headdim_qk, headdim_v,\n        dtype, device, has_D=has_D, has_Z=has_Z, has_input_states=True,\n        requires_grad=False,\n    )\n\n    # --- Step ref ---\n    out_step, final_states_step = mamba3_siso_step_ref(\n        inputs['Q'], inputs['K'], inputs['V'],\n        inputs['ADT'], inputs['DT'], inputs['Trap'],\n        inputs['Q_bias'], inputs['K_bias'], inputs['Angles'],\n        inputs['D'], inputs['Z'],\n        Input_States=inputs['Input_States'],\n    )\n    angle_state_step, ssm_state_step, k_state_step, v_state_step = final_states_step\n\n    # --- Fwd ref ---\n    out_fwd, final_states_fwd = mamba3_siso_fwd_ref(\n        inputs['Q'], inputs['K'], inputs['V'],\n        inputs['ADT'], inputs['DT'], inputs['Trap'],\n        inputs['Q_bias'], inputs['K_bias'], inputs['Angles'],\n        inputs['D'], inputs['Z'],\n        Initial_States=inputs['Input_States'],\n        dtype=dtype,\n    )\n    angle_state_fwd, ssm_state_fwd, k_state_fwd, v_state_fwd = final_states_fwd\n\n    # --- Compare outputs ---\n    out_err = relative_error(out_step, out_fwd, name=\"Output\", ref_mag_mask=1e-3)\n    print(f\"Output error: {out_err:.2e}\")\n    # assert out_err < rtol, f\"Output error {out_err:.2e} exceeds tolerance {rtol}\"\n\n    # --- Compare final states ---\n    for state_name, step_state, fwd_state in [\n        ('Angle', angle_state_step, angle_state_fwd),\n        ('SSM',   ssm_state_step,   ssm_state_fwd),\n        ('K',     k_state_step,     k_state_fwd),\n        ('V',     v_state_step,     v_state_fwd),\n    ]:\n        err = relative_error(step_state, fwd_state, name=f\"Final_{state_name}_State\",\n                             angle=(state_name == 'Angle'), ref_mag_mask=1e-3)\n        print(f\"Final_{state_name}_State error: {err:.2e}\")\n\n\n# Main function\nif __name__ == \"__main__\":\n    print(\"Running Mamba-3 step reference vs forward reference test...\")\n    test_mamba3_siso_step_ref_vs_fwd_ref()\n    print(\"=\"*100)\n\n    print(\"\\nRunning Mamba-3 combined forward+backward batched test...\")\n    test_mamba3_siso_combined_batched()\n    print(\"=\"*100)\n\n    print(\"\\nRunning Mamba-3 combined forward+backward varlen test...\")\n    test_mamba3_siso_combined_varlen()\n    print(\"=\"*100)"
  },
  {
    "path": "tests/ops/triton/test_selective_state_update.py",
    "content": "# Copyright (C) 2023, Tri Dao.\n\nimport math\n\nimport torch\nimport torch.nn.functional as F\nimport pytest\n\nfrom einops import rearrange, repeat\n\nfrom mamba_ssm.ops.triton.selective_state_update import selective_state_update, selective_state_update_ref\n\n\n@pytest.mark.parametrize(\"itype\", [torch.float32, torch.float16, torch.bfloat16])\n# @pytest.mark.parametrize('itype', [torch.float16])\n@pytest.mark.parametrize(\"has_z\", [False, True])\n# @pytest.mark.parametrize('has_z', [True])\n@pytest.mark.parametrize(\"dstate\", [16, 32, 64])\n# @pytest.mark.parametrize(\"dstate\", [16])\n@pytest.mark.parametrize(\"dim\", [2048, 2048 + 16, 4096])\n# @pytest.mark.parametrize(\"dim\", [2048])\ndef test_selective_state_update(dim, dstate, has_z, itype):\n    device = \"cuda\"\n    rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)\n    if itype == torch.bfloat16:\n        rtol, atol = 1e-2, 5e-2\n        if torch.version.hip:\n            atol *= 2\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 2\n    state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)\n    x = torch.randn(batch_size, dim, device=device, dtype=itype)\n    dt = torch.randn(batch_size, dim, device=device, dtype=itype)\n    dt_bias = torch.rand(dim, device=device) - 4.0\n    A = -torch.rand(dim, dstate, device=device) - 1.0\n    B = torch.randn(batch_size, dstate, device=device)\n    C = torch.randn(batch_size, dstate, device=device)\n    D = torch.randn(dim, device=device)\n    if has_z:\n        z = torch.randn_like(x)\n    else:\n        z = None\n    state_ref = state.detach().clone()\n    out = selective_state_update(state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True)\n    out_ref = selective_state_update_ref(state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True)\n\n    print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n    print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n    assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)\n    assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)\n\n\n@pytest.mark.parametrize(\"itype\", [torch.float32, torch.float16, torch.bfloat16])\n# @pytest.mark.parametrize('itype', [torch.float16])\n@pytest.mark.parametrize(\"has_z\", [False, True])\n# @pytest.mark.parametrize('has_z', [True])\n@pytest.mark.parametrize(\"tie_hdim\", [False, True])\n# @pytest.mark.parametrize('tie_hdim', [True])\n@pytest.mark.parametrize(\"ngroups\", [1, 2, 4])\n# @pytest.mark.parametrize(\"ngroups\", [2])\n@pytest.mark.parametrize(\"dstate\", [16, 32, 64])\n# @pytest.mark.parametrize(\"dstate\", [16])\n@pytest.mark.parametrize(\"dim\", [2048, 4096])\n# @pytest.mark.parametrize(\"dim\", [2048])\ndef test_selective_state_update_with_heads(dim, dstate, ngroups, has_z, tie_hdim, itype):\n    device = \"cuda\"\n    rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 3e-2)\n    if itype == torch.bfloat16:\n        rtol, atol = 1e-2, 1e-1\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 2\n    headdim = 64\n    nheads = dim // headdim\n    state = torch.randn(batch_size, nheads, headdim, dstate, dtype=itype, device=device)\n    x = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)\n    if not tie_hdim:\n        dt = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)\n        dt_bias = torch.rand(nheads, headdim, device=device) - 4.0\n        A = -torch.rand(nheads, headdim, dstate, device=device) - 1.0\n        D = torch.randn(nheads, headdim, device=device)\n    else:\n        dt = repeat(torch.randn(batch_size, nheads, device=device, dtype=itype), \"b h -> b h p\", p=headdim)\n        dt_bias = repeat(torch.rand(nheads, device=device) - 4.0, \"h -> h p\", p=headdim)\n        A = repeat(-torch.rand(nheads, device=device) - 1.0, \"h -> h p n\", p=headdim, n=dstate)\n        D = repeat(torch.randn(nheads, device=device), \"h -> h p\", p=headdim)\n    B = torch.randn(batch_size, ngroups, dstate, device=device)\n    C = torch.randn(batch_size, ngroups, dstate, device=device)\n    if has_z:\n        z = torch.randn_like(x)\n    else:\n        z = None\n    state_ref = state.detach().clone()\n    state_og = state.detach().clone()\n    out = selective_state_update(state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True)\n    out_ref = selective_state_update_ref(state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True)\n\n    print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n    print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n    assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)\n    assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)\n\n@pytest.mark.parametrize(\"itype\", [torch.float32, torch.float16, torch.bfloat16])\n# @pytest.mark.parametrize('itype', [torch.float16])\n@pytest.mark.parametrize(\"has_z\", [False, True])\n# @pytest.mark.parametrize('has_z', [True])\n@pytest.mark.parametrize(\"dstate\", [16, 32, 64])\n# @pytest.mark.parametrize(\"dstate\", [16])\n@pytest.mark.parametrize(\"dim\", [2048, 2048 + 16, 4096])\n# @pytest.mark.parametrize(\"dim\", [2048])\ndef test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype):\n    device = \"cuda\"\n    rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)\n    if itype == torch.bfloat16:\n        rtol, atol = 6e-2, 6e-2\n        if torch.version.hip:\n            atol *= 2\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 16\n\n    total_entries = 10 * batch_size\n    state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device)\n    state_indices = torch.randperm(total_entries)[:batch_size].to(dtype=torch.int32, device=device)\n\n    x = torch.randn(batch_size, dim, device=device, dtype=itype)\n    dt = torch.randn(batch_size, dim, device=device, dtype=itype)\n    dt_bias = torch.rand(dim, device=device) - 4.0\n    A = -torch.rand(dim, dstate, device=device) - 1.0\n    B = torch.randn(batch_size, dstate, device=device)\n    C = torch.randn(batch_size, dstate, device=device)\n    D = torch.randn(dim, device=device)\n    if has_z:\n        z = torch.randn_like(x)\n    else:\n        z = None\n    state_ref = state[state_indices,:].detach().clone()\n    out = selective_state_update(state, x, dt, A, B, C, D=D, z=z,\n                                 dt_bias=dt_bias, dt_softplus=True, state_batch_indices=state_indices)\n    out_ref = selective_state_update_ref(state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True)\n\n    print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n    print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n    assert torch.allclose(state[state_indices,:], state_ref, rtol=rtol, atol=atol)\n    assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)\n\n\n@pytest.mark.parametrize(\"itype\", [torch.float32, torch.float16, torch.bfloat16])\n#@pytest.mark.parametrize('itype', [torch.float32])\n@pytest.mark.parametrize(\"has_z\", [False, True])\n# @pytest.mark.parametrize('has_z', [True])\n@pytest.mark.parametrize(\"tie_hdim\", [False, True])\n# @pytest.mark.parametrize('tie_hdim', [True])\n@pytest.mark.parametrize(\"ngroups\", [1, 2, 4])\n# @pytest.mark.parametrize(\"ngroups\", [2])\n@pytest.mark.parametrize(\"dstate\", [16, 32, 64])\n# @pytest.mark.parametrize(\"dstate\", [16])\n@pytest.mark.parametrize(\"dim\", [2048, 4096])\n# @pytest.mark.parametrize(\"dim\", [2048])\ndef test_selective_state_update_with_heads_with_batch_indices(dim, dstate, ngroups, has_z, tie_hdim, itype):\n    device = \"cuda\"\n    rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 3e-2)\n    if itype == torch.bfloat16:\n        rtol, atol = 1e-1, 1e-1\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 16\n    headdim = 64\n    nheads = dim // headdim\n\n    total_entries = 10 * batch_size\n    state = torch.randn(total_entries, nheads, headdim, dstate, dtype=itype, device=device)\n    state_indices = torch.randperm(total_entries)[:batch_size].to(dtype=torch.int32, device=device)\n\n    x = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)\n    if not tie_hdim:\n        dt = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)\n        dt_bias = torch.rand(nheads, headdim, device=device) - 4.0\n        A = -torch.rand(nheads, headdim, dstate, device=device) - 1.0\n        D = torch.randn(nheads, headdim, device=device)\n    else:\n        dt = repeat(torch.randn(batch_size, nheads, device=device, dtype=itype), \"b h -> b h p\", p=headdim)\n        dt_bias = repeat(torch.rand(nheads, device=device) - 4.0, \"h -> h p\", p=headdim)\n        A = repeat(-torch.rand(nheads, device=device) - 1.0, \"h -> h p n\", p=headdim, n=dstate)\n        D = repeat(torch.randn(nheads, device=device), \"h -> h p\", p=headdim)\n    B = torch.randn(batch_size, ngroups, dstate, device=device)\n    C = torch.randn(batch_size, ngroups, dstate, device=device)\n    if has_z:\n        z = torch.randn_like(x)\n    else:\n        z = None\n    state_ref = state[state_indices,:].detach().clone()\n    state_og = state[state_indices,:].detach().clone()\n    out = selective_state_update(state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True, state_batch_indices=state_indices)\n    out_ref = selective_state_update_ref(state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True)\n\n    print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n    print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n    assert torch.allclose(state[state_indices,:], state_ref, rtol=rtol, atol=atol)\n    assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)\n"
  },
  {
    "path": "tests/ops/triton/test_ssd.py",
    "content": "import math\n\nimport torch\nimport torch.nn.functional as F\n\nimport pytest\n\nfrom einops import rearrange, repeat\n\nfrom mamba_ssm.ops.triton.ssd_chunk_state import chunk_state, chunk_state_ref\nfrom mamba_ssm.ops.triton.ssd_chunk_state import _chunk_cumsum_fwd, _chunk_state_fwd\nfrom mamba_ssm.ops.triton.ssd_chunk_state import chunk_state_varlen\nfrom mamba_ssm.ops.triton.ssd_state_passing import state_passing, state_passing_ref\nfrom mamba_ssm.ops.triton.ssd_state_passing import _state_passing_fwd\nfrom mamba_ssm.ops.triton.ssd_chunk_scan import chunk_scan, chunk_scan_ref\nfrom mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_chunk_scan, ssd_chunk_scan_combined_ref, ssd_selective_scan\nfrom mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined, mamba_split_conv1d_scan_ref\n\n\ndef detach_clone(*args):\n    return tuple([arg.detach().clone().requires_grad_() if arg is not None else None for arg in args])\n\n\n@pytest.mark.parametrize('dtype', [torch.float32, torch.float16, torch.bfloat16])\n# @pytest.mark.parametrize('dtype', [torch.bfloat16])\n@pytest.mark.parametrize('ngroups', [1, 2, 8, \"max\"])\n# @pytest.mark.parametrize('ngroups', [1])\n@pytest.mark.parametrize('chunk_size', [64, 128])\n# @pytest.mark.parametrize('chunk_size', [128])\ndef test_chunk_state_varlen(chunk_size, ngroups, dtype):\n    device = 'cuda'\n    rtol, atol = (1e-2, 3e-3)\n    # set seed\n    torch.random.manual_seed(chunk_size + (ngroups if ngroups != \"max\" else 64))\n    batch = 300\n    seqlens = torch.randint(1, 200, (batch,), device=device)\n    # batch = 3\n    # seqlens = torch.tensor([201, 56, 5], device=device)\n    cu_seqlens = F.pad(seqlens.cumsum(0), (1, 0))\n    total_seqlen = seqlens.sum().item()\n    seq_idx = torch.cat([torch.full((s,), i, dtype=torch.int32, device=device) for i, s in enumerate(seqlens)], dim=0).unsqueeze(0)\n    dim = 4096\n    # dim = 64\n    headdim = 64\n    # dim = 32\n    dstate = 32\n    assert dim % headdim == 0\n    nheads = dim // headdim\n    if ngroups == \"max\":\n        ngroups = nheads\n    assert nheads % ngroups == 0\n    B = torch.randn(total_seqlen, ngroups, dstate, dtype=dtype, device=device) / 5\n    x = torch.randn(total_seqlen, nheads, headdim, dtype=dtype, device=device)\n    A = -0.1 * (torch.rand(nheads, device=device))\n    dt = F.softplus(torch.randn(total_seqlen, nheads, device=device, dtype=torch.float32) - 4)\n    dA_cumsum, dt_rounded = _chunk_cumsum_fwd(dt.unsqueeze(0), A, chunk_size)\n    chunk_states = _chunk_state_fwd(B.unsqueeze(0), x.unsqueeze(0), dt_rounded, dA_cumsum, seq_idx=seq_idx)\n    chunk_states, _ = _state_passing_fwd(rearrange(chunk_states, \"... p n -> ... (p n)\"), dA_cumsum[:, :, :, -1],\n                                         seq_idx=seq_idx, chunk_size=chunk_size)\n    chunk_states = rearrange(chunk_states, \"... (p n) -> ... p n\", n=dstate)\n    chunk_states = chunk_states.squeeze(0)\n    dA_cumsum = dA_cumsum.squeeze(0)\n    dt_rounded = dt_rounded.squeeze(0)\n    out = chunk_state_varlen(B, x, dt_rounded, dA_cumsum, cu_seqlens, chunk_states)\n    out_ref = []\n    for b in range(batch):\n        x_s = x[cu_seqlens[b]:cu_seqlens[b + 1]].unsqueeze(0)\n        B_s = B[cu_seqlens[b]:cu_seqlens[b + 1]].unsqueeze(0)\n        dt_s = dt[cu_seqlens[b]:cu_seqlens[b + 1]].unsqueeze(0)\n        dA_cumsum_s, dt_rounded_s = _chunk_cumsum_fwd(dt_s, A, chunk_size)\n        states = chunk_state(B_s, x_s, dt_rounded_s, dA_cumsum_s)\n        _, final_states = _state_passing_fwd(rearrange(states, \"... p n -> ... (p n)\"), dA_cumsum_s[:, :, :, -1],\n                                             chunk_size=chunk_size)\n        final_states = rearrange(final_states, \"... (p n) -> ... p n\", n=dstate)\n        out_ref.append(final_states)\n    out_ref = torch.cat(out_ref, dim=0)\n    print(f\"Max diff = {(out - out_ref).abs().max().item()}\")\n    assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)\n"
  },
  {
    "path": "tests/test_determinism.py",
    "content": "# Copyright (c) 2024, Tri Dao, Albert Gu.\n\nimport os\n\nimport pytest\nimport torch\n\n\ndef _set_deterministic(enabled: bool) -> None:\n    torch.use_deterministic_algorithms(enabled)\n    if enabled:\n        os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n\n\ndef _set_seeds(seed: int) -> None:\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n\n\ndef _max_abs_diff(a: torch.Tensor, b: torch.Tensor) -> float:\n    return (a.float() - b.float()).abs().max().item()\n\n\ndef _make_inputs(\n    *,\n    seed: int,\n    headdim: int,\n    dstate: int,\n    chunk_size: int = 256,\n    ngroups: int = 1,\n    dtype: torch.dtype = torch.bfloat16,\n    d_has_hdim: bool = False,\n) -> dict[str, torch.Tensor]:\n    import math\n\n    _set_seeds(seed)\n    device = \"cuda\"\n\n    batch = 2\n    seqlen = 2048\n    nheads = 8\n    nchunks = math.ceil(seqlen / chunk_size)\n\n    x = torch.randn(batch, seqlen, nheads, headdim, device=device, dtype=dtype)\n    dout = torch.randn_like(x)\n    dt = torch.randn(batch, nheads, nchunks, chunk_size, device=device, dtype=torch.float32)\n    dA_cumsum = torch.randn_like(dt)\n    cb = torch.randn(batch, nchunks, ngroups, chunk_size, chunk_size, device=device, dtype=dtype)\n\n    B = torch.randn(batch, seqlen, ngroups, dstate, device=device, dtype=dtype).contiguous()\n    C = torch.randn(batch, seqlen, ngroups, dstate, device=device, dtype=dtype).contiguous()\n    dstates = torch.randn(batch, nchunks, nheads, headdim, dstate, device=device, dtype=torch.float32)\n    prev_states = torch.randn_like(dstates)\n\n    ddA = torch.randn(batch, nheads, nchunks, chunk_size, device=device, dtype=torch.float32)\n    ddt_out = torch.randn_like(ddA)\n    dt_raw = torch.randn(batch, seqlen, nheads, device=device, dtype=dtype)\n    A = (torch.randn(nheads, device=device, dtype=torch.float32) * -1.0).contiguous()\n    dt_bias = torch.randn(nheads, device=device, dtype=torch.float32).contiguous()\n    # D shape: (nheads, headdim) when d_has_hdim=True, else (nheads,)\n    if d_has_hdim:\n        D = torch.randn(nheads, headdim, device=device, dtype=torch.float32)\n    else:\n        D = torch.randn(nheads, device=device, dtype=torch.float32)\n\n    return {\n        \"x\": x,\n        \"dout\": dout,\n        \"dt\": dt,\n        \"dA_cumsum\": dA_cumsum,\n        \"cb\": cb,\n        \"B\": B,\n        \"C\": C,\n        \"dstates\": dstates,\n        \"prev_states\": prev_states,\n        \"ddA\": ddA,\n        \"ddt_out\": ddt_out,\n        \"dt_raw\": dt_raw,\n        \"A\": A,\n        \"dt_bias\": dt_bias,\n        \"D\": D,\n    }\n\n\ndef _run_case_outputs(\n    *,\n    case: str,\n    deterministic: bool,\n    seed: int,\n    headdim: int = 64,\n    dstate: int = 64,\n    chunk_size: int = 256,\n    ngroups: int = 1,\n    dtype: torch.dtype = torch.bfloat16,\n    d_has_hdim: bool = False,\n) -> dict[str, torch.Tensor]:\n    _set_deterministic(deterministic)\n    t = _make_inputs(\n        seed=seed,\n        headdim=headdim,\n        dstate=dstate,\n        chunk_size=chunk_size,\n        ngroups=ngroups,\n        dtype=dtype,\n        d_has_hdim=d_has_hdim,\n    )\n\n    if case == \"chunk_scan_bwd_dx\":\n        from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_dx\n        dx, ddt = _chunk_scan_bwd_dx(t[\"cb\"], t[\"x\"], t[\"dt\"], t[\"dA_cumsum\"], t[\"dout\"])\n        out = {\"dx\": dx, \"ddt\": ddt}\n    elif case == \"chunk_scan_bwd_dC\":\n        from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_dC\n        dC, ddA_prev = _chunk_scan_bwd_dC(t[\"prev_states\"], t[\"dA_cumsum\"], t[\"dout\"], C=t[\"C\"], ngroups=1)\n        out = {\"dC\": dC, \"ddA_cumsum_prev\": ddA_prev}\n    elif case == \"chunk_state_bwd_dx\":\n        from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_bwd_dx\n        dx, ddt, ddA = _chunk_state_bwd_dx(t[\"B\"], t[\"x\"], t[\"dt\"], t[\"dA_cumsum\"], t[\"dstates\"])\n        out = {\"dx\": dx, \"ddt\": ddt, \"ddA_cumsum\": ddA}\n    elif case == \"chunk_state_bwd_db\":\n        from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_bwd_db\n        dB, ddA = _chunk_state_bwd_db(t[\"x\"], t[\"dt\"], t[\"dA_cumsum\"], t[\"dstates\"], B=t[\"B\"], ngroups=1)\n        out = {\"dB\": dB, \"ddA_cumsum\": ddA}\n    elif case == \"chunk_state_bwd_ddAcs_stable\":\n        from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_bwd_ddAcs_stable\n        ddA = _chunk_state_bwd_ddAcs_stable(t[\"B\"], t[\"x\"], t[\"dt\"], t[\"dA_cumsum\"], t[\"dstates\"])\n        out = {\"ddA_cumsum\": ddA}\n    elif case == \"chunk_cumsum_bwd\":\n        from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_cumsum_bwd\n        ddt, dA, ddt_bias = _chunk_cumsum_bwd(t[\"ddA\"], t[\"ddt_out\"], t[\"dt_raw\"], t[\"A\"], dt_bias=t[\"dt_bias\"], dt_softplus=True)\n        out = {\"ddt\": ddt, \"dA\": dA, \"ddt_bias\": ddt_bias}\n    elif case.startswith(\"combined_bwd_dx\"):\n        from mamba_ssm.ops.triton.ssd_combined import _chunk_scan_chunk_state_bwd_dx\n        dx, ddt, dD = _chunk_scan_chunk_state_bwd_dx(t[\"x\"], t[\"dt\"], t[\"dA_cumsum\"], t[\"B\"], t[\"cb\"], t[\"dout\"], t[\"dstates\"], D=t[\"D\"])\n        out = {\"dx\": dx, \"ddt\": ddt, \"dD\": dD}\n    else:\n        raise AssertionError(f\"Unknown case: {case}\")\n\n    torch.cuda.synchronize()\n    return {k: v.detach().clone().float() for k, v in out.items() if v is not None}\n\n\n_KERNEL_CASES = [\n    \"chunk_scan_bwd_dx\",\n    \"chunk_scan_bwd_dC\",\n    \"chunk_state_bwd_dx\",\n    \"chunk_state_bwd_db\",\n    \"chunk_state_bwd_ddAcs_stable\",\n    \"chunk_cumsum_bwd\",\n]\n\n_COMBINED_CASES = [\n    (\"combined_bwd_dx\", False),\n    (\"combined_bwd_dx_d_hdim\", True),\n]\n\n_HEADDIMS = [64, 128]\n_DSTATES = [64]\n\n\ndef _kernel_is_reproducible(case: str, headdim: int, dstate: int, d_has_hdim: bool = False):\n    runs = 5\n    outs = [\n        _run_case_outputs(case=case, deterministic=True, seed=123, headdim=headdim, dstate=dstate, d_has_hdim=d_has_hdim)\n        for _ in range(runs)\n    ]\n    ref = outs[0]\n    for i in range(1, runs):\n        for k in ref:\n            assert _max_abs_diff(ref[k], outs[i][k]) == 0.0, f\"{case} output {k} differs (headdim={headdim}, dstate={dstate})\"\n\n\ndef _kernel_close_to_default(case: str, headdim: int, dstate: int, d_has_hdim: bool = False):\n    atol = rtol = 1e-2\n    det = _run_case_outputs(case=case, deterministic=True, seed=123, headdim=headdim, dstate=dstate, d_has_hdim=d_has_hdim)\n    for _ in range(3):\n        default = _run_case_outputs(case=case, deterministic=False, seed=123, headdim=headdim, dstate=dstate, d_has_hdim=d_has_hdim)\n        for k in det:\n            assert torch.allclose(default[k], det[k], atol=atol, rtol=rtol), f\"{case} output {k} not close (headdim={headdim}, dstate={dstate})\"\n\n\n@pytest.mark.skipif(not torch.cuda.is_available(), reason=\"CUDA required\")\n@pytest.mark.parametrize(\"dstate\", _DSTATES)\n@pytest.mark.parametrize(\"headdim\", _HEADDIMS)\n@pytest.mark.parametrize(\"case\", _KERNEL_CASES)\ndef test_kernel_reproducible(case: str, headdim: int, dstate: int):\n    _kernel_is_reproducible(case, headdim, dstate)\n\n\n@pytest.mark.skipif(not torch.cuda.is_available(), reason=\"CUDA required\")\n@pytest.mark.parametrize(\"dstate\", _DSTATES)\n@pytest.mark.parametrize(\"headdim\", _HEADDIMS)\n@pytest.mark.parametrize(\"case,d_has_hdim\", _COMBINED_CASES)\ndef test_combined_kernel_reproducible(case: str, d_has_hdim: bool, headdim: int, dstate: int):\n    _kernel_is_reproducible(case, headdim, dstate, d_has_hdim)\n\n\n@pytest.mark.skipif(not torch.cuda.is_available(), reason=\"CUDA required\")\n@pytest.mark.parametrize(\"dstate\", _DSTATES)\n@pytest.mark.parametrize(\"headdim\", _HEADDIMS)\n@pytest.mark.parametrize(\"case\", _KERNEL_CASES)\ndef test_kernel_close_to_default(case: str, headdim: int, dstate: int):\n    _kernel_close_to_default(case, headdim, dstate)\n\n\n@pytest.mark.skipif(not torch.cuda.is_available(), reason=\"CUDA required\")\n@pytest.mark.parametrize(\"dstate\", _DSTATES)\n@pytest.mark.parametrize(\"headdim\", _HEADDIMS)\n@pytest.mark.parametrize(\"case,d_has_hdim\", _COMBINED_CASES)\ndef test_combined_kernel_close_to_default(case: str, d_has_hdim: bool, headdim: int, dstate: int):\n    _kernel_close_to_default(case, headdim, dstate, d_has_hdim)\n\n\n@pytest.mark.skipif(not torch.cuda.is_available(), reason=\"CUDA required\")\ndef test_default_mode_is_not_reproducible():\n    from mamba_ssm.modules.mamba2 import Mamba2\n\n    device = \"cuda\"\n    dtype = torch.bfloat16\n    seed = 123\n    runs = 20\n    batch = 4\n    seqlen = 4096\n\n    _set_seeds(seed)\n    model = Mamba2(\n        d_model=256, d_state=64, headdim=64, expand=2, d_conv=4, chunk_size=256,\n        use_mem_eff_path=True, device=device, dtype=dtype,\n    ).train()\n    x_data = torch.randn(batch, seqlen, model.d_model, device=device, dtype=dtype)\n\n    def _run() -> dict[str, torch.Tensor]:\n        _set_deterministic(False)\n        model.zero_grad(set_to_none=True)\n        x = x_data.clone().requires_grad_(True)\n        y = model(x)\n        (y.float().square().mean()).backward()\n        torch.cuda.synchronize()\n        grads = {\"input\": x.grad.detach().float().clone()}\n        for name, p in model.named_parameters():\n            if p.grad is not None:\n                grads[name] = p.grad.detach().float().clone()\n        return grads\n\n    _run()  # warmup\n    ref = _run()\n    observed_diff = False\n    for _ in range(runs - 1):\n        g = _run()\n        for k in ref:\n            if _max_abs_diff(ref[k], g[k]) != 0.0:\n                observed_diff = True\n                break\n        if observed_diff:\n            break\n\n    if not observed_diff:\n        pytest.xfail(\n            f\"Did not observe nondeterminism in default mode after {runs} runs. \"\n            \"This GPU may have deterministic atomic behavior at these shapes.\"\n        )\n\n\n@pytest.mark.skipif(not torch.cuda.is_available(), reason=\"CUDA required\")\ndef test_mamba2_fwd_bwd_deterministic_reproducible():\n    from mamba_ssm.modules.mamba2 import Mamba2\n\n    device = \"cuda\"\n    dtype = torch.bfloat16\n    seed = 123\n    runs = 5\n    batch = 2\n    seqlen = 2048\n    headdim = 64\n\n    _set_seeds(seed)\n    _set_deterministic(True)\n\n    model = Mamba2(\n        d_model=headdim, d_state=16, headdim=headdim, expand=2, d_conv=4, chunk_size=16,\n        use_mem_eff_path=True, device=device, dtype=dtype,\n    ).train()\n    x_data = torch.randn(batch, seqlen, model.d_model, device=device, dtype=dtype)\n\n    def _run() -> tuple[torch.Tensor, dict[str, torch.Tensor]]:\n        model.zero_grad(set_to_none=True)\n        x = x_data.clone().requires_grad_(True)\n        y = model(x)\n        (y.float().square().mean()).backward()\n        torch.cuda.synchronize()\n        grads: dict[str, torch.Tensor] = {\"input\": x.grad.detach().float().clone()}\n        for name, p in model.named_parameters():\n            if p.grad is not None:\n                grads[name] = p.grad.detach().float().clone()\n        return y.detach().float().clone(), grads\n\n    _run()  # warmup\n    y0, g0 = _run()\n    for _ in range(runs - 1):\n        y, g = _run()\n        assert _max_abs_diff(y0, y) == 0.0\n        assert g.keys() == g0.keys()\n        for k in g0:\n            assert _max_abs_diff(g0[k], g[k]) == 0.0, f\"Mamba2 grad {k} differs\"\n\n\n@pytest.mark.skipif(not torch.cuda.is_available(), reason=\"CUDA required\")\ndef test_mamba2_fwd_bwd_deterministic_close_to_default():\n    from mamba_ssm.modules.mamba2 import Mamba2\n\n    device = \"cuda\"\n    dtype = torch.bfloat16\n    seed = 123\n    batch = 2\n    seqlen = 2048\n    headdim = 64\n    atol = rtol = 1e-2\n\n    def _run(deterministic: bool) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:\n        torch.use_deterministic_algorithms(deterministic, warn_only=True)\n        _set_seeds(seed)\n        model = Mamba2(\n            d_model=headdim * 4, d_state=32, headdim=headdim, expand=2, d_conv=4, chunk_size=64,\n            use_mem_eff_path=True, device=device, dtype=dtype,\n        ).train()\n        x = torch.randn(batch, seqlen, model.d_model, device=device, dtype=dtype).requires_grad_(True)\n        y = model(x)\n        (y.float().square().mean()).backward()\n        torch.cuda.synchronize()\n        grads: dict[str, torch.Tensor] = {\"input\": x.grad.detach().float().clone()}\n        for name, p in model.named_parameters():\n            if p.grad is not None:\n                grads[name] = p.grad.detach().float().clone()\n        return y.detach().float().clone(), grads\n\n    _run(False)  # warmup\n    y_default, g_default = _run(False)\n    y_det, g_det = _run(True)\n\n    assert torch.allclose(y_default, y_det, atol=atol, rtol=rtol), \"Mamba2 output differs\"\n    for k in g_default:\n        assert torch.allclose(g_default[k], g_det[k], atol=atol, rtol=rtol), f\"Mamba2 grad {k} not close\"\n"
  },
  {
    "path": "tests/test_generation.py",
    "content": "import torch\nimport torch.nn.functional as F\n\nfrom mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel\nfrom mamba_ssm.models.config_mamba import MambaConfig\nfrom mamba_ssm.utils.generation import InferenceParams\n\nimport pytest\n\nfrom einops import rearrange, repeat\n\n\ndef test_generation():\n    batch = 3\n    seqlen = 20\n    device = \"cuda\"\n    dtype = torch.float16\n\n    config = MambaConfig(\n        d_model=1024,\n        n_layer=4,\n        vocab_size=50277,\n        ssm_cfg=dict(layer=\"Mamba2\"),\n        rms_norm=True,\n        residual_in_fp32=True,\n        fused_add_norm=True,\n        pad_vocab_size_multiple=16,\n    )\n    torch.manual_seed(2357)\n    model = MambaLMHeadModel(config, device=device, dtype=dtype)\n    x = torch.randint(0, 1000, (batch, seqlen), device=device, dtype=torch.long)\n    out_ref = model(x).logits\n    prompt_len = seqlen // 2\n    out = model.generate(\n        input_ids = x[:, :prompt_len], max_length=seqlen, output_scores=True, return_dict_in_generate=True,\n        cg=True,  # Can turn off CUDA graph for easier debugging\n        # instead of sampling, we take output tokens from x, to get logits for testing\n        # For actual generation, don't pass in teacher_outputs\n        teacher_outputs=x,\n    )\n    out_scores = torch.stack(out.scores, dim=1)\n    print(f\"Max diff: {(out_scores - out_ref[:, prompt_len - 1: -1]).abs().max()}\")\n    assert torch.allclose(out_scores, out_ref[:, prompt_len - 1: -1], rtol=1e-3, atol=1e-2)\n\n\ndef test_generation_varlen():\n    seqlens = [170, 65, 100]\n    genlen = 20\n    total_seqlen = sum(seqlens)\n    device = \"cuda\"\n    dtype = torch.float16\n\n    config = MambaConfig(\n        d_model=1024,\n        n_layer=4,\n        vocab_size=50277,\n        ssm_cfg=dict(layer=\"Mamba2\"),\n        rms_norm=True,\n        residual_in_fp32=True,\n        fused_add_norm=True,\n        pad_vocab_size_multiple=16,\n    )\n    torch.manual_seed(2357)\n    model = MambaLMHeadModel(config, device=device, dtype=dtype)\n    xs = [torch.randint(0, 1000, (1, seqlen), device=device, dtype=torch.long) for seqlen in seqlens]\n\n    # Reference 1: Forward pass with seq_idx\n    x = torch.cat(xs, dim=1)\n    seq_idx = torch.cat([torch.full((ids.shape[1],), i, dtype=torch.int32, device=device)\n                         for i, ids in enumerate(xs)], dim=0).unsqueeze(0)\n    cu_seqlens = F.pad(torch.tensor(seqlens, device=device, dtype=torch.int32).cumsum(dim=0), (1, 0))\n    out_ref = model(x, seq_idx=seq_idx).logits\n    # Only take the last @genlen logits of each sequence\n    out_ref = torch.cat([out_ref[:, cu_seqlens[i + 1] - genlen - 1:cu_seqlens[i + 1] - 1]\n                         for i in range(len(seqlens))], dim=0)\n\n    # Reference 2: Generate the last @genlen tokens of each sequence in a for loop\n    out_loop = []\n    for input_ids in xs:\n        out = model.generate(\n            input_ids=input_ids[:, :-genlen], max_length=input_ids.shape[1], output_scores=True,\n            return_dict_in_generate=True, cg=True, teacher_outputs=input_ids,\n        ).scores\n        out_loop.append(torch.stack(out, dim=1))\n    out_loop = torch.cat(out_loop, dim=0)\n    print(f\"Max diff between ref1 and ref2: {(out_loop - out_ref).abs().max()}\")\n\n    # Varlen generation\n    input_ids = torch.cat([ids[:, :-genlen] for ids in xs], dim=1)\n    prompt_seqlens = [seqlen - genlen for seqlen in seqlens]\n    cu_seqlens = F.pad(torch.tensor(prompt_seqlens, device=device, dtype=torch.int32).cumsum(dim=0), (1, 0))\n    seq_idx = torch.cat([torch.full((seqlen,), i, dtype=torch.int32, device=device)\n                         for i, seqlen in enumerate(prompt_seqlens)], dim=0).unsqueeze(0)\n    inference_params = InferenceParams(max_seqlen=2048, max_batch_size=len(seqlens))\n\n    scores, sequences = [], []\n    # Both seq_idx and cu_seqlens must be passed in for varlen generation\n    logits = model(input_ids, inference_params=inference_params, seq_idx=seq_idx, cu_seqlens=cu_seqlens).logits\n    logits = rearrange(logits[0, cu_seqlens[1:] - 1], \"b d -> b 1 d\")\n    scores.append(logits)\n    # In practice we should sample. In this case we take from the teacher_output for testing\n    sampled_tokens = rearrange(torch.stack([ids[0, -genlen] for ids in xs], dim=0), \"b -> b 1\")\n    sequences.append(sampled_tokens)\n    for i in range(1, genlen):\n        inference_params.seqlen_offset += 1\n        logits = model(sampled_tokens, inference_params=inference_params, num_last_tokens=1).logits\n        scores.append(logits)\n        # In practice we should sample. In this case we take from the teacher_output for testing\n        sampled_tokens = rearrange(torch.stack([ids[0, -genlen + i] for ids in xs], dim=0), \"b -> b 1\")\n        sequences.append(sampled_tokens)\n    out_varlen = torch.cat(scores, dim=1)\n    print(f\"Max diff: {(out_varlen - out_ref).abs().max()}\")\n    assert (out_varlen - out_ref).abs().max() < 2 * (out_loop - out_ref).abs().max()\n\ndef test_generation_varlen_with_padding():\n    seqlens = [170, 65, 100]\n    non_padded_seqlen = sum(seqlens)\n    padded_seqlen = 512\n    seqlens.append(padded_seqlen - non_padded_seqlen)\n    genlen = 20\n    total_seqlen = sum(seqlens)\n    assert total_seqlen == padded_seqlen\n    device = \"cuda\"\n    dtype = torch.float16\n\n    config = MambaConfig(\n        d_model=1024,\n        n_layer=4,\n        vocab_size=50277,\n        ssm_cfg=dict(layer=\"Mamba2\"),\n        rms_norm=True,\n        residual_in_fp32=True,\n        fused_add_norm=True,\n        pad_vocab_size_multiple=16,\n    )\n    torch.manual_seed(2357)\n    model = MambaLMHeadModel(config, device=device, dtype=dtype)\n    xs = [torch.randint(0, 1000, (1, seqlen), device=device, dtype=torch.long) for seqlen in seqlens]\n\n    # Reference 1: Forward pass with seq_idx\n    x = torch.cat(xs[:-1], dim=1)\n    seq_idx = torch.cat([torch.full((ids.shape[1],), i, dtype=torch.int32, device=device)\n                         for i, ids in enumerate(xs[:-1])], dim=0).unsqueeze(0)\n    cu_seqlens = F.pad(torch.tensor(seqlens[:-1], device=device, dtype=torch.int32).cumsum(dim=0), (1, 0))\n\n    out_ref = model(x, seq_idx=seq_idx).logits\n    # Only take the last @genlen logits of each sequence\n    out_ref = torch.cat([out_ref[:, cu_seqlens[i + 1] - genlen - 1:cu_seqlens[i + 1] - 1]\n                         for i in range(len(seqlens) - 1)], dim=0)\n\n    # Reference 2: Generate the last @genlen tokens of each sequence in a for loop\n    out_loop = []\n    for input_ids in xs[:-1]:\n        out = model.generate(\n            input_ids=input_ids[:, :-genlen], max_length=input_ids.shape[1], output_scores=True,\n            return_dict_in_generate=True, cg=True, teacher_outputs=input_ids,\n        ).scores\n        out_loop.append(torch.stack(out, dim=1))\n    out_loop = torch.cat(out_loop, dim=0)\n    print(f\"Max diff between ref1 and ref2: {(out_loop - out_ref).abs().max()}\")\n\n    # Varlen generation\n    input_ids = torch.cat([ids[:, :-genlen] for ids in xs], dim=1)\n    prompt_seqlens = [seqlen - genlen for seqlen in seqlens]\n    cu_seqlens = F.pad(torch.tensor(prompt_seqlens, device=device, dtype=torch.int32).cumsum(dim=0), (1, 0))\n    seq_idx = torch.cat([torch.full((seqlen,), i, dtype=torch.int32, device=device)\n                         for i, seqlen in enumerate(prompt_seqlens)], dim=0).unsqueeze(0)\n    inference_params = InferenceParams(max_seqlen=2048, max_batch_size=len(seqlens))\n\n    # Account for padding\n    offset = genlen * len(seqlens)\n    seq_idx[non_padded_seqlen - offset : padded_seqlen - offset] = -1\n    cu_seqlens[-1] = cu_seqlens[-2]\n\n    scores, sequences = [], []\n    # Both seq_idx and cu_seqlens must be passed in for varlen generation\n    logits = model(input_ids, inference_params=inference_params, seq_idx=seq_idx, cu_seqlens=cu_seqlens).logits\n    logits = rearrange(logits[0, cu_seqlens[1:] - 1], \"b d -> b 1 d\")\n    scores.append(logits)\n    # In practice we should sample. In this case we take from the teacher_output for testing\n    sampled_tokens = rearrange(torch.stack([ids[0, -genlen] for ids in xs], dim=0), \"b -> b 1\")\n    sequences.append(sampled_tokens)\n    for i in range(1, genlen):\n        inference_params.seqlen_offset += 1\n        logits = model(sampled_tokens, inference_params=inference_params, num_last_tokens=1).logits\n        scores.append(logits)\n        # In practice we should sample. In this case we take from the teacher_output for testing\n        sampled_tokens = rearrange(torch.stack([ids[0, -genlen + i] for ids in xs], dim=0), \"b -> b 1\")\n        sequences.append(sampled_tokens)\n    out_varlen = torch.cat(scores, dim=1)\n\n    print(f\"Max diff: {(out_varlen[:-1] - out_ref).abs().max()}\")\n    assert (out_varlen[:-1] - out_ref).abs().max() < 2 * (out_loop - out_ref).abs().max()\n"
  },
  {
    "path": "usage.md",
    "content": "# Mamba adoption\n\nWe've been very happy to see Mamba being adopted by many organizations\nand research labs to speed up their training / inference.\nThis page contains a partial list of places where Mamba is being used.\nIf you'd like to add links to your organization / product / codebase, please open a\nPR or email us. We'd very much like to hear from you!\n\n## Large language models and multi-modal models\n\n- [Tencent's Hunyuan-TurboS (560B)](https://arxiv.org/abs/2505.15431)\n\n- [Nvidia Nemotron-H (8B, 47B, 56B)](https://research.nvidia.com/labs/adlr/nemotronh/)\n\n- [AI21 Jamba (398B)](https://www.ai21.com/blog/announcing-jamba-model-family/)\n\n- [TII Falcon-H1 (34B)](https://falconllm.tii.ae/falcon-h1.html)\n\n- [IBM Bamba (9B)](https://research.ibm.com/blog/bamba-ssm-transformer-model)\n\n- [Mistral's Codestral (7B)](https://mistral.ai/news/codestral-mamba)\n\n- [Nvidia Mamba-2 Hybrid (8B)](https://arxiv.org/abs/2406.07887)\n\n- [Microsoft Samba (4B)](https://arxiv.org/abs/2406.07522v1)\n\n- [TII Falcon-Mamba (7B)](https://falconllm.tii.ae/tii-releases-first-sslm-with-falcon-mamba-7b.html)\n\n## Inference frameworks\n\n- vLLM\n\n- Nvidia's TensorRT-LLM\n\n## Hardware\n\n- Nvidia GPUs\n\n- [AMD GPUs](https://rocm.blogs.amd.com/artificial-intelligence/mamba/README.html)\n\n- [AWS Trainium 2](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/tutorials/fused_mamba.html)\n\n\n"
  }
]