Repository: state-spaces/mamba Branch: main Commit: a76afbd7dfdc Files: 94 Total size: 1.2 MB Directory structure: gitextract_vi30yky3/ ├── .github/ │ ├── scripts/ │ │ ├── build.sh │ │ ├── check_for_ngc_images.sh │ │ └── test.sh │ └── workflows/ │ ├── _build.yml │ ├── _build_in_container.yml │ ├── build.yml │ ├── build_in_container.yml │ └── publish.yaml ├── .gitignore ├── .gitmodules ├── AUTHORS ├── LICENSE ├── MANIFEST.in ├── README.md ├── benchmarks/ │ └── benchmark_generation_mamba_simple.py ├── csrc/ │ └── selective_scan/ │ ├── reverse_scan.cuh │ ├── selective_scan.cpp │ ├── selective_scan.h │ ├── selective_scan_bwd_bf16_complex.cu │ ├── selective_scan_bwd_bf16_real.cu │ ├── selective_scan_bwd_fp16_complex.cu │ ├── selective_scan_bwd_fp16_real.cu │ ├── selective_scan_bwd_fp32_complex.cu │ ├── selective_scan_bwd_fp32_real.cu │ ├── selective_scan_bwd_kernel.cuh │ ├── selective_scan_common.h │ ├── selective_scan_fwd_bf16.cu │ ├── selective_scan_fwd_fp16.cu │ ├── selective_scan_fwd_fp32.cu │ ├── selective_scan_fwd_kernel.cuh │ ├── static_switch.h │ └── uninitialized_copy.cuh ├── evals/ │ └── lm_harness_eval.py ├── mamba_ssm/ │ ├── __init__.py │ ├── distributed/ │ │ ├── __init__.py │ │ ├── distributed_utils.py │ │ └── tensor_parallel.py │ ├── models/ │ │ ├── __init__.py │ │ ├── config_mamba.py │ │ └── mixer_seq_simple.py │ ├── modules/ │ │ ├── __init__.py │ │ ├── block.py │ │ ├── mamba2.py │ │ ├── mamba2_simple.py │ │ ├── mamba3.py │ │ ├── mamba_simple.py │ │ ├── mha.py │ │ ├── mlp.py │ │ └── ssd_minimal.py │ ├── ops/ │ │ ├── __init__.py │ │ ├── cute/ │ │ │ └── mamba3/ │ │ │ └── mamba3_step_fn.py │ │ ├── selective_scan_interface.py │ │ ├── tilelang/ │ │ │ └── mamba3/ │ │ │ ├── mamba3_mimo.py │ │ │ ├── mamba3_mimo_bwd.py │ │ │ └── mamba3_mimo_fwd.py │ │ └── triton/ │ │ ├── __init__.py │ │ ├── angle_cumsum.py │ │ ├── k_activations.py │ │ ├── layer_norm.py │ │ ├── layernorm_gated.py │ │ ├── mamba3/ │ │ │ ├── angle_dt.py │ │ │ ├── mamba3_mimo_rotary_step.py │ │ │ ├── mamba3_mimo_utils.py │ │ │ ├── mamba3_siso_bwd.py │ │ │ ├── mamba3_siso_combined.py │ │ │ ├── mamba3_siso_fwd.py │ │ │ ├── mamba3_siso_step.py │ │ │ └── utils.py │ │ ├── selective_state_update.py │ │ ├── softplus.py │ │ ├── ssd_bmm.py │ │ ├── ssd_chunk_scan.py │ │ ├── ssd_chunk_state.py │ │ ├── ssd_combined.py │ │ └── ssd_state_passing.py │ └── utils/ │ ├── __init__.py │ ├── determinism.py │ ├── generation.py │ ├── hf.py │ └── torch.py ├── pyproject.toml ├── rocm_patch/ │ └── rocm6_0.patch ├── setup.py ├── tests/ │ ├── benchmark_determinism_kernels.py │ ├── ops/ │ │ ├── cute/ │ │ │ └── test_mamba3_mimo_step.py │ │ ├── test_selective_scan.py │ │ ├── tilelang/ │ │ │ └── test_mamba3_mimo.py │ │ └── triton/ │ │ ├── test_layernorm_gated.py │ │ ├── test_mamba3_siso.py │ │ ├── test_selective_state_update.py │ │ └── test_ssd.py │ ├── test_determinism.py │ └── test_generation.py └── usage.md ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/scripts/build.sh ================================================ #!/bin/bash set -eoxu pipefail # We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6 # https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810 # However this still fails so I am using a newer version of setuptools pip install setuptools==68.0.0 pip install ninja packaging wheel export PATH=/usr/local/cuda/bin:/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH # Limit MAX_JOBS otherwise the github runner goes OOM export MAX_JOBS=2 export MAMBA_FORCE_BUILD="TRUE" export MAMBA_FORCE_CXX11_ABI=$CXX11_ABI # 5h timeout since GH allows max 6h and we want some buffer EXIT_CODE=0 timeout 5h python setup.py bdist_wheel --dist-dir=dist || EXIT_CODE=$? if [ $EXIT_CODE -eq 0 ]; then tmpname=cu${WHEEL_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi$CXX11_ABI wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2") ls dist/*whl |xargs -I {} mv {} dist/${wheel_name} echo "wheel_name=${wheel_name}" >> $GITHUB_ENV fi echo $EXIT_CODE ================================================ FILE: .github/scripts/check_for_ngc_images.sh ================================================ #!/bin/bash # Configuration BASE_IMAGE="nvcr.io/nvidia/pytorch" TAG_SUFFIX="-py3" MONTHS_TO_CHECK=7 # Check current month and previous 6 months (total 7) # Initialize an array to store existing tags EXISTING_TAGS=() echo "Checking for existence of the last ${MONTHS_TO_CHECK} NGC PyTorch images: ${BASE_IMAGE}:YY.MM${TAG_SUFFIX}" echo "---------------------------------------------------------------------" # Loop through the last N months for i in $(seq 0 $((MONTHS_TO_CHECK - 1))); do # Calculate Year and Month for the tag CURRENT_YEAR=$(date +%Y) CURRENT_MONTH=$(date +%m) # Calculate target month and year TARGET_DATE=$(date -d "$CURRENT_YEAR-$CURRENT_MONTH-01 -$i months" +%y.%m) # Construct the full image tag and the tag-only string IMAGE_TAG="${TARGET_DATE}${TAG_SUFFIX}" FULL_IMAGE="${BASE_IMAGE}:${IMAGE_TAG}" echo "Checking: ${FULL_IMAGE}" # Use 'docker manifest inspect' to check for image existence without pulling. if docker manifest inspect "${FULL_IMAGE}" > /dev/null 2>&1; then echo "✅ EXISTS: Found." # Add the tag-only string to the array EXISTING_TAGS+=("nvcr.io/nvidia/pytorch:${IMAGE_TAG}") else echo "❌ MISSING: Not found." fi done echo "---------------------------------------------------------------------" ## JSON Output Generation # This uses the collected array to build a JSON string. # 1. Convert the shell array to a newline-separated string. TAGS_NL_SEP=$(printf "%s\n" "${EXISTING_TAGS[@]}") # 2. Use jq to read the newline-separated list and format it into a JSON array. # . | split("\n") | .[:-1] reads the input, splits it by newline, and removes the trailing empty element. if command -v jq &> /dev/null; then JSON_STRING=$(echo -e "${TAGS_NL_SEP}" | jq -R -s 'split("\n") | .[:-1]') echo "Generated JSON String of Existing Tags:" echo "${JSON_STRING}" # Optional: Save the JSON string to a variable for further use # echo "JSON_STRING is now available in the shell if you source this script." else echo "WARNING: 'jq' is not installed. Cannot format output as JSON." echo "Found Tags: ${EXISTING_TAGS[*]}" fi echo "---" echo "Check complete." echo "${JSON_STRING}" > ngc_images.json ================================================ FILE: .github/scripts/test.sh ================================================ #!/bin/bash set -exou pipefail pip install dist/*.whl python -c "import mamba_ssm; print(mamba_ssm.__version__)" ================================================ FILE: .github/workflows/_build.yml ================================================ name: ~Build wheel template on: workflow_call: inputs: runs-on: description: "The runner to use for the build" required: true type: string python-version: description: "The Python version to use for the build" required: true type: string cuda-version: description: "The CUDA version to use for the build" required: true type: string torch-version: description: "The PyTorch version to use for the build" required: true type: string cxx11_abi: description: "The C++11 ABI to use for the build" required: true type: string upload-to-release: description: "Upload wheel to this release" required: false type: boolean default: false release-version: description: "Upload wheel to this release" required: false type: string defaults: run: shell: bash -x -e -u -o pipefail {0} jobs: build-wheel: runs-on: ${{ inputs.runs-on }} name: Build wheel (${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}) steps: - name: Checkout uses: actions/checkout@v4 with: ref: ${{ inputs.release-version }} submodules: recursive - name: Checkout build scripts uses: actions/checkout@v4 with: path: build-scripts/ - name: Set up Python uses: actions/setup-python@v5 with: python-version: ${{ inputs.python-version }} - name: Set CUDA and PyTorch versions run: | echo "MATRIX_CUDA_VERSION=$(echo ${{ inputs.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV echo "MATRIX_TORCH_VERSION=$(echo ${{ inputs.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV echo "WHEEL_CUDA_VERSION=$(echo ${{ inputs.cuda-version }} | awk -F \. {'print $1'})" >> $GITHUB_ENV echo "MATRIX_PYTHON_VERSION=$(echo ${{ inputs.python-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV - name: Free up disk space if: ${{ runner.os == 'Linux' }} # https://github.com/easimon/maximize-build-space/blob/master/action.yml # https://github.com/easimon/maximize-build-space/tree/test-report run: | sudo rm -rf /usr/share/dotnet sudo rm -rf /opt/ghc sudo rm -rf /opt/hostedtoolcache/CodeQL - name: Set up swap space if: runner.os == 'Linux' uses: pierotofy/set-swap-space@v1.0 with: swap-size-gb: 10 - name: Install CUDA ${{ inputs.cuda-version }} if: ${{ inputs.cuda-version != 'cpu' }} uses: Jimver/cuda-toolkit@v0.2.30 id: cuda-toolkit with: cuda: ${{ inputs.cuda-version }} linux-local-args: '["--toolkit"]' # default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1 # method: ${{ (inputs.cuda-version == '11.8.0' || inputs.cuda-version == '12.1.0') && 'network' || 'local' }} method: "network" sub-packages: '["nvcc"]' - name: Install PyTorch ${{ inputs.torch-version }}+cu${{ inputs.cuda-version }} run: | pip install --upgrade pip # For some reason torch 2.2.0 on python 3.12 errors saying no setuptools pip install setuptools==68.0.0 # With python 3.13 and torch 2.5.1, unless we update typing-extensions, we get error # AttributeError: attribute '__default__' of 'typing.ParamSpec' objects is not writable pip install typing-extensions==4.12.2 # Pick the highest available PyTorch wheel CUDA version that doesn't exceed system CUDA export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \ available = { \ '2.6': [118, 124, 126], \ '2.7': [118, 126, 128], \ '2.8': [126, 128, 129], \ '2.9': [126, 128, 130], \ '2.10': [126, 128, 130], \ }[env['MATRIX_TORCH_VERSION']]; \ sys_cuda = int(env['MATRIX_CUDA_VERSION']); \ print(max(v for v in available if v <= sys_cuda))" \ ) if [[ ${{ inputs.torch-version }} == *"dev"* ]]; then # pip install --no-cache-dir --pre torch==${{ inputs.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION} # Can't use --no-deps because we need cudnn etc. # Hard-coding this version of pytorch-triton for torch 2.6.0.dev20241001 pip install jinja2 pip install https://download.pytorch.org/whl/nightly/pytorch_triton-3.1.0%2Bcf34004b8a-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl pip install --no-cache-dir --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ inputs.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl else pip install --no-cache-dir torch==${{ inputs.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION} fi nvcc --version python --version python -c "import torch; print('PyTorch:', torch.__version__)" python -c "import torch; print('CUDA:', torch.version.cuda)" python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)" shell: bash - name: Build wheel id: build_wheel env: CXX11_ABI: ${{ inputs.cxx11_abi }} MATRIX_TORCH_VERSION: ${{ env.MATRIX_TORCH_VERSION}} WHEEL_CUDA_VERSION: ${{ env.WHEEL_CUDA_VERSION }} MATRIX_PYTHON_VERSION: ${{ env.MATRIX_PYTHON_VERSION }} run: | EXIT_CODE=$(bash build-scripts/.github/scripts/build.sh | tail -n 1) # Store exit code in GitHub env for later steps echo "build_exit_code=$EXIT_CODE" | tee -a "$GITHUB_OUTPUT" exit $EXIT_CODE - name: Log Built Wheels run: | ls dist - name: Get Release with tag id: get_current_release uses: joutvhu/get-release@v1 with: tag_name: ${{ inputs.release-version }} env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - name: Upload Release Asset id: upload_release_asset if: inputs.upload-to-release uses: actions/upload-release-asset@v1 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: upload_url: ${{ steps.get_current_release.outputs.upload_url }} asset_path: ./dist/${{env.wheel_name}} asset_name: ${{env.wheel_name}} asset_content_type: application/* ================================================ FILE: .github/workflows/_build_in_container.yml ================================================ name: ~Build wheel template on: workflow_call: inputs: runs-on: description: "The runner to use for the build" required: true type: string container-image: description: "Container image" required: true type: string upload-to-release: description: "Upload wheel to this release" required: false type: boolean default: false release-version: description: "Upload wheel to this release" required: false type: string defaults: run: shell: bash -x -e -u -o pipefail {0} jobs: build-wheel: runs-on: ${{ inputs.runs-on }} name: Build wheel (${{ inputs.container-image }}) steps: - name: Move /var/lib/containerd/ run: | mkdir -p "${GITHUB_WORKSPACE}/docker/containerd" sudo mv /var/lib/containerd/ "${GITHUB_WORKSPACE}/docker/containerd" - name: Move /var/lib/containerd/ run: | mkdir -p "${GITHUB_WORKSPACE}/docker/docker" sudo mv /var/lib/docker/ "${GITHUB_WORKSPACE}/docker/docker" - name: Maximize build space uses: easimon/maximize-build-space@master with: root-reserve-mb: 5120 temp-reserve-mb: 32 swap-size-mb: 10240 remove-dotnet: "true" remove-android: "true" remove-haskell: "true" remove-codeql: "true" build-mount-path: "/var/lib/" - name: Restore /var/lib/containerd/ run: sudo sh -c "mv ${GITHUB_WORKSPACE}/docker/containerd/* /var/lib/containerd" - name: Restore /var/lib/docker/ run: sudo sh -c "mv ${GITHUB_WORKSPACE}/docker/docker/* /var/lib/docker" - name: Checkout source uses: actions/checkout@v4 with: ref: ${{ inputs.release-version }} submodules: recursive - name: Checkout build scripts uses: actions/checkout@v4 with: path: build-scripts/ - name: Build run: | echo "Free space:" df -h - name: Pull the container run: docker pull ${{ inputs.container-image }} - name: Set CUDA and PyTorch versions run: | cat <<'EOF' >> script.sh #!/bin/bash set -eoxu pipefail echo "MATRIX_CUDA_VERSION=$(echo $CUDA_VERSION | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV echo "MATRIX_TORCH_VERSION=$NVIDIA_PYTORCH_VERSION" >> $GITHUB_ENV echo "WHEEL_CUDA_VERSION=$(echo $CUDA_VERSION | awk -F \. {'print $1'})" >> $GITHUB_ENV echo "MATRIX_PYTHON_VERSION=$(python -c "import sys; print('{}.{}'.format(sys.version_info[0], sys.version_info[1]))" | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV echo "CXX11_ABI=$(python -c 'import torch; print(str(torch._C._GLIBCXX_USE_CXX11_ABI).upper())')" >> $GITHUB_ENV cat $GITHUB_ENV EOF docker run \ --rm \ --shm-size=64g \ --workdir /workspace \ --volume $(pwd):/workspace \ --volume $GITHUB_ENV:$GITHUB_ENV \ -e GITHUB_ENV=$GITHUB_ENV \ ${{ inputs.container-image }} bash /workspace/script.sh - name: Build wheel id: build_wheel env: CXX11_ABI: ${{ env.CXX11_ABI }} MATRIX_TORCH_VERSION: ${{ env.MATRIX_TORCH_VERSION}} WHEEL_CUDA_VERSION: ${{ env.WHEEL_CUDA_VERSION }} MATRIX_PYTHON_VERSION: ${{ env.MATRIX_PYTHON_VERSION }} run: | EXIT_CODE=$(docker run \ --rm \ --shm-size=64g \ --workdir /workspace \ --volume $(pwd):/workspace \ --volume $GITHUB_ENV:$GITHUB_ENV \ -e PIP_CONSTRAINT= \ -e GITHUB_ENV=$GITHUB_ENV \ -e CXX11_ABI=$CXX11_ABI \ -e MATRIX_TORCH_VERSION=$MATRIX_TORCH_VERSION \ -e WHEEL_CUDA_VERSION=$WHEEL_CUDA_VERSION \ -e MATRIX_PYTHON_VERSION=$MATRIX_PYTHON_VERSION \ ${{ inputs.container-image }} bash /workspace/build-scripts/.github/scripts/build.sh | tail -n 1) - name: Test wheels run: | docker run \ --rm \ --shm-size=64g \ --workdir /workspace \ --volume $(pwd):/workspace \ --volume $GITHUB_ENV:$GITHUB_ENV \ -e GITHUB_ENV=$GITHUB_ENV \ ${{ inputs.container-image }} bash /workspace/build-scripts/.github/scripts/test.sh - name: Log Built Wheels run: | ls dist - name: Get Release with tag id: get_current_release uses: joutvhu/get-release@v1 with: tag_name: ${{ inputs.release-version }} env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - name: Upload Release Asset id: upload_release_asset if: inputs.upload-to-release uses: actions/upload-release-asset@v1 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: upload_url: ${{ steps.get_current_release.outputs.upload_url }} asset_path: ./dist/${{env.wheel_name}} asset_name: ${{env.wheel_name}} asset_content_type: application/* ================================================ FILE: .github/workflows/build.yml ================================================ name: Build wheels on: workflow_dispatch: inputs: runs-on: description: "The runner to use for the build" required: true type: string default: ubuntu-22.04 python-version: description: "The Python version to use for the build" required: true type: string cuda-version: description: "The CUDA version to use for the build" required: true type: string torch-version: description: "The PyTorch version to use for the build" required: true type: string cxx11_abi: description: "Enable torch flag C++11 ABI (TRUE/FALSE)" required: true type: string upload-to-release: description: "Upload wheel to this release" required: false type: boolean default: false release-version: description: "Upload wheel to this release" required: false type: string jobs: build-wheels: uses: ./.github/workflows/_build.yml with: runs-on: ${{ inputs.runs-on || 'ubuntu-22.04' }} python-version: ${{ inputs.python-version || '3.12' }} cuda-version: ${{ inputs.cuda-version || '12.9.1' }} torch-version: ${{ inputs.torch-version || '2.10.0' }} cxx11_abi: ${{ inputs.cxx11_abi || 'TRUE' }} upload-to-release: ${{ inputs.upload-to-release || false }} release-version: ${{ inputs.release-version || 'v2.2.6.post3' }} ================================================ FILE: .github/workflows/build_in_container.yml ================================================ name: Build wheels in a container on: workflow_dispatch: inputs: runs-on: description: "The runner to use for the build" required: true type: string default: ubuntu-22.04 container-image: description: "Container image" required: true type: string upload-to-release: description: "Upload wheel to this release" required: false type: boolean default: false release-version: description: "Release version tag to checkout and upload to" required: false type: string push: tags-ignore: - v* jobs: get_version: runs-on: ubuntu-latest outputs: version: ${{ steps.get_version.outputs.version }} steps: - name: Get version from input or git id: get_version run: | if [ -n "${{ inputs.release-version }}" ]; then echo "version=${{ inputs.release-version }}" >> $GITHUB_OUTPUT else # Get the latest tag from the repo git clone --filter=blob:none --no-checkout $GITHUB_SERVER_URL/$GITHUB_REPOSITORY.git repo cd repo echo "version=$(git describe --tags --abbrev=0)" >> $GITHUB_OUTPUT fi shell: bash check_for_ngc_images: runs-on: ubuntu-latest outputs: images: ${{ steps.check_for_ngc_images.outputs.IMAGES }} steps: - name: Checkout repository uses: actions/checkout@v4 - name: Check for NGC PyTorch images id: check_for_ngc_images run: | bash ./.github/scripts/check_for_ngc_images.sh echo "IMAGES=$(cat ngc_images.json| jq -cr)" >> $GITHUB_OUTPUT build-wheels: needs: [get_version, check_for_ngc_images] uses: ./.github/workflows/_build_in_container.yml strategy: fail-fast: false matrix: container-image: ${{ fromJson(needs.check_for_ngc_images.outputs.images) }} with: runs-on: ${{ inputs.runs-on || 'ubuntu-22.04' }} container-image: ${{ matrix.container-image }} upload-to-release: ${{ inputs.upload-to-release || false }} release-version: ${{ needs.get_version.outputs.version }} ================================================ FILE: .github/workflows/publish.yaml ================================================ # This workflow will: # - Create a new Github release # - Build wheels for supported architectures # - Deploy the wheels to the Github release # - Release the static code to PyPi # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries name: Build wheels and deploy on: push: tags: - v* jobs: setup_release: name: Create Release runs-on: ubuntu-latest outputs: release-version: ${{ steps.extract_branch.outputs.branch }} steps: - name: Get the tag version id: extract_branch run: echo "branch=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT shell: bash - name: Create Release env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} run: gh release create ${{ steps.extract_branch.outputs.branch }} --repo $GITHUB_REPOSITORY --title ${{ steps.extract_branch.outputs.branch }} --generate-notes shell: bash build_wheels: name: Build Wheel needs: setup_release strategy: fail-fast: false matrix: # Using ubuntu-22.04 instead of 24.04 for more compatibility (glibc). Ideally we'd use the # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. os: [ubuntu-22.04, ubuntu-22.04-arm] python-version: ["3.10", "3.11", "3.12", "3.13"] torch-version: ["2.6.0", "2.7.1", "2.8.0", "2.9.1", "2.10.0"] cuda-version: ["11.8.0", "12.9.1", "13.0.1"] # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs) # when building without C++11 ABI and using it on nvcr images. cxx11_abi: ["FALSE", "TRUE"] exclude: # CUDA 11.8 is not supported by PyTorch 2.8+ - torch-version: "2.8.0" cuda-version: "11.8.0" - torch-version: "2.9.1" cuda-version: "11.8.0" - torch-version: "2.10.0" cuda-version: "11.8.0" # CUDA 13.0 is only supported by PyTorch 2.9+ - torch-version: "2.6.0" cuda-version: "13.0.1" - torch-version: "2.7.1" cuda-version: "13.0.1" - torch-version: "2.8.0" cuda-version: "13.0.1" # No aarch64 PyTorch wheels for 2.6.0, or 2.7.1+cu118 - torch-version: "2.6.0" os: ubuntu-22.04-arm - torch-version: "2.7.1" cuda-version: "11.8.0" os: ubuntu-22.04-arm # PyTorch 2.7+ pip wheels use CXX11_ABI=1 by default, no need for FALSE - torch-version: "2.7.1" cxx11_abi: "FALSE" - torch-version: "2.8.0" cxx11_abi: "FALSE" - torch-version: "2.9.1" cxx11_abi: "FALSE" - torch-version: "2.10.0" cxx11_abi: "FALSE" uses: ./.github/workflows/_build.yml with: runs-on: ${{ matrix.os }} python-version: ${{ matrix.python-version }} cuda-version: ${{ matrix.cuda-version }} torch-version: ${{ matrix.torch-version }} cxx11_abi: ${{ matrix.cxx11_abi }} release-version: ${{ needs.setup_release.outputs.release-version }} upload-to-release: true check_for_ngc_images: runs-on: ubuntu-latest outputs: images: ${{ steps.check_for_ngc_images.outputs.IMAGES }} steps: - name: Checkout repository uses: actions/checkout@v4 - name: Check for NGC PyTorch images id: check_for_ngc_images run: | bash ./.github/scripts/check_for_ngc_images.sh echo "IMAGES=$(cat ngc_images.json| jq -cr)" | tee -a $GITHUB_OUTPUT build_ngc_wheels: name: Build Wheel for NGC PyTorch needs: [setup_release, check_for_ngc_images] strategy: fail-fast: false matrix: os: [ubuntu-22.04, ubuntu-22.04-arm] container-image: ${{ fromJson(needs.check_for_ngc_images.outputs.images) }} uses: ./.github/workflows/_build_in_container.yml with: runs-on: ${{ matrix.os }} container-image: ${{ matrix.container-image }} release-version: ${{ needs.setup_release.outputs.release-version }} upload-to-release: true publish_package: name: Publish package needs: [build_wheels] if: always() && !cancelled() runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: python-version: "3.10" - name: Install dependencies run: | pip install ninja packaging setuptools wheel twine # We don't want to download anything CUDA-related here pip install torch --index-url https://download.pytorch.org/whl/cpu - name: Build core package env: MAMBA_SKIP_CUDA_BUILD: "TRUE" run: | python setup.py sdist --dist-dir=dist - name: Deploy env: TWINE_USERNAME: "__token__" TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} run: | python -m twine upload dist/* ================================================ FILE: .gitignore ================================================ *__pycache__/ *.egg-info/ build/ **.so *.hip *_hip.* ================================================ FILE: .gitmodules ================================================ [submodule "3rdparty/lm-evaluation-harness"] path = 3rdparty/lm-evaluation-harness url = https://github.com/EleutherAI/lm-evaluation-harness/ ================================================ FILE: AUTHORS ================================================ Tri Dao, tri@tridao.me Albert Gu, agu@andrew.cmu.edu ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright 2023 Tri Dao, Albert Gu Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: MANIFEST.in ================================================ recursive-include csrc * recursive-include csrc * README.md ================================================ FILE: README.md ================================================ # Mamba ![Mamba](assets/selection.png "Selective State Space") > **Mamba: Linear-Time Sequence Modeling with Selective State Spaces**\ > Albert Gu*, Tri Dao*\ > Paper: https://arxiv.org/abs/2312.00752 ![Mamba-2](assets/ssd_algorithm.png "State Space Dual Model") > **Transformers are SSMs: Generalized Models and Efficient Algorithms**\ > **Through Structured State Space Duality**\ > Tri Dao*, Albert Gu*\ > Paper: https://arxiv.org/abs/2405.21060 ![Mamba-3](assets/mamba3.png "Inference-first State Space Model") > **Mamba-3: Improved Sequence Modeling using State Space Principles**\ > **Through Structured State Space Duality**\ > Aakash Lahoti*, Kevin Y. Li*, Berlin Chen*, Caitlin Wang*, Aviv Bick, J. Zico Kolter, Tri Dao†, Albert Gu†\ > Paper: https://arxiv.org/abs/2603.15569 ## About Mamba is a new state space model architecture showing promising performance on information-dense data such as language modeling, where previous subquadratic models fall short of Transformers. It is based on the line of progress on [structured state space models](https://github.com/state-spaces/s4), with an efficient hardware-aware design and implementation in the spirit of [FlashAttention](https://github.com/Dao-AILab/flash-attention). ## Installation Install PyTorch first, then: - [Option] `pip install causal-conv1d>=1.4.0 --no-build-isolation`: an efficient implementation of a simple causal Conv1d layer used inside the Mamba block. - `pip install mamba-ssm --no-build-isolation`: the core Mamba package. - `pip install mamba-ssm[causal-conv1d] --no-build-isolation`: To install core Mamba package and causal-conv1d. `--no-build-isolation` is required so that pip uses your existing CUDA-enabled PyTorch instead of installing torch-cpu in an isolated build environment. Other requirements: - Linux - NVIDIA GPU - PyTorch 1.12+ - CUDA 11.6+ For AMD cards, see additional prerequisites below. ## Usage We expose several levels of interface with the Mamba model. ### Selective SSM Mamba is based on a selective SSM layer, which is the focus of the paper (Section 3; Algorithm 2). Source: [ops/selective_scan_interface.py](mamba_ssm/ops/selective_scan_interface.py). ### Mamba Block The main module of this repository is the Mamba architecture block wrapping the selective SSM. Source: [modules/mamba_simple.py](mamba_ssm/modules/mamba_simple.py). Usage: ``` python import torch from mamba_ssm import Mamba batch, length, dim = 2, 64, 16 x = torch.randn(batch, length, dim).to("cuda") model = Mamba( # This module uses roughly 3 * expand * d_model^2 parameters d_model=dim, # Model dimension d_model d_state=16, # SSM state expansion factor d_conv=4, # Local convolution width expand=2, # Block expansion factor ).to("cuda") y = model(x) assert y.shape == x.shape ``` ### Mamba-2 The Mamba-2 block is implemented at [modules/mamba2.py](mamba_ssm/modules/mamba2.py). A simpler version is at [modules/mamba2_simple.py](mamba_ssm/modules/mamba2_simple.py) The usage is similar to Mamba(-1): ``` python from mamba_ssm import Mamba2 model = Mamba2( # This module uses roughly 3 * expand * d_model^2 parameters d_model=dim, # Model dimension d_model d_state=64, # SSM state expansion factor, typically 64 or 128 d_conv=4, # Local convolution width expand=2, # Block expansion factor ).to("cuda") y = model(x) assert y.shape == x.shape ``` #### SSD A minimal version of the inner SSD module (Listing 1 from the Mamba-2 paper) with conversion between "discrete" and "continuous" SSM versions is at [modules/ssd_minimal.py](mamba_ssm/modules/ssd_minimal.py). ### Mamba Language Model Finally, we provide an example of a complete language model: a deep sequence model backbone (with repeating Mamba blocks) + language model head. Source: [models/mixer_seq_simple.py](mamba_ssm/models/mixer_seq_simple.py). This is an example of how to integrate Mamba into an end-to-end neural network. This example is used in the generation scripts below. ## Pretrained Models Pretrained models are uploaded to [Hugging Face](https://huggingface.co/state-spaces): `mamba-130m`, `mamba-370m`, `mamba-790m`, `mamba-1.4b`, `mamba-2.8b`, `mamba2-130m`, `mamba2-370m`, `mamba2-780m`, `mamba2-1.3b`, `mamba2-2.7b`, `transformerpp-2.7b`, `mamba2attn-2.7b`, trained on 300B tokens on the Pile, as well as `mamba-2.8b-slimpj` (trained on 600B tokens on the SlimPajama dataset). The models will be autodownloaded by the generation script below. These models were trained on the [Pile](https://huggingface.co/datasets/EleutherAI/pile), and follow the standard model dimensions described by GPT-3 and followed by many open source models: | Parameters | Layers | Model dim. | |------------|--------|------------| | 130M | 24 | 768 | | 370M | 48 | 1024 | | 790M | 48 | 1536 | | 1.4B | 48 | 2048 | | 2.8B | 64 | 2560 | (The layer count of Mamba doubles that of a Transformer with similar size, as two Mamba blocks are needed for each "layer" (MHA block + MLP block) of a Transformer.) Note: these are base models trained only for 300B tokens, without any form of downstream modification (instruction tuning, etc.). Performance is expected to be comparable or better than other architectures trained on similar data, but not to match larger or fine-tuned models. ## Evaluations To run zero-shot evaluations of models (corresponding to Table 3 of the paper), we use the [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness) library. 1. Install `lm-evaluation-harness` by `pip install lm-eval==0.4.2`. 2. Run evaluation with (more documentation at the [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor) repo): ``` sh lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-130m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256 python evals/lm_harness_eval.py --model hf --model_args pretrained=EleutherAI/pythia-160m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64 ``` To reproduce the results on the `mamba-2.8b-slimpj` model reported in the blogposts: ``` sh lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks boolq,piqa,hellaswag,winogrande,arc_easy,arc_challenge,openbookqa,race,truthfulqa_mc2 --device cuda --batch_size 256 lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks mmlu --num_fewshot 5 --device cuda --batch_size 256 ``` To run evaluations on Mamba-2 models, simply replace the model names: ``` sh lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba2-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256 lm_eval --model mamba_ssm --model_args pretrained=state-spaces/transformerpp-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256 lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba2attn-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256 ``` Note that the result of each task might differ from reported values by 0.1-0.3 due to noise in the evaluation process. ## Inference The script [benchmarks/benchmark_generation_mamba_simple.py](benchmarks/benchmark_generation_mamba_simple.py) 1. autoloads a model from the Hugging Face Hub, 2. generates completions of a user-specified prompt, 3. benchmarks the inference speed of this generation. Other configurable options include the top-p (nucleus sampling) probability, and the softmax temperature. ### Examples To test generation latency (e.g. batch size = 1) with different sampling strategies: ``` sh python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2 python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2 python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --minp 0.05 --topk 0 --temperature 0.7 --repetition-penalty 1.2 ``` To test generation throughput with random prompts (e.g. large batch size): ``` sh python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --batch 64 python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --batch 64 ``` With Mamba-2, you just need to change the model name: ``` sh python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba2-2.7b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2 ``` ## Troubleshooting ### Precision Our models were trained using PyTorch [AMP](https://pytorch.org/docs/stable/amp.html) for mixed precision. AMP keeps model parameters in float32 and casts to half precision when necessary. On the other hand, other frameworks like DeepSpeed store parameters in float16 and upcasts when necessary (e.g. for optimizer accumulation). We've observed that higher precision for the main model parameters may be necessary, because SSMs are sensitive to their recurrent dynamics. If you are experiencing instabilities, as a first step please try a framework storing parameters in fp32 (such as AMP). ### Initialization Some parts of the model have initializations inherited from prior work on S4 models. For [example](https://github.com/state-spaces/mamba/blob/f0affcf69f06d1d06cef018ff640bf080a11c421/mamba_ssm/modules/mamba_simple.py#L102), the $\Delta$ parameter has a targeted range by initializing the bias of its linear projection. However, some frameworks may have post-initialization hooks (e.g. setting all bias terms in `nn.Linear` modules to zero). If this is the case, you may have to add custom logic (e.g. this [line](https://github.com/state-spaces/mamba/blob/f0affcf69f06d1d06cef018ff640bf080a11c421/mamba_ssm/modules/mamba_simple.py#L104) turns off re-initializing in our trainer, but would be a no-op in any other framework) that is specific to the training framework. ## Additional Prerequisites for AMD cards ### Patching ROCm If you are on ROCm 6.0, run the following steps to avoid errors during compilation. This is not required for ROCm 6.1 onwards. 1. Locate your ROCm installation directory. This is typically found at `/opt/rocm/`, but may vary depending on your installation. 2. Apply the Patch. Run with `sudo` in case you encounter permission issues. ```bash patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h < rocm_patch/rocm6_0.patch ``` ## Citation If you use this codebase, or otherwise find our work valuable, please cite Mamba: ``` @article{mamba, title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces}, author={Gu, Albert and Dao, Tri}, journal={arXiv preprint arXiv:2312.00752}, year={2023} } @inproceedings{mamba2, title={Transformers are {SSM}s: Generalized Models and Efficient Algorithms Through Structured State Space Duality}, author={Dao, Tri and Gu, Albert}, booktitle={International Conference on Machine Learning (ICML)}, year={2024} } @misc{lahoti2026mamba3improvedsequencemodeling, title={Mamba-3: Improved Sequence Modeling using State Space Principles}, author={Aakash Lahoti and Kevin Y. Li and Berlin Chen and Caitlin Wang and Aviv Bick and J. Zico Kolter and Tri Dao and Albert Gu}, year={2026}, eprint={2603.15569}, archivePrefix={arXiv}, primaryClass={cs.LG}, url={https://arxiv.org/abs/2603.15569}, } ``` ================================================ FILE: benchmarks/benchmark_generation_mamba_simple.py ================================================ # Copyright (c) 2023, Tri Dao, Albert Gu. import argparse import time import json import torch import torch.nn.functional as F from einops import rearrange from transformers import AutoTokenizer, AutoModelForCausalLM from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel parser = argparse.ArgumentParser(description="Generation benchmarking") parser.add_argument("--model-name", type=str, default="state-spaces/mamba-130m") parser.add_argument("--prompt", type=str, default=None) parser.add_argument("--promptlen", type=int, default=100) parser.add_argument("--genlen", type=int, default=100) parser.add_argument("--temperature", type=float, default=1.0) parser.add_argument("--topk", type=int, default=1) parser.add_argument("--topp", type=float, default=1.0) parser.add_argument("--minp", type=float, default=0.0) parser.add_argument("--repetition-penalty", type=float, default=1.0) parser.add_argument("--batch", type=int, default=1) args = parser.parse_args() repeats = 3 device = "cuda" dtype = torch.float16 print(f"Loading model {args.model_name}") is_mamba = args.model_name.startswith("state-spaces/mamba") or args.model_name.startswith("state-spaces/transformerpp") if is_mamba: tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") model = MambaLMHeadModel.from_pretrained(args.model_name, device=device, dtype=dtype) else: tokenizer = AutoTokenizer.from_pretrained(args.model_name) model = AutoModelForCausalLM.from_pretrained(args.model_name, device_map={"": device}, torch_dtype=dtype) model.eval() print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") torch.random.manual_seed(0) if args.prompt is None: input_ids = torch.randint(1, 1000, (args.batch, args.promptlen), dtype=torch.long, device="cuda") attn_mask = torch.ones_like(input_ids, dtype=torch.long, device="cuda") else: tokens = tokenizer(args.prompt, return_tensors="pt") input_ids = tokens.input_ids.to(device=device) attn_mask = tokens.attention_mask.to(device=device) max_length = input_ids.shape[1] + args.genlen if is_mamba: fn = lambda: model.generate( input_ids=input_ids, max_length=max_length, cg=True, return_dict_in_generate=True, output_scores=True, enable_timing=False, temperature=args.temperature, top_k=args.topk, top_p=args.topp, min_p=args.minp, repetition_penalty=args.repetition_penalty, ) else: fn = lambda: model.generate( input_ids=input_ids, attention_mask=attn_mask, max_length=max_length, return_dict_in_generate=True, pad_token_id=tokenizer.eos_token_id, do_sample=True, temperature=args.temperature, top_k=args.topk, top_p=args.topp, repetition_penalty=args.repetition_penalty, ) out = fn() if args.prompt is not None: print(tokenizer.batch_decode(out.sequences.tolist())) torch.cuda.synchronize() start = time.time() for _ in range(repeats): fn() torch.cuda.synchronize() print(f"Prompt length: {len(input_ids[0])}, generation length: {len(out.sequences[0]) - len(input_ids[0])}") print(f"{args.model_name} prompt processing + decoding time: {(time.time() - start) / repeats * 1000:.0f}ms") ================================================ FILE: csrc/selective_scan/reverse_scan.cuh ================================================ /****************************************************************************** * Copyright (c) 2023, Tri Dao. ******************************************************************************/ #pragma once #ifndef USE_ROCM #include #include #include #include // #include #else #include namespace cub = hipcub; #endif #include "uninitialized_copy.cuh" /** * Perform a reverse sequential reduction over \p LENGTH elements of the \p input array. The aggregate is returned. */ template < int LENGTH, typename T, typename ReductionOp> __device__ __forceinline__ T ThreadReverseReduce(const T (&input)[LENGTH], ReductionOp reduction_op) { static_assert(LENGTH > 0); T retval = input[LENGTH - 1]; #pragma unroll for (int i = LENGTH - 2; i >= 0; --i) { retval = reduction_op(retval, input[i]); } return retval; } /** * Perform a sequential inclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix. The aggregate is returned. */ template < int LENGTH, typename T, typename ScanOp> __device__ __forceinline__ T ThreadReverseScanInclusive( const T (&input)[LENGTH], T (&output)[LENGTH], ScanOp scan_op, const T postfix) { T inclusive = postfix; #pragma unroll for (int i = LENGTH - 1; i >= 0; --i) { inclusive = scan_op(inclusive, input[i]); output[i] = inclusive; } return inclusive; } /** * Perform a sequential exclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix. The aggregate is returned. */ template < int LENGTH, typename T, typename ScanOp> __device__ __forceinline__ T ThreadReverseScanExclusive( const T (&input)[LENGTH], T (&output)[LENGTH], ScanOp scan_op, const T postfix) { // Careful, output maybe be aliased to input T exclusive = postfix; T inclusive; #pragma unroll for (int i = LENGTH - 1; i >= 0; --i) { inclusive = scan_op(exclusive, input[i]); output[i] = exclusive; exclusive = inclusive; } return inclusive; } /** * \brief WarpReverseScan provides SHFL-based variants of parallel postfix scan of items partitioned across a CUDA thread warp. * * LOGICAL_WARP_THREADS must be a power-of-two */ template < typename T, ///< Data type being scanned int LOGICAL_WARP_THREADS ///< Number of threads per logical warp > struct WarpReverseScan { //--------------------------------------------------------------------- // Constants and type definitions //--------------------------------------------------------------------- /// Whether the logical warp size and the PTX warp size coincide // In hipcub, warp_threads is defined as HIPCUB_WARP_THREADS ::rocprim::warp_size() // While in cub, it's defined as a macro that takes a redundant unused argument. #ifndef USE_ROCM #define WARP_THREADS CUB_WARP_THREADS(0) #else // ROCm 7.0+: HIPCUB_WARP_THREADS (rocprim::warp_size()) is no longer constexpr. // We need a compile-time constant for IS_ARCH_WARP below. // See: https://rocm.docs.amd.com/en/latest/about/release-notes.html #if defined(__AMDGCN_WAVEFRONT_SIZE) // Deprecated but still available and constexpr in ROCm 7.x #define WARP_THREADS __AMDGCN_WAVEFRONT_SIZE #elif defined(__gfx942__) || defined(__gfx941__) || defined(__gfx940__) // AMD Instinct MI300 series (CDNA3) - 64-wide wavefronts #define WARP_THREADS 64 #elif defined(__gfx90a__) // AMD Instinct MI200 series (CDNA2) - 64-wide wavefronts #define WARP_THREADS 64 #elif defined(__gfx908__) // AMD Instinct MI100 (CDNA1) - 64-wide wavefronts #define WARP_THREADS 64 #elif defined(__gfx906__) || defined(__gfx900__) // AMD Instinct MI50/MI60 (Vega) - 64-wide wavefronts #define WARP_THREADS 64 #elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) // AMD Radeon RX 7000 series (RDNA3) - 32-wide wavefronts #define WARP_THREADS 32 #elif defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1034__) // AMD Radeon RX 6000 series (RDNA2) - 32-wide wavefronts #define WARP_THREADS 32 #elif defined(__gfx1010__) || defined(__gfx1011__) || defined(__gfx1012__) // AMD Radeon RX 5000 series (RDNA1) - 32-wide wavefronts #define WARP_THREADS 32 #else // Unknown architecture - default to 64 (CDNA/GCN) // This may not be optimal for RDNA GPUs #pragma message("Warning: Unknown AMD GPU architecture. Defaulting WARP_THREADS to 64. " \ "For RDNA GPUs (gfx10xx/gfx11xx), this should be 32.") #define WARP_THREADS 64 #endif #endif static constexpr bool IS_ARCH_WARP = (LOGICAL_WARP_THREADS == WARP_THREADS); /// The number of warp scan steps static constexpr int STEPS = cub::Log2::VALUE; static_assert(LOGICAL_WARP_THREADS == 1 << STEPS); //--------------------------------------------------------------------- // Thread fields //--------------------------------------------------------------------- /// Lane index in logical warp unsigned int lane_id; /// Logical warp index in 32-thread physical warp unsigned int warp_id; /// 32-thread physical warp member mask of logical warp unsigned int member_mask; //--------------------------------------------------------------------- // Construction //--------------------------------------------------------------------- /// Constructor explicit __device__ __forceinline__ WarpReverseScan() #ifndef USE_ROCM : lane_id(threadIdx.x & 0x1f) // CUDA: 32-thread warps, mask = 31 #else : lane_id(threadIdx.x & (WARP_THREADS - 1)) // ROCm: use actual wavefront size (64 or 32) #endif , warp_id(IS_ARCH_WARP ? 0 : (lane_id / LOGICAL_WARP_THREADS)) , member_mask(cub::WarpMask(warp_id)) { if (!IS_ARCH_WARP) { lane_id = lane_id % LOGICAL_WARP_THREADS; } } /// Broadcast __device__ __forceinline__ T Broadcast( T input, ///< [in] The value to broadcast int src_lane) ///< [in] Which warp lane is to do the broadcasting { return cub::ShuffleIndex(input, src_lane, member_mask); } /// Inclusive scan template __device__ __forceinline__ void InclusiveReverseScan( T input, ///< [in] Calling thread's input item. T &inclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input. ScanOpT scan_op) ///< [in] Binary scan operator { inclusive_output = input; #pragma unroll for (int STEP = 0; STEP < STEPS; STEP++) { int offset = 1 << STEP; T temp = cub::ShuffleDown( inclusive_output, offset, LOGICAL_WARP_THREADS - 1, member_mask ); // Perform scan op if from a valid peer inclusive_output = static_cast(lane_id) >= LOGICAL_WARP_THREADS - offset ? inclusive_output : scan_op(temp, inclusive_output); } } /// Exclusive scan // Get exclusive from inclusive template __device__ __forceinline__ void ExclusiveReverseScan( T input, ///< [in] Calling thread's input item. T &exclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input. ScanOpT scan_op, ///< [in] Binary scan operator T &warp_aggregate) ///< [out] Warp-wide aggregate reduction of input items. { T inclusive_output; InclusiveReverseScan(input, inclusive_output, scan_op); warp_aggregate = cub::ShuffleIndex(inclusive_output, 0, member_mask); // initial value unknown exclusive_output = cub::ShuffleDown( inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask ); } /** * \brief Computes both inclusive and exclusive reverse scans using the specified binary scan functor across the calling warp. Because no initial value is supplied, the \p exclusive_output computed for the last warp-lane is undefined. */ template __device__ __forceinline__ void ReverseScan( T input, ///< [in] Calling thread's input item. T &inclusive_output, ///< [out] Calling thread's inclusive-scan output item. T &exclusive_output, ///< [out] Calling thread's exclusive-scan output item. ScanOpT scan_op) ///< [in] Binary scan operator { InclusiveReverseScan(input, inclusive_output, scan_op); // initial value unknown exclusive_output = cub::ShuffleDown( inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask ); } }; /** * \brief BlockReverseScan provides variants of raking-based parallel postfix scan across a CUDA thread block. */ template < typename T, ///< Data type being scanned int BLOCK_DIM_X, ///< The thread block length in threads along the X dimension bool MEMOIZE=false ///< Whether or not to buffer outer raking scan partials to incur fewer shared memory reads at the expense of higher register pressure > struct BlockReverseScan { //--------------------------------------------------------------------- // Types and constants //--------------------------------------------------------------------- /// Constants /// The thread block size in threads static constexpr int BLOCK_THREADS = BLOCK_DIM_X; /// Layout type for padded thread block raking grid using BlockRakingLayout = cub::BlockRakingLayout; // The number of reduction elements is not a multiple of the number of raking threads for now static_assert(BlockRakingLayout::UNGUARDED); /// Number of raking threads static constexpr int RAKING_THREADS = BlockRakingLayout::RAKING_THREADS; /// Number of raking elements per warp synchronous raking thread static constexpr int SEGMENT_LENGTH = BlockRakingLayout::SEGMENT_LENGTH; /// Cooperative work can be entirely warp synchronous static constexpr bool WARP_SYNCHRONOUS = (int(BLOCK_THREADS) == int(RAKING_THREADS)); /// WarpReverseScan utility type using WarpReverseScan = WarpReverseScan; /// Shared memory storage layout type struct _TempStorage { typename BlockRakingLayout::TempStorage raking_grid; ///< Padded thread block raking grid }; /// Alias wrapper allowing storage to be unioned struct TempStorage : cub::Uninitialized<_TempStorage> {}; //--------------------------------------------------------------------- // Per-thread fields //--------------------------------------------------------------------- // Thread fields _TempStorage &temp_storage; unsigned int linear_tid; T cached_segment[SEGMENT_LENGTH]; //--------------------------------------------------------------------- // Utility methods //--------------------------------------------------------------------- /// Performs upsweep raking reduction, returning the aggregate template __device__ __forceinline__ T Upsweep(ScanOp scan_op) { T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid); // Read data into registers #pragma unroll for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; } T raking_partial = cached_segment[SEGMENT_LENGTH - 1]; #pragma unroll for (int i = SEGMENT_LENGTH - 2; i >= 0; --i) { raking_partial = scan_op(raking_partial, cached_segment[i]); } return raking_partial; } /// Performs exclusive downsweep raking scan template __device__ __forceinline__ void ExclusiveDownsweep( ScanOp scan_op, T raking_partial) { T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid); // Read data back into registers if (!MEMOIZE) { #pragma unroll for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; } } ThreadReverseScanExclusive(cached_segment, cached_segment, scan_op, raking_partial); // Write data back to smem #pragma unroll for (int i = 0; i < SEGMENT_LENGTH; ++i) { smem_raking_ptr[i] = cached_segment[i]; } } //--------------------------------------------------------------------- // Constructors //--------------------------------------------------------------------- /// Constructor __device__ __forceinline__ BlockReverseScan( TempStorage &temp_storage) : temp_storage(temp_storage.Alias()), linear_tid(cub::RowMajorTid(BLOCK_DIM_X, 1, 1)) {} /// Computes an exclusive thread block-wide postfix scan using the specified binary \p scan_op functor. Each thread contributes one input element. the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by lane0 in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs. template < typename ScanOp, typename BlockPostfixCallbackOp> __device__ __forceinline__ void ExclusiveReverseScan( T input, ///< [in] Calling thread's input item T &exclusive_output, ///< [out] Calling thread's output item (may be aliased to \p input) ScanOp scan_op, ///< [in] Binary scan operator BlockPostfixCallbackOp &block_postfix_callback_op) ///< [in-out] [warp0 only] Call-back functor for specifying a thread block-wide postfix to be applied to all inputs. { if (WARP_SYNCHRONOUS) { // Short-circuit directly to warp-synchronous scan T block_aggregate; WarpReverseScan warp_scan; warp_scan.ExclusiveReverseScan(input, exclusive_output, scan_op, block_aggregate); // Obtain warp-wide postfix in lane0, then broadcast to other lanes T block_postfix = block_postfix_callback_op(block_aggregate); block_postfix = warp_scan.Broadcast(block_postfix, 0); exclusive_output = linear_tid == BLOCK_THREADS - 1 ? block_postfix : scan_op(block_postfix, exclusive_output); } else { // Place thread partial into shared memory raking grid T *placement_ptr = BlockRakingLayout::PlacementPtr(temp_storage.raking_grid, linear_tid); detail::uninitialized_copy(placement_ptr, input); __syncthreads(); // Reduce parallelism down to just raking threads if (linear_tid < RAKING_THREADS) { WarpReverseScan warp_scan; // Raking upsweep reduction across shared partials T upsweep_partial = Upsweep(scan_op); // Warp-synchronous scan T exclusive_partial, block_aggregate; warp_scan.ExclusiveReverseScan(upsweep_partial, exclusive_partial, scan_op, block_aggregate); // Obtain block-wide postfix in lane0, then broadcast to other lanes T block_postfix = block_postfix_callback_op(block_aggregate); block_postfix = warp_scan.Broadcast(block_postfix, 0); // Update postfix with warpscan exclusive partial T downsweep_postfix = linear_tid == RAKING_THREADS - 1 ? block_postfix : scan_op(block_postfix, exclusive_partial); // Exclusive raking downsweep scan ExclusiveDownsweep(scan_op, downsweep_postfix); } __syncthreads(); // Grab thread postfix from shared memory exclusive_output = *placement_ptr; // // Compute warp scan in each warp. // // The exclusive output from the last lane in each warp is invalid. // T inclusive_output; // WarpReverseScan warp_scan; // warp_scan.ReverseScan(input, inclusive_output, exclusive_output, scan_op); // // Compute the warp-wide postfix and block-wide aggregate for each warp. Warp postfix for the last warp is invalid. // T block_aggregate; // T warp_postfix = ComputeWarpPostfix(scan_op, inclusive_output, block_aggregate); // // Apply warp postfix to our lane's partial // if (warp_id != 0) { // exclusive_output = scan_op(warp_postfix, exclusive_output); // if (lane_id == 0) { exclusive_output = warp_postfix; } // } // // Use the first warp to determine the thread block postfix, returning the result in lane0 // if (warp_id == 0) { // T block_postfix = block_postfix_callback_op(block_aggregate); // if (lane_id == 0) { // // Share the postfix with all threads // detail::uninitialized_copy(&temp_storage.block_postfix, // block_postfix); // exclusive_output = block_postfix; // The block postfix is the exclusive output for tid0 // } // } // __syncthreads(); // // Incorporate thread block postfix into outputs // T block_postfix = temp_storage.block_postfix; // if (linear_tid > 0) { exclusive_output = scan_op(block_postfix, exclusive_output); } } } /** * \brief Computes an inclusive block-wide postfix scan using the specified binary \p scan_op functor. Each thread contributes an array of consecutive input elements. the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by lane0 in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs. */ template < int ITEMS_PER_THREAD, typename ScanOp, typename BlockPostfixCallbackOp> __device__ __forceinline__ void InclusiveReverseScan( T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items T (&output)[ITEMS_PER_THREAD], ///< [out] Calling thread's output items (may be aliased to \p input) ScanOp scan_op, ///< [in] Binary scan functor BlockPostfixCallbackOp &block_postfix_callback_op) ///< [in-out] [warp0 only] Call-back functor for specifying a block-wide postfix to be applied to the logical input sequence. { // Reduce consecutive thread items in registers T thread_postfix = ThreadReverseReduce(input, scan_op); // Exclusive thread block-scan ExclusiveReverseScan(thread_postfix, thread_postfix, scan_op, block_postfix_callback_op); // Inclusive scan in registers with postfix as seed ThreadReverseScanInclusive(input, output, scan_op, thread_postfix); } }; ================================================ FILE: csrc/selective_scan/selective_scan.cpp ================================================ /****************************************************************************** * Copyright (c) 2023, Tri Dao. ******************************************************************************/ #include #include #include #include #include "selective_scan.h" #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") #define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ if (ITYPE == at::ScalarType::Half) { \ using input_t = at::Half; \ __VA_ARGS__(); \ } else if (ITYPE == at::ScalarType::BFloat16) { \ using input_t = at::BFloat16; \ __VA_ARGS__(); \ } else if (ITYPE == at::ScalarType::Float) { \ using input_t = float; \ __VA_ARGS__(); \ } else { \ AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ } #define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \ if (WTYPE == at::ScalarType::Half) { \ using weight_t = at::Half; \ __VA_ARGS__(); \ } else if (WTYPE == at::ScalarType::BFloat16) { \ using weight_t = at::BFloat16; \ __VA_ARGS__(); \ } else if (WTYPE == at::ScalarType::Float) { \ using weight_t = float; \ __VA_ARGS__(); \ } else { \ AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \ } #define DISPATCH_WTYPE_FLOAT_AND_COMPLEX(WTYPE, NAME, ...) \ if (WTYPE == at::ScalarType::Float) { \ using weight_t = float; \ __VA_ARGS__(); \ } else if (WTYPE == at::ScalarType::ComplexFloat) { \ using weight_t = c10::complex; \ __VA_ARGS__(); \ } else { \ AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \ } template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); void set_ssm_params_fwd(SSMParamsBase ¶ms, // sizes const size_t batch, const size_t dim, const size_t seqlen, const size_t dstate, const size_t n_groups, const size_t n_chunks, const bool is_variable_B, const bool is_variable_C, // device pointers const at::Tensor u, const at::Tensor delta, const at::Tensor A, const at::Tensor B, const at::Tensor C, const at::Tensor out, const at::Tensor z, const at::Tensor out_z, void* D_ptr, void* delta_bias_ptr, void* x_ptr, bool has_z, bool delta_softplus) { // Reset the parameters memset(¶ms, 0, sizeof(params)); params.batch = batch; params.dim = dim; params.seqlen = seqlen; params.dstate = dstate; params.n_groups = n_groups; params.n_chunks = n_chunks; params.dim_ngroups_ratio = dim / n_groups; params.delta_softplus = delta_softplus; params.is_variable_B = is_variable_B; params.is_variable_C = is_variable_C; // Set the pointers and strides. params.u_ptr = u.data_ptr(); params.delta_ptr = delta.data_ptr(); params.A_ptr = A.data_ptr(); params.B_ptr = B.data_ptr(); params.C_ptr = C.data_ptr(); params.D_ptr = D_ptr; params.delta_bias_ptr = delta_bias_ptr; params.out_ptr = out.data_ptr(); params.x_ptr = x_ptr; params.z_ptr = has_z ? z.data_ptr() : nullptr; params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr; // All stride are in elements, not bytes. params.A_d_stride = A.stride(0); params.A_dstate_stride = A.stride(1); if (!is_variable_B) { params.B_d_stride = B.stride(0); } else { params.B_batch_stride = B.stride(0); params.B_group_stride = B.stride(1); } params.B_dstate_stride = !is_variable_B ? B.stride(1) : B.stride(2); if (!is_variable_C) { params.C_d_stride = C.stride(0); } else { params.C_batch_stride = C.stride(0); params.C_group_stride = C.stride(1); } params.C_dstate_stride = !is_variable_C ? C.stride(1) : C.stride(2); params.u_batch_stride = u.stride(0); params.u_d_stride = u.stride(1); params.delta_batch_stride = delta.stride(0); params.delta_d_stride = delta.stride(1); if (has_z) { params.z_batch_stride = z.stride(0); params.z_d_stride = z.stride(1); params.out_z_batch_stride = out_z.stride(0); params.out_z_d_stride = out_z.stride(1); } params.out_batch_stride = out.stride(0); params.out_d_stride = out.stride(1); } void set_ssm_params_bwd(SSMParamsBwd ¶ms, // sizes const size_t batch, const size_t dim, const size_t seqlen, const size_t dstate, const size_t n_groups, const size_t n_chunks, const bool is_variable_B, const bool is_variable_C, // device pointers const at::Tensor u, const at::Tensor delta, const at::Tensor A, const at::Tensor B, const at::Tensor C, const at::Tensor z, const at::Tensor out, const at::Tensor out_z, void* D_ptr, void* delta_bias_ptr, void* x_ptr, const at::Tensor dout, const at::Tensor du, const at::Tensor ddelta, const at::Tensor dA, const at::Tensor dB, const at::Tensor dC, const at::Tensor dz, void* dD_ptr, void* ddelta_bias_ptr, bool has_z, bool delta_softplus, bool recompute_out_z) { // Pass in "dout" instead of "out", we're not gonna use "out" unless we have z set_ssm_params_fwd(params, batch, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C, u, delta, A, B, C, has_z ? out : dout, has_z ? z : dout, // If not recompute_out_z, pass dout instead of out_z. // This won't be used by the bwd kernel recompute_out_z ? out_z : dout, D_ptr, delta_bias_ptr, x_ptr, has_z, delta_softplus); if (!recompute_out_z) { params.out_z_ptr = nullptr; } // Set the pointers and strides. params.dout_ptr = dout.data_ptr(); params.du_ptr = du.data_ptr(); params.dA_ptr = dA.data_ptr(); params.dB_ptr = dB.data_ptr(); params.dC_ptr = dC.data_ptr(); params.dD_ptr = dD_ptr; params.ddelta_ptr = ddelta.data_ptr(); params.ddelta_bias_ptr = ddelta_bias_ptr; params.dz_ptr = has_z ? dz.data_ptr() : nullptr; // All stride are in elements, not bytes. params.dout_batch_stride = dout.stride(0); params.dout_d_stride = dout.stride(1); params.dA_d_stride = dA.stride(0); params.dA_dstate_stride = dA.stride(1); if (!is_variable_B) { params.dB_d_stride = dB.stride(0); } else { params.dB_batch_stride = dB.stride(0); params.dB_group_stride = dB.stride(1); } params.dB_dstate_stride = !is_variable_B ? dB.stride(1) : dB.stride(2); if (!is_variable_C) { params.dC_d_stride = dC.stride(0); } else { params.dC_batch_stride = dC.stride(0); params.dC_group_stride = dC.stride(1); } params.dC_dstate_stride = !is_variable_C ? dC.stride(1) : dC.stride(2); params.du_batch_stride = du.stride(0); params.du_d_stride = du.stride(1); params.ddelta_batch_stride = ddelta.stride(0); params.ddelta_d_stride = ddelta.stride(1); if (has_z) { params.dz_batch_stride = dz.stride(0); params.dz_d_stride = dz.stride(1); } } std::vector selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta, const at::Tensor &A, const at::Tensor &B, const at::Tensor &C, const c10::optional &D_, const c10::optional &z_, const c10::optional &delta_bias_, bool delta_softplus) { auto input_type = u.scalar_type(); auto weight_type = A.scalar_type(); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::ComplexFloat); const bool is_variable_B = B.dim() >= 3; const bool is_variable_C = C.dim() >= 3; const bool is_complex = weight_type == at::ScalarType::ComplexFloat; TORCH_CHECK(delta.scalar_type() == input_type); TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type)); TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type)); TORCH_CHECK(u.is_cuda()); TORCH_CHECK(delta.is_cuda()); TORCH_CHECK(A.is_cuda()); TORCH_CHECK(B.is_cuda()); TORCH_CHECK(C.is_cuda()); TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1); TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1); const auto sizes = u.sizes(); const int batch_size = sizes[0]; const int dim = sizes[1]; const int seqlen = sizes[2]; const int dstate = A.size(1); const int n_groups = is_variable_B ? B.size(1) : 1; TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256"); CHECK_SHAPE(u, batch_size, dim, seqlen); CHECK_SHAPE(delta, batch_size, dim, seqlen); CHECK_SHAPE(A, dim, dstate); if (!is_variable_B) { CHECK_SHAPE(B, dim, dstate); } else { CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2); TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1); } if (!is_variable_C) { CHECK_SHAPE(C, dim, dstate); } else { CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2); TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1); } if (D_.has_value()) { auto D = D_.value(); TORCH_CHECK(D.scalar_type() == at::ScalarType::Float); TORCH_CHECK(D.is_cuda()); TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1); CHECK_SHAPE(D, dim); } if (delta_bias_.has_value()) { auto delta_bias = delta_bias_.value(); TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float); TORCH_CHECK(delta_bias.is_cuda()); TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1); CHECK_SHAPE(delta_bias, dim); } at::Tensor z, out_z; const bool has_z = z_.has_value(); if (has_z) { z = z_.value(); TORCH_CHECK(z.scalar_type() == input_type); TORCH_CHECK(z.is_cuda()); TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1); CHECK_SHAPE(z, batch_size, dim, seqlen); out_z = torch::empty_like(z); } const int n_chunks = (seqlen + 2048 - 1) / 2048; // const int n_chunks = (seqlen + 1024 - 1) / 1024; // at::Tensor out = torch::empty_like(u); // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout at::Tensor out = torch::empty_like(delta); at::Tensor x; x = torch::empty({batch_size, dim, n_chunks, dstate * 2}, u.options().dtype(weight_type)); SSMParamsBase params; set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C, u, delta, A, B, C, out, z, out_z, D_.has_value() ? D_.value().data_ptr() : nullptr, delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr, x.data_ptr(), has_z, delta_softplus); // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing at::cuda::CUDAGuard device_guard{u.device()}; auto stream = at::cuda::getCurrentCUDAStream().stream(); DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] { DISPATCH_WTYPE_FLOAT_AND_COMPLEX(A.scalar_type(), "selective_scan_fwd", [&] { selective_scan_fwd_cuda(params, stream); }); }); std::vector result = {out, x}; if (has_z) { result.push_back(out_z); } return result; } std::vector selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta, const at::Tensor &A, const at::Tensor &B, const at::Tensor &C, const c10::optional &D_, const c10::optional &z_, const c10::optional &delta_bias_, const at::Tensor &dout, const c10::optional &x_, const c10::optional &out_, c10::optional &dz_, bool delta_softplus, bool recompute_out_z) { auto input_type = u.scalar_type(); auto weight_type = A.scalar_type(); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::ComplexFloat); const bool is_variable_B = B.dim() >= 3; const bool is_variable_C = C.dim() >= 3; const bool is_complex = weight_type == at::ScalarType::ComplexFloat; TORCH_CHECK(delta.scalar_type() == input_type); TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type)); TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type)); TORCH_CHECK(dout.scalar_type() == input_type); TORCH_CHECK(u.is_cuda()); TORCH_CHECK(delta.is_cuda()); TORCH_CHECK(A.is_cuda()); TORCH_CHECK(B.is_cuda()); TORCH_CHECK(C.is_cuda()); TORCH_CHECK(dout.is_cuda()); TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1); TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1); TORCH_CHECK(dout.stride(-1) == 1 || dout.size(-1) == 1); const auto sizes = u.sizes(); const int batch_size = sizes[0]; const int dim = sizes[1]; const int seqlen = sizes[2]; const int dstate = A.size(1); const int n_groups = is_variable_B ? B.size(1) : 1; TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256"); CHECK_SHAPE(u, batch_size, dim, seqlen); CHECK_SHAPE(delta, batch_size, dim, seqlen); CHECK_SHAPE(A, dim, dstate); if (!is_variable_B) { CHECK_SHAPE(B, dim, dstate); } else { CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2); TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1); } if (!is_variable_C) { CHECK_SHAPE(C, dim, dstate); } else { CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2); TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1); } CHECK_SHAPE(dout, batch_size, dim, seqlen); if (D_.has_value()) { auto D = D_.value(); TORCH_CHECK(D.scalar_type() == at::ScalarType::Float); TORCH_CHECK(D.is_cuda()); TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1); CHECK_SHAPE(D, dim); } if (delta_bias_.has_value()) { auto delta_bias = delta_bias_.value(); TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float); TORCH_CHECK(delta_bias.is_cuda()); TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1); CHECK_SHAPE(delta_bias, dim); } at::Tensor z, out, dz, out_z; const bool has_z = z_.has_value(); if (has_z) { z = z_.value(); TORCH_CHECK(z.scalar_type() == input_type); TORCH_CHECK(z.is_cuda()); TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1); CHECK_SHAPE(z, batch_size, dim, seqlen); TORCH_CHECK(out_.has_value()); out = out_.value(); TORCH_CHECK(out.scalar_type() == input_type); TORCH_CHECK(out.is_cuda()); TORCH_CHECK(out.stride(-1) == 1 || out.size(-1) == 1); CHECK_SHAPE(out, batch_size, dim, seqlen); if (dz_.has_value()) { dz = dz_.value(); TORCH_CHECK(dz.scalar_type() == input_type); TORCH_CHECK(dz.is_cuda()); TORCH_CHECK(dz.stride(-1) == 1 || dz.size(-1) == 1); CHECK_SHAPE(dz, batch_size, dim, seqlen); } else { dz = torch::empty_like(z); } if (recompute_out_z) { out_z = torch::empty_like(out); } } const int n_chunks = (seqlen + 2048 - 1) / 2048; // const int n_chunks = (seqlen + 1024 - 1) / 1024; if (n_chunks > 1) { TORCH_CHECK(x_.has_value()); } if (x_.has_value()) { auto x = x_.value(); TORCH_CHECK(x.scalar_type() == weight_type); TORCH_CHECK(x.is_cuda()); TORCH_CHECK(x.is_contiguous()); CHECK_SHAPE(x, batch_size, dim, n_chunks, 2 * dstate); } at::Tensor du = torch::empty_like(u); at::Tensor ddelta = torch::empty_like(delta); at::Tensor dA = torch::zeros_like(A); at::Tensor dB = !is_variable_B ? torch::zeros_like(B) : torch::zeros_like(B, B.options().dtype(torch::kFloat32)); at::Tensor dC = !is_variable_C ? torch::zeros_like(C) : torch::zeros_like(C, C.options().dtype(torch::kFloat32)); at::Tensor dD; if (D_.has_value()) { dD = torch::zeros_like(D_.value()); } at::Tensor ddelta_bias; if (delta_bias_.has_value()) { ddelta_bias = torch::zeros_like(delta_bias_.value()); } SSMParamsBwd params; set_ssm_params_bwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C, u, delta, A, B, C, z, out, out_z, D_.has_value() ? D_.value().data_ptr() : nullptr, delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr, x_.has_value() ? x_.value().data_ptr() : nullptr, dout, du, ddelta, dA, dB, dC, dz, D_.has_value() ? dD.data_ptr() : nullptr, delta_bias_.has_value() ? ddelta_bias.data_ptr() : nullptr, has_z, delta_softplus, recompute_out_z); // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing at::cuda::CUDAGuard device_guard{u.device()}; auto stream = at::cuda::getCurrentCUDAStream().stream(); DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_bwd", [&] { DISPATCH_WTYPE_FLOAT_AND_COMPLEX(A.scalar_type(), "selective_scan_bwd", [&] { selective_scan_bwd_cuda(params, stream); }); }); std::vector result = {du, ddelta, dA, dB.to(B.dtype()), dC.to(C.dtype()), dD, ddelta_bias}; if (has_z) { result.push_back(dz); } if (recompute_out_z) { result.push_back(out_z); } return result; } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fwd", &selective_scan_fwd, "Selective scan forward"); m.def("bwd", &selective_scan_bwd, "Selective scan backward"); } ================================================ FILE: csrc/selective_scan/selective_scan.h ================================================ /****************************************************************************** * Copyright (c) 2023, Tri Dao. ******************************************************************************/ #pragma once //////////////////////////////////////////////////////////////////////////////////////////////////// struct SSMScanParamsBase { using index_t = uint32_t; int batch, seqlen, n_chunks; index_t a_batch_stride; index_t b_batch_stride; index_t out_batch_stride; // Common data pointers. void *__restrict__ a_ptr; void *__restrict__ b_ptr; void *__restrict__ out_ptr; void *__restrict__ x_ptr; }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct SSMParamsBase { using index_t = uint32_t; int batch, dim, seqlen, dstate, n_groups, n_chunks; int dim_ngroups_ratio; bool is_variable_B; bool is_variable_C; bool delta_softplus; index_t A_d_stride; index_t A_dstate_stride; index_t B_batch_stride; index_t B_d_stride; index_t B_dstate_stride; index_t B_group_stride; index_t C_batch_stride; index_t C_d_stride; index_t C_dstate_stride; index_t C_group_stride; index_t u_batch_stride; index_t u_d_stride; index_t delta_batch_stride; index_t delta_d_stride; index_t z_batch_stride; index_t z_d_stride; index_t out_batch_stride; index_t out_d_stride; index_t out_z_batch_stride; index_t out_z_d_stride; // Common data pointers. void *__restrict__ A_ptr; void *__restrict__ B_ptr; void *__restrict__ C_ptr; void *__restrict__ D_ptr; void *__restrict__ u_ptr; void *__restrict__ delta_ptr; void *__restrict__ delta_bias_ptr; void *__restrict__ out_ptr; void *__restrict__ x_ptr; void *__restrict__ z_ptr; void *__restrict__ out_z_ptr; }; struct SSMParamsBwd: public SSMParamsBase { index_t dout_batch_stride; index_t dout_d_stride; index_t dA_d_stride; index_t dA_dstate_stride; index_t dB_batch_stride; index_t dB_group_stride; index_t dB_d_stride; index_t dB_dstate_stride; index_t dC_batch_stride; index_t dC_group_stride; index_t dC_d_stride; index_t dC_dstate_stride; index_t du_batch_stride; index_t du_d_stride; index_t dz_batch_stride; index_t dz_d_stride; index_t ddelta_batch_stride; index_t ddelta_d_stride; // Common data pointers. void *__restrict__ dout_ptr; void *__restrict__ dA_ptr; void *__restrict__ dB_ptr; void *__restrict__ dC_ptr; void *__restrict__ dD_ptr; void *__restrict__ du_ptr; void *__restrict__ dz_ptr; void *__restrict__ ddelta_ptr; void *__restrict__ ddelta_bias_ptr; }; ================================================ FILE: csrc/selective_scan/selective_scan_bwd_bf16_complex.cu ================================================ /****************************************************************************** * Copyright (c) 2023, Tri Dao. ******************************************************************************/ // Split into multiple files to compile in paralell #include "selective_scan_bwd_kernel.cuh" template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); ================================================ FILE: csrc/selective_scan/selective_scan_bwd_bf16_real.cu ================================================ /****************************************************************************** * Copyright (c) 2023, Tri Dao. ******************************************************************************/ // Split into multiple files to compile in paralell #include "selective_scan_bwd_kernel.cuh" template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); ================================================ FILE: csrc/selective_scan/selective_scan_bwd_fp16_complex.cu ================================================ /****************************************************************************** * Copyright (c) 2023, Tri Dao. ******************************************************************************/ // Split into multiple files to compile in paralell #include "selective_scan_bwd_kernel.cuh" template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); ================================================ FILE: csrc/selective_scan/selective_scan_bwd_fp16_real.cu ================================================ /****************************************************************************** * Copyright (c) 2023, Tri Dao. ******************************************************************************/ // Split into multiple files to compile in paralell #include "selective_scan_bwd_kernel.cuh" template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); ================================================ FILE: csrc/selective_scan/selective_scan_bwd_fp32_complex.cu ================================================ /****************************************************************************** * Copyright (c) 2023, Tri Dao. ******************************************************************************/ // Split into multiple files to compile in paralell #include "selective_scan_bwd_kernel.cuh" template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); ================================================ FILE: csrc/selective_scan/selective_scan_bwd_fp32_real.cu ================================================ /****************************************************************************** * Copyright (c) 2023, Tri Dao. ******************************************************************************/ // Split into multiple files to compile in paralell #include "selective_scan_bwd_kernel.cuh" template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); ================================================ FILE: csrc/selective_scan/selective_scan_bwd_kernel.cuh ================================================ /****************************************************************************** * Copyright (c) 2023, Tri Dao. ******************************************************************************/ #pragma once #include #include #include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK #include // For atomicAdd on complex #ifndef USE_ROCM #include #include #include #include #else #include namespace cub = hipcub; #endif #include "selective_scan.h" #include "selective_scan_common.h" #include "reverse_scan.cuh" #include "static_switch.h" template __device__ __forceinline__ scalar_t conj(scalar_t x); template<> __device__ __forceinline__ float conj(float x) { return x; } template<> __device__ __forceinline__ complex_t conj(complex_t x) { return std::conj(x); } template struct Selective_Scan_bwd_kernel_traits { static_assert(kNItems_ % 4 == 0); using input_t = input_t_; using weight_t = weight_t_; static constexpr int kNThreads = kNThreads_; static constexpr int kNItems = kNItems_; static constexpr int kNBytes = sizeof(input_t); static_assert(kNBytes == 2 || kNBytes == 4); static constexpr int kNElts = kNBytes == 4 ? 4 : constexpr_min(8, kNItems); static_assert(kNItems % kNElts == 0); static constexpr int kNLoads = kNItems / kNElts; static constexpr bool kIsComplex = std::is_same_v; static constexpr bool kIsEvenLen = kIsEvenLen_; static constexpr bool kIsVariableB = kIsVariableB_; static constexpr bool kIsVariableC = kIsVariableC_; static constexpr bool kDeltaSoftplus = kDeltaSoftplus_; static constexpr bool kHasZ = kHasZ_; // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads with float improves occupancy. // For complex this would lead to massive register spilling, so we keep it at 2. static constexpr int kMinBlocks = kNThreads == 128 && !kIsComplex ? 3 : 2; using vec_t = typename BytesToType::Type; using scan_t = std::conditional_t; using BlockLoadT = cub::BlockLoad; using BlockLoadVecT = cub::BlockLoad; using BlockLoadWeightT = cub::BlockLoad; using BlockLoadWeightVecT = cub::BlockLoad; using BlockStoreT = cub::BlockStore; using BlockStoreVecT = cub::BlockStore; // using BlockScanT = cub::BlockScan; using BlockScanT = cub::BlockScan; // using BlockScanT = cub::BlockScan; using BlockReverseScanT = BlockReverseScan; using BlockReduceT = cub::BlockReduce; using BlockReduceFloatT = cub::BlockReduce; using BlockReduceComplexT = cub::BlockReduce; using BlockExchangeT = cub::BlockExchange; static constexpr int kSmemIOSize = custom_max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockLoadVecT::TempStorage), (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage), (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage), sizeof(typename BlockStoreT::TempStorage), sizeof(typename BlockStoreVecT::TempStorage)}); static constexpr int kSmemExchangeSize = (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockExchangeT::TempStorage); static constexpr int kSmemReduceSize = sizeof(typename BlockReduceT::TempStorage); static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize + kSmemReduceSize + sizeof(typename BlockScanT::TempStorage) + sizeof(typename BlockReverseScanT::TempStorage); }; template __global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks) void selective_scan_bwd_kernel(SSMParamsBwd params) { constexpr bool kIsComplex = Ktraits::kIsComplex; constexpr bool kIsVariableB = Ktraits::kIsVariableB; constexpr bool kIsVariableC = Ktraits::kIsVariableC; constexpr bool kDeltaSoftplus = Ktraits::kDeltaSoftplus; constexpr bool kHasZ = Ktraits::kHasZ; constexpr int kNThreads = Ktraits::kNThreads; constexpr int kNItems = Ktraits::kNItems; using input_t = typename Ktraits::input_t; using weight_t = typename Ktraits::weight_t; using scan_t = typename Ktraits::scan_t; // Shared memory. extern __shared__ char smem_[]; // cast to lvalue reference of expected type // char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t); // auto& smem_load = reinterpret_cast(smem_ + 2 * MAX_DSTATE * sizeof(weight_t)); // auto& smem_load = reinterpret_cast(smem_loadstorescan); auto& smem_load = reinterpret_cast(smem_); auto& smem_load_weight = reinterpret_cast(smem_); auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); auto& smem_store = reinterpret_cast(smem_); auto& smem_exchange = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); auto& smem_exchange1 = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize + sizeof(typename Ktraits::BlockExchangeT::TempStorage)); auto& smem_reduce = *reinterpret_cast(reinterpret_cast(&smem_exchange) + Ktraits::kSmemExchangeSize); auto& smem_reduce_float = *reinterpret_cast(&smem_reduce); auto& smem_reduce_complex = *reinterpret_cast(&smem_reduce); auto& smem_scan = *reinterpret_cast(reinterpret_cast(&smem_reduce) + Ktraits::kSmemReduceSize); auto& smem_reverse_scan = *reinterpret_cast(reinterpret_cast(&smem_scan) + sizeof(typename Ktraits::BlockScanT::TempStorage)); weight_t *smem_delta_a = reinterpret_cast(smem_ + Ktraits::kSmemSize); scan_t *smem_running_postfix = reinterpret_cast(smem_delta_a + 2 * MAX_DSTATE + kNThreads); weight_t *smem_da = reinterpret_cast(smem_running_postfix + MAX_DSTATE); weight_t *smem_dbc = reinterpret_cast(smem_da + MAX_DSTATE); const int batch_id = blockIdx.x; const int dim_id = blockIdx.y; const int group_id = dim_id / (params.dim_ngroups_ratio); input_t *u = reinterpret_cast(params.u_ptr) + batch_id * params.u_batch_stride + dim_id * params.u_d_stride; input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride + dim_id * params.delta_d_stride; input_t *dout = reinterpret_cast(params.dout_ptr) + batch_id * params.dout_batch_stride + dim_id * params.dout_d_stride; weight_t *A = reinterpret_cast(params.A_ptr) + dim_id * params.A_d_stride; weight_t *B = reinterpret_cast(params.B_ptr) + dim_id * params.B_d_stride; input_t *Bvar = reinterpret_cast(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride; weight_t *C = reinterpret_cast(params.C_ptr) + dim_id * params.C_d_stride; input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; weight_t *dA = reinterpret_cast(params.dA_ptr) + dim_id * params.dA_d_stride; weight_t *dB = reinterpret_cast(params.dB_ptr) + (!kIsVariableB ? dim_id * params.dB_d_stride : batch_id * (!kIsComplex ? params.dB_batch_stride : params.dB_batch_stride / 2) + group_id * params.dB_group_stride); weight_t *dC = reinterpret_cast(params.dC_ptr) + (!kIsVariableC ? dim_id * params.dC_d_stride : batch_id * (!kIsComplex ? params.dC_batch_stride : params.dC_batch_stride / 2) + group_id * params.dC_group_stride); float *dD = params.dD_ptr == nullptr ? nullptr : reinterpret_cast(params.dD_ptr) + dim_id; float D_val = params.D_ptr == nullptr ? 0 : reinterpret_cast(params.D_ptr)[dim_id]; float *ddelta_bias = params.ddelta_bias_ptr == nullptr ? nullptr : reinterpret_cast(params.ddelta_bias_ptr) + dim_id; float delta_bias = params.delta_bias_ptr == nullptr ? 0 : reinterpret_cast(params.delta_bias_ptr)[dim_id]; scan_t *x = params.x_ptr == nullptr ? nullptr : reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id) * (params.n_chunks) * params.dstate; float dD_val = 0; float ddelta_bias_val = 0; constexpr int kChunkSize = kNThreads * kNItems; u += (params.n_chunks - 1) * kChunkSize; delta += (params.n_chunks - 1) * kChunkSize; dout += (params.n_chunks - 1) * kChunkSize; Bvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2); Cvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2); for (int chunk = params.n_chunks - 1; chunk >= 0; --chunk) { input_t u_vals[kNItems]; input_t delta_vals_load[kNItems]; input_t dout_vals_load[kNItems]; __syncthreads(); load_input(u, u_vals, smem_load, params.seqlen - chunk * kChunkSize); u -= kChunkSize; __syncthreads(); load_input(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize); // Will reload delta at the same location if kDeltaSoftplus if constexpr (!kDeltaSoftplus) { delta -= kChunkSize; } __syncthreads(); load_input(dout, dout_vals_load, smem_load, params.seqlen - chunk * kChunkSize); dout -= kChunkSize; float dout_vals[kNItems], delta_vals[kNItems]; #pragma unroll for (int i = 0; i < kNItems; ++i) { dout_vals[i] = float(dout_vals_load[i]); delta_vals[i] = float(delta_vals_load[i]) + delta_bias; if constexpr (kDeltaSoftplus) { delta_vals[i] = delta_vals[i] <= 20.f ? log1pf(expf(delta_vals[i])) : delta_vals[i]; } } if constexpr (kHasZ) { input_t *z = reinterpret_cast(params.z_ptr) + batch_id * params.z_batch_stride + dim_id * params.z_d_stride + chunk * kChunkSize; input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + dim_id * params.out_d_stride + chunk * kChunkSize; input_t *dz = reinterpret_cast(params.dz_ptr) + batch_id * params.dz_batch_stride + dim_id * params.dz_d_stride + chunk * kChunkSize; input_t z_vals[kNItems], out_vals[kNItems]; __syncthreads(); load_input(z, z_vals, smem_load, params.seqlen - chunk * kChunkSize); __syncthreads(); load_input(out, out_vals, smem_load, params.seqlen - chunk * kChunkSize); float dz_vals[kNItems], z_silu_vals[kNItems]; #pragma unroll for (int i = 0; i < kNItems; ++i) { float z_val = z_vals[i]; float z_sigmoid_val = 1.0f / (1.0f + expf(-z_val)); z_silu_vals[i] = z_val * z_sigmoid_val; dz_vals[i] = dout_vals[i] * float(out_vals[i]) * z_sigmoid_val * (1.0f + z_val * (1.0f - z_sigmoid_val)); dout_vals[i] *= z_silu_vals[i]; } __syncthreads(); store_output(dz, dz_vals, smem_store, params.seqlen - chunk * kChunkSize); if (params.out_z_ptr != nullptr) { // Recompute and store out_z float out_z_vals[kNItems]; #pragma unroll for (int i = 0; i < kNItems; ++i) { out_z_vals[i] = float(out_vals[i]) * z_silu_vals[i]; } // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0) { // printf("out_val=%f, z_silu_val = %f, out_z_val = %f\n", float(out_vals[0]), z_silu_vals[0], out_z_vals[0]); // } input_t *out_z = reinterpret_cast(params.out_z_ptr) + batch_id * params.out_z_batch_stride + dim_id * params.out_z_d_stride + chunk * kChunkSize; __syncthreads(); store_output(out_z, out_z_vals, smem_store, params.seqlen - chunk * kChunkSize); } } float du_vals[kNItems]; #pragma unroll for (int i = 0; i < kNItems; ++i) { du_vals[i] = D_val * dout_vals[i]; } #pragma unroll for (int i = 0; i < kNItems; ++i) { dD_val += dout_vals[i] * float(u_vals[i]); } float ddelta_vals[kNItems] = {0}; __syncthreads(); for (int state_idx = 0; state_idx < params.dstate; ++state_idx) { const weight_t A_val = A[state_idx * params.A_dstate_stride]; // Multiply the real part of A with LOG2E so we can use exp2f instead of expf. weight_t A_scaled; constexpr float kLog2e = M_LOG2E; if constexpr (!kIsComplex) { A_scaled = A_val * kLog2e; } else { A_scaled = complex_t(A_val.real_ * kLog2e, A_val.imag_); } weight_t B_val, C_val; weight_t B_vals[kNItems], C_vals[kNItems]; if constexpr (!kIsVariableB) { B_val = B[state_idx * params.B_dstate_stride]; } else { load_weight(Bvar + state_idx * params.B_dstate_stride, B_vals, smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); } if constexpr (!kIsVariableC) { C_val = C[state_idx * params.C_dstate_stride]; } else { auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1; load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); } // const weight_t A_val = smem_a[state_idx]; scan_t thread_data[kNItems], thread_reverse_data[kNItems]; if constexpr (!kIsComplex) { #pragma unroll for (int i = 0; i < kNItems; ++i) { const float delta_a_exp = exp2f(delta_vals[i] * A_scaled); thread_data[i] = make_float2(delta_a_exp, !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]); if (i == 0) { smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * MAX_DSTATE : threadIdx.x + 2 * MAX_DSTATE] = delta_a_exp; } else { thread_reverse_data[i - 1].x = delta_a_exp; } thread_reverse_data[i].y = dout_vals[i] * (!kIsVariableC ? (!kIsVariableB ? B_val * C_val : C_val) : (!kIsVariableB ? B_val * C_vals[i] : C_vals[i])); } __syncthreads(); thread_reverse_data[kNItems - 1].x = threadIdx.x == kNThreads - 1 ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * MAX_DSTATE]) : smem_delta_a[threadIdx.x + 1 + 2 * MAX_DSTATE]; // Initialize running total scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float2(1.f, 0.f); SSMScanPrefixCallbackOp prefix_op(running_prefix); typename Ktraits::BlockScanT(smem_scan).InclusiveScan( thread_data, thread_data, SSMScanOp(), prefix_op ); scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float2(1.f, 0.f); SSMScanPrefixCallbackOp postfix_op(running_postfix); typename Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan( thread_reverse_data, thread_reverse_data, SSMScanOp(), postfix_op ); if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; } weight_t dA_val = 0, dBC_val = 0; weight_t dB_vals[kNItems], dC_vals[kNItems]; #pragma unroll for (int i = 0; i < kNItems; ++i) { const float dx = thread_reverse_data[i].y; const float ddelta_u = !kIsVariableB ? dx : dx * B_vals[i]; du_vals[i] += ddelta_u * delta_vals[i]; const float a = thread_data[i].y - (!kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]); ddelta_vals[i] += ddelta_u * float(u_vals[i]) + dx * A_val * a; dA_val += dx * delta_vals[i] * a; if constexpr (!kIsVariableB || !kIsVariableC) { if constexpr (!kIsVariableB) { // dBC_val is dB_val dBC_val += dout_vals[i] * (!kIsVariableC ? thread_data[i].y : thread_data[i].y * C_vals[i]); } else { // dBC_val is dC_val dBC_val += dout_vals[i] * thread_data[i].y; } } if constexpr (kIsVariableB) { dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]); } if constexpr (kIsVariableC) { dC_vals[i] = dout_vals[i] * (!kIsVariableB ? thread_data[i].y * B_val : thread_data[i].y); } } // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower if constexpr (kIsVariableB || kIsVariableC) { if constexpr (kIsVariableB) { typename Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals, dB_vals); } if constexpr (kIsVariableC) { auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1; typename Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals, dC_vals); } const int seqlen_remaining = params.seqlen - chunk * kChunkSize - threadIdx.x; weight_t *dB_cur = dB + state_idx * params.dB_dstate_stride + chunk * kChunkSize + threadIdx.x; weight_t *dC_cur = dC + state_idx * params.dC_dstate_stride + chunk * kChunkSize + threadIdx.x; #pragma unroll for (int i = 0; i < kNItems; ++i) { if (i * kNThreads < seqlen_remaining) { if constexpr (kIsVariableB) { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals[i]); } if constexpr (kIsVariableC) { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals[i]); } } } } if constexpr (!kIsVariableB || !kIsVariableC) { float2 dA_dBC_val = make_float2(dA_val, dBC_val); dA_dBC_val = typename Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val); dA_val = dA_dBC_val.x; if (threadIdx.x == 0) { smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dA_dBC_val.y : dA_dBC_val.y + smem_dbc[state_idx]; } } else { dA_val = typename Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dA_val); } if (threadIdx.x == 0) { smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx]; } } else { #pragma unroll for (int i = 0; i < kNItems; ++i) { // Pytorch's implementation of complex exp (which calls thrust) is very slow complex_t delta_a_exp = cexp2f(delta_vals[i] * A_scaled); weight_t B_delta_u_val = !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : B_vals[i] * delta_vals[i] * float(u_vals[i]); thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_); if (i == 0) { smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * MAX_DSTATE : threadIdx.x + 2 * MAX_DSTATE] = delta_a_exp; } else { thread_reverse_data[i - 1].x = delta_a_exp.real_; thread_reverse_data[i - 1].y = -delta_a_exp.imag_; } complex_t dout_BC = 2 * dout_vals[i] * conj(!kIsVariableC ? (!kIsVariableB ? B_val * C_val : C_val) : (!kIsVariableB ? B_val * C_vals[i] : C_vals[i])); thread_reverse_data[i].z = dout_BC.real_; thread_reverse_data[i].w = dout_BC.imag_; } __syncthreads(); complex_t delta_a_exp = threadIdx.x == kNThreads - 1 ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * MAX_DSTATE]) : smem_delta_a[threadIdx.x + 1 + 2 * MAX_DSTATE]; thread_reverse_data[kNItems - 1].x = delta_a_exp.real_; thread_reverse_data[kNItems - 1].y = -delta_a_exp.imag_; // Initialize running total scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float4(1.f, 0.f, 0.f, 0.f); SSMScanPrefixCallbackOp prefix_op(running_prefix); typename Ktraits::BlockScanT(smem_scan).InclusiveScan( thread_data, thread_data, SSMScanOp(), prefix_op ); scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f); SSMScanPrefixCallbackOp postfix_op(running_postfix); typename Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan( thread_reverse_data, thread_reverse_data, SSMScanOp(), postfix_op ); if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; } weight_t dA_val = 0, dBC_val = 0; weight_t dB_vals[kNItems], dC_vals[kNItems]; #pragma unroll for (int i = 0; i < kNItems; ++i) { complex_t x = complex_t(thread_data[i].z, thread_data[i].w); complex_t dx = complex_t(thread_reverse_data[i].z, thread_reverse_data[i].w); float ddelta_u = !kIsVariableB ? dx.real_ : (dx * conj(B_vals[i])).real_; if constexpr (!kIsVariableB || !kIsVariableC) { if constexpr (!kIsVariableB) { // dBC_val is dB_val dBC_val += (2 * dout_vals[i]) * conj(!kIsVariableC ? x : x * C_vals[i]); } else { // dBC_val is dC_val dBC_val += (2 * dout_vals[i]) * conj(x); } } const complex_t a_conj = conj(x - (!kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i])); du_vals[i] += ddelta_u * delta_vals[i]; ddelta_vals[i] += ddelta_u * float(u_vals[i]) + (dx * conj(A_val) * a_conj).real_; dA_val += delta_vals[i] * dx * a_conj; if constexpr (kIsVariableB) { dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]); } if constexpr (kIsVariableC) { dC_vals[i] = (2 * dout_vals[i]) * conj(!kIsVariableB ? x * B_val : x); } } // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower if constexpr (kIsVariableB || kIsVariableC) { float dB_vals_f[kNItems * 2], dC_vals_f[kNItems * 2]; if constexpr (kIsVariableB) { #pragma unroll for (int i = 0; i < kNItems; ++i) { dB_vals_f[i * 2] = dB_vals[i].real_; dB_vals_f[i * 2 + 1] = dB_vals[i].imag_; } typename Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals_f, dB_vals_f); } if constexpr (kIsVariableC) { #pragma unroll for (int i = 0; i < kNItems; ++i) { dC_vals_f[i * 2] = dC_vals[i].real_; dC_vals_f[i * 2 + 1] = dC_vals[i].imag_; } auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1; typename Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals_f, dC_vals_f); } const int seqlen_remaining = (params.seqlen - chunk * kChunkSize) * 2 - threadIdx.x; float *dB_cur = reinterpret_cast(dB) + state_idx * params.dB_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x; float *dC_cur = reinterpret_cast(dC) + state_idx * params.dC_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x; #pragma unroll for (int i = 0; i < kNItems * 2; ++i) { if (i * kNThreads < seqlen_remaining) { if constexpr (kIsVariableB) { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals_f[i]); } if constexpr (kIsVariableC) { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals_f[i]); } } } } if constexpr (!kIsVariableB || !kIsVariableC) { float4 dA_dBC_val = make_float4(dA_val.real_, dA_val.imag_, dBC_val.real_, dBC_val.imag_); dA_dBC_val = typename Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val); dA_val = complex_t(dA_dBC_val.x, dA_dBC_val.y); dBC_val = complex_t(dA_dBC_val.z, dA_dBC_val.w); if (threadIdx.x == 0) { smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dBC_val : dBC_val + smem_dbc[state_idx]; } } else { dA_val = typename Ktraits::BlockReduceComplexT(smem_reduce_complex).Sum(dA_val); } if (threadIdx.x == 0) { smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx]; } } } if constexpr (kDeltaSoftplus) { __syncthreads(); input_t delta_vals_load[kNItems]; load_input(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize); delta -= kChunkSize; #pragma unroll for (int i = 0; i < kNItems; ++i) { float delta_val = float(delta_vals_load[i]) + delta_bias; float delta_val_neg_exp = expf(-delta_val); ddelta_vals[i] = delta_val <= 20.f ? ddelta_vals[i] / (1.f + delta_val_neg_exp) : ddelta_vals[i]; } } for (int i = 0; i < kNItems; ++i) { ddelta_bias_val += ddelta_vals[i]; } input_t *du = reinterpret_cast(params.du_ptr) + batch_id * params.du_batch_stride + dim_id * params.du_d_stride + chunk * kChunkSize; input_t *ddelta = reinterpret_cast(params.ddelta_ptr) + batch_id * params.ddelta_batch_stride + dim_id * params.ddelta_d_stride + chunk * kChunkSize; __syncthreads(); store_output(du, du_vals, smem_store, params.seqlen - chunk * kChunkSize); __syncthreads(); store_output(ddelta, ddelta_vals, smem_store, params.seqlen - chunk * kChunkSize); Bvar -= kChunkSize * (!kIsComplex ? 1 : 2); Cvar -= kChunkSize * (!kIsComplex ? 1 : 2); } if (params.dD_ptr != nullptr) { dD_val = typename Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dD_val); if (threadIdx.x == 0) { gpuAtomicAdd(dD, dD_val); } } if (params.ddelta_bias_ptr != nullptr) { __syncthreads(); ddelta_bias_val = typename Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(ddelta_bias_val); if (threadIdx.x == 0) { gpuAtomicAdd(ddelta_bias, ddelta_bias_val); } } for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) { gpuAtomicAdd(&(dA[state_idx * params.dA_dstate_stride]), smem_da[state_idx]); weight_t dBC_val; if (!kIsVariableB || !kIsVariableC) { dBC_val = smem_dbc[state_idx]; } if constexpr (!kIsVariableB) { gpuAtomicAdd(&(dB[state_idx * params.dB_dstate_stride]), !kIsVariableC ? dBC_val * conj(C[state_idx * params.C_dstate_stride]) : dBC_val); } if constexpr (!kIsVariableC) { gpuAtomicAdd(&(dC[state_idx * params.dC_dstate_stride]), !kIsVariableB ? dBC_val * conj(B[state_idx * params.B_dstate_stride]) : dBC_val); } } } template void selective_scan_bwd_launch(SSMParamsBwd ¶ms, cudaStream_t stream) { BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] { BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] { BOOL_SWITCH(params.delta_softplus, kDeltaSoftplus, [&] { BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] { using Ktraits = Selective_Scan_bwd_kernel_traits; // using Ktraits = Selective_Scan_bwd_kernel_traits; // TODO: check this constexpr int kSmemSize = Ktraits::kSmemSize + MAX_DSTATE * sizeof(typename Ktraits::scan_t) + (kNThreads + 4 * MAX_DSTATE) * sizeof(typename Ktraits::weight_t); dim3 grid(params.batch, params.dim); auto kernel = &selective_scan_bwd_kernel; if (kSmemSize >= 48 * 1024) { #ifndef USE_ROCM C10_CUDA_CHECK(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); #else C10_CUDA_CHECK(cudaFuncSetAttribute( (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); std::cerr << "Warning (selective_scan_bwd_kernel): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl; #endif } kernel<<>>(params); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); }); }); }); }); } template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream) { #ifndef USE_ROCM if (params.seqlen <= 128) { selective_scan_bwd_launch<32, 4, input_t, weight_t>(params, stream); } else if (params.seqlen <= 256) { selective_scan_bwd_launch<32, 8, input_t, weight_t>(params, stream); } else if (params.seqlen <= 512) { selective_scan_bwd_launch<32, 16, input_t, weight_t>(params, stream); } else if (params.seqlen <= 1024) { selective_scan_bwd_launch<64, 16, input_t, weight_t>(params, stream); } else { selective_scan_bwd_launch<128, 16, input_t, weight_t>(params, stream); } #else if (params.seqlen <= 256) { selective_scan_bwd_launch<64, 4, input_t, weight_t>(params, stream); } else if (params.seqlen <= 512) { selective_scan_bwd_launch<64, 8, input_t, weight_t>(params, stream); } else if (params.seqlen <= 1024) { selective_scan_bwd_launch<64, 16, input_t, weight_t>(params, stream); } else { selective_scan_bwd_launch<128, 16, input_t, weight_t>(params, stream); } #endif } ================================================ FILE: csrc/selective_scan/selective_scan_common.h ================================================ /****************************************************************************** * Copyright (c) 2023, Tri Dao. ******************************************************************************/ #pragma once #ifndef USE_ROCM #include #else #include #endif #include #include // For scalar_value_type #ifndef USE_ROCM constexpr size_t custom_max(std::initializer_list ilist) { return std::max(ilist); } template constexpr T constexpr_min(T a, T b) { return std::min(a, b); } #else constexpr size_t custom_max(std::initializer_list ilist) { return *std::max_element(ilist.begin(), ilist.end()); } template constexpr T constexpr_min(T a, T b) { return a < b ? a : b; } #endif #define MAX_DSTATE 256 using complex_t = c10::complex; inline __device__ float2 operator+(const float2 & a, const float2 & b){ return {a.x + b.x, a.y + b.y}; } inline __device__ float3 operator+(const float3 &a, const float3 &b) { return {a.x + b.x, a.y + b.y, a.z + b.z}; } inline __device__ float4 operator+(const float4 & a, const float4 & b){ return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w}; } //////////////////////////////////////////////////////////////////////////////////////////////////// template struct BytesToType {}; template<> struct BytesToType<16> { using Type = uint4; static_assert(sizeof(Type) == 16); }; template<> struct BytesToType<8> { using Type = uint64_t; static_assert(sizeof(Type) == 8); }; template<> struct BytesToType<4> { using Type = uint32_t; static_assert(sizeof(Type) == 4); }; template<> struct BytesToType<2> { using Type = uint16_t; static_assert(sizeof(Type) == 2); }; template<> struct BytesToType<1> { using Type = uint8_t; static_assert(sizeof(Type) == 1); }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Converter{ static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) { #pragma unroll for (int i = 0; i < N; ++i) { dst[i] = src[i]; } } }; template struct Converter{ static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) { static_assert(N % 2 == 0); auto &src2 = reinterpret_cast(src); auto &dst2 = reinterpret_cast(dst); #pragma unroll for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); } } }; #if __CUDA_ARCH__ >= 800 template struct Converter{ static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) { static_assert(N % 2 == 0); auto &src2 = reinterpret_cast(src); auto &dst2 = reinterpret_cast(dst); #pragma unroll for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); } } }; #endif //////////////////////////////////////////////////////////////////////////////////////////////////// // From https://stackoverflow.com/questions/9860711/cucomplex-h-and-exp // and https://forums.developer.nvidia.com/t/complex-number-exponential-function/24696 __device__ __forceinline__ complex_t cexp2f(complex_t z) { float t = exp2f(z.real_); float c, s; sincosf(z.imag_, &s, &c); return complex_t(c * t, s * t); } __device__ __forceinline__ complex_t cexpf(complex_t z) { float t = expf(z.real_); float c, s; sincosf(z.imag_, &s, &c); return complex_t(c * t, s * t); } template struct SSMScanOp; template<> struct SSMScanOp { __device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const { return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y); } }; template<> struct SSMScanOp { __device__ __forceinline__ float4 operator()(const float4 &ab0, const float4 &ab1) const { complex_t a0 = complex_t(ab0.x, ab0.y); complex_t b0 = complex_t(ab0.z, ab0.w); complex_t a1 = complex_t(ab1.x, ab1.y); complex_t b1 = complex_t(ab1.z, ab1.w); complex_t out_a = a1 * a0; complex_t out_b = a1 * b0 + b1; return make_float4(out_a.real_, out_a.imag_, out_b.real_, out_b.imag_); } }; // A stateful callback functor that maintains a running prefix to be applied // during consecutive scan operations. template struct SSMScanPrefixCallbackOp { using scan_t = std::conditional_t, float2, float4>; scan_t running_prefix; // Constructor __device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {} // Callback operator to be entered by the first warp of threads in the block. // Thread-0 is responsible for returning a value for seeding the block-wide scan. __device__ scan_t operator()(scan_t block_aggregate) { scan_t old_prefix = running_prefix; running_prefix = SSMScanOp()(running_prefix, block_aggregate); return old_prefix; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void load_input(typename Ktraits::input_t *u, typename Ktraits::input_t (&u_vals)[Ktraits::kNItems], typename Ktraits::BlockLoadT::TempStorage &smem_load, int seqlen) { if constexpr (Ktraits::kIsEvenLen) { auto& smem_load_vec = reinterpret_cast(smem_load); using vec_t = typename Ktraits::vec_t; typename Ktraits::BlockLoadVecT(smem_load_vec).Load( reinterpret_cast(u), reinterpret_cast(u_vals) #ifdef USE_ROCM , Ktraits::kNThreads * Ktraits::kNLoads #endif ); } else { typename Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f); } } template inline __device__ void load_weight(typename Ktraits::input_t *Bvar, typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems], typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight, int seqlen) { constexpr int kNItems = Ktraits::kNItems; if constexpr (!Ktraits::kIsComplex) { typename Ktraits::input_t B_vals_load[kNItems]; if constexpr (Ktraits::kIsEvenLen) { auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight); using vec_t = typename Ktraits::vec_t; typename Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( reinterpret_cast(Bvar), reinterpret_cast(B_vals_load) ); } else { typename Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f); } // #pragma unroll // for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; } Converter::to_float(B_vals_load, B_vals); } else { typename Ktraits::input_t B_vals_load[kNItems * 2]; if constexpr (Ktraits::kIsEvenLen) { auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight); using vec_t = typename Ktraits::vec_t; typename Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( reinterpret_cast(Bvar), reinterpret_cast(B_vals_load) ); } else { typename Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f); } #pragma unroll for (int i = 0; i < kNItems; ++i) { B_vals[i] = complex_t(B_vals_load[i * 2], B_vals_load[i * 2 + 1]); } } } template inline __device__ void store_output(typename Ktraits::input_t *out, const float (&out_vals)[Ktraits::kNItems], typename Ktraits::BlockStoreT::TempStorage &smem_store, int seqlen) { typename Ktraits::input_t write_vals[Ktraits::kNItems]; #pragma unroll for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; } if constexpr (Ktraits::kIsEvenLen) { auto& smem_store_vec = reinterpret_cast(smem_store); using vec_t = typename Ktraits::vec_t; typename Ktraits::BlockStoreVecT(smem_store_vec).Store( reinterpret_cast(out), reinterpret_cast(write_vals) ); } else { typename Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen); } } ================================================ FILE: csrc/selective_scan/selective_scan_fwd_bf16.cu ================================================ /****************************************************************************** * Copyright (c) 2023, Tri Dao. ******************************************************************************/ // Split into multiple files to compile in paralell #include "selective_scan_fwd_kernel.cuh" template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); ================================================ FILE: csrc/selective_scan/selective_scan_fwd_fp16.cu ================================================ /****************************************************************************** * Copyright (c) 2023, Tri Dao. ******************************************************************************/ // Split into multiple files to compile in paralell #include "selective_scan_fwd_kernel.cuh" template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); ================================================ FILE: csrc/selective_scan/selective_scan_fwd_fp32.cu ================================================ /****************************************************************************** * Copyright (c) 2023, Tri Dao. ******************************************************************************/ // Split into multiple files to compile in paralell #include "selective_scan_fwd_kernel.cuh" template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); ================================================ FILE: csrc/selective_scan/selective_scan_fwd_kernel.cuh ================================================ /****************************************************************************** * Copyright (c) 2023, Tri Dao. ******************************************************************************/ #pragma once #include #include #include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK #ifndef USE_ROCM #include #include #include #else #include namespace cub = hipcub; #endif #include "selective_scan.h" #include "selective_scan_common.h" #include "static_switch.h" template struct Selective_Scan_fwd_kernel_traits { static_assert(kNItems_ % 4 == 0); using input_t = input_t_; using weight_t = weight_t_; static constexpr int kNThreads = kNThreads_; // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy. static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3; static constexpr int kNItems = kNItems_; static constexpr int kNRows = kNRows_; static constexpr int kNBytes = sizeof(input_t); static_assert(kNBytes == 2 || kNBytes == 4); static constexpr int kNElts = kNBytes == 4 ? 4 : constexpr_min(8, kNItems); static_assert(kNItems % kNElts == 0); static constexpr int kNLoads = kNItems / kNElts; static constexpr bool kIsComplex = std::is_same_v; static constexpr bool kIsEvenLen = kIsEvenLen_; static constexpr bool kIsVariableB = kIsVariableB_; static constexpr bool kIsVariableC = kIsVariableC_; static constexpr bool kHasZ = kHasZ_; static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1; using vec_t = typename BytesToType::Type; using scan_t = std::conditional_t; using BlockLoadT = cub::BlockLoad; using BlockLoadVecT = cub::BlockLoad; using BlockLoadWeightT = cub::BlockLoad; using BlockLoadWeightVecT = cub::BlockLoad; using BlockStoreT = cub::BlockStore; using BlockStoreVecT = cub::BlockStore; // using BlockScanT = cub::BlockScan; // using BlockScanT = cub::BlockScan; using BlockScanT = cub::BlockScan; static constexpr int kSmemIOSize = custom_max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockLoadVecT::TempStorage), (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage), (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage), sizeof(typename BlockStoreT::TempStorage), sizeof(typename BlockStoreVecT::TempStorage)}); static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage); }; template __global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks) void selective_scan_fwd_kernel(SSMParamsBase params) { constexpr bool kIsComplex = Ktraits::kIsComplex; constexpr bool kIsVariableB = Ktraits::kIsVariableB; constexpr bool kIsVariableC = Ktraits::kIsVariableC; constexpr bool kHasZ = Ktraits::kHasZ; constexpr int kNThreads = Ktraits::kNThreads; constexpr int kNItems = Ktraits::kNItems; constexpr int kNRows = Ktraits::kNRows; constexpr bool kDirectIO = Ktraits::kDirectIO; using input_t = typename Ktraits::input_t; using weight_t = typename Ktraits::weight_t; using scan_t = typename Ktraits::scan_t; // Shared memory. extern __shared__ char smem_[]; // cast to lvalue reference of expected type // char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t); // auto& smem_load = reinterpret_cast(smem_ + 2 * MAX_DSTATE * sizeof(weight_t)); // auto& smem_load = reinterpret_cast(smem_loadstorescan); auto& smem_load = reinterpret_cast(smem_); auto& smem_load_weight = reinterpret_cast(smem_); auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); auto& smem_store = reinterpret_cast(smem_); auto& smem_scan = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); // weight_t *smem_a = reinterpret_cast(smem_ + smem_loadstorescan_size); // weight_t *smem_bc = reinterpret_cast(smem_a + MAX_DSTATE); scan_t *smem_running_prefix = reinterpret_cast(smem_ + Ktraits::kSmemSize); const int batch_id = blockIdx.x; const int dim_id = blockIdx.y; const int group_id = dim_id / (params.dim_ngroups_ratio); input_t *u = reinterpret_cast(params.u_ptr) + batch_id * params.u_batch_stride + dim_id * kNRows * params.u_d_stride; input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride + dim_id * kNRows * params.delta_d_stride; weight_t *A = reinterpret_cast(params.A_ptr) + dim_id * kNRows * params.A_d_stride; weight_t *B = reinterpret_cast(params.B_ptr) + dim_id * kNRows * params.B_d_stride; input_t *Bvar = reinterpret_cast(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride; weight_t *C = reinterpret_cast(params.C_ptr) + dim_id * kNRows * params.C_d_stride; input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; scan_t *x = reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate; float D_val[kNRows] = {0}; if (params.D_ptr != nullptr) { #pragma unroll for (int r = 0; r < kNRows; ++r) { D_val[r] = reinterpret_cast(params.D_ptr)[dim_id * kNRows + r]; } } float delta_bias[kNRows] = {0}; if (params.delta_bias_ptr != nullptr) { #pragma unroll for (int r = 0; r < kNRows; ++r) { delta_bias[r] = reinterpret_cast(params.delta_bias_ptr)[dim_id * kNRows + r]; } } // for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) { // smem_a[state_idx] = A[state_idx * params.A_dstate_stride]; // smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride]; // } constexpr int kChunkSize = kNThreads * kNItems; for (int chunk = 0; chunk < params.n_chunks; ++chunk) { input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems]; __syncthreads(); #pragma unroll for (int r = 0; r < kNRows; ++r) { if constexpr (!kDirectIO) { if (r > 0) { __syncthreads(); } } load_input(u + r * params.u_d_stride, u_vals[r], smem_load, params.seqlen - chunk * kChunkSize); if constexpr (!kDirectIO) { __syncthreads(); } load_input(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize); } u += kChunkSize; delta += kChunkSize; float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems]; #pragma unroll for (int r = 0; r < kNRows; ++r) { #pragma unroll for (int i = 0; i < kNItems; ++i) { float u_val = float(u_vals[r][i]); delta_vals[r][i] = float(delta_vals_load[r][i]) + delta_bias[r]; if (params.delta_softplus) { delta_vals[r][i] = delta_vals[r][i] <= 20.f ? log1pf(expf(delta_vals[r][i])) : delta_vals[r][i]; } delta_u_vals[r][i] = delta_vals[r][i] * u_val; out_vals[r][i] = D_val[r] * u_val; } } __syncthreads(); for (int state_idx = 0; state_idx < params.dstate; ++state_idx) { weight_t A_val[kNRows]; #pragma unroll for (int r = 0; r < kNRows; ++r) { A_val[r] = A[state_idx * params.A_dstate_stride + r * params.A_d_stride]; // Multiply the real part of A with LOG2E so we can use exp2f instead of expf. constexpr float kLog2e = M_LOG2E; if constexpr (!kIsComplex) { A_val[r] *= kLog2e; } else { A_val[r].real_ *= kLog2e; } } // This variable holds B * C if both B and C are constant across seqlen. If only B varies // across seqlen, this holds C. If only C varies across seqlen, this holds B. // If both B and C vary, this is unused. weight_t BC_val[kNRows]; weight_t B_vals[kNItems], C_vals[kNItems]; if constexpr (kIsVariableB) { load_weight(Bvar + state_idx * params.B_dstate_stride, B_vals, smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); if constexpr (!kIsVariableC) { #pragma unroll for (int r = 0; r < kNRows; ++r) { BC_val[r] = C[state_idx * params.C_dstate_stride + r * params.C_d_stride]; } } } if constexpr (kIsVariableC) { auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1; load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); if constexpr (!kIsVariableB) { #pragma unroll for (int r = 0; r < kNRows; ++r) { BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride]; } } } if constexpr (!kIsVariableB && !kIsVariableC) { #pragma unroll for (int r = 0; r < kNRows; ++r) { BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride] * C[state_idx * params.C_dstate_stride + r * params.C_d_stride]; } } #pragma unroll for (int r = 0; r < kNRows; ++r) { if (r > 0) { __syncthreads(); } // Scan could be using the same smem scan_t thread_data[kNItems]; #pragma unroll for (int i = 0; i < kNItems; ++i) { if constexpr (!kIsComplex) { thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]), !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]); if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) { thread_data[i] = make_float2(1.f, 0.f); } } } else { // Pytorch's implementation of complex exp (which calls thrust) is very slow complex_t delta_a_exp = cexp2f(delta_vals[r][i] * A_val[r]); weight_t B_delta_u_val = !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]; thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_); if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) { thread_data[i] = make_float4(1.f, 0.f, 0.f, 0.f); } } } } // Initialize running total scan_t running_prefix; if constexpr (!kIsComplex) { // If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.f, 0.f); // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f); } else { running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float4(1.f, 0.f, 0.f, 0.f); // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f); } SSMScanPrefixCallbackOp prefix_op(running_prefix); typename Ktraits::BlockScanT(smem_scan).InclusiveScan( thread_data, thread_data, SSMScanOp(), prefix_op ); // There's a syncthreads in the scan op, so we don't need to sync here. // Unless there's only 1 warp, but then it's the same thread (0) reading and writing. if (threadIdx.x == 0) { smem_running_prefix[state_idx] = prefix_op.running_prefix; x[(r * params.n_chunks + chunk) * params.dstate + state_idx] = prefix_op.running_prefix; } #pragma unroll for (int i = 0; i < kNItems; ++i) { const weight_t C_val = !kIsVariableC ? BC_val[r] : (!kIsVariableB ? BC_val[r] * C_vals[i] : C_vals[i]); if constexpr (!kIsComplex) { out_vals[r][i] += thread_data[i].y * C_val; } else { out_vals[r][i] += (complex_t(thread_data[i].z, thread_data[i].w) * C_val).real_ * 2; } } } } input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + dim_id * kNRows * params.out_d_stride + chunk * kChunkSize; __syncthreads(); #pragma unroll for (int r = 0; r < kNRows; ++r) { if constexpr (!kDirectIO) { if (r > 0) { __syncthreads(); } } store_output(out + r * params.out_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize); } if constexpr (kHasZ) { input_t *z = reinterpret_cast(params.z_ptr) + batch_id * params.z_batch_stride + dim_id * kNRows * params.z_d_stride + chunk * kChunkSize; input_t *out_z = reinterpret_cast(params.out_z_ptr) + batch_id * params.out_z_batch_stride + dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize; #pragma unroll for (int r = 0; r < kNRows; ++r) { input_t z_vals[kNItems]; __syncthreads(); load_input(z + r * params.z_d_stride, z_vals, smem_load, params.seqlen - chunk * kChunkSize); #pragma unroll for (int i = 0; i < kNItems; ++i) { float z_val = z_vals[i]; out_vals[r][i] *= z_val / (1 + expf(-z_val)); } __syncthreads(); store_output(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize); } } Bvar += kChunkSize * (!kIsComplex ? 1 : 2); Cvar += kChunkSize * (!kIsComplex ? 1 : 2); } } template void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { // Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block // processing 1 row. constexpr int kNRows = 1; BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] { BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] { BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] { using Ktraits = Selective_Scan_fwd_kernel_traits; constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t); dim3 grid(params.batch, params.dim / kNRows); // Had to change this substantially since potentially the hip // interface for setting kernel launch attributes is slightly different from // cuda's. In particualar, it seems to expect a plain const void * pointer. auto kernel = &selective_scan_fwd_kernel; if (kSmemSize >= 48 * 1024) { #ifndef USE_ROCM C10_CUDA_CHECK(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); #else C10_CUDA_CHECK(cudaFuncSetAttribute( (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); std::cerr << "Warning (selective_scan_fwd_kernel): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl; #endif } kernel<<>>(params); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); }); }); }); } template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) { #ifndef USE_ROCM if (params.seqlen <= 128) { selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream); } else if (params.seqlen <= 256) { selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream); } else if (params.seqlen <= 512) { selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream); } else if (params.seqlen <= 1024) { selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream); } else { selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream); } #else if (params.seqlen <= 256) { selective_scan_fwd_launch<64, 4, input_t, weight_t>(params, stream); } else if (params.seqlen <= 512) { selective_scan_fwd_launch<64, 8, input_t, weight_t>(params, stream); } else if (params.seqlen <= 1024) { selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream); } else { selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream); } #endif } ================================================ FILE: csrc/selective_scan/static_switch.h ================================================ // Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h #pragma once /// @param COND - a boolean expression to switch by /// @param CONST_NAME - a name given for the constexpr bool variable. /// @param ... - code to execute for true and false /// /// Usage: /// ``` /// BOOL_SWITCH(flag, BoolConst, [&] { /// some_function(...); /// }); /// ``` #define BOOL_SWITCH(COND, CONST_NAME, ...) \ [&] { \ if (COND) { \ constexpr bool CONST_NAME = true; \ return __VA_ARGS__(); \ } else { \ constexpr bool CONST_NAME = false; \ return __VA_ARGS__(); \ } \ }() ================================================ FILE: csrc/selective_scan/uninitialized_copy.cuh ================================================ /****************************************************************************** * Copyright (c) 2011-2022, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * ******************************************************************************/ #pragma once #ifndef USE_ROCM #include #include #else #include // Map ::cuda::std to the standard std namespace namespace cuda { namespace std = ::std; } #endif namespace detail { #if defined(_NVHPC_CUDA) template __host__ __device__ void uninitialized_copy(T *ptr, U &&val) { // NVBug 3384810 new (ptr) T(::cuda::std::forward(val)); } #else template ::value, int >::type = 0> __host__ __device__ void uninitialized_copy(T *ptr, U &&val) { *ptr = ::cuda::std::forward(val); } template ::value, int >::type = 0> __host__ __device__ void uninitialized_copy(T *ptr, U &&val) { new (ptr) T(::cuda::std::forward(val)); } #endif } // namespace detail ================================================ FILE: evals/lm_harness_eval.py ================================================ import torch import transformers from transformers import AutoTokenizer from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel from lm_eval.api.model import LM from lm_eval.models.huggingface import HFLM from lm_eval.api.registry import register_model from lm_eval.__main__ import cli_evaluate @register_model("mamba") class MambaEvalWrapper(HFLM): AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM def __init__(self, pretrained="state-spaces/mamba-2.8b", max_length=2048, batch_size=None, device="cuda", dtype=torch.float16): LM.__init__(self) self._model = MambaLMHeadModel.from_pretrained(pretrained, device=device, dtype=dtype) self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") self.tokenizer.pad_token_id = self.tokenizer.eos_token_id self.vocab_size = self.tokenizer.vocab_size self._batch_size = int(batch_size) if batch_size is not None else 64 self._max_length = max_length self._device = torch.device(device) @property def batch_size(self): return self._batch_size def _model_generate(self, context, max_length, stop, **generation_kwargs): raise NotImplementedError() if __name__ == "__main__": cli_evaluate() ================================================ FILE: mamba_ssm/__init__.py ================================================ __version__ = "2.3.1" from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn from mamba_ssm.modules.mamba_simple import Mamba from mamba_ssm.modules.mamba2 import Mamba2 from mamba_ssm.modules.mamba3 import Mamba3 from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel ================================================ FILE: mamba_ssm/distributed/__init__.py ================================================ ================================================ FILE: mamba_ssm/distributed/distributed_utils.py ================================================ from typing import Optional import torch from torch import Tensor from torch.distributed import ProcessGroup # `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for # `_all_gather_base` and `_reduce_scatter_base`. They require the most recent # version of PyTorch. The following 4 lines are for backward compatibility with # older PyTorch. if "all_gather_into_tensor" not in dir(torch.distributed): torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base if "reduce_scatter_tensor" not in dir(torch.distributed): torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base # Raw operation, does not support autograd, but does support async def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): world_size = torch.distributed.get_world_size(process_group) output = torch.empty( world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device ) handle = torch.distributed.all_gather_into_tensor( output, input_.contiguous(), group=process_group, async_op=async_op ) return output, handle # Raw operation, does not support autograd, but does support async def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): world_size = torch.distributed.get_world_size(process_group) assert input_.shape[0] % world_size == 0 output = torch.empty( input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device ) handle = torch.distributed.reduce_scatter_tensor( output, input_.contiguous(), group=process_group, async_op=async_op ) return output, handle # Raw operation, does not support autograd, but does support async def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): input_ = input_.contiguous() handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op) return input_, handle class AllGatherFunc(torch.autograd.Function): """Gather the input from sequence parallel region and concatenate.""" @staticmethod def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: ctx.process_group = process_group output, _ = all_gather_raw(input_, process_group) return output @staticmethod def backward(ctx, grad_output: Tensor): grad_input, _ = reduce_scatter_raw(grad_output, ctx.process_group) return grad_input, None # Supports autograd, but does not support async all_gather = AllGatherFunc.apply class ReduceScatterFunc(torch.autograd.Function): """Reduce scatter the input from the sequence parallel region and concatenate.""" @staticmethod def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: ctx.process_group = process_group output, _ = reduce_scatter_raw(input_, process_group) return output @staticmethod def backward(ctx, grad_output: Tensor): grad_input, _ = all_gather_raw(grad_output, ctx.process_group) return grad_input, None # Supports autograd, but does not support async reduce_scatter = ReduceScatterFunc.apply class AllReduceFunc(torch.autograd.Function): """Gather the input from sequence parallel region and concatenate.""" @staticmethod def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: ctx.process_group = process_group output, _ = all_reduce_raw(input_, process_group) return output @staticmethod def backward(ctx, grad_output: Tensor): return grad_output, None # Supports autograd, but does not support async all_reduce = AllReduceFunc.apply def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup): # We want to iterate over parameters with _shared_params=True in the same order, # as different ranks might have different number of parameters (e.g., only rank 0 has bias). pamams_shared = { name: p for name, p in model.named_parameters() if getattr(p, "_shared_params", False) } for _, p in sorted(pamams_shared.items()): with torch.no_grad(): # Broadcast needs src to be global rank, not group rank torch.distributed.broadcast( p, src=torch.distributed.get_global_rank(process_group, 0), group=process_group ) # Ref: https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/optimizer/optimizer.py#L256 def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup): # We want to iterate over parameters with _sequence_parallel=True in the same order, # as different ranks might have different number of parameters (e.g., only rank 0 has bias). params_seqparallel = { name: p for name, p in model.named_parameters() if getattr(p, "_sequence_parallel", False) } grads = [p.grad for _, p in sorted(params_seqparallel.items())] if grads: with torch.no_grad(): coalesced = torch._utils._flatten_dense_tensors(grads) torch.distributed.all_reduce(coalesced, group=process_group) for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)): buf.copy_(synced) def get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, multiple_of: int = 1) -> int: """Get the dim for the local rank derived from splitting dim on world_size processes. The split may not be even across the world_size processes. """ multiple = dim // multiple_of div = multiple // world_size mod = multiple % world_size local_multiple = div + int(local_rank < mod) return local_multiple * multiple_of ================================================ FILE: mamba_ssm/distributed/tensor_parallel.py ================================================ # Copyright (c) 2024, Tri Dao. # The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from torch.distributed import ProcessGroup from mamba_ssm.utils.torch import custom_bwd, custom_fwd from einops import rearrange from mamba_ssm.distributed.distributed_utils import ( all_gather_raw, all_reduce, all_reduce_raw, reduce_scatter, reduce_scatter_raw, ) class ParallelLinearFunc(torch.autograd.Function): @staticmethod @custom_fwd def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True): """ If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel with sequence parallelism: we do an all_gather_raw of x before doing the matmul. """ ctx.compute_weight_gradient = weight.requires_grad ctx.process_group = process_group ctx.sequence_parallel = sequence_parallel if torch.is_autocast_enabled(): x = x.to(dtype=torch.get_autocast_gpu_dtype()) x = x.contiguous() if process_group is not None and sequence_parallel: # We want to kick off the all_gather early, before weight dtype conversion total_x, handle_x = all_gather_raw(x, process_group, async_op=True) else: total_x = x if torch.is_autocast_enabled(): weight = weight.to(dtype=torch.get_autocast_gpu_dtype()) bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None weight = weight.contiguous() if process_group is not None and sequence_parallel: handle_x.wait() batch_shape, n = total_x.shape[:-1], total_x.shape[-1] batch_dim = batch_shape.numel() # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174 output = F.linear(total_x, weight, bias) if ctx.compute_weight_gradient: ctx.save_for_backward(x, weight) else: ctx.save_for_backward(weight) return output @staticmethod @custom_bwd def backward(ctx, grad_output): grad_output = grad_output.contiguous() process_group = ctx.process_group sequence_parallel = ctx.sequence_parallel if ctx.compute_weight_gradient: x, weight = ctx.saved_tensors if process_group is not None and sequence_parallel: total_x, handle_x = all_gather_raw(x, process_group, async_op=True) else: total_x = x else: (weight,) = ctx.saved_tensors total_x = None batch_shape = grad_output.shape[:-1] batch_dim = batch_shape.numel() grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) if ctx.needs_input_grad[0]: grad_input = F.linear(grad_output, weight.t()) grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) if process_group is not None: reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True) else: grad_input = None if ctx.needs_input_grad[1]: assert ctx.compute_weight_gradient if process_group is not None and sequence_parallel: handle_x.wait() grad_weight = torch.einsum( "bo,bi->oi", grad_output, total_x.reshape(batch_dim, total_x.shape[-1]) ) else: grad_weight = None grad_bias = grad_output.sum(dim=0) if ctx.needs_input_grad[2] else None if process_group is not None and ctx.needs_input_grad[0]: handle_grad_input.wait() return grad_input, grad_weight, grad_bias, None, None def parallel_linear_func( x: Tensor, weight: Tensor, bias: Optional[Tensor] = None, process_group: Optional[ProcessGroup] = None, sequence_parallel: bool = True, ): return ParallelLinearFunc.apply(x, weight, bias, process_group, sequence_parallel) class ColumnParallelLinear(nn.Linear): def __init__( self, in_features: int, out_features: int, process_group: ProcessGroup, bias: bool = True, sequence_parallel=True, multiple_of=1, device=None, dtype=None, ) -> None: world_size = torch.distributed.get_world_size(process_group) if out_features % multiple_of: raise ValueError(f"out_features ({out_features}) must be a multiple of {multiple_of}") multiple = out_features // multiple_of # We want to split @multiple across world_size, but it could be an uneven split div = multiple // world_size mod = multiple % world_size # The first @mod ranks get @div + 1 copies, the rest get @div copies local_multiple = div + int(torch.distributed.get_rank(process_group) < mod) super().__init__( in_features, local_multiple * multiple_of, bias=bias, device=device, dtype=dtype ) self.process_group = process_group self.sequence_parallel = sequence_parallel def forward(self, x): # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: # we do an all_gather of x before doing the matmul. # If not, then the input is already gathered. return parallel_linear_func( x, self.weight, self.bias, process_group=self.process_group, sequence_parallel=self.sequence_parallel, ) class RowParallelLinear(nn.Linear): def __init__( self, in_features: int, out_features: int, process_group: ProcessGroup, bias: bool = True, sequence_parallel=True, multiple_of=1, device=None, dtype=None, ) -> None: world_size = torch.distributed.get_world_size(process_group) rank = torch.distributed.get_rank(process_group) if in_features % multiple_of: raise ValueError(f"in_features ({in_features}) must be a multiple of {multiple_of}") multiple = in_features // multiple_of # We want to split @multiple across world_size, but it could be an uneven split div = multiple // world_size mod = multiple % world_size # The first @mod ranks get @div + 1 copies, the rest get @div copies local_multiple = div + int(torch.distributed.get_rank(process_group) < mod) # Only rank 0 will have bias super().__init__( local_multiple * multiple_of, out_features, bias=bias and rank == 0, device=device, dtype=dtype, ) self.process_group = process_group self.sequence_parallel = sequence_parallel def forward(self, x): """ We're doing Tensor Parallel with sequence parallelism: we do the matmul and then a reduce_scatter of the result. """ out = parallel_linear_func(x, self.weight, self.bias) reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce return reduce_fn(out, self.process_group) class VocabParallelEmbedding(nn.Embedding): def __init__(self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs): self.process_group = process_group if process_group is not None: world_size = torch.distributed.get_world_size(process_group) if num_embeddings % world_size != 0: raise ValueError( f"num_embeddings ({num_embeddings}) must be divisible by " f"world_size ({world_size})" ) if world_size > 1 and padding_idx is not None: raise RuntimeError("ParallelEmbedding does not support padding_idx") else: world_size = 1 super().__init__(num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs) def forward(self, input: Tensor) -> Tensor: if self.process_group is None: return super().forward(input) else: rank = torch.distributed.get_rank(self.process_group) vocab_size = self.num_embeddings vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size # Create a mask of valid vocab ids (1 means it needs to be masked). input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index) input = input - vocab_start_index input[input_ids_mask] = 0 embeddings = super().forward(input) embeddings[input_ids_mask] = 0.0 return embeddings class ColumnParallelEmbedding(nn.Embedding): def __init__(self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs): self.process_group = process_group if process_group is not None: world_size = torch.distributed.get_world_size(process_group) if embedding_dim % world_size != 0: raise ValueError( f"embedding_dim ({embedding_dim}) must be divisible by " f"world_size ({world_size})" ) else: world_size = 1 super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs) class ParallelEmbeddings(nn.Module): def __init__( self, embed_dim, vocab_size, max_position_embeddings, process_group, padding_idx=None, sequence_parallel=True, device=None, dtype=None, ): """ If max_position_embeddings <= 0, there's no position embeddings """ factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.process_group = process_group self.sequence_parallel = sequence_parallel self.word_embeddings = VocabParallelEmbedding( vocab_size, embed_dim, padding_idx=padding_idx, process_group=process_group, **factory_kwargs, ) self.max_position_embeddings = max_position_embeddings if self.max_position_embeddings > 0: self.position_embeddings = ColumnParallelEmbedding( max_position_embeddings, embed_dim, process_group=process_group, **factory_kwargs ) def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False): """ input_ids: (batch, seqlen) position_ids: (batch, seqlen) """ batch_size, seqlen = input_ids.shape world_size = torch.distributed.get_world_size(self.process_group) embeddings = self.word_embeddings(input_ids) if self.max_position_embeddings > 0: if position_ids is None: position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device) position_embeddings = self.position_embeddings(position_ids) if world_size <= 1: embeddings = embeddings + position_embeddings else: partition_dim = self.position_embeddings.embedding_dim rank = torch.distributed.get_rank(self.process_group) embeddings[ ..., rank * partition_dim : (rank + 1) * partition_dim ] += position_embeddings if combine_batch_seqlen_dim: embeddings = rearrange(embeddings, "b s d -> (b s) d") reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce return embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group) ================================================ FILE: mamba_ssm/models/__init__.py ================================================ ================================================ FILE: mamba_ssm/models/config_mamba.py ================================================ from dataclasses import dataclass, field @dataclass class MambaConfig: d_model: int = 2560 d_intermediate: int = 0 n_layer: int = 64 vocab_size: int = 50277 ssm_cfg: dict = field(default_factory=dict) attn_layer_idx: list = field(default_factory=list) attn_cfg: dict = field(default_factory=dict) rms_norm: bool = True residual_in_fp32: bool = True fused_add_norm: bool = True pad_vocab_size_multiple: int = 8 tie_embeddings: bool = True ================================================ FILE: mamba_ssm/models/mixer_seq_simple.py ================================================ # Copyright (c) 2023, Albert Gu, Tri Dao. import math from functools import partial import json import os import copy from collections import namedtuple import torch import torch.nn as nn from mamba_ssm.models.config_mamba import MambaConfig from mamba_ssm.modules.mamba_simple import Mamba from mamba_ssm.modules.mamba2 import Mamba2 from mamba_ssm.modules.mha import MHA from mamba_ssm.modules.mlp import GatedMLP from mamba_ssm.modules.block import Block from mamba_ssm.utils.generation import GenerationMixin from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf try: from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn except ImportError: RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None def create_block( d_model, d_intermediate, ssm_cfg=None, attn_layer_idx=None, attn_cfg=None, norm_epsilon=1e-5, rms_norm=False, residual_in_fp32=False, fused_add_norm=False, layer_idx=None, device=None, dtype=None, ): if ssm_cfg is None: ssm_cfg = {} if attn_layer_idx is None: attn_layer_idx = [] if attn_cfg is None: attn_cfg = {} factory_kwargs = {"device": device, "dtype": dtype} if layer_idx not in attn_layer_idx: # Create a copy of the config to modify ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {} ssm_layer = ssm_cfg.pop("layer", "Mamba1") if ssm_layer not in ["Mamba1", "Mamba2"]: raise ValueError(f"Invalid ssm_layer: {ssm_layer}, only support Mamba1 and Mamba2") mixer_cls = partial( Mamba2 if ssm_layer == "Mamba2" else Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs ) else: mixer_cls = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs) norm_cls = partial( nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs ) if d_intermediate == 0: mlp_cls = nn.Identity else: mlp_cls = partial( GatedMLP, hidden_features=d_intermediate, out_features=d_model, **factory_kwargs ) block = Block( d_model, mixer_cls, mlp_cls, norm_cls=norm_cls, fused_add_norm=fused_add_norm, residual_in_fp32=residual_in_fp32, ) block.layer_idx = layer_idx return block # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454 def _init_weights( module, n_layer, initializer_range=0.02, # Now only used for embedding layer. rescale_prenorm_residual=True, n_residuals_per_layer=1, # Change to 2 if we have MLP ): if isinstance(module, nn.Linear): if module.bias is not None: if not getattr(module.bias, "_no_reinit", False): nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, std=initializer_range) if rescale_prenorm_residual: # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. # > -- GPT-2 :: https://openai.com/blog/better-language-models/ # # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py for name, p in module.named_parameters(): if name in ["out_proj.weight", "fc2.weight"]: # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) # We need to reinit p since this code could be called multiple times # Having just p *= scale would repeatedly scale it down nn.init.kaiming_uniform_(p, a=math.sqrt(5)) with torch.no_grad(): p /= math.sqrt(n_residuals_per_layer * n_layer) class MixerModel(nn.Module): def __init__( self, d_model: int, n_layer: int, d_intermediate: int, vocab_size: int, ssm_cfg=None, attn_layer_idx=None, attn_cfg=None, norm_epsilon: float = 1e-5, rms_norm: bool = False, initializer_cfg=None, fused_add_norm=False, residual_in_fp32=False, device=None, dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.residual_in_fp32 = residual_in_fp32 self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs) # We change the order of residual and layer norm: # Instead of LN -> Attn / MLP -> Add, we do: # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and # the main branch (output of MLP / Mixer). The model definition is unchanged. # This is for performance reason: we can fuse add + layer_norm. self.fused_add_norm = fused_add_norm if self.fused_add_norm: if layer_norm_fn is None or rms_norm_fn is None: raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels") self.layers = nn.ModuleList( [ create_block( d_model, d_intermediate=d_intermediate, ssm_cfg=ssm_cfg, attn_layer_idx=attn_layer_idx, attn_cfg=attn_cfg, norm_epsilon=norm_epsilon, rms_norm=rms_norm, residual_in_fp32=residual_in_fp32, fused_add_norm=fused_add_norm, layer_idx=i, **factory_kwargs, ) for i in range(n_layer) ] ) self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)( d_model, eps=norm_epsilon, **factory_kwargs ) self.apply( partial( _init_weights, n_layer=n_layer, **(initializer_cfg if initializer_cfg is not None else {}), n_residuals_per_layer=1 if d_intermediate == 0 else 2, # 2 if we have MLP ) ) def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): return { i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) for i, layer in enumerate(self.layers) } def forward(self, input_ids, inference_params=None, **mixer_kwargs): hidden_states = self.embedding(input_ids) residual = None for layer in self.layers: hidden_states, residual = layer( hidden_states, residual, inference_params=inference_params, **mixer_kwargs ) if not self.fused_add_norm: residual = (hidden_states + residual) if residual is not None else hidden_states hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype)) else: # Set prenorm=False here since we don't need the residual hidden_states = layer_norm_fn( hidden_states, self.norm_f.weight, self.norm_f.bias, eps=self.norm_f.eps, residual=residual, prenorm=False, residual_in_fp32=self.residual_in_fp32, is_rms_norm=isinstance(self.norm_f, RMSNorm) ) return hidden_states class MambaLMHeadModel(nn.Module, GenerationMixin): def __init__( self, config: MambaConfig, initializer_cfg=None, device=None, dtype=None, ) -> None: self.config = config d_model = config.d_model n_layer = config.n_layer d_intermediate = config.d_intermediate vocab_size = config.vocab_size ssm_cfg = config.ssm_cfg attn_layer_idx = config.attn_layer_idx attn_cfg = config.attn_cfg rms_norm = config.rms_norm residual_in_fp32 = config.residual_in_fp32 fused_add_norm = config.fused_add_norm pad_vocab_size_multiple = config.pad_vocab_size_multiple factory_kwargs = {"device": device, "dtype": dtype} super().__init__() if vocab_size % pad_vocab_size_multiple != 0: vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple) self.backbone = MixerModel( d_model=d_model, n_layer=n_layer, d_intermediate=d_intermediate, vocab_size=vocab_size, ssm_cfg=ssm_cfg, attn_layer_idx=attn_layer_idx, attn_cfg=attn_cfg, rms_norm=rms_norm, initializer_cfg=initializer_cfg, fused_add_norm=fused_add_norm, residual_in_fp32=residual_in_fp32, **factory_kwargs, ) self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs) # Initialize weights and apply final processing self.apply( partial( _init_weights, n_layer=n_layer, **(initializer_cfg if initializer_cfg is not None else {}), ) ) self.tie_weights() def tie_weights(self): if self.config.tie_embeddings: self.lm_head.weight = self.backbone.embedding.weight def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0, **mixer_kwargs): """ "position_ids" is just to be compatible with Transformer generation. We don't use it. num_last_tokens: if > 0, only return the logits for the last n tokens """ hidden_states = self.backbone(input_ids, inference_params=inference_params, **mixer_kwargs) if num_last_tokens > 0: hidden_states = hidden_states[:, -num_last_tokens:] lm_logits = self.lm_head(hidden_states) CausalLMOutput = namedtuple("CausalLMOutput", ["logits"]) return CausalLMOutput(logits=lm_logits) @classmethod def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs): config_data = load_config_hf(pretrained_model_name) config = MambaConfig(**config_data) model = cls(config, device=device, dtype=dtype, **kwargs) model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype)) return model def save_pretrained(self, save_directory): """ Minimal implementation of save_pretrained for MambaLMHeadModel. Save the model and its configuration file to a directory. """ # Ensure save_directory exists os.makedirs(save_directory, exist_ok=True) # Save the model's state_dict model_path = os.path.join(save_directory, 'pytorch_model.bin') torch.save(self.state_dict(), model_path) # Save the configuration of the model config_path = os.path.join(save_directory, 'config.json') with open(config_path, 'w') as f: json.dump(self.config.__dict__, f, indent=4) ================================================ FILE: mamba_ssm/modules/__init__.py ================================================ ================================================ FILE: mamba_ssm/modules/block.py ================================================ # Copyright (c) 2024, Tri Dao, Albert Gu. from typing import Optional import torch from torch import nn, Tensor from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn class Block(nn.Module): def __init__( self, dim, mixer_cls, mlp_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False ): """ Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection" This Block has a slightly different structure compared to a regular prenorm Transformer block. The standard block is: LN -> MHA/MLP -> Add. [Ref: https://arxiv.org/abs/2002.04745] Here we have: Add -> LN -> Mixer, returning both the hidden_states (output of the mixer) and the residual. This is purely for performance reasons, as we can fuse add and LayerNorm. The residual needs to be provided (except for the very first block). """ super().__init__() self.residual_in_fp32 = residual_in_fp32 self.fused_add_norm = fused_add_norm self.norm = norm_cls(dim) self.mixer = mixer_cls(dim) if mlp_cls is not nn.Identity: self.norm2 = norm_cls(dim) self.mlp = mlp_cls(dim) else: self.mlp = None if self.fused_add_norm: assert RMSNorm is not None, "RMSNorm import fails" assert isinstance( self.norm, (nn.LayerNorm, RMSNorm) ), "Only LayerNorm and RMSNorm are supported for fused_add_norm" def forward( self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None, **mixer_kwargs ): r"""Pass the input through the encoder layer. Args: hidden_states: the sequence to the encoder layer (required). residual: hidden_states = Mixer(LN(residual)) """ if not self.fused_add_norm: residual = (hidden_states + residual) if residual is not None else hidden_states hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) if self.residual_in_fp32: residual = residual.to(torch.float32) else: hidden_states, residual = layer_norm_fn( hidden_states, self.norm.weight, self.norm.bias, residual=residual, prenorm=True, residual_in_fp32=self.residual_in_fp32, eps=self.norm.eps, is_rms_norm=isinstance(self.norm, RMSNorm) ) hidden_states = self.mixer(hidden_states, inference_params=inference_params, **mixer_kwargs) if self.mlp is not None: if not self.fused_add_norm: residual = hidden_states + residual hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype)) if self.residual_in_fp32: residual = residual.to(torch.float32) else: hidden_states, residual = layer_norm_fn( hidden_states, self.norm2.weight, self.norm2.bias, residual=residual, prenorm=True, residual_in_fp32=self.residual_in_fp32, eps=self.norm2.eps, is_rms_norm=isinstance(self.norm2, RMSNorm) ) hidden_states = self.mlp(hidden_states) return hidden_states, residual def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) ================================================ FILE: mamba_ssm/modules/mamba2.py ================================================ # Copyright (c) 2024, Tri Dao, Albert Gu. import math import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat try: from causal_conv1d import causal_conv1d_fn, causal_conv1d_update except ImportError: causal_conv1d_fn, causal_conv1d_update = None, None try: from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states except ImportError: causal_conv1d_varlen_states = None try: from mamba_ssm.ops.triton.selective_state_update import selective_state_update except ImportError: selective_state_update = None from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated from mamba_ssm.distributed.tensor_parallel import ColumnParallelLinear, RowParallelLinear from mamba_ssm.distributed.distributed_utils import all_reduce, reduce_scatter from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined from huggingface_hub import PyTorchModelHubMixin class Mamba2(nn.Module, PyTorchModelHubMixin): def __init__( self, d_model, d_state=128, d_conv=4, conv_init=None, expand=2, headdim=64, d_ssm=None, # If not None, we only apply SSM on this many dimensions, the rest uses gated MLP ngroups=1, A_init_range=(1, 16), D_has_hdim=False, rmsnorm=True, norm_before_gate=False, dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4, dt_limit=(0.0, float("inf")), bias=False, conv_bias=True, # Fused kernel and sharding options chunk_size=256, use_mem_eff_path=True, layer_idx=None, # Absorb kwarg for general module process_group=None, sequence_parallel=True, device=None, dtype=None, ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.d_model = d_model self.d_state = d_state self.d_conv = d_conv self.conv_init = conv_init self.expand = expand self.process_group = process_group self.sequence_parallel = sequence_parallel self.world_size = 1 if process_group is None else process_group.size() self.local_rank = 0 if process_group is None else process_group.rank() self.d_inner = (self.expand * self.d_model) // self.world_size assert self.d_inner * self.world_size == self.expand * self.d_model self.headdim = headdim self.d_ssm = self.d_inner if d_ssm is None else d_ssm // self.world_size assert ngroups % self.world_size == 0 self.ngroups = ngroups // self.world_size assert self.d_ssm % self.headdim == 0 self.nheads = self.d_ssm // self.headdim self.D_has_hdim = D_has_hdim self.rmsnorm = rmsnorm self.norm_before_gate = norm_before_gate self.dt_limit = dt_limit self.activation = "silu" self.chunk_size = chunk_size self.use_mem_eff_path = use_mem_eff_path self.layer_idx = layer_idx # Order: [z, x, B, C, dt] d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads if self.process_group is None: self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs) else: self.in_proj = ColumnParallelLinear(self.d_model, d_in_proj * self.world_size, bias=bias, process_group=self.process_group, sequence_parallel=self.sequence_parallel, **factory_kwargs) conv_dim = self.d_ssm + 2 * self.ngroups * self.d_state self.conv1d = nn.Conv1d( in_channels=conv_dim, out_channels=conv_dim, bias=conv_bias, kernel_size=d_conv, groups=conv_dim, padding=d_conv - 1, **factory_kwargs, ) if self.conv_init is not None: nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init) self.act = nn.SiLU() # Initialize log dt bias dt = torch.exp( torch.rand(self.nheads, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) ) dt = torch.clamp(dt, min=dt_init_floor) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 inv_dt = dt + torch.log(-torch.expm1(-dt)) self.dt_bias = nn.Parameter(inv_dt) # Just to be explicit. Without this we already don't put wd on dt_bias because of the check # name.endswith("bias") in param_grouping.py self.dt_bias._no_weight_decay = True assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0] A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(*A_init_range) A_log = torch.log(A).to(dtype=dtype) self.A_log = nn.Parameter(A_log) self.A_log._no_weight_decay = True # D "skip" parameter self.D = nn.Parameter(torch.ones(self.d_ssm if self.D_has_hdim else self.nheads, device=device)) self.D._no_weight_decay = True if self.rmsnorm: assert RMSNormGated is not None self.norm = RMSNormGated(self.d_ssm, eps=1e-5, norm_before_gate=self.norm_before_gate, group_size=self.d_ssm // ngroups, **factory_kwargs) if self.process_group is None: self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) else: self.out_proj = RowParallelLinear(self.d_inner * self.world_size, self.d_model, bias=bias, process_group=self.process_group, sequence_parallel=self.sequence_parallel, **factory_kwargs) def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_params=None): """ u: (batch, seqlen, hidden_dim) if seqlen=None. If seqlen is not None, u is (batch * seqlen, hidden_dim). This is so that when we split u during sequence parallel, we split the batch * seqlen dimension (in case batch is small). Returns: same shape as u """ seqlen_og = seqlen if seqlen is None: batch, seqlen, dim = u.shape else: batch_seqlen, dim = u.shape batch = batch_seqlen // seqlen conv_state, ssm_state = None, None if inference_params is not None: inference_batch = cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch conv_state, ssm_state = self._get_states_from_cache(inference_params, inference_batch) if inference_params.seqlen_offset > 0: # The states are updated inplace out, _, _ = self.step(u, conv_state, ssm_state) return out zxbcdt = self.in_proj(u) # (B, L, d_in_proj) or (B * L, d_in_proj) if seqlen_og is not None: zxbcdt = rearrange(zxbcdt, "(b l) d -> b l d", l=seqlen) # If the model is loaded in fp16, without the .float() here, A might be -inf A = -torch.exp(self.A_log.float()) # (nheads) or (d_inner, d_state) dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit) if self.use_mem_eff_path and inference_params is None: out = mamba_split_conv1d_scan_combined( zxbcdt, rearrange(self.conv1d.weight, "d 1 w -> d w"), self.conv1d.bias, self.dt_bias, A, D=rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D, chunk_size=self.chunk_size, seq_idx=seq_idx, activation=self.activation, rmsnorm_weight=self.norm.weight if self.rmsnorm else None, rmsnorm_eps=self.norm.eps if self.rmsnorm else 1e-6, outproj_weight=self.out_proj.weight, outproj_bias=self.out_proj.bias, headdim=None if self.D_has_hdim else self.headdim, ngroups=self.ngroups, norm_before_gate=self.norm_before_gate, **dt_limit_kwargs, ) if seqlen_og is not None: out = rearrange(out, "b l d -> (b l) d") if self.process_group is not None: reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce out = reduce_fn(out, self.process_group) else: d_mlp = (zxbcdt.shape[-1] - 2 * self.d_ssm - 2 * self.ngroups * self.d_state - self.nheads) // 2 z0, x0, z, xBC, dt = torch.split( zxbcdt, [d_mlp, d_mlp, self.d_ssm, self.d_ssm + 2 * self.ngroups * self.d_state, self.nheads], dim=-1 ) if conv_state is not None: if cu_seqlens is None: # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. xBC_t = rearrange(xBC, "b l d -> b d l") conv_state.copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W) else: assert causal_conv1d_varlen_states is not None, "varlen inference requires causal_conv1d package" assert batch == 1, "varlen inference only supports batch dimension 1" conv_varlen_states = causal_conv1d_varlen_states( xBC.squeeze(0), cu_seqlens, state_len=conv_state.shape[-1] ) conv_state.copy_(conv_varlen_states) assert self.activation in ["silu", "swish"] if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]: assert seq_idx is None, "varlen conv1d requires the causal_conv1d package" xBC = self.act( self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, :-(self.d_conv - 1)] ) # (B, L, self.d_ssm + 2 * ngroups * d_state) else: xBC = causal_conv1d_fn( xBC.transpose(1, 2), rearrange(self.conv1d.weight, "d 1 w -> d w"), bias=self.conv1d.bias, activation=self.activation, seq_idx=seq_idx, ).transpose(1, 2) x, B, C = torch.split(xBC, [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1) y = mamba_chunk_scan_combined( rearrange(x, "b l (h p) -> b l h p", p=self.headdim), dt, A, rearrange(B, "b l (g n) -> b l g n", g=self.ngroups), rearrange(C, "b l (g n) -> b l g n", g=self.ngroups), chunk_size=self.chunk_size, D=rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D, z=rearrange(z, "b l (h p) -> b l h p", p=self.headdim) if not self.rmsnorm else None, dt_bias=self.dt_bias, dt_softplus=True, seq_idx=seq_idx, cu_seqlens=cu_seqlens, **dt_limit_kwargs, return_final_states=ssm_state is not None, return_varlen_states=cu_seqlens is not None and inference_params is not None, ) if ssm_state is not None: y, last_state, *rest = y if cu_seqlens is None: ssm_state.copy_(last_state) else: varlen_states = rest[0] ssm_state.copy_(varlen_states) y = rearrange(y, "b l h p -> b l (h p)") if self.rmsnorm: y = self.norm(y, z) if d_mlp > 0: y = torch.cat([F.silu(z0) * x0, y], dim=-1) if seqlen_og is not None: y = rearrange(y, "b l d -> (b l) d") out = self.out_proj(y) return out def step(self, hidden_states, conv_state, ssm_state): dtype = hidden_states.dtype assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now" zxbcdt = self.in_proj(hidden_states.squeeze(1)) # (B 2D) d_mlp = (zxbcdt.shape[-1] - 2 * self.d_ssm - 2 * self.ngroups * self.d_state - self.nheads) // 2 z0, x0, z, xBC, dt = torch.split( zxbcdt, [d_mlp, d_mlp, self.d_ssm, self.d_ssm + 2 * self.ngroups * self.d_state, self.nheads], dim=-1 ) # Conv step if causal_conv1d_update is None: conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) conv_state[:, :, -1] = xBC xBC = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) if self.conv1d.bias is not None: xBC = xBC + self.conv1d.bias xBC = self.act(xBC).to(dtype=dtype) else: xBC = causal_conv1d_update( xBC, conv_state, rearrange(self.conv1d.weight, "d 1 w -> d w"), self.conv1d.bias, self.activation, ) x, B, C = torch.split(xBC, [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1) A = -torch.exp(self.A_log.float()) # (nheads,) # SSM step if selective_state_update is None: assert self.ngroups == 1, "Only support ngroups=1 for this inference code path" # Discretize A and B dt = F.softplus(dt + self.dt_bias.to(dtype=dt.dtype)) # (batch, nheads) dA = torch.exp(dt * A) # (batch, nheads) x = rearrange(x, "b (h p) -> b h p", p=self.headdim) dBx = torch.einsum("bh,bn,bhp->bhpn", dt, B, x) ssm_state.copy_(ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx) y = torch.einsum("bhpn,bn->bhp", ssm_state.to(dtype), C) y = y + rearrange(self.D.to(dtype), "h -> h 1") * x y = rearrange(y, "b h p -> b (h p)") if not self.rmsnorm: y = y * self.act(z) # (B D) else: A = repeat(A, "h -> h p n", p=self.headdim, n=self.d_state).to(dtype=torch.float32) dt = repeat(dt, "b h -> b h p", p=self.headdim) dt_bias = repeat(self.dt_bias, "h -> h p", p=self.headdim) D = repeat(self.D, "h -> h p", p=self.headdim) B = rearrange(B, "b (g n) -> b g n", g=self.ngroups) C = rearrange(C, "b (g n) -> b g n", g=self.ngroups) x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.headdim) if not self.rmsnorm: z = rearrange(z, "b (h p) -> b h p", p=self.headdim) y = selective_state_update( ssm_state, x_reshaped, dt, A, B, C, D, z=z if not self.rmsnorm else None, dt_bias=dt_bias, dt_softplus=True ) y = rearrange(y, "b h p -> b (h p)") if self.rmsnorm: y = self.norm(y, z) if d_mlp > 0: y = torch.cat([F.silu(z0) * x0, y], dim=-1) out = self.out_proj(y) return out.unsqueeze(1), conv_state, ssm_state def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): device = self.out_proj.weight.device conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype conv_state = torch.zeros( batch_size, self.d_conv, self.conv1d.weight.shape[0], device=device, dtype=conv_dtype ).transpose(1, 2) ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype ssm_state = torch.zeros( batch_size, self.nheads, self.headdim, self.d_state, device=device, dtype=ssm_dtype ) return conv_state, ssm_state def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): assert self.layer_idx is not None if self.layer_idx not in inference_params.key_value_memory_dict: batch_shape = (batch_size,) conv_state = torch.zeros( batch_size, self.d_conv, self.conv1d.weight.shape[0], device=self.conv1d.weight.device, dtype=self.conv1d.weight.dtype, ).transpose(1, 2) ssm_state = torch.zeros( batch_size, self.nheads, self.headdim, self.d_state, device=self.in_proj.weight.device, dtype=self.in_proj.weight.dtype, ) inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state) else: conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] # TODO: What if batch size changes between generation, and we reuse the same states? if initialize_states: conv_state.zero_() ssm_state.zero_() return conv_state, ssm_state ================================================ FILE: mamba_ssm/modules/mamba2_simple.py ================================================ # Copyright (c) 2024, Tri Dao, Albert Gu. import math import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat try: from causal_conv1d import causal_conv1d_fn except ImportError: causal_conv1d_fn = None try: from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated, LayerNorm except ImportError: RMSNormGated, LayerNorm = None, None from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined class Mamba2Simple(nn.Module): def __init__( self, d_model, d_state=64, d_conv=4, conv_init=None, expand=2, headdim=128, ngroups=1, A_init_range=(1, 16), dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4, dt_limit=(0.0, float("inf")), learnable_init_states=False, activation="swish", bias=False, conv_bias=True, # Fused kernel and sharding options chunk_size=256, use_mem_eff_path=True, layer_idx=None, # Absorb kwarg for general module device=None, dtype=None, ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.d_model = d_model self.d_state = d_state self.d_conv = d_conv self.conv_init = conv_init self.expand = expand self.d_inner = self.expand * self.d_model self.headdim = headdim self.ngroups = ngroups assert self.d_inner % self.headdim == 0 self.nheads = self.d_inner // self.headdim self.dt_limit = dt_limit self.learnable_init_states = learnable_init_states self.activation = activation self.chunk_size = chunk_size self.use_mem_eff_path = use_mem_eff_path self.layer_idx = layer_idx # Order: [z, x, B, C, dt] d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs) conv_dim = self.d_inner + 2 * self.ngroups * self.d_state self.conv1d = nn.Conv1d( in_channels=conv_dim, out_channels=conv_dim, bias=conv_bias, kernel_size=d_conv, groups=conv_dim, padding=d_conv - 1, **factory_kwargs, ) if self.conv_init is not None: nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init) # self.conv1d.weight._no_weight_decay = True if self.learnable_init_states: self.init_states = nn.Parameter(torch.zeros(self.nheads, self.headdim, self.d_state, **factory_kwargs)) self.init_states._no_weight_decay = True self.act = nn.SiLU() # Initialize log dt bias dt = torch.exp( torch.rand(self.nheads, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) ) dt = torch.clamp(dt, min=dt_init_floor) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 inv_dt = dt + torch.log(-torch.expm1(-dt)) self.dt_bias = nn.Parameter(inv_dt) # Just to be explicit. Without this we already don't put wd on dt_bias because of the check # name.endswith("bias") in param_grouping.py self.dt_bias._no_weight_decay = True # A parameter assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0] A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(*A_init_range) A_log = torch.log(A).to(dtype=dtype) self.A_log = nn.Parameter(A_log) # self.register_buffer("A_log", torch.zeros(self.nheads, dtype=torch.float32, device=device), persistent=True) self.A_log._no_weight_decay = True # D "skip" parameter self.D = nn.Parameter(torch.ones(self.nheads, device=device)) self.D._no_weight_decay = True # Extra normalization layer right before output projection assert RMSNormGated is not None self.norm = RMSNormGated(self.d_inner, eps=1e-5, norm_before_gate=False, **factory_kwargs) self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) def forward(self, u, seq_idx=None): """ u: (B, L, D) Returns: same shape as u """ batch, seqlen, dim = u.shape zxbcdt = self.in_proj(u) # (B, L, d_in_proj) A = -torch.exp(self.A_log) # (nheads) or (d_inner, d_state) initial_states=repeat(self.init_states, "... -> b ...", b=batch) if self.learnable_init_states else None dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit) if self.use_mem_eff_path: # Fully fused path out = mamba_split_conv1d_scan_combined( zxbcdt, rearrange(self.conv1d.weight, "d 1 w -> d w"), self.conv1d.bias, self.dt_bias, A, D=self.D, chunk_size=self.chunk_size, seq_idx=seq_idx, activation=self.activation, rmsnorm_weight=self.norm.weight, rmsnorm_eps=self.norm.eps, outproj_weight=self.out_proj.weight, outproj_bias=self.out_proj.bias, headdim=self.headdim, ngroups=self.ngroups, norm_before_gate=False, initial_states=initial_states, **dt_limit_kwargs, ) else: z, xBC, dt = torch.split( zxbcdt, [self.d_inner, self.d_inner + 2 * self.ngroups * self.d_state, self.nheads], dim=-1 ) dt = F.softplus(dt + self.dt_bias) # (B, L, nheads) assert self.activation in ["silu", "swish"] # 1D Convolution if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]: xBC = self.act( self.conv1d(xBC.transpose(1, 2)).transpose(1, 2) ) # (B, L, self.d_inner + 2 * ngroups * d_state) xBC = xBC[:, :seqlen, :] else: xBC = causal_conv1d_fn( x=xBC.transpose(1, 2), weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), bias=self.conv1d.bias, activation=self.activation, ).transpose(1, 2) # Split into 3 main branches: X, B, C # These correspond to V, K, Q respectively in the SSM/attention duality x, B, C = torch.split(xBC, [self.d_inner, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1) y = mamba_chunk_scan_combined( rearrange(x, "b l (h p) -> b l h p", p=self.headdim), dt, A, rearrange(B, "b l (g n) -> b l g n", g=self.ngroups), rearrange(C, "b l (g n) -> b l g n", g=self.ngroups), chunk_size=self.chunk_size, D=self.D, z=None, seq_idx=seq_idx, initial_states=initial_states, **dt_limit_kwargs, ) y = rearrange(y, "b l h p -> b l (h p)") # Multiply "gate" branch and apply extra normalization layer y = self.norm(y, z) out = self.out_proj(y) return out ================================================ FILE: mamba_ssm/modules/mamba3.py ================================================ # Copyright (c) 2026, Dao AI Lab, Goombalab. import math from einops import rearrange, repeat import torch import torch.nn as nn import torch.nn.functional as F from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated from mamba_ssm.ops.tilelang.mamba3.mamba3_mimo import mamba3_mimo as mamba3_mimo_combined from mamba_ssm.ops.triton.angle_cumsum import angle_dt from mamba_ssm.ops.triton.mamba3.mamba3_siso_combined import mamba3_siso_combined from mamba_ssm.ops.triton.mamba3.mamba3_mimo_rotary_step import apply_rotary_qk_inference_fwd from mamba_ssm.ops.cute.mamba3.mamba3_step_fn import mamba3_step_fn class Mamba3(nn.Module): def __init__( self, d_model, d_state=128, expand=2, headdim=64, ngroups=1, # ---------------------------------------- # Mamba-3 configs rope_fraction=0.5, dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4, A_floor=1e-4, is_outproj_norm=False, is_mimo=False, mimo_rank=4, #------------------------------------------- # Fused kernel and sharding options chunk_size=64, # Recommended: 64 for SISO, 64/mimo_rank for MIMO dropout=0.0, # Just to absorb the kwarg layer_idx=None, # Absorb kwarg for general module n_layer=None, # Absorb kwarg for general module device=None, dtype=None, **kwargs, ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.d_model = d_model self.d_state = d_state self.expand = expand self.headdim = headdim self.chunk_size = chunk_size self.layer_idx = layer_idx self.A_floor = A_floor self.is_outproj_norm=is_outproj_norm self.is_mimo = is_mimo self.mimo_rank = mimo_rank if not self.is_mimo: self.mimo_rank = 1 self.d_inner = int(self.expand * self.d_model) assert self.d_inner % self.headdim == 0 self.nheads = self.d_inner // self.headdim self.num_bc_heads = ngroups # RoPE flags assert rope_fraction in [0.5, 1.0] self.rotary_dim_divisor = int(2/rope_fraction) self.split_tensor_size = int(d_state * rope_fraction) if self.split_tensor_size % 2 != 0: self.split_tensor_size -= 1 self.num_rope_angles = self.split_tensor_size // 2 assert self.num_rope_angles > 0 # Order: [z, x, B, C, dd_dt, dd_A, trap, angle] d_in_proj = 2 * self.d_inner + 2 * self.d_state * self.num_bc_heads * self.mimo_rank + 3 * self.nheads + self.num_rope_angles self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=False, **factory_kwargs) # dt_bias parameterization _dt = torch.exp( torch.rand(self.nheads, device=device, dtype=torch.float32) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) ) _dt = torch.clamp(_dt, min=dt_init_floor) _dt_bias = _dt + torch.log(-torch.expm1(-_dt)) self.dt_bias = nn.Parameter(_dt_bias, requires_grad=True) self.dt_bias._no_weight_decay = True # B and C biases self.B_bias = nn.Parameter(1+torch.zeros((self.nheads, self.mimo_rank, self.d_state), dtype=torch.float32, device=device), requires_grad=True) self.C_bias = nn.Parameter(1+torch.zeros((self.nheads, self.mimo_rank, self.d_state), dtype=torch.float32, device=device), requires_grad=True) # RMS Norm for B and C assert RMSNormGated is not None self.B_norm = RMSNormGated(self.d_state, eps=1e-5, **factory_kwargs) self.C_norm = RMSNormGated(self.d_state, eps=1e-5, **factory_kwargs) if self.is_mimo: # Initialize up/down MIMO projection (for x and z) mimo_x_init_weights = torch.ones(self.nheads, self.mimo_rank, self.headdim, device=device) / self.mimo_rank mimo_z_init_weights = torch.ones(self.nheads, self.mimo_rank, self.headdim, device=device) mimo_o_init_weights = torch.ones(self.nheads, self.mimo_rank, self.headdim, device=device) / self.mimo_rank self.mimo_x = nn.Parameter(mimo_x_init_weights, requires_grad=True) self.mimo_z = nn.Parameter(mimo_z_init_weights, requires_grad=True) self.mimo_o = nn.Parameter(mimo_o_init_weights, requires_grad=True) # D "skip" parameter self.D = nn.Parameter(torch.ones(self.nheads, device=device)) self.D._no_weight_decay = True if self.is_outproj_norm: self.norm = RMSNormGated( self.d_inner, eps=1e-5, norm_before_gate=True, group_size=self.headdim, **factory_kwargs ) # Output projection self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=False, **factory_kwargs) def forward(self, u, seq_idx=None, cu_seqlens=None, inference_params=None): """ u: (batch, seqlen, hidden_dim) Returns: same shape as u """ batch, seqlen, dim = u.shape if cu_seqlens is not None: raise NotImplementedError("Currently does not support varlen in Mamba-3 (MIMO).") angle_dt_state, ssm_state, k_state, v_state = None, None, None, None if inference_params is not None: inference_batch = cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch angle_dt_state, ssm_state, k_state, v_state = self._get_states_from_cache(inference_params, inference_batch) if inference_params.seqlen_offset > 0: # The states are updated inplace here; however, due to the current implementation, # setting inplace=True incurs significant overhead. So potentially # faster to call step() directly with inplace=False: out, _, _, _, _ = self.step(u, angle_dt_state, ssm_state, k_state, v_state) return out # Apply in_proj zxBCdtAtrap = self.in_proj(u) z, x, B, C, dd_dt, dd_A, trap, angles = torch.split( zxBCdtAtrap, [ self.d_inner, self.d_inner, self.d_state * self.num_bc_heads * self.mimo_rank, self.d_state * self.num_bc_heads * self.mimo_rank, self.nheads, self.nheads, self.nheads, self.num_rope_angles ], dim=-1) z = rearrange(z, "b l (h p) -> b l h p", p=self.headdim) x = rearrange(x, "b l (h p) -> b l h p", p=self.headdim) B = rearrange(B, "b l (r g n) -> b l r g n", r=self.mimo_rank, g=self.num_bc_heads) C = rearrange(C, "b l (r g n) -> b l r g n", r=self.mimo_rank, g=self.num_bc_heads) trap = rearrange(trap, "b l h -> b h l") # Compute ADT, DT _A = -F.softplus(dd_A.to(torch.float32)) # (B, L, N) _A = torch.clamp(_A, max=-self.A_floor) DT = F.softplus(dd_dt + self.dt_bias) # (B, L, N) ADT = _A * DT DT = rearrange(DT, "b l n -> b n l") ADT = rearrange(ADT, "b l n -> b n l") # Compute angle angles = angles.unsqueeze(-2).expand(-1, -1, self.nheads, -1) # (B, L, N, S) # Apply RMS Norm on B and C B = self.B_norm(B) C = self.C_norm(C) # Apply Mamba-3 kernel if self.is_mimo: angles = angle_dt(angles, DT.transpose(-1, -2)) # (B, L, N, S) y = mamba3_mimo_combined( Q=C, K=B, V=x, ADT=ADT, DT=DT, Trap=trap, Q_bias=self.C_bias, K_bias=self.B_bias, MIMO_V=self.mimo_x, MIMO_Z=self.mimo_z, MIMO_Out=self.mimo_o if not self.is_outproj_norm else None, Angles=angles, D=self.D, Z=z if not self.is_outproj_norm else None, chunk_size=self.chunk_size, rotary_dim_divisor=self.rotary_dim_divisor, dtype=x.dtype, return_state=ssm_state is not None, ) if ssm_state is not None: y, last_angle, last_state, last_k, last_v, *rest = y angle_dt_state.copy_(last_angle) ssm_state.copy_(last_state) k_state.copy_(last_k) v_state.copy_(last_v) if self.is_outproj_norm: z = torch.einsum("blhp,hrp->blrhp", z.float(), self.mimo_z) z = rearrange(z, "b l r h p -> b l r (h p)") y = rearrange(y, "b l r h p -> b l r (h p)").float() y = self.norm(y, z) y = rearrange(y, "b l r (h p) -> b l r h p", p=self.headdim) y = torch.einsum("blrhp,hrp->blhp", y, self.mimo_o) y = rearrange(y, "b l h p -> b l (h p)") else: y = mamba3_siso_combined( Q=C.squeeze(2), K=B.squeeze(2), V=x, ADT=ADT, DT=DT, Trap=trap, Q_bias=self.C_bias.squeeze(1), K_bias=self.B_bias.squeeze(1), Angles=angles, D=self.D, Z=z if not self.is_outproj_norm else None, chunk_size=self.chunk_size, Input_States=None, return_final_states=ssm_state is not None, ) if ssm_state is not None: y, last_angle, last_state, last_k, last_v, *rest = y angle_dt_state.copy_(last_angle) ssm_state.copy_(last_state) k_state.copy_(last_k) v_state.copy_(last_v) y = rearrange(y, "b l h p -> b l (h p)") if self.is_outproj_norm: z = rearrange(z, "b l h p -> b l (h p)") y = self.norm(y, z) out = self.out_proj(y.to(x.dtype)) return out def _preprocess(self, A_proj, dd_dt, B, C, x, z, trap_proj, angle_proj): _A = -F.softplus(A_proj.to(torch.float32)) _A = torch.clamp(_A, max=-self.A_floor) DT = F.softplus(dd_dt + self.dt_bias) trap = torch.sigmoid(trap_proj) rank = self.mimo_rank if self.is_mimo else 1 B = rearrange(B, "b (r g s) -> b r g s", g=self.num_bc_heads, r=rank) C = rearrange(C, "b (r g s) -> b r g s", g=self.num_bc_heads, r=rank) B = self.B_norm(B) C = self.C_norm(C) B = B.expand(-1, -1, self.nheads, -1) # (B, R, N, S) C = C.expand(-1, -1, self.nheads, -1) # (B, R, N, S) x = rearrange(x, "b (h p) -> b h p", p=self.headdim) z = rearrange(z, "b (h p) -> b h p", p=self.headdim) angles = angle_proj.unsqueeze(-2).expand(-1, self.nheads, -1) return DT, B, C, x, z, trap, _A, angles def _postprocess(self, y, outpj, z, zpj, headdim): # y: (batch, R, H, D) — apply mimo_z to z, then norm, then mimo_o z_r = torch.einsum("bhp,rhp->brhp", z.float(), zpj) # (batch, R, H, D) z_r = rearrange(z_r, "b r h p -> b r (h p)") y = rearrange(y, "b r h p -> b r (h p)").float() y = self.norm(y, z_r) y = rearrange(y, "b r (h p) -> b r h p", p=headdim) y = torch.einsum("brhp,rhp->bhp", y, outpj) # (batch, H, D) return y def step(self, u, angle_state, ssm_state, k_state, v_state, **kwargs): """ Decode function using CuteDSL kernel from mamba3_step_fn.py. Also modify the state vars in-place for the next step. NOTE: Only tested on H100. Compatibility with other hardware will be made available in the future. Args: u: (batch, d_model) angle_state: (batch, nheads, num_rope_angles) ssm_state: (batch, nheads, headdim, d_state) k_state: (batch, R, nheads, d_state), where R = mimo_rank (R=1 if not MIMO) v_state: (batch, nheads, headdim) **kwargs: ignored Returns: out: (batch, d_model) nxt_angle_state: (batch, nheads, num_rope_angles) state_out: (batch, nheads, headdim, d_state) nxt_k_state: (batch, R, nheads, d_state), where R = mimo_rank (R=1 if not MIMO) nxt_v_state: (batch, nheads, headdim) """ # in_proj zxBCdt = self.in_proj(u) z, x, B, C, dd_dt, dd_A, trap, angles = torch.split( zxBCdt, [ self.d_inner, self.d_inner, self.d_state * self.num_bc_heads * self.mimo_rank, self.d_state * self.num_bc_heads * self.mimo_rank, self.nheads, self.nheads, self.nheads, self.num_rope_angles, ], dim=-1) DT, B, C, x, z, trap, A, angles = self._preprocess( dd_A, dd_dt, B, C, x, z, trap, angles) bias_q = rearrange(self.C_bias, "h r n -> r h n") bias_k = rearrange(self.B_bias, "h r n -> r h n") # NOTE: MIMO calls the Tilelang kernel, # which permute the blockwise rotation matrix so that # the i-th entry is paired with the i+N//2-th entry: rotate_pairwise = not self.is_mimo C, B, nxt_angle_state = apply_rotary_qk_inference_fwd( q=C, k=B, angle_state=angle_state, angle_proj=angles, dt=DT, bias_q=bias_q, bias_k=bias_k, conjugate=False, inplace=False, # NOTE: inplace is incompatible with self.nheads != self.num_bc_heads rotate_pairwise=rotate_pairwise) nxt_v_state = x nxt_k_state = B if self.is_mimo: xpj = rearrange(self.mimo_x, "h r p -> r h p", p=self.headdim).contiguous() zpj = rearrange(self.mimo_z, "h r p -> r h p", p=self.headdim).contiguous() outpj = rearrange(self.mimo_o, "h r p -> r h p", p=self.headdim).contiguous() else: xpj = torch.ones(self.mimo_rank, self.nheads, self.headdim, device=x.device, dtype=x.dtype) zpj = torch.ones(self.mimo_rank, self.nheads, self.headdim, device=z.device, dtype=z.dtype) outpj = torch.ones(self.mimo_rank, self.nheads, self.headdim, device=x.device, dtype=x.dtype) if self.is_outproj_norm: batch = x.shape[0] y = torch.empty(batch, self.mimo_rank, self.nheads, self.headdim, device=x.device, dtype=x.dtype) mamba3_step_fn( ssm_state, k_state, v_state, A, B, C, self.D, x, DT, trap, xpj, outproj=None, state_out=None, # can be not in place if pass in state_out out=y, z=None, zproj=None, tile_D=64, num_warps=4, ) y = self._postprocess(y, outpj, z, zpj, self.headdim) else: y = torch.empty_like(x) mamba3_step_fn( ssm_state, k_state, v_state, A, B, C, self.D, x, DT, trap, xpj, outproj=outpj, state_out=None, # can be not in place if pass in state_out out=y, z=z, zproj=zpj, tile_D=64, num_warps=4, ) # out_proj out = rearrange(y, "b h p -> b (h p)") out = self.out_proj(out.to(x.dtype)) angle_state.copy_(nxt_angle_state) # Uncomment the following if mamba3_step_fn is not in place: # state_out = torch.empty_like(ssm_state) # ssm_state.copy_(state_out) k_state.copy_(nxt_k_state) v_state.copy_(nxt_v_state) return out, nxt_angle_state, ssm_state, nxt_k_state, nxt_v_state def allocate_inference_cache(self, batch_size, max_seqlen, device=None, dtype=None, inplace_state=None, **kwargs): device = self.in_proj.weight.device if device is None else device dtype = self.in_proj.weight.dtype if dtype is None else dtype # RoPE State angle_dt_state = torch.zeros( (batch_size, self.nheads, self.num_rope_angles), device=device, dtype=torch.float32, ) # Mamba-3 Combined Kernel States # SSM State ssm_state = torch.zeros( (batch_size, self.nheads, self.headdim, self.d_state), device=device, dtype=torch.float32, ) # K (=B) State if self.is_mimo: k_state = torch.zeros( (batch_size, self.mimo_rank, self.nheads, self.d_state), device=device, dtype=dtype, ) else: k_state = torch.zeros( (batch_size, 1, self.nheads, self.d_state), device=device, dtype=dtype, ) # V (=x) State v_state = torch.zeros( (batch_size, self.nheads, self.headdim), device=device, dtype=dtype, ) return (angle_dt_state, ssm_state, k_state, v_state) def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): assert self.layer_idx is not None device = self.in_proj.weight.device dtype = self.in_proj.weight.dtype if self.layer_idx not in inference_params.key_value_memory_dict: angle_dt_state = torch.zeros( (batch_size, self.nheads, self.num_rope_angles), device=device, dtype=torch.float32, ) ssm_state = torch.zeros( (batch_size, self.nheads, self.headdim, self.d_state), device=device, dtype=torch.float32, ) if self.is_mimo: k_state = torch.zeros( (batch_size, self.mimo_rank, self.nheads, self.d_state), device=device, dtype=dtype, ) else: k_state = torch.zeros( (batch_size, self.nheads, self.d_state), device=device, dtype=dtype, ) v_state = torch.zeros( (batch_size, self.nheads, self.headdim), device=device, dtype=dtype, ) inference_params.key_value_memory_dict[self.layer_idx] = (angle_dt_state, ssm_state, k_state, v_state) else: angle_dt_state, ssm_state, k_state, v_state = inference_params.key_value_memory_dict[self.layer_idx] # TODO: What if batch size changes between generation, and we reuse the same states? if initialize_states: angle_dt_state.zero_() ssm_state.zero_() k_state.zero_() v_state.zero_() return angle_dt_state, ssm_state, k_state, v_state ================================================ FILE: mamba_ssm/modules/mamba_simple.py ================================================ # Copyright (c) 2023, Tri Dao, Albert Gu. import math from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from einops import rearrange, repeat from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn try: from causal_conv1d import causal_conv1d_fn, causal_conv1d_update except ImportError: causal_conv1d_fn, causal_conv1d_update = None, None try: from mamba_ssm.ops.triton.selective_state_update import selective_state_update except ImportError: selective_state_update = None try: from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn except ImportError: RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None class Mamba(nn.Module): def __init__( self, d_model, d_state=16, d_conv=4, expand=2, dt_rank="auto", dt_min=0.001, dt_max=0.1, dt_init="random", dt_scale=1.0, dt_init_floor=1e-4, conv_bias=True, bias=False, use_fast_path=True, # Fused kernel options layer_idx=None, device=None, dtype=None, ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.d_model = d_model self.d_state = d_state self.d_conv = d_conv self.expand = expand self.d_inner = int(self.expand * self.d_model) self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank self.use_fast_path = use_fast_path self.layer_idx = layer_idx self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) self.conv1d = nn.Conv1d( in_channels=self.d_inner, out_channels=self.d_inner, bias=conv_bias, kernel_size=d_conv, groups=self.d_inner, padding=d_conv - 1, **factory_kwargs, ) self.activation = "silu" self.act = nn.SiLU() self.x_proj = nn.Linear( self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs ) self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) # Initialize special dt projection to preserve variance at initialization dt_init_std = self.dt_rank**-0.5 * dt_scale if dt_init == "constant": nn.init.constant_(self.dt_proj.weight, dt_init_std) elif dt_init == "random": nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std) else: raise NotImplementedError # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max dt = torch.exp( torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) ).clamp(min=dt_init_floor) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 inv_dt = dt + torch.log(-torch.expm1(-dt)) with torch.no_grad(): self.dt_proj.bias.copy_(inv_dt) # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit self.dt_proj.bias._no_reinit = True # S4D real initialization A = repeat( torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), "n -> d n", d=self.d_inner, ).contiguous() A_log = torch.log(A) # Keep A_log in fp32 self.A_log = nn.Parameter(A_log) self.A_log._no_weight_decay = True # D "skip" parameter self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 self.D._no_weight_decay = True self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) def forward(self, hidden_states, inference_params=None): """ hidden_states: (B, L, D) Returns: same shape as hidden_states """ batch, seqlen, dim = hidden_states.shape conv_state, ssm_state = None, None if inference_params is not None: conv_state, ssm_state = self._get_states_from_cache(inference_params, batch) if inference_params.seqlen_offset > 0: # The states are updated inplace out, _, _ = self.step(hidden_states, conv_state, ssm_state) return out # We do matmul and transpose BLH -> HBL at the same time xz = rearrange( self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"), "d (b l) -> b d l", l=seqlen, ) if self.in_proj.bias is not None: xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1") A = -torch.exp(self.A_log.float()) # (d_inner, d_state) # In the backward pass we write dx and dz next to each other to avoid torch.cat if self.use_fast_path and causal_conv1d_fn is not None and inference_params is None: # Doesn't support outputting the states out = mamba_inner_fn( xz, self.conv1d.weight, self.conv1d.bias, self.x_proj.weight, self.dt_proj.weight, self.out_proj.weight, self.out_proj.bias, A, None, # input-dependent B None, # input-dependent C self.D.float(), delta_bias=self.dt_proj.bias.float(), delta_softplus=True, ) else: x, z = xz.chunk(2, dim=1) # Compute short convolution if conv_state is not None: # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W) if causal_conv1d_fn is None: x = self.act(self.conv1d(x)[..., :seqlen]) else: assert self.activation in ["silu", "swish"] x = causal_conv1d_fn( x=x, weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), bias=self.conv1d.bias, activation=self.activation, ) # We're careful here about the layout, to avoid extra transposes. # We want dt to have d as the slowest moving dimension # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d) dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) dt = self.dt_proj.weight @ dt.t() dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() assert self.activation in ["silu", "swish"] y = selective_scan_fn( x, dt, A, B, C, self.D.float(), z=z, delta_bias=self.dt_proj.bias.float(), delta_softplus=True, return_last_state=ssm_state is not None, ) if ssm_state is not None: y, last_state = y ssm_state.copy_(last_state) y = rearrange(y, "b d l -> b l d") out = self.out_proj(y) return out def step(self, hidden_states, conv_state, ssm_state): dtype = hidden_states.dtype assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now" xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D) x, z = xz.chunk(2, dim=-1) # (B D) # Conv step if causal_conv1d_update is None: conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) conv_state[:, :, -1] = x x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) if self.conv1d.bias is not None: x = x + self.conv1d.bias x = self.act(x).to(dtype=dtype) else: x = causal_conv1d_update( x, conv_state, rearrange(self.conv1d.weight, "d 1 w -> d w"), self.conv1d.bias, self.activation, ) x_db = self.x_proj(x) # (B dt_rank+2*d_state) dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) # Don't add dt_bias here dt = F.linear(dt, self.dt_proj.weight) # (B d_inner) A = -torch.exp(self.A_log.float()) # (d_inner, d_state) # SSM step if selective_state_update is None: # Discretize A and B dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype)) dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A)) dB = torch.einsum("bd,bn->bdn", dt, B) ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB) y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C) y = y + self.D.to(dtype) * x y = y * self.act(z) # (B D) else: y = selective_state_update( ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True ) out = self.out_proj(y) return out.unsqueeze(1), conv_state, ssm_state def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): device = self.out_proj.weight.device conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype conv_state = torch.zeros( batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype ) ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype # ssm_dtype = torch.float32 ssm_state = torch.zeros( batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype ) return conv_state, ssm_state def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): assert self.layer_idx is not None if self.layer_idx not in inference_params.key_value_memory_dict: batch_shape = (batch_size,) conv_state = torch.zeros( batch_size, self.d_model * self.expand, self.d_conv, device=self.conv1d.weight.device, dtype=self.conv1d.weight.dtype, ) ssm_state = torch.zeros( batch_size, self.d_model * self.expand, self.d_state, device=self.dt_proj.weight.device, dtype=self.dt_proj.weight.dtype, # dtype=torch.float32, ) inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state) else: conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] # TODO: What if batch size changes between generation, and we reuse the same states? if initialize_states: conv_state.zero_() ssm_state.zero_() return conv_state, ssm_state ================================================ FILE: mamba_ssm/modules/mha.py ================================================ # Copyright (c) 2024, Tri Dao, Albert Gu. import math import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange try: from flash_attn import flash_attn_with_kvcache except ImportError: flash_attn_with_kvcache = None try: from flash_attn.layers.rotary import RotaryEmbedding except ImportError: RotaryEmbedding = None try: from causal_conv1d import causal_conv1d_fn, causal_conv1d_update except ImportError: causal_conv1d_fn, causal_conv1d_update = None, None def _update_kv_cache(kv, inference_params, layer_idx): """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)""" # Pre-allocate memory for key-values for inference. num_heads, head_dim = kv.shape[-2:] assert layer_idx in inference_params.key_value_memory_dict kv_cache, _ = inference_params.key_value_memory_dict[layer_idx] # Adjust key and value for inference batch_start = inference_params.batch_size_offset batch_end = batch_start + kv.shape[0] sequence_start = inference_params.seqlen_offset sequence_end = sequence_start + kv.shape[1] assert batch_end <= kv_cache.shape[0] assert sequence_end <= kv_cache.shape[1] assert kv_cache is not None kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv return kv_cache[batch_start:batch_end, :sequence_end, ...] class MHA(nn.Module): """Multi-head self-attention and cross-attention""" def __init__( self, embed_dim, num_heads, num_heads_kv=None, head_dim=None, # If None, use embed_dim // num_heads mlp_dim=0, qkv_proj_bias=True, out_proj_bias=True, softmax_scale=None, causal=False, layer_idx=None, d_conv=0, rotary_emb_dim=0, rotary_emb_base=10000.0, rotary_emb_interleaved=False, device=None, dtype=None, ) -> None: """ num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads. return_residual: whether to return the input x along with the output. This is for performance reason: for post-norm architecture, returning the input allows us to fuse the backward of nn.Linear with the residual connection. """ factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.embed_dim = embed_dim self.layer_idx = layer_idx self.d_conv = d_conv self.rotary_emb_dim = rotary_emb_dim self.softmax_scale = softmax_scale self.causal = causal self.num_heads = num_heads self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads assert ( self.num_heads % self.num_heads_kv == 0 ), "num_heads must be divisible by num_heads_kv" if head_dim is None: assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" self.head_dim = head_dim if head_dim is not None else self.embed_dim // num_heads self.mlp_dim = math.ceil(mlp_dim / 256) * 256 qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv) out_dim = self.head_dim * self.num_heads if self.rotary_emb_dim > 0: assert RotaryEmbedding is not None, "rotary requires flash_attn to be installed" self.rotary_emb = RotaryEmbedding( self.rotary_emb_dim, base=rotary_emb_base, interleaved=rotary_emb_interleaved, device=device, ) self.in_proj = nn.Linear(embed_dim, qkv_dim + self.mlp_dim, bias=qkv_proj_bias, **factory_kwargs) if self.d_conv > 0: self.conv1d = nn.Conv1d( qkv_dim, qkv_dim, kernel_size=self.d_conv, padding=self.d_conv - 1, groups=qkv_dim, **factory_kwargs ) self.out_proj = nn.Linear(out_dim + self.mlp_dim // 2, embed_dim, bias=out_proj_bias, **factory_kwargs) def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None): dtype = self.out_proj.weight.dtype if dtype is None else dtype device = self.out_proj.weight.device if self.d_conv > 0: conv_state = torch.zeros( batch_size, self.conv1d.weight.shape[0], self.d_conv, device=device, dtype=dtype ) else: conv_state = None kv_cache = torch.empty( batch_size, max_seqlen, 2, self.num_heads_kv, self.head_dim, dtype=dtype, device=device, ) return kv_cache, conv_state def _update_kv_cache(self, kv, inference_params): """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)""" assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" return _update_kv_cache(kv, inference_params, self.layer_idx) def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params): """ Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention. q: (batch_size, seqlen_q, nheads, head_dim) kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim) """ assert inference_params is not None and inference_params.seqlen_offset > 0 if self.rotary_emb_dim > 0: self.rotary_emb._update_cos_sin_cache( inference_params.max_seqlen, device=q.device, dtype=q.dtype ) rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached else: rotary_cos, rotary_sin = None, None batch = q.shape[0] kv_cache, _ = inference_params.key_value_memory_dict[self.layer_idx] kv_cache = kv_cache[:batch] cache_seqlens = ( inference_params.lengths_per_sample[:batch] if inference_params.lengths_per_sample is not None else inference_params.seqlen_offset ) assert flash_attn_with_kvcache is not None, "flash_attn must be installed" context = flash_attn_with_kvcache( q, kv_cache[:, :, 0], kv_cache[:, :, 1], kv[:, :, 0], kv[:, :, 1], rotary_cos=rotary_cos, rotary_sin=rotary_sin, cache_seqlens=cache_seqlens, softmax_scale=self.softmax_scale, causal=self.causal, rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False, ) return context def _update_kvcache_attention(self, q, kv, inference_params): """Write kv to inference_params, then do attention""" if ( inference_params.seqlen_offset == 0 or flash_attn_with_kvcache is None ): # TODO: this only uses seqlen_offset and not lengths_per_sample. kv = self._update_kv_cache(kv, inference_params) k, v = kv.unbind(dim=-3) k = torch.repeat_interleave(k, dim=2, repeats=self.num_heads // self.num_heads_kv) v = torch.repeat_interleave(v, dim=2, repeats=self.num_heads // self.num_heads_kv) return F.scaled_dot_product_attention( q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=self.causal, scale=self.softmax_scale ).transpose(1, 2) else: batch = q.shape[0] kv_cache, _ = inference_params.key_value_memory_dict[self.layer_idx] kv_cache = kv_cache[:batch] cache_seqlens = ( inference_params.lengths_per_sample[:batch] if inference_params.lengths_per_sample is not None else inference_params.seqlen_offset ) return flash_attn_with_kvcache( q, kv_cache[:, :, 0], kv_cache[:, :, 1], kv[:, :, 0], kv[:, :, 1], cache_seqlens=cache_seqlens, softmax_scale=self.softmax_scale, causal=self.causal, ) def forward(self, x, inference_params=None): """ Arguments: x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total is the is the sum of the sequence lengths in the batch. inference_params: for generation. Adapted from Megatron-LM (and Apex) https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470 """ if inference_params is not None and self.layer_idx not in inference_params.key_value_memory_dict: inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache( x.shape[0], inference_params.max_seqlen, dtype=x.dtype ) seqlen_offset = ( 0 if inference_params is None else ( inference_params.lengths_per_sample if inference_params.lengths_per_sample is not None else inference_params.seqlen_offset ) ) rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None qkv = self.in_proj(x) if self.mlp_dim > 0: qkv, x_mlp = qkv.split([qkv.shape[-1] - self.mlp_dim, self.mlp_dim], dim=-1) x_mlp_up, x_mlp_gate = x_mlp.chunk(2, dim=-1) x_mlp = x_mlp_up * F.silu(x_mlp_gate) if self.d_conv > 0: # The inference code for conv1d is pretty messy, should clean it up if (inference_params is None or inference_params.seqlen_offset == 0): if causal_conv1d_fn is None: qkv = rearrange( self.conv1d(rearrange(qkv, "b s d -> b d s"))[..., :-(self.d_conv - 1)], "b d s -> b s d" ).contiguous() else: qkv = causal_conv1d_fn( qkv.transpose(1, 2), rearrange(self.conv1d.weight, "d 1 w -> d w"), self.conv1d.bias ).transpose(1, 2) if inference_params is not None: _, conv_state = inference_params.key_value_memory_dict[self.layer_idx] # If we just take qkv[:, :, -self.d_conv :], it will error if seqlen < self.d_conv # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. qkv_t = rearrange(qkv, "b l d -> b d l") conv_state.copy_(F.pad(qkv_t, (self.d_conv - qkv_t.shape[-1], 0))) # Update state (B D W) else: _, conv_state = inference_params.key_value_memory_dict[self.layer_idx] assert qkv.shape[1] == 1, "Only support decoding with 1 token at a time for now" qkv = qkv.squeeze(1) # Conv step if causal_conv1d_update is None: conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) conv_state[:, :, -1] = qkv qkv = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) if self.conv1d.bias is not None: qkv = qkv + self.conv1d.bias else: qkv = causal_conv1d_update( qkv, conv_state, rearrange(self.conv1d.weight, "d 1 w -> d w"), self.conv1d.bias ) qkv = qkv.unsqueeze(1) q, kv = qkv.split([self.num_heads * self.head_dim, self.num_heads_kv * 2 * self.head_dim], dim=-1) q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim) kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim) if ( inference_params is None or inference_params.seqlen_offset == 0 or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0) ): if self.rotary_emb_dim > 0: q, kv = self.rotary_emb( q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen ) if inference_params is None: k, v = kv.unbind(dim=-3) k = torch.repeat_interleave(k, dim=2, repeats=self.num_heads // self.num_heads_kv) v = torch.repeat_interleave(v, dim=2, repeats=self.num_heads // self.num_heads_kv) context = F.scaled_dot_product_attention( q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=self.causal, scale=self.softmax_scale ).transpose(1, 2) else: context = self._update_kvcache_attention(q, kv, inference_params) else: context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params) context = rearrange(context, "... h d -> ... (h d)") if self.mlp_dim > 0: context = torch.cat([context, x_mlp], dim=-1) out = self.out_proj(context) return out ================================================ FILE: mamba_ssm/modules/mlp.py ================================================ # Copyright (c) 2024, Tri Dao, Albert Gu. from torch import nn from torch.nn import functional as F class GatedMLP(nn.Module): def __init__( self, in_features, hidden_features=None, out_features=None, activation=F.silu, bias=False, multiple_of=128, device=None, dtype=None, ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() out_features = out_features if out_features is not None else in_features hidden_features = ( hidden_features if hidden_features is not None else int(8 * in_features / 3) ) hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias, **factory_kwargs) self.activation = activation self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, **factory_kwargs) def forward(self, x): y = self.fc1(x) y, gate = y.chunk(2, dim=-1) y = y * self.activation(gate) y = self.fc2(y) return y ================================================ FILE: mamba_ssm/modules/ssd_minimal.py ================================================ # Copyright (c) 2024, Albert Gu and Tri Dao. """Minimal implementation of SSD. This is the same as Listing 1 from the paper. """ import torch import torch.nn.functional as F from einops import rearrange, repeat from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined def segsum_unstable(x): """Naive segment sum calculation.""" T = x.size(-1) x_cumsum = torch.cumsum(x, dim=-1) x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :] mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0) x_segsum = x_segsum.masked_fill(~mask, -torch.inf) return x_segsum def segsum(x): """More stable segment sum calculation.""" T = x.size(-1) x = repeat(x, "... d -> ... d e", e=T) mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1) x = x.masked_fill(~mask, 0) x_segsum = torch.cumsum(x, dim=-2) mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0) x_segsum = x_segsum.masked_fill(~mask, -torch.inf) return x_segsum def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None): """ Arguments: X: (batch, length, n_heads, d_head) A: (batch, length, n_heads) B: (batch, length, n_heads, d_state) C: (batch, length, n_heads, d_state) Return: Y: (batch, length, n_heads, d_head) """ assert X.dtype == A.dtype == B.dtype == C.dtype assert X.shape[1] % block_len == 0 # Rearrange into blocks/chunks X, A, B, C = [rearrange(x, "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C)] A = rearrange(A, "b c l h -> b h c l") A_cumsum = torch.cumsum(A, dim=-1) # 1. Compute the output for each intra-chunk (diagonal blocks) L = torch.exp(segsum(A)) Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X) # 2. Compute the state for each intra-chunk # (right term of low-rank factorization of off-diagonal blocks; B terms) decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X) # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries # (middle term of factorization of off-diag blocks; A terms) if initial_states is None: initial_states = torch.zeros_like(states[:, :1]) states = torch.cat([initial_states, states], dim=1) decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)))) new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states) states, final_state = new_states[:, :-1], new_states[:, -1] # 4. Compute state -> output conversion per chunk # (left term of low-rank factorization of off-diagonal blocks; C terms) state_decay_out = torch.exp(A_cumsum) Y_off = torch.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out) # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) Y = rearrange(Y_diag+Y_off, "b c l h p -> b (c l) h p") return Y, final_state # Simple test def test_correctness(): torch.manual_seed(42) ## Dimensions # Denoted (B, T, Q, D, P) in the paper batch, seqlen, chunk_size, dim, headdim = 1, 2048, 64, 2048, 64 nheads = dim // headdim # (H) in the paper ngroups = 1 # (G) in the paper dstate = 64 # (N) in the paper dtype = torch.float32 device = "cuda" x = torch.randn(batch, seqlen, nheads, headdim, dtype=dtype, device=device) dt = F.softplus(torch.randn(batch, seqlen, nheads, dtype=torch.float32, device=device) - 4).requires_grad_() A = (-torch.exp(torch.rand(nheads, dtype=torch.float32, device=device))).requires_grad_() B = torch.randn(batch, seqlen, ngroups, dstate, dtype=dtype, device=device) C = torch.randn(batch, seqlen, ngroups, dstate, dtype=dtype, device=device) D = torch.randn(nheads, dtype=dtype, device=device) # Comparing fused version and minimal version y = mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None) y_min, _ = ssd_minimal_discrete(x*dt.unsqueeze(-1), A*dt, B, C, chunk_size) ================================================ FILE: mamba_ssm/ops/__init__.py ================================================ ================================================ FILE: mamba_ssm/ops/cute/mamba3/mamba3_step_fn.py ================================================ # Copyright (c) 2025, Tri Dao. # Modified to use tvm-ffi and fake tensors instead of dlpack. # Modified to optionally update state in place (state_out=None) or write to separate state_out. import math from typing import Optional, Type, Literal, List import torch import torch.nn.functional as F from torch import Tensor import cuda.bindings.driver as cuda import cutlass import cutlass.cute as cute from cutlass import Int32, Float32, Float16, BFloat16, Boolean, const_expr from quack.compile_utils import make_fake_tensor from quack.cute_dsl_utils import torch2cute_dtype_map def transpose_view(a: cute.Tensor) -> cute.Tensor: """Transpose the first two dimensions of a tensor on smem.""" shape = (a.shape[1], a.shape[0], *a.shape[2:]) order = (1, 0, *range(2, cute.rank(a))) return cute.composition(a, cute.make_ordered_layout(shape, order=order)) def select(a: cute.Tensor, mode: List[int]) -> cute.Tensor: return cute.make_tensor(a.iterator, cute.select(a.layout, mode)) def get_gmem_tiled_copy(dtype: Type[cutlass.Numeric], major_mode_size: int, num_threads: int, is_async: bool = True): num_copy_bits = math.gcd(major_mode_size, 128 // dtype.width) * dtype.width copy_elems = num_copy_bits // dtype.width copy_op = cute.nvgpu.cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) gmem_threads_per_row = major_mode_size // copy_elems assert num_threads % gmem_threads_per_row == 0 thr_layout = cute.make_ordered_layout( (num_threads // gmem_threads_per_row, gmem_threads_per_row), order=(1, 0), ) val_layout = cute.make_layout((1, copy_elems)) return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) class Mamba3Step(): def __init__(self, tile_D: int, dstate: int, mimo: int = 1, num_warps: int = 4, remove_gate: bool = False, remove_outproj: bool = False): assert num_warps >= 2 assert dstate % 8 == 0, "dstate must be multiple of 8" # for vectorized load /store self.tile_D = tile_D self.dstate = dstate self.mimo = mimo self.num_warps = num_warps self.remove_gate = remove_gate self.remove_outproj = remove_outproj def _setup_smem_layouts(self): self.sState_layout = cute.make_ordered_layout((self.tile_D, self.dstate), order=(1, 0)) # We don't need any swizzling for Bstate, B, C self.sBC_layout = cute.make_ordered_layout((self.mimo, self.dstate), order=(1, 0)) # We don't need any swizzling for Xproj, Zproj, Outproj self.sProj_layout = cute.make_ordered_layout((self.mimo, self.tile_D), order=(1, 0)) def _setup_gmem_tiled_copy(self, ): num_threads = self.num_warps * cute.arch.WARP_SIZE self.gmem_tiled_copy_state = get_gmem_tiled_copy(self.dtype, self.dstate, num_threads) self.gmem_tiled_copy_BC = get_gmem_tiled_copy(self.b_dtype, self.dstate, num_threads) self.gmem_tiled_copy_Proj = get_gmem_tiled_copy(self.proj_dtype, self.tile_D, num_threads) # Gmem tiled copy for X, Z # e.g. for tile_D = 64, we only want each thread loading 2 values copy_elems_x = const_expr(min(4, cute.ceil_div(self.tile_D, cute.arch.WARP_SIZE))) num_copy_bits_x = copy_elems_x * self.x_dtype.width copy_atom_load_x = cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), self.x_dtype, num_bits_per_copy=num_copy_bits_x ) gmem_threads_per_row = self.tile_D // copy_elems_x assert cute.arch.WARP_SIZE >= gmem_threads_per_row # Only 1 warp loads X, Z self.gmem_tiled_copy_X = cute.make_tiled_copy_tv( copy_atom_load_x, cute.make_layout(self.tile_D // copy_elems_x), cute.make_layout(copy_elems_x) ) @cute.jit def __call__( # B: batch size, H: num heads, D: head dim, N: dstate, R: mimo self, mState: cute.Tensor, # (B, H, D, N) mBstate: cute.Tensor, # (B, R, H, N) mXstate: cute.Tensor, # (B, H, D) mA: cute.Tensor, # (B, H) mB: cute.Tensor, # (B, R, H, N) mC: cute.Tensor, # (B, R, H, N) mD: cute.Tensor, # (H) mX: cute.Tensor, # (B, H, D) mDt: cute.Tensor, # (B, H) mTrap: cute.Tensor, # (B, H) mXproj: cute.Tensor, # (R, H, D) mOutproj: Optional[cute.Tensor], # (R, H, D), None if remove_outproj mStateOut: cute.Tensor, # (B, H, D, N) — same as mState for in-place, or separate mOut: cute.Tensor, # (B, H, D) or (B, R, H, D) if remove_outproj mZ: Optional[cute.Tensor], # (B, H, D), None if remove_gate mZproj: Optional[cute.Tensor], # (R, H, D), None if remove_gate stream: cuda.CUstream, ): self.dtype = mState.element_type self.b_dtype = mB.element_type self.proj_dtype = mXproj.element_type self.x_dtype = mX.element_type assert mStateOut.element_type == self.dtype assert mBstate.element_type == mB.element_type == mC.element_type if const_expr(mOutproj is not None): assert mXproj.element_type == mOutproj.element_type if const_expr(mZ is not None): assert mXproj.element_type == mZproj.element_type assert mZ.element_type == self.x_dtype self._setup_smem_layouts() self._setup_gmem_tiled_copy() # TV layout, this is the most important step as it decides which elements in B, C, State # each thread will load from smem num_threads = self.num_warps * cute.arch.WARP_SIZE # TODO: these need to be adjusted based on dstate and tile_D assert self.dstate in [32, 64, 128] # TODO: This is not optimal for dstate=32 and 64, just to get sth quick to run vecsize_dstate = 4 if self.dstate == 128 else 2 if self.dstate == 64 else 1 threads_per_dstate = self.dstate // vecsize_dstate assert cute.arch.WARP_SIZE % threads_per_dstate == 0 num_groups = num_threads // threads_per_dstate assert self.tile_D % num_groups == 0 lanes_per_D = self.tile_D // num_groups copy_atom_state_s2r = cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), mState.element_type, num_bits_per_copy=vecsize_dstate * mState.element_type.width ) tiled_copy_state_s2r = cute.make_tiled_copy_tv( copy_atom_state_s2r, cute.make_ordered_layout((num_groups, threads_per_dstate), order=(1, 0)), cute.make_ordered_layout((lanes_per_D, vecsize_dstate), order=(1, 0)), ) copy_atom_B_s2r = cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), mB.element_type, num_bits_per_copy=vecsize_dstate * mB.element_type.width ) tiled_copy_B_s2r = cute.make_tiled_copy_tv( copy_atom_B_s2r, cute.make_ordered_layout((1, threads_per_dstate), order=(1, 0)), cute.make_ordered_layout((1, vecsize_dstate), order=(1, 0)), ) self.buffer_align_bytes = 1024 sZproj_size = cute.cosize(self.sProj_layout) if not self.remove_gate else 0 sOutproj_size = cute.cosize(self.sProj_layout) if not self.remove_outproj else 0 @cute.struct class SharedStorage: sX: cute.struct.Align[cute.struct.MemRange[Float32, self.tile_D], 128] sXgamma: cute.struct.Align[cute.struct.MemRange[Float32, self.tile_D], 128] sXstate: cute.struct.Align[cute.struct.MemRange[Float32, self.tile_D], 128] sState: cute.struct.Align[ cute.struct.MemRange[self.dtype, cute.cosize(self.sState_layout)], self.buffer_align_bytes, ] sBstate: cute.struct.Align[ cute.struct.MemRange[self.b_dtype, cute.cosize(self.sBC_layout)], self.buffer_align_bytes, ] sB: cute.struct.Align[ cute.struct.MemRange[self.b_dtype, cute.cosize(self.sBC_layout)], self.buffer_align_bytes, ] sC: cute.struct.Align[ cute.struct.MemRange[self.b_dtype, cute.cosize(self.sBC_layout)], self.buffer_align_bytes, ] sXproj: cute.struct.Align[ cute.struct.MemRange[self.proj_dtype, cute.cosize(self.sProj_layout)], self.buffer_align_bytes, ] sZproj: cute.struct.Align[ cute.struct.MemRange[self.proj_dtype, sZproj_size], self.buffer_align_bytes, ] sOutproj: cute.struct.Align[ cute.struct.MemRange[self.proj_dtype, sOutproj_size], self.buffer_align_bytes, ] self.shared_storage = SharedStorage self.kernel( mState, mBstate, mXstate, mA, mB, mC, mD, mX, mDt, mTrap, mXproj, mOutproj, mStateOut, mOut, mZ, mZproj, self.sState_layout, self.sBC_layout, self.sProj_layout, self.gmem_tiled_copy_state, self.gmem_tiled_copy_BC, self.gmem_tiled_copy_Proj, self.gmem_tiled_copy_X, tiled_copy_state_s2r, tiled_copy_B_s2r, vecsize_dstate, ).launch( # grid: (d, h, b) grid=[cute.ceil_div(mState.shape[2], self.tile_D), mState.shape[1], mState.shape[0]], block=[num_threads, 1, 1], stream=stream, ) @cute.kernel def kernel( self, mState: cute.Tensor, # (B, H, D, N) mBstate: cute.Tensor, # (B, R, H, N) mXstate: cute.Tensor, # (B, H, D) mA: cute.Tensor, # (B, H) mB: cute.Tensor, # (B, R, H, N) mC: cute.Tensor, # (B, R, H, N) mD: cute.Tensor, # (H) mX: cute.Tensor, # (B, H, D) mDt: cute.Tensor, # (B, H) mTrap: cute.Tensor, # (B, H) mXproj: cute.Tensor, # (R, H, D) mOutproj: Optional[cute.Tensor], # (R, H, D), None if remove_outproj mStateOut: cute.Tensor, # (B, H, D, N) mOut: cute.Tensor, # (B, H, D) or (B, R, H, D) if remove_outproj mZ: Optional[cute.Tensor], # (B, H, D), None if remove_gate mZproj: Optional[cute.Tensor], # (R, H, D), None if remove_gate sState_layout: cute.Layout | cute.ComposedLayout, sBC_layout: cute.Layout | cute.ComposedLayout, sProj_layout: cute.Layout | cute.ComposedLayout, gmem_tiled_copy_state: cute.TiledCopy, gmem_tiled_copy_BC: cute.TiledCopy, gmem_tiled_copy_Proj: cute.TiledCopy, gmem_tiled_copy_X: cute.TiledCopy, tiled_copy_state_s2r: cute.TiledCopy, tiled_copy_B_s2r: cute.TiledCopy, vecsize_dstate: cutlass.Constexpr[int], ): tidx, _, _ = cute.arch.thread_idx() bidd, bidh, bidb = cute.arch.block_idx() warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) lane_idx = cute.arch.lane_idx() limit_d = mState.shape[2] # /////////////////////////////////////////////////////////////////////////////// # Slice for CTA # /////////////////////////////////////////////////////////////////////////////// # (tile_D, N) gState, gStateOut = [ cute.local_tile(t[bidb, bidh, None, None], (self.tile_D, self.dstate), (bidd, 0)) for t in (mState, mStateOut) ] # (R, N) gBstate, gB, gC = [ cute.local_tile(t[bidb, None, bidh, None], (self.mimo, self.dstate), (0, 0)) for t in (mBstate, mB, mC) ] # (tile_D,) gXstate, gX = [ cute.local_tile(t[bidb, bidh, None], (self.tile_D,), (bidd,)) for t in (mXstate, mX) ] if const_expr(mOutproj is not None): # Output is (B, H, D), outproj reduces MIMO rank gOut = cute.local_tile(mOut[bidb, bidh, None], (self.tile_D,), (bidd,)) gXproj = cute.local_tile(mXproj[None, bidh, None], (self.mimo, self.tile_D), (0, bidd)) gOutproj = cute.local_tile(mOutproj[None, bidh, None], (self.mimo, self.tile_D), (0, bidd)) else: # Output is (B, R, H, D), no outproj reduction gXproj = cute.local_tile(mXproj[None, bidh, None], (self.mimo, self.tile_D), (0, bidd)) gOutproj = None if const_expr(mZ is not None): gZ = cute.local_tile(mZ[bidb, bidh, None], (self.tile_D,), (bidd,)) gZproj = cute.local_tile(mZproj[None, bidh, None], (self.mimo, self.tile_D), (0, bidd)) # /////////////////////////////////////////////////////////////////////////////// # Generate smem tensors # /////////////////////////////////////////////////////////////////////////////// smem = cutlass.utils.SmemAllocator() storage = smem.allocate(self.shared_storage) sState = storage.sState.get_tensor(sState_layout) sBstate = storage.sBstate.get_tensor(sBC_layout) sB = storage.sB.get_tensor(sBC_layout) sC = storage.sC.get_tensor(sBC_layout) sXproj = storage.sXproj.get_tensor(sProj_layout) sZproj = storage.sZproj.get_tensor(sProj_layout) if const_expr(mZ is not None) else None sOutproj = storage.sOutproj.get_tensor(sProj_layout) if const_expr(mOutproj is not None) else None sXstate = storage.sXstate.get_tensor(cute.make_layout(self.tile_D)) sX = storage.sX.get_tensor(cute.make_layout(self.tile_D)) sXgamma = storage.sXgamma.get_tensor(cute.make_layout(self.tile_D)) # /////////////////////////////////////////////////////////////////////////////// # Partitioning using copy atoms # /////////////////////////////////////////////////////////////////////////////// gmem_thr_copy_state = gmem_tiled_copy_state.get_slice(tidx) # copying states from r2g uses the same tiled copy as s2r gmem_thr_copy_StateOut = tiled_copy_state_s2r.get_slice(tidx) gmem_thr_copy_BC = gmem_tiled_copy_BC.get_slice(tidx) gmem_thr_copy_Proj = gmem_tiled_copy_Proj.get_slice(tidx) gmem_thr_copy_X = gmem_tiled_copy_X.get_slice(lane_idx) # Only 1 warp loads X, Z tSgS = gmem_thr_copy_state.partition_S(gState) tSsS_g2s = gmem_thr_copy_state.partition_D(sState) tSgSOut = gmem_thr_copy_StateOut.partition_D(gStateOut) tBCgBstate = gmem_thr_copy_BC.partition_S(gBstate) tBCsBstate = gmem_thr_copy_BC.partition_D(sBstate) tBCgB = gmem_thr_copy_BC.partition_S(gB) tBCsB = gmem_thr_copy_BC.partition_D(sB) tBCgC = gmem_thr_copy_BC.partition_S(gC) tBCsC = gmem_thr_copy_BC.partition_D(sC) tPgXproj = gmem_thr_copy_Proj.partition_S(gXproj) tPsXproj = gmem_thr_copy_Proj.partition_D(sXproj) if const_expr(mZ is not None): tPgZproj = gmem_thr_copy_Proj.partition_S(gZproj) tPsZproj = gmem_thr_copy_Proj.partition_D(sZproj) if const_expr(mOutproj is not None): tPgOutproj = gmem_thr_copy_Proj.partition_S(gOutproj) tPsOutproj = gmem_thr_copy_Proj.partition_D(sOutproj) tXgX = gmem_thr_copy_X.partition_S(gX) tXsX = gmem_thr_copy_X.partition_D(sX) tXsXgamma = gmem_thr_copy_X.partition_D(sXgamma) tXgXstate = gmem_thr_copy_X.partition_S(gXstate) tXsXstate = gmem_thr_copy_X.partition_D(sXstate) # Idk why this order of threads_per_dstate and num_groups are reversed threads_per_dstate, num_groups = tiled_copy_state_s2r.layout_tv_tiled[0].shape lanes_per_D = self.tile_D // num_groups # For bound checking cS = cute.make_identity_tensor((self.tile_D, self.dstate)) tScS = gmem_thr_copy_state.partition_S(cS) cBC = cute.make_identity_tensor((self.mimo, self.dstate)) tBCcBC = gmem_thr_copy_BC.partition_S(cBC) cProj = cute.make_identity_tensor((self.mimo, self.tile_D)) tPcProj = gmem_thr_copy_Proj.partition_S(cProj) A_val = Float32(mA[bidb, bidh]) dt_val = Float32(mDt[bidb, bidh]) trap_val = Float32(mTrap[bidb, bidh]) # Load X and Xstate, these are small so we want to kick them off first tXrX = cute.make_fragment_like(tXgX) tXrXstate = cute.make_fragment_like(tXgXstate) copy_elems_x = cute.size(tXgX.shape[0][0]) assert cute.size(tXgX.shape) == copy_elems_x # Only 1 load instruction num_loads_X = const_expr(self.tile_D // copy_elems_x) need_bound_check_X = const_expr(cute.arch.WARP_SIZE > num_loads_X) if warp_idx == 0: if not need_bound_check_X or lane_idx < num_loads_X: cute.copy(gmem_tiled_copy_X, tXgX, tXrX) if warp_idx == 1: if not need_bound_check_X or lane_idx < num_loads_X: cute.copy(gmem_tiled_copy_X, tXgXstate, tXrXstate) # Load Bstate, B, Xproj need bound checking for m in cutlass.range(cute.size(tBCcBC.shape[1]), unroll_full=True): if tBCcBC[0, m, 0][0] < self.mimo: cute.copy(gmem_tiled_copy_BC, tBCgBstate[None, m, None], tBCsBstate[None, m, None]) cute.copy(gmem_tiled_copy_BC, tBCgB[None, m, None], tBCsB[None, m, None]) for m in cutlass.range(cute.size(tPcProj.shape[1]), unroll_full=True): if tPcProj[0, m, 0][0] < self.mimo: cute.copy(gmem_tiled_copy_Proj, tPgXproj[None, m, None], tPsXproj[None, m, None]) cute.arch.cp_async_commit_group() # Load State, not doing any bound check for now cute.copy(gmem_tiled_copy_state, tSgS, tSsS_g2s) cute.arch.cp_async_commit_group() alpha_val = cute.arch.exp(A_val * dt_val) # Transform X and Xstate by multiplying with gamma and beta, then write to smem if warp_idx == 0: tXrX_f32 = cute.make_fragment_like(tXrX, Float32) tXrX_f32.store(tXrX.load().to(Float32)) if not need_bound_check_X or lane_idx < num_loads_X: cute.autovec_copy(tXrX_f32, tXsX) gamma_val = trap_val * dt_val tXrX_f32.store(tXrX_f32.load() * gamma_val) if not need_bound_check_X or lane_idx < num_loads_X: cute.autovec_copy(tXrX_f32, tXsXgamma) if warp_idx == 1: beta_val = (1.0 - trap_val) * dt_val * alpha_val tXrXstate_f32 = cute.make_fragment_like(tXgXstate, Float32) tXrXstate_f32.store(tXrXstate.load().to(Float32) * beta_val) if not need_bound_check_X or lane_idx < num_loads_X: cute.autovec_copy(tXrXstate_f32, tXsXstate) # Load C, need bound checking for m in cutlass.range(cute.size(tBCcBC.shape[1]), unroll_full=True): if tBCcBC[0, m, 0][0] < self.mimo: cute.copy(gmem_tiled_copy_BC, tBCgC[None, m, None], tBCsC[None, m, None]) cute.arch.cp_async_commit_group() cute.arch.cp_async_wait_group(2) # B, Bstate, Xproj are done loading cute.arch.sync_threads() # Load B, Bstate, Xproj from smem smem_thr_copy_B = tiled_copy_B_s2r.get_slice(tidx % threads_per_dstate) # ((vecsize_dstate, 1), mimo, 1) -> ((vecsize_dstate, 1), mimo) tSsB = smem_thr_copy_B.partition_S(sB)[None, None, 0] tSsBstate = smem_thr_copy_B.partition_S(sBstate)[None, None, 0] tSrB = cute.make_fragment_like(tSsB) tSrBstate = cute.make_fragment_like(tSsBstate) cute.autovec_copy(tSsB, tSrB) cute.autovec_copy(tSsBstate, tSrBstate) tSrB_f32 = cute.make_fragment_like(tSrB, Float32) tSrB_f32.store(tSrB.load().to(Float32)) tSrBstate_f32 = cute.make_fragment_like(tSrBstate, Float32) tSrBstate_f32.store(tSrBstate.load().to(Float32)) # Loading x and xstate, at most 1 val per thread x_val = Float32(0.0) if lane_idx < lanes_per_D: # TODO: should this be warp_idx or group_idx? x_val = sXgamma[warp_idx * lanes_per_D + lane_idx] x_state_val = Float32(0.0) if lane_idx < lanes_per_D: x_state_val = sXstate[warp_idx * lanes_per_D + lane_idx] new_state = cute.make_fragment((vecsize_dstate, lanes_per_D), Float32) for r in cutlass.range_constexpr(self.mimo): x_proj_val = Float32(0.0) if lane_idx < lanes_per_D: x_proj_val = Float32(sXproj[r, warp_idx * lanes_per_D + lane_idx]) x_gamma_x_proj_val = x_val * x_proj_val x_state_x_proj_val = x_state_val * x_proj_val for d in cutlass.range(lanes_per_D, unroll_full=True): xg = cute.arch.shuffle_sync(x_gamma_x_proj_val, d) xb = cute.arch.shuffle_sync(x_state_x_proj_val, d) for v in cutlass.range(vecsize_dstate, unroll_full=True): if const_expr(r == 0): new_state[v, d] = xg * tSrB_f32[v, r] else: new_state[v, d] += xg * tSrB_f32[v, r] new_state[v, d] += xb * tSrBstate_f32[v, r] cute.arch.cp_async_wait_group(1) # state is done loading cute.arch.sync_threads() thr_copy_state_s2r = tiled_copy_state_s2r.get_slice(tidx) # ((vecsize_state, lanes_per_D), 1, 1) tSsS = thr_copy_state_s2r.partition_S(sState) tSrS = cute.make_fragment_like(tSsS) cute.autovec_copy(tSsS, tSrS) # ((vecsize_state, lanes_per_D), 1, 1) # tSrS_f32 = cute.make_fragment_like(tSrS, Float32) tSrS_f32 = cute.make_fragment(((vecsize_dstate, 1), lanes_per_D, 1), Float32) assert cute.size(tSrS.shape) == cute.size(tSrS_f32.shape) tSrS_f32.store(tSrS.load().to(Float32)) for v in cutlass.range(cute.size(tSrS_f32), unroll_full=True): tSrS_f32[v] = tSrS_f32[v] * alpha_val + new_state[v] tSrS.store(tSrS_f32.load().to(self.dtype)) # Load Z from gmem -> rmem, it's small, at most 1 val per thread if const_expr(mZ is not None): z_val = Float32(0.0) if lane_idx < lanes_per_D: z_val = Float32(gZ[warp_idx * lanes_per_D + lane_idx]) # Load Zproj and Outproj, need bound checking for m in cutlass.range(cute.size(tPcProj.shape[1]), unroll_full=True): if tPcProj[0, m, 0][0] < self.mimo: if const_expr(mZ is not None): cute.copy(gmem_tiled_copy_Proj, tPgZproj[None, m, None], tPsZproj[None, m, None]) if const_expr(mOutproj is not None): cute.copy(gmem_tiled_copy_Proj, tPgOutproj[None, m, None], tPsOutproj[None, m, None]) cute.arch.cp_async_commit_group() # Write state back to StateOut (may be same memory as State for in-place) cute.copy(tiled_copy_state_s2r, tSrS, tSgSOut) # Do state @ C cute.arch.cp_async_wait_group(1) # C is done loading cute.arch.sync_threads() # ((vecsize_dstate, 1), mimo, 1) -> ((vecsize_dstate, 1), 1, mimo) tSsC = select(smem_thr_copy_B.partition_S(sC), mode=[0, 2, 1]) tSrC = cute.make_fragment_like(tSsC) cute.autovec_copy(tSsC, tSrC) tSrC_f32 = cute.make_fragment_like(tSrC, Float32) tSrC_f32.store(tSrC.load().to(Float32)) out_expanded = cute.make_fragment((lanes_per_D, self.mimo), Float32) # tSrS_f32 has shape ((vecsize_dstate, 1), lanes_per_D, 1) # tSrC has shape ((vecsize_dstate, 1), mimo) out_expanded.store( (tSrS_f32.load() * tSrC_f32.load()).reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=(0, None, None)) ) assert lanes_per_D <= threads_per_dstate for d in cutlass.range(lanes_per_D, unroll_full=True): for r in cutlass.range(self.mimo, unroll_full=True): out_expanded[d, r] += cute.arch.shuffle_sync_bfly(out_expanded[d, r], offset=16) for i in cutlass.range_constexpr(int(math.log2(lanes_per_D))): step = 1 << (int(math.log2(lanes_per_D)) - 1 - i) should_swap = not Boolean(lane_idx & step) for j in cutlass.range_constexpr(step): for r in cutlass.range(self.mimo, unroll_full=True): lower, upper = out_expanded[j, r], out_expanded[j + step, r] out_expanded[j, r] = upper if should_swap else lower out_expanded[j + step, r] = lower if should_swap else upper shfl_val = cute.arch.shuffle_sync_bfly(out_expanded[j, r], offset=step) out_expanded[j, r] = shfl_val + out_expanded[j + step, r] # After this, the out values are just out_expanded[0, None] out = out_expanded[0, None] # (mimo,) # Add D * x * x_proj to out D_val = Float32(mD[bidh]) x_val = Float32(0.0) if lane_idx < lanes_per_D: x_val = sX[warp_idx * lanes_per_D + lane_idx] for r in cutlass.range_constexpr(self.mimo): x_proj_val = Float32(0.0) if lane_idx < lanes_per_D: x_proj_val = Float32(sXproj[r, warp_idx * lanes_per_D + lane_idx]) out[r] += D_val * x_val * x_proj_val cute.arch.cp_async_wait_group(0) # Zproj and Outproj are done loading cute.arch.sync_threads() if const_expr(mOutproj is not None): # Gate: z_r * sigmoid(z_r) if const_expr(mZ is not None): for r in cutlass.range_constexpr(self.mimo): z_proj_val = Float32(0.0) if lane_idx < lanes_per_D: z_proj_val = Float32(sZproj[r, warp_idx * lanes_per_D + lane_idx]) z_r_half = 0.5 * (z_val * z_proj_val) z_r_silu = z_r_half * cute.math.tanh(z_r_half, fastmath=True) + z_r_half out[r] *= z_r_silu # Final projection along mimo dim out_val = Float32(0.0) for r in cutlass.range_constexpr(self.mimo): out_proj_val = Float32(0.0) if lane_idx < lanes_per_D: out_proj_val = Float32(sOutproj[r, warp_idx * lanes_per_D + lane_idx]) if const_expr(r == 0): out_val = out[r] * out_proj_val else: out_val += out[r] * out_proj_val # Write final output (B, H, D) if lane_idx < lanes_per_D: gOut[warp_idx * lanes_per_D + lane_idx] = out_val.to(mOut.element_type) else: # No outproj: write per-rank output (B, R, H, D) for r in cutlass.range_constexpr(self.mimo): gOut_r = cute.local_tile(mOut[bidb, r, bidh, None], (self.tile_D,), (bidd,)) if lane_idx < lanes_per_D: gOut_r[warp_idx * lanes_per_D + lane_idx] = out[r].to(mOut.element_type) def mamba3_step_fn( # B: batch size, H: num heads, D: head dim, N: dstate, R: mimo state: Tensor, # (B, H, D, N) — updated in place if state_out is None Bstate: Tensor, # (B, R, H, N) Xstate: Tensor, # (B, H, D) A: Tensor, # (B, H) B: Tensor, # (B, R, H, N) C: Tensor, # (B, R, H, N) D: Tensor, # (H) x: Tensor, # (B, H, D) dt: Tensor, # (B, H) trap: Tensor, # (B, H) xproj: Tensor, # (R, H, D) outproj: Optional[Tensor] = None, # (R, H, D), None if remove_outproj state_out: Optional[Tensor] = None, # (B, H, D, N), None for in-place update out: Tensor = None, # (B, H, D) or (B, R, H, D) if remove_outproj z: Optional[Tensor] = None, # (B, H, D), None if remove_gate zproj: Optional[Tensor] = None, # (R, H, D), None if remove_gate tile_D: int = 64, num_warps: int = 2, ) -> None: has_z = z is not None has_outproj = outproj is not None inplace = state_out is None batch, nheads, hdim, dstate = state.shape mimo = Bstate.shape[1] assert state.shape == (batch, nheads, hdim, dstate) assert Bstate.shape == (batch, mimo, nheads, dstate) assert Xstate.shape == (batch, nheads, hdim) assert A.shape == (batch, nheads) assert B.shape == (batch, mimo, nheads, dstate) assert C.shape == (batch, mimo, nheads, dstate) assert D.shape == (nheads,) assert x.shape == (batch, nheads, hdim) if has_z: assert z.shape == (batch, nheads, hdim) assert zproj is not None assert zproj.shape == (mimo, nheads, hdim) assert dt.shape == (batch, nheads) assert trap.shape == (batch, nheads) assert xproj.shape == (mimo, nheads, hdim) xproj = xproj.contiguous() if has_outproj: assert outproj.shape == (mimo, nheads, hdim) assert out.shape == (batch, nheads, hdim) else: assert out.shape == (batch, mimo, nheads, hdim) # Use state itself as output target when in-place if inplace: state_out = state else: assert state_out.shape == (batch, nheads, hdim, dstate) required_tensors = [state, Bstate, Xstate, A, B, C, D, x, dt, trap, xproj, state_out, out] if has_outproj: required_tensors.append(outproj) if has_z: required_tensors.extend([z, zproj]) assert all(t.is_cuda for t in required_tensors) assert state.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported input dtype" # Map torch dtypes to cutlass dtypes state_cute_dtype = torch2cute_dtype_map[state.dtype] b_cute_dtype = torch2cute_dtype_map[Bstate.dtype] x_cute_dtype = torch2cute_dtype_map[x.dtype] proj_cute_dtype = torch2cute_dtype_map[xproj.dtype] a_cute_dtype = torch2cute_dtype_map[A.dtype] d_cute_dtype = torch2cute_dtype_map[D.dtype] dt_cute_dtype = torch2cute_dtype_map[dt.dtype] trap_cute_dtype = torch2cute_dtype_map[trap.dtype] compile_key = ( tile_D, num_warps, dstate, hdim, mimo, state.dtype, Bstate.dtype, xproj.dtype, A.dtype, D.dtype, dt.dtype, trap.dtype, has_z, has_outproj, ) if compile_key not in mamba3_step_fn.compile_cache: mamba3_step_op = Mamba3Step(tile_D, dstate, mimo, num_warps, remove_gate=not has_z, remove_outproj=not has_outproj) # Create symbolic dimensions for batch and nheads batch_sym = cute.sym_int() nheads_sym = cute.sym_int() # Divisibility for strides (128-bit alignment) div_state = 128 // state_cute_dtype.width div_b = 128 // b_cute_dtype.width div_x = 128 // x_cute_dtype.width div_proj = 128 // proj_cute_dtype.width div_a = 128 // a_cute_dtype.width div_d = 128 // d_cute_dtype.width div_dt = 128 // dt_cute_dtype.width div_trap = 128 // trap_cute_dtype.width # Create fake tensors with symbolic batch/nheads dimensions state_fake = make_fake_tensor(state_cute_dtype, (batch_sym, nheads_sym, hdim, dstate), div_state) Bstate_fake = make_fake_tensor(b_cute_dtype, (batch_sym, mimo, nheads_sym, dstate), div_b) Xstate_fake = make_fake_tensor(x_cute_dtype, (batch_sym, nheads_sym, hdim), div_x) A_fake = make_fake_tensor(a_cute_dtype, (batch_sym, nheads_sym), div_a) B_fake = make_fake_tensor(b_cute_dtype, (batch_sym, mimo, nheads_sym, dstate), div_b) C_fake = make_fake_tensor(b_cute_dtype, (batch_sym, mimo, nheads_sym, dstate), div_b) D_fake = make_fake_tensor(d_cute_dtype, (nheads_sym,), div_d) x_fake = make_fake_tensor(x_cute_dtype, (batch_sym, nheads_sym, hdim), div_x) dt_fake = make_fake_tensor(dt_cute_dtype, (batch_sym, nheads_sym), div_dt) trap_fake = make_fake_tensor(trap_cute_dtype, (batch_sym, nheads_sym), div_trap) xproj_fake = make_fake_tensor(proj_cute_dtype, (mimo, nheads_sym, hdim), div_proj) outproj_fake = make_fake_tensor(proj_cute_dtype, (mimo, nheads_sym, hdim), div_proj) if has_outproj else None state_out_fake = make_fake_tensor(state_cute_dtype, (batch_sym, nheads_sym, hdim, dstate), div_state) if has_outproj: out_fake = make_fake_tensor(x_cute_dtype, (batch_sym, nheads_sym, hdim), div_x) else: out_fake = make_fake_tensor(x_cute_dtype, (batch_sym, mimo, nheads_sym, hdim), div_x) z_fake = make_fake_tensor(x_cute_dtype, (batch_sym, nheads_sym, hdim), div_x) if has_z else None zproj_fake = make_fake_tensor(proj_cute_dtype, (mimo, nheads_sym, hdim), div_proj) if has_z else None fake_stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) mamba3_step_fn.compile_cache[compile_key] = cute.compile( mamba3_step_op, state_fake, Bstate_fake, Xstate_fake, A_fake, B_fake, C_fake, D_fake, x_fake, dt_fake, trap_fake, xproj_fake, outproj_fake, state_out_fake, out_fake, z_fake, zproj_fake, fake_stream, options="--enable-tvm-ffi", ) # Call with real PyTorch tensors directly (no dlpack conversion needed) # When inplace, state_out is state (set above) mamba3_step_fn.compile_cache[compile_key]( state, Bstate, Xstate, A, B, C, D, x, dt, trap, xproj, outproj, state_out, out, z, zproj, ) mamba3_step_fn.compile_cache = {} def selective_state_update_fused_ref_v2( state, A, B, C, xproj, x, zproj, z, dt, B_state, x_state, trap, D, outproj, compute_dtype=torch.float32 ): """ Reference to match the new fused kernel API. Shapes: state: (B, N, H, S) A: (B, N) B: (B, R, N, S) C: (B, R, N, S) xproj: (R, N, H) x: (B, N, H) zproj: (R, N, H) z: (B, N, H) dt: (B, N) B_state: (B, R, N, S) x_state: (B, N, H) trap: (B, N) D: (N,) outproj: (R, N, H) Returns: out: (B, N, H) new_state: (B, N, H, S) """ Bsz, N, H, S = state.shape R = B.shape[1] # Dtypes for numerics (match kernel's fp32 accum) og_dtype = state.dtype A_f = A.to(compute_dtype) # (B, N) dt_f = dt.to(compute_dtype) # (B, N) trap_f = trap.to(compute_dtype) # (B, N) D_f = D.to(compute_dtype) # (N,) x_f = x.to(compute_dtype) # (B, N, H) xst_f = x_state.to(compute_dtype) # (B, N, H) B_f = B.to(compute_dtype) # (B, R, N, S) C_f = C.to(compute_dtype) # (B, R, N, S) Bst_f = B_state.to(compute_dtype) # (B, R, N, S) Xp_f = xproj.to(compute_dtype) # (R, N, H) st_f = state.to(compute_dtype) # (B, N, H, S) alpha = torch.exp(A_f * dt_f) # (B, N) beta = (1.0 - trap_f) * dt_f * alpha # (B, N) gamma = trap_f * dt_f # (B, N) x_vals = (x_f[:, None, :, :] * Xp_f[None, :, :, :]) # (B, R, N, H) xs_vals = (xst_f[:, None, :, :] * Xp_f[None, :, :, :]) # (B, R, N, H) xBt_state = torch.einsum('brnh,brns->bnhs', x_vals * gamma.unsqueeze(-1).unsqueeze(1), B_f) xBt_prev = torch.einsum('brnh,brns->bnhs', xs_vals * beta.unsqueeze(-1).unsqueeze(1), Bst_f) new_state = st_f * alpha[:, :, None, None] + xBt_state + xBt_prev # (B, N, H, S) out_r = torch.einsum('bnhs,brns->brnh', new_state, C_f) # (B, R, N, H) out_r = out_r + (x_vals * D_f[None, :, None]) # (B, R, N, H) if z is not None: z_f = z.to(compute_dtype) # (B, N, H) Zp_f = zproj.to(compute_dtype) # (R, N, H) z_vals = (z_f[:, None, :, :] * Zp_f[None, :, :, :]) # (B, R, N, H) out_r = out_r * z_vals * torch.sigmoid(z_vals) # (B, R, N, H) if outproj is not None: Op_f = outproj.to(compute_dtype) # (R, N, H) out = torch.einsum('brnh,rnh->bnh', out_r, Op_f) # (B, N, H) else: out = out_r # (B, R, N, H) return out.to(og_dtype), new_state.to(og_dtype) def _bytes_of(t): return t.numel() * t.element_size() if __name__ == "__main__": torch.manual_seed(1357) batch, nheads, hdim, dstate, mimo = 128, 64, 64, 128, 4 device = torch.device("cuda:0") dtype_state = torch.float32 dtype = torch.float32 state = torch.randn(batch, nheads, hdim, dstate, device=device, dtype=dtype_state) Bstate = torch.randn(batch, mimo, nheads, dstate, device=device, dtype=dtype) Xstate = torch.randn(batch, nheads, hdim, device=device, dtype=dtype) A = -F.softplus(torch.randn(batch, nheads, device=device, dtype=torch.float32)) B = torch.randn(batch, mimo, nheads, dstate, device=device, dtype=dtype) C = torch.randn(batch, mimo, nheads, dstate, device=device, dtype=dtype) D = torch.randn(nheads, device=device, dtype=dtype) x = torch.randn(batch, nheads, hdim, device=device, dtype=dtype) z = torch.randn(batch, nheads, hdim, device=device, dtype=dtype) dt = torch.exp(torch.rand(nheads, device=device) * (math.log(0.1) - math.log(0.001)) + math.log(0.001)) dt = torch.clamp(dt, min=1e-4) dt_bias = dt + torch.log(-torch.expm1(-dt)) dt = F.softplus(torch.randn(batch, nheads, device=device) + dt_bias) # (B, H) trap = torch.sigmoid(torch.randn(batch, nheads, device=device, dtype=torch.float32)) xproj = torch.randn(mimo, nheads, hdim, device=device, dtype=dtype) zproj = torch.randn(mimo, nheads, hdim, device=device, dtype=dtype) outproj = torch.randn(mimo, nheads, hdim, device=device, dtype=dtype) out = torch.zeros_like(x) # ========================================================================= # Test 1: Out-of-place (explicit state_out) # ========================================================================= print("=== Out-of-place test ===") state_out = torch.zeros_like(state) fn_oop = lambda: mamba3_step_fn( state, Bstate, Xstate, A, B, C, D, x, dt, trap, xproj, outproj, state_out, out, z=z, zproj=zproj, tile_D=64, num_warps=4, ) fn_oop() out_ref, state_out_ref = selective_state_update_fused_ref_v2(state, A, B, C, xproj, x, zproj, z, dt, Bstate, Xstate, trap, D, outproj, compute_dtype=torch.float64) out_pt, state_out_pt = selective_state_update_fused_ref_v2(state, A, B, C, xproj, x, zproj, z, dt, Bstate, Xstate, trap, D, outproj, compute_dtype=torch.float32) print(f"state_out vs ref (f64): {(state_out - state_out_ref).abs().max()}") print(f"state_out_pt vs ref (f64): {(state_out_pt - state_out_ref).abs().max()}") print(f"out vs ref (f64): {(out - out_ref).abs().max()}") print(f"out_pt vs ref (f64): {(out_pt - out_ref).abs().max()}") # ========================================================================= # Test 2: In-place (state_out=None) # ========================================================================= print("\n=== In-place test ===") # Fresh state for in-place test state_ip = state.clone() out_ip = torch.zeros_like(x) fn_ip = lambda: mamba3_step_fn( state_ip, Bstate, Xstate, A, B, C, D, x, dt, trap, xproj, outproj, None, # state_out=None -> in-place out_ip, z=z, zproj=zproj, tile_D=64, num_warps=4, ) fn_ip() # state_ip was updated in place, compare against same reference print(f"state (in-place) vs ref (f64): {(state_ip - state_out_ref).abs().max()}") print(f"out (in-place) vs ref (f64): {(out_ip - out_ref).abs().max()}") # Verify in-place and out-of-place produce identical results print(f"state in-place vs out-of-place: {(state_ip - state_out).abs().max()}") print(f"out in-place vs out-of-place: {(out_ip - out).abs().max()}") # ========================================================================= # Benchmark (out-of-place) # ========================================================================= read_bytes = ( _bytes_of(state) + _bytes_of(A) + _bytes_of(B) + _bytes_of(C) + _bytes_of(xproj) + _bytes_of(x) + _bytes_of(zproj) + _bytes_of(z) + _bytes_of(dt) + _bytes_of(Bstate) + _bytes_of(Xstate) + _bytes_of(trap) + _bytes_of(D) + _bytes_of(outproj) ) out_bytes = _bytes_of(out) new_state_bytes = _bytes_of(state) total_bytes = read_bytes + out_bytes + new_state_bytes from triton.testing import do_bench_cudagraph ms = do_bench_cudagraph(fn_oop, rep=30) bandwidth = (total_bytes) / ms * 1e-6 print(f"\nMamba3 step (out-of-place): {ms:.3f} ms, {bandwidth:.1f} GB/s") ================================================ FILE: mamba_ssm/ops/selective_scan_interface.py ================================================ # Copyright (c) 2023, Tri Dao, Albert Gu. import torch import torch.nn.functional as F from mamba_ssm.utils.torch import custom_bwd, custom_fwd from einops import rearrange, repeat try: from causal_conv1d import causal_conv1d_fn from causal_conv1d.cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function except ImportError: causal_conv1d_fn = None causal_conv1d_fwd_function = None causal_conv1d_bwd_function = None causal_conv1d_update_function = None from mamba_ssm.ops.triton.layer_norm import _layer_norm_fwd import selective_scan_cuda class SelectiveScanFn(torch.autograd.Function): @staticmethod def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False): if u.stride(-1) != 1: u = u.contiguous() if delta.stride(-1) != 1: delta = delta.contiguous() if D is not None: D = D.contiguous() if B.stride(-1) != 1: B = B.contiguous() if C.stride(-1) != 1: C = C.contiguous() if z is not None and z.stride(-1) != 1: z = z.contiguous() if B.dim() == 3: B = rearrange(B, "b dstate l -> b 1 dstate l") ctx.squeeze_B = True if C.dim() == 3: C = rearrange(C, "b dstate l -> b 1 dstate l") ctx.squeeze_C = True out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus) ctx.delta_softplus = delta_softplus ctx.has_z = z is not None last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) if not ctx.has_z: ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) return out if not return_last_state else (out, last_state) else: ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out) out_z = rest[0] return out_z if not return_last_state else (out_z, last_state) @staticmethod def backward(ctx, dout, *args): if not ctx.has_z: u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors z = None out = None else: u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors if dout.stride(-1) != 1: dout = dout.contiguous() # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the # backward of selective_scan_cuda with the backward of chunk). # Here we just pass in None and dz will be allocated in the C++ code. du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus, False # option to recompute out_z, not used here ) dz = rest[0] if ctx.has_z else None dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC return (du, ddelta, dA, dB, dC, dD if D is not None else None, dz, ddelta_bias if delta_bias is not None else None, None, None) def rms_norm_forward( x, weight, bias, eps=1e-6, is_rms_norm=True, ): # x (b l) d if x.stride(-1) != 1: x = x.contiguous() weight = weight.contiguous() if bias is not None: bias = bias.contiguous() y = _layer_norm_fwd( x, weight, bias, eps, None, residual_dtype=None, is_rms_norm=is_rms_norm )[0] # y (b l) d return y def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False): """if return_last_state is True, returns (out, last_state) last_state has shape (batch, dim, dstate). Note that the gradient of the last state is not considered in the backward pass. """ return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False): """ u: r(B D L) delta: r(B D L) A: c(D N) or r(D N) B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) D: r(D) z: r(B D L) delta_bias: r(D), fp32 out: r(B D L) last_state (optional): r(B D dstate) or c(B D dstate) """ dtype_in = u.dtype u = u.float() delta = delta.float() if delta_bias is not None: delta = delta + delta_bias[..., None].float() if delta_softplus: delta = F.softplus(delta) batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] is_variable_B = B.dim() >= 3 is_variable_C = C.dim() >= 3 if A.is_complex(): if is_variable_B: B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2)) if is_variable_C: C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2)) else: B = B.float() C = C.float() x = A.new_zeros((batch, dim, dstate)) ys = [] deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) if not is_variable_B: deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u) else: if B.dim() == 3: deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) else: B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) if is_variable_C and C.dim() == 4: C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) last_state = None for i in range(u.shape[2]): x = deltaA[:, :, i] * x + deltaB_u[:, :, i] if not is_variable_C: y = torch.einsum('bdn,dn->bd', x, C) else: if C.dim() == 3: y = torch.einsum('bdn,bn->bd', x, C[:, :, i]) else: y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) if i == u.shape[2] - 1: last_state = x if y.is_complex(): y = y.real * 2 ys.append(y) y = torch.stack(ys, dim=2) # (batch dim L) out = y if D is None else y + u * rearrange(D, "d -> d 1") if z is not None: out = out * F.silu(z) out = out.to(dtype=dtype_in) return out if not return_last_state else (out, last_state) class MambaInnerFn(torch.autograd.Function): @staticmethod @custom_fwd def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1, b_rms_weight=None, c_rms_weight= None, dt_rms_weight= None, b_c_dt_rms_eps=1e-6): """ xz: (batch, dim, seqlen) """ assert causal_conv1d_fwd_function is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d." assert checkpoint_lvl in [0, 1] L = xz.shape[-1] delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) if torch.is_autocast_enabled(): x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype()) if out_proj_bias is not None else None) if xz.stride(-1) != 1: xz = xz.contiguous() conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w") x, z = xz.chunk(2, dim=1) conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None conv1d_out = causal_conv1d_fwd_function( x, conv1d_weight, conv1d_bias, None, None, None, True ) # We're being very careful here about the layout, to avoid extra transposes. # We want delta to have d as the slowest moving dimension # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d) delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L) ctx.is_variable_B = B is None ctx.is_variable_C = C is None ctx.B_proj_bias_is_None = B_proj_bias is None ctx.C_proj_bias_is_None = C_proj_bias is None if B is None: # variable B B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl dstate) if B_proj_bias is not None: B = B + B_proj_bias.to(dtype=B.dtype) if not A.is_complex(): # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous() else: B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() else: if B.stride(-1) != 1: B = B.contiguous() if C is None: # variable C C = x_dbl[:, -d_state:] # (bl dstate) if C_proj_bias is not None: C = C + C_proj_bias.to(dtype=C.dtype) if not A.is_complex(): # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous() else: C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() else: if C.stride(-1) != 1: C = C.contiguous() if D is not None: D = D.contiguous() if b_rms_weight is not None: B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous() B = rms_norm_forward(B, b_rms_weight, bias=None, eps=b_c_dt_rms_eps) B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous() if c_rms_weight is not None: C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous() C = rms_norm_forward(C, c_rms_weight, bias=None, eps=b_c_dt_rms_eps) C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous() if dt_rms_weight is not None: delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous() delta = rms_norm_forward(delta, dt_rms_weight, bias=None, eps=b_c_dt_rms_eps) delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous() out, scan_intermediates, out_z = selective_scan_cuda.fwd( conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus ) ctx.delta_softplus = delta_softplus ctx.out_proj_bias_is_None = out_proj_bias is None ctx.checkpoint_lvl = checkpoint_lvl ctx.b_rms_weight = b_rms_weight ctx.c_rms_weight = c_rms_weight ctx.dt_rms_weight = dt_rms_weight ctx.b_c_dt_rms_eps = b_c_dt_rms_eps if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass conv1d_out, delta = None, None ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, b_rms_weight, c_rms_weight, dt_rms_weight, out) return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias) @staticmethod @custom_bwd def backward(ctx, dout): # dout: (batch, seqlen, dim) assert causal_conv1d_fwd_function is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d." (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, b_rms_weight, c_rms_weight, dt_rms_weight, out) = ctx.saved_tensors L = xz.shape[-1] delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) x, z = xz.chunk(2, dim=1) if dout.stride(-1) != 1: dout = dout.contiguous() if ctx.checkpoint_lvl == 1: conv1d_out = causal_conv1d_fwd_function( x, conv1d_weight, conv1d_bias, None, None, None, True ) delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L) if dt_rms_weight is not None: delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous() delta = rms_norm_forward(delta, ctx.dt_rms_weight, None, ctx.b_c_dt_rms_eps) delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous() if b_rms_weight is not None: # Recompute & RMSNorm B B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous() B = rms_norm_forward( B, ctx.b_rms_weight, None, ctx.b_c_dt_rms_eps ) B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous() if c_rms_weight is not None: # Recompute & RMSNorm C C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous() C = rms_norm_forward( C, ctx.c_rms_weight, None, ctx.b_c_dt_rms_eps ) C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous() # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the # backward of selective_scan_cuda with the backward of chunk). dxz = torch.empty_like(xz) # (batch, dim, seqlen) dx, dz = dxz.chunk(2, dim=1) dout = rearrange(dout, "b l e -> e (b l)") dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L) dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd( conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates, out, dz, ctx.delta_softplus, True # option to recompute out_z ) dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)")) dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None dD = dD if D is not None else None dx_dbl = torch.empty_like(x_dbl) dB_proj_bias = None if ctx.is_variable_B: if not A.is_complex(): dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous() else: dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None dx_dbl[:, delta_rank:delta_rank + d_state] = dB # (bl d) dB = None dC_proj_bias = None if ctx.is_variable_C: if not A.is_complex(): dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous() else: dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None dx_dbl[:, -d_state:] = dC # (bl d) dC = None ddelta = rearrange(ddelta, "b d l -> d (b l)") ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank]) dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight) dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)") dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d")) dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out) dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1]) # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the # backward of conv1d with the backward of chunk). dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_bwd_function( x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True ) dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w") return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight, dout_proj_weight, dout_proj_bias, dA, dB, dC, dD, ddelta_bias if delta_bias is not None else None, # 6-None are delta_softplus, checkpoint_lvl, b_rms_weight, c_rms_weight, dt_rms_weight, b_c_dt_rms_eps dB_proj_bias, dC_proj_bias, None, None, None, None, None, None) def mamba_inner_fn( xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1, b_rms_weight= None, c_rms_weight= None, dt_rms_weight= None, b_c_dt_rms_eps=1e-6 ): return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus, checkpoint_lvl, b_rms_weight, c_rms_weight, dt_rms_weight, b_c_dt_rms_eps) def mamba_inner_ref( xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, C_proj_bias=None, delta_softplus=True ): assert causal_conv1d_fn is not None, "causal_conv1d_fn is not available. Please install causal-conv1d." L = xz.shape[-1] delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) x, z = xz.chunk(2, dim=1) x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, activation="silu") # We're being very careful here about the layout, to avoid extra transposes. # We want delta to have d as the slowest moving dimension # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight) # (bl d) delta = delta_proj_weight @ x_dbl[:, :delta_rank].t() delta = rearrange(delta, "d (b l) -> b d l", l=L) if B is None: # variable B B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl d) if B_proj_bias is not None: B = B + B_proj_bias.to(dtype=B.dtype) if not A.is_complex(): B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() else: B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() if C is None: # variable B C = x_dbl[:, -d_state:] # (bl d) if C_proj_bias is not None: C = C + C_proj_bias.to(dtype=C.dtype) if not A.is_complex(): C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() else: C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True) return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias) ================================================ FILE: mamba_ssm/ops/tilelang/mamba3/mamba3_mimo.py ================================================ """Mamba-3 Tilelang Autograd Wrapper Interface for Mamba-3 Tilelang kernels with automatic differentiation Copyright (c) 2026, Dao AI Lab, Goombalab """ from __future__ import annotations from typing import Optional, Tuple, Union import torch from torch import Tensor # Import kernels from mamba_ssm.ops.tilelang.mamba3.mamba3_mimo_fwd import mamba_mimo_forward from mamba_ssm.ops.triton.mamba3.mamba3_mimo_utils import compute_dacs_segsum_triton from mamba_ssm.ops.tilelang.mamba3.mamba3_mimo_bwd import mamba_mimo_bwd_combined # ============================================================================= # Autograd Function # ============================================================================= class _Mamba3Function(torch.autograd.Function): """Custom autograd function for Mamba-3 with Triton/Tilelang kernels.""" @staticmethod def forward( ctx, Q: Tensor, K: Tensor, V: Tensor, ADT: Tensor, DT: Tensor, Trap: Tensor, Q_bias: Tensor, K_bias: Tensor, MIMO_V: Tensor, MIMO_Z: Tensor, MIMO_Out: Union[Tensor, None], Angles: Tensor, D: Tensor, Z: Tensor, chunk_size: int, rotary_dim_divisor: int, dtype: torch.dtype, return_state: bool, ) -> Tensor | Tuple[Tensor, Tuple]: """Forward pass: call Triton/Tilelang kernel and save tensors for backward.""" ctx.chunk_size = chunk_size ctx.rotary_dim_divisor = rotary_dim_divisor ctx.dtype = dtype (Q, K, V, ADT, DT, Trap, Q_bias, K_bias, MIMO_V, MIMO_Z, MIMO_Out, Angles, D, Z) = tuple( t.contiguous() if t is not None else None for t in ( Q, K, V, ADT, DT, Trap, Q_bias, K_bias, MIMO_V, MIMO_Z, MIMO_Out, Angles, D, Z, ) ) DA_CS, DA_CS_REV, Segsum = compute_dacs_segsum_triton(ADT, chunk_size) Out, Final_SSM_State, Final_K = mamba_mimo_forward( Q, K, V, Q_bias, K_bias, MIMO_V, MIMO_Out, Z, D, MIMO_Z, Angles, DA_CS, DA_CS_REV, DT, Trap, Segsum, return_state=return_state, chunk_size=chunk_size, rotary_dim_divisor=rotary_dim_divisor, dtype=dtype, ) ctx.chunk_size = chunk_size ctx.save_for_backward( Q, K, V, ADT, DT, Trap, Q_bias, K_bias, Angles, D, Z, MIMO_V, MIMO_Out, MIMO_Z, ) if not return_state: return Out else: Final_Angle = torch.remainder(Angles[:, -1, :, :], 2 * torch.pi).contiguous().detach() Final_SSM_State = Final_SSM_State.permute(0, 1, 3, 2).contiguous().detach() Final_K = Final_K.contiguous().detach() Final_V = V[:, -1, :, :].contiguous().detach() ctx.mark_non_differentiable(Final_Angle, Final_SSM_State, Final_K, Final_V) return Out, Final_Angle, Final_SSM_State, Final_K, Final_V @staticmethod def backward(ctx, dout, *args) -> tuple: """Backward pass: compute gradients using Triton backward kernels.""" if len(ctx.saved_tensors) == 0: raise RuntimeError( "Backward called but forward ran without gradient tracking. " "Ensure inputs require grad or run under torch.enable_grad()." ) dout = dout.contiguous() (Q, K, V, ADT, DT, Trap, Q_bias, K_bias, Angles, D, Z, MIMO_V, MIMO_Out, MIMO_Z, ) = ctx.saved_tensors DA_CS, DA_CS_REV, Segsum = compute_dacs_segsum_triton(ADT, ctx.chunk_size) (dQ, dK, dV, dADT, dDT, dTrap, dQ_bias, dK_bias, dMIMO_V, dMIMO_Z, dMIMO_Out, dAngles, dD, dZ) = mamba_mimo_bwd_combined( dout, Q, K, V, Q_bias, K_bias, MIMO_V, MIMO_Out, Z, MIMO_Z, Angles, DA_CS, DA_CS_REV, DT, Trap, D, Segsum, ctx.chunk_size, ctx.rotary_dim_divisor, ctx.dtype, ) return ( dQ, dK, dV, dADT, dDT, dTrap, dQ_bias, dK_bias, dMIMO_V, dMIMO_Z, dMIMO_Out, dAngles, dD, dZ, None, None, None, None, ) # ============================================================================= # Public API # ============================================================================= def mamba3_mimo( Q: Tensor, K: Tensor, V: Tensor, ADT: Tensor, DT: Tensor, Trap: Tensor, Q_bias: Tensor, K_bias: Tensor, MIMO_V: Tensor, MIMO_Z: Tensor, MIMO_Out: Tensor, Angles: Tensor, D: Tensor, Z: Tensor, chunk_size: int, rotary_dim_divisor: int, dtype: torch.dtype, return_state: bool = False, ) -> Tensor | Tuple[Tensor, Tuple]: """Mamba-3 attention with Tilelang kernels and automatic differentiation. Args: Q: Query tensor (batch, seqlen, mimo_rank, nheads_qk, headdim_qk) K: Key tensor (batch, seqlen, mimo_rank, nheads_qk, headdim_qk) V: Value tensor (batch, seqlen, nheads, headdim_v) ADT: Decay factor A * dt (batch, nheads, seqlen) DT: Time delta tensor dt (batch, nheads, seqlen) Trap: Trapezoidal mixing factor, pre-sigmoid (batch, nheads, seqlen) Q_bias: Query bias (nheads, mimo_rank, headdim_qk) K_bias: Key bias (nheads, mimo_rank, headdim_qk) MIMO_V: Mimo up projection for V (nheads, mimo_rank, headdim_v), MIMO_Z: Mimo up projection for Z (nheads, mimo_rank, headdim_v), MIMO_Out: Mimo down projection for output (nheads, mimo_rank, headdim_v). If None, does not reduce output with MIMO_Out, Angles: Rotary position embeddings (batch, seqlen, nheads, headangles) D: Optional skip connection weight (nheads,) Z: Optional gating tensor (batch, seqlen, nheads, headdim_v) chunk_size: Chunk size for state computation (default: 64//R) rotary_dim_divisor: Divisor for rotary embedding dimensions (default: 4, meaning angles have 1/4 of headdim_qk) Returns: output: (batch, seqlen, nheads, headdim_v) if MIMO_Out is not None (batch, seqlen, mimo_rank, nheads, headdim_v) if MIMO_Out is None final_state: Tuple of tensors representing the running Angle sum, final SSM state, final K, and final V for autoregressive decoding. Only returned if return_state=True. NOTE: The kernel is most optimized for seqlen: 2048, nheads_qk: 1, nheads: 32 headdim_qk: 128, headdim_v: 64, mimo_rank: 4, and chunk_size: 16. On H100. NOTE: The code is still prone to smem over-allocation and Tilelang compilation error once headdim_qk, headdim_v, mimo_rank, chunk_size, or hardware type deviate from the combinations tested. NOTE: Chunk size of 64/R is recommended, where R is the MIMO rank. However, it may be necessary to reduce chunk size in case of smem over-allocation, which can occur with larger headdim_qk, headdim_v, or mimo_rank values. NOTE: Currently final_state is currently intended to be a non-differentiable side output. In particular, loss = f(output) is fine, but loss = f(output, final_state) will not work properly since the backward does not compute gradients for final_state components. """ batch, seqlen, mimo_rank, nheads_qk, headdim_qk = Q.shape _, _, nheads, headdim_v = V.shape assert chunk_size >= 8, f"chunk_size must be at least 8" assert nheads % nheads_qk == 0, f"nheads ({nheads}) must be divisible by nheads_qk ({nheads_qk})" assert headdim_qk % 2 == 0, f"headdim_qk ({headdim_qk}) must be even for rotary embeddings" assert rotary_dim_divisor in [2, 4], f"currently only supports rotary embedding on entire or half of headdim_qk" # NOTE: the following (headdim_qk, headdim_v) values currently can result in compilation errors: (16, 32), (256, 128) if headdim_qk not in [16, 32, 64, 128, 256]: print(f"WARNING: The value headdim_qk={headdim_qk} has not been tested. " +\ "Proceed with caution and consider one of the tested headdim_qk: 16, 32, 64, 128, 256.") if headdim_v not in [32, 64, 128]: print(f"WARNING: The value headdim_v={headdim_v} has not been tested. " +\ "Proceed with caution and consider one of the tested headdim_v: 32, 64, 128.") if mimo_rank not in [1, 2, 4, 8]: print(f"WARNING: The value mimo_rank={mimo_rank} has not been tested. " +\ "Proceed with caution and consider one of the tested mimo_rank: 1, 2, 4, 8.") if chunk_size*mimo_rank > 64: print(f"WARNING: chunk_size * mimo_rank = {chunk_size*mimo_rank} exceeds 64, which may result in smem over-allocation. Consider decreasing chunk_size.") return _Mamba3Function.apply( Q, K, V, ADT, DT, Trap, Q_bias, K_bias, MIMO_V, MIMO_Z, MIMO_Out, Angles, D, Z, chunk_size, rotary_dim_divisor, dtype, return_state, ) ================================================ FILE: mamba_ssm/ops/tilelang/mamba3/mamba3_mimo_bwd.py ================================================ """ Tilelang implementation of Mamba3 backward kernels, with MIMO support. Copyright (c) 2026, Dao AI Lab, Goombalab """ import torch import tilelang import tilelang.language as T from triton.testing import do_bench from tilelang.autotuner import autotune import itertools import argparse from einops import rearrange from typing import Optional, Tuple from mamba_ssm.ops.triton.mamba3.mamba3_mimo_utils import bwd_dadt_fused_triton, bwd_dtrap_ddt_triton # def get_configs(): # iter_params = dict(num_stages=[0, 1, 2, 3], threads=[128, 256, 512]) # # iter_params = dict(num_stages=[2], threads=[128]) # return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] # @autotune( # configs=get_configs(), # warmup=3, # rep=20, # ) @tilelang.jit( out_idx=[], pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, }) def mamba_mimo_bwd_fwd( B, S, H, G, N, P, R, hasZ, hasD, reduceO, chunk_size: int = 16, rotary_dim_divisor: int = 4, dtype: str = 'float16', threads: int = 128, num_stages: int = 0, ) -> torch.Tensor: accum_dtype = 'float32' nchunks = tilelang.cdiv(S, chunk_size) fused_chunk_size = chunk_size * R if reduceO: DOUT_shape = (B, S, H, P) else: DOUT_shape = (B, S, R, H, P) @T.prim_func def mamba_mimo_bwd_fwd_kernel( DOUT: T.Tensor(DOUT_shape, dtype), # type: ignore Q: T.Tensor([B, S, R, G, N], dtype), # type: ignore K: T.Tensor([B, S, R, G, N], dtype), # type: ignore V: T.Tensor([B, S, H, P], dtype), # type: ignore Q_BIAS: T.Tensor([H, R, N], T.float32), # type: ignore K_BIAS: T.Tensor([H, R, N], T.float32), # type: ignore MIMO_V: T.Tensor([H, R, P], T.float32), # type: ignore MIMO_O: T.Tensor([H, R, P], T.float32), # type: ignore DMIMO_O: T.Tensor([B, H, R, P], T.float32), # type: ignore STATES: T.Tensor([B, H, nchunks, N, P], dtype), # type: ignore Z: T.Tensor([B, S, H, P], dtype), # type: ignore MIMO_Z: T.Tensor([H, R, P], T.float32), # type: ignore DZ: T.Tensor([B, S, H, P], dtype), # type: ignore DMIMO_Z: T.Tensor([B, H, R, P], T.float32), # type: ignore ANGLES: T.Tensor([B, S, H, N//rotary_dim_divisor], T.float32), # type: ignore DA_CS: T.Tensor([B, H, S], T.float32), # type: ignore DA_CS_REV: T.Tensor([B, H, S], T.float32), # type: ignore DT: T.Tensor([B, H, S], T.float32), # type: ignore TRAP: T.Tensor([B, H, S], dtype), # type: ignore D: T.Tensor([H], T.float32), # type: ignore QK_DOT: T.Tensor([B, H, S, R, R], dtype), # type: ignore SEGSUM: T.Tensor([B, H, nchunks, chunk_size, chunk_size], T.float32), # type: ignore ): """ Overview: Fused backward-forward pass over chunks. Recomputes local forward intermediates, accumulates projection gradients (DMIMO_O and optional DMIMO_Z), emits optional DZ, stores per-chunk recurrent STATES, and materializes QK_DOT for the second backward pass. Inputs: - Activations and upstream grad: DOUT, Q, K, V. - Projection weights/biases: Q_BIAS, K_BIAS, MIMO_V (Psi), MIMO_O (Phi), optional MIMO_Z (Zeta). - Optional forward modifiers: Z, D. - Discretization tensors: DA_CS, DA_CS_REV, DT, TRAP, and SEGSUM. Outputs: - MIMO projection grads: DMIMO_O and optional DMIMO_Z. - Optional activation grad: DZ. - Cached intermediates for pass 2: STATES and QK_DOT. Notation: - Psi: MIMO X projection. - Phi: MIMO O projection. - Zeta: MIMO Z projection. - Trap: convex-combination modulator used in exponential-trapezoidal discretization. """ with T.Kernel(H, B, threads=threads) as (i_h, i_b): # --- Kernel Setup --- # GQA support: map V head to Q/K head i_h_qk = i_h // (H // G) # --- Buffer Allocation --- q_shared = T.alloc_shared([fused_chunk_size, N], dtype) k_shared = T.alloc_shared([fused_chunk_size, N], dtype) PsiV_shared = T.alloc_shared([fused_chunk_size, P], dtype) qs_shared = T.alloc_shared([fused_chunk_size, P], dtype) o_shared = T.alloc_shared([chunk_size, P], dtype) v_shared = T.alloc_shared([chunk_size, P], dtype) states_accum_cast_shared = T.alloc_shared([N, P], dtype) qk_dot_full_shared = T.alloc_shared([fused_chunk_size, fused_chunk_size], dtype) # --- Output Accumulators --- if reduceO: dPhi_shared = T.alloc_shared([R, P], accum_dtype) T.clear(dPhi_shared) dout_shared = T.alloc_shared([chunk_size, P], dtype) z_shared = T.alloc_shared([chunk_size, P], dtype) dZeta_shared = T.alloc_shared([R, P], accum_dtype) T.clear(dZeta_shared) # --- Swizzling Annotation --- T.annotate_layout({ q_shared: tilelang.layout.make_swizzled_layout(q_shared), k_shared: tilelang.layout.make_swizzled_layout(k_shared), PsiV_shared: tilelang.layout.make_swizzled_layout(PsiV_shared), qs_shared: tilelang.layout.make_swizzled_layout(qs_shared), o_shared: tilelang.layout.make_swizzled_layout(o_shared), states_accum_cast_shared: tilelang.layout.make_swizzled_layout(states_accum_cast_shared), qk_dot_full_shared: tilelang.layout.make_swizzled_layout(qk_dot_full_shared), dout_shared: tilelang.layout.make_swizzled_layout(dout_shared), z_shared: tilelang.layout.make_swizzled_layout(z_shared), }) T.use_swizzle(10, "row") T.no_set_max_nreg() # --- Per-Head Constants / Running State --- states_frag = T.alloc_fragment([N, P], accum_dtype) T.clear(states_frag) if reduceO: phi_frag_intrachunk = T.alloc_fragment([R, P], dtype=dtype) T.copy(MIMO_O[i_h, :, :], phi_frag_intrachunk) Psi_frag = T.alloc_fragment([R, P], dtype) T.copy(MIMO_V[i_h, :, :], Psi_frag) q_bias_frag = T.alloc_fragment([R, N], dtype) k_bias_frag = T.alloc_fragment([R, N], dtype) T.copy(Q_BIAS[i_h, :, :], q_bias_frag) T.copy(K_BIAS[i_h, :, :], k_bias_frag) # --- Chunk Loop --- for i in T.Pipelined(0, nchunks, num_stages=num_stages): chunk_start = i * chunk_size fused_chunk_start = chunk_start * R # --- Discretization Factors (Shifted Gamma + Trap Scale) --- trap_shifted_frag = T.alloc_fragment([chunk_size], T.float32) dt_shifted_frag = T.alloc_fragment([chunk_size], dtype) for cs in T.Parallel(chunk_size): trap_shifted_frag[cs] = T.if_then_else( chunk_start + cs + 1 < S, TRAP[i_b, i_h, chunk_start + cs + 1], 0.0, ) dt_shifted_frag[cs] = T.if_then_else( chunk_start + cs + 1 < S, DT[i_b, i_h, chunk_start + cs + 1], 0.0, ) shifted_gamma_frag = T.alloc_fragment([chunk_size], dtype) for cs in T.Parallel(chunk_size): shifted_gamma_frag[cs] = T.if_then_else(chunk_start + cs < (S - 1), dt_shifted_frag[cs] * (T.sigmoid(-trap_shifted_frag[cs])), 0.0) shifted_gamma_shared = T.alloc_shared([chunk_size], dtype) T.copy(shifted_gamma_frag, shifted_gamma_shared) trap_frag = T.alloc_fragment([chunk_size], T.float32) T.copy(TRAP[i_b, i_h, chunk_start: chunk_start+chunk_size], trap_frag) dt_frag = T.alloc_fragment([chunk_size], dtype) T.copy(DT[i_b, i_h, chunk_start: chunk_start+chunk_size], dt_frag) gamma_frag = T.alloc_fragment([chunk_size], T.float32) for cs in T.Parallel(chunk_size): gamma_frag[cs] = dt_frag[cs] * T.sigmoid(trap_frag[cs]) trap_scale_frag = T.alloc_fragment([chunk_size], dtype) for cs in T.Parallel(chunk_size): trap_scale_frag[cs] = gamma_frag[cs] + shifted_gamma_shared[cs] trap_scale_shared = T.alloc_shared([chunk_size], dtype) T.copy(trap_scale_frag, trap_scale_shared) # --- Up-Project V and Prepare Biased Q/K --- PsiV_frag = T.alloc_fragment([chunk_size, R, P], dtype) T.copy(V[i_b, chunk_start:chunk_start+chunk_size, i_h, :], v_shared) for cs, r, p in T.Parallel(chunk_size, R, P): PsiV_frag[cs, r, p] = v_shared[cs, p] * Psi_frag[r, p] PsiV_reshaped_frag = T.view(PsiV_frag, shape=[fused_chunk_size, P]) T.copy(PsiV_reshaped_frag, PsiV_shared) q_reshaped_shared = T.view(q_shared, shape=[chunk_size, R, N]) T.copy(Q[i_b, chunk_start:chunk_start+chunk_size, :, i_h_qk, :], q_reshaped_shared) q_frag = T.alloc_fragment([chunk_size, R, N], dtype) T.copy(q_reshaped_shared, q_frag) for cs, r, n in T.Parallel(chunk_size, R, N): q_frag[cs, r, n] += q_bias_frag[r, n] T.copy(q_frag, q_reshaped_shared) k_reshaped_shared = T.view(k_shared, shape=[chunk_size, R, N]) T.copy(K[i_b, chunk_start:chunk_start+chunk_size, :, i_h_qk, :], k_reshaped_shared) k_frag = T.alloc_fragment([chunk_size, R, N], dtype) T.copy(k_reshaped_shared, k_frag) for cs, r, n in T.Parallel(chunk_size, R, N): k_frag[cs, r, n] += k_bias_frag[r, n] T.copy(k_frag, k_reshaped_shared) # --- Cache Diagonal qk_dot Path --- # Keep full qk_dot in shared memory to reuse per-step R x R blocks. qk_dot_frag = T.alloc_fragment([fused_chunk_size, fused_chunk_size], dtype=accum_dtype) T.gemm(q_shared, k_shared, qk_dot_frag, transpose_B=True, clear_accum=True) T.copy(qk_dot_frag, qk_dot_full_shared) # Output QK_DOT for the bwd_bwd kernel (per-time-step blocks only) for cs, r_out, r_in in T.Parallel(chunk_size, R, R): QK_DOT[i_b, i_h, chunk_start + cs, r_out, r_in] = \ qk_dot_full_shared[cs * R + r_out, cs * R + r_in] # --- Rotary Q/K + Trap Scaling --- q_first_half_frag = T.alloc_fragment([chunk_size, R, N//rotary_dim_divisor], dtype) q_second_half_frag = T.alloc_fragment([chunk_size, R, N//rotary_dim_divisor], dtype) for cs, r, n in T.Parallel(chunk_size, R, N//rotary_dim_divisor): q_first_half_frag[cs, r, n] = q_shared[cs*R + r, n] q_second_half_frag[cs, r, n] = q_shared[cs*R + r, N//2 + n] # NOTE: angles are casted to fp32 for numerical stability angles_frag = T.alloc_fragment([chunk_size, N//rotary_dim_divisor], T.float32) T.copy(ANGLES[i_b, chunk_start:chunk_start+chunk_size, i_h, :], angles_frag) for cs, r, n in T.Parallel(chunk_size, R, N//rotary_dim_divisor): q_shared[cs*R + r, n] = T.cos(angles_frag[cs, n]) * q_first_half_frag[cs, r, n] - T.sin(angles_frag[cs, n]) * q_second_half_frag[cs, r, n] q_shared[cs*R + r, N//2 + n] = T.sin(angles_frag[cs, n]) * q_first_half_frag[cs, r, n] + T.cos(angles_frag[cs, n]) * q_second_half_frag[cs, r, n] k_first_half_frag = T.alloc_fragment([chunk_size, R, N//rotary_dim_divisor], dtype) k_second_half_frag = T.alloc_fragment([chunk_size, R, N//rotary_dim_divisor], dtype) for cs, r, n in T.Parallel(chunk_size, R, N//rotary_dim_divisor): k_first_half_frag[cs, r, n] = k_shared[cs*R + r, n] k_second_half_frag[cs, r, n] = k_shared[cs*R + r, N//2 + n] for cs, r, n in T.Parallel(chunk_size, R, N//rotary_dim_divisor): k_shared[cs*R + r, n] = T.cos(angles_frag[cs, n]) * k_first_half_frag[cs, r, n] - T.sin(angles_frag[cs, n]) * k_second_half_frag[cs, r, n] k_shared[cs*R + r, N//2 + n] = T.sin(angles_frag[cs, n]) * k_first_half_frag[cs, r, n] + T.cos(angles_frag[cs, n]) * k_second_half_frag[cs, r, n] k_trap_scaled_frag = T.alloc_fragment([fused_chunk_size, N], dtype) T.copy(k_shared, k_trap_scaled_frag) for csr, n in T.Parallel(fused_chunk_size, N): k_trap_scaled_frag[csr, n] *= trap_scale_shared[csr//R] T.copy(k_trap_scaled_frag, k_shared) # --- Interchunk + Intrachunk Output Accumulation --- q_state_out_frag = T.alloc_fragment([fused_chunk_size, P], dtype=accum_dtype) # NOTE: Tilelang unable to infer correct layout when trying to cast # states_frag to 16-bit within rmem T.copy(states_frag, states_accum_cast_shared) T.gemm(q_shared, states_accum_cast_shared, q_state_out_frag, clear_accum=True) qk_intrachunk_frag = T.alloc_fragment([fused_chunk_size, fused_chunk_size], dtype=accum_dtype) T.gemm(q_shared, k_shared, qk_intrachunk_frag, transpose_B=True, clear_accum=True) # Strictly causal masking over chunk steps (exclude same-step diagonal). da_cs__or__exp_da_cs_shared = T.alloc_shared([chunk_size], T.float32) T.copy(DA_CS[i_b, i_h, chunk_start:chunk_start+chunk_size], da_cs__or__exp_da_cs_shared) for csr_i, csr_j in T.Parallel(fused_chunk_size, fused_chunk_size): qk_intrachunk_frag[csr_i, csr_j] = T.if_then_else( csr_i//R > csr_j//R, qk_intrachunk_frag[csr_i, csr_j] * T.exp(SEGSUM[i_b, i_h, i, csr_i//R, csr_j//R]), 0.0 ) qk_intrachunk_masked_shared = T.alloc_shared([fused_chunk_size, fused_chunk_size], dtype=dtype) for csr_i, csr_j in T.Parallel(fused_chunk_size, fused_chunk_size): qk_intrachunk_masked_shared[csr_i, csr_j] = qk_intrachunk_frag[csr_i, csr_j] # Exponentiate da_cs__or__exp_da_cs_shared so that later usage does not have to: for cs in T.Parallel(chunk_size): da_cs__or__exp_da_cs_shared[cs] = T.exp(da_cs__or__exp_da_cs_shared[cs]) exp_da_cs_frag = T.alloc_fragment([chunk_size], dtype=T.float32) T.copy(da_cs__or__exp_da_cs_shared, exp_da_cs_frag) for csr, p in T.Parallel(fused_chunk_size, P): q_state_out_frag[csr, p] *= exp_da_cs_frag[csr//R] o_mimo_accum_frag = T.alloc_fragment([fused_chunk_size, P], dtype=accum_dtype) T.gemm(qk_intrachunk_masked_shared, PsiV_shared, o_mimo_accum_frag, clear_accum=True) # Merge interchunk and intrachunk contributions. for cs, p in T.Parallel(fused_chunk_size, P): o_mimo_accum_frag[cs, p] += q_state_out_frag[cs, p] # --- Add Diagonal Terms (qk_dot and optional D) --- qkdot_psiv_frag = T.alloc_fragment([chunk_size, R, P], dtype=dtype) T.clear(qkdot_psiv_frag) for cs, r_out, p in T.Parallel(chunk_size, R, P): for r_in in T.serial(R): qkdot_psiv_frag[cs, r_out, p] += qk_dot_full_shared[cs * R + r_out, cs * R + r_in] * PsiV_shared[cs * R + r_in, p] qkdot_psiv_frag[cs, r_out, p] *= gamma_frag[cs] # Apply gamma qkdot_psiv_reshaped_frag = T.view(qkdot_psiv_frag, shape=[fused_chunk_size, P]) for csr, p in T.Parallel(fused_chunk_size, P): o_mimo_accum_frag[csr, p] += qkdot_psiv_reshaped_frag[csr, p] if hasD: D_var = T.alloc_var(T.float32) T.copy(D[i_h], D_var) PsiV_D_frag = T.alloc_fragment([fused_chunk_size, P], T.float32) T.copy(PsiV_shared, PsiV_D_frag) for csr, p in T.Parallel(fused_chunk_size, P): o_mimo_accum_frag[csr, p] += D_var * PsiV_D_frag[csr, p] # --- Project to dMIMO_O and Optional Z Backward Path --- if reduceO: out_prereduced_shared = T.alloc_shared([fused_chunk_size, P], dtype) T.copy(o_mimo_accum_frag, out_prereduced_shared) o_gated_frag = T.alloc_fragment([chunk_size, R, P], T.float32) if hasZ: # Apply Z gating to out: T.copy(Z[i_b, chunk_start:chunk_start+chunk_size, i_h, :], z_shared) z_o_frag = T.alloc_fragment([chunk_size, P], T.float32) T.copy(z_shared, z_o_frag) Zeta_o_frag = T.alloc_fragment([R, P], T.float32) T.copy(MIMO_Z[i_h, :, :], Zeta_o_frag) for cs, r, p in T.Parallel(chunk_size, R, P): # Apply SiLU to o_gated_frag: tmp = z_o_frag[cs, p] * Zeta_o_frag[r, p] * 0.5 o_gated_frag[cs, r, p] = tmp * T.tanh(tmp) + tmp for cs, r, p in T.Parallel(chunk_size, R, P): o_gated_frag[cs, r, p] *= out_prereduced_shared[cs*R + r, p] else: for cs, r, p in T.Parallel(chunk_size, R, P): o_gated_frag[cs, r, p] = out_prereduced_shared[cs*R + r, p] # NOTE: keeping dPhi_frag in fp32 for numerical reason dPhi_frag = T.alloc_fragment([R, P], T.float32) T.copy(dPhi_shared, dPhi_frag) dout_frag = T.alloc_fragment([chunk_size, P], dtype) T.copy(DOUT[i_b, chunk_start:chunk_start+chunk_size, i_h, :], dout_shared) T.copy(dout_shared, dout_frag) for r, p in T.Parallel(R, P): for cs in T.serial(chunk_size): dPhi_frag[r, p] += o_gated_frag[cs, r, p] * dout_frag[cs, p] T.copy(dPhi_frag, dPhi_shared) if hasZ: # Up-project DOUT from SISO to MIMO. Phi_frag = T.alloc_fragment([R, P], dtype) T.copy(MIMO_O[i_h, :, :], Phi_frag) dPhiO_frag = T.alloc_fragment([chunk_size, R, P], dtype) dout_preexpand_frag = T.alloc_fragment([chunk_size, P], dtype) T.copy(dout_shared, dout_preexpand_frag) for cs, r, p in T.Parallel(chunk_size, R, P): dPhiO_frag[cs, r, p] = dout_frag[cs, p] * Phi_frag[r, p] # NOTE: layout issue when trying to reuse o_mimo_accum_frag # NOTE: note that it uses out_prereduced_shared, which is the pre-Z-gate version # of out for cs, r, p in T.Parallel(chunk_size, R, P): dPhiO_frag[cs, r, p] *= out_prereduced_shared[cs*R + r, p] # Backward of SILU(z) is sigmoid(z) * (1 + z * (1 - sigmoid(z))) z_frag = T.alloc_fragment([chunk_size, P], T.float32) T.copy(z_shared, z_frag) Zeta_frag = T.alloc_fragment([R, P], T.float32) T.copy(MIMO_Z[i_h, :, :], Zeta_frag) dZetaZ_frag = T.alloc_fragment([chunk_size, R, P], T.float32) for cs, r, p in T.Parallel(chunk_size, R, P): dZetaZ_frag[cs, r, p] = z_frag[cs, p] * Zeta_frag[r, p] dZetaZ_frag[cs, r, p] = dPhiO_frag[cs, r, p]* T.sigmoid(dZetaZ_frag[cs, r, p]) * \ (1 + dZetaZ_frag[cs, r, p] * (T.sigmoid(-dZetaZ_frag[cs, r, p]))) dZ_frag = T.alloc_fragment([chunk_size, P], dtype) T.clear(dZ_frag) for cs, p in T.Parallel(chunk_size, P): for r in T.serial(R): dZ_frag[cs, p] += dZetaZ_frag[cs, r, p] * Zeta_frag[r, p] T.copy(dZ_frag, DZ[i_b, chunk_start:chunk_start+chunk_size, i_h, :]) for cs, r, p in T.Parallel(chunk_size, R, P): dZetaZ_frag[cs, r, p] *= z_frag[cs, p] dZeta_frag = T.alloc_fragment([R, P], T.float32) T.copy(dZeta_shared, dZeta_frag) T.reduce_sum(dZetaZ_frag, dZeta_frag, clear=False, dim=0) T.copy(dZeta_frag, dZeta_shared) else: if hasZ: out_prereduced_shared = T.alloc_shared([fused_chunk_size, P], dtype) T.copy(o_mimo_accum_frag, out_prereduced_shared) T.copy(Z[i_b, chunk_start:chunk_start+chunk_size, i_h, :], z_shared) dPhiO_frag = T.alloc_fragment([chunk_size, R, P], dtype) for cs, r, p in T.Parallel(chunk_size, R, P): dPhiO_frag[cs, r, p] = DOUT[i_b, chunk_start + cs, r, i_h, p] for cs, r, p in T.Parallel(chunk_size, R, P): dPhiO_frag[cs, r, p] *= out_prereduced_shared[cs*R + r, p] # Backward of SILU(z) is sigmoid(z) * (1 + z * (1 - sigmoid(z))) z_frag = T.alloc_fragment([chunk_size, P], T.float32) T.copy(z_shared, z_frag) Zeta_frag = T.alloc_fragment([R, P], T.float32) T.copy(MIMO_Z[i_h, :, :], Zeta_frag) dZetaZ_frag = T.alloc_fragment([chunk_size, R, P], T.float32) for cs, r, p in T.Parallel(chunk_size, R, P): dZetaZ_frag[cs, r, p] = z_frag[cs, p] * Zeta_frag[r, p] dZetaZ_frag[cs, r, p] = dPhiO_frag[cs, r, p]* T.sigmoid(dZetaZ_frag[cs, r, p]) * \ (1 + dZetaZ_frag[cs, r, p] * (T.sigmoid(-dZetaZ_frag[cs, r, p]))) ## Compute DZ dZ_frag = T.alloc_fragment([chunk_size, P], dtype) T.clear(dZ_frag) for cs, p in T.Parallel(chunk_size, P): for r in T.serial(R): dZ_frag[cs, p] += dZetaZ_frag[cs, r, p] * Zeta_frag[r, p] T.copy(dZ_frag, DZ[i_b, chunk_start:chunk_start+chunk_size, i_h, :]) ## Compute DMIMO_Z for cs, r, p in T.Parallel(chunk_size, R, P): dZetaZ_frag[cs, r, p] *= z_frag[cs, p] dZeta_frag = T.alloc_fragment([R, P], T.float32) T.copy(dZeta_shared, dZeta_frag) T.reduce_sum(dZetaZ_frag, dZeta_frag, clear=False, dim=0) T.copy(dZeta_frag, dZeta_shared) # --- Save and Update Recurrent State --- T.copy(states_frag, STATES[i_b, i_h, i, :, :]) # DA_CS_REV scales stepwise K contribution into the new state. dA_cs_rev_frag = T.alloc_fragment([chunk_size], T.float32) T.copy(DA_CS_REV[i_b, i_h, chunk_start:chunk_start+chunk_size], dA_cs_rev_frag) # NOTE: we can recycle k_trap_scaled_frag from earlier, however, # that is slower, so choose to recopy from smem: k_state_frag = T.alloc_fragment([fused_chunk_size, N], dtype) T.copy(k_shared, k_state_frag) for csr, n in T.Parallel(fused_chunk_size, N): k_state_frag[csr, n] *= T.exp(dA_cs_rev_frag[csr//R]) # DA_CS(last) applies chunk-level decay to the carried state. da_cs_sum = T.alloc_var(T.float32) T.copy(DA_CS[i_b, i_h, chunk_start+chunk_size-1], da_cs_sum) for n, p in T.Parallel(N, P): states_frag[n, p] *= T.exp(da_cs_sum) T.gemm(k_state_frag, PsiV_shared, states_frag, transpose_A=True, clear_accum=False) if reduceO: T.copy(dPhi_shared, DMIMO_O[i_b, i_h, :, :]) if hasZ: T.copy(dZeta_shared, DMIMO_Z[i_b, i_h, :, :]) return mamba_mimo_bwd_fwd_kernel # def get_configs(): # iter_params = dict(num_stages=[0], threads=[128, 256]) # return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] # @autotune( # configs=get_configs(), # warmup=3, # rep=20, # ) @tilelang.jit( out_idx=[], pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, }) def mamba_mimo_bwd_bwd( B, S, H, G, N, P, R, hasZ, hasD, reduceO, chunk_size: int = 16, rotary_dim_divisor: int = 4, dtype: str = 'float16', threads: int = 256, num_stages: int = 0, ) -> torch.Tensor: accum_dtype = 'float32' nchunks = tilelang.cdiv(S, chunk_size) fused_chunk_size = chunk_size * R if reduceO: DOUT_shape = (B, S, H, P) else: DOUT_shape = (B, S, R, H, P) @T.prim_func def mamba_mimo_bwd_bwd_kernel( DOUT: T.Tensor(DOUT_shape, dtype), # type: ignore Q: T.Tensor([B, S, R, G, N], dtype), # type: ignore K: T.Tensor([B, S, R, G, N], dtype), # type: ignore V: T.Tensor([B, S, H, P], dtype), # type: ignore Q_BIAS: T.Tensor([H, R, N], T.float32), # type: ignore K_BIAS: T.Tensor([H, R, N], T.float32), # type: ignore MIMO_V: T.Tensor([H, R, P], T.float32), # type: ignore MIMO_O: T.Tensor([H, R, P], T.float32), # type: ignore DK: T.Tensor([B, S*R, H, N], dtype), # type: ignore DV: T.Tensor([B, S, H, P], dtype), # type: ignore DMIMO_V: T.Tensor([B, H, R, P], T.float32), # type: ignore STATES: T.Tensor([B, H, nchunks, N, P], dtype), # type: ignore DQ: T.Tensor([B, S*R, H, N], dtype), # type: ignore Z: T.Tensor([B, S, H, P], dtype), # type: ignore MIMO_Z: T.Tensor([H, R, P], T.float32), # type: ignore ANGLES: T.Tensor([B, S, H, N//rotary_dim_divisor], T.float32), # type: ignore DA_CS: T.Tensor([B, H, S], T.float32), # type: ignore DA_CS_REV: T.Tensor([B, H, S], T.float32), # type: ignore DT: T.Tensor([B, H, S], T.float32), # type: ignore TRAP: T.Tensor([B, H, S], dtype), # type: ignore DFACTOR: T.Tensor([B, H, S], T.float32), # type: ignore DGAMMA_DIAG: T.Tensor([B, H, S], T.float32), # type: ignore DANGLES: T.Tensor([B, S, H, N//rotary_dim_divisor], T.float32), # type: ignore D: T.Tensor([H], T.float32), # type: ignore DD: T.Tensor([B, H], T.float32), # type: ignore QK_DOT: T.Tensor([B, H, S, R, R], dtype), # type: ignore # DQK_DOT: T.Tensor([B, H, S, R, R], dtype), # type: ignore DDA: T.Tensor([B, H, S], T.float32), # type: ignore DSSDA: T.Tensor([B, H, nchunks, chunk_size, chunk_size], T.float32), # type: ignore DDA_CS_REV: T.Tensor([B, H, S], T.float32), # type: ignore DDA_CS: T.Tensor([B, H, S], T.float32), # type: ignore SEGSUM: T.Tensor([B, H, nchunks, chunk_size, chunk_size], T.float32), # type: ignore ): """ Overview: Reverse-chunk backward pass that consumes cached STATES and QK_DOT from the first pass to produce gradients for the fused Mamba3 attention block. Inputs: - Forward activations/tensors: DOUT, Q, K, V, optional Z, optional D. - Projection weights/biases: Q_BIAS, K_BIAS, MIMO_V (Psi), MIMO_O (Phi), optional MIMO_Z (Zeta). - Cached intermediates: STATES and QK_DOT. - Discretization grads and factors: DA_CS, DA_CS_REV, DT, TRAP, DDA, DSSDA, DDA_CS_REV, DDA_CS, and SEGSUM. Outputs: - QKV grads: DQ, DK, DV. - MIMO projection grads: DMIMO_V. - Discretization/rotation grads: DANGLES, DFACTOR, DGAMMA_DIAG, DDA_CS_REV, DDA_CS, DDA. - Additional grads: optional DD. Notation: - Psi: MIMO X projection. - Phi: MIMO O projection. - Zeta: MIMO Z projection. - Trap: convex-combination modulator used in exponential-trapezoidal discretization. """ with T.Kernel(H, B, threads=threads) as (i_h, i_b): # --- Kernel Setup --- # GQA support: map V head to Q/K head i_h_qk = i_h // (H // G) # --- Buffer Allocation --- dstates_shared = T.alloc_shared([N, P], dtype) dstates_frag = T.alloc_fragment([N, P], accum_dtype) dout_shared = T.alloc_shared([chunk_size, P], dtype) dPhiO_shared = T.alloc_shared([fused_chunk_size, P], dtype) q_shared = T.alloc_shared([fused_chunk_size, N], dtype) k_shared = T.alloc_shared([fused_chunk_size, N], dtype) v_shared = T.alloc_shared([chunk_size, P], dtype) states_shared = T.alloc_shared([N, P], dtype) lkq_masked__or__dkq_masked_shared = T.alloc_shared([fused_chunk_size, fused_chunk_size], dtype) dPsiV_combined_shared = T.alloc_shared([fused_chunk_size, P], dtype) dqk_from_diag_shared = T.alloc_shared([fused_chunk_size, fused_chunk_size], accum_dtype) q_pre_rot_shared = T.alloc_shared([fused_chunk_size, N], dtype) k_pre_rot_shared = T.alloc_shared([fused_chunk_size, N], dtype) dk_shared = T.alloc_shared([fused_chunk_size, N], dtype) dq_shared = T.alloc_shared([fused_chunk_size, N], dtype) qk_dot_shared = T.alloc_shared([chunk_size, R, R], dtype) k_pre_trap_shared = T.alloc_shared([fused_chunk_size, N], dtype) dangle_dk__or__dq_shared = T.alloc_shared([fused_chunk_size, N//rotary_dim_divisor], T.float32) # --- Swizzling Annotation --- noswizzle_annot = threads == 256 and (N <= 32 or P >= 128) # NOTE: heuristics for when swizzling annotation causes kernel hang, needs more investigation if not noswizzle_annot: T.annotate_layout({ dstates_shared: tilelang.layout.make_swizzled_layout(dstates_shared), dout_shared: tilelang.layout.make_swizzled_layout(dout_shared), q_shared: tilelang.layout.make_swizzled_layout(q_shared), k_shared: tilelang.layout.make_swizzled_layout(k_shared), v_shared: tilelang.layout.make_swizzled_layout(v_shared), states_shared: tilelang.layout.make_swizzled_layout(states_shared), lkq_masked__or__dkq_masked_shared: tilelang.layout.make_swizzled_layout(lkq_masked__or__dkq_masked_shared), dPsiV_combined_shared: tilelang.layout.make_swizzled_layout(dPsiV_combined_shared), dqk_from_diag_shared: tilelang.layout.make_swizzled_layout(dqk_from_diag_shared), k_pre_rot_shared: tilelang.layout.make_swizzled_layout(k_pre_rot_shared), q_pre_rot_shared: tilelang.layout.make_swizzled_layout(q_pre_rot_shared), dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), dq_shared: tilelang.layout.make_swizzled_layout(dq_shared), k_pre_trap_shared: tilelang.layout.make_swizzled_layout(k_pre_trap_shared), dangle_dk__or__dq_shared: tilelang.layout.make_swizzled_layout(dangle_dk__or__dq_shared), }) T.use_swizzle(10, "row") T.no_set_max_nreg() # --- Per-Head Constants / Running State --- T.clear(dstates_frag) T.clear(dstates_shared) if reduceO: Phi_frag = T.alloc_fragment([R, P], dtype) T.copy(MIMO_O[i_h, :, :], Phi_frag) Psi_frag = T.alloc_fragment([R, P], dtype) T.copy(MIMO_V[i_h, :, :], Psi_frag) dPsi_acc = T.alloc_fragment([R, P], accum_dtype) # TODO T.clear(dPsi_acc) if hasD: dD_frag = T.alloc_fragment([1], accum_dtype) T.clear(dD_frag) q_bias_frag = T.alloc_fragment([R, N], dtype) k_bias_frag = T.alloc_fragment([R, N], dtype) T.copy(Q_BIAS[i_h, :, :], q_bias_frag) T.copy(K_BIAS[i_h, :, :], k_bias_frag) # --- Reverse Chunk Loop --- for chunk_idx_rev in T.Pipelined(0, nchunks, num_stages=num_stages): chunk_idx = nchunks - 1 - chunk_idx_rev chunk_start = chunk_idx * chunk_size fused_chunk_start = chunk_start * R # --- Discretization Factors (Shifted Gamma + Trap Scale) --- trap_shifted_frag = T.alloc_fragment([chunk_size], T.float32) dt_shifted_frag = T.alloc_fragment([chunk_size], dtype) for cs in T.Parallel(chunk_size): trap_shifted_frag[cs] = T.if_then_else( chunk_start + cs + 1 < S, TRAP[i_b, i_h, chunk_start + cs + 1], 0.0, ) dt_shifted_frag[cs] = T.if_then_else( chunk_start + cs + 1 < S, DT[i_b, i_h, chunk_start + cs + 1], 0.0, ) shifted_gamma_frag = T.alloc_fragment([chunk_size], dtype) for cs in T.Parallel(chunk_size): shifted_gamma_frag[cs] = T.if_then_else(chunk_start + cs < (S - 1), dt_shifted_frag[cs] * T.sigmoid(-trap_shifted_frag[cs]), 0.0) trap_frag = T.alloc_fragment([chunk_size], T.float32) T.copy(TRAP[i_b, i_h, chunk_start: chunk_start+chunk_size], trap_frag) dt_frag = T.alloc_fragment([chunk_size], dtype) T.copy(DT[i_b, i_h, chunk_start: chunk_start+chunk_size], dt_frag) gamma_frag = T.alloc_fragment([chunk_size], T.float32) for cs in T.Parallel(chunk_size): gamma_frag[cs] = dt_frag[cs] * T.sigmoid(trap_frag[cs]) gamma_cached_frag = T.alloc_fragment([chunk_size], T.float32) T.copy(gamma_frag, gamma_cached_frag) trap_scale_frag = T.alloc_fragment([chunk_size], dtype) for cs in T.Parallel(chunk_size): trap_scale_frag[cs] = gamma_frag[cs] + shifted_gamma_frag[cs] trap_scale_shared = T.alloc_shared([chunk_size], dtype) T.copy(trap_scale_frag, trap_scale_shared) # --- DOUT Projection and Optional Z / D Paths --- dPhiO_frag = T.alloc_fragment([chunk_size, R, P], dtype) if reduceO: for cs, p in T.Parallel(chunk_size, P): dout_shared[cs, p] = DOUT[i_b, chunk_start+cs, i_h, p] for cs, r, p in T.Parallel(chunk_size, R, P): dPhiO_frag[cs, r, p] = dout_shared[cs, p] * Phi_frag[r, p] else: for cs, r, p in T.Parallel(chunk_size, R, P): dPhiO_frag[cs, r, p] = DOUT[i_b, chunk_start + cs, r, i_h, p] if hasZ: ## Backpropagate via *SILU(Z) Zeta_frag = T.alloc_fragment([R, P], dtype) T.copy(MIMO_Z[i_h, :, :], Zeta_frag) z_frag = T.alloc_fragment([chunk_size, P], dtype) T.copy(Z[i_b, chunk_start:chunk_start+chunk_size, i_h, :], z_frag) for cs, r, p in T.Parallel(chunk_size, R, P): tmp = z_frag[cs, p] * Zeta_frag[r, p] * 0.5 dPhiO_frag[cs, r, p] *= tmp * T.tanh(tmp) + tmp T.copy(T.view(dPhiO_frag, shape=[fused_chunk_size, P]), dPhiO_shared) T.copy(V[i_b, chunk_start:chunk_start+chunk_size, i_h, :], v_shared) if hasD: # Compute dD via projected DOUT and V/Psi factors. v_dD_frag = T.alloc_fragment([chunk_size, P], accum_dtype) Psi_dD_frag = T.alloc_fragment([R, P], accum_dtype) T.copy(v_shared, v_dD_frag) T.copy(MIMO_V[i_h, :, :], Psi_dD_frag) for cs, r, p in T.Parallel(chunk_size, R, P): dPhiO_frag[cs, r, p] *= v_dD_frag[cs, p] * Psi_dD_frag[r, p] T.reduce_sum(T.view(dPhiO_frag, shape=[fused_chunk_size*P]), dD_frag, clear=False) # --- Prepare Rotated/Scaled QK and Compute dPsiV --- # Load q and apply q_bias to it: for cs, r, n in T.Parallel(chunk_size, R, N): q_shared[cs*R + r, n] = Q[i_b, chunk_start+cs, r, i_h_qk, n] q_frag = T.alloc_fragment([chunk_size, R, N], dtype) for cs, r, n in T.Parallel(chunk_size, R, N): q_frag[cs, r, n] = q_shared[cs*R + r, n] for cs, r, n in T.Parallel(chunk_size, R, N): q_frag[cs, r, n] += q_bias_frag[r, n] for cs, r, n in T.Parallel(chunk_size, R, N): q_shared[cs*R + r, n] = q_frag[cs, r, n] T.copy(q_shared, q_pre_rot_shared) # Save pre-rotated q for later: # Apply rotary to q: q_first_half_frag = T.alloc_fragment([chunk_size, R, N//rotary_dim_divisor], dtype) q_second_half_frag = T.alloc_fragment([chunk_size, R, N//rotary_dim_divisor], dtype) for cs, r, n in T.Parallel(chunk_size, R, N//rotary_dim_divisor): q_first_half_frag[cs, r, n] = q_shared[cs*R + r, n] q_second_half_frag[cs, r, n] = q_shared[cs*R + r, N//2 + n] angles_frag = T.alloc_fragment([chunk_size, N//rotary_dim_divisor], T.float32) T.copy(ANGLES[i_b, chunk_start:chunk_start+chunk_size, i_h, :], angles_frag) for cs, r, n in T.Parallel(chunk_size, R, N//rotary_dim_divisor): q_shared[cs*R + r, n] = T.cos(angles_frag[cs, n]) * q_first_half_frag[cs, r, n] - T.sin(angles_frag[cs, n]) * q_second_half_frag[cs, r, n] q_shared[cs*R + r, N//2 + n] = T.sin(angles_frag[cs, n]) * q_first_half_frag[cs, r, n] + T.cos(angles_frag[cs, n]) * q_second_half_frag[cs, r, n] # Load k and apply k_bias to it: k_reshaped_shared = T.view(k_pre_trap_shared, shape=[chunk_size, R, N]) T.copy(K[i_b, chunk_start:chunk_start+chunk_size, :, i_h_qk, :], k_reshaped_shared) k_frag = T.alloc_fragment([chunk_size, R, N], dtype) T.copy(k_reshaped_shared, k_frag) for cs, r, n in T.Parallel(chunk_size, R, N): k_frag[cs, r, n] += k_bias_frag[r, n] T.copy(k_frag, k_reshaped_shared) # Save pre-rotated k for later: for csr, n in T.Parallel(fused_chunk_size, N): k_pre_rot_shared[csr, n] = k_pre_trap_shared[csr, n] # Apply rotary to k: k_first_half_frag = T.alloc_fragment([chunk_size, R, N//rotary_dim_divisor], dtype) k_second_half_frag = T.alloc_fragment([chunk_size, R, N//rotary_dim_divisor], dtype) for cs, r, n in T.Parallel(chunk_size, R, N//rotary_dim_divisor): k_first_half_frag[cs, r, n] = k_reshaped_shared[cs, r, n] k_second_half_frag[cs, r, n] = k_reshaped_shared[cs, r, N//2 + n] for cs, r, n in T.Parallel(chunk_size, R, N//rotary_dim_divisor): k_reshaped_shared[cs, r, n] = T.cos(angles_frag[cs, n]) * k_first_half_frag[cs, r, n] - T.sin(angles_frag[cs, n]) * k_second_half_frag[cs, r, n] k_reshaped_shared[cs, r, N//2 + n] = T.sin(angles_frag[cs, n]) * k_first_half_frag[cs, r, n] + T.cos(angles_frag[cs, n]) * k_second_half_frag[cs, r, n] # Apply Trap-specific scaling: k_trap_scaled_frag = T.alloc_fragment([fused_chunk_size, N], dtype) T.copy(k_pre_trap_shared, k_trap_scaled_frag) for csr, n in T.Parallel(fused_chunk_size, N): k_trap_scaled_frag[csr, n] *= trap_scale_shared[csr//R] T.copy(k_trap_scaled_frag, k_shared) # Apply the effect of interchunk (state update): dPsiV_frag = T.alloc_fragment([fused_chunk_size, P], accum_dtype) T.gemm(k_shared, dstates_shared, dPsiV_frag, clear_accum=True) dA_cs_rev_frag = T.alloc_fragment([chunk_size], T.float32) dA_cs_rev_shared = T.alloc_shared([chunk_size], T.float32) T.copy(DA_CS_REV[i_b, i_h, chunk_start:chunk_start+chunk_size], dA_cs_rev_shared) T.copy(dA_cs_rev_shared, dA_cs_rev_frag) for csr, p in T.Parallel(fused_chunk_size, P): # DA_CS_REV scales per-step state contribution into dPsiV. dPsiV_frag[csr, p] *= T.exp(dA_cs_rev_frag[csr//R]) # Apply the effect of intrachunk: lkq_frag = T.alloc_fragment([fused_chunk_size, fused_chunk_size], accum_dtype) T.gemm(k_shared, q_shared, lkq_frag, transpose_B=True, clear_accum=True) T.copy(lkq_frag, lkq_masked__or__dkq_masked_shared) # NOTE: Save later for the computation of DSSDA, using lkq_masked__or__dkq_masked_shared which has the same shape if R == 1: # More smem efficient which is necessary for R=1, but slower due to the need for casting lkq_masked_dtype_buf = T.alloc_fragment([fused_chunk_size, fused_chunk_size], dtype) T.copy(lkq_masked__or__dkq_masked_shared, lkq_masked_dtype_buf) for csr_i, csr_j in T.Parallel(fused_chunk_size, fused_chunk_size): # Reverse-causal mask for backward flow across chunk steps. lkq_masked_dtype_buf[csr_i, csr_j] = T.if_then_else( csr_i//R < csr_j//R, lkq_masked_dtype_buf[csr_i, csr_j] * T.exp(SEGSUM[i_b, i_h, chunk_idx, csr_j//R, csr_i//R]), 0.0 ) else: for csr_i, csr_j in T.Parallel(fused_chunk_size, fused_chunk_size): # Reverse-causal mask for backward flow across chunk steps. lkq_frag[csr_i, csr_j] = T.if_then_else( csr_i//R < csr_j//R, lkq_frag[csr_i, csr_j] * T.exp(SEGSUM[i_b, i_h, chunk_idx, csr_j//R, csr_i//R]), 0.0 ) lkq_masked_dtype_buf = T.alloc_shared([fused_chunk_size, fused_chunk_size], dtype) T.copy(lkq_frag, lkq_masked_dtype_buf) # Convert to dtype T.gemm(lkq_masked_dtype_buf, dPhiO_shared, dPsiV_frag, clear_accum=False) # --- Add Diagonal Contributions to dPsiV (D and qk_dot) --- dPsiV_D_fused_frag = T.alloc_fragment([fused_chunk_size, P], accum_dtype) if hasD: D_frag = T.alloc_var(T.float32) T.copy(D[i_h], D_frag) for csr, p in T.Parallel(fused_chunk_size, P): dPsiV_D_fused_frag[csr, p] = dPsiV_frag[csr, p] + dPhiO_shared[csr, p]*D_frag else: T.copy(dPsiV_frag, dPsiV_D_fused_frag) # Compute the contribution from the qk_dot term: # NOTE: recomputing qk_dot here is much slower than just loading from # the result of the bwd_fwd kernel qk_dot_frag = T.alloc_fragment([chunk_size, R, R], dtype) T.copy(QK_DOT[i_b, i_h, chunk_start:chunk_start+chunk_size, :, :], qk_dot_shared) T.copy(qk_dot_shared, qk_dot_frag) gamma_dPsiV_frag = T.alloc_fragment([chunk_size], dtype) T.copy(gamma_frag, gamma_dPsiV_frag) for csr, p in T.Parallel(fused_chunk_size, P): cs = csr // R r_in = csr % R for r_out in T.serial(R): csr_out = cs * R + r_out dPsiV_D_fused_frag[csr, p] += dPhiO_shared[csr_out, p] * qk_dot_frag[cs, r_out, r_in] * gamma_dPsiV_frag[cs] T.copy(dPsiV_D_fused_frag, dPsiV_combined_shared) # --- Compute dV and dPsi from dPsiV --- # Compute dV dv_frag = T.alloc_fragment([chunk_size, P], dtype) T.clear(dv_frag) for cs, p in T.Parallel(chunk_size, P): for r in T.serial(R): dv_frag[cs, p] += dPsiV_combined_shared[cs*R + r, p] * Psi_frag[r, p] T.copy(dv_frag, DV[i_b, chunk_start:chunk_start+chunk_size, i_h, :]) dPsi_frag = T.alloc_fragment([R, P], accum_dtype) T.copy(dPsi_acc, dPsi_frag) v_frag = T.alloc_fragment([chunk_size, P], accum_dtype) T.copy(v_shared, v_frag) for r, p in T.Parallel(R, P): for cs in T.serial(chunk_size): dPsi_frag[r, p] += dPsiV_combined_shared[cs*R + r, p] * v_frag[cs, p] T.copy(dPsi_frag, dPsi_acc) # Compute Psi_V PsiV_frag = T.alloc_fragment([chunk_size, R, P], dtype) T.clear(PsiV_frag) for cs, p in T.Parallel(chunk_size, P): for r in T.serial(R): PsiV_frag[cs, r, p] += v_frag[cs, p] * Psi_frag[r, p] # NOTE: Tilelang unable to perform gemm with reshaped PsiV_frag # so have to copy to smem PsiV_shared = T.alloc_shared([fused_chunk_size, P], dtype) for cs, r, p in T.Parallel(chunk_size, R, P): PsiV_shared[cs*R + r, p] = PsiV_frag[cs, r, p] # Compute dqk_from_diag, which is the contribution to dQ/dK from qk_dot: dqk_from_diag_frag = T.alloc_fragment([fused_chunk_size, fused_chunk_size], accum_dtype) T.gemm(dPhiO_shared, PsiV_shared, dqk_from_diag_frag, transpose_B=True, clear_accum=True) # (cs*r_out, cs*r_in) # Compute dgamma_diag dgamma_diag_prereduce_frag = T.alloc_fragment([chunk_size, R, R], accum_dtype) T.copy(qk_dot_shared, dgamma_diag_prereduce_frag) T.copy(dqk_from_diag_frag, dqk_from_diag_shared) for cs, r_out, r_in in T.Parallel(chunk_size, R, R): dgamma_diag_prereduce_frag[cs, r_out, r_in] *= dqk_from_diag_shared[cs*R + r_out, cs*R + r_in] dgamma_diag_reduced_frag = T.alloc_fragment([chunk_size], accum_dtype) T.reduce_sum( T.view(dgamma_diag_prereduce_frag, shape=[chunk_size, R*R]), dgamma_diag_reduced_frag, dim=-1, clear=True ) T.copy(dgamma_diag_reduced_frag, DGAMMA_DIAG[i_b, i_h, chunk_start:chunk_start+chunk_size]) # Apply shifted gamma to dqk: gamma_qk_frag = T.alloc_fragment([chunk_size], accum_dtype) T.copy(gamma_cached_frag, gamma_qk_frag) # Apply shifted gamma for csr_i, csr_j in T.Parallel(fused_chunk_size, fused_chunk_size): dqk_from_diag_frag[csr_i, csr_j] *= gamma_qk_frag[csr_i//R] T.copy(dqk_from_diag_frag, dqk_from_diag_shared) # --- dK Path + ddA Terms --- dk_frag = T.alloc_fragment([fused_chunk_size, N], accum_dtype) T.gemm(PsiV_shared, dstates_shared, dk_frag, transpose_B=True, clear_accum=True) # Compute contribution to ddA from KV part of state update (part 1 of 4) ddA_state_kv_prereduce_frag = T.alloc_fragment([fused_chunk_size, N], accum_dtype) T.copy(k_shared, ddA_state_kv_prereduce_frag) for csr, n in T.Parallel(fused_chunk_size, N): ddA_state_kv_prereduce_frag[csr, n] *= dk_frag[csr, n] ddA_state_kv_prereduce_frag_reshaped = T.view(ddA_state_kv_prereduce_frag, shape=[chunk_size, R*N]) ddA_state_kv_frag = T.alloc_fragment([chunk_size], accum_dtype) T.reduce_sum(ddA_state_kv_prereduce_frag_reshaped, ddA_state_kv_frag, dim=-1, clear=True) T.copy(ddA_state_kv_frag, DDA_CS_REV[i_b, i_h, chunk_start:chunk_start+chunk_size]) # Interchunk path uses k_scaled * exp(dA_cs_rev) in forward, # so apply exp(dA_cs_rev) to the interchunk dk term only. dA_cs_rev_dk_frag = T.alloc_fragment([chunk_size], T.float32) T.copy(dA_cs_rev_shared, dA_cs_rev_dk_frag) for cs in T.Parallel(chunk_size): dA_cs_rev_dk_frag[cs] = T.exp(dA_cs_rev_dk_frag[cs]) for csr, n in T.Parallel(fused_chunk_size, N): dk_frag[csr, n] *= dA_cs_rev_dk_frag[csr//R] dk_intrachunk_frag = T.alloc_fragment([fused_chunk_size, fused_chunk_size], accum_dtype) T.gemm(PsiV_shared, dPhiO_shared, dk_intrachunk_frag, transpose_B=True, clear_accum=True) # Compute contribution to ddA from intrachunk (part 2 of 4) kq_frag = T.alloc_fragment([fused_chunk_size, fused_chunk_size], dtype) T.copy(lkq_masked__or__dkq_masked_shared, kq_frag) for csr_i, csr_j in T.Parallel(fused_chunk_size, fused_chunk_size): kq_frag[csr_i, csr_j] *= dk_intrachunk_frag[csr_i, csr_j] kq_frag_reshaped = T.view(kq_frag, shape=[fused_chunk_size, chunk_size, R]) interchunk_dda_prereduce_frag = T.alloc_fragment([fused_chunk_size, chunk_size], accum_dtype) T.reduce_sum(kq_frag_reshaped, interchunk_dda_prereduce_frag, dim=-1, clear=True) interchunk_dda_prereduce_frag_reshaped = T.view(interchunk_dda_prereduce_frag, shape=[chunk_size, R, chunk_size]) interchunk_dda_frag = T.alloc_fragment([chunk_size, chunk_size], accum_dtype) T.reduce_sum(interchunk_dda_prereduce_frag_reshaped, interchunk_dda_frag, dim=1, clear=True) T.copy(interchunk_dda_frag, DSSDA[i_b, i_h, chunk_idx, :, :]) for csr_i, csr_j in T.Parallel(fused_chunk_size, fused_chunk_size): # Reverse-causal mask for intrachunk gradient flow. dk_intrachunk_frag[csr_i, csr_j] = T.if_then_else( csr_i//R < csr_j//R, dk_intrachunk_frag[csr_i, csr_j] * T.exp(SEGSUM[i_b, i_h, chunk_idx, csr_j//R, csr_i//R]), 0.0 ) T.copy(dk_intrachunk_frag, lkq_masked__or__dkq_masked_shared) # denote lkq_masked__or__dkq_masked_shared as dkq_intrachunk T.copy(dk_frag, dk_shared) dk_nodiag_frag = T.alloc_fragment([fused_chunk_size, N], accum_dtype) T.copy(dk_shared, dk_nodiag_frag) T.gemm(lkq_masked__or__dkq_masked_shared, q_shared, dk_nodiag_frag, clear_accum=False) # Adding dk_interchunk to dkq_intrachunk @ q # Compute dfactor, using dk_nodiag_frag: k_factor_frag = T.alloc_fragment([chunk_size, R, N], accum_dtype) T.copy(k_pre_trap_shared, T.view(k_factor_frag, shape=[fused_chunk_size, N])) dfactor_prereduce_frag = T.alloc_fragment([chunk_size, R, N], accum_dtype) for cs, r, n in T.Parallel(chunk_size, R, N): dfactor_prereduce_frag[cs, r, n] = k_factor_frag[cs, r, n] * dk_nodiag_frag[cs*R + r, n] dfactor_frag = T.alloc_fragment([chunk_size], accum_dtype) T.reduce_sum(T.view(dfactor_prereduce_frag, shape=[chunk_size, R*N]), dfactor_frag, dim=-1, clear=True) T.copy(dfactor_frag, DFACTOR[i_b, i_h, chunk_start:chunk_start+chunk_size]) # Account for the effect of trap_scale = gamma + shifted_gamma: trap_scale_dk_frag = T.alloc_fragment([chunk_size], dtype) T.copy(trap_scale_shared, trap_scale_dk_frag) for csr, n in T.Parallel(fused_chunk_size, N): dk_nodiag_frag[csr, n] *= trap_scale_dk_frag[csr//R] T.copy(dk_nodiag_frag, dk_shared) # --- State-Passing ddA Terms + Interchunk dQ --- T.copy(STATES[i_b, i_h, chunk_idx, :, :], states_shared) # Load cached states from bwd_fwd # NOTE: Compute the contribution of state passing (part 3 of 4) states_frag = T.alloc_fragment([N, P], T.float32) T.copy(states_shared, states_frag) ddA_state_passing = T.alloc_fragment([1], T.float32) ddA_state_passing_prereduce_frag = T.alloc_fragment([N, P], T.float32) da_cs_sum = T.alloc_var(T.float32) T.copy(DA_CS[i_b, i_h, chunk_start+chunk_size-1], da_cs_sum) for n, p in T.Parallel(N, P): ddA_state_passing_prereduce_frag[n, p] = ( states_frag[n, p] * dstates_frag[n, p] * T.exp(da_cs_sum) ) T.reduce_sum( T.view(ddA_state_passing_prereduce_frag, shape=[N*P]), ddA_state_passing, dim=-1, clear=True, ) dda_frag = T.alloc_fragment([chunk_size,], T.float32) for cs in T.Parallel(chunk_size): dda_frag[cs] = ddA_state_passing[0] T.copy(dda_frag, DDA[i_b, i_h, chunk_start:chunk_start+chunk_size]) dq_frag = T.alloc_fragment([fused_chunk_size, N], accum_dtype) T.gemm(dPhiO_shared, states_shared, dq_frag, transpose_B=True, clear_accum=True) # NOTE: Compute the contribution to ddA from applying it to q*state (part 4 of 4) dda_cs_prereduce_frag = T.alloc_fragment([fused_chunk_size, N], accum_dtype) T.copy(q_shared, dda_cs_prereduce_frag) for csr, n in T.Parallel(fused_chunk_size, N): dda_cs_prereduce_frag[csr, n] *= dq_frag[csr, n] dda_cs_frag = T.alloc_fragment([chunk_size], accum_dtype) T.reduce_sum(T.view(dda_cs_prereduce_frag, shape=[chunk_size, R*N]), dda_cs_frag, dim=-1, clear=True) T.copy(dda_cs_frag, DDA_CS[i_b, i_h, chunk_start:chunk_start+chunk_size]) dA_cs_dq_frag = T.alloc_fragment([chunk_size], T.float32) dA_cs_shared = T.alloc_shared([chunk_size], T.float32) T.copy(DA_CS[i_b, i_h, chunk_start:chunk_start+chunk_size], dA_cs_shared) T.copy(dA_cs_shared, dA_cs_dq_frag) for csr, n in T.Parallel(fused_chunk_size, N): # DA_CS scales interchunk q-state contribution in backward. dq_frag[csr, n] *= T.exp(dA_cs_dq_frag[csr//R]) # NOTE: Unable to reuse dk_intrachunk_frag_dtype due to layout issue # (we do gemm with the transpose of dk_intrachunk_frag_dtype) T.copy(dq_frag, dq_shared) dq_combined_frag = T.alloc_fragment([fused_chunk_size, N], accum_dtype) T.copy(dq_shared, dq_combined_frag) T.gemm(lkq_masked__or__dkq_masked_shared, k_shared, dq_combined_frag, transpose_A=True, clear_accum=False) T.copy(dq_combined_frag, dq_shared) # --- Inverse Rotary for dK and dQ + dAngles --- angles_dk_frag = T.alloc_fragment([chunk_size, N//rotary_dim_divisor], T.float32) T.copy(ANGLES[i_b, chunk_start:chunk_start+chunk_size, i_h, :], angles_dk_frag) dk_first_half_frag = T.alloc_fragment([chunk_size, R, N//rotary_dim_divisor], dtype) dk_second_half_frag = T.alloc_fragment([chunk_size, R, N//rotary_dim_divisor], dtype) k_prerot_first_half_frag = T.alloc_fragment([chunk_size, R, N//rotary_dim_divisor], dtype) k_prerot_second_half_frag = T.alloc_fragment([chunk_size, R, N//rotary_dim_divisor], dtype) for cs, r, n in T.Parallel(chunk_size, R, N//rotary_dim_divisor): dk_first_half_frag[cs, r, n] = dk_shared[cs*R + r, n] dk_second_half_frag[cs, r, n] = dk_shared[cs*R + r, N//2 + n] k_prerot_first_half_frag[cs, r, n] = k_pre_rot_shared[cs*R + r, n] k_prerot_second_half_frag[cs, r, n] = k_pre_rot_shared[cs*R + r, N//2 + n] # Compute the contribution of dk to dangle: dangle_dk_frag = T.alloc_fragment([chunk_size, R, N//rotary_dim_divisor], T.float32) for cs, r, n in T.Parallel(chunk_size, R, N//rotary_dim_divisor): dangle_dk_frag[cs, r, n] = dk_first_half_frag[cs, r, n] * (-k_prerot_first_half_frag[cs, r, n] * T.sin(angles_dk_frag[cs, n]) - k_prerot_second_half_frag[cs, r, n] * T.cos(angles_dk_frag[cs, n])) +\ dk_second_half_frag[cs, r, n] * (k_prerot_first_half_frag[cs, r, n] * T.cos(angles_dk_frag[cs, n]) - k_prerot_second_half_frag[cs, r, n] * T.sin(angles_dk_frag[cs, n])) T.copy(T.view(dangle_dk_frag, shape=[fused_chunk_size, N//rotary_dim_divisor]), dangle_dk__or__dq_shared) # Rotate dk_shared: for cs, r, n in T.Parallel(chunk_size, R, N//rotary_dim_divisor): dk_shared[cs*R + r, n] = T.cos(angles_dk_frag[cs, n]) * dk_first_half_frag[cs, r, n] + T.sin(angles_dk_frag[cs, n]) * dk_second_half_frag[cs, r, n] dk_shared[cs*R + r, N//2 + n] = -T.sin(angles_dk_frag[cs, n]) * dk_first_half_frag[cs, r, n] + T.cos(angles_dk_frag[cs, n]) * dk_second_half_frag[cs, r, n] dk_combined_frag = T.alloc_fragment([fused_chunk_size, N], accum_dtype) T.copy(dk_shared, dk_combined_frag) # Compute the effect of dqk_from_diag q_dk_frag = T.alloc_fragment([fused_chunk_size, N], accum_dtype) # Keeping q_dk_frag in accum_dtype to avoid casting instructions T.copy(q_pre_rot_shared, q_dk_frag) # NOTE: we need to use the pre-rotated version of q q_dk_frag_reshaped = T.view(q_dk_frag, [chunk_size, R, N]) for csr_in, n in T.Parallel(fused_chunk_size, N): cs = csr_in // R for r_out in T.serial(R): csr_out = cs*R + r_out dk_combined_frag[csr_in, n] += dqk_from_diag_shared[csr_out, csr_in] * q_dk_frag_reshaped[cs, r_out, n] # Copy to gmem: T.copy(dk_combined_frag, DK[i_b, fused_chunk_start:fused_chunk_start+fused_chunk_size, i_h, :]) angles_dq_frag = T.alloc_fragment([chunk_size, N//rotary_dim_divisor], T.float32) T.copy(ANGLES[i_b, chunk_start:chunk_start+chunk_size, i_h, :], angles_dq_frag) dq_first_half_frag = T.alloc_fragment([chunk_size, R, N//rotary_dim_divisor], dtype) dq_second_half_frag = T.alloc_fragment([chunk_size, R, N//rotary_dim_divisor], dtype) for cs, r, n in T.Parallel(chunk_size, R, N//rotary_dim_divisor): dq_first_half_frag[cs, r, n] = dq_shared[cs*R + r, n] dq_second_half_frag[cs, r, n] = dq_shared[cs*R + r, N//2 + n] # Compute the contribution of dq to dangle: q_prerot_first_half_frag = T.alloc_fragment([chunk_size, R, N//rotary_dim_divisor], dtype) q_prerot_second_half_frag = T.alloc_fragment([chunk_size, R, N//rotary_dim_divisor], dtype) for cs, r, n in T.Parallel(chunk_size, R, N//rotary_dim_divisor): q_prerot_first_half_frag[cs, r, n] = q_pre_rot_shared[cs*R + r, n] q_prerot_second_half_frag[cs, r, n] = q_pre_rot_shared[cs*R + r, N//2 + n] dangle_dq_frag = T.alloc_fragment([chunk_size, R, N//rotary_dim_divisor], T.float32) T.copy(dangle_dk__or__dq_shared, T.view(dangle_dq_frag, shape=[fused_chunk_size, N//rotary_dim_divisor])) for cs, r, n in T.Parallel(chunk_size, R, N//rotary_dim_divisor): dangle_dq_frag[cs, r, n] += dq_first_half_frag[cs, r, n] * (-q_prerot_first_half_frag[cs, r, n] * T.sin(angles_dq_frag[cs, n]) - q_prerot_second_half_frag[cs, r, n] * T.cos(angles_dq_frag[cs, n])) +\ dq_second_half_frag[cs, r, n] * (q_prerot_first_half_frag[cs, r, n] * T.cos(angles_dq_frag[cs, n]) - q_prerot_second_half_frag[cs, r, n] * T.sin(angles_dq_frag[cs, n])) # Sum dangle across R, and copy to gmem dangle_frag_reduced = T.alloc_fragment([chunk_size, N//rotary_dim_divisor], T.float32) T.clear(dangle_frag_reduced) for cs, n in T.Parallel(chunk_size, N//rotary_dim_divisor): for r in T.serial(R): dangle_frag_reduced[cs, n] += dangle_dq_frag[cs, r, n] T.copy(dangle_frag_reduced, DANGLES[i_b, chunk_start:chunk_start+chunk_size, i_h, :]) # Rotate dq_shared: for cs, r, n in T.Parallel(chunk_size, R, N//rotary_dim_divisor): dq_shared[cs*R + r, n] = T.cos(angles_dk_frag[cs, n]) * dq_first_half_frag[cs, r, n] + T.sin(angles_dk_frag[cs, n]) * dq_second_half_frag[cs, r, n] dq_shared[cs*R + r, N//2 + n] = -T.sin(angles_dk_frag[cs, n]) * dq_first_half_frag[cs, r, n] + T.cos(angles_dk_frag[cs, n]) * dq_second_half_frag[cs, r, n] T.copy(dq_shared, dq_frag) # Compute the effect of dqk_from_diag for csr_out, n in T.Parallel(fused_chunk_size, N): cs = csr_out // R for r_in in T.serial(R): csr_in = cs*R + r_in dq_frag[csr_out, n] += dqk_from_diag_shared[csr_out, csr_in] * k_pre_rot_shared[csr_in, n] # Copy to gmem: T.copy(dq_frag, DQ[i_b, fused_chunk_start:fused_chunk_start+fused_chunk_size, i_h, :]) # --- Update Reverse-Passed State Gradient --- da_cs_sum_dstates = T.alloc_var(T.float32) T.copy(DA_CS[i_b, i_h, chunk_start+chunk_size-1], da_cs_sum_dstates) for n, p in T.Parallel(N, P): dstates_frag[n, p] *= T.exp(da_cs_sum_dstates) dPhiO_scaled_frag = T.alloc_fragment([fused_chunk_size, P], dtype) T.copy(dPhiO_shared, dPhiO_scaled_frag) dA_cs_dPhiO_frag = T.alloc_fragment([chunk_size], T.float32) T.copy(dA_cs_shared, dA_cs_dPhiO_frag) for csr, p in T.Parallel(fused_chunk_size, P): # DA_CS applies chunk-level decay to the passed gradient state. dPhiO_scaled_frag[csr, p] *= T.exp(dA_cs_dPhiO_frag[csr//R]) T.gemm(q_shared, dPhiO_scaled_frag, dstates_frag, transpose_A=True, clear_accum=False) T.copy(dstates_frag, dstates_shared) T.copy(dPsi_acc, DMIMO_V[i_b, i_h, :, :]) if hasD: T.copy(dD_frag, DD[i_b, i_h]) return mamba_mimo_bwd_bwd_kernel def mamba_mimo_bwd_combined( dout, q, k, v, q_bias, k_bias, mimo_v, mimo_o, z, mimo_z, angles, dA_cs, dA_cs_rev, dt, trap, D, segsum, chunk_size, rotary_dim_divisor, dtype, bf_threads=128, bf_num_stages=0, bb_threads=256, bb_num_stages=0, ): # TileLang kernel expects contiguous last-dim strides for DOUT. B, S, R, G, N = q.shape H, P = v.shape[-2], v.shape[-1] reduceO = mimo_o is not None dmimo_o = torch.empty([B, H, R, P], dtype=mimo_v.dtype, device=mimo_v.device) if reduceO else None states = torch.empty([B, H, S//chunk_size, N, P], dtype=v.dtype, device=v.device) # NOTE: states dtype is set to v.dtype if z is not None: dz_tilelang = torch.empty_like(v) dmimo_z = torch.empty([B, H, R, P], dtype=mimo_v.dtype, device=mimo_v.device) else: dz_tilelang = None dmimo_z = None qk_dot = torch.zeros([B, H, S, R, R], dtype=q.dtype, device=q.device) if isinstance(dtype, torch.dtype): dtype_str = str(dtype).replace("torch.", "") else: dtype_str = dtype bwd_fwd_kernel = mamba_mimo_bwd_fwd(B, S, H, G, N, P, R, z is not None, D is not None, reduceO, chunk_size, rotary_dim_divisor, dtype_str, bf_threads, bf_num_stages) bwd_fwd_kernel( dout, q, k, v, q_bias, k_bias, mimo_v, mimo_o, dmimo_o, states, z, mimo_z, dz_tilelang, dmimo_z, angles, dA_cs, dA_cs_rev, dt, trap, D, qk_dot, segsum, ) if reduceO: dmimo_o = dmimo_o.sum(dim=0) dq_tilelang = torch.empty([B, S, R, H, N], dtype=q.dtype, device=q.device) dk_tilelang = torch.empty([B, S, R, H, N], dtype=k.dtype, device=k.device) dv_tilelang = torch.empty_like(v) dmimo_v = torch.empty([B, H, R, P], dtype=mimo_v.dtype, device=mimo_v.device) dD = torch.empty([B, H], dtype=D.dtype, device=D.device) if D is not None else None dangles = torch.zeros([B, S, H, N//rotary_dim_divisor], dtype=angles.dtype, device=angles.device) dfactor = torch.zeros([B, H, S], dtype=torch.float32, device=trap.device) dgamma_diag = torch.zeros([B, H, S], dtype=torch.float32, device=trap.device) ddA = torch.zeros([B, H, S], dtype=torch.float32, device=dt.device) dSSdA = torch.zeros([B, H, S//chunk_size, chunk_size, chunk_size], dtype=torch.float32, device=dt.device) ddA_cs_rev = torch.zeros([B, H, S], dtype=torch.float32, device=dt.device) ddA_cs = torch.zeros([B, H, S], dtype=torch.float32, device=dt.device) bwd_bwd_kernel = mamba_mimo_bwd_bwd(B, S, H, G, N, P, R, z is not None, D is not None, reduceO, chunk_size, rotary_dim_divisor, dtype_str, bb_threads, bb_num_stages) bwd_bwd_kernel( dout, q, k, v, q_bias, k_bias, mimo_v, mimo_o, dk_tilelang.view(B, S*R, H, N), dv_tilelang, dmimo_v, states, dq_tilelang.view(B, S*R, H, N), z, mimo_z, angles, dA_cs, dA_cs_rev, dt, trap, dfactor, dgamma_diag, dangles, D, dD, qk_dot, ddA, dSSdA, ddA_cs_rev, ddA_cs, segsum, ) if G == 1: dq_bias_tilelang = dq_tilelang.sum(dim=(0, 1)).permute((1, 0, 2)) dk_bias_tilelang = dk_tilelang.sum(dim=(0, 1)).permute((1, 0, 2)) dq_tilelang = dq_tilelang.sum(dim=3, keepdim=True) dk_tilelang = dk_tilelang.sum(dim=3, keepdim=True) dmimo_v = dmimo_v.sum(dim=0) dmimo_z = dmimo_z.sum(dim=0) if dmimo_z is not None else None dD = dD.sum(dim=0) if dD is not None else None elif G == H: dq_bias_tilelang = dq_tilelang.sum(dim=(0, 1)).permute((1, 0, 2)) dk_bias_tilelang = dk_tilelang.sum(dim=(0, 1)).permute((1, 0, 2)) dmimo_v = dmimo_v.sum(dim=0) dmimo_z = dmimo_z.sum(dim=0) if dmimo_z is not None else None dD = dD.sum(dim=0) if dD is not None else None else: raise ValueError(f"G value of {G} is not currently supported!") ddt, dtrap = bwd_dtrap_ddt_triton( trap, dt, dfactor, dgamma_diag, chunk_size ) ddA += bwd_dadt_fused_triton( dSSdA, segsum, ddA_cs, ddA_cs_rev, dA_cs, dA_cs_rev, chunk_size ) return (dq_tilelang, dk_tilelang, dv_tilelang, ddA, ddt, dtrap, dq_bias_tilelang, dk_bias_tilelang, dmimo_v, dmimo_z, dmimo_o, dangles, dD, dz_tilelang) ================================================ FILE: mamba_ssm/ops/tilelang/mamba3/mamba3_mimo_fwd.py ================================================ """ Tilelang implementation of Mamba3 forward kernel, with MIMO support. Copyright (c) 2026, Dao AI Lab, Goombalab """ import torch import tilelang import tilelang.language as T from tilelang.profiler import do_bench from tilelang.autotuner import autotune import itertools import argparse from typing import Optional, Tuple # NOTE: Uncomment the following to autotune: # def get_configs(): # iter_params = dict(num_stages=[0, 1, 2, 3], threads=[128, 256, 512]) # # iter_params = dict(num_stages=[2], threads=[128]) # return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] # @autotune( # configs=get_configs(), # warmup=3, # rep=20, # ) @tilelang.jit( out_idx=[], pass_configs={ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, }) def mamba_mimo_fwd( B, S, H, G, N, P, R, hasZ, hasD, reduceO, return_final_state=False, chunk_size: int = 16, rotary_dim_divisor = 4, dtype: str = 'bfloat16', threads: int = 128, num_stages: int = 0, ) -> torch.Tensor: accum_dtype = 'float32' # Block sizes for K and V dimensions - use full dimensions (no tiling) assert S % chunk_size == 0, "Sequence length must be divisible by chunk_size" nchunks = tilelang.cdiv(S, chunk_size) fused_chunk_size = chunk_size * R if reduceO: O_shape = (B, S, H, P) else: O_shape = (B, S, R, H, P) @T.prim_func def mamba_mimo_fwd_kernel( Q: T.Tensor([B, S, R, G, N], dtype), # type: ignore K: T.Tensor([B, S, R, G, N], dtype), # type: ignore V: T.Tensor([B, S, H, P], dtype), # type: ignore O: T.Tensor(O_shape, dtype), # type: ignore Q_BIAS: T.Tensor([H, R, N], T.float32), # type: ignore K_BIAS: T.Tensor([H, R, N], T.float32), # type: ignore MIMO_V: T.Tensor([H, R, P], T.float32), # type: ignore MIMO_O: T.Tensor([H, R, P], T.float32), # type: ignore Z: T.Tensor([B, S, H, P], dtype), # type: ignore D: T.Tensor([H], T.float32), # type: ignore MIMO_Z: T.Tensor([H, R, P], T.float32), # type: ignore ANGLES: T.Tensor([B, S, H, N//rotary_dim_divisor], T.float32), # type: ignore DA_CS: T.Tensor([B, H, S], T.float32), # type: ignore DA_CS_REV: T.Tensor([B, H, S], T.float32), # type: ignore DT: T.Tensor([B, H, S], T.float32), # type: ignore TRAP: T.Tensor([B, H, S], dtype), # type: ignore SEGSUM: T.Tensor([B, H, nchunks, chunk_size, chunk_size], T.float32), # type: ignore FINAL_STATE: T.Tensor([B, H, N, P], T.float32), # type: ignore FINAL_K: T.Tensor([B, R, H, N], dtype) # type: ignore ): """ Overview: Fused chunked forward pass that combines MIMO projections with recurrent state updates. Computes interchunk and intrachunk contributions with optional D and Z paths, then writes output activations. Inputs: - Activations: Q, K, V. - Projection parameters/biases: MIMO_V (Psi), MIMO_O (Phi), optional MIMO_Z (Zeta), ANGLES, and Q_BIAS/K_BIAS. - Optional modifiers: Z, and D. - Discretization tensors: DA_CS, DA_CS_REV, DT, TRAP, and SEGSUM. Outputs: - O: fused forward output activations. - FINAL_STATE: final recurrent states (if return_state is True). - FINAL_K: final K tensor (if return_state is True, for use in decode) Notation: - Psi: MIMO X projection. - Phi: MIMO O projection. - Zeta: MIMO Z projection. - Trap: convex-combination modulator used in exponential-trapezoidal discretization. """ with T.Kernel(H, B, threads=threads) as (i_h, i_b): # --- Kernel Setup --- # GQA support: map V head to Q/K head i_h_qk = i_h // (H // G) # --- Buffer Allocation --- q_shared = T.alloc_shared([fused_chunk_size, N], dtype) k_shared = T.alloc_shared([fused_chunk_size, N], dtype) q_bias_frag = T.alloc_fragment([R, N], dtype) k_bias_frag = T.alloc_fragment([R, N], dtype) angles_shared = T.alloc_shared([chunk_size, N], dtype) PsiV_shared = T.alloc_shared([fused_chunk_size, P], dtype) qs_shared = T.alloc_shared([fused_chunk_size, P], dtype) o_shared = T.alloc_shared([chunk_size, P], dtype) v_shared = T.alloc_shared([chunk_size, P], dtype) states_accum_cast_shared = T.alloc_shared([N, P], dtype) qk_intrachunk_shared = T.alloc_shared([fused_chunk_size, fused_chunk_size], dtype) qk_dot_full_shared = T.alloc_shared([fused_chunk_size, fused_chunk_size], dtype) # --- Swizzling Annotation --- T.annotate_layout({ q_shared: tilelang.layout.make_swizzled_layout(q_shared), k_shared: tilelang.layout.make_swizzled_layout(k_shared), v_shared: tilelang.layout.make_swizzled_layout(v_shared), angles_shared: tilelang.layout.make_swizzled_layout(angles_shared), PsiV_shared: tilelang.layout.make_swizzled_layout(PsiV_shared), qs_shared: tilelang.layout.make_swizzled_layout(qs_shared), o_shared: tilelang.layout.make_swizzled_layout(o_shared), states_accum_cast_shared: tilelang.layout.make_swizzled_layout(states_accum_cast_shared), qk_dot_full_shared: tilelang.layout.make_swizzled_layout(qk_dot_full_shared), qk_intrachunk_shared: tilelang.layout.make_swizzled_layout(qk_intrachunk_shared), }) T.use_swizzle(10, "row") T.no_set_max_nreg() # --- Per-Head Constants / Running State --- states_frag = T.alloc_fragment([N, P], accum_dtype) T.clear(states_frag) phi_frag_intrachunk = T.alloc_fragment([R, P], dtype=dtype) if reduceO: T.copy(MIMO_O[i_h, :, :], phi_frag_intrachunk) Psi_frag = T.alloc_fragment([R, P], dtype) T.copy(MIMO_V[i_h, :, :], Psi_frag) T.copy(Q_BIAS[i_h, :, :], q_bias_frag) T.copy(K_BIAS[i_h, :, :], k_bias_frag) # --- Chunk Loop --- for i in T.Pipelined(0, nchunks, num_stages=num_stages): chunk_start = i * chunk_size # --- Discretization Factors (Shifted Gamma + Trap Scale) --- trap_shifted_frag = T.alloc_fragment([chunk_size], T.float32) T.copy(TRAP[i_b, i_h, chunk_start+1: chunk_start+chunk_size+1], trap_shifted_frag) dt_shifted_frag = T.alloc_fragment([chunk_size], dtype) T.copy(DT[i_b, i_h, chunk_start+1: chunk_start+chunk_size+1], dt_shifted_frag) shifted_gamma_frag = T.alloc_fragment([chunk_size], dtype) for cs in T.Parallel(chunk_size): shifted_gamma_frag[cs] = T.if_then_else(chunk_start + cs < (S - 1), dt_shifted_frag[cs] * T.sigmoid(-trap_shifted_frag[cs]), 0.0) shifted_gamma_shared = T.alloc_shared([chunk_size], dtype) T.copy(shifted_gamma_frag, shifted_gamma_shared) trap_frag = T.alloc_fragment([chunk_size], T.float32) T.copy(TRAP[i_b, i_h, chunk_start: chunk_start+chunk_size], trap_frag) dt_frag = T.alloc_fragment([chunk_size], dtype) T.copy(DT[i_b, i_h, chunk_start: chunk_start+chunk_size], dt_frag) gamma_frag = T.alloc_fragment([chunk_size], T.float32) for cs in T.Parallel(chunk_size): gamma_frag[cs] = dt_frag[cs] * T.sigmoid(trap_frag[cs]) trap_scale_frag = T.alloc_fragment([chunk_size], dtype) for cs in T.Parallel(chunk_size): trap_scale_frag[cs] = gamma_frag[cs] + shifted_gamma_shared[cs] trap_scale_shared = T.alloc_shared([chunk_size], dtype) T.copy(trap_scale_frag, trap_scale_shared) # --- Up-Project V and Prepare Biased Q/K --- PsiV_frag = T.alloc_fragment([chunk_size, R, P], dtype) for cs, p in T.Parallel(chunk_size, P): v_shared[cs, p] = V[i_b, chunk_start+cs, i_h, p] for cs, r, p in T.Parallel(chunk_size, R, P): PsiV_frag[cs, r, p] = v_shared[cs, p] * Psi_frag[r, p] PsiV_reshaped_frag = T.view(PsiV_frag, shape=[fused_chunk_size, P]) T.copy(PsiV_reshaped_frag, PsiV_shared) q_frag = T.alloc_fragment([chunk_size, R, N], dtype) T.copy(Q[i_b, chunk_start:chunk_start+chunk_size, :, i_h_qk, :], q_frag) for cs, r, n in T.Parallel(chunk_size, R, N): q_frag[cs, r, n] += q_bias_frag[r, n] T.copy(T.view(q_frag, shape=[fused_chunk_size, N]), q_shared) k_frag = T.alloc_fragment([chunk_size, R, N], dtype) T.copy(K[i_b, chunk_start:chunk_start+chunk_size, :, i_h_qk, :], k_frag) for cs, r, n in T.Parallel(chunk_size, R, N): k_frag[cs, r, n] += k_bias_frag[r, n] T.copy(T.view(k_frag, shape=[fused_chunk_size, N]), k_shared) # --- Cache Diagonal qk_dot Path --- # Keep full qk_dot in shared memory because we reuse same-step R x R blocks later. qk_dot_frag = T.alloc_fragment([fused_chunk_size, fused_chunk_size], dtype=accum_dtype) T.gemm(q_shared, k_shared, qk_dot_frag, transpose_B=True, clear_accum=True) T.copy(qk_dot_frag, qk_dot_full_shared) # Option B: extremely slow # qk_dot_frag = T.alloc_fragment([chunk_size, R, R], dtype=accum_dtype) # T.clear(qk_dot_frag) # for cs, r_out, r_in in T.Parallel(chunk_size, R, R): # for n in T.serial(N): # qk_dot_frag[cs, r_out, r_in] += ( # q_frag[cs, r_out, n] * k_frag[cs, r_in, n] # ) # T.copy(T.view(qk_dot_frag, shape=[fused_chunk_size, R]), qk_dot_shared) # NOTE ("option C"): The following fails Tilelang compilation: # qk_predot_frag = T.alloc_fragment([chunk_size, R, R, N], dtype) # for cs, r_out, r_in, n in T.Parallel(chunk_size, R, R, N): # qk_predot_frag[cs, r_out, r_in, n] = q_frag[cs, r_out, n] * k_frag[cs, r_in, n] # qk_dot_frag = T.alloc_fragment([chunk_size, R, R], dtype) # T.reduce_sum(qk_predot_frag, qk_dot_frag, dim=-1, clear=True) # T.copy(T.view(qk_dot_frag, shape=[fused_chunk_size, R]), qk_dot_shared) # --- Rotary Q + Interchunk Contribution --- q_first_half_frag = T.alloc_fragment([chunk_size, R, N//rotary_dim_divisor], dtype) q_second_half_frag = T.alloc_fragment([chunk_size, R, N//rotary_dim_divisor], dtype) for cs, r, n in T.Parallel(chunk_size, R, N//rotary_dim_divisor): q_first_half_frag[cs, r, n] = q_shared[cs*R + r, n] q_second_half_frag[cs, r, n] = q_shared[cs*R + r, N//2 + n] # NOTE: angles are casted to fp32 for numerical stability angles_frag = T.alloc_fragment([chunk_size, N//rotary_dim_divisor], T.float32) T.copy(ANGLES[i_b, chunk_start:chunk_start+chunk_size, i_h, :], angles_frag) for cs, r, n in T.Parallel(chunk_size, R, N//rotary_dim_divisor): q_shared[cs*R + r, n] = T.cos(angles_frag[cs, n]) * q_first_half_frag[cs, r, n] - T.sin(angles_frag[cs, n]) * q_second_half_frag[cs, r, n] q_shared[cs*R + r, N//2 + n] = T.sin(angles_frag[cs, n]) * q_first_half_frag[cs, r, n] + T.cos(angles_frag[cs, n]) * q_second_half_frag[cs, r, n] o_mimo_accum_frag = T.alloc_fragment([fused_chunk_size, P], dtype=accum_dtype) T.copy(states_frag, states_accum_cast_shared) T.gemm(q_shared, states_accum_cast_shared, o_mimo_accum_frag, clear_accum=True) # --- Rotary K + Trap Scaling + Intrachunk Contribution --- k_first_half_frag = T.alloc_fragment([chunk_size, R, N//rotary_dim_divisor], dtype) k_second_half_frag = T.alloc_fragment([chunk_size, R, N//rotary_dim_divisor], dtype) for cs, r, n in T.Parallel(chunk_size, R, N//rotary_dim_divisor): k_first_half_frag[cs, r, n] = k_shared[cs*R + r, n] k_second_half_frag[cs, r, n] = k_shared[cs*R + r, N//2 + n] for cs, r, n in T.Parallel(chunk_size, R, N//rotary_dim_divisor): k_shared[cs*R + r, n] = T.cos(angles_frag[cs, n]) * k_first_half_frag[cs, r, n] - T.sin(angles_frag[cs, n]) * k_second_half_frag[cs, r, n] k_shared[cs*R + r, N//2 + n] = T.sin(angles_frag[cs, n]) * k_first_half_frag[cs, r, n] + T.cos(angles_frag[cs, n]) * k_second_half_frag[cs, r, n] if i == nchunks - 1 and return_final_state: seq_boundary = T.min(chunk_start + chunk_size, S) - chunk_start for csr, n in T.Parallel(fused_chunk_size, N): if csr >= (seq_boundary - 1) * R and csr < seq_boundary * R: # Only copy the last chunk's R rows to FINAL_K FINAL_K[i_b, csr % R, i_h, n] = k_shared[csr, n] k_trap_scaled_frag = T.alloc_fragment([fused_chunk_size, N], dtype) T.copy(k_shared, k_trap_scaled_frag) for csr, n in T.Parallel(fused_chunk_size, N): k_trap_scaled_frag[csr, n] *= trap_scale_shared[csr//R] T.copy(k_trap_scaled_frag, k_shared) qk_intrachunk_frag = T.alloc_fragment([fused_chunk_size, fused_chunk_size], dtype=accum_dtype) T.gemm(q_shared, k_shared, qk_intrachunk_frag, transpose_B=True, clear_accum=True) # Strictly causal mask over chunk steps (exclude same-step diagonal). da_cs__or__exp_da_cs_shared = T.alloc_shared([chunk_size], T.float32) T.copy(DA_CS[i_b, i_h, chunk_start:chunk_start+chunk_size], da_cs__or__exp_da_cs_shared) qk_intrachunk_masked_frag = T.alloc_fragment([fused_chunk_size, fused_chunk_size], dtype=dtype) for csr_i, csr_j in T.Parallel(fused_chunk_size, fused_chunk_size): qk_intrachunk_masked_frag[csr_i, csr_j] = T.if_then_else( csr_i//R > csr_j//R, # NOTE: we do indeed want to exclude the diagonal qk_intrachunk_frag[csr_i, csr_j] * T.exp(SEGSUM[i_b, i_h, i, csr_i//R, csr_j//R]), 0.0 ) # Exponentiate da_cs__or__exp_da_cs_shared so that later usage does not have to: for cs in T.Parallel(chunk_size): da_cs__or__exp_da_cs_shared[cs] = T.exp(da_cs__or__exp_da_cs_shared[cs]) exp_da_cs_frag = T.alloc_fragment([chunk_size], dtype=T.float32) T.copy(da_cs__or__exp_da_cs_shared, exp_da_cs_frag) for csr, p in T.Parallel(fused_chunk_size, P): o_mimo_accum_frag[csr, p] *= exp_da_cs_frag[csr//R] # NOTE: if we gemm with qk_intrachunk_masked_frag the compiler will # error with layout issue if threads != 128: # Copy via shared memory to satisfy layout constraints before GEMM. T.copy(qk_intrachunk_masked_frag, qk_intrachunk_shared) # Adding the two intermediate outputs together (interchunk += intrachunk) T.gemm(qk_intrachunk_shared, PsiV_shared, o_mimo_accum_frag, clear_accum=False) # --- Add Diagonal Terms (qk_dot and optional D) --- qkdot_psiv_frag = T.alloc_fragment([chunk_size, R, P], dtype=dtype) T.clear(qkdot_psiv_frag) for cs, r_out, p in T.Parallel(chunk_size, R, P): for r_in in T.serial(R): qkdot_psiv_frag[cs, r_out, p] += qk_dot_full_shared[cs * R + r_out, cs * R + r_in] * PsiV_shared[cs * R + r_in, p] qkdot_psiv_frag[cs, r_out, p] *= gamma_frag[cs] # Apply shifted gamma if hasD: PsiV_D_frag = T.alloc_fragment([chunk_size, R, P], T.float32) for cs, r, p in T.Parallel(chunk_size, R, P): PsiV_D_frag[cs, r, p] = PsiV_shared[cs * R + r, p] D_var = T.alloc_var(T.float32) T.copy(D[i_h], D_var) for cs, r_out, p in T.Parallel(chunk_size, R, P): qkdot_psiv_frag[cs, r_out, p] += D_var * PsiV_D_frag[cs, r_out, p] qkdot_psiv_reshaped_frag = T.view(qkdot_psiv_frag, shape=[fused_chunk_size, P]) for csr, p in T.Parallel(fused_chunk_size, P): o_mimo_accum_frag[csr, p] += qkdot_psiv_reshaped_frag[csr, p] # --- Optional Z Gating + Down-Projection --- if reduceO: if hasZ: z_frag = T.alloc_fragment([chunk_size, P], dtype) T.copy(Z[i_b, chunk_start:chunk_start+chunk_size, i_h, :], z_frag) z_expanded_frag = T.alloc_fragment([chunk_size, R, P], dtype) for cs, r, p in T.Parallel(chunk_size, R, P): # Apply SiLU to z_expanded_frag[cs, r, p]: o_gated = z_frag[cs, p] * MIMO_Z[i_h, r, p] * 0.5 z_expanded_frag[cs, r, p] = o_gated * T.tanh(o_gated) + o_gated lqk_PsiV_reshaped_frag = T.view(o_mimo_accum_frag, shape=[chunk_size, R, P]) if hasZ: for cs, r, p in T.Parallel(chunk_size, R, P): lqk_PsiV_reshaped_frag[cs, r, p] *= phi_frag_intrachunk[r, p] * z_expanded_frag[cs, r, p] else: for cs, r, p in T.Parallel(chunk_size, R, P): lqk_PsiV_reshaped_frag[cs, r, p] *= phi_frag_intrachunk[r, p] lqk_PsiV_reshaped_shared = T.alloc_shared([chunk_size, R, P], dtype) T.copy(lqk_PsiV_reshaped_frag, lqk_PsiV_reshaped_shared) o_frag = T.alloc_fragment([chunk_size, P], dtype) T.clear(o_frag) for r in T.serial(R): for cs, p in T.Parallel(chunk_size, P): o_frag[cs, p] += lqk_PsiV_reshaped_shared[cs, r, p] T.copy(o_frag, O[i_b, chunk_start:chunk_start+chunk_size, i_h, :]) else: if hasZ: z_frag = T.alloc_fragment([chunk_size, P], dtype) T.copy(Z[i_b, chunk_start:chunk_start+chunk_size, i_h, :], z_frag) z_expanded_frag = T.alloc_fragment([chunk_size, R, P], dtype) for cs, r, p in T.Parallel(chunk_size, R, P): # Apply SiLU to z_expanded_frag[cs, r, p]: o_gated = z_frag[cs, p] * MIMO_Z[i_h, r, p] * 0.5 z_expanded_frag[cs, r, p] = o_gated * T.tanh(o_gated) + o_gated lqk_PsiV_reshaped_shared = T.alloc_shared([chunk_size, R, P], dtype) for cs, r, p in T.Parallel(chunk_size, R, P): lqk_PsiV_reshaped_shared[cs, r, p] = o_mimo_accum_frag[cs* R + r, p] * z_expanded_frag[cs, r, p] # T.copy(lqk_PsiV_frag, lqk_PsiV_reshaped_shared) # for cs, r, p in T.Parallel(chunk_size, R, P): # lqk_PsiV_reshaped_shared[cs, r, p] *= z_expanded_frag[cs, r, p] else: lqk_PsiV_reshaped_shared = T.alloc_shared([chunk_size, R, P], dtype) # T.copy(lqk_PsiV_reshaped_frag, lqk_PsiV_reshaped_shared) for cs, r, p in T.Parallel(chunk_size, R, P): lqk_PsiV_reshaped_shared[cs, r, p] = o_mimo_accum_frag[cs* R + r, p] T.copy(lqk_PsiV_reshaped_shared, O[i_b, chunk_start:chunk_start+chunk_size, :, i_h, :]) # --- Recurrent State Update --- # DA_CS_REV scales per-step K contributions for state accumulation. dA_cs_rev_frag = T.alloc_fragment([chunk_size], T.float32) T.copy(DA_CS_REV[i_b, i_h, chunk_start:chunk_start+chunk_size], dA_cs_rev_frag) k_state_frag = T.alloc_fragment([fused_chunk_size, N], dtype) T.copy(k_shared, k_state_frag) for csr, n in T.Parallel(fused_chunk_size, N): k_state_frag[csr, n] *= T.exp(dA_cs_rev_frag[csr//R]) # DA_CS(last) applies the chunk-level decay to the carried state. da_cs_sum = T.alloc_var(T.float32) T.copy(DA_CS[i_b, i_h, chunk_start+chunk_size-1], da_cs_sum) for n, p in T.Parallel(N, P): states_frag[n, p] *= T.exp(da_cs_sum) T.gemm(k_state_frag, PsiV_shared, states_frag, transpose_A=True, clear_accum=False) # --- Save Last State (if applicable) --- if return_final_state: T.copy(states_frag, FINAL_STATE[i_b, i_h, :, :]) return mamba_mimo_fwd_kernel def mamba_mimo_forward(q, k, v, q_bias, k_bias, mimo_v, mimo_o, z, D, mimo_z, angles, dA_cs, dA_cs_rev, dt, trap, segsum, chunk_size, rotary_dim_divisor, dtype, return_state=False, threads=128, num_stages=0): B, S, R, G, N = q.shape H, P = v.shape[-2], v.shape[-1] if isinstance(dtype, torch.dtype): tl_dtype = str(dtype).replace("torch.", "") else: tl_dtype = dtype reduceO = mimo_o is not None kernel = mamba_mimo_fwd(B, S, H, G, N, P, R, z is not None, D is not None, reduceO, return_final_state=return_state, chunk_size=chunk_size, rotary_dim_divisor=rotary_dim_divisor, dtype=tl_dtype, threads=threads, num_stages=num_stages) # print(kernel.get_kernel_source()) # NOTE: prints compiled CUDA code if reduceO: o = torch.empty((B, S, H, P), device='cuda', dtype=dtype) else: o = torch.empty((B, S, R, H, P), device='cuda', dtype=dtype) # Kernel always declares all tensor parameters; pass dummies for None args mimo_o_arg = mimo_o if reduceO else torch.empty((H, R, P), device=q.device, dtype=torch.float32) z_arg = z if z is not None else torch.empty((B, S, H, P), device=q.device, dtype=dtype) D_arg = D if D is not None else torch.empty((H,), device=q.device, dtype=torch.float32) mimo_z_arg = mimo_z if mimo_z is not None else torch.empty((H, R, P), device=q.device, dtype=torch.float32) h = torch.empty((B, H, N, P), device='cuda', dtype=torch.float32) if return_state else None k_final = torch.empty((B, R, H, N), device='cuda', dtype=dtype) if return_state else None kernel( q, k, v, o, q_bias, k_bias, mimo_v, mimo_o_arg, z_arg, D_arg, mimo_z_arg, angles, dA_cs, dA_cs_rev, dt, trap, segsum, h, k_final ) return o, h, k_final ================================================ FILE: mamba_ssm/ops/triton/__init__.py ================================================ ================================================ FILE: mamba_ssm/ops/triton/angle_cumsum.py ================================================ # Copyright (c) 2025, Tri Dao. from typing import Optional import math import torch import triton import triton.language as tl from triton.language.extra import libdevice class AngleDtFn(torch.autograd.Function): @staticmethod def forward(ctx, angle: torch.Tensor, # (B, S, H, D) dt: torch.Tensor, # (B, S, H) chunk_size: int = 128 # power of 2 ) -> torch.Tensor: # run Triton fwd out = apply_angle_dt_fwd(angle, dt, chunk_size=chunk_size) # save for bwd ctx.save_for_backward(angle, dt) ctx.chunk_size = int(chunk_size) return out @staticmethod def backward(ctx, grad_out: torch.Tensor): angle, dt = ctx.saved_tensors # run Triton bwd grad_dt, grad_angle = apply_angle_dt_bwd( grad_out=grad_out, angle=angle, dt=dt, chunk_size=ctx.chunk_size ) # grads align with (angle, dt, chunk_size) return grad_angle, grad_dt, None def angle_dt(angle: torch.Tensor, dt: torch.Tensor, *, chunk_size: int = 128) -> torch.Tensor: return AngleDtFn.apply(angle, dt, chunk_size) @triton.jit def cumsum_kernel( OUT, # Output tensor (batch, seqlen, nheads, dim) X, # Input tensor (batch, seqlen, nheads, dim) seqlen, dim, stride_out, # (batch, seqlen, nheads, dim) stride_x, # (batch, seqlen, nheads, dim) # Meta-parameters BLOCK_S: tl.constexpr, BLOCK_D: tl.constexpr, ): # Program IDs pid_h = tl.program_id(axis=0) # Head index (one per head) pid_d = tl.program_id(axis=1) # Dim block pid_b = tl.program_id(axis=2) # Batch index (one per batch element) # Offset pointers by batch and head X = X + pid_b * stride_x[0] + pid_h * stride_x[2] OUT = OUT + pid_b * stride_out[0] + pid_h * stride_out[2] # Compute ranges dim_range = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) dim_mask = dim_range < dim # Load entire sequence for this batch, head, and dim block seq_range = tl.arange(0, BLOCK_S)[:, None] # (BLOCK_S, 1) # Load input: (seqlen, dim) for this batch and head x_ptrs = X + seq_range * stride_x[1] + dim_range[None, :] * stride_x[3] x_mask = (seq_range < seqlen) & dim_mask[None, :] x_vals = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.float32) # Compute cumulative sum along sequence dimension (axis 0) cumsum_vals = tl.cumsum(x_vals, axis=0) # Store output: (seqlen, dim) for this batch and head out_ptrs = OUT + seq_range * stride_out[1] + dim_range[None, :] * stride_out[3] out_mask = (seq_range < seqlen) & dim_mask[None, :] tl.store(out_ptrs, cumsum_vals, mask=out_mask) @triton.jit def angle_dt_fwd_kernel( OUT, # Output tensor (batch, seqlen, nheads, dim) OUT_SUM, # Output sum tensor (batch, seqlen // chunk_size, nheads, dim) ANGLE, # Angle tensor (batch, seqlen, nheads, dim) DT, # Delta time tensor (batch, seqlen, nheads) PREFIX, # Prefix tensor (batch, numchunks, nheads, dim) - optional seqlen, dim, chunk_size, stride_out, # (batch, seqlen, nheads, dim) stride_out_sum, # (batch, seqlen // chunk_size, nheads, dim) stride_angle, # (batch, seqlen, nheads, dim) stride_dt, # (batch, seqlen, nheads) stride_prefix, # (batch, numchunks, nheads, dim) # Meta-parameters BLOCK_S: tl.constexpr, BLOCK_D: tl.constexpr, WRITE_OUTPUT: tl.constexpr, # Whether to write the full output WRITE_CHUNK_SUM: tl.constexpr, # Whether to write the chunk sum HAS_PREFIX: tl.constexpr, # Whether prefix is provided ): # Program IDs pid_b = tl.program_id(axis=2) # Batch index (one per batch element) pid_s = tl.program_id(axis=1) # Sequence block (chunk index) pid_h = tl.program_id(axis=0) # Head index (one per head) # Offset pointers by batch and head ANGLE = ANGLE + pid_b * stride_angle[0] + pid_h * stride_angle[2] DT = DT + pid_b * stride_dt[0] + pid_h * stride_dt[2] if WRITE_OUTPUT: OUT = OUT + pid_b * stride_out[0] + pid_h * stride_out[2] if WRITE_CHUNK_SUM: OUT_SUM = OUT_SUM + pid_b * stride_out_sum[0] + pid_h * stride_out_sum[2] if HAS_PREFIX: PREFIX = PREFIX + pid_b * stride_prefix[0] + pid_h * stride_prefix[2] # Compute ranges - each block processes exactly chunk_size elements seq_start = pid_s * chunk_size seq_range = seq_start + tl.arange(0, BLOCK_S) dim_range = tl.arange(0, BLOCK_D) # Masks seq_mask = seq_range < seqlen dim_mask = dim_range < dim # Load angle: (seqlen, dim) for this batch and head angle_ptrs = ANGLE + seq_range[:, None] * stride_angle[1] + dim_range[None, :] * stride_angle[3] angle_mask = (seq_mask[:, None] & dim_mask[None, :]) angle_vals = tl.load(angle_ptrs, mask=angle_mask, other=0.0).to(tl.float32) # Load dt: (seqlen,) for this batch and head dt_ptrs = DT + seq_range * stride_dt[1] dt_mask = seq_mask dt_vals = tl.load(dt_ptrs, mask=dt_mask, other=0.0).to(tl.float32) # Multiply: angle (S, D) * dt (S, 1) -> output (S, D) # angle_vals: (BLOCK_S, BLOCK_D) # dt_vals: (BLOCK_S,) #output_vals = angle_vals * dt_vals[:, None] # (BLOCK_S, BLOCK_D) output_vals = tl.sigmoid(2.0 * angle_vals) * 2.0 - 1.0 # output_vals = libdevice.tanh(output_vals) # This is pretty slow # This is still not super fast, idk how to enable fastmath #output_vals = tl.sigmoid(2.0 * output_vals) * 2.0 - 1.0 output_vals = output_vals * dt_vals[:, None] # This is the fastest, but with reduced accuracy. We probably don't need it # output_vals = tl.inline_asm_elementwise( # "tanh.approx.f32 $0, $1;", # "=f,f", # [output_vals], # dtype=tl.float32, # is_pure=True, # pack=1, # ) output_vals *= 3.141592653589793 # pi # Conditionally compute and store chunk sum if WRITE_CHUNK_SUM: # Compute sum along sequence dimension (within this chunk) # Sum over the sequence dimension (axis 0) chunk_sum = tl.sum(output_vals, axis=0) # (BLOCK_D,) # Store chunk sum: (seqlen // chunk_size, dim) for this batch and head sum_ptrs = OUT_SUM + pid_s * stride_out_sum[1] + dim_range * stride_out_sum[3] sum_mask = dim_mask tl.store(sum_ptrs, chunk_sum, mask=sum_mask) # Conditionally store output: (seqlen, dim) for this batch and head if WRITE_OUTPUT: output_vals = tl.cumsum(output_vals, axis=0) # Cumulative sum along sequence dimension (axis 0) # Add prefix if provided if HAS_PREFIX: # If chunk idx is 0, prefix is 0. If chunk idx is i, read from prefix at location i-1 if pid_s > 0: # Load prefix for this chunk from location pid_s - 1 prefix_ptrs = PREFIX + (pid_s - 1) * stride_prefix[1] + dim_range * stride_prefix[3] prefix_mask = dim_mask prefix_vals = tl.load(prefix_ptrs, mask=prefix_mask, other=0.0).to(tl.float32) # Add prefix to all elements in this chunk output_vals = output_vals + prefix_vals[None, :] # Broadcast prefix across sequence dimension # For pid_s == 0, prefix is implicitly 0, so no addition needed out_ptrs = OUT + seq_range[:, None] * stride_out[1] + dim_range[None, :] * stride_out[3] out_mask = (seq_mask[:, None] & dim_mask[None, :]) tl.store(out_ptrs, output_vals, mask=out_mask) # The kernel expects inputs to be flipped in the sequence dimension. # This is because it processes chunks in reverse order. @triton.jit def angle_dt_bwd_kernel( GRAD_DT, # Grad dt tensor (batch, seqlen, nheads) GRAD_ANGLE, # Grad angle tensor (batch, seqlen, nheads, dim) GRAD_SUM, # Grad sum tensor (batch, seqlen // chunk_size, nheads, dim) GRAD_OUT, # Grad input tensor (batch, seqlen, nheads, dim) ANGLE, # Angle tensor (batch, seqlen, nheads, dim) DT, # Delta time tensor (batch, seqlen, nheads) PREFIX, # Prefix tensor (batch, numchunks, nheads, dim) - optional seqlen, dim, chunk_size, stride_grad_dt, # (batch, seqlen, nheads) stride_grad_angle, # (batch, seqlen, nheads, dim) stride_grad_sum, # (batch, seqlen // chunk_size, nheads, dim) stride_grad_out, # (batch, seqlen, nheads, dim) stride_angle, # (batch, seqlen, nheads, dim) stride_dt, # (batch, seqlen, nheads) stride_prefix, # (batch, numchunks, nheads, dim) # Meta-parameters BLOCK_S: tl.constexpr, BLOCK_D: tl.constexpr, WRITE_GRAD: tl.constexpr, # Whether to write the full output WRITE_CHUNK_SUM: tl.constexpr, # Whether to write the chunk sum HAS_PREFIX: tl.constexpr, # Whether prefix is provided ): # Program IDs pid_b = tl.program_id(axis=2) # Batch index (one per batch element) pid_s = tl.program_id(axis=1) # Sequence block (chunk index) pid_h = tl.program_id(axis=0) # Head index (one per head) # Offset pointers by batch and head GRAD_OUT = GRAD_OUT + pid_b * stride_grad_out[0] + pid_h * stride_grad_out[2] if WRITE_GRAD: GRAD_DT = GRAD_DT + pid_b * stride_grad_dt[0] + pid_h * stride_grad_dt[2] GRAD_ANGLE = GRAD_ANGLE + pid_b * stride_grad_angle[0] + pid_h * stride_grad_angle[2] DT = DT + pid_b * stride_dt[0] + pid_h * stride_dt[2] ANGLE = ANGLE + pid_b * stride_angle[0] + pid_h * stride_angle[2] if WRITE_CHUNK_SUM: GRAD_SUM = GRAD_SUM + pid_b * stride_grad_sum[0] + pid_h * stride_grad_sum[2] if HAS_PREFIX: PREFIX = PREFIX + pid_b * stride_prefix[0] + pid_h * stride_prefix[2] # Compute ranges - each block processes exactly chunk_size elements seq_start = pid_s * chunk_size seq_range = seq_start + tl.arange(0, BLOCK_S) dim_range = tl.arange(0, BLOCK_D) # Masks seq_mask = seq_range < seqlen dim_mask = dim_range < dim # Load angle: (seqlen, dim) for this batch and head grad_out_ptrs = GRAD_OUT + seq_range[:, None] * stride_grad_out[1] + dim_range[None, :] * stride_grad_out[3] grad_out_mask = (seq_mask[:, None] & dim_mask[None, :]) grad_out_vals = tl.load(grad_out_ptrs, mask=grad_out_mask, other=0.0).to(tl.float32) # Conditionally compute and store chunk sum if WRITE_CHUNK_SUM: # Compute sum along sequence dimension (within this chunk) # Sum over the sequence dimension (axis 0) chunk_sum = tl.sum(grad_out_vals, axis=0) # (BLOCK_D,) # Store chunk sum: (seqlen // chunk_size, dim) for this batch and head sum_ptrs = GRAD_SUM + pid_s * stride_grad_sum[1] + dim_range * stride_grad_sum[3] sum_mask = dim_mask tl.store(sum_ptrs, chunk_sum, mask=sum_mask) # Conditionally store output: (seqlen, dim) for this batch and head if WRITE_GRAD: grad_out_vals = tl.cumsum(grad_out_vals, axis=0) # Cumulative sum along sequence dimension (axis 0) # Add prefix if provided if HAS_PREFIX: # If chunk idx is 0, prefix is 0. If chunk idx is i, read from prefix at location i-1 if pid_s > 0: # Load prefix for this chunk from location pid_s - 1 prefix_ptrs = PREFIX + (pid_s - 1) * stride_prefix[1] + dim_range * stride_prefix[3] prefix_mask = dim_mask prefix_vals = tl.load(prefix_ptrs, mask=prefix_mask, other=0.0).to(tl.float32) # Add prefix to all elements in this chunk grad_out_vals = grad_out_vals + prefix_vals[None, :] # Broadcast prefix across sequence dimension # For pid_s == 0, prefix is implicitly 0, so no addition needed # Load angle: (seqlen, dim) for this batch and head angle_ptrs = ANGLE + seq_range[:, None] * stride_angle[1] + dim_range[None, :] * stride_angle[3] angle_mask = (seq_mask[:, None] & dim_mask[None, :]) angle_vals = tl.load(angle_ptrs, mask=angle_mask, other=0.0).to(tl.float32) # Load dt: (seqlen,) for this batch and head dt_ptrs = DT + seq_range * stride_dt[1] dt_mask = seq_mask dt_vals = tl.load(dt_ptrs, mask=dt_mask, other=0.0).to(tl.float32) # (BLOCK_S,) # Compute dt gradients tanh_angle_vals = tl.sigmoid(2.0 * angle_vals) * 2.0 - 1.0 # (BLOCK_S, BLOCK_D) pi_tanh_angle_vals = tanh_angle_vals*3.141592653589793 dt_grad_vals = grad_out_vals * pi_tanh_angle_vals # (BLOCK_S, BLOCK_D) dt_grad_vals = tl.sum(dt_grad_vals, axis=1) # Sum over dim to get (BLOCK_S,) # Store dt gradients grad_dt_ptrs = GRAD_DT + seq_range * stride_grad_dt[1] grad_dt_mask = seq_mask tl.store(grad_dt_ptrs, dt_grad_vals, mask=grad_dt_mask) # Compute angle gradients d_tanh = 1.0 - tanh_angle_vals * tanh_angle_vals grad_angle_vals = (3.141592653589793 * dt_vals[:, None]) * d_tanh * grad_out_vals # Store angle gradients grad_angle_ptrs = GRAD_ANGLE + seq_range[:, None] * stride_grad_angle[1] + dim_range[None, :] * stride_grad_angle[3] grad_angle_mask = (seq_mask[:, None] & dim_mask[None, :]) tl.store(grad_angle_ptrs, grad_angle_vals, mask=grad_angle_mask) def apply_angle_dt_fwd( angle: torch.Tensor, # (batch, seqlen, nheads, dim) dt: torch.Tensor, # (batch, seqlen, nheads) chunk_size: int = 128, ) -> tuple[torch.Tensor, torch.Tensor]: """ Multiply angle and dt tensors element-wise and compute chunk sums. Arguments: angle: (batch, seqlen, nheads, dim) dt: (batch, seqlen, nheads) chunk_size: Size of chunks for summing (must be power of 2) write_output: Whether to write the full output tensor write_chunk_sum: Whether to write the chunk sum tensor prefix: Optional prefix to add before cumsum (batch, numchunks, nheads, dim) Returns: output: (batch, seqlen, nheads, dim) - may contain uninitialized data if write_output=False output_sum: (batch, seqlen // chunk_size, nheads, dim) - may contain uninitialized data if write_chunk_sum=False """ batch, seqlen, nheads, dim = angle.shape assert angle.shape == (batch, seqlen, nheads, dim) assert dt.shape == (batch, seqlen, nheads) assert chunk_size > 0 and (chunk_size & (chunk_size - 1)) == 0, "chunk_size must be power of 2" # Calculate output dimensions num_chunks = math.ceil(seqlen / chunk_size) # Create output tensors (always fp32) output = torch.empty(batch, seqlen, nheads, dim, device=angle.device, dtype=torch.float32) output_sum = torch.empty(batch, num_chunks, nheads, dim, device=angle.device, dtype=torch.float32) # Launch kernel BLOCK_S = chunk_size # Use chunk_size as BLOCK_S BLOCK_D = triton.next_power_of_2(dim) # Step 1: compute the sum of each chunk. Don't write the output grid = lambda META: (nheads, num_chunks, batch) with torch.cuda.device(angle.device.index): torch.library.wrap_triton(angle_dt_fwd_kernel)[grid]( None, # output output_sum, angle, dt, None, # prefix seqlen, dim, chunk_size, (0, 0, 0, 0), # output_stride output_sum.stride(), angle.stride(), dt.stride(), (0, 0, 0, 0), # prefix_stride BLOCK_S=BLOCK_S, BLOCK_D=BLOCK_D, WRITE_OUTPUT=False, WRITE_CHUNK_SUM=True, HAS_PREFIX=False, ) # Step 2: compute cumsum on output_sum to get prefix prefix = apply_cumsum(output_sum) # Shape: (batch, num_chunks, nheads, dim) # Step 3: call angle_dt_kernel again with output and prefix, don't need to write output_sum with torch.cuda.device(angle.device.index): torch.library.wrap_triton(angle_dt_fwd_kernel)[grid]( output, # output None, # output_sum (don't need to write) angle, dt, prefix, # prefix seqlen, dim, chunk_size, output.stride(), # output_stride (0, 0, 0, 0), # output_sum_stride angle.stride(), dt.stride(), prefix.stride(), # prefix_stride BLOCK_S=BLOCK_S, BLOCK_D=BLOCK_D, WRITE_OUTPUT=True, WRITE_CHUNK_SUM=False, HAS_PREFIX=True, ) return output def apply_angle_dt_bwd( grad_out: torch.Tensor, # (batch, seqlen, nheads, dim) angle: torch.Tensor, # (batch, seqlen, nheads, dim) dt: torch.Tensor, # (batch, seqlen, nheads) chunk_size: int = 128, ) -> tuple[torch.Tensor, torch.Tensor]: """ Multiply angle and dt tensors element-wise and compute chunk sums. Arguments: grad_out: (batch, seqlen, nheads, dim) - gradient of the output angle: (batch, seqlen, nheads, dim) - stored angle tensor dt: (batch, seqlen, nheads) - stored delta time tensor chunk_size: Size of chunks for summing (must be power of 2) write_output: Whether to write the full output tensor write_chunk_sum: Whether to write the chunk sum tensor prefix: Optional prefix to add before cumsum (batch, numchunks, nheads, dim) Returns: output: (batch, seqlen, nheads, dim) - may contain uninitialized data if write_output=False output_sum: (batch, seqlen // chunk_size, nheads, dim) - may contain uninitialized data if write_chunk_sum=False """ batch, seqlen, nheads, dim = grad_out.shape assert grad_out.shape == (batch, seqlen, nheads, dim) assert angle.shape == (batch, seqlen, nheads, dim) assert dt.shape == (batch, seqlen, nheads) assert chunk_size > 0 and (chunk_size & (chunk_size - 1)) == 0, "chunk_size must be power of 2" # Calculate output dimensions num_chunks = math.ceil(seqlen / chunk_size) # Reverse the sequence dimension of grad_out, angle, dt grad_out = grad_out.flip(dims=(1,)) # Reverse along sequence dimension angle = angle.flip(dims=(1,)) dt = dt.flip(dims=(1,)) # Create output tensors (always fp32) grad_dt = torch.empty_like(dt) # (batch, seqlen, nheads) grad_angle = torch.empty_like(angle) # (batch, seqlen, nheads, dim) grad_sum = torch.empty(batch, num_chunks, nheads, dim, device=angle.device, dtype=torch.float32) # Launch kernel BLOCK_S = chunk_size # Use chunk_size as BLOCK_S BLOCK_D = triton.next_power_of_2(dim) # Step 1: compute the sum of each chunk. Don't write the output grid = lambda META: (nheads, num_chunks, batch) with torch.cuda.device(angle.device.index): torch.library.wrap_triton(angle_dt_bwd_kernel)[grid]( None, # GRAD_DT None, # GRAD_ANGLE grad_sum, # GRAD_SUM grad_out, # GRAD_OUT angle, dt, None, # PREFIX seqlen, dim, chunk_size, (0, 0, 0), # stride_grad_dt (0, 0, 0, 0), # stride_grad_angle grad_sum.stride(), # stride_grad_sum grad_out.stride(), # stride_grad_out angle.stride(), dt.stride(), (0, 0, 0, 0), # stride_prefix BLOCK_S=BLOCK_S, BLOCK_D=BLOCK_D, WRITE_GRAD=False, # Don't write grad_dt and grad_angle yet WRITE_CHUNK_SUM=True, # Write chunk sums to grad_sum HAS_PREFIX=False, # No prefix provided ) # Step 2: compute cumsum on output_sum to get prefix prefix = apply_cumsum(grad_sum) # Shape: (batch, num_chunks, nheads, dim) # Step 3: call angle_dt_fwd_chunksum_kernel again with output and prefix, don't need to write output_sum with torch.cuda.device(angle.device.index): torch.library.wrap_triton(angle_dt_bwd_kernel)[grid]( grad_dt, grad_angle, None, # GRAD_SUM (don't need to write) grad_out, angle, dt, prefix, # prefix seqlen, dim, chunk_size, grad_dt.stride(), # stride_grad_dt grad_angle.stride(), # stride_grad_angle (0, 0, 0), # stride_grad_sum grad_out.stride(), # stride_grad_out angle.stride(), dt.stride(), prefix.stride(), # stride_prefix BLOCK_S=BLOCK_S, BLOCK_D=BLOCK_D, WRITE_GRAD=True, # Write grad_dt and grad_angle WRITE_CHUNK_SUM=False, # Don't write chunk sums again HAS_PREFIX=True, # Use the computed prefix ) grad_dt = grad_dt.flip(dims=(1,)) grad_angle = grad_angle.flip(dims=(1,)) return grad_dt, grad_angle def apply_cumsum( x: torch.Tensor, # (batch, seqlen, nheads, dim) ) -> torch.Tensor: """ Compute cumulative sum along sequence dimension using Triton. Arguments: x: (batch, seqlen, nheads, dim) Returns: output: (batch, seqlen, nheads, dim) - cumulative sum along seqlen dimension """ batch, seqlen, nheads, dim = x.shape assert seqlen <= 512, f"seqlen must be <= 512, got {seqlen}" # Create output tensor (always fp32) output = torch.empty_like(x, dtype=torch.float32) # Launch kernel BLOCK_S = triton.next_power_of_2(seqlen) BLOCK_D = triton.next_power_of_2(min(dim, 16)) grid = lambda META: (nheads, triton.cdiv(dim, META["BLOCK_D"]), batch) with torch.cuda.device(x.device.index): torch.library.wrap_triton(cumsum_kernel)[grid]( output, x, seqlen, dim, output.stride(), x.stride(), BLOCK_S=BLOCK_S, BLOCK_D=BLOCK_D, ) return output def apply_angle_dt_reference( angle: torch.Tensor, # (batch, seqlen, nheads, dim) dt: torch.Tensor, # (batch, seqlen, nheads) chunk_size: int = 64, ) -> tuple[torch.Tensor, torch.Tensor]: """Reference PyTorch implementation.""" batch, seqlen, nheads, dim = angle.shape # Element-wise multiply: angle (B, S, H, D) * dt (B, S, H, 1) -> (B, S, H, D) #base_vals = (angle * dt[..., None]).to(torch.float32) # Always return fp32 base_vals = (angle).to(torch.float32) # Apply tanh then multiply by pi base_vals = torch.tanh(base_vals) * dt[..., None].to(torch.float32) * torch.pi # Simple cumulative sum along seqlen dimension output = torch.cumsum(base_vals, dim=1) return output def test_correctness(): """Test correctness against reference implementation.""" print("Testing angle_dt kernel correctness...") # Test parameters batch, seqlen, nheads, dim = 2, 512, 4, 32 chunk_size = 64 device = "cuda" dtype = torch.float32 # Create test tensors #torch.manual_seed(42) angle = torch.randn(batch, seqlen, nheads, dim, device=device, dtype=dtype) dt = torch.randn(batch, seqlen, nheads, device=device, dtype=dtype) # Test kernel vs reference out_triton = apply_angle_dt_fwd(angle, dt, chunk_size) out_ref = apply_angle_dt_reference(angle, dt, chunk_size) max_diff = (out_triton - out_ref).abs().max().item() print(f"Output max difference: {max_diff:.6f}") assert max_diff < 1e-3, f"Too large difference in output: {max_diff}" print("Test passed! ✓") print("All basic tests passed! ✓") def test_cumsum_correctness(): """Test cumsum kernel correctness against PyTorch.""" print("Testing cumsum kernel correctness...") # Test parameters batch, seqlen, nheads, dim = 4, 128, 8, 64 device = "cuda" dtype = torch.float32 # Create test tensors #torch.manual_seed(42) x = torch.randn(batch, seqlen, nheads, dim, device=device, dtype=dtype) # Test kernel vs PyTorch out_triton = apply_cumsum(x) out_ref = torch.cumsum(x, dim=1).to(torch.float32) max_diff = (out_triton - out_ref).abs().max().item() print(f"Cumsum max difference: {max_diff:.6f}") assert max_diff < 1e-4, f"Too large difference in cumsum: {max_diff}" print("Cumsum test passed! ✓") def test_backward_correctness(): """Backward correctness vs PyTorch autograd on small cases.""" print("Testing backward correctness...") device = "cuda" tol = 5e-3 # fp32 cases = [ (2, 257, 3, 17, 64), # odd S/D, non-power-of-two (1, 129, 4, 33, 32), ] for (batch, seqlen, nheads, dim, chunk_size) in cases: angle = torch.randn(batch, seqlen, nheads, dim, device=device, dtype=torch.float32) dt = torch.randn(batch, seqlen, nheads, device=device, dtype=torch.float32) grad_out = torch.randn(batch, seqlen, nheads, dim, device=device, dtype=torch.float32) # Triton bwd grad_dt_tri, grad_angle_tri = apply_angle_dt_bwd(grad_out, angle, dt, chunk_size) # Reference bwd via autograd angle_ref = angle.detach().clone().requires_grad_(True) dt_ref = dt.detach().clone().requires_grad_(True) out_ref = apply_angle_dt_reference(angle_ref, dt_ref, chunk_size) out_ref.backward(grad_out) grad_angle_ref = angle_ref.grad.detach() grad_dt_ref = dt_ref.grad.detach() max_da = (grad_angle_tri - grad_angle_ref).abs().max().item() max_dd = (grad_dt_tri - grad_dt_ref ).abs().max().item() print(f" Case B={batch} S={seqlen} H={nheads} D={dim} chunk={chunk_size} | " f"max|Δ angle|={max_da:.3e} max|Δ dt|={max_dd:.3e}") assert max_da < tol, f"angle grad mismatch {max_da}" assert max_dd < tol, f"dt grad mismatch {max_dd}" print("Backward correctness test passed! ✓") def benchmark_angle_dt(): """Benchmark angle_dt kernel and measure memory bandwidth.""" print("\nBenchmarking angle_dt kernel...") # Benchmark parameters batch, seqlen, nheads, dim = 8, 4096, 32, 32 # batch, seqlen, nheads, dim = 1, 128, 1, 1 chunk_size = 128 device = "cuda" dtype = torch.bfloat16 # Create input tensors #torch.manual_seed(42) # Generate angle by expanding from (batch, seqlen, 1, dim) to (batch, seqlen, nheads, dim) angle_base = torch.randn(batch, seqlen, 1, dim, device=device, dtype=dtype) angle = angle_base.expand(batch, seqlen, nheads, dim) dt = torch.randn(batch, seqlen, nheads, device=device, dtype=dtype) fn = lambda: apply_angle_dt_fwd(angle, dt, chunk_size) out = fn() # Warmup for _ in range(10): fn() # Benchmark torch.cuda.synchronize() import time time.sleep(0.5) # Run benchmark time_ms = triton.testing.do_bench(fn, warmup=10, rep=100) # Calculate memory bandwidth # Read: angle_base (actual underlying data) + dt # Write: output + output_sum (always fp32, so 4 bytes per element) # Note: angle is expanded so actual memory read is only angle_base.numel() bytes_read = angle_base.untyped_storage().nbytes() + dt.untyped_storage().nbytes() bytes_write = out.untyped_storage().nbytes() # Both output and output_sum (fp32 = 4 bytes) total_bytes = bytes_read + bytes_write # Convert to GB/s time_s = time_ms / 1000.0 bandwidth_gb_s = (total_bytes / 1e9) / time_s print(f"Angle base shape: {angle_base.shape}") print(f"Angle expanded shape: {angle.shape}") print(f"Angle stride: {angle.stride()}") print(f"DT shape: {dt.shape}") print(f"Output shape: {out.shape}") print(f"Chunk size: {chunk_size}") print(f"Time: {time_ms:.3f} ms") print(f"Memory transferred: {total_bytes / 1e9:.3f} GB") print(f"Memory bandwidth: {bandwidth_gb_s:.1f} GB/s") # from flash_attn.utils.benchmark import pytorch_profiler # pytorch_profiler(fn) return time_ms, bandwidth_gb_s def benchmark_angle_dt_backward(): """Benchmark backward pass and report rough memory bandwidth.""" print("\nBenchmarking angle_dt backward...") batch, seqlen, nheads, dim = 8, 4096, 32, 32 chunk_size = 128 device = "cuda" # Use fp32 for bwd accumulations angle = torch.randn(batch, seqlen, nheads, dim, device=device, dtype=torch.float32) dt = torch.randn(batch, seqlen, nheads, device=device, dtype=torch.float32) grad_out = torch.randn(batch, seqlen, nheads, dim, device=device, dtype=torch.float32) fn = lambda: apply_angle_dt_bwd(grad_out, angle, dt, chunk_size) _ = fn() # Warmup for _ in range(10): fn() torch.cuda.synchronize() import time time.sleep(0.5) time_ms = triton.testing.do_bench(fn, warmup=10, rep=100) # Rough traffic estimate (two-stage + prefixes), conservative: num_chunks = (seqlen + chunk_size - 1) // chunk_size bytes_read = ( grad_out.numel() * 4 + # read grad_out angle.numel() * 4 + # read angle dt.numel() * 4 + # read dt (batch * num_chunks * nheads * dim) * 4 + # read grad_sum for prefix (batch * num_chunks * nheads * dim) * 4 # read prefix in stage 2 ) bytes_write = ( (batch * num_chunks * nheads * dim) * 4 + # write grad_sum (stage 1) (batch * seqlen * nheads) * 4 + # write grad_dt (batch * seqlen * nheads * dim) * 4 # write grad_angle ) total_bytes = bytes_read + bytes_write bandwidth_gb_s = (total_bytes / 1e9) / (time_ms / 1000.0) print(f"B={batch} S={seqlen} H={nheads} D={dim} chunk={chunk_size}") print(f"Time: {time_ms:.3f} ms") print(f"Memory transferred (est): {total_bytes / 1e9:.3f} GB") print(f"Memory bandwidth (est): {bandwidth_gb_s:.1f} GB/s") return time_ms, bandwidth_gb_s if __name__ == "__main__": test_correctness() test_cumsum_correctness() benchmark_angle_dt() ================================================ FILE: mamba_ssm/ops/triton/k_activations.py ================================================ # Copyright (c) 2024, Tri Dao, Albert Gu. import torch import triton import triton.language as tl from mamba_ssm.utils.determinism import autotune_configs @triton.autotune( configs=autotune_configs([ triton.Config({'BLOCK_N': 32}), triton.Config({'BLOCK_N': 64}), triton.Config({'BLOCK_N': 128}), triton.Config({'BLOCK_N': 256}), triton.Config({'BLOCK_N': 512}), triton.Config({'BLOCK_N': 1024}), ]), key=['ncols'], ) @triton.jit def _swiglu_fwd_kernel( X, Y, OUT, stride_x_row, # how much to increase the pointer when moving by 1 row stride_y_row, stride_out_row, ncols, BLOCK_N: tl.constexpr, ): # Map the program id to the row of X and Y it should compute. row = tl.program_id(0) start_col = tl.program_id(1) * BLOCK_N X += row * stride_x_row Y += row * stride_y_row OUT += row * stride_out_row cols = start_col + tl.arange(0, BLOCK_N) x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32) y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32) out = x * tl.sigmoid(x) * y tl.store(OUT + cols, out, mask=cols < ncols) def _swiglu_fwd(xy, out=None): if xy.stride(-1) != 1: xy = xy.contiguous() batch_shape = xy.shape[:-1] xy = xy.reshape(-1, xy.shape[-1]) x, y = xy.chunk(2, dim=-1) if out is None: out = torch.empty_like(x) else: out = out.reshape(-1, out.shape[-1]) assert out.shape == x.shape assert out.stride(-1) == 1 M, N = x.shape grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N'])) with torch.cuda.device(x.device.index): _swiglu_fwd_kernel[grid](x, y, out, x.stride(0), y.stride(0), out.stride(0), N) return out.reshape(*batch_shape, out.shape[-1]) @triton.autotune( configs=autotune_configs([ triton.Config({'BLOCK_N': 32}), triton.Config({'BLOCK_N': 64}), triton.Config({'BLOCK_N': 128}), triton.Config({'BLOCK_N': 256}), triton.Config({'BLOCK_N': 512}), triton.Config({'BLOCK_N': 1024}), ]), key=['ncols'], ) @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["OUT"] is not None}) @triton.jit def _swiglu_bwd_kernel( X, Y, DOUT, OUT, DX, DY, stride_x_row, # how much to increase the pointer when moving by 1 row stride_y_row, stride_dout_row, stride_out_row, stride_dx_row, stride_dy_row, ncols, BLOCK_N: tl.constexpr, RECOMPUTE_OUTPUT: tl.constexpr, ): # Map the program id to the row of X and Y it should compute. row = tl.program_id(0) start_col = tl.program_id(1) * BLOCK_N X += row * stride_x_row Y += row * stride_y_row DOUT += row * stride_dout_row if RECOMPUTE_OUTPUT: OUT += row * stride_out_row DX += row * stride_dx_row DY += row * stride_dy_row cols = start_col + tl.arange(0, BLOCK_N) x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32) y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32) dout = tl.load(DOUT + cols, mask=cols < ncols, other=0.).to(tl.float32) x_sigmoid = tl.sigmoid(x) dx = x_sigmoid * (1 + x * (1 - x_sigmoid)) * y * dout dy = x * x_sigmoid * dout tl.store(DX + cols, dx, mask=cols < ncols) tl.store(DY + cols, dy, mask=cols < ncols) if RECOMPUTE_OUTPUT: out = x * x_sigmoid * y tl.store(OUT + cols, out, mask=cols < ncols) def _swiglu_bwd(xy, dout, dxy=None, recompute_output=False, out=None): if xy.stride(-1) != 1: xy = xy.contiguous() if dout.stride(-1) != 1: dout = dout.contiguous() batch_shape = xy.shape[:-1] xy = xy.reshape(-1, xy.shape[-1]) x, y = xy.chunk(2, dim=-1) dout = dout.reshape(-1, dout.shape[-1]) assert dout.shape == x.shape if dxy is None: dxy = torch.empty_like(xy) else: dxy = dxy.reshape(-1, dxy.shape[-1]) assert dxy.shape == xy.shape dx, dy = dxy.chunk(2, dim=-1) assert dx.stride(-1) == 1 assert dy.stride(-1) == 1 if recompute_output: if out is None: out = torch.empty_like(x) else: out = out.reshape(-1, out.shape[-1]) assert out.shape == x.shape assert out.stride(-1) == 1 M, N = x.shape grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N'])) with torch.cuda.device(x.device.index): _swiglu_bwd_kernel[grid](x, y, dout, out if recompute_output else None, dx, dy, x.stride(0), y.stride(0), dout.stride(0), out.stride(0) if recompute_output else 0, dx.stride(0), dy.stride(0), N) if not recompute_output: return dxy.reshape(*batch_shape, dxy.shape[-1]) else: return dxy.reshape(*batch_shape, dxy.shape[-1]), out.reshape(*batch_shape, out.shape[-1]) class SwiGLU(torch.autograd.Function): @staticmethod def forward(ctx, xy): ctx.save_for_backward(xy) return _swiglu_fwd(xy) @staticmethod def backward(ctx, dout): xy, = ctx.saved_tensors return _swiglu_bwd(xy, dout) swiglu = SwiGLU.apply ================================================ FILE: mamba_ssm/ops/triton/layer_norm.py ================================================ # Copyright (c) 2024, Tri Dao. # Implement dropout + residual + layer_norm / rms_norm. # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html # For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. # This is faster for dimensions up to 8k, but after that it's much slower due to register spilling. # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. import math import warnings import torch import torch.nn.functional as F from mamba_ssm.utils.torch import custom_bwd, custom_fwd import triton import triton.language as tl from mamba_ssm.utils.determinism import autotune_configs def layer_norm_ref( x, weight, bias, residual=None, x1=None, weight1=None, bias1=None, eps=1e-6, dropout_p=0.0, rowscale=None, prenorm=False, dropout_mask=None, dropout_mask1=None, upcast=False, ): dtype = x.dtype if upcast: x = x.float() weight = weight.float() bias = bias.float() if bias is not None else None residual = residual.float() if residual is not None else residual x1 = x1.float() if x1 is not None else None weight1 = weight1.float() if weight1 is not None else None bias1 = bias1.float() if bias1 is not None else None if x1 is not None: assert rowscale is None, "rowscale is not supported with parallel LayerNorm" if rowscale is not None: x = x * rowscale[..., None] if dropout_p > 0.0: if dropout_mask is not None: x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p) else: x = F.dropout(x, p=dropout_p) if x1 is not None: if dropout_mask1 is not None: x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p) else: x1 = F.dropout(x1, p=dropout_p) if x1 is not None: x = x + x1 if residual is not None: x = (x + residual).to(x.dtype) out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to( dtype ) if weight1 is None: return out if not prenorm else (out, x) else: out1 = F.layer_norm( x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps ).to(dtype) return (out, out1) if not prenorm else (out, out1, x) def rms_norm_ref( x, weight, bias, residual=None, x1=None, weight1=None, bias1=None, eps=1e-6, dropout_p=0.0, rowscale=None, prenorm=False, dropout_mask=None, dropout_mask1=None, upcast=False, ): dtype = x.dtype if upcast: x = x.float() weight = weight.float() bias = bias.float() if bias is not None else None residual = residual.float() if residual is not None else residual x1 = x1.float() if x1 is not None else None weight1 = weight1.float() if weight1 is not None else None bias1 = bias1.float() if bias1 is not None else None if x1 is not None: assert rowscale is None, "rowscale is not supported with parallel LayerNorm" if rowscale is not None: x = x * rowscale[..., None] if dropout_p > 0.0: if dropout_mask is not None: x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p) else: x = F.dropout(x, p=dropout_p) if x1 is not None: if dropout_mask1 is not None: x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p) else: x1 = F.dropout(x1, p=dropout_p) if x1 is not None: x = x + x1 if residual is not None: x = (x + residual).to(x.dtype) rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(dtype) if weight1 is None: return out if not prenorm else (out, x) else: out1 = ((x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)).to( dtype ) return (out, out1) if not prenorm else (out, out1, x) def config_prune(configs): if torch.version.hip: try: # set warp size based on gcn architecure gcn_arch_name = torch.cuda.get_device_properties(0).gcnArchName if "gfx10" in gcn_arch_name or "gfx11" in gcn_arch_name: # radeon warp_size = 32 else: # instinct warp_size = 64 except AttributeError as e: # fall back to crude method to set warp size device_name = torch.cuda.get_device_properties(0).name if 'instinct' in device_name.lower(): warp_size = 64 else: warp_size = 32 warnings.warn(f"{e}, warp size set to {warp_size} based on device name: {device_name}", UserWarning) else: # cuda warp_size = 32 max_block_sz = 1024 max_num_warps = max_block_sz // warp_size pruned_configs = [config for config in configs if config.num_warps <= max_num_warps] return pruned_configs configs_autotune = [ triton.Config({}, num_warps=1), triton.Config({}, num_warps=2), triton.Config({}, num_warps=4), triton.Config({}, num_warps=8), triton.Config({}, num_warps=16), triton.Config({}, num_warps=32), ] pruned_configs_autotune = config_prune(configs_autotune) @triton.autotune( configs=autotune_configs(pruned_configs_autotune), key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"], ) # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) # @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None}) @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None}) @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None}) @triton.jit def _layer_norm_fwd_1pass_kernel( X, # pointer to the input Y, # pointer to the output W, # pointer to the weights B, # pointer to the biases RESIDUAL, # pointer to the residual X1, W1, B1, Y1, RESIDUAL_OUT, # pointer to the residual ROWSCALE, SEEDS, # Dropout seeds for each row DROPOUT_MASK, Mean, # pointer to the mean Rstd, # pointer to the 1/std stride_x_row, # how much to increase the pointer when moving by 1 row stride_y_row, stride_res_row, stride_res_out_row, stride_x1_row, stride_y1_row, M, # number of rows in X N, # number of columns in X eps, # epsilon to avoid division by zero dropout_p, # Dropout probability IS_RMS_NORM: tl.constexpr, BLOCK_N: tl.constexpr, HAS_RESIDUAL: tl.constexpr, STORE_RESIDUAL_OUT: tl.constexpr, HAS_BIAS: tl.constexpr, HAS_DROPOUT: tl.constexpr, STORE_DROPOUT_MASK: tl.constexpr, HAS_ROWSCALE: tl.constexpr, HAS_X1: tl.constexpr, HAS_W1: tl.constexpr, HAS_B1: tl.constexpr, ): # Map the program id to the row of X and Y it should compute. row = tl.program_id(0) X += row * stride_x_row Y += row * stride_y_row if HAS_RESIDUAL: RESIDUAL += row * stride_res_row if STORE_RESIDUAL_OUT: RESIDUAL_OUT += row * stride_res_out_row if HAS_X1: X1 += row * stride_x1_row if HAS_W1: Y1 += row * stride_y1_row # Compute mean and variance cols = tl.arange(0, BLOCK_N) x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) if HAS_ROWSCALE: rowscale = tl.load(ROWSCALE + row).to(tl.float32) x *= rowscale if HAS_DROPOUT: # Compute dropout mask # 7 rounds is good enough, and reduces register pressure keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0) if STORE_DROPOUT_MASK: tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N) if HAS_X1: x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32) if HAS_ROWSCALE: rowscale = tl.load(ROWSCALE + M + row).to(tl.float32) x1 *= rowscale if HAS_DROPOUT: # Compute dropout mask # 7 rounds is good enough, and reduces register pressure keep_mask = ( tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p ) x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0) if STORE_DROPOUT_MASK: tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N) x += x1 if HAS_RESIDUAL: residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32) x += residual if STORE_RESIDUAL_OUT: tl.store(RESIDUAL_OUT + cols, x, mask=cols < N) if not IS_RMS_NORM: mean = tl.sum(x, axis=0) / N tl.store(Mean + row, mean) xbar = tl.where(cols < N, x - mean, 0.0) var = tl.sum(xbar * xbar, axis=0) / N else: xbar = tl.where(cols < N, x, 0.0) var = tl.sum(xbar * xbar, axis=0) / N rstd = 1 / tl.sqrt(var + eps) tl.store(Rstd + row, rstd) # Normalize and apply linear transformation mask = cols < N w = tl.load(W + cols, mask=mask).to(tl.float32) if HAS_BIAS: b = tl.load(B + cols, mask=mask).to(tl.float32) x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd y = x_hat * w + b if HAS_BIAS else x_hat * w # Write output tl.store(Y + cols, y, mask=mask) if HAS_W1: w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) if HAS_B1: b1 = tl.load(B1 + cols, mask=mask).to(tl.float32) y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1 tl.store(Y1 + cols, y1, mask=mask) def _layer_norm_fwd( x, weight, bias, eps, residual=None, x1=None, weight1=None, bias1=None, dropout_p=0.0, rowscale=None, out_dtype=None, residual_dtype=None, is_rms_norm=False, return_dropout_mask=False, ): if residual is not None: residual_dtype = residual.dtype M, N = x.shape assert x.stride(-1) == 1 if residual is not None: assert residual.stride(-1) == 1 assert residual.shape == (M, N) assert weight.shape == (N,) assert weight.stride(-1) == 1 if bias is not None: assert bias.stride(-1) == 1 assert bias.shape == (N,) if x1 is not None: assert x1.shape == x.shape assert rowscale is None assert x1.stride(-1) == 1 if weight1 is not None: assert weight1.shape == (N,) assert weight1.stride(-1) == 1 if bias1 is not None: assert bias1.shape == (N,) assert bias1.stride(-1) == 1 if rowscale is not None: assert rowscale.is_contiguous() assert rowscale.shape == (M,) # allocate output y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) assert y.stride(-1) == 1 if weight1 is not None: y1 = torch.empty_like(y) assert y1.stride(-1) == 1 else: y1 = None if ( residual is not None or (residual_dtype is not None and residual_dtype != x.dtype) or dropout_p > 0.0 or rowscale is not None or x1 is not None ): residual_out = torch.empty( M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype ) assert residual_out.stride(-1) == 1 else: residual_out = None mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None rstd = torch.empty((M,), dtype=torch.float32, device=x.device) if dropout_p > 0.0: seeds = torch.randint( 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64 ) else: seeds = None if return_dropout_mask and dropout_p > 0.0: dropout_mask = torch.empty(M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool) else: dropout_mask = None # Less than 64KB per feature: enqueue fused kernel MAX_FUSED_SIZE = 65536 // x.element_size() BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) if N > BLOCK_N: raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") with torch.cuda.device(x.device.index): _layer_norm_fwd_1pass_kernel[(M,)]( x, y, weight, bias, residual, x1, weight1, bias1, y1, residual_out, rowscale, seeds, dropout_mask, mean, rstd, x.stride(0), y.stride(0), residual.stride(0) if residual is not None else 0, residual_out.stride(0) if residual_out is not None else 0, x1.stride(0) if x1 is not None else 0, y1.stride(0) if y1 is not None else 0, M, N, eps, dropout_p, is_rms_norm, BLOCK_N, residual is not None, residual_out is not None, bias is not None, dropout_p > 0.0, dropout_mask is not None, rowscale is not None, ) # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0 if dropout_mask is not None and x1 is not None: dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0) else: dropout_mask1 = None return ( y, y1, mean, rstd, residual_out if residual_out is not None else x, seeds, dropout_mask, dropout_mask1, ) @triton.autotune( configs=autotune_configs(pruned_configs_autotune), key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"], ) # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None}) # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None}) @triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None}) @triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None}) @triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None}) @triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None}) @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) @triton.jit def _layer_norm_bwd_kernel( X, # pointer to the input W, # pointer to the weights B, # pointer to the biases Y, # pointer to the output to be recomputed DY, # pointer to the output gradient DX, # pointer to the input gradient DW, # pointer to the partial sum of weights gradient DB, # pointer to the partial sum of biases gradient DRESIDUAL, W1, DY1, DX1, DW1, DB1, DRESIDUAL_IN, ROWSCALE, SEEDS, Mean, # pointer to the mean Rstd, # pointer to the 1/std stride_x_row, # how much to increase the pointer when moving by 1 row stride_y_row, stride_dy_row, stride_dx_row, stride_dres_row, stride_dy1_row, stride_dx1_row, stride_dres_in_row, M, # number of rows in X N, # number of columns in X eps, # epsilon to avoid division by zero dropout_p, rows_per_program, IS_RMS_NORM: tl.constexpr, BLOCK_N: tl.constexpr, HAS_DRESIDUAL: tl.constexpr, STORE_DRESIDUAL: tl.constexpr, HAS_BIAS: tl.constexpr, HAS_DROPOUT: tl.constexpr, HAS_ROWSCALE: tl.constexpr, HAS_DY1: tl.constexpr, HAS_DX1: tl.constexpr, HAS_B1: tl.constexpr, RECOMPUTE_OUTPUT: tl.constexpr, ): # Map the program id to the elements of X, DX, and DY it should compute. row_block_id = tl.program_id(0) row_start = row_block_id * rows_per_program # Do not early exit if row_start >= M, because we need to write DW and DB cols = tl.arange(0, BLOCK_N) mask = cols < N X += row_start * stride_x_row if HAS_DRESIDUAL: DRESIDUAL += row_start * stride_dres_row if STORE_DRESIDUAL: DRESIDUAL_IN += row_start * stride_dres_in_row DY += row_start * stride_dy_row DX += row_start * stride_dx_row if HAS_DY1: DY1 += row_start * stride_dy1_row if HAS_DX1: DX1 += row_start * stride_dx1_row if RECOMPUTE_OUTPUT: Y += row_start * stride_y_row w = tl.load(W + cols, mask=mask).to(tl.float32) if RECOMPUTE_OUTPUT and HAS_BIAS: b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32) if HAS_DY1: w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) dw = tl.zeros((BLOCK_N,), dtype=tl.float32) if HAS_BIAS: db = tl.zeros((BLOCK_N,), dtype=tl.float32) if HAS_DY1: dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32) if HAS_B1: db1 = tl.zeros((BLOCK_N,), dtype=tl.float32) row_end = min((row_block_id + 1) * rows_per_program, M) for row in range(row_start, row_end): # Load data to SRAM x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) if HAS_DY1: dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32) if not IS_RMS_NORM: mean = tl.load(Mean + row) rstd = tl.load(Rstd + row) # Compute dx xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd xhat = tl.where(mask, xhat, 0.0) if RECOMPUTE_OUTPUT: y = xhat * w + b if HAS_BIAS else xhat * w tl.store(Y + cols, y, mask=mask) wdy = w * dy dw += dy * xhat if HAS_BIAS: db += dy if HAS_DY1: wdy += w1 * dy1 dw1 += dy1 * xhat if HAS_B1: db1 += dy1 if not IS_RMS_NORM: c1 = tl.sum(xhat * wdy, axis=0) / N c2 = tl.sum(wdy, axis=0) / N dx = (wdy - (xhat * c1 + c2)) * rstd else: c1 = tl.sum(xhat * wdy, axis=0) / N dx = (wdy - xhat * c1) * rstd if HAS_DRESIDUAL: dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32) dx += dres # Write dx if STORE_DRESIDUAL: tl.store(DRESIDUAL_IN + cols, dx, mask=mask) if HAS_DX1: if HAS_DROPOUT: keep_mask = ( tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p ) dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0) else: dx1 = dx tl.store(DX1 + cols, dx1, mask=mask) if HAS_DROPOUT: keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0) if HAS_ROWSCALE: rowscale = tl.load(ROWSCALE + row).to(tl.float32) dx *= rowscale tl.store(DX + cols, dx, mask=mask) X += stride_x_row if HAS_DRESIDUAL: DRESIDUAL += stride_dres_row if STORE_DRESIDUAL: DRESIDUAL_IN += stride_dres_in_row if RECOMPUTE_OUTPUT: Y += stride_y_row DY += stride_dy_row DX += stride_dx_row if HAS_DY1: DY1 += stride_dy1_row if HAS_DX1: DX1 += stride_dx1_row tl.store(DW + row_block_id * N + cols, dw, mask=mask) if HAS_BIAS: tl.store(DB + row_block_id * N + cols, db, mask=mask) if HAS_DY1: tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask) if HAS_B1: tl.store(DB1 + row_block_id * N + cols, db1, mask=mask) def _layer_norm_bwd( dy, x, weight, bias, eps, mean, rstd, dresidual=None, dy1=None, weight1=None, bias1=None, seeds=None, dropout_p=0.0, rowscale=None, has_residual=False, has_x1=False, is_rms_norm=False, x_dtype=None, recompute_output=False, ): M, N = x.shape assert x.stride(-1) == 1 assert dy.stride(-1) == 1 assert dy.shape == (M, N) if dresidual is not None: assert dresidual.stride(-1) == 1 assert dresidual.shape == (M, N) assert weight.shape == (N,) assert weight.stride(-1) == 1 if bias is not None: assert bias.stride(-1) == 1 assert bias.shape == (N,) if dy1 is not None: assert weight1 is not None assert dy1.shape == dy.shape assert dy1.stride(-1) == 1 if weight1 is not None: assert weight1.shape == (N,) assert weight1.stride(-1) == 1 if bias1 is not None: assert bias1.shape == (N,) assert bias1.stride(-1) == 1 if seeds is not None: assert seeds.is_contiguous() assert seeds.shape == (M if not has_x1 else M * 2,) if rowscale is not None: assert rowscale.is_contiguous() assert rowscale.shape == (M,) # allocate output dx = ( torch.empty_like(x) if x_dtype is None else torch.empty(M, N, dtype=x_dtype, device=x.device) ) dresidual_in = ( torch.empty_like(x) if has_residual and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1) else None ) dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None if recompute_output: assert weight1 is None, "recompute_output is not supported with parallel LayerNorm" # Less than 64KB per feature: enqueue fused kernel MAX_FUSED_SIZE = 65536 // x.element_size() BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) if N > BLOCK_N: raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) _db = ( torch.empty((sm_count, N), dtype=torch.float32, device=bias.device) if bias is not None else None ) _dw1 = torch.empty_like(_dw) if weight1 is not None else None _db1 = torch.empty_like(_db) if bias1 is not None else None rows_per_program = math.ceil(M / sm_count) grid = (sm_count,) with torch.cuda.device(x.device.index): _layer_norm_bwd_kernel[grid]( x, weight, bias, y, dy, dx, _dw, _db, dresidual, weight1, dy1, dx1, _dw1, _db1, dresidual_in, rowscale, seeds, mean, rstd, x.stride(0), 0 if not recompute_output else y.stride(0), dy.stride(0), dx.stride(0), dresidual.stride(0) if dresidual is not None else 0, dy1.stride(0) if dy1 is not None else 0, dx1.stride(0) if dx1 is not None else 0, dresidual_in.stride(0) if dresidual_in is not None else 0, M, N, eps, dropout_p, rows_per_program, is_rms_norm, BLOCK_N, dresidual is not None, dresidual_in is not None, bias is not None, dropout_p > 0.0, ) dw = _dw.sum(0).to(weight.dtype) db = _db.sum(0).to(bias.dtype) if bias is not None else None dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None # Don't need to compute dresidual_in separately in this case if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None: dresidual_in = dx if has_x1 and dropout_p == 0.0: dx1 = dx return ( (dx, dw, db, dresidual_in, dx1, dw1, db1) if not recompute_output else (dx, dw, db, dresidual_in, dx1, dw1, db1, y) ) class LayerNormFn(torch.autograd.Function): @staticmethod def forward( ctx, x, weight, bias, residual=None, x1=None, weight1=None, bias1=None, eps=1e-6, dropout_p=0.0, rowscale=None, prenorm=False, residual_in_fp32=False, is_rms_norm=False, return_dropout_mask=False, ): x_shape_og = x.shape # reshape input data into 2D tensor x = x.reshape(-1, x.shape[-1]) if x.stride(-1) != 1: x = x.contiguous() if residual is not None: assert residual.shape == x_shape_og residual = residual.reshape(-1, residual.shape[-1]) if residual.stride(-1) != 1: residual = residual.contiguous() if x1 is not None: assert x1.shape == x_shape_og assert rowscale is None, "rowscale is not supported with parallel LayerNorm" x1 = x1.reshape(-1, x1.shape[-1]) if x1.stride(-1) != 1: x1 = x1.contiguous() weight = weight.contiguous() if bias is not None: bias = bias.contiguous() if weight1 is not None: weight1 = weight1.contiguous() if bias1 is not None: bias1 = bias1.contiguous() if rowscale is not None: rowscale = rowscale.reshape(-1).contiguous() residual_dtype = ( residual.dtype if residual is not None else (torch.float32 if residual_in_fp32 else None) ) y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd( x, weight, bias, eps, residual, x1, weight1, bias1, dropout_p=dropout_p, rowscale=rowscale, residual_dtype=residual_dtype, is_rms_norm=is_rms_norm, return_dropout_mask=return_dropout_mask, ) ctx.save_for_backward( residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd ) ctx.x_shape_og = x_shape_og ctx.eps = eps ctx.dropout_p = dropout_p ctx.is_rms_norm = is_rms_norm ctx.has_residual = residual is not None ctx.has_x1 = x1 is not None ctx.prenorm = prenorm ctx.x_dtype = x.dtype y = y.reshape(x_shape_og) y1 = y1.reshape(x_shape_og) if y1 is not None else None residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None dropout_mask1 = dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None if not return_dropout_mask: if weight1 is None: return y if not prenorm else (y, residual_out) else: return (y, y1) if not prenorm else (y, y1, residual_out) else: if weight1 is None: return ( (y, dropout_mask, dropout_mask1) if not prenorm else (y, residual_out, dropout_mask, dropout_mask1) ) else: return ( (y, y1, dropout_mask, dropout_mask1) if not prenorm else (y, y1, residual_out, dropout_mask, dropout_mask1) ) @staticmethod def backward(ctx, dy, *args): x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors dy = dy.reshape(-1, dy.shape[-1]) if dy.stride(-1) != 1: dy = dy.contiguous() assert dy.shape == x.shape if weight1 is not None: dy1, args = args[0], args[1:] dy1 = dy1.reshape(-1, dy1.shape[-1]) if dy1.stride(-1) != 1: dy1 = dy1.contiguous() assert dy1.shape == x.shape else: dy1 = None if ctx.prenorm: dresidual = args[0] dresidual = dresidual.reshape(-1, dresidual.shape[-1]) if dresidual.stride(-1) != 1: dresidual = dresidual.contiguous() assert dresidual.shape == x.shape else: dresidual = None dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd( dy, x, weight, bias, ctx.eps, mean, rstd, dresidual, dy1, weight1, bias1, seeds, ctx.dropout_p, rowscale, ctx.has_residual, ctx.has_x1, ctx.is_rms_norm, x_dtype=ctx.x_dtype, ) return ( dx.reshape(ctx.x_shape_og), dw, db, dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, dx1.reshape(ctx.x_shape_og) if dx1 is not None else None, dw1, db1, None, None, None, None, None, None, None, ) def layer_norm_fn( x, weight, bias, residual=None, x1=None, weight1=None, bias1=None, eps=1e-6, dropout_p=0.0, rowscale=None, prenorm=False, residual_in_fp32=False, is_rms_norm=False, return_dropout_mask=False, ): return LayerNormFn.apply( x, weight, bias, residual, x1, weight1, bias1, eps, dropout_p, rowscale, prenorm, residual_in_fp32, is_rms_norm, return_dropout_mask, ) def rms_norm_fn( x, weight, bias, residual=None, x1=None, weight1=None, bias1=None, eps=1e-6, dropout_p=0.0, rowscale=None, prenorm=False, residual_in_fp32=False, return_dropout_mask=False, ): return LayerNormFn.apply( x, weight, bias, residual, x1, weight1, bias1, eps, dropout_p, rowscale, prenorm, residual_in_fp32, True, return_dropout_mask, ) class RMSNorm(torch.nn.Module): def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, dtype=None): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.eps = eps if dropout_p > 0.0: self.drop = torch.nn.Dropout(dropout_p) else: self.drop = None self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) self.register_parameter("bias", None) self.reset_parameters() def reset_parameters(self): torch.nn.init.ones_(self.weight) def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): return rms_norm_fn( x, self.weight, self.bias, residual=residual, eps=self.eps, dropout_p=self.drop.p if self.drop is not None and self.training else 0.0, prenorm=prenorm, residual_in_fp32=residual_in_fp32, ) class LayerNormLinearFn(torch.autograd.Function): @staticmethod @custom_fwd def forward( ctx, x, norm_weight, norm_bias, linear_weight, linear_bias, residual=None, eps=1e-6, prenorm=False, residual_in_fp32=False, is_rms_norm=False, ): x_shape_og = x.shape # reshape input data into 2D tensor x = x.reshape(-1, x.shape[-1]) if x.stride(-1) != 1: x = x.contiguous() if residual is not None: assert residual.shape == x_shape_og residual = residual.reshape(-1, residual.shape[-1]) if residual.stride(-1) != 1: residual = residual.contiguous() norm_weight = norm_weight.contiguous() if norm_bias is not None: norm_bias = norm_bias.contiguous() residual_dtype = ( residual.dtype if residual is not None else (torch.float32 if residual_in_fp32 else None) ) y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd( x, norm_weight, norm_bias, eps, residual, out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(), residual_dtype=residual_dtype, is_rms_norm=is_rms_norm, ) y = y.reshape(x_shape_og) dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype linear_weight = linear_weight.to(dtype) linear_bias = linear_bias.to(dtype) if linear_bias is not None else None out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias) # We don't store y, will be recomputed in the backward pass to save memory ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd) ctx.x_shape_og = x_shape_og ctx.eps = eps ctx.is_rms_norm = is_rms_norm ctx.has_residual = residual is not None ctx.prenorm = prenorm ctx.x_dtype = x.dtype ctx.linear_bias_is_none = linear_bias is None return out if not prenorm else (out, residual_out.reshape(x_shape_og)) @staticmethod @custom_bwd def backward(ctx, dout, *args): x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors dout = dout.reshape(-1, dout.shape[-1]) dy = F.linear(dout, linear_weight.t()) dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0) if dy.stride(-1) != 1: dy = dy.contiguous() assert dy.shape == x.shape if ctx.prenorm: dresidual = args[0] dresidual = dresidual.reshape(-1, dresidual.shape[-1]) if dresidual.stride(-1) != 1: dresidual = dresidual.contiguous() assert dresidual.shape == x.shape else: dresidual = None dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd( dy, x, norm_weight, norm_bias, ctx.eps, mean, rstd, dresidual=dresidual, has_residual=ctx.has_residual, is_rms_norm=ctx.is_rms_norm, x_dtype=ctx.x_dtype, recompute_output=True, ) dlinear_weight = torch.einsum("bo,bi->oi", dout, y) return ( dx.reshape(ctx.x_shape_og), dnorm_weight, dnorm_bias, dlinear_weight, dlinear_bias, dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, None, None, None, None, ) def layer_norm_linear_fn( x, norm_weight, norm_bias, linear_weight, linear_bias, residual=None, eps=1e-6, prenorm=False, residual_in_fp32=False, is_rms_norm=False, ): return LayerNormLinearFn.apply( x, norm_weight, norm_bias, linear_weight, linear_bias, residual, eps, prenorm, residual_in_fp32, is_rms_norm, ) ================================================ FILE: mamba_ssm/ops/triton/layernorm_gated.py ================================================ # Copyright (c) 2024, Tri Dao. # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html # For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. # This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling. # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. import math import torch import torch.nn.functional as F import triton import triton.language as tl from einops import rearrange def rms_norm_ref(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, upcast=True): dtype = x.dtype N = x.shape[-1] weight = weight.float() bias = bias.float() if bias is not None else None if upcast: x = x.float() z = z.float() if z is not None else z if z is not None and not norm_before_gate: x = x * F.silu(z) if group_size is None: rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight) else: x_group = rearrange(x, "... (g d) -> ... g d", d=group_size) rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps) out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight if bias is not None: out = out + bias if z is not None and norm_before_gate: out *= F.silu(z) return out.to(dtype) @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) @triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None}) @triton.jit def _layer_norm_fwd_1pass_kernel( X, # pointer to the input Y, # pointer to the output W, # pointer to the weights B, # pointer to the biases Z, # pointer to the other branch Mean, # pointer to the mean Rstd, # pointer to the 1/std stride_x_row, # how much to increase the pointer when moving by 1 row stride_y_row, stride_z_row, M, # number of rows in X N, # number of columns in X eps, # epsilon to avoid division by zero BLOCK_N: tl.constexpr, HAS_BIAS: tl.constexpr, HAS_Z: tl.constexpr, NORM_BEFORE_GATE: tl.constexpr, IS_RMS_NORM: tl.constexpr, ): # Map the program id to the row of X and Y it should compute. row = tl.program_id(0) group = tl.program_id(1) X += row * stride_x_row + group * N Y += row * stride_y_row + group * N if HAS_Z: Z += row * stride_z_row + group * N if not IS_RMS_NORM: Mean += group * M Rstd += group * M W += group * N if HAS_BIAS: B += group * N # Compute mean and variance cols = tl.arange(0, BLOCK_N) x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) if HAS_Z and not NORM_BEFORE_GATE: z = tl.load(Z + cols, mask=cols < N).to(tl.float32) x *= z * tl.sigmoid(z) if not IS_RMS_NORM: mean = tl.sum(x, axis=0) / N tl.store(Mean + row, mean) xbar = tl.where(cols < N, x - mean, 0.) var = tl.sum(xbar * xbar, axis=0) / N else: xbar = tl.where(cols < N, x, 0.) var = tl.sum(xbar * xbar, axis=0) / N rstd = 1 / tl.sqrt(var + eps) tl.store(Rstd + row, rstd) # Normalize and apply linear transformation mask = cols < N w = tl.load(W + cols, mask=mask).to(tl.float32) if HAS_BIAS: b = tl.load(B + cols, mask=mask).to(tl.float32) x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd y = x_hat * w + b if HAS_BIAS else x_hat * w if HAS_Z and NORM_BEFORE_GATE: z = tl.load(Z + cols, mask=mask).to(tl.float32) y *= z * tl.sigmoid(z) # Write output tl.store(Y + cols, y, mask=mask) def _layer_norm_fwd(x, weight, bias, eps, z=None, out=None, group_size=None, norm_before_gate=True, is_rms_norm=False): M, N = x.shape if group_size is None: group_size = N assert N % group_size == 0 ngroups = N // group_size assert x.stride(-1) == 1 if z is not None: assert z.stride(-1) == 1 assert z.shape == (M, N) assert weight.shape == (N,) assert weight.stride(-1) == 1 if bias is not None: assert bias.stride(-1) == 1 assert bias.shape == (N,) # allocate output if out is not None: assert out.shape == x.shape else: out = torch.empty_like(x) assert out.stride(-1) == 1 mean = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) if not is_rms_norm else None rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) # Less than 64KB per feature: enqueue fused kernel MAX_FUSED_SIZE = 65536 // x.element_size() BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) if group_size > BLOCK_N: raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") # heuristics for number of warps num_warps = min(max(BLOCK_N // 256, 1), 8) grid = (M, ngroups) with torch.cuda.device(x.device.index): _layer_norm_fwd_1pass_kernel[grid](x, out, weight, bias, z, mean, rstd, x.stride(0), out.stride(0), z.stride(0) if z is not None else 0, M, group_size, eps, BLOCK_N=BLOCK_N, NORM_BEFORE_GATE=norm_before_gate, IS_RMS_NORM=is_rms_norm, num_warps=num_warps) return out, mean, rstd @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) @triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None}) @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) @triton.jit def _layer_norm_bwd_kernel( X, # pointer to the input W, # pointer to the weights B, # pointer to the biases Z, # pointer to the other branch Y, # pointer to the output to be recomputed DY, # pointer to the output gradient DX, # pointer to the input gradient DW, # pointer to the partial sum of weights gradient DB, # pointer to the partial sum of biases gradient DZ, # pointer to the other branch Mean, # pointer to the mean Rstd, # pointer to the 1/std stride_x_row, # how much to increase the pointer when moving by 1 row stride_z_row, stride_y_row, stride_dy_row, stride_dx_row, stride_dz_row, stride_dw_row, stride_db_row, M, # number of rows in X N, # number of columns in X eps, # epsilon to avoid division by zero rows_per_program, NORM_BEFORE_GATE: tl.constexpr, IS_RMS_NORM: tl.constexpr, HAS_BIAS: tl.constexpr, HAS_Z: tl.constexpr, RECOMPUTE_OUTPUT: tl.constexpr, BLOCK_N: tl.constexpr, ): # Map the program id to the elements of X, DX, and DY it should compute. row_block_id = tl.program_id(0) group = tl.program_id(1) row_start = row_block_id * rows_per_program cols = tl.arange(0, BLOCK_N) mask = cols < N X += row_start * stride_x_row + group * N if HAS_Z: Z += row_start * stride_z_row + group * N DZ += row_start * stride_dz_row + group * N DY += row_start * stride_dy_row + group * N DX += row_start * stride_dx_row + group * N if RECOMPUTE_OUTPUT: Y += row_start * stride_y_row + group * N if not IS_RMS_NORM: Mean += group * M Rstd += group * M W += group * N w = tl.load(W + cols, mask=mask).to(tl.float32) if (RECOMPUTE_OUTPUT or HAS_Z) and HAS_BIAS: B += group * N b = tl.load(B + cols, mask=mask, other=0.).to(tl.float32) dw = tl.zeros((BLOCK_N,), dtype=tl.float32) if HAS_BIAS: db = tl.zeros((BLOCK_N,), dtype=tl.float32) row_end = min((row_block_id + 1) * rows_per_program, M) for row in range(row_start, row_end): # Load data to SRAM x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) if not IS_RMS_NORM: mean = tl.load(Mean + row) if HAS_Z and not NORM_BEFORE_GATE: z = tl.load(Z + cols, mask=mask, other=0.).to(tl.float32) x_og = x x = x_og * z * tl.sigmoid(z) rstd = tl.load(Rstd + row) # Compute dx xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd xhat = tl.where(mask, xhat, 0.) if HAS_Z and NORM_BEFORE_GATE: z = tl.load(Z + cols, mask=mask, other=0.).to(tl.float32) z_sigmoid = tl.sigmoid(z) y = xhat * w + b if HAS_BIAS else xhat * w if RECOMPUTE_OUTPUT: tl.store(Y + cols, y * z * z_sigmoid, mask=mask) dz = dy * y * z_sigmoid * (1 + z * (1 - z_sigmoid)) tl.store(DZ + cols, dz, mask=mask) dy *= z * z_sigmoid else: if RECOMPUTE_OUTPUT: y = xhat * w + b if HAS_BIAS else xhat * w tl.store(Y + cols, y, mask=mask) wdy = w * dy c1 = tl.sum(xhat * wdy, axis=0) / N if not IS_RMS_NORM: c2 = tl.sum(wdy, axis=0) / N dx = (wdy - (xhat * c1 + c2)) * rstd else: dx = (wdy - xhat * c1) * rstd dw += dy * xhat if HAS_BIAS: db += dy if HAS_Z and not NORM_BEFORE_GATE: z_sigmoid = tl.sigmoid(z) dz = dx * x_og * z_sigmoid * (1 + z * (1 - z_sigmoid)) tl.store(DZ + cols, dz, mask=mask) dx *= z * z_sigmoid # Write dx tl.store(DX + cols, dx, mask=mask) X += stride_x_row if HAS_Z: Z += stride_z_row DZ += stride_dz_row if RECOMPUTE_OUTPUT: Y += stride_y_row DY += stride_dy_row DX += stride_dx_row tl.store(DW + row_block_id * stride_dw_row + group * N + cols, dw, mask=mask) if HAS_BIAS: tl.store(DB + row_block_id * stride_db_row + group * N + cols, db, mask=mask) def _layer_norm_bwd(dy, x, weight, bias, eps, mean, rstd, z=None, group_size=None, norm_before_gate=True, is_rms_norm=False, recompute_output=False, dz=None, out=None): M, N = x.shape if group_size is None: group_size = N assert N % group_size == 0 ngroups = N // group_size assert x.stride(-1) == 1 assert dy.stride(-1) == 1 assert dy.shape == (M, N) if z is not None: assert z.stride(-1) == 1 assert z.shape == (M, N) assert weight.shape == (N,) assert weight.stride(-1) == 1 if bias is not None: assert bias.stride(-1) == 1 assert bias.shape == (N,) # allocate output dx = torch.empty_like(x) if dz is not None: assert z is not None assert dz.shape == z.shape assert dz.stride(-1) == 1 else: dz = torch.empty_like(z) if z is not None else None if recompute_output: if out is None: out = torch.empty_like(x) assert out.shape == x.shape # Less than 64KB per feature: enqueue fused kernel MAX_FUSED_SIZE = 65536 // x.element_size() BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) if group_size > BLOCK_N: raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") # heuristics for number of warps num_warps = min(max(BLOCK_N // 256, 1), 8) sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count # If group size is small (e.g., 64), we're only using 1 warp. So having just 108 programs # would limit the occupancy. nrow_groups = math.ceil(sm_count * math.ceil(4 / num_warps) / ngroups) _dw = torch.empty((nrow_groups, N), dtype=torch.float32, device=weight.device) _db = torch.empty((nrow_groups, N), dtype=torch.float32, device=bias.device) if bias is not None else None rows_per_program = math.ceil(M / nrow_groups) grid = (nrow_groups, ngroups) with torch.cuda.device(x.device.index): _layer_norm_bwd_kernel[grid](x, weight, bias, z, out if recompute_output else None, dy, dx, _dw, _db, dz, mean, rstd, x.stride(0), z.stride(0) if z is not None else 0, 0 if not recompute_output else out.stride(0), dy.stride(0), dx.stride(0), dz.stride(0) if dz is not None else 0, _dw.stride(0), _db.stride(0) if _db is not None else 0, M, group_size, eps, rows_per_program, BLOCK_N=BLOCK_N, NORM_BEFORE_GATE=norm_before_gate, IS_RMS_NORM=is_rms_norm, num_warps=num_warps) dw = _dw.sum(0).to(weight.dtype) db = _db.sum(0).to(bias.dtype) if bias is not None else None return (dx, dw, db, dz) if not recompute_output else (dx, dw, db, dz, out) class LayerNormFn(torch.autograd.Function): @staticmethod def forward(ctx, x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, is_rms_norm=False): """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z)) """ x_shape_og = x.shape # reshape input data into 2D tensor x = x.reshape(-1, x.shape[-1]) if x.stride(-1) != 1: x = x.contiguous() if z is not None: assert z.shape == x_shape_og z = z.reshape(-1, z.shape[-1]) if z.stride(-1) != 1: z = z.contiguous() weight = weight.contiguous() if bias is not None: bias = bias.contiguous() y, mean, rstd = _layer_norm_fwd(x, weight, bias, eps, z=z, group_size=group_size, norm_before_gate=norm_before_gate, is_rms_norm=is_rms_norm) ctx.save_for_backward(x, weight, bias, mean, rstd, z) ctx.x_shape_og = x_shape_og ctx.eps = eps ctx.group_size = group_size ctx.norm_before_gate = norm_before_gate ctx.is_rms_norm = is_rms_norm return y.reshape(x_shape_og) @staticmethod def backward(ctx, dy): x, weight, bias, mean, rstd, z = ctx.saved_tensors dy = dy.reshape(-1, dy.shape[-1]) if dy.stride(-1) != 1: dy = dy.contiguous() assert dy.shape == x.shape dx, dw, db, dz = _layer_norm_bwd(dy, x, weight, bias, ctx.eps, mean, rstd, z, ctx.group_size, ctx.norm_before_gate, ctx.is_rms_norm) return dx.reshape(ctx.x_shape_og), dw, db, dz.reshape(ctx.x_shape_og) if dz is not None else None, None, None, None, None def layernorm_fn(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, is_rms_norm=False): return LayerNormFn.apply(x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm) def rmsnorm_fn(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True): return LayerNormFn.apply(x, weight, bias, z, eps, group_size, norm_before_gate, True) class LayerNorm(torch.nn.Module): def __init__(self, hidden_size, eps=1e-5, group_size=None, norm_before_gate=True, device=None, dtype=None): """If group_size is not None, we do GroupNorm with each group having group_size elements. group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group). """ factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.eps = eps self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) self.group_size = group_size self.norm_before_gate = norm_before_gate self.reset_parameters() def reset_parameters(self): torch.nn.init.ones_(self.weight) torch.nn.init.zeros_(self.bias) def forward(self, x, z=None): """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z)) """ return layernorm_fn(x, self.weight, self.bias, z=z, group_size=self.group_size, eps=self.eps, norm_before_gate=self.norm_before_gate) class RMSNorm(torch.nn.Module): def __init__(self, hidden_size, eps=1e-5, group_size=None, norm_before_gate=True, device=None, dtype=None): """If group_size is not None, we do GroupNorm with each group having group_size elements. group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group). """ factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.eps = eps self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) self.register_parameter("bias", None) self.group_size = group_size self.norm_before_gate = norm_before_gate self.reset_parameters() def reset_parameters(self): torch.nn.init.ones_(self.weight) def forward(self, x, z=None): """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z)) """ return rmsnorm_fn(x, self.weight, self.bias, z=z, eps=self.eps, group_size=self.group_size, norm_before_gate=self.norm_before_gate) ================================================ FILE: mamba_ssm/ops/triton/mamba3/angle_dt.py ================================================ from typing import Tuple, Optional import torch from torch import Tensor import triton import triton.language as tl from mamba_ssm.ops.triton.mamba3.utils import tanh_approx, sech2_approx # ----------------------------------------------------------------------------- # Forward kernel # ----------------------------------------------------------------------------- @triton.autotune( configs=[ triton.Config({}, num_stages=s, num_warps=w) for s in [1, 2, 3] for w in [2, 4, 8] ], key=["CHUNK_SIZE", "BLOCK_D", "HAS_INIT_STATE", "RETURN_OUTPUT_STATE", "IS_VARLEN"], ) @triton.jit def angle_dt_fwd_kernel( # Outputs OUT, OUTPUT_STATE, # Inputs ANGLE, DT, INIT_STATE, CU_SEQLENS, # Strides for OUT (batch, seqlen, nheads, dim) stride_out_batch, stride_out_seq, stride_out_head, stride_out_dim, # Strides for OUTPUT_STATE (num_sequences, nheads, dim) stride_output_state_seq, stride_output_state_head, stride_output_state_dim, # Strides for ANGLE (batch, seqlen, nheads, dim) stride_angle_batch, stride_angle_seq, stride_angle_head, stride_angle_dim, # Strides for DT (batch, nheads, seqlen) stride_dt_batch, stride_dt_head, stride_dt_seq, # Strides for INIT_STATE (num_sequences, nheads, dim) stride_init_seq, stride_init_head, stride_init_dim, # Stride for CU_SEQLENS stride_cu_seqlen, # Dimensions seqlen, dim, # Meta-parameters CHUNK_SIZE: tl.constexpr, BLOCK_D: tl.constexpr, HAS_INIT_STATE: tl.constexpr, RETURN_OUTPUT_STATE: tl.constexpr, IS_VARLEN: tl.constexpr, ): pid_h = tl.program_id(0) pid_b = tl.program_id(1) # Handle varlen mode if IS_VARLEN: pid_seq = tl.program_id(2) seq_idx = pid_seq cu_seqlen_start = tl.load(CU_SEQLENS + pid_seq * stride_cu_seqlen).to(tl.int32) cu_seqlen_end = tl.load(CU_SEQLENS + (pid_seq + 1) * stride_cu_seqlen).to(tl.int32) seq_len = cu_seqlen_end - cu_seqlen_start seq_offset = cu_seqlen_start else: seq_idx = pid_b seq_len = seqlen seq_offset = 0 nchunks = tl.cdiv(seq_len, CHUNK_SIZE) # Offset base pointers by batch and head ANGLE += pid_b * stride_angle_batch + pid_h * stride_angle_head + seq_offset * stride_angle_seq DT += pid_b * stride_dt_batch + pid_h * stride_dt_head + seq_offset * stride_dt_seq OUT += pid_b * stride_out_batch + pid_h * stride_out_head + seq_offset * stride_out_seq dim_range = tl.arange(0, BLOCK_D) dim_mask = dim_range < dim # Initialize state from init_state or zeros if HAS_INIT_STATE: init_ptrs = INIT_STATE + seq_idx * stride_init_seq + pid_h * stride_init_head + dim_range * stride_init_dim state = tl.load(init_ptrs, mask=dim_mask, other=0.0).to(tl.float32) else: state = tl.zeros((BLOCK_D,), dtype=tl.float32) PI = 3.141592653589793 TWO_PI = 2 * PI for chunk_idx in range(nchunks): chunk_start = chunk_idx * CHUNK_SIZE seq_range = tl.arange(0, CHUNK_SIZE) seq_mask = (chunk_start + seq_range) < seq_len # Load angle (CHUNK_SIZE, BLOCK_D) angle_ptrs = ANGLE + (chunk_start + seq_range[:, None]) * stride_angle_seq + dim_range[None, :] * stride_angle_dim angle_vals = tl.load(angle_ptrs, mask=seq_mask[:, None] & dim_mask[None, :], other=0.0).to(tl.float32) angle_vals = tanh_approx(angle_vals) * PI # Load dt (CHUNK_SIZE,) dt_ptrs = DT + (chunk_start + seq_range) * stride_dt_seq dt_vals = tl.load(dt_ptrs, mask=seq_mask, other=0.0).to(tl.float32) # Compute vals = angle * dt vals = angle_vals * dt_vals[:, None] # Cumsum within chunk + add state from previous chunks chunk_cumsum = tl.cumsum(vals, axis=0) out_vals = chunk_cumsum + state[None, :] # Apply mod 2*pi for rotary angle normalization out_vals = out_vals - TWO_PI * tl.floor(out_vals / TWO_PI) # Store output out_ptrs = OUT + (chunk_start + seq_range[:, None]) * stride_out_seq + dim_range[None, :] * stride_out_dim tl.store(out_ptrs, out_vals, mask=seq_mask[:, None] & dim_mask[None, :]) # Update state: add chunk sum and apply mod 2*pi chunk_sum = tl.sum(vals, axis=0) state = state + chunk_sum state = state - TWO_PI * tl.floor(state / TWO_PI) # Store final state if requested if RETURN_OUTPUT_STATE: output_state_ptrs = OUTPUT_STATE + seq_idx * stride_output_state_seq + pid_h * stride_output_state_head + dim_range * stride_output_state_dim tl.store(output_state_ptrs, state, mask=dim_mask) def angle_dt_fwd( angle: Tensor, dt: Tensor, init_state: Optional[Tensor] = None, chunk_size: int = 64, return_output_state: bool = False, cu_seqlens: Optional[Tensor] = None, ) -> Tensor | Tuple[Tensor, Tensor]: """Forward pass for angle * dt cumsum. Args: angle: Angle tensor (batch, seqlen, nheads, dim) dt: Time delta tensor (batch, nheads, seqlen) init_state: Initial state (num_sequences, nheads, dim) or None chunk_size: Chunk size for chunked computation return_output_state: Whether to return final state cu_seqlens: Cumulative sequence lengths (num_sequences + 1,) for varlen mode Returns: If return_output_state=False: out: Cumulative output (batch, seqlen, nheads, dim) If return_output_state=True: Tuple of: out: Cumulative output (batch, seqlen, nheads, dim) output_state: Final state (num_sequences, nheads, dim) """ batch, seqlen, nheads, dim = angle.shape is_varlen = cu_seqlens is not None # Determine number of sequences if is_varlen: assert batch == 1, "Varlen mode requires batch=1" num_sequences = cu_seqlens.shape[0] - 1 else: num_sequences = batch assert dt.shape == (batch, nheads, seqlen), f"dt shape mismatch: {dt.shape}" if init_state is not None: assert init_state.shape == (num_sequences, nheads, dim), f"init_state shape mismatch: {init_state.shape}" out = torch.empty_like(angle) BLOCK_D = triton.next_power_of_2(dim) # Handle None init_state for kernel HAS_INIT_STATE = init_state is not None if not HAS_INIT_STATE: init_state = angle # dummy, won't be accessed stride_init = (0, 0, 0) else: stride_init = init_state.stride() # Handle output_state if return_output_state: output_state = torch.empty(num_sequences, nheads, dim, device=angle.device, dtype=angle.dtype) stride_output_state = output_state.stride() else: output_state = out # dummy, won't be accessed stride_output_state = (0, 0, 0) # Handle cu_seqlens if cu_seqlens is not None: stride_cu_seqlen = cu_seqlens.stride(0) else: cu_seqlens = angle # dummy, won't be accessed stride_cu_seqlen = 0 # Grid setup if is_varlen: grid = (nheads, batch, num_sequences) else: grid = (nheads, batch) angle_dt_fwd_kernel[grid]( out, output_state, angle, dt, init_state, cu_seqlens, out.stride(0), out.stride(1), out.stride(2), out.stride(3), stride_output_state[0], stride_output_state[1], stride_output_state[2], angle.stride(0), angle.stride(1), angle.stride(2), angle.stride(3), dt.stride(0), dt.stride(1), dt.stride(2), stride_init[0], stride_init[1], stride_init[2], stride_cu_seqlen, seqlen, dim, CHUNK_SIZE=chunk_size, BLOCK_D=BLOCK_D, HAS_INIT_STATE=HAS_INIT_STATE, RETURN_OUTPUT_STATE=return_output_state, IS_VARLEN=is_varlen, ) if return_output_state: return out, output_state return out # ----------------------------------------------------------------------------- # Backward kernel # ----------------------------------------------------------------------------- @triton.autotune( configs=[ triton.Config({}, num_stages=s, num_warps=w) for s in [1, 2, 3] for w in [2, 4, 8] ], key=["CHUNK_SIZE", "BLOCK_D", "HAS_INIT_STATE", "HAS_GRAD_OUTPUT_STATE", "IS_VARLEN"], ) @triton.jit def angle_dt_bwd_kernel( # Outputs GRAD_ANGLE, GRAD_DT, GRAD_INIT_STATE, # Inputs GRAD_OUT, GRAD_OUTPUT_STATE, ANGLE, DT, CU_SEQLENS, # Strides for GRAD_ANGLE (batch, seqlen, nheads, dim) stride_grad_angle_batch, stride_grad_angle_seq, stride_grad_angle_head, stride_grad_angle_dim, # Strides for GRAD_DT (batch, nheads, seqlen) stride_grad_dt_batch, stride_grad_dt_head, stride_grad_dt_seq, # Strides for GRAD_INIT_STATE (num_sequences, nheads, dim) stride_grad_init_seq, stride_grad_init_head, stride_grad_init_dim, # Strides for GRAD_OUT (batch, seqlen, nheads, dim) stride_grad_out_batch, stride_grad_out_seq, stride_grad_out_head, stride_grad_out_dim, # Strides for GRAD_OUTPUT_STATE (num_sequences, nheads, dim) stride_grad_output_state_seq, stride_grad_output_state_head, stride_grad_output_state_dim, # Strides for ANGLE (batch, seqlen, nheads, dim) stride_angle_batch, stride_angle_seq, stride_angle_head, stride_angle_dim, # Strides for DT (batch, nheads, seqlen) stride_dt_batch, stride_dt_head, stride_dt_seq, # Stride for CU_SEQLENS stride_cu_seqlen, # Dimensions seqlen, dim, # Meta-parameters CHUNK_SIZE: tl.constexpr, BLOCK_D: tl.constexpr, HAS_INIT_STATE: tl.constexpr, HAS_GRAD_OUTPUT_STATE: tl.constexpr, IS_VARLEN: tl.constexpr, ): pid_h = tl.program_id(0) pid_b = tl.program_id(1) # Handle varlen mode if IS_VARLEN: pid_seq = tl.program_id(2) seq_idx = pid_seq cu_seqlen_start = tl.load(CU_SEQLENS + pid_seq * stride_cu_seqlen).to(tl.int32) cu_seqlen_end = tl.load(CU_SEQLENS + (pid_seq + 1) * stride_cu_seqlen).to(tl.int32) seq_len = cu_seqlen_end - cu_seqlen_start seq_offset = cu_seqlen_start else: seq_idx = pid_b seq_len = seqlen seq_offset = 0 nchunks = tl.cdiv(seq_len, CHUNK_SIZE) # Offset base pointers by batch and head GRAD_ANGLE += pid_b * stride_grad_angle_batch + pid_h * stride_grad_angle_head + seq_offset * stride_grad_angle_seq GRAD_DT += pid_b * stride_grad_dt_batch + pid_h * stride_grad_dt_head + seq_offset * stride_grad_dt_seq GRAD_OUT += pid_b * stride_grad_out_batch + pid_h * stride_grad_out_head + seq_offset * stride_grad_out_seq ANGLE += pid_b * stride_angle_batch + pid_h * stride_angle_head + seq_offset * stride_angle_seq DT += pid_b * stride_dt_batch + pid_h * stride_dt_head + seq_offset * stride_dt_seq dim_range = tl.arange(0, BLOCK_D) dim_mask = dim_range < dim PI = 3.141592653589793 # Initialize gradient state from grad_output_state or zeros if HAS_GRAD_OUTPUT_STATE: grad_output_state_ptrs = GRAD_OUTPUT_STATE + seq_idx * stride_grad_output_state_seq + pid_h * stride_grad_output_state_head + dim_range * stride_grad_output_state_dim grad_state = tl.load(grad_output_state_ptrs, mask=dim_mask, other=0.0).to(tl.float32) else: grad_state = tl.zeros((BLOCK_D,), dtype=tl.float32) # Loop in reverse: derivative of cumsum is reverse cumsum for chunk_idx in range(nchunks - 1, -1, -1): chunk_start = chunk_idx * CHUNK_SIZE seq_range = tl.arange(0, CHUNK_SIZE) seq_mask = (chunk_start + seq_range) < seq_len # Load grad_out (CHUNK_SIZE, BLOCK_D) grad_out_ptrs = GRAD_OUT + (chunk_start + seq_range[:, None]) * stride_grad_out_seq + dim_range[None, :] * stride_grad_out_dim grad_out_vals = tl.load(grad_out_ptrs, mask=seq_mask[:, None] & dim_mask[None, :], other=0.0).to(tl.float32) # Reverse cumsum within chunk: rev_cumsum = total - cumsum + x # But we need to handle the mask properly for partial chunks chunk_sum = tl.sum(grad_out_vals, axis=0) fwd_cumsum = tl.cumsum(grad_out_vals, axis=0) rev_cumsum = chunk_sum[None, :] - fwd_cumsum + grad_out_vals # Add gradient from future chunks grad_vals = rev_cumsum + grad_state[None, :] # Load angle and dt angle_ptrs = ANGLE + (chunk_start + seq_range[:, None]) * stride_angle_seq + dim_range[None, :] * stride_angle_dim pretanh_angle_vals = tl.load(angle_ptrs, mask=seq_mask[:, None] & dim_mask[None, :], other=0.0).to(tl.float32) angle_vals = tanh_approx(pretanh_angle_vals) * PI dt_ptrs = DT + (chunk_start + seq_range) * stride_dt_seq dt_vals = tl.load(dt_ptrs, mask=seq_mask, other=0.0).to(tl.float32) # Compute gradients: out = angle * dt grad_angle_vals = grad_vals * dt_vals[:, None] * PI * sech2_approx(pretanh_angle_vals) grad_dt_vals = tl.sum(grad_vals * angle_vals, axis=1) # Store gradients grad_angle_ptrs = GRAD_ANGLE + (chunk_start + seq_range[:, None]) * stride_grad_angle_seq + dim_range[None, :] * stride_grad_angle_dim tl.store(grad_angle_ptrs, grad_angle_vals, mask=seq_mask[:, None] & dim_mask[None, :]) grad_dt_ptrs = GRAD_DT + (chunk_start + seq_range) * stride_grad_dt_seq tl.store(grad_dt_ptrs, grad_dt_vals, mask=seq_mask) # Update state for previous chunk grad_state = grad_state + chunk_sum # Store gradient for init_state if provided if HAS_INIT_STATE: grad_init_ptrs = GRAD_INIT_STATE + seq_idx * stride_grad_init_seq + pid_h * stride_grad_init_head + dim_range * stride_grad_init_dim tl.store(grad_init_ptrs, grad_state, mask=dim_mask) def angle_dt_bwd( grad_out: Tensor, angle: Tensor, dt: Tensor, has_init_state: bool = False, chunk_size: int = 64, grad_output_state: Optional[Tensor] = None, cu_seqlens: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Optional[Tensor]]: """Backward pass for angle * dt cumsum. Args: grad_out: Gradient of output (batch, seqlen, nheads, dim) angle: Angle tensor (batch, seqlen, nheads, dim) dt: Time delta tensor (batch, nheads, seqlen) has_init_state: Whether init_state was provided in forward chunk_size: Chunk size for chunked computation grad_output_state: Gradient of output state (num_sequences, nheads, dim) or None cu_seqlens: Cumulative sequence lengths (num_sequences + 1,) for varlen mode Returns: grad_angle: Gradient for angle (batch, seqlen, nheads, dim) grad_dt: Gradient for dt (batch, nheads, seqlen) grad_init_state: Gradient for init_state (num_sequences, nheads, dim) or None """ batch, seqlen, nheads, dim = angle.shape is_varlen = cu_seqlens is not None # Determine number of sequences if is_varlen: assert batch == 1, "Varlen mode requires batch=1" num_sequences = cu_seqlens.shape[0] - 1 else: num_sequences = batch grad_angle = torch.empty_like(angle) grad_dt = torch.empty_like(dt) BLOCK_D = triton.next_power_of_2(dim) # Handle init_state gradient if has_init_state: grad_init_state = torch.empty(num_sequences, nheads, dim, device=angle.device, dtype=torch.float32) stride_grad_init = grad_init_state.stride() else: grad_init_state = None stride_grad_init = (0, 0, 0) grad_init_dummy = grad_angle # dummy pointer # Handle grad_output_state HAS_GRAD_OUTPUT_STATE = grad_output_state is not None if not HAS_GRAD_OUTPUT_STATE: grad_output_state = grad_angle # dummy, won't be accessed stride_grad_output_state = (0, 0, 0) else: stride_grad_output_state = grad_output_state.stride() # Handle cu_seqlens if cu_seqlens is not None: stride_cu_seqlen = cu_seqlens.stride(0) else: cu_seqlens = angle # dummy, won't be accessed stride_cu_seqlen = 0 # Grid setup if is_varlen: grid = (nheads, batch, num_sequences) else: grid = (nheads, batch) angle_dt_bwd_kernel[grid]( grad_angle, grad_dt, grad_init_state if has_init_state else grad_init_dummy, grad_out, grad_output_state, angle, dt, cu_seqlens, grad_angle.stride(0), grad_angle.stride(1), grad_angle.stride(2), grad_angle.stride(3), grad_dt.stride(0), grad_dt.stride(1), grad_dt.stride(2), stride_grad_init[0], stride_grad_init[1], stride_grad_init[2], grad_out.stride(0), grad_out.stride(1), grad_out.stride(2), grad_out.stride(3), stride_grad_output_state[0], stride_grad_output_state[1], stride_grad_output_state[2], angle.stride(0), angle.stride(1), angle.stride(2), angle.stride(3), dt.stride(0), dt.stride(1), dt.stride(2), stride_cu_seqlen, seqlen, dim, CHUNK_SIZE=chunk_size, BLOCK_D=BLOCK_D, HAS_INIT_STATE=has_init_state, HAS_GRAD_OUTPUT_STATE=HAS_GRAD_OUTPUT_STATE, IS_VARLEN=is_varlen, ) return grad_angle, grad_dt, grad_init_state ================================================ FILE: mamba_ssm/ops/triton/mamba3/mamba3_mimo_rotary_step.py ================================================ # Copyright (c) 2025, Tri Dao. # We need a pretty recent version of triton to support tuples. 3.3 definitely will work, # idk which is the minimum version. import math from typing import Optional, Tuple import torch import triton import triton.language as tl import triton.testing #from flash_attn.cute.benchmark import pytorch_profiler @triton.jit def rotary_qk_inference_kernel( OUT_Q, # Pointers to matrices OUT_K, OUT_ANGLE_STATE, Q, K, ANGLE_STATE, ANGLE_PROJ, DT, BIAS_Q, BIAS_K, nheads, headdim, stride_out_q, # (batch, mimo_dim, nheads, headdim) stride_out_k, # (batch, mimo_dim, nheads, headdim) stride_out_angle_state, # (batch, nheads, rotary_dim // 2) stride_q, # (batch, mimo_dim, nheads, headdim) stride_k, # (batch, mimo_dim, nheads, headdim) stride_angle_state, # (batch, nheads, rotary_dim // 2) stride_angle_proj, # (batch, nheads, rotary_dim // 2) stride_dt, # (batch, nheads) stride_bias_q, # (mimo_dim, nheads, headdim) stride_bias_k, # (mimo_dim, nheads, headdim) # Meta-parameters ROTARY_DIM: tl.constexpr, CONJUGATE: tl.constexpr, HAS_BIAS_Q: tl.constexpr, HAS_BIAS_K: tl.constexpr, MIMO_DIM: tl.constexpr, BLOCK_D: tl.constexpr, # headdim, no chunking ROTATE_PAIRWISE: tl.constexpr, # If true, rotate every pair of dimensions together. Otherwise, rotate the first half and second half separately (like in the original RoPE paper) ): pid_nheads = tl.program_id(axis=0) # heads pid_batch = tl.program_id(axis=1) Q = Q + pid_batch * stride_q[0] + pid_nheads * stride_q[2] K = K + pid_batch * stride_k[0] + pid_nheads * stride_k[2] ANGLE_STATE = ANGLE_STATE + pid_batch * stride_angle_state[0] + pid_nheads * stride_angle_state[1] # FIX: [1] ANGLE_PROJ = ANGLE_PROJ + pid_batch * stride_angle_proj[0] + pid_nheads * stride_angle_proj[1] # FIX: [1] DT = DT + pid_batch * stride_dt[0] + pid_nheads * stride_dt[1] OUT_Q = OUT_Q + pid_batch * stride_out_q[0] + pid_nheads * stride_out_q[2] OUT_K = OUT_K + pid_batch * stride_out_k[0] + pid_nheads * stride_out_k[2] OUT_ANGLE_STATE = OUT_ANGLE_STATE + pid_batch * stride_out_angle_state[0] + pid_nheads * stride_out_angle_state[1] # FIX: [1] rm = tl.arange(0, MIMO_DIM) rd = tl.arange(0, BLOCK_D) rd_half = tl.arange(0, BLOCK_D // 2) # Load angle and compute cos/sin (same for both q and k) ANGLE_STATE = ANGLE_STATE + rd_half * stride_angle_state[2] # (rotary_dim // 2) mask_angle = rd_half < ROTARY_DIM // 2 angle_state = tl.load(ANGLE_STATE, mask=mask_angle, other=0.0).to(tl.float32) ANGLE_PROJ = ANGLE_PROJ + rd_half * stride_angle_proj[2] # (rotary_dim // 2) angle_proj = tl.load(ANGLE_PROJ, mask=mask_angle, other=0.0).to(tl.float32) dt = tl.load(DT, mask=True, other=0.0).to(tl.float32) # Match angle_dt: tanh(angle_proj) * dt * pi angle_proj = tl.sigmoid(2.0 * angle_proj) * 2.0 - 1.0 # tanh angle = angle_state + angle_proj * dt * 3.141592653589793 # (rotary_dim // 2) OUT_ANGLE_STATE = OUT_ANGLE_STATE + rd_half * stride_out_angle_state[2] tl.store(OUT_ANGLE_STATE, angle, mask=mask_angle) angle = angle[None, :] # (1, rotary_dim // 2) for mimo_dim broadcasting cos = tl.cos(angle) sin = tl.sin(angle) if CONJUGATE: sin = -sin # Process Q tensor Q = Q + (rm[:, None] * stride_q[1] + rd[None, :] * stride_q[3]) OUT_Q = OUT_Q + (rm[:, None] * stride_out_q[1] + rd[None, :] * stride_out_q[3]) mask = rd[None, :] < headdim q = tl.load(Q, mask=mask, other=0.0).to(tl.float32) # (mimo_dim, headdim) # Add bias to Q if present if HAS_BIAS_Q: BIAS_Q = BIAS_Q + pid_nheads * stride_bias_q[1] BIAS_Q = BIAS_Q + (rm[:, None] * stride_bias_q[0] + rd[None, :] * stride_bias_q[2]) bias_q = tl.load(BIAS_Q, mask=mask, other=0.0).to(tl.float32) q = q + bias_q if ROTATE_PAIRWISE: # Apply rotary to Q q0, q1 = tl.split(tl.reshape(q, [MIMO_DIM, BLOCK_D // 2, 2])) qo0 = q0 * cos - q1 * sin qo1 = q0 * sin + q1 * cos qo = tl.reshape(tl.join(qo0, qo1), [MIMO_DIM, BLOCK_D]) tl.store(OUT_Q, qo, mask=mask) else: # Apply rotary to Q q_reshaped = tl.reshape(q, [MIMO_DIM, 2, BLOCK_D // 2]) q_permuted = tl.permute(q_reshaped, (0, 2, 1)) # (mimo_dim, block_d // 2, 2) q0, q1 = tl.split(q_permuted) qo0 = q0 * cos - q1 * sin qo1 = q0 * sin + q1 * cos q_joined = tl.join(qo0, qo1) q_final = tl.permute(q_joined, (0, 2, 1)) # (mimo_dim, 2, block_d // 2) qo = tl.reshape(q_final, [MIMO_DIM, BLOCK_D]) tl.store(OUT_Q, qo, mask=mask) # Process K tensor K = K + (rm[:, None] * stride_k[1] + rd[None, :] * stride_k[3]) OUT_K = OUT_K + (rm[:, None] * stride_out_k[1] + rd[None, :] * stride_out_k[3]) k = tl.load(K, mask=mask, other=0.0).to(tl.float32) # Add bias to K if present if HAS_BIAS_K: BIAS_K = BIAS_K + pid_nheads * stride_bias_k[1] BIAS_K = BIAS_K + (rm[:, None] * stride_bias_k[0] + rd[None, :] * stride_bias_k[2]) bias_k = tl.load(BIAS_K, mask=mask, other=0.0).to(tl.float32) k = k + bias_k if ROTATE_PAIRWISE: # Apply rotary to K k0, k1 = tl.split(tl.reshape(k, [MIMO_DIM, BLOCK_D // 2, 2])) ko0 = k0 * cos - k1 * sin ko1 = k0 * sin + k1 * cos ko = tl.reshape(tl.join(ko0, ko1), [MIMO_DIM, BLOCK_D]) tl.store(OUT_K, ko, mask=mask) else: # Apply rotary to K k_reshaped = tl.reshape(k, [MIMO_DIM, 2, BLOCK_D // 2]) k_permuted = tl.permute(k_reshaped, (0, 2, 1)) # (mimo_dim, block_d // 2, 2) k0, k1 = tl.split(k_permuted) ko0 = k0 * cos - k1 * sin ko1 = k0 * sin + k1 * cos k_joined = tl.join(ko0, ko1) k_final = tl.permute(k_joined, (0, 2, 1)) # (mimo_dim, 2, block_d // 2) ko = tl.reshape(k_final, [MIMO_DIM, BLOCK_D]) tl.store(OUT_K, ko, mask=mask) def apply_rotary_qk_inference_fwd( q: torch.Tensor, k: torch.Tensor, angle_state: torch.Tensor, angle_proj: torch.Tensor, dt: torch.Tensor, bias_q: Optional[torch.Tensor] = None, bias_k: Optional[torch.Tensor] = None, inplace=False, conjugate=False, rotate_pairwise=True, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Apply rotary embedding to both q and k tensors using the same angle. Also computes output angle state for next step. Arguments: q: (batch, mimo_dim, nheads, headdim) k: (batch, mimo_dim, nheads, headdim) angle_state: (batch, nheads, rotary_dim / 2) angle_proj: (batch, nheads, rotary_dim / 2) dt: (batch, nheads) bias_q: Optional (mimo_dim, nheads, headdim) - bias to add to q before rotary bias_k: Optional (mimo_dim, nheads, headdim) - bias to add to k before rotary Returns: (q_out, k_out, angle_state_out): q_out and k_out are (batch, mimo_dim, nheads, headdim), angle_state_out is (batch, nheads, rotary_dim / 2) """ batch, mimo_dim, nheads, headdim = q.shape assert headdim % 2 == 0 assert k.shape == q.shape, f"k shape {k.shape} != q shape {q.shape}" rotary_dim = angle_state.shape[-1] * 2 assert angle_state.shape == (batch, nheads, rotary_dim // 2) assert angle_state.shape == angle_proj.shape assert dt.shape == (batch, nheads) assert rotary_dim <= headdim, "rotary_dim must be <= headdim" assert headdim <= 256, "Only support headdim <= 256" if bias_q is not None: assert bias_q.shape == (mimo_dim, nheads, headdim), f"bias_q shape {bias_q.shape} != (mimo_dim, nheads, headdim) {(mimo_dim, nheads, headdim)}" bias_q = bias_q.contiguous() if bias_k is not None: assert bias_k.shape == (mimo_dim, nheads, headdim), f"bias_k shape {bias_k.shape} != (mimo_dim, nheads, headdim) {(mimo_dim, nheads, headdim)}" bias_k = bias_k.contiguous() output_q = torch.empty_like(q) if not inplace else q output_k = torch.empty_like(k) if not inplace else k output_angle_state = torch.empty_like(angle_state) if not inplace else angle_state grid = lambda META: (nheads, batch) # noqa with torch.cuda.device(q.device.index): torch.library.wrap_triton(rotary_qk_inference_kernel)[grid]( output_q, # data ptrs output_k, output_angle_state, q, k, angle_state, angle_proj, dt, bias_q, bias_k, nheads, headdim, output_q.stride(), # output strides tuples output_k.stride(), output_angle_state.stride(), q.stride(), # input strides tuples k.stride(), angle_state.stride(), angle_proj.stride(), dt.stride(), bias_q.stride() if bias_q is not None else (0, 0, 0), bias_k.stride() if bias_k is not None else (0, 0, 0), rotary_dim, conjugate, bias_q is not None, bias_k is not None, MIMO_DIM=mimo_dim, BLOCK_D=triton.next_power_of_2(headdim), num_warps=8, # important, 4 warps is slower if we compute qk_sum ROTATE_PAIRWISE=rotate_pairwise, ) return output_q, output_k, output_angle_state def apply_rotary_qk_inference_reference( q: torch.Tensor, # (B, R, N, D) k: torch.Tensor, # (B, R, N, D) angle_state: torch.Tensor, # (B, N, S) S: num_rope_angles angle_proj: torch.Tensor, # (B, N, S) dt: torch.Tensor, # (B, N) bias_q: Optional[torch.Tensor] = None, # (R, N, D) bias_k: Optional[torch.Tensor] = None, # (R, N, D) conjugate=False, rotate_pairwise=True, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Reference PyTorch implementation for QK rotary embedding with qk_sum.""" batch, mimo_dim, nheads, headdim = q.shape rotary_dim = angle_state.shape[-1] * 2 # Match angle_dt: tanh(angle_proj) * dt * pi angle_proj = torch.tanh(angle_proj) angle = angle_state + angle_proj * dt[:, :, None] * math.pi # (B, N, S) angle_state_new = angle angle = angle.unsqueeze(1).expand(-1, mimo_dim, -1, -1) # (B, R, N, S) # Add biases if present if bias_q is not None: q = q + bias_q[None, :, :, :] # Broadcast bias_q if bias_k is not None: k = k + bias_k[None, :, :, :] # Broadcast bias_k # Only apply rotary to the rotary dimensions q_rot = q[..., :rotary_dim] # (B, R, N, rotary_dim) q_pass = q[..., rotary_dim:] k_rot = k[..., :rotary_dim] k_pass = k[..., rotary_dim:] # Compute cos and sin from angle (same for both q and k) cos = torch.cos(angle) # (B, N, S) sin = torch.sin(angle) if conjugate: sin = -sin if rotate_pairwise: # Interleaved rotary: pairs are (x0,x1), (x2,x3), ... q_rot = q_rot.reshape(batch, mimo_dim, nheads, rotary_dim // 2, 2) q0, q1 = q_rot[..., 0], q_rot[..., 1] k_rot = k_rot.reshape(batch, mimo_dim, nheads, rotary_dim // 2, 2) k0, k1 = k_rot[..., 0], k_rot[..., 1] qo0 = q0 * cos - q1 * sin qo1 = q0 * sin + q1 * cos ko0 = k0 * cos - k1 * sin ko1 = k0 * sin + k1 * cos qout_rot = torch.stack([qo0, qo1], dim=-1).reshape(batch, mimo_dim, nheads, rotary_dim) kout_rot = torch.stack([ko0, ko1], dim=-1).reshape(batch, mimo_dim, nheads, rotary_dim) # Concatenate rotated and pass-through dimensions if rotary_dim < headdim: q_out = torch.cat([qout_rot, q_pass], dim=-1) k_out = torch.cat([kout_rot, k_pass], dim=-1) else: q_out = qout_rot k_out = kout_rot else: # Halved rotary: split full headdim in half, pairs are (dim_i, dim_{i+D/2}) # Matches kernel which splits BLOCK_D in half; cos(0)=1/sin(0)=0 gives identity # for pairs beyond rotary_dim//2 half = headdim // 2 q0, q1 = q[..., :half], q[..., half:] k0, k1 = k[..., :half], k[..., half:] # Pad cos/sin from rotary_dim//2 to headdim//2 with cos=1, sin=0 rdim_half = rotary_dim // 2 if half > rdim_half: pad_shape = list(cos.shape) pad_shape[-1] = half - rdim_half cos = torch.cat([cos, torch.ones(pad_shape, device=cos.device, dtype=cos.dtype)], dim=-1) sin = torch.cat([sin, torch.zeros(pad_shape, device=sin.device, dtype=sin.dtype)], dim=-1) qo0 = q0 * cos - q1 * sin qo1 = q0 * sin + q1 * cos ko0 = k0 * cos - k1 * sin ko1 = k0 * sin + k1 * cos q_out = torch.cat([qo0, qo1], dim=-1) k_out = torch.cat([ko0, ko1], dim=-1) return q_out, k_out, angle_state_new def test_correctness_qk_inference(): print("Testing QK Inference correctness...") device = "cuda" torch.manual_seed(2025) dtype_qk = torch.bfloat16 # common inference dtype dtype_ang = torch.float32 def run_case(B, R, N, D, RD, with_bias, conjugate, expanded_heads, rotate_pairwise): assert D % 2 == 0 # Build q,k with optional head broadcasting q0 = torch.randn(B, R, 1 if expanded_heads else N, D, device=device, dtype=dtype_qk) k0 = torch.randn(B, R, 1 if expanded_heads else N, D, device=device, dtype=dtype_qk) q = q0.expand(B, R, N, D) if expanded_heads else q0 k = k0.expand(B, R, N, D) if expanded_heads else k0 angle_state = torch.randn(B, N, RD // 2, device=device, dtype=dtype_ang) angle_proj = torch.randn(B, N, RD // 2, device=device, dtype=dtype_ang) dt = torch.randn(B, N, device=device, dtype=dtype_ang) bias_q = torch.randn(R, N, D, device=device, dtype=dtype_qk) if with_bias else None bias_k = torch.randn(R, N, D, device=device, dtype=dtype_qk) if with_bias else None # Reference q_ref, k_ref, updated_angle_ref = apply_rotary_qk_inference_reference( q, k, angle_state, angle_proj, dt, bias_q=bias_q, bias_k=bias_k, conjugate=conjugate, rotate_pairwise=rotate_pairwise, ) # Kernel q_out, k_out, updated_angle = apply_rotary_qk_inference_fwd( q, k, angle_state, angle_proj, dt, bias_q=bias_q, bias_k=bias_k, conjugate=conjugate, inplace=False, rotate_pairwise=rotate_pairwise, ) def _chk(name, a, b, atol=1e-1, rtol=1e-1): diff = (a - b).abs().max().item() if not torch.allclose(a, b, atol=atol, rtol=rtol): raise AssertionError(f"{name} mismatch: max|Δ|={diff:.3e} got={tuple(a.shape)} ref={tuple(b.shape)}") print(f" {name:18s} ok max|Δ|={diff:.2e}") print(f"\nInference [{B=}, {R=}, {N=}, {D=}, {RD=} | bias={with_bias}, conj={conjugate}, expanded={expanded_heads}, pairwise={rotate_pairwise}]") _chk("q_out", q_out.float(), q_ref.float(), atol=1e-1, rtol=1e-1) _chk("k_out", k_out.float(), k_ref.float(), atol=1e-1, rtol=1e-1) _chk("updated_angle", updated_angle, updated_angle_ref, atol=1e-1, rtol=1e-1) # standard config B, R, N, D, RD = 2, 4, 64, 128, 64 for with_bias in [False, True]: for conjugate in [False, True]: for expanded in [True, False]: for pairwise in [True, False]: run_case(B, R, N, D, RD, with_bias, conjugate, expanded, pairwise) # light shape sweep for (BB, RR, NN, DD, RRd) in [ (1, 2, 64, 64, 32), (3, 1, 32, 128, 64), (2, 8, 32, 128, 64), ]: for pairwise in [True, False]: run_case(BB, RR, NN, DD, RRd, with_bias=True, conjugate=False, expanded_heads=True, rotate_pairwise=pairwise) run_case(BB, RR, NN, DD, RRd, with_bias=False, conjugate=True, expanded_heads=False, rotate_pairwise=pairwise) print("\nAll QK Inference tests passed! ✓") if __name__ == "__main__": test_correctness_qk_inference() ================================================ FILE: mamba_ssm/ops/triton/mamba3/mamba3_mimo_utils.py ================================================ """ Fused Triton kernels for Mamba3 backward pass ddt computation. This module implements fused kernels that combine three separate backward operations: 1. bwd_segsum_ddt_from_dSSdA - Complex 2D segsum operation 2. bwd_ddt_from_ddA_cs_rev - Forward exclusive cumsum operation 3. bwd_ddt_from_ddA_cs - Reverse cumsum operation The fusion reduces memory traffic and kernel launch overhead. """ import torch import triton import triton.language as tl import math from typing import Optional, Tuple # Constants LOG2 = math.log(2.0) NEG_LOG2E = -math.log2(math.e) # ============================================================================ # Kernel 1: Fused Cumsum Operations (forward exclusive + reverse) # ============================================================================ @triton.autotune( configs=[ triton.Config({}, num_stages=s, num_warps=w) for s in [1, 2, 3] for w in [4, 8] ], key=["CHUNK_SIZE"], restore_value=["ddt_out_ptr"], ) @triton.jit def bwd_dadt_cumsum_fused_kernel( ddA_cs_ptr, # [B, H, S] ddA_cs_rev_ptr, # [B, H, S] dA_cs_ptr, # [B, H, S] dA_cs_rev_ptr, # [B, H, S] ddt_out_ptr, # [B, H, S] - output stride_batch, stride_head, stride_seq, B: tl.constexpr, H: tl.constexpr, S: tl.constexpr, CHUNK_SIZE: tl.constexpr, ): """ Fused kernel that computes contributions from: - bwd_ddt_from_ddA_cs: reverse cumsum operation - bwd_ddt_from_ddA_cs_rev: forward exclusive cumsum operation Each program handles one chunk for one (batch, head) pair. Grid: (B, H, nchunks) """ # Get program indices pid_batch = tl.program_id(0) pid_head = tl.program_id(1) pid_chunk = tl.program_id(2) # Calculate chunk boundaries chunk_start = pid_chunk * CHUNK_SIZE offs_seq = chunk_start + tl.arange(0, CHUNK_SIZE) mask = offs_seq < S # Compute base offset for this (batch, head) pair base_offset = pid_batch * stride_batch + pid_head * stride_head # Load chunk data for all four input tensors ddA_cs = tl.load(ddA_cs_ptr + base_offset + offs_seq * stride_seq, mask=mask, other=0.0) ddA_cs_rev = tl.load(ddA_cs_rev_ptr + base_offset + offs_seq * stride_seq, mask=mask, other=0.0) dA_cs = tl.load(dA_cs_ptr + base_offset + offs_seq * stride_seq, mask=mask, other=0.0) dA_cs_rev = tl.load(dA_cs_rev_ptr + base_offset + offs_seq * stride_seq, mask=mask, other=0.0) # ======================================================================== # Operation 1: bwd_ddt_from_ddA_cs (reverse cumsum) # ======================================================================== # Scale by log(2) * exp2(dA_cs) # Use literal constants instead of globals scaled_ddA_cs = tl.exp(dA_cs) * ddA_cs # LOG2 # Apply reverse cumsum within chunk ddt_cs = tl.cumsum(scaled_ddA_cs, axis=0, reverse=True) # ======================================================================== # Operation 2: bwd_ddt_from_ddA_cs_rev (forward exclusive cumsum) # ======================================================================== # Scale by log(2) * exp2(dA_cs_rev) # Use literal constants instead of globals scaled_ddA_cs_rev = tl.exp(dA_cs_rev) * ddA_cs_rev # LOG2 # Apply forward cumsum within chunk (inclusive) ddt_cs_rev_inclusive = tl.cumsum(scaled_ddA_cs_rev, axis=0) # Roll one to the right: i = tl.arange(0, CHUNK_SIZE)[:, None] # [N,1] j = tl.arange(0, CHUNK_SIZE)[None, :] # [1,N] S = (i == j + 1) # strictly lower diagonal (one below main) ddt_cs_rev_exclusive = tl.sum(tl.where(S, ddt_cs_rev_inclusive, 0), axis=1) # # Convert to exclusive cumsum # # Exclusive cumsum: output[i] = sum(input[0:i]) # # Inclusive cumsum: cumsum[i] = sum(input[0:i+1]) # # Therefore: exclusive[i] = inclusive[i] - input[i] # # Which is: exclusive[i] = cumsum[i] - scaled_ddA_cs_rev[i] # ddt_cs_rev_shifted = ddt_cs_rev_inclusive - scaled_ddA_cs_rev # ======================================================================== # Combine contributions and apply final scaling # ======================================================================== # Use literal constant instead of global ddt_total = ddt_cs + ddt_cs_rev_exclusive # Store result tl.store(ddt_out_ptr + base_offset + offs_seq * stride_seq, ddt_total, mask=mask) # ============================================================================ # Kernel 2: Segsum Operation with 2D Matrix Processing # ============================================================================ @triton.autotune( configs=[ triton.Config({}, num_stages=s, num_warps=w) for s in [2, 3] for w in [4, 8] ], key=["CHUNK_SIZE"], restore_value=["ddt_out_ptr"], ) @triton.jit def bwd_segsum_dadt_kernel( dSSdA_ptr, # [B, H, nchunks, C, C] SSdA_cs_ptr, # [B, H, S] ddt_out_ptr, # [B, H, S] - accumulated output stride_dSSdA_batch, stride_dSSdA_head, stride_dSSdA_chunk, stride_dSSdA_row, stride_dSSdA_col, stride_SSdA_batch, stride_SSdA_head, stride_SSdA_chunk, stride_SSdA_row, stride_SSdA_col, stride_ddt_batch, stride_ddt_head, stride_ddt_seq, B: tl.constexpr, H: tl.constexpr, nchunks: tl.constexpr, C: tl.constexpr, CHUNK_SIZE: tl.constexpr, ): """ Kernel for bwd_segsum_ddt_from_dSSdA operation. Matches the reference implementation: 1. Permute dSSdA last two dims 2. Compute seg = dA_cs[i] - dA_cs[j] 3. Scale by log(2) * exp2(seg) 4. Reverse cumsum along dim -2 (column-wise for each row) 5. Apply lower triangular mask (i > j) 6. Sum along dim -1 (sum over j for each i) Each program handles one chunk for one (batch, head) pair. Grid: (B, H, nchunks) """ # Get program indices pid_batch = tl.program_id(0) pid_head = tl.program_id(1) pid_chunk = tl.program_id(2) # Calculate chunk boundaries chunk_start = pid_chunk * CHUNK_SIZE offs_c = tl.arange(0, CHUNK_SIZE) offs_seq = chunk_start + offs_c # Load dA_cs for this chunk [C] # dA_cs_offset = pid_batch * stride_dA_batch + pid_head * stride_dA_head # dA_cs_chunk = tl.load(dA_cs_ptr + dA_cs_offset + offs_seq * stride_dA_seq) # Base offset for dSSdA matrix [nchunks, C, C] dSSdA_offset = dSSdA_ptr + (pid_batch * stride_dSSdA_batch + pid_head * stride_dSSdA_head + pid_chunk * stride_dSSdA_chunk) SSdA_offset = SSdA_cs_ptr + (pid_batch * stride_SSdA_batch + pid_head * stride_SSdA_head + pid_chunk * stride_SSdA_chunk) ddt_ptrs = ddt_out_ptr + (pid_batch * stride_ddt_batch + pid_head * stride_ddt_head + offs_seq * stride_ddt_seq) # NOTE: dSSdA is actually the transpose corresponding to seq_k \time seq_q dSSdA_block = tl.load(dSSdA_offset + offs_c[:, None]*stride_dSSdA_col + offs_c[None, :]*stride_dSSdA_row) SSdA_block = tl.load(SSdA_offset + offs_c[:, None]*stride_SSdA_row + offs_c[None, :]*stride_SSdA_col) dSSdA_block = dSSdA_block * tl.exp(SSdA_block) dSSdA_block = tl.cumsum(dSSdA_block, axis=0, reverse=True) offs_i = tl.arange(0, CHUNK_SIZE)[:, None] offs_j = tl.arange(0, CHUNK_SIZE)[None, :] SS_mask = offs_i > offs_j dSSdA = tl.where(SS_mask, dSSdA_block, 0.0) ddt_chunk = tl.load(ddt_ptrs) ddt_chunk += tl.sum(dSSdA, axis=1) tl.store(ddt_ptrs, ddt_chunk) # ============================================================================ # Kernel 3: backwards from gamma terms to trap # ============================================================================ @triton.autotune( configs=[ triton.Config({}, num_stages=s, num_warps=w) for s in [2, 3] for w in [4, 8] ], key=["CHUNK_SIZE"], ) @triton.jit def bwd_dtrap_ddt_kernel( trap_ptr, dt_ptr, dfactor_ptr, dgamma_diag_ptr, ddt_ptr, dtrap_ptr, stride_trap_batch, stride_trap_head, stride_trap_seq, stride_dt_batch, stride_dt_head, stride_dt_seq, stride_dfactor_batch, stride_dfactor_head, stride_dfactor_seq, stride_dgamma_diag_batch, stride_dgamma_diag_head, stride_dgamma_diag_seq, stride_ddt_batch, stride_ddt_head, stride_ddt_seq, stride_dtrap_batch, stride_dtrap_head, stride_dtrap_seq, SEQLEN: tl.constexpr, CHUNK_SIZE: tl.constexpr, ): # Get program indices pid_batch = tl.program_id(0) pid_head = tl.program_id(1) pid_chunk = tl.program_id(2) # Calculate chunk boundaries chunk_start = pid_chunk * CHUNK_SIZE offs_c = tl.arange(0, CHUNK_SIZE) offs_seq = chunk_start + offs_c trap_offset = pid_batch*stride_trap_batch + pid_head*stride_trap_head dt_offset = pid_batch*stride_dt_batch + pid_head*stride_dt_head dfactor_offset = pid_batch*stride_dfactor_batch + pid_head*stride_dfactor_head dgamma_diag_offset = pid_batch*stride_dgamma_diag_batch + pid_head*stride_dgamma_diag_head strap_block = tl.load( trap_ptr + trap_offset + (offs_seq + 1)*stride_trap_seq, mask=(offs_seq + 1) < SEQLEN, other=0.0 ) sdt_block = tl.load( dt_ptr + dt_offset + (offs_seq + 1)*stride_dt_seq, mask=(offs_seq + 1) < SEQLEN, other=0.0 ) trap_block = tl.load( trap_ptr + trap_offset + offs_seq * stride_trap_seq, mask=offs_seq < SEQLEN, other=0.0 ) dt_block = tl.load( dt_ptr + dt_offset + offs_seq * stride_dt_seq, mask=offs_seq < SEQLEN, other=0.0 ) dfactor_block = tl.load( dfactor_ptr + dfactor_offset + offs_seq * stride_dfactor_seq, mask=offs_seq < SEQLEN, other=0.0 ) dgamma_diag_input_block = tl.load( dgamma_diag_ptr + dgamma_diag_offset + offs_seq * stride_dgamma_diag_seq, mask=offs_seq < SEQLEN, other=0.0 ) # dgamma and dsgamma for current positions dgamma_block = dfactor_block + dgamma_diag_input_block dsgamma_block = dfactor_block #+ dsgamma_input_block # dsdt and dstrap for current positions (using shifted strap/sdt) dsdt_block = tl.sigmoid(-strap_block.to(tl.float32)) * dsgamma_block dstrap_block = -sdt_block * dsgamma_block # Compute dsdt/dstrap at previous position for cross-chunk shift prev_seq = chunk_start - 1 prev_mask = prev_seq >= 0 prev_dgamma = tl.load( dfactor_ptr + dfactor_offset + prev_seq * stride_dfactor_seq, mask=prev_mask, other=0.0 ) # prev_dsgamma_input = tl.load( # dsgamma_ptr + dsgamma_offset + prev_seq * stride_dsgamma_seq, # mask=prev_mask, other=0.0 # ) prev_dsgamma = prev_dgamma #+ prev_dsgamma_input prev_strap = tl.load( trap_ptr + trap_offset + chunk_start * stride_trap_seq, mask=chunk_start < SEQLEN, other=0.0 ) prev_sdt = tl.load( dt_ptr + dt_offset + chunk_start * stride_dt_seq, mask=chunk_start < SEQLEN, other=0.0 ) prev_dsdt = tl.sigmoid(-prev_strap.to(tl.float32)) * prev_dsgamma prev_dstrap = -prev_sdt * prev_dsgamma # Shift right by one within chunk: out[i] = in[i-1], with cross-chunk value at i=0 offs_i = tl.arange(0, CHUNK_SIZE)[:, None] offs_j = tl.arange(0, CHUNK_SIZE)[None, :] shift_mask = offs_i == (offs_j + 1) dsdt_shift = tl.sum(tl.where(shift_mask, dsdt_block[None, :], 0.0), axis=1) dstrap_shift = tl.sum(tl.where(shift_mask, dstrap_block[None, :], 0.0), axis=1) offs = tl.arange(0, CHUNK_SIZE) dsdt_shift = tl.where(offs == 0, prev_dsdt, dsdt_shift) dstrap_shift = tl.where(offs == 0, prev_dstrap, dstrap_shift) # Add dgamma path ddt_out = dsdt_shift + dgamma_block * tl.sigmoid(trap_block.to(tl.float32)) dtrap_out = dstrap_shift + dgamma_block * dt_block dtrap_out *= tl.sigmoid(trap_block.to(tl.float32)) * tl.sigmoid(-trap_block.to(tl.float32)) ddt_ptrs = ddt_ptr + (pid_batch * stride_ddt_batch + pid_head * stride_ddt_head + offs_seq * stride_ddt_seq) dtrap_ptrs = dtrap_ptr + (pid_batch * stride_dtrap_batch + pid_head * stride_dtrap_head + offs_seq * stride_dtrap_seq) tl.store(ddt_ptrs, ddt_out, mask=offs_seq < SEQLEN) tl.store(dtrap_ptrs, dtrap_out, mask=offs_seq < SEQLEN) # ============================================================================ # Kernel 4: compute da_cs, da_cs_rev, segsum from da # ============================================================================ @triton.autotune( configs=[ triton.Config({}, num_stages=s, num_warps=w) for s in [2, 3] for w in [4, 8] ], key=["CHUNK_SIZE"], ) @triton.jit def dacs_segsum_kernel( da_ptr, da_cs_ptr, da_cs_rev_ptr, segsum_ptr, stride_da_batch, stride_da_head, stride_da_seq, stride_da_cs_batch, stride_da_cs_head, stride_da_cs_seq, stride_da_cs_rev_batch, stride_da_cs_rev_head, stride_da_cs_rev_seq, stride_segsum_batch, stride_segsum_head, stride_segsum_chunk, stride_segsum_row, stride_segsum_col, SEQLEN: tl.constexpr, CHUNK_SIZE: tl.constexpr, ): pid_batch = tl.program_id(0) pid_head = tl.program_id(1) pid_chunk = tl.program_id(2) chunk_start = pid_chunk * CHUNK_SIZE offs = tl.arange(0, CHUNK_SIZE) offs_seq = chunk_start + offs mask = offs_seq < SEQLEN base_da = pid_batch * stride_da_batch + pid_head * stride_da_head da_chunk = tl.load(da_ptr + base_da + offs_seq * stride_da_seq, mask=mask, other=0.0) da_cs = tl.cumsum(da_chunk, axis=0) da_cs = tl.minimum(da_cs, 0.0) da_cs_rev = tl.cumsum(da_chunk, axis=0, reverse=True) # Roll one to the left: i = tl.arange(0, CHUNK_SIZE)[:, None] # [N,1] j = tl.arange(0, CHUNK_SIZE)[None, :] # [1,N] S = (i == j - 1) # strictly upper diagonal (one above main) da_cs_rev = tl.sum(tl.where(S, da_cs_rev, 0), axis=1) da_cs_rev = tl.minimum(da_cs_rev, 0.0) base_da_cs = pid_batch * stride_da_cs_batch + pid_head * stride_da_cs_head base_da_cs_rev = pid_batch * stride_da_cs_rev_batch + pid_head * stride_da_cs_rev_head tl.store(da_cs_ptr + base_da_cs + offs_seq * stride_da_cs_seq, da_cs, mask=mask) tl.store(da_cs_rev_ptr + base_da_cs_rev + offs_seq * stride_da_cs_rev_seq, da_cs_rev, mask=mask) broadcasted_indices = tl.zeros_like(offs) segsum = tl.load(da_ptr + base_da + offs_seq[:, None] * stride_da_seq + broadcasted_indices[None, :]) offs_i = offs[:, None] offs_j = offs[None, :] segsum = tl.where(offs_i > offs_j, segsum, 0.0) segsum = tl.cumsum(segsum, axis=0) segsum = tl.minimum(segsum, 0.0) base_segsum = (pid_batch * stride_segsum_batch + pid_head * stride_segsum_head + pid_chunk * stride_segsum_chunk) tl.store(segsum_ptr + base_segsum + offs_i * stride_segsum_row + offs_j * stride_segsum_col, segsum) # ============================================================================ # Wrapper Function # ============================================================================ def bwd_dadt_fused_triton( dSSdA: torch.Tensor, # [B, H, nchunks, C, C] SSdA: torch.Tensor, # [B, H, nchunks, C, C] ddA_cs: torch.Tensor, # [B, H, S] ddA_cs_rev: torch.Tensor, # [B, H, S] dA_cs: torch.Tensor, # [B, H, S] dA_cs_rev: torch.Tensor, # [B, H, S] chunk_size: int, ) -> torch.Tensor: # Validate inputs B, H, S = ddA_cs.shape nchunks = S // chunk_size assert S % chunk_size == 0, f"Sequence length {S} must be divisible by chunk_size {chunk_size}" assert dSSdA.shape == (B, H, nchunks, chunk_size, chunk_size), \ f"dSSdA shape mismatch: expected {(B, H, nchunks, chunk_size, chunk_size)}, got {dSSdA.shape}" # Initialize output tensor dadt_out = torch.zeros(B, H, S, device=ddA_cs.device, dtype=torch.float32) # Kernel 1: Fused ddA_cs and ddA_cs_rev contributions grid1 = (B, H, nchunks) bwd_dadt_cumsum_fused_kernel[grid1]( ddA_cs, ddA_cs_rev, dA_cs, dA_cs_rev, dadt_out, ddA_cs.stride(0), ddA_cs.stride(1), ddA_cs.stride(2), B, H, S, CHUNK_SIZE=chunk_size, ) # Kernel 2: dSSdA segsum contribution grid2 = (B, H, nchunks) bwd_segsum_dadt_kernel[grid2]( dSSdA, SSdA, dadt_out, dSSdA.stride(0), dSSdA.stride(1), dSSdA.stride(2), dSSdA.stride(3), dSSdA.stride(4), SSdA.stride(0), SSdA.stride(1), SSdA.stride(2), SSdA.stride(3), SSdA.stride(4), dadt_out.stride(0), dadt_out.stride(1), dadt_out.stride(2), B, H, nchunks, chunk_size, CHUNK_SIZE=chunk_size, ) return dadt_out def bwd_dtrap_ddt_triton( trap: torch.Tensor, # [B, H, S] dt: torch.Tensor, # [B, H, S] dfactor: torch.Tensor, # [B, H, S] dgamma_diag: torch.Tensor, # [B, H, S] chunk_size: int, # NOTE: the chunk_size does not have to be the same as the other kernels ): B, H, S = dt.shape nchunks = S // chunk_size ddt = torch.zeros_like(dt) dtrap = torch.zeros_like(trap) grid = (B, H, nchunks) bwd_dtrap_ddt_kernel[grid]( trap, dt, dfactor, dgamma_diag, ddt, dtrap, trap.stride(0), trap.stride(1), trap.stride(2), dt.stride(0), dt.stride(1), dt.stride(2), dfactor.stride(0), dfactor.stride(1), dfactor.stride(2), dgamma_diag.stride(0), dgamma_diag.stride(1), dgamma_diag.stride(2), ddt.stride(0), ddt.stride(1), ddt.stride(2), dtrap.stride(0), dtrap.stride(1), dtrap.stride(2), S, chunk_size, ) return ddt, dtrap def compute_dacs_segsum_triton( da: torch.Tensor, # (B, H, S) chunk_size: int, ): B, H, S = da.shape nchunks = (S + chunk_size - 1) // chunk_size da_cs = torch.empty_like(da) da_cs_rev = torch.empty_like(da) segsum = torch.empty(B, H, nchunks, chunk_size, chunk_size, device=da.device, dtype=da.dtype) grid = (B, H, nchunks) dacs_segsum_kernel[grid]( da, da_cs, da_cs_rev, segsum, da.stride(0), da.stride(1), da.stride(2), da_cs.stride(0), da_cs.stride(1), da_cs.stride(2), da_cs_rev.stride(0), da_cs_rev.stride(1), da_cs_rev.stride(2), segsum.stride(0), segsum.stride(1), segsum.stride(2), segsum.stride(3), segsum.stride(4), S, chunk_size, ) return da_cs, da_cs_rev, segsum # ============================================================================ # Reference Implementations (for testing) # ============================================================================ def bwd_segsum_ddt_from_dSSdA_ref( dSSdA: torch.Tensor, dA_cs: torch.Tensor, chunk_size: int, ): """Reference implementation of bwd_segsum_ddt_from_dSSdA.""" B, H, nchunks, C, C_ = dSSdA.shape assert C == chunk_size == C_ dA_cs_chunk = dA_cs.view(B, H, nchunks, C) dSSdA = dSSdA.permute([0, 1, 2, 4, 3]) seg = dA_cs_chunk[..., :, None] - dA_cs_chunk[..., None, :] dSSdA = dSSdA * torch.exp(seg) ddA = torch.flip(torch.cumsum(torch.flip(dSSdA, dims=[-2]), dim=-2), dims=[-2]) mask = torch.tril(torch.ones(C, C, device=dSSdA.device, dtype=dSSdA.dtype), -1) ddA = ddA * mask ddA = ddA.sum(-1) ddt = ddA * (-math.log2(math.e)) return ddt.reshape(B, H, nchunks*C) def bwd_ddt_from_ddA_cs_rev_ref( ddA_cs_rev: torch.Tensor, dA_cs_rev: torch.Tensor, chunk_size: int, ): """Reference implementation of bwd_ddt_from_ddA_cs_rev.""" B, H, S = ddA_cs_rev.shape nchunks = S // chunk_size ddA_cs_rev = torch.exp(dA_cs_rev) * ddA_cs_rev dA_cs_rev = dA_cs_rev.view(B, H, nchunks, chunk_size) ddA_cs_rev = ddA_cs_rev.view(B, H, nchunks, chunk_size) ddA = torch.cumsum(ddA_cs_rev, dim=-1) ddA = torch.cat([torch.zeros_like(ddA[..., :1]), ddA[..., :-1]], dim=-1) ddt = ddA * (-math.log2(math.e)) return ddt.reshape(B, H, nchunks*chunk_size) def bwd_ddt_from_ddA_cs_ref( ddA_cs: torch.Tensor, dA_cs: torch.Tensor, chunk_size: int, ): """Reference implementation of bwd_ddt_from_ddA_cs.""" B, H, S = ddA_cs.shape nchunks = S // chunk_size ddA_cs = torch.exp(dA_cs) * ddA_cs dA_cs = dA_cs.view(B, H, nchunks, chunk_size) ddA_cs = ddA_cs.view(B, H, nchunks, chunk_size) ddA = torch.flip(torch.cumsum(torch.flip(ddA_cs, dims=[-1]), dim=-1), dims=[-1]) ddt = ddA * (-math.log2(math.e)) return ddt.reshape(B, H, nchunks*chunk_size) def compute_dtrap_ddt_ref(dfactor: torch.Tensor, dgamma_diag_input: torch.Tensor, trap_presigmoid, dt, ) -> Tuple[torch.Tensor, torch.Tensor]: trap = torch.nn.functional.sigmoid(trap_presigmoid) strap = torch.nn.functional.pad(trap[:, :, 1:], (0, 1), value=0.0) sdt = torch.nn.functional.pad(dt[:, :, 1:], (0, 1), value=0.0) dgamma = dfactor.detach().clone() + dgamma_diag_input.detach().clone() dsgamma = dfactor.detach().clone() # + dsgamma_input.detach().clone() dsdt = (1 - strap) * dsgamma dstrap = -sdt * dsgamma # shift rightward: ddt = torch.nn.functional.pad(dsdt[:, :, :-1], (1, 0), value=0.0) dtrap = torch.nn.functional.pad(dstrap[:, :, :-1], (1, 0), value=0.0) # Add the dgamma path: dtrap += dgamma*dt # grad of sigmoid(x) = sigmoid(x) * (1 - sigmoid(x)) dtrap *= trap * torch.nn.functional.sigmoid(-trap_presigmoid) ddt += dgamma*trap return ddt, dtrap def compute_dacs_segsum_ref(da: torch.Tensor, # (B, H, S) chunk_size: int, ): B, H, S = da.shape nchunks = S // chunk_size da_reshaped = da.view(B, H, nchunks, chunk_size) da_cs = torch.cumsum(da_reshaped, dim=-1) da_cs_sum = torch.sum(da_reshaped, dim=-1) da_cs_rev = da_cs_sum[..., None] - da_cs #torch.flip(torch.cumsum(torch.flip(da_reshaped, dims=[-1]), dim=-1), dims=[-1]) from einops import repeat segsum = repeat(da_reshaped, "... d -> ... d e", e=chunk_size) mask = torch.tril(torch.ones(chunk_size, chunk_size, device=da_cs.device, dtype=bool), diagonal=-1) segsum = segsum.masked_fill(~mask, 0) segsum = torch.cumsum(segsum, dim=-2) return da_cs.view(B, H, S), da_cs_rev.view(B, H, S), segsum # ============================================================================ # Testing Functions # ============================================================================ def test_bwd_ddt_fused_correctness(): """Test the fused kernel against reference implementation.""" print("=" * 70) print("Test: basic_correctness") print("=" * 70) B, H, S = 16, 32, 2048 chunk_size = 16 nchunks = S // chunk_size C = chunk_size # Generate random inputs torch.manual_seed(42) dSSdA = torch.randn(B, H, nchunks, C, C, device='cuda', dtype=torch.float32) ddA_cs = torch.randn(B, H, S, device='cuda', dtype=torch.float32) ddA_cs_rev = torch.randn(B, H, S, device='cuda', dtype=torch.float32) dA_cs = torch.randn(B, H, S, device='cuda', dtype=torch.float32) * 0.1 # Scale to avoid overflow dA_cs_rev = torch.randn(B, H, S, device='cuda', dtype=torch.float32) * 0.1 dA_cs_reshape = dA_cs.view(B, H, nchunks, chunk_size) SSdA = dA_cs_reshape[:, :, :, :, None] - dA_cs_reshape[:, :, :, None, :] # Reference implementation (separate functions) ddt_ref1 = bwd_segsum_ddt_from_dSSdA_ref(dSSdA.clone(), dA_cs.clone(), chunk_size) ddt_ref2 = bwd_ddt_from_ddA_cs_rev_ref(ddA_cs_rev.clone(), dA_cs_rev.clone(), chunk_size) ddt_ref3 = bwd_ddt_from_ddA_cs_ref(ddA_cs.clone(), dA_cs.clone(), chunk_size) ddt_ref = ddt_ref1 + ddt_ref2 + ddt_ref3 # TODO: # Fused Triton implementation ddt_triton = bwd_dadt_fused_triton( dSSdA, SSdA, ddA_cs, ddA_cs_rev, dA_cs, dA_cs_rev, chunk_size ) * -1.4426950408889634 # i.e., -log2(e) # Compare max_diff = (ddt_ref - ddt_triton).abs().max().item() mean_diff = (ddt_ref - ddt_triton).abs().mean().item() print(f" Max difference: {max_diff:.2e}") print(f" Mean difference: {mean_diff:.2e}") passed = max_diff < 1e-4 print(f" Status: {'PASS' if passed else 'FAIL'}") print() return passed def test_dtrap_ddt_correctness(): """Test the fused kernel against reference implementation.""" import torch.nn.functional as F print("=" * 70) print("Test: basic_correctness") print("=" * 70) B, H, S = 16, 32, 2048 chunk_size = 16 nchunks = S // chunk_size C = chunk_size # Generate random inputs torch.manual_seed(42) trap = torch.rand(B, H, S, device='cuda', dtype=torch.float16) dt = F.softplus(-3.0 + torch.randn(B, H, S, device='cuda', dtype=torch.float)) dfactor = torch.randn(B, H, S, device='cuda', dtype=torch.float32) * 0.1 dgamma_diag = torch.randn(B, H, S, device='cuda', dtype=torch.float32) * 0.1 # Reference implementation ddt_ref, dtrap_ref = compute_dtrap_ddt_ref(dfactor, dgamma_diag, trap, dt) # Triton implementation ddt_triton, dtrap_triton = bwd_dtrap_ddt_triton( trap, dt, dfactor, dgamma_diag, chunk_size ) # Compare max_diff_ddt = (ddt_ref - ddt_triton).abs().max().item() mean_diff_ddt = (ddt_ref - ddt_triton).abs().mean().item() max_diff_dtrap = (dtrap_ref - dtrap_triton).abs().max().item() mean_diff_dtrap = (dtrap_ref - dtrap_triton).abs().mean().item() print(f" ddt max difference: {max_diff_ddt:.2e}") print(f" ddt mean difference: {mean_diff_ddt:.2e}") print(f" dtrap max difference: {max_diff_dtrap:.2e}") print(f" dtrap mean difference:{mean_diff_dtrap:.2e}") passed = max(max_diff_ddt, max_diff_dtrap) < 1e-3 print(f" Status: {'PASS' if passed else 'FAIL'}") print() return passed def test_dacs_segsum_correctness(): import torch.nn.functional as F B, H, S = 16, 32, 2048 chunk_size = 16 da = -F.softplus(-3.0 + torch.randn(B, H, S, device='cuda', dtype=torch.float)) da_cs_ref, da_cs_rev_ref, segsum_ref = compute_dacs_segsum_ref(da, chunk_size) da_cs_triton, da_cs_rev_triton, segsum_triton = compute_dacs_segsum_triton(da, chunk_size) max_diff_cs = (da_cs_ref - da_cs_triton).abs().max().item() mean_diff_cs = (da_cs_ref - da_cs_triton).abs().mean().item() max_diff_cs_rev = (da_cs_rev_ref - da_cs_rev_triton).abs().max().item() mean_diff_cs_rev = (da_cs_rev_ref - da_cs_rev_triton).abs().mean().item() max_diff_segsum = (segsum_ref - segsum_triton).abs().max().item() mean_diff_segsum = (segsum_ref - segsum_triton).abs().mean().item() print(f" da_cs max difference: {max_diff_cs:.2e}") print(f" da_cs mean difference: {mean_diff_cs:.2e}") print(f" da_cs_rev max difference: {max_diff_cs_rev:.2e}") print(f" da_cs_rev mean difference:{mean_diff_cs_rev:.2e}") print(f" segsum max difference: {max_diff_segsum:.2e}") print(f" segsum mean difference: {mean_diff_segsum:.2e}") passed = max(max_diff_cs, max_diff_cs_rev, max_diff_segsum) < 1e-4 print(f" Status: {'PASS' if passed else 'FAIL'}") print() return passed # ============================================================================ # Benchmarking Functions # ============================================================================ def benchmark_bwd_ddt(): """Benchmark fused kernel against unfused baseline.""" from triton.testing import do_bench print("=" * 70) print("Benchmark: bwd_ddt_fused") print("=" * 70) B, H, S = 16, 32, 2048 chunk_size = 16 nchunks = S // chunk_size C = chunk_size print(f"Configuration: B={B}, H={H}, S={S}, chunk_size={chunk_size}") print() # Setup inputs torch.manual_seed(42) dSSdA = torch.randn(B, H, nchunks, C, C, device='cuda', dtype=torch.float32) ddA_cs = torch.randn(B, H, S, device='cuda', dtype=torch.float32) ddA_cs_rev = torch.randn(B, H, S, device='cuda', dtype=torch.float32) dA_cs = torch.randn(B, H, S, device='cuda', dtype=torch.float32) * 0.1 dA_cs_rev = torch.randn(B, H, S, device='cuda', dtype=torch.float32) * 0.1 dA_cs_reshape = dA_cs.view(B, H, nchunks, chunk_size) SSdA = dA_cs_reshape[:, :, :, :, None] - dA_cs_reshape[:, :, :, None, :] # Benchmark reference (unfused) def ref_impl(): ddt1 = bwd_segsum_ddt_from_dSSdA_ref(dSSdA, dA_cs, chunk_size) ddt2 = bwd_ddt_from_ddA_cs_rev_ref(ddA_cs_rev, dA_cs_rev, chunk_size) ddt3 = bwd_ddt_from_ddA_cs_ref(ddA_cs, dA_cs, chunk_size) return ddt1 + ddt2 + ddt3 # Benchmark individual components ref1_time = do_bench(lambda: bwd_segsum_ddt_from_dSSdA_ref(dSSdA, dA_cs, chunk_size), warmup=25, rep=100) ref2_time = do_bench(lambda: bwd_ddt_from_ddA_cs_rev_ref(ddA_cs_rev, dA_cs_rev, chunk_size), warmup=25, rep=100) ref3_time = do_bench(lambda: bwd_ddt_from_ddA_cs_ref(ddA_cs, dA_cs, chunk_size), warmup=25, rep=100) ref_time = do_bench(ref_impl, warmup=25, rep=100) # Benchmark fused def fused_impl(): return bwd_dadt_fused_triton( dSSdA, SSdA, ddA_cs, ddA_cs_rev, dA_cs, dA_cs_rev, chunk_size ) fused_time = do_bench(fused_impl, warmup=25, rep=100) print("Reference (unfused):") print(f" Function 1 (segsum): {ref1_time:.3f} ms") print(f" Function 2 (cs_rev): {ref2_time:.3f} ms") print(f" Function 3 (cs): {ref3_time:.3f} ms") print(f" Total: {ref_time:.3f} ms") print() print("Fused Triton:") print(f" Total: {fused_time:.3f} ms") print(f" Speedup: {ref_time / fused_time:.2f}x") print() return ref_time, fused_time def benchmark_dacs_segsum(): """Benchmark dacs+segsum Triton against reference implementation.""" from triton.testing import do_bench import torch.nn.functional as F print("=" * 70) print("Benchmark: dacs_segsum") print("=" * 70) B, H, S = 16, 32, 2048 chunk_size = 16 print(f"Configuration: B={B}, H={H}, S={S}, chunk_size={chunk_size}") print() torch.manual_seed(42) da = F.softplus(-3.0 + torch.randn(B, H, S, device='cuda', dtype=torch.float)) def ref_impl(): return compute_dacs_segsum_ref(da, chunk_size) def triton_impl(): return compute_dacs_segsum_triton(da, chunk_size) ref_time = do_bench(ref_impl, warmup=25, rep=100) triton_time = do_bench(triton_impl, warmup=25, rep=100) print("Reference:") print(f" Total: {ref_time:.3f} ms") print("Triton:") print(f" Total: {triton_time:.3f} ms") print(f" Speedup: {ref_time / triton_time:.2f}x") print() return ref_time, triton_time def benchmark_dtrap_ddt(): """Benchmark dtrap/ddt kernel against reference implementation.""" from triton.testing import do_bench import torch.nn.functional as F print("=" * 70) print("Benchmark: bwd_dtrap_ddt") print("=" * 70) B, H, S = 16, 32, 2048 chunk_size = 16 print(f"Configuration: B={B}, H={H}, S={S}, chunk_size={chunk_size}") print() torch.manual_seed(42) trap = torch.ones(B, H, S, device='cuda', dtype=torch.float16) * 0.5 dt = F.softplus(-3.0 + torch.randn(B, H, S, device='cuda', dtype=torch.float)) dfactor = torch.randn(B, H, S, device='cuda', dtype=torch.float32) * 0.1 dgamma_diag = torch.randn(B, H, S, device='cuda', dtype=torch.float32) * 0.1 def ref_impl(): return compute_dtrap_ddt_ref(dfactor, dgamma_diag, trap, dt) def triton_impl(): return bwd_dtrap_ddt_triton(trap, dt, dfactor, dgamma_diag, chunk_size) ref_time = do_bench(ref_impl, warmup=25, rep=100) triton_time = do_bench(triton_impl, warmup=25, rep=100) print("Reference:") print(f" Total: {ref_time:.3f} ms") print("Triton:") print(f" Total: {triton_time:.3f} ms") print(f" Speedup: {ref_time / triton_time:.2f}x") print() return ref_time, triton_time # ============================================================================ # Main execution # ============================================================================ if __name__ == "__main__": test_bwd_ddt_fused_correctness() # benchmark_bwd_ddt() test_dtrap_ddt_correctness() # benchmark_dtrap_ddt() # benchmark_dacs_segsum() test_dacs_segsum_correctness() # benchmark_dacs_segsum() ================================================ FILE: mamba_ssm/ops/triton/mamba3/mamba3_siso_bwd.py ================================================ """ Mamba-3 Backward Pass Triton Kernels. Copyright (c) 2026, Dao AI Lab, Goombalab """ from typing import Optional, Tuple import math import torch import torch.nn.functional as F from einops import rearrange, repeat import triton import triton.language as tl from mamba_ssm.ops.triton.mamba3.utils import cos_approx, sin_approx, sigmoid_approx # ============================================================================= # dZ Kernel # ============================================================================= @triton.autotune( configs=[ triton.Config({"CHUNK_SIZE": cs}, num_stages=s, num_warps=w) for cs in [32, 64] for s in [1, 2, 3] for w in [2, 4, 8] ], key=["HEADDIM_V"] ) @triton.jit def mamba3_siso_bwd_kernel_dzdo( # Input tensors DO, Z, O, # Output tensors Dz, DO_scaled, # Strides for DO: (batch, seqlen, nheads, headdim_v) stride_do_batch, stride_do_seqlen, stride_do_head, stride_do_vdim, # Strides for Z: (batch, seqlen, nheads, headdim_v) stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_vdim, # Strides for O: (batch, seqlen, nheads, headdim_v) stride_o_batch, stride_o_seqlen, stride_o_head, stride_o_vdim, # Strides for Dz: (batch, seqlen, nheads, headdim_v) stride_dz_batch, stride_dz_seqlen, stride_dz_head, stride_dz_vdim, # Strides for DO_scaled: (batch, seqlen, nheads, headdim_v) stride_do_scaled_batch, stride_do_scaled_seqlen, stride_do_scaled_head, stride_do_scaled_vdim, # Dimensions seqlen, headdim_v, # Compile-time constants CHUNK_SIZE: tl.constexpr, HEADDIM_V: tl.constexpr, ): """ Backward kernel for Z-gating: computes dZ and scales dO. In the forward pass, output is gated as: out = O * Z * sigmoid(Z) = O * silu(Z) This kernel computes: - dZ = dO * O * sigmoid(Z) * (1 + Z * (1 - sigmoid(Z))) - dO_scaled = dO * sigmoid(Z) * Z (for downstream gradient computation) Each program instance processes one (chunk, head, batch) triplet. """ pid_chunk = tl.program_id(0) pid_head = tl.program_id(1) pid_batch = tl.program_id(2) # Compute offsets for this (batch, head) pair do_offset = pid_batch * stride_do_batch + pid_head * stride_do_head z_offset = pid_batch * stride_z_batch + pid_head * stride_z_head o_offset = pid_batch * stride_o_batch + pid_head * stride_o_head dz_offset = pid_batch * stride_dz_batch + pid_head * stride_dz_head do_scaled_offset = pid_batch * stride_do_scaled_batch + pid_head * stride_do_scaled_head chunk_start = pid_chunk * CHUNK_SIZE offs_seq = chunk_start + tl.arange(0, CHUNK_SIZE) offs_dim = tl.arange(0, HEADDIM_V) mask = (offs_seq[:, None] < seqlen) & (offs_dim[None, :] < HEADDIM_V) # Load dO block: (CHUNK_SIZE, headdim_v) do_ptrs = DO + do_offset + offs_seq[:, None] * stride_do_seqlen + offs_dim[None, :] * stride_do_vdim do_block = tl.load(do_ptrs, mask=mask, other=0.0) # Load Z block: (CHUNK_SIZE, headdim_v) z_ptrs = Z + z_offset + offs_seq[:, None] * stride_z_seqlen + offs_dim[None, :] * stride_z_vdim z_block = tl.load(z_ptrs, mask=mask, other=0.0) # Load O block (pre-gating output): (CHUNK_SIZE, headdim_v) o_ptrs = O + o_offset + offs_seq[:, None] * stride_o_seqlen + offs_dim[None, :] * stride_o_vdim o_block = tl.load(o_ptrs, mask=mask, other=0.0) # Compute sigmoid(Z) for gating sigmoid_z = tl.sigmoid(z_block.to(tl.float32)) # Scale dO by sigmoid(Z) do_block = do_block * sigmoid_z # Compute dZ gradient # d/dZ [O * Z * sigmoid(Z)] = O * sigmoid(Z) * (1 + Z * (1 - sigmoid(Z))) # = O * sigmoid(Z) + O * Z * sigmoid(Z) * (1 - sigmoid(Z)) dz_block = do_block * o_block * (1 + z_block * (1 - sigmoid_z)) # Store dZ dz_ptrs = Dz + dz_offset + offs_seq[:, None] * stride_dz_seqlen + offs_dim[None, :] * stride_dz_vdim tl.store(dz_ptrs, dz_block, mask=mask) # Complete scaling of dO: dO * sigmoid(Z) * Z do_block = do_block * z_block # Store scaled dO for downstream gradient computation do_scaled_ptrs = DO_scaled + do_scaled_offset + offs_seq[:, None] * stride_do_scaled_seqlen + offs_dim[None, :] * stride_do_scaled_vdim tl.store(do_scaled_ptrs, do_block, mask=mask) def compute_dzdo( do: torch.Tensor, z: torch.Tensor, o: torch.Tensor, chunk_size: int = 64, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Compute Z-gating gradients for Mamba-3 backward pass. When Z-gating is used in the forward pass (out = O * silu(Z)), this function computes the gradient with respect to Z and scales dO for downstream gradient computation. Args: do: Output gradient tensor of shape (batch, seqlen, nheads, headdim_v) z: Gating tensor from forward pass of shape (batch, seqlen, nheads, headdim_v) o: Pre-gating output from forward pass of shape (batch, seqlen, nheads, headdim_v) chunk_size: Chunk size used in forward pass (default: 64) Returns: Tuple containing: - dz: Gradient for Z tensor of shape (batch, seqlen, nheads, headdim_v) - do_scaled: Scaled output gradient of shape (batch, seqlen, nheads, headdim_v) This should be used as input to subsequent gradient kernels. """ batch, seqlen, nheads, headdim_v = do.shape # Validate inputs assert z is not None and o is not None and do is not None, "Z, O, and DO tensors must be provided" assert z.is_cuda and o.is_cuda and do.is_cuda, "All tensors must be on CUDA" assert z.shape == do.shape and o.shape == do.shape, f"Shape mismatch: Z={z.shape}, O={o.shape}, DO={do.shape}" # Ensure contiguity for optimal memory access if do.stride(-1) != 1: do = do.contiguous() if z.stride(-1) != 1: z = z.contiguous() if o.stride(-1) != 1: o = o.contiguous() # Allocate output tensors dz = torch.empty_like(z, dtype=do.dtype) do_scaled = torch.empty_like(do, dtype=do.dtype) # Round up head dimension to power of 2 for efficient loading HEADDIM_V = triton.next_power_of_2(headdim_v) # Launch kernel: grid = (nchunks, nheads, batch) # CHUNK_SIZE is autotuned, so we compute nchunks dynamically via a lambda def grid(META): return (triton.cdiv(seqlen, META["CHUNK_SIZE"]), nheads, batch) mamba3_siso_bwd_kernel_dzdo[grid]( do, z, o, dz, do_scaled, # DO strides do.stride(0), do.stride(1), do.stride(2), do.stride(3), # Z strides z.stride(0), z.stride(1), z.stride(2), z.stride(3), # O strides o.stride(0), o.stride(1), o.stride(2), o.stride(3), # Dz strides dz.stride(0), dz.stride(1), dz.stride(2), dz.stride(3), # DO_scaled strides do_scaled.stride(0), do_scaled.stride(1), do_scaled.stride(2), do_scaled.stride(3), # Dimensions seqlen, headdim_v, # Compile-time constants HEADDIM_V=HEADDIM_V, ) return dz, do_scaled # ============================================================================= # dQKV Kernel # ============================================================================= @triton.autotune( configs=[ triton.Config({}, num_stages=s, num_warps=w) for s in [1, 2, 3] for w in [2, 4, 8] ], key=["CHUNK_SIZE", "HEADDIM_QK", "HEADDIM_V", "IS_VARLEN"] ) @triton.jit def mamba3_siso_bwd_kernel_dqkv( # Input tensors Q, K, V, DA_CS, DA_CS_SUM, QK_Dot, D, SSM_States, dO, d_OSSM_State, Cu_Seqlens, # dO is scaled with Z # Output tensors dQ, dK, dV, dADT, dQK_Dot, dD, d_ISSM_State, # dQK_Dot is scaled with scale # Strides for Inputs # Strides for Q: (batch, seqlen, nheads_qk, HEADDIM_QK) stride_q_batch, stride_q_seqlen, stride_q_head, stride_q_qkdim, # Strides for K: (batch, seqlen, nheads_qk, HEADDIM_QK) stride_k_batch, stride_k_seqlen, stride_k_head, stride_k_qkdim, # Strides for V: (batch, seqlen, nheads, HEADDIM_V) stride_v_batch, stride_v_seqlen, stride_v_head, stride_v_vdim, # Strides for DA_CS: (batch, nheads, seqlen) stride_da_cs_batch, stride_da_cs_head, stride_da_cs_seqlen, # Strides for DA_CS_SUM: (batch, nheads, nchunks) stride_da_cs_sum_batch, stride_da_cs_sum_head, stride_da_cs_sum_seqlen, # Strides for QK (QK dot products): (batch, nheads, nchunks*CHUNK_SIZE) stride_qk_dot_batch, stride_qk_dot_head, stride_qk_dot_seqlen, # Strides for D: (nheads,) stride_d_head, # Strides for SSM_States: (batch, nheads, HEADDIM_V, nchunks*HEADDIM_QK) stride_ssm_states_batch, stride_ssm_states_head, stride_ssm_states_vdim, stride_ssm_states_qkdim, # Strides for dO: (batch, seqlen, nheads, HEADDIM_V) stride_do_batch, stride_do_seqlen, stride_do_head, stride_do_vdim, # Strides for d_OSSM_State: (num_sequences, nheads, HEADDIM_V, HEADDIM_QK) stride_d_ossm_state_batch, stride_d_ossm_state_head, stride_d_ossm_state_vdim, stride_d_ossm_state_qkdim, # Strides for Cu_Seqlens: (num_sequences + 1,) stride_cu_seqlen, # Strides for Outputs # Strides for dQ: (batch, seqlen, nheads, HEADDIM_QK) stride_dq_batch, stride_dq_seqlen, stride_dq_head, stride_dq_qkdim, # Strides for dK: (batch, seqlen, nheads, HEADDIM_QK) stride_dk_batch, stride_dk_seqlen, stride_dk_head, stride_dk_qkdim, # Strides for dV: (batch, seqlen, nheads, HEADDIM_V) stride_dv_batch, stride_dv_seqlen, stride_dv_head, stride_dv_vdim, # Strides for dAdt: (batch, nheads, seqlen) stride_dadt_batch, stride_dadt_head, stride_dadt_seqlen, # Strides for dQK_dot: (batch, nheads, seqlen) stride_dQK_dot_batch, stride_dQK_dot_head, stride_dQK_dot_seqlen, # Strides for dD: (nheads,) stride_dd_batch, stride_dd_head, # Strides for d_ISSM_State: (num_sequences, nheads, HEADDIM_V, HEADDIM_QK) stride_d_issm_state_batch, stride_d_issm_state_head, stride_d_issm_state_vdim, stride_d_issm_state_qkdim, # Dimensions seqlen, nheads_qk, headdim_qk, headdim_v, CHUNK_SIZE: tl.constexpr, HEADDIM_QK: tl.constexpr, HEADDIM_V: tl.constexpr, RECOMPUTE_MASK: tl.constexpr, HAS_D_OSSM_STATE: tl.constexpr, RETURN_D_ISSM_STATE: tl.constexpr, IS_VARLEN: tl.constexpr, ): """ Backward kernel for Mamba-3 attention mechanism. Each program instance handles one (head, batch/seq) pair and iterates through all chunks in reverse order. This reverse iteration is necessary because state gradients flow backward through the sequence. The kernel computes: - dQ, dK: Gradients for query/key from both intra-chunk attention and inter-chunk states - dV: Gradient for values - dADT: Gradient for the decay parameter (A * dt) - dQK_Dot: Gradient for the QK dot product term - dD: Gradient for the skip connection (if present) - dISSM_State: Gradient for the input SSM state (if present) Grid: - Normal mode: (nheads, batch) - Varlen mode: (nheads, num_sequences) """ # ==================== Program Indexing ==================== pid_head = tl.program_id(0) pid_batch = tl.program_id(1) if IS_VARLEN: pid_seq = pid_batch pid_batch = 0 cu_seqlen = tl.load(Cu_Seqlens + pid_seq * stride_cu_seqlen).to(tl.int32) cu_seqlen_next = tl.load(Cu_Seqlens + (pid_seq + 1) * stride_cu_seqlen).to(tl.int32) seqlen = cu_seqlen_next - cu_seqlen cu_chunks = pid_seq + cu_seqlen // CHUNK_SIZE else: cu_seqlen = 0 cu_chunks = 0 pid_seq = 0 # Compute Q/K head index for GQA (grouped query attention) # Multiple output heads may share the same Q/K head nheads = tl.num_programs(0) head_idx_qk = pid_head // (nheads // nheads_qk) # Input Pointer Offsets q_offset = pid_batch * stride_q_batch + head_idx_qk * stride_q_head + IS_VARLEN * cu_seqlen * stride_q_seqlen k_offset = pid_batch * stride_k_batch + head_idx_qk * stride_k_head + IS_VARLEN * cu_seqlen * stride_k_seqlen v_offset = pid_batch * stride_v_batch + pid_head * stride_v_head + IS_VARLEN * cu_seqlen * stride_v_seqlen da_cs_offset = pid_batch * stride_da_cs_batch + pid_head * stride_da_cs_head + IS_VARLEN * cu_seqlen * stride_da_cs_seqlen da_cs_sum_offset = pid_batch * stride_da_cs_sum_batch + pid_head * stride_da_cs_sum_head + IS_VARLEN * cu_chunks * stride_da_cs_sum_seqlen qk_dot_offset = pid_batch * stride_qk_dot_batch + pid_head * stride_qk_dot_head + IS_VARLEN * cu_seqlen * stride_qk_dot_seqlen ssm_states_offset = pid_batch * stride_ssm_states_batch + pid_head * stride_ssm_states_head + IS_VARLEN * cu_chunks * HEADDIM_QK * stride_ssm_states_qkdim do_offset = pid_batch * stride_do_batch + pid_head * stride_do_head + IS_VARLEN * cu_seqlen * stride_do_seqlen if HAS_D_OSSM_STATE: d_ossm_state_offset = (pid_batch + IS_VARLEN * pid_seq) * stride_d_ossm_state_batch + pid_head * stride_d_ossm_state_head # Load skip connection value D if present if D is not None: D_offset = pid_head * stride_d_head D_val = tl.load(D + D_offset) # Output Pointer Offsets dq_offset = pid_batch * stride_dq_batch + pid_head * stride_dq_head + IS_VARLEN * cu_seqlen * stride_dq_seqlen dk_offset = pid_batch * stride_dk_batch + pid_head * stride_dk_head + IS_VARLEN * cu_seqlen * stride_dk_seqlen dv_offset = pid_batch * stride_dv_batch + pid_head * stride_dv_head + IS_VARLEN * cu_seqlen * stride_dv_seqlen dadt_offset = pid_batch * stride_dadt_batch + pid_head * stride_dadt_head + IS_VARLEN * cu_seqlen * stride_dadt_seqlen dQK_dot_offset = pid_batch * stride_dQK_dot_batch + pid_head * stride_dQK_dot_head + IS_VARLEN * cu_seqlen * stride_dQK_dot_seqlen if D is not None: dD_offset = pid_head * stride_dd_head + pid_batch * stride_dd_batch + IS_VARLEN * pid_seq * stride_dd_batch dD_acc = tl.zeros([1], dtype=tl.float32) if RETURN_D_ISSM_STATE: d_issm_state_offset = (pid_batch + IS_VARLEN * pid_seq) * stride_d_issm_state_batch + pid_head * stride_d_issm_state_head # Accumulates gradients flowing backward through states across chunks if HAS_D_OSSM_STATE: d_ssm_ptrs = d_OSSM_State + d_ossm_state_offset + tl.arange(0, HEADDIM_V)[:, None] * stride_d_ossm_state_vdim + tl.arange(0, HEADDIM_QK)[None, :] * stride_d_ossm_state_qkdim d_ssm_states_mask = (tl.arange(0, HEADDIM_V)[:, None] < headdim_v) & (tl.arange(0, HEADDIM_QK)[None, :] < headdim_qk) d_ssm_states_acc = tl.load(d_ssm_ptrs, mask=d_ssm_states_mask, other=0.0).to(tl.float32) else: d_ssm_states_acc = tl.zeros([HEADDIM_V, HEADDIM_QK], dtype=tl.float32) num_chunks = tl.cdiv(seqlen, CHUNK_SIZE) # TMA Descriptors for Efficient Memory Access q_desc = tl.make_tensor_descriptor( Q + q_offset, shape=[seqlen, headdim_qk], strides=[stride_q_seqlen, stride_q_qkdim], block_shape=[CHUNK_SIZE, HEADDIM_QK], ) k_desc = tl.make_tensor_descriptor( K + k_offset, shape=[seqlen, headdim_qk], strides=[stride_k_seqlen, stride_k_qkdim], block_shape=[CHUNK_SIZE, HEADDIM_QK], ) v_desc = tl.make_tensor_descriptor( V + v_offset, shape=[seqlen, headdim_v], strides=[stride_v_seqlen, stride_v_vdim], block_shape=[CHUNK_SIZE, HEADDIM_V], ) ssm_states_desc = tl.make_tensor_descriptor( SSM_States + ssm_states_offset, shape=[headdim_v, num_chunks * headdim_qk], strides=[stride_ssm_states_vdim, stride_ssm_states_qkdim], block_shape=[HEADDIM_V, HEADDIM_QK], ) do_desc = tl.make_tensor_descriptor( dO + do_offset, shape=[seqlen, headdim_v], strides=[stride_do_seqlen, stride_do_vdim], block_shape=[CHUNK_SIZE, HEADDIM_V], ) dq_desc = tl.make_tensor_descriptor( dQ + dq_offset, shape=[seqlen, headdim_qk], strides=[stride_dq_seqlen, stride_dq_qkdim], block_shape=[CHUNK_SIZE, HEADDIM_QK], ) dk_desc = tl.make_tensor_descriptor( dK + dk_offset, shape=[seqlen, headdim_qk], strides=[stride_dk_seqlen, stride_dk_qkdim], block_shape=[CHUNK_SIZE, HEADDIM_QK], ) dv_desc = tl.make_tensor_descriptor( dV + dv_offset, shape=[seqlen, headdim_v], strides=[stride_dv_seqlen, stride_dv_vdim], block_shape=[CHUNK_SIZE, HEADDIM_V], ) for chunk_idx_loop in range(num_chunks): chunk_idx = num_chunks - 1 - chunk_idx_loop # Reverse order for backward pass chunk_start = chunk_idx * CHUNK_SIZE # Sequence-length mask for non-TMA loads/stores offs_cs = chunk_start + tl.arange(0, CHUNK_SIZE) seq_mask = offs_cs < seqlen # ============================================================ # Load Decay Values # We load these first to overlap computation with TMA loads # ============================================================ da_cs_ptrs = DA_CS + da_cs_offset + offs_cs * stride_da_cs_seqlen da_cs = tl.load(da_cs_ptrs, mask=seq_mask, other=0.0) # Cumulative decay within chunk: (CHUNK_SIZE,) da_cs_sum_ptrs = DA_CS_SUM + da_cs_sum_offset + chunk_idx * stride_da_cs_sum_seqlen da_cs_chunk_sum = tl.load(da_cs_sum_ptrs) # Total decay for this chunk: scalar # ============================================================ # Load Q, K, V, dO, SSM_States via TMA # ============================================================ do_block = do_desc.load([chunk_start, 0]) # (CHUNK_SIZE, HEADDIM_V) v_block = v_desc.load([chunk_start, 0]) # (CHUNK_SIZE, HEADDIM_V) q_block = q_desc.load([chunk_start, 0]) # (CHUNK_SIZE, HEADDIM_QK) k_block = k_desc.load([chunk_start, 0]) # (CHUNK_SIZE, HEADDIM_QK) ssm_states_block = ssm_states_desc.load([0, chunk_idx * headdim_qk]) # (HEADDIM_V, HEADDIM_QK) # ============================================================ # Compute Decay Scaling Factors # ============================================================ # Reverse cumsum: how much decay from position i to end of chunk da_cs_rev = da_cs_chunk_sum - da_cs exp_da_cs_rev = tl.math.exp2(da_cs_rev) # For scaling inter-chunk contributions exp_da_cs = tl.math.exp2(da_cs) # For scaling intra-chunk contributions # Compute strictly causal mask with exponential decay (this is L^T) if not RECOMPUTE_MASK: causal_decay_mask = tl.where( tl.arange(0, CHUNK_SIZE)[None, :] > tl.arange(0, CHUNK_SIZE)[:, None], tl.math.exp2(tl.minimum(da_cs[None, :] - da_cs[:, None], 0.0)), 0.0 ) # ============================================================ # Compute dADT Gradient (Part 1): From Intra-chunk Attention # This is register-heavy so we compute it early before spilling # ============================================================ # Gradient contribution from (QK^T ⊙ L) V term dAinv = tl.dot(v_block, tl.trans(do_block)) # V @ dO^T if RECOMPUTE_MASK: dAinv *= tl.math.exp2(tl.minimum(da_cs[None, :] - da_cs[:, None], 0.0)) dAinv = tl.where( tl.arange(0, CHUNK_SIZE)[None, :] > tl.arange(0, CHUNK_SIZE)[:, None], dAinv, 0.0 ) else: dAinv *= causal_decay_mask dAinv *= tl.dot(k_block, tl.trans(q_block)) # Element-wise with K @ Q^T dM_rev_vector = tl.sum(dAinv, axis=0) - tl.sum(dAinv, axis=1) # (CHUNK_SIZE,) # ============================================================ # Compute dK: Key Gradient # dK = (V @ dO^T ⊙ mask)^T @ Q + V @ dStates * scale # ============================================================ # Intra-chunk: dP^T @ Q where dP = dO @ V^T ⊙ mask dp_t_block = tl.dot(v_block, tl.trans(do_block)) # V @ dO^T: (CHUNK_SIZE, CHUNK_SIZE) if RECOMPUTE_MASK: dp_t_block *= tl.math.exp2(tl.minimum(da_cs[None, :] - da_cs[:, None], 0.0)) dp_t_block = tl.where( tl.arange(0, CHUNK_SIZE)[None, :] > tl.arange(0, CHUNK_SIZE)[:, None], dp_t_block, 0.0 ) else: dp_t_block *= causal_decay_mask acc_dk = tl.dot(dp_t_block.to(q_block.dtype), q_block) # (CHUNK_SIZE, HEADDIM_QK) # Inter-chunk: gradient flowing through accumulated states acc_dk += tl.dot(v_block, d_ssm_states_acc.to(v_block.dtype)) * exp_da_cs_rev[:, None] dk_desc.store([chunk_start, 0], acc_dk) # ============================================================ # Compute dQ: Query Gradient # dQ = (V @ dO^T ⊙ mask) @ K + dO @ States * scale # ============================================================ # Intra-chunk: S^T @ K where S = V @ dO^T ⊙ mask s_block = tl.dot(v_block, tl.trans(do_block)) # (CHUNK_SIZE, CHUNK_SIZE) if RECOMPUTE_MASK: s_block *= tl.math.exp2(tl.minimum(da_cs[None, :] - da_cs[:, None], 0.0)) s_block = tl.where( tl.arange(0, CHUNK_SIZE)[None, :] > tl.arange(0, CHUNK_SIZE)[:, None], s_block, 0.0 ) else: s_block *= causal_decay_mask acc_dq = tl.dot(tl.trans(s_block).to(k_block.dtype), k_block) # (CHUNK_SIZE, HEADDIM_QK) # Inter-chunk: gradient through states from previous chunks acc_dq += tl.dot(do_block, ssm_states_block) * exp_da_cs[:, None] dq_desc.store([chunk_start, 0], acc_dq) # ============================================================ # Compute dV: Value Gradient # dV = (K @ Q^T ⊙ mask) @ dO + K @ dStates^T * scale + dO * (D + qk_dot) # ============================================================ # Intra-chunk: P^T @ dO where P = Q @ K^T ⊙ mask p_t_block = tl.dot(k_block, tl.trans(q_block)) # K @ Q^T: (CHUNK_SIZE, CHUNK_SIZE) if RECOMPUTE_MASK: p_t_block *= tl.math.exp2(tl.minimum(da_cs[None, :] - da_cs[:, None], 0.0)) p_t_block = tl.where( tl.arange(0, CHUNK_SIZE)[None, :] > tl.arange(0, CHUNK_SIZE)[:, None], p_t_block, 0.0 ) else: p_t_block *= causal_decay_mask acc_dv = tl.dot(p_t_block.to(do_block.dtype), do_block) # (CHUNK_SIZE, HEADDIM_V) # Inter-chunk: gradient through states acc_dv += tl.dot(k_block, tl.trans(d_ssm_states_acc).to(k_block.dtype)) * exp_da_cs_rev[:, None] # Skip connection gradient contribution # Load dO again with volatile to avoid cache conflicts dO_reloaded = tl.load( dO + do_offset + offs_cs[:, None] * stride_do_seqlen + tl.arange(0, HEADDIM_V)[None, :] * stride_do_vdim, mask=seq_mask[:, None] & (tl.arange(0, HEADDIM_V)[None, :] < headdim_v), other=0.0, volatile=True ) qk_dot = tl.load(QK_Dot + qk_dot_offset + offs_cs * stride_qk_dot_seqlen, mask=seq_mask, other=0.0) if D is not None: acc_dv += dO_reloaded * (D_val + qk_dot[:, None]) else: acc_dv += dO_reloaded * qk_dot[:, None] dv_desc.store([chunk_start, 0], acc_dv) # ============================================================ # Compute dQK_Dot and dD: Skip Connection Gradients # ============================================================ v_block_reloaded = tl.load( V + v_offset + offs_cs[:, None] * stride_v_seqlen + tl.arange(0, HEADDIM_V)[None, :] * stride_v_vdim, mask=seq_mask[:, None] & (tl.arange(0, HEADDIM_V)[None, :] < headdim_v), other=0.0, volatile=True ) # dQK_dot = sum_v(dO * V) for each position dQK_dot_block = tl.dot( dO_reloaded * v_block_reloaded, tl.full([HEADDIM_V, 1], 1, dtype=dO_reloaded.dtype) ) tl.store( dQK_Dot + dQK_dot_offset + offs_cs * stride_dQK_dot_seqlen, dQK_dot_block.reshape(CHUNK_SIZE), mask=seq_mask ) # Accumulate dD gradient if D is not None: dD_acc += tl.dot( tl.full([1, CHUNK_SIZE], 1, dtype=tl.float32), dQK_dot_block ).reshape(1) # ============================================================ # Compute dADT Gradient (Part 2): From Inter-chunk States # ============================================================ # Gradient from Q @ States^T term QS = tl.dot(q_block, tl.trans(ssm_states_block)) # (CHUNK_SIZE, HEADDIM_V) dM_rev_vector += tl.sum(QS * dO_reloaded, axis=1) * exp_da_cs # (CHUNK_SIZE,) # ============================================================ # Compute dADT Gradient (Part 3): From State Accumulation # ============================================================ # Gradient flowing through d_ssm_states_acc @ SSM_States SSM_States_ptrs = (SSM_States + ssm_states_offset + tl.arange(0, HEADDIM_V)[:, None] * stride_ssm_states_vdim + (chunk_idx * headdim_qk + tl.arange(0, HEADDIM_QK)[None, :]) * stride_ssm_states_qkdim) SSM_States_mask = (tl.arange(0, HEADDIM_V)[:, None] < headdim_v) & ((chunk_idx * headdim_qk + tl.arange(0, HEADDIM_QK)[None, :]) < num_chunks * headdim_qk) SSM_States_reloaded = tl.load(SSM_States_ptrs, volatile=True, mask=SSM_States_mask) # (HEADDIM_V, HEADDIM_QK) dM_scalar = tl.sum(SSM_States_reloaded * d_ssm_states_acc) * tl.math.exp2(da_cs_chunk_sum) # ============================================================ # Compute dADT Gradient (Part 4): From K @ dStates # ============================================================ dSK = tl.dot(k_block, tl.trans(d_ssm_states_acc).to(k_block.dtype)) # (CHUNK_SIZE, HEADDIM_V) dM_vector = tl.sum(dSK * v_block_reloaded, axis=1) * exp_da_cs_rev # (CHUNK_SIZE,) # ============================================================ # Combine dADT Gradient Components via Reverse Cumsum # ============================================================ dM_rev_vector += (tl.sum(dM_rev_vector) + dM_scalar) + tl.cumsum(dM_vector - dM_rev_vector) - dM_vector # Store dADT dadt_ptrs = dADT + dadt_offset + offs_cs * stride_dadt_seqlen tl.store(dadt_ptrs, dM_rev_vector, mask=seq_mask) # ============================================================ # Accumulate State Gradients for Previous Chunks # ============================================================ dO_reloaded *= exp_da_cs[:, None] d_ssm_states_acc = (tl.math.exp2(da_cs_chunk_sum) * d_ssm_states_acc + tl.dot(tl.trans(dO_reloaded).to(q_block.dtype), q_block)) # Store Final dD Gradient if D is not None: tl.store(dD + dD_offset + tl.arange(0, 1), dD_acc) # Store d_ISSM_State if RETURN_D_ISSM_STATE: d_ISSM_State_ptrs = d_ISSM_State + d_issm_state_offset + tl.arange(0, HEADDIM_V)[:, None] * stride_d_issm_state_vdim + tl.arange(0, HEADDIM_QK)[None, :] * stride_d_issm_state_qkdim d_ISSM_State_mask = (tl.arange(0, HEADDIM_V)[:, None] < headdim_v) & (tl.arange(0, HEADDIM_QK)[None, :] < headdim_qk) tl.store(d_ISSM_State_ptrs, d_ssm_states_acc, mask=d_ISSM_State_mask) def compute_dqkv( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, da_cs: torch.Tensor, da_cs_sum: torch.Tensor, qk_dot: torch.Tensor, SSM_States: torch.Tensor, do: torch.Tensor, d_ossm_state: Optional[torch.Tensor] = None, d_ov_state: Optional[torch.Tensor] = None, D: Optional[torch.Tensor] = None, chunk_size: int = 64, has_input_state: bool = False, Cu_Seqlens: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: """ Compute gradients dQ_mid, dK_mid, dV, dADT, dQK_dot, dD, d_issm_state for Mamba-3 backward pass. This kernel operates on the rotated/scaled Q and K tensors (Q_mid, K_mid from forward). Args: q: Rotated query tensor Q_mid (batch, seqlen, headdim_qk, headdim_qk) k: Rotated+scaled key tensor K_mid (batch, seqlen, headdim_qk, headdim_qk) v: Value tensor (batch, seqlen, nheads, headdim_v) da_cs: Cumulative decay per chunk (batch, nheads, seqlen) da_cs_sum: Sum of decay per chunk (batch, nheads, nchunks) qk_dot: QK dot products from forward (batch, nheads, seqlen) SSM_States: SSM states from forward pass (batch, nheads, headdim_v, nchunks * headdim_qk) do: Output gradient, possibly scaled by Z (batch, seqlen, nheads, headdim_v) d_ossm_state: Gradient of output SSM states (num_sequences, nheads, headdim_v, headdim_qk) d_ov_state: Gradient of output V state (num_sequences, nheads, headdim_v) - added to last token of dV D: Optional skip connection weight (nheads,) chunk_size: Chunk size (default: 64) has_input_state: Whether to compute gradient for input states Returns: Tuple of (dQ_mid, dK_mid, dV, dADT, dQK_dot, dD, d_issm_state) where d_issm_state is None if has_input_state=False """ batch, seqlen, nheads_qk, headdim_qk = q.shape _, _, nheads, headdim_v = v.shape is_varlen = Cu_Seqlens is not None if is_varlen: num_sequences = Cu_Seqlens.shape[0] - 1 assert batch == 1 nchunks = num_sequences + seqlen // chunk_size else: num_sequences = batch nchunks = (seqlen + chunk_size - 1) // chunk_size assert nheads % nheads_qk == 0, "nheads must be divisible by nheads_qk (for GQA support)" assert q.is_cuda and k.is_cuda and v.is_cuda and da_cs.is_cuda and da_cs_sum.is_cuda and do.is_cuda, "All tensors must be on CUDA" assert k.shape == q.shape assert v.shape == (batch, seqlen, nheads, headdim_v) assert da_cs.shape == (batch, nheads, seqlen) assert da_cs_sum.shape == (batch, nheads, nchunks) assert qk_dot.shape == (batch, nheads, seqlen) assert SSM_States.shape == (batch, nheads, headdim_v, nchunks * headdim_qk) assert do.shape == (batch, seqlen, nheads, headdim_v) assert d_ossm_state is None or d_ossm_state.shape == (num_sequences, nheads, headdim_v, headdim_qk) assert d_ov_state is None or d_ov_state.shape == (num_sequences, nheads, headdim_v) if D is not None: assert D.shape == (nheads,) # Ensure all tensors are contiguous for optimal memory access # Check if tensors have expected strides (innermost dimension stride = 1) if q.stride(-1) != 1: q = q.contiguous() if k.stride(-1) != 1: k = k.contiguous() if v.stride(-1) != 1: v = v.contiguous() if da_cs.stride(-1) != 1: da_cs = da_cs.contiguous() if da_cs_sum.stride(-1) != 1: da_cs_sum = da_cs_sum.contiguous() if qk_dot.stride(-1) != 1: qk_dot = qk_dot.contiguous() if SSM_States.stride(-1) != 1: SSM_States = SSM_States.contiguous() if do.stride(-1) != 1: do = do.contiguous() if D is not None and D.stride(-1) != 1: D = D.contiguous() if d_ossm_state is not None and d_ossm_state.stride(-1) != 1: d_ossm_state = d_ossm_state.contiguous() if d_ov_state is not None and d_ov_state.stride(-1) != 1: d_ov_state = d_ov_state.contiguous() # Allocate output tensors dq = torch.empty((batch, seqlen, nheads, headdim_qk), dtype=q.dtype, device=q.device) dk = torch.empty((batch, seqlen, nheads, headdim_qk), dtype=k.dtype, device=k.device) dv = torch.empty_like(v) dAdt = torch.empty_like(da_cs) dQK = torch.empty_like(da_cs) dD = torch.empty((num_sequences, nheads), dtype=torch.float32, device=q.device) if D is not None else None d_issm_state = torch.empty((num_sequences, nheads, headdim_v, headdim_qk), dtype=torch.float32, device=q.device) if has_input_state else None # Round up head dimensions to power of 2 for efficient loading HEADDIM_QK = triton.next_power_of_2(headdim_qk) HEADDIM_V = triton.next_power_of_2(headdim_v) # Grid: each program handles one (head, batch/num_sequences) pair if is_varlen: grid = (nheads, num_sequences) else: grid = (nheads, batch) # Launch kernel mamba3_siso_bwd_kernel_dqkv[grid]( q, k, v, da_cs, da_cs_sum, qk_dot, D, SSM_States, do, d_ossm_state, Cu_Seqlens, dq, dk, dv, dAdt, dQK, dD, d_issm_state, # Q strides q.stride(0), q.stride(1), q.stride(2), q.stride(3), # K strides k.stride(0), k.stride(1), k.stride(2), k.stride(3), # V strides v.stride(0), v.stride(1), v.stride(2), v.stride(3), # DA_CS strides da_cs.stride(0), da_cs.stride(1), da_cs.stride(2), # DA_CS_SUM strides da_cs_sum.stride(0), da_cs_sum.stride(1), da_cs_sum.stride(2), # QK_Dot strides qk_dot.stride(0), qk_dot.stride(1), qk_dot.stride(2), # D stride D.stride(0) if D is not None else 0, # SSM_States strides: (batch, nheads, headdim_v, nchunks*headdim_qk) SSM_States.stride(0), SSM_States.stride(1), SSM_States.stride(2), SSM_States.stride(3), # dO strides do.stride(0), do.stride(1), do.stride(2), do.stride(3), # d_ossm_state strides d_ossm_state.stride(0) if d_ossm_state is not None else 0, d_ossm_state.stride(1) if d_ossm_state is not None else 0, d_ossm_state.stride(2) if d_ossm_state is not None else 0, d_ossm_state.stride(3) if d_ossm_state is not None else 0, # Cu_Seqlens strides Cu_Seqlens.stride(0) if Cu_Seqlens is not None else 0, # dQ strides dq.stride(0), dq.stride(1), dq.stride(2), dq.stride(3), # dK strides dk.stride(0), dk.stride(1), dk.stride(2), dk.stride(3), # dV strides dv.stride(0), dv.stride(1), dv.stride(2), dv.stride(3), # dAdt strides dAdt.stride(0), dAdt.stride(1), dAdt.stride(2), # dQK strides dQK.stride(0), dQK.stride(1), dQK.stride(2), # dD strides dD.stride(0) if D is not None else 0, dD.stride(1) if D is not None else 0, # d_issm_state strides d_issm_state.stride(0) if d_issm_state is not None else 0, d_issm_state.stride(1) if d_issm_state is not None else 0, d_issm_state.stride(2) if d_issm_state is not None else 0, d_issm_state.stride(3) if d_issm_state is not None else 0, # Dimensions seqlen, nheads_qk, headdim_qk, headdim_v, # Compile-time constants CHUNK_SIZE=chunk_size, HEADDIM_QK=HEADDIM_QK, HEADDIM_V=HEADDIM_V, RECOMPUTE_MASK=False, HAS_D_OSSM_STATE=d_ossm_state is not None, RETURN_D_ISSM_STATE=has_input_state, IS_VARLEN=is_varlen, ) # Add output V state gradients to the last token if d_ov_state is not None: if is_varlen: last_token_idx = Cu_Seqlens[1:] - 1 dv[0, last_token_idx] += d_ov_state else: dv[:, -1, :, :] += d_ov_state dD = dD.sum(dim=0) if dD is not None else None return dq, dk, dv, dAdt, dQK, dD, d_issm_state # ============================================================================= # d Rotary+Bias Kernel # ============================================================================= @triton.autotune( configs=[ triton.Config({}, num_stages=s, num_warps=w) for s in [1, 2, 3] for w in [2, 4, 8] ], key=["CHUNK_SIZE", "BLOCK_HEADDIM_QK", "HEADDIM_QK", "GQA_RATIO"] ) @triton.jit def mamba3_siso_bwd_kernel_rotary_bias_angles( # Input tensors Q, K, Scale, Gamma, Q_bias, K_bias, Angles, dQ_in, dK_in, dQK, # Output tensors dQ, dK, dAngles, dScale, dGamma, dQ_bias, dK_bias, # Strides for inputs ------------------------------------------------------- # Q: (batch, seqlen, nheads_qk, BLOCK_HEADDIM_QK) stride_q_batch, stride_q_seqlen, stride_q_head, stride_q_qkdim, # K: (batch, seqlen, nheads_qk, BLOCK_HEADDIM_QK) stride_k_batch, stride_k_seqlen, stride_k_head, stride_k_qkdim, # Scale: (batch, nheads, seqlen) stride_scale_batch, stride_scale_head, stride_scale_seqlen, # Gamma: (batch, nheads, seqlen) stride_gamma_batch, stride_gamma_head, stride_gamma_seqlen, # Q_bias: (nheads, BLOCK_HEADDIM_QK) stride_q_bias_head, stride_q_bias_qkdim, # K_bias: (nheads, BLOCK_HEADDIM_QK) stride_k_bias_head, stride_k_bias_qkdim, # Angles: (batch, seqlen, nheads, BLOCK_HEADDIM_QK/2) stride_angles_batch, stride_angles_seqlen, stride_angles_head, stride_angles_qkdim, # dQ_in: (batch, seqlen, nheads, BLOCK_HEADDIM_QK) stride_dq_in_batch, stride_dq_in_seqlen, stride_dq_in_head, stride_dq_in_qkdim, # dK_in: (batch, seqlen, nheads, BLOCK_HEADDIM_QK) stride_dk_in_batch, stride_dk_in_seqlen, stride_dk_in_head, stride_dk_in_qkdim, # dQK: (batch, nheads, seqlen) stride_dqk_batch, stride_dqk_head, stride_dqk_seqlen, # Strides for outputs ------------------------------------------------------ # dQ: (batch, seqlen, nheads_qk, BLOCK_HEADDIM_QK) stride_dq_batch, stride_dq_seqlen, stride_dq_head, stride_dq_qkdim, # dK: (batch, seqlen, nheads_qk, BLOCK_HEADDIM_QK) stride_dk_batch, stride_dk_seqlen, stride_dk_head, stride_dk_qkdim, # dAngles: (batch, seqlen, nheads, BLOCK_HEADDIM_QK/2) stride_dangles_batch, stride_dangles_seqlen, stride_dangles_head, stride_dangles_qkdim, # dScale: (batch, nheads, HEADDIM_QK // BLOCK_HEADDIM_QK, seqlen) stride_dscale_batch, stride_dscale_head, stride_dscale_nqkchunks ,stride_dscale_seqlen, # dGamma: (batch, nheads, HEADDIM_QK // BLOCK_HEADDIM_QK, seqlen) stride_dgamma_batch, stride_dgamma_head, stride_dgamma_nqkchunks, stride_dgamma_seqlen, # dQ_bias: (batch, nchunks, nheads, BLOCK_HEADDIM_QK) stride_dq_bias_batch, stride_dq_bias_nchunks, stride_dq_bias_head, stride_dq_bias_qkdim, # dK_bias: (batch, nchunks, nheads, BLOCK_HEADDIM_QK) stride_dk_bias_batch, stride_dk_bias_nchunks, stride_dk_bias_head, stride_dk_bias_qkdim, # ---- sizes ---- seqlen, nheads_qk, nheads, headdim_qk, headdim_angles, CHUNK_SIZE: tl.constexpr, HEADDIM_QK: tl.constexpr, BLOCK_HEADDIM_QK: tl.constexpr, GQA_RATIO: tl.constexpr, ): """ Grid: (nchunks, batch) Each program processes one (batch, chunk) pair. Loop structure: - Outer loop: iterate over qk_heads (nheads_qk) - Inner loop: iterate over GQA group (GQA_RATIO heads per qk_head) """ pid_nchunk = tl.program_id(0) pid_batch = tl.program_id(1) nchunks = tl.cdiv(seqlen, CHUNK_SIZE) # Base offsets for inputs q_offset_base = pid_batch * stride_q_batch k_offset_base = pid_batch * stride_k_batch scale_offset_base = pid_batch * stride_scale_batch gamma_offset_base = pid_batch * stride_gamma_batch angle_offset_base = pid_batch * stride_angles_batch dq_in_offset_base = pid_batch * stride_dq_in_batch dk_in_offset_base = pid_batch * stride_dk_in_batch dqk_offset_base = pid_batch * stride_dqk_batch # Base offsets for outputs dq_offset_base = pid_batch * stride_dq_batch dk_offset_base = pid_batch * stride_dk_batch dangle_offset_base = pid_batch * stride_dangles_batch dscale_offset_base = pid_batch * stride_dscale_batch dgamma_offset_base = pid_batch * stride_dgamma_batch dq_bias_offset_base = pid_batch * stride_dq_bias_batch + pid_nchunk * stride_dq_bias_nchunks dk_bias_offset_base = pid_batch * stride_dk_bias_batch + pid_nchunk * stride_dk_bias_nchunks num_nheads_qk = HEADDIM_QK // BLOCK_HEADDIM_QK for nhead_qk_id in range(num_nheads_qk): offs_s = tl.arange(0, CHUNK_SIZE) + pid_nchunk * CHUNK_SIZE offs_d = tl.arange(0, BLOCK_HEADDIM_QK) + nhead_qk_id * BLOCK_HEADDIM_QK offs_dr = tl.arange(0, BLOCK_HEADDIM_QK // 2) + nhead_qk_id * (BLOCK_HEADDIM_QK // 2) # Outer loop: iterate over qk_heads for qk_head_idx in range(nheads_qk): # ============================================================ # Load Q, K for this qk_head (once per GQA group) # ============================================================ q_offset = q_offset_base + qk_head_idx * stride_q_head k_offset = k_offset_base + qk_head_idx * stride_k_head q_ptrs = Q + q_offset + offs_s[:, None] * stride_q_seqlen + offs_d[None, :] * stride_q_qkdim k_ptrs = K + k_offset + offs_s[:, None] * stride_k_seqlen + offs_d[None, :] * stride_k_qkdim # Zero accumulators for this qk_head dq_acc = tl.zeros((CHUNK_SIZE, BLOCK_HEADDIM_QK), dtype=tl.float32) dk_acc = tl.zeros((CHUNK_SIZE, BLOCK_HEADDIM_QK), dtype=tl.float32) # Inner loop: iterate over GQA group for gqa_idx in range(GQA_RATIO): nhead_idx = qk_head_idx * GQA_RATIO + gqa_idx # ============================================================ # Load per-head data # ============================================================ # Bias for this head q_bias = tl.load( Q_bias + nhead_idx * stride_q_bias_head + offs_d * stride_q_bias_qkdim, mask=offs_d < headdim_qk).to(tl.float32) k_bias = tl.load( K_bias + nhead_idx * stride_k_bias_head + offs_d * stride_k_bias_qkdim, mask=offs_d < headdim_qk).to(tl.float32) # Q + bias, K + bias q0 = tl.load(q_ptrs, mask=(offs_s[:, None] < seqlen) & (offs_d[None, :] < headdim_qk), other=0.0) # [CHUNK_SIZE, BLOCK_HEADDIM_QK] k0 = tl.load(k_ptrs, mask=(offs_s[:, None] < seqlen) & (offs_d[None, :] < headdim_qk), other=0.0) # [CHUNK_SIZE, BLOCK_HEADDIM_QK] Q_wbias = q0 + q_bias[None, :] K_wbias = k0 + k_bias[None, :] # dQK for this head dqk_offset = dqk_offset_base + nhead_idx * stride_dqk_head dqk = tl.load(dQK + dqk_offset + offs_s * stride_dqk_seqlen, mask=offs_s < seqlen, other=0.0) # Scale, Gamma for this head scale_offset = scale_offset_base + nhead_idx * stride_scale_head gamma_offset = gamma_offset_base + nhead_idx * stride_gamma_head scale = tl.load(Scale + scale_offset + offs_s * stride_scale_seqlen, mask=offs_s < seqlen, other=0.0).to(tl.float32) gamma = tl.load(Gamma + gamma_offset + offs_s * stride_gamma_seqlen, mask=offs_s < seqlen, other=0.0).to(tl.float32) # Angles for this head angle_offset = angle_offset_base + nhead_idx * stride_angles_head theta = tl.load( Angles + angle_offset + offs_s[:, None] * stride_angles_seqlen + offs_dr[None, :] * stride_angles_qkdim, mask=(offs_dr[None, :] < headdim_angles) & (offs_s[:, None] < seqlen), other=0.0).to(tl.float32) # dQ_in, dK_in for this head dq_in_offset = dq_in_offset_base + nhead_idx * stride_dq_in_head dk_in_offset = dk_in_offset_base + nhead_idx * stride_dk_in_head dQ_in_load = tl.load(dQ_in + dq_in_offset + offs_s[:, None] * stride_dq_in_seqlen + offs_d[None, :] * stride_dq_in_qkdim, mask=(offs_s[:, None] < seqlen) & (offs_d[None, :] < headdim_qk), other=0.0) dK_in_load = tl.load(dK_in + dk_in_offset + offs_s[:, None] * stride_dk_in_seqlen + offs_d[None, :] * stride_dk_in_qkdim, mask=(offs_s[:, None] < seqlen) & (offs_d[None, :] < headdim_qk), other=0.0) # ============================================================ # Compute dGamma = dQK * (Q_wbias · K_wbias) # ============================================================ QK_dot = tl.sum(Q_wbias * K_wbias, axis=1) d_gamma = dqk * QK_dot dgamma_store_offset = dgamma_offset_base + nhead_idx * stride_dgamma_head tl.store( dGamma + dgamma_store_offset + offs_s * stride_dgamma_seqlen + nhead_qk_id * stride_dgamma_nqkchunks, d_gamma, mask=offs_s < seqlen) # ============================================================ # Compute cos/sin for rotary # ============================================================ cos_angle = cos_approx(theta.to(tl.float32)) sin_angle = sin_approx(theta.to(tl.float32)) # ============================================================ # Compute dScale = sum(dK_in * K_rot) # ============================================================ K_r = tl.reshape(K_wbias, [CHUNK_SIZE, BLOCK_HEADDIM_QK // 2, 2]) K_r0, K_r1 = tl.split(K_r) K_rot0 = K_r0 * cos_angle - K_r1 * sin_angle K_rot1 = K_r0 * sin_angle + K_r1 * cos_angle K_rot = tl.reshape(tl.join(K_rot0, K_rot1), [CHUNK_SIZE, BLOCK_HEADDIM_QK]) dscale_val = tl.sum(dK_in_load * K_rot, axis=1) dscale_store_offset = dscale_offset_base + nhead_idx * stride_dscale_head tl.store( dScale + dscale_store_offset + offs_s * stride_dscale_seqlen + nhead_qk_id * stride_dscale_nqkchunks, dscale_val, mask=offs_s < seqlen) # ============================================================ # Compute dQ_pre, dK_pre through inverse rotary # ============================================================ dK_in_scaled = dK_in_load * scale[:, None] # shape: (CHUNK_SIZE, BLOCK_HEADDIM_QK) Q_r = tl.reshape(Q_wbias, [CHUNK_SIZE, BLOCK_HEADDIM_QK // 2, 2]) Q_r0, Q_r1 = tl.split(Q_r) dQ_in_r = tl.reshape(dQ_in_load, [CHUNK_SIZE, BLOCK_HEADDIM_QK // 2, 2]) dK_in_r = tl.reshape(dK_in_scaled, [CHUNK_SIZE, BLOCK_HEADDIM_QK // 2, 2]) dQ_in_r0, dQ_in_r1 = tl.split(dQ_in_r) dK_in_r0, dK_in_r1 = tl.split(dK_in_r) # Inverse rotary dq0 = dQ_in_r0 * cos_angle + dQ_in_r1 * sin_angle dq1 = -dQ_in_r0 * sin_angle + dQ_in_r1 * cos_angle dk0 = dK_in_r0 * cos_angle + dK_in_r1 * sin_angle dk1 = -dK_in_r0 * sin_angle + dK_in_r1 * cos_angle dQ_pre = tl.reshape(tl.join(dq0, dq1), [CHUNK_SIZE, BLOCK_HEADDIM_QK]) dK_pre = tl.reshape(tl.join(dk0, dk1), [CHUNK_SIZE, BLOCK_HEADDIM_QK]) # Add dQK path dqk_scaled = (dqk * gamma)[:, None] dQ_pre = dQ_pre + dqk_scaled * K_wbias dK_pre = dK_pre + dqk_scaled * Q_wbias # ============================================================ # Accumulate dQ, dK for GQA reduction # ============================================================ dq_acc += dQ_pre dk_acc += dK_pre # ============================================================ # Store dQ_bias, dK_bias for this head (sum over chunk) # ============================================================ dq_bias_out = tl.sum(dQ_pre, axis=0) dk_bias_out = tl.sum(dK_pre, axis=0) dq_bias_store_offset = dq_bias_offset_base + nhead_idx * stride_dq_bias_head dk_bias_store_offset = dk_bias_offset_base + nhead_idx * stride_dk_bias_head tl.store(dQ_bias + dq_bias_store_offset + offs_d * stride_dq_bias_qkdim, dq_bias_out, mask=offs_d < headdim_qk) tl.store(dK_bias + dk_bias_store_offset + offs_d * stride_dk_bias_qkdim, dk_bias_out, mask=offs_d < headdim_qk) # ============================================================ # Compute and store dAngles for this head # ============================================================ dtheta_q = dQ_in_r0 * (-Q_r0 * sin_angle - Q_r1 * cos_angle) + dQ_in_r1 * (Q_r0 * cos_angle - Q_r1 * sin_angle) dtheta_k = dK_in_r0 * (-K_r0 * sin_angle - K_r1 * cos_angle) + dK_in_r1 * (K_r0 * cos_angle - K_r1 * sin_angle) dtheta = dtheta_q + dtheta_k dangle_store_offset = dangle_offset_base + nhead_idx * stride_dangles_head tl.store( dAngles + dangle_store_offset + offs_s[:, None] * stride_dangles_seqlen + offs_dr[None, :] * stride_dangles_qkdim, dtheta, mask=(offs_dr[None, :] < headdim_angles) & (offs_s[:, None] < seqlen)) # ============================================================ # End of GQA group: store accumulated dQ, dK # ============================================================ dq_offset = dq_offset_base + qk_head_idx * stride_dq_head dk_offset = dk_offset_base + qk_head_idx * stride_dk_head dq_ptrs = dQ + dq_offset + offs_s[:, None] * stride_dq_seqlen + offs_d[None, :] * stride_dq_qkdim dk_ptrs = dK + dk_offset + offs_s[:, None] * stride_dk_seqlen + offs_d[None, :] * stride_dk_qkdim tl.store(dq_ptrs, dq_acc, mask=(offs_s[:, None] < seqlen) & (offs_d[None, :] < headdim_qk)) tl.store(dk_ptrs, dk_acc, mask=(offs_s[:, None] < seqlen) & (offs_d[None, :] < headdim_qk)) # NOTE: Do not autotune this kernel. It overwrites dK, dK_bias, dAngles via atomic adds and autotuning will lead to multiple overwrites. @triton.jit def mamba3_siso_bwd_kernel_dk_state_post( # Inputs tensors dK_State, Angles, K, K_bias, Cu_Seqlens, # Outputs tensors dK, dK_bias, dAngles, # Strides for dK_State: (num_sequences, nheads, headdim_qk) stride_dk_state_batch, stride_dk_state_head, stride_dk_state_qkdim, # Strides for Angles: (batch, seqlen, nheads, headdim_angles) stride_angles_batch, stride_angles_seqlen, stride_angles_head, stride_angles_qkdim, # Strides for K: (batch, seqlen, nheads_qk, headdim_qk) stride_k_batch, stride_k_seqlen, stride_k_head, stride_k_qkdim, # Strides for K_bias: (nheads, headdim_qk) stride_k_bias_head, stride_k_bias_qkdim, # Strides for Cu_Seqlens: (num_sequences + 1,) stride_cu_seqlen, # Strides for dK: (batch, seqlen, nheads_qk, headdim_qk) stride_dk_batch, stride_dk_seqlen, stride_dk_head, stride_dk_qkdim, # Strides for dK_bias: (nheads, headdim_qk) stride_dk_bias_head, stride_dk_bias_qkdim, # Strides for dAngles: (batch, seqlen, nheads, headdim_angles) stride_dangles_batch, stride_dangles_seqlen, stride_dangles_head, stride_dangles_qkdim, # Dimensions seqlen, headdim_qk, headdim_angles, HEADDIM_QK: tl.constexpr, GQA_RATIO: tl.constexpr, IS_VARLEN: tl.constexpr, ): """ Post-kernel for d_ok_state contributions. Grid: (nheads, batch) Each program handles one (batch, nhead) pair and computes: 1. dK via inverse rotary + GQA reduction (atomic add) 2. dK_bias via inverse rotary + batch reduction (atomic add) 3. dAngles via rotary gradient (atomic add) """ pid_head = tl.program_id(0) pid_batch = tl.program_id(1) if IS_VARLEN: pid_seq = pid_batch pid_batch = 0 cu_seqlen = tl.load(Cu_Seqlens + (pid_seq + 1) * stride_cu_seqlen).to(tl.int32) last_pos = cu_seqlen - 1 else: pid_seq = 0 last_pos = seqlen - 1 qk_head_idx = pid_head // GQA_RATIO offs_d = tl.arange(0, HEADDIM_QK) offs_dr = tl.arange(0, HEADDIM_QK // 2) # Load dK_State as interleaved pairs dk_state_base = dK_State + (pid_batch + pid_seq) * stride_dk_state_batch + pid_head * stride_dk_state_head dk_state = tl.load(dk_state_base + offs_d * stride_dk_state_qkdim, mask=offs_d < headdim_qk, other=0.0).to(tl.float32) dk_state_r = tl.reshape(dk_state, [HEADDIM_QK // 2, 2]) dk_state_r0, dk_state_r1 = tl.split(dk_state_r) # shape: (HEADDIM_QK // 2,) # Load angles at last position angles_base = Angles + pid_batch * stride_angles_batch + last_pos * stride_angles_seqlen + pid_head * stride_angles_head angles_val = tl.load(angles_base + offs_dr * stride_angles_qkdim, mask=offs_dr < headdim_angles, other=0.0).to(tl.float32) # shape: (HEADDIM_QK // 2,) cos_ang = cos_approx(angles_val) sin_ang = sin_approx(angles_val) # Inverse rotary: dk_rotated dk0 = dk_state_r0 * cos_ang + dk_state_r1 * sin_ang dk1 = -dk_state_r0 * sin_ang + dk_state_r1 * cos_ang dk_rotated = tl.reshape(tl.join(dk0, dk1), [HEADDIM_QK]) # 1. Accumulate to dK (GQA reduction via atomic) dk_base = dK + pid_batch * stride_dk_batch + last_pos * stride_dk_seqlen + qk_head_idx * stride_dk_head tl.atomic_add(dk_base + offs_d * stride_dk_qkdim, dk_rotated, mask=offs_d < headdim_qk) # 2. Accumulate to dK_bias (batch reduction via atomic) dk_bias_base = dK_bias + pid_head * stride_dk_bias_head tl.atomic_add(dk_bias_base + offs_d * stride_dk_bias_qkdim, dk_rotated, mask=offs_d < headdim_qk) # 3. Compute dAngles # Load K at last position (using qk_head_idx for GQA) k_base = K + pid_batch * stride_k_batch + last_pos * stride_k_seqlen + qk_head_idx * stride_k_head k_val = tl.load(k_base + offs_d * stride_k_qkdim, mask=offs_d < headdim_qk, other=0.0).to(tl.float32) kr = tl.reshape(k_val, [HEADDIM_QK // 2, 2]) k_r0, k_r1 = tl.split(kr) # shape: (HEADDIM_QK // 2,) # Load K_bias k_bias_base = K_bias + pid_head * stride_k_bias_head k_bias_val = tl.load(k_bias_base + offs_d * stride_k_bias_qkdim, mask=offs_d < headdim_qk, other=0.0).to(tl.float32) kbr = tl.reshape(k_bias_val, [HEADDIM_QK // 2, 2]) kb_r0, kb_r1 = tl.split(kbr) # shape: (HEADDIM_QK // 2,) # K_wbias = K + K_bias K_wbias_r0 = k_r0 + kb_r0 K_wbias_r1 = k_r1 + kb_r1 # dtheta = dk_r0 * (-K0*sin - K1*cos) + dk_r1 * (K0*cos - K1*sin) dtheta_k = (dk_state_r0 * (-K_wbias_r0 * sin_ang - K_wbias_r1 * cos_ang) + dk_state_r1 * (K_wbias_r0 * cos_ang - K_wbias_r1 * sin_ang)) # Accumulate to dAngles at last position da_base = dAngles + pid_batch * stride_dangles_batch + last_pos * stride_dangles_seqlen + pid_head * stride_dangles_head tl.atomic_add(da_base + offs_dr * stride_dangles_qkdim, dtheta_k, mask=offs_dr < headdim_angles) def compute_dqktheta( q: torch.Tensor, k: torch.Tensor, scale: torch.Tensor, gamma: torch.Tensor, q_bias: torch.Tensor, k_bias: torch.Tensor, angles: torch.Tensor, dq_in: torch.Tensor, dk_in: torch.Tensor, dqk: torch.Tensor, d_ok_state: Optional[torch.Tensor] = None, chunk_size: int = 64, Cu_Seqlens: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Compute gradients through rotary embeddings and biases for Mamba-3 backward pass. This kernel undoes the rotary embedding and computes gradients for the original Q, K, angles, scaling factors, and biases. Args: q: Original query tensor before bias/rotary (batch, seqlen, nheads_qk, headdim_qk) k: Original key tensor before bias/rotary (batch, seqlen, nheads_qk, headdim_qk) scale: Combined scale factor gamma + gamma (batch, nheads, seqlen) gamma: gamma factor (batch, nheads, seqlen) q_bias: Query bias (nheads, headdim_qk) k_bias: Key bias (nheads, headdim_qk) angles: Rotary angles (batch, seqlen, nheads, headdim_angles) dq_in: Gradient from downstream for Q_mid (batch, seqlen, nheads, headdim_qk) dk_in: Gradient from downstream for K_mid (batch, seqlen, nheads, headdim_qk) dqk: Gradient for QK dot products (batch, nheads, seqlen) d_ok_state: Gradient of output K state (batch, nheads, headdim_qk) - added to last token of dK (without scaling) chunk_size: Chunk size (default: 64) Returns: Tuple of (dQ, dK, dQ_bias, dK_bias, dAngles, dScale, dSGamma) - dQ: (batch, seqlen, nheads_qk, headdim_qk) - dK: (batch, seqlen, nheads_qk, headdim_qk) - dQ_bias: (nheads, headdim_qk) - dK_bias: (nheads, headdim_qk) - dAngles: (batch, seqlen, nheads, headdim_angles) - dScale: (batch, nheads, seqlen) - dGamma: (batch, nheads, seqlen) """ batch, seqlen, nheads_qk, headdim_qk = q.shape assert q.shape == k.shape nheads = scale.shape[1] nchunks = triton.cdiv(seqlen, chunk_size) GQA_RATIO = nheads // nheads_qk assert scale.shape == (batch, nheads, seqlen) assert gamma.shape == (batch, nheads, seqlen) assert q_bias.shape == (nheads, headdim_qk) assert k_bias.shape == (nheads, headdim_qk) headdim_angles = angles.shape[-1] assert angles.shape == (batch, seqlen, nheads, headdim_angles) assert dq_in.shape == (batch, seqlen, nheads, headdim_qk) assert dk_in.shape == (batch, seqlen, nheads, headdim_qk) assert dqk.shape == (batch, nheads, seqlen) if d_ok_state is not None: num_sequences = Cu_Seqlens.shape[0] - 1 if Cu_Seqlens is not None else batch assert d_ok_state.shape == (num_sequences, nheads, headdim_qk) assert nheads % nheads_qk == 0, "nheads must be multiple of nheads_qk for GQA support" # Ensure contiguity after reshaping if not q.is_contiguous(): q = q.contiguous() if not k.is_contiguous(): k = k.contiguous() if not scale.is_contiguous(): scale = scale.contiguous() if not gamma.is_contiguous(): gamma = gamma.contiguous() if not dqk.is_contiguous(): dqk = dqk.contiguous() if not angles.is_contiguous(): angles = angles.contiguous() if not dq_in.is_contiguous(): dq_in = dq_in.contiguous() if not dk_in.is_contiguous(): dk_in = dk_in.contiguous() if q_bias.stride(-1) != 1: q_bias = q_bias.contiguous() if k_bias.stride(-1) != 1: k_bias = k_bias.contiguous() if d_ok_state is not None and (not d_ok_state.is_contiguous()): d_ok_state = d_ok_state.contiguous() HEADDIM_QK = triton.next_power_of_2(headdim_qk) BLOCK_HEADDIM_QK = min(HEADDIM_QK, 64) # Allocate output tensors layout dq = torch.empty((batch, seqlen, nheads_qk, headdim_qk), dtype=dq_in.dtype, device=q.device) dk = torch.empty((batch, seqlen, nheads_qk, headdim_qk), dtype=dk_in.dtype, device=k.device) dangles = torch.empty((batch, seqlen, nheads, headdim_angles), dtype=angles.dtype, device=angles.device) dscale = torch.empty((batch, nheads, HEADDIM_QK // BLOCK_HEADDIM_QK, seqlen), dtype=scale.dtype, device=scale.device) dgamma = torch.empty((batch, nheads, HEADDIM_QK // BLOCK_HEADDIM_QK, seqlen), dtype=gamma.dtype, device=gamma.device) dq_bias_partial = torch.empty((batch, nchunks, nheads, headdim_qk), dtype=torch.float32, device=q.device) dk_bias_partial = torch.empty((batch, nchunks, nheads, headdim_qk), dtype=torch.float32, device=k.device) # Grid: (nchunks, batch) grid = (nchunks, batch) mamba3_siso_bwd_kernel_rotary_bias_angles[grid]( # Input tensors q, k, scale, gamma, q_bias, k_bias, angles, dq_in, dk_in, dqk, # Output tensors dq, dk, dangles, dscale, dgamma, dq_bias_partial, dk_bias_partial, # Q strides: (batch, seqlen, nheads_qk, headdim_qk) q.stride(0), q.stride(1), q.stride(2), q.stride(3), # K strides k.stride(0), k.stride(1), k.stride(2), k.stride(3), # Scale strides: (batch, nheads, seqlen) scale.stride(0), scale.stride(1), scale.stride(2), # SGamma strides gamma.stride(0), gamma.stride(1), gamma.stride(2), # Q_bias strides: (nheads, headdim_qk) q_bias.stride(0), q_bias.stride(1), # K_bias strides k_bias.stride(0), k_bias.stride(1), # Angles strides: (batch, seqlen, nheads, headdim_qk//2) angles.stride(0), angles.stride(1), angles.stride(2), angles.stride(3), # dQ_in strides: (batch, seqlen, nheads, headdim_qk) dq_in.stride(0), dq_in.stride(1), dq_in.stride(2), dq_in.stride(3), # dK_in strides dk_in.stride(0), dk_in.stride(1), dk_in.stride(2), dk_in.stride(3), # dQK strides: (batch, nheads, seqlen) dqk.stride(0), dqk.stride(1), dqk.stride(2), # Output tensors # dQ strides: (batch, seqlen, nheads_qk, headdim_qk) dq.stride(0), dq.stride(1), dq.stride(2), dq.stride(3), # dK strides dk.stride(0), dk.stride(1), dk.stride(2), dk.stride(3), # dAngles strides: (batch, seqlen, nheads, headdim_qk//2) dangles.stride(0), dangles.stride(1), dangles.stride(2), dangles.stride(3), # dScale strides: (batch, nheads, seqlen) dscale.stride(0), dscale.stride(1), dscale.stride(2), dscale.stride(3), # dSGamma strides dgamma.stride(0), dgamma.stride(1), dgamma.stride(2), dgamma.stride(3), # dQ_bias_partial strides: (batch, nchunks, nheads, headdim_qk) dq_bias_partial.stride(0), dq_bias_partial.stride(1), dq_bias_partial.stride(2), dq_bias_partial.stride(3), # dK_bias_partial strides dk_bias_partial.stride(0), dk_bias_partial.stride(1), dk_bias_partial.stride(2), dk_bias_partial.stride(3), # Sizes seqlen, nheads_qk, nheads, headdim_qk, headdim_angles, CHUNK_SIZE=chunk_size, HEADDIM_QK=HEADDIM_QK, BLOCK_HEADDIM_QK=BLOCK_HEADDIM_QK, GQA_RATIO=GQA_RATIO, ) # Reshape outputs back to original layout dscale = torch.sum(dscale, dim=2) # Sum over headdim blocks dgamma = torch.sum(dgamma, dim=2) # Sum over headdim blocks # Reduce bias gradients: (batch, nchunks, nheads, headdim_qk) -> (nheads, headdim_qk) dq_bias = dq_bias_partial.sum(dim=(0, 1)) dk_bias = dk_bias_partial.sum(dim=(0, 1)) # NOTE: We handle d_ok_state contributions in a different kernel because merging it in # causes a +800% increase in register spillage and a +200us increase in runtime. For now # this new kernel only introduces +5us. if d_ok_state is not None: apply_dk_state_post( d_ok_state, angles, k, k_bias, dk, dk_bias, dangles, Cu_Seqlens ) return dq, dk, dq_bias, dk_bias, dangles, dscale, dgamma def apply_dk_state_post( d_ok_state: torch.Tensor, angles: torch.Tensor, k: torch.Tensor, k_bias: torch.Tensor, dk: torch.Tensor, dk_bias: torch.Tensor, dangles: torch.Tensor, Cu_Seqlens: Optional[torch.Tensor] = None, ): batch, seqlen, nheads, headdim_angles = angles.shape _, _, headdim_qk = d_ok_state.shape nheads_qk = k.shape[2] GQA_RATIO = nheads // nheads_qk is_varlen = Cu_Seqlens is not None if is_varlen: num_sequences = Cu_Seqlens.shape[0] - 1 assert batch == 1 else: num_sequences = batch # Ensure contiguity if not d_ok_state.is_contiguous(): d_ok_state = d_ok_state.contiguous() if not angles.is_contiguous(): angles = angles.contiguous() if not k.is_contiguous(): k = k.contiguous() if not k_bias.is_contiguous(): k_bias = k_bias.contiguous() HEADDIM_QK = triton.next_power_of_2(headdim_qk) grid = (nheads, num_sequences) mamba3_siso_bwd_kernel_dk_state_post[grid]( # Input tensors d_ok_state, angles, k, k_bias, Cu_Seqlens, # Output tensors dk, dk_bias, dangles, # dK_State strides: (batch, nheads, headdim_qk) d_ok_state.stride(0), d_ok_state.stride(1), d_ok_state.stride(2), # Angles strides: (batch, seqlen, nheads, headdim_angles) angles.stride(0), angles.stride(1), angles.stride(2), angles.stride(3), # K strides: (batch, seqlen, nheads_qk, headdim_qk) k.stride(0), k.stride(1), k.stride(2), k.stride(3), # K_bias strides: (nheads, headdim_qk) k_bias.stride(0), k_bias.stride(1), # Cu_Seqlens strides: (num_sequences + 1,) Cu_Seqlens.stride(0) if is_varlen else 0, # dK strides: (batch, seqlen, nheads_qk, headdim_qk) dk.stride(0), dk.stride(1), dk.stride(2), dk.stride(3), # dK_bias strides: (nheads, headdim_qk) dk_bias.stride(0), dk_bias.stride(1), # dAngles strides: (batch, seqlen, nheads, headdim_angles) dangles.stride(0), dangles.stride(1), dangles.stride(2), dangles.stride(3), # Dimensions seqlen, headdim_qk, headdim_angles, HEADDIM_QK=HEADDIM_QK, GQA_RATIO=GQA_RATIO, IS_VARLEN=is_varlen, num_warps=2, num_stages=3, ) # ============================================================================= # dDT, dTrap, and dInput States Kernel # ============================================================================= @triton.autotune( configs=[ triton.Config({"CHUNK_SIZE": cs}, num_stages=s, num_warps=w) for cs in [64, 128, 256] for s in [1, 2, 3] for w in [2, 4, 8] ], key=["HEADDIM_V", "HEADDIM_QK", "HAS_INPUT_STATE", "IS_VARLEN"] ) @triton.jit def mamba3_siso_bwd_kernel_ddt_dtrap_dinput_states( # Input tensors dScale, dGamma, DT, Trap, d_ISSM_State, Input_K_State, Input_V_State, Cu_Seqlens, # Output tensors dDT, dTrap, dInput_SSM_State, dInput_K_State, dInput_V_State, # Strides for dScale: (batch, nheads, seqlen) stride_dscale_batch, stride_dscale_head, stride_dscale_seqlen, # Strides for dGamma: (batch, nheads, seqlen) stride_dgamma_batch, stride_dgamma_head, stride_dgamma_seqlen, # Strides for DT: (batch, nheads, seqlen) stride_dt_batch, stride_dt_head, stride_dt_seqlen, # Strides for Trap: (batch, nheads, seqlen) stride_trap_batch, stride_trap_head, stride_trap_seqlen, # Strides for d_ISSM_State: (num_sequences, nheads, headdim_v, headdim_qk) stride_d_issm_state_batch, stride_d_issm_state_head, stride_d_issm_state_vdim, stride_d_issm_state_qkdim, # Strides for Input_K_State: (num_sequences, nheads, headdim_qk) stride_input_k_state_batch, stride_input_k_state_head, stride_input_k_state_qkdim, # Strides for Input_V_State: (num_sequences, nheads, headdim_v) stride_input_v_state_batch, stride_input_v_state_head, stride_input_v_state_vdim, # Stride for Cu_Seqlens stride_cu_seqlen, # Strides for dDT: (batch, nheads, seqlen) stride_ddt_batch, stride_ddt_head, stride_ddt_seqlen, # Strides for dTrap: (batch, nheads, seqlen) stride_dtrap_batch, stride_dtrap_head, stride_dtrap_seqlen, # Strides for dInput_SSM_State: (num_sequences, nheads, headdim_v, headdim_qk) stride_dinput_ssm_state_batch, stride_dinput_ssm_state_head, stride_dinput_ssm_state_vdim, stride_dinput_ssm_state_qkdim, # Strides for dInput_K_State: (num_sequences, nheads, headdim_qk) stride_dinput_k_state_batch, stride_dinput_k_state_head, stride_dinput_k_state_qkdim, # Strides for dInput_V_State: (num_sequences, nheads, headdim_v) stride_dinput_v_state_batch, stride_dinput_v_state_head, stride_dinput_v_state_vdim, # Dimensions seqlen, headdim_v, headdim_qk, # Compile-time constants CHUNK_SIZE: tl.constexpr, HEADDIM_V: tl.constexpr, HEADDIM_QK: tl.constexpr, HAS_INPUT_STATE: tl.constexpr, IS_VARLEN: tl.constexpr, ): """ Backward kernel for computing dDT, dTrap, and input state gradients. Part 1 - dDT and dTrap from dScale and dGamma: Forward: gamma_t = DT_t * Trap_t (used independently) shifted_gamma_t = DT_{t+1} * (1 - Trap_{t+1}) (used as scale for position t) Backward: DT[t] appears in gamma[t] and shifted_gamma[t-1]: dDT_t = dGamma_t * Trap_t + dScale_{t-1} * (1 - Trap_t) Trap[t] appears in gamma[t] and shifted_gamma[t-1]: dTrap_t = dGamma_t * DT_t - dScale_{t-1} * DT_t Part 2 - Input state gradients (first token only, if HAS_INPUT_STATE): Forward: scalar = DT_0 * (1 - Trap_0) SSM_State = Input_SSM_State + outer(Input_V, Input_K) * scalar Backward: dInput_SSM_State = d_ISSM_State dInput_V = einsum(d_ISSM_State, Input_K) * scalar dInput_K = einsum(d_ISSM_State, Input_V) * scalar dDT_0 += d_scalar * (1 - Trap_0) dTrap_0 += d_scalar * (-DT_0) Grid: - Normal mode: (nheads, batch) - Varlen mode: (nheads, num_sequences) """ pid_head = tl.program_id(0) pid_batch = tl.program_id(1) if IS_VARLEN: pid_seq = pid_batch pid_batch = 0 cu_seqlen = tl.load(Cu_Seqlens + pid_seq * stride_cu_seqlen).to(tl.int32) cu_seqlen_next = tl.load(Cu_Seqlens + (pid_seq + 1) * stride_cu_seqlen).to(tl.int32) seqlen = cu_seqlen_next - cu_seqlen else: pid_seq = 0 cu_seqlen = 0 # ==================== Pointer Offsets ==================== dscale_offset = pid_batch * stride_dscale_batch + pid_head * stride_dscale_head + IS_VARLEN * cu_seqlen * stride_dscale_seqlen dgamma_offset = pid_batch * stride_dgamma_batch + pid_head * stride_dgamma_head + IS_VARLEN * cu_seqlen * stride_dgamma_seqlen dt_offset = pid_batch * stride_dt_batch + pid_head * stride_dt_head + IS_VARLEN * cu_seqlen * stride_dt_seqlen trap_offset = pid_batch * stride_trap_batch + pid_head * stride_trap_head + IS_VARLEN * cu_seqlen * stride_trap_seqlen ddt_offset = pid_batch * stride_ddt_batch + pid_head * stride_ddt_head + IS_VARLEN * cu_seqlen * stride_ddt_seqlen dtrap_offset = pid_batch * stride_dtrap_batch + pid_head * stride_dtrap_head + IS_VARLEN * cu_seqlen * stride_dtrap_seqlen # ==================== Part 1: dDT and dTrap ==================== num_chunks = tl.cdiv(seqlen, CHUNK_SIZE) for chunk_idx in range(num_chunks): offs_s = chunk_idx * CHUNK_SIZE + tl.arange(0, CHUNK_SIZE) mask = offs_s < seqlen # Load dscale_t, dGamma_t, Trap_t, DT_t for current positions dscale_t = tl.load(dScale + dscale_offset + offs_s * stride_dscale_seqlen, mask=mask, other=0.0) dgamma_t = tl.load(dGamma + dgamma_offset + offs_s * stride_dgamma_seqlen, mask=mask, other=0.0) trap_presig_t = tl.load(Trap + trap_offset + offs_s * stride_trap_seqlen, mask=mask, other=0.0).to(tl.float32) trap_t = sigmoid_approx(trap_presig_t) dt_t = tl.load(DT + dt_offset + offs_s * stride_dt_seqlen, mask=mask, other=0.0) # Load dScale_{t-1} (shifted by 1, with 0 at t=0) # shifted_gamma[t-1] = DT[t] * (1 - Trap[t]) feeds into scale[t-1] offs_s_prev = offs_s - 1 mask_prev = (offs_s_prev >= 0) & (offs_s_prev < seqlen) dscale_prev = tl.load( dScale + dscale_offset + offs_s_prev * stride_dscale_seqlen, mask=mask_prev, other=0.0 ) # Compute gradients: ddt_t = (dgamma_t + dscale_t) * trap_t + dscale_prev * (1.0 - trap_t) dtrap_t = (dgamma_t + dscale_t) * dt_t - dscale_prev * dt_t dtrap_presig_t = dtrap_t * trap_t * (1.0 - trap_t) # Store results tl.store(dDT + ddt_offset + offs_s * stride_ddt_seqlen, ddt_t, mask=mask) tl.store(dTrap + dtrap_offset + offs_s * stride_dtrap_seqlen, dtrap_presig_t, mask=mask) # ==================== Part 2: Input State Gradients ==================== if HAS_INPUT_STATE: # Pointer offsets for input states d_issm_offset = (pid_batch + pid_seq) * stride_d_issm_state_batch + pid_head * stride_d_issm_state_head input_k_offset = (pid_batch + pid_seq) * stride_input_k_state_batch + pid_head * stride_input_k_state_head input_v_offset = (pid_batch + pid_seq) * stride_input_v_state_batch + pid_head * stride_input_v_state_head dinput_ssm_offset = (pid_batch + pid_seq) * stride_dinput_ssm_state_batch + pid_head * stride_dinput_ssm_state_head dinput_k_offset = (pid_batch + pid_seq) * stride_dinput_k_state_batch + pid_head * stride_dinput_k_state_head dinput_v_offset = (pid_batch + pid_seq) * stride_dinput_v_state_batch + pid_head * stride_dinput_v_state_head # Load DT_0 and Trap_0 (first token) dt_0 = tl.load(DT + dt_offset).to(tl.float32) trap_presig_0 = tl.load(Trap + trap_offset).to(tl.float32) trap_0 = sigmoid_approx(trap_presig_0) scalar = dt_0 * (1.0 - trap_0) # Dimension offsets offs_v = tl.arange(0, HEADDIM_V) offs_qk = tl.arange(0, HEADDIM_QK) # Load Input_K_State and Input_V_State input_k = tl.load( Input_K_State + input_k_offset + offs_qk * stride_input_k_state_qkdim, mask=offs_qk < headdim_qk, other=0.0).to(tl.float32) input_v = tl.load( Input_V_State + input_v_offset + offs_v * stride_input_v_state_vdim, mask=offs_v < headdim_v, other=0.0 ).to(tl.float32) # Load d_ISSM_State: (headdim_v, headdim_qk) d_issm = tl.load( d_ISSM_State + d_issm_offset + offs_v[:, None] * stride_d_issm_state_vdim + offs_qk[None, :] * stride_d_issm_state_qkdim, mask=(offs_v[:, None] < headdim_v) & (offs_qk[None, :] < headdim_qk), other=0.0 ).to(tl.float32) # dInput_SSM_State = d_ISSM_State (direct copy) tl.store( dInput_SSM_State + dinput_ssm_offset + offs_v[:, None] * stride_dinput_ssm_state_vdim + offs_qk[None, :] * stride_dinput_ssm_state_qkdim, d_issm, mask=(offs_v[:, None] < headdim_v) & (offs_qk[None, :] < headdim_qk), ) # d_scalar = sum(d_ISSM_State * outer(Input_V, Input_K)) outer_product = input_v[:, None] * input_k[None, :] d_scalar = tl.sum(d_issm * outer_product) # dInput_V = sum_d(d_ISSM_State * Input_K) * scalar # dInput_K = sum_D(d_ISSM_State * Input_V) * scalar dinput_v = tl.sum(d_issm * input_k[None, :], axis=1) * scalar dinput_k = tl.sum(d_issm * input_v[:, None], axis=0) * scalar # Store dInput_V_State and dInput_K_State tl.store(dInput_V_State + dinput_v_offset + offs_v * stride_dinput_v_state_vdim, dinput_v, mask=offs_v < headdim_v) tl.store(dInput_K_State + dinput_k_offset + offs_qk * stride_dinput_k_state_qkdim, dinput_k, mask=offs_qk < headdim_qk) # Add contributions to dDT_0 and dTrap_0 from input state gradient ddt_0_contrib = d_scalar * (1.0 - trap_0) dtrap_0_contrib = d_scalar * (-dt_0) dtrap_0_presig_contrib = dtrap_0_contrib * trap_0 * (1.0 - trap_0) # Atomically add to the first position (already written in Part 1) tl.atomic_add(dDT + ddt_offset, ddt_0_contrib) tl.atomic_add(dTrap + dtrap_offset, dtrap_0_presig_contrib) def compute_ddt_dtrap_dinput_states( dscale: torch.Tensor, dgamma: torch.Tensor, dt: torch.Tensor, trap: torch.Tensor, d_issm_state: Optional[torch.Tensor] = None, input_k_state: Optional[torch.Tensor] = None, input_v_state: Optional[torch.Tensor] = None, Cu_Seqlens: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: """ Compute dDT, dTrap from dScale/dGamma, and optionally input state gradients. Args: dscale: Gradient of scale, shape (batch, nheads, seqlen) dgamma: Gradient of gamma, shape (batch, nheads, seqlen) dt: DT tensor from forward pass, shape (batch, nheads, seqlen) trap: Trap tensor from forward pass, shape (batch, nheads, seqlen) d_issm_state: Gradient of SSM_State_mid (optional), shape (batch, nheads, headdim_v, headdim_qk) input_k_state: Input K state from forward pass (optional), shape (batch, nheads, headdim_qk) input_v_state: Input V state from forward pass (optional), shape (batch, nheads, headdim_v) Returns: Tuple containing: - dDT: Gradient for DT, shape (batch, nheads, seqlen) - dTrap: Gradient for Trap, shape (batch, nheads, seqlen) - dInput_SSM_State: Gradient for Input_SSM_State (None if no input state) - dInput_K_State: Gradient for Input_K_State (None if no input state) - dInput_V_State: Gradient for Input_V_State (None if no input state) """ batch, nheads, seqlen = dscale.shape has_input_state = d_issm_state is not None is_varlen = Cu_Seqlens is not None if is_varlen: num_sequences = Cu_Seqlens.shape[0] - 1 assert batch == 1, "Batch size must be 1 when using variable-length sequences." else: num_sequences = batch # Validate inputs assert dgamma.shape == (batch, nheads, seqlen), f"dgamma shape mismatch: {dgamma.shape}" assert dt.shape == (batch, nheads, seqlen), f"dt shape mismatch: {dt.shape}" assert trap.shape == (batch, nheads, seqlen), f"trap shape mismatch: {trap.shape}" if has_input_state: assert input_k_state is not None and input_v_state is not None, \ "input_k_state and input_v_state must be provided with d_issm_state" headdim_v, headdim_qk = d_issm_state.shape[2], d_issm_state.shape[3] assert d_issm_state.shape == (num_sequences, nheads, headdim_v, headdim_qk), \ f"d_issm_state shape mismatch: {d_issm_state.shape}" assert input_k_state.shape == (num_sequences, nheads, headdim_qk), \ f"input_k_state shape mismatch: {input_k_state.shape}" assert input_v_state.shape == (num_sequences, nheads, headdim_v), \ f"input_v_state shape mismatch: {input_v_state.shape}" else: headdim_v, headdim_qk = 64, 128 # Dummy values for block size calculation # Ensure contiguity dscale = dscale.contiguous() if not dscale.is_contiguous() else dscale dgamma = dgamma.contiguous() if not dgamma.is_contiguous() else dgamma dt = dt.contiguous() if not dt.is_contiguous() else dt trap = trap.contiguous() if not trap.is_contiguous() else trap if has_input_state: d_issm_state = d_issm_state.contiguous() if not d_issm_state.is_contiguous() else d_issm_state input_k_state = input_k_state.contiguous() if not input_k_state.is_contiguous() else input_k_state input_v_state = input_v_state.contiguous() if not input_v_state.is_contiguous() else input_v_state # Allocate outputs dDT = torch.empty_like(dt, dtype=torch.float32) dTrap = torch.empty_like(trap, dtype=torch.float32) if has_input_state: d_Input_SSM_State = torch.empty_like(d_issm_state) d_Input_K_State = torch.empty((num_sequences, nheads, headdim_qk), dtype=torch.float32, device=dt.device) d_Input_V_State = torch.empty((num_sequences, nheads, headdim_v), dtype=torch.float32, device=dt.device) else: d_Input_SSM_State = None d_Input_K_State = None d_Input_V_State = None # Launch kernel HEADDIM_V = triton.next_power_of_2(headdim_v) if has_input_state else 64 HEADDIM_QK = triton.next_power_of_2(headdim_qk) if has_input_state else 128 # Grid if is_varlen: grid = (nheads, num_sequences) else: grid = (nheads, batch) mamba3_siso_bwd_kernel_ddt_dtrap_dinput_states[grid]( # Inputs dscale, dgamma, dt, trap, d_issm_state if has_input_state else dscale, # Dummy pointer if not used input_k_state if has_input_state else dscale, input_v_state if has_input_state else dscale, Cu_Seqlens, # Outputs dDT, dTrap, d_Input_SSM_State if has_input_state else dDT, # Dummy pointer if not used d_Input_K_State if has_input_state else dDT, d_Input_V_State if has_input_state else dDT, # Strides for dScale dscale.stride(0), dscale.stride(1), dscale.stride(2), # Strides for dSGamma dgamma.stride(0), dgamma.stride(1), dgamma.stride(2), # Strides for DT dt.stride(0), dt.stride(1), dt.stride(2), # Strides for Trap trap.stride(0), trap.stride(1), trap.stride(2), # Strides for d_ISSM_State d_issm_state.stride(0) if has_input_state else 0, d_issm_state.stride(1) if has_input_state else 0, d_issm_state.stride(2) if has_input_state else 0, d_issm_state.stride(3) if has_input_state else 0, # Strides for Input_K_State input_k_state.stride(0) if has_input_state else 0, input_k_state.stride(1) if has_input_state else 0, input_k_state.stride(2) if has_input_state else 0, # Strides for Input_V_State input_v_state.stride(0) if has_input_state else 0, input_v_state.stride(1) if has_input_state else 0, input_v_state.stride(2) if has_input_state else 0, # Stride for Cu_Seqlens Cu_Seqlens.stride(0) if Cu_Seqlens is not None else 0, # Strides for dDT dDT.stride(0), dDT.stride(1), dDT.stride(2), # Strides for dTrap dTrap.stride(0), dTrap.stride(1), dTrap.stride(2), # Strides for d_Input_SSM_State d_Input_SSM_State.stride(0) if has_input_state else 0, d_Input_SSM_State.stride(1) if has_input_state else 0, d_Input_SSM_State.stride(2) if has_input_state else 0, d_Input_SSM_State.stride(3) if has_input_state else 0, # Strides for d_Input_K_State d_Input_K_State.stride(0) if has_input_state else 0, d_Input_K_State.stride(1) if has_input_state else 0, d_Input_K_State.stride(2) if has_input_state else 0, # Strides for d_Input_V_State d_Input_V_State.stride(0) if has_input_state else 0, d_Input_V_State.stride(1) if has_input_state else 0, d_Input_V_State.stride(2) if has_input_state else 0, # Dimensions seqlen, headdim_v, headdim_qk, # Constants HEADDIM_V=HEADDIM_V, HEADDIM_QK=HEADDIM_QK, HAS_INPUT_STATE=has_input_state, IS_VARLEN=is_varlen, ) return dDT, dTrap, d_Input_SSM_State, d_Input_K_State, d_Input_V_State # ============================================================================= # Memory Allocator for TMA Descriptors # ============================================================================= def _alloc_fn(size: int, alignment: int, stream: Optional[int]): """Custom allocator for TMA descriptor global memory allocation.""" return torch.empty(size, device="cuda", dtype=torch.int8) triton.set_allocator(_alloc_fn) ================================================ FILE: mamba_ssm/ops/triton/mamba3/mamba3_siso_combined.py ================================================ """Mamba-3 Triton Autograd Wrapper Copyright (c) 2025, Dao AI Lab, Goombalab """ from __future__ import annotations from dataclasses import dataclass from typing import Optional, Tuple import torch from torch import Tensor import triton # Import kernels from mamba_ssm.ops.triton.mamba3.mamba3_siso_fwd import mamba3_siso_fwd from mamba_ssm.ops.triton.mamba3.mamba3_siso_bwd import compute_dzdo, compute_dqkv, compute_dqktheta, compute_ddt_dtrap_dinput_states from mamba_ssm.ops.triton.mamba3.angle_dt import angle_dt_fwd, angle_dt_bwd def _triton_alloc_fn(size: int, alignment: int, stream: Optional[int]): """Allocator for Triton runtime memory (TMA descriptors, scratch).""" return torch.empty(size, device="cuda", dtype=torch.int8) # Set allocator immediately at import time. try: triton.set_allocator(_triton_alloc_fn) except Exception: pass # Allocator may already be set @dataclass(frozen=True) class Mamba3Output: """Container for Mamba-3 outputs and optional intermediates. Attributes: out: Main output tensor (batch, seqlen, nheads, headdim_v) final_angle_state: Final angle state (num_sequences, nheads, headdim_angles) final_ssm_state: Final SSM state (num_sequences, nheads, headdim_v, headdim_qk) final_k_state: Final K state (num_sequences, nheads, headdim_qk) final_v_state: Final V state (num_sequences, nheads, headdim_v) """ out: Tensor final_angle_state: Optional[Tensor] = None final_ssm_state: Optional[Tensor] = None final_k_state: Optional[Tensor] = None final_v_state: Optional[Tensor] = None class _Mamba3Function(torch.autograd.Function): """Custom autograd function for Mamba-3 with Triton kernels.""" @staticmethod def forward( ctx, Q: Tensor, K: Tensor, V: Tensor, ADT: Tensor, DT: Tensor, Trap: Tensor, Q_bias: Tensor, K_bias: Tensor, Angles: Tensor, D: Optional[Tensor], Z: Optional[Tensor], Input_Angle_State: Optional[Tensor], Input_SSM_State: Optional[Tensor], Input_K_State: Optional[Tensor], Input_V_State: Optional[Tensor], cu_seqlens: Optional[Tensor], chunk_size: int, return_final_states: bool, ) -> Tensor | Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: """Forward pass: call Triton kernel and save tensors for backward.""" try: triton.set_allocator(_triton_alloc_fn) except Exception: pass needs_backward = any(ctx.needs_input_grad) has_varlen = cu_seqlens is not None all_states_present = (Input_SSM_State is not None) and (Input_K_State is not None) and (Input_V_State is not None) and (Input_Angle_State is not None) all_states_absent = (Input_SSM_State is None) and (Input_K_State is None) and (Input_V_State is None) and (Input_Angle_State is None) assert all_states_present or all_states_absent, "Input states must be provided together or all be None." Angles_Cumsum, Final_Angle_State = angle_dt_fwd( Angles, DT, init_state=Input_Angle_State, chunk_size=chunk_size, return_output_state=True, cu_seqlens=cu_seqlens, ) Input_States = ( (Input_SSM_State, Input_K_State, Input_V_State) if Input_SSM_State is not None else None ) Out, Out_v, SSM_States, DA_CS, DA_CS_SUM, Q_rot, K_scaled, QK_dot, Scale, Gamma, Final_States = mamba3_siso_fwd( Q, K, V, ADT, DT, Trap, Q_bias, K_bias, Angles_Cumsum, D, Z, Input_States, chunk_size=chunk_size, store_states_adt_outv=needs_backward, return_final_states=return_final_states, cu_seqlens=cu_seqlens, ) Final_SSM_State = Final_States[0] if Final_States is not None else None Final_K_State = Final_States[1] if Final_States is not None else None Final_V_State = Final_States[2] if Final_States is not None else None if needs_backward: ctx.chunk_size = chunk_size ctx.has_D = D is not None ctx.has_Z = Z is not None ctx.has_input_state = Input_SSM_State is not None ctx.return_final_states = return_final_states ctx.has_varlen = has_varlen # Save tensors - use empty tensor placeholders for None values D_save = D if D is not None else torch.empty((), device=Q.device) Z_save = Z if Z is not None else torch.empty((), device=Q.device) Input_SSM_State_save = Input_SSM_State if Input_SSM_State is not None else torch.empty((), device=Q.device) Input_K_State_save = Input_K_State if Input_K_State is not None else torch.empty((), device=Q.device) Input_V_State_save = Input_V_State if Input_V_State is not None else torch.empty((), device=Q.device) Final_SSM_State_save = Final_SSM_State if Final_SSM_State is not None else torch.empty((), device=Q.device) cu_seqlens_save = cu_seqlens if cu_seqlens is not None else torch.empty((), device=Q.device, dtype=torch.int32) ctx.save_for_backward( Q, K, V, ADT, DT, Trap, Q_bias, K_bias, Angles, Angles_Cumsum, D_save, Z_save, Input_SSM_State_save, Input_K_State_save, Input_V_State_save, Out, Out_v, SSM_States, DA_CS, DA_CS_SUM, Q_rot, K_scaled, QK_dot, Scale, Gamma, Final_SSM_State_save, cu_seqlens_save ) else: ctx.chunk_size = chunk_size ctx.has_D = D is not None ctx.has_Z = Z is not None ctx.has_input_state = Input_SSM_State is not None ctx.return_final_states = return_final_states ctx.has_varlen = has_varlen ctx.save_for_backward() if return_final_states: return Out, Final_Angle_State, Final_SSM_State, Final_K_State, Final_V_State return Out @staticmethod def backward( ctx, grad_out: Optional[Tensor] = None, grad_final_angle_state: Optional[Tensor] = None, grad_final_ssm_state: Optional[Tensor] = None, grad_final_k_state: Optional[Tensor] = None, grad_final_v_state: Optional[Tensor] = None ) -> tuple: """Backward pass: compute gradients using Triton backward kernels.""" try: triton.set_allocator(_triton_alloc_fn) except Exception: pass if len(ctx.saved_tensors) == 0: raise RuntimeError( "Backward called but forward ran without gradient tracking. " "Ensure inputs require grad or run under torch.enable_grad()." ) if grad_out is None and grad_final_ssm_state is None and grad_final_k_state is None and grad_final_v_state is None and grad_final_angle_state is None: raise RuntimeError("No gradients provided for backward pass.") (Q, K, V, ADT, DT, Trap, Q_bias, K_bias, Angles, Angles_Cumsum, D_save, Z_save, Input_SSM_State_save, Input_K_State_save, Input_V_State_save, Out, Out_v, SSM_States, DA_CS, DA_CS_SUM, Q_rot, K_scaled, QK_dot, Scale, Gamma, Final_SSM_State_save, cu_seqlens_save) = ctx.saved_tensors D = D_save if ctx.has_D else None Z = Z_save if ctx.has_Z else None Input_SSM_State = Input_SSM_State_save if ctx.has_input_state else None Input_K_State = Input_K_State_save if ctx.has_input_state else None Input_V_State = Input_V_State_save if ctx.has_input_state else None cu_seqlens = cu_seqlens_save if ctx.has_varlen else None if grad_out is None: grad_out = torch.zeros_like(Out) # Step 1: Compute dZ and scale grad_out if Z gating is present if Z is not None: dZ, grad_out_scaled = compute_dzdo( grad_out, Z, Out_v, chunk_size=ctx.chunk_size ) else: dZ = None grad_out_scaled = grad_out # Step 2: Compute main gradients (dQ_mid, dK_mid, dV, dADT, dQK_dot, dD, dInput_SSM_State) dQ_mid, dK_mid, dV, dADT, dQK_dot, dD, dInput_SSM_State = compute_dqkv( q=Q_rot, k=K_scaled, v=V, da_cs=DA_CS, da_cs_sum=DA_CS_SUM, qk_dot=QK_dot, SSM_States=SSM_States, do=grad_out_scaled, d_ossm_state=grad_final_ssm_state, d_ov_state=grad_final_v_state, D=D, chunk_size=ctx.chunk_size, has_input_state=ctx.has_input_state, Cu_Seqlens=cu_seqlens, ) # Step 3: Compute gradients through rotary embeddings and biases dQ, dK, dQ_bias, dK_bias, dAngles_Cumsum, dScale, dGamma = compute_dqktheta( q=Q, k=K, scale=Scale, gamma=Gamma, q_bias=Q_bias, k_bias=K_bias, angles=Angles_Cumsum, dq_in=dQ_mid, dk_in=dK_mid, dqk=dQK_dot, d_ok_state=grad_final_k_state, chunk_size=ctx.chunk_size, Cu_Seqlens=cu_seqlens, ) # Step 4: Compute dDT, dTrap, and input state gradients dDT, dTrap, dInput_SSM_State_final, dInput_K_State, dInput_V_State = compute_ddt_dtrap_dinput_states( dscale=dScale, dgamma=dGamma, dt=DT, trap=Trap.float(), d_issm_state=dInput_SSM_State if ctx.has_input_state else None, input_k_state=Input_K_State, input_v_state=Input_V_State, Cu_Seqlens=cu_seqlens, ) # Step 5: Compute gradients through angle_dt cumsum dAngles, dDT_angle, dInput_Angle_State = angle_dt_bwd( grad_out=dAngles_Cumsum, angle=Angles, dt=DT, has_init_state=ctx.has_input_state, chunk_size=ctx.chunk_size, grad_output_state=grad_final_angle_state if ctx.return_final_states else None, cu_seqlens=cu_seqlens, ) # Accumulate DT gradients from angle_dt backward dDT = dDT + dDT_angle if ctx.has_input_state: dInput_SSM_State = dInput_SSM_State_final else: dInput_SSM_State = None dInput_K_State = None dInput_V_State = None dInput_Angle_State = None return ( dQ, # Q dK, # K dV, # V dADT, # ADT dDT, # DT dTrap, # Trap dQ_bias, # Q_bias dK_bias, # K_bias dAngles, # Angles dD, # D dZ, # Z dInput_Angle_State, # Input_Angle_State dInput_SSM_State, # Input_SSM_State dInput_K_State, # Input_K_State dInput_V_State, # Input_V_State None, # cu_seqlens (not differentiable) None, # chunk_size (not differentiable) None, # return_final_states (not differentiable) ) def mamba3_siso_combined( Q: Tensor, K: Tensor, V: Tensor, ADT: Tensor, DT: Tensor, Trap: Tensor, Q_bias: Tensor, K_bias: Tensor, Angles: Tensor, D: Optional[Tensor] = None, Z: Optional[Tensor] = None, Input_States: Optional[Tuple[Tensor, Tensor, Tensor, Tensor]] = None, chunk_size: int = 64, return_final_states: bool = False, cu_seqlens: Optional[Tensor] = None, ) -> Tensor | Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: """Mamba-3 attention with Triton kernels and automatic differentiation. This is the main entry point for Mamba-3 forward and backward passes using optimized Triton kernels. Supports GQA (grouped-query attention), rotary position embeddings, optional gating, skip connections, state passing for recurrent inference, and variable-length sequences. Internally computes cumulative angles: Angles_Cumsum = cumsum(Angles * DT) mod 2π Args: Q: Query tensor (batch, seqlen, nheads_qk, headdim_qk) K: Key tensor (batch, seqlen, nheads_qk, headdim_qk) V: Value tensor (batch, seqlen, nheads, headdim_v) ADT: Decay factor A * dt (batch, nheads, seqlen) DT: Time delta tensor dt (batch, nheads, seqlen) Trap: Trapezoidal factor (batch, nheads, seqlen) Mixing factor in [0, 1] for trapezoidal discretization. Q_bias: Query bias (nheads, headdim_qk) K_bias: Key bias (nheads, headdim_qk) Angles: Rotary angle rates (batch, seqlen, nheads, headdim_angles) Raw angle values that get accumulated via cumsum(Angles * DT). If headdim_angles < headdim_qk // 2, remaining dims are unrotated. D: Skip connection (nheads,) Optional per-head skip connection weight applied to V. Z: Gating tensor (batch, seqlen, nheads, headdim_v) Optional gating applied as: out = out * silu(Z). Input_States: Optional initial state tuple for recurrent inference. Angle State: (num_sequences, nheads, headdim_angles) SSM State: (num_sequences, nheads, headdim_v, headdim_qk) K State: (num_sequences, nheads, headdim_qk) V State: (num_sequences, nheads, headdim_v) chunk_size: Chunk size for chunked state computation (default: 64). return_final_states: If True, return final states for recurrent inference. cu_seqlens: Cumulative sequence lengths for variable-length support. Shape: (num_sequences + 1,), dtype: torch.int32. Example: [0, 128, 256, 512] for 3 sequences of lengths 128, 128, 256. When using cu_seqlens, batch must be 1 and the seqlen dimension contains all sequences concatenated. Returns: If return_final_states=False: out: Output tensor (batch, seqlen, nheads, headdim_v) If return_final_states=True: Tuple of: out: Output tensor (batch, seqlen, nheads, headdim_v) final_angle_state: Angle state (num_sequences, nheads, headdim_angles) final_ssm_state: SSM state (num_sequences, nheads, headdim_v, headdim_qk) final_k_state: K state (num_sequences, nheads, headdim_qk) final_v_state: V state (num_sequences, nheads, headdim_v) Notes: - For GQA: nheads must be divisible by nheads_qk. - headdim_qk and headdim_v must be powers of two for TMA compatiblity, - Variable-length mode (cu_seqlens is not None) requires batch == 1. - num_sequences = batch for batched mode, len(cu_seqlens)-1 for varlen mode. Performance Notes: The kernel is optimized for: nheads_qk=1, nheads=32, headdim_qk=128, headdim_v=64, chunk_size=64. """ batch, seqlen, nheads_qk, headdim_qk = Q.shape _, _, nheads, headdim_v = V.shape assert nheads % nheads_qk == 0, f"nheads ({nheads}) must be divisible by nheads_qk ({nheads_qk})" assert headdim_qk % 2 == 0, f"headdim_qk ({headdim_qk}) must be even for rotary embeddings" # Varlen mode checks has_varlen = cu_seqlens is not None if has_varlen: if batch != 1: raise ValueError(f"Batch size must be 1 with variable-length sequences (cu_seqlens), got {batch}.") Input_Angle_State, Input_SSM_State, Input_K_State, Input_V_State = ( Input_States if Input_States is not None else (None, None, None, None) ) all_states_present = (Input_SSM_State is not None) and (Input_K_State is not None) and (Input_V_State is not None) and (Input_Angle_State is not None) all_states_absent = (Input_SSM_State is None) and (Input_K_State is None) and (Input_V_State is None) and (Input_Angle_State is None) assert all_states_present or all_states_absent, "Input states must be provided together or all be None." return _Mamba3Function.apply( Q, K, V, ADT, DT, Trap, Q_bias, K_bias, Angles, D, Z, Input_Angle_State, Input_SSM_State, Input_K_State, Input_V_State, cu_seqlens, chunk_size, return_final_states ) ================================================ FILE: mamba_ssm/ops/triton/mamba3/mamba3_siso_fwd.py ================================================ """ Mamba-3 SISO Forward Pass Triton Kernel. Copyright (c) 2025, Dao AI Lab, Goombalab """ from typing import Optional, Tuple import math import torch import torch.nn.functional as F from einops import rearrange, repeat import triton import triton.language as tl from mamba_ssm.ops.triton.mamba3.utils import cos_approx, sin_approx, tanh_approx, silu, sigmoid_approx @triton.autotune( configs=[ triton.Config({}, num_stages=s, num_warps=w) for s in [1, 2, 3] for w in [2, 4, 8] ], key=[ "CHUNK_SIZE", "HEADDIM_QK", "HEADDIM_V", "STORE_SSM_STATES_ADT_OUTV", "HAS_D", "HAS_Z", "HAS_INITIAL_STATES", "RETURN_FINAL_STATES", "IS_VARLEN"], ) @triton.jit def mamba3_siso_fwd_kernel( # Inputs Q, K, V, ADT, DT, Trap, Q_bias, K_bias, Angles, D, Z, Initial_SSM_State, Initial_K_State, Initial_V_State, Cu_Seqlens, # Outputs Out, Out_v, SSM_States, DA_CS_Store, DA_CS_SUM_Store, Q_store, K_store, QK_store, Scale_store, Gamma_store, Final_SSM_State, Final_K_State, # Input Strides stride_q_batch, stride_q_seqlen, stride_q_head, stride_q_qkdim, stride_k_batch, stride_k_seqlen, stride_k_head, stride_k_qkdim, stride_v_batch, stride_v_seqlen, stride_v_head, stride_v_vdim, stride_adt_batch, stride_adt_head, stride_adt_seqlen, stride_dt_batch, stride_dt_head, stride_dt_seqlen, stride_trap_batch, stride_trap_head, stride_trap_seqlen, stride_q_bias_head, stride_q_bias_qkdim, stride_k_bias_head, stride_k_bias_qkdim, stride_angles_batch, stride_angles_seqlen, stride_angles_head, stride_angles_qkdim, stride_d_head, stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_vdim, stride_init_ssm_state_seq, stride_init_ssm_state_head, stride_init_ssm_state_vdim, stride_init_ssm_state_qkdim, stride_init_k_state_seq, stride_init_k_state_head, stride_init_k_state_qkdim, stride_init_v_state_seq, stride_init_v_state_head, stride_init_v_state_vdim, stride_cu_seqlen, # Output Strides stride_o_batch, stride_o_seqlen, stride_o_head, stride_o_vdim, stride_o_v_batch, stride_o_v_seqlen, stride_o_v_head, stride_o_v_vdim, stride_ssm_states_batch, stride_ssm_states_head, stride_ssm_states_vdim, stride_ssm_states_qkdim, stride_da_cs_store_batch, stride_da_cs_store_head, stride_da_cs_store_seqlen, stride_da_cs_sum_store_batch, stride_da_cs_sum_store_head, stride_da_cs_sum_store_seqlen, stride_q_store_batch, stride_q_store_seqlen, stride_q_store_head, stride_q_store_qkdim, stride_k_store_batch, stride_k_store_seqlen, stride_k_store_head, stride_k_store_qkdim, stride_qk_store_batch, stride_qk_store_head, stride_qk_store_seqlen, stride_scale_store_batch, stride_scale_store_head, stride_scale_store_seqlen, stride_gamma_store_batch, stride_gamma_store_head, stride_gamma_store_seqlen, stride_final_ssm_state_seq, stride_final_ssm_state_head, stride_final_ssm_state_vdim, stride_final_ssm_state_qkdim, stride_final_k_state_seq, stride_final_k_state_head, stride_final_k_state_chunk, stride_final_k_state_qkdim, # Dimensions seqlen, nheads_qk, headdim_qk, headdim_v, headdim_angles, CHUNK_SIZE: tl.constexpr, HEADDIM_QK: tl.constexpr, HEADDIM_V: tl.constexpr, STORE_SSM_STATES_ADT_OUTV: tl.constexpr, HAS_INITIAL_STATES: tl.constexpr, RETURN_FINAL_STATES: tl.constexpr, HAS_D: tl.constexpr, HAS_Z: tl.constexpr, IS_VARLEN: tl.constexpr, ): """ Mamba-3 forward kernel. Grid: (nheads, batch) for batched, (nheads, 1, num_sequences) for varlen Inputs: Q, K: (batch, seqlen, nheads_qk, headdim_qk) V: (batch, seqlen, nheads, headdim_v) ADT, DT, Trap: (batch, nheads, seqlen) Q_bias, K_bias: (nheads, headdim_qk) Angles: (batch, seqlen, nheads, headdim_angles) D: (nheads,) Z: (batch, seqlen, nheads, headdim_v) Initial SSM State: (num_sequences, nheads, headdim_v, headdim_qk) Initial K State: (num_sequences, nheads, headdim_qk) Initial V State: (num_sequences, nheads, headdim_v) Cu_Seqlens: (num_sequences + 1,) NOTE: num_sequences = batch for batched mode, or len(cu_seqlens)-1 for varlen mode. Compile-time constants: CHUNK_SIZE: Chunk size for processing sequences HEADDIM_QK: Head dimension for Q/K HEADDIM_V: Head dimension for V STORE_SSM_STATES_ADT_OUTV: Whether to store SSM states, ADT, and Out_v for backward pass Set to FALSE for inference-only runs for efficiency HAS_INITIAL_STATES: Whether input SSM states are provided for state passing RETURN_FINAL_STATES: Whether to return final SSM states for state passing HAS_D: Whether D-skip connection is used HAS_Z: Whether Z-gating is used IS_VARLEN: Whether the input is a variable-length sequence NOTE: 1. nheads % nheads_qk == 0 2. Kernel is optimized for headdim_qk = 128 and headdim_v = 64 Outputs: Out: (batch, seqlen, nheads, headdim_v) Out_v: (batch, seqlen, nheads, headdim_v) (if STORE_SSM_STATES_ADT_OUTV) SSM_States: (batch, nheads, headdim_v, nchunks * headdim_qk) (if STORE_SSM_STATES_ADT_OUTV) DA_CS_Store: (batch, nheads, seqlen) (if STORE_SSM_STATES_ADT_OUTV) DA_CS_SUM_Store: (batch, nheads, nchunks) (if STORE_SSM_STATES_ADT_OUTV) Q_store: (batch, seqlen, nheads, headdim_qk) K_store: (batch, seqlen, nheads, headdim_qk) QK_store: (batch, seqlen, nheads) Scale_store: (batch, seqlen, nheads) Gamma_store: (batch, seqlen, nheads) Final SSM State: (num_sequences, nheads, headdim_v, headdim_qk) (if RETURN_FINAL_STATES) Final K State: (num_sequences, nheads, chunk_size, headdim_qk) (if RETURN_FINAL_STATES) NOTE: 1. For batched inputs, nchunks = ceil(seqlen / CHUNK_SIZE) and for varlen inputs, nchunks = num_sequences + total_seqlen//CHUNK_SIZE. 2. Final K state has an additional chunk_size dimension since triton does not allow indexing within a chunk. We pick the correct index in the wrapper. """ pid_head = tl.program_id(0) pid_batch = tl.program_id(1) if IS_VARLEN: pid_seq = tl.program_id(2) seq_idx = pid_seq cu_seqlen_start = tl.load(Cu_Seqlens + pid_seq * stride_cu_seqlen).to(tl.int32) cu_seqlen_end = tl.load(Cu_Seqlens + (pid_seq + 1) * stride_cu_seqlen).to(tl.int32) total_seqlen = seqlen seqlen = cu_seqlen_end - cu_seqlen_start seq_offset = cu_seqlen_start chunk_offset = pid_seq + cu_seqlen_start // CHUNK_SIZE else: seq_idx = pid_batch seq_offset = 0 chunk_offset = 0 num_chunks = tl.cdiv(seqlen, CHUNK_SIZE) # Compute head index for Q/K (supports Grouped Query Attention) nheads = tl.num_programs(0) head_idx_qk = pid_head // (nheads // nheads_qk) # Setup input pointers q_ptr = Q + pid_batch * stride_q_batch + head_idx_qk * stride_q_head + seq_offset * stride_q_seqlen k_ptr = K + pid_batch * stride_k_batch + head_idx_qk * stride_k_head + seq_offset * stride_k_seqlen v_ptr = V + pid_batch * stride_v_batch + pid_head * stride_v_head + seq_offset * stride_v_seqlen adt_ptr = ADT + pid_batch * stride_adt_batch + pid_head * stride_adt_head + seq_offset * stride_adt_seqlen dt_ptr = DT + pid_batch * stride_dt_batch + pid_head * stride_dt_head + seq_offset * stride_dt_seqlen trap_ptr = Trap + pid_batch * stride_trap_batch + pid_head * stride_trap_head + seq_offset * stride_trap_seqlen q_bias_ptr = Q_bias + pid_head * stride_q_bias_head k_bias_ptr = K_bias + pid_head * stride_k_bias_head angle_ptr = Angles + pid_batch * stride_angles_batch + pid_head * stride_angles_head + seq_offset * stride_angles_seqlen if HAS_D: D_ptr = D + pid_head * stride_d_head D_val = tl.load(D_ptr).to(tl.float32) if HAS_Z: z_ptr = Z + pid_batch * stride_z_batch + pid_head * stride_z_head + seq_offset * stride_z_seqlen # State pointers use seq_idx (unified for batched and varlen) if HAS_INITIAL_STATES: init_ssm_state_ptr = Initial_SSM_State + seq_idx * stride_init_ssm_state_seq + pid_head * stride_init_ssm_state_head init_k_state_ptr = Initial_K_State + seq_idx * stride_init_k_state_seq + pid_head * stride_init_k_state_head init_v_state_ptr = Initial_V_State + seq_idx * stride_init_v_state_seq + pid_head * stride_init_v_state_head # Setup output pointers o_ptr = Out + pid_batch * stride_o_batch + pid_head * stride_o_head + seq_offset * stride_o_seqlen if STORE_SSM_STATES_ADT_OUTV: out_v_ptr = Out_v + pid_batch * stride_o_v_batch + pid_head * stride_o_v_head + seq_offset * stride_o_v_seqlen ssm_states_ptr = SSM_States + pid_batch * stride_ssm_states_batch + pid_head * stride_ssm_states_head + chunk_offset * HEADDIM_QK * stride_ssm_states_qkdim da_cs_store_ptr = DA_CS_Store + pid_batch * stride_da_cs_store_batch + pid_head * stride_da_cs_store_head + seq_offset * stride_da_cs_store_seqlen da_cs_sum_store_ptr = DA_CS_SUM_Store + pid_batch * stride_da_cs_sum_store_batch + pid_head * stride_da_cs_sum_store_head + chunk_offset * stride_da_cs_sum_store_seqlen q_store_ptr = Q_store + pid_batch * stride_q_store_batch + pid_head * stride_q_store_head + seq_offset * stride_q_store_seqlen k_store_ptr = K_store + pid_batch * stride_k_store_batch + pid_head * stride_k_store_head + seq_offset * stride_k_store_seqlen qk_store_ptr = QK_store + pid_batch * stride_qk_store_batch + pid_head * stride_qk_store_head + seq_offset * stride_qk_store_seqlen scale_store_ptr = Scale_store + pid_batch * stride_scale_store_batch + pid_head * stride_scale_store_head + seq_offset * stride_scale_store_seqlen gamma_store_ptr = Gamma_store + pid_batch * stride_gamma_store_batch + pid_head * stride_gamma_store_head + seq_offset * stride_gamma_store_seqlen if RETURN_FINAL_STATES: final_ssm_state_ptr = Final_SSM_State + seq_idx * stride_final_ssm_state_seq + pid_head * stride_final_ssm_state_head final_k_state_ptr = Final_K_State + seq_idx * stride_final_k_state_seq + pid_head * stride_final_k_state_head # Create TMA tensor descriptors q_desc = tl.make_tensor_descriptor( q_ptr, shape=[seqlen, headdim_qk], strides=[stride_q_seqlen, stride_q_qkdim], block_shape=[CHUNK_SIZE, HEADDIM_QK], ) k_desc = tl.make_tensor_descriptor( k_ptr, shape=[seqlen, headdim_qk], strides=[stride_k_seqlen, stride_k_qkdim], block_shape=[CHUNK_SIZE, HEADDIM_QK], ) v_desc = tl.make_tensor_descriptor( v_ptr, shape=[seqlen, headdim_v], strides=[stride_v_seqlen, stride_v_vdim], block_shape=[CHUNK_SIZE, HEADDIM_V], ) if HAS_Z: z_desc = tl.make_tensor_descriptor( z_ptr, shape=[seqlen, headdim_v], strides=[stride_z_seqlen, stride_z_vdim], block_shape=[CHUNK_SIZE, HEADDIM_V], ) q_store_desc = tl.make_tensor_descriptor( q_store_ptr, shape=[seqlen, headdim_qk], strides=[stride_q_store_seqlen, stride_q_store_qkdim], block_shape=[CHUNK_SIZE, HEADDIM_QK], ) k_store_desc = tl.make_tensor_descriptor( k_store_ptr, shape=[seqlen, headdim_qk], strides=[stride_k_store_seqlen, stride_k_store_qkdim], block_shape=[CHUNK_SIZE, HEADDIM_QK], ) o_desc = tl.make_tensor_descriptor( o_ptr, shape=[seqlen, headdim_v], strides=[stride_o_seqlen, stride_o_vdim], block_shape=[CHUNK_SIZE, HEADDIM_V], ) if STORE_SSM_STATES_ADT_OUTV: ssm_states_desc = tl.make_tensor_descriptor( ssm_states_ptr, shape=[headdim_v, num_chunks * headdim_qk], strides=[stride_ssm_states_vdim, stride_ssm_states_qkdim], block_shape=[HEADDIM_V, HEADDIM_QK], ) # Phase 1: Preprocessing - Apply bias, rotary embeddings, compute QK dots. for chunk_idx in range(num_chunks): chunk_start = chunk_idx * CHUNK_SIZE offs_seqlen = chunk_start + tl.arange(0, CHUNK_SIZE) offs_hd = tl.arange(0, HEADDIM_QK) offs_hdr = tl.arange(0, HEADDIM_QK // 2) # Load Q and K blocks via TMA q_pre_block = q_desc.load([chunk_start, 0]) k_pre_block = k_desc.load([chunk_start, 0]) # Load rotary angles angle_block = tl.load( angle_ptr + offs_seqlen[:, None] * stride_angles_seqlen + offs_hdr[None, :] * stride_angles_qkdim, mask=(offs_seqlen[:, None] < seqlen) & (offs_hdr[None, :] < headdim_angles), other=0.0 ) # Compute shifted gamma and scale dt = tl.load(dt_ptr + offs_seqlen * stride_dt_seqlen, mask=offs_seqlen < seqlen, other=0.0).to(tl.float32) dt_shifted = tl.load( dt_ptr + (offs_seqlen + 1) * stride_dt_seqlen, mask=offs_seqlen + 1 < seqlen, other=0.0).to(tl.float32) trap = tl.load(trap_ptr + offs_seqlen * stride_trap_seqlen, mask=offs_seqlen < seqlen, other=0.0).to(tl.float32) trap = sigmoid_approx(trap) trap_shifted = tl.load( trap_ptr + (offs_seqlen + 1) * stride_trap_seqlen, mask=offs_seqlen + 1 < seqlen, other=0.0).to(tl.float32) trap_shifted = sigmoid_approx(trap_shifted) shifted_gamma = dt_shifted * (1 - trap_shifted) gamma = dt * trap scale = shifted_gamma + gamma # Store scale and shifted gamma for backward pass tl.store(gamma_store_ptr + offs_seqlen * stride_gamma_store_seqlen, gamma, mask=offs_seqlen < seqlen) tl.store(scale_store_ptr + offs_seqlen * stride_scale_store_seqlen, scale, mask=offs_seqlen < seqlen) # Add biases to Q and K q_bias_block = tl.load(q_bias_ptr + offs_hd * stride_q_bias_qkdim, offs_hd < headdim_qk) q_pre_block += q_bias_block[None, :] k_bias_block = tl.load(k_bias_ptr + offs_hd * stride_k_bias_qkdim, offs_hd < headdim_qk) k_pre_block += k_bias_block[None, :] # Compute QK dot products for skip connection store_qk_dot = tl.dot( q_pre_block * k_pre_block, tl.full([HEADDIM_QK, 1], 1, dtype=q_pre_block.dtype) ).to(q_pre_block.dtype) store_qk_dot = store_qk_dot.reshape(CHUNK_SIZE) store_qk_dot *= gamma tl.store(qk_store_ptr + offs_seqlen * stride_qk_store_seqlen, store_qk_dot, mask=offs_seqlen < seqlen) # Compute rotary embedding cos/sin cos_block = cos_approx(angle_block.to(tl.float32)) sin_block = sin_approx(angle_block.to(tl.float32)) # Apply rotary embeddings to K and scale k0, k1 = tl.split(tl.reshape(k_pre_block, [CHUNK_SIZE, HEADDIM_QK // 2, 2])) ko0 = k0 * cos_block - k1 * sin_block ko1 = k0 * sin_block + k1 * cos_block k_pre_block = tl.reshape(tl.join(ko0, ko1), [CHUNK_SIZE, HEADDIM_QK]).to(k_pre_block.dtype) if chunk_idx == num_chunks - 1 and RETURN_FINAL_STATES: tl.store(final_k_state_ptr + tl.arange(0, CHUNK_SIZE)[:, None] * stride_final_k_state_chunk + offs_hd[None, :] * stride_final_k_state_qkdim, k_pre_block, mask=(offs_hd[None, :] < headdim_qk)) k_pre_block *= scale[:, None] k_store_desc.store([chunk_start, 0], k_pre_block) # Apply rotary embeddings to Q q0, q1 = tl.split(tl.reshape(q_pre_block, [CHUNK_SIZE, HEADDIM_QK // 2, 2])) qo0 = q0 * cos_block - q1 * sin_block qo1 = q0 * sin_block + q1 * cos_block q_pre_block = tl.reshape(tl.join(qo0, qo1), [CHUNK_SIZE, HEADDIM_QK]).to(q_pre_block.dtype) q_store_desc.store([chunk_start, 0], q_pre_block) # Phase 2: Main computation and output generation. if HAS_INITIAL_STATES: acc_ssm_states = tl.load( init_ssm_state_ptr + tl.arange(0, HEADDIM_V)[:, None] * stride_init_ssm_state_vdim + tl.arange(0, HEADDIM_QK)[None, :] * stride_init_ssm_state_qkdim, mask= (tl.arange(0, HEADDIM_V)[:, None] < headdim_v) & (tl.arange(0, HEADDIM_QK)[None, :] < headdim_qk), other=0.0).to(tl.float32) input_k_state = tl.load( init_k_state_ptr + tl.arange(0, HEADDIM_QK) * stride_init_k_state_qkdim, mask=tl.arange(0, HEADDIM_QK) < headdim_qk, other=0.0).to(tl.float32) input_v_state = tl.load( init_v_state_ptr + tl.arange(0, HEADDIM_V) * stride_init_v_state_vdim, mask=tl.arange(0, HEADDIM_V) < headdim_v, other=0.0).to(tl.float32) dt_scalar = tl.load(dt_ptr).to(tl.float32) trap_scalar = tl.load(trap_ptr).to(tl.float32) trap_scalar = sigmoid_approx(trap_scalar) # Step on the SSM states with input K/V states to account for trapezoidal discretization acc_ssm_states += input_v_state[:, None] * input_k_state[None, :] * dt_scalar * (1 - trap_scalar) else: acc_ssm_states = tl.zeros([HEADDIM_V, HEADDIM_QK], dtype=tl.float32) if HAS_D: D_val = tl.load(D_ptr).to(tl.float32) else: D_val = 0.0 for chunk_idx in range(num_chunks): chunk_start = chunk_idx * CHUNK_SIZE offs_seqlen = chunk_start + tl.arange(0, CHUNK_SIZE) # Load decay factors (log2 scale for exp2 computation) adt_ptrs = adt_ptr + offs_seqlen * stride_adt_seqlen da = tl.load(adt_ptrs, mask=offs_seqlen < seqlen, other=0.0) * 1.44269504089 # log2(e) # Load preprocessed Q, K, V blocks q_block = q_store_desc.load([chunk_start, 0]) k_block = k_store_desc.load([chunk_start, 0]) v_block = v_desc.load([chunk_start, 0]) if HAS_Z: z_block = z_desc.load([chunk_start, 0]) # Compute cumulative decay for this chunk da_cs = tl.cumsum(da) da_cs_last = tl.sum(da) da_cs_rev = da_cs_last - da_cs # Store decay info for backward pass if STORE_SSM_STATES_ADT_OUTV: tl.store(da_cs_store_ptr + offs_seqlen * stride_da_cs_store_seqlen, da_cs, mask=offs_seqlen < seqlen) tl.store(da_cs_sum_store_ptr + chunk_idx * stride_da_cs_sum_store_seqlen, da_cs_last) # Output contribution from previous state: Q @ SSM_States^T * exp(da_cs) acc_o = tl.dot(q_block, tl.trans(acc_ssm_states).to(q_block.dtype)) acc_o *= tl.math.exp2(da_cs)[:, None] # Output contribution from current chunk: causal(Q @ K^T * exp(decay)) @ V # NOTE: We compute the (i,i) component using QK dot to prevent non-causal numerical leakage s_block = tl.dot(q_block, tl.trans(k_block)) s_block *= tl.math.exp2(tl.minimum((da_cs[:, None] - da_cs[None, :]), 0.0)) s_block = tl.where( tl.arange(0, CHUNK_SIZE)[:, None] > tl.arange(0, CHUNK_SIZE)[None, :], s_block, 0.0 ) acc_o += tl.dot(s_block.to(v_block.dtype), v_block) # Add D-skip connection and subtract QK dot contribution qk_dot = tl.load(qk_store_ptr + offs_seqlen * stride_qk_store_seqlen, mask=offs_seqlen < seqlen, other=0.0) acc_o += (D_val + qk_dot)[:, None] * v_block if STORE_SSM_STATES_ADT_OUTV: tl.store(out_v_ptr + offs_seqlen[:, None] * stride_o_v_seqlen + tl.arange(0, HEADDIM_V)[None, :] * stride_o_v_vdim, acc_o, mask=(offs_seqlen[:, None] < seqlen) & (tl.arange(0, HEADDIM_V)[None, :] < headdim_v)) # Apply Z-gating if present if HAS_Z: acc_o = acc_o * silu(z_block.to(tl.float32)) # Store output o_desc.store([chunk_start, 0], acc_o) if STORE_SSM_STATES_ADT_OUTV: ssm_states_desc.store([0, chunk_idx * headdim_qk], acc_ssm_states.to(ssm_states_desc.dtype)) # Update recurrent states scale = tl.math.exp2(da_cs_rev) v_block *= scale[:, None] acc_ssm_states = acc_ssm_states * tl.math.exp2(da_cs_last) + tl.dot( tl.trans(v_block).to(k_block.dtype), k_block ) # Store final states if requested if RETURN_FINAL_STATES: tl.store(final_ssm_state_ptr + tl.arange(0, HEADDIM_V)[:, None] * stride_final_ssm_state_vdim + tl.arange(0, HEADDIM_QK)[None, :] * stride_final_ssm_state_qkdim, acc_ssm_states, mask=(tl.arange(0, HEADDIM_V)[:, None] < headdim_v) & (tl.arange(0, HEADDIM_QK)[None, :] < headdim_qk)) # Memory Allocator for TMA Descriptors def _alloc_fn(size: int, alignment: int, stream: Optional[int]): """Custom allocator for TMA descriptor global memory allocation.""" return torch.empty(size, device="cuda", dtype=torch.int8) triton.set_allocator(_alloc_fn) def mamba3_siso_fwd( Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, ADT: torch.Tensor, DT: torch.Tensor, Trap: torch.Tensor, Q_bias: torch.Tensor, K_bias: torch.Tensor, Angles: torch.Tensor, D: Optional[torch.Tensor] = None, Z: Optional[torch.Tensor] = None, Initial_States: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None, chunk_size: int = 64, store_states_adt_outv: bool = False, return_final_states: bool = False, cu_seqlens: Optional[torch.Tensor] = None, ): """ Mamba-3 forward pass wrapper. Args: Q: Query tensor (batch, seqlen, nheads_qk, headdim_qk) K: Key tensor (batch, seqlen, nheads_qk, headdim_qk) V: Value tensor (batch, seqlen, nheads, headdim_v) ADT: Decay tensor (batch, nheads, seqlen) DT: DT tensor (batch, nheads, seqlen) Trap: Trap tensor (batch, nheads, seqlen) Q_bias: Query bias (nheads, headdim_qk) K_bias: Key bias (nheads, headdim_qk) Angles: Rotary angles (batch, seqlen, nheads, headdim_angles) - headdim_angles <= headdim_qk // 2 and headdim_angles % 2 == 0. D: Skip connection weight (nheads,) Z: Gating tensor (batch, seqlen, nheads, headdim_v) - Applies SiLU gating: out = out * silu(Z). Initial_States: Tuple of (SSM_State, K_State, V_State) SSM State shape: (num_sequences, nheads, headdim_v, headdim_qk). K state shape: (num_sequences, nheads, headdim_qk). V state shape: (num_sequences, nheads, headdim_v). - K state is post bias and rotation and pre scaling cu_seqlens: Cumulative sequence lengths (num_sequences + 1,) for varlen chunk_size: Chunk size for processing store_states_adt_outv: Store intermediate states for backward pass return_final_states: Return final states Returns: Out: Output tensor (batch, seqlen, nheads, headdim_v) Out_v: Pre-gate output tensor (batch, seqlen, nheads, headdim_v) (if store_states_adt_outv) SSM_States: Per-chunk SSM States (batch, nheads, headdim_v, nchunks * headdim_qk) (if store_states_adt_outv) DA_CS_Store: Cumulative decay (batch, nheads, seqlen) (if store_states_adt_outv) DA_CS_SUM_Store: Chunk decay sum (batch, nheads, nchunks) (if store_states_adt_outv) Q_store: Rotated Q+bias (batch, seqlen, nheads, headdim_qk) (None if store_states_adt_outv=False) K_store: Rotated K+bias (batch, seqlen, nheads, headdim_qk) (None if store_states_adt_outv=False) QK_store: QK dot products (batch, nheads, seqlen) (None if store_states_adt_outv=False) Scale_store: Scale factors (batch, nheads, seqlen) (None if store_states_adt_outv=False) Gamma_store: Gamma factors (batch, nheads, seqlen) (None if store_states_adt_outv=False) Final States: Final output state (None if return_output_state=False) Final SSM State (num_sequences, nheads, headdim_v, headdim_qk) Final K state (num_sequences, nheads, headdim_qk) Final V state (num_sequences, nheads, headdim_v) Notes: 1. For varlen mode: batch must be 1, cu_seqlens required 2. num_sequences = batch for batched mode, len(cu_seqlens)-1 for varlen 3. nheads % nheads_qk == 0 4. nchunks = ceil(seqlen / chunk_size) for batched mode, num_sequences + total_seqlen//chunk_size for varlen mode. COMMENT: Design choice to store: Q_store, K_store, QK_store, is primarily an artifact of Triton's lack of programmatic access to shared memory---In the forward pass, we compute, store and then re-load these tensors in shared memory (using TMA) to prevent register spilling. """ batch, seqlen, nheads_qk, headdim_qk = Q.shape _, _, nheads, headdim_v = V.shape device = Q.device is_varlen = cu_seqlens is not None assert seqlen > 0, "Sequence length must be greater than 0" # Determine number of sequences if is_varlen: assert batch == 1, "Varlen mode requires batch=1" num_sequences = cu_seqlens.shape[0] - 1 else: num_sequences = batch cu_seqlens = None # Validate shapes assert Q.shape == K.shape, f"Q and K shape mismatch: {Q.shape} vs {K.shape}" assert nheads % nheads_qk == 0, f"nheads ({nheads}) must be divisible by nheads_qk ({nheads_qk})" assert ADT.shape == (batch, nheads, seqlen) assert DT.shape == (batch, nheads, seqlen) assert Trap.shape == (batch, nheads, seqlen) assert Q_bias.shape == (nheads, headdim_qk) assert K_bias.shape == (nheads, headdim_qk) headdim_angles = Angles.shape[-1] assert headdim_angles <= headdim_qk // 2 and headdim_angles % 2 == 0 assert Angles.shape == (batch, seqlen, nheads, headdim_angles) if D is not None: assert D.shape == (nheads,) if Z is not None: assert Z.shape == (batch, seqlen, nheads, headdim_v) if Initial_States is not None: Init_SSM_State, Init_K_State, Init_V_State = Initial_States assert Init_SSM_State.shape == (num_sequences, nheads, headdim_v, headdim_qk), \ f"Initial_States[0] shape mismatch: expected {(num_sequences, nheads, headdim_v, headdim_qk)}, got {Init_SSM_State.shape}" assert Init_K_State.shape == (num_sequences, nheads, headdim_qk), \ f"Initial_States[1] shape mismatch: expected {(num_sequences, nheads, headdim_qk)}, got {Init_K_State.shape}" assert Init_V_State.shape == (num_sequences, nheads, headdim_v), \ f"Initial_States[2] shape mismatch: expected {(num_sequences, nheads, headdim_v)}, got {Init_V_State.shape}" else: Init_SSM_State, Init_K_State, Init_V_State = None, None, None # Ensure contiguous Q = Q.contiguous() if not Q.is_contiguous() else Q K = K.contiguous() if not K.is_contiguous() else K V = V.contiguous() if not V.is_contiguous() else V ADT = ADT.contiguous() if not ADT.is_contiguous() else ADT DT = DT.contiguous() if not DT.is_contiguous() else DT Trap = Trap.contiguous() if not Trap.is_contiguous() else Trap Q_bias = Q_bias.contiguous() if not Q_bias.is_contiguous() else Q_bias K_bias = K_bias.contiguous() if not K_bias.is_contiguous() else K_bias Angles = Angles.contiguous() if not Angles.is_contiguous() else Angles if D is not None: D = D.contiguous() if not D.is_contiguous() else D if Z is not None: Z = Z.contiguous() if not Z.is_contiguous() else Z if Initial_States is not None: Init_SSM_State = Init_SSM_State.contiguous() if not Init_SSM_State.is_contiguous() else Init_SSM_State Init_K_State = Init_K_State.contiguous() if not Init_K_State.is_contiguous() else Init_K_State Init_V_State = Init_V_State.contiguous() if not Init_V_State.is_contiguous() else Init_V_State # Calculate nchunks if is_varlen: nchunks = num_sequences + seqlen // chunk_size else: nchunks = (seqlen + chunk_size - 1) // chunk_size # Allocate output tensors Out = torch.empty((batch, seqlen, nheads, headdim_v), device=device, dtype=V.dtype) if store_states_adt_outv: SSM_States = torch.zeros((batch, nheads, headdim_v, nchunks * headdim_qk), device=device, dtype=torch.bfloat16) DA_CS_Store = torch.empty((batch, nheads, seqlen), device=device, dtype=torch.float32) DA_CS_SUM_Store = torch.zeros((batch, nheads, nchunks), device=device, dtype=torch.float32) Out_v = torch.empty((batch, seqlen, nheads, headdim_v), device=device, dtype=V.dtype) else: SSM_States, DA_CS_Store, DA_CS_SUM_Store, Out_v = None, None, None, None Q_store = torch.empty((batch, seqlen, nheads, headdim_qk), device=device, dtype=Q.dtype) K_store = torch.empty((batch, seqlen, nheads, headdim_qk), device=device, dtype=K.dtype) QK_store = torch.empty((batch, nheads, seqlen), device=device, dtype=torch.float32) Scale_store = torch.empty((batch, nheads, seqlen), device=device, dtype=torch.float32) Gamma_store = torch.empty((batch, nheads, seqlen), device=device, dtype=torch.float32) if return_final_states: Final_SSM_State = torch.empty((num_sequences, nheads, headdim_v, headdim_qk), device=device, dtype=torch.float32) Final_K_State = torch.empty((num_sequences, nheads, chunk_size, headdim_qk), device=device, dtype=torch.float32) else: Final_SSM_State, Final_K_State = None, None HEADDIM_V = triton.next_power_of_2(headdim_v) HEADDIM_QK = triton.next_power_of_2(headdim_qk) # Grid setup if is_varlen: grid = (nheads, batch, num_sequences) # batch = 1 else: grid = (nheads, batch) mamba3_siso_fwd_kernel[grid]( # Inputs Q, K, V, ADT, DT, Trap, Q_bias, K_bias, Angles, D, Z, Init_SSM_State, Init_K_State, Init_V_State, cu_seqlens, # Outputs Out, Out_v, SSM_States, DA_CS_Store, DA_CS_SUM_Store, Q_store, K_store, QK_store, Scale_store, Gamma_store, Final_SSM_State, Final_K_State, # Input strides Q.stride(0), Q.stride(1), Q.stride(2), Q.stride(3), K.stride(0), K.stride(1), K.stride(2), K.stride(3), V.stride(0), V.stride(1), V.stride(2), V.stride(3), ADT.stride(0), ADT.stride(1), ADT.stride(2), DT.stride(0), DT.stride(1), DT.stride(2), Trap.stride(0), Trap.stride(1), Trap.stride(2), Q_bias.stride(0), Q_bias.stride(1), K_bias.stride(0), K_bias.stride(1), Angles.stride(0), Angles.stride(1), Angles.stride(2), Angles.stride(3), D.stride(0) if D is not None else 0, Z.stride(0) if Z is not None else 0, Z.stride(1) if Z is not None else 0, Z.stride(2) if Z is not None else 0, Z.stride(3) if Z is not None else 0, Init_SSM_State.stride(0) if Init_SSM_State is not None else 0, Init_SSM_State.stride(1) if Init_SSM_State is not None else 0, Init_SSM_State.stride(2) if Init_SSM_State is not None else 0, Init_SSM_State.stride(3) if Init_SSM_State is not None else 0, Init_K_State.stride(0) if Init_K_State is not None else 0, Init_K_State.stride(1) if Init_K_State is not None else 0, Init_K_State.stride(2) if Init_K_State is not None else 0, Init_V_State.stride(0) if Init_V_State is not None else 0, Init_V_State.stride(1) if Init_V_State is not None else 0, Init_V_State.stride(2) if Init_V_State is not None else 0, cu_seqlens.stride(0) if cu_seqlens is not None else 0, # Output strides Out.stride(0), Out.stride(1), Out.stride(2), Out.stride(3), Out_v.stride(0) if Out_v is not None else 0, Out_v.stride(1) if Out_v is not None else 0, Out_v.stride(2) if Out_v is not None else 0, Out_v.stride(3) if Out_v is not None else 0, SSM_States.stride(0) if SSM_States is not None else 0, SSM_States.stride(1) if SSM_States is not None else 0, SSM_States.stride(2) if SSM_States is not None else 0, SSM_States.stride(3) if SSM_States is not None else 0, DA_CS_Store.stride(0) if DA_CS_Store is not None else 0, DA_CS_Store.stride(1) if DA_CS_Store is not None else 0, DA_CS_Store.stride(2) if DA_CS_Store is not None else 0, DA_CS_SUM_Store.stride(0) if DA_CS_SUM_Store is not None else 0, DA_CS_SUM_Store.stride(1) if DA_CS_SUM_Store is not None else 0, DA_CS_SUM_Store.stride(2) if DA_CS_SUM_Store is not None else 0, Q_store.stride(0), Q_store.stride(1), Q_store.stride(2), Q_store.stride(3), K_store.stride(0), K_store.stride(1), K_store.stride(2), K_store.stride(3), QK_store.stride(0), QK_store.stride(1), QK_store.stride(2), Scale_store.stride(0), Scale_store.stride(1), Scale_store.stride(2), Gamma_store.stride(0), Gamma_store.stride(1), Gamma_store.stride(2), Final_SSM_State.stride(0) if Final_SSM_State is not None else 0, Final_SSM_State.stride(1) if Final_SSM_State is not None else 0, Final_SSM_State.stride(2) if Final_SSM_State is not None else 0, Final_SSM_State.stride(3) if Final_SSM_State is not None else 0, Final_K_State.stride(0) if Final_K_State is not None else 0, Final_K_State.stride(1) if Final_K_State is not None else 0, Final_K_State.stride(2) if Final_K_State is not None else 0, Final_K_State.stride(3) if Final_K_State is not None else 0, # Dimensions seqlen, nheads_qk, headdim_qk, headdim_v, headdim_angles, # Compile-time constants chunk_size, HEADDIM_QK, HEADDIM_V, STORE_SSM_STATES_ADT_OUTV=store_states_adt_outv, HAS_INITIAL_STATES=Initial_States is not None, RETURN_FINAL_STATES=return_final_states, HAS_D=D is not None, HAS_Z=Z is not None, IS_VARLEN=is_varlen, ) Final_States = None if return_final_states: if is_varlen: seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] last_chunk_pos = (seq_lens - 1) % chunk_size final_k = Final_K_State[ torch.arange(num_sequences, device=device), :, last_chunk_pos, : ] last_token_idx = cu_seqlens[1:] - 1 final_v = V[0, last_token_idx] else: k_state_idx = (seqlen - 1) % chunk_size final_k = Final_K_State[:, :, k_state_idx, :] final_v = V[:, -1] Final_States = (Final_SSM_State, final_k, final_v) return (Out, Out_v, SSM_States, DA_CS_Store, DA_CS_SUM_Store, Q_store, K_store, QK_store, Scale_store, Gamma_store, Final_States) ================================================ FILE: mamba_ssm/ops/triton/mamba3/mamba3_siso_step.py ================================================ """ Mamba-3 Step Kernel. Copyright (c) 2025, Dao AI Lab, Goombalab """ from typing import Optional, Tuple import math import torch import triton import triton.language as tl from mamba_ssm.ops.triton.mamba3.utils import cos_approx, sin_approx, silu, tanh_approx, sigmoid_approx @triton.autotune( configs=[ triton.Config({}, num_stages=s, num_warps=w) for s in [1, 2, 3] for w in [2, 4, 8] ], key=[ "HEADDIM_QK", "HEADDIM_V", "HAS_D", "HAS_Z",], ) @triton.jit def mamba3_siso_step_kernel( # Inputs Q, K, V, ADT, DT, Trap, Q_bias, K_bias, Angles, D, Z, Input_Angle_State, Input_SSM_State, Input_K_State, Input_V_State, # Outputs Out, Output_Angle_State, Output_SSM_State, Output_K_State, # Input Strides stride_q_batch, stride_q_head, stride_q_qkdim, stride_k_batch, stride_k_head, stride_k_qkdim, stride_v_batch, stride_v_head, stride_v_vdim, stride_adt_batch, stride_adt_head, stride_dt_batch, stride_dt_head, stride_trap_batch, stride_trap_head, stride_q_bias_head, stride_q_bias_qkdim, stride_k_bias_head, stride_k_bias_qkdim, stride_angles_batch, stride_angles_head, stride_angles_qkdim, stride_d_head, stride_z_batch, stride_z_head, stride_z_vdim, stride_angle_state_batch, stride_angle_state_head, stride_angle_state_anglesdim, stride_input_ssm_state_batch, stride_input_ssm_state_head, stride_input_ssm_state_vdim, stride_input_ssm_state_qkdim, stride_input_k_state_batch, stride_input_k_state_head, stride_input_k_state_qkdim, stride_input_v_state_batch, stride_input_v_state_head, stride_input_v_state_vdim, # Output Strides stride_o_batch, stride_o_head, stride_o_vdim, stride_output_angle_state_batch, stride_output_angle_state_head, stride_output_angle_state_anglesdim, stride_output_ssm_state_batch, stride_output_ssm_state_head, stride_output_ssm_state_vdim, stride_output_ssm_state_qkdim, stride_output_k_state_batch, stride_output_k_state_head, stride_output_k_state_qkdim, # Dimensions nheads_qk, HEADDIM_QK: tl.constexpr, HEADDIM_V: tl.constexpr, HEADDIM_ANGLES: tl.constexpr, HAS_D: tl.constexpr, HAS_Z: tl.constexpr, ): """ Mamba-3 Step kernel. Inputs: Q, K: (batch, nheads_qk, headdim_qk) V: (batch, nheads, headdim_v) ADT, DT, Trap: (batch, nheads) Q_bias, K_bias: (nheads, headdim_qk) Angles: (batch, nheads, headdim_angles) D: (nheads,) Z: (batch, nheads, headdim_v) Out: (batch, nheads, headdim_v) SSM_States: (batch, nheads, headdim_v, headdim_qk) Input/Output Angle State: (batch, nheads, headdim_angles) Input/Output SSM State: (batch, nheads, headdim_v, headdim_qk) Input/Output K State: (batch, nheads, headdim_qk) Input/Output V State: (batch, nheads, headdim_v) Compile-time constants: HEADDIM_QK: Head dimension for Q/K HEADDIM_V: Head dimension for V HEADDIM_ANGLES: Head dimension for Angles HAS_D: Whether D-skip connection is used HAS_Z: Whether Z-gating is used Outputs: Out: (batch, nheads, headdim_v) Output_Angle_State: (batch, nheads, headdim_angles) Output_SSM_State: (batch, nheads, headdim_v, headdim_qk) Output_K_State: (batch, nheads, headdim_qk) """ # Program ID determines which (head, batch) pair this instance processes pid_head = tl.program_id(0) pid_batch = tl.program_id(1) # Compute head index for Q/K (supports Grouped Query Attention) nheads = tl.num_programs(0) head_idx_qk = pid_head // (nheads // nheads_qk) # Setup input pointers q_ptr = Q + pid_batch * stride_q_batch + head_idx_qk * stride_q_head k_ptr = K + pid_batch * stride_k_batch + head_idx_qk * stride_k_head v_ptr = V + pid_batch * stride_v_batch + pid_head * stride_v_head adt_ptr = ADT + pid_batch * stride_adt_batch + pid_head * stride_adt_head dt_ptr = DT + pid_batch * stride_dt_batch + pid_head * stride_dt_head trap_ptr = Trap + pid_batch * stride_trap_batch + pid_head * stride_trap_head q_bias_ptr = Q_bias + pid_head * stride_q_bias_head k_bias_ptr = K_bias + pid_head * stride_k_bias_head angle_ptr = Angles + pid_batch * stride_angles_batch + pid_head * stride_angles_head if HAS_D: D_ptr = D + pid_head * stride_d_head D_val = tl.load(D_ptr).to(tl.float32) if HAS_Z: z_ptr = Z + pid_batch * stride_z_batch + pid_head * stride_z_head input_angle_state_ptr = Input_Angle_State + pid_batch * stride_angle_state_batch + pid_head * stride_angle_state_head input_ssm_state_ptr = Input_SSM_State + pid_batch * stride_input_ssm_state_batch + pid_head * stride_input_ssm_state_head input_k_state_ptr = Input_K_State + pid_batch * stride_input_k_state_batch + pid_head * stride_input_k_state_head input_v_state_ptr = Input_V_State + pid_batch * stride_input_v_state_batch + pid_head * stride_input_v_state_head # Setup output pointers o_ptr = Out + pid_batch * stride_o_batch + pid_head * stride_o_head output_angle_state_ptr = Output_Angle_State + pid_batch * stride_output_angle_state_batch + pid_head * stride_output_angle_state_head output_ssm_state_ptr = Output_SSM_State + pid_batch * stride_output_ssm_state_batch + pid_head * stride_output_ssm_state_head output_k_state_ptr = Output_K_State + pid_batch * stride_output_k_state_batch + pid_head * stride_output_k_state_head PI = 3.141592653589793 TWO_PI = 2 * PI offs_qk = tl.arange(0, HEADDIM_QK) offs_v = tl.arange(0, HEADDIM_V) offs_qkr = tl.arange(0, HEADDIM_QK // 2) # Load Q and K blocks q_pre_block = tl.load(q_ptr + offs_qk * stride_q_qkdim) # (HEADDIM_QK) k_pre_block = tl.load(k_ptr + offs_qk * stride_k_qkdim) # (HEADDIM_QK) # Load Q and K biases q_bias_block = tl.load(q_bias_ptr + offs_qk * stride_q_bias_qkdim) # (HEADDIM_QK) k_bias_block = tl.load(k_bias_ptr + offs_qk * stride_k_bias_qkdim) # (HEADDIM_QK) q_pre_block += q_bias_block k_pre_block += k_bias_block # Load rotary angles (smaller block, direct load is faster than TMA) dt = tl.load(dt_ptr) angle_block = tl.load( angle_ptr + offs_qkr * stride_angles_qkdim, mask=offs_qkr < HEADDIM_ANGLES, other=0.0 ) # (HEADDIM_QK) angle_block = tanh_approx(angle_block.to(tl.float32)) * PI * dt angle_state = tl.load( input_angle_state_ptr + offs_qkr * stride_angle_state_anglesdim, mask=offs_qkr < HEADDIM_ANGLES, other=0.0 ) # (HEADDIM_QK) angle_block += angle_state angle_block -= TWO_PI * tl.floor(angle_block / TWO_PI) # angles mod 2pi tl.store(output_angle_state_ptr + offs_qkr * stride_output_angle_state_anglesdim, angle_block, mask=offs_qkr < HEADDIM_ANGLES) # Rotate Q and K with angles cos_block = cos_approx(angle_block.to(tl.float32)) sin_block = sin_approx(angle_block.to(tl.float32)) # Apply rotary embeddings to K and scale q0, q1 = tl.split(tl.reshape(q_pre_block, [HEADDIM_QK // 2, 2])) qo0 = q0 * cos_block - q1 * sin_block qo1 = q0 * sin_block + q1 * cos_block q_block = tl.reshape(tl.join(qo0, qo1), [HEADDIM_QK]).to(q_pre_block.dtype) k0, k1 = tl.split(tl.reshape(k_pre_block, [HEADDIM_QK // 2, 2])) ko0 = k0 * cos_block - k1 * sin_block ko1 = k0 * sin_block + k1 * cos_block k_block = tl.reshape(tl.join(ko0, ko1), [HEADDIM_QK]).to(k_pre_block.dtype) # Store K state tl.store(output_k_state_ptr + offs_qk * stride_output_k_state_qkdim, k_block) # Load previous K, V and current V k_prev_state = tl.load(input_k_state_ptr + offs_qk * stride_input_k_state_qkdim) # (HEADDIM_QK) v_prev_state = tl.load(input_v_state_ptr + offs_v * stride_input_v_state_vdim) # (HEADDIM_V) v_block = tl.load(v_ptr + offs_v * stride_v_vdim) # (HEADDIM_V) # Load ADT, DT and Trap adt = tl.load(adt_ptr) * 1.44269504089 trap = tl.load(trap_ptr) trap = sigmoid_approx(trap.to(tl.float32)) alpha = tl.math.exp2(adt) beta = alpha * dt * (1 - trap) gamma = trap * dt ssm_state_diff = (beta * v_prev_state)[:, None] * k_prev_state[None, :] + (gamma * v_block)[:, None] * k_block[None, :] # Load previous SSM state ssm_state = tl.load( input_ssm_state_ptr + offs_v[:, None] * stride_input_ssm_state_vdim + offs_qk[None, :] * stride_input_ssm_state_qkdim).to(tl.float32) # (HEADDIM_V, HEADDIM_QK) ssm_state = ssm_state * alpha + ssm_state_diff # Store updated SSM state tl.store(output_ssm_state_ptr + offs_v[:, None] * stride_output_ssm_state_vdim + offs_qk[None, :] * stride_output_ssm_state_qkdim, ssm_state) # Compute output out = tl.dot(ssm_state.to(tl.bfloat16), q_block.reshape([HEADDIM_QK, 1]).to(tl.bfloat16)) # (HEADDIM_V, 1) out = out.reshape([HEADDIM_V]).to(tl.float32) # out = tl.sum(ssm_state * q_block[None, :], axis=1) # (HEADDIM_V,) # Add D-skip connection if HAS_D: out += D_val * v_block # Apply Z-gating if HAS_Z: z_block = tl.load(z_ptr + offs_v * stride_z_vdim) # (HEADDIM_V) out = out * silu(z_block.to(tl.float32)) # Store output tl.store(o_ptr + offs_v * stride_o_vdim, out) # Memory Allocator for TMA Descriptors def _alloc_fn(size: int, alignment: int, stream: Optional[int]): """Custom allocator for TMA descriptor global memory allocation.""" return torch.empty(size, device="cuda", dtype=torch.int8) triton.set_allocator(_alloc_fn) def mamba3_siso_step( Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, ADT: torch.Tensor, DT: torch.Tensor, Trap: torch.Tensor, Q_bias: torch.Tensor, K_bias: torch.Tensor, Angles: torch.Tensor, D: Optional[torch.Tensor] = None, Z: Optional[torch.Tensor] = None, Input_States: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None, ): """ Mamba-3 step wrapper. Inputs: Q: Query tensor (batch, nheads_qk, headdim_qk). K: Key tensor (batch, nheads_qk, headdim_qk). V: Value tensor (batch, nheads, headdim_v). ADT: Decay tensor (batch, nheads). DT: DT tensor (batch, nheads). Trap: Trap tensor (batch, nheads). Q_bias: Query bias (nheads, headdim_qk). K_bias: Key bias (nheads, headdim_qk). Angles: Rotary angles (batch, nheads, headdim_angles) - headdim_angles <= headdim_qk // 2 and headdim_angles % 2 == 0. D: Skip connection weight (nheads,). Z: Gating tensor of shape (batch, nheads, headdim_v). - Applies SiLU gating: out = out * silu(Z). Input_States: Tuple of (Angle State SSM State, K state, V state) Angle state shape: (batch, nheads, headdim_angles). SSM state shape: (batch, nheads, headdim_v, headdim_qk). K state shape: (batch, nheads, headdim_qk). V state shape: (batch, nheads, headdim_v). NOTE: nheads % nheads_qk == 0 Outputs: Out: Output tensor (batch, nheads, headdim_v) Output_States: Final output state (None if return_output_state=False) - Output_Angle_State: Angle State (batch, nheads, headdim_angles) - Output_SSM_State: SSM State (batch, nheads, headdim_v, headdim_qk) - K_State: K state (batch, nheads, headdim_qk) - V_State: V state (batch, nheads, headdim_v) """ # Get dimensions batch, nheads_qk, headdim_qk = Q.shape _, nheads, headdim_v = V.shape device = Q.device # Validate input shapes assert Q.shape == K.shape, f"Q and K shape mismatch: {Q.shape} vs {K.shape}" assert nheads % nheads_qk == 0, f"nheads ({nheads}) must be divisible by nheads_qk ({nheads_qk})" assert ADT.shape == (batch, nheads), f"ADT shape mismatch: expected {(batch, nheads)}, got {ADT.shape}" assert DT.shape == (batch, nheads), f"DT shape mismatch: expected {(batch, nheads)}, got {DT.shape}" assert Trap.shape == (batch, nheads), f"Trap shape mismatch: expected {(batch, nheads)}, got {Trap.shape}" assert Q_bias.shape == (nheads, headdim_qk), f"Q_bias shape mismatch: expected {(nheads, headdim_qk)}, got {Q_bias.shape}" assert K_bias.shape == (nheads, headdim_qk), f"K_bias shape mismatch: expected {(nheads, headdim_qk)}, got {K_bias.shape}" headdim_angles = Angles.shape[-1] assert headdim_angles <= headdim_qk // 2 and headdim_angles % 2 == 0, f"headdim_angles ({headdim_angles}) must be <= headdim_qk // 2 ({headdim_qk // 2}) and even." assert Angles.shape == (batch, nheads, headdim_angles), f"Angles shape mismatch: expected {(batch, nheads, headdim_angles)}, got {Angles.shape}" if D is not None: assert D.shape == (nheads,), f"D shape mismatch: expected {(nheads,)}, got {D.shape}" if Z is not None: assert Z.shape == (batch, nheads, headdim_v), f"Z shape mismatch: expected {(batch, nheads, headdim_v)}, got {Z.shape}" Input_Angle_State, Input_SSM_State, Input_K_State, Input_V_State = Input_States assert Input_Angle_State.shape == (batch, nheads, headdim_angles), f"Input_Angle_State shape mismatch: expected {(batch, nheads, headdim_angles)}, got {Input_Angle_State.shape}" assert Input_SSM_State.shape == (batch, nheads, headdim_v, headdim_qk), f"Input_SSM_State shape mismatch: expected {(batch, nheads, headdim_v, headdim_qk)}, got {Input_SSM_State.shape}" assert Input_K_State.shape == (batch, nheads, headdim_qk), f"Input_K_State shape mismatch: expected {(batch, nheads, headdim_qk)}, got {Input_K_State.shape}" assert Input_V_State.shape == (batch, nheads, headdim_v), f"Input_V_State shape mismatch: expected {(batch, nheads, headdim_v)}, got {Input_V_State.shape}" # Ensure all tensors are contiguous Q = Q.contiguous() if not Q.is_contiguous() else Q K = K.contiguous() if not K.is_contiguous() else K V = V.contiguous() if not V.is_contiguous() else V ADT = ADT.contiguous() if not ADT.is_contiguous() else ADT DT = DT.contiguous() if not DT.is_contiguous() else DT Trap = Trap.contiguous() if not Trap.is_contiguous() else Trap Q_bias = Q_bias.contiguous() if not Q_bias.is_contiguous() else Q_bias K_bias = K_bias.contiguous() if not K_bias.is_contiguous() else K_bias Angles = Angles.contiguous() if not Angles.is_contiguous() else Angles if D is not None: D = D.contiguous() if not D.is_contiguous() else D if Z is not None: Z = Z.contiguous() if not Z.is_contiguous() else Z if Input_States is not None: Input_Angle_State = Input_Angle_State.contiguous() if not Input_Angle_State.is_contiguous() else Input_Angle_State Input_SSM_State = Input_SSM_State.contiguous() if not Input_SSM_State.is_contiguous() else Input_SSM_State Input_K_State = Input_K_State.contiguous() if not Input_K_State.is_contiguous() else Input_K_State Input_V_State = Input_V_State.contiguous() if not Input_V_State.is_contiguous() else Input_V_State # Allocate output tensors Out = torch.empty((batch, nheads, headdim_v), device=device, dtype=V.dtype) Output_Angle_State = torch.empty((batch, nheads, headdim_angles), device=device, dtype=torch.float32) Output_SSM_State = torch.empty((batch, nheads, headdim_v, headdim_qk), device=device, dtype=torch.float32) Output_K_State = torch.empty((batch, nheads, headdim_qk), device=device, dtype=torch.float32) grid = (nheads, batch) mamba3_siso_step_kernel[grid]( # Inputs Q, K, V, ADT, DT, Trap, Q_bias, K_bias, Angles, D, Z, Input_Angle_State, Input_SSM_State, Input_K_State, Input_V_State, # Outputs Out, Output_Angle_State, Output_SSM_State, Output_K_State, # Input strides Q.stride(0), Q.stride(1), Q.stride(2), K.stride(0), K.stride(1), K.stride(2), V.stride(0), V.stride(1), V.stride(2), ADT.stride(0), ADT.stride(1), DT.stride(0), DT.stride(1), Trap.stride(0), Trap.stride(1), Q_bias.stride(0), Q_bias.stride(1), K_bias.stride(0), K_bias.stride(1), Angles.stride(0), Angles.stride(1), Angles.stride(2), D.stride(0) if D is not None else 0, Z.stride(0) if Z is not None else 0, Z.stride(1) if Z is not None else 0, Z.stride(2) if Z is not None else 0, Input_Angle_State.stride(0), Input_Angle_State.stride(1), Input_Angle_State.stride(2), Input_SSM_State.stride(0), Input_SSM_State.stride(1), Input_SSM_State.stride(2), Input_SSM_State.stride(3), Input_K_State.stride(0), Input_K_State.stride(1), Input_K_State.stride(2), Input_V_State.stride(0), Input_V_State.stride(1), Input_V_State.stride(2), # Output strides Out.stride(0), Out.stride(1), Out.stride(2), Output_Angle_State.stride(0), Output_Angle_State.stride(1), Output_Angle_State.stride(2), Output_SSM_State.stride(0), Output_SSM_State.stride(1), Output_SSM_State.stride(2), Output_SSM_State.stride(3), Output_K_State.stride(0), Output_K_State.stride(1), Output_K_State.stride(2), # Dimensions nheads_qk, # Compile-time constants headdim_qk, headdim_v, headdim_angles, HAS_D=D is not None, HAS_Z=Z is not None, ) Output_States = [Output_Angle_State, Output_SSM_State, Output_K_State, V] return Out, Output_States ================================================ FILE: mamba_ssm/ops/triton/mamba3/utils.py ================================================ """ Mamba-3 Util Functions. Copyright (c) 2025, Dao AI Lab, Goombalab """ import triton import triton.language as tl # We use PTX approximations instead of triton built-in functions # to trade off a bit of accuracy for much faster speed. @triton.jit def cos_approx(x): """ (Fast) Cosine approximation using PTX inline assembly. Args: x: Input triton tensor (any shape) in float32 Returns: Approximate cosine values in float32 """ return tl.inline_asm_elementwise( "cos.approx.f32 $0, $1;", constraints="=f,f", args=[x], dtype=tl.float32, is_pure=True, pack=1, ) @triton.jit def sin_approx(x): """ (Fast) Sine approximation using PTX inline assembly. Args: x: Input triton tensor (any shape) in float32 Returns: Approximate sine values in float32 """ return tl.inline_asm_elementwise( "sin.approx.f32 $0, $1;", constraints="=f,f", args=[x], dtype=tl.float32, is_pure=True, pack=1, ) @triton.jit def tanh_approx(x): """ (Fast) hyperbolic tangent approximation using PTX inline assembly. Args: x: Input triton tensor (any shape) in float32 Returns: Approximate tanh values in float32 """ return tl.inline_asm_elementwise( "tanh.approx.f32 $0, $1;", constraints="=f,f", args=[x], dtype=tl.float32, is_pure=True, pack=1, ) @triton.jit def sech2_approx(x): """ (Fast) square of the hyperbolic secant approximation using PTX inline assembly. Args: x: Input triton tensor (any shape) in float32 Returns: Approximate sech^2 values in float32 """ tanh_x = tl.inline_asm_elementwise( "tanh.approx.f32 $0, $1;", constraints="=f,f", args=[x], dtype=tl.float32, is_pure=True, pack=1, ) return 1.0 - tanh_x * tanh_x @triton.jit def sigmoid_approx(x): """ (Fast) Sigmoid approximation using PTX inline assembly. Formula: sigmoid(x) = 0.5 * (1 + tanh(0.5 * x)) Leverages fast tanh approximation for speed. Args: x: Input triton tensor (any shape) in float32 Returns: Approximate sigmoid values in float32 """ # tanh_half_x = tl.inline_asm_elementwise( # "tanh.approx.f32 $0, $1;", # constraints="=f,f", # args=[0.5 * x], # dtype=tl.float32, # is_pure=True, # pack=1, # ) # return 0.5 * (1.0 + tanh_half_x) # NOTE: We ended up using the built-in sigmoid for better performance, as the PTX approximation was not faster in this case. return tl.sigmoid(x) @triton.jit def silu(x): """ SiLU (Swish) activation function: x * sigmoid(x). Formula: silu(x) = 0.5*x * (1 + tanh(0.5*x)) + 0.5*x. Leverages fast tanh_approx for speed. Args: x: Input triton tensor (any shape) in float32 Returns: SiLU activation output in float32 """ # x_half = 0.5 * x # return x_half * tanh_approx(x_half) + x_half # NOTE: We ended up using the built-in sigmoid for better performance, as the PTX approximation was not faster in this case. return x*tl.sigmoid(x) ================================================ FILE: mamba_ssm/ops/triton/selective_state_update.py ================================================ # Copyright (c) 2024, Tri Dao, Albert Gu. """We want triton==2.1.0 or triton==2.2.0 or triton==2.3.0 for this """ import math import torch import torch.nn.functional as F import triton import triton.language as tl from einops import rearrange, repeat from mamba_ssm.ops.triton.softplus import softplus @triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) @triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None}) @triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None}) @triton.heuristics({"HAS_STATE_BATCH_INDICES": lambda args: args["state_batch_indices_ptr"] is not None}) @triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}) @triton.jit def _selective_scan_update_kernel( # Pointers to matrices state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr, state_batch_indices_ptr, # Matrix dimensions batch, nheads, dim, dstate, nheads_ngroups_ratio, # Strides stride_state_batch, stride_state_head, stride_state_dim, stride_state_dstate, stride_x_batch, stride_x_head, stride_x_dim, stride_dt_batch, stride_dt_head, stride_dt_dim, stride_dt_bias_head, stride_dt_bias_dim, stride_A_head, stride_A_dim, stride_A_dstate, stride_B_batch, stride_B_group, stride_B_dstate, stride_C_batch, stride_C_group, stride_C_dstate, stride_D_head, stride_D_dim, stride_z_batch, stride_z_head, stride_z_dim, stride_out_batch, stride_out_head, stride_out_dim, # Meta-parameters DT_SOFTPLUS: tl.constexpr, TIE_HDIM: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, HAS_DT_BIAS: tl.constexpr, HAS_D: tl.constexpr, HAS_Z: tl.constexpr, HAS_STATE_BATCH_INDICES: tl.constexpr, BLOCK_SIZE_DSTATE: tl.constexpr, ): pid_m = tl.program_id(axis=0) pid_b = tl.program_id(axis=1) pid_h = tl.program_id(axis=2) offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head out_ptrs = out_ptr + offs_m * stride_out_dim if HAS_STATE_BATCH_INDICES: state_batch_indices_ptr += pid_b state_batch_idx = tl.load(state_batch_indices_ptr) # Skip padding tokens if state_batch_idx < 0: tl.store(out_ptrs, 0.0, mask=offs_m < dim) return state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head else: state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head if HAS_DT_BIAS: dt_bias_ptr += pid_h * stride_dt_bias_head A_ptr += pid_h * stride_A_head B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group if HAS_Z: z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head offs_n = tl.arange(0, BLOCK_SIZE_DSTATE) state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate) x_ptrs = x_ptr + offs_m * stride_x_dim dt_ptrs = dt_ptr + offs_m * stride_dt_dim if HAS_DT_BIAS: dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim if HAS_D: D_ptr += pid_h * stride_D_head A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate) B_ptrs = B_ptr + offs_n * stride_B_dstate C_ptrs = C_ptr + offs_n * stride_C_dstate if HAS_D: D_ptrs = D_ptr + offs_m * stride_D_dim if HAS_Z: z_ptrs = z_ptr + offs_m * stride_z_dim state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0) x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) if not TIE_HDIM: dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) if HAS_DT_BIAS: dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) if DT_SOFTPLUS: dt = tl.where(dt <= 20.0, softplus(dt), dt) A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) dA = tl.exp(A * dt[:, None]) else: dt = tl.load(dt_ptr).to(tl.float32) if HAS_DT_BIAS: dt += tl.load(dt_bias_ptr).to(tl.float32) if DT_SOFTPLUS: dt = tl.where(dt <= 20.0, softplus(dt), dt) A = tl.load(A_ptr).to(tl.float32) dA = tl.exp(A * dt) # scalar, not a matrix B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) if HAS_D: D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) if HAS_Z: z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) if not TIE_HDIM: dB = B[None, :] * dt[:, None] else: dB = B * dt # vector of size (dstate,) state = state * dA + dB * x[:, None] tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate)) out = tl.sum(state * C[None, :], axis=1) if HAS_D: out += x * D if HAS_Z: out *= z * tl.sigmoid(z) tl.store(out_ptrs, out, mask=offs_m < dim) def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False, state_batch_indices=None): """ Argument: state: (batch, dim, dstate) or (batch, nheads, dim, dstate) x: (batch, dim) or (batch, nheads, dim) dt: (batch, dim) or (batch, nheads, dim) A: (dim, dstate) or (nheads, dim, dstate) B: (batch, dstate) or (batch, ngroups, dstate) C: (batch, dstate) or (batch, ngroups, dstate) D: (dim,) or (nheads, dim) z: (batch, dim) or (batch, nheads, dim) dt_bias: (dim,) or (nheads, dim) Return: out: (batch, dim) or (batch, nheads, dim) """ has_heads = state.dim() > 3 if state.dim() == 3: state = state.unsqueeze(1) if x.dim() == 2: x = x.unsqueeze(1) if dt.dim() == 2: dt = dt.unsqueeze(1) if A.dim() == 2: A = A.unsqueeze(0) if B.dim() == 2: B = B.unsqueeze(1) if C.dim() == 2: C = C.unsqueeze(1) if D is not None and D.dim() == 1: D = D.unsqueeze(0) if z is not None and z.dim() == 2: z = z.unsqueeze(1) if dt_bias is not None and dt_bias.dim() == 1: dt_bias = dt_bias.unsqueeze(0) _, nheads, dim, dstate = state.shape batch = x.shape[0] if x.shape != (batch, nheads, dim): print(f"{state.shape} {x.shape} {batch} {nheads} {dim}") assert x.shape == (batch, nheads, dim) assert dt.shape == x.shape assert A.shape == (nheads, dim, dstate) ngroups = B.shape[1] assert nheads % ngroups == 0, "nheads must be divisible by ngroups" assert B.shape == (batch, ngroups, dstate) assert C.shape == B.shape if D is not None: assert D.shape == (nheads, dim) if z is not None: assert z.shape == x.shape if dt_bias is not None: assert dt_bias.shape == (nheads, dim) if state_batch_indices is not None: assert state_batch_indices.shape == (batch,) out = torch.empty_like(x) grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads) z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0)) # We don't want autotune since it will overwrite the state # We instead tune by hand. BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16 else ((16, 4) if dstate <= 32 else ((8, 4) if dstate <= 64 else ((4, 4) if dstate <= 128 else ((4, 8)))))) tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride(-1) == 0 and dt_bias.stride(-1) == 0 with torch.cuda.device(x.device.index): _selective_scan_update_kernel[grid]( state, x, dt, dt_bias, A, B, C, D, z, out, state_batch_indices, batch, nheads, dim, dstate, nheads // ngroups, state.stride(0), state.stride(1), state.stride(2), state.stride(3), x.stride(0), x.stride(1), x.stride(2), dt.stride(0), dt.stride(1), dt.stride(2), *(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0, A.stride(0), A.stride(1), A.stride(2), B.stride(0), B.stride(1), B.stride(2), C.stride(0), C.stride(1), C.stride(2), *(D.stride(0), D.stride(1)) if D is not None else 0, z_strides[0], z_strides[1], z_strides[2], out.stride(0), out.stride(1), out.stride(2), dt_softplus, tie_hdim, BLOCK_SIZE_M, num_warps=num_warps, ) if not has_heads: out = out.squeeze(1) return out def selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False): """ Argument: state: (batch, dim, dstate) or (batch, nheads, dim, dstate) x: (batch, dim) or (batch, nheads, dim) dt: (batch, dim) or (batch, nheads, dim) A: (dim, dstate) or (nheads, dim, dstate) B: (batch, dstate) or (batch, ngroups, dstate) C: (batch, dstate) or (batch, ngroups, dstate) D: (dim,) or (nheads, dim) z: (batch, dim) or (batch, nheads, dim) dt_bias: (dim,) or (nheads, dim) Return: out: (batch, dim) or (batch, nheads, dim) """ has_heads = state.dim() > 3 if state.dim() == 3: state = state.unsqueeze(1) if x.dim() == 2: x = x.unsqueeze(1) if dt.dim() == 2: dt = dt.unsqueeze(1) if A.dim() == 2: A = A.unsqueeze(0) if B.dim() == 2: B = B.unsqueeze(1) if C.dim() == 2: C = C.unsqueeze(1) if D is not None and D.dim() == 1: D = D.unsqueeze(0) if z is not None and z.dim() == 2: z = z.unsqueeze(1) if dt_bias is not None and dt_bias.dim() == 1: dt_bias = dt_bias.unsqueeze(0) batch, nheads, dim, dstate = state.shape assert x.shape == (batch, nheads, dim) assert dt.shape == x.shape assert A.shape == (nheads, dim, dstate) ngroups = B.shape[1] assert nheads % ngroups == 0, "nheads must be divisible by ngroups" assert B.shape == (batch, ngroups, dstate) assert C.shape == B.shape if D is not None: assert D.shape == (nheads, dim) if z is not None: assert z.shape == x.shape if dt_bias is not None: assert dt_bias.shape == (nheads, dim) dt = dt + dt_bias dt = F.softplus(dt) if dt_softplus else dt dA = torch.exp(rearrange(dt, "b h d -> b h d 1") * A) # (batch, nheads, dim, dstate) B = repeat(B, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate) C = repeat(C, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate) dB = rearrange(dt, "b h d -> b h d 1") * rearrange(B, "b h n -> b h 1 n") # (batch, nheads, dim, dstate) state.copy_(state * dA + dB * rearrange(x, "b h d -> b h d 1")) # (batch, dim, dstate out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C) if D is not None: out += (x * D).to(out.dtype) out = (out if z is None else out * F.silu(z)).to(x.dtype) if not has_heads: out = out.squeeze(1) return out ================================================ FILE: mamba_ssm/ops/triton/softplus.py ================================================ import triton import triton.language as tl from packaging import version TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0") if TRITON3: @triton.jit def softplus(dt): return tl.math.log(tl.math.exp(dt) + 1) else: @triton.jit def softplus(dt): return tl.math.log1p(tl.exp(dt)) ================================================ FILE: mamba_ssm/ops/triton/ssd_bmm.py ================================================ # Copyright (c) 2024, Tri Dao, Albert Gu. """We want triton==2.1.0 or 2.2.0 for this """ import math import torch import torch.nn.functional as F import triton import triton.language as tl from einops import rearrange, repeat from mamba_ssm.utils.determinism import autotune_configs def init_to_zero(names): return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None] @triton.autotune( configs=autotune_configs([ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2), ]), key=['chunk_size', 'K', 'IS_CAUSAL'], ) @triton.jit def _bmm_chunk_fwd_kernel( # Pointers to matrices a_ptr, b_ptr, out_ptr, seq_idx_ptr, # Matrix dimensions seqlen, chunk_size, K, ngroups, stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak, stride_b_batch, stride_b_seqlen, stride_b_head, stride_bk, stride_out_batch, stride_out_chunk, stride_out_head, stride_outm, stride_outn, stride_seq_idx_batch, stride_seq_idx_seqlen, # Meta-parameters IS_CAUSAL: tl.constexpr, dot_dtype: tl.constexpr, HAS_SEQ_IDX: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ): pid_b = tl.program_id(axis=1) pid_ch = tl.program_id(axis=2) pid_c = pid_ch // ngroups pid_h = pid_ch - pid_c * ngroups num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N) pid_m = tl.program_id(axis=0) // num_pid_n pid_n = tl.program_id(axis=0) % num_pid_n if IS_CAUSAL: if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M: return a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head if HAS_SEQ_IDX: seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak) b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen) chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): a = tl.load(a_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0).to(dot_dtype) b = tl.load(b_ptrs, mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_n[None, :] < chunk_size_limit), other=0.0).to(dot_dtype) acc += tl.dot(a, b) a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) if HAS_SEQ_IDX: chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, mask=offs_n < chunk_size_limit, other=-2) acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0) out = acc.to(out_ptr.dtype.element_ty) out_ptr += pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn) tl.store(out_ptrs, out, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size)) @triton.autotune( configs=autotune_configs([ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 64}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=2), ]), key=['chunk_size', 'K'], ) @triton.jit def _bmm_chunk_bwd_kernel( # Pointers to matrices a_ptr, dout_ptr, db_ptr, res_ptr, # Matrix dimensions seqlen, chunk_size, K, ngroups, stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak, stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_csize_m, stride_dout_csize_n, stride_db_batch, stride_db_seqlen, stride_db_head, stride_db_k, stride_res_batch, stride_res_seqlen, stride_res_head, stride_res_k, # Meta-parameters dot_dtype: tl.constexpr, HAS_RESIDUAL: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_CS: tl.constexpr, ): pid_b = tl.program_id(axis=1) pid_ch = tl.program_id(axis=2) pid_c = pid_ch // ngroups pid_h = pid_ch - pid_c * ngroups num_pid_n = tl.cdiv(K, BLOCK_SIZE_N) pid_m = tl.program_id(axis=0) // num_pid_n pid_n = tl.program_id(axis=0) % num_pid_n a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head dout_ptr += pid_b * stride_dout_batch + pid_c * stride_dout_chunk + pid_h * stride_dout_head offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_cs = tl.arange(0, BLOCK_SIZE_CS) dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_csize_n + offs_cs[None, :] * stride_dout_csize_m) a_ptrs = a_ptr + (offs_cs[:, None] * stride_a_seqlen + offs_n[None, :] * stride_ak) chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for cs in range(0, tl.cdiv(chunk_size_limit, BLOCK_SIZE_CS)): dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_cs[None, :] < chunk_size_limit - cs * BLOCK_SIZE_CS), other=0.0).to(dot_dtype) a = tl.load(a_ptrs, mask=(offs_cs[:, None] < chunk_size_limit - cs * BLOCK_SIZE_CS) & (offs_n[None, :] < K), other=0.0).to(dot_dtype) acc += tl.dot(dout, a) dout_ptrs += BLOCK_SIZE_CS * stride_dout_csize_m a_ptrs += BLOCK_SIZE_CS * stride_a_seqlen offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) if HAS_RESIDUAL: res_ptr += pid_b * stride_res_batch + pid_c * chunk_size * stride_res_seqlen + pid_h * stride_res_head res_ptrs = res_ptr + (offs_m[:, None] * stride_res_seqlen + offs_n[None, :] * stride_res_k) res = tl.load(res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K)).to(tl.float32) acc += res db = acc.to(db_ptr.dtype.element_ty) db_ptr += pid_b * stride_db_batch + pid_c * chunk_size * stride_db_seqlen + pid_h * stride_db_head db_ptrs = db_ptr + (offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_k) tl.store(db_ptrs, db, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K)) def _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=None): """ Argument: a: (batch, seqlen, k) or (batch, seqlen, ngroups, k) b: (batch, seqlen, k) or (batch, seqlen, ngroups, k) seq_idx: (batch, seqlen) or None. out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out. causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are guaranteed to be correct. Return: out: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size) """ # Check constraints. has_groups = a.dim() == 4 if not has_groups: batch, seqlen, k = a.shape else: batch, seqlen, ngroups, k = a.shape assert b.shape == a.shape if seq_idx is not None: assert seq_idx.shape == (batch, seqlen) if a.stride(-1) != 1 and a.stride(1) != 1: a = a.contiguous() if b.stride(-1) != 1 and b.stride(1) != 1: b = b.contiguous() nchunks = math.ceil(seqlen / chunk_size) # Allocates output. out_dtype = a.dtype if output_dtype is None else output_dtype out = torch.empty((batch, nchunks, chunk_size, chunk_size) if not has_groups else (batch, nchunks, ngroups, chunk_size, chunk_size), device=a.device, dtype=out_dtype) dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else (tl.float16 if a.dtype == torch.float16 or b.dtype == torch.float16 else tl.float32)) grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(chunk_size, META['BLOCK_SIZE_N']), batch, nchunks if not has_groups else nchunks * ngroups) with torch.cuda.device(a.device.index): _bmm_chunk_fwd_kernel[grid]( a, b, out, seq_idx, seqlen, chunk_size, k, ngroups if has_groups else 1, a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1), b.stride(0), b.stride(1), 0 if not has_groups else b.stride(2), b.stride(-1), out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-2), out.stride(-1), *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), causal, dot_dtype, HAS_SEQ_IDX=seq_idx is not None, ) return out def _bmm_chunk_bwd(a, dout, residual=None, out=None): """ Argument: a: (batch, seqlen, k) or (batch, seqlen, ngroups, k) dout: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size) residual: (batch, seqlen, k) or (batch, seqlen, ngroups, k) Return: out: (batch, seqlen, k) or (batch, seqlen, ngroups, k) If there was seq_idx in the fwd pass, then dout[i, j] for seq_idx[i] != seq_idx[j] should already be zeroed out before calling this function. """ # Check constraints. has_groups = a.dim() == 4 if not has_groups: batch, seqlen, k = a.shape else: batch, seqlen, ngroups, k = a.shape nchunks, chunk_size = dout.shape[1], dout.shape[-1] if a.stride(-1) != 1 and a.stride(-2) != 1: a = a.contiguous() if dout.stride(-1) != 1 and dout.stride(-2) != 1: dout = dout.contiguous() if residual is not None: assert residual.shape == (batch, seqlen, k) if not has_groups else (batch, seqlen, ngroups, k) if residual.stride(-1) != 1 and residual.stride(1) != 1: residual = residual.contiguous() # Allocates output. if out is not None: assert out.shape == a.shape assert out.stride(-1) == 1 or out.stride(1) == 1 else: out = torch.empty_like(a) dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or dout.dtype == torch.bfloat16 else (tl.float16 if a.dtype == torch.float16 or dout.dtype == torch.float16 else tl.float32)) grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(k, META['BLOCK_SIZE_N']), batch, nchunks if not has_groups else nchunks * ngroups) residual_strides = ((residual.stride(0), residual.stride(1), 0 if not has_groups else residual.stride(2), residual.stride(-1)) if residual is not None else (0, 0, 0, 0)) with torch.cuda.device(a.device.index): _bmm_chunk_bwd_kernel[grid]( a, dout, out, residual, seqlen, chunk_size, k, ngroups if has_groups else 1, a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1), dout.stride(0), dout.stride(1), 0 if not has_groups else dout.stride(2), dout.stride(-2), dout.stride(-1), out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-1), residual_strides[0], residual_strides[1], residual_strides[2], residual_strides[3], dot_dtype, HAS_RESIDUAL=residual is not None, ) return out ================================================ FILE: mamba_ssm/ops/triton/ssd_chunk_scan.py ================================================ # Copyright (c) 2024, Tri Dao, Albert Gu. """We want triton==2.1.0 or 2.2.0 for this """ import math from packaging import version import torch import torch.nn.functional as F import triton import triton.language as tl from einops import rearrange, repeat from mamba_ssm.ops.triton.ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd from mamba_ssm.utils.determinism import ( alloc_tile_workspace, finalize_tile_workspace, use_deterministic_mode, autotune_configs, ) TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') def init_to_zero(names): return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None] @triton.autotune( configs=autotune_configs([ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2), ]), key=['chunk_size', 'hdim', 'dstate', 'IS_CAUSAL'], ) @triton.jit def _chunk_scan_fwd_kernel( # Pointers to matrices cb_ptr, x_ptr, z_ptr, out_ptr, out_x_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, C_ptr, prev_states_ptr, D_ptr, # Matrix dimensions chunk_size, hdim, dstate, batch, seqlen, nheads_ngroups_ratio, # Strides stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k, stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim, stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim, stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, stride_seq_idx_batch, stride_seq_idx_seqlen, stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate, stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate, stride_D_head, # Meta-parameters IS_CAUSAL: tl.constexpr, HAS_D: tl.constexpr, D_HAS_HDIM: tl.constexpr, HAS_Z: tl.constexpr, HAS_SEQ_IDX: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_DSTATE: tl.constexpr, IS_TRITON_22: tl.constexpr, ): pid_bc = tl.program_id(axis=1) pid_c = pid_bc // batch pid_b = pid_bc - pid_c * batch pid_h = tl.program_id(axis=2) num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) pid_m = tl.program_id(axis=0) // num_pid_n pid_n = tl.program_id(axis=0) % num_pid_n cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + (pid_h // nheads_ngroups_ratio) * stride_C_head prev_states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head if HAS_SEQ_IDX: seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) if HAS_SEQ_IDX: seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) # Without the if (pid_c > -1), with Triton 2.1.0, I get # Assertion `!(srcMmaLayout && dstMmaLayout) && "Unexpected mma -> mm a layout conversion"' failed. # With Triton 2.2.0, this works if IS_TRITON_22 or pid_c > -1: # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 offs_k_dstate = tl.arange(0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate) prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_states_hdim + offs_k_dstate[:, None] * stride_states_dstate) if not HAS_SEQ_IDX: scale_m = tl.exp(dA_cs_m) else: scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0) if BLOCK_SIZE_DSTATE <= 128: C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate), other=0.0) prev_states = tl.load(prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0) prev_states = prev_states.to(C_ptr.dtype.element_ty) acc = tl.dot(C, prev_states) * scale_m[:, None] else: for k in range(0, dstate, BLOCK_SIZE_K): C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate - k), other=0.0) # C = (C * scale_m[:, None]).to(C_ptr.dtype.element_ty) prev_states = tl.load(prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0) prev_states = prev_states.to(C_ptr.dtype.element_ty) acc += tl.dot(C, prev_states) C_ptrs += BLOCK_SIZE_K prev_states_ptrs += BLOCK_SIZE_K acc *= scale_m[:, None] offs_k = tl.arange(0, BLOCK_SIZE_K) cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k) x_ptrs = x_ptr + (offs_k[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) dt_ptrs = dt_ptr + offs_k * stride_dt_csize dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize K_MAX = chunk_size_limit if not IS_CAUSAL else min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit) for k in range(0, K_MAX, BLOCK_SIZE_K): cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < chunk_size - k), other=0.0).to(tl.float32) dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) # If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j]. # So we don't need masking wrt seq_idx here. # cb *= tl.exp((dA_cs_m[:, None] - dA_cs_k[None, :])) cb *= tl.exp(tl.minimum((dA_cs_m[:, None] - dA_cs_k[None, :]), 0.0)) dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) cb *= dt_k if IS_CAUSAL: mask = offs_m[:, None] >= k + offs_k[None, :] cb = tl.where(mask, cb, 0.0) cb = cb.to(x_ptr.dtype.element_ty) x = tl.load(x_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < hdim), other=0.0) acc += tl.dot(cb, x) cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k x_ptrs += BLOCK_SIZE_K * stride_x_seqlen dt_ptrs += BLOCK_SIZE_K * stride_dt_csize dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize offs_out_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) if HAS_D: if D_HAS_HDIM: D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) else: D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) x_residual = tl.load(x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim), mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) acc += x_residual * D if HAS_Z: out_x_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :]) tl.store(out_x_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim)) z_ptr += pid_b * stride_z_batch + pid_c * chunk_size * stride_z_seqlen + pid_h * stride_z_head z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :]) z = tl.load(z_ptrs, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim), other=0.0).to(tl.float32) acc *= z * tl.sigmoid(z) out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim) tl.store(out_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim)) @triton.autotune( configs=autotune_configs([ # triton.Config({'BLOCK_SIZE_N': 256}, num_stages=4, num_warps=4), # triton.Config({'BLOCK_SIZE_N': 128}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_N': 64}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_N': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_N': 64}, num_stages=4, num_warps=8), triton.Config({'BLOCK_SIZE_N': 32}, num_stages=4, num_warps=8), ]), key=['chunk_size', 'hdim', 'dstate'], ) @triton.jit def _chunk_scan_fwd_kernel_wip( # Pointers to matrices cb_ptr, x_ptr, z_ptr, out_ptr, out_x_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, C_ptr, B_ptr, prev_states_ptr, D_ptr, # Matrix dimensions chunk_size, hdim, dstate, batch, seqlen, nheads_ngroups_ratio, # Strides stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k, stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim, stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim, stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, stride_seq_idx_batch, stride_seq_idx_seqlen, stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate, stride_B_batch, stride_B_seqlen, stride_B_head, stride_B_dstate, stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate, stride_D_head, # Meta-parameters HAS_D: tl.constexpr, D_HAS_HDIM: tl.constexpr, HAS_Z: tl.constexpr, HAS_SEQ_IDX: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_DSTATE: tl.constexpr, ): pid_bc = tl.program_id(axis=1) pid_c = pid_bc // batch pid_b = pid_bc - pid_c * batch pid_h = tl.program_id(axis=2) pid_n = tl.program_id(axis=0) cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + (pid_h // nheads_ngroups_ratio) * stride_C_head B_ptr += pid_b * stride_B_batch + pid_c * chunk_size * stride_B_seqlen + (pid_h // nheads_ngroups_ratio) * stride_B_head prev_states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head if HAS_SEQ_IDX: seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head offs_m = tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k_dstate = tl.arange(0, BLOCK_SIZE_DSTATE) C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate) B_ptrs = B_ptr + (offs_m[None, :] * stride_B_seqlen + offs_k_dstate[:, None] * stride_B_dstate) prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_states_hdim + offs_k_dstate[:, None] * stride_states_dstate) num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_m[None, :] * stride_cb_csize_k) x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) dt_ptrs = dt_ptr + offs_m * stride_dt_csize out_ptrs = out_ptr + (offs_m[:, None] * stride_out_seqlen + offs_n[None, :] * stride_out_hdim) prev_states = tl.load(prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0) # if pid_c == 0: # if pid_b == 0: # if pid_h == 0: # tl.device_print("", prev_states) chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) # dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) # scale_m = tl.exp(dA_cs_m) # C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate), other=0.0) # acc = tl.dot(C, prev_states.to(C_ptr.dtype.element_ty)) * scale_m[:, None] # cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_m[None, :] < chunk_size), other=0.0).to(tl.float32) # cb *= tl.exp((dA_cs_m[:, None] - dA_cs_m[None, :])) # dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) # cb *= dt_m # mask = offs_m[:, None] >= offs_m[None, :] # cb = tl.where(mask, cb, 0.0) # cb = cb.to(x_ptr.dtype.element_ty) # x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0) # acc += tl.dot(cb, x) # if HAS_D: # if D_HAS_HDIM: # D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) # else: # D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) # acc += x.to(tl.float32) * D # tl.store(out_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) for start_m in range(0, chunk_size_limit, BLOCK_SIZE_M): start_m = tl.multiple_of(start_m, BLOCK_SIZE_M) dA_cs_m = tl.load(dA_cumsum_ptr + (start_m + offs_m) * stride_dA_cs_csize, mask=offs_m < chunk_size - start_m, other=0.0).to(tl.float32) if HAS_SEQ_IDX: seq_idx_prev = tl.load(seq_idx_ptr + start_m - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) seq_idx_m = tl.load(seq_idx_ptr + (start_m + offs_m) * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit - start_m, other=-1) if not HAS_SEQ_IDX: scale_m = tl.exp(dA_cs_m) else: scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0) C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit - start_m) & (offs_k_dstate[None, :] < dstate), other=0.0) acc = tl.dot(C, prev_states.to(C_ptr.dtype.element_ty)) * scale_m[:, None] # cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size - start_m) & (offs_m[None, :] < chunk_size - start_m), other=0.0).to(tl.float32) # cb *= tl.exp((dA_cs_m[:, None] - dA_cs_m[None, :])) dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size - start_m, other=0.0).to(tl.float32) # cb *= dt_m # mask = offs_m[:, None] >= offs_m[None, :] # cb = tl.where(mask, cb, 0.0) # cb = cb.to(x_ptr.dtype.element_ty) x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit - start_m) & (offs_n[None, :] < hdim), other=0.0) # acc += tl.dot(cb, x) if HAS_D: if D_HAS_HDIM: D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) else: D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) acc += x.to(tl.float32) * D # if HAS_Z: # out_x_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head # out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :]) # tl.store(out_x_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim)) # z_ptr += pid_b * stride_z_batch + pid_c * chunk_size * stride_z_seqlen + pid_h * stride_z_head # z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :]) # z = tl.load(z_ptrs, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim), other=0.0).to(tl.float32) # acc *= z * tl.sigmoid(z) tl.store(out_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit - start_m) & (offs_n[None, :] < hdim)) # TODO: this is not correct, and quite a bit slower if start_m + BLOCK_SIZE_M < chunk_size_limit: # B = tl.load(B_ptrs, mask=(offs_m[None, :] < chunk_size_limit - start_m) & (offs_k_dstate[:, None] < dstate), other=0.0).to(tl.float32) B = tl.load(B_ptrs, mask=(offs_m[None, :] < chunk_size_limit - start_m) & (offs_k_dstate[:, None] < dstate), other=0.0) dA_cs_last = tl.load(dA_cumsum_ptr + (start_m + BLOCK_SIZE_M) * stride_dA_cs_csize).to(tl.float32) # TODO: seq_idx scale = tl.exp((dA_cs_last - dA_cs_m)) * dt_m # B *= scale B = B.to(x_ptr.dtype.element_ty) tmp = tl.dot(B, x) prev_states += tmp.to(prev_states.dtype) C_ptrs += BLOCK_SIZE_M * stride_C_seqlen B_ptrs += BLOCK_SIZE_M * stride_B_seqlen cb_ptrs += BLOCK_SIZE_M * stride_cb_csize_m + BLOCK_SIZE_M * stride_cb_csize_k x_ptrs += BLOCK_SIZE_M * stride_x_seqlen dt_ptrs += BLOCK_SIZE_M * stride_dt_csize out_ptrs += BLOCK_SIZE_M * stride_out_seqlen @triton.autotune( configs=autotune_configs([ triton.Config({'BLOCK_SIZE_M': 32}), triton.Config({'BLOCK_SIZE_M': 64}), triton.Config({'BLOCK_SIZE_M': 128}), triton.Config({'BLOCK_SIZE_M': 256}), ]), key=["chunk_size", "hdim"], ) @triton.jit def _chunk_scan_bwd_dz_kernel( # Pointers to matrices dout_ptr, out_ptr, z_ptr, x_ptr, D_ptr, outz_ptr, dz_ptr, dout_x_ptr, dD_ptr, ddA_cumsum_ptr, # Matrix dimensions chunk_size, hdim, batch, seqlen, # Strides stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim, stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim, stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, stride_D_head, stride_outz_batch, stride_outz_seqlen, stride_outz_head, stride_outz_hdim, stride_dz_batch, stride_dz_seqlen, stride_dz_head, stride_dz_hdim, stride_doutx_batch, stride_doutx_seqlen, stride_doutx_head, stride_doutx_hdim, stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim, stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, # Meta-parameters HAS_D: tl.constexpr, D_HAS_HDIM: tl.constexpr, HAS_DDACS: tl.constexpr, RECOMPUTE_OUTPUT: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, ): pid_bc = tl.program_id(axis=1) pid_c = pid_bc // batch pid_b = pid_bc - pid_c * batch pid_h = tl.program_id(axis=2) pid_m = tl.program_id(axis=0) dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head dout_x_ptr += pid_b * stride_doutx_batch + pid_c * chunk_size * stride_doutx_seqlen + pid_h * stride_doutx_head out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head z_ptr += pid_b * stride_z_batch + pid_c * chunk_size * stride_z_seqlen + pid_h * stride_z_head dz_ptr += pid_b * stride_dz_batch + pid_c * chunk_size * stride_dz_seqlen + pid_h * stride_dz_head if RECOMPUTE_OUTPUT: outz_ptr += pid_b * stride_outz_batch + pid_c * chunk_size * stride_outz_seqlen + pid_h * stride_outz_head if HAS_DDACS: ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head if HAS_D: x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = tl.arange(0, BLOCK_SIZE_N) dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) dout_x_ptrs = dout_x_ptr + (offs_m[:, None] * stride_doutx_seqlen + offs_n[None, :] * stride_doutx_hdim) out_ptrs = out_ptr + (offs_m[:, None] * stride_out_seqlen + offs_n[None, :] * stride_out_hdim) z_ptrs = z_ptr + (offs_m[:, None] * stride_z_seqlen + offs_n[None, :] * stride_z_hdim) dz_ptrs = dz_ptr + (offs_m[:, None] * stride_dz_seqlen + offs_n[None, :] * stride_dz_hdim) if RECOMPUTE_OUTPUT: outz_ptrs = outz_ptr + (offs_m[:, None] * stride_outz_seqlen + offs_n[None, :] * stride_outz_hdim) if HAS_D: x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) if D_HAS_HDIM: dD_ptrs = dD_ptr + offs_n * stride_dD_hdim chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) out = tl.load(out_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) z = tl.load(z_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) z_sigmoid = tl.sigmoid(z) if RECOMPUTE_OUTPUT: outz = out * z * z_sigmoid tl.store(outz_ptrs, outz, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) dz = dout * out * z_sigmoid * (1 + z * (1 - z_sigmoid)) tl.store(dz_ptrs, dz, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) dout *= z * z_sigmoid tl.store(dout_x_ptrs, dout, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) if HAS_D: x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) if D_HAS_HDIM: dD = tl.sum(dout * x, axis=0) tl.store(dD_ptrs, dD, mask=offs_n < hdim) D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) else: dD = tl.sum(dout * x) tl.store(dD_ptr, dD) D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) out -= x * D if HAS_DDACS: ddA_cs = tl.sum(dout * out, axis=1) tl.store(ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size) @triton.autotune( configs=autotune_configs([ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2), ]), key=['hdim', 'dstate', 'chunk_size'], ) @triton.jit def _chunk_scan_bwd_dstates_kernel( # Pointers to matrices dout_ptr, c_ptr, dprev_states_ptr, dA_cumsum_ptr, seq_idx_ptr, # Matrix dimensions hdim, dstate, chunk_size, batch, seqlen, nchunks, nheads_ngroups_ratio, # Strides stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, stride_c_batch, stride_c_seqlen, stride_c_head, stride_c_dstate, stride_dprev_states_batch, stride_dprev_states_chunk, stride_dprev_states_head, stride_dprev_states_hdim, stride_dprev_states_dstate, stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, stride_seq_idx_batch, stride_seq_idx_seqlen, # Meta-parameters HAS_SEQ_IDX: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ): pid_bc = tl.program_id(axis=1) pid_c = pid_bc // batch pid_b = pid_bc - pid_c * batch pid_h = tl.program_id(axis=2) num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) pid_m = tl.program_id(axis=0) // num_pid_n pid_n = tl.program_id(axis=0) % num_pid_n c_ptr += pid_b * stride_c_batch + pid_c * chunk_size * stride_c_seqlen + (pid_h // nheads_ngroups_ratio) * stride_c_head dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head if HAS_SEQ_IDX: seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_hdim + offs_k[None, :] * stride_dout_seqlen) c_ptrs = c_ptr + (offs_n[None, :] * stride_c_dstate + offs_k[:, None] * stride_c_seqlen) dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize if HAS_SEQ_IDX: seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) if HAS_SEQ_IDX: seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) for k in range(0, chunk_size_limit, BLOCK_SIZE_K): dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k), other=0.0).to(tl.float32) dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) if not HAS_SEQ_IDX: scale_k = tl.exp(dA_cs_k) else: seq_idx_k = tl.load(seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1) scale_k = tl.where(seq_idx_k == seq_idx_prev, tl.exp(dA_cs_k), 0.0) dout = (dout * scale_k).to(dout_ptr.dtype.element_ty) c = tl.load(c_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate), other=0.0).to(dout_ptr.dtype.element_ty) acc += tl.dot(dout, c) dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen c_ptrs += BLOCK_SIZE_K * stride_c_seqlen dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize if HAS_SEQ_IDX: seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen out = acc.to(dprev_states_ptr.dtype.element_ty) dprev_states_ptr += pid_b * stride_dprev_states_batch + pid_c * stride_dprev_states_chunk + pid_h * stride_dprev_states_head offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) dprev_states_ptrs = dprev_states_ptr + (offs_m[:, None] * stride_dprev_states_hdim + offs_n[None, :] * stride_dprev_states_dstate) tl.store(dprev_states_ptrs, out, mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)) @triton.autotune( configs=autotune_configs([ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), ]), key=['chunk_size', 'dstate', 'hdim'], ) @triton.jit def _chunk_scan_bwd_dc_kernel( # Pointers to matrices dout_ptr, prev_states_ptr, C_ptr, dA_cumsum_ptr, seq_idx_ptr, dc_ptr, ddA_cumsum_ptr, # Matrix dimensions chunk_size, dstate, hdim, batch, seqlen, nheads, nheads_per_program, ngroups, # Strides stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, stride_prev_states_batch, stride_prev_states_chunk, stride_prev_states_head, stride_prev_states_hdim, stride_prev_states_dstate, stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate, stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, stride_seq_idx_batch, stride_seq_idx_seqlen, stride_dc_batch, stride_dc_seqlen, stride_dc_split, stride_dc_group, stride_dc_dstate, stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, stride_ddA_tile, # Meta-parameters HAS_DDA_CS: tl.constexpr, HAS_SEQ_IDX: tl.constexpr, DETERMINISTIC_REDUCTION: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ): pid_bc = tl.program_id(axis=1) pid_c = pid_bc // batch pid_b = pid_bc - pid_c * batch pid_sg = tl.program_id(axis=2) pid_s = pid_sg // ngroups pid_g = pid_sg - pid_s * ngroups num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) pid_m = tl.program_id(axis=0) // num_pid_n pid_n = tl.program_id(axis=0) % num_pid_n dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dout_head dc_ptr += pid_b * stride_dc_batch + pid_c * chunk_size * stride_dc_seqlen + pid_g * stride_dc_group + pid_s * stride_dc_split prev_states_ptr += pid_b * stride_prev_states_batch + pid_c * stride_prev_states_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_prev_states_head dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head if HAS_DDA_CS: C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + pid_g * stride_C_head ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head + pid_n * stride_ddA_tile if HAS_SEQ_IDX: seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_prev_states_dstate + offs_k[:, None] * stride_prev_states_hdim) dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize if HAS_DDA_CS: C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_n[None, :] * stride_C_dstate) ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) if HAS_DDA_CS: c = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) if HAS_SEQ_IDX: seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) nheads_iter = min(nheads_per_program, nheads // ngroups - pid_s * nheads_per_program) for h in range(nheads_iter): dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) prev_states = tl.load(prev_states_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0) prev_states = prev_states.to(dout_ptrs.dtype.element_ty) dc = tl.dot(dout, prev_states) dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) if not HAS_SEQ_IDX: scale = tl.exp(dA_cs_m) else: scale = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0) dc *= scale[:, None] if HAS_DDA_CS: ddA_cs = tl.sum(dc * c, axis=1) if DETERMINISTIC_REDUCTION: tl.store(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size) else: tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size) acc += dc dout_ptrs += stride_dout_head prev_states_ptrs += stride_prev_states_head dA_cumsum_ptrs += stride_dA_cs_head if HAS_DDA_CS: ddA_cumsum_ptrs += stride_ddA_cs_head # if HAS_SEQ_IDX: # seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) # seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) # acc = tl.where(seq_idx_m[:, None] == seq_idx_prev, acc, 0.0) offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) dc_ptrs = dc_ptr + (offs_m[:, None] * stride_dc_seqlen + offs_n[None, :] * stride_dc_dstate) tl.store(dc_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate)) _CHUNK_SCAN_BWD_DC_MIN_BLOCK_N = min( cfg.kwargs['BLOCK_SIZE_N'] for cfg in _chunk_scan_bwd_dc_kernel.configs ) @triton.autotune( configs=autotune_configs([ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr"])), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), ]), key=['chunk_size', 'hdim'], ) @triton.jit def _chunk_scan_bwd_dx_kernel( # Pointers to matrices x_ptr, cb_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, D_ptr, dx_ptr, ddt_ptr, # dD_ptr, # Matrix dimensions chunk_size, hdim, batch, seqlen, nheads_ngroups_ratio, # Strides stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k, stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, stride_D_head, stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim, stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize, stride_ddt_tile, # stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_hdim, stride_dD_csize, # Meta-parameters HAS_D: tl.constexpr, D_HAS_HDIM: tl.constexpr, DETERMINISTIC_REDUCTION: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ): pid_bc = tl.program_id(axis=1) pid_c = pid_bc // batch pid_b = pid_bc - pid_c * batch pid_h = tl.program_id(axis=2) num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) pid_m = tl.program_id(axis=0) // num_pid_n pid_n = tl.program_id(axis=0) % num_pid_n x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head + pid_n * stride_ddt_tile dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head # if HAS_D: # dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k) dout_ptrs = dout_ptr + (offs_k[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) # Idk why limiting K_MAX gives wrong results, is it a Triton bug? # K_MAX = min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit) K_MAX = chunk_size_limit for k in range(0, K_MAX, BLOCK_SIZE_K): # For some reason setting mask to (offs_m[:, None] < chunk_size_limit) is much slower cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0) dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0) dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32) # cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None]) cb *= tl.exp(tl.minimum((dA_cs_k[None, :] - dA_cs_m[:, None]), 0.0)) # If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range, # we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf. # Multiplying with cb, which is 0.0 outside the range, will make the result NaN. # This will cause NaN in acc, and hence NaN in dx and ddt. mask = (k + offs_k[None, :] >= offs_m[:, None]) & (k + offs_k[None, :] < K_MAX) cb = tl.where(mask, cb, 0.0) cb = cb.to(dout_ptr.dtype.element_ty) acc += tl.dot(cb, dout) cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) dt_ptrs = dt_ptr + offs_m * stride_dt_csize dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) dx = acc * dt_m[:, None] dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head dx_ptrs = dx_ptr + (offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim) if HAS_D: dout_res_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) dout_res = tl.load(dout_res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) if D_HAS_HDIM: D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) else: D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) dx += dout_res * D tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) ddt = tl.sum(acc * x, axis=1) ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize if DETERMINISTIC_REDUCTION: tl.store(ddt_ptrs, ddt, mask=offs_m < chunk_size) else: tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size) # if HAS_D: # dout_new_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_csize + offs_n[None, :] * stride_dout_hdim) # dout = tl.load(dout_new_ptrs, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N), other=0.0).to(tl.float32) # dD = tl.sum(x * dout, axis=0) # tl.store(dD_ptr + offs_n * stride_dD_hdim, dD, mask=offs_n < N) _CHUNK_SCAN_BWD_DX_MIN_BLOCK_N = min( cfg.kwargs['BLOCK_SIZE_N'] for cfg in _chunk_scan_bwd_dx_kernel.configs ) # Disabling HAS_DDA_CS for now since it's much slower @triton.autotune( configs=autotune_configs([ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), # triton.Config({'BLOCK_SIZE_M': 16}, num_stages=3, num_warps=4), # triton.Config({'BLOCK_SIZE_M': 32}, num_stages=3, num_warps=4), # triton.Config({'BLOCK_SIZE_M': 64}, num_stages=3, num_warps=4), # triton.Config({'BLOCK_SIZE_M': 128}, num_stages=3, num_warps=4), # triton.Config({'BLOCK_SIZE_M': 16}, num_stages=4, num_warps=8), # triton.Config({'BLOCK_SIZE_M': 32}, num_stages=4, num_warps=8), # triton.Config({'BLOCK_SIZE_M': 64}, num_stages=4, num_warps=8), # triton.Config({'BLOCK_SIZE_M': 128}, num_stages=4, num_warps=8), ]), key=['chunk_size', 'hdim'], ) # @triton.heuristics({"BLOCK_SIZE_N": lambda args: max(triton.next_power_of_2(args["chunk_size"]), 16)}) # @triton.heuristics({"BLOCK_SIZE_N": lambda args: 32}) @triton.jit def _chunk_scan_bwd_dcb_kernel( # Pointers to matrices x_ptr, dout_ptr, cb_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, dcb_ptr, ddA_cumsum_ptr, # Matrix dimensions chunk_size, hdim, batch, seqlen, nheads, nheads_per_program, ngroups, # Strides stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_n, stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, stride_seq_idx_batch, stride_seq_idx_seqlen, stride_dcb_batch, stride_dcb_chunk, stride_dcb_split, stride_dcb_group, stride_dcb_csize_m, stride_dcb_csize_n, stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize_m, stride_ddA_cs_csize_n, # Meta-parameters HAS_DDA_CS: tl.constexpr, HAS_SEQ_IDX: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ): pid_bc = tl.program_id(axis=1) pid_c = pid_bc // batch pid_b = pid_bc - pid_c * batch pid_sg = tl.program_id(axis=2) pid_s = pid_sg // ngroups pid_g = pid_sg - pid_s * ngroups num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N) pid_m = tl.program_id(axis=0) // num_pid_n pid_n = tl.program_id(axis=0) % num_pid_n x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_x_head dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dout_head dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dt_head dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head if HAS_DDA_CS: cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + pid_g * stride_cb_head ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head + pid_m * stride_ddA_cs_csize_m if HAS_SEQ_IDX: seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim) dt_ptrs = dt_ptr + offs_n * stride_dt_csize if HAS_DDA_CS: cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n) ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_n * stride_ddA_cs_csize_n if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M: dcb_ptr += pid_b * stride_dcb_batch + pid_c * stride_dcb_chunk + pid_g * stride_dcb_group + pid_s * stride_dcb_split dcb_ptrs = dcb_ptr + (offs_m[:, None] * stride_dcb_csize_m + offs_n[None, :] * stride_dcb_csize_n) tl.store(dcb_ptrs, tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=dcb_ptr.dtype.element_ty), mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size)) return chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) chunk_size_limit_n = min(chunk_size_limit, (pid_m + 1) * BLOCK_SIZE_M) acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) if HAS_DDA_CS: cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size), other=0.0).to(tl.float32) nheads_iter = min(nheads_per_program, nheads // ngroups - pid_s * nheads_per_program) for h in range(nheads_iter): dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n), other=0.0) dcb = tl.dot(dout, x) dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size, other=0.0).to(tl.float32) dcb *= dt_n dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size_limit, other=0.0).to(tl.float32) # dcb *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) dcb *= tl.exp(tl.minimum((dA_cs_m[:, None] - dA_cs_n[None, :]), 0.0)) if HAS_DDA_CS: tl.static_assert(not HAS_SEQ_IDX, "HAS_SEQ_IDX not supported with HAS_DDA_CS yet") ddA_cs = dcb * cb mask = offs_m[:, None] >= offs_n[None, :] + 1 ddA_cs = tl.where(mask, ddA_cs, 0.0) ddA_cs = tl.cumsum(ddA_cs, axis=1) ddA_cs = tl.where(mask, ddA_cs, 0.0) ddA_cs = tl.sum(ddA_cs, axis=0) tl.store(ddA_cumsum_ptrs + stride_ddA_cs_csize_n, ddA_cs, mask=offs_n < chunk_size - 1) tl.store(ddA_cumsum_ptr, 0.0) acc += dcb dout_ptrs += stride_dout_head x_ptrs += stride_x_head dt_ptrs += stride_dt_head dA_cumsum_ptr += stride_dA_cs_head if HAS_DDA_CS: ddA_cumsum_ptr += stride_ddA_cs_head ddA_cumsum_ptrs += stride_ddA_cs_head offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) if HAS_SEQ_IDX: seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, mask=offs_n < chunk_size_limit, other=-2) acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0) mask = offs_m[:, None] >= offs_n[None, :] acc = tl.where(mask, acc, 0.0) dcb_ptr += pid_b * stride_dcb_batch + pid_c * stride_dcb_chunk + pid_g * stride_dcb_group + pid_s * stride_dcb_split dcb_ptrs = dcb_ptr + (offs_m[:, None] * stride_dcb_csize_m + offs_n[None, :] * stride_dcb_csize_n) tl.store(dcb_ptrs, acc, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size)) # Not numerically stable and should not be used. Leaving here for reference. @triton.autotune( configs=autotune_configs([ triton.Config({'BLOCK_SIZE_M': 32}), triton.Config({'BLOCK_SIZE_M': 64}), triton.Config({'BLOCK_SIZE_M': 128}), triton.Config({'BLOCK_SIZE_M': 256}), ]), key=["chunk_size", "hdim"], ) @triton.jit def _chunk_scan_bwd_ddAcs_unstable_kernel( # Pointers to matrices dout_ptr, out_ptr, dt_ptr, ddt_ptr, x_ptr, D_ptr, ddA_cumsum_ptr, dD_ptr, # Matrix dimensions chunk_size, hdim, batch, seqlen, # Strides stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim, stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize, stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, stride_D_head, stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim, # Meta-parameters HAS_D: tl.constexpr, D_HAS_HDIM: tl.constexpr, SUBTRACT_DDTDT: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, ): pid_bc = tl.program_id(axis=1) pid_c = pid_bc // batch pid_b = pid_bc - pid_c * batch pid_h = tl.program_id(axis=2) pid_m = tl.program_id(axis=0) dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head if HAS_D: x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = tl.arange(0, BLOCK_SIZE_N) dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) out_ptrs = out_ptr + (offs_m[:, None] * stride_out_seqlen + offs_n[None, :] * stride_out_hdim) if HAS_D: x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) if D_HAS_HDIM: dD_ptrs = dD_ptr + offs_n * stride_dD_hdim chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) out = tl.load(out_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) if HAS_D: x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) if D_HAS_HDIM: dD = tl.sum(dout * x, axis=0) tl.store(dD_ptrs, dD, mask=offs_n < hdim) D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) else: dD = tl.sum(dout * x) tl.store(dD_ptr, dD) D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) out -= x * D ddA_cs = tl.sum(dout * out, axis=1) if SUBTRACT_DDTDT: dt = tl.load(dt_ptr + offs_m * stride_dt_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) ddt = tl.load(ddt_ptr + offs_m * stride_ddt_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) ddA_cs -= dt * ddt tl.store(ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size) @triton.autotune( configs=autotune_configs([ # triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), # triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8), # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8), # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8), # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8), triton.Config({'BLOCK_SIZE_M': 16}, num_stages=3, num_warps=4), triton.Config({'BLOCK_SIZE_M': 32}, num_stages=3, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64}, num_stages=3, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128}, num_stages=3, num_warps=4), triton.Config({'BLOCK_SIZE_M': 16}, num_stages=4, num_warps=8), triton.Config({'BLOCK_SIZE_M': 32}, num_stages=4, num_warps=8), triton.Config({'BLOCK_SIZE_M': 64}, num_stages=4, num_warps=8), triton.Config({'BLOCK_SIZE_M': 128}, num_stages=4, num_warps=8), ]), key=['chunk_size', 'hdim'], ) @triton.jit def _chunk_scan_bwd_ddAcs_stable_kernel_old( # Pointers to matrices x_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, cb_ptr, ddAcs_ptr, # Matrix dimensions chunk_size, hdim, batch, seqlen, nheads_ngroups_ratio, # Strides stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_n, stride_ddAcs_batch, stride_ddAcs_chunk, stride_ddAcs_head, stride_ddAcs_csize_m, stride_ddAcs_csize_n, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ): pid_bc = tl.program_id(axis=1) pid_c = pid_bc // batch pid_b = pid_bc - pid_c * batch pid_h = tl.program_id(axis=2) num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N) pid_m = tl.program_id(axis=0) // num_pid_n pid_n = tl.program_id(axis=0) % num_pid_n x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim) dt_ptrs = dt_ptr + offs_n * stride_dt_csize cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n) chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) chunk_size_limit_n = min(chunk_size_limit, (pid_m + 1) * BLOCK_SIZE_M) # Doing a matmul loop with cumsum later on will cause Triton to crash # Instead we do just one big matmul # acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) # for k in range(0, hdim, BLOCK_SIZE_K): # dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim - k), other=0.0) # x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim - k) & (offs_n[None, :] < chunk_size_limit), other=0.0) # acc += tl.dot(dout, x) # dout_ptrs += BLOCK_SIZE_K * stride_dout_hdim # x_ptrs += BLOCK_SIZE_K * stride_x_hdim dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n), other=0.0) acc = tl.dot(dout, x) cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size), other=0.0).to(tl.float32) acc *= cb dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size, other=0.0).to(tl.float32) acc *= dt_n dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size, other=0.0).to(tl.float32) # acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) acc *= tl.exp(tl.minimum((dA_cs_m[:, None] - dA_cs_n[None, :]), 0.0)) mask = offs_m[:, None] >= offs_n[None, :] + 1 acc = tl.where(mask, acc, 0.0) acc = tl.cumsum(acc, axis=1) acc = tl.where(mask, acc, 0.0) ddA_cs = tl.sum(acc, axis=0) ddAcs_ptr += pid_b * stride_ddAcs_batch + pid_c * stride_ddAcs_chunk + pid_h * stride_ddAcs_head + pid_m * stride_ddAcs_csize_m offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) ddAcs_ptrs = ddAcs_ptr + offs_n * stride_ddAcs_csize_n tl.store(ddAcs_ptrs + stride_ddAcs_csize_n, ddA_cs, mask=offs_n < chunk_size - 1) tl.store(ddAcs_ptr, 0.0) # offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, 64) # offs_k = tl.arange(0, BLOCK_SIZE_K) # dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) # x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim) # dt_ptrs = dt_ptr + offs_n * stride_dt_csize # cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n) # chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) # chunk_size_limit_n = min(chunk_size_limit, (pid_m + 1) * BLOCK_SIZE_M) # rowsum = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) # dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) # dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) # ddAcs_ptr += pid_b * stride_ddAcs_batch + pid_c * stride_ddAcs_chunk + pid_h * stride_ddAcs_head + pid_m * stride_ddAcs_csize_m # ddAcs_ptrs = ddAcs_ptr + offs_n * stride_ddAcs_csize_n # for n in range(0, chunk_size_limit_n, 64): # x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n - n), other=0.0) # acc = tl.dot(dout, x) # cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size - n), other=0.0).to(tl.float32) # acc *= cb # dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size - n, other=0.0).to(tl.float32) # acc *= dt_n # dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size - n, other=0.0).to(tl.float32) # acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) # mask = offs_m[:, None] >= offs_n[None, :] + 1 + n # acc = tl.where(mask, acc, 0.0) # acc = tl.cumsum(acc, axis=1) # acc = tl.where(mask, acc, 0.0) # ddA_cs = tl.sum(acc, axis=0) # tl.store(ddAcs_ptrs, ddA_cs, mask=offs_n < chunk_size - 1 - n) # # tl.store(ddAcs_ptr, 0.0) @triton.autotune( configs=autotune_configs([ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4), ]), key=['chunk_size', 'hdim'], ) @triton.jit def _chunk_scan_bwd_ddAcs_stable_kernel( # Pointers to matrices x_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, cb_ptr, ddA_cumsum_ptr, # Matrix dimensions chunk_size, hdim, batch, seqlen, nheads_ngroups_ratio, # Strides stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_n, stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize_m, stride_ddA_cs_csize_n, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ): pid_bc = tl.program_id(axis=1) pid_c = pid_bc // batch pid_b = pid_bc - pid_c * batch pid_h = tl.program_id(axis=2) pid_m = tl.program_id(axis=0) x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head + pid_m * stride_ddA_cs_csize_m offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim) dt_ptrs = dt_ptr + offs_n * stride_dt_csize cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n) ddAcs_ptrs = ddA_cumsum_ptr + offs_n * stride_ddA_cs_csize_n tl.store(ddA_cumsum_ptr, 0.0) chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) rowsum = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) # Actually hi is (pid_m + 1) * BLOCK_SIZE_M - 1 but subtracting 1 makes it slower lo, hi = 0, (pid_m + 1) * BLOCK_SIZE_M # lo, hi = 0, chunk_size for start_n in range(lo, hi, BLOCK_SIZE_N): start_n = tl.multiple_of(start_n, BLOCK_SIZE_N) # Doing a matmul loop with cumsum later on will cause Triton to crash # Instead we do just one big matmul # acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) # for k in range(0, hdim, BLOCK_SIZE_K): # dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim - k), other=0.0) # x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim - k) & (offs_n[None, :] < chunk_size_limit), other=0.0) # acc += tl.dot(dout, x) # dout_ptrs += BLOCK_SIZE_K * stride_dout_hdim # x_ptrs += BLOCK_SIZE_K * stride_x_hdim # x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n), other=0.0) x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit - start_n), other=0.0) acc = tl.dot(dout, x) dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size - start_n, other=0.0).to(tl.float32) acc *= dt_n # If there's seq_idx, we already zero'ed out cb[i, j] for seq_idx[i] != seq_idx[j] cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size - start_n), other=0.0).to(tl.float32) acc *= cb dA_cs_n = tl.load(dA_cumsum_ptr + (start_n + offs_n) * stride_dA_cs_csize, mask=offs_n < chunk_size - start_n, other=0.0).to(tl.float32) # acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) acc *= tl.exp(tl.minimum((dA_cs_m[:, None] - dA_cs_n[None, :]), 0.0)) mask = offs_m[:, None] >= start_n + offs_n[None, :] + 1 acc = tl.where(mask, acc, 0.0) rowsum_new = rowsum + tl.sum(acc, axis=1) acc = rowsum[:, None] + tl.cumsum(acc, axis=1) rowsum = rowsum_new acc = tl.where(mask, acc, 0.0) ddA_cs = tl.sum(acc, axis=0) tl.store(ddAcs_ptrs + stride_ddA_cs_csize_n, ddA_cs, mask=offs_n < chunk_size - start_n - 1) x_ptrs += BLOCK_SIZE_N * stride_x_seqlen dt_ptrs += BLOCK_SIZE_N * stride_dt_csize cb_ptrs += BLOCK_SIZE_N * stride_cb_csize_n ddAcs_ptrs += BLOCK_SIZE_N * stride_ddA_cs_csize_n # Need to zero out the rest, since we'll be summing the rows together for start_n in range(hi, chunk_size, BLOCK_SIZE_N): tl.store(ddAcs_ptrs + stride_ddA_cs_csize_n, tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32), mask=offs_n < chunk_size - start_n - 1) ddAcs_ptrs += BLOCK_SIZE_N * stride_ddA_cs_csize_n @triton.autotune( configs=autotune_configs([ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), ]), key=['chunk_size', 'dstate', 'hdim'], ) @triton.jit def _chunk_scan_bwd_ddAcs_prev_kernel( # Pointers to matrices dout_ptr, prev_states_ptr, C_ptr, dA_cumsum_ptr, seq_idx_ptr, ddA_cumsum_ptr, # Matrix dimensions chunk_size, dstate, hdim, batch, seqlen, nchunks, nheads_ngroups_ratio, # Strides stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, stride_prev_states_batch, stride_prev_states_chunk, stride_prev_states_head, stride_prev_states_hdim, stride_prev_states_dstate, stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate, stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, stride_seq_idx_batch, stride_seq_idx_seqlen, stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, # Meta-parameters HAS_SEQ_IDX: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ): pid_bc = tl.program_id(axis=1) pid_c = pid_bc // batch pid_b = pid_bc - pid_c * batch pid_h = tl.program_id(axis=2) num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) pid_m = tl.program_id(axis=0) // num_pid_n pid_n = tl.program_id(axis=0) % num_pid_n dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head prev_states_ptr += pid_b * stride_prev_states_batch + pid_c * stride_prev_states_chunk + pid_h * stride_prev_states_head C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + (pid_h // nheads_ngroups_ratio) * stride_C_head ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head if HAS_SEQ_IDX: seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_prev_states_dstate + offs_k[:, None] * stride_prev_states_hdim) C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_n[None, :] * stride_C_dstate) dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) prev_states = tl.load(prev_states_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0) prev_states = prev_states.to(dout_ptrs.dtype.element_ty) acc = tl.dot(dout, prev_states) c = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) ddA_cs = tl.sum(acc * c, axis=1) dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) if not HAS_SEQ_IDX: scale = tl.exp(dA_cs_m) if HAS_SEQ_IDX: seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) scale = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0) ddA_cs *= scale offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size) def _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D=None, z=None, seq_idx=None): batch, seqlen, nheads, headdim = x.shape _, _, nchunks, chunk_size = dt.shape _, _, ngroups, dstate = C.shape assert nheads % ngroups == 0 assert C.shape == (batch, seqlen, ngroups, dstate) assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) if z is not None: assert z.shape == x.shape if D is not None: assert D.shape == (nheads, headdim) or D.shape == (nheads,) assert dt.shape == (batch, nheads, nchunks, chunk_size) assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) assert states.shape == (batch, nchunks, nheads, headdim, dstate) if seq_idx is not None: assert seq_idx.shape == (batch, seqlen) # Allocates output. out = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype) if z is not None: out_x = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype) assert out_x.stride() == out.stride() else: out_x = None grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']), batch * nchunks, nheads) z_strides = ((z.stride(0), z.stride(1), z.stride(2), z.stride(3)) if z is not None else (0, 0, 0, 0)) _chunk_scan_fwd_kernel[grid]( cb, x, z, out, out_x, dt, dA_cumsum, seq_idx, C, states, D, chunk_size, headdim, dstate, batch, seqlen, nheads // ngroups, cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4), x.stride(0), x.stride(1), x.stride(2), x.stride(3), z_strides[0], z_strides[1], z_strides[2], z_strides[3], out.stride(0), out.stride(1), out.stride(2), out.stride(3), dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), C.stride(0), C.stride(1), C.stride(2), C.stride(3), states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4), D.stride(0) if D is not None else 0, True, D is not None, D.dim() == 2 if D is not None else True, BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), HAS_Z=z is not None, HAS_SEQ_IDX=seq_idx is not None, IS_TRITON_22=TRITON_22, ) return out, out_x def _chunk_scan_fwd_wip(cb, x, dt, dA_cumsum, C, B, states, D=None, z=None, seq_idx=None): batch, seqlen, nheads, headdim = x.shape _, _, nchunks, chunk_size = dt.shape _, _, ngroups, dstate = C.shape assert nheads % ngroups == 0 assert C.shape == (batch, seqlen, ngroups, dstate) assert B.shape == C.shape assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) if z is not None: assert z.shape == x.shape if D is not None: assert D.shape == (nheads, headdim) or D.shape == (nheads,) assert dt.shape == (batch, nheads, nchunks, chunk_size) assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) assert states.shape == (batch, nchunks, nheads, headdim, dstate) if seq_idx is not None: assert seq_idx.shape == (batch, seqlen) # Allocates output. out = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype) if z is not None: out_x = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype) assert out_x.stride() == out.stride() else: out_x = None grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_N']), batch * nchunks, nheads) z_strides = ((z.stride(0), z.stride(1), z.stride(2), z.stride(3)) if z is not None else (0, 0, 0, 0)) _chunk_scan_fwd_kernel_wip[grid]( cb, x, z, out, out_x, dt, dA_cumsum, seq_idx, C, B, states, D, chunk_size, headdim, dstate, batch, seqlen, nheads // ngroups, cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4), x.stride(0), x.stride(1), x.stride(2), x.stride(3), z_strides[0], z_strides[1], z_strides[2], z_strides[3], out.stride(0), out.stride(1), out.stride(2), out.stride(3), dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), C.stride(0), C.stride(1), C.stride(2), C.stride(3), B.stride(0), B.stride(1), B.stride(2), B.stride(3), states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4), D.stride(0) if D is not None else 0, D is not None, D.dim() == 2 if D is not None else True, BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), BLOCK_SIZE_M=128, HAS_Z=z is not None, HAS_SEQ_IDX=seq_idx is not None, ) return out, out_x def _chunk_scan_bwd_dz(x, z, out, dout, chunk_size, has_ddAcs=True, D=None, dz=None, recompute_output=False): batch, seqlen, nheads, headdim = x.shape assert z.shape == x.shape assert out.shape == x.shape assert dout.shape == out.shape nchunks = math.ceil(seqlen / chunk_size) if D is not None: assert D.shape == (nheads, headdim) or D.shape == (nheads,) assert D.stride(-1) == 1 if has_ddAcs: ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32) if D is not None: BLOCK_SIZE_min = 32 dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads, headdim if D.dim() == 2 else 1, device=D.device, dtype=torch.float32) else: dD = None if dz is not None: assert dz.shape == z.shape else: dz = torch.empty_like(z) if recompute_output: outz = torch.empty_like(x) dout_x = torch.empty_like(dout) dD_strides = ((dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4)) if D is not None else (0, 0, 0, 0, 0)) grid_dz = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads) with torch.cuda.device(x.device.index): _chunk_scan_bwd_dz_kernel[grid_dz]( dout, out, z, x, D, outz if recompute_output else None, dz, dout_x, dD, ddA_cumsum if has_ddAcs else None, chunk_size, headdim, batch, seqlen, dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), out.stride(0), out.stride(1), out.stride(2), out.stride(3), z.stride(0), z.stride(1), z.stride(2), z.stride(3), x.stride(0), x.stride(1), x.stride(2), x.stride(3), D.stride(0) if D is not None else 0, *((outz.stride(0), outz.stride(1), outz.stride(2), outz.stride(3)) if recompute_output else (0, 0, 0, 0)), dz.stride(0), dz.stride(1), dz.stride(2), dz.stride(3), dout_x.stride(0), dout_x.stride(1), dout_x.stride(2), dout_x.stride(3), dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4], *((ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3)) if has_ddAcs else (0, 0, 0, 0)), D is not None, D.dim() == 2 if D is not None else True, has_ddAcs, BLOCK_SIZE_N=max(triton.next_power_of_2(headdim), 16), RECOMPUTE_OUTPUT=recompute_output, ) if D is not None: BLOCK_SIZE_actual = _chunk_scan_bwd_dz_kernel.best_config.kwargs["BLOCK_SIZE_M"] n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype) if D.dim() == 1: dD = rearrange(dD, "h 1 -> h") return_vals = (dz, dout_x, dD, ddA_cumsum) if has_ddAcs else (dz, dout_x, dD) return return_vals if not recompute_output else (*return_vals, outz) def _chunk_scan_bwd_dstates(C, dA_cumsum, dout, seq_idx=None, dtype=None): batch, seqlen, nheads, headdim = dout.shape _, _, nchunks, chunk_size = dA_cumsum.shape _, _, ngroups, dstate = C.shape assert nheads % ngroups == 0 assert C.shape == (batch, seqlen, ngroups, dstate) assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) if seq_idx is not None: assert seq_idx.shape == (batch, seqlen) dtype = C.dtype if dtype is None else dtype dprev_states = torch.empty(batch, nchunks, nheads, headdim, dstate, device=C.device, dtype=dtype) grid_dstates = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']), batch * nchunks, nheads) with torch.cuda.device(C.device.index): _chunk_scan_bwd_dstates_kernel[grid_dstates]( dout, C, dprev_states, dA_cumsum, seq_idx, headdim, dstate, chunk_size, batch, seqlen, nchunks, nheads // ngroups, dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), C.stride(0), C.stride(1), C.stride(2), C.stride(3), dprev_states.stride(0), dprev_states.stride(1), dprev_states.stride(2), dprev_states.stride(3), dprev_states.stride(4), dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), HAS_SEQ_IDX=seq_idx is not None, ) return dprev_states def _chunk_scan_bwd_dC(prev_states, dA_cumsum, dout, seq_idx=None, C=None, ngroups=1): batch, nchunks, nheads, headdim, dstate = prev_states.shape _, seqlen, _, _ = dout.shape _, _, _, chunk_size = dA_cumsum.shape assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate) assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) assert dout.shape == (batch, seqlen, nheads, headdim) if seq_idx is not None: assert seq_idx.shape == (batch, seqlen) deterministic = use_deterministic_mode() if C is not None: assert C.shape == (batch, seqlen, ngroups, dstate) C_strides = (C.stride(0), C.stride(1), C.stride(2), C.stride(3)) tile_count = math.ceil(dstate / _CHUNK_SCAN_BWD_DC_MIN_BLOCK_N) ddA_cumsum_prev, stride_ddA_tile = alloc_tile_workspace( (batch, nheads, nchunks, chunk_size), tile_count, torch.float32, dout.device, deterministic, zero_init=True, ) ddA_cumsum_prev_strides = ( ddA_cumsum_prev.stride(0), ddA_cumsum_prev.stride(2), ddA_cumsum_prev.stride(1), ddA_cumsum_prev.stride(3), ) else: C_strides = (0, 0, 0, 0) ddA_cumsum_prev = None ddA_cumsum_prev_strides = (0, 0, 0, 0) stride_ddA_tile = 0 nheads_ngroups_ratio = nheads // ngroups sm_count = torch.cuda.get_device_properties(dout.device).multi_processor_count nheads_per_program = max(min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1) nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program) dC = torch.empty(batch, seqlen, nsplits, ngroups, dstate, device=dout.device, dtype=torch.float32) grid_dc = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']), batch * nchunks, nsplits * ngroups) with torch.cuda.device(dout.device.index): _chunk_scan_bwd_dc_kernel[grid_dc]( dout, prev_states, C, dA_cumsum, seq_idx, dC, ddA_cumsum_prev, chunk_size, dstate, headdim, batch, seqlen, nheads, nheads_per_program, ngroups, dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), prev_states.stride(0), prev_states.stride(1), prev_states.stride(2), prev_states.stride(3), prev_states.stride(4), *C_strides, dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), dC.stride(0), dC.stride(1), dC.stride(2), dC.stride(3), dC.stride(4), *ddA_cumsum_prev_strides, stride_ddA_tile, HAS_DDA_CS=ddA_cumsum_prev is not None, HAS_SEQ_IDX=seq_idx is not None, DETERMINISTIC_REDUCTION=deterministic, BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), ) dC = dC.sum(2) if ddA_cumsum_prev is not None: ddA_cumsum_prev = finalize_tile_workspace(ddA_cumsum_prev, deterministic) return dC if C is None else (dC, ddA_cumsum_prev) def _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=None, CB=None, ngroups=1): batch, seqlen, nheads, headdim = x.shape _, _, nchunks, chunk_size = dt.shape assert dt.shape == (batch, nheads, nchunks, chunk_size) assert dA_cumsum.shape == dt.shape assert dout.shape == x.shape if seq_idx is not None: assert seq_idx.shape == (batch, seqlen) if CB is not None: assert CB.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) CB_strides = (CB.stride(0), CB.stride(1), CB.stride(2), CB.stride(3), CB.stride(4)) BLOCK_SIZE_M_min = 16 ddA_cumsum = torch.empty(batch, nheads, nchunks, triton.cdiv(chunk_size, BLOCK_SIZE_M_min), chunk_size, device=x.device, dtype=torch.float32) ddA_cumsum_strides = (ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), ddA_cumsum.stride(4)) else: CB_strides = (0, 0, 0, 0, 0) ddA_cumsum = None ddA_cumsum_strides = (0, 0, 0, 0, 0) nheads_ngroups_ratio = nheads // ngroups sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count nheads_per_program = max(min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1) nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program) dcb = torch.empty(batch, nchunks, nsplits, ngroups, chunk_size, chunk_size, device=x.device, dtype=torch.float32) grid_dcb = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(chunk_size, META['BLOCK_SIZE_N']), batch * nchunks, nsplits * ngroups) with torch.cuda.device(x.device.index): _chunk_scan_bwd_dcb_kernel[grid_dcb]( x, dout, CB, dt, dA_cumsum, seq_idx, dcb, ddA_cumsum, chunk_size, headdim, batch, seqlen, nheads, nheads_per_program, ngroups, x.stride(0), x.stride(1), x.stride(2), x.stride(3), dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), *CB_strides, dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), dcb.stride(0), dcb.stride(1), dcb.stride(2), dcb.stride(3), dcb.stride(4), dcb.stride(5), *ddA_cumsum_strides, HAS_DDA_CS=ddA_cumsum is not None, HAS_SEQ_IDX=seq_idx is not None, BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), ) dcb = dcb.sum(2) if ddA_cumsum is not None: BLOCK_SIZE_M_actual = _chunk_scan_bwd_dcb_kernel.best_config.kwargs["BLOCK_SIZE_M"] n_valid_blocks = (chunk_size + BLOCK_SIZE_M_actual - 1) // BLOCK_SIZE_M_actual ddA_cumsum = ddA_cumsum[:, :, :, :n_valid_blocks].sum(dim=3) return dcb if CB is None else (dcb, ddA_cumsum) def _chunk_scan_bwd_dx(cb, x, dt, dA_cumsum, dout, D=None): batch, seqlen, nheads, headdim = x.shape _, _, nchunks, chunk_size = dt.shape ngroups = cb.shape[2] assert nheads % ngroups == 0 assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) assert dt.shape == (batch, nheads, nchunks, chunk_size) assert dA_cumsum.shape == dt.shape assert dout.shape == x.shape # if D is not None: # BLOCK_SIZE_M_min = 32 # dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_M_min), batch, nchunks, nheads, headdim, device=D.device, dtype=torch.float32) # else: # dD = None dx = torch.empty_like(x) deterministic = use_deterministic_mode() tile_count = math.ceil(headdim / _CHUNK_SCAN_BWD_DX_MIN_BLOCK_N) ddt, stride_ddt_tile = alloc_tile_workspace( (batch, nheads, nchunks, chunk_size), tile_count, torch.float32, dout.device, deterministic, zero_init=True, ) grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']), batch * nchunks, nheads) with torch.cuda.device(x.device.index): _chunk_scan_bwd_dx_kernel[grid_dx]( x, cb, dout, dt, dA_cumsum, D, dx, ddt, # dD, chunk_size, headdim, batch, seqlen, nheads // ngroups, x.stride(0), x.stride(1), x.stride(2), x.stride(3), cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(-1), cb.stride(-2), dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), D.stride(0) if D is not None else 0, dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3), ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3), stride_ddt_tile, # dD.stride(1) if dD is not None else 0, dD.stride(2) if dD is not None else 0, dD.stride(3) if dD is not None else 0, dD.stride(4) if dD is not None else 0, dD.stride(0) if dD is not None else 0, D is not None, D.dim() == 2 if D is not None else True, DETERMINISTIC_REDUCTION=deterministic, ) # if D is not None: # BLOCK_SIZE_actual = _chunk_scan_bwd_dx_kernel.best_config.kwargs["BLOCK_SIZE_M"] # n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual # dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype) ddt = finalize_tile_workspace(ddt, deterministic) return dx, ddt def _chunk_scan_bwd_ddAcs_unstable(x, dt, out, dout, ddt, D=None, subtract_ddtdt=True): """Not numerically stable and should not be used. Leaving here for reference. """ batch, seqlen, nheads, headdim = x.shape _, _, nchunks, chunk_size = dt.shape assert dt.shape == (batch, nheads, nchunks, chunk_size) assert ddt.shape == dt.shape assert out.shape == x.shape assert dout.shape == x.shape if D is not None: assert D.shape == (nheads, headdim) or D.shape == (nheads,) ddA_cumsum = torch.empty_like(dt) grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads) if D is not None: # Triton gives wrong results if we write to the same location BLOCK_SIZE_min = 32 dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads, headdim if D.dim() == 2 else 1, device=D.device, dtype=torch.float32) else: dD = None dD_strides = ((dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4)) if D is not None else (0, 0, 0, 0, 0)) with torch.cuda.device(x.device.index): _chunk_scan_bwd_ddAcs_unstable_kernel[grid_ddtcs]( dout, out, dt, ddt, x, D, ddA_cumsum, dD, chunk_size, headdim, batch, seqlen, dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), out.stride(0), out.stride(1), out.stride(2), out.stride(3), dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3), x.stride(0), x.stride(1), x.stride(2), x.stride(3), D.stride(0) if D is not None else 0, ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4], D is not None, D.dim() == 2 if D is not None else True, subtract_ddtdt, BLOCK_SIZE_N=max(triton.next_power_of_2(headdim), 16), ) if D is not None: BLOCK_SIZE_actual = _chunk_scan_bwd_ddAcs_unstable_kernel.best_config.kwargs["BLOCK_SIZE_M"] n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype) if D.dim() == 1: dD = rearrange(dD, "h 1 -> h") return ddA_cumsum, dD def _chunk_scan_bwd_ddAcs_stable_old(x, dt, dA_cumsum, dout, cb): batch, seqlen, nheads, headdim = x.shape _, _, nchunks, chunk_size = dt.shape assert dt.shape == (batch, nheads, nchunks, chunk_size) assert dout.shape == x.shape assert dA_cumsum.shape == dt.shape ngroups = cb.shape[2] assert nheads % ngroups == 0 assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) BLOCK_SIZE_M_min = 16 ddA_cumsum = torch.empty(batch, nheads, nchunks, triton.cdiv(chunk_size, BLOCK_SIZE_M_min), chunk_size, device=x.device, dtype=torch.float32) grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads) with torch.cuda.device(x.device.index): _chunk_scan_bwd_ddAcs_stable_kernel_old[grid_ddtcs]( x, dout, dt, dA_cumsum, cb, ddA_cumsum, chunk_size, headdim, batch, seqlen, nheads // ngroups, x.stride(0), x.stride(1), x.stride(2), x.stride(3), dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4), ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), ddA_cumsum.stride(4), BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), BLOCK_SIZE_N=max(triton.next_power_of_2(chunk_size), 16), ) BLOCK_SIZE_M_actual = _chunk_scan_bwd_ddAcs_stable_kernel_old.best_config.kwargs["BLOCK_SIZE_M"] n_valid_blocks = (chunk_size + BLOCK_SIZE_M_actual - 1) // BLOCK_SIZE_M_actual ddA_cumsum = ddA_cumsum[:, :, :, :n_valid_blocks].sum(dim=3) return ddA_cumsum def _chunk_scan_bwd_ddAcs_stable(x, dt, dA_cumsum, dout, cb): batch, seqlen, nheads, headdim = x.shape _, _, nchunks, chunk_size = dt.shape assert dt.shape == (batch, nheads, nchunks, chunk_size) assert dout.shape == x.shape assert dA_cumsum.shape == dt.shape ngroups = cb.shape[2] assert nheads % ngroups == 0 assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) BLOCK_SIZE_M_min = 32 ddA_cumsum = torch.empty(batch, nheads, nchunks, triton.cdiv(chunk_size, BLOCK_SIZE_M_min), chunk_size, device=x.device, dtype=torch.float32) grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads) with torch.cuda.device(x.device.index): _chunk_scan_bwd_ddAcs_stable_kernel[grid_ddtcs]( x, dout, dt, dA_cumsum, cb, ddA_cumsum, chunk_size, headdim, batch, seqlen, nheads // ngroups, x.stride(0), x.stride(1), x.stride(2), x.stride(3), dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4), ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), ddA_cumsum.stride(4), BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), ) BLOCK_SIZE_M_actual = _chunk_scan_bwd_ddAcs_stable_kernel.best_config.kwargs["BLOCK_SIZE_M"] n_valid_blocks = (chunk_size + BLOCK_SIZE_M_actual - 1) // BLOCK_SIZE_M_actual ddA_cumsum = ddA_cumsum[:, :, :, :n_valid_blocks].sum(dim=3) return ddA_cumsum def _chunk_scan_bwd_ddAcs_prev(prev_states, C, dout, dA_cumsum, seq_idx=None): batch, nchunks, nheads, headdim, dstate = prev_states.shape _, seqlen, _, _ = dout.shape _, _, _, chunk_size = dA_cumsum.shape assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate) assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) assert dout.shape == (batch, seqlen, nheads, headdim) ngroups = C.shape[2] assert nheads % ngroups == 0 assert C.shape == (batch, seqlen, ngroups, dstate) if seq_idx is not None: assert seq_idx.shape == (batch, seqlen) ddA_cumsum_prev = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32) grid_ddAcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']), batch * nchunks, nheads) with torch.cuda.device(dout.device.index): _chunk_scan_bwd_ddAcs_prev_kernel[grid_ddAcs]( dout, prev_states, C, dA_cumsum, seq_idx, ddA_cumsum_prev, chunk_size, dstate, headdim, batch, seqlen, nchunks, nheads // ngroups, dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), prev_states.stride(0), prev_states.stride(1), prev_states.stride(2), prev_states.stride(3), prev_states.stride(4), C.stride(0), C.stride(1), C.stride(2), C.stride(3), dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), ddA_cumsum_prev.stride(0), ddA_cumsum_prev.stride(2), ddA_cumsum_prev.stride(1), ddA_cumsum_prev.stride(3), HAS_SEQ_IDX=seq_idx is not None, BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), ) return ddA_cumsum_prev class ChunkScanFn(torch.autograd.Function): @staticmethod def forward(ctx, B, C, x, dt, dA_cumsum, prev_states, D=None, z=None): # Check constraints. batch, seqlen, nheads, headdim = x.shape _, _, ngroups, dstate = B.shape assert B.shape == (batch, seqlen, ngroups, dstate) _, _, nchunks, chunk_size = dt.shape assert seqlen == nchunks * chunk_size assert C.shape == B.shape if z is not None: assert z.shape == x.shape if D is not None: assert D.shape == (nheads, headdim) or D.shape == (nheads,) assert dt.shape == (batch, nheads, nchunks, chunk_size) assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate) if B.stride(-1) != 1: B = B.contiguous() if C.stride(-1) != 1: C = C.contiguous() if x.stride(-1) != 1 and x.stride(1) != 1: # Either M or K dimension should be contiguous x = x.contiguous() if z is not None and z.stride(-1) != 1 and z.stride(1) != 1: # Either M or K dimension should be contiguous z = z.contiguous() if D is not None and D.stride(-1) != 1: D = D.contiguous() CB = _bmm_chunk_fwd(C, B, chunk_size) out, out_x = _chunk_scan_fwd(CB, x, dt, dA_cumsum, C, prev_states, D=D, z=z) ctx.save_for_backward(out if z is None else out_x, B, C, CB, x, dt, dA_cumsum, prev_states, D, z) return out @staticmethod def backward(ctx, dout): if dout.stride(-1) != 1: dout = dout.contiguous() out, B, C, CB, x, dt, dA_cumsum, prev_states, D, z = ctx.saved_tensors batch, seqlen, nheads, headdim = x.shape _, _, nchunks, chunk_size = dt.shape _, _, ngroups, dstate = B.shape assert dout.shape == (batch, seqlen, nheads, headdim) if z is not None: dz, dout, dD, ddA_cumsum = _chunk_scan_bwd_dz(x, z, out, dout, chunk_size=chunk_size, D=D) else: dz = None dprev_states = _chunk_scan_bwd_dstates(C, dA_cumsum, dout, dtype=prev_states.dtype) dC = _chunk_scan_bwd_dC(prev_states, dA_cumsum, dout, ngroups=ngroups) dC = dC.to(C.dtype) dCB = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, ngroups=ngroups) dCB = dCB.to(CB.dtype) dB = _bmm_chunk_bwd(C, dCB) dC = _bmm_chunk_bwd(B, rearrange(dCB, "... l s -> ... s l"), residual=dC) dx, ddt = _chunk_scan_bwd_dx(CB, x, dt, dA_cumsum, dout, D=D) # Formula for ddA_cumsum, assuming out is the output of the forward pass before adding x * D. # ddA_cumsum = torch.einsum("bclhp,bclhp->bhcl", out.float(), dout.float()) - ddt * dt if z is not None: ddA_cumsum -= ddt * dt else: # If z is not None, we already calculated ddA_cumsum and dD when computing dz ddA_cumsum, dD = _chunk_scan_bwd_ddAcs_unstable(x, dt, out, dout, ddt, D=D) ddA_cumsum = ddA_cumsum.to(dA_cumsum.dtype) return dB, dC, dx, ddt, ddA_cumsum, dprev_states, dD, dz def chunk_scan(B, C, x, dt, dA_cumsum, prev_states, D=None, z=None): """ prev_states contains the initial_states at index 0, and the state for the next-to-last chunk at index -1. Argument: B: (batch, seqlen, ngroups, dstate) C: (batch, seqlen, ngroups, dstate) x: (batch, seqlen, nheads, headdim) dt: (batch, nheads, nchunks, chunk_size) dA_cumsum: (batch, nheads, nchunks, chunk_size) prev_states: (batch, nchunks, nheads, headdim, dstate) D: (nheads, headdim) or (nheads,) z: (batch, seqlen, nheads, headdim) Return: out: (batch, seqlen, nheads, headdim) """ return ChunkScanFn.apply(B, C, x, dt, dA_cumsum, prev_states, D, z) def chunk_scan_ref(B, C, x, dt, dA_cumsum, prev_states, D=None, z=None): """ Argument: B: (batch, seqlen, ngroups, dstate) C: (batch, seqlen, ngroups, dstate) x: (batch, seqlen, nheads, headdim) dt: (batch, nheads, nchunks, chunk_size) dA_cumsum: (batch, nheads, nchunks, chunk_size) prev_states: (batch, nchunks, nheads, headdim, dstate) D: (nheads, headdim) or (nheads,) z: (batch, seqlen, nheads, headdim) Return: out: (batch, seqlen, nheads, headdim) """ batch, seqlen, nheads, headdim = x.shape _, _, ngroups, dstate = B.shape assert B.shape == (batch, seqlen, ngroups, dstate) _, _, nchunks, chunk_size = dt.shape assert seqlen == nchunks * chunk_size assert C.shape == B.shape B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups) C = repeat(C, "b l g d -> b l (g h) d", h=nheads // ngroups) CB = torch.einsum("bclhn,bcshn->bchls", rearrange(C, "b (c l) h n -> b c l h n", c=nchunks), rearrange(B, "b (c s) h n -> b c s h n", c=nchunks)) # (batch, nheads, nchunks, chunksize, chunksize) dt_segment_sum = dA_cumsum[:, :, :, :, None] - dA_cumsum[:, :, :, None, :] decay = torch.exp(dt_segment_sum) scores_decay = CB * rearrange(decay, "b h c l s -> b c h l s") causal_mask = torch.tril(torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0) scores_decay = scores_decay.masked_fill(~causal_mask, 0) out = torch.einsum('bchls,bhcs,bcshp->bclhp', scores_decay.to(x.dtype), dt.to(x.dtype), rearrange(x, "b (c s) h p -> b c s h p", c=nchunks)) state_decay_out = torch.exp(rearrange(dA_cumsum, "b h c l -> b c l h 1")) out_prev = torch.einsum('bclhn,bchpn->bclhp', rearrange(C, "b (c l) h n -> b c l h n", c=nchunks), prev_states.to(C.dtype)) * state_decay_out out = out + out_prev out = rearrange(out, "b c l h p -> b (c l) h p") if D is not None: if D.dim() == 1: D = rearrange(D, "h -> h 1") out = out + x * D return out if z is None else out * F.silu(z) ================================================ FILE: mamba_ssm/ops/triton/ssd_chunk_state.py ================================================ # Copyright (c) 2024, Tri Dao, Albert Gu. """We want triton==2.1.0 or 2.2.0 for this """ import math import torch import torch.nn.functional as F import triton import triton.language as tl from einops import rearrange, repeat from mamba_ssm.ops.triton.softplus import softplus from mamba_ssm.utils.determinism import ( alloc_tile_workspace, finalize_tile_workspace, use_deterministic_mode, autotune_configs, ) def init_to_zero(names): return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None] @triton.autotune( configs=autotune_configs([ triton.Config({'BLOCK_SIZE_H': 1}), triton.Config({'BLOCK_SIZE_H': 2}), triton.Config({'BLOCK_SIZE_H': 4}), triton.Config({'BLOCK_SIZE_H': 8}), triton.Config({'BLOCK_SIZE_H': 16}), triton.Config({'BLOCK_SIZE_H': 32}), triton.Config({'BLOCK_SIZE_H': 64}), ]), key=['chunk_size', 'nheads'], ) @triton.jit def _chunk_cumsum_fwd_kernel( # Pointers to matrices dt_ptr, A_ptr, dt_bias_ptr, dt_out_ptr, dA_cumsum_ptr, # Matrix dimension batch, seqlen, nheads, chunk_size, dt_min, dt_max, # Strides stride_dt_batch, stride_dt_seqlen, stride_dt_head, stride_A_head, stride_dt_bias_head, stride_dt_out_batch, stride_dt_out_chunk, stride_dt_out_head, stride_dt_out_csize, stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, # Meta-parameters DT_SOFTPLUS: tl.constexpr, HAS_DT_BIAS: tl.constexpr, BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_CHUNK: tl.constexpr, ): pid_b = tl.program_id(axis=0) pid_c = tl.program_id(axis=1) pid_h = tl.program_id(axis=2) dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) offs_c = tl.arange(0, BLOCK_SIZE_CHUNK) dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen) A_ptrs = A_ptr + offs_h * stride_A_head dt_out_ptrs = dt_out_ptr + (offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize) dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize) chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32) if HAS_DT_BIAS: dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0).to(tl.float32) dt += dt_bias[:, None] if DT_SOFTPLUS: dt = tl.where(dt <= 20.0, softplus(dt), dt) # As of Triton 2.2.0, tl.clamp is not available yet # dt = tl.clamp(dt, dt_min, dt_max) dt = tl.minimum(tl.maximum(dt, dt_min), dt_max) dt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0) tl.store(dt_out_ptrs, dt, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32) dA = dt * A[:, None] dA_cs = tl.cumsum(dA, axis=1) tl.store(dA_cs_ptrs, dA_cs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) @triton.autotune( configs=autotune_configs([ triton.Config({'BLOCK_SIZE_H': 1}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), triton.Config({'BLOCK_SIZE_H': 2}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), triton.Config({'BLOCK_SIZE_H': 4}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), triton.Config({'BLOCK_SIZE_H': 8}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), triton.Config({'BLOCK_SIZE_H': 16}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), triton.Config({'BLOCK_SIZE_H': 32}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), triton.Config({'BLOCK_SIZE_H': 64}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])), ]), key=['chunk_size', 'nheads'], ) @triton.jit def _chunk_cumsum_bwd_kernel( # Pointers to matrices ddA_ptr, ddt_out_ptr, dt_ptr, A_ptr, dt_bias_ptr, ddt_ptr, dA_ptr, ddt_bias_ptr, # Matrix dimensions batch, seqlen, nheads, chunk_size, dt_min, dt_max, # Strides stride_ddA_batch, stride_ddA_chunk, stride_ddA_head, stride_ddA_csize, stride_ddt_out_batch, stride_ddt_out_chunk, stride_ddt_out_head, stride_ddt_out_csize, stride_dt_batch, stride_dt_seqlen, stride_dt_head, stride_A_head, stride_dt_bias_head, stride_ddt_batch, stride_ddt_seqlen, stride_ddt_head, stride_dA_batch, stride_dA_chunk, stride_dA_head, stride_ddt_bias_batch, stride_ddt_bias_chunk, stride_ddt_bias_head, # Meta-parameters DT_SOFTPLUS: tl.constexpr, HAS_DT_BIAS: tl.constexpr, BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_CHUNK: tl.constexpr, DETERMINISTIC_REDUCTION: tl.constexpr, ): pid_b = tl.program_id(axis=0) pid_c = tl.program_id(axis=1) pid_h = tl.program_id(axis=2) ddt_out_ptr += pid_b * stride_ddt_out_batch + pid_c * stride_ddt_out_chunk ddA_ptr += pid_b * stride_ddA_batch + pid_c * stride_ddA_chunk dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen ddt_ptr += pid_b * stride_ddt_batch + pid_c * chunk_size * stride_ddt_seqlen offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) offs_c = tl.arange(0, BLOCK_SIZE_CHUNK) ddt_out_ptrs = ddt_out_ptr + (offs_h[:, None] * stride_ddt_out_head + offs_c[None, :] * stride_ddt_out_csize) ddA_ptrs = ddA_ptr + (offs_h[:, None] * stride_ddA_head + offs_c[None, :] * stride_ddA_csize) dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen) ddt_ptrs = ddt_ptr + (offs_h[:, None] * stride_ddt_head + offs_c[None, :] * stride_ddt_seqlen) A_ptrs = A_ptr + offs_h * stride_A_head chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) ddA = tl.load(ddA_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32) ddt_out = tl.load(ddt_out_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32) A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32) ddt = ddA * A[:, None] + ddt_out dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32) if HAS_DT_BIAS: dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0).to(tl.float32) dt += dt_bias[:, None] if DT_SOFTPLUS: dt_presoftplus = dt dt = tl.where(dt <= 20.0, softplus(dt), dt) clamp_mask = (dt < dt_min) | (dt > dt_max) # As of Triton 2.2.0, tl.clamp is not available yet # dt = tl.clamp(dt, dt_min, dt_max) dt = tl.minimum(tl.maximum(dt, dt_min), dt_max) dt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0) ddt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), ddt, 0.0) ddt = tl.where(clamp_mask, 0.0, ddt) if DT_SOFTPLUS: ddt = tl.where(dt_presoftplus <= 20.0, ddt * tl.sigmoid(dt_presoftplus), ddt) tl.store(ddt_ptrs, ddt, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit)) dA = tl.sum(ddA * dt, axis=1) dA_ptr += pid_b * stride_dA_batch + pid_c * stride_dA_chunk if DETERMINISTIC_REDUCTION: tl.store(dA_ptr + offs_h * stride_dA_head, dA, mask=offs_h < nheads) else: tl.atomic_add(dA_ptr + offs_h * stride_dA_head, dA, mask=offs_h < nheads) if HAS_DT_BIAS: ddt_bias = tl.sum(ddt, axis=1) ddt_bias_ptr += pid_b * stride_ddt_bias_batch + pid_c * stride_ddt_bias_chunk if DETERMINISTIC_REDUCTION: tl.store(ddt_bias_ptr + offs_h * stride_ddt_bias_head, ddt_bias, mask=offs_h < nheads) else: tl.atomic_add(ddt_bias_ptr + offs_h * stride_ddt_bias_head, ddt_bias, mask=offs_h < nheads) @triton.autotune( configs=autotune_configs([ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2), ]), key=['hdim', 'dstate', 'chunk_size'], ) @triton.jit def _chunk_state_fwd_kernel( # Pointers to matrices x_ptr, b_ptr, states_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, # Matrix dimensions hdim, dstate, chunk_size, batch, seqlen, nheads_ngroups_ratio, # Strides stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate, stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate, stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, stride_seq_idx_batch, stride_seq_idx_seqlen, # Meta-parameters HAS_SEQ_IDX: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ): pid_bc = tl.program_id(axis=1) pid_c = pid_bc // batch pid_b = pid_bc - pid_c * batch pid_h = tl.program_id(axis=2) num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) pid_m = tl.program_id(axis=0) // num_pid_n pid_n = tl.program_id(axis=0) % num_pid_n b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head if HAS_SEQ_IDX: seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen) b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen) dt_ptrs = dt_ptr + offs_k * stride_dt_csize dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize if HAS_SEQ_IDX: seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) if HAS_SEQ_IDX: seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen) acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, chunk_size_limit, BLOCK_SIZE_K): x = tl.load(x_ptrs, mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k), other=0.0) b = tl.load(b_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) if HAS_SEQ_IDX: seq_idx_k = tl.load(seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1) dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) if not HAS_SEQ_IDX: # scale = tl.exp((dA_cs_last - dA_cs_k)) * dt_k scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k else: # scale = tl.where(seq_idx_k == seq_idx_last, tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0) scale = tl.where((seq_idx_last >= 0) & (seq_idx_k == seq_idx_last), tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k, 0.0) b *= scale[:, None] b = b.to(x_ptr.dtype.element_ty) acc += tl.dot(x, b) x_ptrs += BLOCK_SIZE_K * stride_x_seqlen b_ptrs += BLOCK_SIZE_K * stride_b_seqlen dt_ptrs += BLOCK_SIZE_K * stride_dt_csize dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize if HAS_SEQ_IDX: seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen states = acc.to(states_ptr.dtype.element_ty) states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate) c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate) tl.store(states_ptrs, states, mask=c_mask) @triton.autotune( configs=autotune_configs([ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])), ]), key=['chunk_size', 'hdim', 'dstate'], ) @triton.jit def _chunk_state_bwd_dx_kernel( # Pointers to matrices x_ptr, b_ptr, dstates_ptr, dt_ptr, dA_cumsum_ptr, dx_ptr, ddt_ptr, ddA_cumsum_ptr, # Matrix dimensions chunk_size, hdim, dstate, batch, seqlen, nheads_ngroups_ratio, # Strides stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate, stride_dstates_batch, stride_dstates_chunk, stride_states_head, stride_states_hdim, stride_states_dstate, stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim, stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize, stride_ddt_tile, stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, stride_ddA_tile, # Meta-parameters DETERMINISTIC_REDUCTION: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_DSTATE: tl.constexpr, ): pid_bc = tl.program_id(axis=1) pid_c = pid_bc // batch pid_b = pid_bc - pid_c * batch pid_h = tl.program_id(axis=2) num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) pid_m = tl.program_id(axis=0) // num_pid_n pid_n = tl.program_id(axis=0) % num_pid_n x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_states_head dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head + pid_n * stride_ddt_tile ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head + pid_n * stride_ddA_tile dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 offs_k = tl.arange(0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate) dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate) if BLOCK_SIZE_DSTATE <= 128: b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate), other=0.0) dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0) dstates = dstates.to(b_ptr.dtype.element_ty) acc = tl.dot(b, dstates) else: acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, dstate, BLOCK_SIZE_K): b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate - k), other=0.0) dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0) dstates = dstates.to(b_ptr.dtype.element_ty) acc += tl.dot(b, dstates) b_ptrs += BLOCK_SIZE_K * stride_b_dstate dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) dt_ptrs = dt_ptr + offs_m * stride_dt_csize dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) # acc *= tl.exp(dA_cs_last - dA_cs_m)[:, None] acc *= tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0))[:, None] x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) ddt = tl.sum(acc * x, axis=1) ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize if DETERMINISTIC_REDUCTION: tl.store(ddt_ptrs, ddt, mask=offs_m < chunk_size) else: tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size) ddA_cs = -(ddt * dt_m) ddA_cs_last = -tl.sum(ddA_cs) ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize if DETERMINISTIC_REDUCTION: tl.store(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size) else: tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size) tl.atomic_add(ddA_cumsum_ptr + (chunk_size - 1) * stride_ddA_cs_csize, ddA_cs_last) dx = (acc * dt_m[:, None]).to(dx_ptr.dtype.element_ty) dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head dx_ptrs = dx_ptr + (offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim) tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) _CHUNK_STATE_BWD_DX_MIN_BLOCK_N = min( cfg.kwargs['BLOCK_SIZE_N'] for cfg in _chunk_state_bwd_dx_kernel.configs ) @triton.autotune( configs=autotune_configs([ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), ]), key=['chunk_size', 'dstate', 'hdim'], ) @triton.jit def _chunk_state_bwd_db_kernel( # Pointers to matrices x_ptr, dstates_ptr, b_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, db_ptr, ddA_cumsum_ptr, # Matrix dimensions chunk_size, dstate, hdim, batch, seqlen, nheads, nheads_per_program, ngroups, # Strides stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, stride_dstates_batch, stride_dstates_chunk, stride_states_head, stride_states_hdim, stride_states_dstate, stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate, stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, stride_seq_idx_batch, stride_seq_idx_seqlen, stride_db_batch, stride_db_seqlen, stride_db_split, stride_db_group, stride_db_dstate, stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, stride_ddA_tile, # Meta-parameters HAS_DDA_CS: tl.constexpr, HAS_SEQ_IDX: tl.constexpr, DETERMINISTIC_REDUCTION: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ): pid_bc = tl.program_id(axis=1) pid_c = pid_bc // batch pid_b = pid_bc - pid_c * batch pid_sg = tl.program_id(axis=2) pid_s = pid_sg // ngroups pid_g = pid_sg - pid_s * ngroups num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) pid_m = tl.program_id(axis=0) // num_pid_n pid_n = tl.program_id(axis=0) % num_pid_n x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_x_head db_ptr += pid_b * stride_db_batch + pid_c * chunk_size * stride_db_seqlen + pid_g * stride_db_group + pid_s * stride_db_split dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_states_head dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dt_head dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head if HAS_DDA_CS: b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_g * stride_b_head ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head + pid_n * stride_ddA_tile if HAS_SEQ_IDX: seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_k[None, :] * stride_x_hdim) dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_states_dstate + offs_k[:, None] * stride_states_hdim) dt_ptrs = dt_ptr + offs_m * stride_dt_csize dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize if HAS_DDA_CS: b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_n[None, :] * stride_b_dstate) ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) if HAS_DDA_CS: b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) if HAS_SEQ_IDX: seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen) nheads_iter = min(nheads_per_program, nheads // ngroups - pid_s * nheads_per_program) for h in range(nheads_iter): x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0) dstates = dstates.to(x_ptrs.dtype.element_ty) db = tl.dot(x, dstates) dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) if not HAS_SEQ_IDX: # scale = tl.exp(dA_cs_last - dA_cs_m) scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)) else: # scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0) scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)), 0.0) db *= (scale * dt_m)[:, None] if HAS_DDA_CS: # This is the gradient wrt (dA_cs_last - dA_cs_m), i.e. the exclusive reverse cumsum ddA_cs = tl.sum(db * b, axis=1) if DETERMINISTIC_REDUCTION: tl.store(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1) else: tl.atomic_add(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1) acc += db x_ptrs += stride_x_head dstates_ptrs += stride_states_head dt_ptrs += stride_dt_head dA_cumsum_ptr += stride_dA_cs_head dA_cumsum_ptrs += stride_dA_cs_head if HAS_DDA_CS: ddA_cumsum_ptrs += stride_ddA_cs_head offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) # if HAS_SEQ_IDX: # seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen) # seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) # acc = tl.where(seq_idx_m[:, None] == seq_idx_last, acc, 0.0) db_ptrs = db_ptr + (offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_dstate) tl.store(db_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate)) _CHUNK_STATE_BWD_DB_MIN_BLOCK_N = min( cfg.kwargs['BLOCK_SIZE_N'] for cfg in _chunk_state_bwd_db_kernel.configs ) @triton.autotune( configs=autotune_configs([ # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), ]), key=['chunk_size', 'hdim', 'dstate'], ) @triton.jit def _chunk_state_bwd_ddAcs_stable_kernel( # Pointers to matrices x_ptr, b_ptr, dstates_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, ddA_cumsum_ptr, # Matrix dimensions chunk_size, hdim, dstate, batch, seqlen, nheads_ngroups_ratio, # Strides stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate, stride_dstates_batch, stride_dstates_chunk, stride_states_head, stride_states_hdim, stride_states_dstate, stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, stride_seq_idx_batch, stride_seq_idx_seqlen, stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, stride_ddA_tile, # Meta-parameters HAS_SEQ_IDX: tl.constexpr, DETERMINISTIC_REDUCTION: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_DSTATE: tl.constexpr, ): pid_bc = tl.program_id(axis=1) pid_c = pid_bc // batch pid_b = pid_bc - pid_c * batch pid_h = tl.program_id(axis=2) num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) pid_m = tl.program_id(axis=0) // num_pid_n pid_n = tl.program_id(axis=0) % num_pid_n x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_states_head dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head + pid_n * stride_ddA_tile dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head if HAS_SEQ_IDX: seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 offs_k = tl.arange(0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate) dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate) if BLOCK_SIZE_DSTATE <= 128: b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate), other=0.0) dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0) dstates = dstates.to(b_ptr.dtype.element_ty) acc = tl.dot(b, dstates) else: acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, dstate, BLOCK_SIZE_K): b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate - k), other=0.0) dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0) dstates = dstates.to(b_ptr.dtype.element_ty) acc += tl.dot(b, dstates) b_ptrs += BLOCK_SIZE_K * stride_b_dstate dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) if not HAS_SEQ_IDX: # scale = tl.exp(dA_cs_last - dA_cs_m) scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)) else: seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen) # scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0) scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)), 0.0) acc *= scale[:, None] x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) dt_ptrs = dt_ptr + offs_m * stride_dt_csize dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) ddt = tl.sum(acc * x, axis=1) # ddA_cs = -(ddt * dt_m) # Triton 2.2.0 errors if we have the cumsum here, so we just write it out # then call torch.cumsum outside this kernel. # ddA_cs = tl.cumsum(ddt * dt_m) ddA_cs = ddt * dt_m ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize if DETERMINISTIC_REDUCTION: tl.store(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1) else: tl.atomic_add(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1) _CHUNK_STATE_BWD_DDACS_MIN_BLOCK_N = min( cfg.kwargs['BLOCK_SIZE_N'] for cfg in _chunk_state_bwd_ddAcs_stable_kernel.configs ) @triton.autotune( configs=autotune_configs([ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2), ]), key=['hdim', 'dstate', 'chunk_size'], ) @triton.jit def _chunk_state_varlen_kernel( # Pointers to matrices x_ptr, b_ptr, dt_ptr, dA_cumsum_ptr, chunk_states_ptr, cu_seqlens_ptr, states_ptr, # Matrix dimensions hdim, dstate, chunk_size, seqlen, nheads_ngroups_ratio, # Strides stride_x_seqlen, stride_x_head, stride_x_hdim, stride_b_seqlen, stride_b_head, stride_b_dstate, stride_dt_chunk, stride_dt_head, stride_dt_csize, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, stride_chunk_states_chunk, stride_chunk_states_head, stride_chunk_states_hdim, stride_chunk_states_dstate, stride_states_batch, stride_states_head, stride_states_hdim, stride_states_dstate, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ): pid_b = tl.program_id(axis=1) pid_h = tl.program_id(axis=2) num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) pid_m = tl.program_id(axis=0) // num_pid_n pid_n = tl.program_id(axis=0) % num_pid_n end_idx = tl.load(cu_seqlens_ptr + pid_b + 1) pid_c = (end_idx - 1) // chunk_size b_ptr += pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head chunk_states_ptr += pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen) b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen) dt_ptrs = dt_ptr + offs_k * stride_dt_csize dA_cs_last = tl.load(dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize chunk_size_limit = end_idx - pid_c * chunk_size start_idx = tl.load(cu_seqlens_ptr + pid_b) start_idx_cur = tl.maximum(start_idx - pid_c * chunk_size, 0) acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, chunk_size_limit, BLOCK_SIZE_K): x = tl.load(x_ptrs, mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k) & (offs_k[None, :] >= start_idx_cur - k), other=0.0) b = tl.load(b_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate) & (offs_k[:, None] >= start_idx_cur - k), other=0.0).to(tl.float32) dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) # scale = tl.where((offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k), # tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0) scale = tl.where((offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k), tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k, 0.0) b *= scale[:, None] b = b.to(x_ptr.dtype.element_ty) acc += tl.dot(x, b) x_ptrs += BLOCK_SIZE_K * stride_x_seqlen b_ptrs += BLOCK_SIZE_K * stride_b_seqlen dt_ptrs += BLOCK_SIZE_K * stride_dt_csize dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize # If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk if start_idx < pid_c * chunk_size: chunk_states_ptrs = chunk_states_ptr + (offs_m[:, None] * stride_chunk_states_hdim + offs_n[None, :] * stride_chunk_states_dstate) chunk_states = tl.load(chunk_states_ptrs, mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) # scale = tl.where(start_idx < pid_c * chunk_size, tl.exp(dA_cs_last), 0.0) scale = tl.exp(dA_cs_last) acc += chunk_states * scale states = acc.to(states_ptr.dtype.element_ty) states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate) c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate) tl.store(states_ptrs, states, mask=c_mask) def _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))): batch, seqlen, nheads = dt.shape assert A.shape == (nheads,) if dt_bias is not None: assert dt_bias.shape == (nheads,) nchunks = math.ceil(seqlen / chunk_size) dt_out = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32) dA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32) grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H'])) with torch.cuda.device(dt.device.index): _chunk_cumsum_fwd_kernel[grid_chunk_cs]( dt, A, dt_bias, dt_out, dA_cumsum, batch, seqlen, nheads, chunk_size, dt_limit[0], dt_limit[1], dt.stride(0), dt.stride(1), dt.stride(2), A.stride(0), dt_bias.stride(0) if dt_bias is not None else 0, dt_out.stride(0), dt_out.stride(2), dt_out.stride(1), dt_out.stride(3), dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), dt_softplus, HAS_DT_BIAS=dt_bias is not None, BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size), ) return dA_cumsum, dt_out def _chunk_cumsum_bwd(ddA, ddt_out, dt, A, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf")), ddt=None): batch, seqlen, nheads = dt.shape _, _, nchunks, chunk_size = ddA.shape assert ddA.shape == (batch, nheads, nchunks, chunk_size) assert ddt_out.shape == (batch, nheads, nchunks, chunk_size) assert A.shape == (nheads,) deterministic = use_deterministic_mode() if dt_bias is not None: assert dt_bias.shape == (nheads,) if deterministic: ddt_bias_workspace = torch.zeros( batch, nchunks, nheads, device=dt.device, dtype=torch.float32 ) ddt_bias = torch.empty_like(dt_bias, dtype=torch.float32) stride_ddt_bias_batch = ddt_bias_workspace.stride(0) stride_ddt_bias_chunk = ddt_bias_workspace.stride(1) else: ddt_bias_workspace = ddt_bias = torch.empty_like( dt_bias, dtype=torch.float32 ) stride_ddt_bias_batch = 0 stride_ddt_bias_chunk = 0 else: ddt_bias = None ddt_bias_workspace = None stride_ddt_bias_batch = 0 stride_ddt_bias_chunk = 0 if ddt is not None: assert ddt.shape == dt.shape else: ddt = torch.empty_like(dt) dA = torch.empty_like(A, dtype=torch.float32) if deterministic: dA_workspace = torch.zeros( batch, nchunks, nheads, device=dt.device, dtype=torch.float32 ) stride_dA_batch = dA_workspace.stride(0) stride_dA_chunk = dA_workspace.stride(1) else: dA_workspace = dA stride_dA_batch = 0 stride_dA_chunk = 0 grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H'])) with torch.cuda.device(dt.device.index): _chunk_cumsum_bwd_kernel[grid_chunk_cs]( ddA, ddt_out, dt, A, dt_bias, ddt, dA_workspace, ddt_bias_workspace if ddt_bias is not None else None, batch, seqlen, nheads, chunk_size, dt_limit[0], dt_limit[1], ddA.stride(0), ddA.stride(2), ddA.stride(1), ddA.stride(3), ddt_out.stride(0), ddt_out.stride(2), ddt_out.stride(1), ddt_out.stride(3), dt.stride(0), dt.stride(1), dt.stride(2), A.stride(0), dt_bias.stride(0) if dt_bias is not None else 0, ddt.stride(0), ddt.stride(1), ddt.stride(2), stride_dA_batch, stride_dA_chunk, dA_workspace.stride(-1), stride_ddt_bias_batch, stride_ddt_bias_chunk, (ddt_bias_workspace.stride(-1) if ddt_bias is not None else 0), dt_softplus, HAS_DT_BIAS=dt_bias is not None, BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size), DETERMINISTIC_REDUCTION=deterministic, ) if deterministic: dA.copy_(dA_workspace.sum(dim=(0, 1))) if ddt_bias is not None: ddt_bias.copy_(ddt_bias_workspace.sum(dim=(0, 1))) return ddt, dA, ddt_bias def _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=None, states=None, states_in_fp32=True): batch, seqlen, nheads, headdim = x.shape _, _, nchunks, chunk_size = dt.shape _, _, ngroups, dstate = B.shape assert nheads % ngroups == 0 assert B.shape == (batch, seqlen, ngroups, dstate) assert dt.shape == (batch, nheads, nchunks, chunk_size) assert dA_cumsum.shape == dt.shape if seq_idx is not None: assert seq_idx.shape == (batch, seqlen) if states is not None: assert states.shape == (batch, nchunks, nheads, headdim, dstate) else: states_dtype = torch.float32 if states_in_fp32 else B.dtype states = torch.empty((batch, nchunks, nheads, headdim, dstate), device=x.device, dtype=states_dtype) grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']), batch * nchunks, nheads) with torch.cuda.device(x.device.index): _chunk_state_fwd_kernel[grid]( x, B, states, dt, dA_cumsum, seq_idx, headdim, dstate, chunk_size, batch, seqlen, nheads // ngroups, x.stride(0), x.stride(1), x.stride(2), x.stride(3), B.stride(0), B.stride(1), B.stride(2), B.stride(-1), states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4), dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), HAS_SEQ_IDX=seq_idx is not None, ) return states def _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates, dx=None): batch, seqlen, nheads, headdim = x.shape _, _, nchunks, chunk_size = dt.shape _, _, ngroups, dstate = B.shape assert nheads % ngroups == 0 assert B.shape == (batch, seqlen, ngroups, dstate) assert dt.shape == (batch, nheads, nchunks, chunk_size) assert dA_cumsum.shape == dt.shape assert dstates.shape == (batch, nchunks, nheads, headdim, dstate) if dx is not None: assert dx.shape == x.shape else: dx = torch.empty_like(x) deterministic = use_deterministic_mode() tile_count = math.ceil(headdim / _CHUNK_STATE_BWD_DX_MIN_BLOCK_N) ddt, stride_ddt_tile = alloc_tile_workspace( (batch, nheads, nchunks, chunk_size), tile_count, torch.float32, dt.device, deterministic, zero_init=True, ) ddA_cumsum, stride_ddA_tile = alloc_tile_workspace( (batch, nheads, nchunks, chunk_size), tile_count, torch.float32, dA_cumsum.device, deterministic, zero_init=True, ) grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']), batch * nchunks, nheads) with torch.cuda.device(x.device.index): _chunk_state_bwd_dx_kernel[grid_dx]( x, B, dstates, dt, dA_cumsum, dx, ddt, ddA_cumsum, chunk_size, headdim, dstate, batch, seqlen, nheads // ngroups, x.stride(0), x.stride(1), x.stride(2), x.stride(3), B.stride(0), B.stride(1), B.stride(2), B.stride(-1), dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4), dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3), ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3), stride_ddt_tile, ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), stride_ddA_tile, DETERMINISTIC_REDUCTION=deterministic, BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), ) ddt = finalize_tile_workspace(ddt, deterministic) ddA_cumsum = finalize_tile_workspace(ddA_cumsum, deterministic) if deterministic: # Match `_chunk_state_bwd_dx_kernel` atomic path (`tl.atomic_add(..., ddA_cs_last)` into last element). ddA_cumsum[..., -1] -= ddA_cumsum.sum(dim=-1) return dx, ddt, ddA_cumsum def _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=None, B=None, ngroups=1): batch, seqlen, nheads, headdim = x.shape _, _, nchunks, chunk_size = dt.shape dstate = dstates.shape[-1] assert dt.shape == (batch, nheads, nchunks, chunk_size) assert dA_cumsum.shape == dt.shape assert dstates.shape == (batch, nchunks, nheads, headdim, dstate) if seq_idx is not None: assert seq_idx.shape == (batch, seqlen) deterministic = use_deterministic_mode() if B is not None: assert B.shape == (batch, seqlen, ngroups, dstate) B_strides = (B.stride(0), B.stride(1), B.stride(2), B.stride(3)) # Use torch.empty since the Triton kernel will call init_to_zero tile_count = math.ceil(dstate / _CHUNK_STATE_BWD_DB_MIN_BLOCK_N) ddA_cumsum, stride_ddA_tile = alloc_tile_workspace( (batch, nheads, nchunks, chunk_size), tile_count, torch.float32, x.device, deterministic, zero_init=True, ) ddA_cumsum_strides = ( ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), ) else: B_strides = (0, 0, 0, 0) ddA_cumsum = None ddA_cumsum_strides = (0, 0, 0, 0) stride_ddA_tile = 0 nheads_ngroups_ratio = nheads // ngroups sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count nheads_per_program = max(min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1) nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program) dB = torch.empty(batch, seqlen, nsplits, ngroups, dstate, device=x.device, dtype=torch.float32) grid_db = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']), batch * nchunks, nsplits * ngroups) with torch.cuda.device(x.device.index): _chunk_state_bwd_db_kernel[grid_db]( x, dstates, B, dt, dA_cumsum, seq_idx, dB, ddA_cumsum, chunk_size, dstate, headdim, batch, seqlen, nheads, nheads_per_program, ngroups, x.stride(0), x.stride(1), x.stride(2), x.stride(3), dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4), *B_strides, dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), dB.stride(0), dB.stride(1), dB.stride(2), dB.stride(3), dB.stride(4), *ddA_cumsum_strides, stride_ddA_tile, HAS_DDA_CS=ddA_cumsum is not None, HAS_SEQ_IDX=seq_idx is not None, DETERMINISTIC_REDUCTION=deterministic, BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), ) dB = dB.sum(2) if ddA_cumsum is not None: ddA_cumsum = finalize_tile_workspace(ddA_cumsum, deterministic) # The first element of ddA_cumsum is always zero, since that dA_cumsum does not contribute # to the state of the chunk. # torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:]) # But it's easier to just do the cumsum for all elements, the result will be the same. torch.cumsum(ddA_cumsum, dim=-1, out=ddA_cumsum) return dB if B is None else (dB, ddA_cumsum) def _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=None): batch, seqlen, nheads, headdim = x.shape _, _, nchunks, chunk_size = dt.shape _, _, ngroups, dstate = B.shape assert nheads % ngroups == 0 assert B.shape == (batch, seqlen, ngroups, dstate) assert dt.shape == (batch, nheads, nchunks, chunk_size) assert dA_cumsum.shape == dt.shape assert dstates.shape == (batch, nchunks, nheads, headdim, dstate) if seq_idx is not None: assert seq_idx.shape == (batch, seqlen) # Use torch.empty since the Triton kernel will call init_to_zero deterministic = use_deterministic_mode() tile_count = math.ceil(headdim / _CHUNK_STATE_BWD_DDACS_MIN_BLOCK_N) ddA_cumsum, stride_ddA_tile = alloc_tile_workspace( (batch, nheads, nchunks, chunk_size), tile_count, torch.float32, x.device, deterministic, zero_init=True, ) grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']), batch * nchunks, nheads) with torch.cuda.device(x.device.index): _chunk_state_bwd_ddAcs_stable_kernel[grid_ddtcs]( x, B, dstates, dt, dA_cumsum, seq_idx, ddA_cumsum, chunk_size, headdim, dstate, batch, seqlen, nheads // ngroups, x.stride(0), x.stride(1), x.stride(2), x.stride(3), B.stride(0), B.stride(1), B.stride(2), B.stride(-1), dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4), dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), stride_ddA_tile, HAS_SEQ_IDX=seq_idx is not None, DETERMINISTIC_REDUCTION=deterministic, BLOCK_SIZE_M=max(triton.next_power_of_2(chunk_size), 16), BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), ) ddA_cumsum = finalize_tile_workspace(ddA_cumsum, deterministic) torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:]) return ddA_cumsum def chunk_state_varlen(B, x, dt, dA_cumsum, cu_seqlens, chunk_states): total_seqlen, nheads, headdim = x.shape _, nchunks, chunk_size = dt.shape _, ngroups, dstate = B.shape batch = cu_seqlens.shape[0] - 1 cu_seqlens = cu_seqlens.contiguous() assert nheads % ngroups == 0 assert B.shape == (total_seqlen, ngroups, dstate) assert dt.shape == (nheads, nchunks, chunk_size) assert dA_cumsum.shape == dt.shape assert chunk_states.shape == (nchunks, nheads, headdim, dstate) states = torch.empty(batch, nheads, headdim, dstate, dtype=chunk_states.dtype, device=chunk_states.device) grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']), batch, nheads) with torch.cuda.device(x.device.index): _chunk_state_varlen_kernel[grid]( x, B, dt, dA_cumsum, chunk_states, cu_seqlens, states, headdim, dstate, chunk_size, total_seqlen, nheads // ngroups, x.stride(0), x.stride(1), x.stride(2), B.stride(0), B.stride(1), B.stride(2), dt.stride(1), dt.stride(0), dt.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(0), dA_cumsum.stride(2), chunk_states.stride(0), chunk_states.stride(1), chunk_states.stride(2), chunk_states.stride(3), states.stride(0), states.stride(1), states.stride(2), states.stride(3), ) return states class ChunkStateFn(torch.autograd.Function): @staticmethod def forward(ctx, B, x, dt, dA_cumsum, states_in_fp32=True): batch, seqlen, nheads, headdim = x.shape _, _, nchunks, chunk_size = dt.shape assert seqlen <= nchunks * chunk_size _, _, ngroups, dstate = B.shape assert B.shape == (batch, seqlen, ngroups, dstate) assert dt.shape == (batch, nheads, nchunks, chunk_size) assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) if B.stride(-1) != 1: B = B.contiguous() if x.stride(-1) != 1 and x.stride(1) != 1: # Either M or K dimension should be contiguous x = x.contiguous() states = _chunk_state_fwd(B, x, dt, dA_cumsum, states_in_fp32=states_in_fp32) ctx.save_for_backward(B, x, dt, dA_cumsum) return states @staticmethod def backward(ctx, dstates): B, x, dt, dA_cumsum = ctx.saved_tensors batch, seqlen, nheads, headdim = x.shape _, _, nchunks, chunk_size = dt.shape _, _, ngroups, dstate = B.shape assert dstates.shape == (batch, nchunks, nheads, headdim, dstate) if dstates.stride(-1) != 1: dstates = dstates.contiguous() dx, ddt, ddA_cumsum = _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates) dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, ngroups=ngroups) dB = dB.to(B.dtype) return dB, dx, ddt, ddA_cumsum, None def chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True): """ Argument: B: (batch, seqlen, ngroups, dstate) x: (batch, seqlen, nheads, headdim) dt: (batch, nheads, nchunks, chunk_size) dA_cumsum: (batch, nheads, nchunks, chunk_size) Return: states: (batch, nchunks, nheads, headdim, dstate) """ return ChunkStateFn.apply(B, x, dt, dA_cumsum, states_in_fp32) def chunk_state_ref(B, x, dt, dA_cumsum): """ Argument: B: (batch, seqlen, ngroups, dstate) x: (batch, seqlen, nheads, headdim) dt: (batch, nheads, nchunks, chunk_size) dA_cumsum: (batch, nheads, nchunks, chunk_size) Return: states: (batch, nchunks, nheads, headdim, dstate) """ # Check constraints. batch, seqlen, nheads, headdim = x.shape dstate = B.shape[-1] _, _, nchunks, chunk_size = dt.shape assert seqlen <= nchunks * chunk_size assert x.shape == (batch, seqlen, nheads, headdim) assert dt.shape == (batch, nheads, nchunks, chunk_size) ngroups = B.shape[2] assert nheads % ngroups == 0 assert B.shape == (batch, seqlen, ngroups, dstate) B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups) assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) if seqlen < nchunks * chunk_size: x = F.pad(x, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen)) B = F.pad(B, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen)) x = rearrange(x, "b (c l) h p -> b c l h p", l=chunk_size) B = rearrange(B, "b (c l) ... -> b c l ...", l=chunk_size) decay_states = torch.exp((dA_cumsum[:, :, :, -1:] - dA_cumsum)) return torch.einsum("bclhn,bhcl,bhcl,bclhp->bchpn", B.to(x.dtype), decay_states.to(x.dtype), dt.to(x.dtype), x) ================================================ FILE: mamba_ssm/ops/triton/ssd_combined.py ================================================ # Copyright (c) 2024, Tri Dao, Albert Gu. """We want triton==2.1.0 or 2.2.0 for this """ from typing import Optional import math from packaging import version import torch import torch.nn.functional as F from torch import Tensor from mamba_ssm.utils.torch import custom_bwd, custom_fwd import triton import triton.language as tl from einops import rearrange, repeat try: from causal_conv1d import causal_conv1d_fn from causal_conv1d.cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function except ImportError: causal_conv1d_fn = None causal_conv1d_fwd_function = None causal_conv1d_bwd_function = None causal_conv1d_update_function = None from mamba_ssm.ops.triton.ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_cumsum_fwd, _chunk_cumsum_bwd from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_fwd, _chunk_state_bwd_db from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_bwd_ddAcs_stable from mamba_ssm.ops.triton.ssd_chunk_state import chunk_state, chunk_state_ref from mamba_ssm.ops.triton.ssd_chunk_state import chunk_state_varlen from mamba_ssm.ops.triton.ssd_state_passing import _state_passing_fwd, _state_passing_bwd from mamba_ssm.ops.triton.ssd_state_passing import state_passing, state_passing_ref from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_fwd, _chunk_scan_bwd_dz, _chunk_scan_bwd_dstates from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_dC, _chunk_scan_bwd_dcb from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_ddAcs_stable from mamba_ssm.ops.triton.ssd_chunk_scan import chunk_scan, chunk_scan_ref from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_ddAcs_prev from mamba_ssm.ops.triton.layernorm_gated import rmsnorm_fn, _layer_norm_fwd, _layer_norm_bwd from mamba_ssm.ops.triton.k_activations import _swiglu_fwd, _swiglu_bwd from mamba_ssm.utils.determinism import ( alloc_tile_workspace, autotune_configs, finalize_tile_workspace, use_deterministic_mode, ) TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') def init_to_zero(names): return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None] def ensure_stride(inp): """ Return inp, while ensuring that stride(1) of the returned tensor is a multiple of 8. The inp tensor is of shape [batch, length, channels], where channels is assumed, and tested, to be a multiple of 8. If it is contiguous, inp will have strides [length*channels, channels, 1]. The output of this function will be rearranged to shape [batch, channels, length] before being passed to causal_conv1d. That rearranged tensor will have strides [length*channels, 1, channels]. causal_conv1d handles this stride configuration (which it calls channels_last) directly and efficiently, after first recognizing it (when stride[1]==1 and stride[2]>1). causal_conv1d cannot operate on a channels_last tensor for which stride[2] is not a multiple of 8, and in that case will raise an exception. This function prevents the aforementioned exception by returning a tensor with stride(1) equal to channels, by making the returned tensor contiguous, if inp.stride(1) is not already a multiple of 8. """ assert inp.shape[2] % 8 == 0, "Number of convolution channels is required to be a multiple of 8." return inp if inp.stride(1) % 8 == 0 else inp.contiguous() @triton.autotune( configs=autotune_configs([ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr", "dD_ptr"])), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "dD_ptr"])), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "dD_ptr"])), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "dD_ptr"])), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "dD_ptr"])), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "dD_ptr"])), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "dD_ptr"])), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "dD_ptr"])), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "dD_ptr"])), ]), key=['chunk_size', 'hdim', 'dstate'], ) @triton.jit def _chunk_scan_chunk_state_bwd_dx_kernel( # Pointers to matrices x_ptr, cb_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, D_ptr, b_ptr, dstates_ptr, dx_ptr, ddt_ptr, dD_ptr, # Matrix dimensions chunk_size, hdim, dstate, batch, seqlen, nheads_ngroups_ratio, # Strides stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k, stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, stride_seq_idx_batch, stride_seq_idx_seqlen, stride_D_head, stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate, stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_hdim, stride_dstates_dstate, stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim, stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize, stride_ddt_tile, stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim, # Meta-parameters HAS_D: tl.constexpr, D_HAS_HDIM: tl.constexpr, HAS_SEQ_IDX: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_DSTATE: tl.constexpr, IS_TRITON_22: tl.constexpr, DETERMINISTIC_REDUCTION: tl.constexpr, ): pid_bc = tl.program_id(axis=1) pid_c = pid_bc // batch pid_b = pid_bc - pid_c * batch pid_h = tl.program_id(axis=2) num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) pid_m = tl.program_id(axis=0) // num_pid_n pid_n = tl.program_id(axis=0) % num_pid_n x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head + pid_n * stride_ddt_tile dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_dstates_head if HAS_SEQ_IDX: seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) if not HAS_SEQ_IDX: # scale = tl.exp(dA_cs_last - dA_cs_m) scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)) else: seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen) # scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0) scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)), 0.0) # Might be faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 # However, we're getting error with the Triton compiler 2.1.0 for that code path: # Unexpected mma -> mma layout conversion # Triton 2.2.0 fixes this offs_dstate = tl.arange(0, BLOCK_SIZE_DSTATE if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_dstate[None, :] * stride_b_dstate) dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_dstates_hdim + offs_dstate[:, None] * stride_dstates_dstate) if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128: b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate), other=0.0) dstates = tl.load(dstates_ptrs, mask=(offs_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0) dstates = dstates.to(b_ptr.dtype.element_ty) acc = tl.dot(b, dstates) * scale[:, None] else: for k in range(0, dstate, BLOCK_SIZE_K): b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate - k), other=0.0) dstates = tl.load(dstates_ptrs, mask=(offs_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0) dstates = dstates.to(b_ptr.dtype.element_ty) acc += tl.dot(b, dstates) b_ptrs += BLOCK_SIZE_K * stride_b_dstate dstates_ptrs += BLOCK_SIZE_K * stride_dstates_dstate acc *= scale[:, None] # x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) # x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) # dt_ptrs = dt_ptr + offs_m * stride_dt_csize # dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) # ddt = tl.sum(acc * x, axis=1) * dt_m # ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize # tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size) offs_k = tl.arange(0, BLOCK_SIZE_K) cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k) dout_ptrs = dout_ptr + (offs_k[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize K_MAX = chunk_size_limit K_MIN = pid_m * BLOCK_SIZE_M cb_ptrs += K_MIN * stride_cb_csize_k dout_ptrs += K_MIN * stride_dout_seqlen dA_cumsum_ptrs += K_MIN * stride_dA_cs_csize for k in range(K_MIN, K_MAX, BLOCK_SIZE_K): k = tl.multiple_of(k, BLOCK_SIZE_K) # For some reason setting mask to (offs_m[:, None] < chunk_size_limit) is much slower cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0) dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0) dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32) # cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None]) cb *= tl.exp(tl.minimum((dA_cs_k[None, :] - dA_cs_m[:, None]), 0.0)) # If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range, # we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf. # Multiplying with cb, which is 0.0 outside the range, will make the result NaN. # This will cause NaN in acc, and hence NaN in dx and ddt. mask = (k + offs_k[None, :] >= offs_m[:, None]) & (k + offs_k[None, :] < K_MAX) cb = tl.where(mask, cb, 0.0) cb = cb.to(dout_ptr.dtype.element_ty) acc += tl.dot(cb, dout) cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) dt_ptrs = dt_ptr + offs_m * stride_dt_csize dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) dx = acc * dt_m[:, None] dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head dx_ptrs = dx_ptr + (offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim) if HAS_D: dout_res_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) dout_res = tl.load(dout_res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) if D_HAS_HDIM: D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) else: D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) dx += dout_res * D tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) if HAS_D: dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize if D_HAS_HDIM: dD_ptrs = dD_ptr + offs_n * stride_dD_hdim dD = tl.sum(dout_res * x, axis=0) tl.store(dD_ptrs, dD, mask=offs_n < hdim) else: dD = tl.sum(dout_res * x) if DETERMINISTIC_REDUCTION: tl.store(dD_ptr + pid_n * stride_dD_hdim, dD) else: tl.atomic_add(dD_ptr, dD) ddt = tl.sum(acc * x, axis=1) ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize if DETERMINISTIC_REDUCTION: tl.store(ddt_ptrs, ddt, mask=offs_m < chunk_size) else: tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size) _CHUNK_SCAN_CHUNK_STATE_BWD_DX_MIN_BLOCK_N = min( cfg.kwargs['BLOCK_SIZE_N'] for cfg in _chunk_scan_chunk_state_bwd_dx_kernel.configs ) def _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=None, seq_idx=None, dx=None): batch, seqlen, nheads, headdim = x.shape _, _, nchunks, chunk_size = dt.shape _, _, ngroups, dstate = B.shape assert nheads % ngroups == 0 assert B.shape == (batch, seqlen, ngroups, dstate) assert CB.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) assert dt.shape == (batch, nheads, nchunks, chunk_size) assert dA_cumsum.shape == dt.shape assert dout.shape == x.shape assert dstates.shape == (batch, nchunks, nheads, headdim, dstate) if seq_idx is not None: assert seq_idx.shape == (batch, seqlen) deterministic = use_deterministic_mode() if D is not None: assert D.shape == (nheads, headdim) or D.shape == (nheads,) assert D.stride(-1) == 1 BLOCK_SIZE_min = 32 pid_m_tiles = triton.cdiv(chunk_size, BLOCK_SIZE_min) pid_n_tiles = math.ceil(headdim / _CHUNK_SCAN_CHUNK_STATE_BWD_DX_MIN_BLOCK_N) if D.dim() == 2: dD_hdim = headdim elif deterministic: dD_hdim = pid_n_tiles else: dD_hdim = 1 dD = torch.zeros(pid_m_tiles, batch, nchunks, nheads, dD_hdim, device=D.device, dtype=torch.float32) dD_strides = (dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4)) else: dD = None dD_strides = (0, 0, 0, 0, 0) if dx is None: dx = torch.empty_like(x) else: assert dx.shape == x.shape tile_count = math.ceil(headdim / _CHUNK_SCAN_CHUNK_STATE_BWD_DX_MIN_BLOCK_N) ddt, stride_ddt_tile = alloc_tile_workspace( (batch, nheads, nchunks, chunk_size), tile_count, torch.float32, dout.device, deterministic, zero_init=True, ) grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']), batch * nchunks, nheads) with torch.cuda.device(x.device.index): _chunk_scan_chunk_state_bwd_dx_kernel[grid_dx]( x, CB, dout, dt, dA_cumsum, seq_idx, D, B, dstates, dx, ddt, dD, chunk_size, headdim, dstate, batch, seqlen, nheads // ngroups, x.stride(0), x.stride(1), x.stride(2), x.stride(3), CB.stride(0), CB.stride(1), CB.stride(2), CB.stride(-1), CB.stride(-2), dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), D.stride(0) if D is not None else 0, B.stride(0), B.stride(1), B.stride(2), B.stride(3), dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4), dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3), ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3), stride_ddt_tile, dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4], D is not None, D.dim() == 2 if D is not None else True, HAS_SEQ_IDX=seq_idx is not None, BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), IS_TRITON_22=TRITON_22, DETERMINISTIC_REDUCTION=deterministic, ) if D is not None: BLOCK_SIZE_actual = _chunk_scan_chunk_state_bwd_dx_kernel.best_config.kwargs["BLOCK_SIZE_M"] n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)) if D.dim() == 1: dD = dD.sum(dim=-1) dD = dD.to(dtype=D.dtype) ddt = finalize_tile_workspace(ddt, deterministic) return dx, ddt, dD def _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf"))): batch, seqlen, nheads, headdim = x.shape _, _, ngroups, dstate = B.shape assert nheads % ngroups == 0 assert B.shape == (batch, seqlen, ngroups, dstate) assert x.shape == (batch, seqlen, nheads, headdim) assert dt.shape == (batch, seqlen, nheads) assert A.shape == (nheads,) assert C.shape == B.shape if z is not None: assert z.shape == x.shape if D is not None: assert D.shape == (nheads, headdim) or D.shape == (nheads,) if seq_idx is not None: assert seq_idx.shape == (batch, seqlen) if B.stride(-1) != 1: B = B.contiguous() if C.stride(-1) != 1: C = C.contiguous() if x.stride(-1) != 1 and x.stride(1) != 1: # Either M or K dimension should be contiguous x = x.contiguous() if z is not None and z.stride(-1) != 1 and z.stride(1) != 1: # Either M or K dimension should be contiguous z = z.contiguous() if D is not None and D.stride(-1) != 1: D = D.contiguous() if initial_states is not None: assert initial_states.shape == (batch, nheads, headdim, dstate) # # (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, nheads, chunk_size, chunk_size) # dA_cumsum_tmp0, dt_tmp0 = _chunk_cumsum_fwd(dt[:, :147], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus) # dA_cumsum_tmp1, dt_tmp1 = _chunk_cumsum_fwd(dt[:, 147:], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus) # dA_cumsum_tmp2, dt_tmp2 = _chunk_cumsum_fwd(dt[:, 147:256], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus) dA_cumsum, dt = _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit) states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True) # states_tmp0 = _chunk_state_fwd(B[:, :147], x[:, :147], dt_tmp0, dA_cumsum_tmp0, states_in_fp32=True) # states_tmp1 = _chunk_state_fwd(B[:, 147:], x[:, 147:], dt_tmp1, dA_cumsum_tmp1, states_in_fp32=True) # states_tmp2 = _chunk_state_fwd(B[:, 147:256], x[:, 147:256], dt_tmp2, dA_cumsum_tmp2, states_in_fp32=True) states, final_states = _state_passing_fwd(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1], initial_states=rearrange(initial_states, "... p n -> ... (p n)") if initial_states is not None else None, seq_idx=seq_idx, chunk_size=chunk_size, out_dtype=C.dtype) states, final_states = [rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states]] # states_tmp0 = rearrange(_state_passing_fwd(rearrange(states_tmp0, "... p n -> ... (p n)"), dA_cumsum_tmp0[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate) # states_tmp1 = rearrange(_state_passing_fwd(rearrange(states_tmp1, "... p n -> ... (p n)"), dA_cumsum_tmp1[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate) CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32) out, out_x = _chunk_scan_fwd(CB, x, dt, dA_cumsum, C, states, D=D, z=z, seq_idx=seq_idx) if cu_seqlens is None: return out, out_x, dt, dA_cumsum, states, final_states else: assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1" varlen_states = chunk_state_varlen(B.squeeze(0), x.squeeze(0), dt.squeeze(0), dA_cumsum.squeeze(0), cu_seqlens, states.squeeze(0)) return out, out_x, dt, dA_cumsum, states, final_states, varlen_states def _mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, dfinal_states=None, seq_idx=None, dt_softplus=False, dt_limit=(0.0, float("inf")), dx=None, ddt=None, dB=None, dC=None, dz=None, recompute_output=False): if dout.stride(-1) != 1: dout = dout.contiguous() batch, seqlen, nheads, headdim = x.shape nchunks = math.ceil(seqlen / chunk_size) _, _, ngroups, dstate = B.shape assert dout.shape == (batch, seqlen, nheads, headdim) assert dt.shape == (batch, seqlen, nheads) assert A.shape == (nheads,) assert nheads % ngroups == 0 assert B.shape == (batch, seqlen, ngroups, dstate) assert C.shape == B.shape assert out.shape == x.shape if initial_states is not None: assert initial_states.shape == (batch, nheads, headdim, dstate) if seq_idx is not None: assert seq_idx.shape == (batch, seqlen) if dx is not None: assert dx.shape == x.shape if dB is not None: assert dB.shape == B.shape dB_given = dB else: dB_given = torch.empty_like(B) if dC is not None: assert dC.shape == C.shape dC_given = dC else: dC_given = torch.empty_like(C) if dz is not None: assert z is not None assert dz.shape == z.shape if ddt is not None: assert ddt.shape == dt.shape ddt_given = ddt else: ddt_given = torch.empty_like(dt) # TD: For some reason Triton (2.1.0 and 2.2.0) errors with # "[CUDA]: invalid device context" (e.g. during varlne test), and cloning makes it work. Idk why. dt_in = dt.clone() dA_cumsum, dt = _chunk_cumsum_fwd(dt_in, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit) CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32) states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True) states, _ = _state_passing_fwd(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1], initial_states=rearrange(initial_states, "... p n -> ... (p n)") if initial_states is not None else None, seq_idx=seq_idx, chunk_size=chunk_size) states = rearrange(states, "... (p n) -> ... p n", n=dstate) if z is not None: dz, dout, dD, *rest = _chunk_scan_bwd_dz(x, z, out, dout, chunk_size=chunk_size, has_ddAcs=False, D=D, dz=dz, recompute_output=recompute_output) outz = rest[0] if recompute_output else out else: dz = None outz = out dstates = _chunk_scan_bwd_dstates(C, dA_cumsum, dout, seq_idx=seq_idx, dtype=states.dtype) # dstates has length nchunks, containing the gradient to initial states at index 0 and # gradient to the states of chunk (nchunks - 2) at index (nchunks - 1) # Do computation in fp32 but convert dstates and states to fp16/bf16 since dstates and states # will be used in matmul in the next kernels. dstates, ddA_chunk_cumsum, dinitial_states, states = _state_passing_bwd( rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1], rearrange(dstates, "... p n -> ... (p n)"), dfinal_states=rearrange(dfinal_states, "... p n -> ... (p n)") if dfinal_states is not None else None, seq_idx=seq_idx, has_initial_states=initial_states is not None, dstates_dtype=x.dtype, states_dtype=x.dtype, chunk_size=chunk_size, ) # dstates has length nchunks, containing the gradient to states of chunk 0 at index 0 and # gradient to the final states at index (nchunks - 1) # states has length nchunks, containing the initial states at index 0 and the state for chunk (nchunks - 2) at index (nchunks - 1) # The final states is not stored. states = rearrange(states, "... (p n) -> ... p n", n=dstate) dstates = rearrange(dstates, "... (p n) -> ... p n", n=dstate) dinitial_states = rearrange(dinitial_states, "... (p n) -> ... p n", n=dstate) if dinitial_states is not None else None dx, ddt, dD_from_x = _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=D, seq_idx=seq_idx, dx=dx) # dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=seq_idx, ngroups=ngroups) dB, ddA_next = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=seq_idx, B=B, ngroups=ngroups) # dC = _chunk_scan_bwd_dC(states[:, :-1].to(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups) dC, ddA_cumsum_prev = _chunk_scan_bwd_dC(states.to(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, C=C, ngroups=ngroups) # Computing ddA with the dcb kernel is much slower, so we're not using it for now dCB = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups) # dCB, ddA_tmp = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, CB=CB, ngroups=ngroups) dCB = dCB.to(CB.dtype) _bmm_chunk_bwd(C, dCB, residual=dB, out=dB_given) _bmm_chunk_bwd(B, rearrange(dCB, "... l s -> ... s l"), residual=dC, out=dC_given) # If we have z, then dout_x is recomputed in fp32 so dD = (dout_x * x).sum() is more accurate # than dD_from_x = (dout_x * x).sum() where dout_x is in fp16/bf16 if z is None: dD = dD_from_x # Formula for ddA_cumsum, assuming out is the output of the forward pass before adding x * D. # ddA_cumsum = torch.einsum("bclhp,bclhp->bhcl", out.float(), dout.float()) - ddt * dt # However, this is numerically unstable: when we do the reverse cumsum on ddA_cumsum, there might # be a lot of underflow. # This is already done as part of bwd_dC kernel # ddA_cumsum_prev = _chunk_scan_bwd_ddAcs_prev(states[:, :-1], C, dout, dA_cumsum, seq_idx=seq_idx) ddA_cumsum_prev[..., -1] += ddA_chunk_cumsum ddA_prev = ddA_cumsum_prev.flip([-1]).cumsum(dim=-1).flip([-1]) # This is already done as part of bwd_dB kernel # ddA_next = _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=seq_idx) # We don't need to pass in seq_idx because CB also zeros out entries where seq_idx[i] != seq_idx[j] ddA = _chunk_scan_bwd_ddAcs_stable(x, dt, dA_cumsum, dout, CB) ddA += ddA_next + ddA_prev ddt_given, dA, ddt_bias = _chunk_cumsum_bwd(ddA, ddt, dt_in, A, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit, ddt=ddt_given) # These 2 lines are just to test ddt and dA being computed by old code # _, dA = selective_scan_bwd(dout, x, dt, A, B, C, D=D.float(), z=z) # ddt_given.copy_(ddt) return_vals = (dx, ddt_given, dA, dB_given, dC_given, dD, dz, ddt_bias, dinitial_states) return return_vals if not recompute_output else (*return_vals, outz) def selective_scan_bwd(dout, x, dt, A, B, C, D=None, z=None): """ Argument: dout: (batch, seqlen, nheads, headdim) x: (batch, seqlen, nheads, headdim) dt: (batch, nheads, nchunks, chunk_size) or (batch, nheads, headdim, nchunks, chunk_size) A: (nheads) or (dim, dstate) B: (batch, seqlen, ngroups, dstate) C: (batch, seqlen, ngroups, dstate) D: (nheads, headdim) or (nheads,) z: (batch, seqlen, nheads, headdim) Return: out: (batch, seqlen, nheads, headdim) """ import selective_scan batch, seqlen, nheads, headdim = x.shape chunk_size = dt.shape[-1] _, _, ngroups, dstate = B.shape assert nheads % ngroups == 0 x = rearrange(x, "b l h p -> b (h p) l") squeeze_dt = dt.dim() == 4 if dt.dim() == 4: dt = repeat(dt, "b h c l -> b h p c l", p=headdim) dt = rearrange(dt, "b h p c l -> b (h p) (c l)", p=headdim) squeeze_A = A.dim() == 1 if A.dim() == 1: A = repeat(A, "h -> (h p) n", p=headdim, n=dstate).to(dtype=torch.float32) else: A = A.to(dtype=torch.float32) B = rearrange(B, "b l g n -> b g n l") C = rearrange(C, "b l g n -> b g n l") if D is not None: if D.dim() == 2: D = rearrange(D, "h p -> (h p)") else: D = repeat(D, "h -> (h p)", p=headdim) if z is not None: z = rearrange(z, "b l h p -> b (h p) l") if x.stride(-1) != 1: x = x.contiguous() if dt.stride(-1) != 1: dt = dt.contiguous() if D is not None: D = D.contiguous() if B.stride(-1) != 1: B = B.contiguous() if C.stride(-1) != 1: C = C.contiguous() if z is not None and z.stride(-1) != 1: z = z.contiguous() _, intermediate, *rest = selective_scan.fwd(x, dt.to(dtype=x.dtype), A, B, C, D, z, None, False) if z is not None: out = rest[0] else: out = None dout = rearrange(dout, "b l h p -> b (h p) l") if dout.stride(-1) != 1: dout = dout.contiguous() # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the # backward of selective_scan with the backward of chunk). # Here we just pass in None and dz will be allocated in the C++ code. _, ddt, dA, *rest = selective_scan.bwd( x, dt.to(dtype=x.dtype), A, B, C, D, z, None, dout, intermediate, out, None, False, False # option to recompute out_z, not used here ) ddt = rearrange(ddt, "b (h p) (c l) -> b h p c l", p=headdim, l=chunk_size) if squeeze_dt: ddt = ddt.float().sum(dim=2) if squeeze_A: dA = rearrange(dA, "(h p) n -> h p n", p=headdim).sum(dim=(1, 2)) return ddt, dA class MambaChunkScanCombinedFn(torch.autograd.Function): @staticmethod def forward(ctx, x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf")), return_final_states=False, return_varlen_states=False): ctx.dt_dtype = dt.dtype if not return_varlen_states: cu_seqlens = None else: assert cu_seqlens is not None, "cu_seqlens must be provided if return_varlen_states is True" out, out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, cu_seqlens=cu_seqlens, dt_softplus=dt_softplus, dt_limit=dt_limit) ctx.save_for_backward(out if z is None else out_x, x, dt, dA_cumsum, A, B, C, D, z, dt_bias, initial_states, seq_idx) ctx.dt_softplus = dt_softplus ctx.chunk_size = chunk_size ctx.dt_limit = dt_limit ctx.return_final_states = return_final_states ctx.return_varlen_states = return_varlen_states if not return_varlen_states: return out if not return_final_states else (out, final_states) else: varlen_states = rest[0] return (out, varlen_states) if not return_final_states else (out, final_states, varlen_states) @staticmethod def backward(ctx, dout, *args): out, x, dt, dA_cumsum, A, B, C, D, z, dt_bias, initial_states, seq_idx = ctx.saved_tensors assert not ctx.return_varlen_states, "return_varlen_states is not supported in backward" dfinal_states = args[0] if ctx.return_final_states else None dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states = _mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, ctx.chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=dfinal_states, seq_idx=seq_idx, dt_softplus=ctx.dt_softplus, dt_limit=ctx.dt_limit) return dx, ddt, dA, dB, dC, None, dD, dz, ddt_bias, dinitial_states, None, None, None, None, None, None def mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf")), return_final_states=False, return_varlen_states=False): """ Argument: x: (batch, seqlen, nheads, headdim) dt: (batch, seqlen, nheads) A: (nheads) B: (batch, seqlen, ngroups, dstate) C: (batch, seqlen, ngroups, dstate) chunk_size: int D: (nheads, headdim) or (nheads,) z: (batch, seqlen, nheads, headdim) dt_bias: (nheads,) initial_states: (batch, nheads, headdim, dstate) seq_idx: (batch, seqlen) cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True dt_softplus: Whether to apply softplus to dt Return: out: (batch, seqlen, nheads, headdim) """ return MambaChunkScanCombinedFn.apply(x, dt, A, B, C, chunk_size, D, z, dt_bias, initial_states, seq_idx, cu_seqlens, dt_softplus, dt_limit, return_final_states, return_varlen_states) def mamba_chunk_scan(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False): """ Argument: x: (batch, seqlen, nheads, headdim) dt: (batch, seqlen, nheads) A: (nheads) B: (batch, seqlen, ngroups, dstate) C: (batch, seqlen, ngroups, dstate) D: (nheads, headdim) or (nheads,) z: (batch, seqlen, nheads, headdim) dt_bias: (nheads,) Return: out: (batch, seqlen, nheads, headdim) """ batch, seqlen, nheads, headdim = x.shape dstate = B.shape[-1] if seqlen % chunk_size != 0: dt = F.pad(dt, (0, 0, 0, chunk_size - seqlen % chunk_size)) dt = rearrange(dt, "b (c l) h -> b h c l", l=chunk_size) dt = dt.float() # We want high precision for this before cumsum if dt_bias is not None: dt = dt + rearrange(dt_bias, "h -> h 1 1") if dt_softplus: dt = F.softplus(dt) dA = dt * rearrange(A, "h -> h 1 1") dA = dt * rearrange(A, "h -> h 1 1") dA_cumsum = torch.cumsum(dA, dim=-1) # 1. Compute the state for each chunk states = chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True) # 2. Pass the state to all the chunks by weighted cumsum. states = rearrange(state_passing(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1])[0], "... (p n) -> ... p n", n=dstate) # 3. Compute the output for each chunk out = chunk_scan(B, C, x, dt, dA_cumsum, states, D=D, z=z) return out def ssd_chunk_scan_combined_ref(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False): """ Argument: x: (batch, seqlen, nheads, headdim) dt: (batch, seqlen, nheads) A: (nheads) B: (batch, seqlen, ngroups, dstate) C: (batch, seqlen, ngroups, dstate) D: (nheads, headdim) or (nheads,) z: (batch, seqlen, nheads, headdim) dt_bias: (nheads,) Return: out: (batch, seqlen, nheads, headdim) """ batch, seqlen, nheads, headdim = x.shape dstate = B.shape[-1] if seqlen % chunk_size != 0: dt = F.pad(dt, (0, 0, 0, chunk_size - seqlen % chunk_size)) dt = rearrange(dt, "b (c l) h -> b h c l", l=chunk_size) dt = dt.float() # We want high precision for this before cumsum if dt_bias is not None: dt = dt + rearrange(dt_bias, "h -> h 1 1") if dt_softplus: dt = F.softplus(dt) dA = dt * rearrange(A, "h -> h 1 1") dA_cumsum = torch.cumsum(dA, dim=-1) # 1. Compute the state for each chunk states = chunk_state_ref(B, x, dt, dA_cumsum) states_dtype = states.dtype if states.dtype not in [torch.float32, torch.float64]: states = states.to(torch.float32) # 2. Pass the state to all the chunks by weighted cumsum. # state_passing_ref is much less numerically stable states = rearrange(state_passing_ref(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1])[0], "... (p n) -> ... p n", n=dstate) states = states.to(states_dtype) # 3. Compute the output for each chunk out = chunk_scan_ref(B, C, x, dt, dA_cumsum, states, D=D, z=z) return out def ssd_selective_scan(x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))): """ Argument: x: (batch, seqlen, nheads, headdim) dt: (batch, seqlen, nheads) or (batch, seqlen, nheads, headdim) A: (nheads) or (dim, dstate) B: (batch, seqlen, ngroups, dstate) C: (batch, seqlen, ngroups, dstate) D: (nheads, headdim) or (nheads,) z: (batch, seqlen, nheads, headdim) dt_bias: (nheads,) or (nheads, headdim) Return: out: (batch, seqlen, nheads, headdim) """ from mamba_ssm.ops.selective_scan_interface import selective_scan_fn batch, seqlen, nheads, headdim = x.shape _, _, ngroups, dstate = B.shape x = rearrange(x, "b l h p -> b (h p) l") if dt.dim() == 3: dt = repeat(dt, "b l h -> b l h p", p=headdim) dt = rearrange(dt, "b l h p -> b (h p) l") if A.dim() == 1: A = repeat(A, "h -> (h p) n", p=headdim, n=dstate).to(dtype=torch.float32) else: A = A.to(dtype=torch.float32) B = rearrange(B, "b l g n -> b g n l") C = rearrange(C, "b l g n -> b g n l") if D is not None: if D.dim() == 2: D = rearrange(D, "h p -> (h p)") else: D = repeat(D, "h -> (h p)", p=headdim) if z is not None: z = rearrange(z, "b l h p -> b (h p) l") if dt_bias is not None: if dt_bias.dim() == 1: dt_bias = repeat(dt_bias, "h -> h p", p=headdim) dt_bias = rearrange(dt_bias, "h p -> (h p)") if dt_limit != (0.0, float("inf")): if dt_bias is not None: dt = dt + rearrange(dt_bias, "d -> d 1") if dt_softplus: dt = F.softplus(dt) dt = dt.clamp(min=dt_limit[0], max=dt_limit[1]).to(x.dtype) dt_bias = None dt_softplus = None out = selective_scan_fn(x, dt, A, B, C, D=D, z=z, delta_bias=dt_bias, delta_softplus=dt_softplus) return rearrange(out, "b (h p) l -> b l h p", p=headdim) def mamba_conv1d_scan_ref(xBC, conv1d_weight, conv1d_bias, dt, A, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf")), activation="silu", headdim=None, ngroups=1): """ Argument: xBC: (batch, seqlen, dim + 2 * ngroups * dstate) where dim == nheads * headdim conv1d_weight: (dim + 2 * ngroups * dstate, width) conv1d_bias: (dim + 2 * ngroups * dstate,) dt: (batch, seqlen, nheads) or (batch, seqlen, nheads, headdim) A: (nheads) D: (nheads, headdim) or (nheads,) z: (batch, seqlen, dim) dt_bias: (nheads) or (nheads, headdim) headdim: if D is 1D and z is None, headdim must be passed in Return: out: (batch, seqlen, dim) """ batch, seqlen, nheads = dt.shape[:3] assert nheads % ngroups == 0 if z is not None: dim = z.shape[-1] assert dim % nheads == 0 headdim = dim // nheads else: if D.dim() == 1: assert headdim is not None else: headdim = D.shape[1] dim = nheads * headdim xBC = rearrange(causal_conv1d_fn(rearrange(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias, activation=activation), "b d s -> b s d") dstate = (xBC.shape[-1] - dim) // ngroups // 2 x, B, C = torch.split(xBC, [dim, ngroups * dstate, ngroups * dstate], dim=-1) x = rearrange(x, "b l (h p) -> b l h p", h=nheads) B = rearrange(B, "b l (g n) -> b l g n", g=ngroups) C = rearrange(C, "b l (g n) -> b l g n", g=ngroups) z = rearrange(z, "b l (h p) -> b l h p", h=nheads) if z is not None else None out = ssd_selective_scan(x, dt.to(x.dtype), A, B, C, D=D.float(), z=z, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit) return rearrange(out, "b s h p -> b s (h p)") class MambaSplitConv1dScanCombinedFn(torch.autograd.Function): @staticmethod @custom_fwd def forward(ctx, zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float("inf")), return_final_states=False, activation="silu", rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None, ngroups=1, norm_before_gate=True): assert activation in [None, "silu", "swish"] if D.dim() == 1: assert headdim is not None nheads, = D.shape else: nheads, headdim = D.shape batch, seqlen, _ = zxbcdt.shape dim = nheads * headdim assert nheads % ngroups == 0 dstate = (conv1d_weight.shape[0] - dim) // ngroups // 2 d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ngroups * dstate - nheads) // 2 assert d_nonssm >= 0 assert zxbcdt.shape == (batch, seqlen, 2 * d_nonssm + 2 * dim + 2 * ngroups * dstate + nheads) assert dt_bias.shape == (nheads,) assert A.shape == (nheads,) zx0, z, xBC, dt = torch.split(zxbcdt, [2 * d_nonssm, dim, dim + ngroups * dstate * 2, nheads], dim=-1) seq_idx = seq_idx.contiguous() if seq_idx is not None else None xBC_conv = rearrange( causal_conv1d_fwd_function(rearrange(ensure_stride(xBC), "b s d -> b d s"), conv1d_weight, conv1d_bias, seq_idx, None, None, activation in ["silu", "swish"]), "b d s -> b s d" ) x, B, C = torch.split(xBC_conv, [dim, ngroups * dstate, ngroups * dstate], dim=-1) x = rearrange(x, "b l (h p) -> b l h p", h=nheads) B = rearrange(B, "b l (g n) -> b l g n", g=ngroups) C = rearrange(C, "b l (g n) -> b l g n", g=ngroups) z = rearrange(z, "b l (h p) -> b l h p", h=nheads) if z is not None else None if rmsnorm_weight is None: out, out_x, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size=chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=dt_limit) out = rearrange(out, "b s h p -> b s (h p)") rstd = None if d_nonssm > 0: out = torch.cat([_swiglu_fwd(zx0), out], dim=-1) else: out_x, _, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size=chunk_size, D=D, z=None, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=dt_limit) # reshape input data into 2D tensor x_rms = rearrange(out_x, "b s h p -> (b s) (h p)") z_rms = rearrange(z, "b s h p -> (b s) (h p)") rmsnorm_weight = rmsnorm_weight.contiguous() if d_nonssm == 0: out = None else: out01 = torch.empty((batch, seqlen, d_nonssm + dim), dtype=x_rms.dtype, device=x_rms.device) out = rearrange(out01[..., d_nonssm:], "b s d -> (b s) d") _swiglu_fwd(zx0, out=out01[..., :d_nonssm]) out, _, rstd = _layer_norm_fwd(x_rms, rmsnorm_weight, None, rmsnorm_eps, z_rms, out=out, group_size=dim // ngroups, norm_before_gate=norm_before_gate, is_rms_norm=True) if d_nonssm == 0: out = rearrange(out, "(b s) d -> b s d", b=batch) else: out = out01 ctx.outproj_weight_dtype = outproj_weight.dtype if outproj_weight is not None else None if outproj_weight is not None: if torch.is_autocast_enabled(): dtype = torch.get_autocast_gpu_dtype() out, outproj_weight = out.to(dtype), outproj_weight.to(dtype) outproj_bias = outproj_bias.to(dtype) if outproj_bias is not None else None out = F.linear(out, outproj_weight, outproj_bias) else: assert outproj_bias is None ctx.save_for_backward(zxbcdt, conv1d_weight, conv1d_bias, out_x, A, D, dt_bias, initial_states, seq_idx, rmsnorm_weight, rstd, outproj_weight, outproj_bias) ctx.dt_limit = dt_limit ctx.return_final_states = return_final_states ctx.activation = activation ctx.rmsnorm_eps = rmsnorm_eps ctx.norm_before_gate = norm_before_gate ctx.chunk_size = chunk_size ctx.headdim = headdim ctx.ngroups = ngroups return out if not return_final_states else (out, final_states) @staticmethod @custom_bwd def backward(ctx, dout, *args): zxbcdt, conv1d_weight, conv1d_bias, out, A, D, dt_bias, initial_states, seq_idx, rmsnorm_weight, rstd, outproj_weight, outproj_bias = ctx.saved_tensors dfinal_states = args[0] if ctx.return_final_states else None headdim = ctx.headdim nheads = D.shape[0] dim = nheads * headdim assert nheads % ctx.ngroups == 0 dstate = (conv1d_weight.shape[0] - dim) // ctx.ngroups // 2 d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ctx.ngroups * dstate - nheads) // 2 assert d_nonssm >= 0 recompute_output = outproj_weight is not None if recompute_output: out_recompute = torch.empty(*out.shape[:2], d_nonssm + dim, device=out.device, dtype=out.dtype) out0_recompute, out1_recompute = out_recompute.split([d_nonssm, dim], dim=-1) zx0, z, xBC, dt = torch.split(zxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1) # Recompute x, B, C xBC_conv = rearrange( causal_conv1d_fwd_function(rearrange(ensure_stride(xBC), "b s d -> b d s"), conv1d_weight, conv1d_bias, seq_idx, None, None, ctx.activation in ["silu", "swish"]), "b d s -> b s d" ) x, B, C = torch.split(xBC_conv, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1) x = rearrange(x, "b l (h p) -> b l h p", h=nheads) B = rearrange(B, "b l (g n) -> b l g n", g=ctx.ngroups) C = rearrange(C, "b l (g n) -> b l g n", g=ctx.ngroups) dzxbcdt = torch.empty_like(zxbcdt) dzx0, dz, dxBC_given, ddt_given = torch.split(dzxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1) dxBC = torch.empty_like(xBC) dx, dB, dC = torch.split(dxBC, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1) z = rearrange(z, "b l (h p) -> b l h p", h=nheads) dx = rearrange(dx, "b l (h p) -> b l h p", h=nheads) dB = rearrange(dB, "b l (g n) -> b l g n", g=ctx.ngroups) dC = rearrange(dC, "b l (g n) -> b l g n", g=ctx.ngroups) if outproj_weight is not None: dout_og = dout dout = F.linear(dout, outproj_weight.t()) if d_nonssm > 0: dout0, dout = dout.split([d_nonssm, dim], dim=-1) _swiglu_bwd(zx0, dout0, dxy=dzx0, recompute_output=True, out=out0_recompute) dout = rearrange(dout, "b s (h p) -> b s h p", p=headdim) if rmsnorm_weight is None: dz = rearrange(dz, "b l (h p) -> b l h p", h=nheads) dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states, *rest = _mamba_chunk_scan_combined_bwd( dout, x, dt, A, B, C, out, ctx.chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=dfinal_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=ctx.dt_limit, dx=dx, ddt=ddt_given, dB=dB, dC=dC, dz=dz, recompute_output=recompute_output ) out_for_linear = rearrange(rest[0], "b s h p -> b s (h p)") if recompute_output else None drmsnorm_weight = None else: batch = dout.shape[0] dy_rms = rearrange(dout, "b s h p -> (b s) (h p)") dz = rearrange(dz, "b l d -> (b l) d") x_rms = rearrange(out, "b s h p -> (b s) (h p)") z_rms = rearrange(z, "b s h p -> (b s) (h p)") out1_recompute = rearrange(out1_recompute, "b s d -> (b s) d") if recompute_output else None dout, drmsnorm_weight, _, dz, *rest = _layer_norm_bwd(dy_rms, x_rms, rmsnorm_weight, None, ctx.rmsnorm_eps, None, rstd, z_rms, group_size=dim//ctx.ngroups, norm_before_gate=ctx.norm_before_gate, is_rms_norm=True, recompute_output=recompute_output, dz=dz, out=out1_recompute if recompute_output else None) out_for_linear = out_recompute if recompute_output else None dout = rearrange(dout, "(b s) (h p) -> b s h p", b=batch, p=headdim) dx, ddt, dA, dB, dC, dD, _, ddt_bias, dinitial_states = _mamba_chunk_scan_combined_bwd( dout, x, dt, A, B, C, out, ctx.chunk_size, D=D, z=None, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=dfinal_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=ctx.dt_limit, dx=dx, ddt=ddt_given, dB=dB, dC=dC ) if outproj_weight is not None: doutproj_weight = torch.einsum("bso,bsd->od", dout_og, out_for_linear) doutproj_bias = dout_og.sum(dim=(0, 1)) if outproj_bias is not None else None else: doutproj_weight, doutproj_bias = None, None dxBC_given_update, dweight, dbias, *_ = causal_conv1d_bwd_function( rearrange(ensure_stride(xBC), "b s d -> b d s"), conv1d_weight, conv1d_bias, # It might be okay to not run ensure_stride on dxBC, but we're not sure. So playing safe here. rearrange(ensure_stride(dxBC), "b s d -> b d s"), seq_idx, None, None, rearrange(ensure_stride(dxBC_given), "b s d -> b d s"), False, ctx.activation in ["silu", "swish"] ) dxBC_given_update = rearrange(dxBC_given_update, "b d s -> b s d") if dxBC_given.stride() != dxBC_given_update.stride(): dxBC_given.copy_(dxBC_given_update) else: dxBC_given = dxBC_given_update return dzxbcdt, dweight, dbias, ddt_bias, dA, dD, None, dinitial_states, None, None, None, None, drmsnorm_weight, None, doutproj_weight, doutproj_bias, None, None, None def mamba_split_conv1d_scan_combined(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float("inf")), return_final_states=False, activation="silu", rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None, ngroups=1, norm_before_gate=True): """ Argument: zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim conv1d_weight: (dim + 2 * ngroups * dstate, width) conv1d_bias: (dim + 2 * ngroups * dstate,) dt_bias: (nheads,) A: (nheads) D: (nheads, headdim) or (nheads,) initial_states: (batch, nheads, headdim, dstate) seq_idx: (batch, seqlen), int32 rmsnorm_weight: (dim,) outproj_weight: (out_dim, dim) outproj_bias: (out_dim,) headdim: if D is 1D, headdim must be passed in norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z)) Return: out: (batch, seqlen, dim) """ return MambaSplitConv1dScanCombinedFn.apply(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate) def mamba_split_conv1d_scan_ref(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, dt_limit=(0.0, float("inf")), activation="silu", rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None, ngroups=1, norm_before_gate=True): """ Argument: zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim conv1d_weight: (dim + 2 * ngroups * dstate, width) conv1d_bias: (dim + 2 * ngroups * dstate,) dt_bias: (nheads,) A: (nheads) D: (nheads, headdim) or (nheads,) rmsnorm_weight: (dim,) outproj_weight: (out_dim, dim) outproj_bias: (out_dim,) headdim: if D is 1D, headdim must be passed in norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z)) Return: out: (batch, seqlen, dim) """ if D.dim() == 1: assert headdim is not None nheads, = D.shape else: nheads, headdim = D.shape assert nheads % ngroups == 0 batch, seqlen, _ = zxbcdt.shape dim = nheads * headdim dstate = (zxbcdt.shape[-1] - 2 * dim - nheads) // ngroups // 2 assert zxbcdt.shape == (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) assert dt_bias.shape == (nheads,) assert A.shape == (nheads,) if rmsnorm_weight is not None: assert rmsnorm_weight.shape == (dim,) z, xBC, dt = torch.split(zxbcdt, [dim, dim + 2 * ngroups * dstate, nheads], dim=-1) xBC = rearrange(causal_conv1d_fn(rearrange(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias, activation=activation), "b d s -> b s d") x, B, C = torch.split(xBC, [dim, ngroups * dstate, ngroups * dstate], dim=-1) x = rearrange(x, "b l (h p) -> b l h p", h=nheads) B = rearrange(B, "b l (g n) -> b l g n", g=ngroups) C = rearrange(C, "b l (g n) -> b l g n", g=ngroups) z = rearrange(z, "b l (h p) -> b l h p", h=nheads) out = ssd_selective_scan(x, dt.to(x.dtype), A, B, C, D=D.float(), z=z if rmsnorm_weight is None else None, dt_bias=dt_bias, dt_softplus=True, dt_limit=dt_limit) out = rearrange(out, "b s h p -> b s (h p)") if rmsnorm_weight is not None: out = rmsnorm_fn(out, rmsnorm_weight, None, z=rearrange(z, "b l h p -> b l (h p)"), eps=rmsnorm_eps, norm_before_gate=norm_before_gate) if outproj_weight is not None: out = F.linear(out, outproj_weight, outproj_bias) return out ================================================ FILE: mamba_ssm/ops/triton/ssd_state_passing.py ================================================ # Copyright (c) 2024, Tri Dao, Albert Gu. """We want triton==2.1.0 or 2.2.0 for this """ import math import torch import torch.nn.functional as F import triton import triton.language as tl from einops import rearrange, repeat from mamba_ssm.utils.determinism import autotune_configs @triton.autotune( configs=autotune_configs([ triton.Config({'BLOCK_SIZE': 64}), triton.Config({'BLOCK_SIZE': 128}), triton.Config({'BLOCK_SIZE': 256}), triton.Config({'BLOCK_SIZE': 512}), triton.Config({'BLOCK_SIZE': 1024}), triton.Config({'BLOCK_SIZE': 2048}), ]), key=['dim'], ) @triton.jit def _state_passing_fwd_kernel( # Pointers to matrices states_ptr, out_ptr, final_states_ptr, dA_cs_ptr, initstates_ptr, seq_idx_ptr, # Matrix dimensions dim, nchunks, seqlen, chunk_size, # Strides stride_states_batch, stride_states_chunk, stride_states_head, stride_states_dim, stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim, stride_final_states_batch, stride_final_states_head, stride_final_states_dim, stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_initstates_batch, stride_initstates_head, stride_initstates_dim, stride_seq_idx_batch, stride_seq_idx_seqlen, # Meta-parameters HAS_INITSTATES: tl.constexpr, HAS_SEQ_IDX: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): pid_b = tl.program_id(axis=1) pid_h = tl.program_id(axis=2) pid_m = tl.program_id(axis=0) states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head if HAS_INITSTATES: initstates_ptr += pid_b * stride_initstates_batch + pid_h * stride_initstates_head if HAS_SEQ_IDX: seq_idx_ptr += pid_b * stride_seq_idx_batch offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) states_ptrs = states_ptr + offs_m * stride_states_dim out_ptrs = out_ptr + offs_m * stride_out_dim final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim if not HAS_INITSTATES: states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) else: initstates_ptrs = initstates_ptr + offs_m * stride_initstates_dim states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) tl.store(out_ptrs, states, mask=offs_m < dim) out_ptrs += stride_out_chunk seq_idx = 0 for c in range(nchunks): new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) dA_cs = tl.load(dA_cs_ptr).to(tl.float32) scale = tl.exp(dA_cs) if HAS_SEQ_IDX: seq_idx_new = tl.load(seq_idx_ptr + (min((c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen) scale = tl.where(seq_idx_new == seq_idx, scale, 0.0) seq_idx = seq_idx_new states = scale * states + new_states if c < nchunks - 1: tl.store(out_ptrs, states, mask=offs_m < dim) else: tl.store(final_states_ptrs, states, mask=offs_m < dim) states_ptrs += stride_states_chunk dA_cs_ptr += stride_dA_cs_chunk out_ptrs += stride_out_chunk @triton.autotune( configs=autotune_configs([ triton.Config({'BLOCK_SIZE': 64}), triton.Config({'BLOCK_SIZE': 128}), triton.Config({'BLOCK_SIZE': 256}), triton.Config({'BLOCK_SIZE': 512}), triton.Config({'BLOCK_SIZE': 1024}), triton.Config({'BLOCK_SIZE': 2048}), ]), key=['dim'], ) @triton.jit def _state_passing_bwd_kernel( # Pointers to matrices dout_ptr, out_ptr, dA_cs_ptr, dfinal_states_ptr, seq_idx_ptr, dstates_ptr, ddA_cs_ptr, dinitstates_ptr, states_converted_ptr, # Matrix dimensions dim, nchunks, seqlen, chunk_size, # Strides stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_dim, stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim, stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dfinal_states_batch, stride_dfinal_states_head, stride_dfinal_states_dim, stride_seq_idx_batch, stride_seq_idx_seqlen, stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_dim, stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_dinitstates_batch, stride_dinitstates_head, stride_dinitstates_dim, # Meta-parameters CONVERT_STATES: tl.constexpr, HAS_DFINAL_STATES: tl.constexpr, HAS_DINITSTATES: tl.constexpr, HAS_SEQ_IDX: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): pid_b = tl.program_id(axis=1) pid_h = tl.program_id(axis=2) pid_m = tl.program_id(axis=0) dstates_ptr += pid_b * stride_dstates_batch + pid_h * stride_dstates_head + (nchunks - 1) * stride_dstates_chunk dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + (nchunks - 1) * stride_dA_cs_chunk ddA_cs_ptr += pid_b * stride_ddA_cs_batch + pid_h * stride_ddA_cs_head + (nchunks - 1) * stride_ddA_cs_chunk + pid_m out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk dout_ptr += pid_b * stride_dout_batch + pid_h * stride_dout_head + (nchunks - 1) * stride_dout_chunk if CONVERT_STATES: states_converted_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk if HAS_DFINAL_STATES: dfinal_states_ptr += pid_b * stride_dfinal_states_batch + pid_h * stride_dfinal_states_head if HAS_DINITSTATES: dinitstates_ptr += pid_b * stride_dinitstates_batch + pid_h * stride_dinitstates_head if HAS_SEQ_IDX: seq_idx_ptr += pid_b * stride_seq_idx_batch offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) dstates_ptrs = dstates_ptr + offs_m * stride_dstates_dim out_ptrs = out_ptr + offs_m * stride_out_dim dout_ptrs = dout_ptr + offs_m * stride_dout_dim if CONVERT_STATES: states_converted_ptrs = states_converted_ptr + offs_m * stride_out_dim if HAS_DFINAL_STATES: dstates = tl.load(dfinal_states_ptr + offs_m * stride_dfinal_states_dim, mask=offs_m < dim, other=0.0).to(tl.float32) else: dstates = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) tl.store(dstates_ptrs, dstates, mask=offs_m < dim) if HAS_SEQ_IDX: seq_idx = tl.load(seq_idx_ptr + (seqlen - 1) * stride_seq_idx_seqlen) dstates_ptrs -= stride_dstates_chunk for c in range(nchunks - 1): dA_cs = tl.load(dA_cs_ptr).to(tl.float32) scale = tl.exp(dA_cs) if HAS_SEQ_IDX: seq_idx_new = tl.load(seq_idx_ptr + (((nchunks - c - 1) * chunk_size - 1) * stride_seq_idx_seqlen)) scale = tl.where(seq_idx_new == seq_idx, scale, 0.0) seq_idx = seq_idx_new out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) if CONVERT_STATES: tl.store(states_converted_ptrs, out, mask=offs_m < dim) ddA = tl.sum(out * dstates) * scale tl.store(ddA_cs_ptr, ddA) dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) dstates = scale * dstates + dout tl.store(dstates_ptrs, dstates, mask=offs_m < dim) dout_ptrs -= stride_dout_chunk dstates_ptrs -= stride_dstates_chunk dA_cs_ptr -= stride_dA_cs_chunk ddA_cs_ptr -= stride_ddA_cs_chunk out_ptrs -= stride_out_chunk if CONVERT_STATES: states_converted_ptrs -= stride_out_chunk if CONVERT_STATES: out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) tl.store(states_converted_ptrs, out, mask=offs_m < dim) if not HAS_DINITSTATES: tl.store(ddA_cs_ptr, 0.0) else: dA_cs = tl.load(dA_cs_ptr).to(tl.float32) scale = tl.exp(dA_cs) if HAS_SEQ_IDX: scale = tl.where(seq_idx == 0, scale, 0.0) out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) ddA = tl.sum(out * dstates) * scale tl.store(ddA_cs_ptr, ddA) dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) dstates = scale * dstates + dout tl.store(dinitstates_ptr + offs_m * stride_dinitstates_dim, dstates, mask=offs_m < dim) def _state_passing_fwd(states, dA_chunk_cumsum, initial_states=None, seq_idx=None, chunk_size=None, out_dtype=None): batch, nchunks, nheads, dim = states.shape assert dA_chunk_cumsum.shape == (batch, nheads, nchunks) if initial_states is not None: assert initial_states.shape == (batch, nheads, dim) if seq_idx is not None: assert chunk_size is not None seqlen = seq_idx.shape[-1] assert seq_idx.shape == (batch, seqlen) out_dtype = states.dtype if out_dtype is None else out_dtype out = torch.empty((batch, nchunks, nheads, dim), device=states.device, dtype=out_dtype) final_states = torch.empty((batch, nheads, dim), device=states.device, dtype=torch.float32) grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads) with torch.cuda.device(states.device.index): _state_passing_fwd_kernel[grid]( states, out, final_states, dA_chunk_cumsum, initial_states, seq_idx, dim, nchunks, seqlen if seq_idx is not None else 0, chunk_size if seq_idx is not None else 0, states.stride(0), states.stride(1), states.stride(2), states.stride(3), out.stride(0), out.stride(1), out.stride(2), out.stride(3), final_states.stride(0), final_states.stride(1), final_states.stride(2), dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1), *((initial_states.stride(0), initial_states.stride(1), initial_states.stride(2)) if initial_states is not None else (0, 0, 0)), *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), HAS_INITSTATES=initial_states is not None, HAS_SEQ_IDX=seq_idx is not None, ) return out, final_states def _state_passing_bwd( states, dA_chunk_cumsum, dout, dfinal_states=None, seq_idx=None, has_initial_states=None, dstates_dtype=None, states_dtype=None, chunk_size=None ): """ states contains the initial_states at index 0. The final states are not included in states. """ batch, nchunks, nheads, dim = states.shape assert dA_chunk_cumsum.shape == (batch, nheads, nchunks) assert dout.shape == (batch, nchunks, nheads, dim) if seq_idx is not None: assert chunk_size is not None seqlen = seq_idx.shape[-1] assert seq_idx.shape == (batch, seqlen) dstates = torch.empty_like(dout, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype) if states_dtype is not None and states_dtype != states.dtype: states_converted = torch.empty_like(states, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype) assert states_converted.stride() == states.stride() else: states_converted = None if has_initial_states: dinitstates = torch.empty_like(dstates[:, 0]) else: dinitstates = None if dfinal_states is not None: assert dfinal_states.shape == (batch, nheads, dim) BLOCK_SIZE_min = 64 n_blocks = (dim + BLOCK_SIZE_min - 1) // BLOCK_SIZE_min ddA_chunk_cumsum = torch.empty(batch, nheads, nchunks, n_blocks, dtype=torch.float32, device=dA_chunk_cumsum.device) grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads) with torch.cuda.device(dout.device.index): _state_passing_bwd_kernel[grid]( dout, states, dA_chunk_cumsum, dfinal_states, seq_idx, dstates, ddA_chunk_cumsum, dinitstates, states_converted, dim, nchunks, seqlen if seq_idx is not None else 0, chunk_size if seq_idx is not None else 0, dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), states.stride(0), states.stride(1), states.stride(2), states.stride(3), dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1), *((dfinal_states.stride(0), dfinal_states.stride(1), dfinal_states.stride(2)) if dfinal_states is not None else (0, 0, 0)), *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), ddA_chunk_cumsum.stride(0), ddA_chunk_cumsum.stride(2), ddA_chunk_cumsum.stride(1), *((dinitstates.stride(0), dinitstates.stride(1), dinitstates.stride(2)) if dinitstates is not None else (0, 0, 0)), CONVERT_STATES=states_converted is not None, HAS_DFINAL_STATES=dfinal_states is not None, HAS_DINITSTATES=dinitstates is not None, HAS_SEQ_IDX=seq_idx is not None, ) BLOCK_SIZE_actual = _state_passing_bwd_kernel.best_config.kwargs["BLOCK_SIZE"] n_valid_blocks = (dim + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual ddA_chunk_cumsum = ddA_chunk_cumsum[..., :n_valid_blocks].sum(dim=-1).to(dtype=dA_chunk_cumsum.dtype) if states_dtype is not None and states_dtype == states.dtype: states_converted = states return (dstates, ddA_chunk_cumsum, dinitstates) if states_dtype is None else (dstates, ddA_chunk_cumsum, dinitstates, states_converted) class StatePassingFn(torch.autograd.Function): @staticmethod def forward(ctx, states, dA_chunk_cumsum, initial_states=None): batch, nchunks, nheads, dim = states.shape assert dA_chunk_cumsum.shape == (batch, nheads, nchunks) if states.stride(-1) != 1: states = states.contiguous() out, final_states = _state_passing_fwd(states, dA_chunk_cumsum, initial_states) ctx.save_for_backward(out, dA_chunk_cumsum) ctx.has_initial_states = initial_states is not None return out, final_states @staticmethod def backward(ctx, dout, dfinal_states): out, dA_chunk_cumsum = ctx.saved_tensors batch, nchunks, nheads, dim = out.shape assert dout.shape == (batch, nchunks, nheads, dim) assert dA_chunk_cumsum.shape == (batch, nheads, nchunks) assert dfinal_states.shape == (batch, nheads, dim) if dout.stride(-1) != 1: dout = dout.contiguous() dstates, ddA_chunk_cumsum, dinitstates = _state_passing_bwd( out, dA_chunk_cumsum, dout, dfinal_states=dfinal_states , has_initial_states=ctx.has_initial_states ) return dstates, ddA_chunk_cumsum, dinitstates def state_passing(states, dA_chunk_cumsum, initial_states=None): """ Argument: states: (batch, nchunks, nheads, dim) dA_chunk_cumsum: (batch, nheads, nchunks) initial_states: (batch, nheads, dim) Return: out: (batch, nchunks, nheads, dim) final_states: (batch, nheads, dim) """ return StatePassingFn.apply(states, dA_chunk_cumsum, initial_states) def state_passing_ref(states, dA_chunk_cumsum, initial_states=None): """ Argument: states: (batch, nchunks, nheads, dim) dA_chunk_cumsum: (batch, nheads, nchunks) initial_states: (batch, nheads, dim) Return: out: (batch, nchunks, nheads, dim) final_states: (batch, nheads, dim) """ if initial_states is None: initial_states = torch.zeros_like(states[:, 0]) states = torch.cat([rearrange(initial_states, "b h d -> b 1 h d"), states], dim=1) dA_chunk_cumsum = F.pad(dA_chunk_cumsum, (1, 0)) dA_chunk_cumsum = torch.cumsum(dA_chunk_cumsum, dim=-1) nchunks = dA_chunk_cumsum.shape[-1] # (batch, nheads, nchunks, nchunks) dt_chunk_segment_sum = dA_chunk_cumsum[:, :, :, None] - dA_chunk_cumsum[:, :, None, :] # (batch, nheads, nchunks, nchunks) decay_chunk = torch.exp(dt_chunk_segment_sum) causal_mask = torch.tril(torch.ones(nchunks, nchunks, device=states.device, dtype=bool), diagonal=0) decay_chunk = decay_chunk.masked_fill(~causal_mask, 0) out = torch.einsum("bhzc,bchd->bzhd", decay_chunk.to(dtype=states.dtype), states) return out[:, :-1], out[:, -1] ================================================ FILE: mamba_ssm/utils/__init__.py ================================================ ================================================ FILE: mamba_ssm/utils/determinism.py ================================================ # Copyright (c) 2024, Tri Dao, Albert Gu. import os import warnings from packaging import version import torch try: import triton TRITON_VERSION = version.parse(triton.__version__) except ImportError: TRITON_VERSION = version.parse("0.0.0") TRITON_HAS_CACHE_RESULTS = TRITON_VERSION >= version.parse("3.4.0") _autotune_warning_issued = False _deterministic_override = None def use_deterministic_mode(): if _deterministic_override is not None: return _deterministic_override env = os.environ.get('MAMBA_DETERMINISTIC') if env: return env[0] == '1' return torch.are_deterministic_algorithms_enabled() def set_deterministic_mode(value): global _deterministic_override _deterministic_override = value def _estimate_config_cost(cfg): """Estimate shared memory cost of a config. Lower is cheaper.""" block_product = 1 for key, val in cfg.kwargs.items(): if key.startswith('BLOCK_SIZE_'): block_product *= val return block_product * (getattr(cfg, 'num_stages', 1) or 1) def _filter_configs_by_block_sizes(configs): """Filter configs by TRITON_AUTOTUNE_BLOCK_SIZE_* env vars.""" env_filters = {} for suffix in ('M', 'N', 'K', 'DSTATE'): env_val = os.environ.get(f"TRITON_AUTOTUNE_BLOCK_SIZE_{suffix}") if env_val is not None: env_filters[f'BLOCK_SIZE_{suffix}'] = int(env_val) if not env_filters: return None matching = configs for key, target in env_filters.items(): matching = [c for c in matching if c.kwargs.get(key) == target] return matching[:1] if matching else None def autotune_configs(configs): """Select autotune configs for deterministic mode. Uses cached autotuning (TRITON_CACHE_AUTOTUNING=1) if Triton >= 3.4.0, otherwise auto-selects the cheapest config by block size * stages. """ if not configs or not use_deterministic_mode(): return configs if TRITON_HAS_CACHE_RESULTS and os.environ.get("TRITON_CACHE_AUTOTUNING") == "1": return configs global _autotune_warning_issued if not _autotune_warning_issued: _autotune_warning_issued = True msg = "Deterministic mode: set TRITON_CACHE_AUTOTUNING=1 for cached autotuning." if TRITON_HAS_CACHE_RESULTS else "Deterministic mode: upgrade to Triton >= 3.4.0 for cached autotuning." warnings.warn(msg) filtered = _filter_configs_by_block_sizes(configs) if filtered: return filtered return [min(configs, key=_estimate_config_cost)] def alloc_tile_workspace(base_shape, tile_dim, dtype, device, deterministic, *, zero_init=True): """Allocate buffer for deterministic per-program reductions.""" if base_shape is None: return None, 0 if deterministic: factory = torch.zeros if zero_init else torch.empty tensor = factory(*base_shape, tile_dim, device=device, dtype=dtype) return tensor, tensor.stride(-1) return torch.empty(*base_shape, device=device, dtype=dtype), 0 def finalize_tile_workspace(tensor, deterministic): if tensor is None: return None if deterministic: tensor = tensor.sum(dim=-1) return tensor ================================================ FILE: mamba_ssm/utils/generation.py ================================================ # Copyright (c) 2023, Albert Gu, Tri Dao. import gc import time from collections import namedtuple from dataclasses import dataclass, field from functools import partial from typing import Callable, Optional, Sequence, Union import torch import torch.nn.functional as F from einops import rearrange, repeat from torch import Tensor from torch.profiler import ProfilerActivity, profile, record_function from transformers.generation import GenerateDecoderOnlyOutput, TextStreamer @dataclass class InferenceParams: """Inference parameters that are passed to the main model in order to efficienly calculate and store the context during inference.""" max_seqlen: int max_batch_size: int seqlen_offset: int = 0 batch_size_offset: int = 0 key_value_memory_dict: dict = field(default_factory=dict) lengths_per_sample: Optional[Tensor] = None def reset(self, max_seqlen, max_batch_size): self.max_seqlen = max_seqlen self.max_batch_size = max_batch_size self.seqlen_offset = 0 if self.lengths_per_sample is not None: self.lengths_per_sample.zero_() def modify_logits_for_min_p_filtering(logits, min_p): """Set the logits for none min_p values to -inf. Done in-place.""" if min_p <= 0.0 or min_p >= 1.0: return indices_to_remove = logits < min_p logits.masked_fill_(indices_to_remove, float("-Inf")) # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231 def modify_logits_for_top_k_filtering(logits, top_k): """Set the logits for none top-k values to -inf. Done in-place.""" indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] logits.masked_fill_(indices_to_remove, float("-Inf")) # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170 def modify_logits_for_top_p_filtering(logits, top_p): """Set the logits for none top-p values to -inf. Done in-place.""" if top_p <= 0.0 or top_p >= 1.0: return # First sort and calculate cumulative sum of probabilities. sorted_logits, sorted_indices = torch.sort(logits, descending=False) cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) sorted_indices_to_remove = cumulative_probs <= (1 - top_p) # scatter sorted tensors to original indexing indices_to_remove = sorted_indices_to_remove.scatter( 1, sorted_indices, sorted_indices_to_remove ) logits.masked_fill_(indices_to_remove, float("-inf")) def modify_logit_for_repetition_penalty(logits, prev_output_tokens, repetition_penalty=1.0): """Apply repetition penalty. See https://arxiv.org/abs/1909.05858 logits: (batch_size, vocab_size) prev_output_tokens: (batch_size, seq_len) """ if repetition_penalty == 1.0: return logits score = torch.gather(logits, 1, prev_output_tokens) # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty) logits.scatter_(1, prev_output_tokens, score) return logits def sample(logits, top_k=1, top_p=0.0, min_p=0.0, temperature=1.0): """Sample from top-k logits. Arguments: logits: Tensor of shape (batch_size, vocab_size) """ if top_k == 1: # Short-circuit for greedy decoding return logits.argmax(dim=-1) else: if top_p > 0.0: assert top_p <= 1.0, "top-p should be in (0, 1]." if top_k > 0: top_k = min(top_k, logits.size(-1)) # Safety check logits_top, indices = torch.topk(logits, top_k, dim=-1) if temperature != 1.0: logits_top /= temperature modify_logits_for_top_p_filtering(logits_top, top_p) return indices[ torch.arange(indices.shape[0], device=indices.device), torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1), ] else: if min_p > 0.0: logits_top = logits.clone() max_prob = logits_top[..., 0].item() min_prob = max_prob * min_p modify_logits_for_min_p_filtering(logits_top, min_prob) if temperature != 1.0: logits_top /= temperature return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1) # Clone so that when we modify for top_p we don't change the original logits logits_top = logits / temperature if temperature != 1.0 else logits.clone() modify_logits_for_top_p_filtering(logits_top, top_p) return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze( dim=-1 ) @torch.inference_mode() def decode( input_ids, model, max_length, top_k=1, top_p=0.0, min_p=0.0, temperature=1.0, repetition_penalty=1.0, eos_token_id=None, teacher_outputs=None, vocab_size=None, cg=False, enable_timing=False, output_scores=False, streamer: Optional[TextStreamer] = None ): """Decoding, either greedy or with top-k or top-p sampling. If top-k = 0, don't limit the number of candidates (pure sampling). Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first, then top-p. We assume that all sequences in the same batch have the same length. Arguments: input_ids: (batch, seq_len) max_length: int teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the logits, the next token is taken from the teacher_outputs. Useful for testing. Returns: GenerateDecoderOnlyOutput, with the following fields: sequences: (batch, max_length) scores: tuples of (batch, vocab_size) """ if streamer is not None: streamer.put(input_ids.cpu()) batch_size, seqlen_og = input_ids.shape teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0 if cg: if not hasattr(model, "_decoding_cache"): model._decoding_cache = None model._decoding_cache = update_graph_cache( model, model._decoding_cache, batch_size, seqlen_og, max_length, ) inference_params = model._decoding_cache.inference_params inference_params.reset(max_length, batch_size) else: inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size) def get_logits(input_ids, inference_params): decoding = inference_params.seqlen_offset > 0 if decoding: position_ids = torch.full( (batch_size, 1), inference_params.seqlen_offset, dtype=torch.long, device=input_ids.device, ) else: position_ids = None if not cg or not decoding: logits = model( input_ids, position_ids=position_ids, inference_params=inference_params, num_last_tokens=1, ).logits.squeeze(dim=1) else: logits = model._decoding_cache.run( input_ids, position_ids, inference_params.seqlen_offset ).squeeze(dim=1) return logits[..., :vocab_size] if vocab_size is not None else logits def sample_tokens(logits, inference_params): if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset: token = sample(logits, top_k=top_k, top_p=top_p, min_p=min_p, temperature=temperature) else: token = teacher_outputs[:, inference_params.seqlen_offset] # return rearrange(token, "b -> b 1") return token.unsqueeze(1) def should_stop(current_token, inference_params): if inference_params.seqlen_offset == 0: return False if eos_token_id is not None and (current_token == eos_token_id).all(): return True if inference_params.seqlen_offset >= max_length - 1: return True return False start = torch.cuda.Event(enable_timing=enable_timing) end = torch.cuda.Event(enable_timing=enable_timing) if enable_timing: start.record() scores, sequences = [], [input_ids] sequences_cat = input_ids while not should_stop(sequences[-1], inference_params): logits = get_logits(sequences[-1], inference_params) if output_scores: scores.append(logits.clone()) inference_params.seqlen_offset += sequences[-1].shape[1] if repetition_penalty == 1.0: sampled_tokens = sample_tokens(logits, inference_params) else: logits = modify_logit_for_repetition_penalty( logits, sequences_cat, repetition_penalty ) sampled_tokens = sample_tokens(logits, inference_params) sequences_cat = torch.cat([sequences_cat, sampled_tokens], dim=1) sequences.append(sampled_tokens) if streamer is not None: streamer.put(sampled_tokens.cpu()) if streamer is not None: streamer.end() if enable_timing: end.record() torch.cuda.synchronize() print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms") return GenerateDecoderOnlyOutput(sequences=torch.cat(sequences, dim=1), scores=tuple(scores)) class GenerationMixin: def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): raise NotImplementedError def generate( self, input_ids, max_length, top_k=1, top_p=0.0, min_p=0.0, temperature=1.0, return_dict_in_generate=False, output_scores=False, **kwargs, ): output = decode( input_ids, self, max_length, top_k=top_k, top_p=top_p, min_p = min_p, temperature=temperature, output_scores=output_scores, **kwargs ) if not output_scores: output.scores = None return output if return_dict_in_generate else output.sequences @dataclass class DecodingCGCache: max_batch_size: int = 0 max_seqlen: int = 0 device = None dtype = None callables: dict = field(default_factory=dict) mempool = None inference_params: Optional[InferenceParams] = None run: Optional[Callable] = None @torch.inference_mode() def update_graph_cache( model, cache, batch_size, seqlen_og, max_seqlen, decoding_seqlens=(1,), dtype=None, n_warmups=2, ): if cache is None: cache = DecodingCGCache() param_example = next(iter(model.parameters())) device = param_example.device if dtype is None: dtype = param_example.dtype if ( (device, dtype) != (cache.device, cache.dtype) or batch_size > cache.max_batch_size or max_seqlen > cache.max_seqlen ): # Invalidate the cache cache.callables = {} cache.mempool = None cache.inference_params = None gc.collect() cache.device, cache.dtype = device, dtype cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen assert hasattr(model, "allocate_inference_cache"), "CUDA graph decoding requires that the model has a method allocate_inference_cache" inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype) lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device) cache.inference_params = InferenceParams( max_seqlen=max_seqlen, max_batch_size=batch_size, seqlen_offset=seqlen_og, key_value_memory_dict=inf_cache, lengths_per_sample=lengths_per_sample, ) cache.mempool = torch.cuda.graphs.graph_pool_handle() for decoding_seqlen in decoding_seqlens: if (batch_size, decoding_seqlen) not in cache.callables: cache.callables[batch_size, decoding_seqlen] = capture_graph( model, cache.inference_params, batch_size, max_seqlen, decoding_seqlen=decoding_seqlen, mempool=cache.mempool, n_warmups=n_warmups, ) def dispatch(input_ids, position_ids, seqlen): batch_size, decoding_seqlen = input_ids.shape[:2] return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen) cache.run = dispatch cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing return cache def capture_graph( model, inference_params, batch_size, max_seqlen, decoding_seqlen=1, mempool=None, n_warmups=2 ): device = next(iter(model.parameters())).device input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device) position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device) seqlen_offset_og = inference_params.seqlen_offset inference_params.seqlen_offset = max_seqlen - decoding_seqlen inference_params.lengths_per_sample[:] = inference_params.seqlen_offset # Warmup before capture s = torch.cuda.Stream() s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): for _ in range(n_warmups): logits = model( input_ids, position_ids=position_ids, inference_params=inference_params, num_last_tokens=decoding_seqlen, ).logits s.synchronize() # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0, # which requires that graph launch and non-captured launch to not overlap (I think, # that's how I interpret the documentation). I'm not sure if this is required. if torch.distributed.is_initialized(): torch.distributed.barrier() torch.cuda.current_stream().wait_stream(s) # Captures the graph # To allow capture, automatically sets a side stream as the current stream in the context graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, pool=mempool): logits = model( input_ids, position_ids=position_ids, inference_params=inference_params, num_last_tokens=decoding_seqlen, ).logits def run(new_input_ids, new_position_ids, seqlen): inference_params.lengths_per_sample[:] = seqlen input_ids.copy_(new_input_ids) position_ids.copy_(new_position_ids) graph.replay() return logits.clone() inference_params.seqlen_offset = seqlen_offset_og return run ================================================ FILE: mamba_ssm/utils/hf.py ================================================ import json import torch from transformers.utils import WEIGHTS_NAME, CONFIG_NAME from transformers.utils.hub import cached_file def load_config_hf(model_name): resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False) return json.load(open(resolved_archive_file)) def load_state_dict_hf(model_name, device=None, dtype=None): # If not fp32, then we don't want to load directly to the GPU mapped_device = "cpu" if dtype not in [torch.float32, None] else device resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False) return torch.load(resolved_archive_file, map_location=mapped_device) # Convert dtype before moving to GPU to save memory if dtype is not None: state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()} state_dict = {k: v.to(device=device) for k, v in state_dict.items()} return state_dict ================================================ FILE: mamba_ssm/utils/torch.py ================================================ import torch from functools import partial from typing import Callable def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool): def decorator(*args, **kwargs): if cuda_amp_deprecated: kwargs["device_type"] = "cuda" return dec(*args, **kwargs) return decorator if hasattr(torch.amp, "custom_fwd"): # type: ignore[attr-defined] deprecated = True from torch.amp import custom_fwd, custom_bwd # type: ignore[attr-defined] else: deprecated = False from torch.cuda.amp import custom_fwd, custom_bwd custom_fwd = custom_amp_decorator(custom_fwd, deprecated) custom_bwd = custom_amp_decorator(custom_bwd, deprecated) ================================================ FILE: pyproject.toml ================================================ [project] name = "mamba_ssm" description = "Mamba state-space model" readme = "README.md" authors = [ { name = "Tri Dao", email = "tri@tridao.me" }, { name = "Albert Gu", email = "agu@cs.cmu.edu" } ] requires-python = ">= 3.9" dynamic = ["version"] license = { file = "LICENSE" } # Include a LICENSE file in your repo keywords = ["cuda", "pytorch", "state-space model"] classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: Apache Software License", "Operating System :: Unix" ] dependencies = [ "torch", "triton", "ninja", "einops", "transformers", "packaging", "setuptools>=61.0.0", ] [project.urls] Repository = "https://github.com/state-spaces/mamba" [project.optional-dependencies] causal-conv1d = [ "causal-conv1d>=1.2.0" ] dev = [ "pytest" ] [build-system] # torch is intentionally excluded: pip's build isolation would install # torch-cpu from PyPI, ignoring the user's CUDA-enabled torch. # Users building from source should install torch first, then: # pip install mamba-ssm --no-build-isolation requires = [ "setuptools>=61.0.0", "wheel", "packaging", "ninja", ] build-backend = "setuptools.build_meta" ================================================ FILE: rocm_patch/rocm6_0.patch ================================================ --- /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h 2023-12-12 20:11:48.000000000 +0000 +++ rocm_update_files/amd_hip_bf16.h 2024-05-20 17:40:26.983349079 +0000 @@ -137,7 +137,7 @@ * \ingroup HIP_INTRINSIC_BFLOAT16_CONV * \brief Converts float to bfloat16 */ -__HOST_DEVICE__ __hip_bfloat16 __float2bfloat16(float f) { +__HOST_DEVICE__ static inline __hip_bfloat16 __float2bfloat16(float f) { __hip_bfloat16 ret; union { float fp32; @@ -181,7 +181,7 @@ * \ingroup HIP_INTRINSIC_BFLOAT162_CONV * \brief Converts and moves bfloat162 to float2 */ -__HOST_DEVICE__ float2 __bfloat1622float2(const __hip_bfloat162 a) { +__HOST_DEVICE__ static inline float2 __bfloat1622float2(const __hip_bfloat162 a) { return float2{__bfloat162float(a.x), __bfloat162float(a.y)}; } @@ -209,7 +209,7 @@ * \ingroup HIP_INTRINSIC_BFLOAT162_CONV * \brief Convert double to __hip_bfloat16 */ -__HOST_DEVICE__ __hip_bfloat16 __double2bfloat16(const double a) { +__HOST_DEVICE__ static inline __hip_bfloat16 __double2bfloat16(const double a) { return __float2bfloat16((float)a); } @@ -217,7 +217,7 @@ * \ingroup HIP_INTRINSIC_BFLOAT162_CONV * \brief Convert float2 to __hip_bfloat162 */ -__HOST_DEVICE__ __hip_bfloat162 __float22bfloat162_rn(const float2 a) { +__HOST_DEVICE__ static inline __hip_bfloat162 __float22bfloat162_rn(const float2 a) { return __hip_bfloat162{__float2bfloat16(a.x), __float2bfloat16(a.y)}; } @@ -247,7 +247,7 @@ * \ingroup HIP_INTRINSIC_BFLOAT162_CONV * \brief Converts high 16 bits of __hip_bfloat162 to float and returns the result */ -__HOST_DEVICE__ float __high2float(const __hip_bfloat162 a) { return __bfloat162float(a.y); } +__HOST_DEVICE__ static inline float __high2float(const __hip_bfloat162 a) { return __bfloat162float(a.y); } /** * \ingroup HIP_INTRINSIC_BFLOAT162_CONV @@ -275,7 +275,7 @@ * \ingroup HIP_INTRINSIC_BFLOAT162_CONV * \brief Converts low 16 bits of __hip_bfloat162 to float and returns the result */ -__HOST_DEVICE__ float __low2float(const __hip_bfloat162 a) { return __bfloat162float(a.x); } +__HOST_DEVICE__ static inline float __low2float(const __hip_bfloat162 a) { return __bfloat162float(a.x); } /** * \ingroup HIP_INTRINSIC_BFLOAT162_CONV ================================================ FILE: setup.py ================================================ # Copyright (c) 2023, Albert Gu, Tri Dao. import sys import warnings import os import re import ast from pathlib import Path from packaging.version import parse, Version import platform import shutil from setuptools import setup, find_packages import subprocess import urllib.request import urllib.error from wheel.bdist_wheel import bdist_wheel as _bdist_wheel import torch from torch.utils.cpp_extension import ( BuildExtension, CUDAExtension, CUDA_HOME, HIP_HOME ) with open("README.md", "r", encoding="utf-8") as fh: long_description = fh.read() # ninja build does not work unless include_dirs are abs path this_dir = os.path.dirname(os.path.abspath(__file__)) PACKAGE_NAME = "mamba_ssm" BASE_WHEEL_URL = "https://github.com/state-spaces/mamba/releases/download/{tag_name}/{wheel_name}" # FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels # SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation FORCE_BUILD = os.getenv("MAMBA_FORCE_BUILD", "FALSE") == "TRUE" SKIP_CUDA_BUILD = os.getenv("MAMBA_SKIP_CUDA_BUILD", "FALSE") == "TRUE" # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI FORCE_CXX11_ABI = os.getenv("MAMBA_FORCE_CXX11_ABI", "FALSE") == "TRUE" def get_platform(): """ Returns the platform name as used in wheel filenames. """ if sys.platform.startswith("linux"): return f"linux_{platform.machine()}" elif sys.platform == "darwin": mac_version = ".".join(platform.mac_ver()[0].split(".")[:2]) return f"macosx_{mac_version}_x86_64" elif sys.platform == "win32": return "win_amd64" else: raise ValueError("Unsupported platform: {}".format(sys.platform)) def get_cuda_bare_metal_version(cuda_dir): raw_output = subprocess.check_output( [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True ) output = raw_output.split() release_idx = output.index("release") + 1 bare_metal_ver = parse(output[release_idx].split(",")[0]) return raw_output, bare_metal_ver def get_hip_version(rocm_dir): hipcc_bin = "hipcc" if rocm_dir is None else os.path.join(rocm_dir, "bin", "hipcc") try: raw_output = subprocess.check_output( [hipcc_bin, "--version"], universal_newlines=True ) except Exception as e: print( f"hip installation not found: {e} ROCM_PATH={os.environ.get('ROCM_PATH')}" ) return None, None for line in raw_output.split("\n"): if "HIP version" in line: rocm_version = parse(line.split()[-1].rstrip('-').replace('-', '+')) # local version is not parsed correctly return line, rocm_version return None, None def get_torch_hip_version(): if torch.version.hip: return parse(torch.version.hip.split()[-1].rstrip('-').replace('-', '+')) else: return None def check_if_hip_home_none(global_option: str) -> None: if HIP_HOME is not None: return # warn instead of error because user could be downloading prebuilt wheels, so hipcc won't be necessary # in that case. warnings.warn( f"{global_option} was requested, but hipcc was not found. Are you sure your environment has hipcc available?" ) def check_if_cuda_home_none(global_option: str) -> None: if CUDA_HOME is not None: return # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary # in that case. warnings.warn( f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " "only images whose names contain 'devel' will provide nvcc." ) def append_nvcc_threads(nvcc_extra_args): return nvcc_extra_args + ["--threads", "4"] cmdclass = {} ext_modules = [] HIP_BUILD = bool(torch.version.hip) if not SKIP_CUDA_BUILD: print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) TORCH_MAJOR = int(torch.__version__.split(".")[0]) TORCH_MINOR = int(torch.__version__.split(".")[1]) cc_flag = [] if HIP_BUILD: check_if_hip_home_none(PACKAGE_NAME) rocm_home = os.getenv("ROCM_PATH") _, hip_version = get_hip_version(rocm_home) if HIP_HOME is not None: if hip_version < Version("6.0"): raise RuntimeError( f"{PACKAGE_NAME} is only supported on ROCm 6.0 and above. " "Note: make sure HIP has a supported version by running hipcc --version." ) if hip_version == Version("6.0"): warnings.warn( f"{PACKAGE_NAME} requires a patch to be applied when running on ROCm 6.0. " "Refer to the README.md for detailed instructions.", UserWarning ) cc_flag.append("-DBUILD_PYTHON_PACKAGE") else: check_if_cuda_home_none(PACKAGE_NAME) # Check, if CUDA11 is installed for compute capability 8.0 if CUDA_HOME is not None: _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) if bare_metal_version < Version("11.6"): raise RuntimeError( f"{PACKAGE_NAME} is only supported on CUDA 11.6 and above. " "Note: make sure nvcc has a supported version by running nvcc -V." ) # If system CUDA and PyTorch CUDA have different major versions, # clear TORCH_CUDA_ARCH_LIST to prevent cpp_extension from erroring torch_cuda_version = parse(torch.version.cuda) if bare_metal_version.major != torch_cuda_version.major: os.environ["TORCH_CUDA_ARCH_LIST"] = "" cc_flag.append("-gencode") cc_flag.append("arch=compute_75,code=sm_75") cc_flag.append("-gencode") cc_flag.append("arch=compute_80,code=sm_80") cc_flag.append("-gencode") cc_flag.append("arch=compute_87,code=sm_87") if bare_metal_version >= Version("11.8"): cc_flag.append("-gencode") cc_flag.append("arch=compute_90,code=sm_90") if bare_metal_version >= Version("12.8"): cc_flag.append("-gencode") cc_flag.append("arch=compute_100,code=sm_100") cc_flag.append("-gencode") cc_flag.append("arch=compute_120,code=sm_120") if bare_metal_version >= Version("13.0"): cc_flag.append("-gencode") cc_flag.append("arch=compute_103,code=sm_103") cc_flag.append("-gencode") cc_flag.append("arch=compute_110,code=sm_110") cc_flag.append("-gencode") cc_flag.append("arch=compute_121,code=sm_121") # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as # torch._C._GLIBCXX_USE_CXX11_ABI # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 if FORCE_CXX11_ABI: torch._C._GLIBCXX_USE_CXX11_ABI = True if HIP_BUILD: extra_compile_args = { "cxx": ["-O3", "-std=c++17"], "nvcc": [ "-O3", "-std=c++17", f"--offload-arch={os.getenv('HIP_ARCHITECTURES', 'native')}", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "-fgpu-flush-denormals-to-zero", ] + cc_flag, } else: extra_compile_args = { "cxx": ["-O3", "-std=c++17"], "nvcc": append_nvcc_threads( [ "-O3", "-std=c++17", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_BFLOAT16_OPERATORS__", "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", "-U__CUDA_NO_BFLOAT162_OPERATORS__", "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", "--expt-relaxed-constexpr", "--expt-extended-lambda", "--use_fast_math", "--ptxas-options=-v", "-lineinfo", ] + cc_flag ), } ext_modules.append( CUDAExtension( name="selective_scan_cuda", sources=[ "csrc/selective_scan/selective_scan.cpp", "csrc/selective_scan/selective_scan_fwd_fp32.cu", "csrc/selective_scan/selective_scan_fwd_fp16.cu", "csrc/selective_scan/selective_scan_fwd_bf16.cu", "csrc/selective_scan/selective_scan_bwd_fp32_real.cu", "csrc/selective_scan/selective_scan_bwd_fp32_complex.cu", "csrc/selective_scan/selective_scan_bwd_fp16_real.cu", "csrc/selective_scan/selective_scan_bwd_fp16_complex.cu", "csrc/selective_scan/selective_scan_bwd_bf16_real.cu", "csrc/selective_scan/selective_scan_bwd_bf16_complex.cu", ], extra_compile_args=extra_compile_args, include_dirs=[Path(this_dir) / "csrc" / "selective_scan"], ) ) def get_package_version(): with open(Path(this_dir) / PACKAGE_NAME / "__init__.py", "r") as f: version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE) public_version = ast.literal_eval(version_match.group(1)) local_version = os.environ.get("MAMBA_LOCAL_VERSION") if local_version: return f"{public_version}+{local_version}" else: return str(public_version) def get_wheel_url(): # Determine the version numbers that will be used to determine the correct wheel torch_version_raw = parse(torch.__version__) if HIP_BUILD: # We're using the HIP version used to build torch, not the one currently installed torch_hip_version = get_torch_hip_version() hip_ver = f"{torch_hip_version.major}{torch_hip_version.minor}" else: # We're using the CUDA version used to build torch, not the one currently installed # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) torch_cuda_version = parse(torch.version.cuda) # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3 # to save CI time. Minor versions should be compatible. if torch_cuda_version.major == 11: torch_cuda_version = parse("11.8") elif torch_cuda_version.major == 12: torch_cuda_version = parse("12.3") elif torch_cuda_version.major == 13: torch_cuda_version = parse("13.0") else: raise ValueError(f"CUDA version {torch_cuda_version} not supported") cuda_version = f"{torch_cuda_version.major}" gpu_compute_version = hip_ver if HIP_BUILD else cuda_version cuda_or_hip = "hip" if HIP_BUILD else "cu" python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" platform_name = get_platform() mamba_ssm_version = get_package_version() if os.environ.get("NVIDIA_PRODUCT_NAME", "") == "PyTorch": torch_version = str(os.environ.get("NVIDIA_PYTORCH_VERSION")) # On NGC images, use the container's CUDA version (matching how wheels are built) ngc_cuda_version = os.environ.get("CUDA_VERSION", "") if ngc_cuda_version: cuda_version = str(parse(ngc_cuda_version).major) else: torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}" cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper() # Determine wheel URL based on CUDA version, torch version, python version and OS wheel_filename = f"{PACKAGE_NAME}-{mamba_ssm_version}+{cuda_or_hip}{gpu_compute_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" wheel_url = BASE_WHEEL_URL.format( tag_name=f"v{mamba_ssm_version}", wheel_name=wheel_filename ) return wheel_url, wheel_filename class CachedWheelsCommand(_bdist_wheel): """ The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot find an existing wheel (which is currently the case for all installs). We use the environment parameters to detect whether there is already a pre-built version of a compatible wheel available and short-circuits the standard full build pipeline. """ def run(self): if FORCE_BUILD: return super().run() wheel_url, wheel_filename = get_wheel_url() print("Guessing wheel URL: ", wheel_url) try: urllib.request.urlretrieve(wheel_url, wheel_filename) # Make the archive # Lifted from the root wheel processing command # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85 if not os.path.exists(self.dist_dir): os.makedirs(self.dist_dir) impl_tag, abi_tag, plat_tag = self.get_tag() archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") print("Raw wheel path", wheel_path) shutil.move(wheel_filename, wheel_path) except urllib.error.HTTPError: print("Precompiled wheel not found. Building from source...") # If the wheel could not be downloaded, build from source super().run() setup( name=PACKAGE_NAME, version=get_package_version(), packages=find_packages( exclude=( "build", "csrc", "include", "tests", "dist", "docs", "benchmarks", "mamba_ssm.egg-info", ) ), author="Tri Dao, Albert Gu", author_email="tri@tridao.me, agu@cs.cmu.edu", description="Mamba state-space model", long_description=long_description, long_description_content_type="text/markdown", url="https://github.com/state-spaces/mamba", classifiers=[ "Programming Language :: Python :: 3", "License :: OSI Approved :: Apache Software License", "Operating System :: Unix", ], ext_modules=ext_modules, cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension} if ext_modules else { "bdist_wheel": CachedWheelsCommand, }, python_requires=">=3.9", install_requires=[ "torch", "packaging", "ninja", "einops", "triton>=3.5.0", "transformers", "tilelang>=0.1.7.post3", "nvidia-cutlass-dsl==4.4.1", "quack-kernels==0.3.1", # "causal_conv1d>=1.4.0", ], ) ================================================ FILE: tests/benchmark_determinism_kernels.py ================================================ #!/usr/bin/env python # Copyright (c) 2024, Tri Dao, Albert Gu. import gc import math import torch from triton.testing import do_bench from mamba_ssm.utils.determinism import set_deterministic_mode MODEL_PRESETS = { "small": {"nheads": 32, "headdim": 64, "dstate": 64, "ngroups": 1}, "nemotronh-56b": {"nheads": 256, "headdim": 64, "dstate": 256, "ngroups": 8}, } def _reset_peak_memory() -> None: gc.collect() torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() torch.cuda.synchronize() def _peak_memory_mb(fn, *, warmup: int = 3) -> float: for _ in range(warmup): fn() torch.cuda.synchronize() _reset_peak_memory() fn() torch.cuda.synchronize() return torch.cuda.max_memory_allocated() / (1024 * 1024) def make_tensors(*, batch: int, seqlen: int, nheads: int, headdim: int, dstate: int, ngroups: int, chunk_size: int, dtype: torch.dtype = torch.bfloat16) -> dict[str, torch.Tensor]: device = "cuda" nchunks = math.ceil(seqlen / chunk_size) return { "x": torch.randn(batch, seqlen, nheads, headdim, device=device, dtype=dtype), "B": torch.randn(batch, seqlen, ngroups, dstate, device=device, dtype=dtype), "C": torch.randn(batch, seqlen, ngroups, dstate, device=device, dtype=dtype), "dt": torch.randn(batch, nheads, nchunks, chunk_size, device=device, dtype=torch.float32), "dA_cumsum": torch.randn(batch, nheads, nchunks, chunk_size, device=device, dtype=torch.float32), "dstates": torch.randn(batch, nchunks, nheads, headdim, dstate, device=device, dtype=torch.float32), "dout": torch.randn(batch, seqlen, nheads, headdim, device=device, dtype=dtype), "ddA": torch.randn(batch, nheads, nchunks, chunk_size, device=device, dtype=torch.float32), "ddt_out": torch.randn(batch, nheads, nchunks, chunk_size, device=device, dtype=torch.float32), "dt_raw": torch.randn(batch, seqlen, nheads, device=device, dtype=dtype), "A": torch.randn(nheads, device=device, dtype=torch.float32) * -1, "dt_bias": torch.randn(nheads, device=device, dtype=torch.float32), "prev_states": torch.randn(batch, nchunks, nheads, headdim, dstate, device=device, dtype=torch.float32), "cb": torch.randn(batch, nchunks, ngroups, chunk_size, chunk_size, device=device, dtype=dtype), } def get_benchmarks(t: dict[str, torch.Tensor], *, ngroups: int): from mamba_ssm.ops.triton.ssd_chunk_state import ( _chunk_cumsum_bwd, _chunk_state_bwd_db, _chunk_state_bwd_ddAcs_stable, _chunk_state_bwd_dx, ) from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_dC, _chunk_scan_bwd_dx from mamba_ssm.ops.triton.ssd_combined import _chunk_scan_chunk_state_bwd_dx x = t["x"].contiguous() B = t["B"].contiguous() C = t["C"].contiguous() dout = t["dout"].contiguous() dstates = t["dstates"].contiguous() return [ ("chunk_cumsum_bwd", lambda: _chunk_cumsum_bwd(t["ddA"], t["ddt_out"], t["dt_raw"], t["A"], dt_bias=t["dt_bias"], dt_softplus=True)), ("chunk_state_bwd_dx", lambda: _chunk_state_bwd_dx(B, x, t["dt"], t["dA_cumsum"], dstates)), ("chunk_state_bwd_db", lambda: _chunk_state_bwd_db(x, t["dt"], t["dA_cumsum"], dstates, B=B, ngroups=ngroups)), ("chunk_state_bwd_ddAcs", lambda: _chunk_state_bwd_ddAcs_stable(B, x, t["dt"], t["dA_cumsum"], dstates)), ("chunk_scan_bwd_dC", lambda: _chunk_scan_bwd_dC(t["prev_states"], t["dA_cumsum"], dout, C=C, ngroups=ngroups)), ("chunk_scan_bwd_dx", lambda: _chunk_scan_bwd_dx(t["cb"], x, t["dt"], t["dA_cumsum"], dout)), ("combined_bwd_dx", lambda: _chunk_scan_chunk_state_bwd_dx(x, t["dt"], t["dA_cumsum"], B, t["cb"], dout, dstates)), ] def _run_one(fn, *, deterministic: bool, warmup: int, rep: int): set_deterministic_mode(deterministic) ms = do_bench(fn, warmup=warmup, rep=rep, return_mode="median") peak_mb = _peak_memory_mb(fn, warmup=1) return ms, peak_mb def main() -> None: import argparse parser = argparse.ArgumentParser(description="Benchmark determinism overhead for key Triton backward kernels") parser.add_argument("--preset", choices=sorted(MODEL_PRESETS.keys()), default="small") parser.add_argument("--warmup", type=int, default=25) parser.add_argument("--rep", type=int, default=100) parser.add_argument("--batch", type=int, default=4) parser.add_argument("--seqlen", type=int, default=2048) parser.add_argument("--chunk-size", type=int, default=256) args = parser.parse_args() if not torch.cuda.is_available(): raise SystemExit("CUDA not available") p = MODEL_PRESETS[args.preset] tensors = make_tensors( batch=args.batch, seqlen=args.seqlen, nheads=p["nheads"], headdim=p["headdim"], dstate=p["dstate"], ngroups=p["ngroups"], chunk_size=args.chunk_size, ) benches = get_benchmarks(tensors, ngroups=p["ngroups"]) print(f"GPU: {torch.cuda.get_device_name(0)}") print(f"preset={args.preset} batch={args.batch} seqlen={args.seqlen} chunk_size={args.chunk_size}") print(f"{'kernel':<20} {'ms':>9} {'det_ms':>9} {'ms_%':>6} {'MB':>9} {'det_MB':>9} {'MB_%':>6}") rows = [] try: for name, fn in benches: ms, mb = _run_one(fn, deterministic=False, warmup=args.warmup, rep=args.rep) det_ms, det_mb = _run_one(fn, deterministic=True, warmup=args.warmup, rep=args.rep) ms_pct = (det_ms / ms - 1.0) * 100.0 mb_pct = (det_mb / mb - 1.0) * 100.0 if mb else 0.0 rows.append((name, ms, det_ms, ms_pct, mb, det_mb, mb_pct)) print(f"{name:<20} {ms:>9.3f} {det_ms:>9.3f} {ms_pct:>+6.0f}% {mb:>9.1f} {det_mb:>9.1f} {mb_pct:>+6.0f}%") finally: set_deterministic_mode(None) total_ms = sum(r[1] for r in rows) total_det_ms = sum(r[2] for r in rows) max_mb = max(r[4] for r in rows) if rows else 0.0 max_det_mb = max(r[5] for r in rows) if rows else 0.0 total_pct = (total_det_ms / total_ms - 1.0) * 100.0 if total_ms else 0.0 max_mb_pct = (max_det_mb / max_mb - 1.0) * 100.0 if max_mb else 0.0 print(f"{'TOTAL/MAX':<20} {total_ms:>9.3f} {total_det_ms:>9.3f} {total_pct:>+6.0f}% {max_mb:>9.1f} {max_det_mb:>9.1f} {max_mb_pct:>+6.0f}%") if __name__ == "__main__": main() ================================================ FILE: tests/ops/cute/test_mamba3_mimo_step.py ================================================ """ Mamba-3 MIMO Step Function Tests Copyright (c) 2026, Dao AI Lab, Goombalab Pytest coverage for Mamba3.step() and mixed forward/step decoding. Usage: pytest -q -s -p no:warnings tests/ops/cute/test_mamba3_mimo_step.py # For correctness tests python tests/ops/cute/test_mamba3_mimo_step.py # For benchmark Remove the -s flag for less verbose output. """ import logging import sys import warnings from dataclasses import dataclass, field from pathlib import Path from typing import Optional import pytest import torch from torch import Tensor warnings.filterwarnings("ignore") logging.disable(logging.WARNING) BATCH = 128 SEQLEN = 32 NHEADS = 64 HDIM = 64 DSTATE = 128 MIMO_DIM = 4 USE_TILELANG = True DTYPE = torch.bfloat16 DEVICE = "cuda" RTOL = 0.1 ATOL = 0.1 def _require_cuda_and_kernel_deps() -> None: if not torch.cuda.is_available(): pytest.skip("CUDA is required for mamba3 step tests") pytest.importorskip("tilelang") pytest.importorskip("triton") def _mamba3_cls(): from mamba_ssm.modules.mamba3 import Mamba3 return Mamba3 @pytest.fixture(scope="module", autouse=True) def _kernel_deps() -> None: _require_cuda_and_kernel_deps() # Adapted from # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31 @dataclass class InferenceParams: """Inference parameters used to store context during inference.""" max_seqlen: int max_batch_size: int seqlen_offset: int = 0 batch_size_offset: int = 0 key_value_memory_dict: dict = field(default_factory=dict) new_key_value_memory_dict: dict = field(default_factory=dict) lengths_per_sample: Optional[Tensor] = None def reset(self, max_seqlen, max_batch_size): self.max_seqlen = max_seqlen self.max_batch_size = max_batch_size self.seqlen_offset = 0 if self.lengths_per_sample is not None: self.lengths_per_sample.zero_() @dataclass class RunOutputs: config_label: str split: int out_fwd_fp32: Tensor outputs_step: Tensor prefix_out: Tensor outputs_mixed: Tensor def _case_config(*, is_outproj_norm: bool) -> dict: d_model = NHEADS * HDIM // 2 return { "d_model": d_model, "d_state": DSTATE, "headdim": HDIM, "is_mimo": True, "mimo_rank": MIMO_DIM, "chunk_size": 64 // MIMO_DIM, "dtype": DTYPE, "device": DEVICE, "layer_idx": 0, "use_tilelang": USE_TILELANG, "is_outproj_norm": is_outproj_norm, } def _diff_stats(actual: Tensor, expected: Tensor) -> str: diff = (actual.float() - expected.float()).abs() return f"max_abs={diff.max().item():.6e}, mean_abs={diff.mean().item():.6e}" def _assert_close( actual: Tensor, expected: Tensor, *, label: str, cfg: str, step: Optional[int] = None, ) -> None: try: torch.testing.assert_close( actual.float(), expected.float(), rtol=RTOL, atol=ATOL, ) except AssertionError as err: location = f", step={step}" if step is not None else "" stats = _diff_stats(actual, expected) raise AssertionError( f"{label} assertion failed for {cfg}{location} ({stats})" ) from err def _run_case(*, is_outproj_norm: bool) -> RunOutputs: Mamba3 = _mamba3_cls() cfg = _case_config(is_outproj_norm=is_outproj_norm) config_label = ( f"use_tilelang={cfg['use_tilelang']}, " f"is_outproj_norm={cfg['is_outproj_norm']}, " f"batch={BATCH}, seqlen={SEQLEN}, " f"nheads={NHEADS}, hdim={HDIM}, dstate={DSTATE}, mimo_dim={MIMO_DIM}" ) torch.manual_seed(42) torch.cuda.manual_seed_all(42) model_fwd = Mamba3(**cfg) model_fwd.eval() cfg_fp32 = {**cfg, "dtype": torch.float32} torch.manual_seed(42) torch.cuda.manual_seed_all(42) model_fwd_fp32 = Mamba3(**cfg_fp32) model_fwd_fp32.eval() model_fwd_fp32.load_state_dict( {k: v.float() for k, v in model_fwd.state_dict().items()}, strict=False, ) torch.manual_seed(42) torch.cuda.manual_seed_all(42) model_step = Mamba3(**cfg) model_step.eval() model_step.load_state_dict(model_fwd.state_dict(), strict=False) torch.manual_seed(42) torch.cuda.manual_seed_all(42) model_mix = Mamba3(**cfg) model_mix.eval() model_mix.load_state_dict(model_fwd.state_dict(), strict=False) u = torch.randn(BATCH, SEQLEN, cfg["d_model"], device=DEVICE, dtype=DTYPE) with torch.no_grad(): out_fwd_fp32 = model_fwd_fp32(u.float()) state = model_step.allocate_inference_cache(BATCH, 1, device=DEVICE, dtype=DTYPE) outputs_step = [] for t in range(SEQLEN): out_step, nxt_angle_state, state_out, nxt_k_state, nxt_v_state = model_step.step( u[:, t], *state ) state = (nxt_angle_state, state_out, nxt_k_state, nxt_v_state) outputs_step.append(out_step) outputs_step = torch.stack(outputs_step, dim=1) split = SEQLEN // 2 assert 0 < split < SEQLEN inference_params = InferenceParams(max_seqlen=SEQLEN, max_batch_size=BATCH) prefix_out = model_mix(u[:, :split], inference_params=inference_params) state = inference_params.key_value_memory_dict[model_mix.layer_idx] mixed_suffix = [] for t in range(split, SEQLEN): out_step, nxt_angle_state, state_out, nxt_k_state, nxt_v_state = model_mix.step( u[:, t], *state ) state = (nxt_angle_state, state_out, nxt_k_state, nxt_v_state) mixed_suffix.append(out_step) outputs_mixed = torch.cat([prefix_out, torch.stack(mixed_suffix, dim=1)], dim=1) return RunOutputs( config_label=config_label, split=split, out_fwd_fp32=out_fwd_fp32, outputs_step=outputs_step, prefix_out=prefix_out, outputs_mixed=outputs_mixed, ) @pytest.mark.parametrize( "is_outproj_norm", [ pytest.param(False, id="outproj_norm_false"), pytest.param(True, id="outproj_norm_true"), ], ) def test_step_matches_forward_fp32(is_outproj_norm: bool) -> None: outputs = _run_case(is_outproj_norm=is_outproj_norm) for t in range(SEQLEN): _assert_close( outputs.outputs_step[:, t], outputs.out_fwd_fp32[:, t], label="pure-step", cfg=outputs.config_label, step=t, ) _assert_close( outputs.prefix_out, outputs.out_fwd_fp32[:, :outputs.split], label="mixed-prefix", cfg=outputs.config_label, ) for t in range(outputs.split, SEQLEN): _assert_close( outputs.outputs_mixed[:, t], outputs.out_fwd_fp32[:, t], label="mixed-suffix", cfg=outputs.config_label, step=t, ) def run_step_benchmark(*, is_outproj_norm: bool) -> None: _require_cuda_and_kernel_deps() from triton.testing import do_bench_cudagraph Mamba3 = _mamba3_cls() cfg = _case_config(is_outproj_norm=is_outproj_norm) rotate_str = "halved" if USE_TILELANG else "pairwise" torch.manual_seed(42) torch.cuda.manual_seed_all(42) model_step = Mamba3(**cfg) model_step.eval() state_bm = model_step.allocate_inference_cache(BATCH, 1, device=DEVICE, dtype=DTYPE) u_step_bm = torch.randn(BATCH, cfg["d_model"], device=DEVICE, dtype=DTYPE) with torch.no_grad(): model_step.step(u_step_bm, *state_bm) def full_step_fn(): out, _, _, _, _ = model_step.step(u_step_bm, *state_bm) return out ms_full = do_bench_cudagraph(full_step_fn, rep=30) dtype_size = torch.tensor([], dtype=DTYPE).element_size() state_dtype_size = 4 num_rope_angles = model_step.num_rope_angles bytes_read = ( BATCH * cfg["d_model"] * dtype_size + BATCH * NHEADS * HDIM * DSTATE * state_dtype_size + BATCH * NHEADS * num_rope_angles * state_dtype_size + BATCH * MIMO_DIM * NHEADS * DSTATE * dtype_size + BATCH * NHEADS * HDIM * dtype_size ) bytes_write = ( BATCH * NHEADS * HDIM * dtype_size + BATCH * NHEADS * HDIM * DSTATE * state_dtype_size + BATCH * NHEADS * num_rope_angles * state_dtype_size + BATCH * MIMO_DIM * NHEADS * DSTATE * dtype_size + BATCH * NHEADS * HDIM * dtype_size ) total_bytes = bytes_read + bytes_write bw = total_bytes / (ms_full * 1e-3) / 1e9 print("\n" + "=" * 70) print( "Benchmark: Mamba3.step() " f"(rotation={rotate_str}, is_outproj_norm={is_outproj_norm})" ) print("=" * 70) print( f" batch={BATCH}, d_model={cfg['d_model']}, nheads={NHEADS}, " f"hdim={HDIM}, dstate={DSTATE}, mimo_dim={MIMO_DIM}" ) print(f" Time per step: {ms_full:.4f} ms") print( " Memory I/O: " f"{total_bytes / 1e6:.2f} MB " f"(Read: {bytes_read / 1e6:.2f} MB, Write: {bytes_write / 1e6:.2f} MB)" ) print(f" Bandwidth: {bw:.1f} GB/s") if __name__ == "__main__": run_step_benchmark(is_outproj_norm=False) run_step_benchmark(is_outproj_norm=True) ================================================ FILE: tests/ops/test_selective_scan.py ================================================ # Copyright (C) 2023, Tri Dao. import math import torch import torch.nn.functional as F import pytest from einops import rearrange from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, mamba_inner_ref # @pytest.mark.parametrize('wtype', [torch.float32, torch.complex64]) @pytest.mark.parametrize('wtype', [torch.float32]) # @pytest.mark.parametrize('itype', [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize('itype', [torch.float32]) # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 372, 512, 784, 1024, 1134, 2048, 4096]) @pytest.mark.parametrize('seqlen', [128, 256, 512, 1024, 2048, 4096]) # @pytest.mark.parametrize('seqlen', [128]) # @pytest.mark.parametrize("return_last_state", [False, True]) @pytest.mark.parametrize("return_last_state", [True]) # @pytest.mark.parametrize('has_delta_bias', [False, True]) @pytest.mark.parametrize('has_delta_bias', [True]) # @pytest.mark.parametrize('delta_softplus', [False, True]) @pytest.mark.parametrize('delta_softplus', [True]) # @pytest.mark.parametrize('has_z', [False, True]) @pytest.mark.parametrize('has_z', [True]) # @pytest.mark.parametrize('has_D', [False, True]) @pytest.mark.parametrize('has_D', [True]) @pytest.mark.parametrize("varBC_groups", [1, 2]) # @pytest.mark.parametrize("varBC_groups", [1]) # @pytest.mark.parametrize("is_variable_C", [False, True]) @pytest.mark.parametrize("is_variable_C", [True]) # @pytest.mark.parametrize("is_variable_B", [False, True]) @pytest.mark.parametrize("is_variable_B", [True]) def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z, has_delta_bias, delta_softplus, return_last_state, seqlen, itype, wtype): if varBC_groups > 1 and (not is_variable_B or not is_variable_C): pytest.skip() # This config is not applicable device = 'cuda' rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3) if itype == torch.bfloat16: rtol, atol = 3e-2, 5e-2 rtolw, atolw = (1e-3, 1e-3) if has_z: # If we have z, the errors on the weights seem higher rtolw = max(rtolw, rtol) atolw = max(atolw, atol) # set seed torch.random.manual_seed(0) batch_size = 2 dim = 4 dstate = 8 is_complex = wtype == torch.complex64 A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_() if not is_variable_B: B_shape = (dim, dstate) elif varBC_groups == 1: B_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2) else: B_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2) B = torch.randn(*B_shape, device=device, dtype=wtype if not is_variable_B else itype, requires_grad=True) if not is_variable_C: C_shape = (dim, dstate) elif varBC_groups == 1: C_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2) else: C_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2) C = torch.randn(*C_shape, device=device, dtype=wtype if not is_variable_C else itype, requires_grad=True) if has_D: D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) else: D = None if has_z: z = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True) else: z = None if has_delta_bias: delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_() else: delta_bias = None u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True) delta = (0.5 * torch.rand(batch_size, dim, seqlen, device=device, dtype=itype)).requires_grad_() A_ref = A.detach().clone().requires_grad_() B_ref = B.detach().clone().requires_grad_() C_ref = C.detach().clone().requires_grad_() D_ref = D.detach().clone().requires_grad_() if D is not None else None z_ref = z.detach().clone().requires_grad_() if z is not None else None u_ref = u.detach().clone().requires_grad_() delta_ref = delta.detach().clone().requires_grad_() delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None out, *rest = selective_scan_fn( u, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=delta_softplus, return_last_state=return_last_state ) if return_last_state: state = rest[0] out_ref, *rest = selective_scan_ref( u_ref, delta_ref, A_ref, B_ref, C_ref, D_ref, z=z_ref, delta_bias=delta_bias_ref, delta_softplus=delta_softplus, return_last_state=return_last_state ) if return_last_state: state_ref = rest[0] # dA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) # dt_u = delta * u print(f'Output max diff: {(out - out_ref).abs().max().item()}') print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) if return_last_state: print(f'State max diff: {(state - state_ref).abs().max().item()}') assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) g = torch.randn_like(out) out_ref.backward(g) out.backward(g) print(f'du max diff: {(u.grad - u_ref.grad).abs().max().item()}') print(f'ddelta max diff: {(delta.grad - delta_ref.grad).abs().max().item()}') print(f'dA max diff: {(A.grad - A_ref.grad).abs().max().item()}') print(f'dB max diff: {(B.grad - B_ref.grad).abs().max().item()}') print(f'dC max diff: {(C.grad - C_ref.grad).abs().max().item()}') if has_D: print(f'dD max diff: {(D.grad - D_ref.grad).abs().max().item()}') if has_z: print(f'dz max diff: {(z.grad - z_ref.grad).abs().max().item()}') if has_delta_bias: print(f'ddelta_bias max diff: {(delta_bias.grad - delta_bias_ref.grad).abs().max().item()}') assert torch.allclose(u.grad, u_ref.grad.to(dtype=itype), rtol=rtol * 2, atol=atol * 2) assert torch.allclose(delta.grad, delta_ref.grad.to(dtype=itype), rtol=rtol * 5, atol=atol * 10) assert torch.allclose(A.grad, A_ref.grad, rtol=rtolw, atol=atolw * 5) assert torch.allclose(B.grad, B_ref.grad, rtol=rtolw if not is_variable_B else rtol, atol=atolw if not is_variable_B else atol) assert torch.allclose(C.grad, C_ref.grad, rtol=rtolw if not is_variable_C else rtol, atol=atolw if not is_variable_C else atol) if has_D: assert torch.allclose(D.grad, D_ref.grad, rtol=rtolw, atol=atolw) if has_z: assert torch.allclose(z.grad, z_ref.grad, rtol=rtolw, atol=atolw) if has_delta_bias: assert torch.allclose(delta_bias.grad, delta_bias_ref.grad, rtol=rtolw, atol=atolw) @pytest.mark.parametrize('wtype', [torch.float32, torch.complex64]) # @pytest.mark.parametrize('wtype', [torch.complex64]) # @pytest.mark.parametrize('itype', [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize('itype', [torch.float32]) # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 372, 512, 784, 1024, 1134, 2048, 4096]) @pytest.mark.parametrize('seqlen', [128]) @pytest.mark.parametrize("is_variable_C", [False, True]) # @pytest.mark.parametrize("is_variable_C", [False]) @pytest.mark.parametrize("is_variable_B", [False, True]) # @pytest.mark.parametrize("is_variable_B", [True]) def test_mamba_inner_fn(is_variable_B, is_variable_C, seqlen, itype, wtype): device = 'cuda' rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3) if itype == torch.bfloat16: rtol, atol = 3e-2, 5e-2 rtolw, atolw = (1e-3, 1e-3) # If we have z, the errors on the weights seem higher rtolw = max(rtolw, rtol) atolw = max(atolw, atol) # set seed torch.random.manual_seed(0) batch_size = 2 dim = 768 dstate = 8 dt_rank = 48 is_complex = wtype == torch.complex64 xz = torch.randn(batch_size, 2 * dim, seqlen, device=device, dtype=itype, requires_grad=True) conv1d_weight = torch.randn(dim, 1, 3, device=device, dtype=torch.float32, requires_grad=True) conv1d_bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) x_proj_weight = torch.randn(dt_rank + (bool(is_variable_B) + bool(is_variable_C)) * dstate * (1 if not is_complex else 2), dim, device=device, dtype=itype, requires_grad=True) delta_proj_weight = torch.randn(dim, dt_rank, device=device, dtype=itype, requires_grad=True) out_proj_weight = torch.randn(dim // 2, dim, device=device, dtype=itype, requires_grad=True) out_proj_bias = None A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_() B = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True) if not is_variable_B else None) C = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True) if not is_variable_C else None) D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_() B_proj_bias = None C_proj_bias = None xz_ref = xz.detach().clone().requires_grad_() conv1d_weight_ref = conv1d_weight.detach().clone().requires_grad_() conv1d_bias_ref = conv1d_bias.detach().clone().requires_grad_() x_proj_weight_ref = x_proj_weight.detach().clone().requires_grad_() delta_proj_weight_ref = delta_proj_weight.detach().clone().requires_grad_() out_proj_weight_ref = out_proj_weight.detach().clone().requires_grad_() out_proj_bias_ref = (out_proj_bias.detach().clone().requires_grad_() if out_proj_bias is not None else None) A_ref = A.detach().clone().requires_grad_() B_ref = B.detach().clone().requires_grad_() if B is not None else None C_ref = C.detach().clone().requires_grad_() if C is not None else None D_ref = D.detach().clone().requires_grad_() delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None out = mamba_inner_fn(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B, C, D, delta_bias=delta_bias, delta_softplus=True) out_ref = mamba_inner_ref(xz_ref, conv1d_weight_ref, conv1d_bias_ref, x_proj_weight_ref, delta_proj_weight_ref, out_proj_weight_ref, out_proj_bias_ref, A_ref, B_ref, C_ref, D_ref, delta_bias=delta_bias_ref, delta_softplus=True) # dA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) # dt_u = delta * u print(f'Output max diff: {(out - out_ref).abs().max().item()}') print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) g = torch.randn_like(out) out_ref.backward(g) out.backward(g) print(f'dxz max diff: {(xz.grad - xz_ref.grad).abs().max().item()}') print(f'dA max diff: {(A.grad - A_ref.grad).abs().max().item()}') if not is_variable_B: print(f'dB max diff: {(B.grad - B_ref.grad).abs().max().item()}') if not is_variable_C: print(f'dC max diff: {(C.grad - C_ref.grad).abs().max().item()}') print(f'dD max diff: {(D.grad - D_ref.grad).abs().max().item()}') print(f'ddelta_bias max diff: {(delta_bias.grad - delta_bias_ref.grad).abs().max().item()}') print(f'dout_proj_weight max diff: {(out_proj_weight.grad - out_proj_weight_ref.grad).abs().max().item()}') print(f'ddelta_proj_weight max diff: {(delta_proj_weight.grad - delta_proj_weight_ref.grad).abs().max().item()}') print(f'dx_proj_weight max diff: {(x_proj_weight.grad - x_proj_weight_ref.grad).abs().max().item()}') print(f'dconv1d_weight max diff: {(conv1d_weight.grad - conv1d_weight_ref.grad).abs().max().item()}') print(f'dconv1d_bias max diff: {(conv1d_bias.grad - conv1d_bias_ref.grad).abs().max().item()}') # assert torch.allclose(xz.grad, xz_ref.grad.to(dtype=itype), rtol=rtol * 2, atol=atol * 2) # assert torch.allclose(delta.grad, delta_ref.grad.to(dtype=itype), rtol=rtol * 5, atol=atol * 10) # assert torch.allclose(A.grad, A_ref.grad, rtol=rtolw, atol=atolw * 5) # assert torch.allclose(B.grad, B_ref.grad, rtol=rtolw if not is_variable_B else rtol, # atol=atolw if not is_variable_B else atol) # assert torch.allclose(C.grad, C_ref.grad, rtol=rtolw if not is_variable_C else rtol, # atol=atolw if not is_variable_C else atol) # assert torch.allclose(D.grad, D_ref.grad, rtol=rtolw, atol=atolw) # assert torch.allclose(delta_bias.grad, delta_bias_ref.grad, rtol=rtolw, atol=atolw) ================================================ FILE: tests/ops/tilelang/test_mamba3_mimo.py ================================================ """ Mamba-3 MIMO Kernel Tests Copyright (c) 2026, Dao AI Lab, Goombalab Usage: pytest -q -s -p no:warnings tests/ops/tilelang/test_mamba3_mimo.py -k bwd pytest -q -s -p no:warnings tests/ops/tilelang/test_mamba3_mimo.py -k fwd pytest -q -s -p no:warnings tests/ops/tilelang/test_mamba3_mimo.py -k smoke pytest -q -s -p no:warnings tests/ops/tilelang/test_mamba3_mimo.py -k chunk_ref_matches_step_ref Remove the -s flag for less verbose output. """ import sys from pathlib import Path from types import SimpleNamespace import math from typing import Optional, Tuple from einops import rearrange, repeat import pytest import torch from torch import Tensor F = torch.nn.functional FIXED_B = 4 FIXED_S = 2048 FIXED_H = 16 FIXED_G = 1 FIXED_ROTARY_DIM_DIVISOR = 4 FIXED_DTYPE = torch.bfloat16 REL_TOL = 0.10 CASE_GRID = [ pytest.param(16, 64, 4, 8, 128, id="N16_P64_R4_C8_BB128"), pytest.param(32, 64, 4, 16, 256, id="N32_P64_R4_C16_BB256"), pytest.param(64, 64, 4, 16, 256, id="N64_P64_R4_C16_BB256"), pytest.param(128, 64, 4, 16, 256, id="N128_P64_R4_C16_BB256"), pytest.param(256, 64, 4, 8, 256, id="N256_P64_R4_C8_BB256"), pytest.param(64, 128, 4, 16, 256, id="N64_P128_R4_C16_BB256"), pytest.param(128, 32, 4, 16, 256, id="N128_P32_R4_C16_BB256"), pytest.param(128, 128, 4, 8, 256, id="N128_P128_R4_C8_BB256"), pytest.param(128, 64, 8, 8, 256, id="N128_P64_R8_C8_BB256"), pytest.param(128, 64, 2, 32, 256, id="N128_P64_R2_C32_BB256"), pytest.param(128, 64, 1, 64, 256, id="N128_P64_R1_C64_BB256"), ] def _require_cuda_and_kernel_deps() -> None: if not torch.cuda.is_available(): pytest.skip("CUDA is required for mamba3 tilelang tests") pytest.importorskip("tilelang") pytest.importorskip("triton") @pytest.fixture(scope="module") def mods() -> SimpleNamespace: _require_cuda_and_kernel_deps() import mamba_ssm.ops.tilelang.mamba3.mamba3_mimo as mamba3_top import mamba_ssm.ops.tilelang.mamba3.mamba3_mimo_bwd as mamba3_bwd import mamba_ssm.ops.tilelang.mamba3.mamba3_mimo_fwd as mamba3_fwd import mamba_ssm.ops.triton.mamba3.mamba3_mimo_utils as mamba3_mimo_utils return SimpleNamespace( top=mamba3_top, bwd=mamba3_bwd, fwd=mamba3_fwd, utils=mamba3_mimo_utils, ) def max_rel_err(ours: Tensor, ref: Tensor, eps: float = 1e-5) -> float: ours_f = ours.float() ref_f = ref.float() num = (ours_f - ref_f).abs().max() den = ref_f.abs().max().clamp_min(eps) return float((num / den).item()) def assert_stable_rel( ours: Tensor, ref: Tensor, *, label: str, cfg: str, rel_tol: float = REL_TOL, ) -> None: ours_f = ours.float() ref_f = ref.float() rel = max_rel_err(ours_f, ref_f) close_mask = torch.isclose(ours_f, ref_f, rtol=REL_TOL, atol=0.1) bad_frac = float((~close_mask).float().mean().item()) max_abs = float((ours_f - ref_f).abs().max().item()) print( f"[debug] {label} ({cfg}) " f"stable_max_rel={rel:.6f} max_abs={max_abs:.6e} " f"bad_frac(rtol=0.1,atol=0.1)={bad_frac:.6f}" ) if rel < rel_tol: return raise AssertionError( f"{label} stable_max_rel >= {rel_tol} for {cfg}: " f"stable_max_rel={rel:.6f}, max_abs={max_abs:.6e}, " f"diag_bad_frac_at_rtol0.1_atol0.1={bad_frac:.6f}" ) def build_inputs( *, mods: SimpleNamespace, n: int, p: int, r: int, chunk_size: int, seed: int, b: int = FIXED_B, s: int = FIXED_S, h: int = FIXED_H, g: int = FIXED_G, dtype: torch.dtype = FIXED_DTYPE, has_z: bool = True, has_d: bool = True, rotary_dim_divisor: int = FIXED_ROTARY_DIM_DIVISOR, ) -> dict: assert s % chunk_size == 0 torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) q = torch.randn((b, s, r, g, n), device="cuda", dtype=dtype) k = torch.randn((b, s, r, g, n), device="cuda", dtype=dtype) v = torch.randn((b, s, h, p), device="cuda", dtype=dtype) q_bias = torch.randn((h, r, n), device="cuda", dtype=torch.float32) k_bias = torch.randn((h, r, n), device="cuda", dtype=torch.float32) mimo_v = torch.randn((h, r, p), device="cuda", dtype=torch.float32) / r mimo_o = torch.randn((h, r, p), device="cuda", dtype=torch.float32) / r z = torch.randn_like(v) if has_z else None mimo_z = torch.randn_like(mimo_v) if has_z else None d = torch.randn((h,), device="cuda", dtype=torch.float32) if has_d else None angles = torch.rand( (b, s, h, n // rotary_dim_divisor), device="cuda", dtype=torch.float32 ) dt = F.softplus(-3.0 + torch.randn((b, h, s), device="cuda", dtype=torch.float32)) a = torch.rand((b, h, s), device="cuda", dtype=torch.float32) dA = (-dt * a).detach() dA_cs, dA_cs_rev, segsum = mods.utils.compute_dacs_segsum_triton(dA, chunk_size) trap = torch.rand((b, h, s), device="cuda", dtype=dtype) dout = torch.randn_like(v) return { "q": q, "k": k, "v": v, "q_bias": q_bias, "k_bias": k_bias, "mimo_v": mimo_v, "mimo_o": mimo_o, "z": z, "mimo_z": mimo_z, "D": d, "angles": angles, "dt": dt, "dA": dA, "dA_cs": dA_cs, "dA_cs_rev": dA_cs_rev, "segsum": segsum, "trap": trap, "dout": dout, "chunk_size": chunk_size, "rotary_dim_divisor": rotary_dim_divisor, } def make_smoke_inputs( *, batch: int = 1, seqlen: int = 64, mimo_rank: int = 4, nheads_qk: int = 1, nheads: int = 8, headdim_qk: int = 64, headdim_v: int = 32, chunk_size: int = 16, rotary_dim_divisor: int = 4, device: str = "cuda", dtype: torch.dtype = torch.bfloat16, seed: int = 0, ): torch.manual_seed(seed) if device == "cuda": torch.cuda.manual_seed_all(seed) Q = torch.randn( (batch, seqlen, mimo_rank, nheads_qk, headdim_qk), device=device, dtype=dtype, requires_grad=True, ) K = torch.randn_like(Q, requires_grad=True) V = torch.randn( (batch, seqlen, nheads, headdim_v), device=device, dtype=dtype, requires_grad=True, ) import torch.nn.functional as F import math DT = F.softplus( -3.0 + torch.randn( batch, nheads, seqlen, device=device, dtype=torch.float, ) ).detach().requires_grad_(True) # Make ADT a leaf so .grad is populated without retain_grad(). ADT = (-DT.detach() * math.log2(math.e)).clone().detach().requires_grad_(True) Trap = ( torch.rand( (batch, nheads, seqlen), device=device, dtype=dtype, ) * 0.5 ).detach().requires_grad_(True) Q_bias = torch.randn( (nheads, mimo_rank, headdim_qk), device=device, dtype=torch.float32, requires_grad=True, ) K_bias = torch.randn_like(Q_bias, requires_grad=True) MIMO_V = torch.randn( (nheads, mimo_rank, headdim_v), device=device, dtype=torch.float32, requires_grad=True, ) MIMO_Z = (torch.randn_like(MIMO_V) / mimo_rank).detach().requires_grad_(True) MIMO_Out = (torch.randn_like(MIMO_V) / mimo_rank).detach().requires_grad_(True) Angles = torch.rand( (batch, seqlen, nheads, headdim_qk // rotary_dim_divisor), device=device, dtype=torch.float32, requires_grad=True, ) D = torch.randn( (nheads,), device=device, dtype=torch.float32, requires_grad=True, ) Z = torch.randn( (batch, seqlen, nheads, headdim_v), device=device, dtype=dtype, requires_grad=True, ) return dict( Q=Q, K=K, V=V, ADT=ADT, DT=DT, Trap=Trap, Q_bias=Q_bias, K_bias=K_bias, MIMO_V=MIMO_V, MIMO_Z=MIMO_Z, MIMO_Out=MIMO_Out, Angles=Angles, D=D, Z=Z, chunk_size=chunk_size, rotary_dim_divisor=rotary_dim_divisor, dtype=dtype, ) def grads_to_dA(grad_dA_cs: Tensor, grad_dA_cs_rev: Tensor, chunk_size: int) -> Tensor: b, h, s = grad_dA_cs.shape assert s % chunk_size == 0 nchunks = s // chunk_size g_f = grad_dA_cs.view(b, h, nchunks, chunk_size) grad_from_f = torch.flip(torch.cumsum(torch.flip(g_f, dims=[-1]), dim=-1), dims=[-1]) g_r = grad_dA_cs_rev.view(b, h, nchunks, chunk_size) prefix = torch.cumsum(g_r, dim=-1) grad_from_r = torch.cat([torch.zeros_like(prefix[..., :1]), prefix[..., :-1]], dim=-1) return (grad_from_f + grad_from_r).view(b, h, s) def mamba3_MIMO_step_ref( Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, ADT: torch.Tensor, DT: torch.Tensor, Trap: torch.Tensor, Q_bias: torch.Tensor, K_bias: torch.Tensor, Angles: torch.Tensor, MIMO_V: torch.Tensor, MIMO_O: torch.Tensor, D: Optional[torch.Tensor] = None, Z: Optional[torch.Tensor] = None, MIMO_Z: Optional[torch.Tensor] = None, Input_States: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]: """Reference implementation of Mamba-3 MIMO in recurrent (step) mode. Args: Input_States: Optional tuple of (Angle_State, SSM_State, K_State, V_State) Returns: out: Output tensor (batch, seqlen, nheads, headdim_v) Final_States: Tuple of (Angle_State, SSM_State, K_State, V_State) """ batch, seqlen, mimo_rank, nheads_qk, headdim_qk = Q.shape _, _, nheads, headdim_v = V.shape headdim_angles = Angles.shape[-1] device = Q.device assert seqlen > 0 # Expand Q/K for GQA if Q.shape[3] != V.shape[2]: Q = repeat(Q, "b s r h_bc d -> b s r (h_bc g) d", g=V.shape[2] // Q.shape[3]) if K.shape[3] != V.shape[2]: K = repeat(K, "b s r h_bc d -> b s r (h_bc g) d", g=V.shape[2] // K.shape[3]) def apply_rotary_emb(tensor, cos, sin): tensor_reshaped = tensor.view(*tensor.shape[:-1], -1, 2) tensor_0 = tensor_reshaped[..., 0] tensor_1 = tensor_reshaped[..., 1] if cos.shape[-1] < tensor_0.shape[-1]: pad_size = tensor_0.shape[-1] - cos.shape[-1] cos = F.pad(cos, (0, pad_size), value=1.0) sin = F.pad(sin, (0, pad_size), value=0.0) rotated_0 = tensor_0 * cos - tensor_1 * sin rotated_1 = tensor_0 * sin + tensor_1 * cos rotated = torch.stack([rotated_0, rotated_1], dim=-1).view_as(tensor) return rotated q_bias = rearrange(Q_bias, "h r d -> r h d") k_bias = rearrange(K_bias, "h r d -> r h d") # Initialize states if Input_States is not None: Angle_State, SSM_State, K_State, V_State = Input_States Angle_State = Angle_State.clone() SSM_State = SSM_State.clone().to(torch.float32) K_State = K_State.clone() V_State = V_State.clone() else: Angle_State = torch.zeros((batch, nheads, headdim_angles), dtype=torch.float32, device=device) SSM_State = torch.zeros((batch, nheads, headdim_v, headdim_qk), dtype=torch.float32, device=device) K_State = torch.zeros((batch, nheads, mimo_rank, headdim_qk), dtype=Q.dtype, device=device) V_State = torch.zeros((batch, nheads, mimo_rank, headdim_v), dtype=V.dtype, device=device) # MIMO up project x and z: v_proj = torch.einsum("bthd,hrd->btrhd", V, MIMO_V) if Z is not None: z_proj = torch.einsum("bthd,hrd->btrhd", Z, MIMO_Z) else: z_proj = None TWO_PI = 2 * math.pi out_arr = [] # Main SSM recurrence for idx in range(seqlen): q = Q[:, idx, :, :, :] + q_bias.unsqueeze(0) k = K[:, idx, :, :, :] + k_bias.unsqueeze(0) v = v_proj[:, idx, :, :, :] # (B R H P) adt = ADT[:, :, idx] dt = DT[:, :, idx] trap = torch.nn.functional.sigmoid(Trap[:, :, idx]) z = z_proj[:, idx, :, :, :] if z_proj is not None else None angles = Angles[:, idx, :, :] # (B H N) q = q.permute(0, 2, 1, 3) # (B H R N) k = k.permute(0, 2, 1, 3) v = v.permute(0, 2, 1, 3) if z is not None: z = z.permute(0, 2, 1, 3) # Update angle state with cumsum: Angle_State = (Angle_State + Angles * DT) mod 2π # Angle_State = Angle_State + angles * dt.unsqueeze(-1) # Angle_State = Angle_State - TWO_PI * torch.floor(Angle_State / TWO_PI) Angle_State = Angle_State + torch.tanh(angles) * dt.unsqueeze(-1) * math.pi # Apply rotary embeddings to Q and K using cumulative angles cos_angles = torch.cos(Angle_State).unsqueeze(2) # (B H 1 N) sin_angles = torch.sin(Angle_State).unsqueeze(2) q_rot = apply_rotary_emb(q, cos_angles, sin_angles) k_rot = apply_rotary_emb(k, cos_angles, sin_angles) alpha = torch.exp(adt) beta = (1 - trap) * dt * alpha gamma = trap * dt # Update SSM state using previous K_State and V_State prev_kv = torch.einsum("bhrd,bhrp->bhpd", K_State, V_State) curr_kv = torch.einsum("bhrd,bhrp->bhpd", k_rot, v) SSM_State = alpha.unsqueeze(-1).unsqueeze(-1) * SSM_State SSM_State = SSM_State + beta.unsqueeze(-1).unsqueeze(-1) * prev_kv SSM_State = SSM_State + gamma.unsqueeze(-1).unsqueeze(-1) * curr_kv # Compute output out = torch.einsum("bhpd,bhrd->bhrp", SSM_State, q_rot.to(SSM_State.dtype)) if D is not None: out = out + D[None, :, None, None] * v if z is not None: out = out * z * torch.sigmoid(z) out = torch.einsum("bhrp,hrp->bhp", out, MIMO_O) out_arr.append(out) # Update K and V states for next step K_State = k_rot V_State = v out = torch.stack(out_arr, dim=1) Final_States = (Angle_State, SSM_State, K_State, V_State) return out, Final_States def apply_angle_dt_reference( angle: Tensor, # (batch, seqlen, nheads, dim) dt: Tensor, # (batch, seqlen, nheads) ) -> Tensor: # Match debug_mimo_step.py preprocessing for chunk reference path. base_vals = angle.to(torch.float32) base_vals = torch.tanh(base_vals) * dt[..., None].to(torch.float32) * torch.pi return torch.cumsum(base_vals, dim=1) def mamba3_MIMO_chunk_ref( q: Tensor, k: Tensor, v: Tensor, q_bias: Tensor, k_bias: Tensor, mimo_v: Tensor, mimo_o: Optional[Tensor], z: Optional[Tensor], mimo_z: Optional[Tensor], angles: Tensor, dA_cs: Tensor, dA_cs_rev: Tensor, dt: Tensor, trap: Tensor, D: Optional[Tensor], chunk_size: int = 64, rotary_dim_divisor: int = 4, return_final_state: bool = False, dtype: torch.dtype = torch.float32, rotate_pairwise: bool = False, contract_mimo_out: bool = True, ) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]: # Local copy of the reference program so tests remain valid even if module-level # debug/reference helpers are removed from shipped kernels. from einops import rearrange, repeat nchunks = q.shape[1] // chunk_size q, k, v = q.to(dtype), k.to(dtype), v.to(dtype) if z is not None: z = z.to(dtype) mimo_z = mimo_z.to(dtype) if D is not None: D = D.to(dtype) q_bias, k_bias = q_bias.to(dtype), k_bias.to(dtype) mimo_v = mimo_v.to(dtype) if contract_mimo_out: assert mimo_o is not None mimo_o = mimo_o.to(dtype) if dA_cs is not None: dA_cs, dA_cs_rev = dA_cs.to(dtype), dA_cs_rev.to(dtype) dA_cs = rearrange(dA_cs, "b h (n c) -> b h n c", c=chunk_size) dA_cs_rev = rearrange(dA_cs_rev, "b h (n c) -> b h n c", c=chunk_size) batch, seqlen, mimo_rank, nheads_qk, dstate = q.shape nheads = v.shape[-2] if nheads_qk != nheads: q = repeat(q, "b s r h_qk d -> b s r (h_qk g) d", g=nheads // nheads_qk) k = repeat(k, "b s r h_qk d -> b s r (h_qk g) d", g=nheads // nheads_qk) angles = angles.to(dtype) if angles is not None else None trap = trap.to(dtype) if trap is not None else None dt = dt.to(dtype) if dt is not None else None q_bias = rearrange(q_bias, "h r d -> r h d") k_bias = rearrange(k_bias, "h r d -> r h d") q = q + q_bias[None, None, :, :, :] k = k + k_bias[None, None, :, :, :] qk_dot = torch.einsum("bsRhd,bsrhd->bsRrh", q, k) if angles is not None: angles = angles.unsqueeze(2) cos_angles = torch.cos(angles) sin_angles = torch.sin(angles) def apply_rotary_emb(tensor: Tensor, cos: Tensor, sin: Tensor) -> Tensor: if rotate_pairwise: # Pairwise convention used by mamba3_MIMO_step_ref / debug_mimo_step.py. tensor_reshaped = tensor.view(*tensor.shape[:-1], -1, 2) tensor_0 = tensor_reshaped[..., 0] tensor_1 = tensor_reshaped[..., 1] rotated_0 = tensor_0 * cos - tensor_1 * sin rotated_1 = tensor_0 * sin + tensor_1 * cos return torch.stack([rotated_0, rotated_1], dim=-1).view_as(tensor) # Kernel-aligned convention (kept as default for existing tests). tensor_reshaped = tensor.view(*tensor.shape[:-1], 2, -1) tensor_0 = tensor_reshaped[..., 0, :] tensor_1 = tensor_reshaped[..., 1, :] rotated_0 = tensor_0 * cos - tensor_1 * sin rotated_1 = tensor_0 * sin + tensor_1 * cos return torch.stack([rotated_0, rotated_1], dim=-2).view_as(tensor) def apply_rotary_emb_rotate_half(tensor: Tensor, cos: Tensor, sin: Tensor) -> Tensor: tensor_reshaped = tensor.view(*tensor.shape[:-1], 4, -1) tensor_0 = tensor_reshaped[..., 0, :] tensor_1 = tensor_reshaped[..., 2, :] rotated_0 = tensor_0 * cos - tensor_1 * sin rotated_1 = tensor_0 * sin + tensor_1 * cos return torch.stack( [ rotated_0, tensor_reshaped[..., 1, :], rotated_1, tensor_reshaped[..., 3, :], ], dim=-2, ).view_as(tensor) if rotary_dim_divisor == 4: q = apply_rotary_emb_rotate_half(q, cos_angles, sin_angles) k = apply_rotary_emb_rotate_half(k, cos_angles, sin_angles) elif rotary_dim_divisor == 2: q = apply_rotary_emb(q, cos_angles, sin_angles) k = apply_rotary_emb(k, cos_angles, sin_angles) else: raise ValueError(f"Invalid rotary_dim_divisor: {rotary_dim_divisor}") if return_final_state: final_k = k[:, -1].contiguous().clone() else: final_k = None trap = torch.nn.functional.sigmoid(trap) gamma = dt * trap dt_shifted = torch.nn.functional.pad(dt[:, :, 1:], (0, 1), value=0.0) trap_shifted = torch.nn.functional.pad(trap[:, :, 1:], (0, 1), value=0.0) shifted_gamma = dt_shifted * (1 - trap_shifted) factor = gamma + shifted_gamma k = torch.einsum("bsrhn,bhs->bsrhn", k, factor) qk_dot = torch.einsum("bsrRh,bhs->bsrRh", qk_dot, shifted_gamma) v = torch.einsum("bthd,hrd->btrhd", v, mimo_v) def segsum_unstable(x: Tensor) -> Tensor: x_segsum = x[..., :, None] - x[..., None, :] mask = torch.tril(torch.ones(x.size(-1), x.size(-1), device=x.device, dtype=torch.bool), diagonal=0) return x_segsum.masked_fill(~mask, -torch.inf) mimo_mask_outer = segsum_unstable(dA_cs) mimo_mask_inner = torch.ones(mimo_rank, mimo_rank, dtype=torch.bool, device=q.device) mimo_mask = torch.kron(mimo_mask_outer, mimo_mask_inner[None, None, None, :, :]) q = rearrange(q, "b (n c) r h d -> b h n (c r) d", c=chunk_size) k_scaled = rearrange(k, "b (n c) r h d -> b h n c r d", c=chunk_size) k_scaled = torch.einsum("bhncrd,bhnc->bhncrd", k_scaled, torch.exp(dA_cs_rev)) k_scaled = rearrange(k_scaled, "b h n c r d -> b h n (c r) d", c=chunk_size) k = rearrange(k, "b (n c) r h d -> b h n (c r) d", c=chunk_size) v = rearrange(v, "b (n c) r h d -> b h n (c r) d", c=chunk_size) kv = k_scaled.transpose(-1, -2) @ v curr_state = torch.zeros_like(kv[:, :, 0, :, :]) for n in range(nchunks): curr_dA_sum = dA_cs[:, :, n, -1] next_state = (torch.exp(curr_dA_sum[:, :, None, None]) * curr_state) + kv[:, :, n, :, :] kv[:, :, n, :, :] = curr_state curr_state = next_state if return_final_state: final_state = next_state.float() else: final_state = None q_inter = q * torch.exp(repeat(dA_cs, "b h n c -> b h n (c r)", r=mimo_rank).unsqueeze(-1)) inter = q_inter @ kv intra = ((q @ k.transpose(-1, -2)) * torch.exp(mimo_mask)) @ v o = inter + intra o = rearrange(o, "b h n (c r) d -> b h n c r d", r=mimo_rank) v = rearrange(v, "b h n (c r) d -> b h (n c) r d", r=mimo_rank) qk_dot = rearrange(qk_dot, "b t R r h -> b h t R r") qkv = torch.einsum("bhtRr,bhtrp->bhtRp", qk_dot, v) qkv = rearrange(qkv, "b h (n c) r d -> b h n c r d", c=chunk_size) o -= qkv if D is not None: vd = torch.einsum("bhtrp,h->bhtrp", v, D) vd = rearrange(vd, "b h (n c) r d -> b h n c r d", c=chunk_size) o += vd if z is not None: z = torch.einsum("bthd,hrd->btrhd", z, mimo_z) z = rearrange(z, "b (n c) r h d -> b h n c r d", c=chunk_size) o = o * torch.nn.functional.silu(z) if contract_mimo_out: assert mimo_o is not None o = torch.einsum("bhncrd,hrd->bhncd", o, mimo_o) return rearrange(o, "b h n c d -> b (n c) h d"), final_state, final_k return rearrange(o, "b h n c r d -> b (n c) r h d"), final_state, final_k def run_ref_backward_fp32( mods: SimpleNamespace, inputs: dict, *, contract_mimo_out: bool = True, grad_output: Optional[Tensor] = None, ) -> dict: ref_dtype = torch.float32 q = inputs["q"].detach().to(ref_dtype).requires_grad_(True) k = inputs["k"].detach().to(ref_dtype).requires_grad_(True) v = inputs["v"].detach().to(ref_dtype).requires_grad_(True) q_bias = inputs["q_bias"].detach().to(ref_dtype).requires_grad_(True) k_bias = inputs["k_bias"].detach().to(ref_dtype).requires_grad_(True) mimo_v = inputs["mimo_v"].detach().to(ref_dtype).requires_grad_(True) mimo_o = ( inputs["mimo_o"].detach().to(ref_dtype).requires_grad_(True) if contract_mimo_out else None ) z = ( inputs["z"].detach().to(ref_dtype).requires_grad_(True) if inputs["z"] is not None else None ) mimo_z = ( inputs["mimo_z"].detach().to(ref_dtype).requires_grad_(True) if inputs["mimo_z"] is not None else None ) angles = inputs["angles"].detach().to(ref_dtype).requires_grad_(True) dt = inputs["dt"].detach().to(ref_dtype).requires_grad_(True) trap = inputs["trap"].detach().to(ref_dtype).requires_grad_(True) d = inputs["D"].detach().to(ref_dtype).requires_grad_(True) dA_cs_base, dA_cs_rev_base, _ = mods.utils.compute_dacs_segsum_triton( inputs["dA"].detach().to(torch.float32), inputs["chunk_size"] ) dA_cs = dA_cs_base.detach().to(ref_dtype).requires_grad_(True) dA_cs_rev = dA_cs_rev_base.detach().to(ref_dtype).requires_grad_(True) out, _, _ = mamba3_MIMO_chunk_ref( q, k, v, q_bias, k_bias, mimo_v, mimo_o, z, mimo_z, angles, dA_cs, dA_cs_rev, dt, trap, d, chunk_size=inputs["chunk_size"], rotary_dim_divisor=inputs["rotary_dim_divisor"], dtype=ref_dtype, contract_mimo_out=contract_mimo_out, ) grad_input_items = [ ("q", q), ("k", k), ("v", v), ("q_bias", q_bias), ("k_bias", k_bias), ("mimo_v", mimo_v), ("angles", angles), ("dA_cs", dA_cs), ("dA_cs_rev", dA_cs_rev), ("dt", dt), ("trap", trap), ("dD", d), ] if z is not None: grad_input_items.append(("z", z)) if mimo_z is not None: grad_input_items.append(("mimo_z", mimo_z)) if contract_mimo_out: grad_input_items.append(("mimo_o", mimo_o)) if grad_output is None: grad_output = inputs["dout"] grads = torch.autograd.grad( outputs=out, inputs=tuple(t for _, t in grad_input_items), grad_outputs=grad_output.detach().to(ref_dtype), retain_graph=False, allow_unused=True, ) grad_map = {name: grad for (name, _), grad in zip(grad_input_items, grads)} grad_map["dA"] = grads_to_dA(grad_map["dA_cs"], grad_map["dA_cs_rev"], inputs["chunk_size"]) return { "dq": grad_map["q"], "dk": grad_map["k"], "dv": grad_map["v"], "dA": grad_map["dA"], "ddt": grad_map["dt"], "dtrap": grad_map["trap"], "dq_bias": grad_map["q_bias"], "dk_bias": grad_map["k_bias"], "dmimo_v": grad_map["mimo_v"], "dmimo_z": grad_map.get("mimo_z"), "dmimo_o": grad_map.get("mimo_o"), "dangles": grad_map["angles"], "dD": grad_map["dD"], "dz": grad_map.get("z"), } def test_mamba3_MIMO_chunk_ref_matches_step_ref() -> None: # Lightweight deterministic ref-vs-ref consistency test. B, S, H, G, P, N, R, chunk_size = 1, 128, 8, 1, 32, 64, 4, 16 dtype = torch.float32 device = "cpu" torch.manual_seed(0) q = torch.randn((B, S, R, G, N), device=device, dtype=dtype) k = torch.randn((B, S, R, G, N), device=device, dtype=dtype) v = torch.randn((B, S, H, P), device=device, dtype=dtype) q_bias = torch.randn((H, R, N), device=device, dtype=dtype) k_bias = torch.randn((H, R, N), device=device, dtype=dtype) mimo_v = torch.rand((H, R, P), device=device, dtype=dtype) mimo_o = torch.rand((H, R, P), device=device, dtype=dtype) z = torch.randn_like(v) mimo_z = torch.rand_like(mimo_v) D = torch.randn((H,), device=device, dtype=dtype) angles = torch.rand((B, S, H, N // 2), device=device, dtype=dtype) dt = F.softplus(-3.0 + torch.randn(B, H, S, device=device, dtype=torch.float32)) A_neg = -F.softplus(torch.randn((B, H, S), device=device, dtype=torch.float32)) A_neg = torch.clamp(A_neg, max=-1e-4) ADT = A_neg * dt trap = torch.rand(B, H, S, device=device, dtype=dtype) * 0.5 dA_cs = torch.cumsum(rearrange(ADT, "b h (n c) -> b h n c", c=chunk_size), dim=-1) dA_cs_rev = dA_cs[..., -1:] - dA_cs angles_prerotated = apply_angle_dt_reference(angles, dt.permute(0, 2, 1)) chunk_out, _, _ = mamba3_MIMO_chunk_ref( q, k, v, q_bias, k_bias, mimo_v, mimo_o, z, mimo_z, angles_prerotated, dA_cs.view(B, H, S), dA_cs_rev.view(B, H, S), dt, trap, D, chunk_size=chunk_size, rotary_dim_divisor=2, return_final_state=True, dtype=dtype, rotate_pairwise=True, ) step_out, _ = mamba3_MIMO_step_ref( q, k, v, ADT, dt, trap, q_bias, k_bias, angles, mimo_v, mimo_o, D=D, Z=z, MIMO_Z=mimo_z, ) assert chunk_out.shape == step_out.shape assert_stable_rel( chunk_out, step_out, label="chunk_ref_vs_step_ref", cfg=f"B={B}, S={S}, H={H}, P={P}, N={N}, R={R}, C={chunk_size}", rel_tol=0.02, ) @pytest.mark.parametrize("n,p,r,chunk_size,bb_threads", CASE_GRID) def test_fused_chunk_linear_attn_fwd_relative_error_lt_10pct( mods: SimpleNamespace, n: int, p: int, r: int, chunk_size: int, bb_threads: int ) -> None: del bb_threads inputs = build_inputs( mods=mods, n=n, p=p, r=r, chunk_size=chunk_size, seed=1234 + n + p + r + chunk_size, ) out_tilelang, _, _ = mods.fwd.mamba_mimo_forward( inputs["q"], inputs["k"], inputs["v"], inputs["q_bias"], inputs["k_bias"], inputs["mimo_v"], inputs["mimo_o"], inputs["z"], inputs["D"], inputs["mimo_z"], inputs["angles"], inputs["dA_cs"], inputs["dA_cs_rev"], inputs["dt"], inputs["trap"], inputs["segsum"], chunk_size=chunk_size, rotary_dim_divisor=inputs["rotary_dim_divisor"], dtype=FIXED_DTYPE, ) out_ref_fp32, _, _ = mamba3_MIMO_chunk_ref( inputs["q"].clone(), inputs["k"].clone(), inputs["v"].clone(), inputs["q_bias"].clone(), inputs["k_bias"].clone(), inputs["mimo_v"].clone(), inputs["mimo_o"].clone(), inputs["z"].clone(), inputs["mimo_z"].clone(), inputs["angles"].clone(), inputs["dA_cs"].clone(), inputs["dA_cs_rev"].clone(), inputs["dt"].clone(), inputs["trap"].clone(), inputs["D"].clone(), chunk_size=chunk_size, rotary_dim_divisor=inputs["rotary_dim_divisor"], dtype=torch.float32, ) assert_stable_rel( out_tilelang, out_ref_fp32, label="forward", cfg=f"N={n}, P={p}, R={r}, chunk={chunk_size}", ) @pytest.mark.parametrize("n,p,r,chunk_size,bb_threads", CASE_GRID) def test_fused_chunk_linear_attn_fwd_return_state_relative_error_lt_10pct( mods: SimpleNamespace, n: int, p: int, r: int, chunk_size: int, bb_threads: int ) -> None: del bb_threads inputs = build_inputs( mods=mods, n=n, p=p, r=r, chunk_size=chunk_size, seed=3456 + n + p + r + chunk_size, ) out_tilelang, final_state_tilelang, final_k_tilelang = mods.fwd.mamba_mimo_forward( inputs["q"], inputs["k"], inputs["v"], inputs["q_bias"], inputs["k_bias"], inputs["mimo_v"], inputs["mimo_o"], inputs["z"], inputs["D"], inputs["mimo_z"], inputs["angles"], inputs["dA_cs"], inputs["dA_cs_rev"], inputs["dt"], inputs["trap"], inputs["segsum"], return_state=True, chunk_size=chunk_size, rotary_dim_divisor=inputs["rotary_dim_divisor"], dtype=FIXED_DTYPE, ) out_ref_fp32, final_state_ref, final_k_ref = mamba3_MIMO_chunk_ref( inputs["q"].clone(), inputs["k"].clone(), inputs["v"].clone(), inputs["q_bias"].clone(), inputs["k_bias"].clone(), inputs["mimo_v"].clone(), inputs["mimo_o"].clone(), inputs["z"].clone(), inputs["mimo_z"].clone(), inputs["angles"].clone(), inputs["dA_cs"].clone(), inputs["dA_cs_rev"].clone(), inputs["dt"].clone(), inputs["trap"].clone(), inputs["D"].clone(), chunk_size=chunk_size, rotary_dim_divisor=inputs["rotary_dim_divisor"], return_final_state=True, dtype=torch.float32, ) assert_stable_rel( out_tilelang, out_ref_fp32, label="forward_return_state_out", cfg=f"N={n}, P={p}, R={r}, chunk={chunk_size}", ) assert_stable_rel( final_state_tilelang, final_state_ref, label="forward_return_state_final_state", cfg=f"N={n}, P={p}, R={r}, chunk={chunk_size}", ) assert_stable_rel( final_k_tilelang, final_k_ref, label="forward_return_state_final_k", cfg=f"N={n}, P={p}, R={r}, chunk={chunk_size}", ) @pytest.mark.parametrize("n,p,r,chunk_size,bb_threads", CASE_GRID) def test_fused_chunk_linear_attn_fwd_prereduce_relative_error_lt_10pct( mods: SimpleNamespace, n: int, p: int, r: int, chunk_size: int, bb_threads: int ) -> None: del bb_threads inputs = build_inputs( mods=mods, n=n, p=p, r=r, chunk_size=chunk_size, seed=2345 + n + p + r + chunk_size, has_z=False, ) out_tilelang, _, _ = mods.fwd.mamba_mimo_forward( inputs["q"], inputs["k"], inputs["v"], inputs["q_bias"], inputs["k_bias"], inputs["mimo_v"], None, inputs["z"], inputs["D"], inputs["mimo_z"], inputs["angles"], inputs["dA_cs"], inputs["dA_cs_rev"], inputs["dt"], inputs["trap"], inputs["segsum"], chunk_size=chunk_size, rotary_dim_divisor=inputs["rotary_dim_divisor"], dtype=FIXED_DTYPE, ) out_ref_fp32, _, _ = mamba3_MIMO_chunk_ref( inputs["q"].clone(), inputs["k"].clone(), inputs["v"].clone(), inputs["q_bias"].clone(), inputs["k_bias"].clone(), inputs["mimo_v"].clone(), None, None, None, inputs["angles"].clone(), inputs["dA_cs"].clone(), inputs["dA_cs_rev"].clone(), inputs["dt"].clone(), inputs["trap"].clone(), inputs["D"].clone(), chunk_size=chunk_size, rotary_dim_divisor=inputs["rotary_dim_divisor"], dtype=torch.float32, contract_mimo_out=False, ) assert_stable_rel( out_tilelang, out_ref_fp32, label="forward_prereduce", cfg=f"N={n}, P={p}, R={r}, chunk={chunk_size}", ) @pytest.mark.parametrize("n,p,r,chunk_size,bb_threads", CASE_GRID) def test_mamba_mimo_bwd_combined_relative_errors_lt_10pct( mods: SimpleNamespace, n: int, p: int, r: int, chunk_size: int, bb_threads: int ) -> None: inputs = build_inputs( mods=mods, n=n, p=p, r=r, chunk_size=chunk_size, seed=5678 + n + p + r + chunk_size, ) ref_grads = run_ref_backward_fp32(mods, inputs) ( dq, dk, dv, dA, ddt, dtrap, dq_bias, dk_bias, dmimo_v, dmimo_z, dmimo_o, dangles, dD, dz, ) = mods.bwd.mamba_mimo_bwd_combined( inputs["dout"], inputs["q"], inputs["k"], inputs["v"], inputs["q_bias"], inputs["k_bias"], inputs["mimo_v"], inputs["mimo_o"], inputs["z"], inputs["mimo_z"], inputs["angles"], inputs["dA_cs"], inputs["dA_cs_rev"], inputs["dt"], inputs["trap"], inputs["D"], inputs["segsum"], chunk_size, inputs["rotary_dim_divisor"], FIXED_DTYPE, bb_threads=bb_threads, ) comparisons = { "dq": (dq, ref_grads["dq"]), "dk": (dk, ref_grads["dk"]), "dv": (dv, ref_grads["dv"]), "dA": (dA, ref_grads["dA"]), "ddt": (ddt, ref_grads["ddt"]), "dtrap": (dtrap, ref_grads["dtrap"]), "dq_bias": (dq_bias, ref_grads["dq_bias"]), "dk_bias": (dk_bias, ref_grads["dk_bias"]), "dmimo_v": (dmimo_v, ref_grads["dmimo_v"]), "dmimo_z": (dmimo_z, ref_grads["dmimo_z"]), "dmimo_o": (dmimo_o, ref_grads["dmimo_o"]), "dangles": (dangles, ref_grads["dangles"]), "dD": (dD, ref_grads["dD"]), "dz": (dz, ref_grads["dz"]), } for name, (ours, ref) in comparisons.items(): assert_stable_rel( ours, ref, label=name, cfg=f"N={n}, P={p}, R={r}, chunk={chunk_size}, bb_threads={bb_threads}", ) @pytest.mark.parametrize("n,p,r,chunk_size,bb_threads", CASE_GRID) def test_mamba_mimo_bwd_combined_prereduce_relative_errors_lt_10pct( mods: SimpleNamespace, n: int, p: int, r: int, chunk_size: int, bb_threads: int ) -> None: inputs = build_inputs( mods=mods, n=n, p=p, r=r, chunk_size=chunk_size, seed=6789 + n + p + r + chunk_size, has_z=False, ) b, s, h, p_dim = inputs["v"].shape dout_prereduce = torch.randn((b, s, r, h, p_dim), device="cuda", dtype=FIXED_DTYPE) ref_grads = run_ref_backward_fp32( mods, inputs, contract_mimo_out=False, grad_output=dout_prereduce, ) ( dq, dk, dv, dA, ddt, dtrap, dq_bias, dk_bias, dmimo_v, dmimo_z, dmimo_o, dangles, dD, dz, ) = mods.bwd.mamba_mimo_bwd_combined( dout_prereduce, inputs["q"], inputs["k"], inputs["v"], inputs["q_bias"], inputs["k_bias"], inputs["mimo_v"], None, None, None, inputs["angles"], inputs["dA_cs"], inputs["dA_cs_rev"], inputs["dt"], inputs["trap"], inputs["D"], inputs["segsum"], chunk_size, inputs["rotary_dim_divisor"], FIXED_DTYPE, bb_threads=bb_threads, ) assert dmimo_o is None assert dmimo_z is None assert dz is None comparisons = { "dq_prereduce": (dq, ref_grads["dq"]), "dk_prereduce": (dk, ref_grads["dk"]), "dv_prereduce": (dv, ref_grads["dv"]), "dA_prereduce": (dA, ref_grads["dA"]), "ddt_prereduce": (ddt, ref_grads["ddt"]), "dtrap_prereduce": (dtrap, ref_grads["dtrap"]), "dq_bias_prereduce": (dq_bias, ref_grads["dq_bias"]), "dk_bias_prereduce": (dk_bias, ref_grads["dk_bias"]), "dmimo_v_prereduce": (dmimo_v, ref_grads["dmimo_v"]), "dangles_prereduce": (dangles, ref_grads["dangles"]), "dD_prereduce": (dD, ref_grads["dD"]), } for name, (ours, ref) in comparisons.items(): assert_stable_rel( ours, ref, label=name, cfg=f"N={n}, P={p}, R={r}, chunk={chunk_size}, bb_threads={bb_threads}", ) def test_mamba_mimo_smoke_forward_backward(mods: SimpleNamespace) -> None: inputs = make_smoke_inputs( batch=FIXED_B, seqlen=FIXED_S, mimo_rank=4, nheads_qk=FIXED_G, nheads=FIXED_H, headdim_qk=128, headdim_v=64, chunk_size=16, rotary_dim_divisor=FIXED_ROTARY_DIM_DIVISOR, device="cuda", dtype=FIXED_DTYPE, seed=999, ) out = mods.top.mamba3_mimo(**inputs) assert out.shape == (FIXED_B, FIXED_S, FIXED_H, 64) loss = out.float().sum() loss.backward() grad_names = [ "Q", "K", "V", "ADT", "DT", "Trap", "Q_bias", "K_bias", "MIMO_V", "MIMO_Z", "MIMO_Out", "Angles", "D", "Z", ] for name in grad_names: grad = inputs[name].grad assert grad is not None, f"Missing gradient for {name}" assert torch.isfinite(grad).all(), f"Non-finite gradient detected for {name}" ================================================ FILE: tests/ops/triton/test_layernorm_gated.py ================================================ import math import torch import torch.nn.functional as F import pytest from einops import rearrange, repeat from mamba_ssm.ops.triton.layernorm_gated import layernorm_fn, rms_norm_ref @pytest.mark.parametrize("norm_before_gate", [True, False]) # @pytest.mark.parametrize("norm_before_gate", [False]) @pytest.mark.parametrize("has_group", [False, True]) # @pytest.mark.parametrize("has_group", [False]) @pytest.mark.parametrize("is_rms_norm", [False, True]) # @pytest.mark.parametrize("is_rms_norm", [True]) @pytest.mark.parametrize("has_z", [False, True]) # @pytest.mark.parametrize("has_z", [True]) @pytest.mark.parametrize("has_bias", [False, True]) # @pytest.mark.parametrize("has_bias", [False]) # @pytest.mark.parametrize('dtype', [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize('dtype', [torch.float16]) # @pytest.mark.parametrize("wtype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("wtype", [torch.float32]) @pytest.mark.parametrize('d', [2048, 4096]) # @pytest.mark.parametrize('d', [4096]) def test_layer_norm_gated(d, dtype, wtype, has_bias, has_z, is_rms_norm, has_group, norm_before_gate): if not has_z and not norm_before_gate: pytest.skip() if not norm_before_gate and not is_rms_norm: # Reference LN isn't implemented for this case yet pytest.skip() device = 'cuda' rtol, atol = (1e-5, 1e-5) if dtype == torch.float32 else (1e-2, 8e-3) group_size = None if not has_group else 64 # set seed torch.random.manual_seed(0) batch = 16 seqlen = 1024 x = torch.randn(batch, seqlen, d, dtype=dtype, device=device, requires_grad=True) if has_z: z = torch.randn(batch, seqlen, d, dtype=dtype, device=device, requires_grad=True) else: z = None weight = torch.randn(d, dtype=wtype, device=device, requires_grad=True) if has_bias: bias = torch.randn(d, dtype=wtype, device=device, requires_grad=True) else: bias = None x_ref = x.detach().clone().requires_grad_() x_pt = x.detach().clone().requires_grad_() z_ref = z.detach().clone().requires_grad_() if z is not None else None z_pt = z.detach().clone().requires_grad_() if z is not None else None weight_ref = weight.detach().clone().requires_grad_() weight_pt = weight.detach().clone().requires_grad_() bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None bias_pt = bias.detach().clone().requires_grad_() if bias is not None else None out = layernorm_fn(x, weight, bias, z=z, eps=1e-5, group_size=group_size, norm_before_gate=norm_before_gate, is_rms_norm=is_rms_norm) if not is_rms_norm: if not has_group: out_ref = F.layer_norm(x_ref.float(), (d,), weight=weight_ref.float(), bias=bias_ref.float() if bias_ref is not None else None, eps=1e-5) out_pt = F.layer_norm(x_pt.to(wtype), (d,), weight=weight_pt, bias=bias_pt, eps=1e-5) else: out_ref = rearrange(F.layer_norm(rearrange(x_ref, "... (g d) -> ... g d", d=group_size).float(), (group_size,), eps=1e-5), "... g d -> ... (g d)") * weight_ref.float() if has_bias: out_ref = out_ref + bias_ref.float() out_pt = rearrange(F.layer_norm(rearrange(x_pt, "... (g d) -> ... g d", d=group_size), (group_size,), eps=1e-5), "... g d -> ... (g d)") * weight_pt if has_bias: out_pt = out_pt + bias_pt if has_z and norm_before_gate: out_ref = out_ref * F.silu(z_ref.float()) out_pt = out_pt * F.silu(z_pt) else: out_ref = rms_norm_ref(x_ref, weight_ref, bias_ref, z=z_ref, eps=1e-5, group_size=group_size, norm_before_gate=norm_before_gate) out_pt = rms_norm_ref(x_pt, weight_pt, bias_pt, z=z_pt, eps=1e-5, group_size=group_size, norm_before_gate=norm_before_gate, upcast=False) print(f"Max diff = {(out - out_ref).abs().max().item()}") print(f"Max diff Pytorch = {(out_pt - out_ref).abs().max().item()}") assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + atol g = torch.randn_like(out) out.backward(g) out_ref.backward(g) out_pt.backward(g) print(f"Max dx diff = {(x.grad - x_ref.grad).abs().max().item()}") print(f"Max dx diff Pytorch = {(x_pt.grad - x_ref.grad).abs().max().item()}") if has_z: print(f"Max dz diff = {(z.grad - z_ref.grad).abs().max().item()}") print(f"Max dz diff Pytorch = {(z_pt.grad - z_ref.grad).abs().max().item()}") print(f"Max dw diff = {(weight.grad - weight_ref.grad).abs().max().item()}") print(f"Max dw diff Pytorch = {(weight_pt.grad - weight_ref.grad).abs().max().item()}") if has_bias: print(f"Max db diff = {(bias.grad - bias_ref.grad).abs().max().item()}") print(f"Max db diff Pytorch = {(bias_pt.grad - bias_ref.grad).abs().max().item()}") assert (x.grad - x_ref.grad).abs().max().item() <= 2 * (x_pt.grad - x_ref.grad).abs().max().item() + atol if has_z: assert (z.grad - z_ref.grad).abs().max().item() <= 2 * (z_pt.grad - z_ref.grad).abs().max().item() + atol assert (weight.grad - weight_ref.grad).abs().max().item() <= 2 * (weight_pt.grad - weight_ref.grad).abs().max().item() + atol if has_bias: assert (bias.grad - bias_ref.grad).abs().max().item() <= 2 * (bias_pt.grad - bias_ref.grad).abs().max().item() + atol ================================================ FILE: tests/ops/triton/test_mamba3_siso.py ================================================ """ Mamba-3 SISO Kernel Tests Copyright (c) 2025, Dao AI Lab, Goombalab """ import copy import math from typing import Optional, Tuple import pytest import torch import torch.nn.functional as F from einops import rearrange, repeat from mamba_ssm.ops.triton.mamba3.mamba3_siso_combined import mamba3_siso_combined from mamba_ssm.ops.triton.mamba3.mamba3_siso_step import mamba3_siso_step # Reference Implementations def _segsum(x: torch.Tensor) -> torch.Tensor: """Segment sum helper for attention computation.""" T = x.size(-1) x = repeat(x, "... d -> ... d e", e=T) mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1) x = x.masked_fill(~mask, 0) x_segsum = torch.cumsum(x, dim=-2) mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0) x_segsum = x_segsum.masked_fill(~mask, -torch.inf) return x_segsum def mamba3_siso_step_ref( Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, ADT: torch.Tensor, DT: torch.Tensor, Trap: torch.Tensor, Q_bias: torch.Tensor, K_bias: torch.Tensor, Angles: torch.Tensor, D: Optional[torch.Tensor] = None, Z: Optional[torch.Tensor] = None, Input_States: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]: """Reference implementation of Mamba-3 in recurrent (step) mode. Args: Input_States: Optional tuple of (Angle_State, SSM_State, K_State, V_State) Returns: out: Output tensor (batch, seqlen, nheads, headdim_v) Final_States: Tuple of (Angle_State, SSM_State, K_State, V_State) """ batch, seqlen, nheads_qk, headdim_qk = Q.shape _, _, nheads, headdim_v = V.shape headdim_angles = Angles.shape[-1] device = Q.device assert seqlen > 0 Angles = torch.tanh(Angles) * math.pi # Expand Q/K for GQA if Q.shape[2] != V.shape[2]: Q = repeat(Q, "b s h_bc d -> b s (h_bc g) d", g=V.shape[2] // Q.shape[2]) if K.shape[2] != V.shape[2]: K = repeat(K, "b s h_bc d -> b s (h_bc g) d", g=V.shape[2] // K.shape[2]) def apply_rotary_emb(tensor, cos, sin): tensor_reshaped = tensor.view(*tensor.shape[:-1], -1, 2) tensor_0 = tensor_reshaped[..., 0] tensor_1 = tensor_reshaped[..., 1] if cos.shape[-1] < tensor_0.shape[-1]: pad_size = tensor_0.shape[-1] - cos.shape[-1] cos = F.pad(cos, (0, pad_size), value=1.0) sin = F.pad(sin, (0, pad_size), value=0.0) rotated_0 = tensor_0 * cos - tensor_1 * sin rotated_1 = tensor_0 * sin + tensor_1 * cos rotated = torch.stack([rotated_0, rotated_1], dim=-1).view_as(tensor) return rotated # Initialize states if Input_States is not None: Angle_State, SSM_State, K_State, V_State = Input_States Angle_State = Angle_State.clone() SSM_State = SSM_State.clone().to(torch.float32) K_State = K_State.clone() V_State = V_State.clone() else: Angle_State = torch.zeros((batch, nheads, headdim_angles), dtype=torch.float32, device=device) SSM_State = torch.zeros((batch, nheads, headdim_v, headdim_qk), dtype=torch.float32, device=device) K_State = torch.zeros((batch, nheads, headdim_qk), dtype=Q.dtype, device=device) V_State = torch.zeros((batch, nheads, headdim_v), dtype=V.dtype, device=device) TWO_PI = 2 * math.pi out_arr = [] for idx in range(seqlen): q = Q[:, idx, :, :] + Q_bias.unsqueeze(0) k = K[:, idx, :, :] + K_bias.unsqueeze(0) v = V[:, idx, :, :] adt = ADT[:, :, idx] dt = DT[:, :, idx] trap = Trap[:, :, idx] z = Z[:, idx, :, :] if Z is not None else None angles = Angles[:, idx, :, :] # Update angle state with cumsum: Angle_State = (Angle_State + Angles * DT) mod 2π Angle_State = Angle_State + angles * dt.unsqueeze(-1) Angle_State = Angle_State - TWO_PI * torch.floor(Angle_State / TWO_PI) # Apply rotary embeddings to Q and K using cumulative angles cos_angles = torch.cos(Angle_State) sin_angles = torch.sin(Angle_State) q_rot = apply_rotary_emb(q, cos_angles, sin_angles) k_rot = apply_rotary_emb(k, cos_angles, sin_angles) trap = torch.sigmoid(trap) alpha = torch.exp(adt) beta = (1 - trap) * dt * alpha gamma = trap * dt # Update SSM state using previous K_State and V_State SSM_State = alpha.unsqueeze(-1).unsqueeze(-1) * SSM_State SSM_State = SSM_State + beta.unsqueeze(-1).unsqueeze(-1) * (K_State.unsqueeze(-2) * V_State.unsqueeze(-1)) SSM_State = SSM_State + gamma.unsqueeze(-1).unsqueeze(-1) * (k_rot.unsqueeze(-2) * v.unsqueeze(-1)) # Compute output out = torch.einsum("bhdD, bhD -> bhd", SSM_State, q_rot.to(SSM_State.dtype)) if D is not None: out = out + D[None, :, None] * v if Z is not None: out = out * z * torch.sigmoid(z) out_arr.append(out) # Update K and V states for next step K_State = k_rot V_State = v out = torch.stack(out_arr, dim=1) Final_States = (Angle_State, SSM_State, K_State, V_State) return out, Final_States def mamba3_siso_fwd_ref( Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, ADT: torch.Tensor, DT: torch.Tensor, Trap: torch.Tensor, Q_bias: torch.Tensor, K_bias: torch.Tensor, Angles: torch.Tensor, D: Optional[torch.Tensor] = None, Z: Optional[torch.Tensor] = None, Initial_States: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = None, chunk_size: int = 64, dtype: torch.dtype = torch.float32, cu_seqlens: Optional[torch.Tensor] = None, ): """Reference implementation of Mamba-3 forward pass. Args: Initial_States: Optional tuple of (Angle_State, SSM_State, K_State, V_State) Returns: out_z: Output with Z gating applied final_states: (Final_Angle_State, Final_SSM_State, Final_K_State, Final_V_State) """ batch, total_seqlen, nheads_qk, headdim_qk = Q.shape _, _, nheads, headdim_v = V.shape headdim_angles = Angles.shape[-1] device = Q.device is_varlen = cu_seqlens is not None if is_varlen: assert batch == 1 # Cast inputs Q = Q.to(dtype) K = K.to(dtype) V = V.to(dtype) ADT = ADT.to(torch.float32) DT = DT.to(torch.float32) Trap = Trap.to(dtype) Q_bias = Q_bias.to(dtype) K_bias = K_bias.to(dtype) Angles = Angles.to(dtype) if D is not None: D = D.to(dtype) if Z is not None: Z = Z.to(dtype) if Initial_States is not None: Initial_Angle_State, Initial_SSM_State, Initial_K_State, Initial_V_State = Initial_States Angles = torch.tanh(Angles) * math.pi # Expand Q/K for GQA if Q.shape[2] != V.shape[2]: Q = repeat(Q, "b s h_bc d -> b s (h_bc g) d", g=V.shape[2] // Q.shape[2]) if K.shape[2] != V.shape[2]: K = repeat(K, "b s h_bc d -> b s (h_bc g) d", g=V.shape[2] // K.shape[2]) out_zs = [] Final_Angle_States = [] Final_SSM_States = [] Final_K_States = [] Final_V_States = [] TWO_PI = 2 * math.pi def _rotary(tensor, cos, sin): tensor_reshaped = tensor.view(*tensor.shape[:-1], -1, 2) tensor_0 = tensor_reshaped[..., 0] tensor_1 = tensor_reshaped[..., 1] if cos.shape[-1] < tensor_0.shape[-1]: pad_size = tensor_0.shape[-1] - cos.shape[-1] cos = F.pad(cos, (0, pad_size), value=1.0) sin = F.pad(sin, (0, pad_size), value=0.0) rotated_0 = tensor_0 * cos - tensor_1 * sin rotated_1 = tensor_0 * sin + tensor_1 * cos return torch.stack([rotated_0, rotated_1], dim=-1).view_as(tensor) def compute_one_sequence(seq_idx): if is_varlen: start_idx, end_idx = cu_seqlens[seq_idx].item(), cu_seqlens[seq_idx + 1].item() Q_curr = Q[0, start_idx:end_idx, :, :] K_curr = K[0, start_idx:end_idx, :, :] V_curr = V[0, start_idx:end_idx, :, :] ADT_curr = ADT[0, :, start_idx:end_idx] DT_curr = DT[0, :, start_idx:end_idx] Trap_curr = Trap[0, :, start_idx:end_idx] Angles_curr = Angles[0, start_idx:end_idx, :, :] Z_curr = Z[0, start_idx:end_idx, :, :] if Z is not None else None else: Q_curr = Q[seq_idx] K_curr = K[seq_idx] V_curr = V[seq_idx] ADT_curr = ADT[seq_idx] DT_curr = DT[seq_idx] Trap_curr = Trap[seq_idx] Angles_curr = Angles[seq_idx] Z_curr = Z[seq_idx] if Z is not None else None Trap_curr = torch.sigmoid(Trap_curr) seqlen_curr = Q_curr.shape[0] Angles_scaled = Angles_curr.float() * DT_curr.transpose(0, 1).unsqueeze(-1) Angles_Cumsum = torch.cumsum(Angles_scaled, dim=0) if Initial_States is not None: Initial_Angle_State_curr = Initial_Angle_State[seq_idx] Angles_Cumsum = Angles_Cumsum + Initial_Angle_State_curr.unsqueeze(0) Angles_Cumsum = Angles_Cumsum - TWO_PI * torch.floor(Angles_Cumsum / TWO_PI) Final_Angle_States.append(Angles_Cumsum[-1]) # Initialize acc_states if Initial_States is not None: Initial_SSM_State_curr = Initial_SSM_State[seq_idx] Initial_K_State_curr = Initial_K_State[seq_idx] Initial_V_State_curr = Initial_V_State[seq_idx] scalar = DT_curr[:, 0] * (1 - Trap_curr[:, 0]) acc_states = Initial_SSM_State_curr + Initial_V_State_curr[:, :, None] * Initial_K_State_curr[:, None, :] * scalar[:, None, None] else: acc_states = torch.zeros((nheads, headdim_v, headdim_qk), device=device, dtype=torch.float32) # Compute shifted gamma and scale DT_shifted = F.pad(DT_curr[:, 1:], (0, 1)) Trap_shifted = F.pad(Trap_curr[:, 1:], (0, 1)) shifted_gamma = DT_shifted * (1 - Trap_shifted) scale = DT_curr * Trap_curr + DT_shifted * (1 - Trap_shifted) # Add biases Q_curr = Q_curr + Q_bias.unsqueeze(0) K_curr = K_curr + K_bias.unsqueeze(0) # Compute QK dot for skip connection QK_dot = torch.sum(K_curr * Q_curr, dim=-1) * shifted_gamma.transpose(0, 1) # Rotary embeddings using Angles_Cumsum cos_angles_curr = torch.cos(Angles_Cumsum).to(Q_curr.dtype) sin_angles_curr = torch.sin(Angles_Cumsum).to(Q_curr.dtype) Q_curr = _rotary(Q_curr, cos_angles_curr, sin_angles_curr) K_curr = _rotary(K_curr, cos_angles_curr, sin_angles_curr) Final_K_States.append(K_curr[-1]) Final_V_States.append(V_curr[-1]) K_curr_scaled = K_curr * scale.transpose(0, 1).unsqueeze(-1).to(K_curr.dtype) # Compute output via quadratic attention QK = torch.einsum("thd,shd->hts", Q_curr, K_curr_scaled) QK_causal = torch.tril(QK) QK_causal = (QK_causal * torch.exp(_segsum(ADT_curr))).to(QK_causal.dtype) out = torch.einsum("hts,shd->thd", QK_causal, V_curr) if Initial_States is not None: da_cs = torch.cumsum(ADT_curr, dim=-1) exp_da_cs = torch.exp(da_cs) out = out + torch.einsum("hDd,thd,ht->thD", acc_states.to(Q_curr.dtype), Q_curr, exp_da_cs.to(Q_curr.dtype)) if D is not None: out = out + D[None, :, None] * V_curr out = out - V_curr * QK_dot.unsqueeze(-1) if Z_curr is not None: out = out * Z_curr * torch.sigmoid(Z_curr) out_zs.append(out) # Compute final state da_cs_last = torch.exp(torch.sum(ADT_curr, dim=-1)) da_cs_rev = torch.exp(torch.sum(ADT_curr, dim=-1, keepdim=True) - torch.cumsum(ADT_curr, dim=-1)) V_curr_scaled = V_curr * da_cs_rev.permute(1, 0).unsqueeze(-1).to(V_curr.dtype) final_acc_states = acc_states * da_cs_last.unsqueeze(-1).unsqueeze(-1) + torch.einsum( "thd,thD->hDd", K_curr_scaled, V_curr_scaled.to(K_curr_scaled.dtype)) Final_SSM_States.append(final_acc_states) num_sequences = cu_seqlens.size(0) - 1 if is_varlen else batch for seq_idx in range(num_sequences): compute_one_sequence(seq_idx) if not is_varlen: out_zs = torch.stack(out_zs, dim=0) Final_Angle_States = torch.stack(Final_Angle_States, dim=0) Final_SSM_States = torch.stack(Final_SSM_States, dim=0) Final_K_States = torch.stack(Final_K_States, dim=0) Final_V_States = torch.stack(Final_V_States, dim=0) else: out_zs = torch.cat(out_zs, dim=0).unsqueeze(0) Final_Angle_States = torch.stack(Final_Angle_States, dim=0) Final_SSM_States = torch.stack(Final_SSM_States, dim=0) Final_K_States = torch.stack(Final_K_States, dim=0) Final_V_States = torch.stack(Final_V_States, dim=0) return out_zs, (Final_Angle_States, Final_SSM_States, Final_K_States, Final_V_States) # ================================================================== # Test Utilities # ================================================================== def detach_clone(*args): """Detach and clone tensors, preserving None values.""" return tuple([arg.detach().clone().requires_grad_() if arg is not None else None for arg in args]) @torch.no_grad() def relative_error( ker: torch.Tensor, ref: torch.Tensor, eps: float = 1e-6, ref_mag_mask: float = 1e-2, p: float = 0.95, name: str = "", print_top_errors: bool = True, angle: bool = False, # if True: use circular absolute error; else: relative error ) -> float: assert ker.shape == ref.shape ker_xx = ker.detach().to(torch.float32) ref_xx = ref.detach().to(torch.float32) abs_ref = ref_xx.abs() if angle: delta = ker_xx - ref_xx delta = torch.remainder(delta + math.pi, 2 * math.pi) - math.pi abs_diff = delta.abs() else: abs_diff = (ker_xx - ref_xx).abs() mask = abs_ref >= ref_mag_mask if not mask.any(): return 0.0 vals = abs_diff[mask].flatten() if angle else (abs_diff[mask] / (abs_ref[mask] + eps)).flatten() n = vals.numel() k = max(1, min(n, int(math.ceil(p * n)))) err = vals.kthvalue(k).values.item() if print_top_errors and err > 0.01: print(f"\n Top 10 errors for {name}:") diff_flat = abs_diff.flatten() ref_flat = ref_xx.flatten() ker_flat = ker_xx.flatten() topk = diff_flat.topk(min(10, diff_flat.numel())) for i, idx in enumerate(topk.indices): idx = idx.item() r = ref_flat[idx].item() k_val = ker_flat[idx].item() d = diff_flat[idx].item() if angle: # For angles, show absolute angular error (radians) print(f" {i}: ref={r:.6e}, ker={k_val:.6e}, ang_err={d:.6e} rad") else: rel_e = d / (abs(r) + eps) if abs(r) >= ref_mag_mask else float('nan') print(f" {i}: ref={r:.6e}, ker={k_val:.6e}, diff={d:.6e}, rel={rel_e:.2%}") return err def create_mamba3_siso_inputs( batch: int, seqlen: int, nheads: int, nheads_qk: int, headdim_qk: int, headdim_v: int, dtype: torch.dtype, device: str, has_D: bool, has_Z: bool, has_input_states: bool, cu_seqlens: Optional[torch.Tensor] = None, requires_grad: bool = False, ): num_sequences = cu_seqlens.size(0) - 1 if cu_seqlens is not None else batch Q = torch.randn((batch, seqlen, nheads_qk, headdim_qk), device=device, dtype=dtype) Q = F.rms_norm(Q, normalized_shape=(headdim_qk,)).clone() K = torch.randn((batch, seqlen, nheads_qk, headdim_qk), device=device, dtype=dtype) K = F.rms_norm(K, normalized_shape=(headdim_qk,)).clone() V = torch.randn((batch, seqlen, nheads, headdim_v), device=device, dtype=dtype) dt_max, dt_min = 0.1, 0.001 a_init = -torch.empty(batch, nheads, seqlen, device=device, dtype=torch.float32).uniform_(1.0, 16.0) dt = torch.exp( torch.rand(batch, nheads, seqlen, device=device, dtype=torch.float32) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) ) ADT = (a_init * dt).contiguous() DT = dt.contiguous() Trap = torch.empty(batch, nheads, seqlen, dtype=dtype, device=device).uniform_(0.0, 1.0).clone() Q_bias = torch.randn(nheads, headdim_qk, dtype=dtype, device=device) K_bias = torch.randn(nheads, headdim_qk, dtype=dtype, device=device) # headdim_angles constraint: 2*headdim_angles <= headdim_qk headdim_angles = headdim_qk // 4 Angles = torch.randn(batch, seqlen, nheads, headdim_angles, dtype=torch.float32, device=device) D = torch.ones((nheads,), device=device, dtype=torch.float32) if has_D else None Z = torch.randn((batch, seqlen, nheads, headdim_v), device=device, dtype=dtype) if has_Z else None if has_input_states: Input_Angle_State = torch.randn((num_sequences, nheads, headdim_angles), device=device, dtype=torch.float32) Input_SSM_State = torch.randn((num_sequences, nheads, headdim_v, headdim_qk), device=device, dtype=torch.float32) Input_K_State = torch.randn((num_sequences, nheads, headdim_qk), device=device, dtype=torch.float32) Input_V_State = torch.randn((num_sequences, nheads, headdim_v), device=device, dtype=torch.float32) Input_States = (Input_Angle_State, Input_SSM_State, Input_K_State, Input_V_State) else: Input_States = None if requires_grad: Q.requires_grad_(True) K.requires_grad_(True) V.requires_grad_(True) ADT.requires_grad_(True) DT.requires_grad_(True) Trap.requires_grad_(True) Q_bias.requires_grad_(True) K_bias.requires_grad_(True) Angles.requires_grad_(True) if D is not None: D.requires_grad_(True) if Z is not None: Z.requires_grad_(True) if Input_States is not None: for state in Input_States: state.requires_grad_(True) return { 'Q': Q, 'K': K, 'V': V, 'ADT': ADT, 'DT': DT, 'Trap': Trap, 'Q_bias': Q_bias, 'K_bias': K_bias, 'Angles': Angles, 'D': D, 'Z': Z, 'Input_States': Input_States, } # ================================================================== # Triton Step Kernel Test # ================================================================== def test_mamba3_siso_step(nheads_qk=4, has_Z=True, has_D=True): """Test Mamba-3 step kernel against reference recurrent implementation.""" device = 'cuda' rtol = 5e-2 dtype = torch.bfloat16 torch.random.manual_seed(42) batch = 128 seqlen = 2345 nheads = 32 headdim_qk = 128 headdim_v = 64 headdim_angles = headdim_qk // 4 inputs = create_mamba3_siso_inputs( batch, seqlen, nheads, nheads_qk, headdim_qk, headdim_v, dtype, device, has_D=has_D, has_Z=has_Z, has_input_states=True, requires_grad=False ) Q_full, K_full, V_full, ADT_full, DT_full, Trap_full, Q_bias, K_bias, Angles_full, D, Z_full, Input_States = inputs['Q'], inputs['K'], inputs['V'], inputs['ADT'], inputs['DT'], inputs['Trap'], inputs['Q_bias'], inputs['K_bias'], inputs['Angles'], inputs['D'], inputs['Z'], inputs['Input_States'] angle_state_triton, ssm_state_triton, k_state_triton, v_state_triton = Input_States outputs_triton = [] for step in range(seqlen): Q_step = Q_full[:, step, :, :].contiguous() K_step = K_full[:, step, :, :].contiguous() V_step = V_full[:, step, :, :].contiguous() ADT_step = ADT_full[:, :, step].contiguous() DT_step = DT_full[:, :, step].contiguous() Trap_step = Trap_full[:, :, step].contiguous() Angles_step = Angles_full[:, step, :, :].contiguous() Z_step = Z_full[:, step, :, :].contiguous() if Z_full is not None else None input_states_triton = (angle_state_triton, ssm_state_triton, k_state_triton, v_state_triton) out_triton, output_states_triton = mamba3_siso_step( Q_step, K_step, V_step, ADT_step, DT_step, Trap_step, Q_bias, K_bias, Angles_step, D, Z_step, input_states_triton ) angle_state_triton, ssm_state_triton, k_state_triton, v_state_triton = output_states_triton outputs_triton.append(out_triton) outputs_triton = torch.stack(outputs_triton, dim=1) # Reference implementation outputs_ref, final_states_ref = mamba3_siso_step_ref( Q_full, K_full, V_full, ADT_full, DT_full, Trap_full, Q_bias, K_bias, Angles_full, D, Z_full, Input_States=Input_States ) angle_state_ref, ssm_state_ref, k_state_ref, v_state_ref = final_states_ref out_rel_err = relative_error(outputs_triton, outputs_ref) print(f"Step output relative error: {out_rel_err:.2e}") assert out_rel_err < rtol, f"Step output relative error {out_rel_err} exceeds tolerance {rtol}" # Compare final states angle_state_err = relative_error(angle_state_triton, angle_state_ref) ssm_state_err = relative_error(ssm_state_triton, ssm_state_ref) k_state_err = relative_error(k_state_triton, k_state_ref) v_state_err = relative_error(v_state_triton, v_state_ref) print(f"Final state errors - Angle: {angle_state_err:.2e}, SSM: {ssm_state_err:.2e}, K: {k_state_err:.2e}, V: {v_state_err:.2e}") assert angle_state_err < rtol, f"Angle state error {angle_state_err} exceeds tolerance {rtol}" assert ssm_state_err < rtol, f"SSM state error {ssm_state_err} exceeds tolerance {rtol}" assert k_state_err < rtol, f"K state error {k_state_err} exceeds tolerance {rtol}" assert v_state_err < rtol, f"V state error {v_state_err} exceeds tolerance {rtol}" # ================================================================== # Triton Forward+Backward Batched Kernel Test # ================================================================== # Combined Forward+Backward batched mode test # NOTE: Relative erros for tensors are within 6-8% (especially when they are reduced). # The error for angle is ~20% because cumsum accumulates error over sequence length. This # error becomes ~3% when cumsum (angle-dt) kernel is removed def test_mamba3_siso_combined_batched(nheads_qk=4, has_Z=True, has_D=True, headdim_qk=128): """Test Mamba-3 combined forward+backward against fwd reference. """ device = 'cuda' rtol = 1e-1 dtype = torch.bfloat16 torch.random.manual_seed(42) batch = 16 seqlen = 2345 nheads = 32 headdim_v = 64 chunk_size = 64 half = seqlen // 2 inputs = create_mamba3_siso_inputs( batch, seqlen, nheads, nheads_qk, headdim_qk, headdim_v, dtype, device, has_D=has_D, has_Z=has_Z, has_input_states=True, requires_grad=True ) inputs_ref = copy.deepcopy(inputs) # Reference: use mamba3_siso_fwd_ref to compute full sequence output. Out_ref, Final_States_ref = mamba3_siso_fwd_ref( inputs_ref['Q'], inputs_ref['K'], inputs_ref['V'], inputs_ref['ADT'], inputs_ref['DT'], inputs_ref['Trap'], inputs_ref['Q_bias'], inputs_ref['K_bias'], inputs_ref['Angles'], inputs_ref['D'], inputs_ref['Z'], inputs_ref['Input_States'], ) # Kernel: two-pass forward via state passing. Out_first, Angle_State_1, SSM_State_1, K_State_1, V_State_1 = mamba3_siso_combined( inputs['Q'][:, :half], inputs['K'][:, :half], inputs['V'][:, :half], inputs['ADT'][:, :, :half], inputs['DT'][:, :, :half], inputs['Trap'][:, :, :half], inputs['Q_bias'], inputs['K_bias'], inputs['Angles'][:, :half], inputs['D'], inputs['Z'][:, :half] if has_Z else None, inputs['Input_States'], chunk_size=chunk_size, return_final_states=True, ) Out_second, Final_Angle_State, Final_SSM_State, Final_K_State, Final_V_State = mamba3_siso_combined( inputs['Q'][:, half:], inputs['K'][:, half:], inputs['V'][:, half:], inputs['ADT'][:, :, half:], inputs['DT'][:, :, half:], inputs['Trap'][:, :, half:], inputs['Q_bias'], inputs['K_bias'], inputs['Angles'][:, half:], inputs['D'], inputs['Z'][:, half:] if has_Z else None, (Angle_State_1, SSM_State_1, K_State_1, V_State_1), chunk_size=chunk_size, return_final_states=True, ) Out_kernel = torch.cat([Out_first, Out_second], dim=1) # Forward comparison out_err = relative_error(Out_kernel, Out_ref, name="Output") print(f"Forward output error: {out_err:.2e}") # assert out_err < rtol, f"Forward output error {out_err:.2e} exceeds tolerance {rtol}" # Compare final states Final_Angle_State_ref, Final_SSM_State_ref, Final_K_State_ref, Final_V_State_ref = Final_States_ref for state_name, ker_state, ref_state in [ ('Angle', Final_Angle_State, Final_Angle_State_ref), ('SSM', Final_SSM_State, Final_SSM_State_ref), ('K', Final_K_State, Final_K_State_ref), ('V', Final_V_State, Final_V_State_ref), ]: err = relative_error(ker_state, ref_state, name=f"Final_{state_name}_State", angle=(state_name=='Angle')) print(f"Final_{state_name}_State error: {err:.2e}") # assert err < rtol, f"Final_{state_name}_State error {err:.2e} exceeds tolerance" # Backward # Give gradients to both output and final states dO = torch.randn_like(Out_ref) dFinal_Angle_State = torch.randn_like(Final_Angle_State) dFinal_SSM_State = torch.randn_like(Final_SSM_State) dFinal_K_State = torch.randn_like(Final_K_State) dFinal_V_State = torch.randn_like(Final_V_State) # Reference backward torch.autograd.backward( [Out_ref, Final_Angle_State_ref, Final_SSM_State_ref, Final_K_State_ref, Final_V_State_ref], [dO, dFinal_Angle_State, dFinal_SSM_State, dFinal_K_State, dFinal_V_State], ) # Kernel backward torch.autograd.backward( [Out_kernel, Final_Angle_State, Final_SSM_State, Final_K_State, Final_V_State], [dO, dFinal_Angle_State, dFinal_SSM_State, dFinal_K_State, dFinal_V_State], ) # Compare gradients for grad_name in ['Q', 'K', 'V', 'ADT', 'DT', 'Trap', 'Q_bias', 'K_bias', 'Angles']: err = relative_error(inputs[grad_name].grad, inputs_ref[grad_name].grad, name=f"d{grad_name}") print(f"d{grad_name} error: {err:.2e}") # assert err < rtol, f"d{grad_name} error {err:.2e} exceeds tolerance" if has_D: err = relative_error(inputs['D'].grad, inputs_ref['D'].grad, name="dD") print(f"dD error: {err:.2e}") if has_Z: err = relative_error(inputs['Z'].grad, inputs_ref['Z'].grad, name="dZ") print(f"dZ error: {err:.2e}") # Input state gradients for i, state_name in enumerate(['Angle', 'SSM', 'K', 'V']): err = relative_error(inputs['Input_States'][i].grad, inputs_ref['Input_States'][i].grad, name=f"dInput_{state_name}_State") print(f"dInput_{state_name}_State error: {err:.2e}") # ================================================================== # Triton Forward+Backward Varlen Kernel Test # ================================================================== # Combined Forward+Backward varlen mode test # NOTE: Relative erros for tensors are within 6-8% (especially when they are reduced). # The error for angle is ~20% because cumsum accumulates error over sequence length. This # error becomes ~3% when cumsum (angle-dt) kernel is removed def test_mamba3_siso_combined_varlen(nheads_qk=4, has_Z=True, has_D=True, headdim_qk=128): """Test Mamba-3 combined forward+backward with variable-length sequences against fwd reference. """ device = 'cuda' rtol = 1e-1 dtype = torch.bfloat16 torch.random.manual_seed(42) num_sequences = 8 seq_lengths = [2345, 2346, 2347, 2348, 2349, 2350, 2351, 2352] total_seqlen = sum(seq_lengths) # Create cu_seqlens cu_seqlens = torch.tensor([0] + list(torch.cumsum(torch.tensor(seq_lengths), dim=0).tolist()), dtype=torch.int32, device=device) batch = 1 # Varlen requires batch=1 nheads = 32 headdim_v = 64 chunk_size = 64 headdim_angles = headdim_qk // 4 # Create packed inputs (batch=1, total_seqlen, ...) Q = torch.randn((batch, total_seqlen, nheads_qk, headdim_qk), device=device, dtype=dtype) Q = F.rms_norm(Q, normalized_shape=(headdim_qk,)).clone() K = torch.randn((batch, total_seqlen, nheads_qk, headdim_qk), device=device, dtype=dtype) K = F.rms_norm(K, normalized_shape=(headdim_qk,)).clone() V = torch.randn((batch, total_seqlen, nheads, headdim_v), device=device, dtype=dtype) dt_max, dt_min = 0.1, 0.001 a_init = -torch.empty(batch, nheads, total_seqlen, device=device, dtype=torch.float32).uniform_(1.0, 16.0) dt = torch.exp( torch.rand(batch, nheads, total_seqlen, device=device, dtype=torch.float32) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) ) ADT = (a_init * dt).contiguous() DT = dt.contiguous() Trap = torch.empty(batch, nheads, total_seqlen, dtype=dtype, device=device).uniform_(0.0, 1.0).clone() Q_bias = torch.randn(nheads, headdim_qk, dtype=dtype, device=device) K_bias = torch.randn(nheads, headdim_qk, dtype=dtype, device=device) Angles = torch.randn(batch, total_seqlen, nheads, headdim_angles, dtype=dtype, device=device) * 0.1 D = torch.ones((nheads,), device=device, dtype=torch.float32) if has_D else None Z = torch.randn((batch, total_seqlen, nheads, headdim_v), device=device, dtype=dtype) if has_Z else None # Input states: one per sequence Input_Angle_State = torch.randn((num_sequences, nheads, headdim_angles), device=device, dtype=torch.float32) Input_SSM_State = torch.randn((num_sequences, nheads, headdim_v, headdim_qk), device=device, dtype=torch.float32) Input_K_State = torch.randn((num_sequences, nheads, headdim_qk), device=device, dtype=torch.float32) Input_V_State = torch.randn((num_sequences, nheads, headdim_v), device=device, dtype=torch.float32) Input_States = (Input_Angle_State, Input_SSM_State, Input_K_State, Input_V_State) # Enable gradients Q.requires_grad_(True) K.requires_grad_(True) V.requires_grad_(True) ADT.requires_grad_(True) DT.requires_grad_(True) Trap.requires_grad_(True) Q_bias.requires_grad_(True) K_bias.requires_grad_(True) Angles.requires_grad_(True) if D is not None: D.requires_grad_(True) if Z is not None: Z.requires_grad_(True) for state in Input_States: state.requires_grad_(True) # Create deep copies for reference inputs_ref = { 'Q': Q.detach().clone().requires_grad_(True), 'K': K.detach().clone().requires_grad_(True), 'V': V.detach().clone().requires_grad_(True), 'ADT': ADT.detach().clone().requires_grad_(True), 'DT': DT.detach().clone().requires_grad_(True), 'Trap': Trap.detach().clone().requires_grad_(True), 'Q_bias': Q_bias.detach().clone().requires_grad_(True), 'K_bias': K_bias.detach().clone().requires_grad_(True), 'Angles': Angles.detach().clone().requires_grad_(True), 'D': D.detach().clone().requires_grad_(True) if D is not None else None, 'Z': Z.detach().clone().requires_grad_(True) if Z is not None else None, 'Input_States': tuple(s.detach().clone().requires_grad_(True) for s in Input_States), } inputs_ker = { 'Q': Q, 'K': K, 'V': V, 'ADT': ADT, 'DT': DT, 'Trap': Trap, 'Q_bias': Q_bias, 'K_bias': K_bias, 'Angles': Angles, 'D': D, 'Z': Z, 'Input_States': Input_States, } # Reference: use mamba3_siso_fwd_ref with cu_seqlens Out_ref, Final_States_ref = mamba3_siso_fwd_ref( inputs_ref['Q'], inputs_ref['K'], inputs_ref['V'], inputs_ref['ADT'], inputs_ref['DT'], inputs_ref['Trap'], inputs_ref['Q_bias'], inputs_ref['K_bias'], inputs_ref['Angles'], inputs_ref['D'], inputs_ref['Z'], inputs_ref['Input_States'], cu_seqlens=cu_seqlens, ) # Kernel: single call with cu_seqlens Out_kernel, Final_Angle_State, Final_SSM_State, Final_K_State, Final_V_State = mamba3_siso_combined( inputs_ker['Q'], inputs_ker['K'], inputs_ker['V'], inputs_ker['ADT'], inputs_ker['DT'], inputs_ker['Trap'], inputs_ker['Q_bias'], inputs_ker['K_bias'], inputs_ker['Angles'], inputs_ker['D'], inputs_ker['Z'], inputs_ker['Input_States'], chunk_size=chunk_size, return_final_states=True, cu_seqlens=cu_seqlens, ) # Forward comparison out_err = relative_error(Out_kernel, Out_ref, name="Output") print(f"Forward output error: {out_err:.2e}") # Compare final states Final_Angle_State_ref, Final_SSM_State_ref, Final_K_State_ref, Final_V_State_ref = Final_States_ref for state_name, ker_state, ref_state in [ ('Angle', Final_Angle_State, Final_Angle_State_ref), ('SSM', Final_SSM_State, Final_SSM_State_ref), ('K', Final_K_State, Final_K_State_ref), ('V', Final_V_State, Final_V_State_ref), ]: err = relative_error(ker_state, ref_state, name=f"Final_{state_name}_State", angle=(state_name=='Angle')) print(f"Final_{state_name}_State error: {err:.2e}") # Backward dO = torch.randn_like(Out_ref) dFinal_Angle_State = torch.randn_like(Final_Angle_State) dFinal_SSM_State = torch.randn_like(Final_SSM_State) dFinal_K_State = torch.randn_like(Final_K_State) dFinal_V_State = torch.randn_like(Final_V_State) # Reference backward torch.autograd.backward( [Out_ref, Final_Angle_State_ref, Final_SSM_State_ref, Final_K_State_ref, Final_V_State_ref], [dO, dFinal_Angle_State, dFinal_SSM_State, dFinal_K_State, dFinal_V_State], ) # Kernel backward torch.autograd.backward( [Out_kernel, Final_Angle_State, Final_SSM_State, Final_K_State, Final_V_State], [dO, dFinal_Angle_State, dFinal_SSM_State, dFinal_K_State, dFinal_V_State], ) # Compare gradients for grad_name in ['Q', 'K', 'V', 'ADT', 'DT', 'Trap', 'Q_bias', 'K_bias', 'Angles']: err = relative_error(inputs_ker[grad_name].grad, inputs_ref[grad_name].grad, name=f"d{grad_name}") print(f"d{grad_name} error: {err:.2e}") if has_D: err = relative_error(inputs_ker['D'].grad, inputs_ref['D'].grad, name="dD") print(f"dD error: {err:.2e}") if has_Z: err = relative_error(inputs_ker['Z'].grad, inputs_ref['Z'].grad, name="dZ") print(f"dZ error: {err:.2e}") # Input state gradients for i, state_name in enumerate(['Angle', 'SSM', 'K', 'V']): err = relative_error(inputs_ker['Input_States'][i].grad, inputs_ref['Input_States'][i].grad, name=f"dInput_{state_name}_State") print(f"dInput_{state_name}_State error: {err:.2e}") # ================================================================== # Sanity check test: Step reference and Forward reference match # ================================================================== def test_mamba3_siso_step_ref_vs_fwd_ref(nheads_qk=4, has_Z=True, has_D=True): """Test that mamba3_siso_step_ref and mamba3_siso_fwd_ref produce identical outputs.""" device = 'cuda' rtol = 1e-4 # Both are pure Python/PyTorch, so should match very closely dtype = torch.float32 # Use float32 for reference-vs-reference comparison torch.random.manual_seed(42) batch = 16 seqlen = 2048 nheads = 32 headdim_qk = 128 headdim_v = 64 headdim_angles = headdim_qk // 4 inputs = create_mamba3_siso_inputs( batch, seqlen, nheads, nheads_qk, headdim_qk, headdim_v, dtype, device, has_D=has_D, has_Z=has_Z, has_input_states=True, requires_grad=False, ) # --- Step ref --- out_step, final_states_step = mamba3_siso_step_ref( inputs['Q'], inputs['K'], inputs['V'], inputs['ADT'], inputs['DT'], inputs['Trap'], inputs['Q_bias'], inputs['K_bias'], inputs['Angles'], inputs['D'], inputs['Z'], Input_States=inputs['Input_States'], ) angle_state_step, ssm_state_step, k_state_step, v_state_step = final_states_step # --- Fwd ref --- out_fwd, final_states_fwd = mamba3_siso_fwd_ref( inputs['Q'], inputs['K'], inputs['V'], inputs['ADT'], inputs['DT'], inputs['Trap'], inputs['Q_bias'], inputs['K_bias'], inputs['Angles'], inputs['D'], inputs['Z'], Initial_States=inputs['Input_States'], dtype=dtype, ) angle_state_fwd, ssm_state_fwd, k_state_fwd, v_state_fwd = final_states_fwd # --- Compare outputs --- out_err = relative_error(out_step, out_fwd, name="Output", ref_mag_mask=1e-3) print(f"Output error: {out_err:.2e}") # assert out_err < rtol, f"Output error {out_err:.2e} exceeds tolerance {rtol}" # --- Compare final states --- for state_name, step_state, fwd_state in [ ('Angle', angle_state_step, angle_state_fwd), ('SSM', ssm_state_step, ssm_state_fwd), ('K', k_state_step, k_state_fwd), ('V', v_state_step, v_state_fwd), ]: err = relative_error(step_state, fwd_state, name=f"Final_{state_name}_State", angle=(state_name == 'Angle'), ref_mag_mask=1e-3) print(f"Final_{state_name}_State error: {err:.2e}") # Main function if __name__ == "__main__": print("Running Mamba-3 step reference vs forward reference test...") test_mamba3_siso_step_ref_vs_fwd_ref() print("="*100) print("\nRunning Mamba-3 combined forward+backward batched test...") test_mamba3_siso_combined_batched() print("="*100) print("\nRunning Mamba-3 combined forward+backward varlen test...") test_mamba3_siso_combined_varlen() print("="*100) ================================================ FILE: tests/ops/triton/test_selective_state_update.py ================================================ # Copyright (C) 2023, Tri Dao. import math import torch import torch.nn.functional as F import pytest from einops import rearrange, repeat from mamba_ssm.ops.triton.selective_state_update import selective_state_update, selective_state_update_ref @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) # @pytest.mark.parametrize('itype', [torch.float16]) @pytest.mark.parametrize("has_z", [False, True]) # @pytest.mark.parametrize('has_z', [True]) @pytest.mark.parametrize("dstate", [16, 32, 64]) # @pytest.mark.parametrize("dstate", [16]) @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) # @pytest.mark.parametrize("dim", [2048]) def test_selective_state_update(dim, dstate, has_z, itype): device = "cuda" rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2) if itype == torch.bfloat16: rtol, atol = 1e-2, 5e-2 if torch.version.hip: atol *= 2 # set seed torch.random.manual_seed(0) batch_size = 2 state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device) x = torch.randn(batch_size, dim, device=device, dtype=itype) dt = torch.randn(batch_size, dim, device=device, dtype=itype) dt_bias = torch.rand(dim, device=device) - 4.0 A = -torch.rand(dim, dstate, device=device) - 1.0 B = torch.randn(batch_size, dstate, device=device) C = torch.randn(batch_size, dstate, device=device) D = torch.randn(dim, device=device) if has_z: z = torch.randn_like(x) else: z = None state_ref = state.detach().clone() out = selective_state_update(state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True) out_ref = selective_state_update_ref(state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) # @pytest.mark.parametrize('itype', [torch.float16]) @pytest.mark.parametrize("has_z", [False, True]) # @pytest.mark.parametrize('has_z', [True]) @pytest.mark.parametrize("tie_hdim", [False, True]) # @pytest.mark.parametrize('tie_hdim', [True]) @pytest.mark.parametrize("ngroups", [1, 2, 4]) # @pytest.mark.parametrize("ngroups", [2]) @pytest.mark.parametrize("dstate", [16, 32, 64]) # @pytest.mark.parametrize("dstate", [16]) @pytest.mark.parametrize("dim", [2048, 4096]) # @pytest.mark.parametrize("dim", [2048]) def test_selective_state_update_with_heads(dim, dstate, ngroups, has_z, tie_hdim, itype): device = "cuda" rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 3e-2) if itype == torch.bfloat16: rtol, atol = 1e-2, 1e-1 # set seed torch.random.manual_seed(0) batch_size = 2 headdim = 64 nheads = dim // headdim state = torch.randn(batch_size, nheads, headdim, dstate, dtype=itype, device=device) x = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype) if not tie_hdim: dt = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype) dt_bias = torch.rand(nheads, headdim, device=device) - 4.0 A = -torch.rand(nheads, headdim, dstate, device=device) - 1.0 D = torch.randn(nheads, headdim, device=device) else: dt = repeat(torch.randn(batch_size, nheads, device=device, dtype=itype), "b h -> b h p", p=headdim) dt_bias = repeat(torch.rand(nheads, device=device) - 4.0, "h -> h p", p=headdim) A = repeat(-torch.rand(nheads, device=device) - 1.0, "h -> h p n", p=headdim, n=dstate) D = repeat(torch.randn(nheads, device=device), "h -> h p", p=headdim) B = torch.randn(batch_size, ngroups, dstate, device=device) C = torch.randn(batch_size, ngroups, dstate, device=device) if has_z: z = torch.randn_like(x) else: z = None state_ref = state.detach().clone() state_og = state.detach().clone() out = selective_state_update(state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True) out_ref = selective_state_update_ref(state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) # @pytest.mark.parametrize('itype', [torch.float16]) @pytest.mark.parametrize("has_z", [False, True]) # @pytest.mark.parametrize('has_z', [True]) @pytest.mark.parametrize("dstate", [16, 32, 64]) # @pytest.mark.parametrize("dstate", [16]) @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) # @pytest.mark.parametrize("dim", [2048]) def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype): device = "cuda" rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2) if itype == torch.bfloat16: rtol, atol = 6e-2, 6e-2 if torch.version.hip: atol *= 2 # set seed torch.random.manual_seed(0) batch_size = 16 total_entries = 10 * batch_size state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device) state_indices = torch.randperm(total_entries)[:batch_size].to(dtype=torch.int32, device=device) x = torch.randn(batch_size, dim, device=device, dtype=itype) dt = torch.randn(batch_size, dim, device=device, dtype=itype) dt_bias = torch.rand(dim, device=device) - 4.0 A = -torch.rand(dim, dstate, device=device) - 1.0 B = torch.randn(batch_size, dstate, device=device) C = torch.randn(batch_size, dstate, device=device) D = torch.randn(dim, device=device) if has_z: z = torch.randn_like(x) else: z = None state_ref = state[state_indices,:].detach().clone() out = selective_state_update(state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True, state_batch_indices=state_indices) out_ref = selective_state_update_ref(state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") assert torch.allclose(state[state_indices,:], state_ref, rtol=rtol, atol=atol) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) #@pytest.mark.parametrize('itype', [torch.float32]) @pytest.mark.parametrize("has_z", [False, True]) # @pytest.mark.parametrize('has_z', [True]) @pytest.mark.parametrize("tie_hdim", [False, True]) # @pytest.mark.parametrize('tie_hdim', [True]) @pytest.mark.parametrize("ngroups", [1, 2, 4]) # @pytest.mark.parametrize("ngroups", [2]) @pytest.mark.parametrize("dstate", [16, 32, 64]) # @pytest.mark.parametrize("dstate", [16]) @pytest.mark.parametrize("dim", [2048, 4096]) # @pytest.mark.parametrize("dim", [2048]) def test_selective_state_update_with_heads_with_batch_indices(dim, dstate, ngroups, has_z, tie_hdim, itype): device = "cuda" rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 3e-2) if itype == torch.bfloat16: rtol, atol = 1e-1, 1e-1 # set seed torch.random.manual_seed(0) batch_size = 16 headdim = 64 nheads = dim // headdim total_entries = 10 * batch_size state = torch.randn(total_entries, nheads, headdim, dstate, dtype=itype, device=device) state_indices = torch.randperm(total_entries)[:batch_size].to(dtype=torch.int32, device=device) x = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype) if not tie_hdim: dt = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype) dt_bias = torch.rand(nheads, headdim, device=device) - 4.0 A = -torch.rand(nheads, headdim, dstate, device=device) - 1.0 D = torch.randn(nheads, headdim, device=device) else: dt = repeat(torch.randn(batch_size, nheads, device=device, dtype=itype), "b h -> b h p", p=headdim) dt_bias = repeat(torch.rand(nheads, device=device) - 4.0, "h -> h p", p=headdim) A = repeat(-torch.rand(nheads, device=device) - 1.0, "h -> h p n", p=headdim, n=dstate) D = repeat(torch.randn(nheads, device=device), "h -> h p", p=headdim) B = torch.randn(batch_size, ngroups, dstate, device=device) C = torch.randn(batch_size, ngroups, dstate, device=device) if has_z: z = torch.randn_like(x) else: z = None state_ref = state[state_indices,:].detach().clone() state_og = state[state_indices,:].detach().clone() out = selective_state_update(state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True, state_batch_indices=state_indices) out_ref = selective_state_update_ref(state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") assert torch.allclose(state[state_indices,:], state_ref, rtol=rtol, atol=atol) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) ================================================ FILE: tests/ops/triton/test_ssd.py ================================================ import math import torch import torch.nn.functional as F import pytest from einops import rearrange, repeat from mamba_ssm.ops.triton.ssd_chunk_state import chunk_state, chunk_state_ref from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_cumsum_fwd, _chunk_state_fwd from mamba_ssm.ops.triton.ssd_chunk_state import chunk_state_varlen from mamba_ssm.ops.triton.ssd_state_passing import state_passing, state_passing_ref from mamba_ssm.ops.triton.ssd_state_passing import _state_passing_fwd from mamba_ssm.ops.triton.ssd_chunk_scan import chunk_scan, chunk_scan_ref from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_chunk_scan, ssd_chunk_scan_combined_ref, ssd_selective_scan from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined, mamba_split_conv1d_scan_ref def detach_clone(*args): return tuple([arg.detach().clone().requires_grad_() if arg is not None else None for arg in args]) @pytest.mark.parametrize('dtype', [torch.float32, torch.float16, torch.bfloat16]) # @pytest.mark.parametrize('dtype', [torch.bfloat16]) @pytest.mark.parametrize('ngroups', [1, 2, 8, "max"]) # @pytest.mark.parametrize('ngroups', [1]) @pytest.mark.parametrize('chunk_size', [64, 128]) # @pytest.mark.parametrize('chunk_size', [128]) def test_chunk_state_varlen(chunk_size, ngroups, dtype): device = 'cuda' rtol, atol = (1e-2, 3e-3) # set seed torch.random.manual_seed(chunk_size + (ngroups if ngroups != "max" else 64)) batch = 300 seqlens = torch.randint(1, 200, (batch,), device=device) # batch = 3 # seqlens = torch.tensor([201, 56, 5], device=device) cu_seqlens = F.pad(seqlens.cumsum(0), (1, 0)) total_seqlen = seqlens.sum().item() seq_idx = torch.cat([torch.full((s,), i, dtype=torch.int32, device=device) for i, s in enumerate(seqlens)], dim=0).unsqueeze(0) dim = 4096 # dim = 64 headdim = 64 # dim = 32 dstate = 32 assert dim % headdim == 0 nheads = dim // headdim if ngroups == "max": ngroups = nheads assert nheads % ngroups == 0 B = torch.randn(total_seqlen, ngroups, dstate, dtype=dtype, device=device) / 5 x = torch.randn(total_seqlen, nheads, headdim, dtype=dtype, device=device) A = -0.1 * (torch.rand(nheads, device=device)) dt = F.softplus(torch.randn(total_seqlen, nheads, device=device, dtype=torch.float32) - 4) dA_cumsum, dt_rounded = _chunk_cumsum_fwd(dt.unsqueeze(0), A, chunk_size) chunk_states = _chunk_state_fwd(B.unsqueeze(0), x.unsqueeze(0), dt_rounded, dA_cumsum, seq_idx=seq_idx) chunk_states, _ = _state_passing_fwd(rearrange(chunk_states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1], seq_idx=seq_idx, chunk_size=chunk_size) chunk_states = rearrange(chunk_states, "... (p n) -> ... p n", n=dstate) chunk_states = chunk_states.squeeze(0) dA_cumsum = dA_cumsum.squeeze(0) dt_rounded = dt_rounded.squeeze(0) out = chunk_state_varlen(B, x, dt_rounded, dA_cumsum, cu_seqlens, chunk_states) out_ref = [] for b in range(batch): x_s = x[cu_seqlens[b]:cu_seqlens[b + 1]].unsqueeze(0) B_s = B[cu_seqlens[b]:cu_seqlens[b + 1]].unsqueeze(0) dt_s = dt[cu_seqlens[b]:cu_seqlens[b + 1]].unsqueeze(0) dA_cumsum_s, dt_rounded_s = _chunk_cumsum_fwd(dt_s, A, chunk_size) states = chunk_state(B_s, x_s, dt_rounded_s, dA_cumsum_s) _, final_states = _state_passing_fwd(rearrange(states, "... p n -> ... (p n)"), dA_cumsum_s[:, :, :, -1], chunk_size=chunk_size) final_states = rearrange(final_states, "... (p n) -> ... p n", n=dstate) out_ref.append(final_states) out_ref = torch.cat(out_ref, dim=0) print(f"Max diff = {(out - out_ref).abs().max().item()}") assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) ================================================ FILE: tests/test_determinism.py ================================================ # Copyright (c) 2024, Tri Dao, Albert Gu. import os import pytest import torch def _set_deterministic(enabled: bool) -> None: torch.use_deterministic_algorithms(enabled) if enabled: os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" def _set_seeds(seed: int) -> None: torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def _max_abs_diff(a: torch.Tensor, b: torch.Tensor) -> float: return (a.float() - b.float()).abs().max().item() def _make_inputs( *, seed: int, headdim: int, dstate: int, chunk_size: int = 256, ngroups: int = 1, dtype: torch.dtype = torch.bfloat16, d_has_hdim: bool = False, ) -> dict[str, torch.Tensor]: import math _set_seeds(seed) device = "cuda" batch = 2 seqlen = 2048 nheads = 8 nchunks = math.ceil(seqlen / chunk_size) x = torch.randn(batch, seqlen, nheads, headdim, device=device, dtype=dtype) dout = torch.randn_like(x) dt = torch.randn(batch, nheads, nchunks, chunk_size, device=device, dtype=torch.float32) dA_cumsum = torch.randn_like(dt) cb = torch.randn(batch, nchunks, ngroups, chunk_size, chunk_size, device=device, dtype=dtype) B = torch.randn(batch, seqlen, ngroups, dstate, device=device, dtype=dtype).contiguous() C = torch.randn(batch, seqlen, ngroups, dstate, device=device, dtype=dtype).contiguous() dstates = torch.randn(batch, nchunks, nheads, headdim, dstate, device=device, dtype=torch.float32) prev_states = torch.randn_like(dstates) ddA = torch.randn(batch, nheads, nchunks, chunk_size, device=device, dtype=torch.float32) ddt_out = torch.randn_like(ddA) dt_raw = torch.randn(batch, seqlen, nheads, device=device, dtype=dtype) A = (torch.randn(nheads, device=device, dtype=torch.float32) * -1.0).contiguous() dt_bias = torch.randn(nheads, device=device, dtype=torch.float32).contiguous() # D shape: (nheads, headdim) when d_has_hdim=True, else (nheads,) if d_has_hdim: D = torch.randn(nheads, headdim, device=device, dtype=torch.float32) else: D = torch.randn(nheads, device=device, dtype=torch.float32) return { "x": x, "dout": dout, "dt": dt, "dA_cumsum": dA_cumsum, "cb": cb, "B": B, "C": C, "dstates": dstates, "prev_states": prev_states, "ddA": ddA, "ddt_out": ddt_out, "dt_raw": dt_raw, "A": A, "dt_bias": dt_bias, "D": D, } def _run_case_outputs( *, case: str, deterministic: bool, seed: int, headdim: int = 64, dstate: int = 64, chunk_size: int = 256, ngroups: int = 1, dtype: torch.dtype = torch.bfloat16, d_has_hdim: bool = False, ) -> dict[str, torch.Tensor]: _set_deterministic(deterministic) t = _make_inputs( seed=seed, headdim=headdim, dstate=dstate, chunk_size=chunk_size, ngroups=ngroups, dtype=dtype, d_has_hdim=d_has_hdim, ) if case == "chunk_scan_bwd_dx": from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_dx dx, ddt = _chunk_scan_bwd_dx(t["cb"], t["x"], t["dt"], t["dA_cumsum"], t["dout"]) out = {"dx": dx, "ddt": ddt} elif case == "chunk_scan_bwd_dC": from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_dC dC, ddA_prev = _chunk_scan_bwd_dC(t["prev_states"], t["dA_cumsum"], t["dout"], C=t["C"], ngroups=1) out = {"dC": dC, "ddA_cumsum_prev": ddA_prev} elif case == "chunk_state_bwd_dx": from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_bwd_dx dx, ddt, ddA = _chunk_state_bwd_dx(t["B"], t["x"], t["dt"], t["dA_cumsum"], t["dstates"]) out = {"dx": dx, "ddt": ddt, "ddA_cumsum": ddA} elif case == "chunk_state_bwd_db": from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_bwd_db dB, ddA = _chunk_state_bwd_db(t["x"], t["dt"], t["dA_cumsum"], t["dstates"], B=t["B"], ngroups=1) out = {"dB": dB, "ddA_cumsum": ddA} elif case == "chunk_state_bwd_ddAcs_stable": from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_bwd_ddAcs_stable ddA = _chunk_state_bwd_ddAcs_stable(t["B"], t["x"], t["dt"], t["dA_cumsum"], t["dstates"]) out = {"ddA_cumsum": ddA} elif case == "chunk_cumsum_bwd": from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_cumsum_bwd ddt, dA, ddt_bias = _chunk_cumsum_bwd(t["ddA"], t["ddt_out"], t["dt_raw"], t["A"], dt_bias=t["dt_bias"], dt_softplus=True) out = {"ddt": ddt, "dA": dA, "ddt_bias": ddt_bias} elif case.startswith("combined_bwd_dx"): from mamba_ssm.ops.triton.ssd_combined import _chunk_scan_chunk_state_bwd_dx dx, ddt, dD = _chunk_scan_chunk_state_bwd_dx(t["x"], t["dt"], t["dA_cumsum"], t["B"], t["cb"], t["dout"], t["dstates"], D=t["D"]) out = {"dx": dx, "ddt": ddt, "dD": dD} else: raise AssertionError(f"Unknown case: {case}") torch.cuda.synchronize() return {k: v.detach().clone().float() for k, v in out.items() if v is not None} _KERNEL_CASES = [ "chunk_scan_bwd_dx", "chunk_scan_bwd_dC", "chunk_state_bwd_dx", "chunk_state_bwd_db", "chunk_state_bwd_ddAcs_stable", "chunk_cumsum_bwd", ] _COMBINED_CASES = [ ("combined_bwd_dx", False), ("combined_bwd_dx_d_hdim", True), ] _HEADDIMS = [64, 128] _DSTATES = [64] def _kernel_is_reproducible(case: str, headdim: int, dstate: int, d_has_hdim: bool = False): runs = 5 outs = [ _run_case_outputs(case=case, deterministic=True, seed=123, headdim=headdim, dstate=dstate, d_has_hdim=d_has_hdim) for _ in range(runs) ] ref = outs[0] for i in range(1, runs): for k in ref: assert _max_abs_diff(ref[k], outs[i][k]) == 0.0, f"{case} output {k} differs (headdim={headdim}, dstate={dstate})" def _kernel_close_to_default(case: str, headdim: int, dstate: int, d_has_hdim: bool = False): atol = rtol = 1e-2 det = _run_case_outputs(case=case, deterministic=True, seed=123, headdim=headdim, dstate=dstate, d_has_hdim=d_has_hdim) for _ in range(3): default = _run_case_outputs(case=case, deterministic=False, seed=123, headdim=headdim, dstate=dstate, d_has_hdim=d_has_hdim) for k in det: assert torch.allclose(default[k], det[k], atol=atol, rtol=rtol), f"{case} output {k} not close (headdim={headdim}, dstate={dstate})" @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") @pytest.mark.parametrize("dstate", _DSTATES) @pytest.mark.parametrize("headdim", _HEADDIMS) @pytest.mark.parametrize("case", _KERNEL_CASES) def test_kernel_reproducible(case: str, headdim: int, dstate: int): _kernel_is_reproducible(case, headdim, dstate) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") @pytest.mark.parametrize("dstate", _DSTATES) @pytest.mark.parametrize("headdim", _HEADDIMS) @pytest.mark.parametrize("case,d_has_hdim", _COMBINED_CASES) def test_combined_kernel_reproducible(case: str, d_has_hdim: bool, headdim: int, dstate: int): _kernel_is_reproducible(case, headdim, dstate, d_has_hdim) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") @pytest.mark.parametrize("dstate", _DSTATES) @pytest.mark.parametrize("headdim", _HEADDIMS) @pytest.mark.parametrize("case", _KERNEL_CASES) def test_kernel_close_to_default(case: str, headdim: int, dstate: int): _kernel_close_to_default(case, headdim, dstate) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") @pytest.mark.parametrize("dstate", _DSTATES) @pytest.mark.parametrize("headdim", _HEADDIMS) @pytest.mark.parametrize("case,d_has_hdim", _COMBINED_CASES) def test_combined_kernel_close_to_default(case: str, d_has_hdim: bool, headdim: int, dstate: int): _kernel_close_to_default(case, headdim, dstate, d_has_hdim) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") def test_default_mode_is_not_reproducible(): from mamba_ssm.modules.mamba2 import Mamba2 device = "cuda" dtype = torch.bfloat16 seed = 123 runs = 20 batch = 4 seqlen = 4096 _set_seeds(seed) model = Mamba2( d_model=256, d_state=64, headdim=64, expand=2, d_conv=4, chunk_size=256, use_mem_eff_path=True, device=device, dtype=dtype, ).train() x_data = torch.randn(batch, seqlen, model.d_model, device=device, dtype=dtype) def _run() -> dict[str, torch.Tensor]: _set_deterministic(False) model.zero_grad(set_to_none=True) x = x_data.clone().requires_grad_(True) y = model(x) (y.float().square().mean()).backward() torch.cuda.synchronize() grads = {"input": x.grad.detach().float().clone()} for name, p in model.named_parameters(): if p.grad is not None: grads[name] = p.grad.detach().float().clone() return grads _run() # warmup ref = _run() observed_diff = False for _ in range(runs - 1): g = _run() for k in ref: if _max_abs_diff(ref[k], g[k]) != 0.0: observed_diff = True break if observed_diff: break if not observed_diff: pytest.xfail( f"Did not observe nondeterminism in default mode after {runs} runs. " "This GPU may have deterministic atomic behavior at these shapes." ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") def test_mamba2_fwd_bwd_deterministic_reproducible(): from mamba_ssm.modules.mamba2 import Mamba2 device = "cuda" dtype = torch.bfloat16 seed = 123 runs = 5 batch = 2 seqlen = 2048 headdim = 64 _set_seeds(seed) _set_deterministic(True) model = Mamba2( d_model=headdim, d_state=16, headdim=headdim, expand=2, d_conv=4, chunk_size=16, use_mem_eff_path=True, device=device, dtype=dtype, ).train() x_data = torch.randn(batch, seqlen, model.d_model, device=device, dtype=dtype) def _run() -> tuple[torch.Tensor, dict[str, torch.Tensor]]: model.zero_grad(set_to_none=True) x = x_data.clone().requires_grad_(True) y = model(x) (y.float().square().mean()).backward() torch.cuda.synchronize() grads: dict[str, torch.Tensor] = {"input": x.grad.detach().float().clone()} for name, p in model.named_parameters(): if p.grad is not None: grads[name] = p.grad.detach().float().clone() return y.detach().float().clone(), grads _run() # warmup y0, g0 = _run() for _ in range(runs - 1): y, g = _run() assert _max_abs_diff(y0, y) == 0.0 assert g.keys() == g0.keys() for k in g0: assert _max_abs_diff(g0[k], g[k]) == 0.0, f"Mamba2 grad {k} differs" @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") def test_mamba2_fwd_bwd_deterministic_close_to_default(): from mamba_ssm.modules.mamba2 import Mamba2 device = "cuda" dtype = torch.bfloat16 seed = 123 batch = 2 seqlen = 2048 headdim = 64 atol = rtol = 1e-2 def _run(deterministic: bool) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: torch.use_deterministic_algorithms(deterministic, warn_only=True) _set_seeds(seed) model = Mamba2( d_model=headdim * 4, d_state=32, headdim=headdim, expand=2, d_conv=4, chunk_size=64, use_mem_eff_path=True, device=device, dtype=dtype, ).train() x = torch.randn(batch, seqlen, model.d_model, device=device, dtype=dtype).requires_grad_(True) y = model(x) (y.float().square().mean()).backward() torch.cuda.synchronize() grads: dict[str, torch.Tensor] = {"input": x.grad.detach().float().clone()} for name, p in model.named_parameters(): if p.grad is not None: grads[name] = p.grad.detach().float().clone() return y.detach().float().clone(), grads _run(False) # warmup y_default, g_default = _run(False) y_det, g_det = _run(True) assert torch.allclose(y_default, y_det, atol=atol, rtol=rtol), "Mamba2 output differs" for k in g_default: assert torch.allclose(g_default[k], g_det[k], atol=atol, rtol=rtol), f"Mamba2 grad {k} not close" ================================================ FILE: tests/test_generation.py ================================================ import torch import torch.nn.functional as F from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel from mamba_ssm.models.config_mamba import MambaConfig from mamba_ssm.utils.generation import InferenceParams import pytest from einops import rearrange, repeat def test_generation(): batch = 3 seqlen = 20 device = "cuda" dtype = torch.float16 config = MambaConfig( d_model=1024, n_layer=4, vocab_size=50277, ssm_cfg=dict(layer="Mamba2"), rms_norm=True, residual_in_fp32=True, fused_add_norm=True, pad_vocab_size_multiple=16, ) torch.manual_seed(2357) model = MambaLMHeadModel(config, device=device, dtype=dtype) x = torch.randint(0, 1000, (batch, seqlen), device=device, dtype=torch.long) out_ref = model(x).logits prompt_len = seqlen // 2 out = model.generate( input_ids = x[:, :prompt_len], max_length=seqlen, output_scores=True, return_dict_in_generate=True, cg=True, # Can turn off CUDA graph for easier debugging # instead of sampling, we take output tokens from x, to get logits for testing # For actual generation, don't pass in teacher_outputs teacher_outputs=x, ) out_scores = torch.stack(out.scores, dim=1) print(f"Max diff: {(out_scores - out_ref[:, prompt_len - 1: -1]).abs().max()}") assert torch.allclose(out_scores, out_ref[:, prompt_len - 1: -1], rtol=1e-3, atol=1e-2) def test_generation_varlen(): seqlens = [170, 65, 100] genlen = 20 total_seqlen = sum(seqlens) device = "cuda" dtype = torch.float16 config = MambaConfig( d_model=1024, n_layer=4, vocab_size=50277, ssm_cfg=dict(layer="Mamba2"), rms_norm=True, residual_in_fp32=True, fused_add_norm=True, pad_vocab_size_multiple=16, ) torch.manual_seed(2357) model = MambaLMHeadModel(config, device=device, dtype=dtype) xs = [torch.randint(0, 1000, (1, seqlen), device=device, dtype=torch.long) for seqlen in seqlens] # Reference 1: Forward pass with seq_idx x = torch.cat(xs, dim=1) seq_idx = torch.cat([torch.full((ids.shape[1],), i, dtype=torch.int32, device=device) for i, ids in enumerate(xs)], dim=0).unsqueeze(0) cu_seqlens = F.pad(torch.tensor(seqlens, device=device, dtype=torch.int32).cumsum(dim=0), (1, 0)) out_ref = model(x, seq_idx=seq_idx).logits # Only take the last @genlen logits of each sequence out_ref = torch.cat([out_ref[:, cu_seqlens[i + 1] - genlen - 1:cu_seqlens[i + 1] - 1] for i in range(len(seqlens))], dim=0) # Reference 2: Generate the last @genlen tokens of each sequence in a for loop out_loop = [] for input_ids in xs: out = model.generate( input_ids=input_ids[:, :-genlen], max_length=input_ids.shape[1], output_scores=True, return_dict_in_generate=True, cg=True, teacher_outputs=input_ids, ).scores out_loop.append(torch.stack(out, dim=1)) out_loop = torch.cat(out_loop, dim=0) print(f"Max diff between ref1 and ref2: {(out_loop - out_ref).abs().max()}") # Varlen generation input_ids = torch.cat([ids[:, :-genlen] for ids in xs], dim=1) prompt_seqlens = [seqlen - genlen for seqlen in seqlens] cu_seqlens = F.pad(torch.tensor(prompt_seqlens, device=device, dtype=torch.int32).cumsum(dim=0), (1, 0)) seq_idx = torch.cat([torch.full((seqlen,), i, dtype=torch.int32, device=device) for i, seqlen in enumerate(prompt_seqlens)], dim=0).unsqueeze(0) inference_params = InferenceParams(max_seqlen=2048, max_batch_size=len(seqlens)) scores, sequences = [], [] # Both seq_idx and cu_seqlens must be passed in for varlen generation logits = model(input_ids, inference_params=inference_params, seq_idx=seq_idx, cu_seqlens=cu_seqlens).logits logits = rearrange(logits[0, cu_seqlens[1:] - 1], "b d -> b 1 d") scores.append(logits) # In practice we should sample. In this case we take from the teacher_output for testing sampled_tokens = rearrange(torch.stack([ids[0, -genlen] for ids in xs], dim=0), "b -> b 1") sequences.append(sampled_tokens) for i in range(1, genlen): inference_params.seqlen_offset += 1 logits = model(sampled_tokens, inference_params=inference_params, num_last_tokens=1).logits scores.append(logits) # In practice we should sample. In this case we take from the teacher_output for testing sampled_tokens = rearrange(torch.stack([ids[0, -genlen + i] for ids in xs], dim=0), "b -> b 1") sequences.append(sampled_tokens) out_varlen = torch.cat(scores, dim=1) print(f"Max diff: {(out_varlen - out_ref).abs().max()}") assert (out_varlen - out_ref).abs().max() < 2 * (out_loop - out_ref).abs().max() def test_generation_varlen_with_padding(): seqlens = [170, 65, 100] non_padded_seqlen = sum(seqlens) padded_seqlen = 512 seqlens.append(padded_seqlen - non_padded_seqlen) genlen = 20 total_seqlen = sum(seqlens) assert total_seqlen == padded_seqlen device = "cuda" dtype = torch.float16 config = MambaConfig( d_model=1024, n_layer=4, vocab_size=50277, ssm_cfg=dict(layer="Mamba2"), rms_norm=True, residual_in_fp32=True, fused_add_norm=True, pad_vocab_size_multiple=16, ) torch.manual_seed(2357) model = MambaLMHeadModel(config, device=device, dtype=dtype) xs = [torch.randint(0, 1000, (1, seqlen), device=device, dtype=torch.long) for seqlen in seqlens] # Reference 1: Forward pass with seq_idx x = torch.cat(xs[:-1], dim=1) seq_idx = torch.cat([torch.full((ids.shape[1],), i, dtype=torch.int32, device=device) for i, ids in enumerate(xs[:-1])], dim=0).unsqueeze(0) cu_seqlens = F.pad(torch.tensor(seqlens[:-1], device=device, dtype=torch.int32).cumsum(dim=0), (1, 0)) out_ref = model(x, seq_idx=seq_idx).logits # Only take the last @genlen logits of each sequence out_ref = torch.cat([out_ref[:, cu_seqlens[i + 1] - genlen - 1:cu_seqlens[i + 1] - 1] for i in range(len(seqlens) - 1)], dim=0) # Reference 2: Generate the last @genlen tokens of each sequence in a for loop out_loop = [] for input_ids in xs[:-1]: out = model.generate( input_ids=input_ids[:, :-genlen], max_length=input_ids.shape[1], output_scores=True, return_dict_in_generate=True, cg=True, teacher_outputs=input_ids, ).scores out_loop.append(torch.stack(out, dim=1)) out_loop = torch.cat(out_loop, dim=0) print(f"Max diff between ref1 and ref2: {(out_loop - out_ref).abs().max()}") # Varlen generation input_ids = torch.cat([ids[:, :-genlen] for ids in xs], dim=1) prompt_seqlens = [seqlen - genlen for seqlen in seqlens] cu_seqlens = F.pad(torch.tensor(prompt_seqlens, device=device, dtype=torch.int32).cumsum(dim=0), (1, 0)) seq_idx = torch.cat([torch.full((seqlen,), i, dtype=torch.int32, device=device) for i, seqlen in enumerate(prompt_seqlens)], dim=0).unsqueeze(0) inference_params = InferenceParams(max_seqlen=2048, max_batch_size=len(seqlens)) # Account for padding offset = genlen * len(seqlens) seq_idx[non_padded_seqlen - offset : padded_seqlen - offset] = -1 cu_seqlens[-1] = cu_seqlens[-2] scores, sequences = [], [] # Both seq_idx and cu_seqlens must be passed in for varlen generation logits = model(input_ids, inference_params=inference_params, seq_idx=seq_idx, cu_seqlens=cu_seqlens).logits logits = rearrange(logits[0, cu_seqlens[1:] - 1], "b d -> b 1 d") scores.append(logits) # In practice we should sample. In this case we take from the teacher_output for testing sampled_tokens = rearrange(torch.stack([ids[0, -genlen] for ids in xs], dim=0), "b -> b 1") sequences.append(sampled_tokens) for i in range(1, genlen): inference_params.seqlen_offset += 1 logits = model(sampled_tokens, inference_params=inference_params, num_last_tokens=1).logits scores.append(logits) # In practice we should sample. In this case we take from the teacher_output for testing sampled_tokens = rearrange(torch.stack([ids[0, -genlen + i] for ids in xs], dim=0), "b -> b 1") sequences.append(sampled_tokens) out_varlen = torch.cat(scores, dim=1) print(f"Max diff: {(out_varlen[:-1] - out_ref).abs().max()}") assert (out_varlen[:-1] - out_ref).abs().max() < 2 * (out_loop - out_ref).abs().max() ================================================ FILE: usage.md ================================================ # Mamba adoption We've been very happy to see Mamba being adopted by many organizations and research labs to speed up their training / inference. This page contains a partial list of places where Mamba is being used. If you'd like to add links to your organization / product / codebase, please open a PR or email us. We'd very much like to hear from you! ## Large language models and multi-modal models - [Tencent's Hunyuan-TurboS (560B)](https://arxiv.org/abs/2505.15431) - [Nvidia Nemotron-H (8B, 47B, 56B)](https://research.nvidia.com/labs/adlr/nemotronh/) - [AI21 Jamba (398B)](https://www.ai21.com/blog/announcing-jamba-model-family/) - [TII Falcon-H1 (34B)](https://falconllm.tii.ae/falcon-h1.html) - [IBM Bamba (9B)](https://research.ibm.com/blog/bamba-ssm-transformer-model) - [Mistral's Codestral (7B)](https://mistral.ai/news/codestral-mamba) - [Nvidia Mamba-2 Hybrid (8B)](https://arxiv.org/abs/2406.07887) - [Microsoft Samba (4B)](https://arxiv.org/abs/2406.07522v1) - [TII Falcon-Mamba (7B)](https://falconllm.tii.ae/tii-releases-first-sslm-with-falcon-mamba-7b.html) ## Inference frameworks - vLLM - Nvidia's TensorRT-LLM ## Hardware - Nvidia GPUs - [AMD GPUs](https://rocm.blogs.amd.com/artificial-intelligence/mamba/README.html) - [AWS Trainium 2](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/tutorials/fused_mamba.html)