[
  {
    "path": ".github/workflows/README.md",
    "content": "# GitHub Workflow Tagging Flow\n\nThis repository uses separate tag lanes so FA2 and FA4 publishing do not collide.\n\n## Release lanes\n\n| Tag pattern | Workflow | Package target | Version source |\n| --- | --- | --- | --- |\n| `v*` | `.github/workflows/publish.yml` | Root package (`flash-attn`) | Root package version metadata |\n| `fa4-v*` | `.github/workflows/publish-fa4.yml` | `flash_attn/cute` package (`flash-attn-4`) | `setuptools-scm` with `fa4-v*` tags |\n\n## How to publish\n\n### FA2 / root package lane\n\n1. Create a tag matching `v*` (example: `v2.9.0`).\n2. Push that tag.\n3. `publish.yml` creates a release, builds wheel matrix artifacts, and publishes to PyPI.\n\n### FA4 / CUTE package lane\n\n1. Create a tag matching `fa4-v*` (example: `fa4-v0.1.0`).\n2. Push that tag.\n3. `publish-fa4.yml` builds from `flash_attn/cute`, creates a GitHub release, and uploads `flash-attn-4` to PyPI.\n\n## Guardrails\n\n- Do not use `v*` tags for FA4 releases.\n- Do not use `fa4-v*` tags for FA2 releases.\n- Keep `flash_attn/cute/pyproject.toml` tag parsing in sync with the FA4 tag prefix.\n- The workflow filename (`publish-fa4.yml`) is part of the PyPI trusted publishing OIDC identity — do not rename without updating PyPI.\n"
  },
  {
    "path": ".github/workflows/_build.yml",
    "content": "name: ~Build wheel template\n\non:\n  workflow_call:\n    inputs:\n      runs-on:\n        description: \"The runner to use for the build\"\n        required: true\n        type: string\n      python-version:\n        description: \"The Python version to use for the build\"\n        required: true\n        type: string\n      cuda-version:\n        description: \"The CUDA version to use for the build\"\n        required: true\n        type: string\n      torch-version:\n        description: \"The PyTorch version to use for the build\"\n        required: true\n        type: string\n      cxx11_abi:\n        description: \"The C++11 ABI to use for the build\"\n        required: true\n        type: string\n      upload-to-release:\n        description: \"Upload wheel to this release\"\n        required: false\n        type: boolean\n        default: false\n      release-version:\n        description: \"Upload wheel to this release\"\n        required: false\n        type: string\n\ndefaults:\n  run:\n    shell: bash -x -e -u -o pipefail {0}\n\njobs:\n  build-wheel:\n    runs-on: ${{ inputs.runs-on }}\n    name: Build wheel (${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }})\n    steps:\n      - name: Checkout\n        uses: actions/checkout@v5\n        with:\n          ref: ${{ inputs.release-version }}\n          submodules: recursive\n\n      - name: Set up Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: ${{ inputs.python-version }}\n\n      - name: Set CUDA and PyTorch versions\n        run: |\n          echo \"MATRIX_CUDA_VERSION=$(echo ${{ inputs.cuda-version }} | awk -F \\. {'print $1 $2'})\" >> $GITHUB_ENV\n          echo \"MATRIX_TORCH_VERSION=$(echo ${{ inputs.torch-version }} | awk -F \\. {'print $1 \".\" $2'})\" >> $GITHUB_ENV\n          echo \"WHEEL_CUDA_VERSION=$(echo ${{ inputs.cuda-version }} | awk -F \\. {'print $1'})\" >> $GITHUB_ENV\n          echo \"MATRIX_PYTHON_VERSION=$(echo ${{ inputs.python-version }} | awk -F \\. {'print $1 $2'})\" >> $GITHUB_ENV\n\n      - name: Free up disk space\n        if: ${{ runner.os == 'Linux' }}\n        # https://github.com/easimon/maximize-build-space/blob/master/action.yml\n        # https://github.com/easimon/maximize-build-space/tree/test-report\n        run: |\n          sudo rm -rf /usr/share/dotnet\n          sudo rm -rf /opt/ghc\n          sudo rm -rf /opt/hostedtoolcache/CodeQL\n\n      - name: Set up swap space\n        if: runner.os == 'Linux'\n        uses: pierotofy/set-swap-space@v1.0\n        with:\n          swap-size-gb: 10\n\n      - name: Install CUDA ${{ inputs.cuda-version }}\n        if: ${{ inputs.cuda-version != 'cpu' }}\n        uses: Jimver/cuda-toolkit@v0.2.30\n        id: cuda-toolkit\n        with:\n          cuda: ${{ inputs.cuda-version }}\n          linux-local-args: '[\"--toolkit\"]'\n          # default method is \"local\", and we're hitting some error with caching for CUDA 11.8 and 12.1\n          # method: ${{ (inputs.cuda-version == '11.8.0' || inputs.cuda-version == '12.1.0') && 'network' || 'local' }}\n          method: \"network\"\n          sub-packages: '[\"nvcc\"]'\n\n      - name: Install PyTorch ${{ inputs.torch-version }}+cu${{ inputs.cuda-version }}\n        run: |\n          pip install --upgrade pip\n          # With python 3.13 and torch 2.5.1, unless we update typing-extensions, we get error\n          # AttributeError: attribute '__default__' of 'typing.ParamSpec' objects is not writable\n          pip install typing-extensions==4.12.2\n          # Pick the highest available PyTorch wheel CUDA version that doesn't exceed system CUDA\n          export TORCH_CUDA_VERSION=$(python -c \"from os import environ as env; \\\n            available = { \\\n              '2.6': [118, 124, 126], \\\n              '2.7': [118, 126, 128], \\\n              '2.8': [126, 128, 129], \\\n              '2.9': [126, 128, 130], \\\n              '2.10': [126, 128, 130], \\\n            }[env['MATRIX_TORCH_VERSION']]; \\\n            sys_cuda = int(env['MATRIX_CUDA_VERSION']); \\\n            print(max(v for v in available if v <= sys_cuda))\" \\\n          )\n          # detect if we're on ARM\n          if [ \"$(uname -m)\" = \"aarch64\" ] || [ \"$(uname -m)\" = \"arm64\" ]; then\n              PLAT=linux_aarch64\n          else\n              PLAT=manylinux_2_27_x86_64.manylinux_2_28_x86_64\n          fi\n          echo \"PLAT=$PLAT\" >> $GITHUB_ENV\n          if [[ ${{ inputs.torch-version }} == *\"dev\"* ]]; then\n            # pip install --no-cache-dir --pre torch==${{ inputs.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}\n            # Can't use --no-deps because we need cudnn etc.\n            # Hard-coding this version of pytorch-triton for torch 2.9.0.dev20250904\n            pip install jinja2\n            TRITON_URL=https://download.pytorch.org/whl/nightly/pytorch_triton-3.4.0%2Bgitf7888497-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-${PLAT}.whl\n            TORCH_URL=https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ inputs.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-manylinux_2_28_$(uname -m).whl\n            pip install --no-cache-dir --pre \"${TRITON_URL}\"\n            pip install --no-cache-dir --pre \"${TORCH_URL}\"\n          else\n            pip install --no-cache-dir torch==${{ inputs.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION}\n          fi\n          nvcc --version\n          python --version\n          python -c \"import torch; print('PyTorch:', torch.__version__)\"\n          python -c \"import torch; print('CUDA:', torch.version.cuda)\"\n          python -c \"from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)\"\n\n      - name: Restore build cache\n        uses: actions/cache/restore@v4\n        with:\n          path: build.tar\n          key: build-${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}-${{ github.run_number }}-${{ github.run_attempt }}\n          restore-keys: |\n            build-${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}-\n\n      - name: Unpack build cache\n        run: |\n          echo ::group::Adjust timestamps\n          sudo find / -exec touch -t 197001010000 {} + || true\n          echo ::endgroup::\n\n          if [ -f build.tar ]; then\n            find . -mindepth 1 -maxdepth 1 ! -name 'build.tar' -exec rm -rf {} +\n            tar -xpvf build.tar -C .\n          else\n            echo \"No build.tar found, skipping\"\n          fi\n\n          ls -al ./\n          ls -al build/ || true\n          ls -al csrc/ || true\n\n      - name: Build wheel\n        id: build_wheel\n        run: |\n          # We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6\n          # https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810\n          # However this still fails so I'm using a newer version of setuptools\n          pip install setuptools==75.8.0\n          pip install ninja packaging wheel\n          export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH\n          export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH\n          # Limit MAX_JOBS otherwise the github runner goes OOM\n          # nvcc 11.8 can compile with 2 jobs, but nvcc 12.3 goes OOM\n\n          export MAX_JOBS=$([ \"$MATRIX_CUDA_VERSION\" == \"129\" ] || [ \"$MATRIX_CUDA_VERSION\" == \"130\" ] && echo 1 || echo 2)\n          export NVCC_THREADS=2\n          export FLASH_ATTENTION_FORCE_BUILD=\"TRUE\"\n          export FLASH_ATTENTION_FORCE_CXX11_ABI=${{ inputs.cxx11_abi }}\n\n          # 5h timeout since GH allows max 6h and we want some buffer\n          EXIT_CODE=0\n          timeout 5h python setup.py bdist_wheel --dist-dir=dist || EXIT_CODE=$?\n\n          if [ $EXIT_CODE -eq 0 ]; then\n            tmpname=cu${WHEEL_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ inputs.cxx11_abi }}\n            wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed \"s/-/+$tmpname-/2\")\n            ls dist/*whl |xargs -I {} mv {} dist/${wheel_name}\n            echo \"wheel_name=${wheel_name}\" >> $GITHUB_ENV\n          fi\n\n          # Store exit code in GitHub env for later steps\n          echo \"build_exit_code=$EXIT_CODE\" | tee -a \"$GITHUB_OUTPUT\"\n\n          # Do not fail the job if timeout killed the build\n          exit $EXIT_CODE\n\n      - name: Log build logs after timeout\n        if: always() && steps.build_wheel.outputs.build_exit_code == 124\n        run: |\n          ls -al ./\n          tar -cvf build.tar . --atime-preserve=replace\n\n      - name: Save build cache timeout\n        if: always() && steps.build_wheel.outputs.build_exit_code == 124\n        uses: actions/cache/save@v4\n        with:\n          key: build-${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}-${{ github.run_number }}-${{ github.run_attempt }}\n          path: build.tar\n\n      - name: Log Built Wheels\n        run: |\n          ls dist\n\n      - name: Get Release with tag\n        id: get_current_release\n        uses: joutvhu/get-release@v1\n        with:\n          tag_name: ${{ inputs.release-version }}\n        env:\n          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}\n\n      - name: Upload Release Asset\n        id: upload_release_asset\n        if: inputs.upload-to-release\n        uses: actions/upload-release-asset@v1\n        env:\n          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}\n        with:\n          upload_url: ${{ steps.get_current_release.outputs.upload_url }}\n          asset_path: ./dist/${{env.wheel_name}}\n          asset_name: ${{env.wheel_name}}\n          asset_content_type: application/*\n"
  },
  {
    "path": ".github/workflows/build.yml",
    "content": "name: Build wheels\n\non:\n  workflow_dispatch:\n    inputs:\n      runs-on:\n        description: \"The runner to use for the build\"\n        required: true\n        type: string\n        default: ubuntu-22.04\n      python-version:\n        description: \"The Python version to use for the build\"\n        required: true\n        type: string\n      cuda-version:\n        description: \"The CUDA version to use for the build\"\n        required: true\n        type: string\n      torch-version:\n        description: \"The PyTorch version to use for the build\"\n        required: true\n        type: string\n      cxx11_abi:\n        description: \"Enable torch flag C++11 ABI (TRUE/FALSE)\"\n        required: true\n        type: string\n      upload-to-release:\n        description: \"Upload wheel to this release\"\n        required: false\n        type: boolean\n        default: false\n      release-version:\n        description: \"Upload wheel to this release\"\n        required: false\n        type: string\n\njobs:\n  build-wheels:\n    uses: ./.github/workflows/_build.yml\n    with:\n      runs-on: ${{ inputs.runs-on }}\n      python-version: ${{ inputs.python-version }}\n      cuda-version: ${{ inputs.cuda-version }}\n      torch-version: ${{ inputs.torch-version }}\n      cxx11_abi: ${{ inputs.cxx11_abi }}\n      upload-to-release: ${{ inputs.upload-to-release }}\n      release-version: ${{ inputs.release-version }}\n"
  },
  {
    "path": ".github/workflows/pre-commit.yaml",
    "content": "name: Lint\n\non:\n  pull_request:\n    paths:\n      - 'flash_attn/cute/flash_bwd_sm90.py'\n      - 'flash_attn/cute/flash_bwd_preprocess.py'\n      - 'flash_attn/cute/flash_bwd_postprocess.py'\n      - 'flash_attn/cute/softmax.py'\n      - '.pre-commit-config.yaml'\n  push:\n    branches:\n      - main\n    paths:\n      - 'flash_attn/cute/flash_bwd_sm90.py'\n      - 'flash_attn/cute/flash_bwd_preprocess.py'\n      - 'flash_attn/cute/flash_bwd_postprocess.py'\n      - 'flash_attn/cute/softmax.py'\n      - '.pre-commit-config.yaml'\n\njobs:\n  pre-commit:\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v5\n\n      - name: Set up Python\n        uses: actions/setup-python@v6\n        with:\n          python-version: '3.11'\n\n      - name: Run pre-commit\n        uses: pre-commit/action@v3.0.1\n"
  },
  {
    "path": ".github/workflows/publish-fa4.yml",
    "content": "name: Publish flash-attn-4 to PyPI\n\non:\n  push:\n    tags:\n      - 'fa4-v*'\n\npermissions:\n  contents: write\n\njobs:\n  build:\n    runs-on: ubuntu-latest\n    steps:\n    - uses: actions/checkout@v4\n      with:\n        fetch-depth: 0\n    - uses: actions/setup-python@v5\n      with:\n        python-version: '3.12'\n    - name: Install build dependencies\n      run: pip install build twine\n    - name: Build package\n      run: python -m build\n      working-directory: flash_attn/cute\n    - name: Check package metadata\n      run: twine check dist/*\n      working-directory: flash_attn/cute\n    - name: Store distribution packages\n      uses: actions/upload-artifact@v4\n      with:\n        name: python-package-distributions\n        path: flash_attn/cute/dist/\n\n  github-release:\n    needs: build\n    runs-on: ubuntu-latest\n    steps:\n    - name: Download distribution packages\n      uses: actions/download-artifact@v4\n      with:\n        name: python-package-distributions\n        path: dist/\n    - name: Create GitHub Release\n      uses: softprops/action-gh-release@v2\n      with:\n        files: dist/*\n        generate_release_notes: true\n\n  publish-to-pypi:\n    needs: build\n    runs-on: ubuntu-latest\n    environment:\n      name: pypi\n      url: https://pypi.org/p/flash-attn-4\n    permissions:\n      id-token: write\n    steps:\n    - name: Download distribution packages\n      uses: actions/download-artifact@v4\n      with:\n        name: python-package-distributions\n        path: dist/\n    - name: Publish to PyPI\n      uses: pypa/gh-action-pypi-publish@release/v1\n"
  },
  {
    "path": ".github/workflows/publish.yml",
    "content": "# This workflow will:\n# - Create a new Github release\n# - Build wheels for supported architectures\n# - Deploy the wheels to the Github release\n# - Release the static code to PyPi\n# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries\n\nname: Build wheels and deploy\n\non:\n  create:\n    tags:\n      - v*\n\njobs:\n  setup_release:\n    name: Create Release\n    runs-on: ubuntu-latest\n    outputs:\n      release-version: ${{ steps.extract_branch.outputs.branch }}\n    steps:\n      - name: Get the tag version\n        id: extract_branch\n        run: echo \"branch=${GITHUB_REF#refs/tags/}\" >> $GITHUB_OUTPUT\n        shell: bash\n      - name: Create Release\n        env:\n          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}\n        run: gh release create ${{ steps.extract_branch.outputs.branch }} --repo $GITHUB_REPOSITORY --title ${{ steps.extract_branch.outputs.branch }} --generate-notes\n        shell: bash\n\n  build_wheels:\n    name: Build Wheel\n    needs: setup_release\n    strategy:\n      fail-fast: false\n      matrix:\n        # Using ubuntu-22.04 instead of 24.04 for more compatibility (glibc). Ideally we'd use the\n        # manylinux docker image, but I haven't figured out how to install CUDA on manylinux.\n        os: [ubuntu-22.04, ubuntu-22.04-arm]\n        python-version: [\"3.10\", \"3.11\", \"3.12\", \"3.13\"]\n        torch-version: [\"2.6.0\", \"2.7.1\", \"2.8.0\", \"2.9.1\", \"2.10.0\"]\n        cuda-version: [\"12.9.1\", \"13.0.1\"]\n        # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.\n        # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.\n        # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs)\n        # when building without C++11 ABI and using it on nvcr images.\n        cxx11_abi: [\"FALSE\", \"TRUE\"]\n        exclude:\n          # CUDA 13.0 is only supported by PyTorch 2.9+\n          - torch-version: \"2.6.0\"\n            cuda-version: \"13.0.1\"\n          - torch-version: \"2.7.1\"\n            cuda-version: \"13.0.1\"\n          - torch-version: \"2.8.0\"\n            cuda-version: \"13.0.1\"\n          # No aarch64 PyTorch wheels for 2.6.0\n          - torch-version: \"2.6.0\"\n            os: ubuntu-22.04-arm\n          # PyTorch 2.7+ pip wheels use CXX11_ABI=1 by default, no need for FALSE\n          - torch-version: \"2.7.1\"\n            cxx11_abi: \"FALSE\"\n          - torch-version: \"2.8.0\"\n            cxx11_abi: \"FALSE\"\n          - torch-version: \"2.9.1\"\n            cxx11_abi: \"FALSE\"\n          - torch-version: \"2.10.0\"\n            cxx11_abi: \"FALSE\"\n    uses: ./.github/workflows/_build.yml\n    with:\n      runs-on: ${{ matrix.os }}\n      python-version: ${{ matrix.python-version }}\n      cuda-version: ${{ matrix.cuda-version }}\n      torch-version: ${{ matrix.torch-version }}\n      cxx11_abi: ${{ matrix.cxx11_abi }}\n      release-version: ${{ needs.setup_release.outputs.release-version }}\n      upload-to-release: true\n\n  publish_package:\n    name: Publish package\n    needs: [build_wheels]\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v5\n      - uses: actions/setup-python@v6\n        with:\n          python-version: \"3.10\"\n      - name: Install dependencies\n        run: |\n          pip install ninja packaging wheel twine\n          # Install latest setuptools with support for pypi metadata 2.2 (improved compat w/ uv)\n          pip install setuptools==75.8.0\n          # We don't want to download anything CUDA-related here\n          pip install torch --index-url https://download.pytorch.org/whl/cpu\n      - name: Build core package\n        env:\n          FLASH_ATTENTION_SKIP_CUDA_BUILD: \"TRUE\"\n        run: |\n          python setup.py sdist --dist-dir=dist\n      - name: Deploy\n        env:\n          TWINE_USERNAME: \"__token__\"\n          TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}\n        run: |\n          python -m twine upload dist/*\n"
  },
  {
    "path": ".gitignore",
    "content": "*.ncu-rep\n*.sass\n*.ptx\n*.cubin\n*.plk\n.DS_store\n.vscode\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n\n# C extensions\n*.so\n\n# Distribution / packaging\nbin/\nbuild/\ndevelop-eggs/\ndist/\neggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\n*.egg-info/\n.installed.cfg\n*.egg\n.eggs/\n\n# IDE-related\n.idea/\n.vscode/\n\n# Dev\nvenv\n\n# compile-time generated file\nflash_attn_config.py"
  },
  {
    "path": ".gitmodules",
    "content": "[submodule \"csrc/cutlass\"]\n\tpath = csrc/cutlass\n\turl = https://github.com/NVIDIA/cutlass.git\n[submodule \"csrc/composable_kernel\"]\n\tpath = csrc/composable_kernel\n\turl = https://github.com/ROCm/composable_kernel.git\n\tbranch = amd-master\n[submodule \"third_party/aiter\"]\n\tpath = third_party/aiter\n\turl = https://github.com/ROCm/aiter.git\n"
  },
  {
    "path": ".pre-commit-config.yaml",
    "content": "repos:\n  - repo: https://github.com/astral-sh/ruff-pre-commit\n    rev: v0.11.13\n    hooks:\n      - id: ruff-check\n        args: [--fix, --exit-non-zero-on-fix]\n        files: ^flash_attn/cute/.*\\.py$\n        exclude: &cute_exclude |\n          (?x)^flash_attn/cute/(\n            flash_bwd|\n            flash_fwd|\n            flash_fwd_sm100|\n            interface|\n          )\\.py$\n      - id: ruff-format\n        files: ^flash_attn/cute/.*\\.py$\n        exclude: *cute_exclude\n"
  },
  {
    "path": "AI/DEBUG_2CTA.md",
    "content": "# Debugging GPU Kernel Hangs (Deadlocks) in CUTLASS DSL / 2CTA Kernels\n\n## General Approach to Debugging Kernel Hangs\n\n### Step 1: Build a minimal repro\n\nStrip the test case down to the smallest input that triggers the hang:\n- batch=1, nheads=1, smallest seqlen that hangs\n- Single config, no loops, no benchmarking\n- Add a timeout or run with `compute-sanitizer` so you can distinguish a hang from slow execution\n\n### Step 2: Add printf to locate the hang\n\nGPU `printf` (`cute.printf`) is the primary tool. The goal is binary search: narrow down which warp and which operation is blocked.\n\n**Printf guards** — avoid print storms:\n```python\n# One thread per warp:\nif cute.arch.thread_idx()[0] % 32 == 0:\n    cute.printf(\"...\")\n\n# One thread per CTA (elect_one is a context manager, not a bool):\nwith cute.arch.elect_one():\n    cute.printf(\"...\")\n\n# One specific thread:\nif tidx == 0:\n    cute.printf(\"...\")\n```\n\n**Strategy — coarse to fine:**\n1. First, print at the entry/exit of each warp's main function (load, mma, softmax, correction). This tells you which warp is stuck.\n2. Then add prints before/after each pipeline wait (`consumer_wait`, `producer_acquire`). This tells you which barrier is stuck.\n3. Then print the barrier index, phase, and stage to understand the pipeline state.\n\n**What to print:**\n- CTA index (`cute.arch.block_idx()[0]`) — critical for multi-CTA debugging\n- Pipeline stage index and phase\n- Loop iteration count\n- Whether a `try_wait` succeeds or fails (use `try_wait_token` parameter)\n\n### Step 3: Identify the deadlock chain\n\nA hang is always a cycle. Typical chain in a pipelined kernel:\n\n```\nMMA waiting for K from load (pipeline_kv full barrier)\n  -> Load finished but stuck in producer_tail (waiting for MMA to release empty barrier)\n    -> MMA can't release because it's waiting for K\n```\n\nOnce you see which barrier is stuck, trace backwards: who is supposed to signal it, and why haven't they?\n\n### Step 4: Vary the problem size systematically\n\nTest with different sequence lengths / block counts to find the pattern:\n\n| seqlen | n_blocks | Result |\n|--------|----------|--------|\n| 128    | 1        | ?      |\n| 256    | 2        | ?      |\n| 384    | 3        | ?      |\n| 512    | 4        | ?      |\n\nIf the hang correlates with the number of visits to a pipeline stage (e.g., works for n_blocks <= kv_stages but fails when stages wrap around), the problem is likely in barrier tx_count or phase tracking.\n\n### Step 5: Check barrier byte counts (tx_count)\n\nFor TMA-based pipelines, `arrive_and_expect_tx` sets the expected transaction byte count on an mbarrier. If the expected count doesn't match the actual bytes arriving, the barrier either:\n- Fires too early (expected < actual) — causes data races\n- Never fires (expected > actual) — causes hangs\n\nIn **2CTA / cluster mode**, both CTAs' TMAs signal the **same** cluster-level mbarrier. If each CTA's TMA contributes N bytes, the barrier receives 2N bytes total. The tx_count must be `N * cta_group_size`, not just `N`.\n\n**All TMA pipelines need doubling** — Q, K, and V. Even though each CTA loads a different M-tile for Q, both CTAs' TMA operations still signal the same cluster-level barrier, so the expected byte count must account for both.\n\n### Step 6: Check phase / parity tracking\n\n`mbarrier_try_wait_parity` uses a single parity bit (0 or 1). If your pipeline state tracks phase as a monotonically increasing counter (0, 1, 2, 3, ...), you need `phase % 2` before passing it to the barrier wait. Without this, phase=2 looks like phase=0 to the hardware, which can cause waits on already-completed barriers or misses on pending ones.\n\n### Step 7: Beware compiler-as-bug-source\n\nIf the kernel works WITH printf but hangs WITHOUT it, the printf is acting as a **compiler barrier**. The MLIR/LLVM backend cannot optimize through an opaque function call like printf, which prevents harmful instruction reordering.\n\nSigns this is happening:\n- A single `cute.printf(\"\\n\")` in the right function fixes the hang\n- PTX fences (`fence_view_async_shared`, `fence_acq_rel_cluster`, `sync_warp`, `fence_proxy`) do NOT fix it — these affect hardware memory ordering, not compiler scheduling\n- The fix is location-sensitive (printf in one function fixes it, in another doesn't)\n\nPossible workarounds:\n- `@dsl_user_op` decorator on pipeline methods to make them opaque to the compiler\n- `asm volatile` barriers (if available in the DSL)\n- Compare generated PTX/SASS with and without printf to identify what the compiler is reordering\n- File a bug against the CUTLASS DSL / MLIR pipeline\n\n---\n\n## 2CTA-Specific Pitfalls\n\n### tcgen05.commit with empty commit groups\n\n`tcgen05.commit(mbar, mask, cta_group::2)` is supposed to signal an mbarrier after all pending MMA operations complete. But if there are **no pending operations** (empty commit group), the signal only reaches the local CTA's barrier, not the remote CTA's. Fix: use explicit `mbarrier_arrive(barrier, dst_cta_rank)` to both CTAs.\n\n### producer_tail deadlock\n\nThe default `producer_tail` (inherited from sm90 pipelines) drains the pipeline by calling `producer_acquire` in a loop. In 2CTA mode this deadlocks because the consumer (MMA warp) may have already exited without releasing all stages. Fix: make `producer_tail` a no-op for 2CTA.\n\n### Tile scheduler must account for cluster shape\n\nBoth CTAs in a cluster must get the **same** tile coordinate. Raw `blockIdx.x` assigns consecutive values to CTAs in the same cluster. Fix: divide `blockIdx.x` by `cluster_shape_m`.\n\n### Cross-CTA vs per-CTA pipelines\n\nPipelines where CTA 1's threads remotely arrive on CTA 0's barriers need cluster-sized cooperative group counts. Pipelines that are purely local to each CTA keep per-CTA counts.\n\n### Softmax masking offset\n\nCausal mask row positions must account for the CTA's position within the cluster. Multiply `m_block` by `cta_group_size` when computing mask coordinates.\n"
  },
  {
    "path": "AI/RACECHECK_TMA_HAZARD.md",
    "content": "# compute-sanitizer racecheck hazard with `cp.async.bulk`\n\n## Summary\n\n`compute-sanitizer --tool=racecheck` reports false-positive shared-memory race\nhazards when `cp.async.bulk` (raw-address TMA) is used in a cross-warp\nproducer/consumer pipeline inside a dynamic loop. The same pattern with\n`cp.async.bulk.tensor` (descriptor-based TMA) reports **zero hazards**.\n\nThe fix for the flash backward kernel is to switch the LSE/dPsum copies from\n`CopyBulkG2SOp` (`cp.async.bulk`) to `CopyBulkTensorTileG2SOp`\n(`cp.async.bulk.tensor`) using `cpasync.make_tiled_tma_atom`.\n\n## Affected code\n\n`flash_attn/cute/flash_bwd_sm100.py` — the SM100 backward attention kernel.\n\nOnly **LSE** and **dPsum** buffers are affected because they are the only\nTMA-loaded buffers consumed by thread-level shared memory reads (`lds`).\nQ/K/V/dO are consumed by UMMA hardware instructions, which do not generate\nthread-level `lds` and therefore never trigger racecheck.\n\n## Root cause\n\nracecheck instruments every shared memory access and checks for conflicting\naccesses lacking a recognized happens-before relationship.\n\n**`cp.async.bulk` (raw address):** the sanitizer attributes the smem write to\nthe issuing thread (thread 0 of warp 0 via `elect_one`). When warp 1 issues\n`ld.shared.b32` from the same addresses, the sanitizer searches for a\nhappens-before edge. The only sync is `mbarrier.try_wait.parity` on warp 1\npaired with `mbarrier::complete_tx::bytes` completion from the hardware. The\nsanitizer does not model this as happens-before across warps in a dynamic loop.\n\n**`cp.async.bulk.tensor` (TMA descriptor):** the TMA engine is a separate\nhardware unit. The sanitizer does not attribute the smem write to any thread.\nNo writer thread means no hazard pair, so no race is reported.\n\n### Instruction comparison\n\n| Variant | PTX | racecheck |\n|---------|-----|-----------|\n| Raw (cta scope) | `cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes` | **hazard** |\n| Raw (cluster scope) | `cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes` | **hazard** |\n| Descriptor 1D | `cp.async.bulk.tensor.1d.shared::cta.global.tile.mbarrier::complete_tx::bytes` | clean |\n| Descriptor 2D | `cp.async.bulk.tensor.2d.shared::cta.global.tile.mbarrier::complete_tx::bytes` | clean |\n\n### `--racecheck-memcpy-async=no` does not help\n\nThis flag controls the older `cp.async` (sm80) instruction family, not\n`cp.async.bulk`. The hazard persists with `--racecheck-memcpy-async=no`.\n\n## Proof that it is a false positive\n\n1. **Data correctness** — all variants produce bit-identical results.\n2. **Single-warp test** — one warp does both TMA write and thread read in the\n   same loop; racecheck reports zero hazards with the same mbarrier sync.\n3. **Unrolled loop** — fully unrolling (`unroll_full=True`) reports zero\n   hazards; racecheck tracks mbarrier within straight-line code but not across\n   a dynamic branch back-edge between warps.\n4. **Named barrier** — adding `bar.sync` per iteration between producer and\n   consumer warps eliminates the hazard; the sync is correct, racecheck just\n   needs a primitive it recognizes.\n5. **Descriptor TMA** — switching to `cp.async.bulk.tensor` with identical\n   pipeline code eliminates the hazard; the mbarrier protocol is correct.\n\n## Minimal reproducers\n\n### `AI/` (preferred, cleaner)\n\n| File | Copy instruction | Result |\n|------|-----------------|--------|\n| `racecheck_repro_1d_bulk.py` | `cp.async.bulk` (raw address) | **1 error** |\n| `racecheck_repro_1d_tensor.py` | `cp.async.bulk.tensor.1d` (TMA descriptor) | **0 hazards** |\n\nBoth are ~75-line self-contained kernels: 2 warps, 4 blocks, 2-stage double\nbuffering with `PipelineTmaAsync`. Identical pipeline protocol — only the copy\ninstruction differs.\n\n```bash\npython AI/racecheck_repro_1d_bulk.py                                              # correctness\nCUTE_DSL_LINEINFO=1 compute-sanitizer --tool=racecheck python AI/racecheck_repro_1d_bulk.py   # 1 error\ncompute-sanitizer --tool=racecheck python AI/racecheck_repro_1d_tensor.py         # 0 hazards\n```\n\n### `benchmarks/` (earlier, more variants)\n\n| File | What it tests | Result |\n|------|--------------|--------|\n| `racecheck_false_positive_repro.py` | `cp.async.bulk` + mbarrier in cross-warp loop | 1 error |\n| `racecheck_1d_raw_ptx.py` | Inline PTX `cp.async.bulk.shared::cta.global` | 1 error |\n| `racecheck_tma2d_repro.py` | `cp.async.bulk.tensor.2d` via `make_tiled_tma_atom` | 0 hazards |\n| `racecheck_tma1d_descriptor.py` | `cp.async.bulk.tensor.1d` via `make_tiled_tma_atom` | 0 hazards |\n\n## PTX-level analysis\n\nDumped PTX for both `AI/` reproducers (`CUTE_DSL_KEEP_PTX=1`). The generated\ncode is byte-for-byte identical except for the single copy instruction:\n\n```\n# racecheck_repro_1d_bulk.py  (HAZARD)\ncp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes\n    [%r42], [%rd12], %r43, [%r6+-16];\n\n# racecheck_repro_1d_tensor.py  (CLEAN)\ncp.async.bulk.tensor.1d.shared::cta.global.tile.mbarrier::complete_tx::bytes.L2::cache_hint\n    [%r43], [%rd1, {%r71}], [%r6+-16], %rd8;\n```\n\nAll mbarrier operations (init, `fence.mbarrier_init.release.cluster`,\n`arrive.expect_tx`, `try_wait.parity`, `arrive.release`,\n`fence.proxy.async.shared::cta`, `bar.warp.sync`) are identical.\n\n### racecheck error output\n\n```\nError: Race reported between Write access at ...+0x430 in racecheck_repro_1d_bulk.py:46\n    and Read access at ...+0x770 in racecheck_repro_1d_bulk.py:55 [248 hazards]\n    and Read access at ...+0x7a0 in racecheck_repro_1d_bulk.py:55 [248 hazards]\n    and Read access at ...+0x7d0 in racecheck_repro_1d_bulk.py:55 [248 hazards]\n    and Read access at ...+0x800 in racecheck_repro_1d_bulk.py:55 [248 hazards]\n```\n\n- **Write** (0x430) = line 46: `cute.copy(atom, src, s, mbar_ptr=...)` — the\n  `cp.async.bulk` instruction\n- **Read** (0x770–0x800) = line 55: `dst[...] = s[...]` — four `ld.shared.b32`\n  in the consumer warp\n\n## Fix\n\nChange `copy_stats` in the load function from:\n\n```python\ncopy_atom_stats = cute.make_copy_atom(cpasync.CopyBulkG2SOp(), Float32)\ncopy_stats = partial(cute.copy, copy_atom_stats)\n```\n\nto a descriptor-based TMA using `cpasync.make_tiled_tma_atom` with\n`CopyBulkTensorTileG2SOp`. This generates `cp.async.bulk.tensor.1d` instead of\n`cp.async.bulk`, which racecheck does not instrument.\n\nThe pipeline protocol (mbarrier init, arrive_expect_tx, try_wait_parity,\nconsumer_release) remains identical.\n\n## Backup\n\n`flash_attn/cute/flash_bwd_sm100_gmem_fix.py` contains a working but slower\nfix where compute warps read LSE/dPsum directly from global memory, bypassing\nthe TMA smem pipeline entirely.\n\n## Investigation timeline\n\n1. Observed 2 racecheck errors on LSE and dPsum in `flash_bwd_sm100.py`.\n   Q/K/V/dO clean.\n2. Noticed Q/K/V/dO use UMMA consumers (no thread `lds`) while LSE/dPsum use\n   thread-level `autovec_copy` from smem — explains why only LSE/dPsum trigger.\n3. Built minimal 2-warp pipeline kernel reproducing the hazard.\n4. Single-warp version clean — same mbarrier, same addresses.\n5. Fully-unrolled version clean — racecheck tracks mbarrier within\n   straight-line code.\n6. `bar.sync` per iteration fixes it — racecheck needs a sync it recognizes\n   across the loop back-edge.\n7. `cp.async.bulk.tensor.2d` clean — different instruction, same pipeline.\n8. `cp.async.bulk.tensor.1d` clean — issue is raw vs descriptor, not\n   dimensionality.\n9. Raw inline PTX `cp.async.bulk.shared::cta.global` also triggers — not a\n   CuTe DSL abstraction issue.\n10. Dumped PTX for both `AI/` reproducers — confirmed byte-identical code\n    except for the copy instruction. Sanitizer attributes smem write to\n    issuing thread for `cp.async.bulk` but not for `cp.async.bulk.tensor`.\n11. Confirmed `--racecheck-memcpy-async=no` does not suppress the hazard —\n    flag targets older `cp.async`, not `cp.async.bulk`.\n"
  },
  {
    "path": "AI/SM90_BLOCK_SIZE_TUNING.md",
    "content": "# SM90 Block Size Tuning Guide\n\nHow to choose tile sizes and MMA configurations for FlashAttention on Hopper (SM90).\n\n## Tool\n\nUse `flash_attn/cute/sm90_config_search.py` to enumerate feasible configs:\n\n```bash\n# Both fwd and bwd\npython flash_attn/cute/sm90_config_search.py --headdim 128\n\n# Forward only\npython flash_attn/cute/sm90_config_search.py --mode fwd --headdim 192-128\n\n# Backward only, custom tile choices\npython flash_attn/cute/sm90_config_search.py --mode bwd --headdim 192 --tile-m 64,80 --tile-n 64,96\n```\n\n## Hardware Constraints (H100)\n\n- **SMEM**: 228 KB total. We reserve ~3 KB for LSE, dPsum, and mbarriers, leaving **224 KB** for tensor buffers.\n- **Registers**: Controlled via `setmaxnreg`. Budget per MMA warp group:\n  - 2 WG: 240 regs/thread, minus 24 overhead = **216 usable**\n  - 3 WG: 160 regs/thread, minus 32 overhead = **128 usable**\n- **GMMA atom**: Always M=64. The effective M dimension (after swap) must be divisible by 64. N dimension must be divisible by `atom_layout_n * 8`.\n\n## Architecture: Warp Groups\n\nEach SM90 backward kernel has `num_wg + 1` warp groups (128 threads each):\n- **WG0** (producer): TMA loads for Q, K, V, dO, LSE, dPsum\n- **WG1** (producer): dQaccum store (TMA reduce-add to gmem)\n- **WG2..WG(num_wg)** (MMA consumers): All GEMMs\n\nFor forward: `num_wg` MMA WGs + 1 producer WG. `tile_m = num_wg * 64` (no swap).\n\n## Key Decisions\n\n### 1. Number of Warp Groups (num_wg)\n\n| num_wg | tile_m (fwd) | Threads | Reg budget | Best for |\n|--------|-------------|---------|------------|----------|\n| 2 | 128 | 384 | 216/thread | hdim <= 128 |\n| 3 | 192 | 512 | 128/thread | hdim 129-192 |\n\nMore WGs = larger tile_m = better M-direction parallelism, but tighter register budget and higher smem usage.\n\n### 2. swap_AB\n\nEach MMA can optionally swap its A and B operands. This transposes the output tile, exchanging which dimension maps to M (must be divisible by 64) and which maps to N.\n\n**When to swap:**\n- If the natural M dimension isn't divisible by 64 but N is (e.g., tile_m=80 for SdP)\n- To change which operand is in registers vs shared memory\n\n**Forward**: No swap needed since tile_m = num_wg * 64 is always divisible by 64.\n\n**Backward** (5 MMAs):\n- **SdP** (S=Q@K^T, dP=dO@V^T): output (tile_m, tile_n). Swap if tile_m % 64 != 0.\n- **dKV** (dK=dS^T@Q, dV=P^T@dO): output (tile_n, hdim/hdimv). Swap if tile_n % 64 != 0 but hdim % 64 == 0.\n- **dQ** (dQ=dS@K): output (tile_m, hdim). Swap if tile_m % 64 != 0 but hdim % 64 == 0.\n\n### 3. AtomLayout\n\nThe `atom_layout` distributes WGs across the M and N dimensions of an MMA output. With `num_wg` MMA WGs and `atom_layout_m = A`:\n- M direction: A warp groups, each handling M/A rows\n- N direction: num_wg/A warp groups, each handling N/(num_wg/A) columns\n\nAfter swap, the atom layout is also swapped.\n\n**Impact on smem traffic**: More WGs in the N direction (`wg_n` larger) means each instruction reads a smaller B slice, but more instructions total read overlapping A slices. Fewer WGs in N (`wg_n` smaller) means fewer instructions but each reads a larger B slice. Typically **smaller wg_n = less total smem traffic**.\n\n### 4. mma_dkv_is_rs (Register-Source for dKV)\n\nWhen `AtomLayoutMSdP == 1 && AtomLayoutNdKV == num_wg && SdP_swapAB && !dKV_swapAB`, the P and dS matrices can be kept in registers and fed directly as the A operand of dV and dK GEMMs. This:\n- **Eliminates sP from smem** (saves tile_m * tile_n * 2 bytes)\n- **Eliminates P R2S store** from smem traffic\n- **Eliminates A operand reads** for dK and dV GEMMs\n\nThis is a significant optimization — always preferred when the conditions are met.\n\n### 5. Pipeline Staging\n\n**Forward**:\n- Q: 1 stage (loaded once per n_block tile)\n- K, V: 2 stages (double-buffered, pipelined with TMA)\n- O: overlaps with Q in smem (reuses same buffer at epilogue)\n\n**Backward**:\n- Q: always 2 stages (double-buffered)\n- dO: 2 stages if smem allows (matches Q pipeline), else 1 stage\n- PdS: 1 stage\n- K, V: persistent in smem (loaded once per n_block)\n\n## Register Accounting\n\nAccumulator registers per thread per WG = `M * N / (num_wg * 128)`, where M x N is the output tile.\n\n**Forward peak registers**:\n- With WG overlap: `regs_S + regs_P + regs_O` (S, P in bf16, O all live)\n- Without overlap: `regs_S + regs_O` (S and O alternate, P reuses S regs)\n\nWhere `regs_P = regs_S / 2` (bf16 vs f32).\n\n**Backward peak registers**:\n- `max(2 * regs_SdP, regs_dQ) + regs_dK + regs_dV`\n- S and dP accumulators are both live (S needed for softmax while dP computes)\n- dQ reuses S+dP register space after they're consumed\n- dK and dV accumulate across m_block iterations\n\n## SMEM Accounting\n\nSum of tensor buffers (ignoring alignment padding, which is small):\n\n**Forward**: `max(sQ, sO) + sK*2 + sV*2 + sP`\n- sQ = tile_m * hdim * 2\n- sK = tile_n * hdim * 2 * 2 stages\n- sV = tile_n * hdimv * 2 * 2 stages\n- sO = tile_m * hdimv * 2 (overlaps with sQ)\n- sP = tile_m * tile_n * 2 (0 if RS)\n\n**Backward**: `sQ*2 + sK + sV + sdO*dO_stage + sP + sdS + sdQaccum`\n- sQ = tile_m * hdim * 2 * 2 stages\n- sK = tile_n * hdim * 2\n- sV = tile_n * hdimv * 2\n- sdO = tile_m * hdimv * 2 * dO_stage\n- sP = tile_m * tile_n * 2 (0 if mma_dkv_is_rs)\n- sdS = tile_m * tile_n * 2\n- sdQaccum = tile_m * hdim * 4 (f32)\n\n## SMEM Traffic\n\nPer-iteration smem bandwidth consumed. Each GMMA instruction reads:\n- **A operand**: 64 * K_red * 2 bytes (0 if register-source)\n- **B operand**: (N_eff / wg_n) * K_red * 2 bytes\n\nTotal instructions = (M_eff / 64) * wg_n. Each instruction independently reads A and B from smem.\n\nAdditional traffic: R2S stores for P, dS (bf16), dQ smem store + TMA load (f32).\n\n**Traffic per block** (traffic / (tile_m * tile_n)) normalizes across tile sizes for comparison. Lower is better.\n\n## Example Configs\n\n### hdim=128 (Forward)\nBest: tile_m=128, tile_n=192, RS, 2 WG. 224K smem, 9.3 tr/blk.\n\n### hdim=128 (Backward, non-causal)\nC++ FA3 config: tile_m=80, tile_n=128, SdP_swap=T, dKV_swap=F, dQ_swap=T, aSdP=1, adKV=2. mma_dkv_is_rs=True. 204K smem, 208 regs, 39.6 tr/blk.\n\n### hdim=192 (Backward)\n3 WG, tile_m=64, tile_n=96, SdP_swap=F, dKV_swap=T, adKV=1 or 3. 216K smem, 128 regs. This is the only feasible tile_n > 64 for hdim=192 due to register pressure.\n\n### hdim=192, hdimv=128 (DeepSeek shape)\nWith 3 WG: need AtomLayoutNdKV=3 (since hdimv=128 not divisible by 3). tile_n=96, 212K smem.\nWith 2 WG: tile_n=112 feasible at 210K smem, or tile_n=64 at 168K smem.\n"
  },
  {
    "path": "AI/SM90_R2P_MASKING_SASS.md",
    "content": "# SM90 FWD R2P Masking — SASS Investigation\n\n## SASS Instruction Counts (hdim=128, seqlen=113, tile_n=128)\n\nWith tile_n=128, SM90 has 32 accumulator elements per row (1 chunk of 32).\n\n### Non-causal (seqlen-only masking)\n\n| Metric | Old (no R2P) | New (R2P) | Delta |\n|--------|-------------|-----------|-------|\n| **Total instructions** | 3104 | 3072 | **-32 (-1%)** |\n| R2P | 0 | 4 | +4 |\n| FSEL | 70 | 70 | 0 |\n| ISETP | 55 | 22 | **-33** |\n| SHF | 69 | 73 | +4 |\n| LOP3 | 51 | 56 | +5 |\n\nR2P replaces 33 ISETP (integer set-predicate) instructions with 4 R2P + a few LOP3/SHF. Net savings: 32 instructions. The 4 R2P instructions each convert one byte of a 32-bit bitmask into 7 predicates, covering all 32 elements (4 × 8 bits = 32).\n\n### Causal\n\n| Metric | Old (no R2P) | New (R2P) | Delta |\n|--------|-------------|-----------|-------|\n| **Total instructions** | 5008 | 4857 | **-151 (-3%)** |\n| R2P | 0 | 24 | +24 |\n| FSEL | 200 | 200 | 0 |\n| ISETP | 225 | 22 | **-203** |\n| SHF | 104 | 105 | +1 |\n| LOP3 | 81 | 105 | +24 |\n\nMuch larger savings. The causal kernel applies masking per-row (each row has a different col_limit), so it has many more masking operations. 24 R2P instructions replace 203 ISETP instructions, saving 151 total.\n\n### Local (sliding window, wl=64 wr=0)\n\n| Metric | Old (no R2P) | New (R2P) | Delta |\n|--------|-------------|-----------|-------|\n| **Total instructions** | 7296 | 6217 | **-1079 (-15%)** |\n| R2P | 0 | 32 | +32 |\n| FSEL | 522 | 266 | **-256** |\n| ISETP | 554 | 22 | **-532** |\n| SHF | 115 | 73 | -42 |\n| LOP3 | 96 | 56 | -40 |\n\nDramatic savings. Local masking has two bounds (left + right) per row, doubling the masking work. R2P eliminates 532 ISETP and 256 FSEL instructions, saving 1079 total (15% of kernel).\n\n## How R2P Works in SASS\n\nThe compiler generates this pattern:\n\n```\nSHF.R.U32.HI R9, RZ, R9, R16    ; shift to create bitmask\nR2P PR, R9, 0x7f                  ; byte 0 → predicates P0-P6\nFSEL R15, R36, -INF, P6           ; apply P6: keep or mask to -inf\nR2P PR, R9.B1, 0x7f              ; byte 1 → predicates P0-P6\nFSEL R52, R52, -INF, P6           ; apply P6\nR2P PR, R9.B2, 0x7f              ; byte 2\n...\nR2P PR, R9.B3, 0x7f              ; byte 3\n```\n\nEach `R2P` converts 7 bits of a register byte into 7 predicate registers simultaneously (1 instruction instead of 7 `ISETP`). The subsequent `FSEL` instructions use these predicates for conditional masking.\n\n### Handling the leftover bits (32 is not divisible by 7)\n\nThe `0x7f` immediate tells R2P to map bits 0-6 of each byte to P0-P6, but bit 7 (the MSB of each byte) is not covered. For 32 elements across 4 bytes, that's 4 leftover elements (bits 7, 15, 23, 31). The compiler handles these with separate `LOP3.LUT` or `ISETP` instructions:\n\n```\nR2P PR, R12,     0x7f           ; bits 0-6   → P0-P6  (7 elements)\n  14× FSEL using P0-P6           ; apply to 7 cols × 2 rows\nLOP3.LUT P0, RZ, R12, 0x80, ... ; test bit 7  (1 element)\n  2× FSEL using P0\n\nR2P PR, R12.B1,  0x7f           ; bits 8-14  → P0-P6  (7 elements)\n  14× FSEL using P0-P6\nLOP3.LUT P1, RZ, R12, 0x8000, ..; test bit 15 (1 element)\n  2× FSEL using P1\n\nR2P PR, R12.B2,  0x7f           ; bits 16-22 → P0-P6  (7 elements)\n  14× FSEL using P0-P6\nLOP3.LUT P0, RZ, R12, 0x800000,..; test bit 23 (1 element)\n  2× FSEL using P0\n\nR2P PR, R12.B3,  0x7f           ; bits 24-30 → P0-P6  (7 elements)\n  14× FSEL using P0-P6\nISETP.GT P0, R12, -1            ; test bit 31 (sign bit) (1 element)\n  2× FSEL using P0\n```\n\nTotal: 4×7 = 28 elements via R2P + 4 elements via LOP3/ISETP = 32. Each R2P replaces 7 ISETP with 1 instruction, so net savings is `(7-1) × 4 = 24` predicate-generation instructions per mask application. Additionally, ptxas can overlap R2P with FSEL since they write to separate predicate registers.\n\n## Performance Impact\n\n| Case | Old (ms) | New (ms) | Speedup |\n|------|----------|----------|---------|\n| Causal hdim=64 s=8192 | 2.463 | 2.473 | ~0% |\n| Causal hdim=128 s=8192 | 1.937 | 1.944 | ~0% |\n| Local hdim=64 s=8192 | 0.394 | 0.346 | **+14%** |\n| Local hdim=128 s=8192 | 0.237 | 0.222 | **+7%** |\n| Non-causal hdim=128 s=4096 | 1.742 | 1.728 | ~1% |\n\nCausal sees no perf gain despite fewer instructions because masking is a tiny fraction of total work (dominated by WGMMA). Local sees significant gains because the sliding window has many partially-masked blocks where masking overhead matters more.\n"
  },
  {
    "path": "AI/VARLEN_PREPROCESS_TILE_BUG.md",
    "content": "# Varlen Preprocess Tile Mismatch Bug\n\n## Summary\n\n`SeqlenInfo.create` in `flash_bwd_preprocess.py` defaulted `tile=128`, but the backward kernel uses `tile_m=m_block_size` (e.g. 64 for causal SM90). This caused the preprocess to zero dq_accum and write lse_log2/dpsum at wrong padded offsets for all batches after batch 0.\n\n## How padded_offset works\n\nFor varlen, buffers like dq_accum are laid out with tile-aligned gaps between sequences:\n\n```\npadded_offset_q = ((offset_q + batch_idx * tile_m) // tile_m) * tile_m\n```\n\nThe gap size depends on `tile_m`. With `tile_m=64` vs `tile_m=128`, batch 1 at `offset_q=128` gets:\n- tile=64:  padded_offset = ((128 + 64) // 64) * 64  = **192**\n- tile=128: padded_offset = ((128 + 128) // 128) * 128 = **256**\n\nThe preprocess was zeroing at 256, the backward was writing at 192.\n\n## Symptoms\n\n- Tests pass in isolation (torch.empty gets clean memory)\n- Tests fail when run in sequence (CUDA memory caching reuses NaN-polluted memory)\n- dq_accum valid positions contain NaN after backward kernel\n- `torch.zeros` for dq_accum masks the bug (zeroes everywhere, including the \"right\" offsets)\n- compute-sanitizer shows 0 errors (addresses are valid, just wrong offsets within the buffer)\n\n## Fix\n\n```python\n# flash_bwd_preprocess.py line 216\n# Before:\nseqlen = SeqlenInfo.create(batch_idx, mO.shape[1], mCuSeqlensQ, mSeqUsedQ)\n# After:\nseqlen = SeqlenInfo.create(batch_idx, mO.shape[1], mCuSeqlensQ, mSeqUsedQ, tile=self.tile_m)\n```\n\n## Lesson\n\nAny code computing `padded_offset` for varlen buffers must use the same tile size as the kernel that allocated and accesses those buffers. The `SeqlenInfo.create` default `tile=128` is a trap when `m_block_size != 128`.\n"
  },
  {
    "path": "AI/racecheck_repro_1d_bulk.py",
    "content": "\"\"\"Minimal reproducer: cp.async.bulk (raw address) triggers racecheck hazard.\n\nWarp 0 loads via cp.async.bulk, warp 1 reads from smem after mbarrier wait.\nPipeline is correctly synchronized but racecheck reports 1 error.\n\n  python AI/racecheck_repro_1d_bulk.py                                    # correctness\n  CUTE_DSL_LINEINFO=1 compute-sanitizer --tool=racecheck python AI/racecheck_repro_1d_bulk.py # 1 error\n\"\"\"\nimport cutlass\nimport cutlass.cute as cute\nfrom cutlass.cute.nvgpu import cpasync\nfrom cutlass.cute.runtime import from_dlpack\nfrom cutlass import Float32, Int32\nimport cutlass.pipeline\nfrom cutlass.pipeline.sm90 import PipelineTmaAsync, make_pipeline_state\nimport cuda.bindings.driver as cuda\nimport torch\n\nN_BLKS, TILE = 4, 128\nN_STG = 2\n\n\n@cute.kernel\ndef kernel(g_src: cute.Tensor, g_dst: cute.Tensor):\n    smem = cutlass.utils.SmemAllocator()\n    s = smem.allocate_tensor(Float32, cute.make_layout((TILE, N_STG)), byte_alignment=128)\n    s_mbar = smem.allocate_tensor(cutlass.Int64, cute.make_layout(2 * N_STG), byte_alignment=8)\n    tidx, _, _ = cute.arch.thread_idx()\n    warp, lane = tidx // 32, tidx % 32\n\n    pipe = PipelineTmaAsync.create(\n        barrier_storage=s_mbar.iterator, num_stages=N_STG,\n        producer_group=cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread, 1),\n        consumer_group=cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread, 1),\n        tx_count=TILE * 4, defer_sync=False,\n    )\n    src = cute.local_tile(g_src, (TILE,), (None,))\n    dst = cute.local_tile(g_dst, (TILE,), (None,))\n\n    if warp == 0:\n        ps = make_pipeline_state(cutlass.pipeline.PipelineUserType.Producer, N_STG)\n        for blk in cutlass.range(N_BLKS, unroll=1):\n            pipe.producer_acquire(ps)\n            atom = cute.make_copy_atom(cpasync.CopyBulkG2SOp(), Float32)\n            with cute.arch.elect_one():\n                cute.copy(atom, src[None, blk], s[None, ps.index],\n                          mbar_ptr=pipe.producer_get_barrier(ps))\n            ps.advance()\n        pipe.producer_tail(ps)\n    if warp == 1:\n        cs = make_pipeline_state(cutlass.pipeline.PipelineUserType.Consumer, N_STG)\n        for blk in cutlass.range(N_BLKS, unroll=1):\n            pipe.consumer_wait(cs)\n            for i in cutlass.range_constexpr(TILE // 32):\n                dst[lane + i * 32, blk] = s[lane + i * 32, cs.index]\n            cute.arch.fence_view_async_shared()\n            cute.arch.sync_warp()  # Ned sync_warp as only 1 thread will signal in consumer_release\n            pipe.consumer_release(cs)\n            cs.advance()\n\n\n@cute.jit\ndef go(g_src, g_dst, stream):\n    kernel(g_src, g_dst).launch(grid=[1, 1, 1], block=[64, 1, 1], smem=4096, stream=stream)\n\n\nif __name__ == \"__main__\":\n    src = torch.arange(TILE * N_BLKS, device=\"cuda\", dtype=torch.float32)\n    dst = torch.zeros_like(src)\n    go(from_dlpack(src, assumed_align=16), from_dlpack(dst, assumed_align=16),\n       cuda.CUstream(torch.cuda.current_stream().cuda_stream))\n    torch.cuda.synchronize()\n    assert torch.equal(src, dst), f\"FAIL: max diff={torch.abs(src - dst).max().item()}\"\n    print(\"PASS\")\n"
  },
  {
    "path": "AI/racecheck_repro_1d_tensor.py",
    "content": "\"\"\"Minimal reproducer: cp.async.bulk.tensor.1d (descriptor TMA) passes racecheck.\n\nSame pipeline as racecheck_repro_1d_bulk.py but uses make_tiled_tma_atom to\ncreate a TMA descriptor, which generates cp.async.bulk.tensor.1d PTX.\n\n  python AI/racecheck_repro_1d_tensor.py                                    # correctness\n  CUTE_DSL_LINEINFO=1 compute-sanitizer --tool=racecheck python AI/racecheck_repro_1d_tensor.py # 0 hazards\n\"\"\"\nimport cutlass\nimport cutlass.cute as cute\nfrom cutlass.cute.nvgpu import cpasync\nfrom cutlass.cute.runtime import from_dlpack\nfrom cutlass import Float32, Int32\nimport cutlass.pipeline\nfrom cutlass.pipeline.sm90 import PipelineTmaAsync, make_pipeline_state\nimport cuda.bindings.driver as cuda\nimport torch\n\nN_BLKS, TILE = 4, 128\nN_STG = 2\n\n\n@cute.kernel\ndef kernel(g_dst: cute.Tensor, tma_atom: cute.CopyAtom, tma_tensor: cute.Tensor):\n    smem = cutlass.utils.SmemAllocator()\n    s = smem.allocate_tensor(Float32, cute.make_layout((TILE, N_STG)), byte_alignment=128)\n    s_mbar = smem.allocate_tensor(cutlass.Int64, cute.make_layout(2 * N_STG), byte_alignment=8)\n    tidx, _, _ = cute.arch.thread_idx()\n    warp, lane = tidx // 32, tidx % 32\n\n    pipe = PipelineTmaAsync.create(\n        barrier_storage=s_mbar.iterator, num_stages=N_STG,\n        producer_group=cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread, 1),\n        consumer_group=cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread, 1),\n        tx_count=TILE * 4, defer_sync=False,\n    )\n    tma_s, tma_g = cpasync.tma_partition(\n        tma_atom, Int32(0), cute.make_layout(1),\n        cute.group_modes(s, 0, 1),\n        cute.group_modes(cute.local_tile(tma_tensor, (TILE,), (None,)), 0, 1),\n    )\n    dst = cute.local_tile(g_dst, (TILE,), (None,))\n\n    if warp == 0:\n        with cute.arch.elect_one():\n            cpasync.prefetch_descriptor(tma_atom)\n    if warp == 0:\n        ps = make_pipeline_state(cutlass.pipeline.PipelineUserType.Producer, N_STG)\n        for blk in cutlass.range(N_BLKS, unroll=1):\n            pipe.producer_acquire(ps)\n            cute.copy(tma_atom, tma_g[None, blk], tma_s[None, ps.index],\n                      tma_bar_ptr=pipe.producer_get_barrier(ps))\n            ps.advance()\n        pipe.producer_tail(ps)\n    if warp == 1:\n        cs = make_pipeline_state(cutlass.pipeline.PipelineUserType.Consumer, N_STG)\n        for blk in cutlass.range(N_BLKS, unroll=1):\n            pipe.consumer_wait(cs)\n            for i in cutlass.range_constexpr(TILE // 32):\n                dst[lane + i * 32, blk] = s[lane + i * 32, cs.index]\n            cute.arch.fence_view_async_shared()\n            cute.arch.sync_warp()  # Ned sync_warp as only 1 thread will signal in consumer_release\n            pipe.consumer_release(cs)\n            cs.advance()\n\n\n@cute.jit\ndef go(g_src, g_dst, stream):\n    tma_atom, tma_tensor = cpasync.make_tiled_tma_atom(\n        cpasync.CopyBulkTensorTileG2SOp(), g_src, cute.make_layout(TILE), (TILE,),\n    )\n    kernel(g_dst, tma_atom, tma_tensor).launch(\n        grid=[1, 1, 1], block=[64, 1, 1], smem=4096, stream=stream,\n    )\n\n\nif __name__ == \"__main__\":\n    src = torch.arange(TILE * N_BLKS, device=\"cuda\", dtype=torch.float32)\n    dst = torch.zeros_like(src)\n    go(from_dlpack(src, assumed_align=16), from_dlpack(dst, assumed_align=16),\n       cuda.CUstream(torch.cuda.current_stream().cuda_stream))\n    torch.cuda.synchronize()\n    assert torch.equal(src, dst), f\"FAIL: max diff={torch.abs(src - dst).max().item()}\"\n    print(\"PASS\")\n"
  },
  {
    "path": "AUTHORS",
    "content": "Tri Dao, trid@cs.stanford.edu"
  },
  {
    "path": "CLAUDE.md",
    "content": "# CLAUDE.md\n\nThis file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.\n\n## Project Overview\n\nFlashAttention-4 (FA4) — fast, memory-efficient exact attention kernels written in Python using CuTeDSL (NVIDIA CUTLASS DSL). Kernels are compiled to PTX/CUBIN at runtime. Targets Hopper (SM90) and Blackwell (SM100/SM110) GPUs. Package name: `flash-attn-4`.\n\nThe repository also contains older generations (FA2 in top-level `csrc/`, FA3 in `hopper/`) but active development is on FA4 in `flash_attn/cute/`.\n\n## Build & Install\n\n```bash\npip install flash-attn-4\n# or dev install:\npip install -e \"flash_attn/cute[dev]\"\n```\n\nDependencies: `nvidia-cutlass-dsl>=4.4.1`, `torch`, `einops`, `apache-tvm-ffi`, `quack-kernels>=0.2.10`.\n\n## Running Tests\n\n```bash\npytest tests/cute/test_flash_attn.py\npytest tests/cute/test_flash_attn.py -k \"test_flash_attn_output\" -x  # single test\npytest tests/cute/test_flash_attn_varlen.py\npytest tests/cute/test_mask_mod.py\npytest tests/cute/test_score_mod.py\npytest tests/cute/test_block_sparsity.py\n```\n\n### Fast two-pass testing\n\nCompilation dominates test time. The fast workflow separates compilation (parallel, no GPU needed) from execution (uses cached binaries):\n\n```bash\n# Pass 1: compile all kernels in parallel using FakeTensorMode (no GPU memory allocation)\nFLASH_ATTENTION_FAKE_TENSOR=1 FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1 pytest -n 64 -x tests/cute/test_flash_attn.py\n\n# Pass 2: run tests using cached compiled kernels\nFLASH_ATTENTION_FAKE_TENSOR=0 FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1 pytest -x tests/cute/test_flash_attn.py\n```\n\n- `FLASH_ATTENTION_FAKE_TENSOR=1` — uses PyTorch FakeTensorMode to compile kernels without allocating GPU memory or running them.\n- `FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1` — enables persistent disk cache at `/tmp/${USER}/flash_attention_cute_dsl_cache/`.\n- `-n 256` — pytest-xdist parallel workers (only useful in the compilation pass).\n\nTests are parametrized over dtype (fp16/bf16), head dimension (64, 96, 128), sequence length, causal/non-causal, and MHA/GQA/MQA.\n\nIf you get OOM errors running tests or benchmarks, use `nvidia-smi` to find a free GPU and select it with `CUDA_VISIBLE_DEVICES=<id>`.\n\n## Linting\n\nPre-commit uses ruff on `flash_attn/cute/` files. Large kernel files (`flash_bwd.py`, `flash_fwd.py`, `flash_fwd_sm100.py`, `interface.py`) are excluded from auto-formatting.\n\n```bash\nruff check flash_attn/cute/ --fix\nruff format flash_attn/cute/\n```\n\n## Code Architecture\n\n### Public API (`flash_attn/cute/interface.py`)\n\nTwo entry points exported from `flash_attn/cute/__init__.py`:\n- `flash_attn_func(q, k, v, ...)` — standard attention\n- `flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, ...)` — variable-length\n\nKey parameters: `causal`, `window_size_left/right`, `softmax_scale`, `softcap`, `score_mod`, `mask_mod`, `block_sparse_tensors`, `num_splits`, `pack_gqa`, `m_block_size`, `n_block_size`, `num_threads`.\n\nTensor layout: `(batch, seqlen, num_heads, head_dim)`, last dim contiguous, 16-byte aligned.\n\n### Forward Kernels\n\n- `flash_fwd.py` — `FlashAttentionForwardSm90`: Hopper forward. No SplitKV or paged KV.\n- `flash_fwd_sm100.py` — `FlashAttentionForwardSm100`: Blackwell forward. Full features including SplitKV, paged KV cache, persistent kernels, 2CTA instructions.\n- `flash_fwd_combine.py` — `FlashAttentionForwardCombine`: merges SplitKV partial results.\n\n### Backward Kernels\n\n- `flash_bwd.py` — `FlashAttentionBackwardSm80`: Ampere backward (base).\n- `flash_bwd_sm90.py` — `FlashAttentionBackwardSm90`: Hopper backward.\n- `flash_bwd_sm100.py` — `FlashAttentionBackwardSm100`: Blackwell backward with 2CTA and block sparse support.\n- `flash_bwd_preprocess.py` / `flash_bwd_postprocess.py` — auxiliary backward kernels.\n\n### Core Abstractions\n\n- `softmax.py` — Online softmax with row_max/row_sum tracking, score modifier support.\n- `mask.py` — `AttentionMask`: causal, local/sliding window, block sparse, mask_mod application.\n- `block_info.py` — `BlockInfo`: tile dimensions, n/m block range computation for causal/local masking.\n- `seqlen_info.py` — `SeqlenInfoQK`: sequence length and offset tracking for varlen.\n- `pipeline.py` — `PipelineStateSimple`: circular buffer index/phase management for pipelined loads.\n- `tile_scheduler.py` — Tile scheduling strategies (single tile, varlen-aware, persistent).\n- `copy_utils.py` — Type-converting copies, shared-to-register loads, TMA copy atoms.\n- `named_barrier.py` — Named barrier enums for warp synchronization.\n\n### Architecture-Specific Helpers\n\n- `hopper_helpers.py` — SM90 warp-group GEMM, shared memory layout creation, fence/commit/wait.\n- `blackwell_helpers.py` — SM100 UMMA-based GEMM, PTX-optimized paths, 2CTA support.\n- `mma_sm100_desc.py` — Hardware MMA descriptor enums (formats, saturation, scaling).\n\n### Other Components\n\n- `pack_gqa.py` — Packs multiple Q heads per KV head for efficient GQA.\n- `paged_kv.py` — `PagedKVManager`: paged KV cache with TMA support.\n- `fast_math.py` — exp2 polynomial coefficients, softcap score_mod creation.\n- `utils.py` — Hash functions for compile cache keys, warp reductions, predicates.\n- `cache_utils.py` — JIT compilation cache management.\n- `cute_dsl_utils.py` — Patched `cute.compile` that optionally dumps SASS.\n\n### Compilation & Caching\n\nKernels are JIT-compiled. Cache key includes dtype, head_dim, causal, mask/score_mod hashes, architecture, block sizes. Caching levels: in-memory LRU + optional disk cache via `get_jit_cache()`.\n\nEnv vars: `CUTE_CUBIN_PATH` (dump CUBIN/SASS), `CUTE_DSL_KEEP_PTX=1` (inspect PTX), `CUTE_DSL_PTXAS_PATH` (custom ptxas).\n\n## Key Patterns\n\n- Compile-time constants use `cutlass.Constexpr[type]` for kernel specialization.\n- Score/mask modifiers are user-defined `@cute.jit` callables injected into the kernel at compile time.\n- Forward execution: load Q tile → loop over K/V blocks (pipelined) → online softmax accumulation → store O and LSE.\n- 2CTA instructions (SM100, hdim=128): both CTAs in a cluster coordinate via shared mbarriers; tx_count must be multiplied by `cta_group_size`.\n\n## Debugging GPU Kernels\n\nSee `AI/DEBUG_2CTA.md` for kernel hang/deadlock debugging (printf bisection, pipeline barrier analysis, 2CTA pitfalls). See `AI/RACECHECK_TMA_HAZARD.md` for `compute-sanitizer` false positives with `cp.async.bulk`.\n\nKey tools:\n- `cute.printf` with thread guards (`tidx % 32 == 0`, `elect_one()`) for targeted output\n- `compute-sanitizer --tool=racecheck` (beware false positives with raw TMA)\n- `CUTE_DSL_KEEP_PTX=1` and `CUTE_DSL_LINEINFO=1` for PTX inspection and sanitizer source mapping\n"
  },
  {
    "path": "LICENSE",
    "content": "BSD 3-Clause License\n\nCopyright (c) 2022, the respective contributors, as shown by the AUTHORS file.\nAll rights reserved.\n\nRedistribution and use in source and binary forms, with or without\nmodification, are permitted provided that the following conditions are met:\n\n* Redistributions of source code must retain the above copyright notice, this\n  list of conditions and the following disclaimer.\n\n* Redistributions in binary form must reproduce the above copyright notice,\n  this list of conditions and the following disclaimer in the documentation\n  and/or other materials provided with the distribution.\n\n* Neither the name of the copyright holder nor the names of its\n  contributors may be used to endorse or promote products derived from\n  this software without specific prior written permission.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\nAND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\nIMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\nDISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\nFOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\nDAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\nSERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\nCAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\nOR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\nOF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n"
  },
  {
    "path": "MANIFEST.in",
    "content": "recursive-include csrc *.cu\nrecursive-include csrc *.h\nrecursive-include csrc *.cuh\nrecursive-include csrc *.cpp\nrecursive-include csrc *.hpp\nrecursive-include csrc *.py\n\nrecursive-include flash_attn *.cu\nrecursive-include flash_attn *.h\nrecursive-include flash_attn *.cuh\nrecursive-include flash_attn *.cpp\nrecursive-include flash_attn *.hpp\n"
  },
  {
    "path": "Makefile",
    "content": "\nclean_dist:\n\trm -rf dist/*\n\ncreate_dist: clean_dist\n\tpython setup.py sdist\n\nupload_package: create_dist\n\ttwine upload dist/*\n"
  },
  {
    "path": "README.md",
    "content": "# FlashAttention\nThis repository provides the official implementation of FlashAttention and\nFlashAttention-2 from the\nfollowing papers.\n\n**FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness**  \nTri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré  \nPaper: https://arxiv.org/abs/2205.14135  \nIEEE Spectrum [article](https://spectrum.ieee.org/mlperf-rankings-2022) about our submission to the MLPerf 2.0 benchmark using FlashAttention.\n![FlashAttention](assets/flashattn_banner.jpg)\n\n**FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning**  \nTri Dao\n\nPaper: https://tridao.me/publications/flash2/flash2.pdf\n\n![FlashAttention-2](assets/flashattention_logo.png)\n\n\n## Usage\n\nWe've been very happy to see FlashAttention being widely adopted in such a short\ntime after its release. This [page](https://github.com/Dao-AILab/flash-attention/blob/main/usage.md)\ncontains a partial list of places where FlashAttention is being used.\n\nFlashAttention and FlashAttention-2 are free to use and modify (see LICENSE).\nPlease cite and credit FlashAttention if you use it.\n\n\n## FlashAttention-3 beta release\nFlashAttention-3 is optimized for Hopper GPUs (e.g. H100). \n\nBlogpost: https://tridao.me/blog/2024/flash3/\n\nPaper: https://tridao.me/publications/flash3/flash3.pdf\n\n![FlashAttention-3 speedup on H100 80GB SXM5 with FP16](assets/flash3_fp16_fwd.png)\n\nThis is a beta release for testing / benchmarking before we integrate that with\nthe rest of the repo.\n\nCurrently released:\n- FP16 / BF16 forward and backward, FP8 forward\n\nRequirements: H100 / H800 GPU, CUDA >= 12.3.\n\nWe highly recommend CUDA 12.8 for best performance.\n\nTo install:\n```sh\ncd hopper\npython setup.py install\n```\nTo run the test:\n```sh\nexport PYTHONPATH=$PWD\npytest -q -s test_flash_attn.py\n```\nOnce the package is installed, you can import it as follows:\n```python\nimport flash_attn_interface\nflash_attn_interface.flash_attn_func()\n```\n\n## FlashAttention-4 (CuTeDSL)\n\nFlashAttention-4 is written in CuTeDSL and optimized for Hopper and Blackwell GPUs (e.g. H100, B200).\n\nTo install:\n```sh\npip install flash-attn-4\n```\n\nOnce installed, you can use it as follows:\n```python\nfrom flash_attn.cute import flash_attn_func\n\nout = flash_attn_func(q, k, v, causal=True)\n```\n\n## Installation and features\n**Requirements:**\n- CUDA toolkit or ROCm toolkit\n- PyTorch 2.2 and above.\n- `packaging` Python package (`pip install packaging`)\n- `psutil` Python package (`pip install psutil`)\n- `ninja` Python package (`pip install ninja`) *\n- Linux. Might work for Windows starting v2.3.2 (we've seen a few positive [reports](https://github.com/Dao-AILab/flash-attention/issues/595)) but Windows compilation still requires more testing. If you have ideas on how to set up prebuilt CUDA wheels for Windows, please reach out via Github issue.\n\n\\* Make sure that `ninja` is installed and that it works correctly (e.g. `ninja\n--version` then `echo $?` should return exit code 0). If not (sometimes `ninja\n--version` then `echo $?` returns a nonzero exit code), uninstall then reinstall\n`ninja` (`pip uninstall -y ninja && pip install ninja`). Without `ninja`,\ncompiling can take a very long time (2h) since it does not use multiple CPU\ncores. With `ninja` compiling takes 3-5 minutes on a 64-core machine using CUDA toolkit.\n\n**To install:**\n```sh\npip install flash-attn --no-build-isolation\n```\nAlternatively you can compile from source:\n```sh\npython setup.py install\n```\n\nIf your machine has less than 96GB of RAM and lots of CPU cores, `ninja` might\nrun too many parallel compilation jobs that could exhaust the amount of RAM. To\nlimit the number of parallel compilation jobs, you can set the environment\nvariable `MAX_JOBS`:\n```sh\nMAX_JOBS=4 pip install flash-attn --no-build-isolation\n```\n\n**Interface:** `src/flash_attention_interface.py`\n\n### NVIDIA CUDA Support\n**Requirements:**\n- CUDA 12.0 and above.\n\nWe recommend the\n[Pytorch](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch)\ncontainer from Nvidia, which has all the required tools to install FlashAttention.\n\nFlashAttention-2 with CUDA currently supports:\n1. Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100). Support for Turing\n   GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing\n   GPUs for now.\n2. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).\n3. All head dimensions up to 256. ~~Head dim > 192 backward requires A100/A800 or H100/H800~~. Head dim 256 backward now works on consumer GPUs (if there's no dropout) as of flash-attn 2.5.5.\n\n### AMD ROCm Support\nROCm version has two backends. There is [composable_kernel](https://github.com/ROCm/composable_kernel) (ck) which is the default backend and a [Triton](https://github.com/triton-lang/triton) backend. They provide an implementation of FlashAttention-2.\n\n**Requirements:**\n- ROCm 6.0 and above.\n\nWe recommend the\n[Pytorch](https://hub.docker.com/r/rocm/pytorch)\ncontainer from ROCm, which has all the required tools to install FlashAttention.\n\n#### Composable Kernel Backend\nFlashAttention-2 ROCm CK backend currently supports:\n1. MI200x, MI250x, MI300x, and MI355x GPUs.\n2. Datatype fp16 and bf16\n3. Both forward's and backward's head dimensions up to 256.\n\n#### Triton Backend\nThe Triton implementation of [Flash Attention](https://tridao.me/publications/flash2/flash2.pdf) supports AMD's CDNA (MI200, MI300) and RDNA GPUs using fp16, bf16, and fp32 datatypes. It provides forward and backward passes with causal masking, variable sequence lengths, arbitrary Q/KV sequence lengths and head sizes, MQA/GQA, dropout, rotary embeddings, ALiBi, paged attention, and FP8 (via the Flash Attention v3 interface). Sliding window attention is currently a work in progress.\n\nThe Triton backend kernels are provided by the [aiter](https://github.com/ROCm/aiter) package, included as a git submodule at `third_party/aiter` and automatically installed during setup.\n\nTo install, first get PyTorch for ROCm from https://pytorch.org/get-started/locally/, then install Flash Attention:\n```sh\ncd flash-attention\nFLASH_ATTENTION_TRITON_AMD_ENABLE=\"TRUE\" pip install --no-build-isolation .\n```\n\nTo use a specific aiter commit (e.g., for testing or development):\n```sh\ncd flash-attention\ncd third_party/aiter && git fetch origin && git checkout <commit-sha> && cd ../..\nFLASH_ATTENTION_TRITON_AMD_ENABLE=\"TRUE\" pip install --no-build-isolation .\n```\n\nTo run the tests (note: full suite takes hours):\n```sh\nFLASH_ATTENTION_TRITON_AMD_ENABLE=\"TRUE\" pytest tests/test_flash_attn_triton_amd.py\n```\n\nThe Triton backend uses a default kernel configuration optimized for determinism and reasonable performance across workloads. For peak throughput, enable `FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=\"TRUE\"` to search for optimal settings, which incurs a one-time warmup cost.\n\nAlternativly, if _not_ autotuning, `FLASH_ATTENTION_FWD_TRITON_AMD_CONFIG_JSON` may be used to set a single triton config overriding the hardcoded defaults for `attn_fwd`. E.g.\n```sh\nFLASH_ATTENTION_FWD_TRITON_AMD_CONFIG_JSON='{\"BLOCK_M\":128,\"BLOCK_N\":64,\"waves_per_eu\":1,\"PRE_LOAD_V\":false,\"num_stages\":1,\"num_warps\":8}'\n```\n\nFor a quick start with Docker:\n```dockerfile\nFROM rocm/pytorch:latest\n\nWORKDIR /workspace\n\n# build flash attention with triton backend\nRUN git clone https://github.com/Dao-AILab/flash-attention &&\\ \n    cd flash-attention &&\\\n    FLASH_ATTENTION_TRITON_AMD_ENABLE=\"TRUE\" pip install --no-build-isolation .\n\n# set working dir\nWORKDIR /workspace/flash-attention\n\n# set env variable to use triton backend\nENV FLASH_ATTENTION_TRITON_AMD_ENABLE=\"TRUE\"\n```\n\nBuild and run:\n```sh\ndocker build -t flash-attn-triton .\ndocker run -it --network=host --user root --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ipc=host --shm-size 16G --device=/dev/kfd --device=/dev/dri flash-attn-triton\n```\n\n## How to use FlashAttention\n\nThe main functions implement scaled dot product attention (softmax(Q @ K^T *\nsoftmax_scale) @ V):\n```python\nfrom flash_attn import flash_attn_qkvpacked_func, flash_attn_func\n```\n\n```python\nflash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False,\n                          window_size=(-1, -1), alibi_slopes=None, deterministic=False):\n\"\"\"dropout_p should be set to 0.0 during evaluation\nIf Q, K, V are already stacked into 1 tensor, this function will be faster than\ncalling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation\nof the gradients of Q, K, V.\nIf window_size != (-1, -1), implements sliding window local attention. Query at position i\nwill only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.\nArguments:\n    qkv: (batch_size, seqlen, 3, nheads, headdim)\n    dropout_p: float. Dropout probability.\n    softmax_scale: float. The scaling of QK^T before applying softmax.\n        Default to 1 / sqrt(headdim).\n    causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).\n    window_size: (left, right). If not (-1, -1), implements sliding window local attention.\n    alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to\n        the attention score of query i and key j.\n    deterministic: bool. Whether to use the deterministic implementation of the backward pass,\n        which is slightly slower and uses more memory. The forward pass is always deterministic.\nReturn:\n    out: (batch_size, seqlen, nheads, headdim).\n\"\"\"\n```\n\n```python\nflash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False,\n                window_size=(-1, -1), alibi_slopes=None, deterministic=False):\n\"\"\"dropout_p should be set to 0.0 during evaluation\nSupports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads\nthan Q. Note that the number of heads in Q must be divisible by the number of heads in KV.\nFor example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head\n0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.\nIf window_size != (-1, -1), implements sliding window local attention. Query at position i\nwill only attend to keys between\n[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.\n\nArguments:\n    q: (batch_size, seqlen, nheads, headdim)\n    k: (batch_size, seqlen, nheads_k, headdim)\n    v: (batch_size, seqlen, nheads_k, headdim)\n    dropout_p: float. Dropout probability.\n    softmax_scale: float. The scaling of QK^T before applying softmax.\n        Default to 1 / sqrt(headdim).\n    causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).\n    window_size: (left, right). If not (-1, -1), implements sliding window local attention.\n    alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of\n        (-alibi_slope * |i + seqlen_k - seqlen_q - j|)\n        is added to the attention score of query i and key j.\n    deterministic: bool. Whether to use the deterministic implementation of the backward pass,\n        which is slightly slower and uses more memory. The forward pass is always deterministic.\nReturn:\n    out: (batch_size, seqlen, nheads, headdim).\n\"\"\"\n```\n\n```python\ndef flash_attn_with_kvcache(\n    q,\n    k_cache,\n    v_cache,\n    k=None,\n    v=None,\n    rotary_cos=None,\n    rotary_sin=None,\n    cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,\n    cache_batch_idx: Optional[torch.Tensor] = None,\n    block_table: Optional[torch.Tensor] = None,\n    softmax_scale=None,\n    causal=False,\n    window_size=(-1, -1),  # -1 means infinite context window\n    rotary_interleaved=True,\n    alibi_slopes=None,\n):\n    \"\"\"\n    If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from\n    k and v. This is useful for incremental decoding: you can pass in the cached keys/values from\n    the previous step, and update them with the new keys/values from the current step, and do\n    attention with the updated cache, all in 1 kernel.\n\n    If you pass in k / v, you must make sure that the cache is large enough to hold the new values.\n    For example, the KV cache could be pre-allocated with the max sequence length, and you can use\n    cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.\n\n    Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be\n    rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.\n    If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos\n    and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.\n    If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at\n    indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).\n\n    See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.\n\n    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads\n    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.\n    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head\n    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.\n\n    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.\n    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:\n        1 1 1 1 0\n        1 1 1 1 1\n    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:\n        0 0\n        0 0\n        0 0\n        1 0\n        1 1\n    If the row of the mask is all zero, the output will be zero.\n\n    If window_size != (-1, -1), implements sliding window local attention. Query at position i\n    will only attend to keys between\n    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.\n\n    Note: Does not support backward pass.\n\n    Arguments:\n        q: (batch_size, seqlen, nheads, headdim)\n        k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,\n            or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)\n            page_block_size must be a multiple of 256.\n        v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,\n            or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)\n        k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate\n            k with k_cache, starting at the indices specified by cache_seqlens.\n        v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.\n        rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding\n            to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.\n        rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.\n        cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the\n            KV cache.\n        block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.\n        cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.\n            If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].\n            If the indices are not distinct, and k and v are provided, the values updated in the cache\n                 might come from any of the duplicate indices.\n        softmax_scale: float. The scaling of QK^T before applying softmax.\n            Default to 1 / sqrt(headdim).\n        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).\n        window_size: (left, right). If not (-1, -1), implements sliding window local attention.\n        rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.\n            If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,\n            rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1\n            (i.e. GPT-NeoX style).\n        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of\n            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)\n            is added to the attention score of query i and key j.\n\n    Return:\n        out: (batch_size, seqlen, nheads, headdim).\n    \"\"\"\n```\n\nTo see how these functions are used in a multi-head attention layer (which\nincludes QKV projection, output projection), see the MHA [implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py).\n\n### Using with 🤗 Kernels\n\nIf your hardware environment belongs to any of the above-mentioned, you can also use the [`kernels` library](https://github.com/huggingface/kernels)\nto use Flash Attention 2 and 3 right away.\n\n```py\n# pip install kernels\n\nfrom kernels import get_kernel\n\n# FA2\nfa_module = get_kernel(\"kernels-community/flash-attn2\", version=1)\nflash_attn_func = fa_module.flash_attn_func\n\n# FA3\nfa3_module = get_kernel(\"kernels-community/flash-attn3\", version=1)\nflash_attn_func = fa3_module.flash_attn_func\n```\n\n## Changelog\n\n### 2.0: Complete rewrite, 2x faster\nUpgrading from FlashAttention (1.x) to FlashAttention-2\n\nThese functions have been renamed:\n- `flash_attn_unpadded_func` -> `flash_attn_varlen_func`\n- `flash_attn_unpadded_qkvpacked_func` -> `flash_attn_varlen_qkvpacked_func`\n- `flash_attn_unpadded_kvpacked_func` -> `flash_attn_varlen_kvpacked_func`\n\nIf the inputs have the same sequence lengths in the same batch, it is simpler\nand faster to use these functions:\n```python\nflash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False)\n```\n```python\nflash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False)\n```\n### 2.1: Change behavior of causal flag\n\nIf seqlen_q != seqlen_k and causal=True, the causal mask is aligned to the\nbottom right corner of the attention matrix, instead of the top-left corner.\n\nFor example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 =\nmasked out) is:  \nv2.0:  \n    1 0 0 0 0  \n    1 1 0 0 0  \nv2.1:  \n    1 1 1 1 0  \n    1 1 1 1 1  \n\nIf seqlen_q = 5 and seqlen_k = 2, the causal mask is:  \nv2.0:  \n    1 0  \n    1 1  \n    1 1  \n    1 1  \n    1 1  \nv2.1:  \n    0 0  \n    0 0  \n    0 0  \n    1 0  \n    1 1  \nIf the row of the mask is all zero, the output will be zero.\n\n### 2.2: Optimize for inference\n\nOptimize for inference (iterative decoding) when query has very small sequence\nlength (e.g., query sequence length = 1). The bottleneck here is to load KV\ncache as fast as possible, and we split the loading across different thread\nblocks, with a separate kernel to combine results.\n\nSee the function `flash_attn_with_kvcache` with more features for inference\n(perform rotary embedding, updating KV cache inplace).\n\nThanks to the xformers team, and in particular Daniel Haziza, for this\ncollaboration.\n\n### 2.3: Local (i.e., sliding window) attention\n\nImplement sliding window attention (i.e., local attention). Thanks to [Mistral\nAI](https://mistral.ai/) and in particular Timothée Lacroix for this\ncontribution. Sliding window was used in the [Mistral 7B](https://mistral.ai/news/announcing-mistral-7b/) model.\n\n### 2.4: ALiBi (attention with linear bias), deterministic backward pass.\n\nImplement ALiBi (Press et al., 2021). Thanks to Sanghun Cho from Kakao Brain for this contribution.\n\nImplement deterministic backward pass. Thanks to engineers from [Meituan](www.meituan.com) for this contribution.\n\n### 2.5: Paged KV cache.\n\nSupport paged KV cache (i.e., [PagedAttention](https://arxiv.org/abs/2309.06180)).\nThanks to @beginlner for this contribution.\n\n### 2.6: Softcapping.\n\nSupport attention with softcapping, as used in Gemma-2 and Grok models.\nThanks to @Narsil and @lucidrains for this contribution.\n\n### 2.7: Compatibility with torch compile\n\nThanks to @ani300 for this contribution.\n\n## Performance\n\nWe present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory).\n\nWe currently have benchmarks for these GPUs:\n* [A100](#a100)\n* [H100](#h100)\n<!-- * [RTX 3090](#rtx-3090) -->\n<!-- * [T4](#t4) -->\n\n### A100\n\nWe display FlashAttention speedup using these parameters:\n* Head dimension 64 or 128, hidden dimension 2048 (i.e. either 32 or 16 heads).\n* Sequence length 512, 1k, 2k, 4k, 8k, 16k.\n* Batch size set to 16k / seqlen.\n\n#### Speedup\n\n![FlashAttention speedup on A100 80GB SXM5 with FP16/BF16](assets/flash2_a100_fwd_bwd_benchmark.png)\n\n#### Memory\n\n![FlashAttention memory](assets/flashattn_memory.jpg)\n\nWe show memory savings in this graph (note that memory footprint is the same no matter if you use dropout or masking).\nMemory savings are proportional to sequence length -- since standard attention has memory quadratic in sequence length, whereas FlashAttention has memory linear in sequence length.\nWe see 10X memory savings at sequence length 2K, and 20X at 4K.\nAs a result, FlashAttention can scale to much longer sequence lengths.\n\n### H100\n\n![FlashAttention speedup on H100 SXM5 with FP16/BF16](assets/flash2_h100_fwd_bwd_benchmark.png)\n\n## Full model code and training script\n\nWe have released the full GPT model\n[implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/models/gpt.py).\nWe also provide optimized implementations of other layers (e.g., MLP, LayerNorm,\ncross-entropy loss, rotary embedding). Overall this speeds up training by 3-5x\ncompared to the baseline implementation from Huggingface, reaching up to 225\nTFLOPs/sec per A100, equivalent to 72% model FLOPs utilization (we don't need\nany activation checkpointing).\n\nWe also include a training\n[script](https://github.com/Dao-AILab/flash-attention/tree/main/training) to\ntrain GPT2 on Openwebtext and GPT3 on The Pile.\n\n## Triton implementation of FlashAttention\n\nPhil Tillet (OpenAI) has an experimental implementation of FlashAttention in Triton:\nhttps://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py\n\nAs Triton is a higher-level language than CUDA, it might be easier to understand\nand experiment with. The notations in the Triton implementation are also closer\nto what's used in our paper.\n\nWe also have an experimental implementation in Triton that support attention\nbias (e.g. ALiBi):\nhttps://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_triton.py\n\n\n## Tests\nWe test that FlashAttention produces the same output and gradient as a reference\nimplementation, up to some numerical tolerance. In particular, we check that the\nmaximum numerical error of FlashAttention is at most twice the numerical error\nof a baseline implementation in Pytorch (for different head dimensions, input\ndtype, sequence length, causal / non-causal).\n\nTo run the tests:\n```sh\npytest -q -s tests/test_flash_attn.py\n```\n## When you encounter issues\n\nThis new release of FlashAttention-2 has been tested on several GPT-style\nmodels, mostly on A100 GPUs.\n\nIf you encounter bugs, please open a GitHub Issue!\n\n## Tests\nTo run the tests:\n```sh\npytest tests/test_flash_attn_ck.py\n```\n\n## Citation\nIf you use this codebase, or otherwise found our work valuable, please cite:\n```\n@inproceedings{dao2022flashattention,\n  title={Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},\n  author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\\'e}, Christopher},\n  booktitle={Advances in Neural Information Processing Systems (NeurIPS)},\n  year={2022}\n}\n@inproceedings{dao2023flashattention2,\n  title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning},\n  author={Dao, Tri},\n  booktitle={International Conference on Learning Representations (ICLR)},\n  year={2024}\n}\n```\n"
  },
  {
    "path": "benchmarks/bench_sm90.py",
    "content": "#!/usr/bin/env python\n\"\"\"Unified SM90 benchmark for forward and backward passes.\n\nUsage:\n    # Default: bench fwd+bwd for hdim 64,96,128 at seqlen 8192\n    python benchmarks/bench_sm90.py\n\n    # Forward only, specific hdims\n    python benchmarks/bench_sm90.py --direction fwd --hdim 64,96\n\n    # Backward only\n    python benchmarks/bench_sm90.py --direction bwd --hdim 128\n\n    # Custom seqlens and batch size\n    python benchmarks/bench_sm90.py --seqlen 1024,2048,4096,8192 --batch 0\n\n    # Sweep tile sizes for fwd\n    python benchmarks/bench_sm90.py --sweep-tiles --hdim 96\n\n    # Sweep tile sizes for fwd (all hdims including 192, 256)\n    python benchmarks/bench_sm90.py --sweep-tiles --hdim 64,96,128,192,256\n\n    # Sweep RS/overlap variants\n    python benchmarks/bench_sm90.py --sweep-rs-overlap --hdim 64,96\n\n    # Compare old vs new configs\n    python benchmarks/bench_sm90.py --compare-configs\n\n    # Sweep backward optimizations (V_in_regs, mma_dkv_is_rs, pipeline sharing)\n    python benchmarks/bench_sm90.py --sweep-bwd-opts --hdim 64,128\n\n    # Causal only, more reps and warmup\n    python benchmarks/bench_sm90.py --causal-only --rep 50 --warmup 10\n\"\"\"\nimport argparse\nimport time\n\nimport torch\nimport torch.nn.functional as F\nfrom flash_attn.cute.interface import _flash_attn_fwd, _flash_attn_bwd\n\n\n# ── Helpers ────────────────────────────────────────────────────────────────\n\ndef parse_int_k(s):\n    \"\"\"Parse an integer with optional k/K suffix, e.g. '8k' -> 8192.\"\"\"\n    s = s.strip().lower()\n    if s.endswith(\"k\"):\n        return int(s[:-1]) * 1024\n    return int(s)\n\n\ndef csv_ints(s):\n    \"\"\"Parse comma-separated integers with optional k suffix, e.g. '512,1k,2k'.\"\"\"\n    return [parse_int_k(x) for x in s.split(\",\")]\n\n\ndef parse_headdims(s):\n    \"\"\"Parse comma-separated headdim specs. Each entry is hdim or hdim-hdim_v.\n\n    Examples:\n        '128'           -> [(128, 128)]\n        '192-128'       -> [(192, 128)]\n        '64,128,192'    -> [(64, 64), (128, 128), (192, 192)]\n        '64,128,192-128,192' -> [(64, 64), (128, 128), (192, 128), (192, 192)]\n    \"\"\"\n    result = []\n    for item in s.split(\",\"):\n        if \"-\" in item:\n            parts = item.split(\"-\")\n            result.append((int(parts[0]), int(parts[1])))\n        else:\n            hdim = int(item)\n            result.append((hdim, hdim))\n    return result\n\n\ndef nheads_for_hdim(h):\n    return 32 if h <= 64 else (16 if h <= 192 else 8)\n\n\ndef fwd_flops(batch, nheads, seqlen, hdim, hdim_v=None, causal=False):\n    if hdim_v is None:\n        hdim_v = hdim\n    avg_seqlen = seqlen / 2 if causal else seqlen\n    return batch * nheads * 2 * seqlen * avg_seqlen * (hdim + hdim_v)\n\n\ndef bwd_flops(batch, nheads, seqlen, hdim, causal=False, hdim_v=None):\n    return 2.5 * fwd_flops(batch, nheads, seqlen, hdim, hdim_v=hdim_v, causal=causal)\n\n\ndef get_causals(args):\n    if args.causal_only:\n        return [True]\n    if args.non_causal_only:\n        return [False]\n    return [False, True]\n\n\ndef auto_batch(seqlen, batch_arg, total_tokens=32768):\n    return batch_arg if batch_arg > 0 else max(1, total_tokens // seqlen)\n\n\n# ── Core bench functions ──────────────────────────────────────────────────\n\ndef bench_fwd(batch, seqlen, nheads, hdim, causal, tile_m=None, tile_n=None,\n              mma_pv_is_rs=None, intra_wg_overlap=None, check_correctness=True,\n              warmup=5, rep=30, hdim_v=None):\n    \"\"\"Benchmark forward pass. Returns (ms, tflops, max_diff_or_error).\"\"\"\n    if hdim_v is None:\n        hdim_v = hdim\n    q = torch.randn(batch, seqlen, nheads, hdim, dtype=torch.bfloat16, device=\"cuda\")\n    k = torch.randn(batch, seqlen, nheads, hdim, dtype=torch.bfloat16, device=\"cuda\")\n    v = torch.randn(batch, seqlen, nheads, hdim_v, dtype=torch.bfloat16, device=\"cuda\")\n    kwargs = dict(softmax_scale=hdim ** -0.5, causal=causal)\n    if tile_m is not None and tile_n is not None:\n        kwargs[\"tile_mn\"] = (tile_m, tile_n)\n    if mma_pv_is_rs is not None:\n        kwargs[\"mma_pv_is_rs\"] = mma_pv_is_rs\n    if intra_wg_overlap is not None:\n        kwargs[\"intra_wg_overlap\"] = intra_wg_overlap\n\n    try:\n        out, _lse = _flash_attn_fwd(q, k, v, **kwargs)\n    except Exception as e:\n        return None, None, str(e)[:80]\n\n    max_diff = None\n    if check_correctness:\n        q_ref = q.transpose(1, 2).float()\n        k_ref = k.transpose(1, 2).float()\n        v_ref = v.transpose(1, 2).float()\n        out_ref = F.scaled_dot_product_attention(q_ref, k_ref, v_ref, is_causal=causal)\n        out_ref = out_ref.transpose(1, 2).to(torch.bfloat16)\n        max_diff = (out.float() - out_ref.float()).abs().max().item()\n\n    for _ in range(warmup):\n        _flash_attn_fwd(q, k, v, **kwargs)\n\n    torch.cuda.synchronize()\n    start = torch.cuda.Event(enable_timing=True)\n    end = torch.cuda.Event(enable_timing=True)\n    start.record()\n    for _ in range(rep):\n        _flash_attn_fwd(q, k, v, **kwargs)\n    end.record()\n    torch.cuda.synchronize()\n    ms = start.elapsed_time(end) / rep\n    tflops = fwd_flops(batch, nheads, seqlen, hdim, hdim_v=hdim_v, causal=causal) / ms / 1e9\n    return ms, tflops, max_diff\n\n\ndef bench_bwd(batch, seqlen, nheads, hdim, causal, warmup=5, rep=30, hdim_v=None, **bwd_kwargs):\n    \"\"\"Benchmark backward pass. Returns (ms, tflops, None_or_error).\"\"\"\n    if hdim_v is None:\n        hdim_v = hdim\n    q = torch.randn(batch, seqlen, nheads, hdim, device=\"cuda\", dtype=torch.bfloat16)\n    k = torch.randn(batch, seqlen, nheads, hdim, device=\"cuda\", dtype=torch.bfloat16)\n    v = torch.randn(batch, seqlen, nheads, hdim_v, device=\"cuda\", dtype=torch.bfloat16)\n    softmax_scale = hdim ** -0.5\n    try:\n        out, lse = _flash_attn_fwd(q, k, v, softmax_scale=softmax_scale, causal=causal,\n                                    return_lse=True)\n    except Exception as e:\n        return None, None, str(e)[:80]\n    dout = torch.randn_like(out)\n\n    def fn():\n        _flash_attn_bwd(q, k, v, out, dout, lse, softmax_scale=softmax_scale,\n                         causal=causal, **bwd_kwargs)\n\n    try:\n        fn()  # compile\n    except Exception as e:\n        return None, None, str(e)[:80]\n    for _ in range(warmup):\n        fn()\n\n    torch.cuda.synchronize()\n    start = torch.cuda.Event(enable_timing=True)\n    end = torch.cuda.Event(enable_timing=True)\n    start.record()\n    for _ in range(rep):\n        fn()\n    end.record()\n    torch.cuda.synchronize()\n    ms = start.elapsed_time(end) / rep\n    tflops = bwd_flops(batch, nheads, seqlen, hdim, causal, hdim_v=hdim_v) / ms / 1e9\n    return ms, tflops, None\n\n\n# ── Preset configs ────────────────────────────────────────────────────────\n\n# (tile_m, tile_n, mma_pv_is_rs, intra_wg_overlap)\nTILE_SWEEP_CONFIGS = {\n    64: [\n        (192, 192, False, True),\n        (192, 192, True, True),\n        (192, 128, True, True),\n        (192, 128, False, True),\n        (128, 128, True, True),\n        (128, 192, True, True),\n        (192, 96, True, True),\n        (192, 96, False, True),\n    ],\n    96: [\n        (192, 144, False, True),\n        (192, 144, True, True),\n        (192, 128, False, True),\n        (192, 128, True, True),\n        (192, 96, False, True),\n        (192, 96, True, True),\n        (128, 128, True, True),\n        (128, 128, False, True),\n    ],\n    128: [\n        (128, 128, True, True),\n        (128, 128, False, True),\n        (128, 96, True, True),\n        (128, 96, False, True),\n        (128, 160, True, True),\n        (128, 176, True, True),\n        (128, 192, True, True),\n    ],\n    192: [\n        (128, 64, True, True),\n        (128, 80, True, True),\n        (128, 96, True, True),\n        (128, 112, True, True),\n        (128, 128, True, True),\n    ],\n    256: [\n        (128, 48, True, True),\n        (128, 64, True, True),\n        (128, 80, True, True),\n        (128, 96, True, True),\n    ],\n}\n\nRS_OVERLAP_COMBOS = [\n    (True, True, \"RS+OL\"),\n    (True, False, \"RS+noOL\"),\n    (False, True, \"noRS+OL\"),\n    (False, False, \"noRS+noOL\"),\n]\n\nCOMPARE_CONFIGS = [\n    # (hdim, causal, (old_tile_m, old_tile_n, old_rs, old_ol), (new...))\n    (64, False, (192, 128, True, True), (192, 128, True, True)),\n    (64, True, (192, 128, True, True), (192, 128, True, True)),\n    (96, False, (192, 96, True, True), (192, 144, False, True)),\n    (96, True, (192, 96, True, True), (192, 128, False, True)),\n]\n\n\ndef _get_default_bwd_config(headdim, causal=False):\n    \"\"\"Default SM90 backward config for a given headdim.\"\"\"\n    if headdim <= 128:\n        return dict(\n            m_block_size=64 if causal else 80,\n            n_block_size=128,\n            num_stages_Q=2,\n            num_stages_dO=2,\n            SdP_swapAB=True,\n            dKV_swapAB=False,\n            dQ_swapAB=not causal,\n            AtomLayoutMSdP=1,\n            AtomLayoutNdKV=2,\n            AtomLayoutMdQ=1,\n            num_threads=384,\n        )\n    elif headdim <= 192:\n        return dict(\n            m_block_size=64,\n            n_block_size=96,\n            num_stages_Q=1,\n            num_stages_dO=1,\n            SdP_swapAB=False,\n            dKV_swapAB=True,\n            dQ_swapAB=True,\n            AtomLayoutMSdP=1,\n            AtomLayoutNdKV=1,\n            AtomLayoutMdQ=1,\n            num_threads=512,\n        )\n    else:\n        return dict(\n            m_block_size=64,\n            n_block_size=64,\n            num_stages_Q=1,\n            num_stages_dO=1,\n            SdP_swapAB=False,\n            dKV_swapAB=False,\n            dQ_swapAB=False,\n            AtomLayoutMSdP=1,\n            AtomLayoutNdKV=1,\n            AtomLayoutMdQ=1,\n            num_threads=384,\n        )\n\n\n# Maps optimization name -> function(headdim, causal) -> dict[label, kwargs] or None\nBWD_OPT_CONFIGS = {\n    \"V_in_regs\": lambda hdim, causal: (\n        None if hdim > 128 else {\n            \"baseline (V_in_regs=False)\": {**_get_default_bwd_config(hdim, causal), \"V_in_regs\": False},\n            \"optimized (V_in_regs=True)\": {**_get_default_bwd_config(hdim, causal), \"V_in_regs\": True},\n        }\n    ),\n    \"mma_dkv_is_rs\": lambda hdim, causal: (\n        None if hdim > 128 else {\n            \"baseline (AtomLayoutNdKV=1)\": {**_get_default_bwd_config(hdim, causal), \"AtomLayoutNdKV\": 1},\n            \"optimized (AtomLayoutNdKV=2)\": {**_get_default_bwd_config(hdim, causal), \"AtomLayoutNdKV\": 2},\n        }\n    ),\n    \"Q_dO_pipeline_sharing\": lambda hdim, causal: (\n        None if hdim > 128 else {\n            \"baseline (dO_stage=1, separate)\": {**_get_default_bwd_config(hdim, causal), \"num_stages_dO\": 1},\n            \"optimized (dO_stage=2, shared)\": {**_get_default_bwd_config(hdim, causal), \"num_stages_dO\": 2},\n        }\n    ),\n    \"tile_m\": lambda hdim, causal: (\n        None if hdim > 128 or causal else {\n            \"tile_m=64\": {**_get_default_bwd_config(hdim, causal), \"m_block_size\": 64},\n            \"tile_m=80\": {**_get_default_bwd_config(hdim, causal), \"m_block_size\": 80},\n        }\n    ),\n}\n\n\n# ── Run modes ─────────────────────────────────────────────────────────────\n\ndef run_default(args):\n    \"\"\"Standard fwd/bwd benchmark across hdims.\"\"\"\n    directions = [args.direction] if args.direction != \"both\" else [\"fwd\", \"bwd\"]\n\n    for direction in directions:\n        print(f\"\\n{'=' * 80}\")\n        print(f\"  SM90 {direction.upper()}  (rep={args.rep})\")\n        print(f\"{'=' * 80}\")\n        cols = f\"{'hdim':>5} {'hdim_v':>6} {'causal':>6} {'batch':>5} {'seqlen':>6} {'ms':>8} {'TFLOPS':>8}\"\n        if direction == \"fwd\":\n            cols += f\" {'max_diff':>10}\"\n        print(cols)\n        print(\"-\" * 80)\n\n        for hdim, hdim_v in args.hdim:\n            nheads = nheads_for_hdim(hdim)\n            for seqlen in args.seqlen:\n                batch = auto_batch(seqlen, args.batch)\n                for causal in get_causals(args):\n                    if direction == \"fwd\":\n                        ms, tflops, diff = bench_fwd(batch, seqlen, nheads, hdim, causal, warmup=args.warmup, rep=args.rep, hdim_v=hdim_v)\n                    else:\n                        ms, tflops, diff = bench_bwd(batch, seqlen, nheads, hdim, causal, warmup=args.warmup, rep=args.rep, hdim_v=hdim_v)\n\n                    if ms is not None:\n                        line = f\"{hdim:>5} {hdim_v:>6} {str(causal):>6} {batch:>5} {seqlen:>6} {ms:>8.3f} {tflops:>8.1f}\"\n                        if diff is not None:\n                            line += f\" {diff:>10.6f}\"\n                        print(line)\n                    else:\n                        print(f\"{hdim:>5} {hdim_v:>6} {str(causal):>6} {batch:>5} {seqlen:>6} {'FAIL':>8} {'':>8} {diff}\")\n\n\ndef run_sweep_tiles(args):\n    \"\"\"Sweep tile sizes for fwd across seqlens.\"\"\"\n    seqlens = args.seqlen\n\n    for hdim, hdim_v in args.hdim:\n        nheads = nheads_for_hdim(hdim)\n        configs = TILE_SWEEP_CONFIGS.get(hdim, [])\n        if not configs:\n            print(f\"No tile sweep configs for hdim={hdim}, skipping\")\n            continue\n\n        for causal in get_causals(args):\n            header = f\"{'hdim':>5} {'causal':>6} {'tile_m':>6} {'tile_n':>6} {'pv_rs':>5} {'ol':>5}\"\n            for sl in seqlens:\n                header += f\" {'s=' + str(sl):>8}\"\n            print(header)\n            print(\"=\" * len(header))\n\n            for tile_m, tile_n, rs, ol in configs:\n                row = f\"{hdim:>5} {str(causal):>6} {tile_m:>6} {tile_n:>6} {str(rs):>5} {str(ol):>5}\"\n                for sl in seqlens:\n                    batch = auto_batch(sl, args.batch)\n                    ms, tflops, diff = bench_fwd(batch, sl, nheads, hdim, causal,\n                                                 tile_m, tile_n, rs, ol,\n                                                 check_correctness=False, warmup=args.warmup, rep=args.rep, hdim_v=hdim_v)\n                    row += f\" {tflops:>8.1f}\" if tflops else f\" {'FAIL':>8}\"\n                print(row)\n            print()\n\n\ndef run_sweep_rs_overlap(args):\n    \"\"\"Sweep RS and intra-WG-overlap combinations for fwd.\"\"\"\n    seqlens = args.seqlen\n    tile_for_hdim = {64: (192, 128), 96: (192, 128), 128: (128, 128)}\n\n    for hdim, hdim_v in args.hdim:\n        nheads = nheads_for_hdim(hdim)\n        tile_m, tile_n = tile_for_hdim.get(hdim, (128, 128))\n\n        for causal in get_causals(args):\n            c_str = \"causal\" if causal else \"non-causal\"\n            header = f\"{'Config':<30} {'RS/OL':<12}\"\n            for sl in seqlens:\n                header += f\" {'s=' + str(sl):>8}\"\n            print(header)\n            print(\"=\" * len(header))\n\n            for rs, ol, rs_label in RS_OVERLAP_COMBOS:\n                label = f\"hdim{hdim} {c_str} {tile_m}x{tile_n}\"\n                row = f\"{label:<30} {rs_label:<12}\"\n                for sl in seqlens:\n                    batch = auto_batch(sl, args.batch)\n                    ms, tflops, diff = bench_fwd(batch, sl, nheads, hdim, causal,\n                                                 tile_m, tile_n, rs, ol,\n                                                 check_correctness=False, warmup=args.warmup, rep=args.rep, hdim_v=hdim_v)\n                    row += f\" {tflops:>8.1f}\" if tflops else f\" {'FAIL':>8}\"\n                print(row)\n            print()\n\n\ndef run_compare_configs(args):\n    \"\"\"Compare old vs new tile configs for fwd.\"\"\"\n    seqlens = args.seqlen\n\n    header = f\"{'Config':<50}\"\n    for sl in seqlens:\n        header += f\" {'s=' + str(sl):>8}\"\n    print(header)\n    print(\"=\" * len(header))\n\n    for hdim, causal, old, new in COMPARE_CONFIGS:\n        nheads = nheads_for_hdim(hdim)\n        c_str = \"causal\" if causal else \"non-causal\"\n        for label_prefix, cfg in [(\"OLD\", old), (\"NEW\", new)]:\n            label = f\"hdim{hdim} {c_str:<11} {label_prefix}  {cfg[0]}x{cfg[1]} RS={cfg[2]} OL={cfg[3]}\"\n            row = f\"{label:<50}\"\n            for sl in seqlens:\n                batch = auto_batch(sl, args.batch)\n                ms, tflops, diff = bench_fwd(batch, sl, nheads, hdim, causal, *cfg,\n                                             check_correctness=False, warmup=args.warmup, rep=args.rep)\n                row += f\" {tflops:>8.1f}\" if tflops else f\" {'FAIL':>8}\"\n            print(row)\n        print(\"-\" * len(header))\n\n\ndef run_sweep_bwd_opts(args):\n    \"\"\"Sweep backward kernel optimizations (V_in_regs, mma_dkv_is_rs, etc.).\"\"\"\n    seqlens = args.seqlen\n\n    for opt_name, get_configs_fn in BWD_OPT_CONFIGS.items():\n        for causal in get_causals(args):\n            c_str = \"causal\" if causal else \"non-causal\"\n            has_any = False\n\n            for hdim, hdim_v in args.hdim:\n                configs = get_configs_fn(hdim, causal)\n                if configs is None:\n                    continue\n                if not has_any:\n                    print(f\"\\n{'=' * 70}\")\n                    print(f\"BWD Optimization: {opt_name} ({c_str})\")\n                    print(f\"{'=' * 70}\")\n                    has_any = True\n\n                nheads = nheads_for_hdim(hdim)\n                print(f\"\\n  hdim={hdim}:\")\n                for sl in seqlens:\n                    batch = auto_batch(sl, args.batch)\n                    f = bwd_flops(batch, nheads, sl, hdim, causal, hdim_v=hdim_v)\n                    if len(seqlens) > 1:\n                        print(f\"    seqlen={sl}, batch={batch}:\")\n                    for label, kwargs in configs.items():\n                        ms, tflops, err = bench_bwd(batch, sl, nheads, hdim, causal,\n                                                     warmup=args.warmup, rep=args.rep, hdim_v=hdim_v, **kwargs)\n                        if ms is not None:\n                            print(f\"    {label:40s}: {ms:6.2f} ms  ({tflops:6.1f} TFLOPS)\")\n                        else:\n                            print(f\"    {label:40s}: FAIL  {err}\")\n\n\n# ── Main ──────────────────────────────────────────────────────────────────\n\ndef main():\n    parser = argparse.ArgumentParser(\n        description=\"Unified SM90 attention benchmark\",\n        formatter_class=argparse.RawDescriptionHelpFormatter,\n        epilog=__doc__,\n    )\n    parser.add_argument(\"--direction\", choices=[\"fwd\", \"bwd\", \"both\"], default=\"both\",\n                        help=\"Benchmark direction (default: both)\")\n    parser.add_argument(\"--hdim\", type=parse_headdims, default=[(64, 64), (96, 96), (128, 128)],\n                        help=\"Head dims, comma-separated. Each is hdim or hdim-hdim_v. E.g. 64,128,192-128\")\n    parser.add_argument(\"--seqlen\", type=csv_ints, default=[8192],\n                        help=\"Sequence lengths, comma-separated (default: 8192)\")\n    parser.add_argument(\"--batch\", type=int, default=0,\n                        help=\"Batch size (0 = auto ~32k tokens)\")\n    parser.add_argument(\"--warmup\", type=int, default=5,\n                        help=\"Warmup iterations (default: 5)\")\n    parser.add_argument(\"--rep\", type=int, default=30,\n                        help=\"Repetitions per benchmark (default: 30)\")\n    parser.add_argument(\"--causal-only\", action=\"store_true\")\n    parser.add_argument(\"--non-causal-only\", action=\"store_true\")\n\n    mode = parser.add_mutually_exclusive_group()\n    mode.add_argument(\"--sweep-tiles\", action=\"store_true\",\n                      help=\"Sweep fwd tile sizes\")\n    mode.add_argument(\"--sweep-rs-overlap\", action=\"store_true\",\n                      help=\"Sweep fwd RS/overlap combos\")\n    mode.add_argument(\"--compare-configs\", action=\"store_true\",\n                      help=\"Compare old vs new fwd tile configs\")\n    mode.add_argument(\"--sweep-bwd-opts\", action=\"store_true\",\n                      help=\"Sweep bwd optimizations (V_in_regs, mma_dkv_is_rs, etc.)\")\n\n    args = parser.parse_args()\n    torch.manual_seed(0)\n\n    if args.sweep_tiles:\n        run_sweep_tiles(args)\n    elif args.sweep_rs_overlap:\n        run_sweep_rs_overlap(args)\n    elif args.compare_configs:\n        run_compare_configs(args)\n    elif args.sweep_bwd_opts:\n        run_sweep_bwd_opts(args)\n    else:\n        run_default(args)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "benchmarks/benchmark_alibi.py",
    "content": "# Copyright (c) 2024, Sanghun Cho, Tri Dao.\n\nimport pickle\nimport math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom einops import rearrange, repeat\nfrom flash_attn.layers.rotary import apply_rotary_emb\n\nfrom flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward\nfrom flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined\n\nfrom flash_attn import flash_attn_qkvpacked_func, flash_attn_func\n\ntry:\n    import xformers.ops as xops\nexcept ImportError:\n    xops = None\n\n\ndef generate_cos_sin(seqlen, rotary_dim, device, dtype):\n    assert rotary_dim % 2 == 0\n    angle = torch.rand(seqlen * 2, rotary_dim // 2, device=device) * 2 * math.pi\n    cos = torch.cos(angle).to(dtype=dtype)\n    sin = torch.sin(angle).to(dtype=dtype)\n    return cos, sin\n\n\ndef flash_rotary(q, k, v, cos, sin, causal=False):\n    # corrected by @tridao comments\n    q = apply_rotary_emb(\n        q, cos, sin, seqlen_offsets=0, interleaved=False, inplace=True\n    )\n    k = apply_rotary_emb(\n        k, cos, sin, seqlen_offsets=0, interleaved=False, inplace=True\n    )\n\n    return flash_attn_func(q, k, v, causal=causal)\n\n\ndef attn_bias_from_alibi_slopes(\n    slopes, seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, causal=False\n):\n    batch, nheads = slopes.shape\n    device = slopes.device\n    slopes = rearrange(slopes, \"b h -> b h 1 1\")\n    if causal:\n        return torch.arange(-seqlen_k + 1, 1, device=device, dtype=torch.float32) * slopes\n    else:\n        row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), \"s -> s 1\")\n        col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)\n        sk = (\n            seqlen_k\n            if key_padding_mask is None\n            else rearrange(key_padding_mask.sum(-1), \"b -> b 1 1 1\")\n        )\n        sq = (\n            seqlen_q\n            if query_padding_mask is None\n            else rearrange(query_padding_mask.sum(-1), \"b -> b 1 1 1\")\n        )\n        relative_pos = torch.abs(row_idx + sk - sq - col_idx)\n        return -slopes * relative_pos.to(dtype=slopes.dtype)\n\n\ndef flops(batch, seqlen, headdim, nheads, causal, mode=\"fwd\"):\n    assert mode in [\"fwd\", \"bwd\", \"fwd_bwd\"]\n    f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)\n    return f if mode == \"fwd\" else (2.5 * f if mode == \"bwd\" else 3.5 * f)\n\n\ndef efficiency(flop, time):\n    return (flop / time / 10**12) if not math.isnan(time) else 0.0\n\n\ndef attention_pytorch(q, k, v, dropout_p=0.0, causal=True, attn_bias=None):\n    \"\"\"\n    Arguments:\n        q, k, v: (batch_size, seqlen, nheads, head_dim)\n        dropout_p: float\n        attn_bias: (batch_size, nheads, seqlen, seqlen) or (1, nheads, seqlen, seqlen)\n    Output:\n        output: (batch_size, seqlen, nheads, head_dim)\n    \"\"\"\n    batch_size, seqlen, nheads, d = q.shape\n    q = rearrange(q, 'b t h d -> (b h) t d')\n    k = rearrange(k, 'b s h d -> (b h) d s')\n    softmax_scale = 1.0 / math.sqrt(d)\n    # Preallocate attn_weights for `baddbmm`\n    if attn_bias is not None:\n        scores = rearrange(attn_bias, 'b h t s -> (b h) t s')\n    else:\n        scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=q.dtype, device=q.device)\n    scores = rearrange(torch.baddbmm(scores, q, k, beta=1.0, alpha=softmax_scale),\n                       '(b h) t s -> b h t s', h=nheads)\n    if causal:\n        # \"triu_tril_cuda_template\" not implemented for 'BFloat16'\n        # So we have to construct the mask in float\n        causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)\n        # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)\n        scores = scores + causal_mask.to(dtype=scores.dtype)\n    attention = torch.softmax(scores, dim=-1)\n    attention_drop = F.dropout(attention, dropout_p)\n    output = torch.einsum('bhts,bshd->bthd', attention_drop , v)\n    return output.to(dtype=q.dtype)\n\n\ndef time_fwd_bwd(func, *args, **kwargs):\n    time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs)\n    return time_f[1].mean, time_b[1].mean\n\n\nrepeats = 30\ndevice = 'cuda'\ndtype = torch.float16\n\nbs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)]\ncausal_vals = [False, True]\nheaddim_vals = [64, 128]\ndim = 2048\ndropout_p = 0.0\n\nmethods = ([\"fa2_alibi\", \"torch\"]\n           + ([\"xformers\"] if xops is not None else [])\n           + [\"sdpa\"]\n           + [\"fa2_baseline\"]\n           + [\"fa2_rotary\"])\n\ntime_f = {}\ntime_b = {}\ntime_f_b = {}\nspeed_f = {}\nspeed_b = {}\nspeed_f_b = {}\nfor causal in causal_vals:\n    for headdim in headdim_vals:\n        for batch_size, seqlen in bs_seqlen_vals:\n            config = (causal, headdim, batch_size, seqlen)\n            nheads = dim // headdim\n            q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype,\n                                    requires_grad=True) for _ in range(3)]\n            # alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3\n            alibi_slopes = torch.rand(1, nheads, device=device, dtype=torch.float32) * 0.3\n            attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen, seqlen, causal=causal).to(dtype)\n            attn_bias = repeat(attn_bias, \"1 ... -> b ...\", b=batch_size)\n            f, b = time_fwd_bwd(\n                flash_attn_func,\n                q, k, v,\n                dropout_p,\n                causal=causal,\n                # alibi_slopes=alibi_slopes,\n                alibi_slopes=None,\n                repeats=repeats,\n                verbose=False\n            )\n            time_f[config, \"fa2_baseline\"] = f\n            time_b[config, \"fa2_baseline\"] = b\n\n            q = q.detach().requires_grad_(True)\n            k = k.detach().requires_grad_(True)\n            v = v.detach().requires_grad_(True)\n            f, b = time_fwd_bwd(\n                flash_attn_func,\n                q, k, v,\n                dropout_p,\n                causal=causal,\n                alibi_slopes=rearrange(alibi_slopes, \"1 h -> h\"),\n                # alibi_slopes=None,\n                repeats=repeats,\n                verbose=False\n            )\n            time_f[config, \"fa2_alibi\"] = f\n            time_b[config, \"fa2_alibi\"] = b\n\n            try:\n                q = q.detach().requires_grad_(True)\n                k = k.detach().requires_grad_(True)\n                v = v.detach().requires_grad_(True)\n                f, b = time_fwd_bwd(\n                    attention_pytorch,\n                    q, k, v,\n                    dropout_p,\n                    causal=causal,\n                    attn_bias=attn_bias,\n                    repeats=repeats,\n                    verbose=False\n                )\n            except:  # Skip if OOM\n                f, b = float('nan'), float('nan')\n            time_f[config, \"torch\"] = f\n            time_b[config, \"torch\"] = b\n\n            # F.sdpa doesn't currently (torch 2.1) dispatch to flash-attn but just to be safe\n            with torch.backends.cuda.sdp_kernel(enable_flash=False):\n                q_pt = q.detach().requires_grad_(True).transpose(1, 2)\n                k_pt = k.detach().requires_grad_(True).transpose(1, 2)\n                v_pt = v.detach().requires_grad_(True).transpose(1, 2)\n                f, b = time_fwd_bwd(\n                    F.scaled_dot_product_attention,\n                    q_pt, k_pt, v_pt,\n                    attn_mask=attn_bias,\n                    dropout_p=dropout_p,\n                    is_causal=causal,\n                    repeats=repeats,\n                    verbose=False\n                )\n                time_f[config, \"sdpa\"] = f\n                time_b[config, \"sdpa\"] = b\n\n            if xops is not None:\n                q = q.detach().requires_grad_(True)\n                k = k.detach().requires_grad_(True)\n                v = v.detach().requires_grad_(True)\n                if causal:\n                    attn_bias_xops = xops.LowerTriangularMask().add_bias(attn_bias.expand(-1, -1, seqlen, -1).to(dtype=q.dtype))\n                    # NotImplementedError: No operator found for `memory_efficient_attention_backward` with inputs:\n                    # `flshattB@v2.3.6` is not supported because:\n                    #     attn_bias type is <class 'xformers.ops.fmha.attn_bias.LowerTriangularMaskWithTensorBias'>\n                    # `cutlassB` is not supported because:\n                    #     attn_bias type is <class 'xformers.ops.fmha.attn_bias.LowerTriangularMaskWithTensorBias'>\n                    attn_bias_xops = attn_bias_xops.materialize((batch_size, nheads, seqlen, seqlen), dtype=q.dtype, device=device)\n                else:\n                    attn_bias_xops = attn_bias.to(dtype=q.dtype)\n                f, b = time_fwd_bwd(\n                    xops.memory_efficient_attention,\n                    q, k, v,\n                    attn_bias_xops,\n                    dropout_p,\n                    repeats=repeats,\n                    verbose=False\n                )\n                time_f[config, \"xformers\"] = f\n                time_b[config, \"xformers\"] = b\n\n            q = q.detach().requires_grad_(True)\n            k = k.detach().requires_grad_(True)\n            v = v.detach().requires_grad_(True)\n            cos, sin = generate_cos_sin(seqlen, headdim, device, dtype)\n            f, b = time_fwd_bwd(\n                flash_rotary,\n                q, k, v,\n                cos, sin,\n                causal,\n                repeats=repeats,\n                verbose=False\n            )\n            time_f[config, \"fa2_rotary\"] = f\n            time_b[config, \"fa2_rotary\"] = b\n\n            print(f\"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###\")\n            csv_output = \"\"\n            csv_output += f\"{causal},{headdim},{batch_size},{seqlen},\"\n            for method in methods:\n                time_f_b[config, method] = time_f[config, method] + time_b[config, method]\n                speed_f[config, method] = efficiency(\n                    flops(batch_size, seqlen, headdim, nheads, causal, mode=\"fwd\"),\n                    time_f[config, method]\n                )\n                speed_b[config, method] = efficiency(\n                    flops(batch_size, seqlen, headdim, nheads, causal, mode=\"bwd\"),\n                    time_b[config, method]\n                )\n                speed_f_b[config, method] = efficiency(\n                    flops(batch_size, seqlen, headdim, nheads, causal, mode=\"fwd_bwd\"),\n                    time_f_b[config, method]\n                )\n                print(\n                    f\"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, \"\n                    f\"bwd: {speed_b[config, method]:.2f} TFLOPs/s, \"\n                    f\"fwd + bwd: {speed_f_b[config, method]:.2f} TFLOPs/s\"\n                )\n                csv_output += f\"{speed_f[config, method]:.2f},{speed_b[config, method]:.2f},{speed_f_b[config, method]:.2f},\"\n            print(csv_output)\n"
  },
  {
    "path": "benchmarks/benchmark_attn.py",
    "content": "import argparse\nimport time\nimport torch\n\ntry:\n    import cudnn\nexcept ImportError:\n    cudnn = None\n\nfrom einops import rearrange\n\nfrom flash_attn.cute.bench_utils import (\n    flops, attention_ref,\n    cudnn_fwd_setup, cudnn_bwd_setup,\n)\n\ntry:\n    from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func\nexcept ImportError:\n    flash_attn_func = None\n    flash_attn_varlen_func = None\ntry:\n    from flash_attn.cute.interface import flash_attn_func as flash_attn_func_python\n    from flash_attn.cute.interface import flash_attn_varlen_func as flash_attn_varlen_func_python\nexcept ImportError:\n    flash_attn_func_python = None\n    flash_attn_varlen_func_python = None\ntry:\n    from flash_attn_interface import flash_attn_func as flash_attn_func_v3\n    from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3\nexcept ImportError:\n    flash_attn_func_v3 = None\n    flash_attn_varlen_func_v3 = None\n\nif torch.cuda.get_device_capability()[0] != 9:\n    flash_attn_func_v3 = None\n\nfrom triton.testing import do_bench\n\n\n# ── Autograd backward helper ────────────────────────────────────────────────\n\ndef _make_bwd_fn(fwd_fn, g, inputs):\n    \"\"\"Run fwd once, return a closure that benchmarks backward.\n\n    Args:\n        fwd_fn: zero-arg callable that runs the forward pass (with autograd).\n        g: gradient tensor (b, seqlen, nheads, headdim_v).\n        inputs: list of input tensors whose .grad should be cleared each iteration.\n    \"\"\"\n    out = fwd_fn()\n    if isinstance(out, tuple):\n        out = out[0]\n    g_match = g[:out.shape[0]] if g.shape[0] != out.shape[0] else g  # handle varlen\n    def bwd_fn():\n        for x in inputs:\n            x.grad = None\n        out.backward(g_match, retain_graph=True)\n    return bwd_fn\n\n\n# ── Backend definitions ─────────────────────────────────────────────────────\n# Each setup_* function takes a context dict and returns (fwd_fn, bwd_fn).\n# Either can be None if the backend doesn't support that direction for the\n# given config.  fwd_fn / bwd_fn are zero-arg callables suitable for do_bench.\n\ndef setup_standard(ctx):\n    if ctx[\"dtype\"] == torch.float8_e4m3fn:\n        return None, None\n    q, k, v, g, causal = ctx[\"q\"], ctx[\"k\"], ctx[\"v\"], ctx[\"g\"], ctx[\"causal\"]\n    fwd_fn = lambda: attention_ref(q, k, v, causal=causal)\n    bwd_fn = _make_bwd_fn(fwd_fn, g, [q, k, v]) if ctx[\"has_backward\"] else None\n    return fwd_fn, bwd_fn\n\n\ndef setup_fa2(ctx):\n    if flash_attn_func is None or ctx[\"dtype\"] == torch.float8_e4m3fn:\n        return None, None\n    if ctx[\"headdim\"] != ctx[\"headdim_v\"]:\n        return None, None\n    q, k, v, g, causal = ctx[\"q\"], ctx[\"k\"], ctx[\"v\"], ctx[\"g\"], ctx[\"causal\"]\n    dropout_p, window_size_fa, softcap = ctx[\"dropout_p\"], ctx[\"window_size_fa\"], ctx[\"softcap\"]\n    deterministic = ctx[\"deterministic\"]\n    if ctx[\"varlen\"]:\n        qu, ku, vu = ctx[\"q_unpad\"], ctx[\"k_unpad\"], ctx[\"v_unpad\"]\n        csq, csk, sq, sk = ctx[\"cu_seqlens_q\"], ctx[\"cu_seqlens_k\"], ctx[\"seqlen_q\"], ctx[\"seqlen\"]\n        fwd_fn = lambda: flash_attn_varlen_func(qu, ku, vu, csq, csk, sq, sk, dropout_p, causal=causal, window_size=window_size_fa, softcap=softcap)\n        bwd_fn = _make_bwd_fn(lambda: flash_attn_varlen_func(qu, ku, vu, csq, csk, sq, sk, dropout_p, causal=causal, window_size=window_size_fa, softcap=softcap, deterministic=deterministic), g, [qu, ku, vu]) if ctx[\"has_backward\"] else None\n    else:\n        fwd_fn = lambda: flash_attn_func(q, k, v, dropout_p, causal=causal, window_size=window_size_fa, softcap=softcap)\n        bwd_fn = _make_bwd_fn(lambda: flash_attn_func(q, k, v, dropout_p, causal=causal, window_size=window_size_fa, softcap=softcap, deterministic=deterministic), g, [q, k, v]) if ctx[\"has_backward\"] else None\n    return fwd_fn, bwd_fn\n\n\ndef setup_cudnn(ctx):\n    if cudnn is None or ctx[\"headdim\"] > 256 or ctx[\"dtype\"] == torch.float8_e4m3fn:\n        return None, None\n    q, k, v, g, causal = ctx[\"q\"], ctx[\"k\"], ctx[\"v\"], ctx[\"g\"], ctx[\"causal\"]\n    window_size_left = ctx[\"window_size\"][0]\n    # cuDNN expects (batch, nheads, seqlen, headdim) layout\n    qt, kt, vt, gt = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), g.transpose(1, 2)\n    fwd_fn, o_gpu, lse_gpu = cudnn_fwd_setup(qt, kt, vt, causal=causal, window_size_left=window_size_left)\n    bwd_fn = None\n    if ctx[\"has_backward\"]:\n        fwd_fn()  # populate o and lse for bwd graph\n        bwd_fn = cudnn_bwd_setup(qt, kt, vt, o_gpu, gt, lse_gpu, causal=causal, window_size_left=window_size_left)\n    return fwd_fn, bwd_fn\n\n\ndef setup_fa3(ctx):\n    if flash_attn_func_v3 is None:\n        return None, None\n    q, k, v, g, causal = ctx[\"q\"], ctx[\"k\"], ctx[\"v\"], ctx[\"g\"], ctx[\"causal\"]\n    window_size_fa, softcap = ctx[\"window_size_fa\"], ctx[\"softcap\"]\n    num_splits, pack_gqa, deterministic = ctx[\"num_splits\"], ctx[\"pack_gqa\"], ctx[\"deterministic\"]\n    k_use = ctx.get(\"k_paged\", k) if ctx[\"page_size\"] is not None else k\n    v_use = ctx.get(\"v_paged\", v) if ctx[\"page_size\"] is not None else v\n    if ctx[\"varlen\"]:\n        qu, ku, vu = ctx[\"q_unpad\"], ctx[\"k_unpad\"], ctx[\"v_unpad\"]\n        csq, csk, sq, sk = ctx[\"cu_seqlens_q\"], ctx[\"cu_seqlens_k\"], ctx[\"seqlen_q\"], ctx[\"seqlen\"]\n        fwd_fn = lambda: flash_attn_varlen_func_v3(qu, ku, vu, csq, csk, sq, sk, causal=causal, window_size=window_size_fa, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa)\n    else:\n        fwd_fn = lambda: flash_attn_func_v3(q, k_use, v_use, causal=causal, window_size=window_size_fa, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa)\n    # FA3 bwd only supports headdim == headdim_v and non-fp8\n    bwd_fn = None\n    if ctx[\"has_backward\"] and ctx[\"dtype\"] != torch.float8_e4m3fn and ctx[\"headdim\"] == ctx[\"headdim_v\"]:\n        if ctx[\"varlen\"]:\n            bwd_fn = _make_bwd_fn(lambda: flash_attn_varlen_func_v3(qu, ku, vu, csq, csk, sq, sk, causal=causal, window_size=ctx[\"window_size\"], softcap=softcap, deterministic=deterministic), g, [qu, ku, vu])\n        else:\n            bwd_fn = _make_bwd_fn(lambda: flash_attn_func_v3(q, k, v, causal=causal, softcap=softcap), g, [q, k, v])\n    return fwd_fn, bwd_fn\n\n\ndef setup_fa4(ctx):\n    if flash_attn_func_python is None:\n        return None, None\n    q, k, v, g, causal = ctx[\"q\"], ctx[\"k\"], ctx[\"v\"], ctx[\"g\"], ctx[\"causal\"]\n    window_size, softcap = ctx[\"window_size\"], ctx[\"softcap\"]\n    pack_gqa, deterministic = ctx[\"pack_gqa\"], ctx[\"deterministic\"]\n    sinks = ctx[\"sinks\"]\n    k_use = ctx.get(\"k_paged\", k) if ctx[\"page_size\"] is not None else k\n    v_use = ctx.get(\"v_paged\", v) if ctx[\"page_size\"] is not None else v\n    if ctx[\"varlen\"]:\n        qu = ctx[\"q_unpad\"]\n        ku = ctx.get(\"k_paged\", ctx[\"k_unpad\"]) if ctx[\"page_size\"] is not None else ctx[\"k_unpad\"]\n        vu = ctx.get(\"v_paged\", ctx[\"v_unpad\"]) if ctx[\"page_size\"] is not None else ctx[\"v_unpad\"]\n        csq, csk = ctx[\"cu_seqlens_q\"], ctx[\"cu_seqlens_k\"]\n        pt = ctx[\"page_table\"]\n        fwd_fn = lambda: flash_attn_varlen_func_python(qu, ku, vu, csq, csk, page_table=pt, causal=causal, window_size=window_size, softcap=softcap, pack_gqa=pack_gqa)\n    else:\n        fwd_fn = lambda: flash_attn_func_python(q, k_use, v_use, causal=causal, window_size=window_size, learnable_sink=sinks, softcap=softcap, pack_gqa=pack_gqa)\n    bwd_fn = None\n    if ctx[\"has_backward\"] and ctx[\"dtype\"] != torch.float8_e4m3fn:\n        if ctx[\"varlen\"]:\n            qu, ku, vu = ctx[\"q_unpad\"], ctx[\"k_unpad\"], ctx[\"v_unpad\"]\n            csq, csk = ctx[\"cu_seqlens_q\"], ctx[\"cu_seqlens_k\"]\n            bwd_fn = _make_bwd_fn(lambda: flash_attn_varlen_func_python(qu, ku, vu, csq, csk, causal=causal, softcap=softcap, deterministic=deterministic), g, [qu, ku, vu])\n        else:\n            bwd_fn = _make_bwd_fn(lambda: flash_attn_func_python(q, k, v, causal=causal, softcap=softcap, deterministic=deterministic), g, [q, k, v])\n    return fwd_fn, bwd_fn\n\n\n# Ordered list of (display_name, cli_name, setup_fn)\nBACKENDS = [\n    (\"Standard\", \"standard\", setup_standard),\n    (\"FA2\",      \"fa2\",      setup_fa2),\n    (\"cuDNN\",    \"cudnn\",    setup_cudnn),\n    (\"FA3\",      \"fa3\",      setup_fa3),\n    (\"FA4\",      \"fa4\",      setup_fa4),\n]\n\n\ndef parse_int_k(s):\n    \"\"\"Parse an integer with optional k/K suffix, e.g. '8k' -> 8192.\"\"\"\n    s = s.strip().lower()\n    if s.endswith(\"k\"):\n        return int(s[:-1]) * 1024\n    return int(s)\n\n\ndef csv_ints(s):\n    \"\"\"Parse comma-separated integers with optional k suffix, e.g. '512,1k,2k'.\"\"\"\n    return [parse_int_k(x) for x in s.split(\",\")]\n\n\ndef parse_headdims(s):\n    \"\"\"Parse comma-separated headdim specs. Each entry is hdim or hdim-hdim_v.\n\n    Examples:\n        '128'           -> [(128, 128)]\n        '192-128'       -> [(192, 128)]\n        '64,128,192'    -> [(64, 64), (128, 128), (192, 192)]\n        '64,128,192-128,192' -> [(64, 64), (128, 128), (192, 128), (192, 192)]\n    \"\"\"\n    result = []\n    for item in s.split(\",\"):\n        if \"-\" in item:\n            parts = item.split(\"-\")\n            result.append((int(parts[0]), int(parts[1])))\n        else:\n            hdim = int(item)\n            result.append((hdim, hdim))\n    return result\n\n\ndef csv_strs(s):\n    \"\"\"Parse comma-separated strings, e.g. 'fa3,fa4'.\"\"\"\n    return [x.strip() for x in s.split(\",\")]\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description='Benchmark FlashAttention')\n    parser.add_argument('--headdim', type=parse_headdims, default=[(128, 128)],\n                        help='Head dim(s), comma-separated. Each is hdim or hdim-hdim_v. E.g. 64,128,192-128')\n    parser.add_argument('--fwd', action='store_true', help='Run forward only')\n    parser.add_argument('--bwd', action='store_true', help='Run backward only')\n    parser.add_argument('--varlen', action='store_true', default=False)\n    parser.add_argument('--causal', type=str.lower, choices=['true', 'false', 'both'], default='both',\n                        help='Causal mode (default: both)')\n    parser.add_argument('--seqlen', type=csv_ints, default=[8192],\n                        help='Sequence length(s), comma-separated. Supports k suffix, e.g. 1k,2k,8k')\n    parser.add_argument('--total-seqlen', type=parse_int_k, default='32k',\n                        help='Total sequence length for batch sizing (default: 32k)')\n    parser.add_argument('--batch-size', type=int, default=None,\n                        help='Batch size (default: total_seqlen // seqlen)')\n    parser.add_argument('--deterministic', action='store_true', default=False)\n    parser.add_argument('--nheads', type=int, default=None,\n                        help='Number of Q heads (default: 32 for hdim<=64, 16 for hdim<=192, 8 for hdim>192)')\n    parser.add_argument('--nheads-kv', type=int, default=None,\n                        help='Number of KV heads (default: nheads)')\n    parser.add_argument('--gqa-ratio', type=int, default=None,\n                        help='GQA ratio (nheads // nheads_kv). Ignored if --nheads-kv is set.')\n    parser.add_argument('--backend', type=csv_strs, default=['all'],\n                        help='Which backends to benchmark, comma-separated (choices: all,standard,fa2,fa3,fa4,cudnn)')\n    parser.add_argument('--warmup', type=int, default=5,\n                        help='Warmup iterations (default: 5)')\n    parser.add_argument('--rep', type=int, default=10,\n                        help='Repetitions per benchmark (default: 10)')\n    return parser.parse_args()\n\n\ndef main():\n    args = parse_args()\n\n    headdim_pairs = args.headdim  # list of (hdim, hdim_v) tuples\n\n    # Parse fwd/bwd: if neither specified, do fwd only\n    has_forward = args.fwd or not args.bwd\n    has_backward = args.bwd\n\n    # Parse causal\n    if args.causal == 'true':\n        causal_vals = [True]\n    elif args.causal == 'false':\n        causal_vals = [False]\n    else:\n        causal_vals = [False, True]\n\n    seqlen_list = args.seqlen\n    varlen = args.varlen\n\n    # Filter backends to those requested and available\n    enabled = set(args.backend)\n    if 'all' in enabled:\n        enabled = {cli for _, cli, _ in BACKENDS}\n    active_backends = [(name, cli, fn) for name, cli, fn in BACKENDS if cli in enabled]\n\n    # Parameters\n    torch.manual_seed(0)\n    dropout_p = 0.0\n    dtype = torch.bfloat16\n    dtype_gen = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype\n    device = 'cuda'\n    page_size = None\n    softcap = 0.0\n    deterministic = args.deterministic\n    warmup, rep = args.warmup, args.rep\n\n    time_f = {}\n    time_b = {}\n\n    for headdim, headdim_v in headdim_pairs:\n        nheads = args.nheads if args.nheads is not None else (32 if headdim <= 64 else 16 if headdim <= 192 else 8)\n        if args.nheads_kv is not None:\n            nheads_kv = args.nheads_kv\n        elif args.gqa_ratio is not None:\n            nheads_kv = nheads // args.gqa_ratio\n        else:\n            nheads_kv = nheads\n        has_qv = headdim == 64 and headdim_v == 512\n        sinks = None\n\n        num_splits = 0\n        window_size = (None, None)\n        window_size_fa = (-1, -1)\n        pack_gqa = None\n\n        for seqlen in seqlen_list:\n            batch_size = args.batch_size if args.batch_size is not None else max(1, args.total_seqlen // seqlen)\n            seqlen_q = seqlen\n\n            q = torch.randn(batch_size, seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=has_backward)\n            k = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype_gen, requires_grad=has_backward)\n            v = torch.randn(batch_size, seqlen, nheads_kv, headdim_v, device=device, dtype=dtype_gen, requires_grad=has_backward)\n            q, k, v = [x.detach().to(dtype).requires_grad_(has_backward) for x in [q, k, v]]\n            g = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen)\n\n            # Varlen tensors\n            q_unpad = k_unpad = v_unpad = cu_seqlens_q = cu_seqlens_k = None\n            if varlen:\n                q_unpad, k_unpad, v_unpad = [rearrange(x.detach(), \"b s h d -> (b s) h d\").requires_grad_(has_backward) for x in [q, k, v]]\n                cu_seqlens_q = torch.arange(batch_size + 1, device=device, dtype=torch.int32) * seqlen_q\n                cu_seqlens_k = torch.arange(batch_size + 1, device=device, dtype=torch.int32) * seqlen if page_size is None else None\n\n            # Paged KV tensors\n            k_paged = v_paged = page_table = None\n            if page_size is not None:\n                assert seqlen % page_size == 0\n                k_paged, v_paged = [rearrange(x, \"b (n p) h d -> (b n) p h d\", p=page_size) for x in [k, v]]\n                page_table = rearrange(torch.arange(batch_size * seqlen // page_size, device=device, dtype=torch.int32),\n                                       \"(b s) -> b s\", s=seqlen // page_size)\n\n            for causal in causal_vals:\n                cfg = (headdim, headdim_v, causal, seqlen, batch_size, nheads)\n\n                # Build context dict shared by all backends\n                ctx = dict(\n                    q=q, k=k, v=v, g=g, causal=causal,\n                    headdim=headdim, headdim_v=headdim_v, dtype=dtype,\n                    has_backward=has_backward,\n                    varlen=varlen, q_unpad=q_unpad, k_unpad=k_unpad, v_unpad=v_unpad,\n                    cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,\n                    seqlen_q=seqlen_q, seqlen=seqlen,\n                    page_size=page_size, k_paged=k_paged, v_paged=v_paged, page_table=page_table,\n                    dropout_p=dropout_p, window_size=window_size, window_size_fa=window_size_fa,\n                    softcap=softcap, deterministic=deterministic,\n                    num_splits=num_splits, pack_gqa=pack_gqa, sinks=sinks,\n                )\n\n                for display_name, cli_name, setup_fn in active_backends:\n                    fwd_fn, bwd_fn = setup_fn(ctx)\n                    if fwd_fn is not None and has_forward:\n                        time.sleep(1.0)\n                        print(f\"Benchmarking {display_name} fwd, hdim={headdim}, seqlen={seqlen}, causal={causal}\")\n                        ms = do_bench(fwd_fn, warmup=warmup, rep=rep) * 1e-3\n                        time_f[cfg, display_name] = ms\n                    if bwd_fn is not None and has_backward:\n                        time.sleep(1.0)\n                        print(f\"Benchmarking {display_name} bwd, hdim={headdim}, seqlen={seqlen}, causal={causal}\")\n                        ms = do_bench(bwd_fn, warmup=warmup, rep=rep) * 1e-3\n                        time_b[cfg, display_name] = ms\n\n    # ── Print results table ──────────────────────────────────────────────────\n    backend_names = [name for name, _, _ in BACKENDS]\n    shown_backends = [b for b in backend_names if any(b == k[1] for k in list(time_f) + list(time_b))]\n\n    if not shown_backends:\n        return\n\n    col_w = 16\n\n    for direction, times, flops_mult in [(\"FWD\", time_f, 1.0), (\"BWD\", time_b, 2.5)]:\n        if not times:\n            continue\n        configs = sorted(set(k[0] for k in times))\n        if not configs:\n            continue\n\n        header = f\"{'hdim':>9} {'causal':>6} {'batch':>5} {'seqlen':>6}\"\n        for b in shown_backends:\n            header += f\" {b:>{col_w}}\"\n        print(f\"\\n{'=' * len(header)}\")\n        print(f\"  {direction} (ms / TFLOPS)\")\n        print(f\"{'=' * len(header)}\")\n        print(header)\n        print(\"-\" * len(header))\n\n        for cfg in configs:\n            headdim, headdim_v, causal, seqlen, batch_size, nheads = cfg\n            nFLOPS = flops(batch_size, nheads, seqlen, seqlen, headdim, headdim_v, causal=causal)\n            hdim_str = str(headdim) if headdim == headdim_v else f\"{headdim}-{headdim_v}\"\n            row = f\"{hdim_str:>9} {str(causal):>6} {batch_size:>5} {seqlen:>6}\"\n            for b in shown_backends:\n                t = times.get((cfg, b))\n                if t is not None:\n                    tflops = flops_mult * nFLOPS / t * 1e-12\n                    ms = t * 1e3\n                    cell = f\"{ms:.2f}/{tflops:.0f}\"\n                    row += f\" {cell:>{col_w}}\"\n                else:\n                    row += f\" {'—':>{col_w}}\"\n            print(row)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "benchmarks/benchmark_causal.py",
    "content": "from functools import partial\nimport math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom einops import rearrange, repeat\n\n# from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler\nfrom flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler\nfrom flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func\n# # from flash_attn.triton.fused_attention import attention as attention\n# from flash_attn.flash_attn_triton import flash_attn_qkvpacked_func\n# from flash_attn.flash_attn_triton_og import attention as attention_og\n\n# from triton.ops.flash_attention import attention as attention_triton\n\nfrom flash_attn import flash_attn_qkvpacked_func, flash_attn_kvpacked_func\n\ndef attention_pytorch(qkv, dropout_p=0.0, causal=True):\n    \"\"\"\n    Arguments:\n        qkv: (batch_size, seqlen, 3, nheads, head_dim)\n        dropout_p: float\n    Output:\n        output: (batch_size, seqlen, nheads, head_dim)\n    \"\"\"\n    batch_size, seqlen, _, nheads, d = qkv.shape\n    q, k, v = qkv.unbind(dim=2)\n    q = rearrange(q, 'b t h d -> (b h) t d')\n    k = rearrange(k, 'b s h d -> (b h) d s')\n    softmax_scale = 1.0 / math.sqrt(d)\n    # Preallocate attn_weights for `baddbmm`\n    scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device)\n    scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale),\n                       '(b h) t s -> b h t s', h=nheads)\n    if causal:\n        # \"triu_tril_cuda_template\" not implemented for 'BFloat16'\n        # So we have to construct the mask in float\n        causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)\n        # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)\n        scores = scores + causal_mask.to(dtype=scores.dtype)\n    attention = torch.softmax(scores, dim=-1)\n    attention_drop = F.dropout(attention, dropout_p)\n    output = torch.einsum('bhts,bshd->bthd', attention_drop , v)\n    return output.to(dtype=qkv.dtype)\n\n\ntorch.manual_seed(0)\nrepeats = 30\nbatch_size = 8\nseqlen = 2048\nnheads = 12\nheaddim = 128\n# nheads = 24\n# headdim = 64\n# batch_size = 64\n# seqlen = 512\n# nheads = 8\n# headdim = 128\ndropout_p = 0.0\ncausal = True\ndtype = torch.float16\ndevice = 'cuda'\n\nqkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,\n                  requires_grad=True)\ncu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,\n                          device=qkv.device)\n\nqkv_unpad = rearrange(qkv, 'b s ... -> (b s) ...').detach().requires_grad_(True)\n# benchmark_all(flash_attn_varlen_qkvpacked_func, qkv_unpad,\n#               cu_seqlens, seqlen, dropout_p, causal=causal, repeats=repeats, desc='FlashAttention')\n# pytorch_profiler(flash_attn_varlen_qkvpacked_func, qkv_unpad,\n#                  cu_seqlens, seqlen, dropout_p, causal=causal, backward=True)\nbenchmark_forward(flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, desc='Fav2')\npytorch_profiler(flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, backward=False)\n\n# for dropout_p in [0.1, 0.0]:\n#     for causal in [False, True]:\n#         print(f\"### {dropout_p = }, {causal = } ###\")\n#         pytorch_profiler(fav2_qkvpacked_func, qkv, dropout_p, causal=causal, backward=True)\n\n\n# nheads_k = 2\n# q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True)\n# kv = torch.randn(batch_size, seqlen, 2, nheads_k, headdim, device=device, dtype=dtype,\n#                  requires_grad=True)\n# if fav2_kvpacked_func is not None:\n#     benchmark_all(fav2_kvpacked_func, q, kv, dropout_p, causal=causal, repeats=repeats, desc='Fav2')\n#     pytorch_profiler(fav2_kvpacked_func, q, kv, dropout_p, causal=causal, backward=True)\n\n# dropout_p = 0.0\n# causal = False\n# benchmark_all(attention_pytorch, qkv, dropout_p, causal=causal,\n#               repeats=repeats, desc='PyTorch Attention')\n\n# benchmark_all(flash_attn_qkvpacked_func, qkv, None, causal, repeats=repeats, desc='FlashAttention Triton')\n# pytorch_profiler(flash_attn_qkvpacked_func, qkv, None, causal, backward=True)\n\n# q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype,\n#                        requires_grad=True) for _ in range(3)]\n# benchmark_all(attention_og, q, k, v, 1.0, repeats=repeats, desc='FlashAttention Triton OG')\n# # pytorch_profiler(attention, q, k, v, 1.0, backward=True)\n\n# from src.ops.fftconv import fftconv_func\n\n# dim = nheads * headdim\n# u = torch.randn(batch_size, dim, seqlen, device=device, dtype=dtype, requires_grad=True)\n# k = torch.randn(dim, seqlen, device=device, requires_grad=True)\n# D = torch.randn(dim, device=device, requires_grad=True)\n# benchmark_all(fftconv_func, u, k, D, repeats=repeats, desc='FFTConv')\n# pytorch_profiler(fftconv_func, u, k, D, backward=True)\n# pytorch_profiler(torch.fft.rfft, u.float())\n\nflops = 4 * batch_size * seqlen ** 2 * nheads * headdim\nideal_a100_time = flops / 312 / 1e9\nprint(f\"Ideal A100 fwd time: {ideal_a100_time:.3f}ms, bwd time: {ideal_a100_time * 2.5:.3f}ms\")\nexit(0)\n\n\ndef time_fwd_bwd(func, *args, **kwargs):\n    time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs)\n    return time_f[1].mean, time_b[1].mean\n\nbs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)]\ncausal_vals = [False, True]\nheaddim_vals = [64, 128]\ndim = 2048\ndropout_p = 0.0\n\ntime_f = {}\ntime_b = {}\nfor causal in causal_vals:\n    for headdim in headdim_vals:\n        for batch_size, seqlen in bs_seqlen_vals:\n            nheads = dim // headdim\n            qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,\n                              requires_grad=True)\n            cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,\n                                    device=qkv.device)\n            qkv_unpad = rearrange(qkv, 'b s ... -> (b s) ...').detach().requires_grad_(True)\n            f, b = time_fwd_bwd(\n                flash_attn_varlen_qkvpacked_func, qkv_unpad, cu_seqlens, seqlen, dropout_p,\n                causal=causal, repeats=repeats, verbose=False\n            )\n            time_f[(causal, headdim, batch_size, seqlen), \"Flash\"] = f\n            time_b[(causal, headdim, batch_size, seqlen), \"Flash\"] = b\n\n            qkv = qkv.detach().requires_grad_(True)\n            f, b = time_fwd_bwd(\n                fav2_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False\n            )\n            time_f[(causal, headdim, batch_size, seqlen), \"Flash2\"] = f\n            time_b[(causal, headdim, batch_size, seqlen), \"Flash2\"] = b\n\n            # q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype,\n            #                        requires_grad=True) for _ in range(3)]\n            # # Try both values of sequence_parallel and pick the faster one\n            # f, b = time_fwd_bwd(\n            #     attention_triton, q, k, v, causal, headdim**(-0.5),\n            #     False, repeats=repeats, verbose=False\n            # )\n            # _, b0 = time_fwd_bwd(\n            #     attention_triton, q, k, v, causal, headdim**(-0.5),\n            #     True, repeats=repeats, verbose=False\n            # )\n            # time_f[(causal, headdim, batch_size, seqlen), \"Triton\"] = f\n            # time_b[(causal, headdim, batch_size, seqlen), \"Triton\"] = min(b, b0)\n\n            if seqlen <= 8 * 1024:\n                qkv = qkv.detach().requires_grad_(True)\n                f, b = time_fwd_bwd(\n                    attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False\n                )\n            else:\n                f, b = float('nan'), float('nan')\n            time_f[(causal, headdim, batch_size, seqlen), \"Pytorch\"] = f\n            time_b[(causal, headdim, batch_size, seqlen), \"Pytorch\"] = b\n\n            # q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype,\n            #                        requires_grad=True) for _ in range(3)]\n            # import xformers.ops as xops\n            # f, b = time_fwd_bwd(\n            #     xops.memory_efficient_attention, q, k, v,\n            #     attn_bias=xops.LowerTriangularMask() if causal else None,\n            #     op=(xops.fmha.cutlass.FwOp, xops.fmha.cutlass.BwOp)\n            # )\n            # time_f[(causal, headdim, batch_size, seqlen), \"xformers\"] = f\n            # time_b[(causal, headdim, batch_size, seqlen), \"xformers\"] = b\n\n\nimport pickle\nwith open('flash2_attn_time_h100.plk', 'wb') as fp:\n    pickle.dump((time_f, time_b), fp, protocol=pickle.HIGHEST_PROTOCOL)\n"
  },
  {
    "path": "benchmarks/benchmark_flash_attention.py",
    "content": "# Install the newest triton version with\n# pip install \"git+https://github.com/openai/triton.git#egg=triton&subdirectory=python\"\nimport pickle\nimport math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom einops import rearrange, repeat\n\nfrom flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward\nfrom flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined\n\nfrom flash_attn import flash_attn_qkvpacked_func\n\ntry:\n    from triton.ops.flash_attention import attention as attention_triton\nexcept ImportError:\n    attention_triton = None\n\ntry:\n    import xformers.ops as xops\nexcept ImportError:\n    xops = None\n\n\ndef flops(batch, seqlen, headdim, nheads, causal, mode=\"fwd\"):\n    assert mode in [\"fwd\", \"bwd\", \"fwd_bwd\"]\n    f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)\n    return f if mode == \"fwd\" else (2.5 * f if mode == \"bwd\" else 3.5 * f)\n\ndef efficiency(flop, time):\n    return (flop / time / 10**12) if not math.isnan(time) else 0.0\n\n\ndef attention_pytorch(qkv, dropout_p=0.0, causal=True):\n    \"\"\"\n    Arguments:\n        qkv: (batch_size, seqlen, 3, nheads, head_dim)\n        dropout_p: float\n    Output:\n        output: (batch_size, seqlen, nheads, head_dim)\n    \"\"\"\n    batch_size, seqlen, _, nheads, d = qkv.shape\n    q, k, v = qkv.unbind(dim=2)\n    q = rearrange(q, 'b t h d -> (b h) t d')\n    k = rearrange(k, 'b s h d -> (b h) d s')\n    softmax_scale = 1.0 / math.sqrt(d)\n    # Preallocate attn_weights for `baddbmm`\n    scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device)\n    scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale),\n                       '(b h) t s -> b h t s', h=nheads)\n    if causal:\n        # \"triu_tril_cuda_template\" not implemented for 'BFloat16'\n        # So we have to construct the mask in float\n        causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)\n        # Adding is faster than masked_fill_\n        scores = scores + causal_mask.to(dtype=scores.dtype)\n    attention = torch.softmax(scores, dim=-1)\n    attention_drop = F.dropout(attention, dropout_p)\n    output = torch.einsum('bhts,bshd->bthd', attention_drop , v)\n    return output.to(dtype=qkv.dtype)\n\n\ndef time_fwd_bwd(func, *args, **kwargs):\n    time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs)\n    return time_f[1].mean, time_b[1].mean\n\n\nrepeats = 30\ndevice = 'cuda'\ndtype = torch.float16\n\nbs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)]\ncausal_vals = [False, True]\nheaddim_vals = [64, 128]\ndim = 2048\ndropout_p = 0.0\n\nmethods = ([\"Flash2\", \"Pytorch\"]\n           + ([\"Triton\"] if attention_triton is not None else [])\n           + ([\"xformers.c\"] if xops is not None else [])\n           + ([\"xformers.f\"] if xops is not None else []))\n\ntime_f = {}\ntime_b = {}\ntime_f_b = {}\nspeed_f = {}\nspeed_b = {}\nspeed_f_b = {}\n\nfor causal in causal_vals:\n    for headdim in headdim_vals:\n        for batch_size, seqlen in bs_seqlen_vals:\n            config = (causal, headdim, batch_size, seqlen)\n            nheads = dim // headdim\n\n            # FlashAttention 2\n            if \"Flash2\" in methods:\n                qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim,\n                                  device=device, dtype=dtype, requires_grad=True)\n                f, b = time_fwd_bwd(\n                    flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal,\n                    repeats=repeats, verbose=False\n                )\n                time_f[config, \"Flash2\"] = f\n                time_b[config, \"Flash2\"] = b\n\n            # PyTorch baseline\n            if \"Pytorch\" in methods:\n                try:\n                    # fresh tensor avoids grad-history reuse issues\n                    qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim,\n                                      device=device, dtype=dtype, requires_grad=True)\n                    f, b = time_fwd_bwd(\n                        attention_pytorch, qkv, dropout_p, causal=causal,\n                        repeats=repeats, verbose=False\n                    )\n                except Exception:\n                    f, b = float('nan'), float('nan')\n                time_f[config, \"Pytorch\"] = f\n                time_b[config, \"Pytorch\"] = b\n\n            # Triton\n            if \"Triton\" in methods and attention_triton is not None:\n                q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim,\n                                       device=device, dtype=dtype, requires_grad=True) for _ in range(3)]\n                # Try both values of sequence_parallel and pick the faster backward\n                try:\n                    f, b = time_fwd_bwd(\n                        attention_triton, q, k, v, causal, headdim**(-0.5),\n                        False, repeats=repeats, verbose=False\n                    )\n                except Exception:\n                    f, b = float('nan'), float('inf')\n                try:\n                    _, b0 = time_fwd_bwd(\n                        attention_triton, q, k, v, causal, headdim**(-0.5),\n                        True, repeats=repeats, verbose=False\n                    )\n                except Exception:\n                    b0 = float('inf')\n                time_f[config, \"Triton\"] = f\n                time_b[config, \"Triton\"] = min(b, b0) if min(b, b0) < float('inf') else float('nan')\n\n            # xFormers CUTLASS\n            if \"xformers.c\" in methods and xops is not None:\n                q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim,\n                                       device=device, dtype=dtype, requires_grad=True) for _ in range(3)]\n                f, b = time_fwd_bwd(\n                    xops.memory_efficient_attention, q, k, v,\n                    attn_bias=xops.LowerTriangularMask() if causal else None,\n                    op=(xops.fmha.cutlass.FwOp, xops.fmha.cutlass.BwOp)\n                )\n                time_f[config, \"xformers.c\"] = f\n                time_b[config, \"xformers.c\"] = b\n\n            # xFormers Flash\n            if \"xformers.f\" in methods and xops is not None:\n                q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim,\n                                       device=device, dtype=dtype, requires_grad=True) for _ in range(3)]\n                f, b = time_fwd_bwd(\n                    xops.memory_efficient_attention, q, k, v,\n                    attn_bias=xops.LowerTriangularMask() if causal else None,\n                    op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp)\n                )\n                time_f[config, \"xformers.f\"] = f\n                time_b[config, \"xformers.f\"] = b\n\n            # Report\n            print(f\"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###\")\n            for method in methods:\n                if (config, method) not in time_f or (config, method) not in time_b:\n                    continue\n                time_f_b[config, method] = time_f[config, method] + time_b[config, method]\n                speed_f[config, method] = efficiency(\n                    flops(batch_size, seqlen, headdim, nheads, causal, mode=\"fwd\"),\n                    time_f[config, method]\n                )\n                speed_b[config, method] = efficiency(\n                    flops(batch_size, seqlen, headdim, nheads, causal, mode=\"bwd\"),\n                    time_b[config, method]\n                )\n                speed_f_b[config, method] = efficiency(\n                    flops(batch_size, seqlen, headdim, nheads, causal, mode=\"fwd_bwd\"),\n                    time_f_b[config, method]\n                )\n                print(\n                    f\"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, \"\n                    f\"bwd: {speed_b[config, method]:.2f} TFLOPs/s, \"\n                    f\"fwd + bwd: {speed_f_b[config, method]:.2f} TFLOPs/s\"\n                )\n\n# with open('flash2_attn_time.plk', 'wb') as fp:\n#     pickle.dump((speed_f, speed_b, speed_f_b), fp, protocol=pickle.HIGHEST_PROTOCOL)"
  },
  {
    "path": "benchmarks/benchmark_gemm.py",
    "content": "import time\nimport torch\nimport torch.utils.benchmark as benchmark\n\nfrom triton.testing import do_bench\n\nif torch.version.cuda:\n    backendBLAS = \"cuBLAS\"\nelif torch.version.hip:\n    backendBLAS = \"hipBLAS\"\n\ndef benchmark_forward(fn, *inputs, repeats=10, desc='', verbose=True, **kwinputs):\n    \"\"\"Use Pytorch Benchmark on the forward pass of an arbitrary function.\"\"\"\n    if verbose:\n        print(desc, '- Forward pass')\n    t = benchmark.Timer(\n            stmt='fn(*inputs, **kwinputs)',\n            globals={'fn': fn, 'inputs': inputs, 'kwinputs': kwinputs},\n            num_threads=torch.get_num_threads(),\n            )\n    m = t.timeit(repeats)\n    if verbose:\n        print(m)\n    return t, m\n\n\ntorch.manual_seed(0)\nrepeats = 30\ndtype = torch.bfloat16\ndevice = 'cuda'\nverbose = False\nm, n = 8192, 8192\n\ntflops_matmul = {}\ntflops_matmul1 = {}\nfor k in [512, 1024, 1536, 2048, 2560, 3072, 3584, 4096, 4608, 5120, 5632, 6144, 6656, 7168, 7680, 8192]:\n    a = torch.randn(m, k, device=device, dtype=dtype)\n    b = torch.randn(n, k, device=device, dtype=dtype).transpose(-1, -2)\n    nFLOPS_matmul = 2 * m * n * k\n    time.sleep(2)  # to reduce power throttling\n    timing = benchmark_forward(torch.matmul, a, b, desc=backendBLAS, verbose=verbose, repeats=repeats)[1]\n    tflops_matmul[k] = nFLOPS_matmul / timing.mean * 1e-12\n    print(f'[torch.utils.benchmark] {backendBLAS}, {m = }, {n = }, {k = }: {timing.mean * 1e3:.3f}ms, {tflops_matmul[k]:.1f} TFLOPS')\n    time.sleep(2)  # to reduce power throttling\n    ms = do_bench(lambda: torch.matmul(a, b), warmup=10, rep=repeats)\n    tflops_matmul1[k] = nFLOPS_matmul / ms * 1e-9\n    print(f'[triton.test.do_bench]  {backendBLAS}, {m = }, {n = }, {k = }: {ms:.3f}ms, {tflops_matmul1[k]:.1f} TFLOPS')\n"
  },
  {
    "path": "csrc/flash_attn/flash_api.cpp",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Tri Dao.\n ******************************************************************************/\n\n// Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.\n#include <torch/python.h>\n#include <torch/nn/functional.h>\n#include <c10/cuda/CUDAGuard.h>\n#include <c10/cuda/CUDAStream.h>\n#include <ATen/cuda/CUDAGeneratorImpl.h>  // For at::Generator and at::PhiloxCudaState\n#include \"philox_unpack.cuh\"  // For at::cuda::philox::unpack\n\n#include <cutlass/numeric_types.h>\n\n#include \"namespace_config.h\"\n#include \"hardware_info.h\"\n#include \"flash.h\"\n#include \"static_switch.h\"\n\n#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x \" must be on CUDA\")\n#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x \" must have shape (\" #__VA_ARGS__ \")\")\n#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x \" must be contiguous\")\n\nnamespace FLASH_NAMESPACE {\n\nvoid set_params_fprop(Flash_fwd_params &params,\n                      // sizes\n                      const size_t b,\n                      const size_t seqlen_q,\n                      const size_t seqlen_k,\n                      const size_t seqlen_q_rounded,\n                      const size_t seqlen_k_rounded,\n                      const size_t h,\n                      const size_t h_k,\n                      const size_t d,\n                      const size_t d_rounded,\n                      // device pointers\n                      const at::Tensor q,\n                      const at::Tensor k,\n                      const at::Tensor v,\n                      at::Tensor out,\n                      void *cu_seqlens_q_d,\n                      void *cu_seqlens_k_d,\n                      void *seqused_k,\n                      void *p_d,\n                      void *softmax_lse_d,\n                      float p_dropout,\n                      float softmax_scale,\n                      int window_size_left,\n                      int window_size_right,\n                      const float softcap,\n                      bool seqlenq_ngroups_swapped=false,\n                      const bool unpadded_lse=false) {\n\n    // Reset the parameters\n    params = {};\n\n    params.is_bf16 = q.dtype() == torch::kBFloat16;\n\n    // Set the pointers and strides.\n    params.q_ptr = q.data_ptr();\n    params.k_ptr = k.data_ptr();\n    params.v_ptr = v.data_ptr();\n    // All stride are in elements, not bytes.\n    params.q_row_stride = q.stride(-3);\n    params.k_row_stride = k.stride(-3);\n    params.v_row_stride = v.stride(-3);\n    params.q_head_stride = q.stride(-2);\n    params.k_head_stride = k.stride(-2);\n    params.v_head_stride = v.stride(-2);\n    params.o_ptr = out.data_ptr();\n    params.o_row_stride = out.stride(-3);\n    params.o_head_stride = out.stride(-2);\n\n    if (cu_seqlens_q_d == nullptr) {\n        params.q_batch_stride = q.stride(0);\n        params.k_batch_stride = k.stride(0);\n        params.v_batch_stride = v.stride(0);\n        params.o_batch_stride = out.stride(0);\n        if (seqlenq_ngroups_swapped) {\n             params.q_batch_stride *= seqlen_q;\n             params.o_batch_stride *= seqlen_q;\n        }\n    }\n\n    params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);\n    params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);\n    params.seqused_k = static_cast<int *>(seqused_k);\n\n    // P = softmax(QK^T)\n    params.p_ptr = p_d;\n\n    // Softmax sum\n    params.softmax_lse_ptr = softmax_lse_d;\n\n    // Set the dimensions.\n    params.b = b;\n    params.h = h;\n    params.h_k = h_k;\n    params.h_h_k_ratio = h / h_k;\n    params.seqlen_q = seqlen_q;\n    params.seqlen_k = seqlen_k;\n    params.seqlen_q_rounded = seqlen_q_rounded;\n    params.seqlen_k_rounded = seqlen_k_rounded;\n    params.d = d;\n    params.d_rounded = d_rounded;\n\n    // Set the different scale values.\n    #ifdef FLASHATTENTION_DISABLE_SOFTCAP\n        TORCH_CHECK(softcap <= 0.0, \"This flash attention build does not support softcap.\");\n    #endif\n    if (softcap > 0.0) {\n        params.softcap = softmax_scale / softcap;\n        params.scale_softmax = softcap;\n        params.scale_softmax_log2 = softcap * M_LOG2E;\n    } else{\n        // Remove potential NaN\n        params.softcap = 0.0;\n        params.scale_softmax = softmax_scale;\n        params.scale_softmax_log2 = softmax_scale * M_LOG2E;\n    }\n\n    // Set this to probability of keeping an element to simplify things.\n    params.p_dropout = 1.f - p_dropout;\n    // Convert p from float to int so we don't have to convert the random uint to float to compare.\n    // [Minor] We want to round down since when we do the comparison we use <= instead of <\n    // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));\n    // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));\n    params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));\n    params.rp_dropout = 1.f / params.p_dropout;\n    params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;\n    TORCH_CHECK(p_dropout < 1.f);\n    #ifdef FLASHATTENTION_DISABLE_DROPOUT\n        TORCH_CHECK(p_dropout == 0.0f, \"This flash attention build does not support dropout.\");\n    #endif\n\n    // Causal is the special case where window_size_right == 0 and window_size_left < 0.\n    // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.\n    params.is_causal = window_size_left < 0 && window_size_right == 0;\n\n    if (window_size_left < 0 && window_size_right >= 0) { window_size_left = seqlen_k; }\n    if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_k; }\n    params.window_size_left = window_size_left;\n    params.window_size_right = window_size_right;\n\n    #ifdef FLASHATTENTION_DISABLE_LOCAL\n        TORCH_CHECK(params.is_causal || (window_size_left < 0 && window_size_right < 0),\n            \"This flash attention build does not support local attention.\");\n    #endif\n\n    params.is_seqlens_k_cumulative = true;\n\n    #ifdef FLASHATTENTION_DISABLE_UNEVEN_K\n        TORCH_CHECK(d == d_rounded, \"This flash attention build does not support headdim not being a multiple of 32.\");\n    #endif\n\n    params.unpadded_lse = unpadded_lse;\n    params.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped;\n}\n\nvoid set_params_dgrad(Flash_bwd_params &params,\n                      // sizes\n                      const size_t b,\n                      const size_t seqlen_q,\n                      const size_t seqlen_k,\n                      const size_t seqlen_q_rounded,\n                      const size_t seqlen_k_rounded,\n                      const size_t h,\n                      const size_t h_k,\n                      const size_t d,\n                      const size_t d_rounded,\n                      // device pointers\n                      const at::Tensor q,\n                      const at::Tensor k,\n                      const at::Tensor v,\n                      const at::Tensor out,\n                      const at::Tensor dout,\n                      at::Tensor dq,\n                      at::Tensor dk,\n                      at::Tensor dv,\n                      void *cu_seqlens_q_d,\n                      void *cu_seqlens_k_d,\n                      void *dq_accum_d,\n                      void *dk_accum_d,\n                      void *dv_accum_d,\n                      void *softmax_lse_d,\n                      void *dsoftmax_sum_d,\n                      float p_dropout,\n                      float softmax_scale,\n                      int window_size_left,\n                      int window_size_right,\n                      const float softcap,\n                      bool deterministic,\n                      const bool unpadded_lse) {\n\n    set_params_fprop(params,\n                     b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,\n                     q, k, v, out,\n                     cu_seqlens_q_d,\n                     cu_seqlens_k_d,\n                     nullptr,\n                     nullptr,\n                     softmax_lse_d,\n                     p_dropout,\n                     softmax_scale,\n                     window_size_left,\n                     window_size_right,\n                     softcap,\n                     false, // seqlenq_ngroups_swapped\n                     unpadded_lse);\n\n    // Set the pointers and strides.\n    params.do_ptr = dout.data_ptr();\n    params.do_row_stride = dout.stride(-3);\n    params.do_head_stride = dout.stride(-2);\n    params.dq_ptr = dq.data_ptr();\n    params.dk_ptr = dk.data_ptr();\n    params.dv_ptr = dv.data_ptr();\n    params.dq_row_stride = dq.stride(-3);\n    params.dk_row_stride = dk.stride(-3);\n    params.dv_row_stride = dv.stride(-3);\n    params.dq_head_stride = dq.stride(-2);\n    params.dk_head_stride = dk.stride(-2);\n    params.dv_head_stride = dv.stride(-2);\n\n    if (cu_seqlens_q_d == nullptr) {\n        params.do_batch_stride = dout.stride(0);\n        params.dq_batch_stride = dq.stride(0);\n        params.dk_batch_stride = dk.stride(0);\n        params.dv_batch_stride = dv.stride(0);\n    }\n\n    params.dq_accum_ptr = dq_accum_d;\n    params.dk_accum_ptr = dk_accum_d;\n    params.dv_accum_ptr = dv_accum_d;\n\n    // Softmax sum\n    params.dsoftmax_sum = dsoftmax_sum_d;\n\n    params.deterministic = deterministic;\n}\n\nvoid run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) {\n    FP16_SWITCH(!params.is_bf16, [&] {\n        HEADDIM_SWITCH(params.d, [&] {\n            BOOL_SWITCH(params.is_causal, Is_causal, [&] {\n                if (params.num_splits <= 1 && !force_split_kernel) {  // If we don't set it num_splits == 0\n                    run_mha_fwd_<elem_type, kHeadDim, Is_causal>(params, stream);\n                } else {\n                    run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim, Is_causal>(params, stream);\n                }\n            });\n        });\n    });\n}\n\n// Find the number of splits that maximizes the occupancy. For example, if we have\n// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is\n// better than having 3 splits (efficiency = 0.67). However, we also don't want too many\n// splits as that would incur more HBM reads/writes.\n// So we find the best efficiency, then find the smallest number of splits that gets 85%\n// of the best efficiency.\ninline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) {\n    // If we have enough to almost fill the SMs, then just use 1 split\n    if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; }\n    max_splits = std::min({max_splits, num_SMs, num_n_blocks});\n    float max_efficiency = 0.f;\n    std::vector<float> efficiency;\n    efficiency.reserve(max_splits);\n    auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };\n    // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,\n    // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks\n    // (i.e. it's 11 splits anyway).\n    // So we check if the number of blocks per split is the same as the previous num_splits.\n    auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {\n        return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1);\n    };\n    for (int num_splits = 1; num_splits <= max_splits; num_splits++) {\n        if (!is_split_eligible(num_splits)) {\n            efficiency.push_back(0.f);\n        } else {\n            float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs;\n            float eff = n_waves / ceil(n_waves);\n            // printf(\"num_splits = %d, eff = %f\\n\", num_splits, eff);\n            if (eff > max_efficiency) { max_efficiency = eff; }\n            efficiency.push_back(eff);\n        }\n    }\n    for (int num_splits = 1; num_splits <= max_splits; num_splits++) {\n        if (!is_split_eligible(num_splits)) { continue; }\n        if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) {\n            // printf(\"num_splits chosen = %d\\n\", num_splits);\n            return num_splits;\n        }\n    }\n    return 1;\n}\n\nstd::tuple<at::Tensor, at::Tensor> set_params_splitkv(Flash_fwd_params &params, const int batch_size,\n    const int num_heads, const int head_size, const int max_seqlen_k, const int max_seqlen_q,\n    const int head_size_rounded, const float p_dropout,\n    const int num_splits, const int num_sm, struct c10::TensorOptions opts) {\n\n    // This needs to match with run_mha_fwd_splitkv_dispatch\n    const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64);\n    const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n;\n    // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.\n    // In any case we don't expect seqlen_q to be larger than 64 for inference.\n    const int num_m_blocks = (max_seqlen_q + 64 - 1) / 64;\n    params.num_splits = num_splits;\n    at::Tensor softmax_lse_accum;\n    at::Tensor out_accum;\n\n    if (p_dropout == 0.0f) {  // SplitKV is not implemented for dropout\n        if (num_splits < 1) {\n            // We multiply number of SMs by 2 to hard-code the fact that we're using 128 threads per block.\n            params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, num_sm * 2, num_n_blocks, 128);\n        }\n        if (params.num_splits > 1) {\n            softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));\n            out_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));\n            params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();\n            params.oaccum_ptr = out_accum.data_ptr();\n        }\n        TORCH_CHECK(params.num_splits <= 128, \"num_splits > 128 not supported\");\n    }\n\n    return std::make_tuple(softmax_lse_accum, out_accum);\n}\n\nvoid set_params_alibi(Flash_fwd_params &params, std::optional<at::Tensor> &alibi_slopes_, int batch_size, int num_heads){\n#ifdef FLASHATTENTION_DISABLE_ALIBI\n    TORCH_CHECK(!alibi_slopes_.has_value(), \"This flash attention build does not support alibi.\");\n    params.alibi_slopes_ptr = nullptr;\n#else\n    if (alibi_slopes_.has_value()) {\n        auto alibi_slopes = alibi_slopes_.value();\n        TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, \"ALiBi slopes must have dtype fp32\");\n        CHECK_DEVICE(alibi_slopes);\n        TORCH_CHECK(alibi_slopes.stride(-1) == 1, \"ALiBi slopes tensor must have contiguous last dimension\");\n        TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads}) || alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads}));\n        params.alibi_slopes_ptr = alibi_slopes.data_ptr();\n        params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;\n    } else {\n        params.alibi_slopes_ptr = nullptr;\n    }\n#endif\n}\n\nstd::vector<at::Tensor>\nmha_fwd(at::Tensor &q,         // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)\n        const at::Tensor &k,         // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)\n        const at::Tensor &v,         // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)\n        std::optional<at::Tensor> &out_,             // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)\n        std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads\n        const float p_dropout,\n        const float softmax_scale,\n        bool is_causal,\n        int window_size_left,\n        int window_size_right,\n        const float softcap,\n        const bool return_softmax,\n        std::optional<at::Generator> gen_) {\n\n    // Otherwise the kernel will be launched from cuda:0 device\n    at::cuda::CUDAGuard device_guard{q.device()};\n\n    auto [cc_major, cc_minor] = get_compute_capability(get_current_device());\n    bool is_sm8x_min = cc_major >= 8;\n    TORCH_CHECK(is_sm8x_min, \"FlashAttention only supports Ampere GPUs or newer.\");\n\n    auto q_dtype = q.dtype();\n    TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,\n                \"FlashAttention only support fp16 and bf16 data type\");\n    TORCH_CHECK(k.dtype() == q_dtype, \"query and key must have the same dtype\");\n    TORCH_CHECK(v.dtype() == q_dtype, \"query and value must have the same dtype\");\n\n    CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);\n\n    TORCH_CHECK(q.stride(-1) == 1, \"Input tensor must have contiguous last dimension\");\n    TORCH_CHECK(k.stride(-1) == 1, \"Input tensor must have contiguous last dimension\");\n    TORCH_CHECK(v.stride(-1) == 1, \"Input tensor must have contiguous last dimension\");\n\n    const auto sizes = q.sizes();\n\n    const int batch_size = sizes[0];\n    int seqlen_q = sizes[1];\n    int num_heads = sizes[2];\n    const int head_size = sizes[3];\n    const int seqlen_k = k.size(1);\n    const int num_heads_k = k.size(2);\n    TORCH_CHECK(batch_size > 0, \"batch size must be positive\");\n    TORCH_CHECK(head_size <= 256, \"FlashAttention forward only supports head dimension at most 256\");\n    TORCH_CHECK(head_size % 8 == 0, \"query, key, value, and out_ must have a head_size that is a multiple of 8\");\n    TORCH_CHECK(num_heads % num_heads_k == 0, \"Number of heads in key/value must divide number of heads in query\");\n\n    if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, \"Softcapping does not support dropout for now\"); }\n\n    if (window_size_left >= seqlen_k) { window_size_left = -1; }\n    if (window_size_right >= seqlen_k) { window_size_right = -1; }\n\n    // causal=true is the same as causal=false in this case\n    if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }\n    if (is_causal) { window_size_right = 0; }\n\n    // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case\n    // H/t Daniel Haziza\n    const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size % 8 == 0 && !alibi_slopes_.has_value();\n    const int ngroups = num_heads / num_heads_k;\n    if (seqlenq_ngroups_swapped) {\n        q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2);\n        seqlen_q = ngroups;\n        num_heads = num_heads_k;\n    }\n\n    CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);\n    CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);\n    CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);\n\n    at::Tensor out;\n    if (out_.has_value()) {\n        out = out_.value();\n        TORCH_CHECK(out.dtype() == q_dtype, \"Output must have the same dtype as inputs\");\n        CHECK_DEVICE(out);\n        TORCH_CHECK(out.stride(-1) == 1, \"Output tensor must have contiguous last dimension\");\n        CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size);\n        if (seqlenq_ngroups_swapped) {\n            out = out.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2);\n        }\n    } else {\n        out = torch::empty_like(q);\n    }\n\n    auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };\n    const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64);\n    const int seqlen_q_rounded = round_multiple(seqlen_q, 128);\n    const int seqlen_k_rounded = round_multiple(seqlen_k, 128);\n\n    auto opts = q.options();\n\n    auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));\n    at::Tensor p;\n    // Only return softmax if there's dropout to reduce compilation time\n    if (return_softmax) {\n        TORCH_CHECK(p_dropout > 0.0f, \"return_softmax is only supported when p_dropout > 0.0\");\n        p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts);\n    }\n    else {\n        p = torch::empty({ 0 }, opts);\n    }\n\n    Flash_fwd_params params;\n    set_params_fprop(params,\n                     batch_size,\n                     seqlen_q, seqlen_k,\n                     seqlen_q_rounded, seqlen_k_rounded,\n                     num_heads, num_heads_k,\n                     head_size, head_size_rounded,\n                     q, k, v, out,\n                     /*cu_seqlens_q_d=*/nullptr,\n                     /*cu_seqlens_k_d=*/nullptr,\n                     /*seqused_k=*/nullptr,\n                     return_softmax ? p.data_ptr() : nullptr,\n                     softmax_lse.data_ptr(),\n                     p_dropout,\n                     softmax_scale,\n                     window_size_left,\n                     window_size_right,\n                     softcap\n                     );\n\n    // Keep references to these tensors to extend their lifetime\n    at::Tensor softmax_lse_accum, out_accum;\n    std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(\n        params, batch_size, num_heads, head_size, seqlen_k, seqlen_q,\n        head_size_rounded, p_dropout, /*num_splits*/ 0, get_num_sm(get_current_device()), opts);\n\n    // number of times random will be generated per thread, to offset philox counter in thc random\n    // state\n    // We use a custom RNG that increases the offset by batch_size * nheads * 32.\n    int64_t counter_offset = params.b * params.h * 32;\n    auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);\n    auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));\n    // Forward kernel will populate memory with the seed and offset.\n    params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());\n\n    if (p_dropout > 0.0)  {\n        auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(\n            gen_, at::cuda::detail::getDefaultCUDAGenerator());\n        // See Note [Acquire lock when using random generators]\n        std::lock_guard<std::mutex> lock(gen->mutex_);\n        params.philox_args = gen->philox_cuda_state(counter_offset);\n    }\n\n    set_params_alibi(params, alibi_slopes_, batch_size, num_heads);\n\n    if (seqlen_k > 0) {\n        auto stream = at::cuda::getCurrentCUDAStream().stream();\n        run_mha_fwd(params, stream);\n    } else {\n        // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.\n        out.zero_();\n        softmax_lse.fill_(std::numeric_limits<float>::infinity());\n    }\n\n    if (seqlenq_ngroups_swapped) {\n        out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size});\n        q = q.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size});\n        softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});\n    }\n    return {out, softmax_lse, p, rng_state};\n}\n\nstd::vector<at::Tensor>\nmha_varlen_fwd(at::Tensor &q,  // total_q x num_heads x head_size, total_q := \\sum_{i=0}^{b} s_i\n               const at::Tensor &k,  // total_k x num_heads_k x head_size, total_k := \\sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.\n               const at::Tensor &v,  // total_k x num_heads_k x head_size, total_k := \\sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.\n               std::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_q := \\sum_{i=0}^{b} s_i\n               const at::Tensor &cu_seqlens_q,  // b+1\n               const at::Tensor &cu_seqlens_k,  // b+1\n               std::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.\n               std::optional<const at::Tensor> &leftpad_k_, // batch_size\n               std::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq\n               std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads\n               int max_seqlen_q,\n               const int max_seqlen_k,\n               const float p_dropout,\n               const float softmax_scale,\n               const bool zero_tensors,\n               bool is_causal,\n               int window_size_left,\n               int window_size_right,\n               const float softcap,\n               const bool return_softmax,\n               std::optional<at::Generator> gen_) {\n\n    // Otherwise the kernel will be launched from cuda:0 device\n    at::cuda::CUDAGuard device_guard{q.device()};\n\n    auto [cc_major, cc_minor] = get_compute_capability(get_current_device());\n    bool is_sm8x_min = cc_major >= 8;\n    TORCH_CHECK(is_sm8x_min, \"FlashAttention only supports Ampere GPUs or newer.\");\n\n    auto q_dtype = q.dtype();\n    TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,\n                \"FlashAttention only support fp16 and bf16 data type\");\n    TORCH_CHECK(k.dtype() == q_dtype, \"query and key must have the same dtype\");\n    TORCH_CHECK(v.dtype() == q_dtype, \"query and value must have the same dtype\");\n    TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, \"cu_seqlens_q must have dtype int32\");\n    TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, \"cu_seqlens_k must have dtype int32\");\n\n    CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);\n    CHECK_DEVICE(cu_seqlens_q);\n    CHECK_DEVICE(cu_seqlens_k);\n\n    at::Tensor block_table;\n    const bool paged_KV = block_table_.has_value();\n    if (paged_KV) {\n        block_table = block_table_.value();\n        CHECK_DEVICE(block_table);\n        TORCH_CHECK(block_table.dtype() == torch::kInt32, \"block_table must have dtype torch.int32\");\n        TORCH_CHECK(block_table.stride(-1) == 1, \"block_table must have contiguous last dimension\");\n    }\n\n    TORCH_CHECK(q.stride(-1) == 1, \"Input tensor must have contiguous last dimension\");\n    TORCH_CHECK(k.stride(-1) == 1, \"Input tensor must have contiguous last dimension\");\n    TORCH_CHECK(v.stride(-1) == 1, \"Input tensor must have contiguous last dimension\");\n    CHECK_CONTIGUOUS(cu_seqlens_q);\n    CHECK_CONTIGUOUS(cu_seqlens_k);\n\n    const auto sizes = q.sizes();\n\n    const int batch_size = cu_seqlens_q.numel() - 1;\n    int num_heads = sizes[1];\n    const int head_size = sizes[2];\n    const int num_heads_k = paged_KV ? k.size(2) : k.size(1);\n\n    if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, \"Softcapping does not support dropout for now\"); }\n\n    const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);\n    const int num_blocks = !paged_KV ? 0 : k.size(0);\n    const int page_block_size = !paged_KV ? 1 : k.size(1);\n    TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, \"Paged KV cache block size must be divisible by 256\");\n\n    if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }  // causal=true is the same as causal=false in this case\n    if (is_causal) { window_size_right = 0; }\n\n    void *cu_seqlens_q_d = cu_seqlens_q.data_ptr();\n\n    // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case\n    // H/t Daniel Haziza\n    const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size % 8 == 0 && !alibi_slopes_.has_value();\n    const int ngroups = num_heads / num_heads_k;\n    if (seqlenq_ngroups_swapped) {\n        q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size});\n        max_seqlen_q = ngroups;\n        num_heads = num_heads_k;\n        cu_seqlens_q_d = nullptr;\n    }\n\n    const int total_q = q.sizes()[0];\n\n    TORCH_CHECK(batch_size > 0, \"batch size must be positive\");\n    TORCH_CHECK(head_size <= 256, \"FlashAttention forward only supports head dimension at most 256\");\n    TORCH_CHECK(head_size % 8 == 0, \"query, key, value, and out_ must have a head_size that is a multiple of 8\");\n    TORCH_CHECK(num_heads % num_heads_k == 0, \"Number of heads in key/value must divide number of heads in query\");\n\n    if (window_size_left >= max_seqlen_k) { window_size_left = -1; }\n    if (window_size_right >= max_seqlen_k) { window_size_right = -1; }\n\n    CHECK_SHAPE(q, total_q, num_heads, head_size);\n    if (!paged_KV) {\n        const int total_k = k.size(0);\n        CHECK_SHAPE(k, total_k, num_heads_k, head_size);\n        CHECK_SHAPE(v, total_k, num_heads_k, head_size);\n    } else {\n        CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size);\n        CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size);\n        CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);\n    }\n\n    CHECK_SHAPE(cu_seqlens_q, batch_size + 1);\n    CHECK_SHAPE(cu_seqlens_k, batch_size + 1);\n    if (seqused_k.has_value()){\n        auto seqused_k_ = seqused_k.value();\n        TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, \"seqused_k must have dtype int32\");\n        TORCH_CHECK(seqused_k_.is_cuda(), \"seqused_k must be on CUDA device\");\n        TORCH_CHECK(seqused_k_.is_contiguous(), \"seqused_k must be contiguous\");\n        CHECK_SHAPE(seqused_k_, batch_size);\n    }\n\n    at::Tensor out;\n    if (out_.has_value()) {\n        out = out_.value();\n        TORCH_CHECK(out.dtype() == q_dtype, \"Output must have the same dtype as inputs\");\n        CHECK_DEVICE(out);\n        TORCH_CHECK(out.stride(-1) == 1, \"Output tensor must have contiguous last dimension\");\n        CHECK_SHAPE(out, sizes[0], sizes[1], head_size);\n        if (seqlenq_ngroups_swapped) {\n            out = out.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size});\n        }\n    } else {\n        out = torch::empty_like(q);\n    }\n\n    auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };\n    const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64);\n    const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);\n    const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);\n\n    auto opts = q.options();\n    auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat));\n    at::Tensor p;\n    // Only return softmax if there's dropout to reduce compilation time\n    if (return_softmax) {\n        TORCH_CHECK(p_dropout > 0.0f, \"return_softmax is only supported when p_dropout > 0.0\");\n        p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts);\n    }\n    else {\n        p = torch::empty({ 0 }, opts);\n    }\n\n    if (zero_tensors) {\n        out.zero_();\n        softmax_lse.fill_(-std::numeric_limits<float>::infinity());\n        if (return_softmax) {p.zero_();}\n    }\n\n    Flash_fwd_params params;\n    set_params_fprop(params,\n                     batch_size,\n                     max_seqlen_q, max_seqlen_k,\n                     seqlen_q_rounded, seqlen_k_rounded,\n                     num_heads, num_heads_k,\n                     head_size, head_size_rounded,\n                     q, k, v, out,\n                     cu_seqlens_q_d,\n                     cu_seqlens_k.data_ptr(),\n                     seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,\n                     return_softmax ? p.data_ptr() : nullptr,\n                     softmax_lse.data_ptr(),\n                     p_dropout,\n                     softmax_scale,\n                     window_size_left,\n                     window_size_right,\n                     softcap,\n                     seqlenq_ngroups_swapped,\n                     /*unpadded_lse*/true);\n    params.total_q = total_q;\n\n    if (paged_KV) {\n        params.block_table = block_table.data_ptr<int>();\n        params.block_table_batch_stride = block_table.stride(0);\n        params.k_batch_stride = k.stride(0);\n        params.v_batch_stride = v.stride(0);\n    }\n    params.page_block_size = page_block_size;\n    // Keep references to these tensors to extend their lifetime\n    at::Tensor softmax_lse_accum, out_accum;\n    if (seqlenq_ngroups_swapped) {\n        // Only apply split-k for decoding\n        std::tie(softmax_lse_accum, out_accum) =\n            set_params_splitkv(params, batch_size, num_heads, head_size,\n                               max_seqlen_k, max_seqlen_q, head_size_rounded,\n                               p_dropout, /*num_splits*/ 0, get_num_sm(get_current_device()), opts);\n    }\n\n    if (leftpad_k_.has_value()) {\n        auto leftpad_k = leftpad_k_.value();\n        TORCH_CHECK(!paged_KV, \"We don't support Paged KV and leftpad_k running at the same time yet\");\n        TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, \"leftpad_k must have dtype int32\");\n        CHECK_DEVICE(leftpad_k);\n        CHECK_CONTIGUOUS(leftpad_k);\n        CHECK_SHAPE(leftpad_k, batch_size);\n        params.leftpad_k = static_cast<int *>(leftpad_k.data_ptr());\n    }\n\n    // number of times random will be generated per thread, to offset philox counter in thc random\n    // state\n    // We use a custom RNG that increases the offset by batch_size * nheads * 32.\n    int64_t counter_offset = params.b * params.h * 32;\n    auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);\n    auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));\n    // Forward kernel will populate memory with the seed and offset.\n    params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());\n\n    if (p_dropout > 0.0)  {\n        auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(\n            gen_, at::cuda::detail::getDefaultCUDAGenerator());\n        // See Note [Acquire lock when using random generators]\n        std::lock_guard<std::mutex> lock(gen->mutex_);\n        params.philox_args = gen->philox_cuda_state(counter_offset);\n    }\n\n    set_params_alibi(params, alibi_slopes_, batch_size, num_heads);\n\n    if (max_seqlen_k > 0) {\n        auto stream = at::cuda::getCurrentCUDAStream().stream();\n        run_mha_fwd(params, stream, paged_KV);\n    } else {\n        // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.\n        out.zero_();\n        softmax_lse.fill_(std::numeric_limits<float>::infinity());\n    }\n\n    if (seqlenq_ngroups_swapped) {\n        int64_t size_before[] = {batch_size, max_seqlen_q, num_heads_k, head_size};\n        int64_t size_after[] = {batch_size, num_heads_k * max_seqlen_q, head_size};\n        out = out.reshape(size_before).transpose(1, 2).reshape(size_after);\n        q = q.reshape(size_before).transpose(1, 2).reshape(size_after);\n        softmax_lse = softmax_lse.reshape({num_heads * max_seqlen_q, batch_size});\n    }\n\n    return {out, softmax_lse, p, rng_state};\n}\n\nvoid run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream) {\n    FP16_SWITCH(!params.is_bf16, [&] {\n        HEADDIM_SWITCH(params.d, [&] {\n            BOOL_SWITCH(params.is_causal, Is_causal, [&] {\n                run_mha_bwd_<elem_type, kHeadDim, Is_causal>(params, stream);\n            });\n        });\n    });\n}\n\nstd::vector<at::Tensor>\nmha_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8)\n        const at::Tensor &q,   // batch_size x seqlen_q x num_heads x head_size\n        const at::Tensor &k,   // batch_size x seqlen_k x num_heads_k x head_size\n        const at::Tensor &v,   // batch_size x seqlen_k x num_heads_k x head_size\n        const at::Tensor &out,   // batch_size x seqlen_q x num_heads x head_size\n        const at::Tensor &softmax_lse,     // b x h x seqlen_q\n        std::optional<at::Tensor> &dq_,   // batch_size x seqlen_q x num_heads x head_size\n        std::optional<at::Tensor> &dk_,   // batch_size x seqlen_k x num_heads_k x head_size\n        std::optional<at::Tensor> &dv_,   // batch_size x seqlen_k x num_heads_k x head_size\n        std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads\n        const float p_dropout,         // probability to drop\n        const float softmax_scale,\n        const bool is_causal,\n        int window_size_left,\n        int window_size_right,\n        const float softcap,\n        const bool deterministic,\n        std::optional<at::Generator> gen_,\n        std::optional<at::Tensor> &rng_state) {\n\n    #ifdef FLASHATTENTION_DISABLE_BACKWARD\n        TORCH_CHECK(false, \"This flash attention build does not support backward.\");\n    #endif\n    if (is_causal) { window_size_right = 0; }\n\n    // Otherwise the kernel will be launched from cuda:0 device\n    at::cuda::CUDAGuard device_guard{q.device()};\n\n    auto [cc_major, cc_minor] = get_compute_capability(get_current_device());\n    bool is_sm8x_min = cc_major >= 8;\n    TORCH_CHECK(is_sm8x_min, \"FlashAttention only supports Ampere GPUs or newer.\");\n\n    bool is_dropout = p_dropout > 0.0;\n    auto stream = at::cuda::getCurrentCUDAStream().stream();\n\n    auto q_dtype = q.dtype();\n    TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,\n                \"FlashAttention only support fp16 and bf16 data type\");\n    TORCH_CHECK(k.dtype() == q_dtype, \"query and key must have the same dtype\");\n    TORCH_CHECK(v.dtype() == q_dtype, \"query and value must have the same dtype\");\n    TORCH_CHECK(out.dtype() == q_dtype, \"query and out must have the same dtype\");\n    TORCH_CHECK(dout.dtype() == q_dtype, \"query and dout must have the same dtype\");\n\n    CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);\n    CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);\n\n    TORCH_CHECK(q.stride(-1) == 1, \"Input tensor must have contiguous last dimension\");\n    TORCH_CHECK(k.stride(-1) == 1, \"Input tensor must have contiguous last dimension\");\n    TORCH_CHECK(v.stride(-1) == 1, \"Input tensor must have contiguous last dimension\");\n    TORCH_CHECK(out.stride(-1) == 1, \"out tensor must have contiguous last dimension\");\n    TORCH_CHECK(dout.stride(-1) == 1, \"dout tensor must have contiguous last dimension\");\n\n    const auto sizes = q.sizes();\n\n    const int batch_size = sizes[0];\n    const int seqlen_q = sizes[1];\n    const int num_heads = sizes[2];\n    const int head_size = sizes[3];\n    const int seqlen_k = k.size(1);\n    const int num_heads_k = k.size(2);\n    TORCH_CHECK(batch_size > 0, \"batch size must be positive\");\n    TORCH_CHECK(head_size % 8 == 0, \"head_size should be a multiple of 8\");\n    TORCH_CHECK(head_size <= 256, \"FlashAttention backward only supports head dimension at most 256\");\n    TORCH_CHECK(num_heads % num_heads_k == 0, \"Number of heads in key/value must divide number of heads in query\");\n\n    auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };\n    const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64);\n    const int seqlen_q_rounded = round_multiple(seqlen_q, 128);\n    const int seqlen_k_rounded = round_multiple(seqlen_k, 128);\n\n    if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, \"Softcapping does not support dropout for now\"); }\n\n    if (window_size_left >= seqlen_k) { window_size_left = -1; }\n    if (window_size_right >= seqlen_k) { window_size_right = -1; }\n\n    CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);\n    CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);\n    CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);\n    CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size);\n    CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size);\n\n    at::Tensor dq, dk, dv;\n    if (dq_.has_value()) {\n        dq = dq_.value();\n        TORCH_CHECK(dq.dtype() == q_dtype, \"dq must have the same dtype as q\");\n        CHECK_DEVICE(dq);\n        TORCH_CHECK(dq.stride(-1) == 1, \"dq must have contiguous last dimension\");\n        CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size);\n    } else {\n        dq = torch::empty_like(q);\n    }\n    if (dk_.has_value()) {\n        dk = dk_.value();\n        TORCH_CHECK(dk.dtype() == q_dtype, \"dk must have the same dtype as q\");\n        CHECK_DEVICE(dk);\n        TORCH_CHECK(dk.stride(-1) == 1, \"dk must have contiguous last dimension\");\n        CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size);\n    } else {\n        dk = torch::empty_like(k);\n    }\n    if (dv_.has_value()) {\n        dv = dv_.value();\n        TORCH_CHECK(dv.dtype() == q_dtype, \"dv must have the same dtype as q\");\n        CHECK_DEVICE(dv);\n        TORCH_CHECK(dv.stride(-1) == 1, \"dv must have contiguous last dimension\");\n        CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size);\n    } else {\n        dv = torch::empty_like(v);\n    }\n\n    // bool loop = seqlen_k > blocksize_c;\n    // TODO: change later, for now set to true for simplicity\n    bool loop = true;\n\n    auto opts = q.options();\n    auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));\n    at::Tensor dq_accum;\n    at::Tensor dk_accum, dv_accum;\n    if (loop) {\n        if (!deterministic) {\n            dq_accum = torch::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));\n        } else {\n            const int nsplits = (get_num_sm(get_current_device()) + batch_size * num_heads - 1) / (batch_size * num_heads);\n            dq_accum = torch::zeros({nsplits, batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));\n        }\n        // dk_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));\n        // dv_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));\n    }\n\n    at::Tensor dk_expanded, dv_expanded;\n    if (num_heads_k != num_heads) {  // MQA / GQA\n        dk_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts);\n        dv_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts);\n    } else {\n        dk_expanded = dk;\n        dv_expanded = dv;\n    }\n\n    Flash_bwd_params params;\n\n    set_params_dgrad(params,\n                     batch_size,\n                     seqlen_q, seqlen_k,\n                     seqlen_q_rounded, seqlen_k_rounded,\n                     num_heads, num_heads_k,\n                     head_size, head_size_rounded,\n                     q, k, v, out,\n                     dout, dq, dk_expanded, dv_expanded,\n                     nullptr,\n                     nullptr,\n                     loop ? dq_accum.data_ptr() : nullptr,\n                     // loop ? dk_accum.data_ptr() : nullptr,\n                     // loop ? dv_accum.data_ptr() : nullptr,\n                     nullptr,\n                     nullptr,\n                     softmax_lse.data_ptr(),\n                     softmax_d.data_ptr(),\n                     p_dropout,\n                     softmax_scale,\n                     window_size_left,\n                     window_size_right,\n                     softcap,\n                     deterministic,\n                     /*unpadded_lse*/false);\n    params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);\n\n    auto launch = &run_mha_bwd;\n\n    auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(\n        gen_, at::cuda::detail::getDefaultCUDAGenerator());\n\n    // We use a custom RNG that increases the offset by batch_size * nheads * 32.\n    int64_t counter_offset = params.b * params.h * 32;\n\n    if ( rng_state.has_value() ) {\n        params.rng_state = reinterpret_cast<uint64_t*>(rng_state.value().data_ptr());\n    } else if( is_dropout ) {\n        // See Note [Acquire lock when using random generators]\n        std::lock_guard<std::mutex> lock(gen->mutex_);\n        params.philox_args = gen->philox_cuda_state(counter_offset);\n        auto seeds = at::cuda::philox::unpack(params.philox_args);\n        params.rng_state[0] = std::get<0>(seeds);\n        params.rng_state[1] = std::get<1>(seeds);\n    }\n\n    set_params_alibi(params, alibi_slopes_, batch_size, num_heads);\n\n    if (seqlen_q > 0) {\n        launch(params, stream);\n    } else {\n        // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.\n        dk_expanded.zero_();\n        dv_expanded.zero_();\n        softmax_d.zero_();\n    }\n\n    // For MQA/GQA we need to sum dK and dV across the groups\n    if (num_heads_k != num_heads) {\n        at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});\n        at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});\n    }\n\n    return { dq, dk, dv, softmax_d };\n}\n\nstd::vector<at::Tensor>\nmha_varlen_bwd(const at::Tensor &dout,  // total_q x num_heads, x head_size\n               const at::Tensor &q,   // total_q x num_heads x head_size, total_q := \\sum_{i=0}^{b} s_i\n               const at::Tensor &k,   // total_k x num_heads_k x head_size, total_k := \\sum_{i=0}^{b} s_i\n               const at::Tensor &v,   // total_k x num_heads_k x head_size, total_k := \\sum_{i=0}^{b} s_i\n               const at::Tensor &out,   // total_q x num_heads x head_size\n               const at::Tensor &softmax_lse,    // h x total_q, softmax logsumexp\n               std::optional<at::Tensor> &dq_,   // total_q x num_heads x head_size, total_q := \\sum_{i=0}^{b} s_i\n               std::optional<at::Tensor> &dk_,   // total_k x num_heads_k x head_size, total_k := \\sum_{i=0}^{b} s_i\n               std::optional<at::Tensor> &dv_,   // total_k x num_heads_k x head_size, total_k := \\sum_{i=0}^{b} s_i\n               const at::Tensor &cu_seqlens_q,  // b+1\n               const at::Tensor &cu_seqlens_k,  // b+1\n               std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads\n               const int max_seqlen_q,\n               const int max_seqlen_k,          // max sequence length to choose the kernel\n               const float p_dropout,         // probability to drop\n               const float softmax_scale,\n               const bool zero_tensors,\n               const bool is_causal,\n               int window_size_left,\n               int window_size_right,\n               const float softcap,\n               const bool deterministic,\n               std::optional<at::Generator> gen_,\n               std::optional<at::Tensor> &rng_state) {\n\n    #ifdef FLASHATTENTION_DISABLE_BACKWARD\n        TORCH_CHECK(false, \"This flash attention build does not support backward.\");\n    #endif\n    if (is_causal) { window_size_right = 0; }\n\n    // Otherwise the kernel will be launched from cuda:0 device\n    at::cuda::CUDAGuard device_guard{q.device()};\n\n    auto [cc_major, cc_minor] = get_compute_capability(get_current_device());\n    bool is_sm8x_min = cc_major >= 8;\n    TORCH_CHECK(is_sm8x_min, \"FlashAttention only supports Ampere GPUs or newer.\");\n\n    bool is_dropout = p_dropout > 0.0;\n    auto stream = at::cuda::getCurrentCUDAStream().stream();\n\n    auto q_dtype = q.dtype();\n    TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,\n                \"FlashAttention only support fp16 and bf16 data type\");\n    TORCH_CHECK(k.dtype() == q_dtype, \"query and key must have the same dtype\");\n    TORCH_CHECK(v.dtype() == q_dtype, \"query and value must have the same dtype\");\n    TORCH_CHECK(out.dtype() == q_dtype, \"query and out must have the same dtype\");\n    TORCH_CHECK(dout.dtype() == q_dtype, \"query and dout must have the same dtype\");\n    TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, \"cu_seqlens_q must have dtype int32\");\n    TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, \"cu_seqlens_k must have dtype int32\");\n\n    CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);\n    CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);\n    CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k);\n\n    TORCH_CHECK(q.stride(-1) == 1, \"Input tensor must have contiguous last dimension\");\n    TORCH_CHECK(k.stride(-1) == 1, \"Input tensor must have contiguous last dimension\");\n    TORCH_CHECK(v.stride(-1) == 1, \"Input tensor must have contiguous last dimension\");\n    TORCH_CHECK(out.stride(-1) == 1, \"out tensor must have contiguous last dimension\");\n    TORCH_CHECK(dout.stride(-1) == 1, \"dout tensor must have contiguous last dimension\");\n    CHECK_CONTIGUOUS(cu_seqlens_q);\n    CHECK_CONTIGUOUS(cu_seqlens_k);\n\n    const auto sizes = q.sizes();\n\n    const int total_q = sizes[0];\n    const int batch_size = cu_seqlens_q.numel() - 1;\n    const int num_heads = sizes[1];\n    const int head_size = sizes[2];\n    const int total_k = k.size(0);\n    const int num_heads_k = k.size(1);\n    TORCH_CHECK(batch_size > 0, \"batch size must be positive\");\n    TORCH_CHECK(head_size % 8 == 0, \"head_size should be a multiple of 8\");\n    TORCH_CHECK(head_size <= 256, \"FlashAttention backward only supports head dimension at most 256\");\n    TORCH_CHECK(num_heads % num_heads_k == 0, \"Number of heads in key/value must divide number of heads in query\");\n    if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, \"Softcapping does not support dropout for now\"); }\n\n    auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };\n    const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64);\n    const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);\n    const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);\n\n    if (window_size_left >= max_seqlen_k) { window_size_left = -1; }\n    if (window_size_right >= max_seqlen_k) { window_size_right = -1; }\n\n    CHECK_SHAPE(q, total_q, num_heads, head_size);\n    CHECK_SHAPE(k, total_k, num_heads_k, head_size);\n    CHECK_SHAPE(v, total_k, num_heads_k, head_size);\n    CHECK_SHAPE(out, total_q, num_heads, head_size);\n    CHECK_SHAPE(dout, total_q, num_heads, head_size);\n    CHECK_SHAPE(cu_seqlens_q, batch_size + 1);\n    CHECK_SHAPE(cu_seqlens_k, batch_size + 1);\n\n    at::Tensor dq, dk, dv;\n    if (dq_.has_value()) {\n        dq = dq_.value();\n        TORCH_CHECK(dq.dtype() == q_dtype, \"dq must have the same dtype as q\");\n        CHECK_DEVICE(dq);\n        TORCH_CHECK(dq.stride(-1) == 1, \"dq must have contiguous last dimension\");\n        CHECK_SHAPE(dq, total_q, num_heads, head_size);\n    } else {\n        dq = torch::empty_like(q);\n    }\n    if (dk_.has_value()) {\n        dk = dk_.value();\n        TORCH_CHECK(dk.dtype() == q_dtype, \"dk must have the same dtype as q\");\n        CHECK_DEVICE(dk);\n        TORCH_CHECK(dk.stride(-1) == 1, \"dk must have contiguous last dimension\");\n        CHECK_SHAPE(dk, total_k, num_heads_k, head_size);\n    } else {\n        dk = torch::empty_like(k);\n    }\n    if (dv_.has_value()) {\n        dv = dv_.value();\n        TORCH_CHECK(dv.dtype() == q_dtype, \"dv must have the same dtype as q\");\n        CHECK_DEVICE(dv);\n        TORCH_CHECK(dv.stride(-1) == 1, \"dv must have contiguous last dimension\");\n        CHECK_SHAPE(dv, total_k, num_heads_k, head_size);\n    } else {\n        dv = torch::empty_like(v);\n    }\n\n    // bool loop = max_seqlen_k > blocksize_c;\n    // TODO: change later, for now set to true for simplicity\n    bool loop = true;\n\n    auto opts = q.options();\n    auto softmax_d = torch::empty({num_heads, total_q + 128 * batch_size}, opts.dtype(at::kFloat));\n    at::Tensor dq_accum;\n    if (loop) {\n        // We don't want to allocate dq_accum of size (batch, seqlen_q_rounded, num_heads, head_size_rounded)\n        // because that would be too large if there is a very long sequence and the rest of the sequences are short.\n        // Instead, we allocate dq_accum of size (total_q + 128 * batch, num_heads, head_size_rounded).\n        // Note that 128 is the max block size on the seqlen_q dimension.\n        // For dQ, the i-th sequence is stored in indices from cu_seqlens[i] + 128 * i to\n        // cu_seqlens[i + 1] * 128 * i - 1. This ensures that the i-th sequence and (i + 1)-th sequence will\n        // be at least 128 apart. It's ok for us to do atomicAdds up to 128 rows beyond what we're normally\n        // allowed to do. So we won't have to do any bound checking, and performance should stay the same.\n        // Same holds for softmax_d, since LSE is stored in unpadded format.\n        if (!deterministic) {\n            dq_accum = torch::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat));\n        } else {\n            const int nsplits = (get_num_sm(get_current_device()) + batch_size * num_heads - 1) / (batch_size * num_heads);\n            dq_accum = torch::zeros({nsplits, total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat));\n        }\n    }\n\n    at::Tensor dk_expanded, dv_expanded;\n    if (num_heads_k != num_heads) {  // MQA / GQA\n        dk_expanded = torch::empty({total_k, num_heads, head_size}, opts);\n        dv_expanded = torch::empty({total_k, num_heads, head_size}, opts);\n    } else {\n        dk_expanded = dk;\n        dv_expanded = dv;\n    }\n\n    if( zero_tensors ) {\n        dq.zero_();\n        dk_expanded.zero_();\n        dv_expanded.zero_();\n        softmax_d.zero_();\n    }\n\n    Flash_bwd_params params;\n\n    set_params_dgrad(params,\n                     batch_size,\n                     max_seqlen_q, max_seqlen_k,\n                     seqlen_q_rounded, seqlen_k_rounded,\n                     num_heads, num_heads_k,\n                     head_size, head_size_rounded,\n                     q, k, v, out,\n                     dout, dq, dk_expanded, dv_expanded,\n                     cu_seqlens_q.data_ptr(),\n                     cu_seqlens_k.data_ptr(),\n                     loop ? dq_accum.data_ptr() : nullptr,\n                     nullptr,\n                     nullptr,\n                     softmax_lse.data_ptr(),\n                     softmax_d.data_ptr(),\n                     p_dropout,\n                     softmax_scale,\n                     window_size_left,\n                     window_size_right,\n                     softcap,\n                     deterministic,\n                     /*unpadded_lse*/true);\n    params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);\n    params.total_q = total_q;\n\n    auto launch = &run_mha_bwd;\n\n    auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(\n        gen_, at::cuda::detail::getDefaultCUDAGenerator());\n\n    // We use a custom RNG that increases the offset by batch_size * nheads * 32.\n    int64_t counter_offset = params.b * params.h * 32;\n\n    if ( rng_state.has_value() ) {\n        params.rng_state = reinterpret_cast<uint64_t*>(rng_state.value().data_ptr());\n    } else if( is_dropout ) {\n        // See Note [Acquire lock when using random generators]\n        std::lock_guard<std::mutex> lock(gen->mutex_);\n        params.philox_args = gen->philox_cuda_state(counter_offset);\n        auto seeds = at::cuda::philox::unpack(params.philox_args);\n        params.rng_state[0] = std::get<0>(seeds);\n        params.rng_state[1] = std::get<1>(seeds);\n    }\n\n    set_params_alibi(params, alibi_slopes_, batch_size, num_heads);\n\n    if (max_seqlen_q > 0) {\n        launch(params, stream);\n    } else {\n        // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.\n        dk_expanded.zero_();\n        dv_expanded.zero_();\n        softmax_d.zero_();\n    }\n\n    // For MQA/GQA we need to sum dK and dV across the groups\n    if (num_heads_k != num_heads) {\n        at::sum_out(dk, at::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2});\n        at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2});\n    }\n\n    return { dq, dk, dv, softmax_d };\n}\n\nstd::vector<at::Tensor>\nmha_fwd_kvcache(at::Tensor &q,                 // batch_size x seqlen_q x num_heads x head_size\n                const at::Tensor &kcache,            // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.\n                const at::Tensor &vcache,            // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.\n                std::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size\n                std::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size\n                std::optional<const at::Tensor> &seqlens_k_, // batch_size\n                std::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)\n                std::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)\n                std::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache\n                std::optional<const at::Tensor> &leftpad_k_, // batch_size\n                std::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq\n                std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads\n                std::optional<at::Tensor> &out_,             // batch_size x seqlen_q x num_heads x head_size\n                const float softmax_scale,\n                bool is_causal,\n                int window_size_left,\n                int window_size_right,\n                const float softcap,\n                bool is_rotary_interleaved,   // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2\n                int num_splits\n                ) {\n\n    // Otherwise the kernel will be launched from cuda:0 device\n    at::cuda::CUDAGuard device_guard{q.device()};\n\n    auto [cc_major, cc_minor] = get_compute_capability(get_current_device());\n    bool is_sm8x_min = cc_major >= 8;\n    TORCH_CHECK(is_sm8x_min, \"FlashAttention only supports Ampere GPUs or newer.\");\n\n    auto q_dtype = q.dtype();\n    TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,\n                \"FlashAttention only support fp16 and bf16 data type\");\n    TORCH_CHECK(kcache.dtype() == q_dtype, \"query and key must have the same dtype\");\n    TORCH_CHECK(vcache.dtype() == q_dtype, \"query and value must have the same dtype\");\n\n    CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);\n\n    TORCH_CHECK(q.stride(-1) == 1, \"Input tensor must have contiguous last dimension\");\n    TORCH_CHECK(kcache.stride(-1) == 1, \"Input tensor must have contiguous last dimension\");\n    TORCH_CHECK(vcache.stride(-1) == 1, \"Input tensor must have contiguous last dimension\");\n\n    at::Tensor block_table;\n    const bool paged_KV = block_table_.has_value();\n    if (paged_KV) {\n        TORCH_CHECK(!cache_batch_idx_.has_value(), \"Paged KVcache does not support cache_batch_idx\");\n        block_table = block_table_.value();\n        CHECK_DEVICE(block_table);\n        TORCH_CHECK(block_table.dtype() == torch::kInt32, \"block_table must have dtype torch.int32\");\n        TORCH_CHECK(block_table.stride(-1) == 1, \"block_table must have contiguous last dimension\");\n    }\n\n    const auto sizes = q.sizes();\n\n    const int batch_size = sizes[0];\n    int seqlen_q = sizes[1];\n    int num_heads = sizes[2];\n    const int head_size_og = sizes[3];\n\n    const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);\n    const int num_blocks = !paged_KV ? 0 : kcache.size(0);\n    const int page_block_size = !paged_KV ? 1 : kcache.size(1);\n    TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, \"Paged KV cache block size must be divisible by 256\");\n    const int seqlen_k = !paged_KV ? kcache.size(1) : max_num_blocks_per_seq * page_block_size;\n    const int num_heads_k = kcache.size(2);\n    const int batch_size_c = !paged_KV ? kcache.size(0) : batch_size;\n    TORCH_CHECK(batch_size > 0, \"batch size must be positive\");\n    TORCH_CHECK(head_size_og <= 256, \"FlashAttention forward only supports head dimension at most 256\");\n    TORCH_CHECK(num_heads % num_heads_k == 0, \"Number of heads in key/value must divide number of heads in query\");\n\n    // causal=true is the same as causal=false in this case\n    if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }\n    if (is_causal) { window_size_right = 0; }\n\n    // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case\n    // H/t Daniel Haziza\n    const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0 && !alibi_slopes_.has_value();\n    if (seqlenq_ngroups_swapped) {\n        const int ngroups = num_heads / num_heads_k;\n        q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);\n        seqlen_q = ngroups;\n        num_heads = num_heads_k;\n    }\n\n    if (window_size_left >= seqlen_k) { window_size_left = -1; }\n    if (window_size_right >= seqlen_k) { window_size_right = -1; }\n\n    CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);\n    if (!paged_KV) {\n        CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);\n        CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);\n    } else {\n        CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_og);\n        CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_og);\n        CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);\n    }\n\n    at::Tensor q_padded, kcache_padded, vcache_padded;\n    if (head_size_og % 8 != 0) {\n        q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));\n        kcache_padded = torch::nn::functional::pad(kcache, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));\n        vcache_padded = torch::nn::functional::pad(vcache, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));\n    } else {\n        q_padded = q;\n        kcache_padded = kcache;\n        vcache_padded = vcache;\n    }\n\n    at::Tensor out;\n    if (out_.has_value()) {\n        out = out_.value();\n        TORCH_CHECK(out.dtype() == q_dtype, \"Output must have the same dtype as inputs\");\n        CHECK_DEVICE(out);\n        TORCH_CHECK(out.stride(-1) == 1, \"Output tensor must have contiguous last dimension\");\n        CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);\n        if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }\n    } else {\n        out = torch::empty_like(q_padded);\n    }\n\n    auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };\n    const int head_size = round_multiple(head_size_og, 8);\n    const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64);\n    const int seqlen_q_rounded = round_multiple(seqlen_q, 128);\n    const int seqlen_k_rounded = round_multiple(seqlen_k, 128);\n\n    auto opts = q.options();\n\n    auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));\n\n    Flash_fwd_params params;\n    set_params_fprop(params,\n                     batch_size,\n                     seqlen_q, seqlen_k,\n                     seqlen_q_rounded, seqlen_k_rounded,\n                     num_heads, num_heads_k,\n                     head_size, head_size_rounded,\n                     q_padded, kcache_padded, vcache_padded, out,\n                     /*cu_seqlens_q_d=*/nullptr,\n                     /*cu_seqlens_k_d=*/nullptr,\n                     /*seqused_k=*/nullptr,\n                     /*p_d=*/nullptr,\n                     softmax_lse.data_ptr(),\n                     /*p_dropout=*/0.f,\n                     softmax_scale,\n                     window_size_left,\n                     window_size_right,\n                     softcap\n                     );\n\n    at::Tensor k, v, k_padded, v_padded;\n    if (k_.has_value()) {\n        TORCH_CHECK(v_.has_value(), \"If key is supplied, value must also be passed in\");\n        TORCH_CHECK(seqlens_k_.has_value(), \"If key is supplied, seqlens_k must also be passed in\");\n        TORCH_CHECK(seqlen_q <= seqlen_k, \"If key is supplied, it must have seqlen <= the seqlen of the KV cache\");\n        k = k_.value();\n        v = v_.value();\n        TORCH_CHECK(k.dtype() == q_dtype, \"Key must have the same dtype as query\");\n        TORCH_CHECK(v.dtype() == q_dtype, \"Value must have the same dtype as query\");\n        CHECK_DEVICE(k); CHECK_DEVICE(v);\n        TORCH_CHECK(k.stride(-1) == 1, \"Key tensor must have contiguous last dimension\");\n        TORCH_CHECK(v.stride(-1) == 1, \"Value tensor must have contiguous last dimension\");\n        int seqlen_knew = k.size(1);\n        CHECK_SHAPE(k, batch_size, seqlen_knew, num_heads_k, head_size_og);\n        CHECK_SHAPE(v, batch_size, seqlen_knew, num_heads_k, head_size_og);\n        if (head_size_og % 8 != 0) {\n            k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));\n            v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));\n        } else {\n            k_padded = k;\n            v_padded = v;\n        }\n        params.seqlen_knew = seqlen_knew;\n        params.knew_ptr = k_padded.data_ptr();\n        params.vnew_ptr = v_padded.data_ptr();\n        // All stride are in elements, not bytes.\n        params.knew_batch_stride = k_padded.stride(0);\n        params.vnew_batch_stride = v_padded.stride(0);\n        params.knew_row_stride = k_padded.stride(-3);\n        params.vnew_row_stride = v_padded.stride(-3);\n        params.knew_head_stride = k_padded.stride(-2);\n        params.vnew_head_stride = v_padded.stride(-2);\n    }\n\n    if (seqlens_k_.has_value()) {\n        auto seqlens_k = seqlens_k_.value();\n        TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, \"seqlens_k must have dtype int32\");\n        CHECK_DEVICE(seqlens_k);\n        CHECK_CONTIGUOUS(seqlens_k);\n        CHECK_SHAPE(seqlens_k, batch_size);\n        params.cu_seqlens_k = static_cast<int *>(seqlens_k.data_ptr());\n    }\n    params.is_seqlens_k_cumulative = !(seqlens_k_.has_value());\n    if (leftpad_k_.has_value()) {\n        TORCH_CHECK(!paged_KV, \"We don't support Paged KV and leftpad_k running at the same time yet\");\n        auto leftpad_k = leftpad_k_.value();\n        TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, \"leftpad_k must have dtype int32\");\n        CHECK_DEVICE(leftpad_k);\n        CHECK_CONTIGUOUS(leftpad_k);\n        CHECK_SHAPE(leftpad_k, batch_size);\n        params.leftpad_k = static_cast<int *>(leftpad_k.data_ptr());\n    }\n\n    if (rotary_cos_.has_value()) {\n        TORCH_CHECK(k_.has_value(), \"If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided\");\n        auto rotary_cos = rotary_cos_.value();\n        CHECK_DEVICE(rotary_cos);\n        params.rotary_dim = rotary_cos.size(1) * 2;\n        TORCH_CHECK(params.rotary_dim <= head_size, \"rotary_dim must be <= headdim\");\n        TORCH_CHECK(params.rotary_dim % 16 == 0, \"Only rotary dimensions divisible by 16 are currently supported\");\n        const int seqlen_ro = rotary_cos.size(0);\n        TORCH_CHECK(seqlen_ro >= seqlen_k, \"cos/sin seqlen must be at least the seqlen of KV cache\");\n        CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2);\n        CHECK_CONTIGUOUS(rotary_cos);\n        TORCH_CHECK(rotary_cos.scalar_type() == q_dtype, \"rotary_cos must have the same dtype as query\");\n\n        TORCH_CHECK(rotary_sin_.has_value(), \"If rotary cos is provided, rotary sin must also be provided\");\n        auto rotary_sin = rotary_sin_.value();\n        CHECK_DEVICE(rotary_sin);\n        CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2);\n        CHECK_CONTIGUOUS(rotary_sin);\n        TORCH_CHECK(rotary_sin.scalar_type() == q_dtype, \"rotary_cos must have the same dtype as query\");\n        params.rotary_cos_ptr = rotary_cos.data_ptr();\n        params.rotary_sin_ptr = rotary_sin.data_ptr();\n        params.is_rotary_interleaved = is_rotary_interleaved;\n    } else {\n        params.rotary_dim = 0;\n    }\n\n    if (cache_batch_idx_.has_value()) {\n        auto cache_batch_idx = cache_batch_idx_.value();\n        CHECK_DEVICE(cache_batch_idx);\n        CHECK_CONTIGUOUS(cache_batch_idx);\n        TORCH_CHECK(cache_batch_idx.scalar_type() == torch::kInt32, \"cache_batch_idx must have dtype int32\");\n        params.cache_batch_idx = reinterpret_cast<int *>(cache_batch_idx.data_ptr());\n    }\n\n    // Keep references to these tensors to extend their lifetime\n    at::Tensor softmax_lse_accum, out_accum;\n    std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(\n        params, batch_size, num_heads, head_size, seqlen_k, seqlen_q,\n        head_size_rounded, /*dropout*/ 0.f, num_splits, get_num_sm(get_current_device()), opts);\n\n    if (paged_KV) {\n        params.block_table = block_table.data_ptr<int>();\n        params.block_table_batch_stride = block_table.stride(0);\n    }\n    params.page_block_size = page_block_size;\n\n\n    set_params_alibi(params, alibi_slopes_, batch_size, num_heads);\n\n    auto stream = at::cuda::getCurrentCUDAStream().stream();\n    // Only split kernel supports appending to KV cache, or indexing to the cache with cache_batch_idx,\n    // or paged KV cache\n    run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value() || cache_batch_idx_.has_value() || paged_KV);\n\n    if (head_size_og % 8 != 0) {\n        out = out.index({\"...\", torch::indexing::Slice(torch::indexing::None, head_size_og)});\n        if (out_.has_value()) { out_.value().copy_(out); }\n        if (k_.has_value()) {\n            // It's expensive to copy the KV cache here for the case where head size not divisible by 8,\n            // but we don't expect to get this case in practice. This is just so that the code works for that case.\n            kcache.copy_(kcache_padded.index({\"...\", torch::indexing::Slice(torch::indexing::None, head_size_og)}));\n            vcache.copy_(vcache_padded.index({\"...\", torch::indexing::Slice(torch::indexing::None, head_size_og)}));\n        }\n    }\n\n    if (seqlenq_ngroups_swapped) {\n        out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});\n        softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});\n    }\n    return {out, softmax_lse};\n}\n} // namespace FLASH_NAMESPACE\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n    m.doc() = \"FlashAttention\";\n    m.def(\"fwd\", &FLASH_NAMESPACE::mha_fwd, \"Forward pass\");\n    m.def(\"varlen_fwd\", &FLASH_NAMESPACE::mha_varlen_fwd, \"Forward pass (variable length)\");\n    m.def(\"bwd\", &FLASH_NAMESPACE::mha_bwd, \"Backward pass\");\n    m.def(\"varlen_bwd\", &FLASH_NAMESPACE::mha_varlen_bwd, \"Backward pass (variable length)\");\n    m.def(\"fwd_kvcache\", &FLASH_NAMESPACE::mha_fwd_kvcache, \"Forward pass, with KV-cache\");\n}\n"
  },
  {
    "path": "csrc/flash_attn/src/alibi.h",
    "content": "#include <cmath>\n\n#include \"namespace_config.h\"\n#include <cute/tensor.hpp>\n\n#include <cutlass/cutlass.h>\n#include <cutlass/array.h>\n\n#include \"utils.h\"\n\nnamespace FLASH_NAMESPACE {\n\nusing namespace cute;\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <bool Is_causal>\nstruct Alibi {\n\n    const float alibi_slope;\n    const int max_seqlen_k, max_seqlen_q;\n\n    __forceinline__ __device__ Alibi(const float alibi_slope, const int max_seqlen_k, const int max_seqlen_q)\n        : alibi_slope(alibi_slope)\n        , max_seqlen_k(max_seqlen_k)\n        , max_seqlen_q(max_seqlen_q) {\n    };\n\n\n    template <typename Engine, typename Layout>\n    __forceinline__ __device__ void apply_alibi(Tensor<Engine, Layout> &tensor,\n                                      const int col_idx_offset_,\n                                      const int row_idx_offset,\n                                      const int warp_row_stride) {\n        // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))\n        static_assert(Layout::rank == 2, \"Only support 2D Tensor\");\n        const int lane_id = threadIdx.x % 32;\n        const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;\n        if constexpr (Is_causal) {  // Simpler, we add the same bias vector to all rows\n            #pragma unroll\n            for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {\n                const int col_idx_base = col_idx_offset + nj * 8;\n                #pragma unroll\n                for (int j = 0; j < size<1, 0>(tensor); ++j) {\n                    const int col_idx = col_idx_base + j;\n                    #pragma unroll\n                    for (int mi = 0; mi < size<0>(tensor); ++mi) {\n                        tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;\n                    }\n                }\n            }\n        } else {  // Bias depends on both row_idx and col_idx\n            #pragma unroll\n            for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {\n                const int row_idx_base = row_idx_offset + mi * warp_row_stride;\n                #pragma unroll\n                for (int i = 0; i < size<0, 0>(tensor); ++i) {\n                    const int row_idx = row_idx_base + i * 8;\n                    #pragma unroll\n                    for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {\n                        const int col_idx_base = col_idx_offset + nj * 8;\n                        #pragma unroll\n                        for (int j = 0; j < size<1, 0>(tensor); ++j) {\n                            const int col_idx = col_idx_base + j;\n                            tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);\n                        }\n                    }\n                }\n            }\n        }\n    }\n\n};\n\n}  // namespace FLASH_NAMESPACE\n"
  },
  {
    "path": "csrc/flash_attn/src/block_info.h",
    "content": "/******************************************************************************\n * Copyright (c) 2023, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n#include \"namespace_config.h\"\nnamespace FLASH_NAMESPACE {\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<bool Varlen=true>\nstruct BlockInfo {\n\n    template<typename Params>\n    __device__ BlockInfo(const Params &params, const int bidb)\n        : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb])\n        , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb])\n        , actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)\n        // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].\n        // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.\n        , leftpad_k(params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb])\n        , seqlen_k_cache((!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])) - leftpad_k)\n        , actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] - leftpad_k : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))\n        {\n        }\n\n    template <typename index_t>\n    __forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {\n        return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;\n    }\n\n    template <typename index_t>\n    __forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {\n        return sum_s_k == -1 ? bidb * batch_stride + leftpad_k * row_stride : uint32_t(sum_s_k + leftpad_k) * row_stride;\n    }\n\n    const int sum_s_q;\n    const int sum_s_k;\n    const int actual_seqlen_q;\n    // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.\n    const int leftpad_k;\n    const int seqlen_k_cache;\n    const int actual_seqlen_k;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n}  // namespace FLASH_NAMESPACE\n"
  },
  {
    "path": "csrc/flash_attn/src/dropout.h",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n#include \"namespace_config.h\"\n#include \"philox.cuh\"\n#include \"utils.h\"\n\nnamespace FLASH_NAMESPACE {\n\nstruct Dropout {\n\n    const unsigned long long seed, offset;\n    const uint8_t p_dropout_in_uint8_t;\n\n    __forceinline__ __device__ Dropout(const unsigned long long seed, const unsigned long long offset,\n                              const uint8_t p_dropout_in_uint8_t,\n                              const int bid, const int hid, const int tid, const int nheads)\n            : seed(seed)\n            , offset(offset + (bid * nheads + hid) * 32 + tid % 32)\n            , p_dropout_in_uint8_t(p_dropout_in_uint8_t) {\n    }\n\n    template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout>\n    __forceinline__ __device__ void apply_dropout(Tensor<Engine, Layout> &tensor_,\n                                         int block_row_start, int block_col_start, int block_row_stride) {\n        // convert shape from (4, MMA_M, MMA_N) to (8, MMA_M, MMA_N / 2)\n        Tensor tensor = make_tensor(tensor_.data(), FLASH_NAMESPACE::convert_layout_acc_dropout(tensor_.layout()));\n        using T = typename Engine::value_type;\n        auto encode_dropout = [](bool keep, T val) {\n            return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0));\n        };\n        static_assert(decltype(size<2>(tensor))::value % 2 == 0);\n        const uint16_t p_dropout_8bit_in_uint16_t = uint16_t(p_dropout_in_uint8_t);\n        const uint32_t p_dropout_8bit_in_uint32_t = (uint32_t(p_dropout_8bit_in_uint16_t) << 16) | uint32_t(p_dropout_8bit_in_uint16_t);\n        // if (cute::thread0()) { printf(\"threshold2 = 0x%x\\n\", p_dropout_8bit_in_uint32_t); }\n        #pragma unroll\n        for (int m = 0; m < size<1>(tensor); ++m, block_row_start += block_row_stride) {\n            uint2 rowcol = make_uint2(block_row_start, block_col_start);\n            #pragma unroll\n            for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) {\n                // if (cute::thread(32, 0)) { printf(\"m = %d, n = %d, row = %d, col = %d\\n\", m, n, int(rowcol.x), int(rowcol.y));}\n                uint4 random_uint4 = FLASH_NAMESPACE::philox(seed, reinterpret_cast<unsigned long long&>(rowcol), offset);\n                // if (cute::thread0()) { printf(\"philox = %u, %d, %d, %d\\n\", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);}\n                uint8_t (&rnd_8)[16] = reinterpret_cast<uint8_t (&)[16]>(random_uint4);\n                // Special implementation for 16-bit types: we duplicate the threshold to the\n                // low and high 16 bits of a 32-bit value, then use the f16x2 comparison instruction\n                // to get a mask. The low 16 bits of the mask will be either 0xffff or 0x0000,\n                // and the high 16 bits will be either 0xffff or 0x0000, depending on whether\n                // the random value is less than the threshold.\n                // We then do a bit-wise AND between the mask and the original value (in 32-bit).\n                // We're exploiting the fact that floating point comparison is equivalent to integer\n                // comparison, since we're comparing unsigned integers whose top 8-bits are zero.\n                if (!encode_dropout_in_sign_bit\n                    && (std::is_same<T, cutlass::half_t>::value || std::is_same<T, cutlass::bfloat16_t>::value)) {\n                    uint16_t rnd_16[16];\n                    #pragma unroll\n                    for (int i = 0; i < 16; i++) { rnd_16[i] = uint16_t(rnd_8[i]); }\n                    uint32_t (&rnd_32)[8] = reinterpret_cast<uint32_t (&)[8]>(rnd_16);\n                    #pragma unroll\n                    for (int j = 0; j < 2; j++) {\n                        Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));\n                        // if (cute::thread0()) { printf(\"random = 0x%x, 0x%x, 0x%x, 0x%x\\n\", rnd_32[j * 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); }\n                        // if (cute::thread0()) { printf(\"tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\\n\", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }\n                        #pragma unroll\n                        for (int i = 0; i < 4; i++) {\n                            uint32_t mask;\n                            asm volatile(\"set.le.u32.f16x2 %0, %1, %2;\\n\" : \"=r\"(mask) : \"r\"(rnd_32[j * 4 + i]), \"r\"(p_dropout_8bit_in_uint32_t));\n                            tensor_uint32(i) &= mask;\n                        }\n                        // if (cute::thread0()) { printf(\"tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\\n\", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }\n                    }\n                } else {\n                    #pragma unroll\n                    for (int j = 0; j < 2; j++) {\n                        #pragma unroll\n                        for (int i = 0; i < 8; i++) {\n                            tensor(i, m, n * 2 + j) = encode_dropout(rnd_8[j * 8 + i] <= p_dropout_in_uint8_t, tensor(i, m, n * 2 + j));\n                        }\n                        Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));\n                        // if (cute::thread0()) { printf(\"tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\\n\", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }\n                    }\n                }\n                // // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {\n                // //     printf(\"n = %d, ph  Philox: %u, %u, %u, %u\\n\", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w);\n                // // }\n            }\n        }\n    }\n\n};\n\n} // namespace FLASH_NAMESPACE\n"
  },
  {
    "path": "csrc/flash_attn/src/flash.h",
    "content": "/******************************************************************************\n * Copyright (c) 2023, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n#include \"namespace_config.h\"\n\n#include <cuda.h>\n#include <vector>\n\n#include <ATen/cuda/CUDAGeneratorImpl.h> // For at::Generator and at::PhiloxCudaState\n\nnamespace FLASH_NAMESPACE {\nconstexpr int TOTAL_DIM = 0;\nconstexpr int H_DIM = 1;\nconstexpr int D_DIM = 2;\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct Qkv_params {\n    using index_t = int64_t;\n    // The QKV matrices.\n    void *__restrict__ q_ptr;\n    void *__restrict__ k_ptr;\n    void *__restrict__ v_ptr;\n\n    // The stride between rows of the Q, K and V matrices.\n    index_t q_batch_stride;\n    index_t k_batch_stride;\n    index_t v_batch_stride;\n    index_t q_row_stride;\n    index_t k_row_stride;\n    index_t v_row_stride;\n    index_t q_head_stride;\n    index_t k_head_stride;\n    index_t v_head_stride;\n\n    // The number of heads.\n    int h, h_k;\n    // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be\n    // different from nheads (query).\n    int h_h_k_ratio; // precompute h / h_k,\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct Flash_fwd_params : public Qkv_params {\n\n    // The O matrix (output).\n    void * __restrict__ o_ptr;\n    void * __restrict__ oaccum_ptr;\n\n    // The stride between rows of O.\n    index_t o_batch_stride;\n    index_t o_row_stride;\n    index_t o_head_stride;\n\n    // The pointer to the P matrix.\n    void * __restrict__ p_ptr;\n\n    // The pointer to the softmax sum.\n    void * __restrict__ softmax_lse_ptr;\n    void * __restrict__ softmax_lseaccum_ptr;\n\n    // The dimensions.\n    int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim, total_q;\n\n    // The scaling factors for the kernel.\n    float scale_softmax;\n    float scale_softmax_log2;\n\n    // array of length b+1 holding starting offset of each sequence.\n    int * __restrict__ cu_seqlens_q;\n    int * __restrict__ cu_seqlens_k;\n    int * __restrict__ leftpad_k;\n\n    // If provided, the actual length of each k sequence.\n    int * __restrict__ seqused_k;\n\n    int *__restrict__ blockmask;\n\n    // The K_new and V_new matrices.\n    void * __restrict__ knew_ptr;\n    void * __restrict__ vnew_ptr;\n\n    // The stride between rows of the Q, K and V matrices.\n    index_t knew_batch_stride;\n    index_t vnew_batch_stride;\n    index_t knew_row_stride;\n    index_t vnew_row_stride;\n    index_t knew_head_stride;\n    index_t vnew_head_stride;\n\n    // The cos and sin matrices for rotary embedding.\n    void * __restrict__ rotary_cos_ptr;\n    void * __restrict__ rotary_sin_ptr;\n\n    // The indices to index into the KV cache.\n    int * __restrict__ cache_batch_idx;\n\n    // Paged KV cache\n    int * __restrict__ block_table;\n    index_t block_table_batch_stride;\n    int page_block_size;\n\n    // The dropout probability (probability of keeping an activation).\n    float p_dropout;\n    // uint32_t p_dropout_in_uint;\n    // uint16_t p_dropout_in_uint16_t;\n    uint8_t p_dropout_in_uint8_t;\n\n    // Scale factor of 1 / (1 - p_dropout).\n    float rp_dropout;\n    float scale_softmax_rp_dropout;\n\n    // Local window size\n    int window_size_left, window_size_right;\n    float softcap;\n\n    // Random state.\n    at::PhiloxCudaState philox_args;\n\n    // Pointer to the RNG seed (idx 0) and offset (idx 1).\n    uint64_t * rng_state;\n\n    bool is_bf16;\n    bool is_causal;\n\n    // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].\n    // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.\n    bool is_seqlens_k_cumulative;\n\n    bool is_rotary_interleaved;\n\n    int num_splits;  // For split-KV version\n\n    void * __restrict__ alibi_slopes_ptr;\n    index_t alibi_slopes_batch_stride;\n\n    bool unpadded_lse;  // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q].\n    bool seqlenq_ngroups_swapped;  // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d).\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct Flash_bwd_params : public Flash_fwd_params {\n\n    // The dO and dQKV matrices.\n    void *__restrict__ do_ptr;\n    void *__restrict__ dq_ptr;\n    void *__restrict__ dk_ptr;\n    void *__restrict__ dv_ptr;\n\n    // To accumulate dQ\n    void *__restrict__ dq_accum_ptr;\n    void *__restrict__ dk_accum_ptr;\n    void *__restrict__ dv_accum_ptr;\n\n    // // To accumulate dK and dV in case we're splitting the bwd along seqlen_q\n    // dimension void *__restrict__ dk_accum_ptr; void *__restrict__\n    // dv_accum_ptr;\n\n    // The stride between rows of the dO, dQ, dK and dV matrices.\n    // TD [2022-04-16]: We're using 32-bit indexing to save registers.\n    // The code probably won't work for arrays larger than 2GB.\n    index_t do_batch_stride;\n    index_t do_row_stride;\n    index_t do_head_stride;\n    index_t dq_batch_stride;\n    index_t dk_batch_stride;\n    index_t dv_batch_stride;\n    index_t dq_row_stride;\n    index_t dk_row_stride;\n    index_t dv_row_stride;\n    index_t dq_head_stride;\n    index_t dk_head_stride;\n    index_t dv_head_stride;\n\n    // The pointer to the softmax d sum.\n    void *__restrict__ dsoftmax_sum;\n\n    bool deterministic;\n    index_t dq_accum_split_stride;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename T, int Headdim, bool Is_causal> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);\ntemplate<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);\n\ntemplate<typename T, int Headdim, bool Is_causal> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);\n\n}  // namespace FLASH_NAMESPACE\n"
  },
  {
    "path": "csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_bwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_bwd_<cutlass::bfloat16_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim128<cutlass::bfloat16_t, true>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_bwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_bwd_<cutlass::bfloat16_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim128<cutlass::bfloat16_t, false>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_bwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_bwd_<cutlass::half_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim128<cutlass::half_t, true>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_bwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_bwd_<cutlass::half_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim128<cutlass::half_t, false>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_bwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_bwd_<cutlass::bfloat16_t, 192, true>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim192<cutlass::bfloat16_t, true>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_bwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_bwd_<cutlass::bfloat16_t, 192, false>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim192<cutlass::bfloat16_t, false>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_bwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_bwd_<cutlass::half_t, 192, true>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim192<cutlass::half_t, true>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_bwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_bwd_<cutlass::half_t, 192, false>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim192<cutlass::half_t, false>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_bwd_hdim256_bf16_causal_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_bwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_bwd_<cutlass::bfloat16_t, 256, true>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim256<cutlass::bfloat16_t, true>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_bwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_bwd_<cutlass::bfloat16_t, 256, false>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim256<cutlass::bfloat16_t, false>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_bwd_hdim256_fp16_causal_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_bwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_bwd_<cutlass::half_t, 256, true>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim256<cutlass::half_t, true>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_bwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_bwd_<cutlass::half_t, 256, false>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim256<cutlass::half_t, false>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_bwd_hdim32_bf16_causal_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_bwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_bwd_<cutlass::bfloat16_t, 32, true>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim32<cutlass::bfloat16_t, true>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_bwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_bwd_<cutlass::bfloat16_t, 32, false>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim32<cutlass::bfloat16_t, false>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_bwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_bwd_<cutlass::half_t, 32, true>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim32<cutlass::half_t, true>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_bwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_bwd_<cutlass::half_t, 32, false>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim32<cutlass::half_t, false>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_bwd_hdim64_bf16_causal_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_bwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_bwd_<cutlass::bfloat16_t, 64, true>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim64<cutlass::bfloat16_t, true>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_bwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_bwd_<cutlass::bfloat16_t, 64, false>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim64<cutlass::bfloat16_t, false>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_bwd_hdim64_fp16_causal_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_bwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_bwd_<cutlass::half_t, 64, true>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim64<cutlass::half_t, true>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_bwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_bwd_<cutlass::half_t, 64, false>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim64<cutlass::half_t, false>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_bwd_hdim96_bf16_causal_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_bwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_bwd_<cutlass::bfloat16_t, 96, true>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim96<cutlass::bfloat16_t, true>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_bwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_bwd_<cutlass::bfloat16_t, 96, false>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim96<cutlass::bfloat16_t, false>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_bwd_hdim96_fp16_causal_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_bwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_bwd_<cutlass::half_t, 96, true>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim96<cutlass::half_t, true>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_bwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_bwd_<cutlass::half_t, 96, false>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim96<cutlass::half_t, false>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_bwd_kernel.h",
    "content": "/***************************************************************************************************\n * Copyright (c) 2024, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n#include \"namespace_config.h\"\n#include <cute/tensor.hpp>\n\n#include <cutlass/cutlass.h>\n#include <cutlass/array.h>\n#include <cutlass/numeric_types.h>\n\n#include \"block_info.h\"\n#include \"kernel_traits.h\"\n#include \"utils.h\"\n#include \"softmax.h\"\n#include \"mask.h\"\n#include \"dropout.h\"\n\n#include \"alibi.h\"\n\nnamespace FLASH_NAMESPACE {\n\nusing namespace cute;\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int MMA_N,\n          class... Args,\n          class TiledMMA>\nCUTE_HOST_DEVICE\nauto\nmake_tiled_copy_B_warpcontiguousN(Copy_Atom<Args...> const& copy_atom,\n                                  TiledMMA           const& tiled_mma) {\n    constexpr int TileShape_N = decltype(tiled_mma.template tile_size_mnk<1>())::value;\n    constexpr int TileShape_K = decltype(tiled_mma.template tile_size_mnk<2>())::value;\n    using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;\n    constexpr int AtomShape_N = decltype(size<1>(AtomShape_MNK{}))::value;\n    // Divide by 2 because right now we always use 2 for the ValLayout\n    constexpr int kNWarpsN = TileShape_N / AtomShape_N / 2;\n    constexpr int MMAStride_N = MMA_N * AtomShape_N * 2;\n    // This gives the correct layout, idk why.\n    // auto t = make_tile(Layout<Shape<Shape<_8, _2>, _2>,\n    //                           Stride<Stride<_1, _64>, _8> >{},\n    // auto t = make_tile(Layout<Shape<_8, _2, _2>,\n    //                           Stride<_1, _64, _8> >{},\n    auto t = make_tile(Layout<Shape<Int<AtomShape_N>, Int<kNWarpsN>, _2>,   // (8, 2, 2) or (8, 4, 2)\n                              Stride<_1, Int<MMAStride_N>, _8> >{},       // (1, 64, 8) or (1, 32, 8)\n                       make_layout(Int<TileShape_K>{}));\n    // if (cute::thread0()) {printf(\"make_tiled_copy_B_warpcontiguousN \"); print(t); printf(\"\\n\");  }\n    return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutB_TV(), t);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int MMA_N,\n          class... Args,\n          class TiledMMA>\nCUTE_HOST_DEVICE\nauto\nmake_tiled_copy_C_warpcontiguousN(Copy_Atom<Args...> const& copy_atom,\n                                  TiledMMA           const& tiled_mma) {\n    constexpr int TileShape_M = decltype(tiled_mma.template tile_size_mnk<0>())::value;\n    constexpr int TileShape_N = decltype(tiled_mma.template tile_size_mnk<1>())::value;\n    using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;\n    constexpr int AtomShape_N = decltype(size<1>(AtomShape_MNK{}))::value;\n    // Divide by 2 because right now we always use 2 for the ValLayout\n    constexpr int kNWarpsN = TileShape_N / AtomShape_N / 2;\n    constexpr int MMAStride_N = MMA_N * AtomShape_N * 2;\n    auto t = make_tile(make_layout(Int<TileShape_M>{}),\n                       Layout<Shape<Int<AtomShape_N>, Int<kNWarpsN>, _2>,   // (8, 2, 2) or (8, 4, 2)\n                              Stride<_1, Int<MMAStride_N>, _8> >{});       // (1, 64, 8) or (1, 32, 8)\n    // if (cute::thread0()) {printf(\"make_tiled_copy_C_warpcontiguousN \"); print(t); printf(\"\\n\");  }\n    return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), t);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Is_first, bool Is_last, bool Seq_parallel=false, typename Params>\ninline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const int bidb, const int bidh, const int n_block) {\n\n    using Element = typename Kernel_traits::Element;\n    using ElementAccum = typename Kernel_traits::ElementAccum;\n    using index_t = typename Kernel_traits::index_t;\n\n    // Shared memory.\n    extern __shared__ char smem_[];\n\n    // The thread index.\n    const int tidx = threadIdx.x;\n\n    constexpr int kBlockM = Kernel_traits::kBlockM;\n    constexpr int kBlockN = Kernel_traits::kBlockN;\n    constexpr int kHeadDim = Kernel_traits::kHeadDim;\n    constexpr int MMA_N_SdP = kBlockN / decltype(typename Kernel_traits::TiledMmaSdP{}.template tile_size_mnk<1>())::value;\n    constexpr int AtomLayoutMS = Kernel_traits::AtomLayoutMSdP;\n    constexpr bool Double_buffer = !Kernel_traits::No_double_buffer;\n\n    const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);\n    if (n_block * kBlockN >= binfo.actual_seqlen_k) return;\n\n    int m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM);\n    if (Is_local) {\n        m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + params.window_size_left, kBlockM));\n    }\n\n    const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)\n        + (m_block_max - 1) * kBlockM * params.q_row_stride + bidh * params.q_head_stride;\n    const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)\n        + n_block * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;\n    const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)\n        + n_block * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;\n    const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb)\n        + (m_block_max - 1) * kBlockM * params.do_row_stride + bidh * params.do_head_stride;\n    const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)\n        + (m_block_max - 1) * kBlockM * params.o_row_stride + bidh * params.o_head_stride;\n    const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb)\n        + (m_block_max - 1) * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride;\n    const index_t dq_accum_batch_stride = static_cast<index_t>(params.seqlen_q_rounded) * params.h * params.d_rounded;\n    const index_t dq_accum_row_stride = static_cast<index_t>(params.h) * params.d_rounded;\n    const index_t row_offset_dq_accum = binfo.q_offset(dq_accum_batch_stride, dq_accum_row_stride, bidb)\n        + ((m_block_max - 1) * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128ll * bidb)) * dq_accum_row_stride + bidh * params.d_rounded\n        // If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer.\n        + (!params.deterministic ? 0 : blockIdx.x * params.dq_accum_split_stride);\n    const index_t row_offset_lse = (params.unpadded_lse? bidh * params.total_q + binfo.q_offset(params.seqlen_q, 1, bidb): (bidb * params.h + bidh) * params.seqlen_q) + (m_block_max - 1) * kBlockM;\n    // Regarding 128 * params.b see a comment in mha_varlen_bwd about padding of dq_accum and softmax_d\n    const index_t row_offset_dpsum = (params.unpadded_lse? bidh * (params.total_q + 128 * params.b) + binfo.q_offset(params.seqlen_q_rounded, 1, bidb) + 128 * bidb: (bidb * params.h + bidh) * params.seqlen_q_rounded) + (m_block_max - 1) * kBlockM;\n\n    Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),\n                            Shape<Int<kBlockM>, Int<kHeadDim>>{},\n                            make_stride(params.q_row_stride, _1{}));\n    Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),\n                            Shape<Int<kBlockN>, Int<kHeadDim>>{},\n                            make_stride(params.k_row_stride, _1{}));\n    Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v),\n                            Shape<Int<kBlockN>, Int<kHeadDim>>{},\n                            make_stride(params.v_row_stride, _1{}));\n    Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.do_ptr) + row_offset_do),\n                             Shape<Int<kBlockM>, Int<kHeadDim>>{},\n                             make_stride(params.do_row_stride, _1{}));\n    Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),\n                            Shape<Int<kBlockM>, Int<kHeadDim>>{},\n                            make_stride(params.o_row_stride, _1{}));\n    Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dq_ptr) + row_offset_dq),\n                             Shape<Int<kBlockM>, Int<kHeadDim>>{},\n                             make_stride(params.dq_row_stride, _1{}));\n    Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dq_accum_ptr) + row_offset_dq_accum),\n                                  Shape<Int<kBlockM>, Int<kHeadDim>>{},\n                                  make_stride(params.h * params.d_rounded, _1{}));\n    Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),\n                              Shape<Int<kBlockM>>{}, Stride<_1>{});\n    Tensor gdPsum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dsoftmax_sum) + row_offset_dpsum),\n                                Shape<Int<kBlockM>>{}, Stride<_1>{});\n\n    Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),\n                            typename Kernel_traits::SmemLayoutQdO{});\n    Tensor sQt = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutQdOtransposed{});\n    Tensor sQtNoSwizzle = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutQdOtransposedNoSwizzle{});\n    // Double buffer for sQ\n    Tensor sdO = make_tensor(sQ.data() + (Double_buffer ? 2 : 1) * size(sQ), typename Kernel_traits::SmemLayoutQdO{});\n    Tensor sdOt = make_tensor(sdO.data(), typename Kernel_traits::SmemLayoutQdOtransposed{});\n    Tensor sdOtransposedNoSwizzle = make_tensor(sdO.data(),\n                                                typename Kernel_traits::SmemLayoutQdOtransposedNoSwizzle{});\n    Tensor sK = make_tensor(sdO.data() + size(sdO), typename Kernel_traits::SmemLayoutKV{});\n    Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});\n    Tensor sKt = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutKtransposed{});\n    Tensor sKtNoSwizzle = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutKtransposedNoSwizzle{});\n    Tensor sdS = make_tensor(!Kernel_traits::Is_V_in_regs ? sV.data() + size(sV) : sK.data() + size(sK),\n                             typename Kernel_traits::SmemLayoutPdS{});\n    Tensor sdSt = make_tensor(sdS.data(), typename Kernel_traits::SmemLayoutPdStransposed{});\n    Tensor sdStNoSwizzle = make_tensor(sdS.data(), typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{});\n    Tensor sP = make_tensor(sdS.data() + size(sdS), typename Kernel_traits::SmemLayoutPdS{});\n    Tensor sPt = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutPdStransposed{});\n    Tensor sPtNoSwizzle = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{});\n    // sP and sdQ share the same memory so be careful\n    Tensor sdQ = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutdQ{});\n\n    typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;\n    auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);\n    using GmemTiledCopydO = std::conditional_t<\n        Is_first,\n        typename Kernel_traits::GmemTiledCopydO,\n        typename Kernel_traits::GmemTiledCopyQKV\n    >;\n    GmemTiledCopydO gmem_tiled_copy_dO;\n    auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx);\n    typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ;\n    auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx);\n    using GmemLayoutAtomdQaccum = std::conditional_t<\n        !Seq_parallel,\n        typename Kernel_traits::GmemTiledCopydQaccum,\n        typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd\n    >;\n    GmemLayoutAtomdQaccum gmem_tiled_copy_dQaccum;\n    auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx);\n\n    Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);\n    Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);\n    Tensor tdOgdO = gmem_thr_copy_dO.partition_S(gdO);\n    Tensor tdOsdO = gmem_thr_copy_dO.partition_D(sdO);\n    Tensor tdOgO = gmem_thr_copy_dO.partition_S(gO);\n    Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK);  // (KCPY, KCPY_N, KCPY_K)\n    Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);\n    Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV);  // (VCPY, VCPY_N, VCPY_K)\n    Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);\n    Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ);    // ((Atom,AtomNum),ATOM_M,ATOM_N)\n    Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ);\n    Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum);\n    // if (cute::thread0()) { print(tdQgdQaccum.layout()); printf(\"\\n\"); }\n    // __syncthreads();\n    // if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx < 64) {\n    //     printf(\"tidx = %d, tdQgdQaccum = 0x%p\\n\", tidx, tdQgdQaccum.data());\n    // }\n\n    typename Kernel_traits::TiledMmaSdP tiled_mma_sdp;\n    auto thr_mma_sdp = tiled_mma_sdp.get_thread_slice(tidx);\n    Tensor tSrQ = thr_mma_sdp.partition_fragment_A(sQ);         // (MMA,MMA_N,MMA_K)\n    Tensor tSrK = thr_mma_sdp.partition_fragment_B(sK);         // (MMA,MMA_N,MMA_K)\n    Tensor tdPrdO = thr_mma_sdp.partition_fragment_A(sdO);      // (MMA,MMA_N,MMA_K)\n    Tensor tdPrV = thr_mma_sdp.partition_fragment_B(sV);        // (MMA,MMA_N,MMA_K)\n\n    typename Kernel_traits::TiledMmadKV tiled_mma_dkv;\n    auto thr_mma_dkv = tiled_mma_dkv.get_thread_slice(tidx);\n    Tensor tdKrdSt = thr_mma_dkv.partition_fragment_A(sdStNoSwizzle); // (MMA, MMA_N, MMA_N)\n    Tensor tdKrQt = thr_mma_dkv.partition_fragment_B(sQtNoSwizzle);   // (MMA, MMA_K, MMA_N)\n    Tensor tdVrPt = thr_mma_dkv.partition_fragment_A(sPtNoSwizzle);   // (MMA, MMA_N, MMA_N)\n    Tensor tdVrdO = thr_mma_dkv.partition_fragment_B(sdOtransposedNoSwizzle); // (MMA, MMA_K, MMA_N)\n\n    typename Kernel_traits::TiledMmadQ tiled_mma_dq;\n    auto thr_mma_dq = tiled_mma_dq.get_thread_slice(tidx);\n    Tensor tdQrdS = thr_mma_dq.partition_fragment_A(sdS);                      // (MMA, MMA_N, MMA_N)\n    Tensor tdQrKt = thr_mma_dq.partition_fragment_B(sKtNoSwizzle);    // (MMA, MMA_K, MMA_N)\n\n    Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{});  // MMA, MMA_N, MMA_K\n    Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{});  // MMA, MMA_N, MMA_K\n\n    //\n    // Copy Atom retiling\n    //\n\n    auto smem_tiled_copy_QdO = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp);\n    auto smem_thr_copy_QdO = smem_tiled_copy_QdO.get_thread_slice(tidx);\n    Tensor tSsQ = smem_thr_copy_QdO.partition_S(sQ);\n    Tensor tdPsdO = smem_thr_copy_QdO.partition_S(sdO);\n\n    // auto smem_thr_copy_KV = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp).get_thread_slice(tidx);\n    auto smem_tiled_copy_KV = make_tiled_copy_B_warpcontiguousN<MMA_N_SdP>(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp);\n    auto smem_thr_copy_KV = smem_tiled_copy_KV.get_thread_slice(tidx);\n    Tensor tSsK = smem_thr_copy_KV.partition_S(sK);\n    // if (cute::thread(0, 0) && n_block == 0) { printf(\"sK layout: \"); print(sK.layout()); printf(\"\\n\"); }\n    // if (cute::thread(0, 0) && n_block == 0) { print(tSsK.layout()); printf(\"\\n\"); }\n    Tensor tdPsV = smem_thr_copy_KV.partition_S(sV);\n\n    // Partition sP and sdS to match the accumulator partitioning\n    // This has to be tiled_mma_sdp, not tiled_mma_dkv\n    // auto smem_thr_copy_PdS = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomPdS{}, tiled_mma_sdp).get_thread_slice(tidx);\n    auto smem_tiled_copy_PdS = make_tiled_copy_C_warpcontiguousN<MMA_N_SdP>(typename Kernel_traits::SmemCopyAtomPdS{}, tiled_mma_sdp);\n    auto smem_thr_copy_PdS = smem_tiled_copy_PdS.get_thread_slice(tidx);\n    Tensor tPsP = smem_thr_copy_PdS.partition_D(sP);      // ((Atom,AtomNum),PIPE_M,PIPE_N)\n    // if (cute::thread(0, 0) && n_block == 0) { printf(\"sP layout: \"); print(sP.layout()); printf(\"\\n\"); }\n    // if (cute::thread(0, 0) && n_block == 0) { print(tPsP.layout()); printf(\"\\n\"); }\n    // if (n_block == 0 && blockIdx.x == 0 && blockIdx.y == 0 && tidx < 64) {\n    //     printf(\"tidx=%d, tPsP = 0x%p\\n\", tidx, tPsP.data());\n    // }\n    Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdS);   // ((Atom,AtomNum),PIPE_M,PIPE_N)\n\n    auto smem_tiled_copy_PdSt = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dkv);\n    auto smem_thr_copy_PdSt = smem_tiled_copy_PdSt.get_thread_slice(tidx);\n    Tensor tdVsPt = smem_thr_copy_PdSt.partition_S(sPt);\n    Tensor tdKsdSt = smem_thr_copy_PdSt.partition_S(sdSt);\n\n    auto smem_tiled_copy_QdOt = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dkv);\n    auto smem_thr_copy_QdOt = smem_tiled_copy_QdOt.get_thread_slice(tidx);\n    Tensor tdVsdOt = smem_thr_copy_QdOt.partition_S(sdOt);\n    Tensor tdKsQt = smem_thr_copy_QdOt.partition_S(sQt);\n\n    auto smem_tiled_copy_dS = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_dq);\n    auto smem_thr_copy_dS = smem_tiled_copy_dS.get_thread_slice(tidx);\n    Tensor tdQsdS = smem_thr_copy_dS.partition_S(sdS);\n\n    auto smem_tiled_copy_Kt = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dq);\n    auto smem_thr_copy_Kt = smem_tiled_copy_Kt.get_thread_slice(tidx);\n    Tensor tdQsKt = smem_thr_copy_Kt.partition_S(sKt);\n\n    auto smem_tiled_copy_dQ = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdQ{}, tiled_mma_dq);\n    auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(tidx);\n    Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ);  // ((Atom,AtomNum),PIPE_M,PIPE_N)\n\n    //\n    // PREDICATES\n    //\n\n    Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ)));    // (BLK_M,BLK_K) -> (blk_m,blk_k)\n    Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK)));    // (BLK_N,BLK_K) -> (blk_n,blk_k)\n    Tensor tQcQ = gmem_thr_copy_QKV.partition_D(cQ);\n    Tensor tKVcKV = gmem_thr_copy_QKV.partition_D(cKV);\n\n    // Allocate predicate tensors for k\n    Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));\n    Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK)));\n\n    // Set predicates for k bounds\n    if (!Is_even_K) {\n        #pragma unroll\n        for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; }\n        #pragma unroll\n        for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; }\n    }\n\n    // Prologue\n\n    // We'll advance gdQ and gdQaccum before the 1st read/write.\n    tdQgdQ.data() = tdQgdQ.data() + kBlockM * params.dq_row_stride;\n    tdQgdQaccum.data() = tdQgdQaccum.data() + kBlockM * params.h * params.d_rounded;\n\n    int m_block = m_block_max - 1;\n    int m_block_min = (!Is_causal && !Is_local)\n        ? 0\n        : std::max(0, (n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k - params.window_size_right) / kBlockM);\n    // If not local, we're guaranteed that m_block_min <= m_block:\n    // We checked earlier that n_block * kBlockN < actual_seqlen_k, so in the causal case,\n    // n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k < actual_seqlen_q.\n    // So m_block_min <= (actual_seqlen_q - 1) / kBlockM.\n    // Recall that m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM) = (actual_seqlen_q + kBlockM - 1) / kBlockM.\n    // So m_block_m - 1 = (actual_seqlen_q - 1) / kBlockM.\n    // We conclude that m_block_min <= m_block, so we will always have at least 1 iteration of the for loop.\n    // However, if local, then this possible to have some blocks of K & V not attending to any query.\n    // We might need to exit early and write 0 to dK and dV for those blocks.\n    // Otherwise we get wrong result for the case where we don't enter the for loop.\n    // And we might read OOB elements from gQ and gdO.\n    // This also covers the case where actual_seqlen_q == 0\n    if ((Is_local || !Is_even_MN) && m_block < m_block_min) {\n        const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb)\n          + n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride;\n        const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb)\n          + n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride;\n        Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dk_ptr) + row_offset_dk),\n                                 Shape<Int<kBlockN>, Int<kHeadDim>>{},\n                                 make_stride(params.dk_row_stride, _1{}));\n        Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dv_ptr) + row_offset_dv),\n                                 Shape<Int<kBlockN>, Int<kHeadDim>>{},\n                                 make_stride(params.dv_row_stride, _1{}));\n        typename Kernel_traits::GmemTiledCopydKV gmem_tiled_copy_dKV;\n        auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx);\n        Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK);\n        Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV);\n        Tensor tdKrdK = make_tensor<Element>(shape(tdKgdK));\n        Tensor tdVrdV = make_tensor<Element>(shape(tdVgdV));\n        clear(tdKrdK);\n        clear(tdVrdV);\n        Tensor cdKV = make_identity_tensor(make_shape(size<0>(gdK), size<1>(gdK)));    // (BLK_N,BLK_K) -> (blk_n,blk_k)\n        Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);\n        Tensor tdKVpdKV = make_tensor<bool>(make_shape(size<2>(tdKgdK)));\n        #pragma unroll\n        for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; }\n        // Clear_OOB_K must be false since we don't want to write zeros to gmem\n        FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(\n            gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN\n        );\n        FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(\n            gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN\n        );\n        return;\n    }\n\n    if (Double_buffer && m_block % 2 == 1) {  // Double buffer for sQ\n        tQsQ.data() = tQsQ.data() + size(sQ);\n        tSsQ.data() = tSsQ.data() + size(sQ);\n        tdKsQt.data() = tdKsQt.data() + size(sQ);\n    }\n\n    if ((!Is_first && !Seq_parallel) || params.deterministic) { __syncthreads(); }\n\n    if (Kernel_traits::Is_V_in_regs) {\n        // Clear the smem tiles to account for predicated off loads\n        FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(\n            gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN\n        );\n        FLASH_NAMESPACE::cp_async_fence();\n    }\n\n    Tensor tdOrdO = make_fragment_like(tdOgdO);\n    Tensor tdOrO = make_fragment_like(tdOgO);\n    if (!Is_first) {\n        // Clear the smem tiles to account for predicated off loads\n        FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(\n            gmem_tiled_copy_dO, tdOgdO, tdOsdO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM\n        );\n    } else {\n        FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(\n            gmem_tiled_copy_dO, tdOgdO, tdOrdO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM\n        );\n        FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(\n            gmem_tiled_copy_dO, tdOgO, tdOrO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM\n        );\n    }\n    FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(\n        gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM\n    );\n\n    Tensor caccS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{});    // (BLK_M,BLK_N) -> (blk_m,blk_n)\n    Tensor taccScS = thr_mma_sdp.partition_C(caccS);                           // (MMA,MMA_N,MMA_N)\n    static_assert(decltype(size<0>(taccScS))::value == 4);\n    // Convert to ((2, 2), MMA_N, MMA_N) then take only the row indices.\n    Tensor taccScS_row = logical_divide(taccScS, Shape<_2>{})(make_coord(0, _), _, 0);\n    Tensor lse = make_tensor<ElementAccum>(Shape<Int<decltype(size(taccScS_row))::value>>{});\n    #pragma unroll\n    for (int mi = 0; mi < size(lse); ++mi) {\n        const int row = get<0>(taccScS_row(mi));\n        lse(mi) = Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM ? gLSE(row) : INFINITY;\n    }\n    // We want LSE = inf if the row is OOB. In that case Q would be zero, K would be zero,\n    // and scores would be zero. With LSE = 0, probs will be all 1's, and when we multiply\n    // with V (which would be zero), we're fine. However, with ALiBi, we might modify these\n    // scores, and probs can become NaN. Instead if we set LSE = inf for OOB rows, probs are always 0.\n\n    // Tensor tKrK = make_fragment_like(tKsK);\n    // // cute::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, 0), tKrK);\n    // cute::copy(gmem_tiled_copy_QKV, tKgK, tKrK);\n    // // if (cute::thread(1, 0)) { print(tKrK); }\n\n    FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(\n        gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN\n    );\n    if (!Kernel_traits::Is_V_in_regs) {\n        FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(\n            gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN\n        );\n    }\n    FLASH_NAMESPACE::cp_async_fence();\n\n    // if (cute::thread0()) { print(tdOgdO.layout()); printf(\"\\n\"); print(tdOrdO); print(tdOrO); }\n    if (Is_first) {\n        cute::copy(tdOrdO, tdOsdO);\n        dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, gdPsum,\n                                                    Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);\n    }\n\n    if (Kernel_traits::Is_V_in_regs) {\n        cute::cp_async_wait<1>();\n        __syncthreads();\n        Tensor tdPrV_copy_view = smem_thr_copy_KV.retile_D(tdPrV);\n        CUTE_STATIC_ASSERT_V(size<1>(tdPsV) == size<1>(tdPrV_copy_view));            // M\n        cute::copy(smem_tiled_copy_KV, tdPsV, tdPrV_copy_view);\n    }\n\n    FLASH_NAMESPACE::Dropout dropout(params.rng_state[0], params.rng_state[1], params.p_dropout_in_uint8_t,\n                           bidb, bidh, tidx, params.h);\n\n    clear(acc_dv);\n    clear(acc_dk);\n\n    const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;\n    FLASH_NAMESPACE::Alibi<Is_causal> alibi(alibi_slope, binfo.actual_seqlen_k, binfo.actual_seqlen_q);\n\n    for (; m_block >= m_block_min; --m_block) {\n        Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{});  // (MMA=4, MMA_N, MMA_N)\n        clear(acc_s);\n        cute::cp_async_wait<0>();\n        __syncthreads();\n\n        Tensor dP_sum = make_fragment_like(lse);\n        #pragma unroll\n        for (int mi = 0; mi < size(lse); ++mi) { dP_sum(mi) = gdPsum(get<0>(taccScS_row(mi))); }\n\n        // if (cute::thread0()) { print(sK); }\n        // Tensor tSrK_copy_view = smem_thr_copy_KV.retile_D(tSrK);\n        // #pragma unroll\n        // for (int k = 0; k < size<2>(tSrK_copy_view); ++k) {\n        //     cute::copy(smem_tiled_copy_KV, tSsK(_, _, k), tSrK_copy_view(_, _, k));\n        // }\n        // if (cute::thread0()) { print(tSrK); }\n        FLASH_NAMESPACE::gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma_sdp,\n                    smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV);\n\n        if constexpr (Is_softcap) {\n            FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap);\n        }\n\n        // Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (row=(2, MMA_N), col=(2, MMA_N))\n        Tensor scores = make_tensor(acc_s.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_s.layout()));\n        // if (cute::thread(32, 0)) { print(scores); }\n\n        // Softcapping - calculating dTanh and scaling dS later with it\n        [[maybe_unused]] Tensor dtanh = make_tensor_like(scores);\n        if constexpr (Is_softcap) {\n            FLASH_NAMESPACE::calculate_dtanh(scores, dtanh, params.softcap);\n        }\n\n        // Alibi\n        if (Has_alibi) {\n            alibi.apply_alibi(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,\n                              m_block * kBlockM + get<0>(taccScS_row(0)), AtomLayoutMS * 16);\n        }\n\n        // TD [2023-07-29]: I was thinking that we don't need to mask out the elements beyond\n        // actual_seqlen_k, because acc_s would be some finite value for those indices.\n        // In the end when we multiply with K to get dQ, the corresponding values of K would be 0,\n        // so the result would still be correct.\n        // However, it's possible that the values in acc_s are so large that they overflow\n        // when we multiply with dP and convert to fp16, resulting in Inf in dS and NaNs in dQ.\n        // So we need to mask out the elements beyond actual_seqlen_k.\n        if (!Is_causal && !Is_local) {\n            if (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k) {\n                FLASH_NAMESPACE::apply_mask(scores, binfo.actual_seqlen_k,\n                                  n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16);\n            }\n        } else if (Is_causal) {\n            // Putting this causal masking right after acc_s is *much* slower for some reason.\n            // TD [2023-08-16]: We need the 2nd condition because if seqlen_q is long and seqlen_k is short\n            // (e.g., 256 and 2), the 2nd block of seqlen_q (from 128 to 255), we're not doing causal masking.\n            // But we still want to mask out elements beyond actual_seqlen_k.\n            if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k\n                || (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k)) {\n                FLASH_NAMESPACE::apply_mask_causal(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,\n                                         binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)),\n                                         binfo.actual_seqlen_q,\n                                         // binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4,\n                                         AtomLayoutMS * 16);\n            }\n        } else if (Is_local) {\n            if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k - params.window_size_right\n                || (m_block + 1) * kBlockM >= n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + params.window_size_left\n                || (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k)) {\n                FLASH_NAMESPACE::apply_mask_local(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,\n                                        binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)),\n                                        binfo.actual_seqlen_q, AtomLayoutMS * 16,\n                                        params.window_size_left, params.window_size_right);\n            }\n\n        }\n\n        // if (cute::thread(32, 0)) { print(scores); }\n        // Compute the exponential value.\n        FLASH_NAMESPACE::scale_apply_exp2</*scale_max=*/false>(scores, lse, params.scale_softmax_log2);\n        if constexpr (Is_dropout) {\n            int warp_id = tidx / 32;\n            int block_row_idx = m_block * (kBlockM / 16) + warp_id % AtomLayoutMS;\n            // Need col to be multiples of 32, since we're doing dropout with block of 16 x 32\n            static_assert(MMA_N_SdP % 2 == 0);\n            int block_col_idx = n_block * (kBlockN / 32) + (warp_id / AtomLayoutMS) * (MMA_N_SdP / 2);\n            dropout.template apply_dropout</*encode_dropout_in_sign_bit=*/true>(\n                acc_s, block_row_idx, block_col_idx, AtomLayoutMS\n            );\n        }\n        // Convert scores from fp32 to fp16/bf16\n        Tensor rP = !Is_dropout\n            ? FLASH_NAMESPACE::convert_type<Element>(acc_s)\n            : FLASH_NAMESPACE::convert_type_relu<Element>(acc_s);\n        // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_N, MMA_N / 2)\n        // if using m16n8k16 or (4, MMA_N, MMA_N) if using m16n8k8.\n        Tensor tPrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs<typename Kernel_traits::TiledMmaSdP>(rP.layout()));\n        Tensor tPaP = smem_thr_copy_PdS.retile_S(tPrP);     // ((Atom,AtomNum), MMA_N, MMA_N)\n        cute::copy(smem_tiled_copy_PdS, tPaP, tPsP);\n        // if (cute::thread0()) { print(tPaP); }\n        // __syncthreads();\n        // if (cute::thread0()) { print(sP); }\n\n        Tensor acc_dp = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{});  // (MMA=4, MMA_N, MMA_N)\n        CUTE_STATIC_ASSERT_V(size<0>(acc_dp) == size<0>(acc_s));                     // MMA\n        CUTE_STATIC_ASSERT_V(size<1>(acc_dp) == size<1>(acc_s));                     // MMA\n        CUTE_STATIC_ASSERT_V(size<2>(acc_dp) == size<2>(acc_s));                     // MMA\n\n        clear(acc_dp);\n        // Tensor acc_dp_reshaped = make_tensor(acc_dp.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_dp.layout()));\n        // #pragma unroll\n        // for (int mi = 0; mi < size<0>(acc_dp_reshaped); ++mi) {\n        //     #pragma unroll\n        //     for (int ni = 0; ni < size<1>(acc_dp_reshaped); ++ni) {\n        //         acc_dp_reshaped(mi, ni) = -dP_sum(mi);\n        //     }\n        // }\n\n        // if (cute::thread0()) { print(dP_sum); }\n\n        FLASH_NAMESPACE::gemm</*A_in_regs=*/false, /*B_in_regs=*/Kernel_traits::Is_V_in_regs>(\n            acc_dp, tdPrdO, tdPrV, tdPsdO, tdPsV, tiled_mma_sdp,\n            smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV\n        );\n\n        // Reshape acc_dp from (MMA=4, MMA_N, MMA_N) to (row=(2, MMA_N), col=(2, MMA_N))\n        Tensor dS = make_tensor(acc_dp.data(), scores.layout());\n        auto pointwise_mult = [](float p, float dp, float d) {\n            return p * (!Is_dropout || p >= 0 ? dp - d : d);\n        };\n        #pragma unroll\n        for (int mi = 0; mi < size<0>(dS); ++mi) {\n            #pragma unroll\n            for (int ni = 0; ni < size<1>(dS); ++ni) {\n                float scaled_ds = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi));\n                if constexpr (Is_softcap) { scaled_ds *= dtanh(mi, ni); }\n                dS(mi, ni) = scaled_ds;\n            }\n        }\n        // if (cute::thread0()) { print(dS); }\n\n        Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape<Int<kBlockM>, Int<kHeadDim>>{});  // MMA, MMA_N, MMA_K\n        tdQgdQaccum.data() = tdQgdQaccum.data() + (-int(kBlockM * params.h * params.d_rounded));\n        if (Is_first || Seq_parallel) {\n            clear(acc_dq);\n        } else {\n            // Reshape acc_dq from (4, 1, 2) to (4, 2, 1) to write to gdQaccum\n            Tensor acc_dq_reshaped = make_tensor(acc_dq.data(),\n                                                 make_layout(get<0>(acc_dq.layout()),\n                                                             get<2>(acc_dq.layout()),\n                                                             get<1>(acc_dq.layout())));\n            cute::copy(gmem_tiled_copy_dQaccum, tdQgdQaccum, acc_dq_reshaped);\n        }\n\n        if (Double_buffer && m_block > m_block_min) {\n            // Double buffer for sQ\n            const int sQ_offset = m_block % 2 == 0 ? size(sQ) : -size(sQ);\n            tQsQ.data() = tQsQ.data() + sQ_offset;\n            tSsQ.data() = tSsQ.data() + sQ_offset;\n            // Advance gQ\n            tQgQ.data() = tQgQ.data() + (-int(kBlockM * params.q_row_stride));\n            FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ);\n            FLASH_NAMESPACE::cp_async_fence();\n        }\n\n        Tensor dS_reshaped = make_tensor(dS.data(), acc_dp.layout());\n        // Convert dS from fp32 to fp16\n        Tensor tdSrdS = FLASH_NAMESPACE::convert_type<Element>(dS_reshaped);\n        // if (cute::thread0()) { print(tPrP); }\n        Tensor tdSadS = smem_thr_copy_PdS.retile_S(tdSrdS);                                          // ((Atom,AtomNum), MMA_N, MMA_N)\n        cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS);\n        __syncthreads();\n\n        // Layout p_l = tPrP.layout();\n        // Tensor tdVrPt = make_tensor(tPrP.data(), make_layout(get<0>(p_l), get<2>(p_l), get<1>(p_l)));\n        // FLASH_NAMESPACE::gemm_rs(acc_dv, tdVrPt, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_thr_copy_QdOt);\n        // Tensor tdKrdSt = make_tensor(tdSrdS.data(), tdVrPt.layout());\n        // FLASH_NAMESPACE::gemm_rs(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_thr_copy_QdOt);\n        FLASH_NAMESPACE::gemm(acc_dv, tdVrPt, tdVrdO, tdVsPt, tdVsdOt, tiled_mma_dkv,\n                    smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt);\n        // if (cute::thread0() && n_block == 0 && m_block == 0) { print(tdVrPt); }\n        // if (cute::thread0()) { print(acc_dv); }\n\n        __syncthreads(); // Need syncthreads since we're writing to the same sdO location\n\n        if (m_block > m_block_min) {\n            // Advance gdO\n            tdOgdO.data() = tdOgdO.data() + (-int(kBlockM * params.do_row_stride));\n            if (Is_first) {\n                tdOgO.data() = tdOgO.data() + (-int(kBlockM * params.o_row_stride));\n                FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_dO, tdOgdO, tdOrdO, tQcQ, tQpQ);\n                FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_dO, tdOgO, tdOrO, tQcQ, tQpQ);\n            } else {\n                FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_dO, tdOgdO, tdOsdO, tQcQ, tQpQ);\n                FLASH_NAMESPACE::cp_async_fence();\n            }\n        }\n\n        FLASH_NAMESPACE::gemm(acc_dq, tdQrdS, tdQrKt, tdQsdS, tdQsKt, tiled_mma_dq,\n                    smem_tiled_copy_dS, smem_tiled_copy_Kt, smem_thr_copy_dS, smem_thr_copy_Kt);\n        // if (cute::thread0()) { print(acc_dq); }\n\n        if (m_block > m_block_min) {\n            gLSE.data() = gLSE.data() + (-int(kBlockM));\n            #pragma unroll\n            for (int mi = 0; mi < size(lse); ++mi) { lse(mi) = gLSE(get<0>(taccScS_row(mi))); }\n            gdPsum.data() = gdPsum.data() + (-int(kBlockM));\n        }\n\n        if (!Is_last) {\n            // Reshape acc_dq from (4, 1, 2) to (4, 2, 1) to write to gdQaccum\n            Tensor acc_dq_reshaped = make_tensor(acc_dq.data(),\n                                                 make_layout(get<0>(acc_dq.layout()),\n                                                             get<2>(acc_dq.layout()),\n                                                             get<1>(acc_dq.layout())));\n            if (!Seq_parallel) {\n                cute::copy(gmem_tiled_copy_dQaccum, acc_dq_reshaped, tdQgdQaccum);\n            } else {\n                // if (cute::thread0()) { print(acc_dq.layout()); printf(\"\\n\"); print(acc_dq_reshaped.layout()); printf(\"\\n\"); print(tdQgdQaccum.layout()); printf(\"\\n\"); }\n                CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQgdQaccum));\n                #pragma unroll\n                for (int i = 0; i < size(acc_dq); ++i) { atomicAdd(&tdQgdQaccum(i), acc_dq(i)); }\n            }\n        } else {\n            #pragma unroll\n            for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax_rp_dropout; }\n            // Convert acc_dq from fp32 to fp16\n            Tensor rdQ = FLASH_NAMESPACE::convert_type<Element>(acc_dq);\n            Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ);  // ((Atom,AtomNum), MMA_N, MMA_N)\n            cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ);\n        }\n\n        FLASH_NAMESPACE::gemm(acc_dk, tdKrdSt, tdKrQt, tdKsdSt, tdKsQt, tiled_mma_dkv,\n                    smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt);\n        // if (cute::thread0()) { print(acc_dk); }\n        if (Double_buffer) {  // Double buffer for sQ\n            tdKsQt.data() = tdKsQt.data() + (m_block % 2 == 0 ? size(sQ) : -size(sQ));\n        }\n        if (!Double_buffer && m_block > m_block_min) {\n            __syncthreads();\n            // Advance gQ\n            tQgQ.data() = tQgQ.data() + (-int(kBlockM * params.q_row_stride));\n            FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ);\n            FLASH_NAMESPACE::cp_async_fence();\n        }\n\n        if (Is_first && m_block > m_block_min) {\n            cute::copy(tdOrdO, tdOsdO);\n            dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, gdPsum,\n                                                        Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);\n        }\n\n        if (Is_last) {\n            __syncthreads();\n            Tensor tdQrdQ = make_tensor<Element>(shape(tdQgdQ));\n            cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ);\n            tdQgdQ.data() = tdQgdQ.data() + (-int(kBlockM * params.dq_row_stride));\n            Tensor cdQ = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{});    // (BLK_M,BLK_K) -> (blk_m,blk_k)\n            Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ);\n            #pragma unroll\n            for (int m = 0; m < size<1>(tdQgdQ); ++m) {\n                if (Is_even_MN || get<0>(tdQcdQ(0, m, 0)) < binfo.actual_seqlen_q - m_block * kBlockM) {\n                    cute::copy(gmem_tiled_copy_dQ, tdQrdQ(_, m, _), tdQgdQ(_, m, _));\n                }\n            }\n        }\n\n    }\n\n    // Epilogue\n\n    if (Is_dropout) {\n        #pragma unroll\n        for (int i = 0; i < size(acc_dv); ++i) { acc_dv(i) *= params.rp_dropout; }\n    }\n    #pragma unroll\n    for (int i = 0; i < size(acc_dk); ++i) { acc_dk(i) *= params.scale_softmax_rp_dropout; }\n\n    // Convert acc_dv from fp32 to fp16\n    Tensor rdK = FLASH_NAMESPACE::convert_type<Element>(acc_dk);\n    Tensor rdV = FLASH_NAMESPACE::convert_type<Element>(acc_dv);\n\n    Tensor sdK = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutdKV{});  // (SMEM_N, SMEM_K)\n    Tensor sdV = make_tensor(sdK.data() + size(sdK), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K)\n\n    // Partition sdV and sdK to match the accumulator partitioning\n    auto smem_tiled_copy_dKV = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdKV{}, tiled_mma_dkv);\n    auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(tidx);\n    Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(rdK);       // ((Atom,AtomNum), MMA_N, MMA_N)\n    Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK);   // ((Atom,AtomNum),PIPE_M,PIPE_N)\n    Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(rdV);       // ((Atom,AtomNum), MMA_N, MMA_N)\n    Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV);    // ((Atom,AtomNum),PIPE_M,PIPE_N)\n\n    // We need syncthreads here since we're writing to the same location as sK and sV.\n    // Without syncthreads, some thread might modify the location of sK while another thread\n    // is reading it for dQ gemm, leading to a race condition.\n    // If Is_last, there's already a __syncthreads() at the end of the loop.\n    if (!Is_last) { __syncthreads(); }\n\n    cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK);\n    cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV);\n\n    const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb)\n       + n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride;\n    const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb)\n       + n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride;\n    Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dk_ptr) + row_offset_dk),\n                             Shape<Int<kBlockN>, Int<kHeadDim>>{},\n                             make_stride(params.dk_row_stride, _1{}));\n    Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dv_ptr) + row_offset_dv),\n                             Shape<Int<kBlockN>, Int<kHeadDim>>{},\n                             make_stride(params.dv_row_stride, _1{}));\n\n    typename Kernel_traits::GmemTiledCopydKV gmem_tiled_copy_dKV;\n    auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx);\n    Tensor tdKsdK = gmem_thr_copy_dKV.partition_S(sdK);   // ((Atom,AtomNum),ATOM_M,ATOM_N)\n    Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK);\n    Tensor tdVsdV = gmem_thr_copy_dKV.partition_S(sdV);   // ((Atom,AtomNum),ATOM_M,ATOM_N)\n    Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV);\n\n    __syncthreads();\n    Tensor tdKrdK = make_tensor<Element>(shape(tdKgdK));\n    cute::copy(gmem_tiled_copy_dKV, tdKsdK, tdKrdK);\n    Tensor tdVrdV = make_tensor<Element>(shape(tdVgdV));\n    cute::copy(gmem_tiled_copy_dKV, tdVsdV, tdVrdV);\n    Tensor cdKV = make_identity_tensor(make_shape(size<0>(sdK), size<1>(sdK)));    // (BLK_N,BLK_K) -> (blk_n,blk_k)\n    Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);\n    Tensor tdKVpdKV = make_tensor<bool>(make_shape(size<2>(tdKgdK)));\n    #pragma unroll\n    for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; }\n    // Clear_OOB_K must be false since we don't want to write zeros to gmem\n    FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(\n        gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN\n    );\n    FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(\n        gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN\n    );\n\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K, typename Params>\ninline __device__ void compute_dq_dk_dv(const Params &params) {\n\n    // The block index for the batch.\n    const int bidb = blockIdx.x;\n    // const int bidb = blockIdx.y;\n    // The block index for the head.\n    const int bidh = blockIdx.y;\n    // const int bidh = blockIdx.z;\n    // The thread index.\n    const int tidx = threadIdx.x;\n\n    const int n_block_max = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;\n    if (n_block_max == 1) {\n        compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K, true, true>(params, bidb, bidh, 0);\n    } else {\n        // Iterating backward from n_block_max - 1 to 0 might save 1 register\n        compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K, true, false>(params, bidb, bidh, n_block_max - 1);\n        for (int n_block = n_block_max - 2; n_block > 0; n_block--) {\n            compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K, false, false>(params, bidb, bidh, n_block);\n        }\n        compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K, false, true>(params, bidb, bidh, 0);\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, typename Params>\ninline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) {\n\n    // The block index for the batch.\n    const int bidb = blockIdx.y;\n    // The block index for the head.\n    const int bidh = blockIdx.z;\n\n    // If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer.\n    for (int n_block = blockIdx.x; n_block < (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; n_block += gridDim.x) {\n        compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, n_block);\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n} // namespace flash\n"
  },
  {
    "path": "csrc/flash_attn/src/flash_bwd_launch_template.h",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n#include \"namespace_config.h\"\n#include <c10/cuda/CUDAException.h>  // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK\n\n#include \"static_switch.h\"\n#include \"hardware_info.h\"\n#include \"flash.h\"\n#include \"flash_bwd_preprocess_kernel.h\"\n#include \"flash_bwd_kernel.h\"\n\nnamespace FLASH_NAMESPACE {\n\n// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800\n#define ARCH_SUPPORTS_FLASH\n#define KERNEL_PARAM_MODIFIER __grid_constant__\n#else\n#define KERNEL_PARAM_MODIFIER\n#endif\n\n// Define a macro for unsupported architecture handling to centralize the error message\n#define FLASH_UNSUPPORTED_ARCH printf(\"FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!\");\n\n// Use a macro to clean up kernel definitions\n#define DEFINE_FLASH_BACKWARD_KERNEL(kernelName, ...) \\\ntemplate<typename Kernel_traits, __VA_ARGS__> \\\n__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_bwd_params params)\n\nDEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K) {\n    #if defined(ARCH_SUPPORTS_FLASH)\n       FLASH_NAMESPACE::compute_dq_dk_dv<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K>(params);\n    #else\n        FLASH_UNSUPPORTED_ARCH\n    #endif\n}\n\nDEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {\n    #if defined(ARCH_SUPPORTS_FLASH)\n        static_assert(!(Is_causal && Is_local));  // If Is_local is true, Is_causal should be false\n        FLASH_NAMESPACE::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params);\n    #else\n        FLASH_UNSUPPORTED_ARCH\n    #endif\n}\n\n\ntemplate<bool Clear_dQaccum=true, typename Kernel_traits>\n__global__ void flash_bwd_dot_do_o_kernel(const Flash_bwd_params params) {\n    FLASH_NAMESPACE::compute_dot_do_o<Clear_dQaccum, Kernel_traits>(params);\n}\n\ntemplate<typename Kernel_traits>\n__global__ void flash_bwd_clear_dkvaccum_kernel(const Flash_bwd_params params) {\n    FLASH_NAMESPACE::clear_dKVaccum<Kernel_traits>(params);\n}\n\ntemplate<typename Kernel_traits>\n__global__ void flash_bwd_convert_dq_kernel(const Flash_bwd_params params, const int nsplits) {\n    FLASH_NAMESPACE::convert_dQ<Kernel_traits>(params, nsplits);\n}\n\ntemplate<typename Kernel_traits>\n__global__ void flash_bwd_convert_dkv_kernel(const Flash_bwd_params params) {\n    FLASH_NAMESPACE::convert_dKV<Kernel_traits>(params);\n}\n\ntemplate<typename Kernel_traits, bool Is_dropout, bool Is_causal>\nvoid run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream) {\n    const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;\n    dim3 grid_m(num_m_block, params.b, params.h);\n    const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;\n    int gridDimx = num_n_block;\n    if (params.deterministic) {\n        int num_sm = get_num_sm(get_current_device());\n        gridDimx = (num_sm + params.b * params.h - 1) / (params.b * params.h);\n    }\n    dim3 grid_n(gridDimx, params.b, params.h);\n\n    if (!params.deterministic) {\n        flash_bwd_dot_do_o_kernel<true, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);\n    } else {\n        flash_bwd_dot_do_o_kernel<false, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);\n    }\n    C10_CUDA_KERNEL_LAUNCH_CHECK();\n\n    // We want to specialize to is_even_MN and not just is_even_M, since in the case where N is not\n    // a multiple of kBlockN, we'll need to apply mask in the loop.\n    const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_traits::kBlockM == 0 && params.seqlen_k % Kernel_traits::kBlockN == 0;\n    const bool is_even_K = params.d == Kernel_traits::kHeadDim;\n    constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1colblock;\n    // printf(\"smem_size_dq_dk_dv = %d\\n\", smem_size_dq_dk_dv);\n    BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {\n        EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {\n            LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {\n                ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {\n                    SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {\n                        // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.\n                        // If head dim > 128, set IsEvenMNConst to false to reduce number of templates\n                        // If Is_local, set Is_causal to false\n                        auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !Has_alibi && Kernel_traits::kHeadDim <= 128, IsEvenKConst && !Has_alibi, Is_softcap>;\n                        // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;\n                        if (smem_size_dq_dk_dv >= 48 * 1024)  {\n                            C10_CUDA_CHECK(cudaFuncSetAttribute(\n                                kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));\n                        }\n                        kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);\n                        C10_CUDA_KERNEL_LAUNCH_CHECK();\n                    });\n                });\n            });\n        });\n    });\n\n    auto kernel_dq = &flash_bwd_convert_dq_kernel<Kernel_traits>;\n    if (Kernel_traits::kSmemdQSize >= 48 * 1024)  {\n        C10_CUDA_CHECK(cudaFuncSetAttribute(\n            kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize));\n    }\n    kernel_dq<<<grid_m, Kernel_traits::kNThreads, Kernel_traits::kSmemdQSize, stream>>>(params, !params.deterministic ? 1 : gridDimx);\n    C10_CUDA_KERNEL_LAUNCH_CHECK();\n}\n\ntemplate<typename Kernel_traits, bool Is_dropout, bool Is_causal>\nvoid run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {\n#ifndef FLASHATTENTION_DISABLE_BACKWARD\n    run_flash_bwd_seqk_parallel<Kernel_traits, Is_dropout, Is_causal>(params, stream);\n#endif\n}\n\ntemplate<typename T, bool Is_causal>\nvoid run_mha_bwd_hdim32(Flash_bwd_params &params, cudaStream_t stream) {\n    constexpr static int Headdim = 32;\n    int device;\n    cudaGetDevice(&device);\n    int max_smem_per_block;\n    cudaError status_ = cudaDeviceGetAttribute(\n        &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);\n    if (status_ != cudaSuccess) {\n      C10_CUDA_CHECK(status_);\n    }\n    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {\n        if (max_smem_per_block >= 2 * ((3 * 128 + 2 * 128) * Headdim + 2 * 128 * 128)) { // 104 KB\n            if constexpr(!Is_dropout) {  // We can afford more registers to keep V in registers\n                run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);\n            } else {\n                run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);\n            }\n        } else {  // 96 KB\n            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);\n        }\n    });\n}\n\ntemplate<typename T, bool Is_causal>\nvoid run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {\n    constexpr static int Headdim = 64;\n    int device;\n    cudaGetDevice(&device);\n    int max_smem_per_block;\n    cudaError status_ = cudaDeviceGetAttribute(\n        &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);\n    if (status_ != cudaSuccess) {\n      C10_CUDA_CHECK(status_);\n    }\n    // printf(\"max_smem_per_block = %d\\n\", max_smem_per_block);\n    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {\n        // Changing AtomLayoutMdQ from 2 to 4 takes the same time\n        // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>>(params, stream);\n        // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>>(params, stream);\n        // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 2, 4, 4, false, false, T>>(params, stream);\n        // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream);\n        // This is slightly faster. We want to split M more so we need fewer registers to store LSE.\n        if (max_smem_per_block >= 144 * 1024) {\n            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);\n            // This has a lot of register spilling\n            // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream);\n        } else {\n            // if (params.h == params.h_k) {\n                // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream);\n            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);\n                // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream);\n                // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>, Is_dropout>(params, stream);\n            // } else {\n            // }\n        }\n    });\n    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>>(params, stream);\n    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, 2, 2, 2, true, false, T>>(params, stream);\n    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 4, 1, 4, 1, false, false, T>>(params, stream);\n    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 16, 128, 4, 1, 4, 1, false, false, T>>(params, stream);\n    // M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times\n    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 2, 2, 2, false, T>>(params, stream);\n    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream);\n    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream);\n\n    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 4, 4, 2, 4, false, false, T>>(params, stream);\n}\n\ntemplate<typename T, bool Is_causal>\nvoid run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream) {\n    constexpr static int Headdim = 96;\n    int device;\n    cudaGetDevice(&device);\n    int max_smem_per_block;\n    cudaError status_ = cudaDeviceGetAttribute(\n        &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);\n    if (status_ != cudaSuccess) {\n      C10_CUDA_CHECK(status_);\n    }\n    // printf(\"max_smem_per_block = %d\\n\", max_smem_per_block);\n    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {\n        if (max_smem_per_block >= 116 * 1024) {\n            if constexpr(!Is_dropout) {  // 92KB\n                run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);\n            } else {  // 116 KB\n                // This is faster for dropout since we don't have many registers to spare\n                run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);\n            }\n        } else {\n            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);\n        }\n    });\n}\n\ntemplate<typename T, bool Is_causal>\nvoid run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {\n    constexpr static int Headdim = 128;\n    int device;\n    cudaGetDevice(&device);\n    int max_smem_per_block;\n    cudaError status_ = cudaDeviceGetAttribute(\n        &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);\n    if (status_ != cudaSuccess) {\n      C10_CUDA_CHECK(status_);\n    }\n    // printf(\"max_smem_per_block = %d\\n\", max_smem_per_block);\n    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {\n        // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 8, 2, 2, 2, false, false, T>>(params, stream);\n        // This is faster, in the case of sequence-parallel bwd (where we need fewer registers).\n        // Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why.\n        // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 2, 2, false, false, T>>(params, stream);\n        if (max_smem_per_block >= 144 * 1024) {\n            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>, Is_dropout, Is_causal>(params, stream);\n            // run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);\n            // run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream);\n            // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>, Is_dropout>(params, stream);\n            // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);\n            // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream);\n        } else {\n            // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);\n            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, false, T>, Is_dropout, Is_causal>(params, stream);\n        }\n        // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>>(params, stream);\n\n        // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 4, 4, false, false, T>>(params, stream);\n    });\n}\n\ntemplate<typename T, bool Is_causal>\nvoid run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream) {\n    constexpr static int Headdim = 192;\n    int device;\n    cudaGetDevice(&device);\n    int max_smem_per_block;\n    cudaError status_ = cudaDeviceGetAttribute(\n        &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);\n    if (status_ != cudaSuccess) {\n      C10_CUDA_CHECK(status_);\n    }\n    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {\n        if (max_smem_per_block >= 136 * 1024) {\n            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout, Is_causal>(params, stream);\n        } else {\n            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, true, T>, Is_dropout, Is_causal>(params, stream);\n        }\n    });\n}\n\ntemplate<typename T, bool Is_causal>\nvoid run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {\n    constexpr static int Headdim = 256;\n    int device;\n    cudaGetDevice(&device);\n    int max_smem_per_block;\n    cudaError status_ = cudaDeviceGetAttribute(\n        &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);\n    if (status_ != cudaSuccess) {\n      C10_CUDA_CHECK(status_);\n    }\n    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {\n        if (max_smem_per_block >= 176 * 1024) {  // H100\n            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout, Is_causal>(params, stream);\n        } else if (max_smem_per_block >= 144 * 1024) {  // A100, we don't do double buffering to save smem\n            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_dropout, Is_causal>(params, stream);\n        } else { // sm86 and sm89, max smem is 99 KB. Only works without dropout. V in regs and no double buffering.\n            if constexpr (!Is_dropout) {\n                run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 32, 8, 4, 1, 2, true, true, T>, false, Is_causal>(params, stream);\n            }\n        }\n    });\n}\n\n} // namespace FLASH_NAMESPACE {\n"
  },
  {
    "path": "csrc/flash_attn/src/flash_bwd_preprocess_kernel.h",
    "content": "/***************************************************************************************************\n * Copyright (c) 2024, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n#include \"namespace_config.h\"\n#include <cute/tensor.hpp>\n\n#include <cutlass/cutlass.h>\n#include <cutlass/array.h>\n#include <cutlass/numeric_types.h>\n\n#include \"block_info.h\"\n#include \"kernel_traits.h\"\n#include \"utils.h\"\n\nnamespace FLASH_NAMESPACE {\n\nusing namespace cute;\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int THREADS_PER_ROW, typename Engine0, typename Layout0, typename Engine1, typename Layout1>\ninline __device__ void dot_do_o(Tensor<Engine0, Layout0> const &do_, Tensor<Engine0, Layout0> const &o,\n                                Tensor<Engine1, Layout1> &dP_sum, const int gdP_col_stride, const float scale) {\n    static_assert(Layout0::rank == 3, \"Only support 3D Tensor\");\n    static_assert(Layout1::rank == 1, \"Only support 1D Tensor\");\n    CUTE_STATIC_ASSERT_V(do_.layout() == o.layout());\n    // Reshape do_ and o from (8, kBlockM / 32, kHeadDim / 64) to (kBlockM / 32, 8 * kHeadDim / 64)\n    // The last coordinate is the \"page\".\n    Tensor do_reshaped = make_tensor(do_.data(), make_layout(get<1>(do_.layout()),\n                                                             make_layout(get<0>(do_.layout()),\n                                                                         get<2>(do_.layout()))));\n    Tensor o_reshaped = make_tensor(o.data(), do_reshaped.layout());\n    Tensor do_fp32 = FLASH_NAMESPACE::convert_type<float>(do_reshaped);\n    Tensor o_fp32 = FLASH_NAMESPACE::convert_type<float>(o_reshaped);\n    #pragma unroll\n    for (int mi = 0; mi < size<0>(do_reshaped); ++mi) {\n        float dP_sum_cur = do_fp32(mi, 0) * o_fp32(mi, 0);\n        #pragma unroll\n        for (int ni = 1; ni < size<1>(do_reshaped); ni++) {\n            dP_sum_cur += do_fp32(mi, ni) * o_fp32(mi, ni);\n        }\n        FLASH_NAMESPACE::SumOp<float> sum_op;\n        dP_sum_cur = FLASH_NAMESPACE::Allreduce<THREADS_PER_ROW>::run(dP_sum_cur, sum_op) * scale;\n        if (threadIdx.x % THREADS_PER_ROW == 0) {\n            dP_sum(mi * gdP_col_stride + threadIdx.x / THREADS_PER_ROW) = dP_sum_cur;\n        }\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n// Just compute dot(do, o) and write the result (softmax_d) to global memory as a separate kernel.\n// This is used in the case where we want to parallelize the backward across seqlen_k.\ntemplate<bool Clear_dQaccum=true, typename Kernel_traits, typename Params>\ninline __device__ void compute_dot_do_o(const Params &params) {\n    using Element = typename Kernel_traits::Element;\n    using ElementAccum = typename Kernel_traits::ElementAccum;\n    using index_t = typename Kernel_traits::index_t;\n\n    const int m_block = blockIdx.x;\n    // The block index for the batch.\n    const int bidb = blockIdx.y;\n    // The block index for the head.\n    const int bidh = blockIdx.z;\n    // The thread index.\n    const int tidx = threadIdx.x;\n\n    constexpr int kBlockM = Kernel_traits::kBlockM;\n    constexpr int kHeadDim = Kernel_traits::kHeadDim;\n\n    const BlockInfo binfo(params, bidb);\n    if (m_block * kBlockM >= binfo.actual_seqlen_q) return;\n\n    const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb)\n        + m_block * kBlockM * params.do_row_stride + bidh * params.do_head_stride;\n    const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)\n        + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;\n    const index_t dq_accum_batch_stride = static_cast<index_t>(params.seqlen_q_rounded) * params.h * params.d_rounded;\n    const index_t dq_accum_row_stride = static_cast<index_t>(params.h) * params.d_rounded;\n    const index_t row_offset_dq_accum = binfo.q_offset(dq_accum_batch_stride, dq_accum_row_stride, bidb)\n        + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128ll * bidb)) * dq_accum_row_stride + bidh * params.d_rounded;\n    // Regarding 128 * params.b see a comment in mha_varlen_bwd about padding of dq_accum and softmax_d\n    const index_t row_offset_dpsum = (params.unpadded_lse ? (bidh * (params.total_q + 128 * params.b) + binfo.q_offset(params.seqlen_q_rounded, 1, bidb) + 128 * bidb): (bidb * params.h + bidh) * params.seqlen_q_rounded) + m_block * kBlockM;\n\n    Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.do_ptr) + row_offset_do),\n                             Shape<Int<kBlockM>, Int<kHeadDim>>{},\n                             make_stride(params.do_row_stride, _1{}));\n    Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),\n                            Shape<Int<kBlockM>, Int<kHeadDim>>{},\n                            make_stride(params.o_row_stride, _1{}));\n    Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dq_accum_ptr) + row_offset_dq_accum),\n                                  Shape<Int<kBlockM>, Int<kHeadDim>>{},\n                                  make_stride(params.h * params.d_rounded, _1{}));\n    Tensor dP_sum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dsoftmax_sum) + row_offset_dpsum),\n                                Shape<Int<kBlockM>>{}, Stride<_1>{});\n\n    typename Kernel_traits::GmemTiledCopydO gmem_tiled_copy_dO;\n    auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx);\n    // TODO: careful, we're zeroing out dQaccum with type float4, but when\n    // we do atomicAdds, we use type float. The layouts are different. Check this.\n    typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dQaccum;\n    auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx);\n\n    Tensor tdOgdO = gmem_thr_copy_dO.partition_S(gdO);\n    Tensor tdOgO = gmem_thr_copy_dO.partition_S(gO);\n    Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum);\n\n    Tensor cdO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{});    // (BLK_M,BLK_K) -> (blk_m,blk_k)\n    Tensor tdOcdO = gmem_thr_copy_dO.partition_S(cdO);\n\n    // Allocate predicate tensors for k\n    Tensor tdOpdO = make_tensor<bool>(make_shape(size<2>(tdOgdO)));\n    // Set predicates for k bounds\n    #pragma unroll\n    for (int k = 0; k < size(tdOpdO); ++k) {tdOpdO(k) = get<1>(tdOcdO(0, 0, k)) < params.d;}\n\n    Tensor tdOrdO = make_fragment_like(tdOgdO);\n    Tensor tdOrO = make_fragment_like(tdOgO);\n    FLASH_NAMESPACE::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true>(\n        gmem_tiled_copy_dO, tdOgdO, tdOrdO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM\n    );\n    FLASH_NAMESPACE::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true>(\n        gmem_tiled_copy_dO, tdOgO, tdOrO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM\n    );\n    // By right we need to scale dP up by 1/p_dropout, but instead we don't and only scale the final\n    // results (dQ and dK) by 1/p_dropout. So we need to keep dP_sum scaled down by p_dropout here,\n    // so that (dP - dP_sum) is on the same scale.\n    dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, dP_sum,\n                                                Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);\n    if (Clear_dQaccum) {\n        // We're actually not zero'ing out all of dQaccum, but only the part that we're going to\n        // do atomicAdds on.\n        Tensor zero = make_fragment_like(tdQgdQaccum);\n        clear(zero);\n        cute::copy(gmem_tiled_copy_dQaccum, zero, tdQgdQaccum);\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename Kernel_traits, typename Params>\ninline __device__ void clear_dKVaccum(const Params &params) {\n    using ElementAccum = typename Kernel_traits::ElementAccum;\n    using index_t = typename Kernel_traits::index_t;\n\n    const int n_block = blockIdx.x;\n    // The block index for the batch.\n    const int bidb = blockIdx.y;\n    // The block index for the head.\n    const int bidh = blockIdx.z;\n    // The thread index.\n    const int tidx = threadIdx.x;\n\n    constexpr int kBlockN = Kernel_traits::kBlockN;\n    constexpr int kHeadDim = Kernel_traits::kHeadDim;\n\n    const BlockInfo binfo(params, bidb);\n    if (n_block * kBlockN >= binfo.actual_seqlen_k) return;\n\n    const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded + n_block * kBlockN) * params.d_rounded;\n\n    Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dk_accum_ptr) + row_offset_dkv_accum),\n                                  Shape<Int<kBlockN>, Int<kHeadDim>>{}, Stride<Int<kHeadDim>, _1>{});\n    Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dv_accum_ptr) + row_offset_dkv_accum),\n                                  Shape<Int<kBlockN>, Int<kHeadDim>>{}, Stride<Int<kHeadDim>, _1>{});\n\n    typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dKVaccum;\n    auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx);\n    Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_D(gdKaccum);\n    Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_D(gdVaccum);\n    Tensor zero = make_fragment_like(tdKgdKaccum);\n    clear(zero);\n    cute::copy(gmem_tiled_copy_dKVaccum, zero, tdKgdKaccum);\n    cute::copy(gmem_tiled_copy_dKVaccum, zero, tdVgdVaccum);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n// Convert dQ from dQaccum (in float) to fp16/bf16.\n// This is used in the case where we want to parallelize the backward across seqlen_k.\ntemplate<typename Kernel_traits, typename Params>\ninline __device__ void convert_dQ(const Params &params, const int nsplits) {\n    using Element = typename Kernel_traits::Element;\n    using ElementAccum = typename Kernel_traits::ElementAccum;\n    using index_t = typename Kernel_traits::index_t;\n\n    // Shared memory.\n    extern __shared__ char smem_[];\n\n    const int m_block = blockIdx.x;\n    // The block index for the batch.\n    const int bidb = blockIdx.y;\n    // The block index for the head.\n    const int bidh = blockIdx.z;\n    // The thread index.\n    const int tidx = threadIdx.x;\n\n    constexpr int kBlockM = Kernel_traits::kBlockM;\n    constexpr int kHeadDim = Kernel_traits::kHeadDim;\n\n    const BlockInfo binfo(params, bidb);\n    if (m_block * kBlockM >= binfo.actual_seqlen_q) return;\n\n    const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb)\n        + m_block * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride;\n    const index_t dq_accum_batch_stride = static_cast<index_t>(params.seqlen_q_rounded) * params.h * params.d_rounded;\n    const index_t dq_accum_row_stride = static_cast<index_t>(params.h) * params.d_rounded;\n    const index_t row_offset_dq_accum = binfo.q_offset(dq_accum_batch_stride, dq_accum_row_stride, bidb)\n        + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128ll * bidb)) * dq_accum_row_stride + bidh * params.d_rounded;\n\n    Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dq_ptr) + row_offset_dq),\n                             Shape<Int<kBlockM>, Int<kHeadDim>>{},\n                             make_stride(params.dq_row_stride, _1{}));\n    Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dq_accum_ptr) + row_offset_dq_accum),\n                                  Shape<Int<kBlockM>, Int<kHeadDim>>{},\n                                  make_stride(params.h * params.d_rounded, _1{}));\n\n    Tensor sdQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),\n                             typename Kernel_traits::SmemLayoutdQ{});\n\n    typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ;\n    auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx);\n    typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dQaccum;\n    auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx);\n\n    typename Kernel_traits::TiledMmadQ tiled_mma_dq;\n    auto smem_tiled_copy_dQ = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdQ{}, tiled_mma_dq);\n    auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(tidx);\n    Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ);  // ((Atom,AtomNum),PIPE_M,PIPE_N)\n\n    Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ);    // ((Atom,AtomNum),ATOM_M,ATOM_N)\n    Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ);\n    Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum);\n\n    Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape<Int<kBlockM>, Int<kHeadDim>>{});  // MMA, MMA_N, MMA_K\n    CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQgdQaccum));\n\n    Tensor tdQrdQaccum = make_fragment_like(tdQgdQaccum);\n    clear(acc_dq);\n    for (int s = 0; s < nsplits; ++s) {\n        cute::copy(gmem_tiled_copy_dQaccum, tdQgdQaccum, tdQrdQaccum);\n        #pragma unroll\n        for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) += tdQrdQaccum(i); }\n        tdQgdQaccum.data() = tdQgdQaccum.data() + params.dq_accum_split_stride;\n    }\n    #pragma unroll\n    for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax_rp_dropout; }\n    // Convert acc_dq from fp32 to fp16\n    Tensor rdQ = FLASH_NAMESPACE::convert_type<Element>(acc_dq);\n    Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ);  // ((Atom,AtomNum), MMA_N, MMA_N)\n    cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ);\n    __syncthreads();\n    Tensor tdQrdQ = make_tensor<Element>(shape(tdQgdQ));\n    cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ);\n\n    Tensor cdQ = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{});    // (BLK_M,BLK_K) -> (blk_m,blk_k)\n    Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ);\n    Tensor tdQpdQ = make_tensor<bool>(make_shape(size<2>(tdQgdQ)));\n    #pragma unroll\n    for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(0, 0, k)) < params.d; }\n    // Clear_OOB_K must be false since we don't want to write zeros to gmem\n    FLASH_NAMESPACE::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(\n        gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, binfo.actual_seqlen_q - m_block * kBlockM\n    );\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n// Convert dK and dV from dKaccum and dVaccum (in float) to fp16/bf16.\n// This is used in the case where we want to parallelize the backward across seqlen_q.\ntemplate<typename Kernel_traits, typename Params>\ninline __device__ void convert_dKV(const Params &params) {\n    using Element = typename Kernel_traits::Element;\n    using ElementAccum = typename Kernel_traits::ElementAccum;\n    using index_t = typename Kernel_traits::index_t;\n\n    // Shared memory.\n    extern __shared__ char smem_[];\n\n    const int n_block = blockIdx.x;\n    // The block index for the batch.\n    const int bidb = blockIdx.y;\n    // The block index for the head.\n    const int bidh = blockIdx.z;\n    // The thread index.\n    const int tidx = threadIdx.x;\n\n    constexpr int kBlockN = Kernel_traits::kBlockN;\n    constexpr int kHeadDim = Kernel_traits::kHeadDim;\n\n    const BlockInfo binfo(params, bidb);\n    if (n_block * kBlockN >= binfo.actual_seqlen_k) return;\n\n    const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb)\n        + n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride;\n    const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb)\n        + n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride;\n    const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded\n                                          + n_block * kBlockN) * params.d_rounded;\n\n    Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dk_ptr) + row_offset_dk),\n                             Shape<Int<kBlockN>, Int<kHeadDim>>{},\n                             make_stride(params.dk_row_stride, _1{}));\n    Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dv_ptr) + row_offset_dv),\n                             Shape<Int<kBlockN>, Int<kHeadDim>>{},\n                             make_stride(params.dv_row_stride, _1{}));\n    Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dk_accum_ptr) + row_offset_dkv_accum),\n                                  Shape<Int<kBlockN>, Int<kHeadDim>>{},\n                                  Stride<Int<kHeadDim>, _1>{});\n    Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dv_accum_ptr) + row_offset_dkv_accum),\n                                  Shape<Int<kBlockN>, Int<kHeadDim>>{},\n                                  Stride<Int<kHeadDim>, _1>{});\n\n    Tensor sdK = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),\n                             typename Kernel_traits::SmemLayoutdKV{});\n    Tensor sdV = make_tensor(sdK.data() + size(sdK), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K)\n\n    typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dKV;\n    auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx);\n    typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dKVaccum;\n    auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx);\n\n    typename Kernel_traits::TiledMmadKV tiled_mma_dkv;\n    auto smem_tiled_copy_dKV = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdKV{}, tiled_mma_dkv);\n    auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(tidx);\n    Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK);  // ((Atom,AtomNum),PIPE_M,PIPE_N)\n    Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV);  // ((Atom,AtomNum),PIPE_M,PIPE_N)\n\n    Tensor tdKsdK = gmem_thr_copy_dKV.partition_S(sdK);    // ((Atom,AtomNum),ATOM_M,ATOM_N)\n    Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK);\n    Tensor tdVsdV = gmem_thr_copy_dKV.partition_S(sdV);    // ((Atom,AtomNum),ATOM_M,ATOM_N)\n    Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV);\n    Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_S(gdKaccum);\n    Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_S(gdVaccum);\n\n    Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{});  // MMA, MMA_N, MMA_K\n    Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{});  // MMA, MMA_N, MMA_K\n    CUTE_STATIC_ASSERT_V(size(acc_dk) == size(tdKgdKaccum));\n    CUTE_STATIC_ASSERT_V(size(acc_dv) == size(tdVgdVaccum));\n\n    Tensor tdKrdKaccum = make_fragment_like(tdKgdKaccum);\n    Tensor tdVrdVaccum = make_fragment_like(tdVgdVaccum);\n    cute::copy(gmem_tiled_copy_dKVaccum, tdKgdKaccum, tdKrdKaccum);\n    cute::copy(gmem_tiled_copy_dKVaccum, tdVgdVaccum, tdVrdVaccum);\n    #pragma unroll\n    for (int i = 0; i < size(acc_dk); ++i) {\n        acc_dk(i) = tdKrdKaccum(i) * params.scale_softmax_rp_dropout;\n    }\n    #pragma unroll\n    for (int i = 0; i < size(acc_dv); ++i) {\n        acc_dv(i) = tdVrdVaccum(i) * params.rp_dropout;\n    }\n    // Convert acc_dk from fp32 to fp16\n    Tensor rdK = FLASH_NAMESPACE::convert_type<Element>(acc_dk);\n    Tensor rdV = FLASH_NAMESPACE::convert_type<Element>(acc_dv);\n    Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(rdK);  // ((Atom,AtomNum), MMA_N, MMA_N)\n    Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(rdV);  // ((Atom,AtomNum), MMA_N, MMA_N)\n    cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK);\n    cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV);\n    __syncthreads();\n    Tensor tdKrdK = make_tensor<Element>(shape(tdKgdK));\n    Tensor tdVrdV = make_tensor<Element>(shape(tdVgdV));\n    cute::copy(gmem_tiled_copy_dKV, tdKsdK, tdKrdK);\n    cute::copy(gmem_tiled_copy_dKV, tdVsdV, tdVrdV);\n\n    Tensor cdKV = make_identity_tensor(Shape<Int<kBlockN>, Int<kHeadDim>>{});    // (BLK_M,BLK_K) -> (blk_m,blk_k)\n    Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);\n    Tensor tdKVpdKV = make_tensor<bool>(make_shape(size<2>(tdKgdK)));\n    #pragma unroll\n    for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; }\n    // Clear_OOB_K must be false since we don't want to write zeros to gmem\n    FLASH_NAMESPACE::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(\n        gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN\n    );\n    FLASH_NAMESPACE::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(\n        gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN\n    );\n}\n\n} // namespace flash\n"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_fwd_<cutlass::bfloat16_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream) {\n    run_mha_fwd_hdim128<cutlass::bfloat16_t, true>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_fwd_<cutlass::bfloat16_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream) {\n    run_mha_fwd_hdim128<cutlass::bfloat16_t, false>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_fwd_<cutlass::half_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream) {\n    run_mha_fwd_hdim128<cutlass::half_t, true>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_fwd_<cutlass::half_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream) {\n    run_mha_fwd_hdim128<cutlass::half_t, false>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_fwd_<cutlass::bfloat16_t, 192, true>(Flash_fwd_params &params, cudaStream_t stream) {\n    run_mha_fwd_hdim192<cutlass::bfloat16_t, true>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_fwd_<cutlass::bfloat16_t, 192, false>(Flash_fwd_params &params, cudaStream_t stream) {\n    run_mha_fwd_hdim192<cutlass::bfloat16_t, false>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_fwd_<cutlass::half_t, 192, true>(Flash_fwd_params &params, cudaStream_t stream) {\n    run_mha_fwd_hdim192<cutlass::half_t, true>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_fwd_<cutlass::half_t, 192, false>(Flash_fwd_params &params, cudaStream_t stream) {\n    run_mha_fwd_hdim192<cutlass::half_t, false>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_fwd_<cutlass::bfloat16_t, 256, true>(Flash_fwd_params &params, cudaStream_t stream) {\n    run_mha_fwd_hdim256<cutlass::bfloat16_t, true>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_fwd_<cutlass::bfloat16_t, 256, false>(Flash_fwd_params &params, cudaStream_t stream) {\n    run_mha_fwd_hdim256<cutlass::bfloat16_t, false>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_fwd_<cutlass::half_t, 256, true>(Flash_fwd_params &params, cudaStream_t stream) {\n    run_mha_fwd_hdim256<cutlass::half_t, true>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_fwd_<cutlass::half_t, 256, false>(Flash_fwd_params &params, cudaStream_t stream) {\n    run_mha_fwd_hdim256<cutlass::half_t, false>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_fwd_<cutlass::bfloat16_t, 32, true>(Flash_fwd_params &params, cudaStream_t stream) {\n    run_mha_fwd_hdim32<cutlass::bfloat16_t, true>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_fwd_<cutlass::bfloat16_t, 32, false>(Flash_fwd_params &params, cudaStream_t stream) {\n    run_mha_fwd_hdim32<cutlass::bfloat16_t, false>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_fwd_<cutlass::half_t, 32, true>(Flash_fwd_params &params, cudaStream_t stream) {\n    run_mha_fwd_hdim32<cutlass::half_t, true>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_fwd_<cutlass::half_t, 32, false>(Flash_fwd_params &params, cudaStream_t stream) {\n    run_mha_fwd_hdim32<cutlass::half_t, false>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_fwd_<cutlass::bfloat16_t, 64, true>(Flash_fwd_params &params, cudaStream_t stream) {\n    run_mha_fwd_hdim64<cutlass::bfloat16_t, true>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_fwd_<cutlass::bfloat16_t, 64, false>(Flash_fwd_params &params, cudaStream_t stream) {\n    run_mha_fwd_hdim64<cutlass::bfloat16_t, false>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_fwd_<cutlass::half_t, 64, true>(Flash_fwd_params &params, cudaStream_t stream) {\n    run_mha_fwd_hdim64<cutlass::half_t, true>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_fwd_<cutlass::half_t, 64, false>(Flash_fwd_params &params, cudaStream_t stream) {\n    run_mha_fwd_hdim64<cutlass::half_t, false>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_fwd_<cutlass::bfloat16_t, 96, true>(Flash_fwd_params &params, cudaStream_t stream) {\n    run_mha_fwd_hdim96<cutlass::bfloat16_t, true>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_fwd_<cutlass::bfloat16_t, 96, false>(Flash_fwd_params &params, cudaStream_t stream) {\n    run_mha_fwd_hdim96<cutlass::bfloat16_t, false>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_fwd_<cutlass::half_t, 96, true>(Flash_fwd_params &params, cudaStream_t stream) {\n    run_mha_fwd_hdim96<cutlass::half_t, true>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate<>\nvoid run_mha_fwd_<cutlass::half_t, 96, false>(Flash_fwd_params &params, cudaStream_t stream) {\n    run_mha_fwd_hdim96<cutlass::half_t, false>(params, stream);\n}\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_kernel.h",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n#include \"namespace_config.h\"\n#include \"philox_unpack.cuh\" // For at::cuda::philox::unpack\n\n#include <cute/tensor.hpp>\n\n#include <cutlass/cutlass.h>\n#include <cutlass/array.h>\n#include <cutlass/numeric_types.h>\n\n#include \"block_info.h\"\n#include \"kernel_traits.h\"\n#include \"utils.h\"\n#include \"softmax.h\"\n#include \"mask.h\"\n#include \"dropout.h\"\n#include \"rotary.h\"\n\nnamespace FLASH_NAMESPACE {\n\nusing namespace cute;\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename ElementAccum, typename Params, int kBlockM, bool Is_even_MN>\n__forceinline__ __device__ auto get_lse_tile(const Params &params, const int bidb, const int bidh, const int m_block, const BlockInfo</*Varlen=*/!Is_even_MN> &binfo) {\n        // When params.unpadded_lse is false, LSE is written as (b, h, seqlen_q) - this is non-variable seqlen path.\n        // Otherwise, when params.seqlenq_ngroups_swapped is true, it is written as (h, seqlen_q, b) to account for seqlen_q <-> h swapping trick.\n        // Otherwise, it's written as (h, b, seqlen_q).\n        const bool varlen_q = params.unpadded_lse && !params.seqlenq_ngroups_swapped;\n        auto lse_offset = varlen_q ? binfo.q_offset(params.seqlen_q, 1, bidb) : 0;\n        auto gmem_ptr_lse = make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr) + lse_offset);\n\n        auto lse_shape = varlen_q ? make_shape(1, params.h, params.total_q) : make_shape(params.b, params.h, params.seqlen_q);\n        auto lse_stride = params.seqlenq_ngroups_swapped ? make_stride(1, params.seqlen_q * params.b, params.b) : (\n            params.unpadded_lse ? make_stride(params.h * params.total_q, params.total_q, 1) :  make_stride(params.h * params.seqlen_q, params.seqlen_q, 1)\n            );\n\n        auto lse_layout = make_layout(lse_shape, lse_stride);\n        Tensor mLSE = make_tensor(gmem_ptr_lse, lse_layout);\n        auto mLSE_slice = varlen_q ? mLSE(0, bidh, _) : mLSE(bidb, bidh, _);\n        return local_tile(mLSE_slice, Shape<Int<kBlockM>>{}, make_coord(m_block));\n}\n\n\ntemplate<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax, typename Params>\ninline __device__ void compute_attn_1rowblock(const Params &params, const int bidb, const int bidh, const int m_block) {\n\n    using Element = typename Kernel_traits::Element;\n    using ElementAccum = typename Kernel_traits::ElementAccum;\n    using index_t = typename Kernel_traits::index_t;\n\n    // Shared memory.\n    extern __shared__ char smem_[];\n\n    // The thread index.\n    const int tidx = threadIdx.x;\n\n    constexpr int kBlockM = Kernel_traits::kBlockM;\n    constexpr int kBlockN = Kernel_traits::kBlockN;\n    constexpr int kHeadDim = Kernel_traits::kHeadDim;\n    constexpr int kNWarps = Kernel_traits::kNWarps;\n\n    auto seed_offset = at::cuda::philox::unpack(params.philox_args);\n    FLASH_NAMESPACE::Dropout dropout(std::get<0>(seed_offset), std::get<1>(seed_offset), params.p_dropout_in_uint8_t,\n                           bidb, bidh, tidx, params.h);\n\n    // Save seed and offset for backward, before any early exiting. Otherwise the 0-th thread block might\n    // exit early and no one saves the rng states.\n    if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) {\n        params.rng_state[0] = std::get<0>(seed_offset);\n        params.rng_state[1] = std::get<1>(seed_offset);\n    }\n\n    const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);\n    if (m_block * kBlockM >= binfo.actual_seqlen_q) return;\n\n    const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN);\n    int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN);\n    if (Is_causal || Is_local) {\n        n_block_max = std::min(n_block_max,\n                               cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN));\n        // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {\n        //     printf(\"m_block = %d, n_block_max = %d\\n\", m_block, n_block_max);\n        // }\n    }\n    // We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0.\n    // Otherwise we might read OOB elements from gK and gV.\n    if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) {\n        Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.o_ptr)\n                                              + binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)),\n                                make_shape(binfo.actual_seqlen_q, params.h, params.d),\n                                make_stride(params.o_row_stride, params.o_head_stride, _1{}));\n        Tensor gO = local_tile(mO(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},\n                              make_coord(m_block, 0));  // (kBlockM, kHeadDim)\n\n        Tensor gLSE = get_lse_tile<ElementAccum, Params, kBlockM, Is_even_MN>(params, bidb, bidh, m_block, binfo);\n\n        typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;\n        auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);\n        Tensor tOgO = gmem_thr_copy_O.partition_D(gO);\n        Tensor tOrO = make_tensor<Element>(shape(tOgO));\n        clear(tOrO);\n        // Construct identity layout for sO\n        Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO)));    // (BLK_M,BLK_K) -> (blk_m,blk_k)\n        // Repeat the partitioning with identity layouts\n        Tensor tOcO = gmem_thr_copy_O.partition_D(cO);\n        Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));\n        if (!Is_even_K) {\n            #pragma unroll\n            for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }\n        }\n        // Clear_OOB_K must be false since we don't want to write zeros to gmem\n        FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(\n            gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM\n        );\n        #pragma unroll\n        for (int m = 0; m < size<1>(tOgO); ++m) {\n            const int row = get<0>(tOcO(0, m, 0));\n            if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSE(row) = INFINITY; }\n        }\n        return;\n    }\n    // if (tidx == 0) { printf(\"m_block = %d, n_block_min = %d, n_block_max = %d\\n\", m_block, n_block_min, n_block_max); }\n\n    // We iterate over the blocks in reverse order. This is because the last block is the only one\n    // that needs masking when we read K and V from global memory. Moreover, iterating in reverse\n    // might save us 1 register (we just need n_block instead of both n_block and n_block_max).\n\n    const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded\n        + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN;\n\n    Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.q_ptr)\n                                          + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)),\n                            make_shape(binfo.actual_seqlen_q, params.h, params.d),\n                            make_stride(params.q_row_stride, params.q_head_stride, _1{}));\n    Tensor gQ = local_tile(mQ(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},\n                           make_coord(m_block, 0));  // (kBlockM, kHeadDim)\n    Tensor mK = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.k_ptr)\n                                          + binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)),\n                            make_shape(binfo.actual_seqlen_k, params.h_k, params.d),\n                            make_stride(params.k_row_stride, params.k_head_stride, _1{}));\n    Tensor gK = local_tile(mK(_, bidh / params.h_h_k_ratio, _), Shape<Int<kBlockN>, Int<kHeadDim>>{},\n                           make_coord(_, 0));  // (kBlockN, kHeadDim, nblocksN)\n    Tensor mV = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.v_ptr)\n                                          + binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)),\n                            make_shape(binfo.actual_seqlen_k, params.h_k, params.d),\n                            make_stride(params.v_row_stride, params.v_head_stride, _1{}));\n    Tensor gV = local_tile(mV(_, bidh / params.h_h_k_ratio, _), Shape<Int<kBlockN>, Int<kHeadDim>>{},\n                           make_coord(_, 0));  // (kBlockN, kHeadDim, nblocksN)\n    Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.p_ptr) + row_offset_p),\n                            Shape<Int<kBlockM>, Int<kBlockN>>{},\n                            make_stride(params.seqlen_k_rounded, _1{}));\n\n    Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),\n                            typename Kernel_traits::SmemLayoutQ{});\n    // Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem;\n    Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)),\n                            typename Kernel_traits::SmemLayoutKV{});\n    Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});\n    Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});\n    Tensor sVtNoSwizzle = make_tensor(sV.data().get(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});\n\n    typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;\n    auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);\n\n    Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);\n    Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);\n    Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK);  // (KCPY, KCPY_N, KCPY_K, nblocksN)\n    Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);\n    Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV);  // (VCPY, VCPY_N, VCPY_K, nblocksN)\n    Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);\n\n    typename Kernel_traits::TiledMma tiled_mma;\n    auto thr_mma = tiled_mma.get_thread_slice(tidx);\n    Tensor tSrQ  = thr_mma.partition_fragment_A(sQ);                           // (MMA,MMA_M,MMA_K)\n    Tensor tSrK  = thr_mma.partition_fragment_B(sK);                           // (MMA,MMA_N,MMA_K)\n    Tensor tOrVt  = thr_mma.partition_fragment_B(sVtNoSwizzle);                // (MMA, MMA_K,MMA_N)\n\n    Tensor tSgS  = thr_mma.partition_C(gP);\n\n    Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{});  // MMA, MMA_M, MMA_K\n\n    //\n    // Copy Atom retiling\n    //\n\n    auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);\n    auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);\n    // if (cute::thread0()) {smem_thr_copy_Q.print_all();}\n    Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);\n    // if (cute::thread0()) {print(tSsQ.layout()); printf(\"\\n\");}\n\n    auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);\n    auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);\n    Tensor tSsK = smem_thr_copy_K.partition_S(sK);\n\n    auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);\n    auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);\n    Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);\n\n    //\n    // PREDICATES\n    //\n\n    // // Allocate predicate tensors for m and n\n    // Tensor tQpQ = make_tensor<bool>(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{});\n    // Tensor tKVpKV = make_tensor<bool>(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{});\n\n    // Construct identity layout for sQ and sK\n    Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ)));    // (BLK_M,BLK_K) -> (blk_m,blk_k)\n    Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK)));    // (BLK_N,BLK_K) -> (blk_n,blk_k)\n    // Tensor tScQ = thr_mma.partition_A(cQ);                           // (MMA,MMA_M,MMA_K)\n    // if (cute::thread0()) {\n    //     print(tScQ.layout()); printf(\"\\n\");\n    //     for (int i = 0; i < size(tScQ); ++i) {\n    //         printf(\"%d \", get<0>(tScQ(i)));\n    //     }\n    //     printf(\"\\n\");\n    //     for (int i = 0; i < size(tScQ); ++i) {\n    //         printf(\"%d \", get<1>(tScQ(i)));\n    //     }\n    //     printf(\"\\n\");\n    // }\n\n    // Repeat the partitioning with identity layouts\n    Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ);       // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)\n    Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV);   // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)\n\n    // Allocate predicate tensors for k\n    Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));\n    Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK)));\n\n    // Set predicates for k bounds\n    if (!Is_even_K) {\n        #pragma unroll\n        for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; }\n        #pragma unroll\n        for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; }\n    }\n\n    // Prologue\n\n    // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs\n    FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,\n                                       binfo.actual_seqlen_q - m_block * kBlockM);\n    if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); }\n\n    // // if (cute::thread(1, 0)) { print(tQsQ); }\n    // // Tensor sQNoSwizzle = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)), typename Kernel_traits::SmemLayoutQNoSwizzle{});\n    // // if (cute::thread0()) { print(sQNoSwizzle); }\n\n    if (Kernel_traits::Share_Q_K_smem) {\n        FLASH_NAMESPACE::cp_async_wait<0>();\n        __syncthreads();\n        Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);\n        CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view));            // M\n        cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);\n        __syncthreads();\n    }\n\n    int n_block = n_block_max - 1;\n    // We don't need to clear the sK smem tiles since we'll mask out the scores anyway.\n    FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block), tKsK, tKVcKV, tKVpKV,\n                                       binfo.actual_seqlen_k - n_block * kBlockN);\n    cute::cp_async_fence();\n    // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); }\n    // __syncthreads();\n\n    if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) {\n        FLASH_NAMESPACE::cp_async_wait<1>();\n        __syncthreads();\n        Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);\n        CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view));            // M\n        cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);\n    }\n\n    clear(acc_o);\n\n    FLASH_NAMESPACE::Softmax<2 * size<1>(acc_o)> softmax;\n\n    const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;\n    FLASH_NAMESPACE::Mask<Is_causal, Is_local, Has_alibi> mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope);\n\n    // For performance reason, we separate out two kinds of iterations:\n    // those that need masking on S, and those that don't.\n    // We need masking on S for the very last block when K and V has length not multiple of kBlockN.\n    // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.\n    // We will have at least 1 \"masking\" iteration.\n\n    // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to\n    // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.\n    constexpr int n_masking_steps = (!Is_causal && !Is_local)\n        ? 1\n        : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1);\n    #pragma unroll\n    for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) {\n        Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});  // (MMA=4, MMA_M, MMA_N)\n        clear(acc_s);\n        FLASH_NAMESPACE::cp_async_wait<0>();\n        __syncthreads();\n\n        // Advance gV\n        if (masking_step > 0) {\n            FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV);\n        } else {\n            // Clear the smem tiles to account for predicated off loads\n            FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(\n                gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN\n            );\n        }\n        cute::cp_async_fence();\n\n        FLASH_NAMESPACE::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(\n            acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,\n            smem_thr_copy_Q, smem_thr_copy_K\n        );\n        // if (cute::thread0()) { print(acc_s); }\n        if constexpr (Is_softcap){\n            FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap);\n        }\n\n        mask.template apply_mask<Is_causal, Is_even_MN>(\n            acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16\n        );\n\n        FLASH_NAMESPACE::cp_async_wait<0>();\n        __syncthreads();\n        if (n_block > n_block_min) {\n            FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV);\n            // This cp_async_fence needs to be in the if block, otherwise the synchronization\n            // isn't right and we get race conditions.\n            cute::cp_async_fence();\n        }\n\n        // TODO: when we have key_padding_mask we'll need to Check_inf\n        masking_step == 0\n            ? softmax.template softmax_rescale_o</*Is_first=*/true,  /*Check_inf=*/Is_causal || Is_local>(acc_s, acc_o, params.scale_softmax_log2)\n            : softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local>(acc_s, acc_o, params.scale_softmax_log2);\n\n        // Convert acc_s from fp32 to fp16/bf16\n        Tensor rP = FLASH_NAMESPACE::convert_type<Element>(acc_s);\n        int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;\n        int block_col_idx = n_block * (kBlockN / 32);\n        if (Return_softmax) {\n            Tensor rP_drop = make_fragment_like(rP);\n            cute::copy(rP, rP_drop);\n            dropout.template apply_dropout</*encode_dropout_in_sign_bit=*/true>(\n                rP_drop, block_row_idx, block_col_idx, kNWarps\n            );\n            cute::copy(rP_drop, tSgS);\n            tSgS.data() = tSgS.data() + (-kBlockN);\n        }\n        if (Is_dropout) {\n            dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps);\n        }\n\n        // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)\n        // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.\n        Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs<typename Kernel_traits::TiledMma>(rP.layout()));\n        // if (cute::thread0()) { print(tOrP); }\n        FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);\n        // if (cute::thread0()) { print(scores); }\n\n        // This check is at the end of the loop since we always have at least 1 iteration\n        if (n_masking_steps > 1 && n_block <= n_block_min) {\n            --n_block;\n            break;\n        }\n    }\n\n    // These are the iterations where we don't need masking on S\n    for (; n_block >= n_block_min; --n_block) {\n        Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});  // (MMA=4, MMA_M, MMA_N)\n        clear(acc_s);\n        FLASH_NAMESPACE::cp_async_wait<0>();\n        __syncthreads();\n        FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV);\n        cute::cp_async_fence();\n\n        FLASH_NAMESPACE::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(\n            acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,\n            smem_thr_copy_Q, smem_thr_copy_K\n        );\n        if constexpr (Is_softcap){\n            FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap);\n        }\n\n        FLASH_NAMESPACE::cp_async_wait<0>();\n        __syncthreads();\n        if (n_block > n_block_min) {\n            FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV);\n            // This cp_async_fence needs to be in the if block, otherwise the synchronization\n            // isn't right and we get race conditions.\n            cute::cp_async_fence();\n        }\n\n        mask.template apply_mask</*Causal_mask=*/false>(\n            acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16\n        );\n\n        softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(acc_s, acc_o, params.scale_softmax_log2);\n\n        Tensor rP = FLASH_NAMESPACE::convert_type<Element>(acc_s);\n        int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;\n        int block_col_idx = n_block * (kBlockN / 32);\n        if (Return_softmax) {\n            Tensor rP_drop = make_fragment_like(rP);\n            cute::copy(rP, rP_drop);\n            dropout.template apply_dropout</*encode_dropout_in_sign_bit=*/true>(\n                rP_drop, block_row_idx, block_col_idx, kNWarps\n            );\n            cute::copy(rP_drop, tSgS);\n            tSgS.data() = tSgS.data() + (-kBlockN);\n        }\n        if (Is_dropout) {\n            dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps);\n        }\n\n        // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)\n        // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.\n        Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs<typename Kernel_traits::TiledMma>(rP.layout()));\n        FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);\n    }\n\n    // Epilogue\n\n    Tensor lse = softmax.template normalize_softmax_lse<Is_dropout>(acc_o, params.scale_softmax, params.rp_dropout);\n\n    // Convert acc_o from fp32 to fp16/bf16\n    Tensor rO = FLASH_NAMESPACE::convert_type<Element>(acc_o);\n    Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{});    // (SMEM_M,SMEM_N)\n    // Partition sO to match the accumulator partitioning\n    auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma);\n    auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx);\n    Tensor taccOrO = smem_thr_copy_O.retile_S(rO);        // ((Atom,AtomNum), MMA_M, MMA_N)\n    Tensor taccOsO = smem_thr_copy_O.partition_D(sO);     // ((Atom,AtomNum),PIPE_M,PIPE_N)\n\n    // sO has the same size as sQ, so we don't need to sync here.\n    if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); }\n\n    cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);\n\n    Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.o_ptr)\n                                          + binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)),\n                            make_shape(binfo.actual_seqlen_q, params.h, params.d),\n                            make_stride(params.o_row_stride, params.o_head_stride, _1{}));\n    Tensor gO = local_tile(mO(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},\n                           make_coord(m_block, 0));  // (kBlockM, kHeadDim)\n    Tensor gLSE = get_lse_tile<ElementAccum, Params, kBlockM, Is_even_MN>(params, bidb, bidh, m_block, binfo);\n\n    typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;\n    auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);\n    Tensor tOsO = gmem_thr_copy_O.partition_S(sO);        // ((Atom,AtomNum),ATOM_M,ATOM_N)\n    Tensor tOgO = gmem_thr_copy_O.partition_D(gO);\n\n    __syncthreads();\n\n    Tensor tOrO = make_tensor<Element>(shape(tOgO));\n    cute::copy(gmem_tiled_copy_O, tOsO, tOrO);\n\n    Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{});    // (BLK_M,BLK_K) -> (blk_m,blk_k)\n    Tensor taccOcO = thr_mma.partition_C(caccO);                           // (MMA,MMA_M,MMA_K)\n    static_assert(decltype(size<0>(taccOcO))::value == 4);\n    // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices.\n    Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0);\n    CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row));                     // MMA_M\n    if (get<1>(taccOcO_row(0)) == 0) {\n        #pragma unroll\n        for (int mi = 0; mi < size(lse); ++mi) {\n            const int row = get<0>(taccOcO_row(mi));\n            if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSE(row) = lse(mi); }\n        }\n    }\n\n    // Construct identity layout for sO\n    Tensor cO = make_identity_tensor(make_shape(size<0>(sO), size<1>(sO)));    // (BLK_M,BLK_K) -> (blk_m,blk_k)\n    // Repeat the partitioning with identity layouts\n    Tensor tOcO = gmem_thr_copy_O.partition_D(cO);                           // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)\n    Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));\n    if (!Is_even_K) {\n        #pragma unroll\n        for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }\n    }\n    // Clear_OOB_K must be false since we don't want to write zeros to gmem\n    FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(\n        gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM\n    );\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, typename Params>\ninline __device__ void compute_attn_1rowblock_splitkv(const Params &params, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) {\n\n    using Element = typename Kernel_traits::Element;\n    using ElementAccum = typename Kernel_traits::ElementAccum;\n    using index_t = typename Kernel_traits::index_t;\n\n    // Shared memory.\n    extern __shared__ char smem_[];\n\n    // The thread index.\n    const int tidx = threadIdx.x;\n\n    constexpr int kBlockM = Kernel_traits::kBlockM;\n    constexpr int kBlockN = Kernel_traits::kBlockN;\n    constexpr int kHeadDim = Kernel_traits::kHeadDim;\n    constexpr int kNWarps = Kernel_traits::kNWarps;\n\n    using GmemTiledCopyO = std::conditional_t<\n        !Split,\n        typename Kernel_traits::GmemTiledCopyO,\n        typename Kernel_traits::GmemTiledCopyOaccum\n    >;\n    using ElementO = std::conditional_t<!Split, Element, ElementAccum>;\n\n    const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);\n    // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf(\"Is_even_MN = %d, is_cumulativ = %d, seqlen_k_cache = %d, actual_seqlen_k = %d\\n\", Is_even_MN, params.is_seqlens_k_cumulative, binfo.seqlen_k_cache, binfo.actual_seqlen_k); }\n    // if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf(\"params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\\n\", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); }\n    if (m_block * kBlockM >= binfo.actual_seqlen_q) return;\n\n    const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits;\n    const int n_block_min = !Is_local\n        ? n_split_idx * n_blocks_per_split\n        : std::max(n_split_idx * n_blocks_per_split, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN);\n    int n_block_max = std::min(cute::ceil_div(binfo.actual_seqlen_k, kBlockN), (n_split_idx + 1) * n_blocks_per_split);\n    if (Is_causal || Is_local) {\n        n_block_max = std::min(n_block_max,\n                               cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN));\n    }\n    if (n_block_min >= n_block_max) {  // This also covers the case where n_block_max <= 0\n        // We exit early and write 0 to gOaccum and -inf to gLSEaccum.\n        // Otherwise we might read OOB elements from gK and gV,\n        // or get wrong results when we combine gOaccum from different blocks.\n        const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)\n            + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;\n        const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q\n            + m_block * kBlockM) * params.d_rounded;\n        const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;\n        Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),\n                                      Shape<Int<kBlockM>, Int<kHeadDim>>{},\n                                     make_stride(Split ? kHeadDim : params.o_row_stride, _1{}));\n        Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum),\n                                      Shape<Int<kBlockM>>{}, Stride<_1>{});\n\n        GmemTiledCopyO gmem_tiled_copy_Oaccum;\n        auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);\n        Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);\n        Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum));\n        clear(tOrOaccum);\n        // Construct identity layout for sO\n        Tensor cO = make_identity_tensor(make_shape(size<0>(gOaccum), size<1>(gOaccum)));    // (BLK_M,BLK_K) -> (blk_m,blk_k)\n        // Repeat the partitioning with identity layouts\n        Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO);\n        Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));\n        if (!Is_even_K) {\n            #pragma unroll\n            for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }\n        }\n        // Clear_OOB_K must be false since we don't want to write zeros to gmem\n        FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(\n            gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM\n        );\n        #pragma unroll\n        for (int m = 0; m < size<1>(tOgOaccum); ++m) {\n            const int row = get<0>(tOcO(0, m, 0));\n            if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSEaccum(row) = Split ? -INFINITY : INFINITY; }\n        }\n        return;\n    }\n\n    // We iterate over the blocks in reverse order. This is because the last block is the only one\n    // that needs masking when we read K and V from global memory. Moreover, iterating in reverse\n    // might save us 1 register (we just need n_block instead of both n_block and n_block_max).\n\n    // We move K and V to the last block.\n    const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb];\n    const int *block_table = params.block_table == nullptr ? nullptr : params.block_table + bidb * params.block_table_batch_stride;\n    const int block_table_idx = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN / params.page_block_size;\n    const int block_table_offset = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN - block_table_idx * params.page_block_size;\n    const index_t row_offset_k = block_table == nullptr\n        ? binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache)\n          + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride\n        : block_table[block_table_idx] * params.k_batch_stride + block_table_offset * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;\n    const index_t row_offset_v = block_table == nullptr\n        ? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache)\n          + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride\n        : block_table[block_table_idx] * params.v_batch_stride + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;\n\n    Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.q_ptr) + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)),\n                            make_shape(binfo.actual_seqlen_q, params.h, params.d),\n                            make_stride(params.q_row_stride, params.q_head_stride, _1{}));\n    Tensor gQ = local_tile(mQ(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},\n                           make_coord(m_block, 0));  // (kBlockM, kHeadDim)\n    Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),\n                            Shape<Int<kBlockN>, Int<kHeadDim>>{},\n                            make_stride(params.k_row_stride, _1{}));\n    // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf(\"k_ptr = %p, row_offset_k = %d, gK_ptr = %p\\n\", params.k_ptr, row_offset_k, gK.data()); }\n    Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v),\n                            Shape<Int<kBlockN>, Int<kHeadDim>>{},\n                            make_stride(params.v_row_stride, _1{}));\n\n    Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),\n                            typename Kernel_traits::SmemLayoutQ{});\n    Tensor sK = make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutKV{});\n    Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});\n    Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});\n    Tensor sVtNoSwizzle = make_tensor(sV.data().get(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});\n\n    typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;\n    auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);\n\n    Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);\n    Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);\n    Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK);  // (KCPY, KCPY_N, KCPY_K)\n    Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);\n    Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV);  // (VCPY, VCPY_N, VCPY_K)\n    Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);\n\n    typename Kernel_traits::TiledMma tiled_mma;\n    auto thr_mma = tiled_mma.get_thread_slice(tidx);\n    Tensor tSrQ  = thr_mma.partition_fragment_A(sQ);                           // (MMA,MMA_M,MMA_K)\n    Tensor tSrK  = thr_mma.partition_fragment_B(sK);                           // (MMA,MMA_N,MMA_K)\n    Tensor tOrVt  = thr_mma.partition_fragment_B(sVtNoSwizzle);                // (MMA, MMA_K,MMA_N)\n\n    Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{});  // MMA, MMA_M, MMA_K\n\n    //\n    // Copy Atom retiling\n    //\n\n    auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);\n    auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);\n    Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);\n\n    auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);\n    auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);\n    Tensor tSsK = smem_thr_copy_K.partition_S(sK);\n\n    auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);\n    auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);\n    Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);\n\n    // PREDICATES\n    //\n\n    // // Allocate predicate tensors for m and n\n    // Tensor tQpQ = make_tensor<bool>(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{});\n    // Tensor tKVpKV = make_tensor<bool>(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{});\n\n    // Construct identity layout for sQ and sK\n    Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ)));    // (BLK_M,BLK_K) -> (blk_m,blk_k)\n    Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK)));    // (BLK_N,BLK_K) -> (blk_n,blk_k)\n\n    // Repeat the partitioning with identity layouts\n    Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ);       // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)\n    Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV);   // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)\n\n    // Allocate predicate tensors for k\n    Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));\n    Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK)));\n\n    // Set predicates for k bounds\n    if (!Is_even_K) {\n        #pragma unroll\n        for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; }\n        #pragma unroll\n        for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; }\n    }\n\n    // Prologue\n\n    // Copy from Knew to K, optionally apply rotary embedding.\n    typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary;\n    auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx);\n    typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont;\n    auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx);\n    if constexpr (Append_KV) {\n        // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to\n        // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe.\n        // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache.\n        const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb])) * (params.rotary_dim / 2);\n        Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),\n                                  Shape<Int<kBlockN>, Int<kHeadDim / 2>>{},\n                                  make_stride(params.rotary_dim / 2, _1{}));\n        Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin),\n                                  Shape<Int<kBlockN>, Int<kHeadDim / 2>>{},\n                                  make_stride(params.rotary_dim / 2, _1{}));\n        Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),\n                                      Shape<Int<kBlockN>, Int<kHeadDim>>{},\n                                      make_stride(params.rotary_dim / 2, _1{}));\n        Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin),\n                                      Shape<Int<kBlockN>, Int<kHeadDim>>{},\n                                      make_stride(params.rotary_dim / 2, _1{}));\n        Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos);\n        Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin);\n        Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont);\n        Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont);\n        // if (cute::thread(0, 0)) { printf(\"rotary_cos_ptr = %p, gCos.data() = %p, tRgCos.data() = %p, rotary_dim = %d\\n\", params.rotary_cos_ptr, gCos.data(), tRgCos.data(), params.rotary_dim); }\n        // if (cute::thread(8, 0)) { print_tensor(gCos); }\n        // if (cute::thread(0, 0)) { print_tensor(tRgCos); }\n\n        // const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb)\n        const index_t row_offset_knew = bidb * params.knew_batch_stride\n            + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride;\n        // const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb)\n        const index_t row_offset_vnew = bidb * params.vnew_batch_stride\n            + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride;\n        // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew \"line up\". When we access them,\n        // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64].\n        // This maps to accessing the first 64 rows of knew_ptr.\n        Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.knew_ptr)\n                                                + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride),\n                                  Shape<Int<kBlockN>, Int<kHeadDim>>{},\n                                  make_stride(params.knew_row_stride, _1{}));\n        // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf(\"knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\\n\", params.knew_ptr, row_offset_knew, gKnew.data()); }\n        Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.vnew_ptr)\n                                                + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride),\n                                  Shape<Int<kBlockN>, Int<kHeadDim>>{},\n                                  make_stride(params.vnew_row_stride, _1{}));\n        Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew);  // (KCPY, KCPY_N, KCPY_K)\n        Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew);  // (VCPY, VCPY_N, VCPY_K)\n\n        const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN);\n        auto tKgK_data = tKgK.data();\n        auto tVgV_data = tVgV.data();\n        for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) {\n            FLASH_NAMESPACE::copy_w_min_idx<Is_even_K>(\n                tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN\n            );\n            tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride));\n            if (params.rotary_dim == 0) {\n                FLASH_NAMESPACE::copy_w_min_idx<Is_even_K>(\n                    tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN\n                );\n            } else {\n                if (params.is_rotary_interleaved) {\n                    // Don't clear OOB_K because we're writing to global memory\n                    FLASH_NAMESPACE::copy_rotary_interleaved<Is_even_K, /*Clear_OOB_K=*/false>(\n                        tKgKnew, tKgK, tRgCos, tRgSin, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN,\n                        binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim\n                    );\n                    tRgCos.data() = tRgCos.data() + (-int(kBlockN * params.rotary_dim / 2));\n                    tRgSin.data() = tRgSin.data() + (-int(kBlockN * params.rotary_dim / 2));\n                } else {\n                    // Don't clear OOB_K because we're writing to global memory\n                    FLASH_NAMESPACE::copy_rotary_contiguous<Is_even_K, /*Clear_OOB_K=*/false>(\n                        tKgKnew, tKgK, tRgCosCont, tRgSinCont, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN,\n                        binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim\n                    );\n                    tRgCosCont.data() = tRgCosCont.data() + (-int(kBlockN * params.rotary_dim / 2));\n                    tRgSinCont.data() = tRgSinCont.data() + (-int(kBlockN * params.rotary_dim / 2));\n\n                }\n            }\n            tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride));\n            if (block_table == nullptr) {\n                tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));\n                tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));\n            } else {\n                if (n_block > n_block_copy_min) {\n                    const int block_table_idx_cur = n_block * kBlockN / params.page_block_size;\n                    const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size;\n                    const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size;\n                    const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;\n                    const int table_diff = block_table[block_table_idx_next] - block_table[block_table_idx_cur];\n                    const int offset_diff = block_table_offset_next - block_table_offset_cur;\n                    tVgV.data() = tVgV.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride;\n                    tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride;\n                }\n            }\n        }\n        // Need this before we can read in K again, so that we'll see the updated K values.\n        __syncthreads();\n        tKgK.data() = tKgK_data;\n        tVgV.data() = tVgV_data;\n    }\n\n    // Read Q from gmem to smem, optionally apply rotary embedding.\n    if (!Append_KV || params.rotary_dim == 0) {\n        // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs\n        FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,\n                                           binfo.actual_seqlen_q - m_block * kBlockM);\n    } else {\n        const index_t row_offset_cossin = (binfo.seqlen_k_cache + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb]) + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2);\n        // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache.\n        // We do this by setting the row stride of gCos / gSin to 0.\n        Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),\n                                  Shape<Int<kBlockM>, Int<kHeadDim / 2>>{},\n                                  make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));\n        Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin),\n                                  Shape<Int<kBlockM>, Int<kHeadDim / 2>>{},\n                                  make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));\n        Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),\n                                  Shape<Int<kBlockM>, Int<kHeadDim>>{},\n                                  make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));\n        Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin),\n                                  Shape<Int<kBlockM>, Int<kHeadDim>>{},\n                                  make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));\n        Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos);\n        Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin);\n        Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont);\n        Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont);\n        if (params.is_rotary_interleaved) {\n            FLASH_NAMESPACE::copy_rotary_interleaved<Is_even_K>(\n                tQgQ, tQsQ, tRgCos, tRgSin, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM,\n                0, params.d, params.rotary_dim\n            );\n        } else {\n            FLASH_NAMESPACE::copy_rotary_contiguous<Is_even_K>(\n                tQgQ, tQsQ, tRgCosCont, tRgSinCont, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM,\n                0, params.d, params.rotary_dim\n            );\n        }\n    }\n\n    int n_block = n_block_max - 1;\n    // We don't need to clear the sK smem tiles since we'll mask out the scores anyway.\n    FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,\n                                       binfo.actual_seqlen_k - n_block * kBlockN);\n    cute::cp_async_fence();\n\n    // FLASH_NAMESPACE::cp_async_wait<0>();\n    // __syncthreads();\n    // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tKsK); }\n    // __syncthreads();\n\n    clear(acc_o);\n\n    FLASH_NAMESPACE::Softmax<2 * size<1>(acc_o)> softmax;\n\n    const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;\n    FLASH_NAMESPACE::Mask<Is_causal, Is_local, Has_alibi> mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope);\n\n    // For performance reason, we separate out two kinds of iterations:\n    // those that need masking on S, and those that don't.\n    // We need masking on S for the very last block when K and V has length not multiple of kBlockN.\n    // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.\n    // We will have at least 1 \"masking\" iteration.\n\n    // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to\n    // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.\n    constexpr int n_masking_steps = (!Is_causal && !Is_local)\n        ? 1\n        : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1);\n    #pragma unroll\n    for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) {\n        Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});  // (MMA=4, MMA_M, MMA_N)\n        clear(acc_s);\n        FLASH_NAMESPACE::cp_async_wait<0>();\n        __syncthreads();\n\n        // Advance gV\n        if (masking_step > 0) {\n            if (block_table == nullptr) {\n                tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));\n            } else {\n                const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size;\n                const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size;\n                const int block_table_idx_next = n_block * kBlockN / params.page_block_size;\n                const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size;\n                tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride;\n            }\n            FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);\n        } else {\n            // Clear the smem tiles to account for predicated off loads\n            FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(\n                gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN\n            );\n        }\n        cute::cp_async_fence();\n\n        FLASH_NAMESPACE::gemm(\n            acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,\n            smem_thr_copy_Q, smem_thr_copy_K\n        );\n        // if (cute::thread0()) { print(acc_s); }\n        if constexpr (Is_softcap){\n            FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap);\n        }\n\n\n        mask.template apply_mask<Is_causal, Is_even_MN>(\n            acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16\n        );\n\n        FLASH_NAMESPACE::cp_async_wait<0>();\n        __syncthreads();\n        // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); }\n        // __syncthreads();\n\n        if (n_block > n_block_min) {\n            // Advance gK\n            if (block_table == nullptr) {\n                tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));\n            } else {\n                const int block_table_idx_cur = n_block * kBlockN / params.page_block_size;\n                const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size;\n                const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size;\n                const int block_table_offset_next =(n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;\n                tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride;\n            }\n            FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);\n            // This cp_async_fence needs to be in the if block, otherwise the synchronization\n            // isn't right and we get race conditions.\n            cute::cp_async_fence();\n        }\n\n        // We have key_padding_mask so we'll need to Check_inf\n        masking_step == 0\n            ? softmax.template softmax_rescale_o</*Is_first=*/true,  /*Check_inf=*/Is_causal || Is_local || !Is_even_MN>(acc_s, acc_o, params.scale_softmax_log2)\n            : softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local || !Is_even_MN>(acc_s, acc_o, params.scale_softmax_log2);\n        // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); }\n\n        // Convert acc_s from fp32 to fp16/bf16\n        Tensor rP = FLASH_NAMESPACE::convert_type<Element>(acc_s);\n        // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)\n        // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.\n        Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs<typename Kernel_traits::TiledMma>(rP.layout()));\n\n        FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);\n\n        // This check is at the end of the loop since we always have at least 1 iteration\n        if (n_masking_steps > 1 && n_block <= n_block_min) {\n            --n_block;\n            break;\n        }\n    }\n\n    // These are the iterations where we don't need masking on S\n    for (; n_block >= n_block_min; --n_block) {\n        Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});  // (MMA=4, MMA_M, MMA_N)\n        clear(acc_s);\n        FLASH_NAMESPACE::cp_async_wait<0>();\n        __syncthreads();\n        // Advance gV\n        if (block_table == nullptr) {\n            tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));\n        } else {\n            const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size;\n            const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size;\n            const int block_table_idx_next = n_block * kBlockN / params.page_block_size;\n            const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size;\n            tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride;\n        }\n        FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);\n        cute::cp_async_fence();\n\n        FLASH_NAMESPACE::gemm(\n            acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,\n            smem_thr_copy_Q, smem_thr_copy_K\n        );\n        if constexpr (Is_softcap){\n            FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap);\n        }\n\n        FLASH_NAMESPACE::cp_async_wait<0>();\n        __syncthreads();\n        if (n_block > n_block_min) {\n            // Advance gK\n            if (block_table == nullptr) {\n                tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));\n            } else {\n                const int block_table_idx_cur = n_block * kBlockN / params.page_block_size;\n                const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size;\n                const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size;\n                const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;\n                tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride;\n            }\n            FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);\n            // This cp_async_fence needs to be in the if block, otherwise the synchronization\n            // isn't right and we get race conditions.\n            cute::cp_async_fence();\n        }\n\n        mask.template apply_mask</*Causal_mask=*/false>(\n            acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16\n        );\n        softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(acc_s, acc_o, params.scale_softmax_log2);\n\n        Tensor rP = FLASH_NAMESPACE::convert_type<Element>(acc_s);\n        // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)\n        // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.\n        Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs<typename Kernel_traits::TiledMma>(rP.layout()));\n\n        FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);\n    }\n\n    // Epilogue\n\n    Tensor lse = softmax.template normalize_softmax_lse</*Is_dropout=*/false, Split>(acc_o, params.scale_softmax);\n    // if (cute::thread0()) { print(lse); }\n\n    Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementO *>(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)\n    // Partition sO to match the accumulator partitioning\n    using SmemTiledCopyO = std::conditional_t<\n        !Split,\n        typename Kernel_traits::SmemCopyAtomO,\n        typename Kernel_traits::SmemCopyAtomOaccum\n    >;\n    auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma);\n    auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx);\n    Tensor rO = FLASH_NAMESPACE::convert_type<ElementO>(acc_o);\n    Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO);        // ((Atom,AtomNum), MMA_M, MMA_N)\n    Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum);     // ((Atom,AtomNum),PIPE_M,PIPE_N)\n\n    // sOaccum is larger than sQ, so we need to syncthreads here\n    // TODO: allocate enough smem for sOaccum\n    if constexpr (Split) { __syncthreads(); }\n\n    cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum);\n\n    const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)\n        + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;\n    const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q\n                                         + m_block * kBlockM) * params.d_rounded;\n    const index_t row_offset_lseaccum = (Split || !params.unpadded_lse ?\n            ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q : bidh * params.total_q + binfo.q_offset(params.seqlen_q, 1, bidb)\n        ) + m_block * kBlockM;\n\n    Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),\n                                 Shape<Int<kBlockM>, Int<kHeadDim>>{},\n                                 make_stride(Split ? kHeadDim : params.o_row_stride, _1{}));\n    Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum),\n                                   Shape<Int<kBlockM>>{}, Stride<_1>{});\n    // if (tidx == 0) { printf(\"row_offset_o = %d, bidh = %d, gOaccum = %p\\n\", row_offset_o, bidh, gOaccum.data()); }\n\n    GmemTiledCopyO gmem_tiled_copy_Oaccum;\n    auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);\n    Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum);        // ((Atom,AtomNum),ATOM_M,ATOM_N)\n    Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);\n\n    __syncthreads();\n\n    Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum));\n    cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum);\n\n    Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{});    // (BLK_M,BLK_K) -> (blk_m,blk_k)\n    Tensor taccOcO = thr_mma.partition_C(caccO);                           // (MMA,MMA_M,MMA_K)\n    static_assert(decltype(size<0>(taccOcO))::value == 4);\n    // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices.\n    Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0);\n    CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row));                     // MMA_M\n    if (get<1>(taccOcO_row(0)) == 0) {\n        #pragma unroll\n        for (int mi = 0; mi < size(lse); ++mi) {\n            const int row = get<0>(taccOcO_row(mi));\n            if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); }\n        }\n    }\n\n    // Construct identity layout for sO\n    Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum)));    // (BLK_M,BLK_K) -> (blk_m,blk_k)\n    // Repeat the partitioning with identity layouts\n    Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO);                           // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)\n    Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));\n    if (!Is_even_K) {\n        #pragma unroll\n        for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }\n    }\n    // Clear_OOB_K must be false since we don't want to write zeros to gmem\n    FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(\n        gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM\n    );\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax, typename Params>\ninline __device__ void compute_attn(const Params &params) {\n    const int m_block = blockIdx.x;\n    // The block index for the batch.\n    const int bidb = blockIdx.y;\n    // The block index for the head.\n    const int bidh = blockIdx.z;\n\n    // We want the fwd and bwd to generate the same dropout pattern (RNG), without restricting\n    // them to have the same number of threads or have to traverse the attention matrix\n    // in the same order.\n    // In the Philox RNG, we use the offset to store the batch, head, and the lane id\n    // (within a warp). We use the subsequence to store the location of the 16 x 32 blocks within\n    // the attention matrix. This way, as long as we have the batch, head, and the location of\n    // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.\n\n    FLASH_NAMESPACE::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Return_softmax>(params, bidb, bidh, m_block);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, typename Params>\ninline __device__ void compute_attn_splitkv(const Params &params) {\n    const int m_block = blockIdx.x;\n    // The block index for the batch.\n    const int bidb = Split ? blockIdx.z / params.h : blockIdx.y;\n    // The block index for the head.\n    const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z;\n    const int n_split_idx = Split ? blockIdx.y : 0;\n    const int num_n_splits = Split ? gridDim.y : 1;\n    FLASH_NAMESPACE::compute_attn_1rowblock_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV>(params, bidb, bidh, m_block, n_split_idx, num_n_splits);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename Kernel_traits, int kBlockM, int Log_max_splits, bool Is_even_K, typename Params>\ninline __device__ void combine_attn_seqk_parallel(const Params &params) {\n    using Element = typename Kernel_traits::Element;\n    using ElementAccum = typename Kernel_traits::ElementAccum;\n    using index_t = typename Kernel_traits::index_t;\n    constexpr int kMaxSplits = 1 << Log_max_splits;\n    constexpr int kHeadDim = Kernel_traits::kHeadDim;\n    constexpr int kNThreads = Kernel_traits::kNThreads;\n\n    static_assert(kMaxSplits <= 128, \"kMaxSplits must be <= 128\");\n    static_assert(kBlockM == 4 || kBlockM == 8 || kBlockM == 16 || kBlockM == 32, \"kBlockM must be 4, 8, 16 or 32\");\n    static_assert(kNThreads == 128, \"We assume that each block has 128 threads\");\n\n    // Shared memory.\n    // kBlockM + 1 instead of kBlockM to reduce bank conflicts.\n    __shared__ ElementAccum sLSE[kMaxSplits][kBlockM + 1];\n\n    // The thread and block index.\n    const int tidx = threadIdx.x;\n    const int bidx = blockIdx.x;\n\n    const index_t lse_size = params.b * params.h * params.seqlen_q;\n\n    const index_t row_offset_lse = bidx * kBlockM;\n    Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lseaccum_ptr) + row_offset_lse),\n                                   Shape<Int<kMaxSplits>, Int<kBlockM>>{},\n                                   make_stride(lse_size, _1{}));\n\n    // LSE format is different depending on params.unpadded_lse and params.seqlenq_ngroups_swapped, see comment in get_lse_tile.\n    // This tensor's layout maps row_offset_lse to {bidb, bidh, q_offset}.\n    Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),\n                              Shape<Int<kBlockM>>{}, Stride<_1>{});\n\n    // This layout maps row_offset_lse to {bidh, q_offset, bidb} or {bidh, bidb, q_offset}.\n    Layout flat_layout = make_layout(lse_size);\n    Layout orig_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b));\n    auto transposed_stride = params.seqlenq_ngroups_swapped ? make_stride(params.b, params.seqlen_q * params.b, 1) : make_stride(1, params.seqlen_q * params.b, params.seqlen_q);\n    Layout remapped_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b), transposed_stride);\n    Layout final_layout = cute::composition(remapped_layout, cute::composition(orig_layout, flat_layout));\n\n    Tensor gLSE_unpadded = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr)), final_layout);\n\n    constexpr int kNLsePerThread = (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads;\n\n    // Read the LSE values from gmem and store them in shared memory, then transpose them.\n    constexpr int kRowsPerLoadLSE = kNThreads / kBlockM;\n    #pragma unroll\n    for (int l = 0; l < kNLsePerThread; ++l) {\n        const int row = l * kRowsPerLoadLSE + tidx / kBlockM;\n        const int col = tidx % kBlockM;\n        ElementAccum lse = (row < params.num_splits && col < lse_size - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY;\n        if (row < kMaxSplits) { sLSE[row][col] = lse; }\n        // if (bidx == 0 && tidx < 32) { printf(\"tidx = %d, row = %d, col = %d, lse = %f\\n\", tidx, row, col, lse); }\n    }\n    // if (bidx == 1 && tidx < 32) { printf(\"tidx = %d, row_offset_lse = %d, lse = %f\\n\", tidx, row_offset_lse, lse_accum(0)); }\n    __syncthreads();\n    Tensor lse_accum = make_tensor<ElementAccum>(Shape<Int<kNLsePerThread>>{});\n    constexpr int kRowsPerLoadTranspose = std::min(kRowsPerLoadLSE, kMaxSplits);\n    // To make sure that kMaxSplits is within 1 warp: we decide how many elements within kMaxSplits\n    // each thread should hold. If kMaxSplits = 16, then each thread holds 2 elements (128 threads,\n    // kBlockM rows, so each time we load we can load 128 / kBlockM rows).\n    // constexpr int kThreadsPerSplit = kMaxSplits / kRowsPerLoadTranspose;\n    // static_assert(kThreadsPerSplit <= 32);\n    static_assert(kRowsPerLoadTranspose <= 32);\n    static_assert(kNLsePerThread * kRowsPerLoadTranspose <= kMaxSplits);\n    #pragma unroll\n    for (int l = 0; l < kNLsePerThread; ++l) {\n        const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose;\n        const int col = tidx / kRowsPerLoadTranspose;\n        lse_accum(l) = (row < kMaxSplits && col < kBlockM) ? sLSE[row][col] : -INFINITY;\n        // if (bidx == 0 && tidx < 32) { printf(\"tidx = %d, row = %d, col = %d, lse = %f\\n\", tidx, row, col, lse_accum(l)); }\n    }\n\n    // Compute the logsumexp of the LSE along the split dimension.\n    ElementAccum lse_max = lse_accum(0);\n    #pragma unroll\n    for (int l = 1; l < kNLsePerThread; ++l) { lse_max = max(lse_max, lse_accum(l)); }\n    MaxOp<float> max_op;\n    lse_max = Allreduce<kRowsPerLoadTranspose>::run(lse_max, max_op);\n    lse_max = lse_max == -INFINITY ? 0.0f : lse_max;  // In case all local LSEs are -inf\n    float lse_sum = expf(lse_accum(0) - lse_max);\n    #pragma unroll\n    for (int l = 1; l < kNLsePerThread; ++l) { lse_sum += expf(lse_accum(l) - lse_max); }\n    SumOp<float> sum_op;\n    lse_sum = Allreduce<kRowsPerLoadTranspose>::run(lse_sum, sum_op);\n    // For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise\n    // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum.\n    ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max;\n    // if (bidx == 0 && tidx < 32) { printf(\"tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\\n\", tidx, lse_accum(0), lse_max, lse_logsum); }\n    if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) {\n        if (params.unpadded_lse) {\n            const index_t lse_offset = row_offset_lse + tidx / kRowsPerLoadTranspose;\n            if (lse_offset < lse_size) {\n                gLSE_unpadded(lse_offset) = lse_logsum;\n            }\n        } else {\n            gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum;\n        }\n    }\n    // Store the scales exp(lse - lse_logsum) in shared memory.\n    #pragma unroll\n    for (int l = 0; l < kNLsePerThread; ++l) {\n        const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose;\n        const int col = tidx / kRowsPerLoadTranspose;\n        if (row < params.num_splits && col < kBlockM) { sLSE[row][col] = expf(lse_accum(l) - lse_logsum); }\n    }\n    __syncthreads();\n\n    const index_t row_offset_oaccum = bidx * kBlockM * params.d_rounded;\n    Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.oaccum_ptr) + row_offset_oaccum),\n                                 Shape<Int<kBlockM>, Int<kHeadDim>>{},\n                                 Stride<Int<kHeadDim>, _1>{});\n    constexpr int kBlockN = kNThreads / kBlockM;\n    using GmemLayoutAtomOaccum = Layout<Shape<Int<kBlockM>, Int<kBlockN>>, Stride<Int<kBlockN>, _1>>;\n    using GmemTiledCopyOaccum = decltype(\n        make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},\n                        GmemLayoutAtomOaccum{},\n                        Layout<Shape < _1, _4>>{}));  // Val layout, 4 vals per store\n    GmemTiledCopyOaccum gmem_tiled_copy_Oaccum;\n    auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);\n    Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum);\n    Tensor tOrO = make_tensor<ElementAccum>(shape(tOgOaccum));\n    Tensor tOrOaccum = make_tensor<ElementAccum>(shape(tOgOaccum));\n    clear(tOrO);\n\n    // Predicates\n    Tensor cOaccum = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{});\n    // Repeat the partitioning with identity layouts\n    Tensor tOcOaccum = gmem_thr_copy_Oaccum.partition_S(cOaccum);\n    Tensor tOpOaccum = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));\n    if (!Is_even_K) {\n        #pragma unroll\n        for (int k = 0; k < size(tOpOaccum); ++k) { tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d; }\n    }\n    // Load Oaccum in then scale and accumulate to O\n    for (int split = 0; split < params.num_splits; ++split) {\n        FLASH_NAMESPACE::copy</*Is_even_MN=*/false, Is_even_K>(\n            gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * params.seqlen_q - bidx * kBlockM\n        );\n        #pragma unroll\n        for (int m = 0; m < size<1>(tOrOaccum); ++m) {\n            int row = get<0>(tOcOaccum(0, m, 0));\n            ElementAccum lse_scale = sLSE[split][row];\n            #pragma unroll\n            for (int k = 0; k < size<2>(tOrOaccum); ++k) {\n                #pragma unroll\n                for (int i = 0; i < size<0>(tOrOaccum); ++i) {\n                    tOrO(i, m, k) += lse_scale * tOrOaccum(i, m, k);\n                }\n            }\n        // if (cute::thread0()) { printf(\"lse_scale = %f, %f\\n\", sLSE[split][0], sLSE[split][1]); print(tOrOaccum); }\n        }\n        tOgOaccum.data() = tOgOaccum.data() + params.b * params.h * params.seqlen_q * params.d_rounded;\n    }\n    // if (cute::thread0()) { print_tensor(tOrO); }\n\n    Tensor rO = FLASH_NAMESPACE::convert_type<Element>(tOrO);\n    // Write to gO\n    #pragma unroll\n    for (int m = 0; m < size<1>(rO); ++m) {\n        const int idx = bidx * kBlockM + get<0>(tOcOaccum(0, m, 0));\n        if (idx < params.b * params.h * params.seqlen_q) {\n            const int batch_idx = idx / (params.h * params.seqlen_q);\n            const int head_idx = (idx - batch_idx * (params.h * params.seqlen_q)) / params.seqlen_q;\n            // The index to the rows of Q\n            const int row = idx - batch_idx * (params.h * params.seqlen_q) - head_idx * params.seqlen_q;\n            auto o_ptr = reinterpret_cast<Element *>(params.o_ptr) + batch_idx * params.o_batch_stride\n                + head_idx * params.o_head_stride + row * params.o_row_stride;\n            #pragma unroll\n            for (int k = 0; k < size<2>(rO); ++k) {\n                if (Is_even_K || tOpOaccum(k)) {\n                    const int col = get<1>(tOcOaccum(0, m, k));\n                    Tensor gO = make_tensor(make_gmem_ptr(o_ptr + col),\n                                            Shape<Int<decltype(size<0>(rO))::value>>{}, Stride<_1>{});\n                    // TODO: Should check if this is using vectorized store, but it seems pretty fast\n                    copy(rO(_, m, k), gO);\n                    // if (bidx == 0 && tidx == 0) { printf(\"tidx = %d, idx = %d, batch_idx = %d, head_idx = %d, row = %d, col = %d\\n\", tidx, idx, batch_idx, head_idx, row, col); print(rO(_, m, k)); print(gO); }\n                    // reinterpret_cast<uint64_t *>(o_ptr)[col / 4] = recast<uint64_t>(rO)(0, m, k);\n                }\n            }\n        }\n    }\n}\n\n} // namespace FLASH_NAMESPACE\n"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_launch_template.h",
    "content": "/******************************************************************************\n * Copyright (c) 2023, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n#include \"namespace_config.h\"\n#include <c10/cuda/CUDAException.h>  // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK\n\n#include \"static_switch.h\"\n#include \"hardware_info.h\"\n#include \"flash.h\"\n#include \"flash_fwd_kernel.h\"\n\nnamespace FLASH_NAMESPACE {\n\n// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800\n#define ARCH_SUPPORTS_FLASH\n#define KERNEL_PARAM_MODIFIER __grid_constant__\n#else\n#define KERNEL_PARAM_MODIFIER\n#endif\n\n// Define a macro for unsupported architecture handling to centralize the error message\n#define FLASH_UNSUPPORTED_ARCH printf(\"FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!\");\n\n// Use a macro to clean up kernel definitions\n#define DEFINE_FLASH_FORWARD_KERNEL(kernelName, ...) \\\ntemplate<typename Kernel_traits, __VA_ARGS__> \\\n__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params)\n\nDEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) {\n    #if defined(ARCH_SUPPORTS_FLASH)\n        static_assert(!(Is_causal && Is_local)); // Enforce constraints\n        FLASH_NAMESPACE::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Return_softmax>(params);\n    #else\n        FLASH_UNSUPPORTED_ARCH\n    #endif\n}\n\nDEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV) {\n    #if defined(ARCH_SUPPORTS_FLASH)\n        FLASH_NAMESPACE::compute_attn_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV>(params);\n    #else\n        FLASH_UNSUPPORTED_ARCH\n    #endif\n}\n\nDEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int Log_max_splits, bool Is_even_K) {\n    static_assert(Log_max_splits >= 1);\n    FLASH_NAMESPACE::combine_attn_seqk_parallel<Kernel_traits, kBlockM, Log_max_splits, Is_even_K>(params);\n}\n\ntemplate<typename Kernel_traits, bool Is_dropout, bool Is_causal>\nvoid run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {\n    constexpr size_t smem_size = Kernel_traits::kSmemSize;\n    // printf(\"smem_size = %d\\n\", smem_size);\n\n    // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.\n    // https://github.com/kokkos/kokkos-kernels/issues/349\n    // https://github.com/HazyResearch/flash-attention/issues/21\n\n    const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;\n    dim3 grid(num_m_block, params.b, params.h);\n    const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;\n    const bool is_even_K = params.d == Kernel_traits::kHeadDim;\n    const bool return_softmax = params.p_ptr != nullptr;\n    BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {\n        EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {\n            LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {\n                BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {\n                    ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {\n                        SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {\n                            // Will only return softmax if dropout, to reduce compilation time.\n                            // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.\n                            // If return_softmax, set IsEvenMNConst to false to reduce number of templates\n                            // If head dim > 128, set IsEvenMNConst to false to reduce number of templates\n                            // If Is_local, set Is_causal to false\n                            auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !Has_alibi && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst && !ReturnSoftmaxConst && !Has_alibi, Is_softcap, ReturnSoftmaxConst && Is_dropout && !Is_softcap>;\n                            // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;\n                            // printf(\"IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\\n\", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));\n                            // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;\n                            if (smem_size >= 48 * 1024) {\n                                C10_CUDA_CHECK(cudaFuncSetAttribute(\n                                    kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));\n                            }\n                            // int ctas_per_sm;\n                            // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(\n                            //     &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);\n                            // printf(\"smem_size = %d, CTAs per SM = %d\\n\", int(smem_size), ctas_per_sm);\n                            kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);\n                            C10_CUDA_KERNEL_LAUNCH_CHECK();\n                        });\n                    });\n                });\n            });\n        });\n    });\n}\n\ntemplate<typename Kernel_traits, bool Is_causal>\nvoid run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {\n    static_assert(!Kernel_traits::Is_Q_in_regs, \"SplitKV implementation does not support Is_Q_in_regs\");\n    static_assert(!Kernel_traits::Share_Q_K_smem, \"SplitKV implementation does not support Share_Q_K_smem\");\n    constexpr size_t smem_size = Kernel_traits::kSmemSize;\n    const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;\n    dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.b, params.num_splits > 1 ? params.b * params.h : params.h);\n    const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;\n    const bool is_even_K = params.d == Kernel_traits::kHeadDim;\n    BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {\n        EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {\n            LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {\n                BOOL_SWITCH(params.num_splits > 1, Split, [&] {\n                    BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {\n                        ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {\n                            SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {\n                                // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.\n                                // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.\n                                // If Is_local, set Is_causal to false\n                                auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && !Has_alibi && Kernel_traits::kHeadDim <= 128, IsEvenKConst && !Has_alibi, Is_softcap, Split, Append_KV>;\n                                // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;\n                                // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;\n                                if (smem_size >= 48 * 1024) {\n                                    C10_CUDA_CHECK(cudaFuncSetAttribute(\n                                        kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));\n                                }\n                                kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);\n                                C10_CUDA_KERNEL_LAUNCH_CHECK();\n                            });\n                        });\n                    });\n                });\n            });\n        });\n    });\n    if (params.num_splits > 1) {\n        // We want kBlockM to be as small as possible for more parallelism.\n        // With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4.\n        // If headdim is divisible by 64, then we set kBlockM = 8, etc.\n        constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16);\n        dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM);\n        EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {\n            if (params.num_splits <= 2) {\n                flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 1, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);\n            } else if (params.num_splits <= 4) {\n                flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 2, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);\n            } else if (params.num_splits <= 8) {\n                flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 3, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);\n            } else if (params.num_splits <= 16) {\n                flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 4, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);\n            } else if (params.num_splits <= 32) {\n                flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 5, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);\n            } else if (params.num_splits <= 64) {\n                flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 6, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);\n            } else if (params.num_splits <= 128) {\n                flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 7, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);\n            }\n            C10_CUDA_KERNEL_LAUNCH_CHECK();\n        });\n    }\n}\n\ntemplate<typename T, int Headdim, bool Is_causal>\nvoid run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream) {\n    constexpr static int kBlockM = 64;  // Fixed for all head dimensions\n    // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,\n    // and for headdim 192 with block size 64 x 128.\n    constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64);\n    run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>, Is_causal>(params, stream);\n}\n\ntemplate<typename T, bool Is_causal>\nvoid run_mha_fwd_hdim32(Flash_fwd_params &params, cudaStream_t stream) {\n    constexpr static int Headdim = 32;\n    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {\n        run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);\n    });\n}\n\ntemplate<typename T, bool Is_causal>\nvoid run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {\n    constexpr static int Headdim = 64;\n    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {\n        if constexpr(!Is_dropout) {\n            // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower\n            // Using block size (64 x 256) is 27% slower for seqlen=2k\n            // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling\n            run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);\n            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);\n            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);\n        } else {\n            run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);\n            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);\n            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);\n            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);\n        }\n    });\n}\n\ntemplate<typename T, bool Is_causal>\nvoid run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) {\n    constexpr static int Headdim = 96;\n    auto [cc_major, cc_minor] = get_compute_capability(get_current_device());\n    bool is_sm8x = cc_major == 8 && cc_minor > 0;\n    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {\n        // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),\n        if (is_sm8x) {\n            if constexpr(!Is_causal) {\n                run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);\n            } else {\n                run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);\n            }\n        } else {\n            run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);\n        }\n        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);\n        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);\n        // These two are always slower\n        // run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, T>>(params, stream);\n        // run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, T>>(params, stream);\n    });\n}\n\ntemplate<typename T, bool Is_causal>\nvoid run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {\n    constexpr static int Headdim = 128;\n    auto [cc_major, cc_minor] = get_compute_capability(get_current_device());\n    bool is_sm8x = cc_major == 8 && cc_minor > 0;\n    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {\n        if constexpr(!Is_dropout) {\n            // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),\n            // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM.\n            if (is_sm8x) {\n                if constexpr(!Is_causal) {\n                    run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);\n                } else {\n                    run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);\n                }\n            } else {\n                run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);\n            }\n            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);\n            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);\n            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);\n            // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k\n            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);\n            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);\n            // 1st ones are good for H100, A100\n            // 2nd one is good for A6000 bc we get slightly better occupancy\n        } else {\n            run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);\n            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);\n            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);\n            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);\n        }\n    });\n}\n\ntemplate<typename T, bool Is_causal>\nvoid run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) {\n    constexpr static int Headdim = 192;\n    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {\n        if constexpr(!Is_dropout) {\n            run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);\n        } else {\n            run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);\n        }\n        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);\n        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);\n        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);\n        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);\n        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);\n    });\n}\n\ntemplate<typename T, bool Is_causal>\nvoid run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {\n    constexpr static int Headdim = 256;\n    int device;\n    cudaGetDevice(&device);\n    int max_smem_per_sm, max_smem_per_block;\n    cudaError status_ = cudaDeviceGetAttribute(\n        &max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device);\n    status_ = cudaDeviceGetAttribute(\n        &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);\n    if (status_ != cudaSuccess) {\n      C10_CUDA_CHECK(status_);\n    }\n    // printf(\"max_smem_per_sm = %d, max_smem_per_block = %d\\n\", max_smem_per_sm, max_smem_per_block);\n    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {\n        // For A100, we want to run with 128 x 64 (128KB smem).\n        // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM.\n        if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) {\n            run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);\n        } else {\n            run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);\n        }\n        // 64 KB\n        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);\n        // 96 KB\n        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);\n    });\n}\n}  // namespace FLASH_NAMESPACE\n"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream);\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream);\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream);\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream);\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 192, true>(Flash_fwd_params &params, cudaStream_t stream);\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 192, false>(Flash_fwd_params &params, cudaStream_t stream);\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 192, true>(Flash_fwd_params &params, cudaStream_t stream);\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 192, false>(Flash_fwd_params &params, cudaStream_t stream);\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 256, true>(Flash_fwd_params &params, cudaStream_t stream);\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 256, false>(Flash_fwd_params &params, cudaStream_t stream);\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 256, true>(Flash_fwd_params &params, cudaStream_t stream);\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 256, false>(Flash_fwd_params &params, cudaStream_t stream);\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 32, true>(Flash_fwd_params &params, cudaStream_t stream);\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 32, false>(Flash_fwd_params &params, cudaStream_t stream);\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 32, true>(Flash_fwd_params &params, cudaStream_t stream);\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 32, false>(Flash_fwd_params &params, cudaStream_t stream);\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 64, true>(Flash_fwd_params &params, cudaStream_t stream);\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 64, false>(Flash_fwd_params &params, cudaStream_t stream);\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 64, true>(Flash_fwd_params &params, cudaStream_t stream);\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 64, false>(Flash_fwd_params &params, cudaStream_t stream);\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 96, true>(Flash_fwd_params &params, cudaStream_t stream);\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 96, false>(Flash_fwd_params &params, cudaStream_t stream);\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 96, true>(Flash_fwd_params &params, cudaStream_t stream);\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n#include \"namespace_config.h\"\n#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {\n\ntemplate void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 96, false>(Flash_fwd_params &params, cudaStream_t stream);\n\n} // namespace FLASH_NAMESPACE"
  },
  {
    "path": "csrc/flash_attn/src/generate_kernels.py",
    "content": "import argparse\nimport itertools\nfrom dataclasses import dataclass\nfrom pathlib import Path\nfrom typing import List, Optional\n\nDTYPE_MAP = {\n    \"fp16\": \"cutlass::half_t\",\n    \"bf16\": \"cutlass::bfloat16_t\",\n}\n\nSM = [80]  # Sm80 kernels support up to\nHEAD_DIMENSIONS = [32, 64, 96, 128, 192, 256]\nIS_CAUSAL = [\"false\", \"true\"]\nNAMESPACE_INCLUDE = '#include \"namespace_config.h\"\\n'\n\ndef get_fwd_template() -> str:\n    return NAMESPACE_INCLUDE + \"\"\"#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {{\n\ntemplate<>\nvoid run_mha_fwd_<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_fwd_params &params, cudaStream_t stream) {{\n    run_mha_fwd_hdim{HEAD_DIM}<{DTYPE}, {IS_CAUSAL}>(params, stream);\n}}\n\n}} // namespace FLASH_NAMESPACE\"\"\"\n\ndef get_fwd_split_template() -> str:\n    return NAMESPACE_INCLUDE + \"\"\"#include \"flash_fwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {{\n\ntemplate void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_fwd_params &params, cudaStream_t stream);\n\n}} // namespace FLASH_NAMESPACE\"\"\"\n\ndef get_bwd_template() -> str:\n    return NAMESPACE_INCLUDE + \"\"\"#include \"flash_bwd_launch_template.h\"\n\nnamespace FLASH_NAMESPACE {{\n\ntemplate<>\nvoid run_mha_bwd_<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_bwd_params &params, cudaStream_t stream) {{\n    run_mha_bwd_hdim{HEAD_DIM}<{DTYPE}, {IS_CAUSAL}>(params, stream);\n}}\n\n}} // namespace FLASH_NAMESPACE\"\"\"\n\n@dataclass\nclass Kernel:\n    sm: int\n    dtype: str\n    head_dim: int\n    is_causal: bool\n    direction: str\n\n    @property\n    def template(self) -> str:\n        template_funcs = {\n            \"fwd\": get_fwd_template,\n            \"bwd\": get_bwd_template,\n            \"fwd_split\": get_fwd_split_template\n        }\n        template_func = template_funcs[self.direction]\n        return template_func().format(\n            DTYPE=DTYPE_MAP[self.dtype],\n            HEAD_DIM=self.head_dim,\n            IS_CAUSAL=self.is_causal\n        )\n\n    @property\n    def filename(self) -> str:\n        return f\"flash_{self.direction}_hdim{self.head_dim}_{self.dtype}_{'causal_' if self.is_causal == 'true' else ''}sm{self.sm}.cu\"\n\ndef get_all_kernels() -> List[Kernel]:\n    for direction in [\"fwd\", \"fwd_split\", \"bwd\"]:\n        for dtype, head_dim, is_causal, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, IS_CAUSAL, SM):\n            yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, is_causal=is_causal, direction=direction)\n\ndef write_kernel(kernel: Kernel, autogen_dir: Path) -> None:\n    prelude = \"\"\"// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\\n\"\"\"\n    content = prelude + kernel.template\n    (autogen_dir / kernel.filename).write_text(content)\n\ndef main(output_dir: Optional[str]) -> None:\n    if output_dir is None:\n        output_dir = Path(__file__).parent\n    else:\n        output_dir = Path(output_dir)\n\n    for kernel in get_all_kernels():\n        write_kernel(kernel, output_dir)\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\n        prog=\"generate_kernels\",\n        description=\"Generate the flash_attention kernels template instantiations\",\n    )\n    parser.add_argument(\n        \"-o\",\n        \"--output_dir\",\n        required=False,\n        help=\"Where to generate the kernels \"\n        \" will default to the current directory \",\n    )\n    args = parser.parse_args()\n    main(args.output_dir)\n"
  },
  {
    "path": "csrc/flash_attn/src/hardware_info.h",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n#include <tuple>\n\n#if !defined(__CUDACC_RTC__)\n#include \"cuda_runtime.h\"\n#endif\n\n#define CHECK_CUDA(call)                                                       \\\n  do {                                                                         \\\n    cudaError_t status_ = call;                                                \\\n    if (status_ != cudaSuccess) {                                              \\\n      fprintf(stderr, \"CUDA error (%s:%d): %s\\n\", __FILE__, __LINE__,          \\\n              cudaGetErrorString(status_));                                    \\\n      exit(1);                                                                 \\\n    }                                                                          \\\n  } while (0)\n\n\ninline int get_current_device() {\n    int device;\n    CHECK_CUDA(cudaGetDevice(&device));\n    return device;\n}\n\ninline std::tuple<int, int> get_compute_capability(int device) {\n    int capability_major, capability_minor;\n    CHECK_CUDA(cudaDeviceGetAttribute(&capability_major, cudaDevAttrComputeCapabilityMajor, device));\n    CHECK_CUDA(cudaDeviceGetAttribute(&capability_minor, cudaDevAttrComputeCapabilityMinor, device));\n    return {capability_major, capability_minor};\n}\n\ninline int get_num_sm(int device) {\n    int multiprocessor_count;\n    CHECK_CUDA(cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device));\n    return multiprocessor_count;\n}\n"
  },
  {
    "path": "csrc/flash_attn/src/kernel_traits.h",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n#include \"cute/tensor.hpp\"\n\n#include \"cutlass/cutlass.h\"\n#include \"cutlass/layout/layout.h\"\n#include <cutlass/numeric_types.h>\n\nusing namespace cute;\n\ntemplate<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::half_t>\nstruct Flash_kernel_traits {\n\n#if defined(__CUDA_ARCH__) &&  __CUDA_ARCH__ >= 800\n    using Element = elem_type;\n    static constexpr bool Has_cp_async = true;\n#else\n    using Element = cutlass::half_t;\n    static constexpr bool Has_cp_async = false;\n#endif\n\n    using ElementAccum = float;\n    using index_t = int64_t;\n\n#if defined(__CUDA_ARCH__) &&  __CUDA_ARCH__ >= 800\n    using MMA_Atom_Arch = std::conditional_t<\n        std::is_same_v<elem_type, cutlass::half_t>,\n        MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,\n        MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>\n    >;\n#else\n    using MMA_Atom_Arch = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;\n#endif\n\n#if defined(__CUDA_ARCH__) &&  __CUDA_ARCH__ >= 750\n    using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, elem_type>;\n    using SmemCopyAtomTransposed = Copy_Atom<SM75_U16x8_LDSM_T, elem_type>;\n#else\n    using SmemCopyAtom = Copy_Atom<DefaultCopy, elem_type>;\n    using SmemCopyAtomTransposed = Copy_Atom<DefaultCopy, elem_type>;\n#endif\n};\n\n// If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true\ntemplate<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, bool Is_Q_in_regs_=false, bool Share_Q_K_smem_=false, typename elem_type=cutlass::half_t,\n         typename Base=Flash_kernel_traits<kHeadDim_, kBlockM_, kBlockN_, kNWarps_, elem_type> >\nstruct Flash_fwd_kernel_traits : public Base {\n    using Element = typename Base::Element;\n    using ElementAccum = typename Base::ElementAccum;\n    using index_t = typename Base::index_t;\n    static constexpr bool Has_cp_async = Base::Has_cp_async;\n    using SmemCopyAtom = typename Base::SmemCopyAtom;\n    using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed;\n\n    static constexpr bool Share_Q_K_smem = Share_Q_K_smem_;\n    static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem;\n\n    // The number of threads.\n    static constexpr int kNWarps = kNWarps_;\n    static constexpr int kNThreads = kNWarps * 32;\n\n    static constexpr int kBlockM = kBlockM_;\n    static constexpr int kBlockN = kBlockN_;\n    static constexpr int kHeadDim = kHeadDim_;\n    static_assert(kHeadDim % 32 == 0);\n    static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;\n    static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);\n    static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;\n\n    using TiledMma = TiledMMA<\n        typename Base::MMA_Atom_Arch,\n        Layout<Shape<Int<kNWarps>,_1,_1>>,  // 4x1x1 or 8x1x1 thread group\n        Tile<Int<16 * kNWarps>, _16, _16>>;\n\n    using SmemLayoutAtomQ = decltype(\n        composition(Swizzle<kSwizzle, 3, 3>{},\n                    // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128\n                    Layout<Shape<_8, Int<kBlockKSmem>>,\n                           Stride<Int<kBlockKSmem>, _1>>{}));\n    using SmemLayoutQ = decltype(tile_to_shape(\n        SmemLayoutAtomQ{},\n        Shape<Int<kBlockM>, Int<kHeadDim>>{}));\n\n    using SmemLayoutKV = decltype(tile_to_shape(\n        SmemLayoutAtomQ{},\n        Shape<Int<kBlockN>, Int<kHeadDim>>{}));\n\n    // https://github.com/ColfaxResearch/cutlass-kernels/blob/a222587e6d59b93ba704853d3946fb686d8b8892/src/fmha/fmha_forward.cu#L434\n    using SmemLayoutVtransposed = decltype(\n        composition(SmemLayoutKV{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockN>>{}, GenRowMajor{})));\n    using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{}));\n\n    using SmemLayoutAtomO = decltype(\n        composition(Swizzle<kSwizzle, 3, 3>{},\n                    Layout<Shape<Int<8>, Int<kBlockKSmem>>,\n                           Stride<Int<kBlockKSmem>, _1>>{}));\n    using SmemLayoutO = decltype(tile_to_shape(\n        SmemLayoutAtomO{},\n        Shape<Int<kBlockM>, Int<kHeadDim>>{}));\n    using SmemCopyAtomO = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>;\n    using SmemCopyAtomOaccum = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>;\n\n    static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element);\n    static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element);\n    static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize;\n\n    static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);\n    static_assert(kHeadDim % kGmemElemsPerLoad == 0, \"kHeadDim must be a multiple of kGmemElemsPerLoad\");\n    // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts.\n    // For example, for d=128, smem is split into 2 \"pages\", each page takes care of columns\n    // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem,\n    // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page,\n    // to the same banks.\n    static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;\n    static_assert(kNThreads % kGmemThreadsPerRow == 0, \"kNThreads must be a multiple of kGmemThreadsPerRow\");\n    using GmemLayoutAtom = Layout<Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,\n                                  Stride<Int<kGmemThreadsPerRow>, _1>>;\n\n    // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading\n    // from the same address by the same threadblock. This is slightly faster.\n    using Gmem_copy_struct = std::conditional_t<\n        Has_cp_async,\n        SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,\n        AutoVectorizingCopyWithAssumedAlignment<128>\n    >;\n    using GmemTiledCopyQKV = decltype(\n        make_tiled_copy(Copy_Atom<Gmem_copy_struct, Element>{},\n                        GmemLayoutAtom{},\n                        Layout<Shape<_1, _8>>{}));  // Val layout, 8 vals per read\n    using GmemTiledCopyO = decltype(\n        make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},\n                        GmemLayoutAtom{},\n                        Layout<Shape<_1, _8>>{}));  // Val layout, 8 vals per store\n\n    using GmemLayoutAtomOaccum = std::conditional_t<\n        kBlockKSmem == 32,\n        Layout<Shape <_16, _8>,  // Thread layout, 8 threads per row\n               Stride< _8, _1>>,\n        Layout<Shape <_8, _16>,  // Thread layout, 16 threads per row\n               Stride< _16, _1>>\n    >;\n    using GmemTiledCopyOaccum = decltype(\n        make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},\n                        GmemLayoutAtomOaccum{},\n                        Layout<Shape < _1, _4>>{}));  // Val layout, 4 vals per store\n    using GmemLayoutAtomRotcossin = GmemLayoutAtom;\n    using GmemTiledCopyRotcossin = decltype(\n        make_tiled_copy(Copy_Atom<UniversalCopy<uint64_t>, Element>{},\n                        GmemLayoutAtomRotcossin{},\n                        Layout<Shape < _1, _4>>{}));  // Val layout, 4 vals per load\n    using GmemTiledCopyRotcossinCont = decltype(\n        make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},\n                        GmemLayoutAtomRotcossin{},\n                        Layout<Shape < _1, _8>>{}));  // Val layout, 8 vals per load\n};\n\n// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue.\n// No_double_buffer is another option to reduce smem usage, but will slow things down.\ntemplate<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_,\n         int AtomLayoutMSdP_=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=2,\n         bool Is_V_in_regs_=false, bool No_double_buffer_=false, typename elem_type=cutlass::half_t,\n         typename Base=Flash_kernel_traits<kHeadDim_, kBlockM_, kBlockN_, kNWarps_, elem_type> >\nstruct Flash_bwd_kernel_traits : public Base {\n    using Element = typename Base::Element;\n    using ElementAccum = typename Base::ElementAccum;\n    using index_t = typename Base::index_t;\n    static constexpr bool Has_cp_async = Base::Has_cp_async;\n    using SmemCopyAtom = typename Base::SmemCopyAtom;\n    using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed;\n\n    static constexpr bool Is_V_in_regs = Is_V_in_regs_;\n    static constexpr bool No_double_buffer = No_double_buffer_;\n\n    // The number of threads.\n    static constexpr int kNWarps = kNWarps_;\n    static constexpr int kNThreads = kNWarps * 32;\n\n    static constexpr int kBlockM = kBlockM_;\n    static constexpr int kBlockN = kBlockN_;\n    static constexpr int kHeadDim = kHeadDim_;\n    static_assert(kHeadDim % 32 == 0);\n    static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;\n    static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);\n    static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;\n\n    static constexpr int AtomLayoutMSdP = AtomLayoutMSdP_;\n    static_assert(kNWarps % AtomLayoutMSdP == 0);\n    static_assert(kNWarps % AtomLayoutNdKV == 0);\n    static_assert(kNWarps % AtomLayoutMdQ == 0);\n\n    using TiledMmaSdP = TiledMMA<\n        typename Base::MMA_Atom_Arch,\n        Layout<Shape<Int<AtomLayoutMSdP>, Int<kNWarps / AtomLayoutMSdP>, _1>>,\n        Tile<Int<16 * AtomLayoutMSdP>, Int<16 * kNWarps / AtomLayoutMSdP>, _16>>;\n\n    using TiledMmadKV = TiledMMA<\n        typename Base::MMA_Atom_Arch,\n        Layout<Shape<Int<AtomLayoutNdKV>, Int<kNWarps / AtomLayoutNdKV>, _1>>,\n        Tile<Int<16 * AtomLayoutNdKV>, Int<16 * kNWarps / AtomLayoutNdKV>, _16>>;\n\n    using TiledMmadQ = TiledMMA<\n        typename Base::MMA_Atom_Arch,\n        Layout<Shape<Int<AtomLayoutMdQ>, Int<kNWarps / AtomLayoutMdQ>, _1>>,  // 2x4x1 or 4x2x1 thread group\n        Tile<Int<16 * AtomLayoutMdQ>, Int<16 * kNWarps / AtomLayoutMdQ>, _16>>;\n\n    using SmemLayoutAtomQdO = decltype(\n        composition(Swizzle<kSwizzle, 3, 3>{},\n                    Layout<Shape<_8, Int<kBlockKSmem>>,\n                           Stride<Int<kBlockKSmem>, _1>>{}));\n    using SmemLayoutQdO = decltype(tile_to_shape(\n        SmemLayoutAtomQdO{},\n        make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));\n\n    using SmemLayoutAtomKV = decltype(\n        composition(Swizzle<kSwizzle, 3, 3>{},\n                    Layout<Shape<Int<kBlockM / kNWarps>, Int<kBlockKSmem>>,\n                           Stride<Int<kBlockKSmem>, _1>>{}));\n    using SmemLayoutKV = decltype(tile_to_shape(\n        // SmemLayoutAtomQdO{},\n        SmemLayoutAtomKV{},\n        make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));\n\n    using SmemLayoutKtransposed = decltype(\n        composition(SmemLayoutKV{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockN>>{}, GenRowMajor{})));\n    using SmemLayoutKtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutKtransposed{}));\n\n    // TODO: generalize to other values of kBlockN\n    // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2\n    // static constexpr int kPBlockN = kBlockN;\n    // Temporarily disabling this for hdim 256 on sm86 and sm89\n    // static_assert(kBlockN >= 64);\n    static_assert(kBlockN >= 32);\n    // TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest.\n    static constexpr int kPBlockN = kBlockN >= 64 ? 64 : 32;\n    static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64);\n    // static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3);\n    static constexpr int kSwizzlePdS = 3;\n    using SmemLayoutAtomPdS = decltype(\n        composition(Swizzle<kSwizzlePdS, 3, 3>{},\n                    Layout<Shape<Int<kBlockM>, Int<kPBlockN>>,\n                           Stride<Int<kPBlockN>, _1>>{}));\n    using SmemLayoutPdS = decltype(tile_to_shape(\n        SmemLayoutAtomPdS{},\n        make_shape(Int<kBlockM>{}, Int<kBlockN>{})));\n    using SmemLayoutPdStransposed = decltype(\n        composition(SmemLayoutPdS{}, make_layout(Shape<Int<kBlockN>, Int<kBlockM>>{}, GenRowMajor{})));\n    using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{}));\n\n    using SmemCopyAtomPdS = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>;\n\n    using SmemLayoutQdOtransposed = decltype(\n        composition(SmemLayoutQdO{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockM>>{}, GenRowMajor{})));\n    using SmemLayoutQdOtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutQdOtransposed{}));\n\n    using SmemLayoutAtomdKV = decltype(\n        composition(Swizzle<kSwizzle, 3, 3>{},\n                    Layout<Shape<_8, Int<kBlockKSmem>>,\n                           Stride<Int<kBlockKSmem>, _1>>{}));\n    using SmemLayoutdKV = decltype(tile_to_shape(\n        SmemLayoutAtomdKV{},\n        make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));\n    using SmemCopyAtomdKV = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>;\n\n    using SmemLayoutAtomdQ = decltype(\n        composition(Swizzle<kSwizzle, 3, 3>{},\n                    Layout<Shape<_8, Int<kBlockKSmem>>,\n                           Stride<Int<kBlockKSmem>, _1>>{}));\n    using SmemLayoutdQ = decltype(tile_to_shape(\n        SmemLayoutAtomdQ{},\n        make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));\n    using SmemCopyAtomdQ = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>;\n\n    // Double buffer for sQ\n    static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element);\n    static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element);\n    static constexpr int kSmemdSSize = size(SmemLayoutPdS{}) * sizeof(Element);\n    static constexpr int kSmemPSize = size(SmemLayoutPdS{}) * sizeof(Element);\n    static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element);\n    static constexpr int kSmemSize = kSmemQdOSize\n        + (!Is_V_in_regs\n           ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)\n           : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)));\n    static constexpr int kSmemSize1colblock = kSmemQdOSize\n        + (!Is_V_in_regs\n           ? kSmemKVSize + kSmemdSSize + kSmemPSize\n           : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize));\n\n    static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);\n    static_assert(kHeadDim % kGmemElemsPerLoad == 0, \"kHeadDim must be a multiple of kGmemElemsPerLoad\");\n    // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem\n    // to affect speed in practice.\n    static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;\n    static_assert(kNThreads % kGmemThreadsPerRow == 0, \"kNThreads must be a multiple of kGmemThreadsPerRow\");\n    using GmemLayoutAtom = Layout<Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,\n                                  Stride<Int<kGmemThreadsPerRow>, _1>>;\n\n    // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading\n    // from the same address by the same threadblock. This is slightly faster.\n    using Gmem_copy_struct = std::conditional_t<\n        Has_cp_async,\n        SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,\n        AutoVectorizingCopyWithAssumedAlignment<128>\n    >;\n    using GmemTiledCopyQKV = decltype(\n        make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{},\n                        GmemLayoutAtom{},\n                        Layout<Shape<_1, _8>>{}));  // Val layout, 8 vals per read\n    using GmemTiledCopydO = decltype(\n        make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{},\n                        GmemLayoutAtom{},\n                        Layout<Shape < _1, _8>>{}));  // Val layout, 8 vals per store\n    using GmemTiledCopydKV = decltype(\n        make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{},\n                        GmemLayoutAtom{},\n                        Layout<Shape < _1, _8>>{}));  // Val layout, 8 vals per store\n    using GmemTiledCopydQ = decltype(\n        make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{},\n                        GmemLayoutAtom{},\n                        Layout<Shape < _1, _8>>{}));  // Val layout, 8 vals per store\n    using GmemLayoutAtomdQaccum = std::conditional_t<\n        kBlockKSmem == 32,\n        Layout<Shape <_32, _8>,  // Thread layout, 8 threads per row\n               Stride< _8, _1>>,\n        Layout<Shape <_16, _16>,  // Thread layout, 16 threads per row\n               Stride< _16, _1>>\n    >;\n    using GmemTiledCopydQaccum = decltype(\n        make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},\n                        GmemLayoutAtomdQaccum{},\n                        Layout<Shape < _1, _4>>{}));  // Val layout, 4 vals per store\n\n    using GmemTiledCopydQaccumAtomicAdd = decltype(\n        make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},\n                        Layout<Shape <_8, _32>,  // Thread layout, 8 threads per row\n                               Stride<_32, _1>>{},\n                        Layout<Shape < _1, _1>>{}));  // Val layout, 1 val per store\n\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n"
  },
  {
    "path": "csrc/flash_attn/src/mask.h",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n#include \"namespace_config.h\"\n\n#include <cute/tensor.hpp>\n\nnamespace FLASH_NAMESPACE {\n\nusing namespace cute;\n\ntemplate <typename Engine, typename Layout>\n__forceinline__ __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_seqlen_k,\n                                  const int col_idx_offset_ = 0) {\n    // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))\n    static_assert(Layout::rank == 2, \"Only support 2D Tensor\");\n    const int lane_id = threadIdx.x % 32;\n    const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;\n    #pragma unroll\n    for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {\n        const int col_idx_base = col_idx_offset + nj * 8;\n        #pragma unroll\n        for (int j = 0; j < size<1, 0>(tensor); ++j) {\n            const int col_idx = col_idx_base + j;\n            if (col_idx >= max_seqlen_k) {\n                // Without the \"make_coord\" we get wrong results\n                #pragma unroll\n                for (int mi = 0; mi < size<0>(tensor); ++mi) {\n                    tensor(mi, make_coord(j, nj)) = -INFINITY;\n                }\n            }\n        }\n    }\n}\n\ntemplate <bool HasWSLeft=true, typename Engine, typename Layout>\n__forceinline__ __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,\n                                        const int max_seqlen_k, const int row_idx_offset,\n                                        const int max_seqlen_q, const int warp_row_stride,\n                                        const int window_size_left, const int window_size_right) {\n    // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))\n    static_assert(Layout::rank == 2, \"Only support 2D Tensor\");\n    const int lane_id = threadIdx.x % 32;\n    const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;\n    #pragma unroll\n    for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {\n        const int row_idx_base = row_idx_offset + mi * warp_row_stride;\n        #pragma unroll\n        for (int i = 0; i < size<0, 0>(tensor); ++i) {\n            const int row_idx = row_idx_base + i * 8;\n            const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);\n            const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);\n            #pragma unroll\n            for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {\n                const int col_idx_base = col_idx_offset + nj * 8;\n                #pragma unroll\n                for (int j = 0; j < size<1, 0>(tensor); ++j) {\n                    const int col_idx = col_idx_base + j;\n                    if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) {\n                        tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;\n                    }\n                }\n            }\n            // if (cute::thread0()) {\n            //     printf(\"mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\\n\", mi, i, row_idx, max_seqlen_k);\n            //     print(tensor(make_coord(i, mi), _));\n            //     // print(tensor(_, j + nj * size<1, 0>(tensor)));\n            // }\n        }\n    }\n}\n\ntemplate <typename Engine, typename Layout>\n__forceinline__ __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,\n                                         const int max_seqlen_k, const int row_idx_offset,\n                                         const int max_seqlen_q, const int warp_row_stride) {\n    // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0\n    apply_mask_local</*HasWSLeft=*/false>(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset,\n                                          max_seqlen_q, warp_row_stride, -1, 0);\n}\n\ntemplate <typename Engine0, typename Layout0, typename Engine1, typename Layout1>\n__forceinline__ __device__ void apply_mask_causal_w_idx(\n    Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol,\n    const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset)\n{\n    // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))\n    static_assert(Layout0::rank == 2, \"Only support 2D Tensor\");\n    static_assert(Layout1::rank == 2, \"Only support 2D Tensor\");\n    CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol));\n    CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol));\n    #pragma unroll\n    for (int mi = 0; mi < size<0>(tensor); ++mi) {\n        const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset + get<0>(idx_rowcol(mi, 0)));\n        #pragma unroll\n        for (int ni = 0; ni < size<1, 1>(tensor); ++ni) {\n            if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) {\n                tensor(mi, ni) = -INFINITY;\n            }\n        }\n        // if (cute::thread0()) {\n        //     printf(\"ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\\n\", ni, j, col_idx, max_seqlen_k);\n        //     print(tensor(_, make_coord(j, ni)));\n        //     // print(tensor(_, j + ni * size<1, 0>(tensor)));\n        // }\n    }\n}\n\ntemplate <bool Is_causal, bool Is_local, bool Has_alibi>\nstruct Mask {\n\n    const int max_seqlen_k, max_seqlen_q;\n    const int window_size_left, window_size_right;\n    const float alibi_slope;\n\n    __forceinline__ __device__ Mask(const int max_seqlen_k, const int max_seqlen_q,\n                                    const int window_size_left, const int window_size_right,\n                                    const float alibi_slope=0.f)\n        : max_seqlen_k(max_seqlen_k)\n        , max_seqlen_q(max_seqlen_q)\n        , window_size_left(window_size_left)\n        , window_size_right(window_size_right)\n        , alibi_slope(!Has_alibi ? 0.0 : alibi_slope) {\n    };\n\n    // Causal_mask: whether this particular iteration needs causal masking\n    template <bool Causal_mask=false, bool Is_even_MN=true, typename Engine, typename Layout>\n    __forceinline__ __device__ void apply_mask(Tensor<Engine, Layout> &tensor_,\n                                               const int col_idx_offset_,\n                                               const int row_idx_offset,\n                                               const int warp_row_stride) {\n        static_assert(!(Causal_mask && Is_local), \"Cannot be both causal and local\");\n        static_assert(Layout::rank == 3, \"Only support 3D Tensor\");\n        static_assert(decltype(size<0>(tensor_))::value == 4, \"First dimension must be 4\");\n        static constexpr bool Need_masking = Has_alibi || Causal_mask || Is_local || !Is_even_MN;\n        // if (cute::thread0()) { printf(\"Has_alibi = %d, Causal_mask=%d, Is_local=%d, Is_even_MN = %d, Need_masking = %d\\n\", Has_alibi, Causal_mask, Is_local, Is_even_MN, Need_masking); }\n        if constexpr (Need_masking) {\n            // Reshape tensor_ from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))\n            Tensor tensor = make_tensor(tensor_.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tensor_.layout()));\n            // Do we need both row and column indices, or just column incides?\n            static constexpr bool Col_idx_only = !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask;\n            const int lane_id = threadIdx.x % 32;\n            const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;\n            if constexpr (Col_idx_only) {\n                #pragma unroll\n                for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {\n                    const int col_idx_base = col_idx_offset + nj * 8;\n                    #pragma unroll\n                    for (int j = 0; j < size<1, 0>(tensor); ++j) {\n                        const int col_idx = col_idx_base + j;\n                        #pragma unroll\n                        for (int mi = 0; mi < size<0>(tensor); ++mi) {\n                            // No causal, no local\n                            if constexpr (Has_alibi) {\n                                tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;\n                            }\n                            if constexpr (!Is_even_MN) {\n                                if (col_idx >= max_seqlen_k) { tensor(mi, make_coord(j, nj)) = -INFINITY; }\n                            }\n                        }\n                    }\n                }\n            } else {\n                #pragma unroll\n                for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {\n                    const int row_idx_base = row_idx_offset + mi * warp_row_stride;\n                    #pragma unroll\n                    for (int i = 0; i < size<0, 0>(tensor); ++i) {\n                        const int row_idx = row_idx_base + i * 8;\n                        const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);\n                        const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);\n                        #pragma unroll\n                        for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {\n                            const int col_idx_base = col_idx_offset + nj * 8;\n                            #pragma unroll\n                            for (int j = 0; j < size<1, 0>(tensor); ++j) {\n                                const int col_idx = col_idx_base + j;\n                                if constexpr (Has_alibi) {\n                                    if constexpr (Is_causal) {\n                                        tensor(make_coord(i, mi), make_coord(j, nj)) += alibi_slope * col_idx;\n                                    } else {\n                                        tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);\n\n                                    }\n                                }\n                                if constexpr (Causal_mask) {\n                                    if (col_idx >= col_idx_limit_right) {\n                                        tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;\n                                    }\n                                }\n                                if constexpr (Is_local) {\n                                    if (col_idx >= col_idx_limit_right || col_idx < col_idx_limit_left) {\n                                        tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;\n                                    }\n                                }\n                                if constexpr (!Causal_mask && !Is_local && !Is_even_MN) {\n                                    // Causal and Local already handles MN masking\n                                    if (col_idx >= max_seqlen_k) {\n                                        tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;\n                                    }\n                                }\n                            }\n                        }\n                    }\n                }\n            }\n        }\n    };\n\n};\n\n} // namespace FLASH_NAMESPACE\n"
  },
  {
    "path": "csrc/flash_attn/src/namespace_config.h",
    "content": "/**\n * @file flash_namespace_config.h\n * @brief Configuration file for Flash namespace management and isolation\n *\n * This header provides configuration macros for managing the Flash namespace\n * across a codebase. It allows for flexible namespace naming and provides\n * utilities for namespace declaration and scoping.\n *\n * Usage Examples:\n *\n * 1. Basic namespace wrapping:\n * @code\n *   BEGIN_FLASH_NAMESPACE\n *   class FlashDevice {\n *     // Implementation\n *   };\n *   END_FLASH_NAMESPACE\n * @endcode\n *\n * 2. Accessing types within the namespace:\n * @code\n *   FLASH_NAMESPACE_ALIAS(FlashDevice) device;\n * @endcode\n *\n * 3. Defining content within namespace scope:\n * @code\n *   FLASH_NAMESPACE_SCOPE(\n *     struct Configuration {\n *       uint32_t size;\n *       bool enabled;\n *     };\n *   )\n * @endcode\n *\n * 4. Custom namespace name:\n * @code\n *   #define FLASH_NAMESPACE custom_flash\n *   #include \"flash_namespace_config.h\"\n * @endcode\n *\n * Configuration:\n * - The default namespace is 'flash' if FLASH_NAMESPACE is not defined\n * - Define FLASH_NAMESPACE before including this header to customize the\n * namespace name\n *\n * Best Practices:\n * - Include this header in all files that need access to the Flash namespace\n *\n */\n#pragma once\n\n#ifndef FLASH_NAMESPACE_CONFIG_H\n#define FLASH_NAMESPACE_CONFIG_H\n\n// Set default namespace to flash\n#ifndef FLASH_NAMESPACE\n#define FLASH_NAMESPACE flash\n#endif\n\n#define FLASH_NAMESPACE_ALIAS(name) FLASH_NAMESPACE::name\n\n#define FLASH_NAMESPACE_SCOPE(content)                                         \\\n  namespace FLASH_NAMESPACE {                                                  \\\n  content                                                                      \\\n  }\n\n#endif // FLASH_NAMESPACE_CONFIG_H\n"
  },
  {
    "path": "csrc/flash_attn/src/philox.cuh",
    "content": "// Pytorch also has an implementation of Philox RNG: https://github.com/pytorch/pytorch/blob/8ca3c881db3e3510fcb7725389f6a0633c9b992c/torch/csrc/jit/tensorexpr/cuda_random.h\n#pragma once\n// Philox CUDA.\n\n#include \"namespace_config.h\"\n\nnamespace FLASH_NAMESPACE {\n\nstruct ull2 {\n    unsigned long long x;\n    unsigned long long y;\n};\n\n__forceinline__ __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) {\n    uint2 *res;\n    unsigned long long tmp;\n    asm (\"mul.wide.u32 %0, %1, %2;\\n\\t\"\n          : \"=l\"(tmp)\n          : \"r\"(a), \"r\"(b));\n    res = (uint2*)(&tmp);\n    return *res;\n}\n\n__forceinline__ __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) {\n    constexpr unsigned long kPhiloxSA = 0xD2511F53;\n    constexpr unsigned long kPhiloxSB = 0xCD9E8D57;\n    uint2 res0 = mulhilo32(kPhiloxSA, ctr.x);\n    uint2 res1 = mulhilo32(kPhiloxSB, ctr.z);\n    uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x};\n    return ret;\n}\n\n__forceinline__ __device__ uint4 philox(unsigned long long seed,\n                               unsigned long long subsequence,\n                               unsigned long long offset) {\n    constexpr unsigned long kPhilox10A = 0x9E3779B9;\n    constexpr unsigned long kPhilox10B = 0xBB67AE85;\n    uint2 key = reinterpret_cast<uint2&>(seed);\n    uint4 counter;\n    ull2 *tmp = reinterpret_cast<ull2*>(&counter);\n    tmp->x = offset;\n    tmp->y = subsequence;\n    #pragma unroll\n    for (int i = 0; i < 6; i++) {\n        counter = philox_single_round(counter, key);\n        key.x += (kPhilox10A);\n        key.y += (kPhilox10B);\n    }\n    uint4 output = philox_single_round(counter, key);\n    return output;\n}\n\n} // namespace FLASH_NAMESPACE\n"
  },
  {
    "path": "csrc/flash_attn/src/philox_unpack.cuh",
    "content": "// This is purely so that it works with torch 2.1. For torch 2.2+ we can include ATen/cuda/PhiloxUtils.cuh\n\n#pragma once\n#include <ATen/cuda/detail/UnpackRaw.cuh>\n"
  },
  {
    "path": "csrc/flash_attn/src/rotary.h",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n#include <cute/tensor.hpp>\n\n#include \"namespace_config.h\"\n#include \"utils.h\"\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nnamespace FLASH_NAMESPACE {\n\nusing namespace cute;\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <bool Is_even_K=true, bool Clear_OOB_K=true,\n          typename Engine0, typename Layout0, typename Engine1, typename Layout1,\n          typename Engine2, typename Layout2, typename Engine3, typename Layout3>\n__forceinline__ __device__ void copy_rotary_interleaved(Tensor<Engine0, Layout0> const &S,\n                                               Tensor<Engine1, Layout1> &D,\n                                               Tensor<Engine2, Layout2> const &Cos,\n                                               Tensor<Engine2, Layout2> const &Sin,\n                                               Tensor<Engine3, Layout3> const &identity_MN,\n                                               const int max_MN, const int min_MN,\n                                               const int dim, const int rotary_dim) {\n    CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});\n    CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});\n    CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D));                     // MMA\n    CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D));                     // MMA_M\n    CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D));                     // MMA_K\n    CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos));                     // MMA_M\n    CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos));                     // MMA_K\n    CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin));                     // MMA_M\n    CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin));                     // MMA_K\n    CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin));                     // MMA_K\n    static_assert(decltype(size<0>(S))::value == decltype(size<0>(Cos))::value * 2);\n    static_assert(decltype(size<0>(Cos))::value % 2 == 0);  // Since we do fast conversion from fp16/bf16 to fp32\n    Tensor rCos = make_fragment_like(Cos);\n    Tensor rSin = make_fragment_like(Sin);\n    Tensor rS = make_fragment_like(S);\n    #pragma unroll\n    for (int m = 0; m < size<1>(S); ++m) {\n        if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {\n            #pragma unroll\n            for (int k = 0; k < size<2>(S); ++k) {\n                if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) {\n                    cute::copy(S(_, m, k), rS(_, m, k));\n                    if (get<1>(identity_MN(0, 0, k)) < rotary_dim) {\n                        cute::copy(Cos(_, m, k), rCos(_, m, k));\n                        cute::copy(Sin(_, m, k), rSin(_, m, k));\n                        Tensor S_fp32 = convert_type<float>(rS(_, m, k));\n                        Tensor cos_fp32 = convert_type<float>(rCos(_, m, k));\n                        Tensor sin_fp32 = convert_type<float>(rSin(_, m, k));\n                        #pragma unroll\n                        for (int i = 0; i < size<0>(rS) / 2; ++i) {\n                            float real = S_fp32(2 * i) * cos_fp32(i) - S_fp32(2 * i + 1) * sin_fp32(i);\n                            float imag = S_fp32(2 * i) * sin_fp32(i) + S_fp32(2 * i + 1) * cos_fp32(i);\n                            S_fp32(2 * i) = real;\n                            S_fp32(2 * i + 1) = imag;\n                        }\n                        // Idk but I need to copy for the convert_type to work\n                        Tensor S_fp32_copy = make_fragment_like(S_fp32);\n                        cute::copy(S_fp32, S_fp32_copy);\n                        using T = typename Engine0::value_type;\n                        Tensor S_og_type = convert_type<T>(S_fp32_copy);\n                        cute::copy(S_og_type, rS(_, m, k));\n                    }\n                    cute::copy(rS(_, m, k), D(_, m, k));\n                } else if (Clear_OOB_K) {\n                    cute::clear(D(_, m, k));\n                }\n            }\n        }\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <bool Is_even_K=true, bool Clear_OOB_K=true,\n          typename Engine0, typename Layout0, typename Engine1, typename Layout1,\n          typename Engine2, typename Layout2, typename Engine3, typename Layout3>\n__forceinline__ __device__ void copy_rotary_contiguous(Tensor<Engine0, Layout0> const &S,\n                                              Tensor<Engine1, Layout1> &D,\n                                              Tensor<Engine2, Layout2> const &Cos,\n                                              Tensor<Engine2, Layout2> const &Sin,\n                                              Tensor<Engine3, Layout3> const &identity_MN,\n                                              const int max_MN, const int min_MN,\n                                              const int dim, const int rotary_dim) {\n    CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});\n    CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});\n    CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D));                     // MMA\n    CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D));                     // MMA_M\n    CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D));                     // MMA_K\n    CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos));                     // MMA_M\n    CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos));                     // MMA_K\n    CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin));                     // MMA_M\n    CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin));                     // MMA_K\n    CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(Cos));                     // MMA\n    CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin));\n    static_assert(decltype(size<0>(Cos))::value % 2 == 0);  // Since we do fast conversion from fp16/bf16 to fp32\n    Tensor rCos = make_fragment_like(Cos);\n    Tensor rSin = make_fragment_like(Sin);\n    Tensor rS = make_fragment_like(S);\n    Tensor rS_other = make_fragment_like(rS(_, 0, 0));\n    #pragma unroll\n    for (int m = 0; m < size<1>(S); ++m) {\n        if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {\n            #pragma unroll\n            for (int k = 0; k < size<2>(S); ++k) {\n                if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) {\n                    cute::copy(S(_, m, k), rS(_, m, k));\n                    if (get<1>(identity_MN(0, 0, k)) < rotary_dim) {\n                        const bool is_left = get<1>(identity_MN(0, 0, k)) < rotary_dim / 2;\n                        Tensor gS_other = make_tensor(S(_, m, k).data() + (is_left ? rotary_dim / 2 : -rotary_dim / 2), S(_, m, k).layout());\n                        cute::copy(gS_other, rS_other);\n                        // if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); }\n                        Tensor gCos = make_tensor(Cos(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Cos(_, m, k).layout());\n                        Tensor gSin = make_tensor(Sin(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Sin(_, m, k).layout());\n                        cute::copy(gCos, rCos(_, m, k));\n                        cute::copy(gSin, rSin(_, m, k));\n                        // if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); }\n                        Tensor S_fp32 = convert_type<float>(rS(_, m, k));\n                        Tensor S_other_fp32 = convert_type<float>(rS_other);\n                        Tensor cos_fp32 = convert_type<float>(rCos(_, m, k));\n                        Tensor sin_fp32 = convert_type<float>(rSin(_, m, k));\n                        #pragma unroll\n                        for (int i = 0; i < size<0>(rS); ++i) {\n                            S_fp32(i) = S_fp32(i) * cos_fp32(i) + S_other_fp32(i) * (is_left ? -sin_fp32(i) : sin_fp32(i));\n                        }\n                        // Idk but I need to copy for the convert_type to work\n                        Tensor S_fp32_copy = make_fragment_like(S_fp32);\n                        cute::copy(S_fp32, S_fp32_copy);\n                        using T = typename Engine0::value_type;\n                        Tensor S_og_type = convert_type<T>(S_fp32_copy);\n                        cute::copy(S_og_type, rS(_, m, k));\n                        // if (cute::thread0()) { print_tensor(rS(_, m, k)); }\n                    }\n                    cute::copy(rS(_, m, k), D(_, m, k));\n                } else if (Clear_OOB_K) {\n                    cute::clear(D(_, m, k));\n                }\n            }\n        }\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n}  // namespace FLASH_NAMESPACE\n"
  },
  {
    "path": "csrc/flash_attn/src/softmax.h",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n#include <cmath>\n\n#include <cute/tensor.hpp>\n\n#include <cutlass/numeric_types.h>\n\n#include \"namespace_config.h\"\n#include \"philox.cuh\"\n#include \"utils.h\"\n\nnamespace FLASH_NAMESPACE {\n\nusing namespace cute;\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>\n__device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {\n    static_assert(Layout0::rank == 2, \"Only support 2D Tensor\");\n    static_assert(Layout1::rank == 1, \"Only support 1D Tensor\");\n    CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));\n    #pragma unroll\n    for (int mi = 0; mi < size<0>(tensor); mi++) {\n        summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));\n        #pragma unroll\n        for (int ni = 1; ni < size<1>(tensor); ni++) {\n            summary(mi) = op(summary(mi), tensor(mi, ni));\n        }\n    }\n}\n\ntemplate<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>\n__device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {\n    CUTE_STATIC_ASSERT_V(size(dst) == size(src));\n    #pragma unroll\n    for (int i = 0; i < size(dst); i++){\n        dst(i) = Allreduce<4>::run(src(i), op);\n    }\n}\n\ntemplate<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>\n__device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {\n    thread_reduce_<zero_init>(tensor, summary, op);\n    quad_allreduce_(summary, summary, op);\n}\n\ntemplate<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>\n__device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){\n    MaxOp<float> max_op;\n    reduce_<zero_init>(tensor, max, max_op);\n}\n\ntemplate<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>\n__device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){\n    SumOp<float> sum_op;\n    thread_reduce_<zero_init>(tensor, sum, sum_op);\n}\n\n// Apply the exp to all the elements.\ntemplate <bool Scale_max=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>\n__forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {\n    static_assert(Layout0::rank == 2, \"Only support 2D Tensor\");\n    static_assert(Layout1::rank == 1, \"Only support 1D Tensor\");\n    CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));\n    #pragma unroll\n    for (int mi = 0; mi < size<0>(tensor); ++mi) {\n        // If max is -inf, then all elements must have been -inf (possibly due to masking).\n        // We don't want (-inf - (-inf)) since that would give NaN.\n        // If we don't have float around M_LOG2E the multiplication is done in fp64.\n        const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E));\n        #pragma unroll\n        for (int ni = 0; ni < size<1>(tensor); ++ni)  {\n            // Instead of computing exp(x - max), we compute exp2(x * log_2(e) -\n            // max * log_2(e)) This allows the compiler to use the ffma\n            // instruction instead of fadd and fmul separately.\n            // The following macro will disable the use of fma.\n            // See: https://github.com/pytorch/pytorch/issues/121558 for more details\n            // This macro is set in PyTorch and not FlashAttention\n            #ifdef UNFUSE_FMA\n                tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled);\n            #else\n                tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);\n            #endif\n        }\n    }\n}\n\n// Apply the exp to all the elements.\ntemplate <bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>\n__forceinline__ __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> &max, Tensor<Engine1, Layout1> &sum, const float scale) {\n    static_assert(Layout0::rank == 2, \"Only support 2D Tensor\");\n    static_assert(Layout1::rank == 1, \"Only support 1D Tensor\");\n    CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));\n    #pragma unroll\n    for (int mi = 0; mi < size<0>(tensor); ++mi) {\n        MaxOp<float> max_op;\n        max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0));\n        #pragma unroll\n        for (int ni = 1; ni < size<1>(tensor); ni++) {\n            max(mi) = max_op(max(mi), tensor(mi, ni));\n        }\n        max(mi) = Allreduce<4>::run(max(mi), max_op);\n        // If max is -inf, then all elements must have been -inf (possibly due to masking).\n        // We don't want (-inf - (-inf)) since that would give NaN.\n        const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale;\n        sum(mi) = 0;\n        #pragma unroll\n        for (int ni = 0; ni < size<1>(tensor); ++ni)  {\n            // Instead of computing exp(x - max), we compute exp2(x * log_2(e) -\n            // max * log_2(e)) This allows the compiler to use the ffma\n            // instruction instead of fadd and fmul separately.\n            tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);\n            sum(mi) += tensor(mi, ni);\n        }\n        SumOp<float> sum_op;\n        sum(mi) = Allreduce<4>::run(sum(mi), sum_op);\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int kNRows>\nstruct Softmax {\n\n    using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));\n    TensorT row_max, row_sum;\n\n    __forceinline__ __device__ Softmax() {};\n\n    template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1>\n    __forceinline__ __device__ void softmax_rescale_o(Tensor0 &acc_s, Tensor1 &acc_o, float softmax_scale_log2) {\n        // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))\n        Tensor scores = make_tensor(acc_s.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_s.layout()));\n        static_assert(decltype(size<0>(scores))::value == kNRows);\n        if (Is_first) {\n            FLASH_NAMESPACE::template reduce_max</*zero_init=*/true>(scores, row_max);\n            FLASH_NAMESPACE::scale_apply_exp2(scores, row_max, softmax_scale_log2);\n            FLASH_NAMESPACE::reduce_sum</*zero_init=*/true>(scores, row_sum);\n        } else {\n            Tensor scores_max_prev = make_fragment_like(row_max);\n            cute::copy(row_max, scores_max_prev);\n            FLASH_NAMESPACE::template reduce_max</*zero_init=*/false>(scores, row_max);\n            // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))\n            Tensor acc_o_rowcol = make_tensor(acc_o.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_o.layout()));\n            static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);\n            #pragma unroll\n            for (int mi = 0; mi < size(row_max); ++mi) {\n                float scores_max_cur = !Check_inf\n                    ? row_max(mi)\n                    : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));\n                float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);\n                row_sum(mi) *= scores_scale;\n                #pragma unroll\n                for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }\n            }\n            FLASH_NAMESPACE::scale_apply_exp2(scores, row_max, softmax_scale_log2);\n            // We don't do the reduce across threads here since we don't need to use the row_sum.\n            // We do that reduce at the end when we need to normalize the softmax.\n            FLASH_NAMESPACE::reduce_sum</*zero_init=*/false>(scores, row_sum);\n        }\n    };\n\n    template<bool Is_dropout=false, bool Split=false, typename Tensor0>\n    __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) {\n        SumOp<float> sum_op;\n        quad_allreduce_(row_sum, row_sum, sum_op);\n        TensorT lse = make_fragment_like(row_sum);\n        Tensor acc_o_rowcol = make_tensor(acc_o.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_o.layout()));\n        static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);\n        #pragma unroll\n        for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {\n            float sum = row_sum(mi);\n            float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;\n            lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);\n            float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;\n            #pragma unroll\n            for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }\n        }\n        return lse;\n    };\n};\n\n}  // namespace FLASH_NAMESPACE\n"
  },
  {
    "path": "csrc/flash_attn/src/static_switch.h",
    "content": "// Inspired by\n// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h\n// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h\n\n#pragma once\n\n/// @param COND       - a boolean expression to switch by\n/// @param CONST_NAME - a name given for the constexpr bool variable.\n/// @param ...       - code to execute for true and false\n///\n/// Usage:\n/// ```\n/// BOOL_SWITCH(flag, BoolConst, [&] {\n///     some_function<BoolConst>(...);\n/// });\n/// ```\n\n#define BOOL_SWITCH(COND, CONST_NAME, ...)      \\\n  [&] {                                         \\\n    if (COND) {                                 \\\n      constexpr static bool CONST_NAME = true;  \\\n      return __VA_ARGS__();                     \\\n    } else {                                    \\\n      constexpr static bool CONST_NAME = false; \\\n      return __VA_ARGS__();                     \\\n    }                                           \\\n  }()\n\n#ifdef FLASHATTENTION_DISABLE_DROPOUT\n  #define DROPOUT_SWITCH(COND, CONST_NAME, ...) \\\n  [&] {                                         \\\n    constexpr static bool CONST_NAME = false;   \\\n    return __VA_ARGS__();                       \\\n  }()\n#else\n  #define DROPOUT_SWITCH BOOL_SWITCH\n#endif\n\n#ifdef FLASHATTENTION_DISABLE_ALIBI\n  #define ALIBI_SWITCH(COND, CONST_NAME, ...)   \\\n  [&] {                                         \\\n    constexpr static bool CONST_NAME = false;   \\\n    return __VA_ARGS__();                       \\\n  }()\n#else\n  #define ALIBI_SWITCH BOOL_SWITCH\n#endif\n\n#ifdef FLASHATTENTION_DISABLE_UNEVEN_K\n  #define EVENK_SWITCH(COND, CONST_NAME, ...)   \\\n  [&] {                                         \\\n    constexpr static bool CONST_NAME = true;    \\\n    return __VA_ARGS__();                       \\\n  }()\n#else\n  #define EVENK_SWITCH BOOL_SWITCH\n#endif\n\n#ifdef FLASHATTENTION_DISABLE_SOFTCAP\n  #define SOFTCAP_SWITCH(COND, CONST_NAME, ...)   \\\n  [&] {                                         \\\n    constexpr static bool CONST_NAME = false;    \\\n    return __VA_ARGS__();                       \\\n  }()\n#else\n  #define SOFTCAP_SWITCH BOOL_SWITCH\n#endif\n\n#ifdef FLASHATTENTION_DISABLE_LOCAL\n  #define LOCAL_SWITCH(COND, CONST_NAME, ...)   \\\n  [&] {                                         \\\n    constexpr static bool CONST_NAME = false;    \\\n    return __VA_ARGS__();                       \\\n  }()\n#else\n  #define LOCAL_SWITCH BOOL_SWITCH\n#endif\n\n#define FP16_SWITCH(COND, ...)               \\\n  [&] {                                      \\\n    if (COND) {                              \\\n      using elem_type = cutlass::half_t;     \\\n      return __VA_ARGS__();                  \\\n    } else {                                 \\\n      using elem_type = cutlass::bfloat16_t; \\\n      return __VA_ARGS__();                  \\\n    }                                        \\\n  }()\n\n#define HEADDIM_SWITCH(HEADDIM, ...)   \\\n  [&] {                                    \\\n    if (HEADDIM <= 32) {                   \\\n      constexpr static int kHeadDim = 32;  \\\n      return __VA_ARGS__();                \\\n    } else if (HEADDIM <= 64) {            \\\n      constexpr static int kHeadDim = 64;  \\\n      return __VA_ARGS__();                \\\n    } else if (HEADDIM <= 96) {            \\\n      constexpr static int kHeadDim = 96;  \\\n      return __VA_ARGS__();                \\\n    } else if (HEADDIM <= 128) {           \\\n      constexpr static int kHeadDim = 128; \\\n      return __VA_ARGS__();                \\\n    } else if (HEADDIM <= 192) {           \\\n      constexpr static int kHeadDim = 192; \\\n      return __VA_ARGS__();                \\\n    } else if (HEADDIM <= 256) {           \\\n      constexpr static int kHeadDim = 256; \\\n      return __VA_ARGS__();                \\\n    }                                      \\\n  }()\n"
  },
  {
    "path": "csrc/flash_attn/src/utils.h",
    "content": "/******************************************************************************\n * Copyright (c) 2023, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n#include <assert.h>\n#include <stdint.h>\n#include <stdlib.h>\n\n#include <cuda_fp16.h>\n\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800\n#include <cuda_bf16.h>\n#endif\n\n#include <cute/tensor.hpp>\n\n#include <cutlass/array.h>\n#include <cutlass/cutlass.h>\n#include <cutlass/numeric_conversion.h>\n#include <cutlass/numeric_types.h>\n\n#include \"namespace_config.h\"\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nnamespace FLASH_NAMESPACE {\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename T>\n__forceinline__ __device__ uint32_t relu2(const uint32_t x);\n\ntemplate<>\n__forceinline__ __device__ uint32_t relu2<cutlass::half_t>(const uint32_t x) {\n    uint32_t res;\n    const uint32_t zero = 0u;\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800\n    asm volatile(\"max.f16x2 %0, %1, %2;\\n\" : \"=r\"(res) : \"r\"(x), \"r\"(zero));\n#else\n    asm volatile( \\\n        \"{\\n\" \\\n        \"\\t .reg .f16x2 sela;\\n\" \\\n        \"\\t set.gtu.u32.f16x2 sela, %1, %2;\\n\" \\\n        \"\\t and.b32 %0, sela, %1;\\n\" \n        \"}\\n\" : \"=r\"(res) : \"r\"(x), \"r\"(zero));\n#endif\n    return res;\n}\n\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800\ntemplate<>\n__forceinline__ __device__ uint32_t relu2<cutlass::bfloat16_t>(const uint32_t x) {\n    uint32_t res;\n    const uint32_t zero = 0u;\n    asm volatile(\"max.bf16x2 %0, %1, %2;\\n\" : \"=r\"(res) : \"r\"(x), \"r\"(zero));\n    return res;\n}\n#endif\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800\n\ntemplate<typename T>\n__forceinline__ __device__ uint32_t convert_relu2(const float2 x);\n\ntemplate<>\n__forceinline__ __device__ uint32_t convert_relu2<cutlass::half_t>(const float2 x) {\n    uint32_t res;\n    const uint32_t a = reinterpret_cast<const uint32_t&>(x.x);\n    const uint32_t b = reinterpret_cast<const uint32_t&>(x.y);\n    asm volatile(\"cvt.rn.relu.f16x2.f32 %0, %1, %2;\\n\" : \"=r\"(res) : \"r\"(b), \"r\"(a));\n    return res;\n}\n\ntemplate<>\n__forceinline__ __device__ uint32_t convert_relu2<cutlass::bfloat16_t>(const float2 x) {\n    uint32_t res;\n    const uint32_t a = reinterpret_cast<const uint32_t&>(x.x);\n    const uint32_t b = reinterpret_cast<const uint32_t&>(x.y);\n    asm volatile(\"cvt.rn.relu.bf16x2.f32 %0, %1, %2;\\n\" : \"=r\"(res) : \"r\"(b), \"r\"(a));\n    return res;\n}\n\n#endif\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename T>\nstruct MaxOp {\n__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; }\n};\n\ntemplate <>\nstruct MaxOp<float> {\n// This is slightly faster\n__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename T>\nstruct SumOp {\n__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<int THREADS>\nstruct Allreduce {\n    static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);\n    template<typename T, typename Operator>\n    static __device__ __forceinline__ T run(T x, Operator &op) {\n        constexpr int OFFSET = THREADS / 2;\n        x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));\n        return Allreduce<OFFSET>::run(x, op);\n    }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<>\nstruct Allreduce<2> {\ntemplate<typename T, typename Operator> \nstatic __device__ __forceinline__ T run(T x, Operator &op) {\n    x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));\n    return x;\n}\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<bool A_in_regs=false, bool B_in_regs=false, typename Tensor0, typename Tensor1,\n         typename Tensor2, typename Tensor3, typename Tensor4,\n         typename TiledMma, typename TiledCopyA, typename TiledCopyB,\n         typename ThrCopyA, typename ThrCopyB>\n__forceinline__ __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA,\n                            Tensor4 const& tCsB, TiledMma tiled_mma,\n                            TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B,\n                            ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) {\n    CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc));                     // MMA_M\n    CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc));                     // MMA_N\n    CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB));                     // MMA_K\n    Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA);\n    CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view));            // M\n    Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);\n    CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view));            // N\n    if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); }\n    if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); }\n    #pragma unroll\n    for (int i = 0; i < size<2>(tCrA); ++i) {\n        if (i < size<2>(tCrA) - 1) {\n            if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); }\n            if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); }\n        }\n        cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,\n         typename TiledMma, typename TiledCopy, typename ThrCopy>\n__forceinline__ __device__ void gemm_rs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,\n                               TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,\n                               ThrCopy smem_thr_copy_B) {\n    CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc));                     // MMA_M\n    CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc));                     // MMA_N\n    CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB));                     // MMA_K\n    Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);\n    CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view));            // N\n    cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{}));\n    #pragma unroll\n    for (int i = 0; i < size<2>(tCrA); ++i) {\n        if (i < size<2>(tCrA) - 1) {\n            cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1));\n        }\n        cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))\ntemplate<typename Layout>\n__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {\n    static_assert(decltype(size<0>(acc_layout))::value == 4);\n    static_assert(decltype(rank(acc_layout))::value == 3);\n    auto l = logical_divide(acc_layout, Shape<_2>{});  // ((2, 2), MMA_M, MMA_N)\n    return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)\n// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8.\ntemplate<typename MMA_traits, typename Layout>\n__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) {\n    using X = Underscore;\n    static_assert(decltype(size<0>(acc_layout))::value == 4);\n    static_assert(decltype(rank(acc_layout))::value == 3);\n    constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{});\n    static_assert(mma_shape_K == 8 || mma_shape_K == 16);\n    if constexpr (mma_shape_K == 8) {\n        return acc_layout;\n    } else {\n        auto l = logical_divide(acc_layout, Shape<X, X, _2>{});  // (4, MMA_M, (2, MMA_N / 2)))\n        return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));\n    }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)\ntemplate<typename Layout>\n__forceinline__ __device__ auto convert_layout_acc_dropout(Layout acc_layout) {\n    using X = Underscore;\n    static_assert(decltype(size<0>(acc_layout))::value == 4);\n    static_assert(decltype(rank(acc_layout))::value == 3);\n    auto l = logical_divide(acc_layout, Shape<X, X, _2>{});  // (4, MMA_M, (2, MMA_N / 2)))\n    return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename To_type, typename Engine, typename Layout>\n__forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {\n    using From_type = typename Engine::value_type;\n    constexpr int numel = decltype(size(tensor))::value;\n    cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;\n    // HACK: this requires tensor to be \"contiguous\"\n    auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));\n    return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Engine, typename Layout>\n__forceinline__ __device__ void relu_(Tensor<Engine, Layout> &tensor) {\n    constexpr int numel = decltype(size(tensor))::value;\n    static_assert(numel % 2 == 0);\n    using value_t = typename Engine::value_type;\n    // HACK: this requires tensor to be \"contiguous\"\n    Tensor tensor_uint32 = recast<uint32_t>(tensor);\n    #pragma unroll\n    for (int i = 0; i < size(tensor_uint32); ++i) {\n        tensor_uint32(i) = relu2<value_t>(tensor_uint32(i));\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n// On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction\ntemplate <typename To_type, typename Engine, typename Layout>\n__forceinline__ __device__ auto convert_type_relu(Tensor<Engine, Layout> const &tensor) {\n    using From_type = typename Engine::value_type;\n    static_assert(std::is_same_v<To_type, cutlass::half_t> || std::is_same_v<To_type, cutlass::bfloat16_t>);\n    static_assert(std::is_same_v<float, From_type>);\n    constexpr int numel = decltype(size(tensor))::value;\n    static_assert(numel % 2 == 0);\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800\n    // HACK: this requires tensor to be \"contiguous\"\n    Tensor tensor_float2 = recast<float2>(tensor);\n    Tensor out_uint32 = make_tensor<uint32_t>(tensor_float2.layout());\n    #pragma unroll\n    for (int i = 0; i < size(out_uint32); ++i) {\n        out_uint32(i) = convert_relu2<To_type>(tensor_float2(i));\n    }\n    Tensor out = make_tensor(make_rmem_ptr<To_type>(out_uint32.data()), tensor.layout());\n#else\n    Tensor out = FLASH_NAMESPACE::convert_type<To_type>(tensor);\n    FLASH_NAMESPACE::relu_(out);\n#endif\n    return out;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n// Blocks until all but N previous cp.async.commit_group operations have committed.\n// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all\n// (which is equivalent to commit_group then wait_group 0).\n// Instead we just call cp.async.wait_group 0, which is slightly faster.\n// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113\ntemplate <int N>\nCUTE_HOST_DEVICE\nvoid cp_async_wait() {\n#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)\n    asm volatile(\"cp.async.wait_group %0;\\n\" :: \"n\"(N));\n#endif\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,\n          typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,\n          typename Engine2, typename Layout2, typename Engine3, typename Layout3>\n__forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,\n                            Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,\n                            Tensor<Engine3, Layout3> const &predicate_K, const int max_MN=0) {\n    CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});\n    CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});\n    CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D));                     // MMA\n    CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D));                     // MMA_M\n    CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D));                     // MMA_K\n    // There's no case where !Clear_OOB_K && Clear_OOB_MN\n    static_assert(!(Clear_OOB_MN && !Clear_OOB_K));\n    #pragma unroll\n    for (int m = 0; m < size<1>(S); ++m) {\n        if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {\n            #pragma unroll\n            for (int k = 0; k < size<2>(S); ++k) {\n                if (Is_even_K || predicate_K(k)) {\n                    cute::copy(tiled_copy, S(_, m, k), D(_, m, k));\n                } else if (Clear_OOB_K) {\n                    cute::clear(D(_, m, k));\n                }\n            }\n        } else if (Clear_OOB_MN) {\n            cute::clear(D(_, m, _));\n        }\n    }\n    // TD [2023-04-13]: Strange that the code below can cause race condition.\n    // I think it's because the copies are under an if statement.\n    // if (Is_even_K) {\n    //     #pragma unroll\n    //     for (int m = 0; m < size<1>(S); ++m) {\n    //         if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {\n    //             copy(tiled_copy, S(_, m, _), D(_, m, _));\n    //         } else if (Clear_OOB_MN) {\n    //             clear(D(_, m, _));\n    //         }\n    //     }\n    // } else {  // It's slightly faster in this case if iterate over K first\n    //     #pragma unroll\n    //     for (int k = 0; k < size<2>(S); ++k) {\n    //         if (predicate_K(k)) {\n    //             #pragma unroll\n    //             for (int m = 0; m < size<1>(S); ++m) {\n    //                 if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {\n    //                     copy(tiled_copy, S(_, m, k), D(_, m, k));\n    //                 } else if (Clear_OOB_MN) {\n    //                     clear(D(_, m, k));\n    //                 }\n    //             }\n    //         } else if (Clear_OOB_K) {  // There's no case where !Clear_OOB_K && Clear_OOB_MN\n    //             if (Clear_OOB_MN || Is_even_MN) {\n    //                 clear(D(_, _, k));\n    //             } else {\n    //                 #pragma unroll\n    //                 for (int m = 0; m < size<1>(S); ++m) {\n    //                     if (!(Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN)) {\n    //                         clear(D(_, m, k));\n    //                     }\n    //                 }\n    //             }\n    //         }\n    //     }\n    // }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <bool Is_even_K=true,\n          typename Engine0, typename Layout0, typename Engine1, typename Layout1,\n          typename Engine2, typename Layout2, typename Engine3, typename Layout3>\n__forceinline__ __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S,\n                                      Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,\n                                      Tensor<Engine3, Layout3> const &predicate_K,\n                                      const int max_MN=0, const int min_MN=0) {\n    CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});\n    CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});\n    CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D));                     // MMA\n    CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D));                     // MMA_M\n    CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D));                     // MMA_K\n    // if (threadIdx.x == 0 && blockIdx.z == 0) { printf(\"blockIdx.y = %d, max_MN = %d, min_MN = %d\\n\", blockIdx.y, max_MN, min_MN); }\n    #pragma unroll\n    for (int m = 0; m < size<1>(S); ++m) {\n        // if (threadIdx.x == 0 && blockIdx.z == 0) { printf(\"blockIdx.y = %d, m = %d\\n\", blockIdx.y, get<0>(identity_MN(0, m, 0))); }\n        if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {\n            // if (threadIdx.x == 0 && blockIdx.z == 0) { printf(\"Inner loop, blockIdx.y = %d, m = %d\\n\", blockIdx.y, get<0>(identity_MN(0, m, 0))); }\n            #pragma unroll\n            for (int k = 0; k < size<2>(S); ++k) {\n                if (Is_even_K || predicate_K(k)) {\n                    cute::copy(S(_, m, k), D(_, m, k));\n                }\n            }\n        }\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Engine, typename Layout>\n__forceinline__ __device__ void apply_softcap(Tensor<Engine, Layout> &tensor, const float softcap){\n    #pragma unroll\n    for (int i = 0; i < size(tensor); ++i) {\n        tensor(i) = cutlass::fast_tanh(tensor(i) * softcap);\n    }\n}\n\ntemplate <typename Engine0, typename Layout0, typename Engine1, typename Layout1>\n__forceinline__ __device__ void calculate_dtanh(Tensor<Engine0, Layout0> &src_tensor, Tensor<Engine1, Layout1> &dst_tensor, const float softcap){\n    #pragma unroll\n    for (int i = 0; i < size(src_tensor); ++i) {\n        dst_tensor(i) = (1.f - (src_tensor(i) * src_tensor(i))) * softcap;\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n}  // namespace FLASH_NAMESPACE\n"
  },
  {
    "path": "csrc/flash_attn_ck/flash_api.cpp",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Tri Dao.\n ******************************************************************************/\n\n#include \"flash_common.hpp\"\n\nstd::vector<at::Tensor>\nmha_fwd(at::Tensor &q,                            // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)\n        const at::Tensor &k,                      // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)\n        const at::Tensor &v,                      // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)\n        std::optional<at::Tensor> &out_,          // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)\n        std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads\n        const float p_dropout,\n        const float softmax_scale,\n        bool is_causal,\n        int window_size_left,\n        int window_size_right,\n        const float softcap,\n        const bool return_softmax,\n        std::optional<at::Generator> gen_);\n\nstd::vector<at::Tensor>\nmha_varlen_fwd(at::Tensor &q,                               // total_q x num_heads x head_size, total_q := \\sum_{i=0}^{b} s_i\n               const at::Tensor &k,                         // total_k x num_heads_k x head_size, total_k := \\sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.\n               const at::Tensor &v,                         // total_k x num_heads_k x head_size, total_k := \\sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.\n               std::optional<at::Tensor> &out_,             // total_q x num_heads x head_size, total_k := \\sum_{i=0}^{b} s_i\n               const at::Tensor &cu_seqlens_q,              // b+1\n               const at::Tensor &cu_seqlens_k,              // b+1\n               std::optional<at::Tensor> &seqused_k,        // b. If given, only this many elements of each batch element's keys are used.\n               std::optional<const at::Tensor> &leftpad_k_, // batch_size\n               std::optional<at::Tensor> &block_table_,     // batch_size x max_num_blocks_per_seq\n               std::optional<at::Tensor> &alibi_slopes_,    // num_heads or b x num_heads\n               int max_seqlen_q,\n               const int max_seqlen_k,\n               const float p_dropout,\n               const float softmax_scale,\n               const bool zero_tensors,\n               bool is_causal,\n               int window_size_left,\n               int window_size_right,\n               const float softcap,\n               const bool return_softmax,\n               std::optional<at::Generator> gen_);\n\nstd::vector<at::Tensor>\nmha_bwd(const at::Tensor &dout,                   // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8)\n        const at::Tensor &q,                      // batch_size x seqlen_q x num_heads x head_size\n        const at::Tensor &k,                      // batch_size x seqlen_k x num_heads_k x head_size\n        const at::Tensor &v,                      // batch_size x seqlen_k x num_heads_k x head_size\n        const at::Tensor &out,                    // batch_size x seqlen_q x num_heads x head_size\n        const at::Tensor &softmax_lse,            // b x h x seqlen_q\n        std::optional<at::Tensor> &dq_,           // batch_size x seqlen_q x num_heads x head_size\n        std::optional<at::Tensor> &dk_,           // batch_size x seqlen_k x num_heads_k x head_size\n        std::optional<at::Tensor> &dv_,           // batch_size x seqlen_k x num_heads_k x head_size\n        std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads\n        const float p_dropout,                    // probability to drop\n        const float softmax_scale,\n        const bool is_causal,\n        int window_size_left,\n        int window_size_right,\n        const float softcap,\n        const bool deterministic,\n        std::optional<at::Generator> gen_,\n        std::optional<at::Tensor> &rng_state);\n\nstd::vector<at::Tensor>\nmha_varlen_bwd(const at::Tensor &dout,                   // total_q x num_heads x head_size\n               const at::Tensor &q,                      // total_q x num_heads x head_size, total_q := \\sum_{i=0}^{b} s_i\n               const at::Tensor &k,                      // total_k x num_heads_k x head_size, total_k := \\sum_{i=0}^{b} s_i\n               const at::Tensor &v,                      // total_k x num_heads_k x head_size, total_k := \\sum_{i=0}^{b} s_i\n               const at::Tensor &out,                    // total_q x num_heads x head_size\n               const at::Tensor &softmax_lse,            // b x h x s   softmax logsumexp\n               std::optional<at::Tensor> &dq_,           // total_q x num_heads x head_size, total_q := \\sum_{i=0}^{b} s_i\n               std::optional<at::Tensor> &dk_,           // total_k x num_heads_k x head_size, total_k := \\sum_{i=0}^{b} s_i\n               std::optional<at::Tensor> &dv_,           // total_k x num_heads_k x head_size, total_k := \\sum_{i=0}^{b} s_i\n               const at::Tensor &cu_seqlens_q,           // b+1\n               const at::Tensor &cu_seqlens_k,           // b+1\n               std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads\n               const int max_seqlen_q,\n               const int max_seqlen_k, // max sequence length to choose the kernel\n               const float p_dropout,  // probability to drop\n               const float softmax_scale,\n               const bool zero_tensors,\n               const bool is_causal,\n               int window_size_left,\n               int window_size_right,\n               const float softcap,\n               const bool deterministic,\n               std::optional<at::Generator> gen_,\n               std::optional<at::Tensor> &rng_state);\n\nstd::vector<at::Tensor>\nmha_fwd_kvcache(at::Tensor &q,                                     // batch_size x seqlen_q x num_heads x head_size\n                const at::Tensor &kcache,                          // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.\n                const at::Tensor &vcache,                          // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.\n                std::optional<const at::Tensor> &k_,               // batch_size x seqlen_knew x num_heads_k x head_size\n                std::optional<const at::Tensor> &v_,               // batch_size x seqlen_knew x num_heads_k x head_size\n                std::optional<const at::Tensor> &seqlens_k_,       // batch_size\n                std::optional<const at::Tensor> &rotary_cos_,      // seqlen_ro x (rotary_dim / 2)\n                std::optional<const at::Tensor> &rotary_sin_,      // seqlen_ro x (rotary_dim / 2)\n                std::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache\n                std::optional<const at::Tensor> &leftpad_k_,       // batch_size\n                std::optional<at::Tensor> &block_table_,           // batch_size x max_num_blocks_per_seq\n                std::optional<at::Tensor> &alibi_slopes_,          // num_heads or batch_size x num_heads\n                std::optional<at::Tensor> &out_,                   // batch_size x seqlen_q x num_heads x head_size\n                const float softmax_scale,\n                bool is_causal,\n                int window_size_left,\n                int window_size_right,\n                const float softcap,\n                bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2\n                int num_splits);\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m)\n{\n        m.doc() = \"FlashAttention\";\n        m.def(\"fwd\", &mha_fwd, \"Forward pass\");\n        m.def(\"varlen_fwd\", &mha_varlen_fwd, \"Forward pass (variable length)\");\n        m.def(\"bwd\", &mha_bwd, \"Backward pass\");\n        m.def(\"varlen_bwd\", &mha_varlen_bwd, \"Backward pass (variable length)\");\n        m.def(\"fwd_kvcache\", &mha_fwd_kvcache, \"Forward pass, with KV-cache\");\n}\n"
  },
  {
    "path": "csrc/flash_attn_ck/flash_common.cpp",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Tri Dao.\n ******************************************************************************/\n\n#include \"flash_common.hpp\"\n\nnamespace flash {\nint override_num_splits_if_necessary(int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits)\n{\n    int device;\n    auto status = hipGetDevice(&device);\n    if(status != hipSuccess)\n        return num_splits;\n\n    hipDeviceProp_t props{};\n    status = hipGetDeviceProperties(&props, device);\n    if(status != hipSuccess)\n        return num_splits;\n\n    // TODO - tile size should match the TileFmhaShape, hardcode for now\n    const int kM0 = 128;\n    const int kN1 = hdim_v;\n\n    const int num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;\n    const int num_n_blocks = (hdim_v + kN1 - 1) / kN1;\n\n    if(num_splits < 1 && p_drop == 0.0f)\n        return num_splits_heuristic_ck(\n            batch * nhead * num_m_blocks, props.multiProcessorCount * 2, num_n_blocks, 128);\n\n    return num_splits;\n}\n\n} // namespace flash\n"
  },
  {
    "path": "csrc/flash_attn_ck/flash_common.hpp",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n// Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.\n#include <torch/python.h>\n#include <torch/nn/functional.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <c10/cuda/CUDAGuard.h>\n\n#ifdef OLD_GENERATOR_PATH\n#include <ATen/CUDAGeneratorImpl.h>\n#else\n#include <ATen/cuda/CUDAGeneratorImpl.h>\n#endif\n\n\n#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x \" must be on CUDA\")\n#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x \" must have shape (\" #__VA_ARGS__ \")\")\n#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x \" must be contiguous\")\n\nnamespace flash {\ninline __global__ void ParsePhiloxCudaState(at::PhiloxCudaState arg, uint64_t* rng_state)\n{\n    // Imitate from PyTorch\n    // https://github.com/pytorch/pytorch/blob/8b61daaf7349e9102117e1aeefaa51666d887547/aten/src/ATen/cuda/detail/UnpackRaw.cuh#L17\n    if (arg.captured_) {\n        rng_state[0] = static_cast<uint64_t>(*arg.seed_.ptr);\n        rng_state[1] = static_cast<uint64_t>(*(arg.offset_.ptr) + arg.offset_intragraph_);\n    } else {\n        rng_state[0] = arg.seed_.val;\n        rng_state[1] = arg.offset_.val;\n    }\n}\n\ninline int num_splits_heuristic_ck(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) {\n    // If we have enough to almost fill the SMs, then just use 1 split\n    if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; }\n    max_splits = std::min({max_splits, num_SMs, num_n_blocks});\n    float max_efficiency = 0.f;\n    std::vector<float> efficiency;\n    efficiency.reserve(max_splits);\n    auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };\n    // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,\n    // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks\n    // (i.e. it's 11 splits anyway).\n    // So we check if the number of blocks per split is the same as the previous num_splits.\n    auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {\n        return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1);\n    };\n    for (int num_splits = 1; num_splits <= max_splits; num_splits++) {\n        if (!is_split_eligible(num_splits)) {\n            efficiency.push_back(0.f);\n        } else {\n            float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs;\n            float eff = n_waves / ceil(n_waves);\n            // printf(\"num_splits = %d, eff = %f\\n\", num_splits, eff);\n            if (eff > max_efficiency) { max_efficiency = eff; }\n            efficiency.push_back(eff);\n        }\n    }\n    for (int num_splits = 1; num_splits <= max_splits; num_splits++) {\n        if (!is_split_eligible(num_splits)) { continue; }\n        if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) {\n            // printf(\"num_splits chosen = %d\\n\", num_splits);\n            return num_splits;\n        }\n    }\n    return 1;\n}\n\nint override_num_splits_if_necessary(int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits);\n\n} // namespace flash\n"
  },
  {
    "path": "csrc/flash_attn_ck/mha_bwd.cpp",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Tri Dao.\n ******************************************************************************/\n\n#include \"flash_common.hpp\"\n\n#include \"fmha_bwd.hpp\"\n#include \"mask.hpp\"\n\nfmha_bwd_traits get_ck_fmha_bwd_traits(const mask_info &mask,\n                                       std::string dtype,\n                                       int seqlen_q,\n                                       int seqlen_k,\n                                       int batch,\n                                       int head_size,\n                                       int nhead_q,\n                                       int nhead_k,\n                                       bool has_dropout,\n                                       bool enable_alibi,\n                                       bool deterministic)\n{\n    return fmha_bwd_traits{seqlen_q,\n                           seqlen_k,\n                           batch,\n                           seqlen_q, // max_seqlen_q\n                           seqlen_k, // max_seqlen_k\n                           head_size, // hdim_q\n                           head_size, // hdim_k\n                           nhead_q,\n                           nhead_k,\n                           dtype,\n                           false, // is_group_mode\n                           mask.type,\n                           enable_alibi ? bias_enum::alibi : bias_enum::no_bias,\n                           false,    // has_dbias\n                           has_dropout,\n                           false, // s_randval\n                           deterministic};\n}\n\nfmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,\n                                   // sizes\n                                   const int b,\n                                   const int seqlen_q,\n                                   const int seqlen_k,\n                                   const int h,\n                                   const int h_k,\n                                   const int hdim,\n                                   // device pointers\n                                   const at::Tensor q,\n                                   const at::Tensor k,\n                                   const at::Tensor v,\n                                   std::optional<at::Tensor> &alibi_slopes_,\n                                   const at::Tensor out,\n                                   const at::Tensor softmax_lse,\n                                   const at::Tensor dout,\n                                   at::Tensor dq_acc,\n                                   at::Tensor d,\n                                   at::Tensor dq,\n                                   at::Tensor dk,\n                                   at::Tensor dv,\n                                   float softmax_scale,\n                                   float p_dropout,\n                                   std::pair<uint64_t*, uint64_t*> drop_seed_offset)\n{\n    // q: (batch_size, seqlen_q, nheads, hdim)\n    ck_tile::index_t batch_stride_q = q.stride(0);\n    ck_tile::index_t stride_q = q.stride(1);\n    ck_tile::index_t nhead_stride_q = q.stride(2);\n\n    // k: (batch_size, seqlen_k, nheads_k, hdim)\n    ck_tile::index_t batch_stride_k = k.stride(0);\n    ck_tile::index_t stride_k = k.stride(1);\n    ck_tile::index_t nhead_stride_k = k.stride(2);\n\n    // v: (batch_size, seqlen_k, nheads_k, hdim)\n    ck_tile::index_t batch_stride_v = v.stride(0);\n    ck_tile::index_t stride_v = v.stride(1);\n    ck_tile::index_t nhead_stride_v = v.stride(2);\n\n    // o: (batch_size, seqlen_q, nheads, hdim)\n    ck_tile::index_t batch_stride_o = out.stride(0);\n    ck_tile::index_t stride_o = out.stride(1);\n    ck_tile::index_t nhead_stride_o = out.stride(2);\n\n    // lse: (batch_size, nheads, seqlen_q)\n    ck_tile::index_t batch_stride_lse = softmax_lse.stride(0);\n    ck_tile::index_t nhead_stride_lse = softmax_lse.stride(1);\n\n    // do: (batch_size, seqlen_q, nheads, hdim)\n    ck_tile::index_t batch_stride_do = dout.stride(0);\n    ck_tile::index_t stride_do = dout.stride(1);\n    ck_tile::index_t nhead_stride_do = dout.stride(2);\n\n    // d: (batch_size, nheads, seqlen_q)\n    // CK assume d share the same stride with lse\n\n    // dq: (batch_size, seqlen_q, nheads, hdim)\n    ck_tile::index_t batch_stride_dq = dq.stride(0);\n    ck_tile::index_t stride_dq = dq.stride(1);\n    ck_tile::index_t nhead_stride_dq = dq.stride(2);\n\n    // dk_expanded: (batch_size, seqlen_k, nheads, hdim)\n    ck_tile::index_t batch_stride_dk = dk.stride(0);\n    ck_tile::index_t stride_dk = dk.stride(1);\n    ck_tile::index_t nhead_stride_dk = dk.stride(2);\n\n    // dv_expanded: (batch_size, seqlen_k, nheads, hdim)\n    ck_tile::index_t batch_stride_dv = dv.stride(0);\n    ck_tile::index_t stride_dv = dv.stride(1);\n    ck_tile::index_t nhead_stride_dv = dv.stride(2);\n\n    // dq_acc: (batch_size, nheads, split, seqlen_q, hdim)\n    ck_tile::long_index_t batch_stride_dq_acc = dq_acc.stride(0);\n    ck_tile::long_index_t nhead_stride_dq_acc = dq_acc.stride(1);\n    ck_tile::index_t split_stride_dq_acc = dq_acc.stride(2);\n    ck_tile::index_t stride_dq_acc = dq_acc.stride(3);\n\n    float p_undrop = 1.0 - p_dropout;\n\n    void *alibi_slopes_ptr = nullptr;\n    ck_tile::index_t stride_alibi_slopes = 0;\n\n    if (alibi_slopes_.has_value()) {\n        auto alibi_slopes = alibi_slopes_.value();\n        CHECK_DEVICE(alibi_slopes);\n        TORCH_CHECK(alibi_slopes.stride(-1) == 1, \"ALiBi slopes tensor must have contiguous last dimension\");\n        TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h}));\n        alibi_slopes_ptr = alibi_slopes.data_ptr();\n        // alibi_slopes:(batch_size, nheads) or (nhead)\n        stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;\n    }\n\n    return fmha_bwd_args{q.data_ptr(),\n                         k.data_ptr(),\n                         v.data_ptr(),\n                         alibi_slopes_ptr, // bias\n                         out.data_ptr(),\n                         softmax_lse.data_ptr(),\n                         dout.data_ptr(),\n                         d.data_ptr(),\n                         nullptr, // rand_val\n                         dq.data_ptr(),\n                         dk.data_ptr(),\n                         dv.data_ptr(),\n                         nullptr, // dbias\n                         dq_acc.data_ptr(), // dq_acc\n                         nullptr, // seqstart_q_ptr\n                         nullptr, // seqstart_k_ptr\n                         nullptr, // seqlen_q_ptr\n                         nullptr, // seqlen_k_ptr\n                         nullptr, // cu_seqlen_q_ptr\n                         nullptr, // cu_seqlen_k_ptr\n                         seqlen_q,\n                         seqlen_k,\n                         b,\n                         seqlen_q, // max_seqlen_q\n                         seqlen_k, // max_seqlen_k\n                         hdim, // hdim_q\n                         hdim, // hdim_v\n                         h, // nhead\n                         h_k, // nhead_k\n                         softmax_scale,\n                         stride_q,\n                         stride_k,\n                         stride_v,\n                         stride_alibi_slopes,\n                         stride_o,\n                         0, // stride_randval\n                         stride_do,\n                         stride_dq_acc,\n                         stride_dq,\n                         stride_dk,\n                         stride_dv,\n                         0, // stride_dbias, FA without bias\n                         nhead_stride_q,\n                         nhead_stride_k,\n                         nhead_stride_v,\n                         0, // nhead_stride_bias, FA without bias\n                         nhead_stride_o,\n                         0, // nhead_stride_randval\n                         nhead_stride_do,\n                         nhead_stride_lse,\n                         nhead_stride_dq_acc,\n                         nhead_stride_dq,\n                         nhead_stride_dk,\n                         nhead_stride_dv,\n                         0, // nhead_stride_dbias, FA without dbias\n                         batch_stride_q,\n                         batch_stride_k,\n                         batch_stride_v,\n                         0  , // batch_stride_bias, FA without bias\n                         batch_stride_o,\n                         0, // batch_stride_randval\n                         batch_stride_do,\n                         batch_stride_lse,\n                         batch_stride_dq_acc,\n                         batch_stride_dq,\n                         batch_stride_dk,\n                         batch_stride_dv,\n                         0  , // batch_stride_dbias, FA without dbias\n                         split_stride_dq_acc,\n                         mask.left,\n                         mask.right,\n                         static_cast<ck_tile::index_t>(mask.type),\n                         p_dropout,\n                         p_undrop,\n                         drop_seed_offset};\n}\n\nstd::vector<at::Tensor>\nmha_bwd(const at::Tensor &dout,                   // batch_size x seqlen_q x num_heads, x multiple_of(head_size, 8)\n        const at::Tensor &q,                      // batch_size x seqlen_q x num_heads x head_size\n        const at::Tensor &k,                      // batch_size x seqlen_k x num_heads_k x head_size\n        const at::Tensor &v,                      // batch_size x seqlen_k x num_heads_k x head_size\n        const at::Tensor &out,                    // batch_size x seqlen_q x num_heads x head_size\n        const at::Tensor &softmax_lse,            // b x h x seqlen_q\n        std::optional<at::Tensor> &dq_,           // batch_size x seqlen_q x num_heads x head_size\n        std::optional<at::Tensor> &dk_,           // batch_size x seqlen_k x num_heads_k x head_size\n        std::optional<at::Tensor> &dv_,           // batch_size x seqlen_k x num_heads_k x head_size\n        std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads\n        const float p_dropout,                    // probability to drop\n        const float softmax_scale,\n        const bool is_causal,\n        int window_size_left,\n        int window_size_right,\n        const float /*softcap*/,\n        const bool deterministic,\n        std::optional<at::Generator> gen_,\n        std::optional<at::Tensor> &rng_state_)\n{\n#ifdef FLASHATTENTION_DISABLE_BACKWARD\n    TORCH_CHECK(false, \"This flash attention build does not support backward.\");\n#endif\n    if (is_causal) { window_size_right = 0; }\n\n    const bool is_dropout = p_dropout > 0.0;\n#ifdef HIPIFY_V2\n    auto stream = at::cuda::getCurrentCUDAStream().stream();\n#else\n    auto stream = at::cuda::getCurrentHIPStream().stream();\n#endif\n\n    auto q_dtype = q.dtype();\n    TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,\n                \"FlashAttention only support fp16 and bf16 data type\");\n\n    TORCH_CHECK(k.dtype() == q_dtype, \"query and key must have the same dtype\");\n    TORCH_CHECK(v.dtype() == q_dtype, \"query and value must have the same dtype\");\n    TORCH_CHECK(out.dtype() == q_dtype, \"query and out must have the same dtype\");\n    TORCH_CHECK(dout.dtype() == q_dtype, \"query and dout must have the same dtype\");\n\n    const std::string q_dtype_str = q_dtype == torch::kFloat16 ? \"fp16\" : \"bf16\";\n\n    CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);\n    CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);\n\n    TORCH_CHECK(q.stride(-1) == 1, \"Input tensor must have contiguous last dimension\");\n    TORCH_CHECK(k.stride(-1) == 1, \"Input tensor must have contiguous last dimension\");\n    TORCH_CHECK(v.stride(-1) == 1, \"Input tensor must have contiguous last dimension\");\n    TORCH_CHECK(out.stride(-1) == 1, \"out tensor must have contiguous last dimension\");\n    TORCH_CHECK(dout.stride(-1) == 1, \"dout tensor must have contiguous last dimension\");\n\n    const auto sizes = q.sizes();\n\n    const int batch_size = sizes[0];\n    const int seqlen_q = sizes[1];\n    const int num_heads = sizes[2];\n    const int head_size = sizes[3];\n    const int seqlen_k = k.size(1);\n    const int num_heads_k = k.size(2);\n    TORCH_CHECK(batch_size > 0, \"batch size must be positive\");\n    TORCH_CHECK(head_size % 8 == 0, \"head_size should be a multiple of 8\");\n    TORCH_CHECK(head_size <= 256, \"CK FlashAttention backward only supports head dimension at most 256\");\n    TORCH_CHECK(num_heads % num_heads_k == 0, \"Number of heads in key/value must divide number of heads in query\");\n\n    if (window_size_left >= seqlen_k) { window_size_left = -1; }\n    if (window_size_right >= seqlen_k) { window_size_right = -1; }\n\n    mask_info mask;\n    if (is_causal) {\n        std::string mask_identify = \"b:\" + std::to_string(window_size_left) + \",\" + \"0\";\n        mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // casual\n    }\n    else if (window_size_left == -1 && window_size_right == -1) {\n        mask = mask_info::decode(\"0\", seqlen_q, seqlen_k); // no mask\n    }\n    else {\n        // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.\n        std::string mask_identify = \"b:\" + std::to_string(window_size_left) + \",\" + std::to_string(window_size_right);\n        mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // local\n    }\n\n    // q, k, v, out had been padded in mha_fwd\n    // dq_, dk_, dv_ are also padded tensor\n    CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);\n    CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);\n    CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);\n    CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size);\n    CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size);\n\n    at::Tensor dq, dk, dv;\n    if (dq_.has_value()) {\n        dq = dq_.value();\n        TORCH_CHECK(dq.dtype() == q_dtype, \"dq must have the same dtype as q\");\n        CHECK_DEVICE(dq);\n        TORCH_CHECK(dq.stride(-1) == 1, \"dq must have contiguous last dimension\");\n        CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size);\n    } else {\n        dq = torch::empty_like(q);\n    }\n    if (dk_.has_value()) {\n    dk = dk_.value();\n    TORCH_CHECK(dk.dtype() == q_dtype, \"dk must have the same dtype as q\");\n    CHECK_DEVICE(dk);\n    TORCH_CHECK(dk.stride(-1) == 1, \"dk must have contiguous last dimension\");\n    CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size);\n    } else {\n        dk = torch::empty_like(k);\n    }\n    if (dv_.has_value()) {\n        dv = dv_.value();\n        TORCH_CHECK(dv.dtype() == q_dtype, \"dv must have the same dtype as q\");\n        CHECK_DEVICE(dv);\n        TORCH_CHECK(dv.stride(-1) == 1, \"dv must have contiguous last dimension\");\n        CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size);\n    } else {\n        dv = torch::empty_like(v);\n    }\n\n    const auto traits = get_ck_fmha_bwd_traits(\n        mask,\n        q_dtype_str,\n        seqlen_q,\n        seqlen_k,\n        batch_size,\n        head_size,\n        num_heads,\n        num_heads_k,\n        is_dropout,\n        alibi_slopes_.has_value(),\n        deterministic);\n    fmha_bwd_launcher launcher(traits);\n    const ck_tile::index_t nsplits = launcher.dq_acc_splits;\n\n    at::cuda::CUDAGuard device_guard{q.device()};\n\n    auto opts = q.options();\n    auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));\n    at::Tensor dq_accum = torch::zeros({batch_size, num_heads, nsplits, seqlen_q, head_size}, opts.dtype(at::kFloat));\n\n    at::Tensor dk_expanded, dv_expanded;\n    if (num_heads_k != num_heads) {  // MQA / GQA\n        dk_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts);\n        dv_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts);\n    } else {\n        dk_expanded = dk;\n        dv_expanded = dv;\n    }\n\n    auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(\n        gen_, at::cuda::detail::getDefaultCUDAGenerator());\n\n    int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();\n    at::Tensor rng_state;\n\n    if (rng_state_.has_value()) {\n        rng_state = rng_state_.value();\n    } else if(is_dropout) {\n        rng_state = torch::empty({2}, opts.dtype(torch::kInt64));\n        // See Note [Acquire lock when using random generators]\n        std::lock_guard<std::mutex> lock(gen->mutex_);\n        auto philox_args = gen->philox_cuda_state(counter_offset);\n        hipLaunchKernelGGL(\n            flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, 0,\n            philox_args, reinterpret_cast<uint64_t*>(rng_state.data_ptr()));\n    }\n\n    if (seqlen_q > 0) {\n        auto rng_state_ptr = reinterpret_cast<uint64_t*>(rng_state.data_ptr());\n        auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1);\n        ck_tile::stream_config stream_config{stream};\n\n        auto args =\n            get_ck_fmha_bwd_args(\n                mask,\n                batch_size,\n                seqlen_q,\n                seqlen_k,\n                num_heads,\n                num_heads_k,\n                head_size,\n                q,\n                k,\n                v,\n                alibi_slopes_,\n                out,\n                softmax_lse,\n                dout,\n                dq_accum,\n                softmax_d,\n                dq,\n                dk_expanded,\n                dv_expanded,\n                softmax_scale,\n                p_dropout,\n                drop_seed_offset);\n\n        float t = fmha_bwd(traits, args, stream_config);\n        TORCH_CHECK(t >= 0, \"invalid argument for fmha_bwd\");\n    } else {\n        // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.\n        dk_expanded.zero_();\n        dv_expanded.zero_();\n        softmax_d.zero_();\n    }\n\n    // For MQA/GQA we need to sum dK and dV across the groups\n    if (num_heads_k != num_heads) {\n        at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});\n        at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});\n    }\n\n    return { dq, dk, dv, softmax_d };\n}\n"
  },
  {
    "path": "csrc/flash_attn_ck/mha_fwd.cpp",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Tri Dao.\n ******************************************************************************/\n\n#include \"flash_common.hpp\"\n\n#include \"fmha_fwd.hpp\"\n#include \"mask.hpp\"\n\nfmha_fwd_traits get_ck_fmha_fwd_traits(const mask_info &mask,\n                                       std::string dtype,\n                                       int head_size,\n                                       bool has_dropout,\n                                       bool has_lse,\n                                       bool enable_alibi)\n{\n    return fmha_fwd_traits{head_size,\n                           head_size,\n                           dtype,\n                           false, // is_group_mode\n                           true,  // is_v_rowmajor\n                           false, // has_logits_soft_cap\n                           mask.type,\n                           enable_alibi ? bias_enum::alibi : bias_enum::no_bias,\n                           has_lse,\n                           has_dropout,\n                           quant_scale_enum::no_scale}; // qscale_type\n}\n\nfmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,\n                                   bool has_dropout_randval,\n                                   const mask_info &mask,\n                                   // sizes\n                                   const int b,\n                                   const int seqlen_q,\n                                   const int seqlen_k,\n                                   const int h,\n                                   const int h_k,\n                                   const int d,\n                                   // device pointers\n                                   const at::Tensor q,\n                                   const at::Tensor k,\n                                   const at::Tensor v,\n                                   std::optional<at::Tensor> &alibi_slopes_,\n                                   at::Tensor out,\n                                   at::Tensor softmax_lse,\n                                   at::Tensor dropout_randval,\n                                   float softmax_scale,\n                                   float p_dropout,\n                                   std::pair<uint64_t*, uint64_t*> drop_seed_offset)\n{\n    // q: (batch_size, seqlen_q, nheads, d)\n    // k: (batch_size, seqlen_k, nheads_k, d)\n    // v: (batch_size, seqlen_k, nheads_k, d)\n    // o: (batch_size, seqlen_q, nheads, d)\n\n    // alibi_slopes:(batch_size, nheads) or (nhead)\n    // lse: (batch_size, nheads, seqlen_q)\n    // randval: (batch_size, nheads, seqlen_q, seqlen_k)\n\n    ck_tile::index_t stride_q = q.stride(1);\n    ck_tile::index_t stride_k = k.stride(1);\n    ck_tile::index_t stride_v = v.stride(1);\n    ck_tile::index_t stride_o = out.stride(1);\n    ck_tile::index_t stride_randval = has_dropout_randval ? dropout_randval.stride(2) : 0;\n\n    ck_tile::index_t nhead_stride_q = q.stride(2);\n    ck_tile::index_t nhead_stride_k = k.stride(2);\n    ck_tile::index_t nhead_stride_v = v.stride(2);\n    ck_tile::index_t nhead_stride_o = out.stride(2);\n    ck_tile::index_t nhead_stride_lse = has_lse ? softmax_lse.stride(1) : 0;\n    ck_tile::index_t nhead_stride_randval = has_dropout_randval ? dropout_randval.stride(1) : 0;\n\n    ck_tile::index_t batch_stride_q = q.stride(0);\n    ck_tile::index_t batch_stride_k = k.stride(0);\n    ck_tile::index_t batch_stride_v = v.stride(0);\n    ck_tile::index_t batch_stride_o = out.stride(0);\n\n    ck_tile::index_t batch_stride_lse = has_lse ? softmax_lse.stride(0) : 0;\n    ck_tile::index_t batch_stride_randval = has_dropout_randval ? dropout_randval.stride(0) : 0;\n\n    void *alibi_slopes_ptr = nullptr;\n    ck_tile::index_t stride_alibi_slopes = 0;\n\n    if (alibi_slopes_.has_value()) {\n        auto alibi_slopes = alibi_slopes_.value();\n        CHECK_DEVICE(alibi_slopes);\n        TORCH_CHECK(alibi_slopes.stride(-1) == 1, \"ALiBi slopes tensor must have contiguous last dimension\");\n        TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h}));\n        alibi_slopes_ptr = alibi_slopes.data_ptr();\n        stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;\n    }\n\n    return fmha_fwd_args{q.data_ptr(),\n                         k.data_ptr(),\n                         v.data_ptr(),\n                         alibi_slopes_ptr, // bias\n                         nullptr, // q_descale_ptr\n                         nullptr, // k_descale_ptr\n                         nullptr, // v_descale_ptr\n                         has_dropout_randval ? dropout_randval.data_ptr() : nullptr,\n                         has_lse ? softmax_lse.data_ptr() : nullptr,\n                         out.data_ptr(),\n                         nullptr, // seqstart_q_ptr\n                         nullptr, // seqstart_k_ptr\n                         nullptr, // seqlen_q_ptr\n                         nullptr, // seqlen_k_ptr\n                         nullptr, // cu_seqlen_q_ptr\n                         nullptr, // cu_seqlen_k_ptr\n                         nullptr, // block_scale_seqstart_q_ptr\n                         nullptr, // block_scale_seqstart_k_ptr\n                         nullptr, // seqstart_v_scale_ptr\n                         nullptr, // sink_ptr\n                         seqlen_q,\n                         seqlen_k,\n                         b,\n                         seqlen_q,      // max_seqlen_q\n                         d,             // hdim_q\n                         d,             // hdim_v\n                         h,             // nhead\n                         h_k,           // nhead_k\n                         softmax_scale, // scale_s\n                         0.0f,          // logits_soft_cap\n                         stride_q,\n                         stride_k,\n                         stride_v,\n                         stride_alibi_slopes,\n                         stride_randval,\n                         stride_o,\n                         0, // stride_q_descale\n                         0, // stride_k_descale\n                         0, // stride_v_descale\n                         nhead_stride_q,\n                         nhead_stride_k,\n                         nhead_stride_v,\n                         0, // nhead_stride_bias, FA without bias\n                         nhead_stride_randval,\n                         nhead_stride_lse,\n                         nhead_stride_o,\n                         0, // nhead_stride_q_descale\n                         0, // nhead_stride_k_descale\n                         0, // nhead_stride_v_descale\n                         batch_stride_q,\n                         batch_stride_k,\n                         batch_stride_v,\n                         0, // batch_stride_bias, FA without bias\n                         batch_stride_randval,\n                         batch_stride_lse,\n                         batch_stride_o,\n                         0, // batch_stride_q_descale\n                         0, // batch_stride_k_descale\n                         0, // batch_stride_v_descale\n                         mask.left,\n                         mask.right,\n                         0, // sink_size\n                         static_cast<ck_tile::index_t>(mask.type),\n                         0, // min_seqlen_q\n                         p_dropout,\n                         has_dropout_randval,\n                         drop_seed_offset,\n                         0,     // block_scale_size_q\n                         0};    // block_scale_size_kv\n}\n\nstd::vector<at::Tensor>\nmha_fwd(at::Tensor &q,                            // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)\n        const at::Tensor &k,                      // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)\n        const at::Tensor &v,                      // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)\n        std::optional<at::Tensor> &out_,          // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)\n        std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads\n        const float p_dropout,\n        const float softmax_scale,\n        bool is_causal,\n        int window_size_left,\n        int window_size_right,\n        const float /*softcap*/,\n        const bool return_dropout_randval,\n        std::optional<at::Generator> gen_)\n{\n    auto q_dtype = q.dtype();\n    TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,\n                \"FlashAttention only support fp16 and bf16 data type\");\n\n    TORCH_CHECK(k.dtype() == q_dtype, \"query and key must have the same dtype\");\n    TORCH_CHECK(v.dtype() == q_dtype, \"query and value must have the same dtype\");\n\n    std::string q_dtype_str = q_dtype == torch::kFloat16 ? \"fp16\" : \"bf16\";\n\n    CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);\n\n    TORCH_CHECK(q.stride(-1) == 1, \"Input tensor must have contiguous last dimension\");\n    TORCH_CHECK(k.stride(-1) == 1, \"Input tensor must have contiguous last dimension\");\n    TORCH_CHECK(v.stride(-1) == 1, \"Input tensor must have contiguous last dimension\");\n\n    const auto sizes = q.sizes();\n\n    const int batch_size = sizes[0];\n    int seqlen_q = sizes[1];\n    int num_heads = sizes[2];\n    const int head_size = sizes[3];\n    const int seqlen_k = k.size(1);\n    const int num_heads_k = k.size(2);\n    TORCH_CHECK(batch_size > 0, \"batch size must be positive\");\n    TORCH_CHECK(head_size <= 256, \"CK only supports head dimension at most 256\");\n    TORCH_CHECK(head_size % 8 == 0, \"query, key, value, and out_ must have a head_size that is a multiple of 8\");\n    TORCH_CHECK(num_heads % num_heads_k == 0, \"Number of heads in key/value must divide number of heads in query\");\n\n    if (window_size_left >= seqlen_k) { window_size_left = -1; }\n    if (window_size_right >= seqlen_k) { window_size_right = -1; }\n\n    // causal=true is the same as causal=false in this case\n    if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }\n\n    mask_info mask;\n    if (is_causal) {\n        // Causal is the special case where window_size_right == 0 and window_size_left < 0.\n        window_size_right = 0;\n        std::string mask_identify = \"b:\" + std::to_string(window_size_left) + \",\" + \"0\";\n        mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // casual\n    }\n    else if (window_size_left == -1 && window_size_right == -1) {\n        mask = mask_info::decode(\"0\", seqlen_q, seqlen_k); // no mask\n    }\n    else {\n        // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.\n        std::string mask_identify = \"b:\" + std::to_string(window_size_left) + \",\" + std::to_string(window_size_right);\n        mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // local\n    }\n\n    // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case\n    // H/t Daniel Haziza\n    const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size % 8 == 0 && !alibi_slopes_.has_value();\n    const int ngroups = num_heads / num_heads_k;\n    if (seqlenq_ngroups_swapped) {\n        q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2);\n        seqlen_q = ngroups;\n        num_heads = num_heads_k;\n    }\n\n    CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);\n    CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);\n    CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);\n\n    at::Tensor out;\n    if (out_.has_value()) {\n        out = out_.value();\n        TORCH_CHECK(out.dtype() == q_dtype, \"Output must have the same dtype as inputs\");\n        CHECK_DEVICE(out);\n        TORCH_CHECK(out.stride(-1) == 1, \"Output tensor must have contiguous last dimension\");\n        CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size);\n        if (seqlenq_ngroups_swapped) {\n            out = out.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2);\n        }\n    }\n    else {\n        out = torch::empty_like(q);\n    }\n\n    // Otherwise the kernel will be launched from cuda:0 device\n    at::cuda::CUDAGuard device_guard{q.device()};\n\n    auto opts = q.options();\n    bool has_lse = true;\n    bool has_dropout = p_dropout > 0.0f;\n\n    at::Tensor softmax_lse;\n    // TODO - check gradient, only training require lse\n    softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(torch::kFloat32));\n\n    at::Tensor p;\n    if (return_dropout_randval) {\n        TORCH_CHECK(has_dropout, \"return_dropout_randval require p_dropout > 0\");\n        p = torch::empty({batch_size, num_heads, seqlen_q, seqlen_k}, opts.dtype(torch::kUInt8));\n    }\n    else {\n        p = torch::empty({ 0 }, opts);\n    }\n\n    int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();\n    auto rng_state = torch::empty({2}, opts.dtype(torch::kInt64));\n    auto rng_state_ptr = reinterpret_cast<uint64_t*>(rng_state.data_ptr());\n\n    if (p_dropout > 0.0)  {\n        auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(\n            gen_, at::cuda::detail::getDefaultCUDAGenerator());\n        // See Note [Acquire lock when using random generators]\n        std::lock_guard<std::mutex> lock(gen->mutex_);\n        auto philox_args = gen->philox_cuda_state(counter_offset);\n        hipLaunchKernelGGL(\n            flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, 0, philox_args, rng_state_ptr);\n    }\n\n    if (seqlen_k > 0) {\n        auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1);\n#ifdef HIPIFY_V2\n        auto stream = at::cuda::getCurrentCUDAStream().stream();\n#else\n        auto stream = at::cuda::getCurrentHIPStream().stream();\n#endif\n        ck_tile::stream_config stream_config{stream};\n\n        auto traits =\n            get_ck_fmha_fwd_traits(\n                mask,\n                q_dtype_str,\n                head_size,\n                has_dropout,\n                has_lse,\n                alibi_slopes_.has_value());\n\n        auto args =\n            get_ck_fmha_fwd_args(\n                has_lse,\n                return_dropout_randval,\n                mask,\n                batch_size,\n                seqlen_q,\n                seqlen_k,\n                num_heads,\n                num_heads_k,\n                head_size,\n                q,\n                k,\n                v,\n                alibi_slopes_,\n                out,\n                softmax_lse,\n                p,\n                softmax_scale,\n                p_dropout,\n                drop_seed_offset);\n\n        float t = fmha_fwd(traits, args, stream_config);\n        TORCH_CHECK(t >= 0, \"invalid argument for fmha_fwd\");\n    }\n    else {\n        // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.\n        out.zero_();\n        softmax_lse.fill_(std::numeric_limits<float>::infinity());\n    }\n\n    if (seqlenq_ngroups_swapped) {\n        out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size});\n        q = q.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size});\n        softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});\n    }\n    return {out, softmax_lse, p, rng_state};\n}\n"
  },
  {
    "path": "csrc/flash_attn_ck/mha_fwd_kvcache.cpp",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Tri Dao.\n ******************************************************************************/\n\n#include \"flash_common.hpp\"\n\n#include \"fmha_fwd.hpp\"\n#include \"rotary.hpp\"\n\nfmha_fwd_appendkv_traits get_ck_fmha_fwd_appendkv_traits(std::string dtype,\n                                                        int head_size,\n                                                        int rotary_dim,\n                                                        bool is_rotary_interleaved)\n{\n    rope_enum rope_type = (0 < rotary_dim ? (is_rotary_interleaved ? rope_enum::interleaved\n                                                                   : rope_enum::half_rotated)\n                                          : rope_enum::none);\n\n    return fmha_fwd_appendkv_traits{head_size,\n                                    head_size,\n                                    dtype,\n                                    true,  // is_v_rowmajor\n                                    rope_type};\n}\n\nfmha_fwd_splitkv_traits get_ck_fmha_fwd_splitkv_traits(const mask_info &mask,\n                                                       std::string dtype,\n                                                       int head_size,\n                                                       bool has_lse,\n                                                       bool enable_alibi)\n{\n    return fmha_fwd_splitkv_traits{head_size,\n                                   head_size,\n                                   dtype,\n                                   false,  // is_group_mode\n                                   true,   // is_v_rowmajor\n                                   false,  // has_logits_soft_cap\n                                   mask.type,\n                                   enable_alibi ? bias_enum::alibi : bias_enum::no_bias,\n                                   has_lse,\n                                   false,  // do_fp8_static_quant\n                                   false}; // has_sink\n}\n\nfmha_fwd_appendkv_args get_ck_fmha_fwd_appendkv_args(const int b,\n                                                     const int seqlen_q,\n                                                     const int seqlen_knew,\n                                                     const int h,\n                                                     const int h_k,\n                                                     const int d,\n                                                     const int rotary_dim,\n                                                     const bool has_mask,\n                                                     const int page_block_size,\n                                                     // device pointers\n                                                     const at::Tensor q,\n                                                     const at::Tensor kcache,\n                                                     const at::Tensor vcache,\n                                                     const at::Tensor knew,\n                                                     const at::Tensor vnew,\n                                                     std::optional<const at::Tensor> &seqlens_k_,\n                                                     std::optional<const at::Tensor> &rotary_cos_,\n                                                     std::optional<const at::Tensor> &rotary_sin_,\n                                                     std::optional<const at::Tensor> &cache_batch_idx_,\n                                                     std::optional<at::Tensor> &block_table_)\n{\n    // q: (batch_size, seqlen_q, nheads, d)\n    // kcache: (batch_size_c, seqlen_k, nheads_k, d) or (num_blocks, page_block_size, nheads_k, d)\n    // vcache: (batch_size_c, seqlen_k, nheads_k, d) or (num_blocks, page_block_size, nheads_k, d)\n    // knew: (batch_size, seqlen_knew, nheads_k, d)\n    // vnew: (batch_size, seqlen_knew, nheads_k, d)\n\n    // seqlens_k: (batch_size)\n    // rotary_cos: (seqlen_ro, rotary_dim / 2)\n    // rotary_sin: (seqlen_ro, rotary_dim / 2)\n    // block_table: (batch_size, max_num_blocks_per_seq)\n\n    fmha_fwd_appendkv_args args;\n    args.q_ptr = q.data_ptr();\n    args.k_ptr = kcache.data_ptr();\n    args.knew_ptr = knew.data_ptr();\n    args.v_ptr = vcache.data_ptr();\n    args.vnew_ptr = vnew.data_ptr();\n    args.seqlen_k_ptr = seqlens_k_.has_value() ? seqlens_k_.value().data_ptr() : nullptr;\n\n    args.seqlen_q = seqlen_q;\n    args.seqlen_knew = seqlen_knew;\n    args.batch = b;\n    args.hdim_q = d;\n    args.hdim_v = d;\n    args.nhead_q = h;\n    args.nhead_k = h_k;\n\n    args.rotary_cos_ptr = rotary_cos_.has_value() ? rotary_cos_.value().data_ptr() : nullptr;\n    args.rotary_sin_ptr = rotary_sin_.has_value() ? rotary_sin_.value().data_ptr() : nullptr;\n    args.rotary_dim = rotary_dim;\n    args.has_mask = has_mask;\n\n    if (block_table_.has_value())\n    {\n        auto block_table = block_table_.value();\n        args.block_table_ptr = block_table.data_ptr();\n        args.batch_stride_block_table = block_table.stride(0);\n        args.page_block_size = page_block_size;\n    }\n    else\n    {\n        args.block_table_ptr = nullptr;\n        args.batch_stride_block_table = 0;\n        args.page_block_size = 0;\n    }\n\n    args.cache_batch_idx = cache_batch_idx_.has_value() ?\n        reinterpret_cast<int *>(cache_batch_idx_.value().data_ptr()) : nullptr;\n\n    args.batch_stride_q = q.stride(0);\n    args.stride_q = q.stride(1);\n    args.nhead_stride_q = q.stride(2);\n\n    args.batch_stride_k = kcache.stride(0);\n    args.stride_k = kcache.stride(1);\n    args.nhead_stride_k = kcache.stride(2);\n\n    args.batch_stride_knew = knew.stride(0);\n    args.stride_knew = knew.stride(1);\n    args.nhead_stride_knew = knew.stride(2);\n\n    args.batch_stride_v = vcache.stride(0);\n    args.stride_v = vcache.stride(1);\n    args.nhead_stride_v = vcache.stride(2);\n\n    args.batch_stride_vnew = vnew.stride(0);\n    args.stride_vnew = vnew.stride(1);\n    args.nhead_stride_vnew = vnew.stride(2);\n\n    return args;\n}\n\nfmha_fwd_splitkv_args get_ck_fmha_fwd_splitkv_args(bool has_lse,\n                                                   const mask_info &mask,\n                                                   const int b,\n                                                   const int seqlen_q,\n                                                   const int seqlen_k,\n                                                   const int h,\n                                                   const int h_k,\n                                                   const int d,\n                                                   const int page_block_size,\n                                                   const int num_splits,\n                                                   float softmax_scale,\n                                                   // device pointers\n                                                   const at::Tensor q,\n                                                   const at::Tensor k,\n                                                   const at::Tensor v,\n                                                   const at::Tensor seqlens_k,\n                                                   std::optional<const at::Tensor> &cache_batch_idx_,\n                                                   std::optional<at::Tensor> &block_table_,\n                                                   std::optional<at::Tensor> &alibi_slopes_,\n                                                   at::Tensor out,\n                                                   at::Tensor lse,\n                                                   at::Tensor lse_acc,\n                                                   at::Tensor out_acc)\n{\n    // q: (batch_size, seqlen_q, nheads, d)\n    // k: (batch_size, seqlen_k, nheads_k, d)\n    // v: (batch_size, seqlen_k, nheads_k, d)\n    // o: (batch_size, seqlen_q, nheads, d)\n\n    // alibi_slopes:(batch_size, nheads) or (nhead)\n    // lse: (batch_size, nheads, seqlen_q)\n    // lse_acc: (split, batch_size, nheads, seqlen_q)\n    // o_acc: (split, batch_size, nheads, seqlen_q, d)\n\n    fmha_fwd_splitkv_args args;\n    args.q_ptr = q.data_ptr();\n    args.k_ptr = k.data_ptr();\n    args.v_ptr = v.data_ptr();\n    args.bias_ptr = nullptr;\n    args.lse_acc_ptr = lse_acc.data_ptr();\n    args.o_acc_ptr = out_acc.data_ptr();\n    args.lse_ptr = nullptr;\n    args.o_ptr = out.data_ptr();\n    args.sink_ptr = nullptr;\n\n    if (block_table_.has_value())\n    {\n        auto block_table = block_table_.value();\n        args.block_table_ptr = block_table.data_ptr();\n        args.batch_stride_block_table = block_table.stride(0);\n        args.page_block_size = page_block_size;\n    }\n    else\n    {\n        args.block_table_ptr = nullptr;\n        args.batch_stride_block_table = 0;\n        args.page_block_size = 0;\n    }\n\n    args.cache_batch_idx = cache_batch_idx_.has_value() ? cache_batch_idx_.value().data_ptr() : nullptr;\n\n    args.seqstart_q_ptr = nullptr;\n    args.seqstart_k_ptr = nullptr;\n    args.seqlen_k_ptr = seqlens_k.data_ptr();\n\n    args.seqlen_q = seqlen_q;\n    args.seqlen_k = seqlen_k;\n    args.batch = b;\n    args.max_seqlen_q = seqlen_q;\n    args.hdim_q = d;\n    args.hdim_v = d;\n    args.nhead_q = h;\n    args.nhead_k = h_k;\n    args.num_splits = num_splits;\n\n    args.scale_s = softmax_scale;\n    args.scale_p = 1;\n    args.scale_o = 1;\n\n    args.batch_stride_q = q.stride(0);\n    args.stride_q = q.stride(1);\n    args.nhead_stride_q = q.stride(2);\n\n    args.batch_stride_k = k.stride(0);\n    args.stride_k = k.stride(1);\n    args.nhead_stride_k = k.stride(2);\n\n    args.batch_stride_v = v.stride(0);\n    args.stride_v = v.stride(1);\n    args.nhead_stride_v = v.stride(2);\n\n    args.batch_stride_o = out.stride(0);\n    args.stride_o = out.stride(1);\n    args.nhead_stride_o = out.stride(2);\n\n    args.batch_stride_bias = 0;\n    args.stride_bias = 0;\n    args.nhead_stride_bias = 0;\n\n    args.batch_stride_lse = 0;\n    args.nhead_stride_lse = 0;\n\n    args.split_stride_lse_acc = lse_acc.stride(0);\n    args.batch_stride_lse_acc = lse_acc.stride(1);\n    args.nhead_stride_lse_acc = lse_acc.stride(2);\n\n    args.split_stride_o_acc = out_acc.stride(0);\n    args.batch_stride_o_acc = out_acc.stride(1);\n    args.nhead_stride_o_acc = out_acc.stride(2);\n    args.stride_o_acc = out_acc.stride(3);\n\n    if (has_lse) {\n        args.lse_ptr = lse.data_ptr();\n        args.batch_stride_lse = lse.stride(0);\n        args.nhead_stride_lse = lse.stride(1);\n    }\n\n    if (alibi_slopes_.has_value()) {\n        auto alibi_slopes = alibi_slopes_.value();\n        CHECK_DEVICE(alibi_slopes);\n        TORCH_CHECK(alibi_slopes.stride(-1) == 1, \"ALiBi slopes tensor must have contiguous last dimension\");\n        TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h}));\n        args.bias_ptr = alibi_slopes.data_ptr();\n        args.stride_bias = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;\n    }\n\n    args.window_size_left = mask.left;\n    args.window_size_right = mask.right;\n    args.sink_size = 0;\n    args.mask_type = static_cast<ck_tile::index_t>(mask.type);\n\n    return args;\n}\n\nstd::vector<at::Tensor>\nmha_fwd_kvcache(at::Tensor &q,                                      // batch_size x seqlen_q x num_heads x head_size\n                const at::Tensor &kcache,                           // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.\n                const at::Tensor &vcache,                           // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.\n                std::optional<const at::Tensor> &k_,                // batch_size x seqlen_knew x num_heads_k x head_size\n                std::optional<const at::Tensor> &v_,                // batch_size x seqlen_knew x num_heads_k x head_size\n                std::optional<const at::Tensor> &seqlens_k_,        // batch_size\n                std::optional<const at::Tensor> &rotary_cos_,       // seqlen_ro x (rotary_dim / 2)\n                std::optional<const at::Tensor> &rotary_sin_,       // seqlen_ro x (rotary_dim / 2)\n                std::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache\n                std::optional<const at::Tensor> & /*leftpad_k_*/,   // batch_size\n                std::optional<at::Tensor> &block_table_,            // batch_size x max_num_blocks_per_seq\n                std::optional<at::Tensor> &alibi_slopes_,           // num_heads or batch_size x num_heads\n                std::optional<at::Tensor> &out_,                    // batch_size x seqlen_q x num_heads x head_size\n                const float softmax_scale,\n                bool is_causal,\n                int window_size_left,\n                int window_size_right,\n                const float /*softcap*/,\n                bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2\n                int num_splits)\n{\n    auto q_dtype = q.dtype();\n    TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,\n                \"FlashAttention only support fp16 and bf16 data type\");\n\n    TORCH_CHECK(kcache.dtype() == q_dtype, \"query and key must have the same dtype\");\n    TORCH_CHECK(vcache.dtype() == q_dtype, \"query and value must have the same dtype\");\n    std::string q_dtype_str = q_dtype == torch::kFloat16 ? \"fp16\" : \"bf16\";\n\n    CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);\n\n    TORCH_CHECK(q.stride(-1) == 1, \"Input tensor must have contiguous last dimension\");\n    TORCH_CHECK(kcache.stride(-1) == 1, \"Input tensor must have contiguous last dimension\");\n    TORCH_CHECK(vcache.stride(-1) == 1, \"Input tensor must have contiguous last dimension\");\n\n    at::Tensor block_table;\n    const bool paged_KV = block_table_.has_value();\n    if (paged_KV) {\n        TORCH_CHECK(!cache_batch_idx_.has_value(), \"Paged KVcache does not support cache_batch_idx\");\n        block_table = block_table_.value();\n        CHECK_DEVICE(block_table);\n        TORCH_CHECK(block_table.dtype() == torch::kInt32, \"block_table must have dtype torch.int32\");\n        TORCH_CHECK(block_table.stride(-1) == 1, \"block_table must have contiguous last dimension\");\n    }\n\n    const auto sizes = q.sizes();\n\n    const int batch_size = sizes[0];\n    int seqlen_q = sizes[1];\n    int num_heads = sizes[2];\n    const int head_size_og = sizes[3];\n\n    const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);\n    const int num_blocks = !paged_KV ? 0 : kcache.size(0);\n    const int page_block_size = !paged_KV ? 1 : kcache.size(1);\n    TORCH_CHECK(!paged_KV || page_block_size % 128 == 0, \"Paged KV cache block size must be divisible by 128\");\n    const int seqlen_k = !paged_KV ? kcache.size(1) : max_num_blocks_per_seq * page_block_size;\n    const int num_heads_k = kcache.size(2);\n    const int batch_size_c = !paged_KV ? kcache.size(0) : batch_size;\n    TORCH_CHECK(batch_size > 0, \"batch size must be positive\");\n    TORCH_CHECK(head_size_og <= 256, \"FlashAttention forward only supports head dimension at most 256\");\n    TORCH_CHECK(num_heads % num_heads_k == 0, \"Number of heads in key/value must divide number of heads in query\");\n\n    // causal=true is the same as causal=false in this case\n    if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }\n    if (is_causal) { window_size_right = 0; }\n\n    mask_info mask;\n    if (is_causal) {\n        // Causal is the special case where window_size_right == 0 and window_size_left < 0.\n        window_size_right = 0;\n        std::string mask_identify = \"b:\" + std::to_string(window_size_left) + \",\" + \"0\";\n        mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // casual\n    }\n    else if (window_size_left == -1 && window_size_right == -1) {\n        mask = mask_info::decode(\"0\", seqlen_q, seqlen_k); // no mask\n    }\n    else {\n        // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.\n        std::string mask_identify = \"b:\" + std::to_string(window_size_left) + \",\" + std::to_string(window_size_right);\n        mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // local\n    }\n\n    // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case\n    // H/t Daniel Haziza\n    const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0 && !alibi_slopes_.has_value();\n    if (seqlenq_ngroups_swapped) {\n        const int ngroups = num_heads / num_heads_k;\n        q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);\n        seqlen_q = ngroups;\n        num_heads = num_heads_k;\n    }\n\n    if (window_size_left >= seqlen_k) { window_size_left = -1; }\n    if (window_size_right >= seqlen_k) { window_size_right = -1; }\n\n    CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);\n    if (!paged_KV) {\n        CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);\n        CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);\n    } else {\n        CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_og);\n        CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_og);\n        CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);\n    }\n\n    at::Tensor q_padded, kcache_padded, vcache_padded;\n    if (head_size_og % 8 != 0) {\n        q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));\n        kcache_padded = torch::nn::functional::pad(kcache, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));\n        vcache_padded = torch::nn::functional::pad(vcache, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));\n    } else {\n        q_padded = q;\n        kcache_padded = kcache;\n        vcache_padded = vcache;\n    }\n\n    at::Tensor out;\n    if (out_.has_value()) {\n        out = out_.value();\n        TORCH_CHECK(out.dtype() == q_dtype, \"Output must have the same dtype as inputs\");\n        CHECK_DEVICE(out);\n        TORCH_CHECK(out.stride(-1) == 1, \"Output tensor must have contiguous last dimension\");\n        CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);\n        if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }\n    } else {\n        out = torch::empty_like(q_padded);\n    }\n\n    auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };\n    const int head_size_8x = round_multiple(head_size_og, 8);\n\n    // Otherwise the kernel will be launched from cuda:0 device\n    at::cuda::CUDAGuard device_guard{q.device()};\n\n    auto opts = q.options();\n\n    // TODO - check gradient, only training require lse\n    bool has_lse = true;\n    auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));\n\n    int seqlen_knew = 0;\n    at::Tensor k, v, k_padded, v_padded;\n    if (k_.has_value()) {\n        TORCH_CHECK(v_.has_value(), \"If key is supplied, value must also be passed in\");\n        TORCH_CHECK(seqlens_k_.has_value(), \"If key is supplied, seqlens_k must also be passed in\");\n        TORCH_CHECK(seqlen_q <= seqlen_k, \"If key is supplied, it must have seqlen <= the seqlen of the KV cache\");\n        k = k_.value();\n        v = v_.value();\n        TORCH_CHECK(k.dtype() == q_dtype, \"Key must have the same dtype as query\");\n        TORCH_CHECK(v.dtype() == q_dtype, \"Value must have the same dtype as query\");\n        CHECK_DEVICE(k); CHECK_DEVICE(v);\n        TORCH_CHECK(k.stride(-1) == 1, \"Key tensor must have contiguous last dimension\");\n        TORCH_CHECK(v.stride(-1) == 1, \"Value tensor must have contiguous last dimension\");\n        seqlen_knew = k.size(1);\n        CHECK_SHAPE(k, batch_size, seqlen_knew, num_heads_k, head_size_og);\n        CHECK_SHAPE(v, batch_size, seqlen_knew, num_heads_k, head_size_og);\n        if (head_size_og % 8 != 0) {\n            k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));\n            v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));\n        } else {\n            k_padded = k;\n            v_padded = v;\n        }\n    }\n\n    if (seqlens_k_.has_value()) {\n        auto seqlens_k = seqlens_k_.value();\n        TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, \"seqlens_k must have dtype int32\");\n        CHECK_DEVICE(seqlens_k);\n        CHECK_CONTIGUOUS(seqlens_k);\n        CHECK_SHAPE(seqlens_k, batch_size);\n    }\n\n    int rotary_dim = 0;\n    if (rotary_cos_.has_value()) {\n        TORCH_CHECK(k_.has_value(), \"If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided\");\n        auto rotary_cos = rotary_cos_.value();\n        CHECK_DEVICE(rotary_cos);\n        rotary_dim = rotary_cos.size(1) * 2;\n        TORCH_CHECK(rotary_dim <= head_size_og, \"rotary_dim must be <= headdim\");\n        TORCH_CHECK(rotary_dim % 16 == 0, \"Only rotary dimensions divisible by 16 are currently supported\");\n        const int seqlen_ro = rotary_cos.size(0);\n        TORCH_CHECK(seqlen_ro >= seqlen_k, \"cos/sin seqlen must be at least the seqlen of KV cache\");\n        CHECK_SHAPE(rotary_cos, seqlen_ro, rotary_dim / 2);\n        CHECK_CONTIGUOUS(rotary_cos);\n        TORCH_CHECK(rotary_cos.scalar_type() == q_dtype, \"rotary_cos must have the same dtype as query\");\n\n        TORCH_CHECK(rotary_sin_.has_value(), \"If rotary cos is provided, rotary sin must also be provided\");\n        auto rotary_sin = rotary_sin_.value();\n        CHECK_DEVICE(rotary_sin);\n        CHECK_SHAPE(rotary_sin, seqlen_ro, rotary_dim / 2);\n        CHECK_CONTIGUOUS(rotary_sin);\n        TORCH_CHECK(rotary_sin.scalar_type() == q_dtype, \"rotary_cos must have the same dtype as query\");\n    }\n\n\n    if (cache_batch_idx_.has_value()) {\n        auto cache_batch_idx = cache_batch_idx_.value();\n        CHECK_DEVICE(cache_batch_idx);\n        CHECK_CONTIGUOUS(cache_batch_idx);\n        TORCH_CHECK(cache_batch_idx.scalar_type() == torch::kInt32, \"cache_batch_idx must have dtype int32\");\n    }\n\n    num_splits = flash::override_num_splits_if_necessary(batch_size, num_heads, seqlen_q, head_size_8x, 0, num_splits);\n    TORCH_CHECK(num_splits > 0, \"num_splits should greater than 0\");\n    TORCH_CHECK(num_splits <= 128, \"num_splits greater than 128 is not supported\");\n\n    // Keep references to these tensors to extend their lifetime\n    auto softmax_lse_accum = torch::empty({num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));\n    auto out_accum = torch::empty({num_splits, batch_size, num_heads, seqlen_q, head_size_8x}, opts.dtype(at::kFloat));\n\n    auto stream = at::cuda::getCurrentCUDAStream().stream();\n    ck_tile::stream_config stream_config{stream};\n\n    if (seqlen_knew > 0 || rotary_dim > 0) {\n        auto appendkv_traits =\n            get_ck_fmha_fwd_appendkv_traits(q_dtype_str, head_size_8x, rotary_dim, is_rotary_interleaved);\n\n        auto appendkv_args =\n            get_ck_fmha_fwd_appendkv_args(\n                batch_size,\n                seqlen_q,\n                seqlen_knew,\n                num_heads,\n                num_heads_k,\n                head_size_8x,\n                rotary_dim,\n                mask.type != mask_enum::no_mask,\n                page_block_size,\n                q_padded,\n                kcache_padded,\n                vcache_padded,\n                k_padded,\n                v_padded,\n                seqlens_k_,\n                rotary_cos_,\n                rotary_sin_,\n                cache_batch_idx_,\n                block_table_);\n\n        fmha_fwd_appendkv(appendkv_traits, appendkv_args, stream_config);\n    }\n\n    // seqlens_k_ is the seqlen of kvcache. We need to add seqlen_knew for before attention\n    auto append_seqlens_k = torch::empty({batch_size}, opts.dtype(torch::kInt32));\n    if (seqlens_k_.has_value())\n        append_seqlens_k = seqlens_k_.value() + seqlen_knew;\n    else\n        append_seqlens_k.fill_(seqlen_knew);\n\n    // we use splitkv even num_splits == 1, because fmha_fwd() does not support seqlen_k_ in batch mode\n    auto splitkv_traits =\n        get_ck_fmha_fwd_splitkv_traits(mask, q_dtype_str, head_size_8x, has_lse, alibi_slopes_.has_value());\n\n    auto splitkv_args =\n        get_ck_fmha_fwd_splitkv_args(\n            has_lse,\n            mask,\n            batch_size,\n            seqlen_q,\n            seqlen_k,\n            num_heads,\n            num_heads_k,\n            head_size_8x,\n            page_block_size,\n            num_splits,\n            softmax_scale,\n            q_padded,\n            kcache_padded,\n            vcache_padded,\n            append_seqlens_k,\n            cache_batch_idx_,\n            block_table_,\n            alibi_slopes_,\n            out,\n            softmax_lse,\n            softmax_lse_accum,\n            out_accum);\n\n    fmha_fwd_splitkv(splitkv_traits, splitkv_args, stream_config);\n\n    if (head_size_og % 8 != 0) {\n        out = out.index({\"...\", torch::indexing::Slice(torch::indexing::None, head_size_og)});\n        if (out_.has_value()) { out_.value().copy_(out); }\n        if (k_.has_value()) {\n            // It's expensive to copy the KV cache here for the case where head size not divisible by 8,\n            // but we don't expect to get this case in practice. This is just so that the code works for that case.\n            kcache.copy_(kcache_padded.index({\"...\", torch::indexing::Slice(torch::indexing::None, head_size_og)}));\n            vcache.copy_(vcache_padded.index({\"...\", torch::indexing::Slice(torch::indexing::None, head_size_og)}));\n        }\n    }\n\n    if (seqlenq_ngroups_swapped) {\n        out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});\n        softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});\n    }\n    return {out, softmax_lse};\n}\n"
  },
  {
    "path": "csrc/flash_attn_ck/mha_varlen_bwd.cpp",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Tri Dao.\n ******************************************************************************/\n\n#include \"flash_common.hpp\"\n\n#include \"fmha_bwd.hpp\"\n#include \"mask.hpp\"\n\nfmha_bwd_traits get_ck_fmha_varlen_bwd_traits(const mask_info &mask,\n                                              std::string dtype,\n                                              int seqlen_q,\n                                              int seqlen_k,\n                                              int batch,\n                                              int max_seqlen_q,\n                                              int max_seqlen_k,\n                                              int head_size,\n                                              int nhead_q,\n                                              int nhead_k,\n                                              bool has_dropout,\n                                              bool enable_alibi,\n                                              bool deterministic)\n{\n    return fmha_bwd_traits{seqlen_q,\n                           seqlen_k,\n                           batch,\n                           max_seqlen_q,\n                           max_seqlen_k,\n                           head_size, // hdim_q\n                           head_size, // hdim_k\n                           nhead_q,\n                           nhead_k,\n                           dtype,\n                           true, // is_group_mode\n                           mask.type,\n                           enable_alibi ? bias_enum::alibi : bias_enum::no_bias,\n                           false,    // has_dbias\n                           has_dropout,\n                           false, // s_randval\n                           deterministic};\n}\nfmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask,\n                                          // sizes\n                                          const int b,\n                                          const int max_seqlen_q,\n                                          const int max_seqlen_k,\n                                          const int h,\n                                          const int h_k,\n                                          const int hdim,\n                                          // device pointers\n                                          const at::Tensor q,\n                                          const at::Tensor k,\n                                          const at::Tensor v,\n                                          const at::Tensor seqlens_q,\n                                          const at::Tensor seqlens_k,\n                                          std::optional<at::Tensor> &alibi_slopes_,\n                                          const at::Tensor out,\n                                          const at::Tensor softmax_lse,\n                                          const at::Tensor dout,\n                                          at::Tensor dq_acc,\n                                          at::Tensor d,\n                                          at::Tensor dq,\n                                          at::Tensor dk,\n                                          at::Tensor dv,\n                                          float softmax_scale,\n                                          float p_dropout,\n                                          std::pair<uint64_t*, uint64_t*> drop_seed_offset)\n{\n    ck_tile::index_t total_q = q.size(0);\n    ck_tile::index_t total_k = k.size(0);\n\n    // q: (total_q, nheads, hdim)\n    ck_tile::index_t batch_stride_q = 0;\n    ck_tile::index_t stride_q = q.stride(0);\n    ck_tile::index_t nhead_stride_q = q.stride(1);\n\n    // k: (total_k, nheads_k, hdim)\n    ck_tile::index_t batch_stride_k = 0;\n    ck_tile::index_t stride_k = k.stride(0);\n    ck_tile::index_t nhead_stride_k = k.stride(1);\n\n    // v: (total_k, nheads_k, hdim)\n    ck_tile::index_t batch_stride_v = 0;\n    ck_tile::index_t stride_v = v.stride(0);\n    ck_tile::index_t nhead_stride_v = v.stride(1);\n\n    // o: (total_q, nheads, hdim)\n    ck_tile::index_t batch_stride_o = 0;\n    ck_tile::index_t stride_o = out.stride(0);\n    ck_tile::index_t nhead_stride_o = out.stride(1);\n\n    // lse: (nheads, total_q)\n    ck_tile::index_t batch_stride_lse = 0;\n    ck_tile::index_t nhead_stride_lse = softmax_lse.stride(0);\n\n    // do: (total_q, nheads, hdim)\n    ck_tile::index_t batch_stride_do = 0;\n    ck_tile::index_t stride_do = dout.stride(0);\n    ck_tile::index_t nhead_stride_do = dout.stride(1);\n\n    // d: (batch_size, nheads, max_seqlen_q)\n    // CK assume d share the same stride with lse\n\n    // dq: (total_q, nheads, hdim)\n    ck_tile::index_t batch_stride_dq = 0;\n    ck_tile::index_t stride_dq = dq.stride(0);\n    ck_tile::index_t nhead_stride_dq = dq.stride(1);\n\n\n    // dk_expanded: (total_k, nheads, hdim)\n    ck_tile::index_t batch_stride_dk = 0;\n    ck_tile::index_t stride_dk = dk.stride(0);\n    ck_tile::index_t nhead_stride_dk = dk.stride(1);\n\n    // dv_expanded: (total_k, nheads, hdim)\n    ck_tile::index_t batch_stride_dv = 0;\n    ck_tile::index_t stride_dv = dv.stride(0);\n    ck_tile::index_t nhead_stride_dv = dv.stride(1);\n\n    // dq_acc: (nheads, split, total_q, hdim)\n    ck_tile::long_index_t batch_stride_dq_acc = 0;\n    ck_tile::long_index_t nhead_stride_dq_acc = dq_acc.stride(0);\n    ck_tile::index_t split_stride_dq_acc = dq_acc.stride(1);\n    ck_tile::index_t stride_dq_acc = dq_acc.stride(2);\n\n    float p_undrop = 1.0 - p_dropout;\n\n    void *alibi_slopes_ptr = nullptr;\n    ck_tile::index_t stride_alibi_slopes = 0;\n\n    if (alibi_slopes_.has_value()) {\n        auto alibi_slopes = alibi_slopes_.value();\n        CHECK_DEVICE(alibi_slopes);\n        TORCH_CHECK(alibi_slopes.stride(-1) == 1, \"ALiBi slopes tensor must have contiguous last dimension\");\n        TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h}));\n        alibi_slopes_ptr = alibi_slopes.data_ptr();\n        // alibi_slopes:(batch_size, nheads) or (nhead)\n        stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;\n    }\n\n    return fmha_bwd_args{q.data_ptr(),\n                         k.data_ptr(),\n                         v.data_ptr(),\n                         alibi_slopes_ptr, // bias\n                         out.data_ptr(),\n                         softmax_lse.data_ptr(),\n                         dout.data_ptr(),\n                         d.data_ptr(),\n                         nullptr, // rand_val\n                         dq.data_ptr(),\n                         dk.data_ptr(),\n                         dv.data_ptr(),\n                         nullptr, // dbias\n                         dq_acc.data_ptr(), // dq_acc\n                         seqlens_q.data_ptr(), // seqstart_q_ptr\n                         seqlens_k.data_ptr(), // seqstart_k_ptr\n                         nullptr, // seqlen_q_ptr\n                         nullptr, // seqlen_k_ptr\n                         nullptr, // cu_seqlen_q_ptr\n                         nullptr, // cu_seqlen_k_ptr\n                         total_q,\n                         total_k,\n                         b,\n                         max_seqlen_q, // max_seqlen_q\n                         max_seqlen_k, // max_seqlen_k\n                         hdim, // hdim_q\n                         hdim, // hdim_v\n                         h, // nhead\n                         h_k, // nhead_k\n                         softmax_scale,\n                         stride_q,\n                         stride_k,\n                         stride_v,\n                         stride_alibi_slopes,\n                         stride_o,\n                         0, // stride_randval\n                         stride_do,\n                         stride_dq_acc,\n                         stride_dq,\n                         stride_dk,\n                         stride_dv,\n                         0, // stride_dbias, FA without bias\n                         nhead_stride_q,\n                         nhead_stride_k,\n                         nhead_stride_v,\n                         0, // nhead_stride_bias, FA without bias\n                         nhead_stride_o,\n                         0, // nhead_stride_randval\n                         nhead_stride_do,\n                         nhead_stride_lse,\n                         nhead_stride_dq_acc,\n                         nhead_stride_dq,\n                         nhead_stride_dk,\n                         nhead_stride_dv,\n                         0, // nhead_stride_dbias, FA without dbias\n                         batch_stride_q,\n                         batch_stride_k,\n                         batch_stride_v,\n                         0  , // batch_stride_bias, FA without bias\n                         batch_stride_o,\n                         0, // batch_stride_randval\n                         batch_stride_do,\n                         batch_stride_lse,\n                         batch_stride_dq_acc,\n                         batch_stride_dq,\n                         batch_stride_dk,\n                         batch_stride_dv,\n                         0  , // batch_stride_dbias, FA without dbias\n                         split_stride_dq_acc,\n                         mask.left,\n                         mask.right,\n                         static_cast<ck_tile::index_t>(mask.type),\n                         p_dropout,\n                         p_undrop,\n                         drop_seed_offset};\n}\n\nstd::vector<at::Tensor>\nmha_varlen_bwd(const at::Tensor &dout,                   // total_q x num_heads x head_size\n               const at::Tensor &q,                      // total_q x num_heads x head_size, total_q := \\sum_{i=0}^{b} s_i\n               const at::Tensor &k,                      // total_k x num_heads_k x head_size, total_k := \\sum_{i=0}^{b} s_i\n               const at::Tensor &v,                      // total_k x num_heads_k x head_size, total_k := \\sum_{i=0}^{b} s_i\n               const at::Tensor &out,                    // total_q x num_heads x head_size\n               const at::Tensor &softmax_lse,            // b x h x s   softmax logsumexp\n               std::optional<at::Tensor> &dq_,           // total_q x num_heads x head_size, total_q := \\sum_{i=0}^{b} s_i\n               std::optional<at::Tensor> &dk_,           // total_k x num_heads_k x head_size, total_k := \\sum_{i=0}^{b} s_i\n               std::optional<at::Tensor> &dv_,           // total_k x num_heads_k x head_size, total_k := \\sum_{i=0}^{b} s_i\n               const at::Tensor &cu_seqlens_q,           // b+1\n               const at::Tensor &cu_seqlens_k,           // b+1\n               std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads\n               const int max_seqlen_q,\n               const int max_seqlen_k, // max sequence length to choose the kernel\n               const float p_dropout,  // probability to drop\n               const float softmax_scale,\n               const bool zero_tensors,\n               const bool is_causal,\n               int window_size_left,\n               int window_size_right,\n               const float /*softcap*/,\n               const bool deterministic,\n               std::optional<at::Generator> gen_,\n               std::optional<at::Tensor> &rng_state_)\n{\n#ifdef FLASHATTENTION_DISABLE_BACKWARD\n    TORCH_CHECK(false, \"This flash attention build does not support backward.\");\n#endif\n    if (is_causal) { window_size_right = 0; }\n\n    const bool is_dropout = p_dropout > 0.0;\n    auto stream = at::cuda::getCurrentCUDAStream().stream();\n\n    auto q_dtype = q.dtype();\n    TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,\n                \"FlashAttention only support fp16 and bf16 data type\");\n\n    TORCH_CHECK(k.dtype() == q_dtype, \"query and key must have the same dtype\");\n    TORCH_CHECK(v.dtype() == q_dtype, \"query and value must have the same dtype\");\n    TORCH_CHECK(out.dtype() == q_dtype, \"query and out must have the same dtype\");\n    TORCH_CHECK(dout.dtype() == q_dtype, \"query and dout must have the same dtype\");\n    TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, \"cu_seqlens_q must have dtype int32\");\n    TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, \"cu_seqlens_k must have dtype int32\");\n\n    const std::string q_dtype_str = q_dtype == torch::kFloat16 ? \"fp16\" : \"bf16\";\n\n    CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);\n    CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);\n    CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k);\n\n    TORCH_CHECK(q.stride(-1) == 1, \"Input tensor must have contiguous last dimension\");\n    TORCH_CHECK(k.stride(-1) == 1, \"Input tensor must have contiguous last dimension\");\n    TORCH_CHECK(v.stride(-1) == 1, \"Input tensor must have contiguous last dimension\");\n    TORCH_CHECK(out.stride(-1) == 1, \"out tensor must have contiguous last dimension\");\n    TORCH_CHECK(dout.stride(-1) == 1, \"dout tensor must have contiguous last dimension\");\n    CHECK_CONTIGUOUS(cu_seqlens_q);\n    CHECK_CONTIGUOUS(cu_seqlens_k);\n\n    const auto sizes = q.sizes();\n\n    const int total_q = sizes[0];\n    const int batch_size = cu_seqlens_q.numel() - 1;\n    const int num_heads = sizes[1];\n    const int head_size = sizes[2];\n    const int total_k = k.size(0);\n    const int num_heads_k = k.size(1);\n    TORCH_CHECK(batch_size > 0, \"batch size must be positive\");\n    TORCH_CHECK(head_size % 8 == 0, \"head_size should be a multiple of 8\");\n    TORCH_CHECK(head_size <= 256, \"CK FlashAttention backward only supports head dimension at most 256\");\n    TORCH_CHECK(num_heads % num_heads_k == 0, \"Number of heads in key/value must divide number of heads in query\");\n\n    if (window_size_left >= max_seqlen_k) { window_size_left = -1; }\n    if (window_size_right >= max_seqlen_k) { window_size_right = -1; }\n\n    mask_info mask;\n    if (is_causal) {\n        std::string mask_identify = \"b:\" + std::to_string(window_size_left) + \",\" + \"0\";\n        mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // casual\n    }\n    else if (window_size_left == -1 && window_size_right == -1) {\n        mask = mask_info::decode(\"0\", max_seqlen_q, max_seqlen_k); // no mask\n    }\n    else {\n        // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.\n        std::string mask_identify = \"b:\" + std::to_string(window_size_left) + \",\" + std::to_string(window_size_right);\n        mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // local\n    }\n\n    // q, k, v, out had been padded in mha_fwd\n    // dq_, dk_, dv_ are also padded tensor\n    CHECK_SHAPE(q, total_q, num_heads, head_size);\n    CHECK_SHAPE(k, total_k, num_heads_k, head_size);\n    CHECK_SHAPE(v, total_k, num_heads_k, head_size);\n    CHECK_SHAPE(out, total_q, num_heads, head_size);\n    CHECK_SHAPE(dout, total_q, num_heads, head_size);\n    CHECK_SHAPE(cu_seqlens_q, batch_size + 1);\n    CHECK_SHAPE(cu_seqlens_k, batch_size + 1);\n\n    at::Tensor dq, dk, dv;\n    if (dq_.has_value()) {\n        dq = dq_.value();\n        TORCH_CHECK(dq.dtype() == q_dtype, \"dq must have the same dtype as q\");\n        CHECK_DEVICE(dq);\n        TORCH_CHECK(dq.stride(-1) == 1, \"dq must have contiguous last dimension\");\n        CHECK_SHAPE(dq, total_q, num_heads, head_size);\n    } else {\n        dq = torch::empty_like(q);\n    }\n    if (dk_.has_value()) {\n        dk = dk_.value();\n        TORCH_CHECK(dk.dtype() == q_dtype, \"dk must have the same dtype as q\");\n        CHECK_DEVICE(dk);\n        TORCH_CHECK(dk.stride(-1) == 1, \"dk must have contiguous last dimension\");\n        CHECK_SHAPE(dk, total_k, num_heads_k, head_size);\n    } else {\n        dk = torch::empty_like(k);\n    }\n    if (dv_.has_value()) {\n        dv = dv_.value();\n        TORCH_CHECK(dv.dtype() == q_dtype, \"dv must have the same dtype as q\");\n        CHECK_DEVICE(dv);\n        TORCH_CHECK(dv.stride(-1) == 1, \"dv must have contiguous last dimension\");\n        CHECK_SHAPE(dv, total_k, num_heads_k, head_size);\n    } else {\n        dv = torch::empty_like(v);\n    }\n\n    const auto traits = get_ck_fmha_varlen_bwd_traits(\n        mask,\n        q_dtype_str,\n        total_q,\n        total_k,\n        batch_size,\n        max_seqlen_q,\n        max_seqlen_k,\n        head_size,\n        num_heads,\n        num_heads_k,\n        is_dropout,\n        alibi_slopes_.has_value(),\n        deterministic);\n    fmha_bwd_launcher launcher(traits);\n    const ck_tile::index_t nsplits = launcher.dq_acc_splits;\n\n    at::cuda::CUDAGuard device_guard{q.device()};\n\n    auto opts = q.options();\n    auto softmax_d = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));\n    at::Tensor dq_accum  = torch::zeros({num_heads, nsplits, total_q, head_size}, opts.dtype(at::kFloat));\n\n    at::Tensor dk_expanded, dv_expanded;\n    if (num_heads_k != num_heads) {  // MQA / GQA\n        dk_expanded = torch::empty({total_k, num_heads, head_size}, opts);\n        dv_expanded = torch::empty({total_k, num_heads, head_size}, opts);\n    } else {\n        dk_expanded = dk;\n        dv_expanded = dv;\n    }\n\n    if(zero_tensors) {\n        dq.zero_();\n        dk_expanded.zero_();\n        dv_expanded.zero_();\n        softmax_d.zero_();\n    }\n\n    auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(\n        gen_, at::cuda::detail::getDefaultCUDAGenerator());\n\n    int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();\n    at::Tensor rng_state;\n\n    if (rng_state_.has_value()) {\n        rng_state = rng_state_.value();\n    } else if(is_dropout) {\n        rng_state = torch::empty({2}, opts.dtype(torch::kInt64));\n        // See Note [Acquire lock when using random generators]\n        std::lock_guard<std::mutex> lock(gen->mutex_);\n        auto philox_args = gen->philox_cuda_state(counter_offset);\n        hipLaunchKernelGGL(\n            flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, 0,\n            philox_args, reinterpret_cast<uint64_t*>(rng_state.data_ptr()));\n    } else {\n        rng_state = torch::empty({2}, opts.dtype(torch::kInt64));\n    }\n\n    if (max_seqlen_q > 0) {\n        auto rng_state_ptr = reinterpret_cast<uint64_t*>(rng_state.data_ptr());\n        auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1);\n        ck_tile::stream_config stream_config{stream};\n\n        auto args =\n            get_ck_fmha_varlen_bwd_args(\n                mask,\n                batch_size,\n                max_seqlen_q,\n                max_seqlen_k,\n                num_heads,\n                num_heads_k,\n                head_size,\n                q,\n                k,\n                v,\n                cu_seqlens_q,\n                cu_seqlens_k,\n                alibi_slopes_,\n                out,\n                softmax_lse,\n                dout,\n                dq_accum,\n                softmax_d,\n                dq,\n                dk_expanded,\n                dv_expanded,\n                softmax_scale,\n                p_dropout,\n                drop_seed_offset);\n\n        float t = fmha_bwd(traits, args, stream_config);\n        TORCH_CHECK(t >= 0, \"invalid argument for fmha_bwd\");\n    } else {\n        // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.\n        dk_expanded.zero_();\n        dv_expanded.zero_();\n        softmax_d.zero_();\n    }\n\n    // For MQA/GQA we need to sum dK and dV across the groups\n    if (num_heads_k != num_heads) {\n        at::sum_out(dk, at::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2});\n        at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2});\n    }\n\n    return { dq, dk, dv, softmax_d };\n}"
  },
  {
    "path": "csrc/flash_attn_ck/mha_varlen_fwd.cpp",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Tri Dao.\n ******************************************************************************/\n\n#include \"flash_common.hpp\"\n\n#include \"fmha_fwd.hpp\"\n#include \"mask.hpp\"\n\nfmha_fwd_traits get_ck_fmha_varlen_fwd_traits(const mask_info &mask,\n                                              std::string dtype,\n                                              int head_size,\n                                              bool has_dropout,\n                                              bool has_lse,\n                                              bool enable_alibi)\n{\n    return fmha_fwd_traits{head_size,\n                           head_size,\n                           dtype,\n                           true,  // is_group_mode\n                           true,  // is_v_rowmajor\n                           false, // has_logits_soft_cap\n                           mask.type,\n                           enable_alibi ? bias_enum::alibi : bias_enum::no_bias,\n                           has_lse,\n                           has_dropout,\n                           quant_scale_enum::no_scale}; // qscale_type\n}\n\nfmha_fwd_splitkv_traits get_ck_fmha_varlen_fwd_splitkv_traits(const mask_info &mask,\n                                                              std::string dtype,\n                                                              int head_size,\n                                                              bool has_lse,\n                                                              bool enable_alibi)\n{\n    return fmha_fwd_splitkv_traits{head_size,\n                                   head_size,\n                                   dtype,\n                                   true,  // is_group_mode\n                                   true,  // is_v_rowmajor\n                                   false, // has_logits_soft_cap\n                                   mask.type,\n                                   enable_alibi ? bias_enum::alibi : bias_enum::no_bias,\n                                   has_lse,\n                                   false,  // do_fp8_static_quant\n                                   false}; // has_sink\n}\n\nfmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,\n                                          bool has_dropout_randval,\n                                          const mask_info &mask,\n                                          // sizes\n                                          const int b,\n                                          const int max_seqlen_q,\n                                          const int h,\n                                          const int h_k,\n                                          const int d,\n                                          // device pointers\n                                          const at::Tensor q,\n                                          const at::Tensor k,\n                                          const at::Tensor v,\n                                          const at::Tensor seqlens_q,\n                                          const at::Tensor seqlens_k,\n                                          std::optional<at::Tensor> &alibi_slopes_,\n                                          at::Tensor out,\n                                          at::Tensor softmax_lse,\n                                          at::Tensor dropout_randval,\n                                          float softmax_scale,\n                                          float p_dropout,\n                                          std::pair<uint64_t*, uint64_t*> drop_seed_offset)\n{\n    // q: (total_q, nheads, d)\n    // k: (total_k, nheads_k, d)\n    // v: (total_k, nheads_k, d)\n    // o: (total_q, nheads, d)\n\n    // alibi_slopes:(batch, nheads) or (nhead)\n    // lse: (nheads, total_q)\n    // randval: (nheads, total_q, max_seqlen_k)\n\n    ck_tile::index_t total_q = q.size(0);\n    ck_tile::index_t total_k = k.size(0);\n\n    ck_tile::index_t stride_q = q.stride(0);\n    ck_tile::index_t stride_k = k.stride(0);\n    ck_tile::index_t stride_v = v.stride(0);\n    ck_tile::index_t stride_o = out.stride(0);\n    ck_tile::index_t stride_randval = has_dropout_randval ? dropout_randval.stride(1) : 0;\n\n    ck_tile::index_t nhead_stride_q = q.stride(1);\n    ck_tile::index_t nhead_stride_k = k.stride(1);\n    ck_tile::index_t nhead_stride_v = v.stride(1);\n    ck_tile::index_t nhead_stride_o = out.stride(1);\n    ck_tile::index_t nhead_stride_lse = has_lse ? softmax_lse.stride(0) : 0;\n    ck_tile::index_t nhead_stride_randval = has_dropout_randval ? dropout_randval.stride(0) : 0;\n\n    ck_tile::index_t batch_stride_q = 0;\n    ck_tile::index_t batch_stride_k = 0;\n    ck_tile::index_t batch_stride_v = 0;\n    ck_tile::index_t batch_stride_o = 0;\n    ck_tile::index_t batch_stride_lse = 0;\n    ck_tile::index_t batch_stride_randval = 0;\n\n    void *alibi_slopes_ptr = nullptr;\n    ck_tile::index_t stride_alibi_slopes = 0;\n\n    if (alibi_slopes_.has_value()) {\n        auto alibi_slopes = alibi_slopes_.value();\n        CHECK_DEVICE(alibi_slopes);\n        TORCH_CHECK(alibi_slopes.stride(-1) == 1, \"ALiBi slopes tensor must have contiguous last dimension\");\n        TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h}));\n        alibi_slopes_ptr = alibi_slopes.data_ptr();\n        stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;\n    }\n\n    return fmha_fwd_args{q.data_ptr(),\n                         k.data_ptr(),\n                         v.data_ptr(),\n                         alibi_slopes_ptr, // bias\n                         nullptr, // q_descale_ptr\n                         nullptr, // k_descale_ptr\n                         nullptr, // v_descale_ptr\n                         has_dropout_randval ? dropout_randval.data_ptr() : nullptr,\n                         has_lse ? softmax_lse.data_ptr() : nullptr,\n                         out.data_ptr(),\n                         seqlens_q.data_ptr(), // seqstart_q_ptr\n                         seqlens_k.data_ptr(), // seqstart_k_ptr\n                         nullptr,              // seqlen_q_ptr\n                         nullptr,              // seqlen_k_ptr\n                         nullptr,              // cu_seqlen_q_ptr\n                         nullptr,              // cu_seqlen_kv_ptr\n                         nullptr,              // block_scale_seqstart_q_ptr\n                         nullptr,              // block_scale_seqstart_k_ptr\n                         nullptr,              // seqstart_v_scale_ptr\n                         nullptr,              // sink_ptr\n                         total_q,\n                         total_k,\n                         b,\n                         max_seqlen_q,\n                         d,             // hdim_q\n                         d,             // hdim_v\n                         h,             // nhead\n                         h_k,           // nhead_k\n                         softmax_scale, // scale_s\n                         0.0f,          // logits_soft_cap\n                         stride_q,\n                         stride_k,\n                         stride_v,\n                         stride_alibi_slopes,\n                         stride_randval,\n                         stride_o,\n                         0, // stride_q_descale\n                         0, // stride_k_descale\n                         0, // stride_v_descale\n                         nhead_stride_q,\n                         nhead_stride_k,\n                         nhead_stride_v,\n                         0, // nhead_stride_bias, FA without bias\n                         nhead_stride_randval,\n                         nhead_stride_lse,\n                         nhead_stride_o,\n                         0, // nhead_stride_q_descale\n                         0, // nhead_stride_k_descale\n                         0, // nhead_stride_v_descale\n                         batch_stride_q,\n                         batch_stride_k,\n                         batch_stride_v,\n                         0, // batch_stride_bias, FA without bias\n                         batch_stride_randval,\n                         batch_stride_lse,\n                         batch_stride_o,\n                         0, // batch_stride_q_descale\n                         0, // batch_stride_k_descale\n                         0, // batch_stride_v_descale\n                         mask.left,\n                         mask.right,\n                         0, // sink_size\n                         static_cast<ck_tile::index_t>(mask.type),\n                         0, // min_seqlen_q\n                         p_dropout,\n                         has_dropout_randval,\n                         drop_seed_offset,\n                         0,     // block_scale_size_q\n                         0};    // block_scale_size_kv\n}\n\nfmha_fwd_splitkv_args get_ck_fmha_varlen_fwd_splitkv_args(bool has_lse,\n                                                          const mask_info &mask,\n                                                          const int b,\n                                                          const int max_seqlen_q,\n                                                          const int h,\n                                                          const int h_k,\n                                                          const int d,\n                                                          const int page_block_size,\n                                                          const int num_splits,\n                                                          float softmax_scale,\n                                                          // device pointers\n                                                          const at::Tensor q,\n                                                          const at::Tensor k,\n                                                          const at::Tensor v,\n                                                          const at::Tensor seqlens_q,\n                                                          const at::Tensor seqlens_k,\n                                                          std::optional<at::Tensor> &block_table_,\n                                                          std::optional<at::Tensor> &alibi_slopes_,\n                                                          at::Tensor out,\n                                                          at::Tensor lse,\n                                                          at::Tensor lse_acc,\n                                                          at::Tensor out_acc)\n{\n    // q: (total_q, nheads, d)\n    // k: (num_blocks, page_block_size, num_heads_k, d)\n    // v: (num_blocks, page_block_size, num_heads_k, d)\n    // o: (total_q, nheads, d)\n\n    // alibi_slopes:(batch_size, nheads) or (nhead)\n    // lse: (nheads, total_q)\n    // lse_acc: (nheads, split, total_q)\n    // o_acc: (nheads, split, total_q, d)\n    // block_table: (batch_size, max_num_blocks_per_seq)\n\n    fmha_fwd_splitkv_args args;\n    args.q_ptr = q.data_ptr();\n    args.k_ptr = k.data_ptr();\n    args.v_ptr = v.data_ptr();\n    args.bias_ptr = nullptr;\n    args.lse_acc_ptr = lse_acc.data_ptr();\n    args.o_acc_ptr = out_acc.data_ptr();\n    args.lse_ptr = nullptr;\n    args.o_ptr = out.data_ptr();\n    args.sink_ptr = nullptr;\n\n    if (block_table_.has_value())\n    {\n        auto block_table = block_table_.value();\n        args.block_table_ptr = block_table.data_ptr();\n        args.batch_stride_block_table = block_table.stride(0);\n        args.page_block_size = page_block_size;\n    }\n    else\n    {\n        args.block_table_ptr = nullptr;\n        args.batch_stride_block_table = 0;\n        args.page_block_size = 0;\n    }\n\n    args.is_gappy = false;\n    args.cache_batch_idx = nullptr;\n\n    args.seqstart_q_ptr = seqlens_q.data_ptr();\n    args.seqstart_k_ptr = seqlens_k.data_ptr();\n    args.seqlen_k_ptr = nullptr;\n\n    args.batch = b;\n    args.max_seqlen_q = max_seqlen_q;\n    args.hdim_q = d;\n    args.hdim_v = d;\n    args.nhead_q = h;\n    args.nhead_k = h_k;\n    args.num_splits = num_splits;\n\n    args.scale_s = softmax_scale;\n    args.scale_p = 1;\n    args.scale_o = 1;\n\n    args.batch_stride_q = 0;\n    args.stride_q = q.stride(0);\n    args.nhead_stride_q = q.stride(1);\n\n    args.batch_stride_k = k.stride(0);\n    args.stride_k = k.stride(1);\n    args.nhead_stride_k = k.stride(2);\n\n    args.batch_stride_v = v.stride(0);\n    args.stride_v = v.stride(1);\n    args.nhead_stride_v = v.stride(2);\n\n    args.batch_stride_o = 0;\n    args.stride_o = out.stride(0);\n    args.nhead_stride_o = out.stride(1);\n\n    args.batch_stride_bias = 0;\n    args.stride_bias = 0;\n    args.nhead_stride_bias = 0;\n\n    args.batch_stride_lse = 0;\n    args.nhead_stride_lse = 0;\n\n    args.batch_stride_lse_acc = 0;\n    args.nhead_stride_lse_acc = lse_acc.stride(0);\n    args.split_stride_lse_acc = lse_acc.stride(1);\n\n    args.batch_stride_o_acc = 0;\n    args.nhead_stride_o_acc = out_acc.stride(0);\n    args.split_stride_o_acc = out_acc.stride(1);\n    args.stride_o_acc = out_acc.stride(2);\n\n    if (has_lse) {\n        args.lse_ptr = lse.data_ptr();\n        args.batch_stride_lse = 0;\n        args.nhead_stride_lse = lse.stride(0);\n    }\n\n    if (alibi_slopes_.has_value()) {\n        auto alibi_slopes = alibi_slopes_.value();\n        CHECK_DEVICE(alibi_slopes);\n        TORCH_CHECK(alibi_slopes.stride(-1) == 1, \"ALiBi slopes tensor must have contiguous last dimension\");\n        TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h}));\n        args.bias_ptr = alibi_slopes.data_ptr();\n        args.stride_bias = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;\n    }\n\n    args.window_size_left = mask.left;\n    args.window_size_right = mask.right;\n    args.sink_size = 0;\n    args.mask_type = static_cast<ck_tile::index_t>(mask.type);\n\n    return args;\n}\n\nstd::vector<at::Tensor>\nmha_varlen_fwd(at::Tensor &q,                   // total_q x num_heads x head_size, total_q := \\sum_{i=0}^{b} s_i\n               const at::Tensor &k,             // total_k x num_heads_k x head_size, total_k := \\sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.\n               const at::Tensor &v,             // total_k x num_heads_k x head_size, total_k := \\sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.\n               std::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \\sum_{i=0}^{b} s_i\n               const at::Tensor &cu_seqlens_q,  // b+1\n               const at::Tensor &cu_seqlens_k,  // b+1\n               std::optional<at::Tensor> & /*seqused_k*/,\n               std::optional<const at::Tensor> &/*leftpad_k_*/, // batch_size\n               std::optional<at::Tensor> &block_table_,  // batch_size x max_num_blocks_per_seq\n               std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads\n               int max_seqlen_q,\n               const int max_seqlen_k,\n               const float p_dropout,\n               const float softmax_scale,\n               const bool zero_tensors,\n               bool is_causal,\n               int window_size_left,\n               int window_size_right,\n               const float /*softcap*/,\n               const bool return_dropout_randval,\n               std::optional<at::Generator> gen_)\n{\n    auto q_dtype = q.dtype();\n    TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,\n                \"FlashAttention only support fp16 and bf16 data type\");\n\n    TORCH_CHECK(k.dtype() == q_dtype, \"query and key must have the same dtype\");\n    TORCH_CHECK(v.dtype() == q_dtype, \"query and value must have the same dtype\");\n    TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, \"cu_seqlens_q must have dtype int32\");\n    TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, \"cu_seqlens_k must have dtype int32\");\n\n    std::string q_dtype_str = q_dtype == torch::kFloat16 ? \"fp16\" : \"bf16\";\n\n    CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);\n    CHECK_DEVICE(cu_seqlens_q);\n    CHECK_DEVICE(cu_seqlens_k);\n\n    at::Tensor block_table;\n    const bool paged_KV = block_table_.has_value();\n    if (paged_KV) {\n        block_table = block_table_.value();\n        CHECK_DEVICE(block_table);\n        TORCH_CHECK(block_table.dtype() == torch::kInt32, \"block_table must have dtype torch.int32\");\n        TORCH_CHECK(block_table.stride(-1) == 1, \"block_table must have contiguous last dimension\");\n    }\n\n    TORCH_CHECK(q.stride(-1) == 1, \"Input tensor must have contiguous last dimension\");\n    TORCH_CHECK(k.stride(-1) == 1, \"Input tensor must have contiguous last dimension\");\n    TORCH_CHECK(v.stride(-1) == 1, \"Input tensor must have contiguous last dimension\");\n    CHECK_CONTIGUOUS(cu_seqlens_q);\n    CHECK_CONTIGUOUS(cu_seqlens_k);\n\n    const auto sizes = q.sizes();\n\n    const int batch_size = cu_seqlens_q.numel() - 1;\n    int num_heads = sizes[1];\n    const int head_size = sizes[2];\n    const int num_heads_k = paged_KV ? k.size(2) : k.size(1);\n\n    const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);\n    const int num_blocks = !paged_KV ? 0 : k.size(0);\n    const int page_block_size = !paged_KV ? 1 : k.size(1);\n    TORCH_CHECK(!paged_KV || page_block_size % 128 == 0, \"Paged KV cache block size must be divisible by 128\");\n\n    if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }  // causal=true is the same as causal=false in this case\n\n    // TODO\n    // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case\n    // H/t Daniel Haziza\n\n    const int total_q = q.size(0);\n\n    TORCH_CHECK(batch_size > 0, \"batch size must be postive\");\n    TORCH_CHECK(head_size <= 256, \"CK only supports head dimension at most 256\");\n    TORCH_CHECK(head_size % 8 == 0, \"query, key, value, and out_ must have a head_size that is a multiple of 8\");\n    TORCH_CHECK(num_heads % num_heads_k == 0, \"Number of heads in key/value must divide number of heads in query\");\n\n    if (window_size_left >= max_seqlen_k) { window_size_left = -1; }\n    if (window_size_right >= max_seqlen_k) { window_size_right = -1; }\n\n    mask_info mask;\n\n    if (is_causal) {\n        // Causal is the special case where window_size_right == 0 and window_size_left < 0.\n        window_size_right = 0;\n        std::string mask_identify = \"b:\" + std::to_string(window_size_left) + \",\" + \"0\";\n        mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // casual\n    }\n    else if (window_size_left == -1 && window_size_right == -1) {\n        mask = mask_info::decode(\"0\", max_seqlen_q, max_seqlen_k); // no mask\n    }\n    else {\n        // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.\n        std::string mask_identify = \"b:\" + std::to_string(window_size_left) + \",\" + std::to_string(window_size_right);\n        mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // local\n    }\n\n    CHECK_SHAPE(q, total_q, num_heads, head_size);\n    if (!paged_KV) {\n        const int total_k = k.size(0);\n        CHECK_SHAPE(k, total_k, num_heads_k, head_size);\n        CHECK_SHAPE(v, total_k, num_heads_k, head_size);\n    } else {\n        CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size);\n        CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size);\n        CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);\n    }\n\n    CHECK_SHAPE(cu_seqlens_q, batch_size + 1);\n    CHECK_SHAPE(cu_seqlens_k, batch_size + 1);\n    at::Tensor out;\n    if (out_.has_value()) {\n        out = out_.value();\n        TORCH_CHECK(out.dtype() == q_dtype, \"Output must have the same dtype as inputs\");\n        CHECK_DEVICE(out);\n        TORCH_CHECK(out.stride(-1) == 1, \"Output tensor must have contiguous last dimension\");\n        CHECK_SHAPE(out, total_q, num_heads, head_size);\n    }\n    else {\n        out = torch::empty_like(q);\n    }\n\n    // Otherwise the kernel will be launched from cuda:0 device\n    at::cuda::CUDAGuard device_guard{q.device()};\n\n    auto opts = q.options();\n    bool has_lse = true;\n    bool has_dropout = p_dropout > 0.0f;\n    if (has_dropout)\n        TORCH_CHECK(!paged_KV, \"Paged KV does not support dropout\");\n\n    at::Tensor softmax_lse;\n    // TODO - check gradient, only training require lse\n    softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(torch::kFloat32));\n\n    at::Tensor p;\n    if (return_dropout_randval) {\n        TORCH_CHECK(has_dropout, \"return_dropout_randval require p_dropout > 0\");\n        p = torch::empty({num_heads, total_q, max_seqlen_k}, opts.dtype(torch::kUInt8));\n    }\n    else {\n        p = torch::empty({ 0 }, opts);\n    }\n\n    if (zero_tensors)\n    {\n        out.zero_();\n        softmax_lse.fill_(-std::numeric_limits<float>::infinity());\n        if (return_dropout_randval) {p.zero_();}\n    }\n\n    int num_splits = 0;\n    num_splits = flash::override_num_splits_if_necessary(batch_size, num_heads, max_seqlen_q, head_size, 0, num_splits);\n    TORCH_CHECK(num_splits > 0, \"num_splits should greater than 0\");\n    TORCH_CHECK(num_splits <= 128, \"num_splits greater than 128 is not supported\");\n\n    auto softmax_lse_accum = torch::empty({num_heads, num_splits, total_q}, opts.dtype(at::kFloat));\n    auto out_accum = torch::empty({num_heads, num_splits, total_q, head_size}, opts.dtype(at::kFloat));\n\n    int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();\n    auto rng_state = torch::empty({2}, opts.dtype(torch::kInt64));\n    auto rng_state_ptr = reinterpret_cast<uint64_t*>(rng_state.data_ptr());\n\n    if (p_dropout > 0.0)  {\n        auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(\n            gen_, at::cuda::detail::getDefaultCUDAGenerator());\n        // See Note [Acquire lock when using random generators]\n        std::lock_guard<std::mutex> lock(gen->mutex_);\n        auto philox_args = gen->philox_cuda_state(counter_offset);\n        hipLaunchKernelGGL(\n            flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, 0, philox_args, rng_state_ptr);\n    }\n\n    if (max_seqlen_k > 0) {\n#ifdef HIPIFY_V2\n        auto stream = at::cuda::getCurrentCUDAStream().stream();\n#else\n        auto stream = at::cuda::getCurrentHIPStream().stream();\n#endif\n        ck_tile::stream_config stream_config{stream};\n\n        if (paged_KV)\n        {\n            auto traits =\n                get_ck_fmha_varlen_fwd_splitkv_traits(\n                    mask,\n                    q_dtype_str,\n                    head_size,\n                    has_lse,\n                    alibi_slopes_.has_value());\n\n            auto args =\n                get_ck_fmha_varlen_fwd_splitkv_args(\n                    has_lse,\n                    mask,\n                    batch_size,\n                    max_seqlen_q,\n                    num_heads,\n                    num_heads_k,\n                    head_size,\n                    page_block_size,\n                    num_splits,\n                    softmax_scale,\n                    q,\n                    k,\n                    v,\n                    cu_seqlens_q,\n                    cu_seqlens_k,\n                    block_table_,\n                    alibi_slopes_,\n                    out,\n                    softmax_lse,\n                    softmax_lse_accum,\n                    out_accum);\n\n            float t = fmha_fwd_splitkv(traits, args, stream_config);\n            TORCH_CHECK(t >= 0, \"invalid argument for fmha_fwd_splitkv\");\n        }\n        else\n        {\n            auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1);\n\n            auto traits =\n                get_ck_fmha_varlen_fwd_traits(\n                    mask,\n                    q_dtype_str,\n                    head_size,\n                    has_dropout,\n                    has_lse,\n                    alibi_slopes_.has_value());\n\n            auto args =\n                get_ck_fmha_varlen_fwd_args(\n                    has_lse,\n                    return_dropout_randval,\n                    mask,\n                    batch_size,\n                    max_seqlen_q,\n                    num_heads,\n                    num_heads_k,\n                    head_size,\n                    q,\n                    k,\n                    v,\n                    cu_seqlens_q,\n                    cu_seqlens_k,\n                    alibi_slopes_,\n                    out,\n                    softmax_lse,\n                    p,\n                    softmax_scale,\n                    p_dropout,\n                    drop_seed_offset);\n\n            float t = fmha_fwd(traits, args, stream_config);\n            TORCH_CHECK(t >= 0, \"invalid argument for fmha_fwd\");\n        }\n    }\n    else {\n        // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.\n        out.zero_();\n        softmax_lse.fill_(std::numeric_limits<float>::infinity());\n    }\n\n    return {out, softmax_lse, p, rng_state};\n}\n"
  },
  {
    "path": "csrc/fused_dense_lib/README.md",
    "content": "This CUDA extension implements fused matmul + bias (forward and backward), and fused matmul + bias + gelu\n(forward and backward), adapted from Apex's\n[FusedDense](https://github.com/NVIDIA/apex/tree/master/apex/fused_dense).\nWe make it work for bfloat16.\n\nFor best performance, you should use CUDA >= 11.8. CuBLAS versions before\nthis doesn't have the best matmul + bias + gelu performance for bfloat16.\n\nIt has only been tested on A100s.\n\n```sh\ncd csrc/fused_dense_lib && pip install .\n```\n"
  },
  {
    "path": "csrc/fused_dense_lib/fused_dense.cpp",
    "content": "// Adapted from https://github.com/NVIDIA/apex/blob/master/csrc/fused_dense.cpp\n// We make it work for bfloat16\n#include <torch/extension.h>\n#include <torch/torch.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <c10/cuda/CUDAGuard.h>\n#include <vector>\n\n#include <stdio.h>\n\n#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x \" must have shape (\" #__VA_ARGS__ \")\")\n\n// https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h\n// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800\n#define DISPATCH_HALF_AND_BF16(TYPE, NAME, ...)                                \\\n  switch (TYPE) {                                                              \\\n  case at::ScalarType::Half: {                                                 \\\n    using scalar_t = at::Half;                                                 \\\n    __VA_ARGS__();                                                             \\\n    break;                                                                     \\\n  }                                                                            \\\n  case at::ScalarType::BFloat16: {                                             \\\n    using scalar_t = at::BFloat16;                                             \\\n    __VA_ARGS__();                                                             \\\n    break;                                                                     \\\n  }                                                                            \\\n  default:                                                                     \\\n    AT_ERROR(#NAME, \" not implemented for '\", toString(TYPE), \"'\");            \\\n  }\n\ntemplate <typename T>\nint linear_bias_wgrad_cuda(const T *input, const T *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, T *d_weight, T *d_bias, void *lt_workspace, size_t workspaceSize);\n\ntemplate <typename T>\nint linear_act_forward_cuda(const T *input, const T *weight, const T *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *output, void *pre_act, void *lt_workspace, size_t workspaceSize);\n\ntemplate <typename T>\nint bias_act_linear_dgrad_bgrad_cuda(const T *weight, const T *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *d_input, T *d_bias, void *lt_workspace, size_t workspaceSize);\n\nstd::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output, bool has_d_bias) {\n\n  int64_t batch_size = input.size(0);\n  int64_t in_features = input.size(1);\n  int64_t out_features = d_output.size(1);\n\n  TORCH_CHECK(input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16);\n  TORCH_CHECK(input.dtype() == d_output.dtype());\n  TORCH_CHECK(input.is_cuda());\n  TORCH_CHECK(d_output.is_cuda());\n  TORCH_CHECK(input.is_contiguous());\n  TORCH_CHECK(d_output.is_contiguous());\n  CHECK_SHAPE(input, batch_size, in_features);\n  CHECK_SHAPE(d_output, batch_size, out_features);\n\n  // Otherwise the kernel will be launched from cuda:0 device\n  at::cuda::CUDAGuard device_guard{input.device()};\n\n  // create output/workspace tensor\n  auto opts = input.options();\n  auto d_weight = at::empty({out_features, in_features}, opts);\n  at::Tensor d_bias;\n  if (has_d_bias) {\n#if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600\n    d_bias = d_output.view({-1, out_features}).sum(0, false);\n#else\n    d_bias = at::empty({out_features}, opts);\n#endif\n  }\n  // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind setting this to 1M.\n  // However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs\n  // https://github.com/NVIDIA/TransformerEngine/blob/a0f0065498bbcfc1da78cf9e8b166f5381613fbc/transformer_engine/pytorch/module.py#L91\n  size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4);\n  auto lt_workspace = at::empty({static_cast<int64_t>(workspaceSize)}, opts.dtype(torch::kUInt8));\n\n  DISPATCH_HALF_AND_BF16(input.scalar_type(), \"linear_bias_wgrad\", [&] {\n    auto result = linear_bias_wgrad_cuda<scalar_t>(\n        input.data_ptr<scalar_t>(),\n        d_output.data_ptr<scalar_t>(),\n        in_features,\n        batch_size,\n        out_features,\n        d_weight.data_ptr<scalar_t>(),\n        has_d_bias ? d_bias.data_ptr<scalar_t>() : nullptr,\n        (void*) (lt_workspace.data_ptr()),\n        workspaceSize);\n    TORCH_CHECK(result == 0, \"linear_bias_wgrad failed.\");\n  });\n\n  return {d_weight, d_bias};\n}\n\nstd::vector<at::Tensor> linear_act_forward(at::Tensor input, at::Tensor weight,\n                                           std::optional<at::Tensor> bias_,\n                                           bool is_gelu, bool save_pre_act, int heuristic) {\n\n  int64_t batch_size = input.size(0);\n  int64_t in_features = input.size(1);\n  int64_t out_features = weight.size(0);\n\n  TORCH_CHECK(input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16);\n  TORCH_CHECK(input.dtype() == weight.dtype());\n  TORCH_CHECK(input.is_cuda());\n  TORCH_CHECK(weight.is_cuda());\n  TORCH_CHECK(input.is_contiguous());\n  TORCH_CHECK(weight.is_contiguous());\n  CHECK_SHAPE(input, batch_size, in_features);\n  CHECK_SHAPE(weight, out_features, in_features);\n  if (bias_.has_value()) {\n    auto bias = bias_.value();\n    TORCH_CHECK(bias.dtype() == input.dtype());\n    TORCH_CHECK(bias.is_cuda());\n    TORCH_CHECK(bias.is_contiguous());\n    CHECK_SHAPE(bias, out_features);\n  }\n\n  // Otherwise the kernel will be launched from cuda:0 device\n  at::cuda::CUDAGuard device_guard{input.device()};\n\n  // create output/workspace tensor\n  auto opts = input.options();\n  auto output = at::empty({batch_size, out_features}, opts);\n  at::Tensor pre_act;\n  // If ReLU, cuBlasLT stores a bit-mask (1 bit per element)\n  if (save_pre_act) { pre_act = at::empty({batch_size, is_gelu ? out_features : out_features / 8},\n                                          is_gelu ? opts : opts.dtype(torch::kUInt8)); }\n  // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind setting this to 1M.\n  // However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs\n  // https://github.com/NVIDIA/TransformerEngine/blob/a0f0065498bbcfc1da78cf9e8b166f5381613fbc/transformer_engine/pytorch/module.py#L91\n  size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4);\n  auto lt_workspace = at::empty({static_cast<int64_t>(workspaceSize)}, opts.dtype(torch::kUInt8));\n\n  DISPATCH_HALF_AND_BF16(input.scalar_type(), \"linear_act_forward\", [&] {\n    auto result = linear_act_forward_cuda<scalar_t>(\n        input.data_ptr<scalar_t>(),\n        weight.data_ptr<scalar_t>(),\n        bias_.has_value()? bias_.value().data_ptr<scalar_t>() : nullptr,\n        in_features,\n        batch_size,\n        out_features,\n        is_gelu,\n        heuristic,\n        output.data_ptr<scalar_t>(),\n        save_pre_act ? pre_act.data_ptr() : nullptr,\n        (void*) (lt_workspace.data_ptr()),\n        workspaceSize);\n    TORCH_CHECK(result == 0, \"linear_act_forward failed.\");\n  });\n\n  std::vector<at::Tensor> result = {output};\n  if (save_pre_act) { result.push_back(pre_act); };\n  return result;\n}\n\nstd::vector<at::Tensor> bias_act_linear_dgrad_bgrad(\n  at::Tensor weight, at::Tensor d_output, at::Tensor pre_act, bool is_gelu, int heuristic\n) {\n\n  int64_t batch_size = d_output.size(0);\n  int64_t out_features = d_output.size(1);\n  int64_t in_features = weight.size(1);\n\n  TORCH_CHECK(weight.dtype() == torch::kFloat16 || weight.dtype() == torch::kBFloat16);\n  TORCH_CHECK(weight.dtype() == d_output.dtype());\n  TORCH_CHECK(is_gelu ? (pre_act.dtype() == weight.dtype()) : (pre_act.dtype() == torch::kUInt8));\n  TORCH_CHECK(weight.is_cuda());\n  TORCH_CHECK(d_output.is_cuda());\n  TORCH_CHECK(pre_act.is_cuda());\n  TORCH_CHECK(weight.is_contiguous());\n  TORCH_CHECK(d_output.is_contiguous());\n  TORCH_CHECK(pre_act.is_contiguous());\n  CHECK_SHAPE(weight, out_features, in_features);\n  CHECK_SHAPE(d_output, batch_size, out_features);\n  // If ReLU, cuBlasLT stores a bit-mask (1 bit per element)\n  CHECK_SHAPE(pre_act, batch_size, is_gelu ? in_features : in_features / 8);\n\n  // Otherwise the kernel will be launched from cuda:0 device\n  at::cuda::CUDAGuard device_guard{weight.device()};\n\n  // create output/workspace tensor\n  auto opts = weight.options();\n  auto d_bias = at::empty({in_features}, opts);\n  auto d_input = at::empty({batch_size, in_features}, opts);\n  // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind setting this to 1M.\n  // However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs\n  // https://github.com/NVIDIA/TransformerEngine/blob/a0f0065498bbcfc1da78cf9e8b166f5381613fbc/transformer_engine/pytorch/module.py#L91\n  size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4);\n  auto lt_workspace = at::empty({static_cast<int64_t>(workspaceSize)}, opts.dtype(torch::kUInt8));\n\n  DISPATCH_HALF_AND_BF16(weight.scalar_type(), \"bias_act_linear_dgrad_bgrad\", [&] {\n    auto result = bias_act_linear_dgrad_bgrad_cuda<scalar_t>(\n        weight.data_ptr<scalar_t>(),\n        d_output.data_ptr<scalar_t>(),\n        pre_act.data_ptr(),\n        in_features,\n        batch_size,\n        out_features,\n        is_gelu,\n        heuristic,\n        d_input.data_ptr<scalar_t>(),\n        d_bias.data_ptr<scalar_t>(),\n        (void*) (lt_workspace.data_ptr()),\n        workspaceSize);\n    TORCH_CHECK(result == 0, \"bias_act_linear_dgrad_bgrad failed.\");\n  });\n\n  return {d_input, d_bias};\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"linear_bias_wgrad\", &linear_bias_wgrad, \"linear bias wgrad\");\n  m.def(\"linear_act_forward\", &linear_act_forward, \"linear gelu/relu forward\");\n  m.def(\"bias_act_linear_dgrad_bgrad\", &bias_act_linear_dgrad_bgrad, \"bias gelu/relu linear dgrad bgrad\");\n}\n"
  },
  {
    "path": "csrc/fused_dense_lib/fused_dense_cuda.cu",
    "content": "// Adapted from https://github.com/NVIDIA/apex/blob/master/csrc/fused_dense_cuda.cu\n#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <assert.h>\n#include <stdio.h>\n#include <stdlib.h>\n#include <string.h>\n#include <torch/torch.h>\n\n/* Includes, cuda */\n#include <cublas_v2.h>\n#include <cuda_runtime.h>\n\n#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000\n#include <cublasLt.h>\n#endif\n\n// FP16 Tensor core wrapper around cublas GEMMEx\ncublasStatus_t gemm_bias(\n    cublasHandle_t handle,\n    cublasOperation_t transa,\n    cublasOperation_t transb,\n    int64_t m,\n    int64_t n,\n    int64_t k,\n    const float* alpha,\n    const at::Half* A,\n    int64_t lda,\n    const at::Half* B,\n    int64_t ldb,\n    const float* beta,\n    at::Half* C,\n    int64_t ldc) {\n  return cublasGemmEx(\n      handle,\n      transa,\n      transb,\n      m,\n      n,\n      k,\n      alpha,\n      A,\n      CUDA_R_16F,\n      lda,\n      B,\n      CUDA_R_16F,\n      ldb,\n      beta,\n      C,\n      CUDA_R_16F,\n      ldc,\n      CUDA_R_32F,\n      CUBLAS_GEMM_DEFAULT_TENSOR_OP);\n}\n\n// BF16 Tensor core wrapper around cublas GEMMEx\ncublasStatus_t gemm_bias(\n    cublasHandle_t handle,\n    cublasOperation_t transa,\n    cublasOperation_t transb,\n    int64_t m,\n    int64_t n,\n    int64_t k,\n    const float* alpha,\n    const at::BFloat16* A,\n    int64_t lda,\n    const at::BFloat16* B,\n    int64_t ldb,\n    const float* beta,\n    at::BFloat16* C,\n    int64_t ldc) {\n  return cublasGemmEx(\n      handle,\n      transa,\n      transb,\n      m,\n      n,\n      k,\n      alpha,\n      A,\n      CUDA_R_16BF,\n      lda,\n      B,\n      CUDA_R_16BF,\n      ldb,\n      beta,\n      C,\n      CUDA_R_16BF,\n      ldc,\n      CUDA_R_32F,\n      CUBLAS_GEMM_DEFAULT_TENSOR_OP);\n}\n\n#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600\n\ntemplate <typename Dtype>\nint gemm_bias_act_lt(\n    cublasOperation_t transa,\n    cublasOperation_t transb,\n    int64_t m,\n    int64_t n,\n    int64_t k,\n    float alpha,\n    const Dtype* A,\n    int64_t lda,\n    const Dtype* B,\n    int64_t ldb,\n    const Dtype* bias,\n    Dtype* C,\n    int64_t ldc,\n    void* pre_act,\n    bool is_gelu,\n    int heuristic,\n    void *lt_workspace,\n    size_t workspaceSize\n    ) {\n  static_assert(std::is_same<Dtype, at::Half>::value || std::is_same<Dtype, at::BFloat16>::value,\n                \"gemm_bias_act_lt only supports fp16 and bf16\");\n  bool save_pre_act = pre_act != nullptr;\n  float beta = 0.0;\n  cudaDataType_t abcType = std::is_same<Dtype, at::Half>::value ? CUDA_R_16F : CUDA_R_16BF;\n\n  cublasLtHandle_t ltHandle =\n    reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());\n\n  cublasStatus_t status = CUBLAS_STATUS_SUCCESS;\n\n  cublasLtMatmulDescOpaque_t operationDesc = {};\n  cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};\n  cublasLtMatmulPreferenceOpaque_t preference = {};\n\n  int returnedResults                             = 0;\n  constexpr int requestedAlgoCount = 5;\n  cublasLtMatmulHeuristicResult_t heuristicResult[requestedAlgoCount] = {0};\n  // constexpr int requestedAlgoCount = 1;\n  // cublasLtMatmulHeuristicResult_t heuristicResult = {};\n  cublasLtEpilogue_t epilogue = is_gelu\n      ? (save_pre_act ? CUBLASLT_EPILOGUE_GELU_AUX : CUBLASLT_EPILOGUE_GELU)\n      : (save_pre_act ? CUBLASLT_EPILOGUE_RELU_AUX : CUBLASLT_EPILOGUE_RELU);\n\n  // Create operation descriptor; see cublasLtMatmulDescAttributes_t\n  // for details about defaults; here we just set the transforms for\n  // A and B.\n  status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  if (save_pre_act) {\n    status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &pre_act, sizeof(pre_act));\n    status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc));\n  }\n\n  if (bias != nullptr) {\n    status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));\n    if (status != CUBLAS_STATUS_SUCCESS) {\n      goto CLEANUP;\n    }\n    epilogue = is_gelu\n        ? (save_pre_act ? CUBLASLT_EPILOGUE_GELU_AUX_BIAS : CUBLASLT_EPILOGUE_GELU_BIAS)\n        : (save_pre_act ? CUBLASLT_EPILOGUE_RELU_AUX_BIAS : CUBLASLT_EPILOGUE_RELU_BIAS);\n  } else {\n    epilogue = is_gelu\n        ? (save_pre_act ? CUBLASLT_EPILOGUE_GELU_AUX : CUBLASLT_EPILOGUE_GELU)\n        : (save_pre_act ? CUBLASLT_EPILOGUE_RELU_AUX : CUBLASLT_EPILOGUE_RELU);\n  }\n\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));\n  if (status != CUBLAS_STATUS_SUCCESS) {\n    goto CLEANUP;\n  }\n\n  // Create matrix descriptors. Not setting any extra attributes.\n  status = cublasLtMatrixLayoutInit(\n    &Adesc, abcType, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatrixLayoutInit(\n    &Bdesc, abcType, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatrixLayoutInit(&Cdesc, abcType, m, n, ldc);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  // Create preference handle; In general, extra attributes can be\n  // used here to disable tensor ops or to make sure algo selected\n  // will work with badly aligned A, B, C. However, for simplicity\n  // here we assume A,B,C are always well aligned (e.g., directly\n  // come from cudaMalloc)\n  status = cublasLtMatmulPreferenceInit(&preference);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulPreferenceSetAttribute(\n    &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  // We just need the best available heuristic to try and run matmul.\n  // There is no guarantee that this will work. For example, if A is\n  // badly aligned, you can request more (e.g. 32) algos and try to\n  // run them one by one until something works.\n  status = cublasLtMatmulAlgoGetHeuristic(\n    ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, requestedAlgoCount, heuristicResult, &returnedResults);\n    // ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  if (returnedResults == 0) {\n    status = CUBLAS_STATUS_NOT_SUPPORTED;\n    goto CLEANUP;\n  }\n  status = cublasLtMatmul(ltHandle,\n                          &operationDesc,\n                          &alpha,\n                          A,\n                          &Adesc,\n                          B,\n                          &Bdesc,\n                          &beta,\n                          C,\n                          &Cdesc,\n                          C,\n                          &Cdesc,\n                          // &heuristicResult.algo,\n                          // TD [2022-04-29] Somehow algo 0 and 2 are a lot slower than other algos\n                          &heuristicResult[heuristic].algo,\n                          // NULL,\n                          lt_workspace,\n                          workspaceSize,\n                          at::cuda::getCurrentCUDAStream());\n\nCLEANUP:\n  // Descriptors are no longer needed as all GPU work was already\n  // enqueued.\n  return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;\n}\n\ntemplate int gemm_bias_act_lt(\n    cublasOperation_t transa,\n    cublasOperation_t transb,\n    int64_t m,\n    int64_t n,\n    int64_t k,\n    float alpha,\n    const at::Half* A,\n    int64_t lda,\n    const at::Half* B,\n    int64_t ldb,\n    const at::Half* bias,\n    at::Half* C,\n    int64_t ldc,\n    void* pre_act,\n    bool is_gelu,\n    int heuristic,\n    void *lt_workspace,\n    size_t workspaceSize);\n\ntemplate int gemm_bias_act_lt(\n    cublasOperation_t transa,\n    cublasOperation_t transb,\n    int64_t m,\n    int64_t n,\n    int64_t k,\n    float alpha,\n    const at::BFloat16* A,\n    int64_t lda,\n    const at::BFloat16* B,\n    int64_t ldb,\n    const at::BFloat16* bias,\n    at::BFloat16* C,\n    int64_t ldc,\n    void* pre_act,\n    bool is_gelu,\n    int heuristic,\n    void *lt_workspace,\n    size_t workspaceSize);\n\ntemplate <typename Dtype>\nint gemm_bgradb_lt(\n    cublasOperation_t transa,\n    cublasOperation_t transb,\n    int64_t m,\n    int64_t n,\n    int64_t k,\n    float alpha,\n    const Dtype* A,\n    int64_t lda,\n    const Dtype* B,\n    int64_t ldb,\n    Dtype* C,\n    int64_t ldc,\n    Dtype* bgrad,\n    void *lt_workspace,\n    size_t workspaceSize) {\n  static_assert(std::is_same<Dtype, at::Half>::value || std::is_same<Dtype, at::BFloat16>::value,\n                \"gemm_bgradb_lt only supports fp16 and bf16\");\n  float beta = 0.0;\n  cudaDataType_t abcType = std::is_same<Dtype, at::Half>::value ? CUDA_R_16F : CUDA_R_16BF;\n\n  cublasLtHandle_t ltHandle =\n    reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());\n\n  cublasStatus_t status = CUBLAS_STATUS_SUCCESS;\n\n  cublasLtMatmulDescOpaque_t operationDesc = {};\n  cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};\n  cublasLtMatmulPreferenceOpaque_t preference = {};\n\n  int returnedResults                             = 0;\n  cublasLtMatmulHeuristicResult_t heuristicResult = {};\n  cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;\n\n  // Create operation descriptor; see cublasLtMatmulDescAttributes_t\n  // for details about defaults; here we just set the transforms for\n  // A and B.\n  status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  if (bgrad != nullptr) {\n    status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad));\n    if (status != CUBLAS_STATUS_SUCCESS) {\n      goto CLEANUP;\n    }\n      epilogue = CUBLASLT_EPILOGUE_BGRADB;\n  }\n\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));\n  if (status != CUBLAS_STATUS_SUCCESS) {\n    goto CLEANUP;\n  }\n\n  // Create matrix descriptors. Not setting any extra attributes.\n  status = cublasLtMatrixLayoutInit(\n    &Adesc, abcType, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatrixLayoutInit(\n    &Bdesc, abcType, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatrixLayoutInit(&Cdesc, abcType, m, n, ldc);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  // Create preference handle; In general, extra attributes can be\n  // used here to disable tensor ops or to make sure algo selected\n  // will work with badly aligned A, B, C. However, for simplicity\n  // here we assume A,B,C are always well aligned (e.g., directly\n  // come from cudaMalloc)\n  status = cublasLtMatmulPreferenceInit(&preference);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulPreferenceSetAttribute(\n    &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  // We just need the best available heuristic to try and run matmul.\n  // There is no guarantee that this will work. For example, if A is\n  // badly aligned, you can request more (e.g. 32) algos and try to\n  // run them one by one until something works.\n  status = cublasLtMatmulAlgoGetHeuristic(\n    ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  if (returnedResults == 0) {\n    status = CUBLAS_STATUS_NOT_SUPPORTED;\n    goto CLEANUP;\n  }\n  status = cublasLtMatmul(ltHandle,\n                          &operationDesc,\n                          &alpha,\n                          A,\n                          &Adesc,\n                          B,\n                          &Bdesc,\n                          &beta,\n                          C,\n                          &Cdesc,\n                          C,\n                          &Cdesc,\n                          //&heuristicResult.algo,\n                          NULL,\n                          lt_workspace,\n                          workspaceSize,\n                          at::cuda::getCurrentCUDAStream());\n\nCLEANUP:\n  // Descriptors are no longer needed as all GPU work was already\n  // enqueued.\n  return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;\n}\n\n\ntemplate int gemm_bgradb_lt(\n    cublasOperation_t transa,\n    cublasOperation_t transb,\n    int64_t m,\n    int64_t n,\n    int64_t k,\n    float alpha,\n    const at::Half* A,\n    int64_t lda,\n    const at::Half* B,\n    int64_t ldb,\n    at::Half* C,\n    int64_t ldc,\n    at::Half* bgrad,\n    void *lt_workspace,\n    size_t workspaceSize);\n\ntemplate int gemm_bgradb_lt(\n    cublasOperation_t transa,\n    cublasOperation_t transb,\n    int64_t m,\n    int64_t n,\n    int64_t k,\n    float alpha,\n    const at::BFloat16* A,\n    int64_t lda,\n    const at::BFloat16* B,\n    int64_t ldb,\n    at::BFloat16* C,\n    int64_t ldc,\n    at::BFloat16* bgrad,\n    void *lt_workspace,\n    size_t workspaceSize);\n\ntemplate <typename Dtype>\nint gemm_dact_bgradb_lt(\n    cublasOperation_t transa,\n    cublasOperation_t transb,\n    int64_t m,\n    int64_t n,\n    int64_t k,\n    float alpha,\n    const Dtype* A,\n    int64_t lda,\n    const Dtype* B,\n    int64_t ldb,\n    const void* pre_act,\n    Dtype* C,\n    int64_t ldc,\n    Dtype* bgrad,\n    bool is_gelu,\n    int heuristic,\n    void *lt_workspace,\n    size_t workspaceSize) {\n  static_assert(std::is_same<Dtype, at::Half>::value || std::is_same<Dtype, at::BFloat16>::value,\n                \"gemm_dact_bgradb_lt only supports fp16 and bf16\");\n  float beta = 0.0;\n  cudaDataType_t abcType = std::is_same<Dtype, at::Half>::value ? CUDA_R_16F : CUDA_R_16BF;\n\n  cublasLtHandle_t ltHandle =\n    reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());\n\n  cublasStatus_t status = CUBLAS_STATUS_SUCCESS;\n\n  cublasLtMatmulDescOpaque_t operationDesc = {};\n  cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};\n  cublasLtMatmulPreferenceOpaque_t preference = {};\n\n  int returnedResults                             = 0;\n  constexpr int requestedAlgoCount = 5;\n  cublasLtMatmulHeuristicResult_t heuristicResult[requestedAlgoCount] = {0};\n  cublasLtEpilogue_t epilogue = is_gelu ? CUBLASLT_EPILOGUE_DGELU_BGRAD : CUBLASLT_EPILOGUE_DRELU_BGRAD;\n\n  // Create operation descriptor; see cublasLtMatmulDescAttributes_t\n  // for details about defaults; here we just set the transforms for\n  // A and B.\n  status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad));\n  if (status != CUBLAS_STATUS_SUCCESS) {\n    goto CLEANUP;\n  }\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &pre_act, sizeof(pre_act));\n  if (status != CUBLAS_STATUS_SUCCESS) {\n    goto CLEANUP;\n  }\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc));\n\n  status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));\n  if (status != CUBLAS_STATUS_SUCCESS) {\n    goto CLEANUP;\n  }\n\n  // Create matrix descriptors. Not setting any extra attributes.\n  status = cublasLtMatrixLayoutInit(\n    &Adesc, abcType, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatrixLayoutInit(\n    &Bdesc, abcType, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatrixLayoutInit(&Cdesc, abcType, m, n, ldc);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  // Create preference handle; In general, extra attributes can be\n  // used here to disable tensor ops or to make sure algo selected\n  // will work with badly aligned A, B, C. However, for simplicity\n  // here we assume A,B,C are always well aligned (e.g., directly\n  // come from cudaMalloc)\n  status = cublasLtMatmulPreferenceInit(&preference);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n  status = cublasLtMatmulPreferenceSetAttribute(\n    &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  // We just need the best available heuristic to try and run matmul.\n  // There is no guarantee that this will work. For example, if A is\n  // badly aligned, you can request more (e.g. 32) algos and try to\n  // run them one by one until something works.\n  status = cublasLtMatmulAlgoGetHeuristic(\n    ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, requestedAlgoCount, heuristicResult, &returnedResults);\n  if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;\n\n  if (returnedResults == 0) {\n    status = CUBLAS_STATUS_NOT_SUPPORTED;\n    goto CLEANUP;\n  }\n  status = cublasLtMatmul(ltHandle,\n                          &operationDesc,\n                          &alpha,\n                          A,\n                          &Adesc,\n                          B,\n                          &Bdesc,\n                          &beta,\n                          C,\n                          &Cdesc,\n                          C,\n                          &Cdesc,\n                          //&heuristicResult.algo,\n                          &heuristicResult[heuristic].algo,\n                          // NULL,\n                          lt_workspace,\n                          workspaceSize,\n                          at::cuda::getCurrentCUDAStream());\n\nCLEANUP:\n  // Descriptors are no longer needed as all GPU work was already\n  // enqueued.\n  return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;\n}\n\ntemplate int gemm_dact_bgradb_lt(\n    cublasOperation_t transa,\n    cublasOperation_t transb,\n    int64_t m,\n    int64_t n,\n    int64_t k,\n    float alpha,\n    const at::Half* A,\n    int64_t lda,\n    const at::Half* B,\n    int64_t ldb,\n    const void* pre_act,\n    at::Half* C,\n    int64_t ldc,\n    at::Half* bgrad,\n    bool is_gelu,\n    int heuristic,\n    void *lt_workspace,\n    size_t workspaceSize);\n\ntemplate int gemm_dact_bgradb_lt(\n    cublasOperation_t transa,\n    cublasOperation_t transb,\n    int64_t m,\n    int64_t n,\n    int64_t k,\n    float alpha,\n    const at::BFloat16* A,\n    int64_t lda,\n    const at::BFloat16* B,\n    int64_t ldb,\n    const void* pre_act,\n    at::BFloat16* C,\n    int64_t ldc,\n    at::BFloat16* bgrad,\n    bool is_gelu,\n    int heuristic,\n    void *lt_workspace,\n    size_t workspaceSize);\n\n#endif\n\ntemplate <typename T>\nint linear_bias_wgrad_cuda(const T *input, const T *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, T *d_weight, T *d_bias, void *lt_workspace, size_t workspaceSize) {\n    const float alpha          = 1.0;\n    const float beta_zero      = 0.0;\n    int status = 1;\n#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600\n    status = gemm_bgradb_lt(\n    // (cublasLtHandle_t)handle,\n    CUBLAS_OP_N,\n    CUBLAS_OP_T,\n    in_features,\n    out_features,\n    batch_size,\n    alpha,\n    input,\n    in_features,\n    d_output,\n    out_features,\n    d_weight,\n    in_features,\n    d_bias,\n    lt_workspace,\n    workspaceSize);\n#endif\n\n    if (status != 0){\n        cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();\n        status = gemm_bias(\n          handle,\n          CUBLAS_OP_N,\n          CUBLAS_OP_T,\n          in_features,\n          out_features,\n          batch_size,\n          &alpha,\n          input,\n          in_features,\n          d_output,\n          out_features,\n          &beta_zero,\n          d_weight,\n          in_features);\n        // TD [2023-01-17]: I can't call Pytorch's gemm for now, due to linking error\n        // https://discuss.pytorch.org/t/how-can-i-use-the-function-at-gemm-float/95341\n        // at::cuda::blas::gemm<T>(\n        //   'N',\n        //   'T',\n        //   in_features,\n        //   out_features,\n        //   batch_size,\n        //   alpha,\n        //   input,\n        //   in_features,\n        //   d_output,\n        //   out_features,\n        //   beta_zero,\n        //   d_weight,\n        //   in_features);\n    }\n\n    return status;\n}\n\ntemplate <typename T>\nint linear_act_forward_cuda(const T *input, const T *weight, const T *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *output, void *pre_act, void *lt_workspace, size_t workspaceSize) {\n    int status = 1;\n#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600\n    status = gemm_bias_act_lt(\n    CUBLAS_OP_T,\n    CUBLAS_OP_N,\n    out_features,\n    batch_size,\n    in_features,\n    /*alpha=*/1.0,\n    weight,\n    in_features,\n    input,\n    in_features,\n    bias,\n    output,\n    out_features,\n    pre_act,\n    is_gelu,\n    heuristic,\n    lt_workspace,\n    workspaceSize);\n    return status;\n#else\n    return 1;\n#endif\n}\n\ntemplate <typename T>\nint bias_act_linear_dgrad_bgrad_cuda(const T *weight, const T *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *d_input, T *d_bias, void *lt_workspace, size_t workspaceSize) {\n    const float alpha          = 1.0;\n    int status = 1;\n#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600\n    status = gemm_dact_bgradb_lt(\n    CUBLAS_OP_N,\n    CUBLAS_OP_N,\n    in_features,\n    batch_size,\n    out_features,\n    alpha,\n    weight,\n    in_features,\n    d_output,\n    out_features,\n    pre_act,\n    d_input,\n    in_features,\n    d_bias,\n    is_gelu,\n    heuristic,\n    lt_workspace,\n    workspaceSize);\n#endif\n    return status;\n\n}\n\ntemplate int linear_bias_wgrad_cuda<at::Half>(const at::Half *input, const at::Half *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, at::Half *d_weight, at::Half *d_bias, void *lt_workspace, size_t workspaceSize);\ntemplate int linear_bias_wgrad_cuda<at::BFloat16>(const at::BFloat16 *input, const at::BFloat16 *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, at::BFloat16 *d_weight, at::BFloat16 *d_bias, void *lt_workspace, size_t workspaceSize);\n\ntemplate int linear_act_forward_cuda<at::Half>(const at::Half *input, const at::Half *weight, const at::Half *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::Half *output, void *pre_act, void *lt_workspace, size_t workspaceSize);\ntemplate int linear_act_forward_cuda<at::BFloat16>(const at::BFloat16 *input, const at::BFloat16 *weight, const at::BFloat16 *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::BFloat16 *output, void *pre_act, void *lt_workspace, size_t workspaceSize);\n\ntemplate int bias_act_linear_dgrad_bgrad_cuda<at::Half>(const at::Half *weight, const at::Half *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::Half *d_input, at::Half *d_bias, void *lt_workspace, size_t workspaceSize);\ntemplate int bias_act_linear_dgrad_bgrad_cuda<at::BFloat16>(const at::BFloat16 *weight, const at::BFloat16 *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::BFloat16 *d_input, at::BFloat16 *d_bias, void *lt_workspace, size_t workspaceSize);"
  },
  {
    "path": "csrc/fused_dense_lib/setup.py",
    "content": "import os\nimport subprocess\nfrom packaging.version import parse, Version\n\nimport torch\nfrom setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME\n\n\ndef get_cuda_bare_metal_version(cuda_dir):\n    raw_output = subprocess.check_output([cuda_dir + \"/bin/nvcc\", \"-V\"], universal_newlines=True)\n    output = raw_output.split()\n    release_idx = output.index(\"release\") + 1\n    bare_metal_version = parse(output[release_idx].split(\",\")[0])\n\n    return raw_output, bare_metal_version\n\n\ndef append_nvcc_threads(nvcc_extra_args):\n    _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)\n    if bare_metal_version >= Version(\"11.2\"):\n        nvcc_threads = os.getenv(\"NVCC_THREADS\") or \"4\"\n        return nvcc_extra_args + [\"--threads\", nvcc_threads]\n    return nvcc_extra_args\n\n\nsetup(\n    name='fused_dense_lib',\n    ext_modules=[\n        CUDAExtension(\n            name='fused_dense_lib',\n            sources=['fused_dense.cpp', 'fused_dense_cuda.cu'],\n            extra_compile_args={\n                               'cxx': ['-O3',],\n                               'nvcc': append_nvcc_threads(['-O3'])\n                               }\n            )\n    ],\n    cmdclass={\n        'build_ext': BuildExtension\n})\n\n"
  },
  {
    "path": "csrc/layer_norm/README.md",
    "content": "This CUDA extension implements fused dropout + residual + LayerNorm, building on\nApex's [FastLayerNorm](https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm).\nMajor changes:\n- Add dropout and residual.\n- Make it work for both pre-norm and post-norm architecture.\n- Support more hidden dimensions (all dimensions divisible by 8, up to 8192).\n- Implement RMSNorm as an option.\n- Support layer norm with parallel residual (e.g., GPT-J, GPT-NeoX, PaLM).\n\nIf you want to use it for dimensions larger than 8k, please file an issue.\n\nThis extension has only been tested on A100s.\n\n```sh\ncd csrc/layer_norm && pip install .\n```\n\nAs of 2024-01-05, this extension is no longer used in the FlashAttention repo.\nWe've instead switched to a Triton-based\n[implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py).\n"
  },
  {
    "path": "csrc/layer_norm/ln.h",
    "content": "#pragma once\n\n#include <unordered_map>\n#include <cuda_fp16.h>\n#include <cuda_bf16.h>\n\n#ifdef OLD_GENERATOR_PATH\n#include <ATen/CUDAGeneratorImpl.h>\n#else\n#include <ATen/cuda/CUDAGeneratorImpl.h>\n#endif\n\nnamespace layer_norm {\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename Params>\nstruct LaunchParams{\n\n    size_t elts_per_thread;\n    size_t workspace_bytes;\n    size_t barrier_size;\n\n    cudaDeviceProp * props;\n\n    cudaStream_t stream;\n\n    Params params;\n\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct ParamsBase {\n    ParamsBase()\n        : ctas_per_col(0)\n        , rows(0)\n        , cols(0)\n        , x(nullptr)\n        , mu(nullptr)\n        , rs(nullptr)\n        , gamma(nullptr)\n        , gamma1(nullptr)\n        , rowscale(nullptr)\n        , colscale(nullptr)\n        , dropout_keep_p(1.f)\n        , dropout_scale(1.f)\n        , is_rms_norm(false)\n        , workspace(nullptr)\n        , barrier(nullptr)\n    {\n    }\n\n    // For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x.\n    int ctas_per_col;\n\n    // Input is interpreted as matrix. We normalize across columns.\n    int rows;\n    int cols;\n\n    // Common data pointers.\n    void *x0;\n    void *x1;\n    void *residual;\n    void *x;\n    void *dmask;\n    void *dmask1;\n    void *mu;\n    void *rs;\n    void *gamma;\n    void *gamma1;\n    void *rowscale;\n    void *colscale;\n    void *x0_subset;\n    void *z_subset;\n\n    float inverse_cols;\n\n    float dropout_keep_p;\n    float dropout_scale;\n    float rowscale_const;\n\n    bool is_rms_norm;\n\n    // Multi-CTA workspace in gmem.\n    void *workspace;\n\n    // Multi-CTA sync barriers in gmem.\n    int *barrier;\n\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct FwdParams : public ParamsBase {\n    FwdParams()\n        : ParamsBase()\n        , z(nullptr)\n        , z1(nullptr)\n        , beta(nullptr)\n        , beta1(nullptr)\n        , epsilon(0.f)\n    {\n    }\n\n    // Output of LN FWD.\n    void *z;\n    void *z1;\n    void *beta;\n    void *beta1;\n    float epsilon;\n\n    // Random state.\n    at::PhiloxCudaState philox_args;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct BwdParams : public ParamsBase {\n    BwdParams()\n        : ParamsBase()\n        , dz(nullptr)\n        , dz1(nullptr)\n        , dx(nullptr)\n        , dbeta_part(nullptr)\n        , dgamma_part(nullptr)\n        , dbeta1_part(nullptr)\n        , dgamma1_part(nullptr)\n        , dcolscale_part(nullptr)\n        , dx0(nullptr)\n        , dx1(nullptr)\n        , dresidual(nullptr)\n        , dbeta(nullptr)\n        , dgamma(nullptr)\n        , dbeta1(nullptr)\n        , dgamma1(nullptr)\n        , dcolscale(nullptr)\n    {\n    }\n\n    // Input: gradient wrt. LN FWD output.\n    void *dz;\n    void *dz1;\n    // Input: gradient wrt residual.\n    void *dx;\n\n    // Workspace for Wgrad pre-reduction.\n    void *dbeta_part;\n    void *dgamma_part;\n    void *dbeta1_part;\n    void *dgamma1_part;\n    void *dcolscale_part;\n\n    // Output: Dgrad.\n    void *dx0;\n    void *dx1;\n    void *dresidual;\n    // Output: Wgrad.\n    void *dbeta;\n    void *dgamma;\n    void *dbeta1;\n    void *dgamma1;\n    void *dcolscale;\n\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nusing FwdFunction = std::function<void(LaunchParams<FwdParams>&, const bool)>;\nusing BwdFunction = std::function<void(LaunchParams<BwdParams>&, const bool)>;\nusing FunctionKey = uint64_t;\nusing FwdRegistry = std::unordered_map<FunctionKey, FwdFunction>;\nusing BwdRegistry = std::unordered_map<FunctionKey, BwdFunction>;\n\nextern FwdRegistry FWD_FUNCS, PARALLEL_FWD_FUNCS;\nextern BwdRegistry BWD_FUNCS, PARALLEL_BWD_FUNCS;\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nusing fp32 = float;\nusing fp16 = half;\nusing bf16 = nv_bfloat16;\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename T>\nstruct TypeId{};\n\ntemplate<>\nstruct TypeId<fp16>{\n    constexpr static uint32_t Value = 0;\n};\n\ntemplate<>\nstruct TypeId<bf16>{\n    constexpr static uint32_t Value = 1;\n};\n\ntemplate<>\nstruct TypeId<fp32>{\n    constexpr static uint32_t Value = 2;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename T, int S>\nstruct Type2Key{\n    constexpr static uint32_t Value = TypeId<T>::Value << S;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename T>\nstruct WeightType2Key : public Type2Key<T, 0>{};\n\ntemplate<typename T>\nstruct InputType2Key : public Type2Key<T, 2>{};\n\ntemplate<typename T>\nstruct ResidualType2Key : public Type2Key<T, 4>{};\n\ntemplate<typename T>\nstruct OutputType2Key : public Type2Key<T, 6>{};\n\ntemplate<typename T>\nstruct ComputeType2Key : public Type2Key<T, 8>{};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename W, typename I, typename R, typename O, typename C>\nstruct Types2Key{\n    constexpr static uint32_t Value = WeightType2Key<W>::Value | InputType2Key<I>::Value | ResidualType2Key<R>::Value | OutputType2Key<O>::Value | ComputeType2Key<C>::Value;\n    constexpr static inline uint64_t get(const uint64_t hidden_size){\n        constexpr uint64_t type_key = Value;\n        return (type_key << 32) | hidden_size;\n    }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE>\nstruct FwdRegistrar{\n    FwdRegistrar(FwdFunction f){\n        uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE);\n        FWD_FUNCS.insert({ key, f });\n    }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE>\nstruct BwdRegistrar{\n    BwdRegistrar(BwdFunction f){\n        uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE);\n        BWD_FUNCS.insert({ key, f });\n    }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE>\nstruct FwdParallelRegistrar{\n    FwdParallelRegistrar(FwdFunction f){\n        uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE);\n        PARALLEL_FWD_FUNCS.insert({ key, f });\n    }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE>\nstruct BwdParallelRegistrar{\n    BwdParallelRegistrar(BwdFunction f){\n        uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE);\n        PARALLEL_BWD_FUNCS.insert({ key, f });\n    }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n}  // namespace layer_norm\n"
  },
  {
    "path": "csrc/layer_norm/ln_api.cpp",
    "content": "#include <torch/extension.h>\n#include \"ATen/cuda/CUDAContext.h\"\n#include <c10/cuda/CUDAGuard.h>\n\n#include \"ln.h\"\n\n/*\n\nSupported Type combinations:\n\ninput  residual   compute   weights   output\n============================================\nfp32     fp32      fp32      fp32      fp32\nfp16     fp32      fp32      fp32      fp16\nfp16     fp16      fp32      fp32      fp16\nbf16     fp32      fp32      fp32      bf16\nbf16     bf16      fp32      fp32      bf16\nfp16     fp16      fp32      fp16      fp16\nbf16     bf16      fp32      bf16      bf16\n\nRemarks:\nOutput type = Input type\nCompute always in FP32\n\n*/\n\nnamespace layer_norm {\n\n// Create registries and provide runtime versions of config hash functions.\n\nFwdRegistry FWD_FUNCS, PARALLEL_FWD_FUNCS;\nBwdRegistry BWD_FUNCS, PARALLEL_BWD_FUNCS;\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nuint32_t get_type_id(torch::Dtype dtype){\n    if( dtype == torch::kFloat16 ) {\n        return TypeId<fp16>::Value;\n    } else if( dtype == torch::kBFloat16 ) {\n        return TypeId<bf16>::Value;\n    } else if( dtype == torch::kFloat32 ) {\n        return TypeId<fp32>::Value;\n    } else {\n        TORCH_CHECK(false, \"Type not supported: \", dtype);\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nuint64_t get_key(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint64_t hidden_size) {\n    using namespace layer_norm;\n    uint64_t type_key = get_type_id(wtype) | (get_type_id(itype) << 2) | (get_type_id(rtype) << 4) | (get_type_id(otype) << 6) | (get_type_id(ctype) << 8);\n    uint64_t launcher_key = (type_key << 32) | hidden_size;\n    return launcher_key;\n}\n\n}  // namespace layer_norm\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nlayer_norm::FwdFunction & get_fwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) {\n    auto iter = layer_norm::FWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size));\n    if( iter != layer_norm::FWD_FUNCS.end() ) {\n        return iter->second;\n    } else {\n        TORCH_CHECK(false, \"FWD: Unsupported hidden_size or types: \", hidden_size, wtype, itype, rtype, otype, ctype);\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nlayer_norm::BwdFunction & get_bwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) {\n    auto iter = layer_norm::BWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size));\n    if( iter != layer_norm::BWD_FUNCS.end() ) {\n        return iter->second;\n    } else {\n        TORCH_CHECK(false, \"BWD: Unsupported hidden_size or types: \", hidden_size, wtype, itype, rtype, otype, ctype);\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nlayer_norm::FwdFunction & get_parallel_fwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) {\n    auto iter = layer_norm::PARALLEL_FWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size));\n    if( iter != layer_norm::PARALLEL_FWD_FUNCS.end() ) {\n        return iter->second;\n    } else {\n        TORCH_CHECK(false, \"FWD: Unsupported hidden_size or types: \", hidden_size, wtype, itype, rtype, otype, ctype);\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nlayer_norm::BwdFunction & get_parallel_bwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) {\n    auto iter = layer_norm::PARALLEL_BWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size));\n    if( iter != layer_norm::PARALLEL_BWD_FUNCS.end() ) {\n        return iter->second;\n    } else {\n        TORCH_CHECK(false, \"BWD: Unsupported hidden_size or types: \", hidden_size, wtype, itype, rtype, otype, ctype);\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstd::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0,      // Input: BxSxhidden_size\n                                           std::optional<const at::Tensor> &residual_,  // Residual: BxSxhidden_size\n                                           const at::Tensor &gamma,   // hidden_size\n                                           std::optional<const at::Tensor> &beta_,   // hidden_size\n                                           std::optional<const at::Tensor> &rowscale_,      // BxS\n                                           std::optional<const at::Tensor> &colscale_,      // hidden_size\n                                           std::optional<const at::Tensor> &x0_subset_,      // BxS\n                                           std::optional<const at::Tensor> &z_subset_,      // BxS\n                                           const float dropout_p,\n                                           const float epsilon,\n                                           const float rowscale_const,\n                                           const int64_t z_numrows,\n                                           std::optional<at::Generator> gen_,\n                                           bool residual_in_fp32=false,\n                                           bool is_rms_norm=false\n) {\n    auto itype = x0.scalar_type();\n    auto rtype = residual_.has_value()\n        ? residual_.value().scalar_type()\n        : (residual_in_fp32 ? torch::kFloat32 : x0.scalar_type());\n    auto wtype = gamma.scalar_type();\n    auto otype = itype;\n    auto ctype = torch::kFloat32;\n    auto mtype = torch::kUInt8;\n\n    TORCH_CHECK(x0.is_cuda());\n    TORCH_CHECK(gamma.is_cuda());\n\n    TORCH_CHECK(x0.is_contiguous());\n    // c10::IntArrayRef does not own the storage, so we need to construct a vector.\n    // Otherwise just constructing IntArrayRef({blah}) will cause uninitialized memory because\n    // blah is then deallocated.\n    std::vector<int64_t> sizes_vec {!x0_subset_.has_value() ? x0.size(0) : x0_subset_.value().size(0), x0.size(1)};\n    auto sizes = c10::IntArrayRef(sizes_vec);\n    TORCH_CHECK(x0.dim() == 2);\n    TORCH_CHECK(sizes.size() == 2);\n\n    const int rows = sizes[0];\n    const int cols = sizes[1];\n    auto hidden_size = gamma.numel();\n    TORCH_CHECK(hidden_size == cols);\n\n    if (beta_.has_value()) {\n        auto beta = beta_.value();\n        TORCH_CHECK(beta.dtype() == wtype);\n        TORCH_CHECK(beta.is_cuda());\n        TORCH_CHECK(beta.is_contiguous());\n        TORCH_CHECK(beta.sizes() == gamma.sizes());\n    }\n\n    if (residual_.has_value()) {\n        auto residual = residual_.value();\n        TORCH_CHECK(residual.is_cuda());\n        TORCH_CHECK(residual.is_contiguous());\n        TORCH_CHECK(residual.sizes() == sizes);\n    }\n\n    if (rowscale_.has_value()) {\n        auto rowscale = rowscale_.value();\n        TORCH_CHECK(rowscale.is_cuda());\n        TORCH_CHECK(rowscale.is_contiguous());\n        TORCH_CHECK(rowscale.sizes() == c10::IntArrayRef{rows});\n        TORCH_CHECK(rowscale.dtype() == itype);\n    }\n\n    if (colscale_.has_value()) {\n        auto colscale = colscale_.value();\n        TORCH_CHECK(colscale.is_cuda());\n        TORCH_CHECK(colscale.is_contiguous());\n        TORCH_CHECK(colscale.sizes() == c10::IntArrayRef{cols});\n        TORCH_CHECK(colscale.dtype() == wtype);\n    }\n\n    if (x0_subset_.has_value()) {\n        auto x0_subset = x0_subset_.value();\n        TORCH_CHECK(x0_subset.is_cuda());\n        TORCH_CHECK(x0_subset.is_contiguous());\n        TORCH_CHECK(x0_subset.sizes() == c10::IntArrayRef{rows});\n        TORCH_CHECK(x0_subset.dtype() == torch::kInt32);\n\n        TORCH_CHECK(z_subset_.has_value());\n        auto z_subset = z_subset_.value();\n        TORCH_CHECK(z_subset.is_cuda());\n        TORCH_CHECK(z_subset.is_contiguous());\n        TORCH_CHECK(z_subset.sizes() == c10::IntArrayRef{rows});\n        TORCH_CHECK(z_subset.dtype() == torch::kInt32);\n    }\n\n    TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 8192));\n    TORCH_CHECK(epsilon >= 0.f);\n\n    // Otherwise the kernel will be launched from cuda:0 device\n    at::cuda::CUDAGuard device_guard{x0.device()};\n\n    auto opts = x0.options();\n\n    bool save_x = residual_.has_value() || (dropout_p > 0.f) || rowscale_.has_value() || colscale_.has_value() || x0_subset_.has_value() || (itype != rtype);\n    at::Tensor x;\n    if (save_x) { x = torch::empty(sizes, opts.dtype(rtype)); }\n    at::Tensor dmask;\n    if (dropout_p > 0.f) { dmask = torch::empty(x0.sizes(), opts.dtype(mtype)); };\n    auto z = torch::empty(z_subset_.has_value() ? c10::IntArrayRef{z_numrows, cols} : sizes, opts.dtype(otype));\n\n    auto mu = torch::empty({ rows }, opts.dtype(ctype));\n    auto rsigma = torch::empty({ rows }, opts.dtype(ctype));\n\n    layer_norm::LaunchParams<layer_norm::FwdParams> launch_params;\n\n    launch_params.props = at::cuda::getCurrentDeviceProperties();\n    launch_params.stream = at::cuda::getCurrentCUDAStream().stream();\n    TORCH_CHECK(dropout_p < 1.f);\n    launch_params.params.dropout_keep_p = 1.f - dropout_p;\n    launch_params.params.residual = residual_.has_value() ? residual_.value().data_ptr() : nullptr;\n    launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;\n    launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr;\n    launch_params.params.x0_subset = x0_subset_.has_value() ? x0_subset_.value().data_ptr() : nullptr;\n    launch_params.params.z_subset = z_subset_.has_value() ? z_subset_.value().data_ptr() : nullptr;\n\n    auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(\n        gen_, at::cuda::detail::getDefaultCUDAGenerator());\n\n    auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };\n    const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024);\n    // Request the kernel launcher.\n    auto launcher = get_fwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));\n\n    // Set the kernel runtime parameters.\n    layer_norm::FwdParams &params = launch_params.params;\n    params.rows = rows;\n    params.cols = cols;\n    params.x0 = x0.data_ptr();\n    params.x = save_x ? x.data_ptr() : nullptr;\n    params.dmask = dropout_p > 0.f ? dmask.data_ptr() : nullptr;\n    params.mu = mu.data_ptr();\n    params.rs = rsigma.data_ptr();\n    params.gamma = gamma.data_ptr();\n    params.beta = beta_.has_value() ? beta_.value().data_ptr() : nullptr;\n    params.z = z.data_ptr();\n    params.epsilon = epsilon;\n    params.dropout_scale = 1.f / (1.f - dropout_p);\n    params.inverse_cols = 1.f / float(params.cols);\n    params.rowscale_const = rowscale_const;\n    params.is_rms_norm = is_rms_norm;\n\n    // Query the kernel-specific launch parameters.\n    launcher(launch_params, true);\n\n    at::Tensor workspace, barrier;\n\n    if (dropout_p > 0.f) {\n        // number of times random will be generated per thread, to offset philox counter in thc random\n        // state\n        int64_t counter_offset = launch_params.elts_per_thread;\n\n        // See Note [Acquire lock when using random generators]\n        {\n            std::lock_guard<std::mutex> lock(gen->mutex_);\n            params.philox_args = gen->philox_cuda_state(counter_offset);\n        }\n    }\n\n    if( launch_params.barrier_size > 0 ) {\n        auto options = x0.options();\n        barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32));\n        workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar));\n        params.workspace = workspace.data_ptr();\n        params.barrier = barrier.data_ptr<int>();\n    }\n\n    // Launch the kernel.\n    launcher(launch_params, false);\n\n    return { z, x, dmask, mu, rsigma };\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstd::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz,     // BxSxhidden_size\n                                           std::optional<const at::Tensor> &dx_,     // BxSxhidden_size\n                                           const at::Tensor &x,      // BxSxhidden_size\n                                           std::optional<const at::Tensor> &x0_,     // BxSxhidden_size\n                                           std::optional<const at::Tensor> &dmask_,  // BxSxhidden_size\n                                           const at::Tensor &mu,     // BxS, FP32!\n                                           const at::Tensor &rsigma, // BxS, FP32!\n                                           const at::Tensor &gamma,   // hidden_size\n                                           std::optional<const at::Tensor> &rowscale_,      // BxS\n                                           std::optional<const at::Tensor> &colscale_,      // hidden_size\n                                           std::optional<const at::Tensor> &x0_subset_,      // BxS\n                                           std::optional<const at::Tensor> &z_subset_,      // BxS\n                                           const float dropout_p,\n                                           const float rowscale_const,\n                                           const int64_t x0_numrows,\n                                           const bool has_residual,\n                                           bool is_rms_norm=false\n) {\n\n    auto itype = dz.scalar_type();\n    auto rtype = x.scalar_type();\n    auto wtype = gamma.scalar_type();\n    auto otype = itype;\n    auto ctype = torch::kFloat32;\n    auto mtype = torch::kUInt8;\n\n    if (dropout_p > 0.f) { TORCH_CHECK(dmask_.has_value()); }\n\n    TORCH_CHECK(dz.dtype() == otype);\n    TORCH_CHECK(mu.dtype() == ctype);\n    TORCH_CHECK(rsigma.dtype() == ctype);\n\n    TORCH_CHECK(x.is_cuda());\n    TORCH_CHECK(dz.is_cuda());\n    TORCH_CHECK(mu.is_cuda());\n    TORCH_CHECK(rsigma.is_cuda());\n    TORCH_CHECK(gamma.is_cuda());\n\n    TORCH_CHECK(x.is_contiguous());\n    TORCH_CHECK(dz.is_contiguous());\n\n    auto sizes = x.sizes();\n    TORCH_CHECK(sizes.size() == 2);\n    auto rows = sizes[0];\n    auto cols = sizes[1];\n    TORCH_CHECK(dz.dim() == 2);\n    TORCH_CHECK(dz.size(1) == cols);\n    auto hidden_size = gamma.numel();\n    TORCH_CHECK(hidden_size == cols);\n\n    // c10::IntArrayRef does not own the storage, so we need to construct a vector.\n    // Otherwise just constructing IntArrayRef({blah}) will cause uninitialized memory because\n    // blah is then deallocated.\n    std::vector<int64_t> x0_sizes_vec {!x0_subset_.has_value() ? rows : x0_numrows, cols};\n    auto x0_sizes = c10::IntArrayRef(x0_sizes_vec);\n\n    if (dx_.has_value()) {\n        auto dx = dx_.value();\n        TORCH_CHECK(dx.dtype() == rtype);\n        TORCH_CHECK(dx.is_cuda());\n        TORCH_CHECK(dx.is_contiguous());\n        TORCH_CHECK(dx.sizes() == sizes);\n    }\n\n    if (dmask_.has_value()) {\n        auto dmask = dmask_.value();\n        TORCH_CHECK(dmask.dtype() == mtype);\n        TORCH_CHECK(dmask.is_cuda());\n        TORCH_CHECK(dmask.is_contiguous());\n        TORCH_CHECK(dmask.sizes() == x0_sizes);\n    }\n\n    if (rowscale_.has_value()) {\n        auto rowscale = rowscale_.value();\n        TORCH_CHECK(rowscale.is_cuda());\n        TORCH_CHECK(rowscale.is_contiguous());\n        TORCH_CHECK(rowscale.sizes() == c10::IntArrayRef{rows});\n        TORCH_CHECK(rowscale.dtype() == itype);\n    }\n\n    if (colscale_.has_value()) {\n        auto colscale = colscale_.value();\n        TORCH_CHECK(colscale.is_cuda());\n        TORCH_CHECK(colscale.is_contiguous());\n        TORCH_CHECK(colscale.sizes() == c10::IntArrayRef{cols});\n        TORCH_CHECK(colscale.dtype() == wtype);\n\n        TORCH_CHECK(x0_.has_value());\n        auto x0 = x0_.value();\n        TORCH_CHECK(x0.is_cuda());\n        TORCH_CHECK(x0.is_contiguous());\n        TORCH_CHECK(x0.sizes() == x0_sizes);\n        TORCH_CHECK(x0.dtype() == itype);\n    }\n\n    if (x0_subset_.has_value()) {\n        auto x0_subset = x0_subset_.value();\n        TORCH_CHECK(x0_subset.is_cuda());\n        TORCH_CHECK(x0_subset.is_contiguous());\n        TORCH_CHECK(x0_subset.sizes() == c10::IntArrayRef{rows});\n        TORCH_CHECK(x0_subset.dtype() == torch::kInt32);\n\n        TORCH_CHECK(z_subset_.has_value());\n        auto z_subset = z_subset_.value();\n        TORCH_CHECK(z_subset.is_cuda());\n        TORCH_CHECK(z_subset.is_contiguous());\n        TORCH_CHECK(z_subset.sizes() == c10::IntArrayRef{rows});\n        TORCH_CHECK(z_subset.dtype() == torch::kInt32);\n    }\n\n    TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 8192));\n\n    TORCH_CHECK(mu.numel() == rows);\n    TORCH_CHECK(mu.sizes() == rsigma.sizes());\n\n    TORCH_CHECK(gamma.numel() == cols);\n\n    // Otherwise the kernel will be launched from cuda:0 device\n    at::cuda::CUDAGuard device_guard{dz.device()};\n\n    auto opts = x.options();\n\n    auto dx0 = torch::empty(x0_sizes, opts.dtype(itype));\n    at::Tensor dresidual;\n    if (has_residual) { dresidual = torch::empty_like(x, opts.dtype(rtype)); }\n    auto dgamma = torch::empty_like(gamma);\n    auto dbeta = torch::empty_like(gamma);\n    at::Tensor dcolscale;\n    if (colscale_.has_value()) {\n        dcolscale = torch::empty_like(colscale_.value());\n    }\n\n    layer_norm::LaunchParams<layer_norm::BwdParams> launch_params;\n    launch_params.stream = at::cuda::getCurrentCUDAStream().stream();\n    launch_params.props = at::cuda::getCurrentDeviceProperties();\n    TORCH_CHECK(dropout_p < 1.f);\n    launch_params.params.dropout_keep_p = 1.f - dropout_p;\n    launch_params.params.dresidual = has_residual ? dresidual.data_ptr() : nullptr;\n    launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;\n    launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr;\n    launch_params.params.x0_subset = x0_subset_.has_value() ? x0_subset_.value().data_ptr() : nullptr;\n    launch_params.params.z_subset = z_subset_.has_value() ? z_subset_.value().data_ptr() : nullptr;\n\n    auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };\n    const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024);\n    auto launcher = get_bwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));\n\n    launcher(launch_params, true);\n\n    auto dgamma_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));\n    auto dbeta_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));\n    at::Tensor dcolscale_part;\n    if (colscale_.has_value()) {\n        dcolscale_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));\n    }\n    at::Tensor workspace, barrier;\n\n    layer_norm::BwdParams &params = launch_params.params;\n    params.rows = rows;\n    params.cols = cols;\n    params.x = x.data_ptr();\n    params.x0 = x0_.has_value() ? x0_.value().data_ptr() : nullptr;\n    params.dmask = dropout_p > 0.f ? dmask_.value().data_ptr() : nullptr;\n    params.mu = mu.data_ptr();\n    params.rs = rsigma.data_ptr();\n    params.gamma = gamma.data_ptr();\n    params.dz = dz.data_ptr();\n    params.dx = dx_.has_value() ? dx_.value().data_ptr() : nullptr;\n    params.dx0 = dx0.data_ptr();\n    params.dbeta = dbeta.data_ptr();\n    params.dgamma = dgamma.data_ptr();\n    params.dcolscale = colscale_.has_value() ? dcolscale.data_ptr() : nullptr;\n    params.dbeta_part = dbeta_part.data_ptr();\n    params.dgamma_part = dgamma_part.data_ptr();\n    params.dcolscale_part = colscale_.has_value() ? dcolscale_part.data_ptr() : nullptr;\n    params.dropout_scale = 1.f / (1.f - dropout_p);\n    params.inverse_cols = 1.f / float(params.cols);\n    params.rowscale_const = rowscale_const;\n    params.is_rms_norm = is_rms_norm;\n\n    if( launch_params.barrier_size > 0 ) {\n        // TODO Any way to avoid this?\n        barrier = torch::zeros(launch_params.barrier_size, opts.dtype(torch::kInt32));\n        workspace = torch::empty(launch_params.workspace_bytes, opts.dtype(torch::kChar));\n        params.workspace = workspace.data_ptr();\n        params.barrier = barrier.data_ptr<int>();\n    }\n\n    launcher(launch_params, false);\n\n    std::vector<at::Tensor> result = { dx0, dresidual, dgamma, dbeta, dgamma_part, dbeta_part };\n    if (colscale_.has_value()) {\n        result.push_back(dcolscale);\n        result.push_back(dcolscale_part);\n    }\n    return result;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstd::vector<at::Tensor> dropout_add_ln_parallel_residual_fwd(\n    const at::Tensor &x0,      // Input: BxSxhidden_size\n    std::optional<const at::Tensor> &x1_,      // Input: BxSxhidden_size\n    std::optional<const at::Tensor> &residual_,  // Residual: BxSxhidden_size\n    const at::Tensor &gamma0,   // hidden_size\n    std::optional<const at::Tensor> &beta0_,   // hidden_size\n    std::optional<const at::Tensor> &gamma1_,   // hidden_size\n    std::optional<const at::Tensor> &beta1_,   // hidden_size\n    const float dropout_p,\n    const float epsilon,\n    std::optional<at::Generator> gen_,\n    bool residual_in_fp32=false,\n    bool is_rms_norm=false\n) {\n    auto itype = x0.scalar_type();\n    auto rtype = residual_.has_value()\n        ? residual_.value().scalar_type()\n        : (residual_in_fp32 ? torch::kFloat32 : x0.scalar_type());\n    auto wtype = gamma0.scalar_type();\n    auto otype = itype;\n    auto ctype = torch::kFloat32;\n    auto mtype = torch::kUInt8;\n\n    TORCH_CHECK(x0.is_cuda());\n    TORCH_CHECK(gamma0.is_cuda());\n\n    TORCH_CHECK(x0.is_contiguous());\n    const auto sizes = x0.sizes();\n    TORCH_CHECK(x0.dim() == 2);\n\n    const int rows = sizes[0];\n    const int cols = sizes[1];\n    auto hidden_size = gamma0.numel();\n    TORCH_CHECK(hidden_size == cols);\n\n    if (x1_.has_value()) {\n        auto x1 = x1_.value();\n        TORCH_CHECK(x1.is_cuda());\n        TORCH_CHECK(x1.is_contiguous());\n        TORCH_CHECK(x1.sizes() == sizes);\n    }\n\n    if (residual_.has_value()) {\n        auto residual = residual_.value();\n        TORCH_CHECK(residual.is_cuda());\n        TORCH_CHECK(residual.is_contiguous());\n        TORCH_CHECK(residual.sizes() == sizes);\n    }\n\n    if (beta0_.has_value()) {\n        auto beta0 = beta0_.value();\n        TORCH_CHECK(beta0.dtype() == wtype);\n        TORCH_CHECK(beta0.is_cuda());\n        TORCH_CHECK(beta0.is_contiguous());\n        TORCH_CHECK(beta0.sizes() == gamma0.sizes());\n    }\n\n    if (gamma1_.has_value()) {\n        auto gamma1 = gamma1_.value();\n        TORCH_CHECK(gamma1.dtype() == wtype);\n        TORCH_CHECK(gamma1.is_cuda());\n        TORCH_CHECK(gamma1.is_contiguous());\n        TORCH_CHECK(gamma1.sizes() == gamma0.sizes());\n    }\n\n    if (beta1_.has_value()) {\n        auto beta1 = beta1_.value();\n        TORCH_CHECK(beta1.dtype() == wtype);\n        TORCH_CHECK(beta1.is_cuda());\n        TORCH_CHECK(beta1.is_contiguous());\n        TORCH_CHECK(beta1.sizes() == gamma0.sizes());\n    }\n\n    TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 8192));\n    TORCH_CHECK(epsilon >= 0.f);\n\n    // Otherwise the kernel will be launched from cuda:0 device\n    at::cuda::CUDAGuard device_guard{x0.device()};\n\n    auto opts = x0.options();\n\n    bool save_x = residual_.has_value() || x1_.has_value() || (dropout_p > 0.f) || (itype != rtype);\n    at::Tensor x;\n    if (save_x) { x = torch::empty(sizes, opts.dtype(rtype)); }\n    at::Tensor dmask0, dmask1;\n    if (dropout_p > 0.f) {\n        dmask0 = torch::empty(x0.sizes(), opts.dtype(mtype));\n        if (x1_.has_value()) { dmask1 = torch::empty(x0.sizes(), opts.dtype(mtype)); }\n    };\n    auto z0 = torch::empty(sizes, opts.dtype(otype));\n    at::Tensor z1;\n    if (gamma1_.has_value()) { z1 = torch::empty(sizes, opts.dtype(otype)); }\n\n    auto mu = torch::empty({ rows }, opts.dtype(ctype));\n    auto rsigma = torch::empty({ rows }, opts.dtype(ctype));\n\n    layer_norm::LaunchParams<layer_norm::FwdParams> launch_params;\n\n    launch_params.props = at::cuda::getCurrentDeviceProperties();\n    launch_params.stream = at::cuda::getCurrentCUDAStream().stream();\n    TORCH_CHECK(dropout_p < 1.f);\n    launch_params.params.dropout_keep_p = 1.f - dropout_p;\n    launch_params.params.residual = residual_.has_value() ? residual_.value().data_ptr() : nullptr;\n\n    auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(\n        gen_, at::cuda::detail::getDefaultCUDAGenerator());\n\n    auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };\n    const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024);\n    // Request the kernel launcher.\n    auto launcher = get_parallel_fwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));\n\n    // Set the kernel runtime parameters.\n    layer_norm::FwdParams &params = launch_params.params;\n    params.rows = rows;\n    params.cols = cols;\n    params.x0 = x0.data_ptr();\n    params.x1 = x1_.has_value() ? x1_.value().data_ptr() : nullptr;\n    params.x = save_x ? x.data_ptr() : nullptr;\n    params.dmask = dropout_p > 0.f ? dmask0.data_ptr() : nullptr;\n    params.dmask1 = (dropout_p > 0.f && x1_.has_value()) ? dmask1.data_ptr() : nullptr;\n    params.mu = mu.data_ptr();\n    params.rs = rsigma.data_ptr();\n    params.gamma = gamma0.data_ptr();\n    params.gamma1 = gamma1_.has_value() ? gamma1_.value().data_ptr() : nullptr;\n    params.beta = beta0_.has_value() ? beta0_.value().data_ptr() : nullptr;\n    params.beta1 = beta1_.has_value() ? beta1_.value().data_ptr() : nullptr;\n    params.z = z0.data_ptr();\n    params.z1 = gamma1_.has_value() ? z1.data_ptr() : nullptr;\n    params.epsilon = epsilon;\n    params.dropout_scale = 1.f / (1.f - dropout_p);\n    params.inverse_cols = 1.f / float(params.cols);\n    params.is_rms_norm = is_rms_norm;\n\n    // Query the kernel-specific launch parameters.\n    launcher(launch_params, true);\n\n    at::Tensor workspace, barrier;\n\n    if (dropout_p > 0.f) {\n        // number of times random will be generated per thread, to offset philox counter in thc random\n        // state\n        int64_t counter_offset = 2 * launch_params.elts_per_thread;\n\n        // See Note [Acquire lock when using random generators]\n        {\n            std::lock_guard<std::mutex> lock(gen->mutex_);\n            params.philox_args = gen->philox_cuda_state(counter_offset);\n        }\n    }\n\n    if( launch_params.barrier_size > 0 ) {\n        auto options = x0.options();\n        barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32));\n        workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar));\n        params.workspace = workspace.data_ptr();\n        params.barrier = barrier.data_ptr<int>();\n    }\n\n    // Launch the kernel.\n    launcher(launch_params, false);\n\n    return { z0, z1, x, dmask0, dmask1, mu, rsigma };\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstd::vector<at::Tensor> dropout_add_ln_parallel_residual_bwd(\n    const at::Tensor &dz0,     // BxSxhidden_size\n    std::optional<const at::Tensor> &dz1_,     // BxSxhidden_size\n    std::optional<const at::Tensor> &dx_,     // BxSxhidden_size\n    const at::Tensor &x,      // BxSxhidden_size\n    std::optional<const at::Tensor> &dmask0_,  // BxSxhidden_size\n    std::optional<const at::Tensor> &dmask1_,  // BxSxhidden_size\n    const at::Tensor &mu,     // BxS, FP32!\n    const at::Tensor &rsigma, // BxS, FP32!\n    const at::Tensor &gamma0,   // hidden_size\n    std::optional<const at::Tensor> &gamma1_,   // hidden_size\n    const float dropout_p,\n    const bool has_x1,\n    const bool has_residual,\n    bool is_rms_norm=false\n) {\n\n    auto itype = dz0.scalar_type();\n    auto rtype = x.scalar_type();\n    auto wtype = gamma0.scalar_type();\n    auto otype = itype;\n    auto ctype = torch::kFloat32;\n    auto mtype = torch::kUInt8;\n\n    if (dropout_p > 0.f) { TORCH_CHECK(dmask0_.has_value()); }\n\n    TORCH_CHECK(dz0.dtype() == otype);\n    TORCH_CHECK(dz0.dtype() == otype);\n    TORCH_CHECK(mu.dtype() == ctype);\n    TORCH_CHECK(rsigma.dtype() == ctype);\n\n    TORCH_CHECK(x.is_cuda());\n    TORCH_CHECK(dz0.is_cuda());\n    TORCH_CHECK(mu.is_cuda());\n    TORCH_CHECK(rsigma.is_cuda());\n    TORCH_CHECK(gamma0.is_cuda());\n\n    TORCH_CHECK(x.is_contiguous());\n    TORCH_CHECK(dz0.is_contiguous());\n\n    auto sizes = x.sizes();\n    TORCH_CHECK(sizes.size() == 2);\n    auto rows = sizes[0];\n    auto cols = sizes[1];\n    TORCH_CHECK(dz0.dim() == 2);\n    TORCH_CHECK(dz0.size(1) == cols);\n    auto hidden_size = gamma0.numel();\n    TORCH_CHECK(hidden_size == cols);\n\n    if (dz1_.has_value()) {\n        auto dz1 = dz1_.value();\n        TORCH_CHECK(dz1.dtype() == otype);\n        TORCH_CHECK(dz1.is_cuda());\n        TORCH_CHECK(dz1.is_contiguous());\n        TORCH_CHECK(dz1.sizes() == sizes);\n\n        TORCH_CHECK(gamma1_.has_value());\n        auto gamma1 = gamma1_.value();\n        TORCH_CHECK(gamma1.dtype() == wtype);\n        TORCH_CHECK(gamma1.is_cuda());\n        TORCH_CHECK(gamma1.is_contiguous());\n        TORCH_CHECK(gamma1.sizes() == gamma0.sizes());\n    }\n\n    if (dx_.has_value()) {\n        auto dx = dx_.value();\n        TORCH_CHECK(dx.dtype() == rtype);\n        TORCH_CHECK(dx.is_cuda());\n        TORCH_CHECK(dx.is_contiguous());\n        TORCH_CHECK(dx.sizes() == sizes);\n    }\n\n    if (dmask0_.has_value()) {\n        auto dmask0 = dmask0_.value();\n        TORCH_CHECK(dmask0.dtype() == mtype);\n        TORCH_CHECK(dmask0.is_cuda());\n        TORCH_CHECK(dmask0.is_contiguous());\n        TORCH_CHECK(dmask0.sizes() == sizes);\n\n        if (has_x1) {\n            TORCH_CHECK(dmask1_.has_value());\n            auto dmask1 = dmask1_.value();\n            TORCH_CHECK(dmask1.dtype() == mtype);\n            TORCH_CHECK(dmask1.is_cuda());\n            TORCH_CHECK(dmask1.is_contiguous());\n            TORCH_CHECK(dmask1.sizes() == sizes);\n        }\n    }\n\n    TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 8192));\n\n    TORCH_CHECK(mu.numel() == rows);\n    TORCH_CHECK(mu.sizes() == rsigma.sizes());\n\n    // Otherwise the kernel will be launched from cuda:0 device\n    at::cuda::CUDAGuard device_guard{dz0.device()};\n\n    auto opts = x.options();\n\n    auto dx0 = torch::empty(sizes, opts.dtype(itype));\n    at::Tensor dx1;\n    if (has_x1) { dx1 = torch::empty(sizes, opts.dtype(itype)); }\n    at::Tensor dresidual;\n    if (has_residual) { dresidual = torch::empty_like(x, opts.dtype(rtype)); }\n    auto dgamma0 = torch::empty_like(gamma0);\n    auto dbeta0 = torch::empty_like(gamma0);\n    at::Tensor dgamma1, dbeta1;\n    if (gamma1_.has_value()) {\n        dgamma1 = torch::empty_like(gamma0);\n        dbeta1 = torch::empty_like(gamma0);\n    }\n\n    layer_norm::LaunchParams<layer_norm::BwdParams> launch_params;\n    launch_params.stream = at::cuda::getCurrentCUDAStream().stream();\n    launch_params.props = at::cuda::getCurrentDeviceProperties();\n    TORCH_CHECK(dropout_p < 1.f);\n    launch_params.params.dropout_keep_p = 1.f - dropout_p;\n    launch_params.params.dresidual = has_residual ? dresidual.data_ptr() : nullptr;\n\n    auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };\n    const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024);\n    auto launcher = get_parallel_bwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));\n\n    launcher(launch_params, true);\n\n    auto dgamma0_part = torch::zeros({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));\n    auto dbeta0_part = torch::zeros({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));\n    at::Tensor dgamma1_part, dbeta1_part;\n    if (gamma1_.has_value()) {\n        dgamma1_part = torch::zeros_like(dgamma0_part);\n        dbeta1_part = torch::zeros_like(dbeta0_part);\n    }\n    at::Tensor workspace, barrier;\n\n    layer_norm::BwdParams &params = launch_params.params;\n    params.rows = rows;\n    params.cols = cols;\n    params.x = x.data_ptr();\n    params.dmask = dropout_p > 0.f ? dmask0_.value().data_ptr() : nullptr;\n    params.dmask1 = (dropout_p > 0.f && has_x1) ? dmask1_.value().data_ptr() : nullptr;\n    params.mu = mu.data_ptr();\n    params.rs = rsigma.data_ptr();\n    params.gamma = gamma0.data_ptr();\n    params.gamma1 = gamma1_.has_value() ? gamma1_.value().data_ptr() : nullptr;\n    params.dz = dz0.data_ptr();\n    params.dz1 = dz1_.has_value() ? dz1_.value().data_ptr() : nullptr;\n    params.dx = dx_.has_value() ? dx_.value().data_ptr() : nullptr;\n    params.dx0 = dx0.data_ptr();\n    params.dx1 = has_x1 ? dx1.data_ptr() : nullptr;\n    params.dbeta = dbeta0.data_ptr();\n    params.dgamma = dgamma0.data_ptr();\n    params.dbeta1 = gamma1_.has_value() ? dbeta1.data_ptr() : nullptr;\n    params.dgamma1 = gamma1_.has_value() ? dgamma1.data_ptr() : nullptr;\n    params.dbeta_part = dbeta0_part.data_ptr();\n    params.dgamma_part = dgamma0_part.data_ptr();\n    params.dbeta1_part = gamma1_.has_value() ? dbeta1_part.data_ptr() : nullptr;\n    params.dgamma1_part = gamma1_.has_value() ? dgamma1_part.data_ptr() : nullptr;\n    params.dropout_scale = 1.f / (1.f - dropout_p);\n    params.inverse_cols = 1.f / float(params.cols);\n    params.is_rms_norm = is_rms_norm;\n\n    if( launch_params.barrier_size > 0 ) {\n        // TODO Any way to avoid this?\n        barrier = torch::zeros(launch_params.barrier_size, opts.dtype(torch::kInt32));\n        workspace = torch::empty(launch_params.workspace_bytes, opts.dtype(torch::kChar));\n        params.workspace = workspace.data_ptr();\n        params.barrier = barrier.data_ptr<int>();\n    }\n\n    launcher(launch_params, false);\n\n    std::vector<at::Tensor> result = { dx0, dx1, dresidual, dgamma0, dbeta0, dgamma1, dbeta1, dgamma0_part, dbeta0_part, dgamma1_part, dbeta1_part };\n    return result;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n    m.doc() = \"CUDA DropoutAddLayerNorm\";\n    m.def(\"dropout_add_ln_fwd\", &dropout_add_ln_fwd, \"Run Dropout + Add + LayerNorm forward kernel\",\n          py::arg(\"x0\"), py::arg(\"residual\"), py::arg(\"gamma\"), py::arg(\"beta_\"),\n          py::arg(\"rowscale_\"), py::arg(\"colscale_\"), py::arg(\"x0_subset_\"), py::arg(\"z_subset_\"),\n          py::arg(\"dropout_p\"), py::arg(\"epsilon\"), py::arg(\"rowscale_const\"), py::arg(\"z_numrows\"),\n          py::arg(\"gen_\"), py::arg(\"residual_in_fp32\")=false, py::arg(\"is_rms_norm\")=false);\n    m.def(\"dropout_add_ln_bwd\", &dropout_add_ln_bwd, \"Run Dropout + Add + LayerNorm backward kernel\",\n          py::arg(\"dz\"), py::arg(\"dx_\"), py::arg(\"x\"), py::arg(\"x0_\"), py::arg(\"dmask_\"), py::arg(\"mu\"),\n          py::arg(\"rsigma\"), py::arg(\"gamma\"), py::arg(\"rowscale_\"), py::arg(\"colscale_\"),\n          py::arg(\"x0_subset_\"), py::arg(\"z_subset_\"), py::arg(\"dropout_p\"), py::arg(\"rowscale_const\"),\n          py::arg(\"x0_numrows\"), py::arg(\"has_residual\"), py::arg(\"is_rms_norm\")=false);\n    m.def(\"dropout_add_ln_parallel_residual_fwd\", &dropout_add_ln_parallel_residual_fwd, \"Run Dropout + Add + LayerNorm parallel residual forward kernel\",\n          py::arg(\"x0\"), py::arg(\"x1_\"), py::arg(\"residual\"), py::arg(\"gamma0\"), py::arg(\"beta0_\"),\n          py::arg(\"gamma1_\"), py::arg(\"beta1_\"), py::arg(\"dropout_p\"), py::arg(\"epsilon\"),\n          py::arg(\"gen_\"), py::arg(\"residual_in_fp32\")=false, py::arg(\"is_rms_norm\")=false);\n    m.def(\"dropout_add_ln_parallel_residual_bwd\", &dropout_add_ln_parallel_residual_bwd, \"Run Dropout + Add + LayerNorm parallel residual backward kernel\",\n          py::arg(\"dz0\"), py::arg(\"dz1_\"), py::arg(\"dx_\"), py::arg(\"x\"), py::arg(\"dmask0_\"),\n          py::arg(\"dmask1_\"), py::arg(\"mu\"), py::arg(\"rsigma\"), py::arg(\"gamma0\"), py::arg(\"gamma1_\"),\n          py::arg(\"dropout_p\"), py::arg(\"has_x1\"), py::arg(\"has_residual\"), py::arg(\"is_rms_norm\")=false);\n}\n"
  },
  {
    "path": "csrc/layer_norm/ln_bwd_1024.cu",
    "content": "#include \"ln_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL\n\nREGISTER_BWD_LAUNCHER(  1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(  1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(  1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(  1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(  1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(  1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(  1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(  1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(  1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(  1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);\n"
  },
  {
    "path": "csrc/layer_norm/ln_bwd_1280.cu",
    "content": "#include \"ln_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL\n\nREGISTER_BWD_LAUNCHER(  1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(  1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(  1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(  1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(  1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(  1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(  1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(  1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(  1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(  1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);\n"
  },
  {
    "path": "csrc/layer_norm/ln_bwd_1536.cu",
    "content": "#include \"ln_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL\n\nREGISTER_BWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 1, 4,  8, 4);\nREGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 1, 4,  8, 4);\nREGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 1, 4,  8, 4);\nREGISTER_BWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 1, 4,  8, 4);\nREGISTER_BWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 1, 4,  8, 4);\nREGISTER_BWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 1, 4,  8, 4);\nREGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 1, 4,  8, 4);\nREGISTER_BWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 1, 4,  8, 4);\n"
  },
  {
    "path": "csrc/layer_norm/ln_bwd_2048.cu",
    "content": "#include \"ln_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL\n\nREGISTER_BWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);"
  },
  {
    "path": "csrc/layer_norm/ln_bwd_256.cu",
    "content": "#include \"ln_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL\n\nREGISTER_BWD_LAUNCHER(  256, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(  256, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(  256, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(  256, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(  256, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(  256, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(  256, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(  256, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(  256, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(  256, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);\n"
  },
  {
    "path": "csrc/layer_norm/ln_bwd_2560.cu",
    "content": "#include \"ln_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL\n\nREGISTER_BWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 1, 4,  8, 4);\nREGISTER_BWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 1, 4,  8, 4);\nREGISTER_BWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 1, 4,  8, 4);\nREGISTER_BWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 1, 4,  8, 4);\nREGISTER_BWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 1, 4,  8, 4);\nREGISTER_BWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 1, 4,  8, 4);\nREGISTER_BWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 1, 4,  8, 4);\nREGISTER_BWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 1, 4,  8, 4);\n"
  },
  {
    "path": "csrc/layer_norm/ln_bwd_3072.cu",
    "content": "#include \"ln_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL\n\nREGISTER_BWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);"
  },
  {
    "path": "csrc/layer_norm/ln_bwd_4096.cu",
    "content": "#include \"ln_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL\n\nREGISTER_BWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);"
  },
  {
    "path": "csrc/layer_norm/ln_bwd_512.cu",
    "content": "#include \"ln_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL\n\nREGISTER_BWD_LAUNCHER(  512, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(  512, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(  512, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(  512, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(  512, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(  512, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(  512, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(  512, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(  512, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(  512, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);\n"
  },
  {
    "path": "csrc/layer_norm/ln_bwd_5120.cu",
    "content": "#include \"ln_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL\n\nREGISTER_BWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);\nREGISTER_BWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);"
  },
  {
    "path": "csrc/layer_norm/ln_bwd_6144.cu",
    "content": "#include \"ln_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL\n\nREGISTER_BWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);\nREGISTER_BWD_LAUNCHER( 6144, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);\nREGISTER_BWD_LAUNCHER( 6144, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);\nREGISTER_BWD_LAUNCHER( 6144, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);\nREGISTER_BWD_LAUNCHER( 6144, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);\nREGISTER_BWD_LAUNCHER( 6144, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);\nREGISTER_BWD_LAUNCHER( 6144, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);\nREGISTER_BWD_LAUNCHER( 6144, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);\nREGISTER_BWD_LAUNCHER( 6144, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);\nREGISTER_BWD_LAUNCHER( 6144, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);"
  },
  {
    "path": "csrc/layer_norm/ln_bwd_7168.cu",
    "content": "#include \"ln_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL\n\nREGISTER_BWD_LAUNCHER( 7168, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);\nREGISTER_BWD_LAUNCHER( 7168, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);\nREGISTER_BWD_LAUNCHER( 7168, fp32, fp16, fp32, fp16, fp32, 1, 1, 8,  8, 4);\nREGISTER_BWD_LAUNCHER( 7168, fp16, fp16, fp32, fp16, fp32, 1, 1, 8,  8, 4);\nREGISTER_BWD_LAUNCHER( 7168, fp32, fp16, fp16, fp16, fp32, 1, 1, 8,  8, 4);\nREGISTER_BWD_LAUNCHER( 7168, fp32, bf16, fp32, bf16, fp32, 1, 1, 8,  8, 4);\nREGISTER_BWD_LAUNCHER( 7168, bf16, bf16, fp32, bf16, fp32, 1, 1, 8,  8, 4);\nREGISTER_BWD_LAUNCHER( 7168, fp32, bf16, bf16, bf16, fp32, 1, 1, 8,  8, 4);\nREGISTER_BWD_LAUNCHER( 7168, fp16, fp16, fp16, fp16, fp32, 1, 1, 8,  8, 4);\nREGISTER_BWD_LAUNCHER( 7168, bf16, bf16, bf16, bf16, fp32, 1, 1, 8,  8, 4);"
  },
  {
    "path": "csrc/layer_norm/ln_bwd_768.cu",
    "content": "#include \"ln_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL\n\nREGISTER_BWD_LAUNCHER(  768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(  768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(  768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(  768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(  768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(  768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(  768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(  768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(  768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);\nREGISTER_BWD_LAUNCHER(  768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);\n"
  },
  {
    "path": "csrc/layer_norm/ln_bwd_8192.cu",
    "content": "#include \"ln_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL\n\nREGISTER_BWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);\nREGISTER_BWD_LAUNCHER( 8192, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);\nREGISTER_BWD_LAUNCHER( 8192, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);\nREGISTER_BWD_LAUNCHER( 8192, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);\nREGISTER_BWD_LAUNCHER( 8192, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);\nREGISTER_BWD_LAUNCHER( 8192, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);\nREGISTER_BWD_LAUNCHER( 8192, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);\nREGISTER_BWD_LAUNCHER( 8192, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);\nREGISTER_BWD_LAUNCHER( 8192, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);\nREGISTER_BWD_LAUNCHER( 8192, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);"
  },
  {
    "path": "csrc/layer_norm/ln_bwd_kernels.cuh",
    "content": "#pragma once\n\n#include \"ln.h\"\n#include \"ln_utils.cuh\"\n#include \"ln_kernel_traits.h\"\n#include \"static_switch.h\"\n\nnamespace layer_norm {\n\ntemplate<typename Ktraits, bool Is_dropout, bool Has_colscale, bool Has_subset, bool Is_even_cols>\n__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) \nvoid ln_bwd_kernel(layer_norm::BwdParams params) {\n\n    enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };\n    enum { WARPS_M = Ktraits::WARPS_M };\n    enum { WARPS_N = Ktraits::WARPS_N };\n    enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };\n    enum { COLS = Ktraits::COLS };\n    enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };\n    enum { LDGS = Ktraits::LDGS };\n    enum { NUM_ELTS = Ktraits::ELTS_PER_LDG };\n    enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP };\n    enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };\n\n    using input_t = typename Ktraits::input_t;\n    using compute_t = typename Ktraits::compute_t;\n    using index_t = typename Ktraits::index_t;\n    using mask_t = typename Ktraits::mask_t;\n    using Ivec = typename Ktraits::Ivec;\n    using Rvec = typename Ktraits::Rvec;\n    using Ovec = typename Ktraits::Ovec;\n    using Wvec = typename Ktraits::Wvec;\n    using Cvec = typename Ktraits::Cvec;\n    using Mvec = typename Ktraits::Mvec;\n    using Reducer = typename Ktraits::Reducer;\n    using reduce_t = typename Reducer::Type;\n\n    extern __shared__ char smem_[];\n\n    const bool has_residual = params.dresidual != nullptr;\n    const bool prenorm = params.dx != nullptr;\n\n    const index_t tidx = threadIdx.x;\n    const index_t bidn = blockIdx.x % CTAS_PER_ROW;\n    const index_t bidm = blockIdx.x / CTAS_PER_ROW;\n    const index_t lane = tidx % THREADS_PER_WARP;\n    const index_t warp = tidx / THREADS_PER_WARP;\n    const index_t warp_m = warp / Ktraits::WARPS_N;\n    const index_t warp_n = warp % Ktraits::WARPS_N;\n    const index_t tid_r = warp_n * THREADS_PER_WARP + lane;\n\n    const index_t r = bidm * Ktraits::ROWS_PER_CTA + warp_m;\n    const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;\n\n    static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW);\n\n    const input_t *rowscale = static_cast<input_t *>(params.rowscale);\n    const index_t *x0_subset = static_cast<index_t *>(params.x0_subset);\n    const index_t *z_subset = static_cast<index_t *>(params.z_subset);\n\n    Cvec dzy_sum[LDGS];\n    Cvec dz_sum[LDGS];\n    Cvec dcolscale_sum[LDGS];\n\n    memset(dzy_sum, 0, sizeof(dzy_sum));\n    memset(dz_sum, 0, sizeof(dz_sum));\n    if (Has_colscale) { memset(dcolscale_sum, 0, sizeof(dcolscale_sum)); }\n\n    compute_t * smem_wgrad = reinterpret_cast<compute_t*>(smem_);\n    char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD;\n\n    Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_dgrad);\n\n    Sum<reduce_t> sum;\n\n    const index_t num_valid_ldgs =\n        ((params.cols / Ktraits::ELTS_PER_LDG) - 1 - c + Ktraits::VEC_COLS_PER_LDG) / Ktraits::VEC_COLS_PER_LDG;\n\n    Wvec gamma[LDGS];\n    Wvec colscale[LDGS];\n    index_t idx = c;\n    #pragma unroll\n    for( int it = 0; it < LDGS; it++ ) {\n        if (Is_even_cols || (it < num_valid_ldgs)) {\n            gamma[it].load_from(params.gamma, idx);\n            if (Has_colscale) { colscale[it].load_from(params.colscale, idx); }\n            idx += Ktraits::VEC_COLS_PER_LDG;\n        }\n    }\n    // TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the\n    // last blocks with syncthreads!\n    // grid stride over rows\n    #pragma unroll 1\n    for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {\n        const compute_t mu_r = static_cast<const compute_t *>(params.mu)[row];\n        const compute_t rs_r = static_cast<const compute_t *>(params.rs)[row];\n        const compute_t rowscale_val = !Has_subset ? (params.rowscale == nullptr ? 1.0f : compute_t(rowscale[row])) : params.rowscale_const;\n        const int row_z = !Has_subset ? row + 1 : z_subset[row];\n        const int row_x0 = !Has_subset ? row + 1 : x0_subset[row];\n        const bool load_dz = !Has_subset || row_z > 0;\n        const bool save_dx0 = !Has_subset || row_x0 > 0;\n        Mvec dmask[LDGS];\n        Rvec dx[LDGS];\n        compute_t dy[LDGS * NUM_ELTS];\n        compute_t y[LDGS * NUM_ELTS];\n        compute_t mdy_local = 0.f;\n        compute_t mdyy_local = 0.f;\n        // If dz is not loaded, then dy should be 0 and we don't care about the value of y.\n        if (load_dz) {\n            index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c;\n            index_t idx_z = !Has_subset ? idx_x : (load_dz ? (row_z - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0);\n            index_t idx_x0 = !Has_subset ? idx_x : (save_dx0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0);\n            #pragma unroll\n            for( int it = 0; it < LDGS; it++ ) {\n                if (Is_even_cols || (it < num_valid_ldgs)) {\n                    Rvec x;\n                    Ovec dz;\n                    dz.load_from(params.dz, !Has_subset ? idx_x : idx_z);\n                    if (prenorm) { dx[it].load_from(params.dx, idx_x); }\n                    x.load_from(params.x, idx_x);\n                    if (Is_dropout) { dmask[it].load_from(params.dmask, !Has_subset ? idx_x : idx_x0); }\n                    idx_x += Ktraits::VEC_COLS_PER_LDG;\n                    idx_z += Ktraits::VEC_COLS_PER_LDG;\n                    idx_x0 += Ktraits::VEC_COLS_PER_LDG;\n                    #pragma unroll\n                    for( int jt = 0; jt < NUM_ELTS; jt++ ) {\n                        compute_t x_tmp = x.data.elt[jt];\n                        compute_t y_tmp = rs_r * (x_tmp - (!params.is_rms_norm ? mu_r : 0.f));\n                        compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]) * compute_t(dz.data.elt[jt]);\n                        compute_t dz_tmp = dz.data.elt[jt];\n\n                        mdy_local += dy_tmp;\n                        mdyy_local += dy_tmp * y_tmp;\n\n                        dy[it * NUM_ELTS + jt] = dy_tmp;\n                        y[it * NUM_ELTS + jt] = y_tmp;\n\n                        dzy_sum[it].data.elt[jt] += dz_tmp * y_tmp;\n                        dz_sum[it].data.elt[jt] += dz_tmp;\n                    }\n                }\n            }\n        } else {\n            index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c;\n            index_t idx_x0 = !Has_subset ? idx_x : (save_dx0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0);\n            #pragma unroll\n            for( int it = 0; it < LDGS; it++ ) {\n                if (Is_even_cols || (it < num_valid_ldgs)) {\n                    if (prenorm) { dx[it].load_from(params.dx, idx_x); }\n                    if (Is_dropout) { dmask[it].load_from(params.dmask, !Has_subset ? idx_x : idx_x0); }\n                    idx_x += Ktraits::VEC_COLS_PER_LDG;\n                    idx_x0 += Ktraits::VEC_COLS_PER_LDG;\n                }\n            }\n        }\n\n        reduce_t result = reducer.allreduce({mdy_local, mdyy_local}, sum);\n        mdy_local = layer_norm::Get<0>::of<reduce_t, compute_t>(result) * params.inverse_cols;\n        mdyy_local = layer_norm::Get<1>::of<reduce_t, compute_t>(result) * params.inverse_cols;\n\n        index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c;\n        index_t idx_x0 = !Has_subset ? idx_x : (save_dx0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0);\n        #pragma unroll\n        for( int it = 0; it < LDGS; it++ ) {\n            if (Is_even_cols || (it < num_valid_ldgs)) {\n                Ivec dx0;\n                Rvec dresidual;\n                Ivec x0;\n                if (Has_colscale && save_dx0) { x0.load_from(params.x0, !Has_subset ? idx_x : idx_x0); }\n                #pragma unroll\n                for( int jt = 0; jt < NUM_ELTS; jt++ ) {\n                    compute_t dx_tmp_res;\n                    if (load_dz) {\n                        compute_t dy_tmp = dy[it * NUM_ELTS + jt];\n                        compute_t y_tmp = y[it * NUM_ELTS + jt];\n                        compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + (!params.is_rms_norm ? mdy_local : 0.f)));\n                        dx_tmp_res = prenorm ? dx_tmp + compute_t(dx[it].data.elt[jt]) : dx_tmp;\n                    } else {\n                        dx_tmp_res = prenorm ? compute_t(dx[it].data.elt[jt]) : 0.f;\n                    }\n                    if (has_residual) { dresidual.data.elt[jt] = dx_tmp_res; }\n                    if (save_dx0) {\n                        compute_t dx0_tmp_res = dx_tmp_res * rowscale_val;\n                        if (Is_dropout) {\n                            dx0_tmp_res *= params.dropout_scale;\n                            if (Has_colscale) {\n                                dcolscale_sum[it].data.elt[jt] += dmask[it].data.elt[jt] ? dx0_tmp_res * compute_t(x0.data.elt[jt]) : 0.f;\n                                dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res * compute_t(colscale[it].data.elt[jt]) : 0.f;\n                            } else {\n                                dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res : 0.f;\n                            }\n                        } else {\n                            if (Has_colscale) {\n                                dcolscale_sum[it].data.elt[jt] += dx0_tmp_res * compute_t(x0.data.elt[jt]);\n                                dx0.data.elt[jt] = dx0_tmp_res * compute_t(colscale[it].data.elt[jt]);\n                            } else {\n                                dx0.data.elt[jt] = dx0_tmp_res;\n                            }\n                        }\n                    }\n                }\n                if (has_residual) { dresidual.store_to(params.dresidual, idx_x); }\n                if (save_dx0) { dx0.store_to(params.dx0, !Has_subset ? idx_x : idx_x0); }\n                idx_x += Ktraits::VEC_COLS_PER_LDG;\n                idx_x0 += Ktraits::VEC_COLS_PER_LDG;\n            }\n        }\n\n    }  // end: grid stride loop\n\n    if( WARPS_M == 1 ) {\n        idx = r * params.cols / Ktraits::ELTS_PER_LDG + c;\n        #pragma unroll\n        for( int it = 0; it < LDGS; it++ ) {\n            if (Is_even_cols || (it < num_valid_ldgs)) {\n                dz_sum[it].store_to(params.dbeta_part, idx);\n                dzy_sum[it].store_to(params.dgamma_part, idx);\n                if (Has_colscale) { dcolscale_sum[it].store_to(params.dcolscale_part, idx); }\n                idx += Ktraits::VEC_COLS_PER_LDG;\n            }\n        }\n    } else {\n        static_assert(WARPS_M == 1 || Ktraits::CTAS_PER_ROW == 1, \"Multiple rows per CTA not supported for Multi-CTA.\");\n        // Finalize reduction of part dgamma and dbeta for this CTA\n        // by reducing over the rows held across the WARPS_M warps\n\n        // Assumption: blockSize divides hidden size.\n        enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA };\n        static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, \"\");\n\n        idx = warp_m * Ktraits::VEC_COLS + tid_r;\n        #pragma unroll\n        for( int it = 0; it < LDGS; it++ ) {\n            dz_sum[it].store_to(smem_wgrad, idx);\n            idx += THREADS_PER_ROW;\n        }\n        __syncthreads();\n        compute_t cta_dz_sum[NUM_RES];\n        memset(cta_dz_sum, 0, sizeof(compute_t) * NUM_RES);\n        for( int it = 0; it < ROWS_PER_CTA; it++ ) {\n            for( int jt = 0; jt < NUM_RES; jt++ ) {\n                cta_dz_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];\n            }\n        }\n        __syncthreads();\n\n        idx = warp_m * Ktraits::VEC_COLS + tid_r;\n        #pragma unroll\n        for( int it = 0; it < LDGS; it++ ) {\n            dzy_sum[it].store_to(smem_wgrad, idx);\n            idx += THREADS_PER_ROW;\n        }\n        __syncthreads();\n        compute_t cta_dzy_sum[NUM_RES];\n        memset(cta_dzy_sum, 0, sizeof(compute_t) * NUM_RES);\n        for( int it = 0; it < ROWS_PER_CTA; it++ ) {\n            for( int jt = 0; jt < NUM_RES; jt++ ) {\n                cta_dzy_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];\n            }\n        }\n\n        compute_t cta_dcolscale_sum[NUM_RES];\n        if (Has_colscale) {\n            __syncthreads();\n            idx = warp_m * Ktraits::VEC_COLS + tid_r;\n            #pragma unroll\n            for( int it = 0; it < LDGS; it++ ) {\n                dcolscale_sum[it].store_to(smem_wgrad, idx);\n                idx += THREADS_PER_ROW;\n            }\n            __syncthreads();\n            memset(cta_dcolscale_sum, 0, sizeof(compute_t) * NUM_RES);\n            for( int it = 0; it < ROWS_PER_CTA; it++ ) {\n                for( int jt = 0; jt < NUM_RES; jt++ ) {\n                    cta_dcolscale_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];\n                }\n            }\n        }\n\n        const index_t num_valid_writes\n            = (params.cols - 1 - tidx + Ktraits::THREADS_PER_CTA) / Ktraits::THREADS_PER_CTA;\n        compute_t *dgamma_part = static_cast<compute_t *>(params.dgamma_part) + bidm * params.cols + tidx;\n        compute_t *dbeta_part = static_cast<compute_t *>(params.dbeta_part) + bidm * params.cols + tidx;\n        compute_t *dcolscale_part = Has_colscale ? static_cast<compute_t *>(params.dcolscale_part) + bidm * params.cols + tidx : nullptr;\n        for( int jt = 0; jt < NUM_RES; jt++ ) {\n            if (Is_even_cols || (jt < num_valid_writes)) {\n                *dgamma_part = cta_dzy_sum[jt];\n                dgamma_part += Ktraits::THREADS_PER_CTA;\n                *dbeta_part = cta_dz_sum[jt];\n                dbeta_part += Ktraits::THREADS_PER_CTA;\n                if (Has_colscale) {\n                    *dcolscale_part = cta_dcolscale_sum[jt];\n                    dcolscale_part += Ktraits::THREADS_PER_CTA;\n                }\n            }\n        }\n\n    }\n}\n\ntemplate<typename Kernel_traits, bool Has_colscale, bool Is_even_cols>\n__global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA)\nvoid ln_bwd_finalize_kernel(BwdParams params)\n{\n\n    using compute_t = typename Kernel_traits::compute_t;\n    using weight_t = typename Kernel_traits::weight_t;\n    using index_t = typename Kernel_traits::index_t;\n    using Reducer = typename Kernel_traits::Reducer;\n    using reduce_t = typename Reducer::Type;\n\n    Sum<reduce_t> sum;\n    enum { NUM_ELT = Kernel_traits::ELTS_PER_LDG };\n    enum { THREADS_PER_WARP = Kernel_traits::THREADS_PER_WARP };\n\n    __shared__ char smem_[Kernel_traits::SMEM_BYTES_PER_CTA];\n\n    constexpr uint32_t bidm = 0;\n\n    const uint32_t bidn = blockIdx.x;\n    const uint32_t tidx = threadIdx.x;\n    const uint32_t warp = tidx / THREADS_PER_WARP;\n    const uint32_t lane = tidx % THREADS_PER_WARP;\n\n    Reducer reducer(params, bidm, bidn, 0, 0, lane, smem_);\n\n    const uint32_t c = bidn * THREADS_PER_WARP + lane;\n    const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane;\n    constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP;\n    for( uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS; col += COL_STRIDE, col_out += COL_STRIDE / 2 ) {\n        // Each thread sums over NUM_ELT columns.\n        Vec<compute_t, NUM_ELT> dbeta_local, dgamma_local, dcolscale_local;\n        memset(&dgamma_local, 0, sizeof(dgamma_local));\n        memset(&dbeta_local, 0, sizeof(dbeta_local));\n        if (Has_colscale) { memset(&dcolscale_local, 0, sizeof(dcolscale_local)); }\n        if (Is_even_cols || col < params.cols) {\n            for( uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA ) {\n                index_t idx = row * params.cols + col;\n\n                Vec<compute_t, NUM_ELT> dbeta_part, dgamma_part, dcolscale_part;\n                dbeta_part.load_from(params.dbeta_part, idx);\n                dgamma_part.load_from(params.dgamma_part, idx);\n                if (Has_colscale) { dcolscale_part.load_from(params.dcolscale_part, idx); }\n                #pragma unroll\n                for( int it = 0; it < NUM_ELT; it++ ) {\n                    dgamma_local.data.elt[it] += dgamma_part.data.elt[it];\n                    dbeta_local.data.elt[it] += dbeta_part.data.elt[it];\n                    if (Has_colscale) { dcolscale_local.data.elt[it] += dcolscale_part.data.elt[it]; }\n                }\n            }\n        }\n        void * smem_gamma = smem_;\n        void * smem_beta = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE];\n        void * smem_colscale = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE];\n\n        const int write_row = warp;\n        const int write_col = lane ^ write_row;\n        const int write_idx = write_row * THREADS_PER_WARP + write_col;\n\n        dgamma_local.store_to(smem_gamma, write_idx);\n        dbeta_local.store_to(smem_beta, write_idx);\n        if (Has_colscale) { dcolscale_local.store_to(smem_colscale, write_idx); }\n\n        __syncthreads();\n\n        // It would be probably safe to reuse the first row of smem_beta and smem_gamma\n        void * smem_gamma_out = &smem_[Kernel_traits::NUM_FACTORS * Kernel_traits::SMEM_BYTES_TRANSPOSE];\n        void * smem_beta_out = &smem_[Kernel_traits::NUM_FACTORS * Kernel_traits::SMEM_BYTES_TRANSPOSE + Kernel_traits::SMEM_BYTES_OUTPUT];\n        void * smem_colscale_out = &smem_[Kernel_traits::NUM_FACTORS * Kernel_traits::SMEM_BYTES_TRANSPOSE + 2 * Kernel_traits::SMEM_BYTES_OUTPUT];\n\n\n        // More than one iter iff ROWS_PER_CTA < 32.\n        for( int w = warp; w < THREADS_PER_WARP; w += Kernel_traits::ROWS_PER_CTA ) {\n            const int read_row = lane;\n            const int read_col = w ^ read_row;\n            const int read_idx = read_row * THREADS_PER_WARP + read_col;\n\n            memset(&dbeta_local, 0, sizeof(dbeta_local));\n            memset(&dgamma_local, 0, sizeof(dgamma_local));\n            if (Has_colscale) { memset(&dcolscale_local, 0, sizeof(dcolscale_local)); }\n\n            // Load beta and gamma transposed \n            if(read_row < Kernel_traits::ROWS_PER_CTA){\n                dbeta_local.load_from(smem_beta, read_idx);\n                dgamma_local.load_from(smem_gamma, read_idx);\n                if (Has_colscale) { dcolscale_local.load_from(smem_colscale, read_idx); }\n            }\n\n            // Call reducer on the loaded value(s) and convert.\n            #pragma unroll\n            for( int it = 0; it < NUM_ELT; it++ ) {\n                compute_t b_i = dbeta_local.data.elt[it];\n                compute_t g_i = dgamma_local.data.elt[it];\n                b_i = reducer.allreduce(b_i, sum);\n                g_i = reducer.allreduce(g_i, sum);\n\n                dgamma_local.data.elt[it] = g_i;\n                dbeta_local.data.elt[it] = b_i;\n                if (Has_colscale) {\n                    compute_t cs_i = dcolscale_local.data.elt[it];\n                    cs_i = reducer.allreduce(cs_i, sum);\n                    dcolscale_local.data.elt[it] = cs_i;\n                }\n            }\n\n            // Leader stores the result at the current column.\n            if(lane == 0){\n                dgamma_local.store_to(smem_gamma_out, w);\n                dbeta_local.store_to(smem_beta_out, w);\n                if (Has_colscale) { dcolscale_local.store_to(smem_colscale_out, w); }\n            }\n\n        }\n\n        // All writes done.\n        __syncthreads();\n\n        // Pack and store: 2-wide stores with half the threads.\n        if (Is_even_cols || col_out * 2 < params.cols) {\n            if( warp == Kernel_traits::ROWS_PER_CTA - 1 && lane < THREADS_PER_WARP / 2 ) {\n\n                using src_t = typename TypeToVec2<compute_t>::Type;\n                using dst_t = typename TypeToVec2<weight_t>::Type;\n                Vec<src_t, NUM_ELT> dbeta_vec2, dgamma_vec2, dcolscale_vec2;\n                Vec<dst_t, NUM_ELT> dbeta_out2, dgamma_out2, dcolscale_out2;\n\n                dgamma_vec2.load_from(smem_gamma_out, lane);\n                dbeta_vec2.load_from(smem_beta_out, lane);\n                if (Has_colscale) { dcolscale_vec2.load_from(smem_colscale_out, lane); }\n                #pragma unroll\n                for( int it = 0; it < NUM_ELT; it++ ) {\n                    dgamma_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dgamma_vec2.data.elt[it]);\n                    dbeta_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dbeta_vec2.data.elt[it]);\n                    if (Has_colscale) { dcolscale_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dcolscale_vec2.data.elt[it]); }\n                }\n                dgamma_out2.store_to(params.dgamma, col_out);\n                dbeta_out2.store_to(params.dbeta, col_out);\n                if (Has_colscale) { dcolscale_out2.store_to(params.dcolscale, col_out); }\n            }\n        }\n    }\n}\n}  // namespace layer_norm\n\nusing namespace layer_norm;\n\ntemplate<\n    typename weight_t,\n    typename input_t,\n    typename residual_t,\n    typename output_t,\n    typename compute_t,\n    typename index_t,\n    int HIDDEN_SIZE,\n    int CTAS_PER_ROW,\n    int WARPS_M,\n    int WARPS_N,\n    int BYTES_PER_LDG_MAIN,\n    int BYTES_PER_LDG_FINAL\n>\nvoid launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params){\n\n    using Kernel_traits = Kernel_traits<weight_t,\n                                        input_t,\n                                        residual_t,\n                                        output_t,\n                                        compute_t,\n                                        index_t,\n                                        HIDDEN_SIZE,\n                                        CTAS_PER_ROW,\n                                        WARPS_M,\n                                        WARPS_N,\n                                        BYTES_PER_LDG_MAIN\n                                        >;\n    bool is_dropout = launch_params.params.dropout_keep_p < 1.f;\n    bool has_colscale = launch_params.params.colscale != nullptr;\n    bool has_subset = launch_params.params.x0_subset != nullptr;\n    bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE;\n    BOOL_SWITCH(is_dropout, IsDropoutConst, [&] {\n        BOOL_SWITCH(has_colscale, HasColscaleConst, [&] {\n            BOOL_SWITCH(has_subset, HasSubsetConst, [&] {\n                BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] {\n                    auto kernel = &ln_bwd_kernel<Kernel_traits, IsDropoutConst, HasColscaleConst, HasSubsetConst, IsEvenColsConst>;\n                    if( configure_params ) {\n                        int ctas_per_sm;\n                        CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(\n                            &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES));\n                        launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;\n                        launch_params.barrier_size = 0;\n                        launch_params.workspace_bytes = 0;\n                        if(Kernel_traits::CTAS_PER_ROW > 1) {\n                            launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;\n                            launch_params.workspace_bytes = launch_params.params.ctas_per_col\n                                                          * Kernel_traits::WARPS_M\n                                                          * Kernel_traits::CTAS_PER_ROW\n                                                          * sizeof(typename Kernel_traits::reduce_t)\n                                                          * 2;\n                        }\n                        return;\n                    }\n\n                    if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) {\n                        CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES));\n                    }\n                    auto stream = launch_params.stream;\n                    auto ctas_per_col = launch_params.params.ctas_per_col;\n\n                    if( Kernel_traits::CTAS_PER_ROW == 1 ) {\n                        kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>>(launch_params.params);\n                    } else {\n                        dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);\n                        dim3 block(Kernel_traits::THREADS_PER_CTA);\n                        void *params_ = (void *)&launch_params.params;\n                        cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES, stream);\n                    }\n\n                    using Kernel_traits_f = layer_norm::Kernel_traits_finalize<HIDDEN_SIZE,\n                                                                              weight_t,\n                                                                              input_t,\n                                                                              residual_t,\n                                                                              output_t,\n                                                                              compute_t,\n                                                                              index_t,\n                                                                              HasColscaleConst,\n                                                                              32 * 32,  // THREADS_PER_CTA\n                                                                              BYTES_PER_LDG_FINAL>;\n\n                    auto kernel_f = &layer_norm::ln_bwd_finalize_kernel<Kernel_traits_f, HasColscaleConst, IsEvenColsConst>;\n                    kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(launch_params.params);\n                });\n            });\n        });\n    });\n}\n"
  },
  {
    "path": "csrc/layer_norm/ln_fwd_1024.cu",
    "content": "#include \"ln_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG\n\nREGISTER_FWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);\n"
  },
  {
    "path": "csrc/layer_norm/ln_fwd_1280.cu",
    "content": "#include \"ln_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG\n\nREGISTER_FWD_LAUNCHER( 1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER( 1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER( 1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER( 1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER( 1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER( 1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER( 1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER( 1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER( 1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER( 1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);\n"
  },
  {
    "path": "csrc/layer_norm/ln_fwd_1536.cu",
    "content": "#include \"ln_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG\n\nREGISTER_FWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);\n"
  },
  {
    "path": "csrc/layer_norm/ln_fwd_2048.cu",
    "content": "#include \"ln_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG\n\nREGISTER_FWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);\n"
  },
  {
    "path": "csrc/layer_norm/ln_fwd_256.cu",
    "content": "#include \"ln_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG\n\nREGISTER_FWD_LAUNCHER(  256, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER(  256, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER(  256, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER(  256, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER(  256, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER(  256, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER(  256, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER(  256, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER(  256, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER(  256, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);\n"
  },
  {
    "path": "csrc/layer_norm/ln_fwd_2560.cu",
    "content": "#include \"ln_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG\n\nREGISTER_FWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);\n"
  },
  {
    "path": "csrc/layer_norm/ln_fwd_3072.cu",
    "content": "#include \"ln_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG\n\nREGISTER_FWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16);\n"
  },
  {
    "path": "csrc/layer_norm/ln_fwd_4096.cu",
    "content": "#include \"ln_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG\n\nREGISTER_FWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16);\n"
  },
  {
    "path": "csrc/layer_norm/ln_fwd_512.cu",
    "content": "#include \"ln_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG\n\nREGISTER_FWD_LAUNCHER(  512, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER(  512, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER(  512, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER(  512, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER(  512, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER(  512, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER(  512, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER(  512, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER(  512, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER(  512, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);\n"
  },
  {
    "path": "csrc/layer_norm/ln_fwd_5120.cu",
    "content": "#include \"ln_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG\n\nREGISTER_FWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16);\n"
  },
  {
    "path": "csrc/layer_norm/ln_fwd_6144.cu",
    "content": "#include \"ln_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG\n\nREGISTER_FWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16);\nREGISTER_FWD_LAUNCHER( 6144, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16);\nREGISTER_FWD_LAUNCHER( 6144, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16);\nREGISTER_FWD_LAUNCHER( 6144, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16);\nREGISTER_FWD_LAUNCHER( 6144, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16);\nREGISTER_FWD_LAUNCHER( 6144, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16);\nREGISTER_FWD_LAUNCHER( 6144, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16);\nREGISTER_FWD_LAUNCHER( 6144, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16);\nREGISTER_FWD_LAUNCHER( 6144, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16);\nREGISTER_FWD_LAUNCHER( 6144, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16);\n"
  },
  {
    "path": "csrc/layer_norm/ln_fwd_7168.cu",
    "content": "#include \"ln_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG\n\nREGISTER_FWD_LAUNCHER( 7168, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16);\nREGISTER_FWD_LAUNCHER( 7168, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16);\nREGISTER_FWD_LAUNCHER( 7168, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER( 7168, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER( 7168, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER( 7168, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER( 7168, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER( 7168, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER( 7168, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16);\nREGISTER_FWD_LAUNCHER( 7168, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16);\n"
  },
  {
    "path": "csrc/layer_norm/ln_fwd_768.cu",
    "content": "#include \"ln_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG\n\nREGISTER_FWD_LAUNCHER(  768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER(  768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER(  768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER(  768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER(  768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER(  768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER(  768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER(  768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER(  768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);\nREGISTER_FWD_LAUNCHER(  768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);\n"
  },
  {
    "path": "csrc/layer_norm/ln_fwd_8192.cu",
    "content": "#include \"ln_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG\n\nREGISTER_FWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16);\nREGISTER_FWD_LAUNCHER( 8192, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16);\nREGISTER_FWD_LAUNCHER( 8192, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16);\nREGISTER_FWD_LAUNCHER( 8192, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16);\nREGISTER_FWD_LAUNCHER( 8192, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16);\nREGISTER_FWD_LAUNCHER( 8192, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16);\nREGISTER_FWD_LAUNCHER( 8192, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16);\nREGISTER_FWD_LAUNCHER( 8192, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16);\nREGISTER_FWD_LAUNCHER( 8192, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16);\nREGISTER_FWD_LAUNCHER( 8192, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16);\n"
  },
  {
    "path": "csrc/layer_norm/ln_fwd_kernels.cuh",
    "content": "#pragma once\n\n#ifdef OLD_GENERATOR_PATH\n#include <ATen/CUDAGeneratorImpl.h>\n#else\n#include <ATen/cuda/CUDAGeneratorImpl.h>\n#endif\n\n#include <ATen/cuda/detail/UnpackRaw.cuh>  // For at::cuda::philox::unpack\n#include <curand_kernel.h>\n\n#include \"ln.h\"\n#include \"ln_utils.cuh\"\n#include \"ln_kernel_traits.h\"\n#include \"static_switch.h\"\n\nnamespace layer_norm {\n\ntemplate<typename Ktraits, bool Is_dropout, bool Has_colscale, bool Has_subset, bool Is_even_cols>\n__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) \nvoid ln_fwd_kernel(FwdParams params) {\n\n    enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };\n    enum { WARPS_N = Ktraits::WARPS_N };\n    enum { WARPS_M = Ktraits::WARPS_M };\n    enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };\n    enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG };\n    enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };\n    enum { LDGS = Ktraits::LDGS };\n    enum { NUM_ELTS = Ktraits::NUM_ELTS };\n    enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };\n\n    using input_t = typename Ktraits::input_t;\n    using residual_t = typename Ktraits::residual_t;\n    using output_t = typename Ktraits::output_t;\n    using index_t = typename Ktraits::index_t;\n    using compute_t = typename Ktraits::compute_t;\n    using mask_t = typename Ktraits::mask_t;\n    using Ivec = typename Ktraits::Ivec;\n    using Rvec = typename Ktraits::Rvec;\n    using Ovec = typename Ktraits::Ovec;\n    using Wvec = typename Ktraits::Wvec;\n    using Cvec = typename Ktraits::Cvec;\n    using Mvec = typename Ktraits::Mvec;\n\n    using Stats = typename Ktraits::Stats;\n    using stats_t = typename Stats::stats_t;\n\n    const bool has_residual = params.residual != nullptr;\n    const bool save_x = has_residual || Is_dropout || Has_colscale || (params.rowscale != nullptr) || Has_subset || !(std::is_same<input_t, residual_t>::value);\n\n    extern __shared__ char smem_[];\n\n    const index_t tidx = threadIdx.x;\n    const index_t bidn = blockIdx.x % CTAS_PER_ROW;\n    const index_t bidm = blockIdx.x / CTAS_PER_ROW;\n    const index_t lane = tidx % THREADS_PER_WARP;\n    const index_t warp = tidx / THREADS_PER_WARP;\n    const index_t warp_m = warp / WARPS_N;\n    const index_t warp_n = warp % WARPS_N;\n\n    const index_t r = bidm * ROWS_PER_CTA + warp_m;\n    const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;\n\n    Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem_);\n\n    compute_t *mu_ptr = static_cast<compute_t *>(params.mu);\n    compute_t *rs_ptr = static_cast<compute_t *>(params.rs);\n\n    const input_t *rowscale = static_cast<input_t *>(params.rowscale);\n    const index_t *x0_subset = static_cast<index_t *>(params.x0_subset);\n    const index_t *z_subset = static_cast<index_t *>(params.z_subset);\n\n    // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/Dropout.cu\n    curandStatePhilox4_32_10_t state;\n    if (Is_dropout) {\n        auto seeds = at::cuda::philox::unpack(params.philox_args);\n        const index_t tidx_global = blockIdx.x * blockDim.x + threadIdx.x;\n        curand_init(std::get<0>(seeds), tidx_global, std::get<1>(seeds), &state);\n    }\n\n    const index_t num_valid_ldgs = ((params.cols / Ktraits::ELTS_PER_LDG) - 1 - c + VEC_COLS_PER_LDG) / VEC_COLS_PER_LDG;\n\n    Wvec gamma[LDGS];\n    Wvec beta[LDGS];\n    Wvec colscale[LDGS];\n    index_t idx = c;\n    #pragma unroll\n    for( int it = 0; it < LDGS; it++ ) {\n        if (Is_even_cols || (it < num_valid_ldgs)) {\n            gamma[it].load_from(params.gamma, idx);\n            if (params.beta != nullptr) {\n                beta[it].load_from(params.beta, idx);\n            } else {\n                beta[it].zero_();\n            }\n            if (Has_colscale) { colscale[it].load_from(params.colscale, idx); }\n            idx += VEC_COLS_PER_LDG;\n        }\n    }\n\n    for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {\n        const compute_t rowscale_val = !Has_subset ? (params.rowscale == nullptr ? 1.0f : compute_t(rowscale[row])) : params.rowscale_const;\n        const int row_x0 = !Has_subset ? row + 1 : x0_subset[row];\n        const int row_z = !Has_subset ? row + 1 : z_subset[row];\n        const bool load_x0 = !Has_subset || row_x0 > 0;\n        index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c;\n        index_t idx_x0 = !Has_subset ? idx_x : (load_x0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0);\n        compute_t xf[LDGS * NUM_ELTS];\n        #pragma unroll\n        for( int it = 0; it < LDGS; it++ ) {\n            if (Is_even_cols || (it < num_valid_ldgs)) {\n                Ivec x0;\n                Rvec residual;\n                Rvec x;\n                Mvec dmask;\n                if (load_x0) { x0.load_from(params.x0, !Has_subset ? idx_x : idx_x0); }\n                if (has_residual) { residual.load_from(params.residual, idx_x); }\n                #pragma unroll\n                for( int jt = 0; jt < NUM_ELTS; jt++ ) {\n                    // TD [2022-04-22]: We're memory bound, not compute bound, so we don't need to use\n                    // the more efficient curand_uniform4.\n                    compute_t x_ij;\n                    if (load_x0) {\n                        mask_t keep = !Is_dropout ? true : curand_uniform(&state) <= params.dropout_keep_p;\n                        if (Is_dropout) { dmask.data.elt[jt] = keep; }\n                        compute_t x0_ij = compute_t(x0.data.elt[jt]) * rowscale_val;\n                        x0_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.0f;\n                        if (Has_colscale) { x0_ij *= compute_t(colscale[it].data.elt[jt]); }\n                        x_ij = has_residual ? x0_ij + compute_t(residual.data.elt[jt]) : x0_ij;\n                    } else {\n                        x_ij = has_residual ? compute_t(residual.data.elt[jt]) : 0.f;\n                    }\n                    if (save_x) { x.data.elt[jt] = x_ij; }\n                    xf[it * NUM_ELTS + jt] = x_ij;\n                }\n                if (save_x) { x.store_to(params.x, idx_x); }\n                if (Is_dropout && load_x0) { dmask.store_to(params.dmask, !Has_subset ? idx_x : idx_x0); }\n                idx_x += VEC_COLS_PER_LDG;\n                idx_x0 += VEC_COLS_PER_LDG;\n            }\n        }\n\n        static_assert(CTAS_PER_ROW == 1, \"Don't support multiple CTAs per row for now\");\n        const index_t num_vecs = params.cols / Ktraits::ELTS_PER_LDG;\n        const index_t num_full_ldgs = num_vecs / Ktraits::VEC_COLS_PER_LDG;\n        const index_t remaining_vecs = num_vecs % Ktraits::VEC_COLS_PER_LDG;\n        auto valid_elts_in_warp_fn = [num_full_ldgs, remaining_vecs] (int warp_n) -> int {\n            // Need to convert to int, otherwise the subtraction will wrap around.\n            const index_t valid_partial_vecs_in_warp =\n                std::min(std::max(int(remaining_vecs) - int(warp_n * THREADS_PER_WARP), int(0)),\n                        int(THREADS_PER_WARP));\n            return (num_full_ldgs * THREADS_PER_WARP + valid_partial_vecs_in_warp) * NUM_ELTS;\n        };\n        stats_t s = stats.template compute<Is_even_cols>(\n            xf, params.inverse_cols, valid_elts_in_warp_fn, num_valid_ldgs * NUM_ELTS\n        );\n\n        compute_t mu = layer_norm::Get<0>::of<stats_t, compute_t>(s);\n        compute_t m2 = layer_norm::Get<1>::of<stats_t, compute_t>(s);\n\n        if( bidn == 0 && warp_n == 0 && lane == 0 ) {\n            mu_ptr[row] = mu;\n        }\n\n        compute_t rs = rsqrtf(m2 * params.inverse_cols + params.epsilon + (!params.is_rms_norm ? 0.f : mu * mu));\n\n        if( bidn == 0 && warp_n == 0 && lane == 0 ) {\n            rs_ptr[row] = rs;\n        }\n\n        const bool save_z = !Has_subset || row_z > 0;\n        if (save_z) {\n            index_t idx_z = (!Has_subset ? row : (row_z - 1)) * params.cols / Ktraits::ELTS_PER_LDG + c;\n            #pragma unroll\n            for( int it = 0; it < LDGS; it++ ) {\n                if (Is_even_cols || (it < num_valid_ldgs)) {\n                    Ovec z;\n                    #pragma unroll\n                    for( int jt = 0; jt < NUM_ELTS; jt++ ) {\n                        compute_t y_ij = compute_t(rs * (xf[it * NUM_ELTS + jt] - (!params.is_rms_norm ? mu : 0.f)));\n                        compute_t g_ij = gamma[it].data.elt[jt];\n                        compute_t b_ij = beta[it].data.elt[jt];\n                        z.data.elt[jt] = output_t(g_ij * y_ij + b_ij);\n                    }\n                    z.store_to(params.z, idx_z);\n                    idx_z += VEC_COLS_PER_LDG;\n                }\n            }\n        }\n\n    }\n}\n\n}  // namespace layer_norm\n\nusing namespace layer_norm;\n\ntemplate<\n    typename weight_t,\n    typename input_t,\n    typename residual_t,\n    typename output_t,\n    typename compute_t,\n    typename index_t,\n    int HIDDEN_SIZE,\n    int CTAS_PER_ROW,\n    int WARPS_M,\n    int WARPS_N,\n    int BYTES_PER_LDG\n>\nvoid launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params){\n\n    using Kernel_traits = Kernel_traits<weight_t,\n                                        input_t,\n                                        residual_t,\n                                        output_t,\n                                        compute_t,\n                                        index_t,\n                                        HIDDEN_SIZE,\n                                        CTAS_PER_ROW,\n                                        WARPS_M,\n                                        WARPS_N,\n                                        BYTES_PER_LDG\n                                        >;\n    bool has_colscale = launch_params.params.colscale != nullptr;\n    bool has_subset = launch_params.params.x0_subset != nullptr;\n    bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE;\n    BOOL_SWITCH(launch_params.params.dropout_keep_p < 1.f, IsDropoutConst, [&] {\n        BOOL_SWITCH(has_colscale, HasColscaleConst, [&] {\n            BOOL_SWITCH(has_subset, HasSubsetConst, [&] {\n                    BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] {\n                        auto kernel = &ln_fwd_kernel<Kernel_traits, IsDropoutConst, HasColscaleConst, HasSubsetConst, IsEvenColsConst>;\n                    if( configure_params ) {\n                        int ctas_per_sm;\n                        CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(\n                            &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD));\n                        launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;\n                        const size_t rows_per_loop = launch_params.params.ctas_per_col * Kernel_traits::ROWS_PER_CTA;\n                        launch_params.elts_per_thread = (launch_params.params.rows + rows_per_loop - 1) / rows_per_loop * Kernel_traits::LDGS * Kernel_traits::NUM_ELTS;\n                        launch_params.barrier_size = 0;\n                        launch_params.workspace_bytes = 0;\n                        if(Kernel_traits::CTAS_PER_ROW > 1) {\n                            launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;\n                            launch_params.workspace_bytes = launch_params.params.ctas_per_col\n                                                          * Kernel_traits::WARPS_M\n                                                          * Kernel_traits::CTAS_PER_ROW\n                                                          * sizeof(typename Kernel_traits::Stats::stats_t)\n                                                          * 2;\n                        }\n                        return;\n                    }\n\n                    if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) {\n                        CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD));\n                    }\n                    auto stream = launch_params.stream;\n                    auto ctas_per_col = launch_params.params.ctas_per_col;\n\n                    if( Kernel_traits::CTAS_PER_ROW == 1 ) {\n                        kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD, stream>>>(launch_params.params);\n                    } else {\n                        dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);\n                        dim3 block(Kernel_traits::THREADS_PER_CTA);\n                        void *params_ = (void *)&launch_params.params;\n                        cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES_FWD, stream);\n                    }\n                });\n            });\n        });\n    });\n}\n"
  },
  {
    "path": "csrc/layer_norm/ln_kernel_traits.h",
    "content": "#pragma once\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nnamespace layer_norm {\ntemplate<\n    uint32_t HIDDEN_SIZE_,\n    typename weight_t_,\n    typename input_t_,\n    typename residual_t_,\n    typename output_t_,\n    typename compute_t_,\n    typename index_t_,\n    uint32_t THREADS_PER_CTA_\n>\nstruct Kernel_traits_base {\n\n    using weight_t = weight_t_;\n    using input_t = input_t_;\n    using residual_t = residual_t_;\n    using output_t = output_t_;\n    using compute_t = compute_t_;\n    using index_t = index_t_;\n\n    enum { HIDDEN_SIZE = HIDDEN_SIZE_ };\n    enum { THREADS_PER_CTA = THREADS_PER_CTA_ };\n    enum { THREADS_PER_WARP = 32 };\n\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<\n    uint32_t HIDDEN_SIZE_,\n    typename weight_t_,\n    typename input_t_,\n    typename residual_t_,\n    typename output_t_,\n    typename compute_t_,\n    typename index_t_,\n    bool Has_colscale,\n    uint32_t THREADS_PER_CTA_,\n    uint32_t BYTES_PER_LDG_,\n    typename Base = Kernel_traits_base<HIDDEN_SIZE_,\n                                        weight_t_,\n                                        input_t_,\n                                        residual_t_,\n                                        output_t_,\n                                        compute_t_,\n                                        index_t_,\n                                        THREADS_PER_CTA_>\n>\nstruct Kernel_traits_finalize : public Base {\n    enum { ROWS_PER_CTA = Base::THREADS_PER_CTA / Base::THREADS_PER_WARP };\n    static_assert((int) ROWS_PER_CTA <= (int) Base::THREADS_PER_WARP);\n    // Bytes per global load from the input. \n    enum { BYTES_PER_LDG = BYTES_PER_LDG_ };\n    // Number of elements fetched by a global load.\n    enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(compute_t_) };\n    // Bytes per global store of the weights.\n    enum { BYTES_PER_STG = ELTS_PER_LDG * sizeof(weight_t_) };\n    static_assert(sizeof(BYTES_PER_LDG) == 4, \"Conflict-free smem transpose only implemented for 4B compute type!\");\n    static_assert(Base::THREADS_PER_CTA == ROWS_PER_CTA * Base::THREADS_PER_WARP, \"We assume one warp per row!\");\n    // The total number of BYTES_PER_LDG-wide words in a hidden vector.\n    enum { COLS = HIDDEN_SIZE_ * sizeof(compute_t_) / BYTES_PER_LDG };\n    static_assert(COLS * BYTES_PER_LDG == HIDDEN_SIZE_ * sizeof(compute_t_));\n\n    // Shared memory size to transpose the CTA result.\n    enum { SMEM_BYTES_TRANSPOSE = Base::THREADS_PER_CTA * BYTES_PER_LDG };\n    // Shared memory size to coalsece the CTA result.\n    enum { SMEM_BYTES_OUTPUT = Base::THREADS_PER_WARP * BYTES_PER_LDG };\n    // Shared memory requirement per CTA. \n    static constexpr int NUM_FACTORS = Has_colscale ? 3 : 2;\n    enum { SMEM_BYTES_PER_CTA = NUM_FACTORS * SMEM_BYTES_TRANSPOSE + NUM_FACTORS * SMEM_BYTES_OUTPUT };\n\n    // The type of the reducer.\n    using Reducer = layer_norm::Reducer<compute_t_, 1, 1, 1>;\n\n    // Condition for the whole CTA to participate in syncthreads.\n    static_assert(COLS % Base::THREADS_PER_WARP == 0);\n    enum { CTAS = COLS / Base::THREADS_PER_WARP };\n}; \n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n\ntemplate<\n    typename weight_t_,\n    typename input_t_,\n    typename residual_t_,\n    typename output_t_,\n    typename compute_t_,\n    typename index_t_,\n    uint32_t HIDDEN_SIZE_, \n    uint32_t CTAS_PER_ROW_, \n    uint32_t WARPS_M_, \n    uint32_t WARPS_N_, \n    uint32_t BYTES_PER_LDG_ = 16,\n    typename Base = Kernel_traits_base<\n        HIDDEN_SIZE_,\n        weight_t_, \n        input_t_,\n        residual_t_,\n        output_t_, \n        compute_t_, \n        index_t_, \n        WARPS_M_*WARPS_N_*THREADS_PER_WARP\n        >\n>\nstruct Kernel_traits : public Base {\n\n    using input_t = typename Base::input_t;\n    using residual_t = typename Base::residual_t;\n    using weight_t = typename Base::weight_t;\n    using compute_t = typename Base::compute_t;\n    using output_t = typename Base::output_t;\n    using index_t = typename Base::index_t;\n    // using mask_t = unsigned char;\n    using mask_t = bool;\n\n    enum { CTAS_PER_ROW = CTAS_PER_ROW_ };\n    enum { WARPS_M = WARPS_M_ };\n    enum { WARPS_N = WARPS_N_ };\n    enum { COLS = HIDDEN_SIZE_ };\n    enum { HIDDEN_SIZE = HIDDEN_SIZE_ };\n    enum { BYTES_PER_LDG = BYTES_PER_LDG_ };\n    enum { NUM_ELTS = BYTES_PER_LDG / sizeof(input_t) };\n\n    enum { THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP };\n    enum { THREADS_PER_CTA = WARPS_M * THREADS_PER_ROW };\n    enum { ROWS_PER_CTA = WARPS_M };\n\n    enum { BYTES_PER_ROW = COLS * sizeof(input_t) };\n    enum { BYTES_PER_ROW_PER_CTA = THREADS_PER_ROW * BYTES_PER_LDG };\n    // Multi-row per CTA not supported for multi-CTA => no smem for WGRAD needed\n    enum { SMEM_BYTES_WGRAD = CTAS_PER_ROW > 1 ? 0 : ROWS_PER_CTA * COLS * sizeof(compute_t) };\n    static_assert(WARPS_M == 1 || CTAS_PER_ROW == 1);\n\n    using reduce_t = typename layer_norm::TypeToVec2<compute_t>::Type;\n    using Reducer = layer_norm::Reducer<reduce_t, CTAS_PER_ROW, WARPS_M, WARPS_N>; \n\n    enum { SMEM_BYTES_DGRAD = Reducer::SMEM_BYTES };\n    enum { SMEM_BYTES = SMEM_BYTES_DGRAD  + SMEM_BYTES_WGRAD };\n\n    using Ivec = layer_norm::Vec<input_t, NUM_ELTS>;\n    using Rvec = layer_norm::Vec<residual_t, NUM_ELTS>;\n    using Ovec = layer_norm::Vec<output_t, NUM_ELTS>;\n    using Wvec = layer_norm::Vec<weight_t, NUM_ELTS>;\n    using Cvec = layer_norm::Vec<compute_t, NUM_ELTS>;\n    using Mvec = layer_norm::Vec<mask_t, NUM_ELTS>;\n    enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(input_t) };\n\n    // Assume that each thread can handle the same number of elements in the output and weights as in the input.\n    static_assert(sizeof(input_t) == sizeof(output_t));\n    static_assert(sizeof(input_t) <= sizeof(residual_t));\n    // The number of columns fetched per load from input: one per thread.\n    enum { VEC_COLS_PER_LDG =  CTAS_PER_ROW * THREADS_PER_ROW };\n    // The total number of vectorized loads/stores per hidden vector.\n    enum { VEC_COLS = COLS / ELTS_PER_LDG };\n    // The number of loads per thread for the input.\n    enum { LDGS = VEC_COLS / VEC_COLS_PER_LDG };\n    static_assert(LDGS * VEC_COLS_PER_LDG  == VEC_COLS);\n    //static_assert(LDGS * BYTES_PER_ROW_PER_CTA * CTAS_PER_ROW == BYTES_PER_ROW, \"\");\n\n    using Stats = layer_norm::Stats<compute_t, CTAS_PER_ROW, WARPS_M, WARPS_N>;\n    enum { SMEM_BYTES_FWD = Stats::SMEM_BYTES };\n\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n}  // namespace layer_norm\n"
  },
  {
    "path": "csrc/layer_norm/ln_parallel_bwd_1024.cu",
    "content": "#include \"ln_parallel_residual_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL\n\nREGISTER_PARALLEL_BWD_LAUNCHER(  1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER(  1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER(  1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER(  1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER(  1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER(  1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER(  1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER(  1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER(  1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER(  1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);\n"
  },
  {
    "path": "csrc/layer_norm/ln_parallel_bwd_1280.cu",
    "content": "#include \"ln_parallel_residual_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL\n\nREGISTER_PARALLEL_BWD_LAUNCHER(  1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER(  1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER(  1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER(  1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER(  1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER(  1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER(  1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER(  1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER(  1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER(  1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);\n"
  },
  {
    "path": "csrc/layer_norm/ln_parallel_bwd_1536.cu",
    "content": "#include \"ln_parallel_residual_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL\n\nREGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 1, 4,  8, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 1, 4,  8, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 1, 4,  8, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 1, 4,  8, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 1, 4,  8, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 1, 4,  8, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 1, 4,  8, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 1, 4,  8, 4);\n"
  },
  {
    "path": "csrc/layer_norm/ln_parallel_bwd_2048.cu",
    "content": "#include \"ln_parallel_residual_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL\n\nREGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);"
  },
  {
    "path": "csrc/layer_norm/ln_parallel_bwd_256.cu",
    "content": "#include \"ln_parallel_residual_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL\n\nREGISTER_PARALLEL_BWD_LAUNCHER(  256, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER(  256, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER(  256, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER(  256, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER(  256, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER(  256, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER(  256, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER(  256, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER(  256, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER(  256, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);\n"
  },
  {
    "path": "csrc/layer_norm/ln_parallel_bwd_2560.cu",
    "content": "#include \"ln_parallel_residual_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL\n\nREGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 1, 4,  8, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 1, 4,  8, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 1, 4,  8, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 1, 4,  8, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 1, 4,  8, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 1, 4,  8, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 1, 4,  8, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 1, 4,  8, 4);\n"
  },
  {
    "path": "csrc/layer_norm/ln_parallel_bwd_3072.cu",
    "content": "#include \"ln_parallel_residual_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL\n\nREGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);"
  },
  {
    "path": "csrc/layer_norm/ln_parallel_bwd_4096.cu",
    "content": "#include \"ln_parallel_residual_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL\n\n// Use 8 warps otherwise there's a lot of register spilling\n\nREGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);"
  },
  {
    "path": "csrc/layer_norm/ln_parallel_bwd_512.cu",
    "content": "#include \"ln_parallel_residual_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL\n\nREGISTER_PARALLEL_BWD_LAUNCHER(  512, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER(  512, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER(  512, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER(  512, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER(  512, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER(  512, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER(  512, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER(  512, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER(  512, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER(  512, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);\n"
  },
  {
    "path": "csrc/layer_norm/ln_parallel_bwd_5120.cu",
    "content": "#include \"ln_parallel_residual_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL\n\n// Use 8 warps otherwise there's a lot of register spilling\n\nREGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 8,  8, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 8,  8, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 8,  8, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 8,  8, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 8,  8, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 8,  8, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 8,  8, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 8,  8, 4);"
  },
  {
    "path": "csrc/layer_norm/ln_parallel_bwd_6144.cu",
    "content": "#include \"ln_parallel_residual_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL\n\nREGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 6144, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 6144, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);"
  },
  {
    "path": "csrc/layer_norm/ln_parallel_bwd_7168.cu",
    "content": "#include \"ln_parallel_residual_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL\n\nREGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp32, fp16, fp32, fp16, fp32, 1, 1, 8,  8, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp16, fp16, fp32, fp16, fp32, 1, 1, 8,  8, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp32, fp16, fp16, fp16, fp32, 1, 1, 8,  8, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp32, bf16, fp32, bf16, fp32, 1, 1, 8,  8, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 7168, bf16, bf16, fp32, bf16, fp32, 1, 1, 8,  8, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp32, bf16, bf16, bf16, fp32, 1, 1, 8,  8, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp16, fp16, fp16, fp16, fp32, 1, 1, 8,  8, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 7168, bf16, bf16, bf16, bf16, fp32, 1, 1, 8,  8, 4);"
  },
  {
    "path": "csrc/layer_norm/ln_parallel_bwd_768.cu",
    "content": "#include \"ln_parallel_residual_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL\n\nREGISTER_PARALLEL_BWD_LAUNCHER(  768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER(  768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER(  768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER(  768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER(  768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER(  768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER(  768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER(  768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER(  768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER(  768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);\n"
  },
  {
    "path": "csrc/layer_norm/ln_parallel_bwd_8192.cu",
    "content": "#include \"ln_parallel_residual_bwd_kernels.cuh\"\n\n// Create backward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL\n\nREGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 8192, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);\nREGISTER_PARALLEL_BWD_LAUNCHER( 8192, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);"
  },
  {
    "path": "csrc/layer_norm/ln_parallel_fwd_1024.cu",
    "content": "#include \"ln_parallel_residual_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG\n\nREGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);\n"
  },
  {
    "path": "csrc/layer_norm/ln_parallel_fwd_1280.cu",
    "content": "#include \"ln_parallel_residual_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG\n\nREGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);\n"
  },
  {
    "path": "csrc/layer_norm/ln_parallel_fwd_1536.cu",
    "content": "#include \"ln_parallel_residual_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG\n\nREGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);\n"
  },
  {
    "path": "csrc/layer_norm/ln_parallel_fwd_2048.cu",
    "content": "#include \"ln_parallel_residual_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG\n\nREGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);\n"
  },
  {
    "path": "csrc/layer_norm/ln_parallel_fwd_256.cu",
    "content": "#include \"ln_parallel_residual_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG\n\nREGISTER_PARALLEL_FWD_LAUNCHER(  256, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER(  256, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER(  256, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER(  256, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER(  256, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER(  256, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER(  256, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER(  256, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER(  256, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER(  256, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);"
  },
  {
    "path": "csrc/layer_norm/ln_parallel_fwd_2560.cu",
    "content": "#include \"ln_parallel_residual_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG\n\nREGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);\n"
  },
  {
    "path": "csrc/layer_norm/ln_parallel_fwd_3072.cu",
    "content": "#include \"ln_parallel_residual_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG\n\nREGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16);\n"
  },
  {
    "path": "csrc/layer_norm/ln_parallel_fwd_4096.cu",
    "content": "#include \"ln_parallel_residual_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG\n\nREGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16);\n"
  },
  {
    "path": "csrc/layer_norm/ln_parallel_fwd_512.cu",
    "content": "#include \"ln_parallel_residual_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG\n\nREGISTER_PARALLEL_FWD_LAUNCHER(  512, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER(  512, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER(  512, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER(  512, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER(  512, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER(  512, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER(  512, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER(  512, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER(  512, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER(  512, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);\n"
  },
  {
    "path": "csrc/layer_norm/ln_parallel_fwd_5120.cu",
    "content": "#include \"ln_parallel_residual_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG\n\nREGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16);\n"
  },
  {
    "path": "csrc/layer_norm/ln_parallel_fwd_6144.cu",
    "content": "#include \"ln_parallel_residual_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG\n\nREGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 6144, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 6144, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16);\n"
  },
  {
    "path": "csrc/layer_norm/ln_parallel_fwd_7168.cu",
    "content": "#include \"ln_parallel_residual_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG\n\nREGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 7168, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 7168, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16);\n"
  },
  {
    "path": "csrc/layer_norm/ln_parallel_fwd_768.cu",
    "content": "#include \"ln_parallel_residual_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG\n\nREGISTER_PARALLEL_FWD_LAUNCHER(  768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER(  768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER(  768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER(  768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER(  768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER(  768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER(  768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER(  768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER(  768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER(  768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);\n"
  },
  {
    "path": "csrc/layer_norm/ln_parallel_fwd_8192.cu",
    "content": "#include \"ln_parallel_residual_fwd_kernels.cuh\"\n\n// Create forward launch function and register. Macro signature:\n//  HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG\n\nREGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 8192, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16);\nREGISTER_PARALLEL_FWD_LAUNCHER( 8192, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16);\n"
  },
  {
    "path": "csrc/layer_norm/ln_parallel_residual_bwd_kernels.cuh",
    "content": "#pragma once\n\n#include \"ln.h\"\n#include \"ln_utils.cuh\"\n#include \"ln_kernel_traits.h\"\n#include \"static_switch.h\"\n#include \"ln_bwd_kernels.cuh\"\n\nnamespace layer_norm {\n\ntemplate<typename Ktraits, bool Is_dropout, bool Tied_norm, bool Is_even_cols>\n__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) \nvoid ln_parallel_residual_bwd_kernel(layer_norm::BwdParams params) {\n\n    enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };\n    enum { WARPS_M = Ktraits::WARPS_M };\n    enum { WARPS_N = Ktraits::WARPS_N };\n    enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };\n    enum { COLS = Ktraits::COLS };\n    enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };\n    enum { LDGS = Ktraits::LDGS };\n    enum { NUM_ELTS = Ktraits::ELTS_PER_LDG };\n    enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP };\n    enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };\n\n    using input_t = typename Ktraits::input_t;\n    using compute_t = typename Ktraits::compute_t;\n    using index_t = typename Ktraits::index_t;\n    using mask_t = typename Ktraits::mask_t;\n    using Ivec = typename Ktraits::Ivec;\n    using Rvec = typename Ktraits::Rvec;\n    using Ovec = typename Ktraits::Ovec;\n    using Wvec = typename Ktraits::Wvec;\n    using Cvec = typename Ktraits::Cvec;\n    using Mvec = typename Ktraits::Mvec;\n    using Reducer = typename Ktraits::Reducer;\n    using reduce_t = typename Reducer::Type;\n\n    extern __shared__ char smem_[];\n\n    const bool has_residual = params.dresidual != nullptr;\n    const bool has_x1 = params.dx1 != nullptr;\n    const bool prenorm = params.dx != nullptr;\n\n    const index_t tidx = threadIdx.x;\n    const index_t bidn = blockIdx.x % CTAS_PER_ROW;\n    const index_t bidm = blockIdx.x / CTAS_PER_ROW;\n    const index_t lane = tidx % THREADS_PER_WARP;\n    const index_t warp = tidx / THREADS_PER_WARP;\n    const index_t warp_m = warp / Ktraits::WARPS_N;\n    const index_t warp_n = warp % Ktraits::WARPS_N;\n    const index_t tid_r = warp_n * THREADS_PER_WARP + lane;\n\n    const index_t r = bidm * Ktraits::ROWS_PER_CTA + warp_m;\n    const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;\n\n    static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW);\n\n    Cvec dz0y_sum[LDGS];\n    Cvec dz0_sum[LDGS];\n    Cvec dz1y_sum[LDGS];\n    Cvec dz1_sum[LDGS];\n\n    memset(dz0y_sum, 0, sizeof(dz0y_sum));\n    memset(dz0_sum, 0, sizeof(dz0_sum));\n    if (!Tied_norm) {\n        memset(dz1y_sum, 0, sizeof(dz1y_sum));\n        memset(dz1_sum, 0, sizeof(dz1_sum));\n    }\n\n    compute_t * smem_wgrad = reinterpret_cast<compute_t*>(smem_);\n    char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD;\n\n    Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_dgrad);\n\n    Sum<reduce_t> sum;\n\n    const index_t num_valid_ldgs =\n        ((params.cols / Ktraits::ELTS_PER_LDG) - 1 - c + Ktraits::VEC_COLS_PER_LDG) / Ktraits::VEC_COLS_PER_LDG;\n\n    Wvec gamma0[LDGS];\n    Wvec gamma1[LDGS];\n    index_t idx = c;\n    #pragma unroll\n    for( int it = 0; it < LDGS; it++ ) {\n        if (Is_even_cols || (it < num_valid_ldgs)) {\n            gamma0[it].load_from(params.gamma, idx);\n            if (!Tied_norm) { gamma1[it].load_from(params.gamma1, idx); }\n            idx += Ktraits::VEC_COLS_PER_LDG;\n        }\n    }\n    // TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the\n    // last blocks with syncthreads!\n    // grid stride over rows\n    #pragma unroll 1\n    for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {\n        const compute_t mu_r = static_cast<const compute_t *>(params.mu)[row];\n        const compute_t rs_r = static_cast<const compute_t *>(params.rs)[row];\n        Mvec dmask0[LDGS], dmask1[LDGS];\n        Rvec dx[LDGS];\n        compute_t dy[LDGS * NUM_ELTS];\n        compute_t y[LDGS * NUM_ELTS];\n        compute_t mdy_local = 0.f;\n        compute_t mdyy_local = 0.f;\n        index_t idx = row * params.cols / Ktraits::ELTS_PER_LDG + c;\n        #pragma unroll\n        for( int it = 0; it < LDGS; it++ ) {\n            if (Is_even_cols || (it < num_valid_ldgs)) {\n                Rvec x;\n                Ovec dz0, dz1;\n                dz0.load_from(params.dz, idx);\n                if (!Tied_norm) { dz1.load_from(params.dz1, idx); }\n                if (prenorm) { dx[it].load_from(params.dx, idx); }\n                x.load_from(params.x, idx);\n                if (Is_dropout) {\n                    dmask0[it].load_from(params.dmask, idx);\n                    if (has_x1) { dmask1[it].load_from(params.dmask1, idx); }\n                }\n                idx += Ktraits::VEC_COLS_PER_LDG;\n                #pragma unroll\n                for( int jt = 0; jt < NUM_ELTS; jt++ ) {\n                    compute_t x_tmp = x.data.elt[jt];\n                    compute_t y_tmp = rs_r * (x_tmp - (!params.is_rms_norm ? mu_r : 0.f));\n                    compute_t dy_tmp = compute_t(gamma0[it].data.elt[jt]) * compute_t(dz0.data.elt[jt]);\n                    if (!Tied_norm) {\n                        dy_tmp += compute_t(gamma1[it].data.elt[jt]) * compute_t(dz1.data.elt[jt]);\n                    }\n                    compute_t dz0_tmp = dz0.data.elt[jt];\n                    compute_t dz1_tmp;\n                    if (!Tied_norm) { dz1_tmp = dz1.data.elt[jt]; }\n\n                    mdy_local += dy_tmp;\n                    mdyy_local += dy_tmp * y_tmp;\n\n                    dy[it * NUM_ELTS + jt] = dy_tmp;\n                    y[it * NUM_ELTS + jt] = y_tmp;\n\n                    dz0y_sum[it].data.elt[jt] += dz0_tmp * y_tmp;\n                    dz0_sum[it].data.elt[jt] += dz0_tmp;\n                    if (!Tied_norm) {\n                        dz1y_sum[it].data.elt[jt] += dz1_tmp * y_tmp;\n                        dz1_sum[it].data.elt[jt] += dz1_tmp;\n                    }\n                }\n            }\n        }\n\n        reduce_t result = reducer.allreduce({mdy_local, mdyy_local}, sum);\n        mdy_local = layer_norm::Get<0>::of<reduce_t, compute_t>(result) * params.inverse_cols;\n        mdyy_local = layer_norm::Get<1>::of<reduce_t, compute_t>(result) * params.inverse_cols;\n\n        idx = row * params.cols / Ktraits::ELTS_PER_LDG + c;\n        #pragma unroll\n        for( int it = 0; it < LDGS; it++ ) {\n            if (Is_even_cols || (it < num_valid_ldgs)) {\n                Ivec dx0, dx1;\n                Rvec dresidual;\n                #pragma unroll\n                for( int jt = 0; jt < NUM_ELTS; jt++ ) {\n                    compute_t dx_tmp_res;\n                    compute_t dy_tmp = dy[it * NUM_ELTS + jt];\n                    compute_t y_tmp = y[it * NUM_ELTS + jt];\n                    compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + (!params.is_rms_norm ? mdy_local : 0.f)));\n                    dx_tmp_res = prenorm ? dx_tmp + compute_t(dx[it].data.elt[jt]) : dx_tmp;\n                    if (has_residual) { dresidual.data.elt[jt] = dx_tmp_res; }\n                    if (Is_dropout) {\n                        dx0.data.elt[jt] = dmask0[it].data.elt[jt] ? dx_tmp_res * params.dropout_scale : 0.f;\n                        if (has_x1) { dx1.data.elt[jt] = dmask1[it].data.elt[jt] ? dx_tmp_res * params.dropout_scale : 0.f; }\n                    } else {\n                        dx0.data.elt[jt] = dx_tmp_res;\n                        if (has_x1) { dx1.data.elt[jt] = dx_tmp_res; }\n                    }\n                }\n                if (has_residual) { dresidual.store_to(params.dresidual, idx); }\n                dx0.store_to(params.dx0, idx);\n                if (has_x1) { dx1.store_to(params.dx1, idx); }\n                idx += Ktraits::VEC_COLS_PER_LDG;\n            }\n        }\n\n    }  // end: grid stride loop\n\n    if( WARPS_M == 1 ) {\n        idx = r * params.cols / Ktraits::ELTS_PER_LDG + c;\n        #pragma unroll\n        for( int it = 0; it < LDGS; it++ ) {\n            if (Is_even_cols || (it < num_valid_ldgs)) {\n                dz0_sum[it].store_to(params.dbeta_part, idx);\n                dz0y_sum[it].store_to(params.dgamma_part, idx);\n                if (!Tied_norm) {\n                    dz1_sum[it].store_to(params.dbeta1_part, idx);\n                    dz1y_sum[it].store_to(params.dgamma1_part, idx);\n                }\n                idx += Ktraits::VEC_COLS_PER_LDG;\n            }\n        }\n    } else {\n        static_assert(WARPS_M == 1 || Ktraits::CTAS_PER_ROW == 1, \"Multiple rows per CTA not supported for Multi-CTA.\");\n        // Finalize reduction of part dgamma and dbeta for this CTA\n        // by reducing over the rows held across the WARPS_M warps\n\n        // Assumption: blockSize divides hidden size.\n        enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA };\n        static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, \"\");\n\n        idx = warp_m * Ktraits::VEC_COLS + tid_r;\n        #pragma unroll\n        for( int it = 0; it < LDGS; it++ ) {\n            dz0_sum[it].store_to(smem_wgrad, idx);\n            idx += THREADS_PER_ROW;\n        }\n        __syncthreads();\n        compute_t cta_dz0_sum[NUM_RES];\n        memset(cta_dz0_sum, 0, sizeof(compute_t) * NUM_RES);\n        for( int it = 0; it < ROWS_PER_CTA; it++ ) {\n            for( int jt = 0; jt < NUM_RES; jt++ ) {\n                cta_dz0_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];\n            }\n        }\n        __syncthreads();\n\n        idx = warp_m * Ktraits::VEC_COLS + tid_r;\n        #pragma unroll\n        for( int it = 0; it < LDGS; it++ ) {\n            dz0y_sum[it].store_to(smem_wgrad, idx);\n            idx += THREADS_PER_ROW;\n        }\n        __syncthreads();\n        compute_t cta_dz0y_sum[NUM_RES];\n        memset(cta_dz0y_sum, 0, sizeof(compute_t) * NUM_RES);\n        for( int it = 0; it < ROWS_PER_CTA; it++ ) {\n            for( int jt = 0; jt < NUM_RES; jt++ ) {\n                cta_dz0y_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];\n            }\n        }\n\n        compute_t cta_dz1_sum[NUM_RES], cta_dz1y_sum[NUM_RES];\n        if (!Tied_norm) {\n            __syncthreads();\n            idx = warp_m * Ktraits::VEC_COLS + tid_r;\n            #pragma unroll\n            for( int it = 0; it < LDGS; it++ ) {\n                dz1_sum[it].store_to(smem_wgrad, idx);\n                idx += THREADS_PER_ROW;\n            }\n            __syncthreads();\n            memset(cta_dz1_sum, 0, sizeof(compute_t) * NUM_RES);\n            for( int it = 0; it < ROWS_PER_CTA; it++ ) {\n                for( int jt = 0; jt < NUM_RES; jt++ ) {\n                    cta_dz1_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];\n                }\n            }\n            __syncthreads();\n            idx = warp_m * Ktraits::VEC_COLS + tid_r;\n            #pragma unroll\n            for( int it = 0; it < LDGS; it++ ) {\n                dz1y_sum[it].store_to(smem_wgrad, idx);\n                idx += THREADS_PER_ROW;\n            }\n            __syncthreads();\n            memset(cta_dz1y_sum, 0, sizeof(compute_t) * NUM_RES);\n            for( int it = 0; it < ROWS_PER_CTA; it++ ) {\n                for( int jt = 0; jt < NUM_RES; jt++ ) {\n                    cta_dz1y_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];\n                }\n            }\n        }\n\n        const index_t num_valid_writes\n            = (params.cols - 1 - tidx + Ktraits::THREADS_PER_CTA) / Ktraits::THREADS_PER_CTA;\n        compute_t *dgamma0_part = static_cast<compute_t *>(params.dgamma_part) + bidm * params.cols + tidx;\n        compute_t *dbeta0_part = static_cast<compute_t *>(params.dbeta_part) + bidm * params.cols + tidx;\n        compute_t *dgamma1_part = !Tied_norm ? static_cast<compute_t *>(params.dgamma1_part) + bidm * params.cols + tidx : nullptr;\n        compute_t *dbeta1_part = !Tied_norm ? static_cast<compute_t *>(params.dbeta1_part) + bidm * params.cols + tidx : nullptr;\n        for( int jt = 0; jt < NUM_RES; jt++ ) {\n            if (Is_even_cols || (jt < num_valid_writes)) {\n                *dgamma0_part = cta_dz0y_sum[jt];\n                dgamma0_part += Ktraits::THREADS_PER_CTA;\n                *dbeta0_part = cta_dz0_sum[jt];\n                dbeta0_part += Ktraits::THREADS_PER_CTA;\n                if (!Tied_norm) {\n                    *dgamma1_part = cta_dz1y_sum[jt];\n                    dgamma1_part += Ktraits::THREADS_PER_CTA;\n                    *dbeta1_part = cta_dz1_sum[jt];\n                    dbeta1_part += Ktraits::THREADS_PER_CTA;\n                }\n            }\n        }\n\n    }\n}\n\ntemplate<typename Kernel_traits, bool Is_even_cols>\n__global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA)\nvoid ln_parallel_residual_bwd_finalize_kernel(BwdParams params)\n{\n\n    using compute_t = typename Kernel_traits::compute_t;\n    using weight_t = typename Kernel_traits::weight_t;\n    using index_t = typename Kernel_traits::index_t;\n    using Reducer = typename Kernel_traits::Reducer;\n    using reduce_t = typename Reducer::Type;\n\n    Sum<reduce_t> sum;\n    enum { NUM_ELT = Kernel_traits::ELTS_PER_LDG };\n    enum { THREADS_PER_WARP = Kernel_traits::THREADS_PER_WARP };\n\n    // Multiplying by 2 since we have both gamma0 and gamma1\n    __shared__ char smem_[2 * Kernel_traits::SMEM_BYTES_PER_CTA];\n\n    constexpr uint32_t bidm = 0;\n\n    const uint32_t bidn = blockIdx.x;\n    const uint32_t tidx = threadIdx.x;\n    const uint32_t warp = tidx / THREADS_PER_WARP;\n    const uint32_t lane = tidx % THREADS_PER_WARP;\n\n    Reducer reducer(params, bidm, bidn, 0, 0, lane, smem_);\n\n    const uint32_t c = bidn * THREADS_PER_WARP + lane;\n    const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane;\n    constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP;\n    for( uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS; col += COL_STRIDE, col_out += COL_STRIDE / 2 ) {\n        // Each thread sums over NUM_ELT columns.\n        Vec<compute_t, NUM_ELT> dbeta0_local, dgamma0_local, dbeta1_local, dgamma1_local;\n        memset(&dgamma0_local, 0, sizeof(dgamma0_local));\n        memset(&dbeta0_local, 0, sizeof(dbeta0_local));\n        memset(&dgamma1_local, 0, sizeof(dgamma1_local));\n        memset(&dbeta1_local, 0, sizeof(dbeta1_local));\n        if (Is_even_cols || col < params.cols) {\n            for( uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA ) {\n                index_t idx = row * params.cols + col;\n\n                Vec<compute_t, NUM_ELT> dbeta0_part, dgamma0_part, dbeta1_part, dgamma1_part;\n                dbeta0_part.load_from(params.dbeta_part, idx);\n                dgamma0_part.load_from(params.dgamma_part, idx);\n                dbeta1_part.load_from(params.dbeta1_part, idx);\n                dgamma1_part.load_from(params.dgamma1_part, idx);\n                #pragma unroll\n                for( int it = 0; it < NUM_ELT; it++ ) {\n                    dgamma0_local.data.elt[it] += dgamma0_part.data.elt[it];\n                    dbeta0_local.data.elt[it] += dbeta0_part.data.elt[it];\n                    dgamma1_local.data.elt[it] += dgamma1_part.data.elt[it];\n                    dbeta1_local.data.elt[it] += dbeta1_part.data.elt[it];\n                }\n            }\n        }\n        void * smem_gamma0 = smem_;\n        void * smem_beta0 = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE];\n        void * smem_gamma1 = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE];\n        void * smem_beta1 = &smem_[3 * Kernel_traits::SMEM_BYTES_TRANSPOSE];\n\n        const int write_row = warp;\n        const int write_col = lane ^ write_row;\n        const int write_idx = write_row * THREADS_PER_WARP + write_col;\n\n        dgamma0_local.store_to(smem_gamma0, write_idx);\n        dbeta0_local.store_to(smem_beta0, write_idx);\n        dgamma1_local.store_to(smem_gamma1, write_idx);\n        dbeta1_local.store_to(smem_beta1, write_idx);\n\n        __syncthreads();\n\n        // It would be probably safe to reuse the first row of smem_beta0 and smem_gamma0\n        void * smem_gamma0_out = &smem_[4 * Kernel_traits::SMEM_BYTES_TRANSPOSE];\n        void * smem_beta0_out = &smem_[4 * Kernel_traits::SMEM_BYTES_TRANSPOSE + Kernel_traits::SMEM_BYTES_OUTPUT];\n        void * smem_gamma1_out = &smem_[4 * Kernel_traits::SMEM_BYTES_TRANSPOSE + 2 * Kernel_traits::SMEM_BYTES_OUTPUT];\n        void * smem_beta1_out = &smem_[4 * Kernel_traits::SMEM_BYTES_TRANSPOSE + 3 * Kernel_traits::SMEM_BYTES_OUTPUT];\n\n        // More than one iter iff ROWS_PER_CTA < 32.\n        for( int w = warp; w < THREADS_PER_WARP; w += Kernel_traits::ROWS_PER_CTA ) {\n            const int read_row = lane;\n            const int read_col = w ^ read_row;\n            const int read_idx = read_row * THREADS_PER_WARP + read_col;\n\n            memset(&dbeta0_local, 0, sizeof(dbeta0_local));\n            memset(&dgamma0_local, 0, sizeof(dgamma0_local));\n            memset(&dbeta1_local, 0, sizeof(dbeta1_local));\n            memset(&dgamma1_local, 0, sizeof(dgamma1_local));\n\n            // Load beta and gamma transposed\n            if(read_row < Kernel_traits::ROWS_PER_CTA){\n                dbeta0_local.load_from(smem_beta0, read_idx);\n                dgamma0_local.load_from(smem_gamma0, read_idx);\n                dbeta1_local.load_from(smem_beta1, read_idx);\n                dgamma1_local.load_from(smem_gamma1, read_idx);\n            }\n\n            // Call reducer on the loaded value(s) and convert.\n            #pragma unroll\n            for( int it = 0; it < NUM_ELT; it++ ) {\n                compute_t b0_i = dbeta0_local.data.elt[it];\n                compute_t g0_i = dgamma0_local.data.elt[it];\n                compute_t b1_i = dbeta1_local.data.elt[it];\n                compute_t g1_i = dgamma1_local.data.elt[it];\n                b0_i = reducer.allreduce(b0_i, sum);\n                g0_i = reducer.allreduce(g0_i, sum);\n                b1_i = reducer.allreduce(b1_i, sum);\n                g1_i = reducer.allreduce(g1_i, sum);\n\n                dgamma0_local.data.elt[it] = g0_i;\n                dbeta0_local.data.elt[it] = b0_i;\n                dgamma1_local.data.elt[it] = g1_i;\n                dbeta1_local.data.elt[it] = b1_i;\n            }\n\n            // Leader stores the result at the current column.\n            if(lane == 0){\n                dgamma0_local.store_to(smem_gamma0_out, w);\n                dbeta0_local.store_to(smem_beta0_out, w);\n                dgamma1_local.store_to(smem_gamma1_out, w);\n                dbeta1_local.store_to(smem_beta1_out, w);\n            }\n\n        }\n\n        // All writes done.\n        __syncthreads();\n\n        // Pack and store: 2-wide stores with half the threads.\n        if (Is_even_cols || col_out * 2 < params.cols) {\n            if( warp == Kernel_traits::ROWS_PER_CTA - 1 && lane < THREADS_PER_WARP / 2 ) {\n\n                using src_t = typename TypeToVec2<compute_t>::Type;\n                using dst_t = typename TypeToVec2<weight_t>::Type;\n                Vec<src_t, NUM_ELT> dbeta0_vec2, dgamma0_vec2, dbeta1_vec2, dgamma1_vec2;\n                Vec<dst_t, NUM_ELT> dbeta0_out2, dgamma0_out2, dbeta1_out2, dgamma1_out2;\n\n                dgamma0_vec2.load_from(smem_gamma0_out, lane);\n                dbeta0_vec2.load_from(smem_beta0_out, lane);\n                dgamma1_vec2.load_from(smem_gamma1_out, lane);\n                dbeta1_vec2.load_from(smem_beta1_out, lane);\n                #pragma unroll\n                for( int it = 0; it < NUM_ELT; it++ ) {\n                    dgamma0_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dgamma0_vec2.data.elt[it]);\n                    dbeta0_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dbeta0_vec2.data.elt[it]);\n                    dgamma1_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dgamma1_vec2.data.elt[it]);\n                    dbeta1_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dbeta1_vec2.data.elt[it]);\n                }\n                dgamma0_out2.store_to(params.dgamma, col_out);\n                dbeta0_out2.store_to(params.dbeta, col_out);\n                dgamma1_out2.store_to(params.dgamma1, col_out);\n                dbeta1_out2.store_to(params.dbeta1, col_out);\n            }\n        }\n    }\n}\n\n}  // namespace layer_norm\n\nusing namespace layer_norm;\n\ntemplate<\n    typename weight_t,\n    typename input_t,\n    typename residual_t,\n    typename output_t,\n    typename compute_t,\n    typename index_t,\n    int HIDDEN_SIZE,\n    int CTAS_PER_ROW,\n    int WARPS_M,\n    int WARPS_N,\n    int BYTES_PER_LDG_MAIN,\n    int BYTES_PER_LDG_FINAL\n>\nvoid launch_parallel_residual_(LaunchParams<BwdParams> &launch_params, const bool configure_params){\n\n    using Kernel_traits = Kernel_traits<weight_t,\n                                        input_t,\n                                        residual_t,\n                                        output_t,\n                                        compute_t,\n                                        index_t,\n                                        HIDDEN_SIZE,\n                                        CTAS_PER_ROW,\n                                        WARPS_M,\n                                        WARPS_N,\n                                        BYTES_PER_LDG_MAIN\n                                        >;\n    bool is_dropout = launch_params.params.dropout_keep_p < 1.f;\n    bool tied_norm = launch_params.params.gamma1 == nullptr;\n    bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE;\n    BOOL_SWITCH(is_dropout, IsDropoutConst, [&] {\n        BOOL_SWITCH(tied_norm, TiedNormConst, [&] {\n            BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] {\n                auto kernel = &ln_parallel_residual_bwd_kernel<Kernel_traits, IsDropoutConst, TiedNormConst, IsEvenColsConst>;\n                if( configure_params ) {\n                    int ctas_per_sm;\n                    CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(\n                        &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES));\n                    launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;\n                    launch_params.barrier_size = 0;\n                    launch_params.workspace_bytes = 0;\n                    if(Kernel_traits::CTAS_PER_ROW > 1) {\n                        launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;\n                        launch_params.workspace_bytes = launch_params.params.ctas_per_col\n                                                      * Kernel_traits::WARPS_M\n                                                      * Kernel_traits::CTAS_PER_ROW\n                                                      * sizeof(typename Kernel_traits::reduce_t)\n                                                      * 2;\n                    }\n                    return;\n                }\n\n                if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) {\n                    CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES));\n                }\n                auto stream = launch_params.stream;\n                auto ctas_per_col = launch_params.params.ctas_per_col;\n\n                if( Kernel_traits::CTAS_PER_ROW == 1 ) {\n                    kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>>(launch_params.params);\n                } else {\n                    dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);\n                    dim3 block(Kernel_traits::THREADS_PER_CTA);\n                    void *params_ = (void *)&launch_params.params;\n                    cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES, stream);\n                }\n\n                using Kernel_traits_f = layer_norm::Kernel_traits_finalize<HIDDEN_SIZE,\n                                                                          weight_t,\n                                                                          input_t,\n                                                                          residual_t,\n                                                                          output_t,\n                                                                          compute_t,\n                                                                          index_t,\n                                                                          /*HasColscaleConst=*/false,\n                                                                          32 * 32,  // THREADS_PER_CTA\n                                                                          BYTES_PER_LDG_FINAL>;\n\n                auto kernel_f = !TiedNormConst\n                    ? &layer_norm::ln_parallel_residual_bwd_finalize_kernel<Kernel_traits_f, IsEvenColsConst>\n                    : &layer_norm::ln_bwd_finalize_kernel<Kernel_traits_f, /*HasColscaleConst=*/false, IsEvenColsConst>;\n                kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(launch_params.params);\n\n            });\n        });\n    });\n}\n"
  },
  {
    "path": "csrc/layer_norm/ln_parallel_residual_fwd_kernels.cuh",
    "content": "#pragma once\n\n#ifdef OLD_GENERATOR_PATH\n#include <ATen/CUDAGeneratorImpl.h>\n#else\n#include <ATen/cuda/CUDAGeneratorImpl.h>\n#endif\n\n#include <ATen/cuda/detail/UnpackRaw.cuh>  // For at::cuda::philox::unpack\n#include <curand_kernel.h>\n\n#include \"ln.h\"\n#include \"ln_utils.cuh\"\n#include \"ln_kernel_traits.h\"\n#include \"static_switch.h\"\n\nnamespace layer_norm {\n\ntemplate<typename Ktraits, bool Is_dropout, bool Tied_norm, bool Is_even_cols>\n__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) \nvoid ln_parallel_residual_fwd_kernel(FwdParams params) {\n\n    enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };\n    enum { WARPS_N = Ktraits::WARPS_N };\n    enum { WARPS_M = Ktraits::WARPS_M };\n    enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };\n    enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG };\n    enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };\n    enum { LDGS = Ktraits::LDGS };\n    enum { NUM_ELTS = Ktraits::NUM_ELTS };\n    enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };\n\n    using input_t = typename Ktraits::input_t;\n    using residual_t = typename Ktraits::residual_t;\n    using output_t = typename Ktraits::output_t;\n    using index_t = typename Ktraits::index_t;\n    using compute_t = typename Ktraits::compute_t;\n    using mask_t = typename Ktraits::mask_t;\n    using Ivec = typename Ktraits::Ivec;\n    using Rvec = typename Ktraits::Rvec;\n    using Ovec = typename Ktraits::Ovec;\n    using Wvec = typename Ktraits::Wvec;\n    using Cvec = typename Ktraits::Cvec;\n    using Mvec = typename Ktraits::Mvec;\n\n    using Stats = typename Ktraits::Stats;\n    using stats_t = typename Stats::stats_t;\n\n    const bool has_residual = params.residual != nullptr;\n    const bool has_x1 = params.x1 != nullptr;\n    const bool save_x = has_residual || has_x1 || Is_dropout || !(std::is_same<input_t, residual_t>::value);\n\n    extern __shared__ char smem_[];\n\n    const index_t tidx = threadIdx.x;\n    const index_t bidn = blockIdx.x % CTAS_PER_ROW;\n    const index_t bidm = blockIdx.x / CTAS_PER_ROW;\n    const index_t lane = tidx % THREADS_PER_WARP;\n    const index_t warp = tidx / THREADS_PER_WARP;\n    const index_t warp_m = warp / WARPS_N;\n    const index_t warp_n = warp % WARPS_N;\n\n    const index_t r = bidm * ROWS_PER_CTA + warp_m;\n    const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;\n\n    Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem_);\n\n    compute_t *mu_ptr = static_cast<compute_t *>(params.mu);\n    compute_t *rs_ptr = static_cast<compute_t *>(params.rs);\n\n    // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/Dropout.cu\n    curandStatePhilox4_32_10_t state;\n    if (Is_dropout) {\n        auto seeds = at::cuda::philox::unpack(params.philox_args);\n        const index_t tidx_global = blockIdx.x * blockDim.x + threadIdx.x;\n        curand_init(std::get<0>(seeds), tidx_global, std::get<1>(seeds), &state);\n    }\n\n    const index_t num_valid_ldgs = ((params.cols / Ktraits::ELTS_PER_LDG) - 1 - c + VEC_COLS_PER_LDG) / VEC_COLS_PER_LDG;\n\n    Wvec gamma0[LDGS];\n    Wvec beta0[LDGS];\n    Wvec gamma1[LDGS];\n    Wvec beta1[LDGS];\n    index_t idx = c;\n    #pragma unroll\n    for( int it = 0; it < LDGS; it++ ) {\n        if (Is_even_cols || (it < num_valid_ldgs)) {\n            gamma0[it].load_from(params.gamma, idx);\n            if (params.beta != nullptr) {\n                beta0[it].load_from(params.beta, idx);\n            } else {\n                beta0[it].zero_();\n            }\n            if (!Tied_norm) {\n                gamma1[it].load_from(params.gamma1, idx);\n                if (params.beta1 != nullptr) {\n                    beta1[it].load_from(params.beta1, idx);\n                } else {\n                    beta1[it].zero_();\n                }\n            }\n            idx += VEC_COLS_PER_LDG;\n        }\n    }\n\n    for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {\n        index_t idx = row * params.cols / Ktraits::ELTS_PER_LDG + c;\n        compute_t xf[LDGS * NUM_ELTS];\n        #pragma unroll\n        for( int it = 0; it < LDGS; it++ ) {\n            if (Is_even_cols || (it < num_valid_ldgs)) {\n                Ivec x0;\n                Ivec x1;\n                Rvec residual;\n                Rvec x;\n                Mvec dmask0;\n                Mvec dmask1;\n                x0.load_from(params.x0, idx);\n                if (has_x1) { x1.load_from(params.x1, idx); }\n                if (has_residual) { residual.load_from(params.residual, idx); }\n                #pragma unroll\n                for( int jt = 0; jt < NUM_ELTS; jt++ ) {\n                    // TD [2022-04-22]: We're memory bound, not compute bound, so we don't need to use\n                    // the more efficient curand_uniform4.\n                    compute_t x_ij;\n                    mask_t keep0 = !Is_dropout ? true : curand_uniform(&state) <= params.dropout_keep_p;\n                    if (Is_dropout) { dmask0.data.elt[jt] = keep0; }\n                    compute_t x0_ij = compute_t(x0.data.elt[jt]);\n                    x0_ij = keep0 ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.0f;\n                    if (has_x1) {\n                        mask_t keep1 = !Is_dropout ? true : curand_uniform(&state) <= params.dropout_keep_p;\n                        if (Is_dropout) { dmask1.data.elt[jt] = keep1; }\n                        compute_t x1_ij = compute_t(x1.data.elt[jt]);\n                        x1_ij = keep1 ? (Is_dropout ? x1_ij * params.dropout_scale : x1_ij) : 0.0f;\n                        x_ij = has_residual ? x0_ij + x1_ij + compute_t(residual.data.elt[jt]) : x0_ij + x1_ij;\n                    } else {\n                        x_ij = has_residual ? x0_ij + compute_t(residual.data.elt[jt]) : x0_ij;\n                    }\n                    if (save_x) { x.data.elt[jt] = x_ij; }\n                    xf[it * NUM_ELTS + jt] = x_ij;\n                }\n                if (save_x) { x.store_to(params.x, idx); }\n                if (Is_dropout) {\n                    dmask0.store_to(params.dmask, idx);\n                    if (has_x1) { dmask1.store_to(params.dmask1, idx); }\n                }\n                idx += VEC_COLS_PER_LDG;\n            }\n        }\n\n        static_assert(CTAS_PER_ROW == 1, \"Don't support multiple CTAs per row for now\");\n        const index_t num_vecs = params.cols / Ktraits::ELTS_PER_LDG;\n        const index_t num_full_ldgs = num_vecs / Ktraits::VEC_COLS_PER_LDG;\n        const index_t remaining_vecs = num_vecs % Ktraits::VEC_COLS_PER_LDG;\n        auto valid_elts_in_warp_fn = [num_full_ldgs, remaining_vecs] (int warp_n) -> int {\n            // Need to convert to int, otherwise the subtraction will wrap around.\n            const index_t valid_partial_vecs_in_warp =\n                std::min(std::max(int(remaining_vecs) - int(warp_n * THREADS_PER_WARP), int(0)),\n                        int(THREADS_PER_WARP));\n            return (num_full_ldgs * THREADS_PER_WARP + valid_partial_vecs_in_warp) * NUM_ELTS;\n        };\n        stats_t s = stats.template compute<Is_even_cols>(\n            xf, params.inverse_cols, valid_elts_in_warp_fn, num_valid_ldgs * NUM_ELTS\n        );\n\n        compute_t mu = layer_norm::Get<0>::of<stats_t, compute_t>(s);\n        compute_t m2 = layer_norm::Get<1>::of<stats_t, compute_t>(s);\n\n        if( bidn == 0 && warp_n == 0 && lane == 0 ) {\n            mu_ptr[row] = mu;\n        }\n\n        compute_t rs = rsqrtf(m2 * params.inverse_cols + params.epsilon + (!params.is_rms_norm ? 0.f : mu * mu));\n\n        if( bidn == 0 && warp_n == 0 && lane == 0 ) {\n            rs_ptr[row] = rs;\n        }\n\n        idx = row * params.cols / Ktraits::ELTS_PER_LDG + c;\n        #pragma unroll\n        for( int it = 0; it < LDGS; it++ ) {\n            if (Is_even_cols || (it < num_valid_ldgs)) {\n                Ovec z0;\n                Ovec z1;\n                #pragma unroll\n                for( int jt = 0; jt < NUM_ELTS; jt++ ) {\n                    compute_t y_ij = compute_t(rs * (xf[it * NUM_ELTS + jt] - (!params.is_rms_norm ? mu : 0.f)));\n                    compute_t g0_ij = gamma0[it].data.elt[jt];\n                    compute_t b0_ij = beta0[it].data.elt[jt];\n                    z0.data.elt[jt] = output_t(g0_ij * y_ij + b0_ij);\n                    if (!Tied_norm) {\n                        compute_t g1_ij = gamma1[it].data.elt[jt];\n                        compute_t b1_ij = beta1[it].data.elt[jt];\n                        z1.data.elt[jt] = output_t(g1_ij * y_ij + b1_ij);\n                    }\n                }\n                z0.store_to(params.z, idx);\n                if (!Tied_norm) { z1.store_to(params.z1, idx); }\n                idx += VEC_COLS_PER_LDG;\n            }\n        }\n\n    }\n}\n\n}  // namespace layer_norm\n\nusing namespace layer_norm;\n\ntemplate<\n    typename weight_t,\n    typename input_t,\n    typename residual_t,\n    typename output_t,\n    typename compute_t,\n    typename index_t,\n    int HIDDEN_SIZE,\n    int CTAS_PER_ROW,\n    int WARPS_M,\n    int WARPS_N,\n    int BYTES_PER_LDG\n>\nvoid launch_parallel_residual_(LaunchParams<FwdParams> &launch_params, const bool configure_params){\n\n    using Kernel_traits = Kernel_traits<weight_t,\n                                        input_t,\n                                        residual_t,\n                                        output_t,\n                                        compute_t,\n                                        index_t,\n                                        HIDDEN_SIZE,\n                                        CTAS_PER_ROW,\n                                        WARPS_M,\n                                        WARPS_N,\n                                        BYTES_PER_LDG\n                                        >;\n    bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE;\n    bool tied_norm = launch_params.params.gamma1 == nullptr;\n    BOOL_SWITCH(launch_params.params.dropout_keep_p < 1.f, IsDropoutConst, [&] {\n        BOOL_SWITCH(tied_norm, TiedNormConst, [&] {\n            BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] {\n                auto kernel = &ln_parallel_residual_fwd_kernel<Kernel_traits, IsDropoutConst, TiedNormConst, IsEvenColsConst>;\n                if( configure_params ) {\n                    int ctas_per_sm;\n                    CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(\n                        &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD));\n                    launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;\n                    const size_t rows_per_loop = launch_params.params.ctas_per_col * Kernel_traits::ROWS_PER_CTA;\n                    launch_params.elts_per_thread = (launch_params.params.rows + rows_per_loop - 1) / rows_per_loop * Kernel_traits::LDGS * Kernel_traits::NUM_ELTS;\n                    launch_params.barrier_size = 0;\n                    launch_params.workspace_bytes = 0;\n                    if(Kernel_traits::CTAS_PER_ROW > 1) {\n                        launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;\n                        launch_params.workspace_bytes = launch_params.params.ctas_per_col\n                                                      * Kernel_traits::WARPS_M\n                                                      * Kernel_traits::CTAS_PER_ROW\n                                                      * sizeof(typename Kernel_traits::Stats::stats_t)\n                                                      * 2;\n                    }\n                    return;\n                }\n\n                if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) {\n                    CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD));\n                }\n                auto stream = launch_params.stream;\n                auto ctas_per_col = launch_params.params.ctas_per_col;\n\n                if( Kernel_traits::CTAS_PER_ROW == 1 ) {\n                    kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD, stream>>>(launch_params.params);\n                } else {\n                    dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);\n                    dim3 block(Kernel_traits::THREADS_PER_CTA);\n                    void *params_ = (void *)&launch_params.params;\n                    cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES_FWD, stream);\n                }\n            });\n        });\n    });\n}\n"
  },
  {
    "path": "csrc/layer_norm/ln_utils.cuh",
    "content": "#pragma once\n\n#include <cassert>\n\n#include <cuda_bf16.h>\n#include <cuda_fp16.h>\n\n#include \"ln.h\"\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nconstexpr uint32_t THREADS_PER_WARP = 32;\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline void check_cuda_(cudaError_t status, const char *file, int line) {\n    if( status != cudaSuccess ) {\n        fprintf(stderr, \"CUDA Error: %s %s %d\\n\", cudaGetErrorString(status), file, line);\n        exit(status);\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n#define CHECK_CUDA(ans)                                                                                                        \\\n    { check_cuda_((ans), __FILE__, __LINE__); }\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n#define DIVUP(x, y) (((x) + ((y)-1)) / (y))\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n#define REGISTER_FWD_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG)                 \\\n    void ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams<FwdParams> &launch_params,                      \\\n                                                                                const bool configure_params) {                               \\\n        launch_<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG>(                    \\\n            launch_params, configure_params);                                                                                                \\\n    }                                                                                                                                        \\\n    static FwdRegistrar<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, HIDDEN_SIZE> reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \\\n        ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE)\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n#define REGISTER_BWD_LAUNCHER(                                                                                                                  \\\n    HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE)                      \\\n    void ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams<BwdParams> &launch_params,                         \\\n                                                                                const bool configure_params) {                                  \\\n        launch_<WTYPE,                                                                                                                          \\\n                ITYPE,                                                                                                                          \\\n                RTYPE,                                                                                                                          \\\n                OTYPE,                                                                                                                          \\\n                CTYPE,                                                                                                                          \\\n                uint32_t,                                                                                                                       \\\n                HIDDEN_SIZE,                                                                                                                    \\\n                CTAS_PER_ROW,                                                                                                                   \\\n                WARPS_M,                                                                                                                        \\\n                WARPS_N,                                                                                                                        \\\n                BYTES_PER_LDG,                                                                                                                  \\\n                BYTES_PER_LDG_FINALIZE>(launch_params, configure_params);                                                                       \\\n    }                                                                                                                                           \\\n    static BwdRegistrar<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, HIDDEN_SIZE> reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(    \\\n        ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE)\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n#define REGISTER_PARALLEL_FWD_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG)                \\\n    void ln_parallel_residual_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams<FwdParams> &launch_params,            \\\n                                                                                const bool configure_params) {                                       \\\n        launch_parallel_residual_<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG>(          \\\n            launch_params, configure_params);                                                                                                        \\\n    }                                                                                                                                                \\\n    static FwdParallelRegistrar<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, HIDDEN_SIZE> reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \\\n        ln_parallel_residual_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE)\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n#define REGISTER_PARALLEL_BWD_LAUNCHER(                                                                                                              \\\n    HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE)                           \\\n    void ln_parallel_residual_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams<BwdParams> &launch_params,            \\\n                                                                                const bool configure_params) {                                       \\\n        launch_parallel_residual_<WTYPE,                                                                                                             \\\n                ITYPE,                                                                                                                               \\\n                RTYPE,                                                                                                                               \\\n                OTYPE,                                                                                                                               \\\n                CTYPE,                                                                                                                               \\\n                uint32_t,                                                                                                                            \\\n                HIDDEN_SIZE,                                                                                                                         \\\n                CTAS_PER_ROW,                                                                                                                        \\\n                WARPS_M,                                                                                                                             \\\n                WARPS_N,                                                                                                                             \\\n                BYTES_PER_LDG,                                                                                                                       \\\n                BYTES_PER_LDG_FINALIZE>(launch_params, configure_params);                                                                            \\\n    }                                                                                                                                                \\\n    static BwdParallelRegistrar<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, HIDDEN_SIZE> reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \\\n        ln_parallel_residual_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE)\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ float2 operator+(const float2 & a, const float2 & b){\n    return {a.x + b.x, a.y + b.y};\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ninline __device__ void operator+=(float2 & a, const float2 & b){\n    a.x += b.x;\n    a.y += b.y;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename T>\nstruct Sum {\n    inline __device__ Sum(){}\n    inline __device__ T operator()(const T &a, const T &b){\n        return a + b;\n    }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename T>\ninline __device__ T warp_shuffle_xor(const T & x, uint32_t idx){\n    return __shfl_xor_sync(uint32_t(-1), x, idx);\n}\n\ntemplate<>\ninline __device__ float2 warp_shuffle_xor<float2>(const float2 & x, uint32_t idx){\n    return { warp_shuffle_xor(x.x, idx), warp_shuffle_xor(x.y, idx) };\n}\n\ntemplate<typename T>\ninline __device__ T warp_shuffle_down(const T & x, uint32_t idx){\n    return __shfl_down_sync(uint32_t(-1), x, idx);\n}\n\ntemplate<>\ninline __device__ float2 warp_shuffle_down<float2>(const float2 & x, uint32_t idx){\n    return { warp_shuffle_down(x.x, idx), warp_shuffle_down(x.y, idx) };\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nnamespace layer_norm {\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct uint16 {\n    uint4 u;\n    uint4 v;\n    uint4 s;\n    uint4 t;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct uint8 {\n    uint4 u;\n    uint4 v;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<int BYTES>\nstruct BytesToType {};\n\ntemplate<>\nstruct BytesToType<64> {\n    using Type = uint16;\n    static_assert(sizeof(Type) == 64);\n};\n\ntemplate<>\nstruct BytesToType<32> {\n    using Type = uint8;\n    static_assert(sizeof(Type) == 32);\n};\n\ntemplate<>\nstruct BytesToType<16> {\n    using Type = uint4;\n    static_assert(sizeof(Type) == 16);\n};\n\ntemplate<>\nstruct BytesToType<8> {\n    using Type = uint64_t;\n    static_assert(sizeof(Type) == 8);\n};\n\ntemplate<>\nstruct BytesToType<4> {\n    using Type = uint32_t;\n    static_assert(sizeof(Type) == 4);\n};\n\ntemplate<>\nstruct BytesToType<2> {\n    using Type = uint16_t;\n    static_assert(sizeof(Type) == 2);\n};\n\ntemplate<>\nstruct BytesToType<1> {\n    using Type = uint8_t;\n    static_assert(sizeof(Type) == 1);\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename T>\nstruct TypeToVec2 {};\n\ntemplate<>\nstruct TypeToVec2<float> {\n    using Type = float2;\n};\n\ntemplate<>\nstruct TypeToVec2<half> {\n    using Type = half2;\n};\n\ntemplate<>\nstruct TypeToVec2<nv_bfloat16> {\n    using Type = nv_bfloat162;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<int INDEX>\nstruct Get {\n    template<typename T, typename R>\n    static inline __device__ R of(const T &vec);\n};\n\ntemplate<>\ntemplate<typename T, typename R>\ninline __device__ R Get<0>::of(const T &vec) {\n    return vec.x;\n}\n\ntemplate<>\ntemplate<typename T, typename R>\ninline __device__ R Get<1>::of(const T &vec) {\n    return vec.y;\n}\n\ntemplate<>\ntemplate<typename T, typename R>\ninline __device__ R Get<2>::of(const T &vec) {\n    return vec.z;\n}\n\ntemplate<>\ntemplate<typename T, typename R>\ninline __device__ R Get<3>::of(const T &vec) {\n    return vec.w;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename Src, typename Dst>\nstruct Converter{\n    static inline __device__ Dst convert(const Src &from) {\n        return Dst(from);\n    }\n};\n\ntemplate<>\nstruct Converter<float2, half2>{\n    static inline __device__ half2 convert(const float2 &x) {\n        return __float22half2_rn(x);\n    }\n};\n\ntemplate<>\nstruct Converter<float2, nv_bfloat162>{\n    static inline __device__ nv_bfloat162 convert(const float2 &x) {\n#if __CUDA_ARCH__ >= 800\n        return __float22bfloat162_rn(x);\n#else\n        union {\n            nv_bfloat162 raw;\n            nv_bfloat16 x;\n            nv_bfloat16 y;\n        } tmp;\n        tmp.x = __float2bfloat16_rn(x.x);\n        tmp.y = __float2bfloat16_rn(x.y);\n        return tmp.raw;\n#endif\n    }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename T>\nstruct Zeros{\n    static inline __device__ T get() {\n        return T(0.f);\n    }\n};\n\ntemplate<> \nstruct Zeros<float2>{\n    static inline __device__ float2 get() {\n        return make_float2(0.f, 0.f);\n    }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename Elt_type, uint32_t NUM_ELT>\nstruct Vec {\n\n    enum { BYTES = NUM_ELT * sizeof(Elt_type) };\n\n    using Vec_type = typename BytesToType<BYTES>::Type;\n\n    using Alias_type = union {\n        Vec_type vec;\n        Elt_type elt[NUM_ELT];\n    };\n\n    Alias_type data;\n\n    template<typename S>\n    inline __device__ void to(Vec<S, NUM_ELT> &other) {\n        #pragma unroll\n        for( int it = 0; it < NUM_ELT; it++ ) {\n            other.data.elt[it] = S(this->data.elt[it]);\n        }\n    }\n\n    template<typename Op>\n    inline __device__ void assign(const Op &op) {\n        #pragma unroll\n        for( int it = 0; it < NUM_ELT; it++ ) {\n            this->data.elt[it] = op(it);\n        }\n    }\n\n    inline __device__ void zero_() {\n        #pragma unroll\n        for( int it = 0; it < NUM_ELT; it++ ) {\n            this->data.elt[it] = Elt_type(0.f);\n        }\n    }\n\n    inline __device__ void load_from(const void *base_ptr, const size_t idx) {\n        this->data.vec = static_cast<const Vec_type *>(base_ptr)[idx];\n    }\n\n    inline __device__ void store_to(void *base_ptr, const size_t idx) {\n        static_cast<Vec_type *>(base_ptr)[idx] = this->data.vec;\n    }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<uint32_t CTAS_PER_ROW>\nstruct InterCTASync {\n\n    template<typename Params>\n    inline __device__ InterCTASync(Params & params, uint32_t bidm, uint32_t bidn)\n        : phase_counter_(0)\n        , b0_(params.barrier + bidm) // The barrier for this group of CTAs.\n        , b1_(params.barrier + bidm + params.ctas_per_col) // The barrier for this group of CTAs.\n    {\n        // BARRIERS ARE ASSUMED TO BE INITIALIZED TO 0!\n    }\n\n    inline __device__ void spin_wait_(int *barrier, int step, int expected) {\n        asm volatile(\"red.release.gpu.global.add.s32 [%0], %1;\" ::\"l\"(barrier), \"r\"(step));\n        for( int found = -1; found != expected; ) {\n            asm volatile(\"ld.global.acquire.gpu.b32 %0, [%1];\" : \"=r\"(found) : \"l\"(barrier));\n        }\n    }\n\n    inline __device__ void sync(){\n        // ALL THREADS MUST ENTER!\n\n        // We switch barrier every iteration.\n        int *barrier = phase_counter_ & 0x1 ? b1_ : b0_;\n        // We decrement every other iteration.\n        bool dec = phase_counter_ & 0x2;\n        int step = dec ? -1 : 1;\n        int expected = dec ? 0 : CTAS_PER_ROW;\n        // There are only 4 phases: up/down for b0/b1.\n        phase_counter_ = (phase_counter_ + 1) & 0x3;\n\n        if( threadIdx.x == 0 ) {\n            spin_wait_(barrier, step, expected);\n        }\n        // CTA waits for thread 0\n        __syncthreads();\n    }\n\n    int phase_counter_;\n    int * b0_;\n    int * b1_;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename T, uint32_t CTAS_PER_ROW, uint32_t WARPS_M, uint32_t WARPS_N>\nstruct Reducer : public Reducer<T, 1, WARPS_M, WARPS_N> {\n\n    using InterCTASync = InterCTASync<CTAS_PER_ROW>;\n    using Base = Reducer<T, 1, WARPS_M, WARPS_N>;\n    using Type = typename Base::Type;\n\n    enum { SMEM_BYTES = Base::SMEM_BYTES };\n\n    enum { WS_BARRIER_BYTES = 2 * sizeof(int) };\n    enum { WS_DATA_BYTES = WARPS_M * CTAS_PER_ROW * sizeof(T) };\n\n    // size of the barriers + temporary result per CTA (multiply with CTAS_PER_ROW to get total)\n    enum { WORKSPACE_BYTES_PER_GROUP = Base::WORKSPACE_BYTES_PER_GROUP + WS_BARRIER_BYTES + WS_DATA_BYTES };\n\n    template<typename Params>\n    inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)\n        : Base(params, bidm, bidn, warp_m, warp_n, lane, smem) \n        , inter_cta_(params, bidm, bidn)\n        , bidn_(bidn) // CTA id within the group.\n        , w0_(static_cast<T*>(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW)\n        , w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW)\n    {\n    }\n\n    template<typename Op>\n    inline __device__ T allreduce(T data, Op &op) {\n        data = Base::reduce(data, op);\n        // We switch workspace every iteration.\n        T *workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;\n\n        // Warp leaders 0 hold the CTA-local results.\n        if( this->warp_n_ == 0 && this->lane_ == 0 ) {\n            workspace[bidn_] = data;\n        }\n        inter_cta_.sync();\n        static_assert(CTAS_PER_ROW <= 32);\n        T total = Zeros<T>::get();\n        if(this->lane_ < CTAS_PER_ROW){\n            total = workspace[this->lane_];\n        }\n        total = Reducer<T, 1, 1, 1>::allreduce_(total, op);\n\n        return total;\n    }\n\n    InterCTASync inter_cta_;\n\n    T *w0_;\n    T *w1_;\n    int bidn_;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename T, uint32_t WARPS_M>\nstruct Reducer<T, 1, WARPS_M, 1> {\n\n    using Type = T;\n    enum { SMEM_BYTES = 0 };\n    enum { WORKSPACE_BYTES_PER_GROUP = 0 };\n\n    enum { THREADS_PER_WARP = 32 };\n\n    template<typename Params>\n    inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) \n        : warp_n_(warp_n)\n        , lane_(lane)\n    {\n    }\n\n    template<typename Op>\n    static inline __device__ T allreduce_(T data, Op &op) {\n        #pragma unroll\n        for( int it = 1; it < THREADS_PER_WARP; it *= 2 ) {\n            data = op(data, warp_shuffle_xor(data, it));\n        }\n        return data;\n    }\n\n    template<typename Op>\n    inline __device__ T allreduce(T data, Op &op) {\n        return allreduce_(data, op);\n    }\n\n    template<typename Op>\n    inline __device__ T reduce(T data, Op &op){\n        // only lane 0 holds the result!\n        #pragma unroll\n        for( int it = THREADS_PER_WARP / 2; it > 0; it /= 2 ) {\n            data = op(data, warp_shuffle_down(data, it));\n        }  \n        return data;\n    }\n    int warp_n_;\n    int lane_;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename T, uint32_t WARPS_M, uint32_t WARPS_N>\nstruct Reducer<T, 1, WARPS_M, WARPS_N> : public Reducer<T, 1, WARPS_M, 1> {\n\n    using Base = Reducer<T, 1, WARPS_M, 1>;\n\n    using Type = T;\n\n    enum { SMEM_BYTES = Base::SMEM_BYTES + WARPS_M * WARPS_N * sizeof(T) * 2 };\n    enum { WORKSPACE_BYTES_PER_GROUP = 0 };\n\n    enum { THREADS_PER_WARP = 32 };\n\n    template<typename Params>\n    inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) \n        : Base(params, bidm, bidn, warp_m, warp_n, lane, smem) \n        , use0_(true)\n    {\n        smem0_ = &static_cast<T *>(smem)[warp_m * WARPS_N];\n        smem1_ = smem0_ + WARPS_M * WARPS_N;\n    }\n\n    template<typename Op>\n    inline __device__ T allreduce(T data, Op & op) {\n        T * smem = use0_ ? smem0_ : smem1_;\n        use0_ = !use0_;\n        data = Base::reduce(data, op);\n        if( this->lane_ == 0 ) {\n            smem[this->warp_n_] = data;\n        }\n        __syncthreads();\n        T out = Zeros<T>::get();\n        #pragma unroll\n        for( int it = 0; it < WARPS_N; it++ ) {\n            out = op(out, smem[it]);\n        }\n        return out;\n    }\n\n    template<typename Op>\n    inline __device__ T reduce(T data, Op &op) {\n        T * smem = use0_ ? smem0_ : smem1_;\n        use0_ = !use0_;\n        // only intra-CTA group leader holds the result!\n        data = Base::reduce(data, op);\n        if( this->lane_ == 0 ) {\n            smem[this->warp_n_] = data;\n        }\n        __syncthreads();\n        T out = Zeros<T>::get();\n        if( this->warp_n_ == 0 && this->lane_ == 0 ) {\n            #pragma unroll\n            for( int it = 0; it < WARPS_N; it++ ) {\n                out = op(out, smem[it]);\n            }\n        }\n        return out;\n    }\n\n    T * smem0_;\n    T * smem1_;\n    bool use0_;\n\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n \ntemplate<typename T, typename int_t>\ninline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, int_t &n_a, int num_active){\n    //Assume at least leftmost is valid and init: step = next_pow2(num_active) / 2 (might get NaN otherwise)\n    const int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1);\n    \n    #pragma unroll\n    for( int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2 ) {\n        // Exchange\n        int_t n_b = warp_shuffle_down(n_a, step);\n        T m_b = warp_shuffle_down(m_a, step);\n        T m2_b = warp_shuffle_down(m2_a, step);\n\n        // Update\n        const int_t n_ab = n_a + n_b; // We can handle one of them being 0, not both.\n        const T rn_ab = 1.f / n_ab; // Might have different n per thread, otherwise this would simplify :(\n        const T delta = m_a - m_b;\n        const float m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab;\n        const float m_ab = (n_a * m_a + n_b * m_b) * rn_ab;\n\n        n_a = n_ab;\n        m_a = m_ab;\n        m2_a = m2_ab;\n    }\n    // Intra-warp broadcast (only lane 0 has valid stats).\n    m_a = __shfl_sync(uint32_t(-1), m_a, 0);\n    m2_a = __shfl_sync(uint32_t(-1), m2_a, 0);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename T, uint32_t CTAS_PER_ROW, uint32_t WARPS_M, uint32_t WARPS_N>\nstruct Stats {\n    // This could be done generically with the Reducer. But then we would have to exchange 3 instead of 2 fields.\n\n    using InterCTASync = InterCTASync<CTAS_PER_ROW>;\n    using BlockStats = Stats<T, 1, WARPS_M, WARPS_N>;\n    using stats_t = typename BlockStats::stats_t;\n\n    enum { SMEM_BYTES = BlockStats::SMEM_BYTES };\n\n    template<typename Params>\n    inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) \n        : inter_cta_(params, bidm, bidn)\n        , block_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem)\n        , bidn_(bidn) // CTA id within the group.\n        , w0_(static_cast<stats_t*>(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW)\n        , w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW)\n        , warp_n_(warp_n)\n        , lane_(lane)\n    {\n    }\n\n    template<uint32_t N>\n    inline __device__ stats_t compute(const T (&elts)[N], const T rn) {\n        constexpr T ELTS_PER_ROW_PER_CTA = N * WARPS_N * THREADS_PER_WARP;\n        // TODO rn is not really needed here..\n        constexpr T block_rn = 1.f / T(ELTS_PER_ROW_PER_CTA);\n        stats_t block_stats = block_stats_.compute(elts, block_rn);\n\n        stats_t *workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;\n\n        if( warp_n_ == 0 && lane_ == 0 ) {\n            workspace[bidn_] = block_stats;\n        }\n\n        // Wait for all CTAS_PER_ROW CTAS in the group to have written their result.\n        inter_cta_.sync();\n\n        T n = Zeros<T>::get();\n        T m = Zeros<T>::get();\n        T m2 = Zeros<T>::get();\n\n        // Assume CTA group size in N less than 32, such that we can finalize with a single warp.\n        static_assert(CTAS_PER_ROW <= 32);\n\n        // Every warp does the final reduction locally. \n        if( lane_ < CTAS_PER_ROW ) {\n            stats_t result = workspace[lane_];\n            n = ELTS_PER_ROW_PER_CTA;\n            m = layer_norm::Get<0>::of<stats_t, T>(result);\n            m2 = layer_norm::Get<1>::of<stats_t, T>(result);\n        }\n\n        warp_chan_upd_dynamic(m, m2, n, CTAS_PER_ROW);\n\n        return { m, m2 };\n    }\n\n    InterCTASync inter_cta_;\n    BlockStats block_stats_;\n\n    stats_t *w0_;\n    stats_t *w1_;\n    int bidn_;\n    int warp_n_;\n    int lane_;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename T, uint32_t WARPS_M, uint32_t WARPS_N>\nstruct Stats<T, 1, WARPS_M, WARPS_N> {\n\n    using WarpStats = Stats<T, 1, WARPS_M, 1>;\n    using stats_t = typename WarpStats::stats_t;\n\n    enum { SMEM_BYTES = WARPS_M * WARPS_N * sizeof(stats_t) * 2 };\n\n    template<typename Params>\n    inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) \n        : warp_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem)\n        , use0_(true)\n    {\n        smem0_ = static_cast<stats_t*>(smem) + warp_m * WARPS_N;\n        smem1_ = smem0_ + WARPS_M * WARPS_N;\n    }\n\n    template<bool Is_even_cols, uint32_t N, typename function_t>\n    inline __device__ stats_t compute(const T (&elts)[N], const T row_norm_factor,\n                                      function_t valid_elts_in_warp_fn, const int num_valid_elts = N) {\n        stats_t * smem = use0_ ? smem0_ : smem1_;\n        use0_ = !use0_;\n        // Compute warp local for all WARPS_N\n        const auto warp_n = warp_stats_.reducer_.warp_n_;\n        const T warp_norm_factor = 1.f / T(Is_even_cols ? N * THREADS_PER_WARP : valid_elts_in_warp_fn(warp_n));\n        stats_t warp_stats = warp_stats_.template compute<Is_even_cols>(\n            elts, warp_norm_factor, valid_elts_in_warp_fn, num_valid_elts\n        );\n\n        //Each warp warp leader stores its stats\n        const auto lane = warp_stats_.reducer_.lane_;\n        if( lane == 0 ) {\n            smem[warp_n] = warp_stats;\n        }\n        __syncthreads();\n\n        int n = 0;;\n        T m = Zeros<T>::get();\n        T m2 = Zeros<T>::get();\n\n        // Assume that there are less than 32 warps, such that we can finalize with a single warp\n        static_assert(WARPS_N <= 32);\n        if(lane < WARPS_N){\n            stats_t result = smem[lane];\n            n = Is_even_cols ? N * THREADS_PER_WARP : valid_elts_in_warp_fn(lane);\n            m = layer_norm::Get<0>::of<stats_t, T>(result);\n            m2 = layer_norm::Get<1>::of<stats_t, T>(result);\n        }\n\n        warp_chan_upd_dynamic(m, m2, n, WARPS_N);\n\n        return { m, m2 };\n    }\n    WarpStats warp_stats_;\n    stats_t * smem0_;\n    stats_t * smem1_;\n    bool use0_;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename T, uint32_t WARPS_M>\nstruct Stats<T, 1, WARPS_M, 1> {\n\n    using stats_t = typename TypeToVec2<T>::Type;\n    // The simple Warp reducer.\n    using Reducer = Reducer<T, 1, WARPS_M, 1>;\n\n    enum { SMEM_BYTES = 0 };\n\n    template<typename Params>\n    inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) \n        : reducer_(params, bidm, bidn, warp_m, warp_n, lane, smem)\n    {\n    }\n\n    template<bool Is_even_cols, uint32_t N, typename function_t>\n    inline __device__ stats_t compute(const T (&elts)[N], const T row_norm_factor,\n                                      // const int valid_elts_in_warp_ignored_, const int num_valid_elts = N) {\n                                      function_t valid_elts_in_warp_fn, const int num_valid_elts = N) {\n\n        auto sum = Sum<T>();\n\n        T m = Zeros<T>::get();\n        #pragma unroll\n        for( int it = 0; it < N; it++ ) {\n            if (Is_even_cols || (it < num_valid_elts)) {\n                m += elts[it];\n            }\n        }\n        m = reducer_.allreduce(m, sum) * row_norm_factor;\n\n        T m2 = Zeros<T>::get();\n        #pragma unroll\n        for( int it = 0; it < N; it++ ) {\n            if (Is_even_cols || (it < num_valid_elts)) {\n                T diff = (elts[it] - m);\n                m2 += diff * diff;\n            }\n        }\n        m2 = reducer_.allreduce(m2, sum);\n\n        return {m, m2};\n    }\n\n    Reducer reducer_;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n}  // namespace layer_norm\n"
  },
  {
    "path": "csrc/layer_norm/setup.py",
    "content": "# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py\nimport sys\nimport warnings\nimport os\nfrom packaging.version import parse, Version\n\nimport torch\nfrom torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME\nfrom setuptools import setup, find_packages\nimport subprocess\n\n# ninja build does not work unless include_dirs are abs path\nthis_dir = os.path.dirname(os.path.abspath(__file__))\n\n\ndef get_cuda_bare_metal_version(cuda_dir):\n    raw_output = subprocess.check_output([cuda_dir + \"/bin/nvcc\", \"-V\"], universal_newlines=True)\n    output = raw_output.split()\n    release_idx = output.index(\"release\") + 1\n    bare_metal_version = parse(output[release_idx].split(\",\")[0])\n\n    return raw_output, bare_metal_version\n\n\ndef check_cuda_torch_binary_vs_bare_metal(cuda_dir):\n    raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir)\n    torch_binary_version = parse(torch.version.cuda)\n\n    print(\"\\nCompiling cuda extensions with\")\n    print(raw_output + \"from \" + cuda_dir + \"/bin\\n\")\n\n    if (bare_metal_version != torch_binary_version):\n        raise RuntimeError(\n            \"Cuda extensions are being compiled with a version of Cuda that does \"\n            \"not match the version used to compile Pytorch binaries.  \"\n            \"Pytorch binaries were compiled with Cuda {}.\\n\".format(torch.version.cuda)\n            + \"In some cases, a minor-version mismatch will not cause later errors:  \"\n            \"https://github.com/NVIDIA/apex/pull/323#discussion_r287021798.  \"\n            \"You can try commenting out this check (at your own risk).\"\n        )\n\n\ndef raise_if_cuda_home_none(global_option: str) -> None:\n    if CUDA_HOME is not None:\n        return\n    raise RuntimeError(\n        f\"{global_option} was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  \"\n        \"If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, \"\n        \"only images whose names contain 'devel' will provide nvcc.\"\n    )\n\n\ndef append_nvcc_threads(nvcc_extra_args):\n    _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)\n    if bare_metal_version >= Version(\"11.2\"):\n        nvcc_threads = os.getenv(\"NVCC_THREADS\") or \"4\"\n        return nvcc_extra_args + [\"--threads\", nvcc_threads]\n    return nvcc_extra_args\n\n\nif not torch.cuda.is_available():\n    # https://github.com/NVIDIA/apex/issues/486\n    # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(),\n    # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command).\n    print(\n        \"\\nWarning: Torch did not find available GPUs on this system.\\n\",\n        \"If your intention is to cross-compile, this is not an error.\\n\"\n        \"By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\\n\"\n        \"Volta (compute capability 7.0), Turing (compute capability 7.5),\\n\"\n        \"and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\\n\"\n        \"If you wish to cross-compile for a single specific architecture,\\n\"\n        'export TORCH_CUDA_ARCH_LIST=\"compute capability\" before running setup.py.\\n',\n    )\n    if os.environ.get(\"TORCH_CUDA_ARCH_LIST\", None) is None and CUDA_HOME is not None:\n        _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)\n        if bare_metal_version >= Version(\"11.8\"):\n            os.environ[\"TORCH_CUDA_ARCH_LIST\"] = \"6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0\"\n        elif bare_metal_version >= Version(\"11.1\"):\n            os.environ[\"TORCH_CUDA_ARCH_LIST\"] = \"6.0;6.1;6.2;7.0;7.5;8.0;8.6\"\n        elif bare_metal_version == Version(\"11.0\"):\n            os.environ[\"TORCH_CUDA_ARCH_LIST\"] = \"6.0;6.1;6.2;7.0;7.5;8.0\"\n        else:\n            os.environ[\"TORCH_CUDA_ARCH_LIST\"] = \"6.0;6.1;6.2;7.0;7.5\"\n\n\nprint(\"\\n\\ntorch.__version__  = {}\\n\\n\".format(torch.__version__))\nTORCH_MAJOR = int(torch.__version__.split(\".\")[0])\nTORCH_MINOR = int(torch.__version__.split(\".\")[1])\n\ncmdclass = {}\next_modules = []\n\n# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h\n# See https://github.com/pytorch/pytorch/pull/70650\ngenerator_flag = []\ntorch_dir = torch.__path__[0]\nif os.path.exists(os.path.join(torch_dir, \"include\", \"ATen\", \"CUDAGeneratorImpl.h\")):\n    generator_flag = [\"-DOLD_GENERATOR_PATH\"]\n\nraise_if_cuda_home_none(\"--fast_layer_norm\")\n# Check, if CUDA11 is installed for compute capability 8.0\ncc_flag = []\n_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)\nif bare_metal_version < Version(\"11.0\"):\n    raise RuntimeError(\"dropout_layer_norm is only supported on CUDA 11 and above\")\ncc_flag.append(\"-gencode\")\ncc_flag.append(\"arch=compute_70,code=sm_70\")\ncc_flag.append(\"-gencode\")\ncc_flag.append(\"arch=compute_80,code=sm_80\")\nif bare_metal_version >= Version(\"11.8\"):\n    cc_flag.append(\"-gencode\")\n    cc_flag.append(\"arch=compute_90,code=sm_90\")\n\next_modules.append(\n    CUDAExtension(\n        name=\"dropout_layer_norm\",\n        sources=[\n            \"ln_api.cpp\",\n            \"ln_fwd_256.cu\",\n            \"ln_bwd_256.cu\",\n            \"ln_fwd_512.cu\",\n            \"ln_bwd_512.cu\",\n            \"ln_fwd_768.cu\",\n            \"ln_bwd_768.cu\",\n            \"ln_fwd_1024.cu\",\n            \"ln_bwd_1024.cu\",\n            \"ln_fwd_1280.cu\",\n            \"ln_bwd_1280.cu\",\n            \"ln_fwd_1536.cu\",\n            \"ln_bwd_1536.cu\",\n            \"ln_fwd_2048.cu\",\n            \"ln_bwd_2048.cu\",\n            \"ln_fwd_2560.cu\",\n            \"ln_bwd_2560.cu\",\n            \"ln_fwd_3072.cu\",\n            \"ln_bwd_3072.cu\",\n            \"ln_fwd_4096.cu\",\n            \"ln_bwd_4096.cu\",\n            \"ln_fwd_5120.cu\",\n            \"ln_bwd_5120.cu\",\n            \"ln_fwd_6144.cu\",\n            \"ln_bwd_6144.cu\",\n            \"ln_fwd_7168.cu\",\n            \"ln_bwd_7168.cu\",\n            \"ln_fwd_8192.cu\",\n            \"ln_bwd_8192.cu\",\n            \"ln_parallel_fwd_256.cu\",\n            \"ln_parallel_bwd_256.cu\",\n            \"ln_parallel_fwd_512.cu\",\n            \"ln_parallel_bwd_512.cu\",\n            \"ln_parallel_fwd_768.cu\",\n            \"ln_parallel_bwd_768.cu\",\n            \"ln_parallel_fwd_1024.cu\",\n            \"ln_parallel_bwd_1024.cu\",\n            \"ln_parallel_fwd_1280.cu\",\n            \"ln_parallel_bwd_1280.cu\",\n            \"ln_parallel_fwd_1536.cu\",\n            \"ln_parallel_bwd_1536.cu\",\n            \"ln_parallel_fwd_2048.cu\",\n            \"ln_parallel_bwd_2048.cu\",\n            \"ln_parallel_fwd_2560.cu\",\n            \"ln_parallel_bwd_2560.cu\",\n            \"ln_parallel_fwd_3072.cu\",\n            \"ln_parallel_bwd_3072.cu\",\n            \"ln_parallel_fwd_4096.cu\",\n            \"ln_parallel_bwd_4096.cu\",\n            \"ln_parallel_fwd_5120.cu\",\n            \"ln_parallel_bwd_5120.cu\",\n            \"ln_parallel_fwd_6144.cu\",\n            \"ln_parallel_bwd_6144.cu\",\n            \"ln_parallel_fwd_7168.cu\",\n            \"ln_parallel_bwd_7168.cu\",\n            \"ln_parallel_fwd_8192.cu\",\n            \"ln_parallel_bwd_8192.cu\",\n        ],\n        extra_compile_args={\n            \"cxx\": [\"-O3\"] + generator_flag,\n            \"nvcc\": append_nvcc_threads(\n                [\n                    \"-O3\",\n                    \"-U__CUDA_NO_HALF_OPERATORS__\",\n                    \"-U__CUDA_NO_HALF_CONVERSIONS__\",\n                    \"-U__CUDA_NO_BFLOAT16_OPERATORS__\",\n                    \"-U__CUDA_NO_BFLOAT16_CONVERSIONS__\",\n                    \"-U__CUDA_NO_BFLOAT162_OPERATORS__\",\n                    \"-U__CUDA_NO_BFLOAT162_CONVERSIONS__\",\n                    \"--expt-relaxed-constexpr\",\n                    \"--expt-extended-lambda\",\n                    \"--use_fast_math\",\n                ]\n                + generator_flag\n                + cc_flag\n            ),\n        },\n        include_dirs=[this_dir],\n    )\n)\n\nsetup(\n    name=\"dropout_layer_norm\",\n    version=\"0.1\",\n    description=\"Fused dropout + add + layer norm\",\n    ext_modules=ext_modules,\n    cmdclass={\"build_ext\": BuildExtension} if ext_modules else {},\n)\n"
  },
  {
    "path": "csrc/layer_norm/static_switch.h",
    "content": "// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h\n// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h\n\n#pragma once\n\n/// @param COND       - a boolean expression to switch by\n/// @param CONST_NAME - a name given for the constexpr bool variable.\n/// @param ...       - code to execute for true and false\n///\n/// Usage:\n/// ```\n/// BOOL_SWITCH(flag, BoolConst, [&] {\n///     some_function<BoolConst>(...);\n/// });\n/// ```\n#define BOOL_SWITCH(COND, CONST_NAME, ...)                                           \\\n    [&] {                                                                            \\\n        if (COND) {                                                                  \\\n            constexpr bool CONST_NAME = true;                                        \\\n            return __VA_ARGS__();                                                    \\\n        } else {                                                                     \\\n            constexpr bool CONST_NAME = false;                                       \\\n            return __VA_ARGS__();                                                    \\\n        }                                                                            \\\n    }()\n"
  },
  {
    "path": "examples/inference/README.md",
    "content": "# Example of LLM inference using FlashAttention\n\nExample script of using FlashAttention for inference coming soon.\n"
  },
  {
    "path": "flash_attn/__init__.py",
    "content": "from pkgutil import extend_path\n\n# look for every subdir with flash_attn base name such that fa2 and fa4 can be co-installed\n__path__ = extend_path(__path__, __name__)\n\n__version__ = \"2.8.4\"\n\nfrom flash_attn.flash_attn_interface import (\n    flash_attn_func,\n    flash_attn_kvpacked_func,\n    flash_attn_qkvpacked_func,\n    flash_attn_varlen_func,\n    flash_attn_varlen_kvpacked_func,\n    flash_attn_varlen_qkvpacked_func,\n    flash_attn_with_kvcache,\n)\n"
  },
  {
    "path": "flash_attn/bert_padding.py",
    "content": "# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py\n\nimport torch\nimport torch.nn.functional as F\nfrom einops import rearrange, repeat\n\n\nclass IndexFirstAxis(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, input, indices):\n        ctx.save_for_backward(indices)\n        assert input.ndim >= 2\n        ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]\n        second_dim = other_shape.numel()\n        # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.\n        # return input[indices]\n        return torch.gather(\n            rearrange(input, \"b ... -> b (...)\"), 0, repeat(indices, \"z -> z d\", d=second_dim)\n        ).reshape(-1, *other_shape)\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        (indices,) = ctx.saved_tensors\n        assert grad_output.ndim >= 2\n        other_shape = grad_output.shape[1:]\n        grad_output = rearrange(grad_output, \"b ... -> b (...)\")\n        grad_input = torch.zeros(\n            [ctx.first_axis_dim, grad_output.shape[1]],\n            device=grad_output.device,\n            dtype=grad_output.dtype,\n        )\n        # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.\n        # grad_input[indices] = grad_output\n        grad_input.scatter_(0, repeat(indices, \"z -> z d\", d=grad_output.shape[1]), grad_output)\n        return grad_input.reshape(ctx.first_axis_dim, *other_shape), None\n\n\nindex_first_axis = IndexFirstAxis.apply\n\n\nclass IndexPutFirstAxis(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, values, indices, first_axis_dim):\n        ctx.save_for_backward(indices)\n        assert indices.ndim == 1\n        assert values.ndim >= 2\n        output = torch.zeros(\n            first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype\n        )\n        # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.\n        output[indices] = values\n        # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)\n        return output\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        (indices,) = ctx.saved_tensors\n        # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.\n        grad_values = grad_output[indices]\n        # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))\n        return grad_values, None, None\n\n\nindex_put_first_axis = IndexPutFirstAxis.apply\n\n\nclass IndexFirstAxisResidual(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, input, indices):\n        ctx.save_for_backward(indices)\n        assert input.ndim >= 2\n        ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]\n        second_dim = other_shape.numel()\n        # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.\n        output = input[indices]\n        # We don't want to reshape input (b ... -> b (...)) since it could change the channel_last\n        # memory format to channel_first. In other words, input might not be contiguous.\n        # If we don't detach, Pytorch complains about output being a view and is being modified inplace\n        return output, input.detach()\n\n    @staticmethod\n    def backward(ctx, grad_output, grad_residual):\n        (indices,) = ctx.saved_tensors\n        assert grad_output.ndim >= 2\n        other_shape = grad_output.shape[1:]\n        assert grad_residual.shape[1:] == other_shape\n        grad_input = grad_residual\n        # grad_input[indices] += grad_output\n        indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1)))\n        indices = indices.expand_as(grad_output)\n        grad_input.scatter_add_(0, indices, grad_output)\n        return grad_input.reshape(ctx.first_axis_dim, *other_shape), None\n\n\nindex_first_axis_residual = IndexFirstAxisResidual.apply\n\n\ndef unpad_input(hidden_states, attention_mask, unused_mask=None):\n    \"\"\"\n    Arguments:\n        hidden_states: (batch, seqlen, ...)\n        attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.\n        unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.\n    Return:\n        hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.\n        indices: (total_nnz), the indices of masked tokens from the flattened input sequence.\n        cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.\n        max_seqlen_in_batch: int\n        seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.\n    \"\"\"\n    all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask\n    seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)\n    used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)\n    indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()\n    max_seqlen_in_batch = seqlens_in_batch.max().item()\n    cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))\n    # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the\n    # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim\n    # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to\n    # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,\n    # so we write custom forward and backward to make it a bit faster.\n    return (\n        index_first_axis(rearrange(hidden_states, \"b s ... -> (b s) ...\"), indices),\n        indices,\n        cu_seqlens,\n        max_seqlen_in_batch,\n        used_seqlens_in_batch, \n    )\n\n\ndef unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length):\n    \"\"\"\n    Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model).\n    The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286).\n    \n    For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is:\n        ```\n        [\n          [2, 3, 0, 0, 0, 0],\n          [3, 2, 0, 0, 0, 0],\n          [6, 0, 0, 0, 0, 0]\n        ]\n        ```\n    , which refers to the 3D-attention mask:\n        ```\n        [\n          [\n            [1, 0, 0, 0, 0, 0],\n            [1, 1, 0, 0, 0, 0],\n            [0, 0, 1, 0, 0, 0],\n            [0, 0, 1, 1, 0, 0],\n            [0, 0, 1, 1, 1, 0],\n            [0, 0, 0, 0, 0, 1]\n          ],\n          [\n            [1, 0, 0, 0, 0, 0],\n            [1, 1, 0, 0, 0, 0],\n            [1, 1, 1, 0, 0, 0],\n            [0, 0, 0, 1, 0, 0],\n            [0, 0, 0, 1, 1, 0],\n            [0, 0, 0, 0, 0, 1]\n          ],\n          [\n            [1, 0, 0, 0, 0, 0],\n            [1, 1, 0, 0, 0, 0],\n            [1, 1, 1, 0, 0, 0],\n            [1, 1, 1, 1, 0, 0],\n            [1, 1, 1, 1, 1, 0],\n            [1, 1, 1, 1, 1, 1]\n          ]\n        ]\n        ```.\n\n    Arguments:\n        hidden_states: (batch, seqlen, ...)\n        attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none.\n    Return:\n        hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.\n        indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.\n        cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.\n        max_seqlen_in_batch: int\n    \"\"\"\n    length = attention_mask_in_length.sum(dim=-1)\n    seqlen = attention_mask_in_length.size(-1)\n    attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(len(length), seqlen) < length.unsqueeze(1)\n    real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten()\n    seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]\n    indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()\n    max_seqlen_in_batch = seqlens_in_batch.max().item()\n    cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))\n    # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the\n    # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim\n    # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to\n    # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,\n    # so we write custom forward and backward to make it a bit faster.\n    return (\n        index_first_axis(rearrange(hidden_states, \"b s ... -> (b s) ...\"), indices),\n        indices,\n        cu_seqlens,\n        max_seqlen_in_batch,\n    )\n\n\ndef pad_input(hidden_states, indices, batch, seqlen):\n    \"\"\"\n    Arguments:\n        hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.\n        indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.\n        batch: int, batch size for the padded sequence.\n        seqlen: int, maximum sequence length for the padded sequence.\n    Return:\n        hidden_states: (batch, seqlen, ...)\n    \"\"\"\n    dim = hidden_states.shape[-1]\n    # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)\n    # output[indices] = hidden_states\n    output = index_put_first_axis(hidden_states, indices, batch * seqlen)\n    return rearrange(output, \"(b s) ... -> b s ...\", b=batch)\n"
  },
  {
    "path": "flash_attn/cute/.flake8",
    "content": "[flake8]\nmax-line-length = 100\n# W503: line break before binary operator\nignore = E731, E741, F841, W503\n"
  },
  {
    "path": "flash_attn/cute/AUTHORS",
    "content": "Tri Dao\nJay Shah\nTed Zadouri\nMarkus Hoehnerbach\nVijay Thakkar\nTimmy Liu\nDriss Guessous\nReuben Stern"
  },
  {
    "path": "flash_attn/cute/LICENSE",
    "content": "BSD 3-Clause License\n\nCopyright (c) 2022, the respective contributors, as shown by the AUTHORS file.\nAll rights reserved.\n\nRedistribution and use in source and binary forms, with or without\nmodification, are permitted provided that the following conditions are met:\n\n* Redistributions of source code must retain the above copyright notice, this\n  list of conditions and the following disclaimer.\n\n* Redistributions in binary form must reproduce the above copyright notice,\n  this list of conditions and the following disclaimer in the documentation\n  and/or other materials provided with the distribution.\n\n* Neither the name of the copyright holder nor the names of its\n  contributors may be used to endorse or promote products derived from\n  this software without specific prior written permission.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\nAND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\nIMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\nDISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\nFOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\nDAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\nSERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\nCAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\nOR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\nOF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n"
  },
  {
    "path": "flash_attn/cute/MANIFEST.in",
    "content": "global-exclude *.egg-info/*\nprune flash_attn_4.egg-info\nprune flash_attn.egg-info\nprune build\nprune dist\n"
  },
  {
    "path": "flash_attn/cute/README.md",
    "content": "# FlashAttention-4 (CuTeDSL)\n\nFlashAttention-4 is a CuTeDSL-based implementation of FlashAttention for Hopper and Blackwell GPUs.\n\n## Installation\n\n```sh\npip install flash-attn-4\n```\n\n## Usage\n\n```python\nfrom flash_attn.cute import flash_attn_func, flash_attn_varlen_func\n\nout = flash_attn_func(q, k, v, causal=True)\n```\n\n## Development\n\n```sh\ngit clone https://github.com/Dao-AILab/flash-attention.git\ncd flash-attention\npip install -e \"flash_attn/cute[dev]\"\npytest tests/cute/\n```\n"
  },
  {
    "path": "flash_attn/cute/__init__.py",
    "content": "\"\"\"Flash Attention CUTE (CUDA Template Engine) implementation.\"\"\"\n\nfrom importlib.metadata import PackageNotFoundError, version\n\ntry:\n    __version__ = version(\"fa4\")\nexcept PackageNotFoundError:\n    __version__ = \"0.0.0\"\n\nimport cutlass.cute as cute\n\nfrom .interface import (\n    flash_attn_func,\n    flash_attn_varlen_func,\n)\n\nfrom flash_attn.cute.cute_dsl_utils import cute_compile_patched\n\n# Patch cute.compile to optionally dump SASS\ncute.compile = cute_compile_patched\n\n\n__all__ = [\n    \"flash_attn_func\",\n    \"flash_attn_varlen_func\",\n]\n"
  },
  {
    "path": "flash_attn/cute/ampere_helpers.py",
    "content": "# Copyright (c) 2025, Tri Dao.\nfrom typing import Type, Callable, Optional\n\nimport cutlass\nimport cutlass.cute as cute\n\n\ndef get_smem_layout_atom(dtype: Type[cutlass.Numeric], k_dim: int) -> cute.ComposedLayout:\n    dtype_byte = cutlass.const_expr(dtype.width // 8)\n    bytes_per_row = cutlass.const_expr(k_dim * dtype_byte)\n    smem_k_block_size = (\n        cutlass.const_expr(\n            128\n            if bytes_per_row % 128 == 0\n            else (64 if bytes_per_row % 64 == 0 else (32 if bytes_per_row % 32 == 0 else 16))\n        )\n        // dtype_byte\n    )\n    swizzle_bits = (\n        4\n        if smem_k_block_size == 128\n        else (3 if smem_k_block_size == 64 else (2 if smem_k_block_size == 32 else 1))\n    )\n    swizzle_base = 2 if dtype_byte == 4 else (3 if dtype_byte == 2 else 4)\n    return cute.make_composed_layout(\n        cute.make_swizzle(swizzle_bits, swizzle_base, swizzle_base),\n        0,\n        cute.make_ordered_layout(\n            (8 if cutlass.const_expr(k_dim % 32 == 0) else 16, smem_k_block_size), order=(1, 0)\n        ),\n    )\n\n\n@cute.jit\ndef gemm(\n    tiled_mma: cute.TiledMma,\n    acc: cute.Tensor,\n    tCrA: cute.Tensor,\n    tCrB: cute.Tensor,\n    tCsA: cute.Tensor,\n    tCsB: cute.Tensor,\n    smem_thr_copy_A: cute.TiledCopy,\n    smem_thr_copy_B: cute.TiledCopy,\n    hook_fn: Optional[Callable] = None,\n    A_in_regs: cutlass.Constexpr[bool] = False,\n    B_in_regs: cutlass.Constexpr[bool] = False,\n    swap_AB: cutlass.Constexpr[bool] = False,\n) -> None:\n    if cutlass.const_expr(swap_AB):\n        gemm(\n            tiled_mma,\n            acc,\n            tCrB,\n            tCrA,\n            tCsB,\n            tCsA,\n            smem_thr_copy_B,\n            smem_thr_copy_A,\n            hook_fn,\n            A_in_regs=B_in_regs,\n            B_in_regs=A_in_regs,\n            swap_AB=False,\n        )\n    else:\n        tCrA_copy_view = smem_thr_copy_A.retile(tCrA)\n        tCrB_copy_view = smem_thr_copy_B.retile(tCrB)\n        if cutlass.const_expr(not A_in_regs):\n            cute.copy(smem_thr_copy_A, tCsA[None, None, 0], tCrA_copy_view[None, None, 0])\n        if cutlass.const_expr(not B_in_regs):\n            cute.copy(smem_thr_copy_B, tCsB[None, None, 0], tCrB_copy_view[None, None, 0])\n        for k in cutlass.range_constexpr(cute.size(tCsA.shape[2])):\n            if k < cute.size(tCsA.shape[2]) - 1:\n                if cutlass.const_expr(not A_in_regs):\n                    cute.copy(\n                        smem_thr_copy_A, tCsA[None, None, k + 1], tCrA_copy_view[None, None, k + 1]\n                    )\n                if cutlass.const_expr(not B_in_regs):\n                    cute.copy(\n                        smem_thr_copy_B, tCsB[None, None, k + 1], tCrB_copy_view[None, None, k + 1]\n                    )\n            cute.gemm(tiled_mma, acc, tCrA[None, None, k], tCrB[None, None, k], acc)\n            if cutlass.const_expr(k == 0 and hook_fn is not None):\n                hook_fn()\n\n\n@cute.jit\ndef gemm_rs(\n    tiled_mma: cute.TiledMma,\n    acc: cute.Tensor,\n    tCrA: cute.Tensor,\n    tCrB: cute.Tensor,\n    tCsB: cute.Tensor,\n    smem_thr_copy_B: cute.TiledCopy,\n    hook_fn: Optional[Callable] = None,\n) -> None:\n    tCrB_copy_view = smem_thr_copy_B.retile(tCrB)\n    cute.copy(smem_thr_copy_B, tCsB[None, None, 0], tCrB_copy_view[None, None, 0])\n    for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])):\n        if cutlass.const_expr(k < cute.size(tCrA.shape[2]) - 1):\n            cute.copy(smem_thr_copy_B, tCsB[None, None, k + 1], tCrB_copy_view[None, None, k + 1])\n        cute.gemm(tiled_mma, acc, tCrA[None, None, k], tCrB[None, None, k], acc)\n        if cutlass.const_expr(k == 0 and hook_fn is not None):\n            hook_fn()\n"
  },
  {
    "path": "flash_attn/cute/barrier.py",
    "content": "import cutlass\nimport cutlass.cute as cute\nfrom cutlass import Int32\nfrom cutlass.cutlass_dsl import T, dsl_user_op\nfrom cutlass._mlir.dialects import llvm\n\n\n@dsl_user_op\ndef ld_acquire(lock_ptr: cute.Pointer, *, loc=None, ip=None) -> cutlass.Int32:\n    lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value()\n    state = llvm.inline_asm(\n        T.i32(),\n        [lock_ptr_i64],\n        \"ld.global.acquire.gpu.b32 $0, [$1];\",\n        \"=r,l\",\n        has_side_effects=True,\n        is_align_stack=False,\n        asm_dialect=llvm.AsmDialect.AD_ATT,\n    )\n    return cutlass.Int32(state)\n\n\n@dsl_user_op\ndef red_relaxed(\n    lock_ptr: cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None\n) -> None:\n    lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value()\n    llvm.inline_asm(\n        None,\n        [lock_ptr_i64, Int32(val).ir_value(loc=loc, ip=ip)],\n        \"red.relaxed.gpu.global.add.s32 [$0], $1;\",\n        \"l,r\",\n        has_side_effects=True,\n        is_align_stack=False,\n        asm_dialect=llvm.AsmDialect.AD_ATT,\n    )\n\n\n@dsl_user_op\ndef red_release(\n    lock_ptr: cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None\n) -> None:\n    lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value()\n    llvm.inline_asm(\n        None,\n        [lock_ptr_i64, Int32(val).ir_value(loc=loc, ip=ip)],\n        \"red.release.gpu.global.add.s32 [$0], $1;\",\n        \"l,r\",\n        has_side_effects=True,\n        is_align_stack=False,\n        asm_dialect=llvm.AsmDialect.AD_ATT,\n    )\n\n\n@cute.jit\ndef wait_eq(lock_ptr: cute.Pointer, thread_idx: int | Int32, flag_offset: int, val: Int32) -> None:\n    flag_ptr = lock_ptr + flag_offset\n    if thread_idx == 0:\n        read_val = Int32(0)\n        while read_val != val:\n            read_val = ld_acquire(flag_ptr)\n\n\n@cute.jit\ndef arrive_inc(\n    lock_ptr: cute.Pointer, thread_idx: int | Int32, flag_offset: int, val: cutlass.Constexpr[Int32]\n) -> None:\n    flag_ptr = lock_ptr + flag_offset\n    if thread_idx == 0:\n        red_release(flag_ptr, val)\n        # red_relaxed(flag_ptr, val)\n"
  },
  {
    "path": "flash_attn/cute/bench_utils.py",
    "content": "\"\"\"Shared benchmark utilities: attention_ref, cuDNN helpers, flops calculation.\"\"\"\n\nimport math\nimport torch\n\ntry:\n    import cudnn\nexcept ImportError:\n    cudnn = None\n\n\n# ── FLOPS calculation ────────────────────────────────────────────────────────\n\n\ndef flops(\n    batch, nheads, seqlen_q, seqlen_k, headdim, headdim_v, causal=False, window_size=(None, None)\n):\n    if causal:\n        avg_seqlen = (max(0, seqlen_k - seqlen_q) + seqlen_k) / 2\n    else:\n        if window_size == (None, None):\n            avg_seqlen = seqlen_k\n        else:\n            row_idx = torch.arange(seqlen_q, device=\"cuda\")\n            col_left = (\n                torch.maximum(row_idx + seqlen_k - seqlen_q - window_size[0], torch.tensor(0))\n                if window_size[0] is not None\n                else torch.zeros_like(row_idx)\n            )\n            col_right = (\n                torch.minimum(\n                    row_idx + seqlen_k - seqlen_q + window_size[1], torch.tensor(seqlen_k - 1)\n                )\n                if window_size[1] is not None\n                else torch.full_like(row_idx, seqlen_k - 1)\n            )\n            avg_seqlen = (col_right - col_left + 1).float().mean().item()\n    return batch * nheads * 2 * seqlen_q * avg_seqlen * (headdim + headdim_v)\n\n\n# ── Reference attention ─────────────────────────────────────────────────────\n\n_attention_ref_mask_cache = {}\n\n\ndef attention_ref(q, k, v, causal=False):\n    \"\"\"Standard attention reference implementation.\n\n    Args:\n        q, k, v: (batch, seqlen, nheads, headdim) tensors.\n        causal: whether to apply causal mask.\n    \"\"\"\n    softmax_scale = 1.0 / math.sqrt(q.shape[-1])\n    scores = torch.einsum(\"bthd,bshd->bhts\", q * softmax_scale, k)\n    if causal:\n        if scores.shape[-2] not in _attention_ref_mask_cache:\n            mask = torch.tril(\n                torch.ones(scores.shape[-2:], device=scores.device, dtype=torch.bool), diagonal=0\n            )\n            _attention_ref_mask_cache[scores.shape[-2]] = mask\n        else:\n            mask = _attention_ref_mask_cache[scores.shape[-2]]\n        scores = scores.masked_fill(mask, float(\"-inf\"))\n    attn = torch.softmax(scores, dim=-1)\n    return torch.einsum(\"bhts,bshd->bthd\", attn, v)\n\n\n# ── cuDNN graph helpers ─────────────────────────────────────────────────────\n\n_TORCH_TO_CUDNN_DTYPE = {\n    torch.float16: \"HALF\",\n    torch.bfloat16: \"BFLOAT16\",\n    torch.float32: \"FLOAT\",\n    torch.int32: \"INT32\",\n    torch.int64: \"INT64\",\n}\n\n\ndef _build_cudnn_graph(io_dtype, tensors, build_fn):\n    \"\"\"Build a cuDNN graph.  Returns (graph, variant_pack, workspace).\"\"\"\n    assert cudnn is not None, \"cuDNN is not available\"\n    cudnn_dtype = getattr(cudnn.data_type, _TORCH_TO_CUDNN_DTYPE[io_dtype])\n    graph = cudnn.pygraph(\n        io_data_type=cudnn_dtype,\n        intermediate_data_type=cudnn.data_type.FLOAT,\n        compute_data_type=cudnn.data_type.FLOAT,\n    )\n    graph_tensors = {name: graph.tensor_like(t.detach()) for name, t in tensors.items()}\n    variant_pack = build_fn(graph, graph_tensors)\n    graph.validate()\n    graph.build_operation_graph()\n    graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])\n    graph.check_support()\n    graph.build_plans()\n    workspace = torch.empty(graph.get_workspace_size(), device=\"cuda\", dtype=torch.uint8)\n    return graph, variant_pack, workspace\n\n\ndef cudnn_fwd_setup(q, k, v, causal=False, window_size_left=None):\n    \"\"\"Build a cuDNN forward SDPA graph.\n\n    Args:\n        q, k, v: (batch, nheads, seqlen, headdim) tensors (cuDNN layout).\n        causal: whether to apply causal mask.\n        window_size_left: sliding window size (None for no window).\n\n    Returns:\n        (fwd_fn, o_gpu, stats_gpu) where fwd_fn is a zero-arg callable.\n    \"\"\"\n    b, nheads, seqlen_q, headdim = q.shape\n    headdim_v = v.shape[-1]\n    o_gpu = torch.empty(b, nheads, seqlen_q, headdim_v, dtype=q.dtype, device=q.device)\n    stats_gpu = torch.empty(b, nheads, seqlen_q, 1, dtype=torch.float32, device=q.device)\n\n    def build(graph, gt):\n        o, stats = graph.sdpa(\n            name=\"sdpa\",\n            q=gt[\"q\"],\n            k=gt[\"k\"],\n            v=gt[\"v\"],\n            is_inference=False,\n            attn_scale=1.0 / math.sqrt(headdim),\n            use_causal_mask=causal or window_size_left is not None,\n            sliding_window_length=window_size_left\n            if window_size_left is not None and not causal\n            else None,\n        )\n        o.set_output(True).set_dim(o_gpu.shape).set_stride(o_gpu.stride())\n        stats.set_output(True).set_data_type(cudnn.data_type.FLOAT)\n        return {gt[\"q\"]: q, gt[\"k\"]: k, gt[\"v\"]: v, o: o_gpu, stats: stats_gpu}\n\n    graph, variant_pack, workspace = _build_cudnn_graph(q.dtype, {\"q\": q, \"k\": k, \"v\": v}, build)\n\n    def fwd_fn():\n        graph.execute(variant_pack, workspace)\n        return o_gpu\n\n    return fwd_fn, o_gpu, stats_gpu\n\n\ndef cudnn_bwd_setup(q, k, v, o, g, lse, causal=False, window_size_left=None):\n    \"\"\"Build a cuDNN backward SDPA graph.\n\n    Args:\n        q, k, v, o, g, lse: (batch, nheads, seqlen, dim) tensors (cuDNN layout).\n        causal: whether to apply causal mask.\n        window_size_left: sliding window size (None for no window).\n\n    Returns:\n        bwd_fn: zero-arg callable that returns (dq, dk, dv).\n    \"\"\"\n    headdim = q.shape[-1]\n    dq_gpu, dk_gpu, dv_gpu = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)\n\n    def build(graph, gt):\n        dq, dk, dv = graph.sdpa_backward(\n            name=\"sdpa_backward\",\n            q=gt[\"q\"],\n            k=gt[\"k\"],\n            v=gt[\"v\"],\n            o=gt[\"o\"],\n            dO=gt[\"g\"],\n            stats=gt[\"lse\"],\n            attn_scale=1.0 / math.sqrt(headdim),\n            use_causal_mask=causal or window_size_left is not None,\n            sliding_window_length=window_size_left\n            if window_size_left is not None and not causal\n            else None,\n            use_deterministic_algorithm=False,\n        )\n        dq.set_output(True).set_dim(dq_gpu.shape).set_stride(dq_gpu.stride())\n        dk.set_output(True).set_dim(dk_gpu.shape).set_stride(dk_gpu.stride())\n        dv.set_output(True).set_dim(dv_gpu.shape).set_stride(dv_gpu.stride())\n        return {\n            gt[\"q\"]: q,\n            gt[\"k\"]: k,\n            gt[\"v\"]: v,\n            gt[\"o\"]: o,\n            gt[\"g\"]: g,\n            gt[\"lse\"]: lse,\n            dq: dq_gpu,\n            dk: dk_gpu,\n            dv: dv_gpu,\n        }\n\n    graph, variant_pack, workspace = _build_cudnn_graph(\n        q.dtype,\n        {\"q\": q, \"k\": k, \"v\": v, \"o\": o, \"g\": g, \"lse\": lse},\n        build,\n    )\n\n    def bwd_fn():\n        graph.execute(variant_pack, workspace)\n        return dq_gpu, dk_gpu, dv_gpu\n\n    return bwd_fn\n"
  },
  {
    "path": "flash_attn/cute/benchmark.py",
    "content": "# Copyright (c) 2023, Tri Dao.\n\"\"\"Useful functions for writing test code.\"\"\"\n\nimport torch\nimport torch.utils.benchmark as benchmark\n\n\ndef benchmark_forward(\n    fn, *inputs, repeats=10, desc=\"\", verbose=True, amp=False, amp_dtype=torch.float16, **kwinputs\n):\n    \"\"\"Use Pytorch Benchmark on the forward pass of an arbitrary function.\"\"\"\n    if verbose:\n        print(desc, \"- Forward pass\")\n\n    def amp_wrapper(*inputs, **kwinputs):\n        with torch.autocast(device_type=\"cuda\", dtype=amp_dtype, enabled=amp):\n            fn(*inputs, **kwinputs)\n\n    t = benchmark.Timer(\n        stmt=\"fn_amp(*inputs, **kwinputs)\",\n        globals={\"fn_amp\": amp_wrapper, \"inputs\": inputs, \"kwinputs\": kwinputs},\n        num_threads=torch.get_num_threads(),\n    )\n    m = t.timeit(repeats)\n    if verbose:\n        print(m)\n    return t, m\n\n\ndef benchmark_backward(\n    fn,\n    *inputs,\n    grad=None,\n    repeats=10,\n    desc=\"\",\n    verbose=True,\n    amp=False,\n    amp_dtype=torch.float16,\n    **kwinputs,\n):\n    \"\"\"Use Pytorch Benchmark on the backward pass of an arbitrary function.\"\"\"\n    if verbose:\n        print(desc, \"- Backward pass\")\n    with torch.autocast(device_type=\"cuda\", dtype=amp_dtype, enabled=amp):\n        y = fn(*inputs, **kwinputs)\n        if type(y) is tuple:\n            y = y[0]\n    if grad is None:\n        grad = torch.randn_like(y)\n    else:\n        if grad.shape != y.shape:\n            raise RuntimeError(\"Grad shape does not match output shape\")\n\n    def f(*inputs, y, grad):\n        # Set .grad to None to avoid extra operation of gradient accumulation\n        for x in inputs:\n            if isinstance(x, torch.Tensor):\n                x.grad = None\n        y.backward(grad, retain_graph=True)\n\n    t = benchmark.Timer(\n        stmt=\"f(*inputs, y=y, grad=grad)\",\n        globals={\"f\": f, \"inputs\": inputs, \"y\": y, \"grad\": grad},\n        num_threads=torch.get_num_threads(),\n    )\n    m = t.timeit(repeats)\n    if verbose:\n        print(m)\n    return t, m\n\n\ndef benchmark_combined(\n    fn,\n    *inputs,\n    grad=None,\n    repeats=10,\n    desc=\"\",\n    verbose=True,\n    amp=False,\n    amp_dtype=torch.float16,\n    **kwinputs,\n):\n    \"\"\"Use Pytorch Benchmark on the forward+backward pass of an arbitrary function.\"\"\"\n    if verbose:\n        print(desc, \"- Forward + Backward pass\")\n    with torch.autocast(device_type=\"cuda\", dtype=amp_dtype, enabled=amp):\n        y = fn(*inputs, **kwinputs)\n        if type(y) is tuple:\n            y = y[0]\n    if grad is None:\n        grad = torch.randn_like(y)\n    else:\n        if grad.shape != y.shape:\n            raise RuntimeError(\"Grad shape does not match output shape\")\n\n    def f(grad, *inputs, **kwinputs):\n        for x in inputs:\n            if isinstance(x, torch.Tensor):\n                x.grad = None\n        with torch.autocast(device_type=\"cuda\", dtype=amp_dtype, enabled=amp):\n            y = fn(*inputs, **kwinputs)\n            if type(y) is tuple:\n                y = y[0]\n        y.backward(grad, retain_graph=True)\n\n    t = benchmark.Timer(\n        stmt=\"f(grad, *inputs, **kwinputs)\",\n        globals={\"f\": f, \"fn\": fn, \"inputs\": inputs, \"grad\": grad, \"kwinputs\": kwinputs},\n        num_threads=torch.get_num_threads(),\n    )\n    m = t.timeit(repeats)\n    if verbose:\n        print(m)\n    return t, m\n\n\ndef benchmark_fwd_bwd(\n    fn,\n    *inputs,\n    grad=None,\n    repeats=10,\n    desc=\"\",\n    verbose=True,\n    amp=False,\n    amp_dtype=torch.float16,\n    **kwinputs,\n):\n    \"\"\"Use Pytorch Benchmark on the forward+backward pass of an arbitrary function.\"\"\"\n    return (\n        benchmark_forward(\n            fn,\n            *inputs,\n            repeats=repeats,\n            desc=desc,\n            verbose=verbose,\n            amp=amp,\n            amp_dtype=amp_dtype,\n            **kwinputs,\n        ),\n        benchmark_backward(\n            fn,\n            *inputs,\n            grad=grad,\n            repeats=repeats,\n            desc=desc,\n            verbose=verbose,\n            amp=amp,\n            amp_dtype=amp_dtype,\n            **kwinputs,\n        ),\n    )\n\n\ndef benchmark_all(\n    fn,\n    *inputs,\n    grad=None,\n    repeats=10,\n    desc=\"\",\n    verbose=True,\n    amp=False,\n    amp_dtype=torch.float16,\n    **kwinputs,\n):\n    \"\"\"Use Pytorch Benchmark on the forward+backward pass of an arbitrary function.\"\"\"\n    return (\n        benchmark_forward(\n            fn,\n            *inputs,\n            repeats=repeats,\n            desc=desc,\n            verbose=verbose,\n            amp=amp,\n            amp_dtype=amp_dtype,\n            **kwinputs,\n        ),\n        benchmark_backward(\n            fn,\n            *inputs,\n            grad=grad,\n            repeats=repeats,\n            desc=desc,\n            verbose=verbose,\n            amp=amp,\n            amp_dtype=amp_dtype,\n            **kwinputs,\n        ),\n        benchmark_combined(\n            fn,\n            *inputs,\n            grad=grad,\n            repeats=repeats,\n            desc=desc,\n            verbose=verbose,\n            amp=amp,\n            amp_dtype=amp_dtype,\n            **kwinputs,\n        ),\n    )\n\n\ndef pytorch_profiler(\n    fn,\n    *inputs,\n    trace_filename=None,\n    backward=False,\n    amp=False,\n    amp_dtype=torch.float16,\n    cpu=False,\n    verbose=True,\n    **kwinputs,\n):\n    \"\"\"Wrap benchmark functions in Pytorch profiler to see CUDA information.\"\"\"\n    if backward:\n        with torch.autocast(device_type=\"cuda\", dtype=amp_dtype, enabled=amp):\n            out = fn(*inputs, **kwinputs)\n            if type(out) is tuple:\n                out = out[0]\n            g = torch.randn_like(out)\n    for _ in range(30):  # Warm up\n        if backward:\n            for x in inputs:\n                if isinstance(x, torch.Tensor):\n                    x.grad = None\n        with torch.autocast(device_type=\"cuda\", dtype=amp_dtype, enabled=amp):\n            out = fn(*inputs, **kwinputs)\n            if type(out) is tuple:\n                out = out[0]\n        # Backward should be done outside autocast\n        if backward:\n            out.backward(g, retain_graph=True)\n    activities = ([torch.profiler.ProfilerActivity.CPU] if cpu else []) + [\n        torch.profiler.ProfilerActivity.CUDA\n    ]\n    with torch.profiler.profile(\n        activities=activities,\n        record_shapes=True,\n        # profile_memory=True,\n        with_stack=True,\n    ) as prof:\n        if backward:\n            for x in inputs:\n                if isinstance(x, torch.Tensor):\n                    x.grad = None\n        with torch.autocast(device_type=\"cuda\", dtype=amp_dtype, enabled=amp):\n            out = fn(*inputs, **kwinputs)\n            if type(out) is tuple:\n                out = out[0]\n        if backward:\n            out.backward(g, retain_graph=True)\n    if verbose:\n        # print(prof.key_averages().table(sort_by=\"self_cuda_time_total\", row_limit=50))\n        print(prof.key_averages().table(row_limit=50))\n    if trace_filename is not None:\n        prof.export_chrome_trace(trace_filename)\n\n\ndef benchmark_memory(fn, *inputs, desc=\"\", verbose=True, **kwinputs):\n    torch.cuda.empty_cache()\n    torch.cuda.reset_peak_memory_stats()\n    torch.cuda.synchronize()\n    fn(*inputs, **kwinputs)\n    torch.cuda.synchronize()\n    mem = torch.cuda.max_memory_allocated() / ((2**20) * 1000)\n    if verbose:\n        print(f\"{desc} max memory: {mem}GB\")\n    torch.cuda.empty_cache()\n    return mem\n"
  },
  {
    "path": "flash_attn/cute/blackwell_helpers.py",
    "content": "# Copyright (c) 2025, Tri Dao.\nfrom typing import Optional, Tuple\n\nimport cutlass\nimport cutlass.cute as cute\nfrom cutlass import Int32, Boolean, const_expr\nfrom cutlass.cute.nvgpu import tcgen05\nfrom cutlass._mlir.dialects import llvm\n\nimport flash_attn.cute.mma_sm100_desc as sm100_desc\n\n\n@cute.jit\ndef gemm_w_idx(\n    tiled_mma: cute.TiledMma,\n    acc: cute.Tensor,\n    tCrA: cute.Tensor,\n    tCrB: cute.Tensor,\n    A_idx: Optional[Int32] = None,\n    B_idx: Optional[Int32] = None,\n    zero_init: bool | Boolean = False,\n    swap_AB: bool = False,\n    num_unroll_groups: int = 1,\n) -> None:\n    if const_expr(swap_AB):\n        return gemm_w_idx(\n            tiled_mma, acc, tCrB, tCrA, B_idx, A_idx, zero_init=zero_init, swap_AB=False\n        )\n    else:\n        rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx]\n        rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx]\n\n        mma_atom = cute.make_mma_atom(tiled_mma.op)\n        for k in cutlass.range(\n            cute.size(tCrA.shape[2]), unroll=cute.size(tCrA.shape[2]) // num_unroll_groups\n        ):\n            mma_atom.set(tcgen05.Field.ACCUMULATE, not zero_init or k != 0)\n            cute.gemm(mma_atom, acc, rA[None, None, k], rB[None, None, k], acc)\n\n\n@cute.jit\ndef gemm_ptx_w_idx(\n    tiled_mma: cute.TiledMma,\n    acc: cute.Tensor,\n    tCrA: cute.Tensor,\n    tCrB: cute.Tensor,\n    sA: Optional[cute.Tensor],\n    sB: cute.Tensor,\n    A_idx: Optional[Int32] = None,\n    B_idx: Optional[Int32] = None,\n    zero_init: bool | Boolean = False,\n    cta_group: int = 1,\n    **kwargs,\n) -> None:\n    rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx]\n    rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx]\n    sA_cur = None\n    if const_expr(sA is not None):\n        sA_cur = sA if const_expr(A_idx is None) else sA[None, None, None, A_idx]\n    sB_cur = sB if const_expr(B_idx is None) else sB[None, None, None, B_idx]\n    mma_atom = cute.make_mma_atom(tiled_mma.op)\n    acc_tmem_addr = acc.iterator.toint()\n    gemm_ptx_partial(\n        mma_atom.op,\n        acc_tmem_addr,\n        rA,\n        rB,\n        sA_cur,\n        sB_cur,\n        zero_init=zero_init,\n        cta_group=cta_group,\n        **kwargs,\n    )\n\n\n@cute.jit\ndef gemm(\n    tiled_mma: cute.TiledMma,\n    acc: cute.Tensor,\n    tCrA: cute.Tensor,\n    tCrB: cute.Tensor,\n    zero_init: bool | Boolean = False,\n) -> None:\n    mma_atom = cute.make_mma_atom(tiled_mma.op)\n    for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])):\n        mma_atom.set(tcgen05.Field.ACCUMULATE, not zero_init or k != 0)\n        cute.gemm(mma_atom, acc, tCrA[None, None, k], tCrB[None, None, k], acc)\n\n\ndef i64_to_i32x2(i: int) -> Tuple[int, int]:\n    \"\"\"Convert a 64-bit integer to a tuple of two 32-bit integers.\"\"\"\n    return i & 0xFFFF_FFFF, (i >> 32) & 0xFFFF_FFFF\n\n\n@cute.jit\ndef gemm_ptx(\n    op: cute.nvgpu.tcgen05.mma.MmaOp,\n    acc: cute.Tensor,\n    tCrA: cute.Tensor,\n    tCrB: cute.Tensor,\n    sA: Optional[cute.Tensor],\n    sB: cute.Tensor,\n    zero_init: bool | Boolean = False,\n) -> None:\n    is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM\n    if const_expr(not is_ts):\n        assert sA is not None, \"sA must be provided when a_src is not TMEM\"\n    sA_layout = sA.layout if sA is not None else None\n    sB_layout = sB.layout\n    idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op))\n    if const_expr(not is_ts):\n        sA_swizzle = sA.iterator.type.swizzle_type\n        smem_desc_base_a: int = const_expr(\n            sm100_desc.make_smem_desc_base(\n                cute.recast_layout(128, op.a_dtype.width, sA_layout[0]),\n                sA_swizzle,\n                sm100_desc.Major.K\n                if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)\n                else sm100_desc.Major.MN,\n            )\n        )\n        smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a)\n        smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo)\n        smem_desc_a_hi = const_expr(smem_desc_a_hi)\n    else:\n        smem_desc_base_a = None\n        smem_desc_base_a_lo, smem_desc_a_hi = None, None\n    sB_swizzle = sB.iterator.type.swizzle_type\n    smem_desc_base_b: int = const_expr(\n        sm100_desc.make_smem_desc_base(\n            cute.recast_layout(128, op.b_dtype.width, sB_layout[0]),\n            sB_swizzle,\n            sm100_desc.Major.K\n            if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)\n            else sm100_desc.Major.MN,\n        )\n    )\n    smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b)\n    smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo)\n    smem_desc_b_hi = const_expr(smem_desc_b_hi)\n\n    if const_expr(not is_ts):\n        smem_desc_start_a_lo = Int32(smem_desc_base_a_lo) | sm100_desc.make_smem_desc_start_addr(\n            sA[None, None, 0].iterator\n        )\n    else:\n        smem_desc_start_a_lo = None\n    smem_desc_start_b_lo = Int32(smem_desc_base_b_lo) | sm100_desc.make_smem_desc_start_addr(\n        sB[None, None, 0].iterator\n    )\n    for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])):\n        if const_expr(not is_ts):\n            smem_desc_a_lo = smem_desc_start_a_lo + (\n                (cute.crd2idx((0, 0, k), sA_layout) * sA.element_type.width // 8) >> 4\n            )\n        smem_desc_b_lo = smem_desc_start_b_lo + (\n            (cute.crd2idx((0, 0, k), sB_layout) * sB.element_type.width // 8) >> 4\n        )\n        # with cute.arch.elect_one():\n        #     cute.printf(\"smem_desc_a_lo = {}, smem_desc_b_lo = {}\", smem_desc_a_lo, smem_desc_b_lo)\n        #     cute.printf(\"smem_desc_a_lo_correct = {}, smem_desc_b_lo_correct = {}\", smem_desc_a_lo_correct, smem_desc_b_lo_correct)\n        with cute.arch.elect_one():\n            if const_expr(not is_ts):\n                llvm.inline_asm(\n                    None,\n                    [\n                        acc.iterator.toint().ir_value(),\n                        smem_desc_a_lo.ir_value(),\n                        smem_desc_b_lo.ir_value(),\n                        Int32(not zero_init or k != 0).ir_value(),\n                    ],\n                    \"{\\n\\t\"\n                    \".reg .pred p;\\n\\t\"\n                    \".reg .b64 smem_desc_a, smem_desc_b;\\n\\t\"\n                    \".reg .b32 idesc;\\n\\t\"\n                    f\"mov.b32 idesc, {hex(idesc)};\\n\\t\"\n                    f\"mov.b64 smem_desc_a, {{$1, {hex(smem_desc_a_hi)}}};\\n\\t\"\n                    f\"mov.b64 smem_desc_b, {{$2, {hex(smem_desc_b_hi)}}};\\n\\t\"\n                    \"setp.ne.b32 p, $3, 0;\\n\\t\"\n                    f\"tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, p;\\n\\t\"\n                    \"}\\n\",\n                    \"r,r,r,r\",\n                    has_side_effects=True,\n                    is_align_stack=False,\n                    asm_dialect=llvm.AsmDialect.AD_ATT,\n                )\n            else:\n                llvm.inline_asm(\n                    None,\n                    [\n                        acc.iterator.toint().ir_value(),\n                        tCrA[None, None, k].iterator.toint().ir_value(),\n                        smem_desc_b_lo.ir_value(),\n                        Int32(not zero_init or k != 0).ir_value(),\n                    ],\n                    \"{\\n\\t\"\n                    \".reg .pred p;\\n\\t\"\n                    \".reg .b64 smem_desc_b;\\n\\t\"\n                    f\"mov.b64 smem_desc_b, {{$2, {hex(smem_desc_b_hi)}}};\\n\\t\"\n                    \"setp.ne.b32 p, $3, 0;\\n\\t\"\n                    f\"tcgen05.mma.cta_group::1.kind::f16 [$0], [$1], smem_desc_b, {hex(idesc)}, p;\\n\\t\"\n                    \"}\\n\",\n                    \"r,r,r,r\",\n                    has_side_effects=True,\n                    is_align_stack=False,\n                    asm_dialect=llvm.AsmDialect.AD_ATT,\n                )\n\n\n@cute.jit\ndef gemm_ptx_loop(\n    op: cute.nvgpu.tcgen05.mma.MmaOp,\n    acc: cute.Tensor,\n    tCrA: cute.Tensor,\n    tCrB: cute.Tensor,\n    sA: Optional[cute.Tensor],\n    sB: cute.Tensor,\n    zero_init: bool | Boolean = False,\n) -> None:\n    is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM\n    if const_expr(not is_ts):\n        assert sA is not None, \"sA must be provided when a_src is not TMEM\"\n    sA_layout = sA.layout if sA is not None else tCrA.layout\n    sB_layout = sB.layout\n    idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op))\n    if const_expr(not is_ts):\n        sA_swizzle = sA.iterator.type.swizzle_type\n        smem_desc_base_a: int = const_expr(\n            sm100_desc.make_smem_desc_base(\n                cute.recast_layout(128, op.a_dtype.width, sA_layout[0]),\n                sA_swizzle,\n                sm100_desc.Major.K\n                if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)\n                else sm100_desc.Major.MN,\n            )\n        )\n        smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a)\n        smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo)\n        smem_desc_a_hi = const_expr(smem_desc_a_hi)\n    else:\n        smem_desc_base_a = None\n        smem_desc_base_a_lo, smem_desc_a_hi = None, None\n    sB_swizzle = sB.iterator.type.swizzle_type\n    smem_desc_base_b: int = const_expr(\n        sm100_desc.make_smem_desc_base(\n            cute.recast_layout(128, op.b_dtype.width, sB_layout[0]),\n            sB_swizzle,\n            sm100_desc.Major.K\n            if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)\n            else sm100_desc.Major.MN,\n        )\n    )\n    smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b)\n    smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo)\n    smem_desc_b_hi = const_expr(smem_desc_b_hi)\n\n    if const_expr(not is_ts):\n        offset_a = [\n            (cute.crd2idx((0, 0, k), sA_layout) * sA.element_type.width // 8) >> 4\n            for k in cutlass.range_constexpr(cute.size(tCrA.shape[2]))\n        ]\n    else:\n        offset_a = [\n            cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 32\n            for k in cutlass.range_constexpr(cute.size(tCrA.shape[2]))\n        ]\n    offset_a_diff = [\n        offset_a[k] - offset_a[k - 1] for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2]))\n    ]\n    offset_b = [\n        (cute.crd2idx((0, 0, k), sB_layout) * sB.element_type.width // 8) >> 4\n        for k in cutlass.range_constexpr(cute.size(tCrB.shape[2]))\n    ]\n    offset_b_diff = [\n        offset_b[k] - offset_b[k - 1] for k in cutlass.range_constexpr(1, cute.size(tCrB.shape[2]))\n    ]\n\n    if const_expr(not is_ts):\n        smem_desc_start_a_lo = Int32(\n            smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator)\n        )\n    else:\n        smem_desc_start_a_lo = None\n    smem_desc_start_b_lo = Int32(\n        smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator)\n    )\n    pred_str = \"p\" if isinstance(zero_init, Boolean) else \"0\" if zero_init else \"1\"\n    if const_expr(not is_ts):\n        llvm.inline_asm(\n            None,\n            [\n                acc.iterator.toint().ir_value(),\n                Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(),\n                Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(),\n                Int32(not zero_init).ir_value(),\n            ],\n            \"{\\n\\t\"\n            \".reg .pred leader_thread;\\n\\t\"\n            \".reg .pred p;\\n\\t\"\n            \".reg .b32 idesc;\\n\\t\"\n            \".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\\n\\t\"\n            \".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\\n\\t\"\n            \".reg .b64 smem_desc_a, smem_desc_b;\\n\\t\"\n            \"elect.sync _|leader_thread, -1;\\n\\t\"\n            f\"mov.b32 idesc, {hex(idesc)};\\n\\t\"\n            \"mov.b32 smem_desc_a_lo, $1;\\n\\t\"\n            \"mov.b32 smem_desc_b_lo, $2;\\n\\t\"\n            f\"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\\n\\t\"\n            f\"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\\n\\t\"\n            f\"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\\n\\t\"\n            f\"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\\n\\t\"\n            \"setp.ne.b32 p, $3, 0;\\n\\t\"\n            f\"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, {pred_str};\\n\\t\"\n            + \"\".join(\n                (\n                    f\"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\\n\\t\"\n                    f\"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\\n\\t\"\n                    f\"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\\n\\t\"\n                    f\"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\\n\\t\"\n                    f\"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, 1;\\n\\t\"\n                )\n                for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2]))\n            )\n            + \"}\\n\",\n            \"r,r,r,r\",\n            has_side_effects=True,\n            is_align_stack=False,\n            asm_dialect=llvm.AsmDialect.AD_ATT,\n        )\n    else:\n        llvm.inline_asm(\n            None,\n            [\n                acc.iterator.toint().ir_value(),\n                Int32(tCrA[None, None, 0].iterator.toint()).ir_value(),\n                Int32(smem_desc_start_b_lo).ir_value(),\n                Int32(not zero_init).ir_value(),\n            ],\n            \"{\\n\\t\"\n            \".reg .pred leader_thread;\\n\\t\"\n            \".reg .pred p;\\n\\t\"\n            \".reg .b32 idesc;\\n\\t\"\n            \".reg .b32 tmem_a;\\n\\t\"\n            \".reg .b32 smem_desc_b_lo;\\n\\t\"\n            \".reg .b32 smem_desc_b_hi;\\n\\t\"\n            \".reg .b64 smem_desc_b;\\n\\t\"\n            \"elect.sync _|leader_thread, -1;\\n\\t\"\n            f\"mov.b32 idesc, {hex(idesc)};\\n\\t\"\n            \"mov.b32 tmem_a, $1;\\n\\t\"\n            \"mov.b32 smem_desc_b_lo, $2;\\n\\t\"\n            f\"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\\n\\t\"\n            f\"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\\n\\t\"\n            \"setp.ne.b32 p, $3, 0;\\n\\t\"\n            f\"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {pred_str};\\n\\t\"\n            + \"\".join(\n                (\n                    # f\"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\\n\\t\"\n                    f\"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\\n\\t\"\n                    f\"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\\n\\t\"\n                    # f\"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, 1;\\n\\t\"\n                    f\"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\\n\\t\"\n                )\n                for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2]))\n            )\n            + \"}\\n\",\n            \"r,r,r,r\",\n            has_side_effects=True,\n            is_align_stack=False,\n            asm_dialect=llvm.AsmDialect.AD_ATT,\n        )\n\n\n@cute.jit\ndef gemm_ptx_partial(\n    op: cute.nvgpu.tcgen05.mma.MmaOp,\n    acc_tmem_addr: Int32,\n    tCrA: cute.Tensor,\n    tCrB: cute.Tensor,\n    sA: Optional[cute.Tensor],\n    sB: cute.Tensor,\n    mbar_ptr: Optional[cutlass.Pointer] = None,\n    mbar_phase: Optional[Int32] = None,\n    split_arrive: Optional[int] = None,\n    zero_init: bool | Boolean = False,\n    # sA_offset: Int32 = 0,\n    # acc_offset: Int32 = 0,\n    tA_addr: Optional[Int32] = None,\n    cta_group: int = 1,\n) -> None:\n    # acc_tmem_addr += acc_offset\n    is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM\n    if const_expr(not is_ts):\n        assert sA is not None, \"sA must be provided when a_src is not TMEM\"\n    sA_layout = sA.layout if sA is not None else tCrA.layout\n    sB_layout = sB.layout\n    idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op))\n    if const_expr(not is_ts):\n        sA_swizzle = sA.iterator.type.swizzle_type\n        smem_desc_base_a: int = const_expr(\n            sm100_desc.make_smem_desc_base(\n                cute.recast_layout(128, op.a_dtype.width, sA_layout[0]),\n                sA_swizzle,\n                sm100_desc.Major.K\n                if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)\n                else sm100_desc.Major.MN,\n            )\n        )\n        smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a)\n        smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo)\n        smem_desc_a_hi = const_expr(smem_desc_a_hi)\n    else:\n        smem_desc_base_a = None\n        smem_desc_base_a_lo, smem_desc_a_hi = None, None\n    sB_swizzle = sB.iterator.type.swizzle_type\n    smem_desc_base_b: int = const_expr(\n        sm100_desc.make_smem_desc_base(\n            cute.recast_layout(128, op.b_dtype.width, sB_layout[0]),\n            sB_swizzle,\n            sm100_desc.Major.K\n            if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)\n            else sm100_desc.Major.MN,\n        )\n    )\n    smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b)\n    smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo)\n    smem_desc_b_hi = const_expr(smem_desc_b_hi)\n\n    tCrA_layout = (\n        tCrA.layout\n        if const_expr(not is_ts)\n        else cute.recast_layout(32, tCrA.element_type.width, tCrA.layout)\n    )\n    offset_a = [cute.crd2idx((0, 0, k), tCrA_layout) for k in range(cute.size(tCrA.shape[2]))]\n    offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, cute.size(tCrA.shape[2]))]\n    offset_b = [cute.crd2idx((0, 0, k), tCrB.layout) for k in range(cute.size(tCrB.shape[2]))]\n    offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, cute.size(tCrB.shape[2]))]\n\n    if const_expr(not is_ts):\n        smem_desc_start_a_lo = Int32(\n            smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator)\n        )\n        # ) + sA_offset\n    else:\n        smem_desc_start_a_lo = None\n    smem_desc_start_b_lo = Int32(\n        smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator)\n    )\n    pred_str = \"p\" if isinstance(zero_init, Boolean) else \"0\" if zero_init else \"1\"\n    if const_expr(not is_ts):\n        assert mbar_ptr is None, \"mbar_ptr must be None when a_src is not TMEM\"\n        llvm.inline_asm(\n            None,\n            [\n                # acc.iterator.toint().ir_value(),\n                Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(),\n                Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(),\n                Int32(not zero_init).ir_value(),\n                Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(),\n            ],\n            \"{\\n\\t\"\n            \".reg .pred leader_thread;\\n\\t\"\n            \".reg .pred p;\\n\\t\"\n            \".reg .b32 idesc;\\n\\t\"\n            \".reg .b32 tmem_acc;\\n\\t\"\n            \".reg .b32 smem_desc_a_lo_start, smem_desc_b_lo_start;\\n\\t\"\n            \".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\\n\\t\"\n            \".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\\n\\t\"\n            \".reg .b64 smem_desc_a, smem_desc_b;\\n\\t\"\n            \"elect.sync _|leader_thread, -1;\\n\\t\"\n            f\"mov.b32 idesc, {hex(idesc)};\\n\\t\"\n            # f\"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\\n\\t\"\n            f\"mov.b32 tmem_acc, $3;\\n\\t\"\n            \"mov.b32 smem_desc_a_lo_start, $0;\\n\\t\"\n            \"mov.b32 smem_desc_b_lo_start, $1;\\n\\t\"\n            f\"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\\n\\t\"\n            f\"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\\n\\t\"\n            f\"mov.b64 smem_desc_a, {{smem_desc_a_lo_start, smem_desc_a_hi}};\\n\\t\"\n            f\"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\\n\\t\"\n            \"setp.ne.b32 p, $2, 0;\\n\\t\"\n            f\"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {pred_str};\\n\\t\"\n            + \"\".join(\n                (\n                    # f\"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\\n\\t\"\n                    # f\"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\\n\\t\"\n                    f\"add.u32 smem_desc_a_lo, smem_desc_a_lo_start, {hex(offset_a[k])};\\n\\t\"\n                    f\"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\\n\\t\"\n                    f\"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\\n\\t\"\n                    f\"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\\n\\t\"\n                    f\"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, 1;\\n\\t\"\n                )\n                for k in range(1, cute.size(tCrA.shape[2]))\n            )\n            + \"}\\n\",\n            # \"r,r,r\",\n            \"r,r,r,r\",\n            has_side_effects=True,\n            is_align_stack=False,\n            asm_dialect=llvm.AsmDialect.AD_ATT,\n        )\n    else:\n        # For TS gemm, somehow tCrA.iterator.toint() returns 0 no matter what, so we need to\n        # explicitly pass in the tA_addr for correctness.\n        tA_addr = tCrA[None, None, 0].iterator.toint() if tA_addr is None else tA_addr\n        input_args = [\n            # Int32(cute.arch.make_warp_uniform(tCrA[None, None, 0].iterator.toint())).ir_value(),\n            Int32(cute.arch.make_warp_uniform(tA_addr)).ir_value(),\n            Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(),\n            Int32(not zero_init).ir_value(),\n            Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(),\n        ]\n        if const_expr(mbar_ptr is not None):\n            assert mbar_phase is not None, \"mbar_phase must be provided when mbar_ptr is not None\"\n            assert split_arrive is not None, (\n                \"split_arrive must be provided when mbar_ptr is not None\"\n            )\n            split_arrive_idx = split_arrive // op.shape_mnk[2]\n            input_args.append(mbar_ptr.toint().ir_value())\n            input_args.append(Int32(mbar_phase).ir_value())\n            mbar_wait_str = (\n                \".reg .pred P1; \\n\\t\"\n                \"LAB_WAIT: \\n\\t\"\n                \"mbarrier.try_wait.parity.shared::cta.b64 P1, [$4], $5, 10000000; \\n\\t\"\n                \"@P1 bra DONE; \\n\\t\"\n                \"bra     LAB_WAIT; \\n\\t\"\n                \"DONE: \\n\\t\"\n            )\n        else:\n            mbar_wait_str = \"\"\n        llvm.inline_asm(\n            None,\n            # [\n            #     # acc.iterator.toint().ir_value(),\n            #     Int32(tCrA[None, None, 0].iterator.toint()).ir_value(),\n            #     Int32(smem_desc_start_b_lo).ir_value(),\n            #     Int32(not zero_init).ir_value(),\n            # ],\n            input_args,\n            \"{\\n\\t\"\n            \".reg .pred leader_thread;\\n\\t\"\n            \".reg .pred p;\\n\\t\"\n            \".reg .b32 idesc;\\n\\t\"\n            \".reg .b32 tmem_acc;\\n\\t\"\n            \".reg .b32 tmem_a;\\n\\t\"\n            \".reg .b32 smem_desc_b_lo_start;\\n\\t\"\n            \".reg .b32 smem_desc_b_lo;\\n\\t\"\n            \".reg .b32 smem_desc_b_hi;\\n\\t\"\n            \".reg .b64 smem_desc_b;\\n\\t\"\n            \"elect.sync _|leader_thread, -1;\\n\\t\"\n            f\"mov.b32 idesc, {hex(idesc)};\\n\\t\"\n            # f\"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\\n\\t\"\n            f\"mov.b32 tmem_acc, $3;\\n\\t\"\n            f\"mov.b32 tmem_a, $0;\\n\\t\"\n            f\"mov.b32 smem_desc_b_lo_start, $1;\\n\\t\"\n            f\"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\\n\\t\"\n            f\"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\\n\\t\"\n            \"setp.ne.b32 p, $2, 0;\\n\\t\"\n            f\"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, {pred_str};\\n\\t\"\n            + \"\".join(\n                (\n                    # f\"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\\n\\t\"\n                    # f\"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\\n\\t\"\n                    f\"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\\n\\t\"\n                    f\"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\\n\\t\"\n                    # f\"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, 1;\\n\\t\"\n                    f\"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\\n\\t\"\n                )\n                for k in range(\n                    1,\n                    cute.size(tCrA.shape[2]) if const_expr(mbar_ptr is None) else split_arrive_idx,\n                )\n            )\n            + mbar_wait_str\n            + (\n                \"\".join(\n                    (\n                        f\"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\\n\\t\"\n                        f\"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\\n\\t\"\n                        f\"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\\n\\t\"\n                    )\n                    for k in range(split_arrive_idx, cute.size(tCrA.shape[2]))\n                )\n                if const_expr(mbar_ptr is not None)\n                else \"\"\n            )\n            + \"}\\n\",\n            \"r,r,r,r\" if const_expr(mbar_ptr is None) else \"r,r,r,r,r,r\",\n            has_side_effects=True,\n            is_align_stack=False,\n            asm_dialect=llvm.AsmDialect.AD_ATT,\n        )\n\n\n@cute.jit\ndef gemm_ptx_partial1(\n    op: cute.nvgpu.tcgen05.mma.MmaOp,\n    acc_tmem_addr: cutlass.Constexpr[int],\n    tCrA: cute.Tensor,\n    tCrB: cute.Tensor,\n    sA_base_addr_for_desc: Int32,\n    sA_addr_offset_for_desc: cutlass.Constexpr[int],\n    sA_stage: Int32,\n    sB_base_addr_for_desc: Int32,\n    sB_addr_offset_for_desc: cutlass.Constexpr[int],\n    sB_stage: Int32,\n    sA_layout: Optional[cute.Layout],\n    sB_layout: Optional[cute.Layout],\n    sA_swizzle: Optional[cute.Swizzle],\n    sB_swizzle: cute.Swizzle,\n    zero_init: bool | Boolean = False,\n) -> None:\n    is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM\n    if const_expr(not is_ts):\n        assert sA_layout is not None, \"sA_layout must be provided when a_src is not TMEM\"\n        assert sA_swizzle is not None, \"sA_swizzle must be provided when a_src is not TMEM\"\n    idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op))\n    if const_expr(not is_ts):\n        smem_desc_base_a: int = const_expr(\n            sm100_desc.make_smem_desc_base(\n                cute.recast_layout(128, op.a_dtype.width, sA_layout[0]),\n                sA_swizzle,\n                sm100_desc.Major.K\n                if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)\n                else sm100_desc.Major.MN,\n            )\n        )\n        smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a)\n        smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo)\n        smem_desc_a_hi = const_expr(smem_desc_a_hi)\n    else:\n        smem_desc_base_a = None\n        smem_desc_base_a_lo, smem_desc_a_hi = None, None\n    smem_desc_base_b: int = const_expr(\n        sm100_desc.make_smem_desc_base(\n            cute.recast_layout(128, op.b_dtype.width, sB_layout[0]),\n            sB_swizzle,\n            sm100_desc.Major.K\n            if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K)\n            else sm100_desc.Major.MN,\n        )\n    )\n    smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b)\n    smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo)\n    smem_desc_b_hi = const_expr(smem_desc_b_hi)\n    mask = [Int32(0)] * 4\n\n    if const_expr(not is_ts):\n        offset_a = [\n            (cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 8) >> 4\n            for k in range(cute.size(tCrA.shape[2]))\n        ]\n    else:\n        offset_a = [\n            cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 32\n            for k in range(cute.size(tCrA.shape[2]))\n        ]\n    offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, cute.size(tCrA.shape[2]))]\n    offset_b = [\n        (cute.crd2idx((0, 0, k), sB_layout) * op.b_dtype.width // 8) >> 4\n        for k in range(cute.size(tCrB.shape[2]))\n    ]\n    offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, cute.size(tCrB.shape[2]))]\n\n    if const_expr(not is_ts):\n        # smem_desc_start_a_lo = Int32(smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator))\n        smem_desc_start_a_lo = const_expr(smem_desc_base_a_lo)\n    else:\n        smem_desc_start_a_lo = None\n    # smem_desc_start_b_lo = Int32(smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator))\n    smem_desc_start_b_lo = const_expr(smem_desc_base_b_lo)\n    pred_str = \"p\" if isinstance(zero_init, Boolean) else \"0\" if zero_init else \"1\"\n    if const_expr(not is_ts):\n        llvm.inline_asm(\n            None,\n            [\n                # acc.iterator.toint().ir_value(),\n                # Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(),\n                Int32(sA_base_addr_for_desc).ir_value(),\n                Int32(sA_stage).ir_value(),\n                # Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(),\n                Int32(sB_base_addr_for_desc).ir_value(),\n                Int32(sB_stage).ir_value(),\n                Int32(not zero_init).ir_value(),\n                mask[0].ir_value(),\n                mask[1].ir_value(),\n                mask[2].ir_value(),\n                mask[3].ir_value(),\n            ],\n            \"{\\n\\t\"\n            \".reg .pred leader_thread;\\n\\t\"\n            \".reg .pred p;\\n\\t\"\n            \".reg .b32 idesc;\\n\\t\"\n            \".reg .b32 tmem_acc;\\n\\t\"\n            \".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\\n\\t\"\n            \".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\\n\\t\"\n            \".reg .b64 smem_desc_a, smem_desc_b;\\n\\t\"\n            \"elect.sync _|leader_thread, -1;\\n\\t\"\n            f\"mov.b32 idesc, {hex(idesc)};\\n\\t\"\n            f\"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\\n\\t\"\n            # \"mov.b32 smem_desc_a_lo, $0;\\n\\t\"\n            # f\"add.u32 smem_desc_a_lo, $0, {hex(smem_desc_start_a_lo)};\\n\\t\"\n            f\"mad.lo.u32 smem_desc_a_lo, $1, {hex(sA_addr_offset_for_desc)}, $0;\\n\\t\"\n            # \"mov.b32 smem_desc_b_lo, $2;\\n\\t\"\n            f\"mad.lo.u32 smem_desc_b_lo, $3, {hex(sB_addr_offset_for_desc)}, $2;\\n\\t\"\n            f\"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\\n\\t\"\n            f\"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\\n\\t\"\n            f\"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\\n\\t\"\n            f\"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\\n\\t\"\n            \"setp.ne.b32 p, $4, 0;\\n\\t\"\n            f\"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, {pred_str};\\n\\t\"\n            + \"\".join(\n                (\n                    f\"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\\n\\t\"\n                    f\"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\\n\\t\"\n                    f\"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\\n\\t\"\n                    f\"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\\n\\t\"\n                    f\"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, 1;\\n\\t\"\n                )\n                for k in range(1, cute.size(tCrA.shape[2]))\n            )\n            + \"}\\n\",\n            \"r,r,r,r,r,r,r,r,r\",\n            has_side_effects=True,\n            is_align_stack=False,\n            asm_dialect=llvm.AsmDialect.AD_ATT,\n        )\n    else:\n        llvm.inline_asm(\n            None,\n            [\n                # acc.iterator.toint().ir_value(),\n                Int32(tCrA[None, None, 0].iterator.toint()).ir_value(),\n                Int32(smem_desc_start_b_lo).ir_value(),\n                Int32(not zero_init).ir_value(),\n                mask[0].ir_value(),\n                mask[1].ir_value(),\n                mask[2].ir_value(),\n                mask[3].ir_value(),\n            ],\n            \"{\\n\\t\"\n            \".reg .pred leader_thread;\\n\\t\"\n            \".reg .pred p;\\n\\t\"\n            \".reg .b32 idesc;\\n\\t\"\n            \".reg .b32 tmem_a;\\n\\t\"\n            \".reg .b32 smem_desc_b_lo;\\n\\t\"\n            \".reg .b32 smem_desc_b_hi;\\n\\t\"\n            \".reg .b64 smem_desc_b;\\n\\t\"\n            \"elect.sync _|leader_thread, -1;\\n\\t\"\n            f\"mov.b32 idesc, {hex(idesc)};\\n\\t\"\n            f\"mov.b32 tmem_a, $1;\\n\\t\"\n            f\"mov.b32 smem_desc_b_lo, $2;\\n\\t\"\n            f\"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\\n\\t\"\n            f\"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\\n\\t\"\n            \"setp.ne.b32 p, $3, 0;\\n\\t\"\n            f\"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, {pred_str};\\n\\t\"\n            + \"\".join(\n                (\n                    f\"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\\n\\t\"\n                    f\"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\\n\\t\"\n                    f\"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\\n\\t\"\n                    f\"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, 1;\\n\\t\"\n                )\n                for k in range(1, cute.size(tCrA.shape[2]))\n            )\n            + \"}\\n\",\n            \"r,r,r,r,r,r,r,r\",\n            has_side_effects=True,\n            is_align_stack=False,\n            asm_dialect=llvm.AsmDialect.AD_ATT,\n        )\n\n\n@cute.jit\ndef gemm_ptx_precomputed(\n    acc_tmem_addr: Int32,\n    smem_desc_start_a: Int32,  # If TS, then this is the tmem start address for A\n    smem_desc_start_b: Int32,\n    idesc: int,\n    smem_desc_base_a: Optional[int],\n    smem_desc_base_b: int,\n    tCrA_layout: cute.Layout,\n    tCrB_layout: cute.Layout,\n    mbar_ptr: Optional[cutlass.Pointer] = None,\n    mbar_phase: Optional[Int32] = None,\n    zero_init: bool | Boolean = False,\n    cta_group: int = 1,\n) -> None:\n    # acc_tmem_addr += acc_offset\n    is_ts = const_expr(smem_desc_base_a is None)\n    num_k_tile = cute.size(tCrA_layout.shape[2])\n    if const_expr(not is_ts):\n        smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a)\n    else:\n        smem_desc_base_a_lo, smem_desc_a_hi = None, None\n    smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b)\n\n    tCrA_layout = (\n        tCrA_layout\n        if const_expr(not is_ts)\n        # else cute.recast_layout(32, tCrA.element_type.width, tCrA_layout)\n        # currently hard-coding the width to 16\n        else cute.recast_layout(32, 16, tCrA_layout)\n    )\n    offset_a = [cute.crd2idx((0, 0, k), tCrA_layout) for k in range(num_k_tile)]\n    offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, num_k_tile)]\n    offset_b = [cute.crd2idx((0, 0, k), tCrB_layout) for k in range(num_k_tile)]\n    offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, num_k_tile)]\n\n    smem_desc_start_a_lo = None\n    if const_expr(not is_ts):\n        smem_desc_start_a_lo = Int32(smem_desc_base_a_lo | smem_desc_start_a)\n        # smem_desc_start_a_lo = smem_desc_start_a\n    smem_desc_start_b_lo = Int32(smem_desc_base_b_lo | smem_desc_start_b)\n    pred_str = \"p\" if isinstance(zero_init, Boolean) else \"0\" if zero_init else \"1\"\n    if const_expr(not is_ts):\n        assert mbar_ptr is None, \"mbar_ptr must be None when a_src is not TMEM\"\n        llvm.inline_asm(\n            None,\n            [\n                # acc.iterator.toint().ir_value(),\n                Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(),\n                Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(),\n                Int32(not zero_init).ir_value(),\n                Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(),\n            ],\n            \"{\\n\\t\"\n            \".reg .pred leader_thread;\\n\\t\"\n            \".reg .pred p;\\n\\t\"\n            \".reg .b32 idesc;\\n\\t\"\n            \".reg .b32 tmem_acc;\\n\\t\"\n            \".reg .b32 smem_desc_a_lo_start, smem_desc_b_lo_start;\\n\\t\"\n            \".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\\n\\t\"\n            \".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\\n\\t\"\n            \".reg .b64 smem_desc_a, smem_desc_b;\\n\\t\"\n            \"elect.sync _|leader_thread, -1;\\n\\t\"\n            f\"mov.b32 idesc, {hex(idesc)};\\n\\t\"\n            # f\"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\\n\\t\"\n            f\"mov.b32 tmem_acc, $3;\\n\\t\"\n            \"mov.b32 smem_desc_a_lo_start, $0;\\n\\t\"\n            \"mov.b32 smem_desc_b_lo_start, $1;\\n\\t\"\n            f\"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\\n\\t\"\n            f\"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\\n\\t\"\n            f\"mov.b64 smem_desc_a, {{smem_desc_a_lo_start, smem_desc_a_hi}};\\n\\t\"\n            f\"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\\n\\t\"\n            \"setp.ne.b32 p, $2, 0;\\n\\t\"\n            f\"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {pred_str};\\n\\t\"\n            + \"\".join(\n                (\n                    # f\"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\\n\\t\"\n                    # f\"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\\n\\t\"\n                    f\"add.s32 smem_desc_a_lo, smem_desc_a_lo_start, {hex(offset_a[k])};\\n\\t\"\n                    f\"add.s32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\\n\\t\"\n                    f\"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\\n\\t\"\n                    f\"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\\n\\t\"\n                    f\"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, 1;\\n\\t\"\n                )\n                for k in range(1, num_k_tile)\n            )\n            + \"}\\n\",\n            # \"r,r,r\",\n            \"r,r,r,r\",\n            has_side_effects=True,\n            is_align_stack=False,\n            asm_dialect=llvm.AsmDialect.AD_ATT,\n        )\n    else:\n        input_args = [\n            Int32(cute.arch.make_warp_uniform(smem_desc_start_a)).ir_value(),\n            Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(),\n            Int32(not zero_init).ir_value(),\n            Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(),\n        ]\n        if const_expr(mbar_ptr is not None):\n            assert mbar_phase is not None, \"mbar_phase must be provided when mbar_ptr is not None\"\n            input_args.append(mbar_ptr.toint().ir_value())\n            input_args.append(Int32(mbar_phase).ir_value())\n            mbar_wait_str = (\n                \".reg .pred P1; \\n\\t\"\n                \"LAB_WAIT: \\n\\t\"\n                \"mbarrier.try_wait.parity.shared::cta.b64 P1, [$4], $5, 10000000; \\n\\t\"\n                \"@P1 bra DONE; \\n\\t\"\n                \"bra     LAB_WAIT; \\n\\t\"\n                \"DONE: \\n\\t\"\n            )\n        else:\n            mbar_wait_str = \"\"\n        llvm.inline_asm(\n            None,\n            # [\n            #     # acc.iterator.toint().ir_value(),\n            #     Int32(tCrA_layout[None, None, 0].iterator.toint()).ir_value(),\n            #     Int32(smem_desc_start_b_lo).ir_value(),\n            #     Int32(not zero_init).ir_value(),\n            # ],\n            input_args,\n            \"{\\n\\t\"\n            \".reg .pred leader_thread;\\n\\t\"\n            \".reg .pred p;\\n\\t\"\n            \".reg .b32 idesc;\\n\\t\"\n            \".reg .b32 tmem_acc;\\n\\t\"\n            \".reg .b32 tmem_a;\\n\\t\"\n            \".reg .b32 smem_desc_b_lo_start;\\n\\t\"\n            \".reg .b32 smem_desc_b_lo;\\n\\t\"\n            \".reg .b32 smem_desc_b_hi;\\n\\t\"\n            \".reg .b64 smem_desc_b;\\n\\t\"\n            \"elect.sync _|leader_thread, -1;\\n\\t\"\n            f\"mov.b32 idesc, {hex(idesc)};\\n\\t\"\n            # f\"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\\n\\t\"\n            f\"mov.b32 tmem_acc, $3;\\n\\t\"\n            f\"mov.b32 tmem_a, $0;\\n\\t\"\n            f\"mov.b32 smem_desc_b_lo_start, $1;\\n\\t\"\n            f\"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\\n\\t\"\n            f\"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\\n\\t\"\n            \"setp.ne.b32 p, $2, 0;\\n\\t\"\n            f\"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, {pred_str};\\n\\t\"\n            + \"\".join(\n                (\n                    # f\"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\\n\\t\"\n                    # f\"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\\n\\t\"\n                    f\"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\\n\\t\"\n                    f\"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\\n\\t\"\n                    # f\"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, 1;\\n\\t\"\n                    f\"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\\n\\t\"\n                )\n                for k in range(\n                    1,\n                    num_k_tile if const_expr(mbar_ptr is None) else num_k_tile // 4 * 3,\n                )\n            )\n            + mbar_wait_str\n            + (\n                \"\".join(\n                    (\n                        # f\"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\\n\\t\"\n                        f\"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\\n\\t\"\n                        f\"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\\n\\t\"\n                        f\"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\\n\\t\"\n                    )\n                    for k in range(num_k_tile // 4 * 3, num_k_tile)\n                )\n                if const_expr(mbar_ptr is not None)\n                else \"\"\n            )\n            + \"}\\n\",\n            \"r,r,r,r\" if const_expr(mbar_ptr is None) else \"r,r,r,r,r,r\",\n            has_side_effects=True,\n            is_align_stack=False,\n            asm_dialect=llvm.AsmDialect.AD_ATT,\n        )\n\n\n@cute.jit\ndef declare_ptx_smem_desc(\n    smem_desc_start_a: Int32,  # If TS, then this is the tmem start address for A\n    smem_desc_base_a: Optional[int],\n    tCrA_layout: cute.Layout,\n    var_name_prefix: str = \"smem_desc\",\n) -> None:\n    is_ts = const_expr(smem_desc_base_a is None)\n    num_k_tile = cute.size(tCrA_layout.shape[2])\n    smem_desc_base_a_lo, smem_desc_a_hi = None, None\n    if const_expr(not is_ts):\n        smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a)\n    tCrA_layout = (\n        tCrA_layout\n        if const_expr(not is_ts)\n        # else cute.recast_layout(32, tCrA.element_type.width, tCrA_layout)\n        # currently hard-coding the width to 16\n        else cute.recast_layout(32, 16, tCrA_layout)\n    )\n    offset_a = [cute.crd2idx((0, 0, k), tCrA_layout) for k in range(num_k_tile)]\n    smem_desc_start_a_lo = None\n    if const_expr(not is_ts):\n        smem_desc_start_a_lo = Int32(smem_desc_base_a_lo | smem_desc_start_a)\n    if const_expr(not is_ts):\n        llvm.inline_asm(\n            None,\n            [Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value()],\n            f\".reg .b32 {var_name_prefix}_lo;\\n\\t\"\n            f\".reg .b64 {var_name_prefix}_<{num_k_tile}>;\\n\\t\"\n            f\"mov.b64 {var_name_prefix}_0, {{$0, {hex(smem_desc_a_hi)}}};\\n\\t\"\n            + \"\".join(\n                (\n                    f\"add.s32 {var_name_prefix}_lo, $0, {hex(offset_a[k])};\\n\\t\"\n                    f\"mov.b64 {var_name_prefix}_{k}, {{{var_name_prefix}_lo, {hex(smem_desc_a_hi)}}};\\n\\t\"\n                )\n                for k in range(1, num_k_tile)\n            ),\n            \"r\",\n            has_side_effects=True,\n            is_align_stack=False,\n            asm_dialect=llvm.AsmDialect.AD_ATT,\n        )\n\n\n@cute.jit\ndef declare_ptx_idesc(op: cute.nvgpu.tcgen05.mma.MmaOp, var_name: str = \"idesc\") -> None:\n    idesc = const_expr(sm100_desc.mma_op_to_idesc(op))\n    llvm.inline_asm(\n        None,\n        [],\n        f\".reg .b32 {var_name};\\n\\t\"  # noqa\n        f\"mov.b32 {var_name}, {hex(idesc)};\\n\\t\",\n        constraints=\"\",\n        has_side_effects=True,\n        is_align_stack=False,\n        asm_dialect=llvm.AsmDialect.AD_ATT,\n    )\n\n\n@cute.jit\ndef gemm_ptx_precomputed_varname(\n    acc_tmem_addr: Int32,\n    smem_desc_start_b: Int32,\n    # idesc: int,\n    smem_desc_base_b: int,\n    tCrB_layout: cute.Layout,\n    smem_var_name_prefix: str,\n    idesc_var_name: str,\n    smem_offset: int,\n    zero_init: bool | Boolean = False,\n    cta_group: int = 1,\n) -> None:\n    is_ts = False\n    num_k_tile = cute.size(tCrB_layout.shape[2])\n    smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b)\n    offset_b = [cute.crd2idx((0, 0, k), tCrB_layout) for k in range(num_k_tile)]\n\n    smem_desc_start_b_lo = Int32(smem_desc_base_b_lo | smem_desc_start_b)\n    pred_str = \"p\" if isinstance(zero_init, Boolean) else \"0\" if zero_init else \"1\"\n    if const_expr(not is_ts):\n        llvm.inline_asm(\n            None,\n            [\n                Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(),\n                Int32(not zero_init).ir_value(),\n                Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(),\n            ],\n            \"{\\n\\t\"\n            \".reg .pred leader_thread;\\n\\t\"\n            \".reg .pred p;\\n\\t\"\n            # \".reg .b32 idesc;\\n\\t\"\n            \".reg .b32 tmem_acc;\\n\\t\"\n            \".reg .b32 smem_desc_b_lo_start;\\n\\t\"\n            \".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\\n\\t\"\n            \".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\\n\\t\"\n            # \".reg .b64 smem_desc_b;\\n\\t\"\n            f\".reg .b64 smem_desc_b_<{num_k_tile}>;\\n\\t\"\n            \"elect.sync _|leader_thread, -1;\\n\\t\"\n            # f\"mov.b32 idesc, {hex(idesc)};\\n\\t\"\n            # f\"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\\n\\t\"\n            f\"mov.b32 tmem_acc, $2;\\n\\t\"\n            \"mov.b32 smem_desc_b_lo_start, $0;\\n\\t\"\n            f\"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\\n\\t\"\n            f\"mov.b64 {{smem_desc_a_lo, smem_desc_a_hi}}, {smem_var_name_prefix}_0;\\n\\t\"\n            f\"add.s32 smem_desc_a_lo, smem_desc_a_lo, {smem_offset};\\n\\t\"\n            f\"mov.b64 {smem_var_name_prefix}_0, {{smem_desc_a_lo, smem_desc_a_hi}};\\n\\t\"\n            f\"mov.b64 smem_desc_b_0, {{smem_desc_b_lo_start, smem_desc_b_hi}};\\n\\t\"\n            + \"\".join(\n                (\n                    f\"mov.b64 {{smem_desc_a_lo, smem_desc_a_hi}}, {smem_var_name_prefix}_{k};\\n\\t\"\n                    f\"add.s32 smem_desc_a_lo, smem_desc_a_lo, {smem_offset};\\n\\t\"\n                    f\"add.s32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\\n\\t\"\n                    f\"mov.b64 {smem_var_name_prefix}_{k}, {{smem_desc_a_lo, smem_desc_a_hi}};\\n\\t\"\n                    f\"mov.b64 smem_desc_b_{k}, {{smem_desc_b_lo, smem_desc_b_hi}};\\n\\t\"\n                )\n                for k in range(1, num_k_tile)\n            )\n            + \"setp.ne.b32 p, $1, 0;\\n\\t\"\n            # f\"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_0, smem_desc_b, idesc, {pred_str};\\n\\t\"\n            f\"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_0, smem_desc_b_0, {idesc_var_name}, {pred_str};\\n\\t\"\n            + \"\".join(\n                (\n                    # f\"mov.b64 {{smem_desc_a_lo, smem_desc_a_hi}}, {smem_var_name_prefix}_{k};\\n\\t\"\n                    # f\"add.s32 smem_desc_a_lo, smem_desc_a_lo, {smem_offset};\\n\\t\"\n                    # f\"add.s32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\\n\\t\"\n                    # f\"mov.b64 {smem_var_name_prefix}_{k}, {{smem_desc_a_lo, smem_desc_a_hi}};\\n\\t\"\n                    # f\"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\\n\\t\"\n                    # f\"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_{k}, smem_desc_b, idesc, 1;\\n\\t\"\n                    # f\"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_{k}, smem_desc_b, {idesc_var_name}, 1;\\n\\t\"\n                    f\"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_{k}, smem_desc_b_{k}, {idesc_var_name}, 1;\\n\\t\"\n                )\n                for k in range(1, num_k_tile)\n            )\n            + \"}\\n\",\n            \"r,r,r\",\n            has_side_effects=True,\n            is_align_stack=False,\n            asm_dialect=llvm.AsmDialect.AD_ATT,\n        )\n"
  },
  {
    "path": "flash_attn/cute/block_info.py",
    "content": "# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\nfrom typing import Tuple, Optional\nfrom dataclasses import dataclass\n\nimport cutlass\nimport cutlass.cute as cute\nfrom cutlass import Int32, const_expr\n\nfrom flash_attn.cute.seqlen_info import SeqlenInfoQK, SeqlenInfoQKNewK\n\n\n@dataclass(frozen=True)\nclass BlockInfo:\n    tile_m: cutlass.Constexpr[int]\n    tile_n: cutlass.Constexpr[int]\n    is_causal: cutlass.Constexpr[bool]\n    is_local: cutlass.Constexpr[bool] = False\n    is_split_kv: cutlass.Constexpr[bool] = False\n    window_size_left: Optional[Int32] = None\n    window_size_right: Optional[Int32] = None\n    qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1\n\n    @cute.jit\n    def get_n_block_min_max(\n        self,\n        seqlen_info: SeqlenInfoQK,\n        m_block: Int32,\n        split_idx: Int32 = 0,\n        num_splits: Int32 = 1,\n    ) -> Tuple[Int32, Int32]:\n        n_block_max = cute.ceil_div(seqlen_info.seqlen_k, self.tile_n)\n        if const_expr(self.is_causal or (self.is_local and self.window_size_right is not None)):\n            m_idx_max = (m_block + 1) * self.tile_m\n            if const_expr(self.qhead_per_kvhead_packgqa > 1):\n                m_idx_max = cute.ceil_div(m_idx_max, self.qhead_per_kvhead_packgqa)\n            n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q\n            n_idx_right = n_idx if const_expr(self.is_causal) else n_idx + self.window_size_right\n            n_block_max = min(n_block_max, cute.ceil_div(n_idx_right, self.tile_n))\n        n_block_min = 0\n        if const_expr(self.is_local and self.window_size_left is not None):\n            m_idx_min = m_block * self.tile_m\n            if const_expr(self.qhead_per_kvhead_packgqa > 1):\n                m_idx_min = m_idx_min // self.qhead_per_kvhead_packgqa\n            n_idx = m_idx_min + seqlen_info.seqlen_k - seqlen_info.seqlen_q\n            n_idx_left = n_idx - self.window_size_left\n            n_block_min = cutlass.max(n_idx_left // self.tile_n, 0)\n        if cutlass.const_expr(self.is_split_kv):\n            num_n_blocks_per_split = (\n                Int32(0)\n                if n_block_max <= n_block_min\n                else (n_block_max - n_block_min + num_splits - 1) // num_splits\n            )\n            n_block_min = n_block_min + split_idx * num_n_blocks_per_split\n            n_block_max = cutlass.min(n_block_min + num_n_blocks_per_split, n_block_max)\n        return n_block_min, n_block_max\n\n    @cute.jit\n    def get_m_block_min_max(self, seqlen_info: SeqlenInfoQK, n_block: Int32) -> Tuple[Int32, Int32]:\n        m_block_max = cute.ceil_div(seqlen_info.seqlen_q, self.tile_m)\n        m_block_min = 0\n        if const_expr(self.is_causal or (self.is_local and self.window_size_right is not None)):\n            n_idx_min = n_block * self.tile_n\n            m_idx = n_idx_min + seqlen_info.seqlen_q - seqlen_info.seqlen_k\n            m_idx_right = m_idx if const_expr(self.is_causal) else m_idx - self.window_size_right\n            m_block_min = max(m_block_min, m_idx_right // self.tile_m)\n        if const_expr(self.is_local and self.window_size_left is not None):\n            n_idx_max = (n_block + 1) * self.tile_n\n            m_idx = n_idx_max + seqlen_info.seqlen_q - seqlen_info.seqlen_k\n            m_idx_left = m_idx + self.window_size_left\n            m_block_max = min(m_block_max, cute.ceil_div(m_idx_left, self.tile_m))\n        return m_block_min, m_block_max\n\n    @cute.jit\n    def get_n_block_k_new_min_max(\n        self,\n        seqlen_info: SeqlenInfoQKNewK,\n        m_block: Int32,\n        split_idx: Int32 = 0,\n        num_splits: Int32 = 1,\n    ) -> Tuple[Int32, Int32]:\n        \"\"\"Get the block range for new K tokens (append KV).\n\n        First computes the full n_block range via get_n_block_min_max, then maps\n        those blocks into the new-K index space by subtracting seqlen_k_og.\n        \"\"\"\n        n_block_min, n_block_max = self.get_n_block_min_max(\n            seqlen_info,\n            m_block,\n            split_idx,\n            num_splits,\n        )\n        idx_k_new_min = cutlass.max(n_block_min * self.tile_n - seqlen_info.seqlen_k_og, 0)\n        idx_k_new_max = cutlass.min(\n            n_block_max * self.tile_n - seqlen_info.seqlen_k_og, seqlen_info.seqlen_k_new\n        )\n        n_block_new_min = idx_k_new_min // self.tile_n\n        n_block_new_max = (\n            cute.ceil_div(idx_k_new_max, self.tile_n)\n            if idx_k_new_max > idx_k_new_min\n            else n_block_new_min\n        )\n        return n_block_new_min, n_block_new_max\n\n    @cute.jit\n    def get_n_block_min_causal_local_mask(\n        self,\n        seqlen_info: SeqlenInfoQK,\n        m_block: Int32,\n        n_block_min: Int32,\n    ) -> Int32:\n        \"\"\"If we have separate iterations with causal or local masking at the start, where do we stop\"\"\"\n        m_idx_min = m_block * self.tile_m\n        if const_expr(self.qhead_per_kvhead_packgqa > 1):\n            m_idx_min = m_idx_min // self.qhead_per_kvhead_packgqa\n        n_idx = m_idx_min + seqlen_info.seqlen_k - seqlen_info.seqlen_q\n        n_idx_right = (\n            n_idx\n            if const_expr(not self.is_local or self.window_size_right is None)\n            else n_idx + self.window_size_right\n        )\n        return cutlass.max(n_block_min, n_idx_right // self.tile_n)\n\n    @cute.jit\n    def get_n_block_min_before_local_mask(\n        self,\n        seqlen_info: SeqlenInfoQK,\n        m_block: Int32,\n        n_block_min: Int32,\n    ) -> Int32:\n        \"\"\"If we have separate iterations with local masking at the end, where do we stop the non-masked iterations\"\"\"\n        if const_expr(not self.is_local or self.window_size_left is None):\n            return n_block_min\n        else:\n            m_idx_max = (m_block + 1) * self.tile_m\n            if const_expr(self.qhead_per_kvhead_packgqa > 1):\n                m_idx_max = cute.ceil_div(m_idx_max, self.qhead_per_kvhead_packgqa)\n            n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q\n            n_idx_left = n_idx - self.window_size_left\n            return cutlass.max(n_block_min, cute.ceil_div(n_idx_left, self.tile_n))\n"
  },
  {
    "path": "flash_attn/cute/block_sparse_utils.py",
    "content": "\"\"\"\nBlock-sparse runtime utilities for CUTE DSL kernels.\n\nThis module contains runtime execution functions for block-sparse attention kernels.\nThese utilities are used by CUTE DSL kernels to produce and consume block-sparse loads.\n\"\"\"\n\nfrom typing import Callable, Optional\nfrom functools import partial\nimport math\nimport cutlass\nimport cutlass.cute as cute\nfrom cutlass import Float32, Int32, const_expr\n\nfrom quack import copy_utils\n\n# Import data structures from block_sparsity\nfrom flash_attn.cute.block_sparsity import BlockSparseTensors\nfrom flash_attn.cute.named_barrier import NamedBarrierBwd\n\n\n# NOTE [SM100 block-sparse empty tiles: mbarrier contract]\n#\n# For block-sparse SM100 forward, a given (m_block, stage) Q tile can have zero active\n# KV blocks (total_block_cnt == 0). In that case there is no seqlen_kv iteration, so\n# the softmax warp-group has no row stats to publish.\n#\n# The correction warp-group seeds fully-masked-row stats and runs the usual correction\n# epilogue so output/LSE have well-defined values. Both warp-groups must still perform\n# the softmax<->correction mbarrier handshake so phases advance correctly across\n# empty->empty and empty->non-empty tile sequences.\n#\n# In the no-sink case, this corresponds to the usual fully-masked-row convention:\n# output is zero and LSE is -inf.\n#\n# Barrier contract (each is `mbar_ptr + <offset> + stage`):\n#\n# Producer/consumer pairs:\n# - `mbar_softmax_corr_full`    : softmax arrive        -> correction wait\n# - `mbar_softmax_corr_empty`   : correction arrive     -> softmax wait\n# - `mbar_P_full_O_rescaled`    : softmax arrive (+ correction arrive) -> MMA wait\n# - `mbar_P_full_2`             : softmax arrive        -> MMA wait\n# - `mbar_corr_epi_full_/empty` : correction <-> epilogue (only when epilogue is separate)\n#\n# Empty tile (`total_block_cnt == 0`):\n# - Softmax: skips the seqlen_kv softmax path entirely (no P stores, no `mbar_P_full_*`).\n#   It only arrives `mbar_softmax_corr_full` once per stage as a synthetic \"no work\" signal.\n#   At the `softmax_loop` level, softmax unconditionally waits `mbar_softmax_corr_empty`\n#   before each tile (when block-sparse) to drain a prior correction arrival and keep\n#   phases aligned across non-empty -> empty transitions.\n# - Correction: waits `mbar_softmax_corr_full`, seeds stats + runs `correction_epilogue(scale=0)`,\n#   and arrives `mbar_softmax_corr_empty` (and `mbar_corr_epi_full_/empty` when applicable).\n# - No `mbar_P_full_*` barriers are arrived (no P, no MMA O); only the softmax<->correction\n#   (and correction<->epilogue) handshakes advance phases.\n#\n# Non-empty tile:\n# - Softmax: runs `softmax_step` (produces P) and uses `mbar_softmax_corr_full/empty` to\n#   publish row_max (during seqlen_kv) and final row stats (once per tile), and to advance phases;\n#   arrives `mbar_P_full_*` when P is stored.\n# - Correction: waits `mbar_softmax_corr_full`, may rescale/release O, arrives `mbar_softmax_corr_empty`\n#   to ack/advance, and arrives `mbar_P_full_O_rescaled` when MMA can proceed.\n#\n# Backward (SM100):\n# - Empty KV tile: for a given `n_block`, `total_m_block_cnt == 0` means no Q tiles contribute.\n# - Both the load and compute loops guard all pipeline work on `process_tile`, so empty tiles\n#   skip producer/consumer operations entirely (no per-tile mbarrier phase handshake like forward).\n# - In the `not dKV_postprocess` path, dK/dV for empty KV tiles are explicitly written as zeros\n#   even when `process_tile == False` (see `flash_bwd_sm100.py` `should_zero_dKV`).\n\n\n@cute.jit\ndef load_block_list(\n    block_indices: cute.Tensor,\n    block_count,\n    first_block_preloaded: cutlass.Constexpr,\n    kv_producer_state,\n    load_K,\n    load_V,\n    pipeline_k,\n    pipeline_v,\n    intra_wg_overlap: cutlass.Constexpr,\n):\n    \"\"\"Iterate over the sparse blocks and load K, V into the pipeline.\n    For the intra_wg_overlap case, we overlap the loads of K and V. And this\n    means we need to pipeline the last V load from the partial block case,\n    with the loads for the full blocks. Set first_block_preloaded when the\n    caller has already issued the first K load for the list.\n\n    Q is loaded separately on its own mbarrier before this function is called.\n\n    Note:\n        we iterate along the block_n indices in reverse.\n\n    Returns:\n        Updated kv_producer_state after processing the block list.\n\n    \"\"\"\n    if block_count > 0:\n        if const_expr(not intra_wg_overlap):\n            for offset in cutlass.range(block_count):\n                n_block = block_indices[block_count - 1 - offset]\n                pipeline_k.producer_acquire(kv_producer_state)\n                load_K(src_idx=n_block, producer_state=kv_producer_state)\n                pipeline_v.producer_acquire(kv_producer_state)\n                load_V(src_idx=n_block, producer_state=kv_producer_state)\n                kv_producer_state.advance()\n        else:\n            n_block_first = block_indices[block_count - 1]\n            if const_expr(not first_block_preloaded):\n                pipeline_k.producer_acquire(kv_producer_state)\n                load_K(src_idx=n_block_first, producer_state=kv_producer_state)\n\n            for idx in cutlass.range(block_count - 1, unroll=1):\n                n_block_prev = block_indices[block_count - 1 - idx]\n                n_block = block_indices[block_count - 2 - idx]\n                kv_producer_state_prev = kv_producer_state.clone()\n                kv_producer_state.advance()\n                pipeline_k.producer_acquire(kv_producer_state)\n                load_K(src_idx=n_block, producer_state=kv_producer_state)\n                pipeline_v.producer_acquire(kv_producer_state_prev)\n                load_V(src_idx=n_block_prev, producer_state=kv_producer_state_prev)\n\n    return kv_producer_state\n\n\n@cute.jit\ndef finish_overlap_v_load(\n    block_indices: cute.Tensor,\n    block_count,\n    load_V,\n    pipeline_v,\n    kv_producer_state,\n):\n    \"\"\"Load the final V block after overlapped K/V loads.\"\"\"\n    if block_count > 0:\n        n_block_last = block_indices[0]\n        pipeline_v.producer_acquire(kv_producer_state)\n        load_V(src_idx=n_block_last, producer_state=kv_producer_state)\n        kv_producer_state.advance()\n\n    return kv_producer_state\n\n\n@cute.jit\ndef sparse_tensor_m_block(\n    m_block,\n    qhead_per_kvhead: cutlass.Constexpr[int],\n    q_subtile_factor: cutlass.Constexpr[int],\n):\n    \"\"\"Map packed m_block indices to block-sparse tensor indices.\"\"\"\n    block = m_block\n    if const_expr(qhead_per_kvhead != 1):\n        block = block // qhead_per_kvhead\n    if const_expr(q_subtile_factor != 1):\n        block = block // q_subtile_factor\n    return block\n\n\n@cute.jit\ndef produce_block_sparse_loads(\n    blocksparse_tensors: BlockSparseTensors,\n    batch_idx,\n    head_idx,\n    m_block,\n    kv_producer_state,\n    load_K,\n    load_V,\n    pipeline_k,\n    pipeline_v,\n    intra_wg_overlap: cutlass.Constexpr,\n    qhead_per_kvhead: cutlass.Constexpr[int] = 1,\n    q_subtile_factor: cutlass.Constexpr[int] = 1,\n):\n    \"\"\"Iterate over the mask and full block lists for a single tile.\n\n    Q is loaded separately on its own mbarrier before this function is called.\n\n    The masked (partial) list may leave the last V load pending when intra-warp-group\n    overlap is enabled. The first full block must consume that pending V while\n    issuing its own K load on the next pipeline stage.\n\n    In the intra-wg-overlap path, the last masked block leaves its V copy in flight\n    while we advance the producer state to start the next full K. Either the full list\n    overlaps that pending V load, or, if no full blocks exist, we explicitly drain it.\n\n    Args:\n        qhead_per_kvhead: Pack-GQA factor. When > 1, m_block is in packed space and\n            must be converted to unpacked for sparse tensor indexing.\n    \"\"\"\n\n    mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors\n\n    m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor)\n\n    curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse]\n    curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None]\n\n    if const_expr(full_block_cnt is not None):\n        curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse]\n        curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None]\n    else:\n        curr_full_block_cnt = Int32(0)\n        curr_full_block_idx = None\n\n    mask_empty = curr_mask_block_cnt == 0\n    full_empty = curr_full_block_cnt == 0\n\n    if mask_empty:\n        # No masked blocks: the full list owns the initial K load.\n        kv_producer_state = load_block_list(\n            curr_full_block_idx,\n            curr_full_block_cnt,\n            first_block_preloaded=False,\n            kv_producer_state=kv_producer_state,\n            load_K=load_K,\n            load_V=load_V,\n            pipeline_k=pipeline_k,\n            pipeline_v=pipeline_v,\n            intra_wg_overlap=intra_wg_overlap,\n        )\n\n        if const_expr(intra_wg_overlap) and curr_full_block_cnt > 0:\n            kv_producer_state = finish_overlap_v_load(\n                curr_full_block_idx,\n                curr_full_block_cnt,\n                load_V,\n                pipeline_v,\n                kv_producer_state,\n            )\n    else:\n        # Masked blocks present. When overlap is disabled this fully drains the list.\n        kv_producer_state = load_block_list(\n            curr_mask_block_idx,\n            curr_mask_block_cnt,\n            first_block_preloaded=False,\n            kv_producer_state=kv_producer_state,\n            load_K=load_K,\n            load_V=load_V,\n            pipeline_k=pipeline_k,\n            pipeline_v=pipeline_v,\n            intra_wg_overlap=intra_wg_overlap,\n        )\n\n        if full_empty:\n            if const_expr(intra_wg_overlap):\n                kv_producer_state = finish_overlap_v_load(\n                    curr_mask_block_idx,\n                    curr_mask_block_cnt,\n                    load_V,\n                    pipeline_v,\n                    kv_producer_state,\n                )\n        else:\n            if const_expr(intra_wg_overlap):\n                # Bridge the masked list to the full list by overlapping the pending masked V\n                # with the first full K load.\n                n_block_mask_last = curr_mask_block_idx[0]\n                n_block_full_first = curr_full_block_idx[curr_full_block_cnt - 1]\n                kv_producer_state_prev = kv_producer_state.clone()\n                kv_producer_state.advance()\n                pipeline_k.producer_acquire(kv_producer_state)\n                load_K(src_idx=n_block_full_first, producer_state=kv_producer_state)\n                pipeline_v.producer_acquire(kv_producer_state_prev)\n                load_V(src_idx=n_block_mask_last, producer_state=kv_producer_state_prev)\n\n                kv_producer_state = load_block_list(\n                    curr_full_block_idx,\n                    curr_full_block_cnt,\n                    first_block_preloaded=True,\n                    kv_producer_state=kv_producer_state,\n                    load_K=load_K,\n                    load_V=load_V,\n                    pipeline_k=pipeline_k,\n                    pipeline_v=pipeline_v,\n                    intra_wg_overlap=intra_wg_overlap,\n                )\n\n                kv_producer_state = finish_overlap_v_load(\n                    curr_full_block_idx,\n                    curr_full_block_cnt,\n                    load_V,\n                    pipeline_v,\n                    kv_producer_state,\n                )\n            else:\n                # Non-overlap path with both lists: run the full list normally.\n                kv_producer_state = load_block_list(\n                    curr_full_block_idx,\n                    curr_full_block_cnt,\n                    first_block_preloaded=False,\n                    kv_producer_state=kv_producer_state,\n                    load_K=load_K,\n                    load_V=load_V,\n                    pipeline_k=pipeline_k,\n                    pipeline_v=pipeline_v,\n                    intra_wg_overlap=intra_wg_overlap,\n                )\n\n    return kv_producer_state\n\n\n@cute.jit\ndef consume_block_sparse_loads(\n    blocksparse_tensors: BlockSparseTensors,\n    batch_idx,\n    head_idx,\n    m_block,\n    seqlen,\n    kv_consumer_state,\n    mma_pv_fn,\n    mma_one_n_block,\n    process_first_half_block,\n    process_last_half_block,\n    mask_fn,\n    score_mod_fn,\n    O_should_accumulate,\n    mask_mod,\n    fastdiv_mods,\n    intra_wg_overlap: cutlass.Constexpr,\n    warp_scheduler_barrier_sync: Callable,\n    warp_scheduler_barrier_arrive: Callable,\n    qhead_per_kvhead: cutlass.Constexpr[int] = 1,\n    q_subtile_factor: cutlass.Constexpr[int] = 1,\n):\n    \"\"\"Consume the mask and full block lists for a single tile on the consumer side.\n\n    Mirrors `produce_block_sparse_loads` so that the consumer pipeline uses\n    the same sparse tensor indexing.\n\n    Args:\n        qhead_per_kvhead: Pack-GQA factor. When > 1, m_block is in packed space and\n            must be converted to unpacked for sparse tensor indexing.\n    \"\"\"\n\n    mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors\n\n    m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor)\n\n    curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse]\n    curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None]\n    curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse]\n    curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None]\n\n    processed_any = curr_mask_block_cnt + curr_full_block_cnt > 0\n\n    if const_expr(not intra_wg_overlap):\n        if curr_mask_block_cnt > 0:\n            mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1]\n            warp_scheduler_barrier_sync()\n            kv_consumer_state = mma_one_n_block(\n                kv_consumer_state,\n                n_block=mask_n_block,\n                mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),\n                mask_fn=partial(\n                    mask_fn,\n                    mask_mod=mask_mod,\n                    mask_seqlen=True,\n                    fastdiv_mods=fastdiv_mods if cutlass.const_expr(mask_mod is not None) else None,\n                ),\n                is_first_n_block=True,\n            )\n            O_should_accumulate = True\n            for i in cutlass.range(1, curr_mask_block_cnt):\n                mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i]\n                kv_consumer_state = mma_one_n_block(\n                    kv_consumer_state,\n                    n_block=mask_n_block,\n                    mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),\n                    mask_fn=partial(mask_fn, mask_mod=mask_mod, mask_seqlen=False),\n                    is_first_n_block=False,\n                )\n                O_should_accumulate = True\n            if curr_full_block_cnt == 0:\n                warp_scheduler_barrier_arrive()\n\n        if curr_full_block_cnt > 0:\n            full_n_block = curr_full_block_idx[curr_full_block_cnt - 1]\n            if curr_mask_block_cnt == 0:\n                warp_scheduler_barrier_sync()\n                kv_consumer_state = mma_one_n_block(\n                    kv_consumer_state,\n                    n_block=full_n_block,\n                    mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),\n                    mask_fn=partial(mask_fn, mask_seqlen=True),\n                    is_first_n_block=True,\n                )\n                O_should_accumulate = True\n                for i in cutlass.range(1, curr_full_block_cnt):\n                    full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i]\n                    kv_consumer_state = mma_one_n_block(\n                        kv_consumer_state,\n                        n_block=full_n_block,\n                        mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),\n                        mask_fn=partial(mask_fn, mask_seqlen=False),\n                        is_first_n_block=False,\n                    )\n                    O_should_accumulate = True\n            else:\n                kv_consumer_state = mma_one_n_block(\n                    kv_consumer_state,\n                    n_block=full_n_block,\n                    mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),\n                    mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True),\n                    is_first_n_block=False,\n                )\n                O_should_accumulate = True\n                for i in cutlass.range(1, curr_full_block_cnt):\n                    full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i]\n                    kv_consumer_state = mma_one_n_block(\n                        kv_consumer_state,\n                        n_block=full_n_block,\n                        mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),\n                        mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=False),\n                        is_first_n_block=False,\n                    )\n                    O_should_accumulate = True\n            warp_scheduler_barrier_arrive()\n    else:\n        if curr_mask_block_cnt > 0:\n            mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1]\n            kv_consumer_state = process_first_half_block(\n                n_block=mask_n_block,\n                seqlen=seqlen,\n                kv_consumer_state=kv_consumer_state,\n                mask_fn=partial(\n                    mask_fn,\n                    mask_mod=mask_mod,\n                    mask_seqlen=True,\n                    fastdiv_mods=fastdiv_mods if cutlass.const_expr(mask_mod is not None) else None,\n                ),\n                score_mod_fn=score_mod_fn,\n                is_first_block=True,\n            )\n            for i in cutlass.range(1, curr_mask_block_cnt):\n                mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i]\n                kv_consumer_state = mma_one_n_block(\n                    kv_consumer_state,\n                    n_block=mask_n_block,\n                    seqlen=seqlen,\n                    mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),\n                    mask_fn=partial(mask_fn, mask_mod=mask_mod, mask_seqlen=False),\n                )\n                O_should_accumulate = True\n\n        if curr_full_block_cnt > 0:\n            full_n_block = curr_full_block_idx[curr_full_block_cnt - 1]\n            if curr_mask_block_cnt == 0:\n                kv_consumer_state = process_first_half_block(\n                    n_block=full_n_block,\n                    seqlen=seqlen,\n                    kv_consumer_state=kv_consumer_state,\n                    mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True),\n                    score_mod_fn=score_mod_fn,\n                    is_first_block=True,\n                )\n            else:\n                kv_consumer_state = mma_one_n_block(\n                    kv_consumer_state,\n                    n_block=full_n_block,\n                    seqlen=seqlen,\n                    mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),\n                    mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True),\n                )\n                O_should_accumulate = True\n            for i in cutlass.range(1, curr_full_block_cnt):\n                full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i]\n                kv_consumer_state = mma_one_n_block(\n                    kv_consumer_state,\n                    n_block=full_n_block,\n                    seqlen=seqlen,\n                    mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),\n                    mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=False),\n                )\n                O_should_accumulate = True\n\n        if curr_mask_block_cnt + curr_full_block_cnt > 0:\n            kv_consumer_state = process_last_half_block(\n                kv_consumer_state=kv_consumer_state,\n                zero_init=not O_should_accumulate,\n            )\n            O_should_accumulate = True\n\n    return kv_consumer_state, O_should_accumulate, processed_any\n\n\n@cute.jit\ndef load_block_list_sm100(\n    block_indices: cute.Tensor,\n    block_count,\n    load_q_with_first: cutlass.Constexpr,\n    q_stage: cutlass.Constexpr,\n    kv_producer_state,\n    load_Q,\n    load_K,\n    load_V,\n    pipeline_kv,\n):\n    \"\"\"SM100 version of load_block_list (no intra_wg_overlap, no extra_tx_count).\"\"\"\n    if block_count > 0:\n        # First iteration: load Q alongside K if requested\n        n_block_first = block_indices[block_count - 1]\n\n        if const_expr(load_q_with_first):\n            # SM100 loads Q0 and optionally Q1\n            load_Q(block=0, stage=0)\n            if const_expr(q_stage == 2):\n                load_Q(block=1, stage=1)\n\n        # SM100 doesn't use producer_acquire for pipeline_kv in load path\n        # The pipeline barriers are handled inside load_KV\n        load_K(block=n_block_first, producer_state=kv_producer_state, page_idx=None)\n        kv_producer_state.advance()\n        load_V(block=n_block_first, producer_state=kv_producer_state, page_idx=None)\n        kv_producer_state.advance()\n\n        # Remaining blocks\n        for offset in cutlass.range(1, block_count):\n            n_block = block_indices[block_count - 1 - offset]\n            load_K(block=n_block, producer_state=kv_producer_state, page_idx=None)\n            kv_producer_state.advance()\n            load_V(block=n_block, producer_state=kv_producer_state, page_idx=None)\n            kv_producer_state.advance()\n\n    return kv_producer_state\n\n\n# SM100-specific tile processor using SM100 helpers\n@cute.jit\ndef produce_block_sparse_loads_sm100(\n    blocksparse_tensors: BlockSparseTensors,\n    batch_idx,\n    head_idx,\n    m_block,\n    kv_producer_state,\n    load_Q,\n    load_K,\n    load_V,\n    pipeline_kv,\n    q_stage: cutlass.Constexpr,\n    q_producer_phase: Int32,\n    qhead_per_kvhead: cutlass.Constexpr,\n    q_subtile_factor: cutlass.Constexpr,\n):\n    \"\"\"SM100 entry point for sparse block iteration.\n\n    SM100 uses PipelineTmaUmma which doesn't support extra_tx_count, so we use\n    simplified block processing that just calls producer_acquire without extras.\n\n    Args:\n        m_block: which tile of m we are processing\n        qhead_per_kvhead: Constexpr pack factor\n    \"\"\"\n    m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor)\n\n    mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors\n\n    curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse]\n    curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None]\n\n    if const_expr(full_block_cnt is not None):\n        curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse]\n        curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None]\n    else:\n        curr_full_block_cnt = Int32(0)\n        curr_full_block_idx = None\n\n    mask_empty = curr_mask_block_cnt == 0\n    full_empty = curr_full_block_cnt == 0\n\n    q_phase_flipped = False\n\n    if mask_empty:\n        # No masked blocks: process full list with Q loading\n        kv_producer_state = load_block_list_sm100(\n            curr_full_block_idx,\n            curr_full_block_cnt,\n            load_q_with_first=True,\n            q_stage=q_stage,\n            kv_producer_state=kv_producer_state,\n            load_Q=load_Q,\n            load_K=load_K,\n            load_V=load_V,\n            pipeline_kv=pipeline_kv,\n        )\n        q_phase_flipped = not full_empty\n    else:\n        # Process masked blocks with Q loading\n        kv_producer_state = load_block_list_sm100(\n            curr_mask_block_idx,\n            curr_mask_block_cnt,\n            load_q_with_first=True,\n            q_stage=q_stage,\n            kv_producer_state=kv_producer_state,\n            load_Q=load_Q,\n            load_K=load_K,\n            load_V=load_V,\n            pipeline_kv=pipeline_kv,\n        )\n        q_phase_flipped = True\n\n        if not full_empty:\n            # Process full blocks without Q loading\n            kv_producer_state = load_block_list_sm100(\n                curr_full_block_idx,\n                curr_full_block_cnt,\n                load_q_with_first=False,\n                q_stage=q_stage,\n                kv_producer_state=kv_producer_state,\n                load_Q=load_Q,\n                load_K=load_K,\n                load_V=load_V,\n                pipeline_kv=pipeline_kv,\n            )\n\n    if q_phase_flipped:\n        q_producer_phase ^= 1\n\n    return kv_producer_state, q_producer_phase\n\n\n@cute.jit\ndef get_total_block_count(\n    blocksparse_tensors: BlockSparseTensors,\n    batch_idx,\n    head_idx,\n    m_block,\n    qhead_per_kvhead: cutlass.Constexpr,\n    q_subtile_factor: cutlass.Constexpr,\n):\n    m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor)\n\n    mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors\n    if const_expr(full_block_cnt is not None):\n        return (\n            mask_block_cnt[batch_idx, head_idx, m_block_sparse]\n            + full_block_cnt[batch_idx, head_idx, m_block_sparse]\n        )\n    else:\n        return mask_block_cnt[batch_idx, head_idx, m_block_sparse]\n\n\n@cute.jit\ndef handle_block_sparse_empty_tile_correction_sm100(\n    tidx: Int32,\n    q_stage: cutlass.Constexpr,\n    m_block_size: cutlass.Constexpr,\n    qhead_per_kvhead,\n    pack_gqa: cutlass.Constexpr,\n    is_split_kv: cutlass.Constexpr,\n    learnable_sink,\n    mLSE,\n    seqlen,\n    m_block: Int32,\n    head_idx: Int32,\n    batch_idx: Int32,\n    split_idx: Int32,\n    sScale: cute.Tensor,\n    stats: list,\n    correction_epilogue: Callable,\n    thr_mma_pv: cute.core.ThrMma,\n    tOtO: cute.Tensor,\n    sO: cute.Tensor,\n    pipeline_sm_stats: cutlass.pipeline.PipelineAsync,\n    sm_stats_barrier: cutlass.pipeline.NamedBarrier,\n    pipeline_o_epi: cutlass.pipeline.PipelineAsync,\n    sm_stats_consumer_phase: Int32,\n    o_corr_consumer_phase: Int32,\n    corr_epi_producer_phase: Int32,\n    softmax_scale_log2: Float32,\n    mO_cur: Optional[cute.Tensor] = None,\n    gO: Optional[cute.Tensor] = None,\n    gmem_tiled_copy_O: Optional[cute.TiledCopy] = None,\n):\n    \"\"\"Handle SM100 forward block-sparse tiles with no active KV blocks.\n\n    This path is taken when `total_block_cnt == 0`. The softmax warp-group still\n    arrives `mbar_softmax_corr_full` (synthetic \"no work\") so the correction\n    warp-group can:\n\n    - seed fully-masked-row stats (row_sum=1; row_max=-inf when tracked) for LSE\n    - run `correction_epilogue` with `scale=0` so the output tile is written as zeros\n      (independent of any prior tmem contents)\n    - wait on `mbar_softmax_corr_full` and arrive `mbar_softmax_corr_empty`\n      (and `mbar_corr_epi_*` when applicable) so phases stay aligned across tiles\n\n    This helper intentionally does not touch `mbar_P_full_*` since no P is produced.\n    See NOTE [SM100 block-sparse empty tiles: mbarrier contract].\n    \"\"\"\n    LOG2_E = Float32(math.log2(math.e))\n    warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4\n\n    for stage in cutlass.range_constexpr(q_stage):\n        row_sum_value = Float32(1.0)\n        row_max_value = (\n            -Float32.inf if const_expr(mLSE is not None or learnable_sink is not None) else None\n        )\n        if const_expr(learnable_sink is not None):\n            sink_val = -Float32.inf\n            if const_expr(not pack_gqa):\n                sink_val = Float32(learnable_sink[head_idx])\n            elif tidx < m_block_size:\n                q_head_idx = (\n                    (q_stage * m_block + stage) * m_block_size + tidx\n                ) % qhead_per_kvhead + head_idx * qhead_per_kvhead\n                sink_val = Float32(learnable_sink[q_head_idx])\n            if sink_val != -Float32.inf and (const_expr(not is_split_kv) or split_idx == 0):\n                if row_max_value == -Float32.inf:\n                    row_max_value = sink_val * (LOG2_E / softmax_scale_log2)\n                    row_sum_value = Float32(1.0)\n                else:\n                    row_sum_value = row_sum_value + cute.math.exp2(\n                        sink_val * LOG2_E - row_max_value * softmax_scale_log2, fastmath=True\n                    )\n        if tidx < m_block_size:\n            scale_row_idx = tidx + stage * m_block_size\n            sScale[scale_row_idx] = row_sum_value\n            if const_expr(mLSE is not None or learnable_sink is not None):\n                sScale[scale_row_idx + q_stage * m_block_size] = row_max_value\n        acc_flag = row_sum_value == Float32(0.0) or row_sum_value != row_sum_value\n        stats[stage] = (row_sum_value, row_max_value, acc_flag)\n\n        # See NOTE [SM100 block-sparse empty tiles: mbarrier contract].\n        # pipeline_sm_stats.consumer_wait_w_index_phase(stage, sm_stats_consumer_phase)\n        sm_stats_barrier.arrive_and_wait_w_index(index=stage * 4 + warp_idx)\n        pipeline_sm_stats.consumer_release_w_index(stage)\n\n        if const_expr(gmem_tiled_copy_O is None):\n            pipeline_o_epi.producer_acquire_w_index_phase(stage, corr_epi_producer_phase)\n        correction_epilogue(\n            thr_mma_pv,\n            tOtO[None, None, None, stage],\n            tidx,\n            stage,\n            m_block,\n            seqlen.seqlen_q,\n            Float32(0.0),  # zero scale ensures empty tile writes zeros into staged outputs\n            sO[None, None, stage],\n            mO_cur,\n            gO[None, None, stage],\n            gmem_tiled_copy_O,\n        )\n        if const_expr(gmem_tiled_copy_O is None):\n            pipeline_o_epi.producer_commit_w_index(stage)\n\n    sm_stats_consumer_phase ^= 1\n    corr_epi_producer_phase ^= 1\n\n    return (\n        sm_stats_consumer_phase,\n        o_corr_consumer_phase,\n        corr_epi_producer_phase,\n    )\n\n\n@cute.jit\ndef softmax_block_sparse_sm100(\n    blocksparse_tensors: BlockSparseTensors,\n    batch_idx,\n    head_idx,\n    m_block,\n    softmax_step: Callable,\n    mask_fn: Callable,\n    mask_fn_none: Callable,\n    mma_si_consumer_phase: Int32,\n    si_corr_producer_phase: Int32,\n    s0_s1_sequence_phase: Int32,\n    pipeline_sm_stats: cutlass.pipeline.PipelineAsync,\n    sm_stats_barrier: cutlass.pipeline.NamedBarrier,\n    q_stage: cutlass.Constexpr,\n    stage_idx: Int32,\n    check_m_boundary: bool,\n    qhead_per_kvhead: cutlass.Constexpr,\n    q_subtile_factor: cutlass.Constexpr[int] = 1,\n):\n    warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4\n    m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor)\n\n    mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors\n\n    curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse]\n    curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None]\n\n    if const_expr(full_block_cnt is not None):\n        curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse]\n        curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None]\n    else:\n        curr_full_block_cnt = Int32(0)\n        curr_full_block_idx = None\n\n    total_block_cnt = curr_mask_block_cnt + curr_full_block_cnt\n\n    if total_block_cnt == 0:\n        # See NOTE [SM100 block-sparse empty tiles: mbarrier contract].\n        # pipeline_sm_stats.producer_commit_w_index(stage_idx)\n        sm_stats_barrier.arrive_w_index(index=stage_idx * 4 + warp_idx)\n    else:\n        if curr_mask_block_cnt > 0:\n            mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1]\n            (\n                mma_si_consumer_phase,\n                si_corr_producer_phase,\n                s0_s1_sequence_phase,\n            ) = softmax_step(\n                mma_si_consumer_phase,\n                si_corr_producer_phase,\n                s0_s1_sequence_phase,\n                mask_n_block,\n                is_first=True,\n                mask_fn=partial(mask_fn, mask_seqlen=True, check_q_boundary=check_m_boundary),\n            )\n            for i in cutlass.range(1, curr_mask_block_cnt):\n                mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i]\n                (\n                    mma_si_consumer_phase,\n                    si_corr_producer_phase,\n                    s0_s1_sequence_phase,\n                ) = softmax_step(\n                    mma_si_consumer_phase,\n                    si_corr_producer_phase,\n                    s0_s1_sequence_phase,\n                    mask_n_block,\n                    mask_fn=partial(mask_fn, mask_seqlen=False, check_q_boundary=check_m_boundary),\n                )\n\n        if curr_full_block_cnt > 0:\n            full_n_block = curr_full_block_idx[curr_full_block_cnt - 1]\n            if curr_mask_block_cnt == 0:\n                (\n                    mma_si_consumer_phase,\n                    si_corr_producer_phase,\n                    s0_s1_sequence_phase,\n                ) = softmax_step(\n                    mma_si_consumer_phase,\n                    si_corr_producer_phase,\n                    s0_s1_sequence_phase,\n                    full_n_block,\n                    is_first=True,\n                    mask_fn=partial(\n                        mask_fn_none, mask_seqlen=True, check_q_boundary=check_m_boundary\n                    ),\n                )\n            else:\n                (\n                    mma_si_consumer_phase,\n                    si_corr_producer_phase,\n                    s0_s1_sequence_phase,\n                ) = softmax_step(\n                    mma_si_consumer_phase,\n                    si_corr_producer_phase,\n                    s0_s1_sequence_phase,\n                    full_n_block,\n                    is_first=False,\n                    mask_fn=partial(\n                        mask_fn_none, mask_seqlen=False, check_q_boundary=check_m_boundary\n                    ),\n                )\n            for i in cutlass.range(1, curr_full_block_cnt):\n                full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i]\n                (\n                    mma_si_consumer_phase,\n                    si_corr_producer_phase,\n                    s0_s1_sequence_phase,\n                ) = softmax_step(\n                    mma_si_consumer_phase,\n                    si_corr_producer_phase,\n                    s0_s1_sequence_phase,\n                    full_n_block,\n                    mask_fn=partial(\n                        mask_fn_none, mask_seqlen=False, check_q_boundary=check_m_boundary\n                    ),\n                )\n\n    return (\n        mma_si_consumer_phase,\n        si_corr_producer_phase,\n        s0_s1_sequence_phase,\n        total_block_cnt == 0,\n    )\n\n\n# =============================================================================\n# Backward-specific block-sparse helpers (SM100)\n# =============================================================================\n#\n# In backward, iteration is transposed compared to forward:\n# - Forward: outer loop over m_blocks (Q tiles), inner loop over n_blocks (KV tiles)\n# - Backward: outer loop over n_blocks (KV tiles), inner loop over m_blocks (Q tiles)\n#\n# The backward block-sparse tensors use \"Q direction\" indexing:\n# - q_block_cnt[batch, head, n_block] → count of m_blocks to process for this KV tile\n# - q_block_idx[batch, head, n_block, :] → indices of m_blocks to process\n#\n\n\n@cute.jit\ndef get_total_q_block_count_bwd(\n    blocksparse_tensors: BlockSparseTensors,\n    batch_idx,\n    head_idx,\n    n_block,\n    subtile_factor: cutlass.Constexpr = 1,\n    m_block_max: int = 0,\n):\n    \"\"\"Count total tile iterations for given n_block (KV tile) in backward.\"\"\"\n    q_block_cnt, _, full_block_cnt, _ = blocksparse_tensors\n    total = q_block_cnt[batch_idx, head_idx, n_block]\n    if const_expr(full_block_cnt is not None):\n        total = total + full_block_cnt[batch_idx, head_idx, n_block]\n    return total * subtile_factor\n\n\n@cute.jit\ndef produce_block_sparse_q_loads_bwd_sm100(\n    blocksparse_tensors: BlockSparseTensors,\n    batch_idx,\n    head_idx,\n    n_block,\n    # Pipeline states (will be returned after advancing)\n    producer_state_Q_LSE,\n    producer_state_dO_dPsum,\n    # Pipelines\n    pipeline_Q,\n    pipeline_LSE,\n    pipeline_dO,\n    pipeline_dPsum,\n    # Load functions\n    load_K,\n    load_V,\n    load_Q,\n    load_dO,\n    copy_stats,\n    # Global tensors for LSE/dPsum\n    gLSE,\n    sLSE,\n    gdPsum,\n    sdPsum,\n    # TMA copy bytes for extra_tx_count\n    tma_copy_bytes_K,\n    tma_copy_bytes_V,\n    # Flags for which loads to perform\n    should_load_Q: cutlass.Constexpr,\n    should_load_dO: cutlass.Constexpr,\n    # Subtiling factor and bounds\n    subtile_factor: cutlass.Constexpr = 1,\n    m_block_max: int = 0,\n):\n    \"\"\"SM100 backward block sparse loading with subtiling.\n\n    Returns updated (producer_state_Q_LSE, producer_state_dO_dPsum).\n    First iteration loads K/V alongside Q/dO; subsequent iterations load only Q/dO.\n    \"\"\"\n    (\n        curr_q_cnt,\n        curr_q_idx,\n        curr_full_cnt,\n        curr_full_idx,\n        loop_count,\n    ) = get_block_sparse_iteration_info_bwd(\n        blocksparse_tensors, batch_idx, head_idx, n_block, subtile_factor, m_block_max\n    )\n\n    for iter_idx in cutlass.range(loop_count, unroll=1):\n        m_block, _ = get_m_block_from_iter_bwd(\n            iter_idx,\n            curr_q_cnt,\n            curr_q_idx,\n            curr_full_cnt,\n            curr_full_idx,\n            subtile_factor,\n            m_block_max,\n        )\n        m_block_safe = m_block\n        if m_block_max > 0:\n            m_block_safe = cutlass.min(m_block, m_block_max - 1)\n\n        if iter_idx == 0:\n            # First block: load K/V alongside Q/dO\n            if const_expr(should_load_Q):\n                pipeline_Q.producer_acquire(producer_state_Q_LSE, extra_tx_count=tma_copy_bytes_K)\n                load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q_LSE))\n                load_Q(m_block_safe, producer_state=producer_state_Q_LSE)\n                pipeline_Q.producer_commit(producer_state_Q_LSE)\n                pipeline_LSE.producer_acquire(producer_state_Q_LSE)\n                with cute.arch.elect_one():\n                    copy_stats(\n                        gLSE[None, m_block_safe],\n                        sLSE[None, producer_state_Q_LSE.index],\n                        mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_Q_LSE),\n                    )\n                producer_state_Q_LSE.advance()\n            if const_expr(should_load_dO):\n                pipeline_dO.producer_acquire(\n                    producer_state_dO_dPsum, extra_tx_count=tma_copy_bytes_V\n                )\n                load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_dPsum))\n                load_dO(m_block_safe, producer_state=producer_state_dO_dPsum)\n                pipeline_dO.producer_commit(producer_state_dO_dPsum)\n                pipeline_dPsum.producer_acquire(producer_state_dO_dPsum)\n                with cute.arch.elect_one():\n                    copy_stats(\n                        gdPsum[None, m_block_safe],\n                        sdPsum[None, producer_state_dO_dPsum.index],\n                        mbar_ptr=pipeline_dPsum.producer_get_barrier(producer_state_dO_dPsum),\n                    )\n                producer_state_dO_dPsum.advance()\n        else:\n            # Subsequent blocks: just load Q/dO (K/V already loaded)\n            if const_expr(should_load_Q):\n                pipeline_Q.producer_acquire(producer_state_Q_LSE)\n                load_Q(m_block_safe, producer_state=producer_state_Q_LSE)\n                pipeline_Q.producer_commit(producer_state_Q_LSE)\n                pipeline_LSE.producer_acquire(producer_state_Q_LSE)\n                with cute.arch.elect_one():\n                    copy_stats(\n                        gLSE[None, m_block_safe],\n                        sLSE[None, producer_state_Q_LSE.index],\n                        mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_Q_LSE),\n                    )\n                producer_state_Q_LSE.advance()\n            if const_expr(should_load_dO):\n                pipeline_dO.producer_acquire(producer_state_dO_dPsum)\n                load_dO(m_block_safe, producer_state=producer_state_dO_dPsum)\n                pipeline_dO.producer_commit(producer_state_dO_dPsum)\n                pipeline_dPsum.producer_acquire(producer_state_dO_dPsum)\n                with cute.arch.elect_one():\n                    copy_stats(\n                        gdPsum[None, m_block_safe],\n                        sdPsum[None, producer_state_dO_dPsum.index],\n                        mbar_ptr=pipeline_dPsum.producer_get_barrier(producer_state_dO_dPsum),\n                    )\n                producer_state_dO_dPsum.advance()\n\n    return producer_state_Q_LSE, producer_state_dO_dPsum\n\n\n@cute.jit\ndef get_block_sparse_iteration_info_bwd(\n    blocksparse_tensors: BlockSparseTensors,\n    batch_idx,\n    head_idx,\n    n_block,\n    subtile_factor: cutlass.Constexpr = 1,\n    m_block_max: int = 0,\n):\n    \"\"\"Extract block-sparse iteration info for backward pass.\n\n    Returns (curr_q_cnt, curr_q_idx, curr_full_cnt, curr_full_idx, total_count).\n    \"\"\"\n    q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors\n    curr_q_cnt = q_cnt[batch_idx, head_idx, n_block]\n    curr_q_idx = q_idx[batch_idx, head_idx, n_block, None]\n\n    if const_expr(full_cnt is not None):\n        curr_full_cnt = full_cnt[batch_idx, head_idx, n_block]\n        curr_full_idx = full_idx[batch_idx, head_idx, n_block, None]\n    else:\n        curr_full_cnt = Int32(0)\n        curr_full_idx = None\n\n    sparse_block_count = curr_q_cnt\n    if const_expr(full_cnt is not None):\n        sparse_block_count = sparse_block_count + curr_full_cnt\n    total_count = sparse_block_count * subtile_factor\n\n    return curr_q_cnt, curr_q_idx, curr_full_cnt, curr_full_idx, total_count\n\n\n@cute.jit\ndef get_m_block_from_iter_bwd(\n    iter_idx,\n    curr_q_cnt,\n    curr_q_idx: cute.Tensor,\n    curr_full_cnt,\n    curr_full_idx: Optional[cute.Tensor],\n    subtile_factor: cutlass.Constexpr = 1,\n    m_block_max: int = 0,\n):\n    \"\"\"Derive m_block index and is_full_block flag from iteration index.\n\n    Returns (m_block, is_full_block):\n        - m_block: The actual Q-tile block index\n        - is_full_block: True if this is a full block (no mask_mod needed)\n    \"\"\"\n    sparse_iter_idx = iter_idx // subtile_factor\n    subtile_offset = iter_idx % subtile_factor\n\n    sparse_m_block = Int32(0)\n    is_full_block = False\n    if const_expr(curr_full_idx is not None):\n        if sparse_iter_idx < curr_q_cnt:\n            sparse_m_block = curr_q_idx[sparse_iter_idx]\n        else:\n            sparse_m_block = curr_full_idx[sparse_iter_idx - curr_q_cnt]\n            is_full_block = True\n    else:\n        sparse_m_block = curr_q_idx[sparse_iter_idx]\n\n    return sparse_m_block * subtile_factor + subtile_offset, is_full_block\n\n\n@cute.jit\ndef _load_q_do_block_sm90(\n    m_block,\n    producer_state_Q,\n    producer_state_dO,\n    pipeline_Q,\n    pipeline_dO,\n    load_K,\n    load_V,\n    load_Q,\n    load_dO,\n    load_LSE,\n    load_dPsum,\n    tma_copy_bytes_K,\n    tma_copy_bytes_V,\n    Q_stage_eq_dO_stage: cutlass.Constexpr,\n    load_kv: bool,\n):\n    \"\"\"Load one Q/dO block, optionally loading K/V on first iteration.\"\"\"\n    if load_kv:\n        pipeline_Q.producer_acquire(producer_state_Q, extra_tx_count=tma_copy_bytes_K)\n        load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q))\n    else:\n        pipeline_Q.producer_acquire(producer_state_Q)\n    load_Q(m_block, producer_state=producer_state_Q)\n    load_LSE(m_block, producer_state=producer_state_Q)\n\n    producer_state_dO_cur = (\n        producer_state_dO if const_expr(not Q_stage_eq_dO_stage) else producer_state_Q\n    )\n    if load_kv:\n        pipeline_dO.producer_acquire(producer_state_dO_cur, extra_tx_count=tma_copy_bytes_V)\n        load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_cur))\n    else:\n        pipeline_dO.producer_acquire(producer_state_dO_cur)\n    load_dO(m_block, producer_state=producer_state_dO_cur)\n    load_dPsum(m_block, producer_state=producer_state_dO_cur)\n\n    producer_state_Q.advance()\n    producer_state_dO.advance()\n    return producer_state_Q, producer_state_dO\n\n\n@cute.jit\ndef produce_block_sparse_q_loads_bwd_sm90(\n    blocksparse_tensors: BlockSparseTensors,\n    batch_idx,\n    head_idx,\n    n_block,\n    producer_state_Q,\n    producer_state_dO,\n    pipeline_Q,\n    pipeline_dO,\n    load_K,\n    load_V,\n    load_Q,\n    load_dO,\n    load_LSE,\n    load_dPsum,\n    tma_copy_bytes_K,\n    tma_copy_bytes_V,\n    Q_stage_eq_dO_stage: cutlass.Constexpr,\n    subtile_factor: cutlass.Constexpr,\n    m_block_max: int,\n):\n    \"\"\"SM90 backward block sparse loading with separate partial/full loops.\n\n    K/V are loaded with the first valid block. Iterates partial blocks first,\n    then full blocks, matching consumer order.\n\n    Returns updated (producer_state_Q, producer_state_dO).\n    \"\"\"\n    q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors\n    curr_q_cnt = q_cnt[batch_idx, head_idx, n_block]\n    curr_q_idx = q_idx[batch_idx, head_idx, n_block, None]\n\n    if const_expr(full_cnt is not None):\n        curr_full_cnt = full_cnt[batch_idx, head_idx, n_block]\n        curr_full_idx = full_idx[batch_idx, head_idx, n_block, None]\n    else:\n        curr_full_cnt = Int32(0)\n        curr_full_idx = None\n\n    kv_loaded = False\n\n    for iter_idx in cutlass.range(curr_q_cnt * subtile_factor, unroll=1):\n        sparse_idx = iter_idx // subtile_factor\n        subtile_offset = iter_idx % subtile_factor\n        m_block = curr_q_idx[sparse_idx] * subtile_factor + subtile_offset\n\n        if m_block < m_block_max:\n            producer_state_Q, producer_state_dO = _load_q_do_block_sm90(\n                m_block,\n                producer_state_Q,\n                producer_state_dO,\n                pipeline_Q,\n                pipeline_dO,\n                load_K,\n                load_V,\n                load_Q,\n                load_dO,\n                load_LSE,\n                load_dPsum,\n                tma_copy_bytes_K,\n                tma_copy_bytes_V,\n                Q_stage_eq_dO_stage,\n                load_kv=not kv_loaded,\n            )\n            kv_loaded = True\n\n    if const_expr(full_cnt is not None):\n        for iter_idx in cutlass.range(curr_full_cnt * subtile_factor, unroll=1):\n            sparse_idx = iter_idx // subtile_factor\n            subtile_offset = iter_idx % subtile_factor\n            m_block = curr_full_idx[sparse_idx] * subtile_factor + subtile_offset\n\n            if m_block < m_block_max:\n                producer_state_Q, producer_state_dO = _load_q_do_block_sm90(\n                    m_block,\n                    producer_state_Q,\n                    producer_state_dO,\n                    pipeline_Q,\n                    pipeline_dO,\n                    load_K,\n                    load_V,\n                    load_Q,\n                    load_dO,\n                    load_LSE,\n                    load_dPsum,\n                    tma_copy_bytes_K,\n                    tma_copy_bytes_V,\n                    Q_stage_eq_dO_stage,\n                    load_kv=not kv_loaded,\n                )\n                kv_loaded = True\n\n    return producer_state_Q, producer_state_dO\n\n\n@cute.jit\ndef consume_block_sparse_mma_bwd_sm90(\n    blocksparse_tensors: BlockSparseTensors,\n    batch_idx,\n    head_idx,\n    n_block,\n    consumer_state_Q,\n    consumer_state_dO,\n    mma_one_m_block_fn,\n    mask,\n    mask_mod,\n    is_causal: cutlass.Constexpr,\n    is_local: cutlass.Constexpr,\n    thr_mma_SdP,\n    score_mod_fn=None,\n    score_mod_bwd_fn=None,\n    subtile_factor: cutlass.Constexpr = 1,\n    m_block_max: int = 0,\n    aux_tensors=None,\n    fastdiv_mods=(None, None),\n):\n    \"\"\"SM90 backward block sparse MMA consumption with separate partial/full loops.\n\n    Partial blocks are processed first (with mask_mod applied), then full blocks\n    (without mask_mod). This ensures mask_mod is only applied where needed.\n\n    Returns updated (consumer_state_Q, consumer_state_dO).\n    \"\"\"\n    q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors\n    curr_q_cnt = q_cnt[batch_idx, head_idx, n_block]\n    curr_q_idx = q_idx[batch_idx, head_idx, n_block, None]\n\n    if const_expr(full_cnt is not None):\n        curr_full_cnt = full_cnt[batch_idx, head_idx, n_block]\n        curr_full_idx = full_idx[batch_idx, head_idx, n_block, None]\n    else:\n        curr_full_cnt = Int32(0)\n        curr_full_idx = None\n\n    dKV_accumulate = False\n\n    mask_fn_partial = partial(\n        mask.apply_mask,\n        batch_idx=batch_idx,\n        head_idx=head_idx,\n        n_block=n_block,\n        thr_mma=thr_mma_SdP,\n        mask_seqlen=True,\n        mask_causal=is_causal,\n        mask_local=is_local,\n        mask_mod=mask_mod,\n        aux_tensors=aux_tensors,\n        fastdiv_mods=fastdiv_mods,\n    )\n\n    mask_fn_full = partial(\n        mask.apply_mask,\n        batch_idx=batch_idx,\n        head_idx=head_idx,\n        n_block=n_block,\n        thr_mma=thr_mma_SdP,\n        mask_seqlen=True,\n        mask_causal=is_causal,\n        mask_local=is_local,\n        aux_tensors=aux_tensors,\n        fastdiv_mods=fastdiv_mods,\n    )\n\n    for iter_idx in cutlass.range(curr_q_cnt * subtile_factor, unroll=1):\n        sparse_idx = iter_idx // subtile_factor\n        subtile_offset = iter_idx % subtile_factor\n        m_block = curr_q_idx[sparse_idx] * subtile_factor + subtile_offset\n\n        if m_block < m_block_max:\n            consumer_state_Q, consumer_state_dO = mma_one_m_block_fn(\n                m_block,\n                consumer_state_Q,\n                consumer_state_dO,\n                mask_fn=mask_fn_partial,\n                score_mod_fn=score_mod_fn,\n                score_mod_bwd_fn=score_mod_bwd_fn,\n                dKV_accumulate=dKV_accumulate,\n            )\n            dKV_accumulate = True\n\n    if const_expr(full_cnt is not None):\n        for iter_idx in cutlass.range(curr_full_cnt * subtile_factor, unroll=1):\n            sparse_idx = iter_idx // subtile_factor\n            subtile_offset = iter_idx % subtile_factor\n            m_block = curr_full_idx[sparse_idx] * subtile_factor + subtile_offset\n\n            if m_block < m_block_max:\n                consumer_state_Q, consumer_state_dO = mma_one_m_block_fn(\n                    m_block,\n                    consumer_state_Q,\n                    consumer_state_dO,\n                    mask_fn=mask_fn_full,\n                    score_mod_fn=score_mod_fn,\n                    score_mod_bwd_fn=score_mod_bwd_fn,\n                    dKV_accumulate=dKV_accumulate,\n                )\n                dKV_accumulate = True\n\n    return consumer_state_Q, consumer_state_dO\n\n\n@cute.jit\ndef _store_one_dQaccum_sm90(\n    m_block,\n    sdQaccum: cute.Tensor,\n    gdQaccum: cute.Tensor,\n    num_mma_warp_groups: cutlass.Constexpr,\n    num_threads_per_warp_group: cutlass.Constexpr,\n    tma_copy_bytes_dQ,\n):\n    \"\"\"Store dQaccum for a single m_block.\"\"\"\n    for warp_group_idx in cutlass.range_constexpr(num_mma_warp_groups):\n        cute.arch.cp_async_bulk_wait_group(num_mma_warp_groups - 1 - warp_group_idx, read=True)\n        cute.arch.barrier_arrive(\n            barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx,\n            number_of_threads=num_threads_per_warp_group + cute.arch.WARP_SIZE,\n        )\n    for warp_group_idx in cutlass.range_constexpr(num_mma_warp_groups):\n        cute.arch.barrier(\n            barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx,\n            number_of_threads=num_threads_per_warp_group + cute.arch.WARP_SIZE,\n        )\n        with cute.arch.elect_one():\n            copy_utils.cpasync_reduce_bulk_add_f32(\n                sdQaccum[None, warp_group_idx].iterator,\n                gdQaccum[(None, warp_group_idx), m_block].iterator,\n                tma_copy_bytes_dQ,\n            )\n        cute.arch.cp_async_bulk_commit_group()\n\n\n@cute.jit\ndef dQaccum_store_block_sparse_bwd_sm90(\n    blocksparse_tensors: BlockSparseTensors,\n    batch_idx,\n    head_idx,\n    n_block,\n    sdQaccum: cute.Tensor,\n    gdQaccum: cute.Tensor,\n    subtile_factor: cutlass.Constexpr,\n    m_block_max: int,\n    num_mma_warp_groups: cutlass.Constexpr,\n    num_threads_per_warp_group: cutlass.Constexpr,\n    tma_copy_bytes_dQ,\n):\n    \"\"\"SM90 backward block sparse dQaccum store with separate partial/full loops.\n\n    Iterates partial blocks first, then full blocks, matching producer/consumer order.\n    \"\"\"\n    q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors\n    curr_q_cnt = q_cnt[batch_idx, head_idx, n_block]\n    curr_q_idx = q_idx[batch_idx, head_idx, n_block, None]\n\n    if const_expr(full_cnt is not None):\n        curr_full_cnt = full_cnt[batch_idx, head_idx, n_block]\n        curr_full_idx = full_idx[batch_idx, head_idx, n_block, None]\n    else:\n        curr_full_cnt = Int32(0)\n        curr_full_idx = None\n\n    for iter_idx in cutlass.range(curr_q_cnt * subtile_factor, unroll=1):\n        sparse_idx = iter_idx // subtile_factor\n        subtile_offset = iter_idx % subtile_factor\n        m_block = curr_q_idx[sparse_idx] * subtile_factor + subtile_offset\n\n        if m_block < m_block_max:\n            _store_one_dQaccum_sm90(\n                m_block,\n                sdQaccum,\n                gdQaccum,\n                num_mma_warp_groups,\n                num_threads_per_warp_group,\n                tma_copy_bytes_dQ,\n            )\n\n    if const_expr(full_cnt is not None):\n        for iter_idx in cutlass.range(curr_full_cnt * subtile_factor, unroll=1):\n            sparse_idx = iter_idx // subtile_factor\n            subtile_offset = iter_idx % subtile_factor\n            m_block = curr_full_idx[sparse_idx] * subtile_factor + subtile_offset\n\n            if m_block < m_block_max:\n                _store_one_dQaccum_sm90(\n                    m_block,\n                    sdQaccum,\n                    gdQaccum,\n                    num_mma_warp_groups,\n                    num_threads_per_warp_group,\n                    tma_copy_bytes_dQ,\n                )\n"
  },
  {
    "path": "flash_attn/cute/block_sparsity.py",
    "content": "\"\"\"\nBlock-sparsity utilities for FlexAttention\n\"\"\"\n\nfrom typing import Callable, NamedTuple, Tuple\n\nimport cutlass.cute as cute\nimport torch\n\nfrom flash_attn.cute.cute_dsl_utils import get_broadcast_dims, to_cute_tensor\n\n\ndef ceildiv(a: int, b: int) -> int:\n    return (a + b - 1) // b\n\n\nclass BlockSparseTensors(NamedTuple):\n    mask_block_cnt: cute.Tensor\n    mask_block_idx: cute.Tensor\n    full_block_cnt: cute.Tensor | None\n    full_block_idx: cute.Tensor | None\n\n    def __new_from_mlir_values__(self, values):\n        if len(values) == 2:\n            values = (*values, None, None)\n        return BlockSparseTensors(*values)\n\n\nclass BlockSparseTensorsTorch(NamedTuple):\n    mask_block_cnt: torch.Tensor\n    mask_block_idx: torch.Tensor\n    full_block_cnt: torch.Tensor | None = None\n    full_block_idx: torch.Tensor | None = None\n    block_size: tuple[int, int] | None = None\n\n\ndef _expand_sparsity_tensor(\n    tensor: torch.Tensor,\n    expected_shape: Tuple[int, ...],\n    tensor_name: str,\n    context: str | None,\n    hint: str | Callable[[], str] | None,\n) -> torch.Tensor:\n    \"\"\"Check if we need to expand the tensor to expected shape, and do so if possible.\"\"\"\n    needs_expand = tensor.shape != expected_shape\n    if not needs_expand:\n        return tensor\n    can_expand = all(map(lambda cur, tgt: cur == tgt or cur == 1, tensor.shape, expected_shape))\n    if not can_expand:\n        context_clause = f\" ({context})\" if context else \"\"\n        resolved_hint = hint() if callable(hint) else hint\n        hint_clause = f\" Hint: {resolved_hint}\" if resolved_hint else \"\"\n        raise ValueError(\n            f\"{tensor_name}{context_clause} with shape {tensor.shape} cannot be expanded to expected shape {expected_shape}.\"\n            f\"{hint_clause}\"\n        )\n    return tensor.expand(*expected_shape)\n\n\ndef _check_and_expand_block(\n    name: str,\n    cnt: torch.Tensor | None,\n    idx: torch.Tensor | None,\n    expected_count_shape: Tuple[int, int, int],\n    expected_index_shape: Tuple[int, int, int, int],\n    context: str | None,\n    hint: str | Callable[[], str] | None,\n) -> Tuple[torch.Tensor | None, torch.Tensor | None]:\n    if (cnt is None) != (idx is None):\n        raise ValueError(\n            f\"{name}_block_cnt and {name}_block_idx must both be provided or both be None\"\n        )\n    if cnt is None or idx is None:\n        return None, None\n    if cnt.dtype != torch.int32 or idx.dtype != torch.int32:\n        raise ValueError(f\"{name}_block tensors must have dtype torch.int32\")\n    if cnt.device != idx.device:\n        raise ValueError(f\"{name}_block_cnt and {name}_block_idx must be on the same device\")\n    if not cnt.is_cuda or not idx.is_cuda:\n        raise ValueError(f\"{name}_block tensors must live on CUDA\")\n    expanded_cnt = _expand_sparsity_tensor(\n        cnt, expected_count_shape, f\"{name}_block_cnt\", context, hint\n    )\n    expanded_idx = _expand_sparsity_tensor(\n        idx, expected_index_shape, f\"{name}_block_idx\", context, hint\n    )\n    return expanded_cnt, expanded_idx\n\n\ndef get_block_sparse_expected_shapes(\n    batch_size: int,\n    num_head: int,\n    seqlen_q: int,\n    seqlen_k: int,\n    m_block_size: int,\n    n_block_size: int,\n    q_stage: int,\n) -> Tuple[Tuple[int, int, int], Tuple[int, int, int, int]]:\n    \"\"\"Return (expected_count_shape, expected_index_shape) for block sparse normalization.\"\"\"\n    m_block_size_effective = q_stage * m_block_size\n    expected_m_blocks = ceildiv(seqlen_q, m_block_size_effective)\n    expected_n_blocks = ceildiv(seqlen_k, n_block_size)\n    expected_count_shape = (batch_size, num_head, expected_m_blocks)\n    expected_index_shape = (batch_size, num_head, expected_m_blocks, expected_n_blocks)\n    return expected_count_shape, expected_index_shape\n\n\ndef infer_block_sparse_expected_shapes(\n    tensors: BlockSparseTensorsTorch,\n    *,\n    batch_size: int,\n    num_head: int,\n    seqlen_q: int,\n    seqlen_k: int,\n    m_block_size: int,\n    n_block_size: int,\n    q_stage: int,\n    context: str,\n    sparse_block_size_q: int | None = None,\n    sparse_block_size_kv: int | None = None,\n) -> Tuple[Tuple[int, int, int], Tuple[int, int, int, int], int]:\n    \"\"\"Infer shapes and scaling for block-sparse tensors.\n\n    Expectations:\n    - mask_block_cnt is (B, H, M) and mask_block_idx is (B, H, M, N).\n    - Batch/head dims may be 1 for broadcast, or match the requested sizes.\n    - sparse_block_size_kv must match tile_n.\n    - sparse_block_size_q must be a multiple of q_stage * tile_m.\n    - If sparse_block_size_q is omitted and seqlen_q/num_m_blocks is ambiguous,\n      the caller must provide block_size to disambiguate. TODO will make this required in a future PR.\n    \"\"\"\n    base_m_block = q_stage * m_block_size\n    base_n_block = n_block_size\n    if sparse_block_size_kv is None:\n        sparse_block_size_kv = base_n_block\n    if sparse_block_size_kv != base_n_block:\n        raise ValueError(f\"Block sparse tensors{context} require BLOCK_SIZE_KV={base_n_block}.\")\n    if tensors.mask_block_idx is None:\n        raise ValueError(\"mask_block_cnt and mask_block_idx must be provided for block sparsity.\")\n    num_m_blocks = tensors.mask_block_idx.shape[2]\n\n    if sparse_block_size_q is None:\n        min_block_size = ceildiv(seqlen_q, num_m_blocks)\n        if num_m_blocks == 1:\n            max_block_size = seqlen_q\n        else:\n            max_block_size = (seqlen_q - 1) // (num_m_blocks - 1)\n        if max_block_size != min_block_size and base_m_block != 1:\n            raise ValueError(\n                f\"Block sparse tensors{context} require explicit sparse_block_size[0] \"\n                f\"to disambiguate block size for seqlen_q={seqlen_q} and num_m_blocks={num_m_blocks}.\"\n            )\n        sparse_block_size_q = min_block_size\n\n    if sparse_block_size_q % base_m_block != 0:\n        raise ValueError(\n            f\"Block sparse tensors{context} have block size {sparse_block_size_q}, \"\n            f\"which must be a multiple of {base_m_block}.\"\n        )\n\n    expected_m_blocks = ceildiv(seqlen_q, sparse_block_size_q)\n    expected_n_blocks = ceildiv(seqlen_k, sparse_block_size_kv)\n    q_subtile_factor = sparse_block_size_q // base_m_block\n    expected_count_shape = (batch_size, num_head, expected_m_blocks)\n    expected_index_shape = (batch_size, num_head, expected_m_blocks, expected_n_blocks)\n\n    mask_block_cnt = tensors.mask_block_cnt\n    mask_block_idx = tensors.mask_block_idx\n    if mask_block_cnt is None or mask_block_idx is None:\n        raise ValueError(\"mask_block_cnt and mask_block_idx must be provided for block sparsity.\")\n    if mask_block_cnt.ndim != 3 or mask_block_idx.ndim != 4:\n        raise ValueError(\n            f\"Block sparse tensors{context} must have shapes (B, H, M) and (B, H, M, N).\"\n        )\n    for dim_name, cur, tgt in (\n        (\"batch\", mask_block_cnt.shape[0], expected_count_shape[0]),\n        (\"head\", mask_block_cnt.shape[1], expected_count_shape[1]),\n    ):\n        if cur != tgt and cur != 1:\n            raise ValueError(f\"Block sparse tensors{context} {dim_name} dim must be {tgt} or 1.\")\n    for dim_name, cur, tgt in (\n        (\"batch\", mask_block_idx.shape[0], expected_index_shape[0]),\n        (\"head\", mask_block_idx.shape[1], expected_index_shape[1]),\n    ):\n        if cur != tgt and cur != 1:\n            raise ValueError(f\"Block sparse tensors{context} {dim_name} dim must be {tgt} or 1.\")\n    if mask_block_cnt.shape[2] != mask_block_idx.shape[2]:\n        raise ValueError(f\"Block sparse tensors{context} must share the same m-block dimension.\")\n    if mask_block_idx.shape[3] != expected_n_blocks:\n        raise ValueError(\n            f\"Block sparse tensors{context} n-block dimension must be {expected_n_blocks}.\"\n        )\n    if expected_m_blocks != num_m_blocks:\n        raise ValueError(\n            f\"Block sparse tensors{context} m-block dimension {num_m_blocks} does not match \"\n            f\"sparse_block_size_q={sparse_block_size_q}. \"\n            f\"Set BlockSparseTensorsTorch.block_size to match the BlockMask BLOCK_SIZE.\"\n        )\n    return expected_count_shape, expected_index_shape, q_subtile_factor\n\n\ndef get_block_sparse_expected_shapes_bwd(\n    batch_size: int,\n    num_head: int,\n    seqlen_q: int,\n    seqlen_k: int,\n    m_block_size: int,\n    n_block_size: int,\n    subtile_factor: int,\n) -> Tuple[Tuple[int, int, int], Tuple[int, int, int, int]]:\n    \"\"\"Return (expected_count_shape, expected_index_shape) for backward block sparse normalization.\n\n    Backward uses Q-direction indexing (transposed from forward), where shapes are\n    indexed by N-blocks first, then M-blocks. The sparse_block_size_q is determined\n    by subtile_factor * m_block_size.\n    \"\"\"\n    sparse_block_size_q = subtile_factor * m_block_size\n    expected_m_blocks = ceildiv(seqlen_q, sparse_block_size_q)\n    expected_n_blocks = ceildiv(seqlen_k, n_block_size)\n    expected_count_shape = (batch_size, num_head, expected_n_blocks)\n    expected_index_shape = (batch_size, num_head, expected_n_blocks, expected_m_blocks)\n    return expected_count_shape, expected_index_shape\n\n\ndef normalize_block_sparse_tensors(\n    tensors: BlockSparseTensorsTorch,\n    *,\n    expected_count_shape: Tuple[int, int, int],\n    expected_index_shape: Tuple[int, int, int, int],\n    context: str | None = None,\n    hint: str | Callable[[], str] | None = None,\n) -> BlockSparseTensorsTorch:\n    if tensors.mask_block_cnt is None or tensors.mask_block_idx is None:\n        raise ValueError(\"mask_block_cnt and mask_block_idx must be provided for block sparsity.\")\n\n    mask_cnt, mask_idx = _check_and_expand_block(\n        \"mask\",\n        tensors.mask_block_cnt,\n        tensors.mask_block_idx,\n        expected_count_shape,\n        expected_index_shape,\n        context,\n        hint,\n    )\n    if mask_cnt is None or mask_idx is None:\n        raise ValueError(\"mask_block_cnt and mask_block_idx must be provided for block sparsity.\")\n\n    full_cnt, full_idx = _check_and_expand_block(\n        \"full\",\n        tensors.full_block_cnt,\n        tensors.full_block_idx,\n        expected_count_shape,\n        expected_index_shape,\n        context,\n        hint,\n    )\n    if full_cnt is not None and mask_cnt.device != full_cnt.device:\n        raise ValueError(\"All block sparse tensors must be on the same device\")\n\n    return BlockSparseTensorsTorch(\n        mask_block_cnt=mask_cnt,\n        mask_block_idx=mask_idx,\n        full_block_cnt=full_cnt,\n        full_block_idx=full_idx,\n        block_size=tensors.block_size,\n    )\n\n\ndef is_block_sparsity_enabled(tensors: BlockSparseTensorsTorch) -> bool:\n    return any(t is not None for t in (tensors.full_block_cnt, tensors.mask_block_cnt))\n\n\ndef get_block_sparse_broadcast_pattern(\n    tensors: BlockSparseTensorsTorch,\n) -> Tuple[Tuple[bool, ...], ...] | None:\n    \"\"\"Return broadcast pattern for block sparse tensors by checking actual strides.\n\n    Returns a tuple of broadcast patterns (one per tensor) where each pattern\n    is a tuple of bools indicating which dims have stride=0.\n    This is used in compile keys to ensure kernels are recompiled when\n    broadcast patterns change, since CuTe's mark_layout_dynamic() keeps\n    stride=0 as static.\n\n    The tensors should already be expanded/normalized before calling this function.\n\n    Returns None if block sparsity is not enabled.\n    \"\"\"\n    if not is_block_sparsity_enabled(tensors):\n        return None\n\n    patterns = []\n    for tensor in (\n        tensors.mask_block_cnt,\n        tensors.mask_block_idx,\n        tensors.full_block_cnt,\n        tensors.full_block_idx,\n    ):\n        if tensor is not None:\n            patterns.append(get_broadcast_dims(tensor))\n        else:\n            patterns.append(None)\n    return tuple(patterns)\n\n\ndef normalize_block_sparse_config(\n    tensors: BlockSparseTensorsTorch,\n    *,\n    batch_size: int,\n    num_head: int,\n    seqlen_q: int,\n    seqlen_k: int,\n    block_size: tuple[int, int],\n    q_stage: int,\n) -> tuple[BlockSparseTensorsTorch, Tuple[Tuple[bool, ...], ...] | None, int]:\n    m_block_size, n_block_size = block_size\n    if tensors.block_size is None:\n        sparse_block_size_q, sparse_block_size_kv = q_stage * m_block_size, n_block_size\n    else:\n        sparse_block_size_q, sparse_block_size_kv = tensors.block_size\n    if sparse_block_size_kv != n_block_size:\n        raise ValueError(\n            f\"Block sparsity requires sparse_block_size[1]={n_block_size} to match tile_n.\"\n        )\n    expected_count_shape, expected_index_shape, q_subtile_factor = (\n        infer_block_sparse_expected_shapes(\n            tensors,\n            batch_size=batch_size,\n            num_head=num_head,\n            seqlen_q=seqlen_q,\n            seqlen_k=seqlen_k,\n            m_block_size=m_block_size,\n            n_block_size=n_block_size,\n            q_stage=q_stage,\n            context=\"forward\",\n            sparse_block_size_q=sparse_block_size_q,\n            sparse_block_size_kv=sparse_block_size_kv,\n        )\n    )\n    normalized_tensors = normalize_block_sparse_tensors(\n        tensors,\n        expected_count_shape=expected_count_shape,\n        expected_index_shape=expected_index_shape,\n    )\n    return (\n        normalized_tensors,\n        get_block_sparse_broadcast_pattern(normalized_tensors),\n        q_subtile_factor,\n    )\n\n\ndef normalize_block_sparse_config_bwd(\n    tensors: BlockSparseTensorsTorch,\n    *,\n    batch_size: int,\n    num_head: int,\n    seqlen_q: int,\n    seqlen_k: int,\n    block_size: tuple[int, int],\n    subtile_factor: int,\n) -> tuple[BlockSparseTensorsTorch, Tuple[Tuple[bool, ...], ...] | None]:\n    m_block_size, n_block_size = block_size\n    if tensors.block_size is None:\n        sparse_block_size_q, sparse_block_size_kv = subtile_factor * m_block_size, n_block_size\n    else:\n        sparse_block_size_q, sparse_block_size_kv = tensors.block_size\n    if sparse_block_size_q != subtile_factor * m_block_size:\n        raise ValueError(\n            f\"Block sparsity expects sparse_block_size_q={subtile_factor * m_block_size} \"\n            f\"for subtile_factor={subtile_factor}.\"\n        )\n    if sparse_block_size_kv != n_block_size:\n        raise ValueError(\n            f\"Block sparsity expects sparse_block_size[1]={n_block_size} to match tile_n.\"\n        )\n    expected_count_shape, expected_index_shape = get_block_sparse_expected_shapes_bwd(\n        batch_size,\n        num_head,\n        seqlen_q,\n        seqlen_k,\n        m_block_size,\n        n_block_size,\n        subtile_factor,\n    )\n    normalized_tensors = normalize_block_sparse_tensors(\n        tensors,\n        expected_count_shape=expected_count_shape,\n        expected_index_shape=expected_index_shape,\n        context=\"_flash_attn_bwd\",\n        hint=lambda: (\n            f\"Backward expects Q-direction block-sparse tensors (q_mask_cnt/q_mask_idx, \"\n            f\"and optionally full_q_cnt/full_q_idx). Regenerate the backward BlockMask with \"\n            f\"BLOCK_SIZE=({subtile_factor * m_block_size}, {n_block_size}).\"\n        ),\n    )\n    return normalized_tensors, get_block_sparse_broadcast_pattern(normalized_tensors)\n\n\ndef to_cute_block_sparse_tensors(\n    tensors: BlockSparseTensorsTorch, enable_tvm_ffi: bool = True\n) -> BlockSparseTensors | None:\n    \"\"\"Convert torch block sparsity tensors to CuTe tensors, optionally for tvm ffi\"\"\"\n    if not is_block_sparsity_enabled(tensors):\n        return None\n    (\n        mask_block_cnt,\n        mask_block_idx,\n        full_block_cnt,\n        full_block_idx,\n        *_,\n    ) = tensors\n\n    (\n        mask_block_cnt_tensor,\n        mask_block_idx_tensor,\n    ) = [\n        to_cute_tensor(t, assumed_align=4, leading_dim=-1, enable_tvm_ffi=enable_tvm_ffi)\n        for t in (mask_block_cnt, mask_block_idx)\n    ]\n    (\n        full_block_cnt_tensor,\n        full_block_idx_tensor,\n    ) = [\n        to_cute_tensor(t, assumed_align=4, leading_dim=-1, enable_tvm_ffi=enable_tvm_ffi)\n        if t is not None\n        else None\n        for t in (full_block_cnt, full_block_idx)\n    ]\n\n    return BlockSparseTensors(\n        mask_block_cnt_tensor,\n        mask_block_idx_tensor,\n        full_block_cnt_tensor,\n        full_block_idx_tensor,\n    )\n\n\ndef fast_sampling(mask_mod):\n    \"\"\"Convenience decorator to mark mask_mod as safe for 5-point fast sampling\"\"\"\n    mask_mod.use_fast_sampling = True\n    return mask_mod\n"
  },
  {
    "path": "flash_attn/cute/cache_utils.py",
    "content": "# Manage Ahead-of-Time (AOT) compiled kernels\nimport fcntl\nimport hashlib\nimport logging\nimport os\nimport pickle\nimport sys\nimport tempfile\nimport time\nfrom functools import lru_cache\nfrom getpass import getuser\nfrom pathlib import Path\nfrom typing import Hashable, TypeAlias\n\nimport ctypes\n\nimport cutlass\nimport cutlass.cute as cute\nimport tvm_ffi\nfrom cutlass.cutlass_dsl import JitCompiledFunction\n\n# Pre-load cute DSL runtime libraries with RTLD_GLOBAL so that their symbols\n# (e.g. _cudaLibraryLoadData) are visible to .so modules loaded later via dlopen.\n# Upstream cute.runtime.load_module loads these without RTLD_GLOBAL, which causes\n# \"undefined symbol\" errors when loading cached kernels from disk.\nfor _lib_path in cute.runtime.find_runtime_libraries(enable_tvm_ffi=False):\n    if Path(_lib_path).exists():\n        ctypes.CDLL(_lib_path, mode=ctypes.RTLD_GLOBAL)\n\nCompileKeyType: TypeAlias = tuple[Hashable, ...]\nCallableFunction: TypeAlias = JitCompiledFunction | tvm_ffi.Function\n\nlogger = logging.getLogger(__name__)\n_handler = logging.StreamHandler()\n_handler.setFormatter(logging.Formatter(\"%(asctime)s.%(msecs)03d %(levelname)s %(message)s\", datefmt=\"%Y-%m-%d %H:%M:%S\"))\nlogger.addHandler(_handler)\nlogger.setLevel(logging.DEBUG)\n\n\n# Enable cache via `FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1`\nCUTE_DSL_CACHE_ENABLED: bool = os.getenv(\"FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED\", \"0\") == \"1\"\n\n\n# Customize cache dir via `FLASH_ATTENTION_CUTE_DSL_CACHE_DIR`, default is\n# `/tmp/${USER}/flash_attention_cute_dsl_cache``\nCUTE_DSL_CACHE_DIR: str | None = os.getenv(\"FLASH_ATTENTION_CUTE_DSL_CACHE_DIR\", None)\n\n\ndef get_cache_path() -> Path:\n    if CUTE_DSL_CACHE_DIR is not None:\n        cache_dir = Path(CUTE_DSL_CACHE_DIR)\n    else:\n        cache_dir = Path(tempfile.gettempdir()) / getuser() / \"flash_attention_cute_dsl_cache\"\n    cache_dir.mkdir(parents=True, exist_ok=True)\n    return cache_dir\n\n\n@lru_cache(maxsize=1)\ndef _compute_source_fingerprint() -> str:\n    \"\"\"\n    Hash all CuTe Python sources plus runtime ABI stamps into a short fingerprint.\n\n    The fingerprint changes whenever:\n    - Any .py file under flash_attn/cute is added, removed, renamed, or modified.\n    - The Python minor version changes (e.g. 3.13 -> 3.14).\n    - The cutlass or tvm_ffi package version changes.\n\n    Computed once per process and cached.\n    \"\"\"\n    cute_root = Path(__file__).resolve().parent\n    h = hashlib.sha256()\n\n    h.update(f\"py{sys.version_info.major}.{sys.version_info.minor}\".encode())\n    h.update(f\"cutlass={cutlass.__version__}\".encode())\n    h.update(f\"tvm_ffi={tvm_ffi.__version__}\".encode())\n\n    for src in sorted(cute_root.rglob(\"*.py\")):\n        if not src.is_file():\n            continue\n        h.update(src.relative_to(cute_root).as_posix().encode())\n        content = src.read_bytes()\n        h.update(len(content).to_bytes(8, \"little\"))\n        h.update(content)\n\n    return h.hexdigest()\n\n\nclass FileLock:\n    \"\"\"Context manager for advisory file locks using fcntl.flock.\n\n    Supports exclusive (write) and shared (read) locks.\n    Always blocks with polling until the lock is acquired or timeout is reached.\n\n    Usage:\n        with FileLock(lock_path, exclusive=True, timeout=15, label=\"abc\"):\n            # do work under lock\n    \"\"\"\n\n    def __init__(\n        self,\n        lock_path: Path,\n        exclusive: bool,\n        timeout: float = 15,\n        label: str = \"\",\n    ):\n        \"\"\"\n        Args:\n            lock_path: Path to the lock file on disk.\n            exclusive: True for exclusive (write) lock, False for shared (read) lock.\n            timeout: Max seconds to wait for lock acquisition before raising RuntimeError.\n            label: Optional human-readable label for error messages.\n        \"\"\"\n        self.lock_path: Path = lock_path\n        self.exclusive: bool = exclusive\n        self.timeout: float = timeout\n        self.label: str = label\n        self._fd: int = -1\n\n    @property\n    def _lock_label(self) -> str:\n        kind = \"exclusive\" if self.exclusive else \"shared\"\n        return f\"{kind} {self.label}\" if self.label else kind\n\n    def __enter__(self) -> \"FileLock\":\n        open_flags = os.O_WRONLY | os.O_CREAT if self.exclusive else os.O_RDONLY | os.O_CREAT\n        lock_type = fcntl.LOCK_EX if self.exclusive else fcntl.LOCK_SH\n\n        self._fd = os.open(str(self.lock_path), open_flags)\n\n        deadline = time.monotonic() + self.timeout\n        acquired = False\n        while time.monotonic() < deadline:\n            try:\n                fcntl.flock(self._fd, lock_type | fcntl.LOCK_NB)\n                acquired = True\n                break\n            except OSError:\n                time.sleep(0.1)\n        if not acquired:\n            os.close(self._fd)\n            self._fd = None\n            raise RuntimeError(\n                f\"Timed out after {self.timeout}s waiting for \"\n                f\"{self._lock_label} lock: {self.lock_path}\"\n            )\n\n        return self\n\n    def __exit__(self, exc_type, exc_val, exc_tb) -> None:\n        if self._fd is not None:\n            fcntl.flock(self._fd, fcntl.LOCK_UN)\n            os.close(self._fd)\n            self._fd = None\n\n\nclass JITCache:\n    \"\"\"\n    In-memory cache for compiled functions.\n    \"\"\"\n\n    def __init__(self):\n        self.cache: dict[CompileKeyType, CallableFunction] = {}\n\n    def __setitem__(self, key: CompileKeyType, fn: JitCompiledFunction) -> None:\n        self.cache[key] = fn\n\n    def __getitem__(self, key: CompileKeyType) -> CallableFunction:\n        return self.cache[key]\n\n    def __contains__(self, key: CompileKeyType) -> bool:\n        return key in self.cache\n\n    def clear(self) -> None:\n        \"\"\"\n        Clear in-memory cache of compiled functions\n        \"\"\"\n        self.cache.clear()\n\n\nclass JITPersistentCache(JITCache):\n    \"\"\"\n    In-memory cache for compiled functions, which is also backed by persistent storage.\n    Use cutedsl ahead-of-time (AOT) compilation, only supporting enable_tvm_ffi=True\n    \"\"\"\n\n    EXPORT_FUNCTION_PREFIX = \"func\"\n    LOCK_TIMEOUT_SECONDS = 15\n\n    def __init__(self, cache_path: Path):\n        super().__init__()\n        cache_path.mkdir(parents=True, exist_ok=True)\n        self.cache_path: Path = cache_path\n\n    def __setitem__(self, key: CompileKeyType, fn: JitCompiledFunction) -> None:\n        JITCache.__setitem__(self, key, fn)\n        self._try_export_to_storage(key, fn)\n\n    def __getitem__(self, key: CompileKeyType) -> CallableFunction:\n        # Use __contains__ to try populating in-memory cache with persistent storage\n        self.__contains__(key)\n        return JITCache.__getitem__(self, key)\n\n    def __contains__(self, key: CompileKeyType) -> bool:\n        # Checks in-memory cache first, then tries loading from storage.\n        # When returning True, guarantees the in-memory cache is populated.\n        if JITCache.__contains__(self, key):\n            return True\n        return self._try_load_from_storage(key)\n\n    def _try_load_from_storage(self, key: CompileKeyType) -> bool:\n        \"\"\"\n        Try to load a function from persistent storage into in-memory cache.\n        Returns True if loaded successfully, False if not found on disk.\n        Holds a shared lock during loading to prevent concurrent writes.\n        \"\"\"\n        sha256_hex = self._key_to_hash(key)\n        obj_path = self.cache_path / f\"{sha256_hex}.o\"\n        with FileLock(\n            self._lock_path(sha256_hex),\n            exclusive=False,\n            timeout=self.LOCK_TIMEOUT_SECONDS,\n            label=sha256_hex,\n        ):\n            if obj_path.exists():\n                logger.debug(\"Loading compiled function from disk: %s\", obj_path)\n                m = cute.runtime.load_module(str(obj_path), enable_tvm_ffi=True)\n                fn = getattr(m, self.EXPORT_FUNCTION_PREFIX)\n                JITCache.__setitem__(self, key, fn)\n                return True\n            else:\n                logger.debug(\"Cache miss on disk for key hash %s\", sha256_hex)\n        return False\n\n    def _try_export_to_storage(self, key: CompileKeyType, fn: JitCompiledFunction) -> None:\n        \"\"\"Export a compiled function to persistent storage under exclusive lock.\"\"\"\n        sha256_hex = self._key_to_hash(key)\n        with FileLock(\n            self._lock_path(sha256_hex),\n            exclusive=True,\n            timeout=self.LOCK_TIMEOUT_SECONDS,\n            label=sha256_hex,\n        ):\n            obj_path = self.cache_path / f\"{sha256_hex}.o\"\n            if obj_path.exists():\n                # Another process already exported.\n                logger.debug(\"Skipping export, already on disk: %s\", obj_path)\n                return\n            logger.debug(\"Exporting compiled function to disk: %s\", obj_path)\n            fn.export_to_c(\n                object_file_path=str(obj_path),\n                function_name=self.EXPORT_FUNCTION_PREFIX,\n            )\n            logger.debug(\"Successfully exported compiled function to disk: %s\", obj_path)\n\n    def _key_to_hash(self, key: CompileKeyType) -> str:\n        return hashlib.sha256(pickle.dumps(key)).hexdigest()\n\n    def _lock_path(self, sha256_hex: str) -> Path:\n        return self.cache_path / f\"{sha256_hex}.lock\"\n\n    def clear(self) -> None:\n        \"\"\"\n        Not only clear the in-memory cache. Also purge persistent compilation cache.\n        \"\"\"\n        logger.debug(\"Clearing persistent cache at %s\", self.cache_path)\n        super().clear()\n        for child in self.cache_path.iterdir():\n            child.unlink()\n\n\ndef get_jit_cache(name: str | None = None) -> JITCache:\n    \"\"\"\n    JIT cache factory.\n    `name` is an optional identifier to create subdirectories to manage cache.\n\n    When persistent caching is enabled, artifacts are namespaced under a\n    source fingerprint directory so that code or dependency changes\n    automatically invalidate stale entries.\n    \"\"\"\n    if CUTE_DSL_CACHE_ENABLED:\n        path = get_cache_path() / _compute_source_fingerprint()\n        if name:\n            path = path / name\n        logger.debug(\"Creating persistent JIT cache at %s\", path)\n        return JITPersistentCache(path)\n    else:\n        logger.debug(\"Persistent cache disabled, using in-memory JIT cache\")\n        return JITCache()\n"
  },
  {
    "path": "flash_attn/cute/compute_block_sparsity.py",
    "content": "from functools import partial\nfrom typing import Callable, Optional, Tuple\n\nimport cutlass\nimport cutlass.cute as cute\nimport torch\nfrom cutlass import Boolean, Int8, Int32, const_expr\n\nfrom flash_attn.cute.block_sparsity import (\n    BlockSparseTensors,\n    BlockSparseTensorsTorch,\n    to_cute_block_sparse_tensors,\n)\nfrom flash_attn.cute.utils import hash_callable, scalar_to_ssa, ssa_to_scalar\nfrom flash_attn.cute.seqlen_info import SeqlenInfoQK\n\n\nclass BlockSparsityKernel:\n    \"\"\"Block sparsity kernel for FlexAttention.\n\n    This kernel computes `mask_mod` for every token of each block\n    to determine if an n block is full, masked, or neither.\n\n    Writes block counts and indices to a BlockSparseTensors object.\n\n    When use_fast_sampling=True, uses 5-point sampling (4 corners + center)\n    which is much faster but only suitable for masks where this is sufficient.\n\n    TODO:\n        - optimize mask_mod evaluation\n        - varlen support\n        - transposed tensors for bwd pass\n    \"\"\"\n\n    def __init__(\n        self,\n        mask_mod: Callable,\n        tile_mn: Tuple[int, int],\n        compute_full_blocks: bool = True,\n        use_aux_tensors: bool = False,\n        use_fast_sampling: bool = False,\n    ):\n        self.mask_mod = mask_mod\n        self.tile_mn = tile_mn\n        self.compute_full_blocks = compute_full_blocks\n        self.use_aux_tensors = use_aux_tensors\n        self.use_fast_sampling = use_fast_sampling\n\n    @cute.jit\n    def __call__(\n        self,\n        blocksparse_tensors: BlockSparseTensors,\n        seqlen_q: Int32,\n        seqlen_k: Int32,\n        aux_tensors: Optional[list] = None,\n    ):\n        self.mask_cnt, self.mask_idx, self.full_cnt, self.full_idx = blocksparse_tensors\n\n        if const_expr(self.compute_full_blocks):\n            assert self.full_cnt is not None and self.full_idx is not None, (\n                \"full block tensors must be provided when computing full blocks\"\n            )\n\n        batch_size, num_heads, num_m_blocks, num_n_blocks = self.mask_idx.shape\n        # launch 1 CTA per m block\n        grid = [num_m_blocks, num_heads, batch_size]\n\n        if const_expr(self.use_fast_sampling):\n            num_threads = 5\n            self.num_warps = 1\n        else:\n            num_threads = self.tile_mn[0]\n            self.num_warps = (num_threads + 32 - 1) // 32\n\n        self.kernel(\n            self.mask_cnt,\n            self.mask_idx,\n            self.full_cnt,\n            self.full_idx,\n            num_n_blocks,\n            seqlen_q,\n            seqlen_k,\n            aux_tensors,\n        ).launch(grid=grid, block=[num_threads, 1, 1])\n\n    @cute.kernel\n    def kernel(\n        self,\n        mask_cnt: cute.Tensor,\n        mask_idx: cute.Tensor,\n        full_cnt: cute.Tensor,\n        full_idx: cute.Tensor,\n        num_n_blocks: Int32,\n        seqlen_q: Int32,\n        seqlen_k: Int32,\n        aux_tensors: Optional[list] = None,\n    ):\n        tidx, _, _ = cute.arch.thread_idx()\n        warp_idx = cute.arch.warp_idx()\n        lane_id = cute.arch.lane_idx()\n        m_block, head_idx, batch_idx = cute.arch.block_idx()\n\n        ssa = partial(scalar_to_ssa, dtype=Int32)\n\n        seqlen = SeqlenInfoQK.create(\n            batch_idx,\n            seqlen_q,\n            seqlen_k,\n            mCuSeqlensQ=None,\n            mCuSeqlensK=None,\n            mSeqUsedQ=None,\n            mSeqUsedK=None,\n        )\n\n        @cute.struct\n        class SharedStorage:\n            reduction_buffer_smem: cute.struct.Align[\n                cute.struct.MemRange[cutlass.Int8, 2 * self.num_warps], 1024\n            ]\n\n        smem = cutlass.utils.SmemAllocator()\n        storage = smem.allocate(SharedStorage, 16)\n\n        reduction_buffer = storage.reduction_buffer_smem.get_tensor(\n            cute.make_layout((self.num_warps, 2))\n        )\n\n        num_mask_blocks = Int32(0)\n        num_full_blocks = Int32(0)\n\n        for n_block in cutlass.range(num_n_blocks, unroll_full=True):\n            m_base = m_block * self.tile_mn[0]\n            n_base = n_block * self.tile_mn[1]\n\n            if const_expr(self.use_fast_sampling):\n                # Fast path: 5-point sampling (4 corners + center)\n                # Clamps OOB indices to nearest in bounds.\n                thread_result = Boolean(False)\n                thread_is_valid = Boolean(False)\n                q_idx = Int32(0)\n                kv_idx = Int32(0)\n\n                if tidx == 0:\n                    # Top-left corner (0, 0); always in bounds\n                    q_idx = m_base\n                    kv_idx = n_base\n                elif tidx == 1:\n                    # Top-right corner\n                    q_idx = m_base\n                    kv_idx = cutlass.min(n_base + self.tile_mn[1] - 1, seqlen_k - 1)\n                elif tidx == 2:\n                    # Bottom-left corner\n                    q_idx = cutlass.min(m_base + self.tile_mn[0] - 1, seqlen_q - 1)\n                    kv_idx = n_base\n                elif tidx == 3:\n                    # Bottom-right corner\n                    q_idx = cutlass.min(m_base + self.tile_mn[0] - 1, seqlen_q - 1)\n                    kv_idx = cutlass.min(n_base + self.tile_mn[1] - 1, seqlen_k - 1)\n                elif tidx == 4:\n                    # Center point\n                    q_idx = m_base + (cutlass.min(seqlen_q - m_base, self.tile_mn[0])) // 2\n                    kv_idx = n_base + (cutlass.min(seqlen_k - n_base, self.tile_mn[1])) // 2\n                else:\n                    thread_is_valid = Boolean(False)\n\n                # Check bounds and determine if this thread has a valid index pair\n                if tidx < 5 and q_idx < seqlen_q and kv_idx < seqlen_k:\n                    thread_is_valid = Boolean(True)\n                    q_idx_ssa = ssa(q_idx)\n                    kv_idx_ssa = ssa(kv_idx)\n                    thread_result = ssa_to_scalar(\n                        self.mask_mod(\n                            ssa(batch_idx),\n                            ssa(head_idx),\n                            q_idx_ssa,\n                            kv_idx_ssa,\n                            seqlen,\n                            aux_tensors,\n                        )\n                    )\n                else:\n                    thread_is_valid = Boolean(False)\n\n                # Use vote_any_sync to see if any valid thread found unmasked or masked\n                # Only count results from threads that checked valid indices\n                has_unmasked = cute.arch.vote_any_sync(thread_result & thread_is_valid)\n                has_masked = cute.arch.vote_any_sync((Boolean(not thread_result)) & thread_is_valid)\n\n            else:\n                # Full path: check all elements in the block\n                # Track if this thread's row has any masked or unmasked elements\n                thread_has_unmasked = Boolean(False)\n                thread_has_masked = Boolean(False)\n                thread_is_valid = Boolean(False)\n\n                # Each thread handles 1 row\n                q_idx = m_base + tidx\n                kv_idx = Int32(0)\n                if tidx < self.tile_mn[0] and q_idx < seqlen_q:\n                    thread_is_valid = Boolean(True)\n                    q_idx_ssa = ssa(q_idx)\n\n                    # Loop over all columns in this row\n                    for c in cutlass.range(self.tile_mn[1], unroll_full=True):\n                        kv_idx = n_base + c\n                        kv_idx_ssa = ssa(kv_idx)\n\n                        # Only check elements within valid sequence bounds\n                        if kv_idx < seqlen_k:\n                            # Direct scalar call\n                            mask_val = ssa_to_scalar(\n                                self.mask_mod(\n                                    ssa(batch_idx),\n                                    ssa(head_idx),\n                                    q_idx_ssa,\n                                    kv_idx_ssa,\n                                    seqlen,\n                                    aux_tensors,\n                                )\n                            )\n\n                            # Update tracking flags\n                            if mask_val:\n                                thread_has_unmasked = Boolean(True)\n                            else:\n                                thread_has_masked = Boolean(True)\n\n                # Block-level reduction to combine results across all threads\n                # Only count votes from threads that checked valid indices\n                warp_has_unmasked_mask = cute.arch.vote_any_sync(\n                    thread_has_unmasked & thread_is_valid\n                )\n                warp_has_masked_mask = cute.arch.vote_any_sync(thread_has_masked & thread_is_valid)\n\n                # lane 0 writes the ballot mask to shared memory\n                lane_id = tidx % 32\n                if lane_id == 0:\n                    # Store as Int8\n                    reduction_buffer[warp_idx, 0] = Int8(1) if warp_has_unmasked_mask else Int8(0)\n                    reduction_buffer[warp_idx, 1] = Int8(1) if warp_has_masked_mask else Int8(0)\n\n                cute.arch.sync_threads()\n\n                # Thread 0 ORs all warp results together\n                has_unmasked = Boolean(False)\n                has_masked = Boolean(False)\n                if tidx == 0:\n                    for w in cutlass.range(self.num_warps):\n                        if reduction_buffer[w, 0]:\n                            has_unmasked = Boolean(True)\n                        if reduction_buffer[w, 1]:\n                            has_masked = Boolean(True)\n\n            # Only thread 0 updates the output arrays (common to both paths)\n            if tidx == 0:\n                # Block classification based on what we found:\n                # - If has_masked and has_unmasked: partial block (needs masking)\n                # - If only has_unmasked: full block (no masking needed)\n                # - If only has_masked: skip this block entirely\n                is_partial = Boolean(has_masked and has_unmasked)\n                is_full = Boolean(has_unmasked and (not has_masked))\n\n                if is_partial:\n                    mask_idx[batch_idx, head_idx, m_block, num_mask_blocks] = n_block\n                    num_mask_blocks += 1\n                elif is_full and const_expr(self.compute_full_blocks):\n                    full_idx[batch_idx, head_idx, m_block, num_full_blocks] = n_block\n                    num_full_blocks += 1\n\n        # Only thread 0 writes back the counts\n        if tidx == 0:\n            mask_cnt[batch_idx, head_idx, m_block] = num_mask_blocks\n            if const_expr(self.compute_full_blocks):\n                full_cnt[batch_idx, head_idx, m_block] = num_full_blocks\n\n\ndef compute_block_sparsity(\n    tile_m,\n    tile_n,\n    batch_size,\n    num_heads,\n    seqlen_q,\n    seqlen_k,\n    mask_mod: Callable,\n    aux_tensors: Optional[list],  # list[cute.Tensor]\n    device,\n    compute_full_blocks: bool = True,\n    use_fast_sampling: bool = False,\n) -> Tuple[BlockSparseTensors, BlockSparseTensorsTorch]:\n    \"\"\"\n    Computes block sparsity for a given `mask_mod`.\n\n    Args:\n        tile_m: The tile size for the m dimension.\n        tile_n: The tile size for the n dimension.\n        batch_size: The batch size.\n        num_heads: The number of heads.\n        seqlen_q: The sequence length for the query.\n        seqlen_k: The sequence length for the key.\n        mask_mod: The `mask_mod` callable to use.\n        aux_tensors: A list of auxiliary tensors.\n        device: The device to use.\n        compute_full_blocks: Whether to compute full blocks. If False, only partially-masked blocks are computed.\n        use_fast_sampling: Whether to use 5-point sampling (4 corners + center). This is much faster, but only suitable for masks where this check is sufficient.\n\n    Returns:\n        A tuple of `BlockSparseTensors` and `BlockSparseTensorsTorch`.\n    \"\"\"\n    # Check if mask_mod is marked as suitable for 5-point fast sampling\n    use_fast_sampling = getattr(mask_mod, \"use_fast_sampling\", use_fast_sampling)\n\n    num_m_blocks = (seqlen_q + tile_m - 1) // tile_m\n    num_n_blocks = (seqlen_k + tile_n - 1) // tile_n\n\n    mask_block_cnt = torch.zeros(\n        (batch_size, num_heads, num_m_blocks), device=device, dtype=torch.int32\n    )\n    mask_block_idx = torch.zeros(\n        (batch_size, num_heads, num_m_blocks, num_n_blocks), device=device, dtype=torch.int32\n    )\n    full_block_cnt = (\n        torch.zeros((batch_size, num_heads, num_m_blocks), device=device, dtype=torch.int32)\n        if compute_full_blocks\n        else None\n    )\n    full_block_idx = (\n        torch.zeros(\n            (batch_size, num_heads, num_m_blocks, num_n_blocks), device=device, dtype=torch.int32\n        )\n        if compute_full_blocks\n        else None\n    )\n\n    blocksparse_tensors_torch = BlockSparseTensorsTorch(\n        mask_block_cnt=mask_block_cnt,\n        mask_block_idx=mask_block_idx,\n        full_block_cnt=full_block_cnt,\n        full_block_idx=full_block_idx,\n        block_size=(tile_m, tile_n),\n    )\n\n    mask_mod_hash = hash_callable(mask_mod)\n    blocksparse_tensors = to_cute_block_sparse_tensors(\n        blocksparse_tensors_torch, enable_tvm_ffi=True\n    )\n\n    compile_key = (\n        tile_m,\n        tile_n,\n        mask_mod_hash,\n        compute_full_blocks,\n        aux_tensors is not None,\n        use_fast_sampling,\n    )\n    if compile_key not in compute_block_sparsity.compile_cache:\n        kernel = BlockSparsityKernel(\n            mask_mod,\n            tile_mn=(tile_m, tile_n),\n            compute_full_blocks=compute_full_blocks,\n            use_aux_tensors=aux_tensors is not None,\n            use_fast_sampling=use_fast_sampling,\n        )\n\n        compute_block_sparsity.compile_cache[compile_key] = cute.compile(\n            kernel, blocksparse_tensors, seqlen_q, seqlen_k, aux_tensors, options=\"--enable-tvm-ffi\"\n        )\n\n    compute_block_sparsity.compile_cache[compile_key](\n        blocksparse_tensors_torch[:4],\n        seqlen_q,\n        seqlen_k,\n        aux_tensors,\n    )\n\n    return blocksparse_tensors, blocksparse_tensors_torch\n\n\ncompute_block_sparsity.compile_cache = {}\n"
  },
  {
    "path": "flash_attn/cute/copy_utils.py",
    "content": "# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.\n\nimport math\nfrom typing import Optional, Type, Callable\n\nimport cutlass\nimport cutlass.cute as cute\nfrom cutlass import Float32, Int32, const_expr\nfrom cutlass.cute.nvgpu import cpasync\nimport cutlass.utils.blackwell_helpers as sm100_utils\nfrom cutlass.cutlass_dsl import T, dsl_user_op\nfrom cutlass._mlir.dialects import llvm\nimport cutlass.pipeline\n\n\n@dsl_user_op\ndef cvt_copy(\n    atom: cute.CopyAtom,\n    src: cute.Tensor,\n    dst: cute.Tensor,\n    *,\n    pred: Optional[cute.Tensor] = None,\n    loc=None,\n    ip=None,\n    **kwargs,\n) -> None:\n    assert isinstance(src.iterator, cute.Pointer) and src.memspace == cute.AddressSpace.rmem\n    if const_expr(src.element_type != dst.element_type):\n        src_cvt = cute.make_fragment_like(src, dst.element_type, loc=loc, ip=ip)\n        src_cvt.store(src.load().to(dst.element_type))\n        src = src_cvt\n    cute.copy(atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)\n\n\n@dsl_user_op\ndef load_s2r(src: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:\n    dst = cute.make_fragment_like(src, src.element_type, loc=loc, ip=ip)\n    cute.autovec_copy(src, dst, loc=loc, ip=ip)\n    return dst\n\n\n@dsl_user_op\ndef get_copy_atom(\n    dtype: Type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False, *, loc=None, ip=None\n) -> cute.CopyAtom:\n    num_copy_bits = const_expr(min(128, num_copy_elems * dtype.width))\n    copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()\n    return cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)\n\n\n@dsl_user_op\ndef make_tmem_copy(\n    tmem_copy_atom: cute.CopyAtom, num_wg: int = 1, *, loc=None, ip=None\n) -> cute.CopyAtom:\n    num_dp, num_bits, num_rep, _ = sm100_utils.get_tmem_copy_properties(tmem_copy_atom)\n    assert num_dp == 32\n    assert num_bits == 32\n    tiler_mn = (cute.make_layout((128 * num_rep * num_wg // 32, 32), stride=(32, 1)),)\n    layout_tv = cute.make_layout(\n        ((32, 4, num_wg), (num_rep, 32)), stride=((0, 1, 4 * num_rep), (4, 4 * num_rep * num_wg))\n    )\n    return cute.make_tiled_copy(tmem_copy_atom, layout_tv, tiler_mn)\n\n\n@dsl_user_op\ndef copy(\n    src: cute.Tensor,\n    dst: cute.Tensor,\n    *,\n    pred: Optional[cute.Tensor] = None,\n    num_copy_elems: int = 1,\n    is_async: bool = False,\n    loc=None,\n    ip=None,\n    **kwargs,\n) -> None:\n    copy_atom = get_copy_atom(src.element_type, num_copy_elems, is_async)\n    cute.copy(copy_atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)\n\n\ndef tiled_copy_1d(\n    dtype: Type[cutlass.Numeric], num_threads: int, num_copy_elems: int = 1, is_async: bool = False\n) -> cute.TiledCopy:\n    num_copy_bits = num_copy_elems * dtype.width\n    copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()\n    copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)\n    thr_layout = cute.make_layout(num_threads)\n    val_layout = cute.make_layout(num_copy_elems)\n    return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)\n\n\ndef tiled_copy_2d(\n    dtype: Type[cutlass.Numeric], major_mode_size: int, num_threads: int, is_async: bool = False\n) -> cute.TiledCopy:\n    num_copy_bits = math.gcd(major_mode_size, 128 // dtype.width) * dtype.width\n    copy_elems = num_copy_bits // dtype.width\n    copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()\n    copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)\n    gmem_threads_per_row = major_mode_size // copy_elems\n    assert num_threads % gmem_threads_per_row == 0\n    thr_layout = cute.make_ordered_layout(\n        (num_threads // gmem_threads_per_row, gmem_threads_per_row),\n        order=(1, 0),\n    )\n    val_layout = cute.make_layout((1, copy_elems))\n    return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)\n\n\n@dsl_user_op\ndef atomic_add_fp32x4(\n    a: Float32, b: Float32, c: Float32, d: Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None\n) -> None:\n    gmem_ptr_i64 = gmem_ptr.toint(loc=loc, ip=ip).ir_value()\n    # cache_hint = cutlass.Int64(0x12F0000000000000)\n    llvm.inline_asm(\n        None,\n        [\n            gmem_ptr_i64,\n            Float32(a).ir_value(loc=loc, ip=ip),\n            Float32(b).ir_value(loc=loc, ip=ip),\n            Float32(c).ir_value(loc=loc, ip=ip),\n            Float32(d).ir_value(loc=loc, ip=ip),\n        ],\n        # [gmem_ptr_i64, Float32(a).ir_value(loc=loc, ip=ip), cache_hint.ir_value()],\n        \"{\\n\\t\"\n        # \".reg .b128 abcd;\\n\\t\"\n        # \"mov.b128 abcd, {$1, $2, $3, $4};\\n\\t\"\n        \".reg .v4 .f32 abcd;\\n\\t\"\n        # \"mov.b128 abcd, {$1, $2, $3, $4};\\n\\t\"\n        \"mov.f32 abcd.x, $1;\\n\\t\"\n        \"mov.f32 abcd.y, $2;\\n\\t\"\n        \"mov.f32 abcd.z, $3;\\n\\t\"\n        \"mov.f32 abcd.w, $4;\\n\\t\"\n        \"red.global.add.v4.f32 [$0], abcd;\\n\\t\"\n        # \"red.global.add.L2::cache_hint.v4.f32 [$0], abcd, 0x14F0000000000000;\\n\\t\"\n        \"}\\n\",\n        # \"red.global.add.L2::cache_hint.f32 [$0], $1, 0x12F0000000000000;\",\n        # \"red.global.add.L2::cache_hint.f32 [$0], $1, $2;\",\n        \"l,f,f,f,f\",\n        # \"l,f,l\",\n        has_side_effects=True,\n        is_align_stack=False,\n        asm_dialect=llvm.AsmDialect.AD_ATT,\n    )\n\n\n@dsl_user_op\ndef set_block_rank(\n    smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: Int32, *, loc=None, ip=None\n) -> Int32:\n    \"\"\"Map the given smem pointer to the address at another CTA rank in the cluster.\"\"\"\n    smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value()\n    return Int32(\n        llvm.inline_asm(\n            T.i32(),\n            [smem_ptr_i32, peer_cta_rank_in_cluster.ir_value()],\n            \"mapa.shared::cluster.u32 $0, $1, $2;\",\n            \"=r,r,r\",\n            has_side_effects=False,\n            is_align_stack=False,\n            asm_dialect=llvm.AsmDialect.AD_ATT,\n        )\n    )\n\n\n@dsl_user_op\ndef store_shared_remote_fp32x4(\n    a: Float32,\n    b: Float32,\n    c: Float32,\n    d: Float32,\n    smem_ptr: cute.Pointer,\n    mbar_ptr: cute.Pointer,\n    peer_cta_rank_in_cluster: Int32,\n    *,\n    loc=None,\n    ip=None,\n) -> None:\n    remote_smem_ptr_i32 = set_block_rank(\n        smem_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip\n    ).ir_value()\n    remote_mbar_ptr_i32 = set_block_rank(\n        mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip\n    ).ir_value()\n    llvm.inline_asm(\n        None,\n        [\n            remote_smem_ptr_i32,\n            remote_mbar_ptr_i32,\n            Float32(a).ir_value(loc=loc, ip=ip),\n            Float32(b).ir_value(loc=loc, ip=ip),\n            Float32(c).ir_value(loc=loc, ip=ip),\n            Float32(d).ir_value(loc=loc, ip=ip),\n        ],\n        \"{\\n\\t\"\n        \".reg .v4 .f32 abcd;\\n\\t\"\n        \"mov.f32 abcd.x, $2;\\n\\t\"\n        \"mov.f32 abcd.y, $3;\\n\\t\"\n        \"mov.f32 abcd.z, $4;\\n\\t\"\n        \"mov.f32 abcd.w, $5;\\n\\t\"\n        \"st.async.shared::cluster.mbarrier::complete_tx::bytes.v4.f32 [$0], abcd, [$1];\\n\\t\"\n        \"}\\n\",\n        \"r,r,f,f,f,f\",\n        has_side_effects=True,\n        is_align_stack=False,\n        asm_dialect=llvm.AsmDialect.AD_ATT,\n    )\n\n\n@dsl_user_op\ndef cpasync_bulk_s2cluster(\n    smem_src_ptr: cute.Pointer,\n    smem_dst_ptr: cute.Pointer,\n    mbar_ptr: cute.Pointer,\n    size: int | Int32,\n    peer_cta_rank_in_cluster: Int32,\n    *,\n    loc=None,\n    ip=None,\n):\n    smem_src_ptr_i32 = smem_src_ptr.toint(loc=loc, ip=ip).ir_value()\n    smem_dst_ptr_i32 = set_block_rank(\n        smem_dst_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip\n    ).ir_value()\n    mbar_ptr_i32 = set_block_rank(mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip).ir_value()\n    llvm.inline_asm(\n        None,\n        [\n            smem_dst_ptr_i32,\n            smem_src_ptr_i32,\n            mbar_ptr_i32,\n            Int32(size).ir_value(loc=loc, ip=ip),\n        ],\n        \"cp.async.bulk.shared::cluster.shared::cta.mbarrier::complete_tx::bytes [$0], [$1], $3, [$2];\",\n        \"r,r,r,r\",\n        has_side_effects=True,\n        is_align_stack=False,\n        asm_dialect=llvm.AsmDialect.AD_ATT,\n    )\n\n\n@dsl_user_op\ndef cpasync_bulk_g2s(\n    gmem_ptr: cute.Pointer,\n    smem_ptr: cute.Pointer,\n    tma_bar_ptr: cute.Pointer,\n    size: int | Int32,\n    *,\n    loc=None,\n    ip=None,\n):\n    gmem_ptr_i64 = gmem_ptr.toint(loc=loc, ip=ip).ir_value()\n    smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value()\n    mbar_ptr_i32 = tma_bar_ptr.toint(loc=loc, ip=ip).ir_value()\n    llvm.inline_asm(\n        None,\n        [gmem_ptr_i64, smem_ptr_i32, mbar_ptr_i32, Int32(size).ir_value()],\n        \"cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes [$1], [$0], $3, [$2];\",\n        \"l,r,r,r\",\n        has_side_effects=True,\n        is_align_stack=False,\n        asm_dialect=llvm.AsmDialect.AD_ATT,\n    )\n\n\n@dsl_user_op\ndef cpasync_reduce_bulk_add_f32(\n    smem_ptr: cute.Pointer,\n    gmem_ptr: cute.Pointer,\n    store_bytes: int | Int32,\n    *,\n    loc=None,\n    ip=None,\n):\n    smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value()\n    # cache_hint = cutlass.Int64(0x14F0000000000000)  # EVICT_LAST\n    llvm.inline_asm(\n        None,\n        [gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value()],\n        \"cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [$0], [$1], $2;\",\n        \"l,r,r\",\n        # [gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value(), cache_hint.ir_value()],\n        # \"cp.reduce.async.bulk.global.shared::cta.bulk_group.L2::cache_hint.add.f32 [$0], [$1], $2, $3;\",\n        # \"l,r,r,l\",\n        has_side_effects=True,\n        is_align_stack=False,\n        asm_dialect=llvm.AsmDialect.AD_ATT,\n    )\n\n\ndef cpasync_bulk_get_copy_fn(\n    src_tensor: cute.Tensor,\n    dst_tensor: cute.Tensor,\n    single_stage: bool = False,\n    **kwargs,\n) -> Callable:\n    # src_is_smem = const_expr(\n    #     isinstance(src_tensor.iterator, cute.Pointer)\n    #     and src_tensor.memspace == cute.AddressSpace.smem\n    # )\n    group_rank_src = const_expr(cute.rank(src_tensor) - (1 if not single_stage else 0))\n    group_rank_dst = const_expr(cute.rank(dst_tensor) - (1 if not single_stage else 0))\n    # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK)\n    src = cute.group_modes(src_tensor, 0, group_rank_src)\n    dst = cute.group_modes(dst_tensor, 0, group_rank_dst)\n\n    def copy_bulk(src_idx, dst_idx, **new_kwargs):\n        size = const_expr(cute.size(src.shape[:-1]) * src.element_type.width // 8)\n        cpasync_bulk_g2s(\n            src[None, src_idx].iterator,\n            dst[None, dst_idx].iterator,\n            size=size,\n            **new_kwargs,\n            **kwargs,\n        )\n\n    def copy_bulk_single_stage(**new_kwargs):\n        size = const_expr(cute.size(src.shape) * src.element_type.width // 8)\n        cpasync_bulk_g2s(src.iterator, dst.iterator, size=size, **new_kwargs, **kwargs)\n\n    return copy_bulk if const_expr(not single_stage) else copy_bulk_single_stage\n\n\ndef tma_get_copy_fn(\n    atom: cute.CopyAtom,\n    cta_coord: cute.Coord,\n    cta_layout: cute.Layout,\n    src_tensor: cute.Tensor,\n    dst_tensor: cute.Tensor,\n    filter_zeros: bool = False,\n    single_stage: bool = False,\n    **kwargs,\n) -> Callable:\n    src_is_smem = const_expr(\n        isinstance(src_tensor.iterator, cute.Pointer)\n        and src_tensor.memspace == cute.AddressSpace.smem\n    )\n    smem_tensor, gmem_tensor = (src_tensor, dst_tensor) if src_is_smem else (dst_tensor, src_tensor)\n    group_rank_smem = const_expr(cute.rank(smem_tensor) - (1 if not single_stage else 0))\n    group_rank_gmem = const_expr(cute.rank(gmem_tensor) - (1 if not single_stage else 0))\n    # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK)\n    s, g = cpasync.tma_partition(\n        atom,\n        cta_coord,\n        cta_layout,\n        cute.group_modes(smem_tensor, 0, group_rank_smem),\n        cute.group_modes(gmem_tensor, 0, group_rank_gmem),\n    )\n    if const_expr(filter_zeros):\n        s = cute.filter_zeros(s)\n        g = cute.filter_zeros(g)\n    src, dst = (s, g) if src_is_smem else (g, s)\n\n    def copy_tma(src_idx, dst_idx, **new_kwargs):\n        cute.copy(atom, src[None, src_idx], dst[None, dst_idx], **new_kwargs, **kwargs)\n\n    def copy_tma_single_stage(**new_kwargs):\n        cute.copy(atom, src, dst, **new_kwargs, **kwargs)\n\n    return (copy_tma if const_expr(not single_stage) else copy_tma_single_stage), s, g\n\n\ndef tma_producer_copy_fn(copy: Callable, pipeline: cutlass.pipeline.PipelineAsync):\n    def copy_fn(src_idx, producer_state: cutlass.pipeline.PipelineState, **new_kwargs):\n        copy(\n            src_idx=src_idx,\n            dst_idx=producer_state.index,\n            tma_bar_ptr=pipeline.producer_get_barrier(producer_state),\n            **new_kwargs,\n        )\n\n    return copy_fn\n"
  },
  {
    "path": "flash_attn/cute/cute_dsl_ptxas.py",
    "content": "\"\"\"\nSystem ptxas replacement for CUTLASS DSL.\nEnvironment variables:\n    CUTE_DSL_PTXAS_PATH    - Path to ptxas (e.g., /usr/local/cuda/bin/ptxas)\n    CUTE_DSL_PTXAS_VERBOSE - Set to 1 for verbose output\n\"\"\"\n\nimport os\nimport sys\nimport re\nimport ctypes\nimport subprocess\nfrom pathlib import Path\n\nimport cutlass\n\n\nCUTE_DSL_PTXAS_PATH = os.environ.get(\"CUTE_DSL_PTXAS_PATH\", None)\nVERBOSE = os.environ.get(\"CUTE_DSL_PTXAS_VERBOSE\", \"0\") == \"1\"\n\n_original_load_cuda_library = None\n_user_wanted_ptx = False  # True if user originally set CUTE_DSL_KEEP_PTX=1\n\n\ndef _log(msg):\n    if VERBOSE:\n        print(f\"[ptxas] {msg}\", file=sys.stderr)\n\n\ndef _get_ptx(compiled_func) -> tuple[str, Path] | None:\n    \"\"\"Find and read PTX file, stripping null bytes.\"\"\"\n    func_name = getattr(compiled_func, \"function_name\", None)\n    if not func_name:\n        return None\n\n    dump_dir = os.environ.get(\"CUTE_DSL_DUMP_DIR\", Path.cwd())\n    for ptx_path in Path(dump_dir).glob(f\"*{func_name}*.ptx\"):\n        content = ptx_path.read_text().rstrip(\"\\x00\")\n        if \".entry \" in content and content.rstrip().endswith(\"}\"):\n            _log(f\"Found PTX: {ptx_path}\")\n            return content, ptx_path\n    return None\n\n\ndef _compile_ptx(ptx_path: Path, ptx_content: str) -> bytes:\n    \"\"\"Compile PTX to cubin using system ptxas.\"\"\"\n    # Extract arch from PTX\n    match = re.search(r\"\\.target\\s+(sm_\\d+[a-z]?)\", ptx_content)\n    arch = match.group(1) if match else \"sm_90a\"\n\n    # Write stripped content back if needed\n    if ptx_path.read_text() != ptx_content:\n        ptx_path.write_text(ptx_content)\n\n    # Compile\n    cubin_tmp = ptx_path.with_suffix(\".cubin.tmp\")\n    try:\n        assert CUTE_DSL_PTXAS_PATH is not None\n        result = subprocess.run(\n            [CUTE_DSL_PTXAS_PATH, f\"-arch={arch}\", \"-O3\", \"-o\", str(cubin_tmp), str(ptx_path)],\n            capture_output=True,\n            text=True,\n        )\n        if result.returncode != 0:\n            raise RuntimeError(f\"ptxas failed: {result.stderr}\")\n\n        cubin_data = cubin_tmp.read_bytes()\n        _log(f\"Compiled {ptx_path.name} -> {len(cubin_data)} bytes ({arch})\")\n\n        # Save cubin if CUTE_DSL_KEEP_CUBIN is set\n        if os.environ.get(\"CUTE_DSL_KEEP_CUBIN\", \"0\") == \"1\":\n            cubin_out = ptx_path.with_suffix(\".cubin\")\n            cubin_out.write_bytes(cubin_data)\n            _log(f\"Saved: {cubin_out}\")\n\n        return cubin_data\n    finally:\n        cubin_tmp.unlink(missing_ok=True)\n\n\ndef _patched_load_cuda_library(self):\n    \"\"\"Replacement for _load_cuda_library that uses system ptxas.\"\"\"\n\n    result = _get_ptx(self)\n    if not result:\n        _log(\"PTX not found, falling back to embedded ptxas\")\n        return _original_load_cuda_library(self)\n\n    ptx_content, ptx_path = result\n\n    try:\n        cubin = _compile_ptx(ptx_path, ptx_content)\n    except Exception as e:\n        _log(f\"Compilation failed ({e}), falling back to embedded ptxas\")\n        return _original_load_cuda_library(self)\n\n    # Load cubin\n    import cuda.bindings.runtime as cuda_runtime\n\n    err, library = cuda_runtime.cudaLibraryLoadData(cubin, None, None, 0, None, None, 0)\n    if err != cuda_runtime.cudaError_t.cudaSuccess:\n        _log(f\"cudaLibraryLoadData failed ({err}), falling back to embedded ptxas\")\n        return _original_load_cuda_library(self)\n\n    # Register kernels on all devices\n    _, cuda_load_to_device = self._get_cuda_init_and_load()\n    lib_ptr = ctypes.c_void_p(int(library))\n    dev_id = ctypes.c_int32(0)\n    err_val = ctypes.c_int32(0)\n    args = (ctypes.c_void_p * 3)(\n        ctypes.cast(ctypes.pointer(lib_ptr), ctypes.c_void_p),\n        ctypes.cast(ctypes.pointer(dev_id), ctypes.c_void_p),\n        ctypes.cast(ctypes.pointer(err_val), ctypes.c_void_p),\n    )\n\n    for dev in range(self.num_devices):\n        dev_id.value = dev\n        cuda_load_to_device(args)\n        if err_val.value != 0:\n            _log(\"cuda_load_to_device failed, falling back to embedded ptxas\")\n            return _original_load_cuda_library(self)\n\n    _log(f\"Loaded kernel from {ptx_path.name}\")\n\n    # Delete PTX if user didn't originally want it kept\n    if not _user_wanted_ptx:\n        ptx_path.unlink(missing_ok=True)\n\n    return [cuda_runtime.cudaLibrary_t(lib_ptr.value)]\n\n\ndef patch():\n    \"\"\"Install system ptxas hook. Call before importing cutlass.\"\"\"\n    global _original_load_cuda_library, _user_wanted_ptx\n\n    assert CUTE_DSL_PTXAS_PATH is not None\n    if not os.path.isfile(CUTE_DSL_PTXAS_PATH) or not os.access(CUTE_DSL_PTXAS_PATH, os.X_OK):\n        raise RuntimeError(f\"ptxas not found: {CUTE_DSL_PTXAS_PATH}\")\n\n    # Track if user originally wanted PTX kept\n    _user_wanted_ptx = os.environ.get(\"CUTE_DSL_KEEP_PTX\", \"0\") == \"1\"\n    # os.environ['CUTE_DSL_KEEP_PTX'] = '1'\n    assert os.environ.get(\"CUTE_DSL_KEEP_PTX\", \"0\") == \"1\", (\n        \"Require CUTE_DSL_KEEP_PTX=1 to use system's ptxas\"\n    )\n\n    cls = cutlass.cutlass_dsl.cuda_jit_executor.CudaDialectJitCompiledFunction\n    _original_load_cuda_library = cls._load_cuda_library\n    cls._load_cuda_library = _patched_load_cuda_library\n    _log(\"Patch applied\")\n    return\n"
  },
  {
    "path": "flash_attn/cute/cute_dsl_utils.py",
    "content": "# Copyright (c) 2025, Tri Dao.\n\nimport os\nimport pathlib\nfrom typing import Tuple\nfrom functools import partial, lru_cache\nfrom dataclasses import dataclass, fields\n\nimport torch\n\ntry:\n    from triton.tools.disasm import extract\nexcept ImportError:\n    extract = None\n\nimport cutlass\nimport cutlass.cute as cute\nfrom cutlass.base_dsl.typing import JitArgument\nfrom cutlass.cutlass_dsl import NumericMeta\nfrom cutlass.cute.runtime import from_dlpack\n\nStaticTypes = (cutlass.Constexpr, NumericMeta, int, bool, str, float, type(None))\n\n\nload_cubin_module_data_og = cutlass.base_dsl.runtime.cuda.load_cubin_module_data\ncute_compile_og = cute.compile\n\n\ntorch2cute_dtype_map = {\n    torch.float16: cutlass.Float16,\n    torch.bfloat16: cutlass.BFloat16,\n    torch.float32: cutlass.Float32,\n}\n\n\n@lru_cache\ndef get_max_active_clusters(cluster_size):\n    return cutlass.utils.HardwareInfo().get_max_active_clusters(cluster_size=cluster_size)\n\n\n@lru_cache\ndef get_device_capacity(device: torch.device = None) -> Tuple[int, int]:\n    return torch.cuda.get_device_capability(device)\n\n\n@dataclass\nclass ArgumentsBase(JitArgument):\n    def __c_pointers__(self):\n        all_fields = [getattr(self, field.name) for field in fields(self)]\n        non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]\n        c_ptrs = []\n        for obj in non_constexpr_fields:\n            if hasattr(obj, \"__c_pointers__\"):\n                c_ptrs.extend(obj.__c_pointers__())\n        return c_ptrs\n\n    def __get_mlir_types__(self):\n        all_fields = [getattr(self, field.name) for field in fields(self)]\n        non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]\n        types, self._values_pos = [], []\n        for obj in non_constexpr_fields:\n            if hasattr(obj, \"__get_mlir_types__\"):\n                obj_types = obj.__get_mlir_types__()\n                types.extend(obj_types)\n                self._values_pos.append(len(obj_types))\n            else:\n                self._values_pos.append(0)\n        return types\n\n    def __new_from_mlir_values__(self, values):\n        all_fields = {field.name: getattr(self, field.name) for field in fields(self)}\n        constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)}\n        non_constexpr_fields = {\n            n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes)\n        }\n        for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos):\n            non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])\n            values = values[n_items:]\n        return self.__class__(**non_constexpr_fields, **constexpr_fields)\n\n\ndef load_cubin_module_data_patched(cubin_data, filepath):\n    pathlib.Path(filepath).write_bytes(cubin_data)\n    return load_cubin_module_data_og(cubin_data)\n\n\ndef cute_compile_patched(*args, **kwargs):\n    \"\"\"A patched version of cute.compile that dump the SASS to a file if CUTE_CUBIN_PATH is set.\"\"\"\n    cubin_path = os.getenv(\"CUTE_CUBIN_PATH\", None)\n    if cubin_path is not None:\n        cutlass.base_dsl.runtime.cuda.load_cubin_module_data = partial(\n            load_cubin_module_data_patched, filepath=cubin_path\n        )\n    output = cute_compile_og(*args, **kwargs)\n    if cubin_path is not None:\n        cutlass.base_dsl.runtime.cuda.load_cubin_module_data = load_cubin_module_data_og\n        if extract is not None:\n            sass = extract(cubin_path, None)\n            pathlib.Path(cubin_path).with_suffix(\".annotated.sass\").write_text(sass)\n    return output\n\n\ndef assume_strides_aligned(t):\n    \"\"\"Assume all strides except the last are divisible by 128 bits.\n\n    Python int strides (e.g., stride=0 from GQA expand) are kept as-is\n    since they're static and don't need alignment assumptions.\n    \"\"\"\n    divby = 128 // t.element_type.width\n    strides = tuple(s if isinstance(s, int) else cute.assume(s, divby=divby) for s in t.stride[:-1])\n    return (*strides, t.stride[-1])\n\n\ndef assume_tensor_aligned(t):\n    \"\"\"Rebuild a tensor with 128-bit aligned stride assumptions. Passes through None.\"\"\"\n    if t is None:\n        return None\n    return cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=assume_strides_aligned(t)))\n\n\ndef to_cute_tensor(t, assumed_align=16, leading_dim=-1, fully_dynamic=False, enable_tvm_ffi=True):\n    \"\"\"Convert torch tensor to cute tensor for TVM FFI. leading_dim=-1 defaults to t.ndim-1.\"\"\"\n    tensor = from_dlpack(t.detach(), assumed_align=assumed_align, enable_tvm_ffi=enable_tvm_ffi)\n    if fully_dynamic:\n        return tensor.mark_layout_dynamic()\n    if leading_dim == -1:\n        leading_dim = t.ndim - 1\n    return tensor.mark_layout_dynamic(leading_dim=leading_dim)\n\n\ndef to_cute_aux_tensor(t, enable_tvm_ffi=True):\n    \"\"\"Convert torch tensor to cute tensor for TVM FFI, tailored to FlexAttention aux tensors.\n    This allows the user to specify alignment and leading dimension for aux tensors used in\n    custom score_mod callables.\n    \"\"\"\n    assumed_align: int = getattr(t, \"__assumed_align__\", None)\n    leading_dim: int = getattr(t, \"__leading_dim__\", None)\n    fully_dynamic: bool = leading_dim is None\n\n    return to_cute_tensor(\n        t,\n        assumed_align=assumed_align,\n        leading_dim=leading_dim,\n        fully_dynamic=fully_dynamic,\n        enable_tvm_ffi=enable_tvm_ffi,\n    )\n\n\ndef get_aux_tensor_metadata(aux_tensors):\n    return tuple(\n        (\n            getattr(t, \"__assumed_align__\", 0),\n            getattr(t, \"__leading_dim__\", -1),\n            hasattr(t, \"__leading_dim__\"),\n        )\n        for t in aux_tensors\n    )\n\n\ndef get_broadcast_dims(tensor: torch.Tensor) -> Tuple[bool, ...]:\n    \"\"\"Return tuple of bools indicating which dims have stride=0 (broadcast).\n\n    This is useful for compile keys since CuTe's mark_layout_dynamic() keeps\n    stride=0 as static, meaning kernels compiled with different broadcast\n    patterns are not interchangeable.\n    \"\"\"\n    return tuple(s == 0 for s in tensor.stride())\n"
  },
  {
    "path": "flash_attn/cute/fa_logging.py",
    "content": "# Copyright (c) 2025, Tri Dao.\n\n\"\"\"Unified FlashAttention logging controlled by a single ``FA_LOG_LEVEL`` env var.\n\nHost-side messages go through Python ``logging`` (logger name ``flash_attn``).\nA default ``StreamHandler`` is attached automatically when ``FA_LOG_LEVEL >= 1``\nso that standalone scripts get output without extra setup; applications that\nconfigure their own logging can remove or replace it via the standard API.\n\nFA_LOG_LEVEL mapping::\n\n    0  off       nothing logged\n    1  host      host-side summaries only (no kernel printf)\n    2  kernel    host + curated kernel traces\n    3  max       host + all kernel traces (noisy, perf hit)\n\nSet via environment variable::\n\n    FA_LOG_LEVEL=1 python train.py\n\nDevice-side ``cute.printf`` calls are compile-time eliminated via\n``cutlass.const_expr`` when the log level is below the callsite threshold,\nso there is zero performance cost when device logging is off.\nChanging the log level after kernel compilation requires a recompile\n(the level participates in the forward compile key).\n\"\"\"\n\nimport logging\nimport os\nimport sys\n\nimport cutlass.cute as cute\nfrom cutlass import const_expr\n\n_LOG_LEVEL_NAMES = {\"off\": 0, \"host\": 1, \"kernel\": 2, \"max\": 3}\n\n\ndef _parse_log_level(raw: str) -> int:\n    if raw in _LOG_LEVEL_NAMES:\n        return _LOG_LEVEL_NAMES[raw]\n    try:\n        level = int(raw)\n    except ValueError:\n        return 0\n    return max(0, min(level, 3))\n\n\n_fa_log_level: int = _parse_log_level(os.environ.get(\"FA_LOG_LEVEL\", \"0\"))\n\n_logger = logging.getLogger(\"flash_attn\")\n_logger.addHandler(logging.NullHandler())\n_default_handler: logging.Handler | None = None\n\n\ndef _configure_default_handler() -> None:\n    global _default_handler\n    if _fa_log_level >= 1:\n        if _default_handler is None:\n            _default_handler = logging.StreamHandler(sys.stdout)\n            _default_handler.setFormatter(logging.Formatter(\"[FA] %(message)s\"))\n            _logger.addHandler(_default_handler)\n        _logger.setLevel(logging.DEBUG)\n    else:\n        if _default_handler is not None:\n            _logger.removeHandler(_default_handler)\n            _default_handler = None\n        _logger.setLevel(logging.WARNING)\n\n\n_configure_default_handler()\n\n\ndef get_fa_log_level() -> int:\n    return _fa_log_level\n\n\ndef set_fa_log_level(level: int | str) -> None:\n    \"\"\"Set the FA log level programmatically.\n\n    Host logging takes effect immediately.  Device logging changes only\n    affect kernels compiled after this call (new compile-key selection).\n    \"\"\"\n    global _fa_log_level\n    if isinstance(level, str):\n        level = _parse_log_level(level)\n    _fa_log_level = max(0, min(int(level), 3))\n    _configure_default_handler()\n\n\ndef fa_log(level: int, msg: str):\n    if _fa_log_level >= level:\n        _logger.info(msg)\n\n\ndef fa_printf(level: int, fmt, *args):\n    if const_expr(_fa_log_level >= level):\n        cute.printf(fmt, *args)\n"
  },
  {
    "path": "flash_attn/cute/fast_math.py",
    "content": "# Copyright (c) 2025, Tri Dao.\n\nimport cutlass\nimport cutlass.cute as cute\nfrom cutlass import Int32\n\n\n@cute.jit\ndef clz(x: Int32) -> Int32:\n    # for i in cutlass.range_constexpr(32):\n    #     if (1 << (31 - i)) & x:\n    #         return Int32(i)\n    # return Int32(32)\n    # Early exit is not supported yet\n    res = Int32(32)\n    done = False\n    for i in cutlass.range(32):\n        if ((1 << (31 - i)) & x) and not done:\n            res = Int32(i)\n            done = True\n    return res\n"
  },
  {
    "path": "flash_attn/cute/flash_bwd.py",
    "content": "# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n# A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/mainloop_bwd_sm80.hpp\n# from Cutlass C++ to Cute-DSL.\nimport math\nfrom types import SimpleNamespace\nfrom typing import Type, Callable, Optional\nfrom functools import partial\n\nimport cuda.bindings.driver as cuda\n\nimport cutlass\nimport cutlass.cute as cute\nfrom cutlass.cute.nvgpu import cpasync, warp\nfrom cutlass import Float32, Int32\nimport cutlass.utils as utils_basic\n\nfrom quack import layout_utils\nfrom flash_attn.cute import ampere_helpers as sm80_utils\nfrom flash_attn.cute.cute_dsl_utils import assume_tensor_aligned\nfrom flash_attn.cute import utils\nfrom flash_attn.cute.mask import AttentionMask\nfrom flash_attn.cute.seqlen_info import SeqlenInfoQK\nfrom quack.cute_dsl_utils import ParamsBase\nfrom flash_attn.cute.tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments\nfrom flash_attn.cute.block_sparsity import BlockSparseTensors\n\n\nclass FlashAttentionBackwardSm80:\n    def __init__(\n        self,\n        dtype: Type[cutlass.Numeric],\n        head_dim: int,\n        head_dim_v: Optional[int] = None,\n        qhead_per_kvhead: int = 1,\n        m_block_size: int = 64,\n        n_block_size: int = 128,\n        num_stages_Q: int = 2,\n        num_stages_dO: int = 2,\n        num_threads: int = 256,\n        pack_gqa: bool = False,\n        is_causal: bool = False,\n        SdP_swapAB: bool = False,\n        dKV_swapAB: bool = False,\n        dQ_swapAB: bool = False,\n        AtomLayoutMSdP: int = 1,\n        AtomLayoutNdKV: int = 8,\n        AtomLayoutMdQ: int = 1,\n        V_in_regs: bool = False,\n    ):\n        \"\"\"Initializes the configuration for a flash attention v2 kernel.\n\n        All contiguous dimensions must be at least 16 bytes aligned which indicates the head dimension\n        should be a multiple of 8.\n\n        :param head_dim: head dimension\n        :type head_dim: int\n        :param m_block_size: m block size\n        :type m_block_size: int\n        :param n_block_size: n block size\n        :type n_block_size: int\n        :param num_threads: number of threads\n        :type num_threads: int\n        :param is_causal: is causal\n        \"\"\"\n        self.dtype = dtype\n        # padding head_dim to a multiple of 16 as k_block_size\n        hdim_multiple_of = 32\n        self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of)\n        head_dim_v = head_dim_v if head_dim_v is not None else head_dim\n        self.same_hdim_kv = head_dim == head_dim_v\n        self.head_dim_v_padded = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of)\n        # Can save registers (and hence be faster) if we don't have to check hdim predication\n        self.check_hdim_oob = head_dim != self.head_dim_padded\n        self.check_hdim_v_oob = head_dim_v != self.head_dim_v_padded\n        self.qhead_per_kvhead = qhead_per_kvhead\n        self.m_block_size = m_block_size\n        self.n_block_size = n_block_size\n        self.num_threads = num_threads\n        self.pack_gqa = pack_gqa\n        self.is_causal = is_causal\n        self.num_stages_Q = num_stages_Q\n        self.num_stages_dO = num_stages_dO\n        self.SdP_swapAB = SdP_swapAB\n        self.dKV_swapAB = dKV_swapAB\n        self.dQ_swapAB = dQ_swapAB\n        self.AtomLayoutMSdP = AtomLayoutMSdP\n        self.AtomLayoutNdKV = AtomLayoutNdKV\n        self.AtomLayoutMdQ = AtomLayoutMdQ\n        num_mma_warps = self.num_threads // cute.arch.WARP_SIZE\n        self.Mma_dKV_is_RS = AtomLayoutMSdP == 1 and AtomLayoutNdKV == num_mma_warps and SdP_swapAB and not dKV_swapAB\n        self.V_in_regs = V_in_regs\n        self.share_QV_smem = V_in_regs\n\n    @staticmethod\n    def can_implement(\n        dtype, head_dim, head_dim_v, m_block_size, n_block_size, num_stages_Q, num_stages_dO,\n        num_threads, is_causal,\n        V_in_regs=False\n    ) -> bool:\n        \"\"\"Check if the kernel can be implemented with the given parameters.\n\n        :param dtype: data type\n        :type dtype: cutlass.Numeric\n        :param head_dim: head dimension\n        :type head_dim: int\n        :param m_block_size: m block size\n        :type m_block_size: int\n        :param n_block_size: n block size\n        :type n_block_size: int\n        :param num_threads: number of threads\n        :type num_threads: int\n        :param is_causal: is causal\n        :type is_causal: bool\n\n        :return: True if the kernel can be implemented, False otherwise\n        :rtype: bool\n        \"\"\"\n        if dtype not in [cutlass.Float16, cutlass.BFloat16]:\n            return False\n        if head_dim % 8 != 0:\n            return False\n        if head_dim_v % 8 != 0:\n            return False\n        if n_block_size % 16 != 0:\n            return False\n        if num_threads % 32 != 0:\n            return False\n        # Check if block size setting is out of shared memory capacity\n        # Shared memory usage: Q tile + (K tile + V tile) where K and V use the same tile size\n        smem_usage_Q = m_block_size * head_dim * num_stages_Q * 2\n        smem_usage_dO = m_block_size * head_dim_v * num_stages_dO * 2\n        smem_usage_K = n_block_size * head_dim * 2\n        smem_usage_V = n_block_size * head_dim_v * 2\n        smem_usage_QV = (smem_usage_Q + smem_usage_V) if not V_in_regs else max(smem_usage_Q, smem_usage_V)\n        smem_usage = smem_usage_QV + smem_usage_dO + smem_usage_K\n        smem_capacity = utils_basic.get_smem_capacity_in_bytes(\"sm_80\")\n        if smem_usage > smem_capacity:\n            return False\n        return True\n\n    def _check_type(\n        self,\n        mQ_type: Type[cutlass.Numeric],\n        mK_type: Type[cutlass.Numeric],\n        mV_type: Type[cutlass.Numeric],\n        mdO_type: Type[cutlass.Numeric],\n        mLSE_type: Type[cutlass.Numeric],\n        mdPsum_type: Type[cutlass.Numeric],\n        mdQaccum_type: Type[cutlass.Numeric],\n        mdK_type: Type[cutlass.Numeric],\n        mdV_type: Type[cutlass.Numeric],\n        mCuSeqlensQ_type: Type[cutlass.Numeric] | None,\n        mCuSeqlensK_type: Type[cutlass.Numeric] | None,\n        mSeqUsedQ_type: Type[cutlass.Numeric] | None,\n        mSeqUsedK_type: Type[cutlass.Numeric] | None,\n    ):\n        if cutlass.const_expr(not (mQ_type == mK_type == mV_type == mdO_type)):\n            raise TypeError(\"All tensors must have the same data type\")\n        if cutlass.const_expr(self.qhead_per_kvhead == 1):\n            if cutlass.const_expr(not (mdK_type == mdV_type == mQ_type)):\n                raise TypeError(\"mdK and mdV tensors must have the same data type as mQ\")\n        else:\n            if cutlass.const_expr(not (mdK_type == mdV_type == cutlass.Float32)):\n                raise TypeError(\"mdKaccum and mdVaccum tensors must have the data type Float32\")\n        if cutlass.const_expr(not mQ_type in [cutlass.Float16, cutlass.BFloat16]):\n            raise TypeError(\"Only Float16 or BFloat16 is supported\")\n        if cutlass.const_expr(not mLSE_type in [cutlass.Float32]):\n            raise TypeError(\"LSE tensor must be Float32\")\n        if cutlass.const_expr(not mdPsum_type in [cutlass.Float32]):\n            raise TypeError(\"dPsum tensor must be Float32\")\n        if cutlass.const_expr(not mdQaccum_type in [cutlass.Float32]):\n            raise TypeError(\"dQaccum tensor must be Float32\")\n        if cutlass.const_expr(mCuSeqlensQ_type not in [None, cutlass.Int32]):\n            raise TypeError(\"cuSeqlensQ tensor must be Int32\")\n        if cutlass.const_expr(mCuSeqlensK_type not in [None, cutlass.Int32]):\n            raise TypeError(\"cuSeqlensK tensor must be Int32\")\n        if cutlass.const_expr(mSeqUsedQ_type not in [None, cutlass.Int32]):\n            raise TypeError(\"SeqUsedQ tensor must be Int32\")\n        if cutlass.const_expr(mSeqUsedK_type not in [None, cutlass.Int32]):\n            raise TypeError(\"SeqUsedK tensor must be Int32\")\n        assert mQ_type == self.dtype\n\n    def _setup_attributes(self):\n        # ///////////////////////////////////////////////////////////////////////////////\n        # Shared memory layout: Q/K/V\n        # ///////////////////////////////////////////////////////////////////////////////\n        sQ_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.head_dim_padded)\n        self.sQ_layout = cute.tile_to_shape(\n            sQ_layout_atom, (self.m_block_size, self.head_dim_padded, self.num_stages_Q), (0, 1, 2),\n        )\n        sK_layout_atom = sQ_layout_atom\n        self.sK_layout = cute.tile_to_shape(\n            sK_layout_atom, (self.n_block_size, self.head_dim_padded), (0, 1),\n        )\n        sV_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.head_dim_v_padded)\n        self.sV_layout = cute.tile_to_shape(\n            sV_layout_atom, (self.n_block_size, self.head_dim_v_padded), (0, 1),\n        )\n        sdO_layout_atom = sV_layout_atom\n        self.sdO_layout = cute.tile_to_shape(\n            sdO_layout_atom, (self.m_block_size, self.head_dim_v_padded, self.num_stages_dO), (0, 1, 2),\n        )\n        # TODO: do we set swizzle to be 3 here explicitly?\n        sPdS_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.n_block_size)\n        self.sPdS_layout = cute.tile_to_shape(\n            sPdS_layout_atom, (self.m_block_size, self.n_block_size), (0, 1),\n        )\n        # We set stride to be multiple of 64 so that if ShuffleLSE, even if threads read from sLSE but out of bounds,\n        # it's still a valid smem address.\n        self.sLSE_layout = cute.make_layout(\n            (self.m_block_size, self.num_stages_Q),\n            stride=(1, cute.round_up(self.m_block_size, 64)),\n        )\n        sLSEMma_layout = cute.make_layout(\n            (self.m_block_size, self.n_block_size, self.num_stages_Q),\n            stride=(1, 0, cute.round_up(self.m_block_size, 64)),\n        )\n        sLSEMma_layout_transposed = cute.make_layout(\n            (self.n_block_size, self.m_block_size, self.num_stages_Q),\n            stride=(0, 1, cute.round_up(self.m_block_size, 64)),\n        )\n        self.sLSEMma_layout = sLSEMma_layout if not self.SdP_swapAB else sLSEMma_layout_transposed\n\n        # ///////////////////////////////////////////////////////////////////////////////\n        # GMEM Tiled copy:\n        # ///////////////////////////////////////////////////////////////////////////////\n        # Thread layouts for copies\n        universal_copy_bits = 128\n        async_copy_elems = universal_copy_bits // self.dtype.width\n        # atom_async_copy: async copy atom for QKV load\n        atom_async_copy = cute.make_copy_atom(\n            cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),\n            self.dtype,\n            num_bits_per_copy=universal_copy_bits,\n        )\n        # atom_universal_copy: universal copy atom for O store\n        atom_universal_copy = cute.make_copy_atom(\n            cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=universal_copy_bits,\n        )\n        # tQK_layout: thread layout for QK load\n        tQK_shape_dim_1 = sQ_layout_atom.outer.shape[1] // async_copy_elems\n        assert self.num_threads % tQK_shape_dim_1 == 0, \"num_threads must be divisible by tQK_shape_dim_1\"\n        tQK_layout = cute.make_ordered_layout(\n            (self.num_threads // tQK_shape_dim_1, tQK_shape_dim_1), order=(1, 0),\n        )\n        # Do we need to check if we overshot kBlockM when we load Q?\n        self.is_even_m_smem_q = self.m_block_size % tQK_layout.shape[0] == 0\n        # Do we need to check if we overshot kBlockN when we load K?\n        self.is_even_n_smem_k = self.n_block_size % tQK_layout.shape[0] == 0\n        tVdO_shape_dim_1 = sV_layout_atom.outer.shape[1] // async_copy_elems\n        assert self.num_threads % tVdO_shape_dim_1 == 0, \"num_threads must be divisible by tVdO_shape_dim_1\"\n        tVdO_layout = cute.make_ordered_layout(\n            (self.num_threads // tVdO_shape_dim_1, tVdO_shape_dim_1), order=(1, 0),\n        )\n        # Do we need to check if we overshot kBlockN when we load V?\n        self.is_even_n_smem_v = self.n_block_size % tVdO_layout.shape[0] == 0\n        self.is_even_m_smem_do = self.m_block_size % tVdO_layout.shape[0] == 0\n\n        # Value layouts for copies\n        vQKVdO_layout = cute.make_layout((1, async_copy_elems))\n\n        # gmem_tiled_copy_QK: tiled copy for QK load\n        self.gmem_tiled_copy_QK = cute.make_tiled_copy_tv(atom_async_copy, tQK_layout, vQKVdO_layout)\n        self.gmem_tiled_copy_VdO = cute.make_tiled_copy_tv(atom_async_copy, tVdO_layout, vQKVdO_layout)\n        self.gmem_tiled_copy_dK = cute.make_tiled_copy_tv(atom_universal_copy, tQK_layout, vQKVdO_layout)\n        self.gmem_tiled_copy_dV = cute.make_tiled_copy_tv(atom_universal_copy, tVdO_layout, vQKVdO_layout)\n        async_copy_elems_accum = universal_copy_bits // cutlass.Float32.width\n\n        # I think we wouldn't require this with smarter padding\n        if cutlass.const_expr(not self.varlen_q):\n            async_copy_elems_accum = universal_copy_bits // cutlass.Float32.width\n            atom_async_copy_accum = cute.make_copy_atom(\n                cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),\n                cutlass.Float32,\n                num_bits_per_copy=universal_copy_bits,\n            )\n        else:\n            async_copy_elems_accum = 1\n            atom_async_copy_accum = cute.make_copy_atom(\n                cute.nvgpu.CopyUniversalOp(),\n                cutlass.Float32,\n                num_bits_per_copy=cutlass.Float32.width,\n            )\n        self.gmem_tiled_copy_LSE = cute.make_tiled_copy_tv(\n            atom_async_copy_accum,\n            cute.make_layout(self.num_threads),\n            cute.make_layout(async_copy_elems_accum),\n        )\n        self.gmem_tiled_copy_dQaccum = cute.make_tiled_copy_tv(\n            cute.make_copy_atom(\n                cute.nvgpu.CopyUniversalOp(), cutlass.Float32, num_bits_per_copy=cutlass.Float32.width\n            ),\n            cute.make_layout(self.num_threads),\n            cute.make_layout(1)\n        )\n        if cutlass.const_expr(self.qhead_per_kvhead > 1):\n            self.gmem_tiled_copy_dK = self.gmem_tiled_copy_dQaccum\n            self.gmem_tiled_copy_dV = self.gmem_tiled_copy_dQaccum\n\n    def _get_tiled_mma(self):\n        num_mma_warps = self.num_threads // 32\n        AtomLayoutSdP = (self.AtomLayoutMSdP, num_mma_warps // self.AtomLayoutMSdP, 1) if cutlass.const_expr(not self.SdP_swapAB) else (num_mma_warps // self.AtomLayoutMSdP, self.AtomLayoutMSdP, 1)\n        tiled_mma_sdp = cute.make_tiled_mma(\n            warp.MmaF16BF16Op(self.dtype, cutlass.Float32, (16, 8, 16)),\n            AtomLayoutSdP,\n            permutation_mnk=(AtomLayoutSdP[0] * 16, AtomLayoutSdP[1] * 16, 16),\n        )\n        AtomLayoutdKV = (self.AtomLayoutNdKV, num_mma_warps // self.AtomLayoutNdKV, 1) if cutlass.const_expr(not self.dKV_swapAB) else (num_mma_warps // self.AtomLayoutNdKV, self.AtomLayoutNdKV, 1)\n        tiled_mma_dkv = cute.make_tiled_mma(\n            warp.MmaF16BF16Op(self.dtype, cutlass.Float32, (16, 8, 16)),\n            AtomLayoutdKV,\n            permutation_mnk=(AtomLayoutdKV[0] * 16, AtomLayoutdKV[1] * 16, 16),\n        )\n        AtomLayoutdQ = (self.AtomLayoutMdQ, num_mma_warps // self.AtomLayoutMdQ, 1) if cutlass.const_expr(not self.dQ_swapAB) else (num_mma_warps // self.AtomLayoutMdQ, self.AtomLayoutMdQ, 1)\n        tiled_mma_dq = cute.make_tiled_mma(\n            warp.MmaF16BF16Op(self.dtype, cutlass.Float32, (16, 8, 16)),\n            AtomLayoutdQ,\n            permutation_mnk=(AtomLayoutdQ[0] * 16, AtomLayoutdQ[1] * 16, 16),\n        )\n        return tiled_mma_sdp, tiled_mma_dkv, tiled_mma_dq\n\n    def _get_shared_storage_cls(self):\n        sQ_struct, sK_struct, sV_struct, sdO_struct = [\n            cute.struct.Align[cute.struct.MemRange[self.dtype, cute.cosize(layout)], 1024]\n            for layout in (self.sQ_layout, self.sK_layout, self.sV_layout, self.sdO_layout)\n        ]\n        cosize_sQV = max(cute.cosize(self.sQ_layout), cute.cosize(self.sV_layout))\n        sQV_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sQV], 1024]\n        sLSE_struct, sdPsum_struct = [\n            cute.struct.Align[cute.struct.MemRange[cutlass.Float32, cute.cosize(layout)], 128]\n            for layout in (self.sLSE_layout, self.sLSE_layout)\n        ]\n        sP_struct, sdS_struct = [\n            cute.struct.Align[cute.struct.MemRange[self.dtype, cute.cosize(layout)], 128]\n            for layout in (self.sPdS_layout, self.sPdS_layout)\n        ]\n\n        @cute.struct\n        class SharedStorageSeparateQV:\n            sK: sK_struct\n            sV: sV_struct\n            sQ: sQ_struct\n            sdO: sdO_struct\n            sLSE: sLSE_struct\n            sdPsum: sdPsum_struct\n            sP: sP_struct\n            sdS: sdS_struct\n            # TODO: the case where there's no sP\n\n        @cute.struct\n        class SharedStorageSharedQV:\n            sK: sK_struct\n            sV: sV_struct\n            sQ: sQV_struct\n            sdO: sdO_struct\n            sLSE: sLSE_struct\n            sdPsum: sdPsum_struct\n            sP: sP_struct\n            sdS: sdS_struct\n\n        return SharedStorageSeparateQV if cutlass.const_expr(not self.share_QV_smem) else SharedStorageSharedQV\n\n    @cute.jit\n    def __call__(\n        self,\n        mQ: cute.Tensor,\n        mK: cute.Tensor,\n        mV: cute.Tensor,\n        mdO: cute.Tensor,\n        mLSE: cute.Tensor,\n        mdPsum: cute.Tensor,\n        mdQaccum: cute.Tensor,\n        mdK: cute.Tensor,\n        mdV: cute.Tensor,\n        softmax_scale: cutlass.Float32,\n        mCuSeqlensQ: Optional[cute.Tensor] = None,\n        mCuSeqlensK: Optional[cute.Tensor] = None,\n        mSeqUsedQ: Optional[cute.Tensor] = None,\n        mSeqUsedK: Optional[cute.Tensor] = None,\n        softcap: Float32 | float | None = None,\n        window_size_left: Int32 | int | None = None,\n        window_size_right: Int32 | int | None = None,\n        mdQ_semaphore: Optional[cute.Tensor] = None,\n        mdK_semaphore: Optional[cute.Tensor] = None,\n        mdV_semaphore: Optional[cute.Tensor] = None,\n        aux_tensors: Optional[list] = None,\n        blocksparse_tensors: Optional[BlockSparseTensors] = None,\n        # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI).\n        stream: cuda.CUstream = None,\n    ):\n        assert mdQ_semaphore is None and mdK_semaphore is None and mdV_semaphore is None, (\n            \"determinism not supported yet for Sm80\"\n        )\n        # Get the data type and check if it is fp16 or bf16\n        self._check_type(*(t.element_type if t is not None else None\n                           for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)))\n        mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV = [\n            assume_tensor_aligned(t) for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV)\n        ]\n        self.varlen_q = (mCuSeqlensQ is not None)\n        self._setup_attributes()\n        SharedStorage = self._get_shared_storage_cls()\n        tiled_mma_sdp, tiled_mma_dkv, tiled_mma_dq = self._get_tiled_mma()\n\n        num_head = mQ.shape[1] if cutlass.const_expr(mCuSeqlensQ is not None) else mQ.shape[2]\n\n        if cutlass.const_expr(mCuSeqlensK is not None):\n            TileScheduler = SingleTileVarlenScheduler\n            num_batch = mCuSeqlensK.shape[0] - 1\n        else:\n            TileScheduler = SingleTileScheduler\n            num_batch = mK.shape[0]\n\n        # Uses seqlen k, etc. since main bwd kernel's blocks are over n\n        tile_sched_args = TileSchedulerArguments(\n            num_block=cute.ceil_div(mK.shape[1], self.n_block_size),\n            num_head=num_head,\n            num_batch=num_batch,\n            num_splits=1,\n            seqlen_k=0,\n            headdim=mK.shape[2],\n            headdim_v=mV.shape[2],\n            total_q=mK.shape[0],\n            tile_shape_mn=(self.n_block_size, self.m_block_size),\n            qhead_per_kvhead_packgqa=self.qhead_per_kvhead if cutlass.const_expr(self.pack_gqa) else 1,\n            mCuSeqlensQ=mCuSeqlensK,\n            mSeqUsedQ=mSeqUsedK,\n        )\n\n        tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)\n        grid_dim = TileScheduler.get_grid_shape(tile_sched_params)\n\n        softmax_scale_log2 = softmax_scale * math.log2(math.e)\n        self.kernel(\n            mQ,\n            mK,\n            mV,\n            mdO,\n            mLSE,\n            mdPsum,\n            mdQaccum,\n            mdK,\n            mdV,\n            mCuSeqlensQ,\n            mCuSeqlensK,\n            mSeqUsedQ,\n            mSeqUsedK,\n            softmax_scale,\n            softmax_scale_log2,\n            self.sQ_layout,\n            self.sK_layout,\n            self.sV_layout,\n            self.sdO_layout,\n            self.sPdS_layout,\n            self.sLSE_layout,\n            self.sLSEMma_layout,\n            self.gmem_tiled_copy_QK,\n            self.gmem_tiled_copy_VdO,\n            self.gmem_tiled_copy_dK,\n            self.gmem_tiled_copy_dV,\n            self.gmem_tiled_copy_LSE,\n            self.gmem_tiled_copy_dQaccum,\n            tiled_mma_sdp,\n            tiled_mma_dkv,\n            tiled_mma_dq,\n            SharedStorage,\n            tile_sched_params,\n            TileScheduler,\n        ).launch(\n            grid=grid_dim,\n            block=[self.num_threads, 1, 1],\n            smem=SharedStorage.size_in_bytes(),\n            stream=stream,\n        )\n\n    @cute.kernel\n    def kernel(\n        self,\n        mQ: cute.Tensor,\n        mK: cute.Tensor,\n        mV: cute.Tensor,\n        mdO: cute.Tensor,\n        mLSE: cute.Tensor,\n        mdPsum: cute.Tensor,\n        mdQaccum: cute.Tensor,\n        mdK: cute.Tensor,\n        mdV: cute.Tensor,\n        mCuSeqlensQ: Optional[cute.Tensor],\n        mCuSeqlensK: Optional[cute.Tensor],\n        mSeqUsedQ: Optional[cute.Tensor],\n        mSeqUsedK: Optional[cute.Tensor],\n        softmax_scale: cutlass.Float32,\n        softmax_scale_log2: cutlass.Float32,\n        sQ_layout: cute.ComposedLayout,\n        sK_layout: cute.ComposedLayout,\n        sV_layout: cute.ComposedLayout,\n        sdO_layout: cute.ComposedLayout,\n        sPdS_layout: cute.ComposedLayout,\n        sLSE_layout: cute.Layout,\n        sLSEMma_layout: cute.Layout,\n        gmem_tiled_copy_QK: cute.TiledCopy,\n        gmem_tiled_copy_VdO: cute.TiledCopy,\n        gmem_tiled_copy_dK: cute.TiledCopy,\n        gmem_tiled_copy_dV: cute.TiledCopy,\n        gmem_tiled_copy_LSE: cute.TiledCopy,\n        gmem_tiled_copy_dQaccum: cute.TiledCopy,\n        tiled_mma_sdp: cute.TiledMma,\n        tiled_mma_dkv: cute.TiledMma,\n        tiled_mma_dq: cute.TiledMma,\n        SharedStorage: cutlass.Constexpr,\n        tile_sched_params: ParamsBase,\n        TileScheduler: cutlass.Constexpr[Callable],\n    ):\n        # Thread index, block index\n        tidx, _, _ = cute.arch.thread_idx()\n\n        tile_scheduler = TileScheduler.create(tile_sched_params)\n        work_tile = tile_scheduler.initial_work_tile_info()\n\n        n_block, head_idx, batch_idx, _ = work_tile.tile_idx\n\n        if work_tile.is_valid_tile:\n            seqlen = SeqlenInfoQK.create(\n                batch_idx,\n                mQ.shape[1],\n                mK.shape[1],\n                mCuSeqlensQ=mCuSeqlensQ,\n                mCuSeqlensK=mCuSeqlensK,\n                mSeqUsedQ=mSeqUsedQ,\n                mSeqUsedK=mSeqUsedK,\n                tile_m=self.m_block_size,\n                tile_n=self.n_block_size,\n            )\n\n            m_block_max = cute.ceil_div(seqlen.seqlen_q, self.m_block_size)\n            m_block_min = 0\n            if cutlass.const_expr(self.is_causal):\n                m_block_min = max(\n                    (n_block * self.n_block_size + seqlen.seqlen_q - seqlen.seqlen_k) // self.m_block_size,\n                    m_block_min,\n                )\n            # TODO: return early if m_block_max == 0\n\n            # ///////////////////////////////////////////////////////////////////////////////\n            # Get the appropriate tiles for this thread block.\n            # ///////////////////////////////////////////////////////////////////////////////\n            blkQ_shape = (self.m_block_size, self.head_dim_padded)\n            blkK_shape = (self.n_block_size, self.head_dim_padded)\n            blkV_shape = (self.n_block_size, self.head_dim_v_padded)\n            blkdO_shape = (self.m_block_size, self.head_dim_v_padded)\n\n            if cutlass.const_expr(not seqlen.has_cu_seqlens_q):\n                mQ_cur = mQ[batch_idx, None, head_idx, None]\n                mLSE_cur = mLSE[batch_idx, head_idx, None]\n                mdO_cur = mdO[batch_idx, None, head_idx, None]\n                mdPsum_cur = mdPsum[batch_idx, head_idx, None]\n                mdQaccum_cur = mdQaccum[batch_idx, head_idx, None]\n            else:\n                padded_offset_q = seqlen.padded_offset_q\n                mQ_cur = cute.domain_offset((seqlen.offset_q, 0), mQ[None, head_idx, None])\n                mLSE_cur = cute.domain_offset((padded_offset_q,), mLSE[head_idx, None])\n                mdO_cur = cute.domain_offset((seqlen.offset_q, 0), mdO[None, head_idx, None])\n                mdPsum_cur = cute.domain_offset((padded_offset_q,), mdPsum[head_idx, None])\n                mdQaccum_cur = cute.domain_offset((padded_offset_q * self.head_dim_padded,), mdQaccum[head_idx, None])\n            head_idx_kv = head_idx // self.qhead_per_kvhead if cutlass.const_expr(not self.pack_gqa) else head_idx\n\n            if cutlass.const_expr(not seqlen.has_cu_seqlens_k):\n                mK_cur, mV_cur = [t[batch_idx, None, head_idx_kv, None] for t in (mK, mV)]\n            else:\n                mK_cur, mV_cur = [cute.domain_offset((seqlen.offset_k, 0), t[None, head_idx_kv, None]) for t in (mK, mV)]\n\n            # (m_block_size, head_dim, m_block)\n            gQ = cute.local_tile(mQ_cur, blkQ_shape, (None, 0))\n            # (n_block_size, head_dim)\n            gK = cute.local_tile(mK_cur, blkK_shape, (n_block, 0))\n            # (n_block_size, head_dim_v)\n            gV = cute.local_tile(mV_cur, blkV_shape, (n_block, 0))\n            # (m_block_size, head_dim_v, m_block)\n            gdO = cute.local_tile(mdO_cur, blkdO_shape, (None, 0))\n            gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (None,))\n            gdPsum = cute.local_tile(mdPsum_cur, (self.m_block_size,), (None,))\n            gdQaccum = cute.local_tile(mdQaccum_cur, (self.m_block_size * self.head_dim_padded,), (None,))\n\n            # ///////////////////////////////////////////////////////////////////////////////\n            # Get shared memory buffer\n            # ///////////////////////////////////////////////////////////////////////////////\n            smem = cutlass.utils.SmemAllocator()\n            storage = smem.allocate(SharedStorage)\n            sQ = storage.sQ.get_tensor(sQ_layout)\n            sK = storage.sK.get_tensor(sK_layout)\n            if cutlass.const_expr(not self.share_QV_smem):\n                sV = storage.sV.get_tensor(sV_layout)\n            else:\n                sV = cute.make_tensor(cute.recast_ptr(sQ.iterator, dtype=self.dtype), sV_layout)\n            sdO = storage.sdO.get_tensor(sdO_layout)\n            sP = storage.sP.get_tensor(sPdS_layout)\n            sdS = storage.sdS.get_tensor(sPdS_layout)\n            sLSE = storage.sLSE.get_tensor(sLSE_layout)\n            sdPsum = storage.sdPsum.get_tensor(sLSE_layout)\n            sLSEMma = storage.sLSE.get_tensor(sLSEMma_layout)\n            sdPsumMma = storage.sdPsum.get_tensor(sLSEMma_layout)\n\n            # Transpose view of tensors for tiled mma\n            sQt, sdOt, sKt, sPt, sdSt = [layout_utils.transpose_view(t) for t in (sQ, sdO, sK, sP, sdS)]\n\n            gmem_thr_copy_QK = gmem_tiled_copy_QK.get_slice(tidx)\n            gmem_thr_copy_VdO = gmem_tiled_copy_VdO.get_slice(tidx)\n            gmem_thr_copy_lse = gmem_tiled_copy_LSE.get_slice(tidx)\n            gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_slice(tidx)\n            # (CPY_Atom, CPY_M, CPY_K, m_block)\n            tQgQ = gmem_thr_copy_QK.partition_S(gQ)\n            tQsQ = gmem_thr_copy_QK.partition_D(sQ)\n            # (CPY_Atom, CPY_N, CPY_K)\n            tKgK = gmem_thr_copy_QK.partition_S(gK)\n            tKsK = gmem_thr_copy_QK.partition_D(sK)\n            # (CPY_Atom, CPY_N, CPY_K)\n            tVgV = gmem_thr_copy_VdO.partition_S(gV)\n            tVsV = gmem_thr_copy_VdO.partition_D(sV)\n            # (CPY_Atom, CPY_M, CPY_K, m_block)\n            tdOgdO = gmem_thr_copy_VdO.partition_S(gdO)\n            tdOsdO = gmem_thr_copy_VdO.partition_D(sdO)\n            tLSEgLSE = gmem_thr_copy_lse.partition_S(gLSE)\n            tLSEsLSE = gmem_thr_copy_lse.partition_D(sLSE)\n            tLSEgdPsum = gmem_thr_copy_lse.partition_S(gdPsum)\n            tLSEsdPsum = gmem_thr_copy_lse.partition_D(sdPsum)\n            tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum)\n\n            # ///////////////////////////////////////////////////////////////////////////////\n            # Tile MMA compute thread partitions and allocate accumulators\n            # ///////////////////////////////////////////////////////////////////////////////\n            thr_mma_sdp = tiled_mma_sdp.get_slice(tidx)\n            thr_mma_dkv = tiled_mma_dkv.get_slice(tidx)\n            thr_mma_dq = tiled_mma_dq.get_slice(tidx)\n            acc_shape_dK = thr_mma_dkv.partition_shape_C((self.n_block_size, self.head_dim_padded))\n            acc_shape_dV = thr_mma_dkv.partition_shape_C((self.n_block_size, self.head_dim_v_padded))\n            acc_dK = cute.make_fragment(acc_shape_dK, cutlass.Float32)\n            acc_dV = cute.make_fragment(acc_shape_dV, cutlass.Float32)\n            acc_dK.fill(0.0)\n            acc_dV.fill(0.0)\n\n            tSrQ = utils.mma_make_fragment_A(sQ[None, None, 0], thr_mma_sdp, swapAB=self.SdP_swapAB)\n            tSrK = utils.mma_make_fragment_B(sK, thr_mma_sdp, swapAB=self.SdP_swapAB)\n            tdPrdO = utils.mma_make_fragment_A(sdO[None, None, 0], thr_mma_sdp, swapAB=self.SdP_swapAB)\n            tdPrV = utils.mma_make_fragment_B(sV, thr_mma_sdp, swapAB=self.SdP_swapAB)\n            tdVrP = utils.mma_make_fragment_A(sPt, thr_mma_dkv, swapAB=self.dKV_swapAB)\n            tdVrdO = utils.mma_make_fragment_B(sdOt[None, None, 0], thr_mma_dkv, swapAB=self.dKV_swapAB)\n            tdKrdS = utils.mma_make_fragment_A(sdSt, thr_mma_dkv, swapAB=self.dKV_swapAB)\n            tdKrQ = utils.mma_make_fragment_B(sQt[None, None, 0], thr_mma_dkv, swapAB=self.dKV_swapAB)\n            tdQrdS = utils.mma_make_fragment_A(sdS, thr_mma_dq, swapAB=self.dQ_swapAB)\n            tdQrK = utils.mma_make_fragment_B(sKt, thr_mma_dq, swapAB=self.dQ_swapAB)\n\n            LSEslice = (None, 0, None) if cutlass.const_expr(not self.SdP_swapAB) else (0, None, None)\n            tSsLSEMma = layout_utils.reshape_acc_to_mn(thr_mma_sdp.partition_C(sLSEMma))[LSEslice]\n            tSsdPsumMma = layout_utils.reshape_acc_to_mn(thr_mma_sdp.partition_C(sdPsumMma))[LSEslice]\n\n            # ///////////////////////////////////////////////////////////////////////////////\n            # Smem copy atom tiling\n            # ///////////////////////////////////////////////////////////////////////////////\n            smem_copy_atom = cute.make_copy_atom(\n                warp.LdMatrix8x8x16bOp(transpose=False, num_matrices=4), self.dtype,\n            )\n            smem_copy_atom_transposed = cute.make_copy_atom(\n                warp.LdMatrix8x8x16bOp(transpose=True, num_matrices=4), self.dtype,\n            )\n            smem_thr_copy_QdO = utils.make_tiled_copy_A(\n                smem_copy_atom, tiled_mma_sdp, swapAB=self.SdP_swapAB\n            ).get_slice(tidx)\n            smem_thr_copy_KV = utils.make_tiled_copy_B(\n                smem_copy_atom, tiled_mma_sdp, swapAB=self.SdP_swapAB\n            ).get_slice(tidx)\n            # TODO: should this be smem_copy_atom_transposed?\n            smem_thr_copy_PdSt = utils.make_tiled_copy_A(\n                smem_copy_atom_transposed, tiled_mma_dkv, swapAB=self.dKV_swapAB\n            ).get_slice(tidx)\n            smem_thr_copy_QdOt = utils.make_tiled_copy_B(\n                smem_copy_atom_transposed, tiled_mma_dkv, swapAB=self.dKV_swapAB\n            ).get_slice(tidx)\n            smem_thr_copy_dS = utils.make_tiled_copy_A(\n                smem_copy_atom, tiled_mma_dq, swapAB=self.dQ_swapAB\n            ).get_slice(tidx)\n            smem_thr_copy_Kt = utils.make_tiled_copy_B(\n                smem_copy_atom_transposed, tiled_mma_dq, swapAB=self.dQ_swapAB\n            ).get_slice(tidx)\n            # TODO: what's the number of bits? What if SdP_swapAB\n            r2s_thr_copy_PdS = cute.make_tiled_copy_C(\n                cute.make_copy_atom(\n                    cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=2 * self.dtype.width\n                ),\n                tiled_mma_sdp,\n            ).get_slice(tidx)\n\n            tSsQ = smem_thr_copy_QdO.partition_S(sQ)\n            tdPsdO = smem_thr_copy_QdO.partition_S(sdO)\n            tSsK = smem_thr_copy_KV.partition_S(sK)\n            tdPsV = smem_thr_copy_KV.partition_S(sV)\n            tdVsPt = smem_thr_copy_PdSt.partition_S(sPt)\n            tdKsdSt = smem_thr_copy_PdSt.partition_S(sdSt)\n            tdVsdOt = smem_thr_copy_QdOt.partition_S(sdOt)\n            tdKsQt = smem_thr_copy_QdOt.partition_S(sQt)\n            tdQsdS = smem_thr_copy_dS.partition_S(sdS)\n            tdQsKt = smem_thr_copy_Kt.partition_S(sKt)\n            tPsP = r2s_thr_copy_PdS.partition_D(sP)\n            tdSsdS = r2s_thr_copy_PdS.partition_D(sdS)\n\n            # ///////////////////////////////////////////////////////////////////////////////\n            # Predicate: Mark indices that need to copy when problem_shape isn't a multiple\n            # of tile_shape\n            # ///////////////////////////////////////////////////////////////////////////////\n            # Construct identity layout for KV\n            cQ = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded))\n            tQcQ = gmem_thr_copy_QK.partition_S(cQ)\n            t0QcQ = gmem_thr_copy_QK.get_slice(0).partition_S(cQ)\n            if cutlass.const_expr(self.head_dim_padded == self.head_dim_v_padded):\n                tdOcdO = tQcQ\n                t0dOcdO = t0QcQ\n            else:\n                cdO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded))\n                tdOcdO = gmem_thr_copy_VdO.partition_S(cdO)\n                t0dOcdO = gmem_thr_copy_VdO.get_slice(0).partition_S(cdO)\n            cLSE = cute.make_identity_tensor((self.m_block_size,))\n            tLSEcLSE = gmem_thr_copy_lse.partition_S(cLSE)\n\n            # Allocate predicate tensors for m and n, here we only allocate the tile of k, and\n            # use \"if\" on the mn dimension.\n            # This is to reduce register pressure and gets 2-3% performance gain.\n\n            d_head = mQ.shape[cute.rank(mQ) - 1]\n            d_head_v = mdO.shape[cute.rank(mdO) - 1]\n\n            tQpQ = utils.predicate_k(tQcQ, limit=d_head)\n            if cutlass.const_expr(self.same_hdim_kv):\n                tdOpdO = tQpQ\n            else:\n                tdOpdO = utils.predicate_k(tdOcdO, limit=d_head_v)\n\n            # group parameters for compute_one_m_block\n            mma_params = SimpleNamespace(\n                thr_mma_sdp=thr_mma_sdp, thr_mma_dkv=thr_mma_dkv, thr_mma_dq=thr_mma_dq,\n                tSrQ=tSrQ, tSrK=tSrK, tdPrdO=tdPrdO, tdPrV=tdPrV,\n                tdVrP=tdVrP, tdVrdO=tdVrdO, tdKrdS=tdKrdS, tdKrQ=tdKrQ,\n                tdQrdS=tdQrdS, tdQrK=tdQrK,\n                acc_dK=acc_dK, acc_dV=acc_dV,\n            )\n            smem_copy_params = SimpleNamespace(\n                smem_thr_copy_QdO=smem_thr_copy_QdO,\n                smem_thr_copy_KV=smem_thr_copy_KV,\n                smem_thr_copy_PdSt=smem_thr_copy_PdSt,\n                smem_thr_copy_QdOt=smem_thr_copy_QdOt,\n                smem_thr_copy_dS=smem_thr_copy_dS,\n                smem_thr_copy_Kt=smem_thr_copy_Kt,\n                r2s_thr_copy_PdS=r2s_thr_copy_PdS,\n                tSsQ=tSsQ, tSsK=tSsK, tdPsdO=tdPsdO, tdPsV=tdPsV,\n                tSsLSEMma=tSsLSEMma, tSsdPsumMma=tSsdPsumMma,\n                tPsP=tPsP, tdSsdS=tdSsdS,\n                tdVsPt=tdVsPt, tdVsdOt=tdVsdOt, tdKsdSt=tdKsdSt, tdKsQt=tdKsQt,\n                tdQsdS=tdQsdS, tdQsKt=tdQsKt,\n            )\n            gmem_copy_params = SimpleNamespace(\n                gmem_thr_copy_dQaccum=gmem_thr_copy_dQaccum, tdQgdQaccum=tdQgdQaccum\n            )\n            load_Q_LSE = partial(\n                self.load_Q_LSE, gmem_tiled_copy_QK, gmem_tiled_copy_LSE,\n                tQgQ, tQsQ, tQcQ, t0QcQ, tQpQ,\n                tLSEgLSE, tLSEsLSE, tLSEcLSE, seqlen=seqlen.seqlen_q\n            )\n            load_dO_dPsum = partial(\n                self.load_dO_dPsum, gmem_tiled_copy_VdO, gmem_tiled_copy_LSE,\n                tdOgdO, tdOsdO, tdOcdO, t0dOcdO, tdOpdO,\n                tLSEgdPsum, tLSEsdPsum, tLSEcLSE, seqlen=seqlen.seqlen_q\n            )\n            compute_one_m_block = partial(\n                self.compute_one_m_block, mma_params=mma_params,\n                smem_copy_params=smem_copy_params, gmem_copy_params=gmem_copy_params,\n                load_Q_LSE=load_Q_LSE, load_dO_dPsum=load_dO_dPsum,\n                m_block_max=m_block_max,\n                softmax_scale_log2=softmax_scale_log2,\n            )\n\n            # ///////////////////////////////////////////////////////////////////////////////\n            # Prologue\n            # ///////////////////////////////////////////////////////////////////////////////\n            # Start async loads of the last mn-tile, where we take care of the mn residue\n            self.load_V(gmem_thr_copy_VdO, tVgV, tVsV, n_block, seqlen=seqlen.seqlen_k,\n                        headdim=d_head_v)\n            if cutlass.const_expr(self.V_in_regs):\n                cute.arch.cp_async_commit_group()\n            self.load_K(gmem_thr_copy_QK, tKgK, tKsK, n_block, seqlen=seqlen.seqlen_k,\n                        headdim=d_head)\n            cute.arch.cp_async_commit_group()\n\n            if cutlass.const_expr(self.V_in_regs):\n                cute.arch.cp_async_wait_group(1)\n                cute.arch.barrier()\n                tdPrV_copy_view = smem_thr_copy_KV.retile(tdPrV)\n                cute.copy(smem_thr_copy_KV, tdPsV, tdPrV_copy_view)\n                # Sync to avoid loading Q to smem_q, which overlaps with smem_v\n                cute.arch.barrier()\n\n            m_block = m_block_min\n            assert self.num_stages_Q >= self.num_stages_dO\n            for stage in cutlass.range_constexpr(self.num_stages_Q):\n                if cutlass.const_expr(self.num_stages_Q == 1 or stage < self.num_stages_Q - 1):\n                    if stage == 0 or m_block + stage < m_block_max:\n                        load_Q_LSE(m_block + stage, smem_pipe_write_q=stage)\n                    cute.arch.cp_async_commit_group()\n                if cutlass.const_expr(stage < self.num_stages_dO):\n                    if stage == 0 or m_block + stage < m_block_max:\n                        load_dO_dPsum(m_block + stage, smem_pipe_write_q=stage)\n                    cute.arch.cp_async_commit_group()\n\n            # ///////////////////////////////////////////////////////////////////////////////\n            # Mainloop\n            # ///////////////////////////////////////////////////////////////////////////////\n            # Start processing of the first n-block.\n            mask = AttentionMask(self.m_block_size, self.n_block_size, seqlen)\n            mask_fn = partial(\n                mask.apply_mask, n_block=n_block, thr_mma=thr_mma_sdp,\n                batch_idx=batch_idx, head_idx=head_idx,\n                mask_seqlen=True, mask_causal=self.is_causal\n            )\n            smem_pipe_read_q = cutlass.Int32(0)\n            smem_pipe_read_do = cutlass.Int32(0)\n            smem_pipe_write_q = cutlass.Int32(self.num_stages_Q - 1)\n            smem_pipe_write_do = cutlass.Int32(0)\n            for m_tile in cutlass.range(m_block_min, m_block_max, unroll=1):\n                compute_one_m_block(\n                    m_tile, smem_pipe_read_q, smem_pipe_read_do, smem_pipe_write_q, smem_pipe_write_do,\n                    mask_fn=mask_fn,\n                )\n                smem_pipe_read_q = self.advance_pipeline(smem_pipe_read_q, self.num_stages_Q)\n                smem_pipe_read_do = self.advance_pipeline(smem_pipe_read_do, self.num_stages_dO)\n                smem_pipe_write_q = self.advance_pipeline(smem_pipe_write_q, self.num_stages_Q)\n                smem_pipe_write_do = self.advance_pipeline(smem_pipe_write_do, self.num_stages_dO)\n\n            # ///////////////////////////////////////////////////////////////////////////////\n            # Epilogue\n            # ///////////////////////////////////////////////////////////////////////////////\n            # If GQA, we scale dK in the postprocessing kernel instead\n            if cutlass.const_expr(self.qhead_per_kvhead == 1):\n                acc_dK.store(acc_dK.load() * softmax_scale)\n            # reuse sK and sV data iterator\n            sdK = cute.make_tensor(sK.iterator, sK_layout)\n            sdV = cute.make_tensor(sV.iterator, sV_layout)\n            self.epilogue(\n                acc_dK, acc_dV, mdK, mdV, sdK, sdV,\n                gmem_tiled_copy_dK, gmem_tiled_copy_dV, tiled_mma_dkv,\n                tidx, n_block, head_idx, batch_idx, seqlen, d_head, d_head_v\n            )\n\n    @cute.jit\n    def compute_one_m_block(\n        self,\n        m_block: cutlass.Int32,\n        smem_pipe_read_q: cutlass.Int32,\n        smem_pipe_read_do: cutlass.Int32,\n        smem_pipe_write_q: cutlass.Int32,\n        smem_pipe_write_do: cutlass.Int32,\n        mma_params: SimpleNamespace,\n        smem_copy_params: SimpleNamespace,\n        gmem_copy_params: SimpleNamespace,\n        load_Q_LSE: Callable,\n        load_dO_dPsum: Callable,\n        m_block_max: cutlass.Int32,\n        softmax_scale_log2: cutlass.Float32,\n        mask_fn: Optional[Callable] = None,\n    ):\n        def load_Q_next():\n            m_block_next = m_block + (self.num_stages_Q - 1 if cutlass.const_expr(self.num_stages_Q > 1) else 1)\n            if m_block_next < m_block_max:\n                load_Q_LSE(m_block_next, smem_pipe_write_q)\n            cute.arch.cp_async_commit_group()\n\n        def load_dO_next():\n            if m_block + self.num_stages_dO < m_block_max:\n                load_dO_dPsum(m_block + self.num_stages_dO, smem_pipe_write_do)\n            cute.arch.cp_async_commit_group()\n\n        # MMA S\n        acc_shape_SdP = mma_params.thr_mma_sdp.partition_shape_C(\n            (self.m_block_size, self.n_block_size) if cutlass.const_expr(not self.SdP_swapAB) else (self.n_block_size, self.m_block_size)\n        )\n        acc_S = cute.make_fragment(acc_shape_SdP, cutlass.Float32)\n        acc_S.fill(0.0)\n        cute.arch.cp_async_wait_group(1 if cutlass.const_expr(self.num_stages_Q > 1) else 0)\n        cute.arch.barrier()\n        sm80_utils.gemm(\n            mma_params.thr_mma_sdp, acc_S, mma_params.tSrQ, mma_params.tSrK,\n            smem_copy_params.tSsQ[None, None, None, smem_pipe_read_q if cutlass.const_expr(self.num_stages_Q > 1) else 0],\n            smem_copy_params.tSsK,\n            smem_copy_params.smem_thr_copy_QdO, smem_copy_params.smem_thr_copy_KV,\n            swap_AB=self.SdP_swapAB,\n        )\n        tLSErLSE = cute.make_fragment_like(smem_copy_params.tSsLSEMma[None, 0])\n        cute.autovec_copy(\n            smem_copy_params.tSsLSEMma[None, smem_pipe_read_q if cutlass.const_expr(self.num_stages_Q > 1) else 0], tLSErLSE\n        )\n        if cutlass.const_expr(mask_fn is not None):\n            mask_fn(acc_S, m_block=m_block)\n        acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S)\n        bidx = 0\n        # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_S_mn)\n        # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == 1: cute.print_tensor(tLSErLSE)\n        assert cute.size(acc_S_mn, mode=[0]) == cute.size(tLSErLSE)\n        for r in cutlass.range(cute.size(acc_S_mn, mode=[0]), unroll_full=True):\n            acc_S_mn[r, None].store(cute.math.exp2(acc_S_mn[r, None].load() * softmax_scale_log2 - tLSErLSE[r], fastmath=True))\n        # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_S_mn)\n\n        # MMA dP\n        acc_dP = cute.make_fragment(acc_shape_SdP, cutlass.Float32)\n        acc_dP.fill(0.0)\n        cute.arch.cp_async_wait_group(1 if cutlass.const_expr(self.num_stages_dO > 1) else 0)\n        cute.arch.barrier()\n        sm80_utils.gemm(\n            mma_params.thr_mma_sdp, acc_dP, mma_params.tdPrdO, mma_params.tdPrV,\n            smem_copy_params.tdPsdO[None, None, None, smem_pipe_read_do if cutlass.const_expr(self.num_stages_dO > 1) else 0],\n            smem_copy_params.tdPsV,\n            smem_copy_params.smem_thr_copy_QdO, smem_copy_params.smem_thr_copy_KV,\n            hook_fn=load_Q_next if cutlass.const_expr(self.num_stages_Q > 1) else None,\n            swap_AB=self.SdP_swapAB,\n        )\n        tLSErdPsum = cute.make_fragment_like(smem_copy_params.tSsdPsumMma[None, 0])\n        cute.autovec_copy(\n            smem_copy_params.tSsdPsumMma[None, smem_pipe_read_do if cutlass.const_expr(self.num_stages_dO > 1) else 0], tLSErdPsum\n        )\n        acc_dP_mn = layout_utils.reshape_acc_to_mn(acc_dP)\n        # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_dP_mn)\n        assert cute.size(acc_dP_mn, mode=[0]) == cute.size(tLSErdPsum)\n        for r in cutlass.range(cute.size(acc_dP_mn, mode=[0]), unroll_full=True):\n            acc_dP_mn[r, None].store(acc_S_mn[r, None].load() * (acc_dP_mn[r, None].load() - tLSErdPsum[r]))\n        # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_dP_mn)\n        rP = cute.make_fragment_like(acc_S, self.dtype)\n        rP.store(acc_S.load().to(self.dtype))\n        if cutlass.const_expr(not self.Mma_dKV_is_RS):\n            tPrP = smem_copy_params.r2s_thr_copy_PdS.retile(rP)  # ((Atom,AtomNum), MMA_N, MMA_N)\n            cute.copy(smem_copy_params.r2s_thr_copy_PdS, tPrP, smem_copy_params.tPsP)\n        rdS = cute.make_fragment_like(acc_dP, self.dtype)\n        rdS.store(acc_dP.load().to(self.dtype))\n        if cutlass.const_expr(not self.Mma_dKV_is_RS):\n            cute.arch.barrier()  # Make sure P is written\n        # For hdim 64, It's faster to write to smem_dS first before the dV gemm\n        if cutlass.const_expr(not self.Mma_dKV_is_RS):\n            tdSrdS = smem_copy_params.r2s_thr_copy_PdS.retile(rdS)\n            cute.copy(smem_copy_params.r2s_thr_copy_PdS, tdSrdS, smem_copy_params.tdSsdS)\n        if cutlass.const_expr(self.Mma_dKV_is_RS):\n            tdVrP = layout_utils.reshape_acc_to_frgA(rP)\n        else:\n            tdVrP = mma_params.tdVrP\n\n        # MMA dK\n        sm80_utils.gemm(\n            mma_params.thr_mma_dkv, mma_params.acc_dV, tdVrP, mma_params.tdVrdO,\n            smem_copy_params.tdVsPt,\n            smem_copy_params.tdVsdOt[None, None, None, smem_pipe_read_do if cutlass.const_expr(self.num_stages_dO > 1) else 0],\n            smem_copy_params.smem_thr_copy_PdSt, smem_copy_params.smem_thr_copy_QdOt,\n            A_in_regs=self.Mma_dKV_is_RS,\n            swap_AB=self.dKV_swapAB,\n        )\n        # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(mma_params.acc_dV)\n        cute.arch.barrier()  # Make sure dS is written\n\n        # MMA dQ\n        def dQ_mma(hook_fn):\n            acc_shape_dQ = mma_params.thr_mma_dq.partition_shape_C(\n                (self.m_block_size, self.head_dim_padded) if cutlass.const_expr(not self.dQ_swapAB) else (self.head_dim_padded, self.m_block_size)\n            )\n            acc_dQ = cute.make_fragment(acc_shape_dQ, cutlass.Float32)\n            acc_dQ.fill(0.0)\n            sm80_utils.gemm(\n                mma_params.thr_mma_dq, acc_dQ, mma_params.tdQrdS, mma_params.tdQrK,\n                smem_copy_params.tdQsdS, smem_copy_params.tdQsKt,\n                smem_copy_params.smem_thr_copy_dS, smem_copy_params.smem_thr_copy_Kt,\n                swap_AB=self.dQ_swapAB,\n                hook_fn=hook_fn\n            )\n            # ((1, 1), num_elements)\n            acc_dQ_atomic = gmem_copy_params.gmem_thr_copy_dQaccum.retile(acc_dQ)\n            tdQgdQaccum_atomic = gmem_copy_params.tdQgdQaccum[None, None, m_block]\n            assert cute.size(acc_dQ_atomic) == cute.size(tdQgdQaccum_atomic)\n            for i in cutlass.range(cute.size(acc_dQ_atomic), unroll_full=True):\n                utils.atomic_add_fp32(acc_dQ_atomic[i], utils.elem_pointer(tdQgdQaccum_atomic, i))\n                # utils.atomic_add_fp32(acc_dQ[i], tdQgdQaccum_atomic.iterator + i * tdQgdQaccum_atomic.stride[1])\n            # if cute.arch.thread_idx()[0] == 64 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_dQ)\n\n        # If num_stages_Q == 1, we want to do Mma_dK first so we can start loading Q for the next iteration\n        if cutlass.const_expr(self.num_stages_Q > 1):\n            dQ_mma(load_dO_next)\n\n        # MMA dK\n        if cutlass.const_expr(self.Mma_dKV_is_RS):\n            tdKrdS = layout_utils.reshape_acc_to_frgA(rdS)\n        else:\n            tdKrdS = mma_params.tdKrdS\n        sm80_utils.gemm(\n            mma_params.thr_mma_dkv, mma_params.acc_dK, tdKrdS, mma_params.tdKrQ,\n            smem_copy_params.tdKsdSt,\n            smem_copy_params.tdKsQt[None, None, None, smem_pipe_read_q if cutlass.const_expr(self.num_stages_Q > 1) else 0],\n            smem_copy_params.smem_thr_copy_PdSt, smem_copy_params.smem_thr_copy_QdOt,\n            A_in_regs=self.Mma_dKV_is_RS,\n            swap_AB=self.dKV_swapAB,\n            hook_fn=load_dO_next if cutlass.const_expr(self.num_stages_Q == 1) else None,\n        )\n        # if cute.arch.thread_idx()[0] == 0: cute.print_tensor(mma_params.acc_dK)\n        if cutlass.const_expr(self.num_stages_Q == 1):\n            cute.arch.barrier()\n            dQ_mma(load_Q_next)\n\n    @cute.jit\n    def epilogue(\n        self,\n        acc_dK: cute.Tensor,\n        acc_dV: cute.Tensor,\n        mdK: cute.Tensor,\n        mdV: cute.Tensor,\n        sdK: cute.Tensor,\n        sdV: cute.Tensor,\n        gmem_tiled_copy_dK: cute.TiledCopy,\n        gmem_tiled_copy_dV: cute.TiledCopy,\n        tiled_mma: cute.TiledMma,\n        tidx: cutlass.Int32,\n        n_block: cutlass.Int32,\n        num_head: cutlass.Int32,\n        batch_size: cutlass.Int32,\n        seqlen: SeqlenInfoQK,\n        d_head: cutlass.Int32,\n        d_head_v: cutlass.Int32\n    ):\n        rdV = cute.make_fragment_like(acc_dV, self.dtype)\n        rdV.store(acc_dV.load().to(self.dtype))\n        rdK = cute.make_fragment_like(acc_dK, self.dtype)\n        rdK.store(acc_dK.load().to(self.dtype))\n        gmem_thr_copy_dK = gmem_tiled_copy_dK.get_slice(tidx)\n        gmem_thr_copy_dV = gmem_tiled_copy_dV.get_slice(tidx)\n\n        batch_idx = batch_size\n        head_idx_kv = num_head // self.qhead_per_kvhead if cutlass.const_expr(not self.pack_gqa) else num_head\n\n        if cutlass.const_expr(self.qhead_per_kvhead == 1):\n            # Make sure all threads have finished reading K and V, otherwise we get racy dQ\n            # because smem_q could be changed.\n            cute.arch.barrier()\n            # smem copy atom for dKV\n            smem_copy_atom_dKV = cute.make_copy_atom(\n                cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=2 * self.dtype.width\n            )\n            smem_thr_copy_dKV = cute.make_tiled_copy_C(smem_copy_atom_dKV, tiled_mma).get_slice(tidx)\n            taccdVrdV = smem_thr_copy_dKV.retile(rdV)\n            taccdKrdK = smem_thr_copy_dKV.retile(rdK)\n            taccdVsdV = smem_thr_copy_dKV.partition_D(sdV)\n            taccdKsdK = smem_thr_copy_dKV.partition_D(sdK)\n            # copy acc O from rmem to smem with the smem copy atom\n            cute.copy(smem_copy_atom_dKV, taccdVrdV, taccdVsdV)\n            cute.copy(smem_copy_atom_dKV, taccdKrdK, taccdKsdK)\n\n\n            if cutlass.const_expr(not seqlen.has_cu_seqlens_k):\n                mdK_cur, mdV_cur = [t[batch_idx, None, head_idx_kv, None] for t in (mdK, mdV)]\n            else:\n                mdK_cur, mdV_cur = [cute.domain_offset((seqlen.offset_k, 0), t[None, head_idx_kv, None]) for t in (mdK, mdV)]\n\n            blkdK_shape = (self.n_block_size, self.head_dim_padded)\n            blkdV_shape = (self.n_block_size, self.head_dim_v_padded)\n            gdK = cute.local_tile(mdK_cur, blkdK_shape, (n_block, 0))\n            gdV = cute.local_tile(mdV_cur, blkdV_shape, (n_block, 0))\n            tdKsdK = gmem_thr_copy_dK.partition_S(sdK)\n            tdKgdK = gmem_thr_copy_dK.partition_D(gdK)\n            tdVsdV = gmem_thr_copy_dV.partition_S(sdV)\n            tdVgdV = gmem_thr_copy_dV.partition_D(gdV)\n            tdKrdK = cute.make_fragment_like(tdKgdK, self.dtype)\n            tdVrdV = cute.make_fragment_like(tdVgdV, self.dtype)\n            # sync before all smem stores are done.\n            cute.arch.barrier()\n            # load acc dK and dV from smem to rmem for wider vectorization\n            # Need to check OOB when reading from smem if kBlockN isn't evenly tiled\n            # TODO\n            cute.autovec_copy(tdKsdK, tdKrdK)\n            cute.autovec_copy(tdVsdV, tdVrdV)\n\n            cdK = cute.make_identity_tensor((self.n_block_size, self.head_dim_padded))\n            tdKcdK = gmem_thr_copy_dK.partition_S(cdK)\n            t0dKcdK = gmem_tiled_copy_dK.get_slice(0).partition_S(cdK)\n            if cutlass.const_expr(self.head_dim_padded == self.head_dim_v_padded):\n                tdVcdV = tdKcdK\n                t0dVcdV = t0dKcdK\n            else:\n                cdV = cute.make_identity_tensor((self.n_block_size, self.head_dim_v_padded))\n                tdVcdV = gmem_thr_copy_dV.partition_S(cdV)\n                t0dVcdV = gmem_tiled_copy_dV.get_slice(0).partition_S(cdV)\n            tdKpdK = utils.predicate_k(tdKcdK, limit=d_head)\n            if cutlass.const_expr(self.same_hdim_kv):\n                tdVpdV = tdKpdK\n            else:\n                tdVpdV = utils.predicate_k(tdVcdV, limit=d_head_v)\n            # copy acc dK and acc_dV from rmem to gmem\n            for rest_m in cutlass.range_constexpr(cute.size(tdKrdK.shape[1])):\n                if t0dKcdK[0, rest_m, 0][0] < seqlen.seqlen_k - n_block * self.n_block_size - tdKcdK[0][0]:\n                    cute.copy(\n                        gmem_tiled_copy_dK,\n                        tdKrdK[None, rest_m, None],\n                        tdKgdK[None, rest_m, None],\n                        pred=tdKpdK[None, rest_m, None] if cutlass.const_expr(self.check_hdim_oob) else None,\n                    )\n            for rest_m in cutlass.range_constexpr(cute.size(tdVrdV.shape[1])):\n                if t0dVcdV[0, rest_m, 0][0] < seqlen.seqlen_k - n_block * self.n_block_size - tdVcdV[0][0]:\n                    cute.copy(\n                        gmem_tiled_copy_dV,\n                        tdVrdV[None, rest_m, None],\n                        tdVgdV[None, rest_m, None],\n                        pred=tdVpdV[None, rest_m, None] if cutlass.const_expr(self.check_hdim_v_oob) else None,\n                    )\n\n        else:  # qhead_per_kvhead > 1, do atomic add\n            # For Sm90, we need to sync to avoid racy writes to smem_q\n            # For Sm80, we don't need to sync since we're not touching smem\n            head_idx_kv = num_head // self.qhead_per_kvhead if cutlass.const_expr(not self.pack_gqa) else num_head\n\n            if cutlass.const_expr(not seqlen.has_cu_seqlens_k):\n                mdK_cur, mdV_cur = [t[batch_idx, head_idx_kv, None] for t in (mdK, mdV)]\n            else:\n                padded_offset_k = seqlen.offset_k + batch_idx * self.n_block_size\n                mdK_cur = cute.domain_offset((padded_offset_k * self.head_dim_padded,), mdK[head_idx_kv, None])\n                mdV_cur = cute.domain_offset((padded_offset_k * self.head_dim_v_padded,), mdV[head_idx_kv, None])\n\n            gdV = cute.local_tile(mdV_cur, (self.n_block_size * self.head_dim_v_padded,), (n_block,))\n            gdK = cute.local_tile(mdK_cur, (self.n_block_size * self.head_dim_padded,), (n_block,))\n            tdVgdVaccum = gmem_thr_copy_dV.partition_S(gdV)\n            tdKgdKaccum = gmem_thr_copy_dK.partition_S(gdK)\n            acc_dV_atomic = gmem_thr_copy_dV.retile(acc_dV)\n            acc_dK_atomic = gmem_thr_copy_dK.retile(acc_dK)\n            assert cute.size(acc_dV_atomic) == cute.size(tdVgdVaccum)\n            assert cute.size(acc_dK_atomic) == cute.size(tdKgdKaccum)\n            for i in cutlass.range(cute.size(acc_dV_atomic), unroll_full=True):\n                utils.atomic_add_fp32(acc_dV_atomic[i], utils.elem_pointer(tdVgdVaccum, i))\n            for i in cutlass.range(cute.size(acc_dK_atomic), unroll_full=True):\n                utils.atomic_add_fp32(acc_dK_atomic[i], utils.elem_pointer(tdKgdKaccum, i))\n\n    @cute.jit\n    def advance_pipeline(self, pipeline_index, num_stages: cutlass.Constexpr):\n        return pipeline_index + 1 if pipeline_index < num_stages - 1 else 0\n\n    @cute.jit\n    def load_K(\n        self,\n        gmem_thr_copy: cute.TiledCopy,\n        tKgK: cute.Tensor,\n        tKsK: cute.Tensor,\n        block: cutlass.Int32,\n        seqlen: cutlass.Int32,\n        headdim: cutlass.Int32,\n    ):\n        cK = cute.make_identity_tensor((self.n_block_size, self.head_dim_padded))\n        tKcK = gmem_thr_copy.partition_S(cK)\n        t0KcK = gmem_thr_copy.get_slice(0).partition_S(cK)\n        tKpK = utils.predicate_k(tKcK, limit=headdim)\n        for n in cutlass.range_constexpr(cute.size(tKsK.shape[1])):\n            # If kBlockN doesn't evenly divide the tiled copy, only the last `n` needs to be checked\n            if self.is_even_n_smem_k or n < cute.size(tKsK.shape[1]) - 1 or tKcK[0, n, 0][0] < self.n_block_size:\n                # Instead of using tKcK, we using t0KcK and subtract the offset from the limit\n                # (seqlen - block * kBlockN). This is because the entries of t0KcK are known at compile time.\n                predicate_n = t0KcK[0, n, 0][0] < seqlen - block * self.n_block_size - tKcK[0][0]\n                predicate = cute.make_fragment_like(tKpK[None, 0, None])\n                for k in cutlass.range_constexpr(cute.size(predicate.shape[1])):\n                    for i in cutlass.range_constexpr(cute.size(predicate.shape[0])):\n                        predicate[i, k] = (tKpK[i, n, k] if cutlass.const_expr(self.check_hdim_oob) else True) and predicate_n\n                cute.copy(\n                    gmem_thr_copy, tKgK[None, n, None], tKsK[None, n, None], pred=predicate,\n                )\n            # We need to clear the sK smem tiles since we'll use sKt for mma_dq\n\n    @cute.jit\n    def load_V(\n        self,\n        gmem_thr_copy: cute.TiledCopy,\n        tVgV: cute.Tensor,\n        tVsV: cute.Tensor,\n        block: cutlass.Int32,\n        seqlen: cutlass.Int32,\n        headdim: cutlass.Int32,\n    ):\n        cV = cute.make_identity_tensor((self.n_block_size, self.head_dim_v_padded))\n        tVcV = gmem_thr_copy.partition_S(cV)\n        t0VcV = gmem_thr_copy.get_slice(0).partition_S(cV)\n        tVpV = utils.predicate_k(tVcV, limit=headdim)\n        for n in cutlass.range_constexpr(cute.size(tVsV.shape[1])):\n            # If kBlockN doesn't evenly divide the tiled copy, only the last `n` needs to be checked\n            if self.is_even_n_smem_v or n < cute.size(tVsV.shape[1]) - 1 or tVcV[0, n, 0][0] < self.n_block_size:\n                # Instead of using tVcV, we using t0VcV and subtract the offset from the limit\n                # (seqlen - block * kBlockN). This is because the entries of t0VcV are known at compile time.\n                predicate_n = t0VcV[0, n, 0][0] < seqlen - block * self.n_block_size - tVcV[0][0]\n                predicate = cute.make_fragment_like(tVpV[None, 0, None])\n                for k in cutlass.range_constexpr(cute.size(predicate.shape[1])):\n                    for i in cutlass.range_constexpr(cute.size(predicate.shape[0])):\n                        predicate[i, k] = (tVpV[i, n, k] if cutlass.const_expr(self.check_hdim_oob) else True) and predicate_n\n                cute.copy(\n                    gmem_thr_copy, tVgV[None, n, None], tVsV[None, n, None], pred=predicate,\n                )\n\n    @cute.jit\n    def load_Q_LSE(\n        self,\n        gmem_tiled_copy_Q: cute.TiledCopy,\n        gmem_tiled_copy_LSE: cute.TiledCopy,\n        tQgQ: cute.Tensor,\n        tQsQ: cute.Tensor,\n        tQcQ: cute.Tensor,\n        t0QcQ: cute.Tensor,\n        tQpQ: cute.Tensor,\n        tLSEgLSE: cute.Tensor,\n        tLSEsLSE: cute.Tensor,\n        tLSEcLSE: cute.Tensor,\n        block: cutlass.Int32,\n        smem_pipe_write_q: cutlass.Int32,\n        seqlen: cutlass.Int32,\n    ):\n        for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])):\n            # If kBlockM doesn't evenly divide the tiled copy, only the last `m` needs to be checked\n            if self.is_even_m_smem_q or m < cute.size(tQsQ.shape[1]) - 1 or tQcQ[0, m, 0][0] < self.m_block_size:\n                # Instead of using tQcQ, we using t0QcQ and subtract the offset from the limit\n                # (seqlen - block * kBlockM). This is because the entries of t0QcQ are known at compile time.\n                predicate_m = t0QcQ[0, m, 0][0] < seqlen - block * self.m_block_size - tQcQ[0][0]\n                predicate = cute.make_fragment_like(tQpQ[None, 0, None])\n                for k in cutlass.range_constexpr(cute.size(predicate.shape[1])):\n                    for i in cutlass.range_constexpr(cute.size(predicate.shape[0])):\n                        predicate[i, k] = (tQpQ[i, m, k] if cutlass.const_expr(self.check_hdim_oob) else True) and predicate_m\n                cute.copy(\n                    gmem_tiled_copy_Q,\n                    tQgQ[None, m, None, block],\n                    tQsQ[None, m, None, smem_pipe_write_q if cutlass.const_expr(self.num_stages_Q) > 1 else 0],\n                    pred=predicate,\n                )\n            # We need to clear the sQ smem tiles since we'll use sQt for mma_dK\n        # We made sure LSE length is padded so we read `kBlockM` elements so that all\n        # elements in sLSE are filled. Without this we might have uninitialized sLSE values.\n        for m in cutlass.range_constexpr(cute.size(tLSEsLSE.shape[1])):\n            if tLSEcLSE[0, m][0] < self.m_block_size:\n                cute.copy(\n                    gmem_tiled_copy_LSE,\n                    tLSEgLSE[None, m, block],\n                    tLSEsLSE[None, m, smem_pipe_write_q if cutlass.const_expr(self.num_stages_Q > 1) else 0],\n                )\n\n    @cute.jit\n    def load_dO_dPsum(\n        self,\n        gmem_tiled_copy_dO: cute.TiledCopy,\n        gmem_tiled_copy_dPsum: cute.TiledCopy,\n        tdOgdO: cute.Tensor,\n        tdOsdO: cute.Tensor,\n        tdOcdO: cute.Tensor,\n        t0dOcdO: cute.Tensor,\n        tdOpdO: cute.Tensor,\n        tdPsumgdPsum: cute.Tensor,\n        tdPsumsdPsum: cute.Tensor,\n        tdPsumcdPsum: cute.Tensor,\n        block: cutlass.Int32,\n        smem_pipe_write_q: cutlass.Int32,\n        seqlen: cutlass.Int32,\n    ):\n        for m in cutlass.range_constexpr(cute.size(tdOsdO.shape[1])):\n            # If kBlockM doesn't evenly divide the tiled copy, only the last `m` needs to be checked\n            if self.is_even_m_smem_do or m < cute.size(tdOsdO.shape[1]) - 1 or tdOcdO[0, m, 0][0] < self.m_block_size:\n                # Instead of using tdOcdO, we using t0dOcdO and subtract the offset from the limit\n                # (seqlen - block * kBlockM). This is because the entries of t0dOcdO are known at compile time.\n                predicate_m = t0dOcdO[0, m, 0][0] < seqlen - block * self.m_block_size - tdOcdO[0][0]\n                predicate = cute.make_fragment_like(tdOpdO[None, 0, None])\n                for k in cutlass.range_constexpr(cute.size(predicate.shape[1])):\n                    for i in cutlass.range_constexpr(cute.size(predicate.shape[0])):\n                        predicate[i, k] = (tdOpdO[i, m, k] if cutlass.const_expr(self.check_hdim_oob) else True) and predicate_m\n                cute.copy(\n                    gmem_tiled_copy_dO,\n                    tdOgdO[None, m, None, block],\n                    tdOsdO[None, m, None, smem_pipe_write_q if cutlass.const_expr(self.num_stages_dO > 1) else 0],\n                    pred=predicate,\n                )\n            # We need to clear the sQ smem tiles since we'll use sQt for mma_dK\n        # We made sure LSE length is padded so we read `kBlockM` elements so that all\n        # elements in sLSE are filled. Without this we might have uninitialized sLSE values.\n        for m in cutlass.range_constexpr(cute.size(tdPsumgdPsum.shape[1])):\n            if tdPsumcdPsum[0, m][0] < self.m_block_size:\n                cute.copy(\n                    gmem_tiled_copy_dPsum,\n                    tdPsumgdPsum[None, m, block],\n                    tdPsumsdPsum[None, m, smem_pipe_write_q if cutlass.const_expr(self.num_stages_dO > 1) else 0],\n                )\n"
  },
  {
    "path": "flash_attn/cute/flash_bwd_postprocess.py",
    "content": "# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n# A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_bwd_postprocess_kernel.h\n# from Cutlass C++ to Cute-DSL.\nimport math\nfrom typing import Callable, Optional, Type\n\nimport cuda.bindings.driver as cuda\n\nimport cutlass\nimport cutlass.cute as cute\nimport cutlass.utils.hopper_helpers as sm90_utils_basic\nimport cutlass.utils.blackwell_helpers as sm100_utils_basic\nfrom cutlass.cute.nvgpu import cpasync, warp, warpgroup\nfrom cutlass import Float32, const_expr\nfrom cutlass.utils import LayoutEnum\n\nfrom quack import copy_utils\nfrom quack import layout_utils\nfrom quack import sm90_utils\n\nfrom flash_attn.cute import utils\nfrom flash_attn.cute.cute_dsl_utils import assume_tensor_aligned\nfrom flash_attn.cute import ampere_helpers as sm80_utils\nfrom flash_attn.cute.seqlen_info import SeqlenInfoQK\nimport cutlass.cute.nvgpu.tcgen05 as tcgen05\nfrom quack.cute_dsl_utils import ParamsBase\nfrom flash_attn.cute.tile_scheduler import (\n    SingleTileScheduler,\n    SingleTileVarlenScheduler,\n    TileSchedulerArguments,\n)\n\n\nclass FlashAttentionBackwardPostprocess:\n    def __init__(\n        self,\n        dtype: Type[cutlass.Numeric],\n        head_dim: int,\n        arch: int,\n        tile_m: int = 128,\n        num_threads: int = 256,\n        AtomLayoutMdQ: int = 1,\n        dQ_swapAB: bool = False,\n        use_2cta_instrs: bool = False,\n        cluster_size: int = 1,  # for varlen offsets\n    ):\n        \"\"\"\n        :param head_dim: head dimension\n        :type head_dim: int\n        :param tile_m: m block size\n        :type tile_m: int\n        \"\"\"\n        self.dtype = dtype\n        self.tile_m = tile_m\n        assert arch // 10 in [8, 9, 10, 11, 12], (\n            \"Only Ampere (8.x), Hopper (9.x), and Blackwell (10.x, 11.x, 12.x) are supported\"\n        )\n        self.arch = arch\n        # padding head_dim to a multiple of 32 as k_block_size\n        hdim_multiple_of = 32\n        self.tile_hdim = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of)\n        self.check_hdim_oob = head_dim != self.tile_hdim\n        self.num_threads = num_threads\n        self.AtomLayoutMdQ = AtomLayoutMdQ\n        self.dQ_swapAB = dQ_swapAB\n        self.use_2cta_instrs = use_2cta_instrs and arch // 10 == 10 and head_dim != 64\n        self.cluster_size = cluster_size\n\n    @staticmethod\n    def can_implement(dtype, head_dim, tile_m, num_threads) -> bool:\n        \"\"\"Check if the kernel can be implemented with the given parameters.\n\n        :param dtype: data type\n        :type dtype: cutlass.Numeric\n        :param head_dim: head dimension\n        :type head_dim: int\n        :param tile_m: m block size\n        :type tile_m: int\n\n        :return: True if the kernel can be implemented, False otherwise\n        :rtype: bool\n        \"\"\"\n        if dtype not in [cutlass.Float16, cutlass.BFloat16]:\n            return False\n        if head_dim % 8 != 0:\n            return False\n        if num_threads % 32 != 0:\n            return False\n        return True\n\n    def _get_tiled_mma(self):\n        if const_expr(self.arch // 10 in [8, 12]):\n            num_mma_warps = self.num_threads // 32\n            atom_layout_dQ = (\n                (self.AtomLayoutMdQ, num_mma_warps // self.AtomLayoutMdQ, 1)\n                if const_expr(not self.dQ_swapAB)\n                else (num_mma_warps // self.AtomLayoutMdQ, self.AtomLayoutMdQ, 1)\n            )\n            tiled_mma = cute.make_tiled_mma(\n                warp.MmaF16BF16Op(self.dtype, Float32, (16, 8, 16)),\n                atom_layout_dQ,\n                permutation_mnk=(atom_layout_dQ[0] * 16, atom_layout_dQ[1] * 16, 16),\n            )\n        elif const_expr(self.arch // 10 == 9):\n            num_wg_mma = self.num_threads // 128\n            atom_layout_dQ = (self.AtomLayoutMdQ, num_wg_mma // self.AtomLayoutMdQ)\n            tiler_mn_dQ = (self.tile_m // atom_layout_dQ[0], self.tile_hdim // atom_layout_dQ[1])\n            tiled_mma = sm90_utils_basic.make_trivial_tiled_mma(\n                self.dtype,\n                self.dtype,\n                warpgroup.OperandMajorMode.K,  # These don't matter, we only care about the accum\n                warpgroup.OperandMajorMode.K,\n                Float32,\n                atom_layout_mnk=(atom_layout_dQ if not self.dQ_swapAB else atom_layout_dQ[::-1])\n                + (1,),\n                tiler_mn=tiler_mn_dQ if not self.dQ_swapAB else tiler_mn_dQ[::-1],\n            )\n        else:\n            cta_group = tcgen05.CtaGroup.ONE\n            tiled_mma = sm100_utils_basic.make_trivial_tiled_mma(\n                self.dtype,\n                tcgen05.OperandMajorMode.MN,  # dS_major_mode\n                tcgen05.OperandMajorMode.MN,  # Kt_major_mode\n                Float32,\n                cta_group,\n                (self.tile_m, self.tile_hdim),\n            )\n        if const_expr(self.arch // 10 in [8, 9, 12]):\n            assert self.num_threads == tiled_mma.size\n        return tiled_mma\n\n    def _setup_attributes(self):\n        # ///////////////////////////////////////////////////////////////////////////////\n        # GMEM Tiled copy:\n        # ///////////////////////////////////////////////////////////////////////////////\n        # Thread layouts for copies\n        universal_copy_bits = 128\n        async_copy_elems_accum = universal_copy_bits // Float32.width\n        atom_async_copy_accum = cute.make_copy_atom(\n            cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),\n            Float32,\n            num_bits_per_copy=universal_copy_bits,\n        )\n        # We don't do bound checking for the gmem -> smem load so we just assert here.\n        assert (self.tile_m * self.tile_hdim // async_copy_elems_accum) % self.num_threads == 0\n        self.g2s_tiled_copy_dQaccum = cute.make_tiled_copy_tv(\n            atom_async_copy_accum,\n            cute.make_layout(self.num_threads),\n            cute.make_layout(async_copy_elems_accum),\n        )\n        num_s2r_copy_elems = 1 if const_expr(self.arch // 10 in [8, 12]) else 4\n        if const_expr(self.arch // 10 in [8, 12]):\n            self.s2r_tiled_copy_dQaccum = copy_utils.tiled_copy_1d(\n                Float32, self.num_threads, num_s2r_copy_elems\n            )\n            self.sdQaccum_layout = cute.make_layout(self.tile_m * self.tile_hdim)\n        elif const_expr(self.arch // 10 == 9):\n            num_threads_per_warp_group = 128\n            num_wg_mma = self.num_threads // 128\n            self.s2r_tiled_copy_dQaccum = cute.make_tiled_copy_tv(\n                cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32, num_bits_per_copy=128),\n                cute.make_layout((num_threads_per_warp_group, num_wg_mma)),  # thr_layout\n                cute.make_layout(128 // Float32.width),  # val_layout\n            )\n            self.sdQaccum_layout = cute.make_layout(\n                (self.tile_m * self.tile_hdim // num_wg_mma, num_wg_mma)\n            )\n        else:\n            self.dQ_reduce_ncol = 32\n            dQaccum_reduce_stage = self.tile_hdim // self.dQ_reduce_ncol\n            assert self.num_threads == 128  # TODO: currently hard-coded\n            self.s2r_tiled_copy_dQaccum = copy_utils.tiled_copy_1d(\n                Float32, self.num_threads, num_s2r_copy_elems\n            )\n            self.sdQaccum_layout = cute.make_layout(\n                (self.tile_m * self.tile_hdim // dQaccum_reduce_stage, dQaccum_reduce_stage)\n            )\n\n        num_copy_elems = 128 // self.dtype.width\n        threads_per_row = math.gcd(128, self.tile_hdim) // num_copy_elems\n        self.gmem_tiled_copy_dQ = copy_utils.tiled_copy_2d(\n            self.dtype, threads_per_row, self.num_threads, num_copy_elems\n        )\n        # ///////////////////////////////////////////////////////////////////////////////\n        # Shared memory layout: dQ\n        # ///////////////////////////////////////////////////////////////////////////////\n        # We can't just use kHeadDim here. E.g. if MMA shape is 64 x 96 but split across 2 WGs,\n        # then setting kBlockKSmem to 32 will cause \"Static shape_div failure\".\n        # We want to treat it as 64 x 48, so kBlockKSmem should be 16.\n        mma_shape_n = self.tiled_mma.get_tile_size(1)\n        if const_expr(self.arch // 10 in [8, 12]):\n            sdQ_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, mma_shape_n)\n            self.sdQ_layout = cute.tile_to_shape(\n                sdQ_layout_atom, (self.tile_m, self.tile_hdim), (0, 1)\n            )\n        elif const_expr(self.arch // 10 == 9):\n            wg_d_dQ = num_wg_mma // self.AtomLayoutMdQ\n            self.sdQ_layout = sm90_utils.make_smem_layout(\n                self.dtype,\n                LayoutEnum.ROW_MAJOR,\n                (self.tile_m, self.tile_hdim),\n                major_mode_size=self.tile_hdim // wg_d_dQ,\n            )\n        else:\n            # TODO: this is hard-coded for hdim 128\n            self.sdQ_layout = sm100_utils_basic.make_smem_layout_epi(\n                self.dtype, LayoutEnum.ROW_MAJOR, (self.tile_m, self.tile_hdim), 1\n            )\n\n    @cute.jit\n    def __call__(\n        self,\n        mdQaccum: cute.Tensor,\n        mdQ: cute.Tensor,\n        scale: cutlass.Float32,\n        mCuSeqlensQ: Optional[cute.Tensor],\n        mSeqUsedQ: Optional[cute.Tensor],\n        # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI).\n        stream: cuda.CUstream = None,\n    ):\n        # Get the data type and check if it is fp16 or bf16\n        if const_expr(mdQ.element_type not in [cutlass.Float16, cutlass.BFloat16]):\n            raise TypeError(\"Only Float16 or BFloat16 is supported\")\n        if const_expr(mdQaccum is not None):\n            if const_expr(mdQaccum.element_type not in [cutlass.Float32]):\n                raise TypeError(\"dQaccum tensor must be Float32\")\n\n        mdQaccum, mdQ = [assume_tensor_aligned(t) for t in (mdQaccum, mdQ)]\n\n        self.tiled_mma = self._get_tiled_mma()\n        self._setup_attributes()\n\n        smem_size = max(\n            cute.size_in_bytes(cutlass.Float32, self.sdQaccum_layout),\n            cute.size_in_bytes(self.dtype, self.sdQ_layout),\n        )\n\n        if const_expr(mCuSeqlensQ is not None):\n            TileScheduler = SingleTileVarlenScheduler\n            num_head = mdQ.shape[1]\n            num_batch = mCuSeqlensQ.shape[0] - 1\n            num_block = cute.ceil_div(mdQ.shape[0], self.tile_m)\n        else:\n            TileScheduler = SingleTileScheduler\n            num_head = mdQ.shape[2]\n            num_batch = mdQ.shape[0]\n            num_block = cute.ceil_div(mdQ.shape[1], self.tile_m)\n\n        tile_sched_args = TileSchedulerArguments(\n            num_block=num_block,\n            num_head=num_head,\n            num_batch=num_batch,\n            num_splits=1,\n            seqlen_k=0,\n            headdim=mdQ.shape[2],\n            headdim_v=0,\n            total_q=mdQ.shape[0],\n            tile_shape_mn=(self.tile_m, 1),\n            mCuSeqlensQ=mCuSeqlensQ,\n            mSeqUsedQ=mSeqUsedQ,\n        )\n\n        tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)\n        grid_dim = TileScheduler.get_grid_shape(tile_sched_params)\n\n        # grid_dim: (m_block, num_head, batch_size)\n        self.kernel(\n            mdQaccum,\n            mdQ,\n            mCuSeqlensQ,\n            mSeqUsedQ,\n            scale,\n            self.tiled_mma,\n            self.dQ_swapAB,\n            self.sdQaccum_layout,\n            self.sdQ_layout,\n            self.g2s_tiled_copy_dQaccum,\n            self.s2r_tiled_copy_dQaccum,\n            self.gmem_tiled_copy_dQ,\n            tile_sched_params,\n            TileScheduler,\n        ).launch(\n            grid=grid_dim,\n            block=[self.num_threads, 1, 1],\n            smem=smem_size,\n            stream=stream,\n        )\n\n    @cute.kernel\n    def kernel(\n        self,\n        mdQaccum: cute.Tensor,\n        mdQ: cute.Tensor,\n        mCuSeqlensQ: Optional[cute.Tensor],\n        mSeqUsedQ: Optional[cute.Tensor],\n        scale: cutlass.Float32,\n        tiled_mma: cute.TiledMma,\n        dQ_swapAB: cutlass.Constexpr,\n        sdQaccum_layout: cute.Layout,\n        sdQ_layout: cute.ComposedLayout,\n        g2s_tiled_copy_dQaccum: cute.TiledCopy,\n        s2r_tiled_copy_dQaccum: cute.TiledCopy,\n        gmem_tiled_copy_dQ: cute.TiledCopy,\n        tile_sched_params: ParamsBase,\n        TileScheduler: cutlass.Constexpr[Callable],\n    ):\n        # ///////////////////////////////////////////////////////////////////////////////\n        # Get shared memory buffer\n        # ///////////////////////////////////////////////////////////////////////////////\n        smem = cutlass.utils.SmemAllocator()\n        sdQaccum = smem.allocate_tensor(cutlass.Float32, sdQaccum_layout, byte_alignment=1024)\n        sdQaccum_flat = cute.make_tensor(sdQaccum.iterator, cute.make_layout(cute.size(sdQaccum)))\n        if const_expr(self.arch // 10 in [8, 9, 12]):\n            sdQ = cute.make_tensor(cute.recast_ptr(sdQaccum.iterator, dtype=self.dtype), sdQ_layout)\n        else:\n            # extra stage dimension\n            sdQ = cute.make_tensor(\n                cute.recast_ptr(sdQaccum.iterator, sdQ_layout.inner, dtype=self.dtype),\n                sdQ_layout.outer,\n            )[None, None, 0]\n        sdQt = layout_utils.transpose_view(sdQ)\n\n        # Thread index, block index\n        tidx, _, _ = cute.arch.thread_idx()\n\n        tile_scheduler = TileScheduler.create(tile_sched_params)\n        work_tile = tile_scheduler.initial_work_tile_info()\n\n        m_block, head_idx, batch_idx, _ = work_tile.tile_idx\n\n        if work_tile.is_valid_tile:\n            # ///////////////////////////////////////////////////////////////////////////////\n            # Get the appropriate tiles for this thread block.\n            # ///////////////////////////////////////////////////////////////////////////////\n\n            seqlen = SeqlenInfoQK.create(\n                batch_idx,\n                mdQ.shape[1],\n                0,\n                mCuSeqlensQ=mCuSeqlensQ,\n                mCuSeqlensK=None,\n                mSeqUsedQ=mSeqUsedQ,\n                mSeqUsedK=None,\n                tile_m=self.tile_m * self.cluster_size,\n            )\n            if const_expr(not seqlen.has_cu_seqlens_q):\n                mdQ_cur = mdQ[batch_idx, None, head_idx, None]\n                mdQaccum_cur = mdQaccum[batch_idx, head_idx, None]\n                head_dim = mdQ.shape[3]\n            else:\n                padded_offset_q = seqlen.padded_offset_q\n                mdQ_cur = cute.domain_offset((seqlen.offset_q, 0), mdQ[None, head_idx, None])\n                mdQaccum_cur = cute.domain_offset(\n                    (padded_offset_q * self.tile_hdim,), mdQaccum[head_idx, None]\n                )\n                head_dim = mdQ.shape[2]\n\n                # HACK: Compiler doesn't seem to recognize that padding\n                # by padded_offset_q * self.tile_hdim keeps alignment\n                # since statically divisible by 4\n\n                mdQaccum_cur_ptr = cute.make_ptr(\n                    dtype=mdQaccum_cur.element_type,\n                    value=mdQaccum_cur.iterator.toint(),\n                    mem_space=mdQaccum_cur.iterator.memspace,\n                    assumed_align=mdQaccum.iterator.alignment,\n                )\n                mdQaccum_cur = cute.make_tensor(mdQaccum_cur_ptr, mdQaccum_cur.layout)\n\n            gdQaccum = cute.local_tile(mdQaccum_cur, (self.tile_m * self.tile_hdim,), (m_block,))\n            gdQ = cute.local_tile(mdQ_cur, (self.tile_m, self.tile_hdim), (m_block, 0))\n\n            seqlen_q = seqlen.seqlen_q\n            seqlen_q_rounded = cute.round_up(seqlen_q, self.tile_m)\n\n            if const_expr(self.arch // 10 == 10 and self.use_2cta_instrs):\n                # 2-CTA: remap dQaccum layout into TMEM view before writing sdQ\n                num_reduce_threads = self.num_threads\n                thr_mma_dsk = tiled_mma.get_slice(tidx)\n                dQacc_shape = thr_mma_dsk.partition_shape_C((self.tile_m, self.tile_hdim))\n                tdQtdQ = thr_mma_dsk.make_fragment_C(dQacc_shape)\n                tdQtdQ = cute.make_tensor(tdQtdQ.iterator, tdQtdQ.layout)\n\n                tmem_load_atom = cute.make_copy_atom(\n                    tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(self.dQ_reduce_ncol)), Float32\n                )\n                tiled_tmem_ld = tcgen05.make_tmem_copy(tmem_load_atom, tdQtdQ)\n                thr_tmem_ld = tiled_tmem_ld.get_slice(tidx)\n\n                cdQ = cute.make_identity_tensor((self.tile_m, self.tile_hdim))\n                tdQcdQ = thr_mma_dsk.partition_C(cdQ)\n                tdQcdQ_tensor = cute.make_tensor(tdQcdQ.iterator, tdQcdQ.layout)\n                tdQrdQ = thr_tmem_ld.partition_D(tdQcdQ_tensor)\n\n                tiled_copy_accum = s2r_tiled_copy_dQaccum\n                g2s_thr_copy = tiled_copy_accum.get_slice(tidx)\n\n                # S -> R\n                tdQrdQ_fp32 = cute.make_fragment(tdQrdQ.shape, cutlass.Float32)\n                tdQrdQ_s2r = cute.make_tensor(tdQrdQ_fp32.iterator, tdQrdQ_fp32.shape)\n\n                smem_copy_atom = sm100_utils_basic.get_smem_store_op(\n                    LayoutEnum.ROW_MAJOR, self.dtype, cutlass.Float32, tiled_tmem_ld\n                )\n                r2s_tiled_copy = cute.make_tiled_copy(\n                    smem_copy_atom,\n                    layout_tv=tiled_tmem_ld.layout_dst_tv_tiled,\n                    tiler_mn=tiled_tmem_ld.tiler_mn,\n                )\n                tdQsdQ_r2s = thr_tmem_ld.partition_D(thr_mma_dsk.partition_C(sdQ))\n                tdQrdQ_r2s = cute.make_fragment(tdQsdQ_r2s.shape, self.dtype)\n\n                num_stages = cute.size(tdQrdQ_fp32, mode=[1])\n                stage_stride = self.dQ_reduce_ncol\n                row_groups = 2\n                assert num_stages % row_groups == 0\n                assert num_reduce_threads % row_groups == 0\n                stage_groups = num_stages // row_groups\n                threads_per_row_group = num_reduce_threads // row_groups\n                stage_loads = tuple((row_group, row_group) for row_group in range(row_groups))\n                stage_iters = tuple(\n                    (row_group, row_group * threads_per_row_group)\n                    for row_group in range(row_groups)\n                )\n                s2r_lane = tidx % threads_per_row_group\n                s2r_buf = tidx // threads_per_row_group\n\n                gdQaccum_layout_g2s = cute.make_layout(\n                    shape=(self.tile_m * self.dQ_reduce_ncol, 1), stride=(1, 0)\n                )\n                sdQaccum_g2s = g2s_thr_copy.partition_D(sdQaccum)\n\n                # G -> S\n                for stage_group in cutlass.range_constexpr(stage_groups):\n                    for stage_offset, smem_buf in stage_loads:\n                        stage_idx = stage_group + stage_offset * stage_groups\n                        gdQaccum_stage = cute.local_tile(\n                            gdQaccum,\n                            (self.tile_m * self.dQ_reduce_ncol,),\n                            (stage_idx,),\n                        )\n                        gdQaccum_stage_g2s = cute.make_tensor(\n                            gdQaccum_stage.iterator,\n                            gdQaccum_layout_g2s,\n                        )\n                        tdQgdQ = g2s_thr_copy.partition_S(gdQaccum_stage_g2s)\n                        cute.copy(\n                            g2s_thr_copy,\n                            tdQgdQ[None, None, 0],\n                            sdQaccum_g2s[None, None, smem_buf],\n                        )\n\n                    cute.arch.fence_view_async_shared()\n                    cute.arch.barrier(barrier_id=6, number_of_threads=num_reduce_threads)\n\n                    # S -> R\n                    for stage_offset, lane_offset in stage_iters:\n                        stage_idx = stage_group + stage_offset * stage_groups\n                        s2r_src_tidx = s2r_lane + lane_offset\n                        s2r_thr_copy = tiled_copy_accum.get_slice(s2r_src_tidx)\n                        sdQaccum_src = s2r_thr_copy.partition_S(sdQaccum)[None, None, s2r_buf]\n\n                        tdQrdQ_s2r_cpy = tdQrdQ_s2r[None, stage_idx, None, None]\n                        tdQrdQ_r2s_cpy = cute.make_tensor(\n                            tdQrdQ_s2r_cpy.iterator, cute.make_layout(sdQaccum_src.shape)\n                        )\n                        cute.copy(s2r_thr_copy, sdQaccum_src, tdQrdQ_r2s_cpy)\n                        cute.arch.fence_view_async_shared()\n                        cute.arch.barrier(barrier_id=7, number_of_threads=num_reduce_threads)\n\n                        # R -> S\n                        stage_lo = stage_idx % stage_stride\n                        stage_hi = stage_idx // stage_stride\n                        tdQrdQ_r2s_cpy = cute.make_tensor(\n                            cute.recast_ptr(tdQrdQ_r2s_cpy.iterator),\n                            tdQrdQ_r2s[((None, 0), (stage_lo, stage_hi), 0, 0)].shape,\n                        )\n                        dQ_vec = tdQrdQ_r2s_cpy.load() * scale\n                        tdQrdQ_r2s[((None, 0), (stage_lo, stage_hi), 0, 0)].store(\n                            dQ_vec.to(self.dtype)\n                        )\n\n                # R -> S\n                cute.copy(\n                    r2s_tiled_copy,\n                    tdQrdQ_r2s[None, None, None, 0],\n                    tdQsdQ_r2s[None, None, None, 0],\n                )\n                cute.arch.fence_view_async_shared()\n                cute.arch.barrier(barrier_id=8, number_of_threads=num_reduce_threads)\n            else:\n                # Step 1: load dQaccum from gmem to smem\n                g2s_thr_copy_dQaccum = g2s_tiled_copy_dQaccum.get_slice(tidx)\n                tdQgdQaccum = g2s_thr_copy_dQaccum.partition_S(gdQaccum)\n                tdQsdQaccumg2s = g2s_thr_copy_dQaccum.partition_D(sdQaccum_flat)\n                cute.copy(g2s_tiled_copy_dQaccum, tdQgdQaccum, tdQsdQaccumg2s)\n                cute.arch.cp_async_commit_group()\n                cute.arch.cp_async_wait_group(0)\n                cute.arch.barrier()\n\n                # Step 2: load dQ from smem to rmem\n                s2r_thr_copy_dQaccum = s2r_tiled_copy_dQaccum.get_slice(tidx)\n                tdQsdQaccum = s2r_thr_copy_dQaccum.partition_S(sdQaccum)\n                tile_shape = (self.tile_m, self.tile_hdim)\n                acc = None\n                tiled_copy_t2r = None\n                if const_expr(self.arch // 10 in [8, 9, 12]):\n                    acc_shape = tiled_mma.partition_shape_C(\n                        tile_shape if const_expr(not dQ_swapAB) else tile_shape[::-1]\n                    )\n                    acc = cute.make_fragment(acc_shape, cutlass.Float32)\n                    assert cute.size(acc) == cute.size(tdQsdQaccum)\n                else:\n                    thr_mma = tiled_mma.get_slice(0)  # 1-CTA\n                    dQacc_shape = tiled_mma.partition_shape_C((self.tile_m, self.tile_hdim))\n                    tdQtdQ = tiled_mma.make_fragment_C(dQacc_shape)\n                    tdQcdQ = thr_mma.partition_C(\n                        cute.make_identity_tensor((self.tile_m, self.tile_hdim))\n                    )\n                    tmem_load_atom = cute.make_copy_atom(\n                        tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(self.dQ_reduce_ncol)),\n                        Float32,\n                    )\n                    tiled_copy_t2r = tcgen05.make_tmem_copy(tmem_load_atom, tdQtdQ)\n                    thr_copy_t2r = tiled_copy_t2r.get_slice(tidx)\n                    tdQrdQ_t2r_shape = thr_copy_t2r.partition_D(tdQcdQ).shape\n                    acc = cute.make_fragment(tdQrdQ_t2r_shape, Float32)\n                tdQrdQaccum = cute.make_tensor(acc.iterator, cute.make_layout(tdQsdQaccum.shape))\n                cute.autovec_copy(tdQsdQaccum, tdQrdQaccum)\n                # Convert tdQrdQaccum from fp32 to fp16/bf16\n                rdQ = cute.make_fragment_like(acc, self.dtype)\n                rdQ.store((acc.load() * scale).to(self.dtype))\n\n                # Step 3: Copy dQ from register to smem\n                cute.arch.barrier()  # make sure all threads have finished loading dQaccum\n                if const_expr(self.arch // 10 in [8, 9, 12]):\n                    copy_atom_r2s_dQ = utils.get_smem_store_atom(\n                        self.arch, self.dtype, transpose=self.dQ_swapAB\n                    )\n                    tiled_copy_r2s_dQ = cute.make_tiled_copy_C(copy_atom_r2s_dQ, tiled_mma)\n                else:\n                    # copy_atom_r2s_dQ = sm100_utils_basic.get_smem_store_op(\n                    #     LayoutEnum.ROW_MAJOR, self.dtype, Float32, tiled_copy_t2r,\n                    # )\n                    # tiled_copy_r2s_dQ = cute.make_tiled_copy_D(copy_atom_r2s_dQ, tiled_copy_t2r)\n                    thr_layout_r2s_dQ = cute.make_layout((self.num_threads, 1))  # 128 threads\n                    val_layout_r2s_dQ = cute.make_layout((1, 128 // self.dtype.width))\n                    copy_atom_r2s_dQ = cute.make_copy_atom(\n                        cute.nvgpu.CopyUniversalOp(),\n                        self.dtype,\n                        num_bits_per_copy=128,\n                    )\n                    tiled_copy_r2s_dQ = cute.make_tiled_copy_tv(\n                        copy_atom_r2s_dQ, thr_layout_r2s_dQ, val_layout_r2s_dQ\n                    )\n                thr_copy_r2s_dQ = tiled_copy_r2s_dQ.get_slice(tidx)\n                cdQ = cute.make_identity_tensor((self.tile_m, self.tile_hdim))\n                if const_expr(self.arch // 10 in [8, 9, 12]):\n                    taccdQrdQ = thr_copy_r2s_dQ.retile(rdQ)\n                else:\n                    taccdQcdQ_shape = thr_copy_r2s_dQ.partition_S(cdQ).shape\n                    taccdQrdQ = cute.make_tensor(rdQ.iterator, taccdQcdQ_shape)\n                taccdQsdQ = thr_copy_r2s_dQ.partition_D(\n                    sdQ if const_expr(not self.dQ_swapAB) else sdQt\n                )\n                cute.copy(thr_copy_r2s_dQ, taccdQrdQ, taccdQsdQ)\n\n            # Step 4: Copy dQ from smem to register to prepare for coalesced write to gmem\n            cute.arch.barrier()  # make sure all smem stores are done\n            gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_slice(tidx)\n            tdQgdQ = gmem_thr_copy_dQ.partition_S(gdQ)\n            tdQsdQ = gmem_thr_copy_dQ.partition_D(sdQ)\n            tdQrdQ = cute.make_fragment_like(tdQsdQ, self.dtype)\n            # TODO: check OOB when reading from smem if kBlockM isn't evenly tiled\n            cute.autovec_copy(tdQsdQ, tdQrdQ)\n\n            # Step 5: Copy dQ from register to gmem\n            tdQcdQ = gmem_thr_copy_dQ.partition_S(cdQ)\n            tdQpdQ = utils.predicate_k(tdQcdQ, limit=head_dim)\n            for rest_m in cutlass.range(cute.size(tdQrdQ.shape[1]), unroll_full=True):\n                if tdQcdQ[0, rest_m, 0][0] < seqlen_q - m_block * self.tile_m:\n                    cute.copy(\n                        gmem_tiled_copy_dQ,\n                        tdQrdQ[None, rest_m, None],\n                        tdQgdQ[None, rest_m, None],\n                        pred=tdQpdQ[None, rest_m, None],\n                    )\n"
  },
  {
    "path": "flash_attn/cute/flash_bwd_preprocess.py",
    "content": "# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n# A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_bwd_preprocess_kernel.h\n# from Cutlass C++ to Cute-DSL.\n#\n# Computes D_i = (dO_i * O_i).sum(dim=-1), optionally adjusted for LSE gradient:\n#   D'_i = D_i - dLSE_i\n# This works because in the backward pass:\n#   dS_ij = P_ij * (dP_ij - D_i)                     [standard]\n# When LSE is differentiable, d(loss)/d(S_ij) gets an extra term dLSE_i * P_ij\n# (since d(LSE_i)/d(S_ij) = P_ij), giving:\n#   dS_ij = P_ij * (dP_ij - D_i) + dLSE_i * P_ij\n#         = P_ij * (dP_ij - (D_i - dLSE_i))\n# So the main backward kernel is unchanged; we just replace D with D' = D - dLSE here.\nimport math\nimport operator\nfrom functools import partial\nfrom typing import Callable, Type, Optional\n\nimport cuda.bindings.driver as cuda\n\nimport cutlass\nimport cutlass.cute as cute\nfrom cutlass import Float32, const_expr\nfrom cutlass.cutlass_dsl import Arch, BaseDSL\n\nfrom quack import copy_utils, layout_utils\n\nfrom flash_attn.cute import utils\nfrom flash_attn.cute.seqlen_info import SeqlenInfo\nfrom quack.cute_dsl_utils import ParamsBase\nfrom flash_attn.cute.tile_scheduler import (\n    SingleTileScheduler,\n    SingleTileVarlenScheduler,\n    TileSchedulerArguments,\n)\n\n\nclass FlashAttentionBackwardPreprocess:\n    def __init__(\n        self,\n        dtype: Type[cutlass.Numeric],\n        head_dim: int,\n        head_dim_v: int,\n        tile_m: int = 128,\n        num_threads: int = 256,\n    ):\n        \"\"\"\n        All contiguous dimensions must be at least 16 bytes aligned which indicates the head dimension\n        should be a multiple of 8.\n\n        :param head_dim: head dimension\n        :type head_dim: int\n        :param tile_m: m block size\n        :type tile_m: int\n        :param num_threads: number of threads\n        :type num_threads: int\n        \"\"\"\n        self.use_pdl = BaseDSL._get_dsl().get_arch_enum() >= Arch.sm_90a\n        self.dtype = dtype\n        self.tile_m = tile_m\n        # padding head_dim to a multiple of 32 as k_block_size\n        hdim_multiple_of = 32\n        self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of)\n        self.head_dim_v_padded = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of)\n        self.check_hdim_v_oob = head_dim_v != self.head_dim_v_padded\n        self.num_threads = num_threads\n\n    @staticmethod\n    def can_implement(dtype, head_dim, tile_m, num_threads) -> bool:\n        \"\"\"Check if the kernel can be implemented with the given parameters.\n\n        :param dtype: data type\n        :type dtype: cutlass.Numeric\n        :param head_dim: head dimension\n        :type head_dim: int\n        :param tile_m: m block size\n        :type tile_m: int\n        :param num_threads: number of threads\n        :type num_threads: int\n\n        :return: True if the kernel can be implemented, False otherwise\n        :rtype: bool\n        \"\"\"\n        if dtype not in [cutlass.Float16, cutlass.BFloat16]:\n            return False\n        if head_dim % 8 != 0:\n            return False\n        if num_threads % 32 != 0:\n            return False\n        if num_threads < tile_m:  # For multiplying lse with log2\n            return False\n        return True\n\n    def _setup_attributes(self):\n        # ///////////////////////////////////////////////////////////////////////////////\n        # GMEM Tiled copy:\n        # ///////////////////////////////////////////////////////////////////////////////\n        # Thread layouts for copies\n        # We want kBlockKGmem to be a power of 2 so that when we do the summing,\n        # it's just between threads in the same warp\n        gmem_k_block_size = (\n            128\n            if self.head_dim_v_padded % 128 == 0\n            else (\n                64\n                if self.head_dim_v_padded % 64 == 0\n                else (32 if self.head_dim_v_padded % 32 == 0 else 16)\n            )\n        )\n        num_copy_elems = 128 // self.dtype.width\n        threads_per_row = gmem_k_block_size // num_copy_elems\n        self.gmem_tiled_copy_O = copy_utils.tiled_copy_2d(\n            self.dtype, threads_per_row, self.num_threads, num_copy_elems\n        )\n        universal_copy_bits = 128\n        num_copy_elems_dQaccum = universal_copy_bits // Float32.width\n        assert (\n            self.tile_m * self.head_dim_padded // num_copy_elems_dQaccum\n        ) % self.num_threads == 0\n        self.gmem_tiled_copy_dQaccum = copy_utils.tiled_copy_1d(\n            Float32, self.num_threads, num_copy_elems_dQaccum\n        )\n\n    @cute.jit\n    def __call__(\n        self,\n        mO: cute.Tensor,  # (batch, seqlen, nheads, head_dim_v) or (total_q, nheads, head_dim_v)\n        mdO: cute.Tensor,  # same shape as mO\n        mPdPsum: cute.Tensor,  # (batch, nheads, seqlen_padded) or (nheads, total_q_padded)\n        mLSE: Optional[cute.Tensor],  # (batch, nheads, seqlen) or (nheads, total_q)\n        mLSElog2: Optional[cute.Tensor],  # same shape as mPdPsum\n        # (batch, nheads, seqlen_padded * head_dim_v) or (nheads, total_q_padded * head_dim_v)\n        mdQaccum: Optional[cute.Tensor],\n        mCuSeqlensQ: Optional[cute.Tensor],  # (batch + 1,)\n        mSeqUsedQ: Optional[cute.Tensor],  # (batch,)\n        mdLSE: Optional[cute.Tensor],  # (batch, nheads, seqlen) or (nheads, total_q)\n        # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI).\n        stream: cuda.CUstream = None,\n    ):\n        # Get the data type and check if it is fp16 or bf16\n        if const_expr(not (mO.element_type == mdO.element_type)):\n            raise TypeError(\"All tensors must have the same data type\")\n        if const_expr(mO.element_type not in [cutlass.Float16, cutlass.BFloat16]):\n            raise TypeError(\"Only Float16 or BFloat16 is supported\")\n        if const_expr(mPdPsum.element_type not in [Float32]):\n            raise TypeError(\"PdPsum tensor must be Float32\")\n        if const_expr(mdQaccum is not None):\n            if const_expr(mdQaccum.element_type not in [Float32]):\n                raise TypeError(\"dQaccum tensor must be Float32\")\n        if const_expr(mLSE is not None):\n            assert mLSElog2 is not None, \"If mLSE is provided, mLSElog2 must also be provided\"\n            if const_expr(mLSE.element_type not in [Float32]):\n                raise TypeError(\"LSE tensor must be Float32\")\n            if const_expr(mLSElog2.element_type not in [Float32]):\n                raise TypeError(\"LSElog2 tensor must be Float32\")\n        if const_expr(mdLSE is not None):\n            if const_expr(mdLSE.element_type not in [Float32]):\n                raise TypeError(\"dLSE tensor must be Float32\")\n\n        self._setup_attributes()\n\n        # (batch, nheads, seqlen) -> (seqlen, nheads, batch) or (total_q, nheads) -> (nheads, total_q)\n        transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0]\n        mPdPsum = layout_utils.select(mPdPsum, transpose)\n        if const_expr(mLSE is not None):\n            mLSE = layout_utils.select(mLSE, transpose)\n            mLSElog2 = layout_utils.select(mLSElog2, transpose)\n        if const_expr(mdLSE is not None):\n            mdLSE = layout_utils.select(mdLSE, transpose)\n        if const_expr(mdQaccum is not None):\n            mdQaccum = layout_utils.select(mdQaccum, transpose)\n\n        if const_expr(mCuSeqlensQ is not None):\n            TileScheduler = SingleTileVarlenScheduler\n            num_head = mO.shape[1]\n            num_batch = mCuSeqlensQ.shape[0] - 1\n        else:\n            TileScheduler = SingleTileScheduler\n            num_head = mO.shape[2]\n            num_batch = mO.shape[0]\n\n        tile_sched_args = TileSchedulerArguments(\n            num_block=cute.ceil_div(mO.shape[1], self.tile_m),\n            num_head=num_head,\n            num_batch=num_batch,\n            num_splits=1,\n            seqlen_k=0,\n            headdim=0,\n            headdim_v=mO.shape[2],\n            total_q=mO.shape[0],\n            tile_shape_mn=(self.tile_m, 1),\n            mCuSeqlensQ=mCuSeqlensQ,\n            mSeqUsedQ=mSeqUsedQ,\n        )\n\n        tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)\n        grid_dim = TileScheduler.get_grid_shape(tile_sched_params)\n\n        self.kernel(\n            mO,\n            mdO,\n            mPdPsum,\n            mLSE,\n            mLSElog2,\n            mdQaccum,\n            mCuSeqlensQ,\n            mSeqUsedQ,\n            mdLSE,\n            self.gmem_tiled_copy_O,\n            self.gmem_tiled_copy_dQaccum,\n            tile_sched_params,\n            TileScheduler,\n        ).launch(\n            grid=grid_dim,\n            block=[self.num_threads, 1, 1],\n            stream=stream,\n            use_pdl=self.use_pdl,\n        )\n\n    @cute.kernel\n    def kernel(\n        self,\n        mO: cute.Tensor,\n        mdO: cute.Tensor,\n        mPdPsum: cute.Tensor,\n        mLSE: Optional[cute.Tensor],\n        mLSElog2: Optional[cute.Tensor],\n        mdQaccum: Optional[cute.Tensor],\n        mCuSeqlensQ: Optional[cute.Tensor],\n        mSeqUsedQ: Optional[cute.Tensor],\n        mdLSE: Optional[cute.Tensor],\n        gmem_tiled_copy_O: cute.TiledCopy,\n        gmem_tiled_copy_dQaccum: cute.TiledCopy,\n        tile_sched_params: ParamsBase,\n        TileScheduler: cutlass.Constexpr[Callable],\n    ):\n        # Thread index, block index\n        tidx, _, _ = cute.arch.thread_idx()\n\n        tile_scheduler = TileScheduler.create(tile_sched_params)\n        work_tile = tile_scheduler.initial_work_tile_info()\n        m_block, head_idx, batch_idx, _ = work_tile.tile_idx\n\n        if work_tile.is_valid_tile:\n            # ///////////////////////////////////////////////////////////////////////////////\n            # Get the appropriate tiles for this thread block.\n            # ///////////////////////////////////////////////////////////////////////////////\n            seqlen = SeqlenInfo.create(\n                batch_idx, mO.shape[1], mCuSeqlensQ, mSeqUsedQ, tile=self.tile_m\n            )\n            mO_cur = seqlen.offset_batch(mO, batch_idx, dim=0)[None, head_idx, None]\n            mdO_cur = seqlen.offset_batch(mdO, batch_idx, dim=0)[None, head_idx, None]\n            mPdPsum_cur = seqlen.offset_batch(mPdPsum, batch_idx, dim=2, padded=True)[\n                None, head_idx\n            ]\n            headdim_v = mO_cur.shape[cute.rank(mO_cur) - 1]\n            seqlen_q = seqlen.seqlen\n            seqlen_q_rounded = cute.round_up(seqlen_q, self.tile_m)\n            seqlen_limit = seqlen_q - m_block * self.tile_m\n\n            lse = None\n            if const_expr(mLSE is not None):\n                mLSE_cur = seqlen.offset_batch(mLSE, batch_idx, dim=2)[None, head_idx]\n                gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (m_block,))\n                lse = Float32.inf\n                if tidx < seqlen_limit:\n                    lse = gLSE[tidx]\n\n            blk_shape = (self.tile_m, self.head_dim_v_padded)\n            gO = cute.local_tile(mO_cur, blk_shape, (m_block, 0))\n            gdO = cute.local_tile(mdO_cur, blk_shape, (m_block, 0))\n            gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx)\n            # (CPY_Atom, CPY_M, CPY_K)\n            tOgO = gmem_thr_copy_O.partition_S(gO)\n            tOgdO = gmem_thr_copy_O.partition_S(gdO)\n            cO = cute.make_identity_tensor(blk_shape)\n            tOcO = gmem_thr_copy_O.partition_S(cO)\n            t0OcO = gmem_thr_copy_O.get_slice(0).partition_S(cO)\n            tOpO = None\n            if const_expr(self.check_hdim_v_oob):\n                tOpO = copy_utils.predicate_k(tOcO, limit=headdim_v)\n            # Each copy will use the same predicate\n            copy = partial(copy_utils.copy, pred=tOpO)\n\n            tOrO = cute.make_rmem_tensor_like(tOgO)\n            tOrdO = cute.make_rmem_tensor_like(tOgdO)\n            if const_expr(self.check_hdim_v_oob):\n                tOrO.fill(0.0)\n                tOrdO.fill(0.0)\n            assert tOgO.shape == tOgdO.shape\n            for m in cutlass.range(cute.size(tOrO.shape[1]), unroll_full=True):\n                # Instead of using tOcO, we using t0OcO and subtract the offset from the limit.\n                # This is bc the entries of t0OcO are known at compile time.\n                if t0OcO[0, m, 0][0] < seqlen_limit - tOcO[0][0]:\n                    copy(tOgO[None, m, None], tOrO[None, m, None])\n                    copy(tOgdO[None, m, None], tOrdO[None, m, None])\n            # O and dO loads are done; signal that the next kernel can start.\n            # Correctness is ensured by griddepcontrol_wait() in bwd_sm90 before it reads our outputs.\n            if const_expr(self.use_pdl):\n                cute.arch.griddepcontrol_launch_dependents()\n            # Sum across the \"k\" dimension\n            pdpsum = (tOrO.load().to(Float32) * tOrdO.load().to(Float32)).reduce(\n                cute.ReductionOp.ADD, init_val=0.0, reduction_profile=(0, None, 1)\n            )\n            threads_per_row = gmem_tiled_copy_O.layout_src_tv_tiled[0].shape[0]\n            assert cute.arch.WARP_SIZE % threads_per_row == 0\n            pdpsum = utils.warp_reduce(pdpsum, operator.add, width=threads_per_row)\n            PdP_sum = cute.make_rmem_tensor(cute.size(tOrO, mode=[1]), Float32)\n            PdP_sum.store(pdpsum)\n\n            # If dLSE is provided, compute D' = D - dLSE (see module docstring for derivation).\n            gdLSE = None\n            if const_expr(mdLSE is not None):\n                mdLSE_cur = seqlen.offset_batch(mdLSE, batch_idx, dim=2)[None, head_idx]\n                gdLSE = cute.local_tile(mdLSE_cur, (self.tile_m,), (m_block,))\n\n            # Write PdPsum from rmem -> gmem\n            gPdPsum = cute.local_tile(mPdPsum_cur, (self.tile_m,), (m_block,))\n            # Only the thread corresponding to column 0 writes out the PdPsum to gmem\n            if tOcO[0, 0, 0][1] == 0:\n                for m in cutlass.range(cute.size(PdP_sum), unroll_full=True):\n                    row = tOcO[0, m, 0][0]\n                    PdPsum_val = 0.0\n                    if row < seqlen_limit:\n                        PdPsum_val = PdP_sum[m]\n                        if const_expr(mdLSE is not None):\n                            PdPsum_val -= gdLSE[row]\n                    gPdPsum[row] = PdPsum_val\n\n            # Clear dQaccum\n            if const_expr(mdQaccum is not None):\n                mdQaccum_cur = seqlen.offset_batch(\n                    mdQaccum, batch_idx, dim=2, padded=True, multiple=self.head_dim_padded\n                )[None, head_idx]\n                blkdQaccum_shape = (self.tile_m * self.head_dim_padded,)\n                gdQaccum = cute.local_tile(mdQaccum_cur, blkdQaccum_shape, (m_block,))\n                gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_slice(tidx)\n                tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum)\n                zero = cute.make_rmem_tensor_like(tdQgdQaccum)\n                zero.fill(0.0)\n                cute.copy(gmem_tiled_copy_dQaccum, zero, tdQgdQaccum)\n\n            if const_expr(mLSE is not None):\n                mLSElog2_cur = seqlen.offset_batch(mLSElog2, batch_idx, dim=2, padded=True)[\n                    None, head_idx\n                ]\n                gLSElog2 = cute.local_tile(mLSElog2_cur, (self.tile_m,), (m_block,))\n                LOG2_E = math.log2(math.e)\n                if tidx < seqlen_q_rounded - m_block * self.tile_m:\n                    gLSElog2[tidx] = lse * LOG2_E if lse != -Float32.inf else 0.0\n"
  },
  {
    "path": "flash_attn/cute/flash_bwd_sm100.py",
    "content": "# Copyright (c) 2025, Ted Zadouri, Markus Hoehnerbach, Jay Shah, Tri Dao.\nimport math\nfrom typing import Callable, Optional\nfrom functools import partial\n\nimport cuda.bindings.driver as cuda\n\nimport cutlass\nimport cutlass.cute as cute\nfrom cutlass.cute import FastDivmodDivisor\nfrom cutlass import Float32, Int32, Int64, const_expr\nfrom cutlass.utils import LayoutEnum\nfrom cutlass.cute.nvgpu import cpasync, tcgen05\nimport cutlass.utils.blackwell_helpers as sm100_utils_basic\nfrom cutlass.pipeline import PipelineAsync\n\nimport quack.activation\nfrom quack import layout_utils\nfrom flash_attn.cute import utils\nfrom flash_attn.cute.cute_dsl_utils import assume_tensor_aligned\nfrom flash_attn.cute import copy_utils\nfrom flash_attn.cute import pipeline\nfrom flash_attn.cute.blackwell_helpers import gemm_w_idx, gemm_ptx_w_idx  # noqa\nfrom flash_attn.cute.mask import AttentionMask\nfrom flash_attn.cute.seqlen_info import SeqlenInfoQK\nfrom flash_attn.cute.block_info import BlockInfo\nfrom quack.cute_dsl_utils import ParamsBase\nfrom flash_attn.cute.tile_scheduler import (\n    TileSchedulerArguments,\n    SingleTileScheduler,\n    SingleTileLPTBwdScheduler,  # noqa\n    SingleTileVarlenScheduler,\n)\n\nfrom flash_attn.cute import barrier\nfrom flash_attn.cute.named_barrier import NamedBarrierBwdSm100\nfrom flash_attn.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner\nfrom flash_attn.cute.block_sparsity import BlockSparseTensors\nfrom flash_attn.cute.block_sparse_utils import (\n    get_total_q_block_count_bwd,\n    get_block_sparse_iteration_info_bwd,\n    get_m_block_from_iter_bwd,\n    produce_block_sparse_q_loads_bwd_sm100,\n)\n\n\nclass FlashAttentionBackwardSm100:\n    arch = 100\n\n    def __init__(\n        self,\n        head_dim: int,\n        head_dim_v: Optional[int] = None,\n        is_causal: bool = False,\n        is_local: bool = False,\n        qhead_per_kvhead: cutlass.Constexpr[int] = 1,\n        tile_m: int = 128,\n        tile_n: int = 128,\n        is_persistent: bool = False,\n        deterministic: bool = False,\n        cluster_size: int = 1,\n        use_2cta_instrs: bool = False,\n        score_mod: cutlass.Constexpr | None = None,\n        score_mod_bwd: cutlass.Constexpr | None = None,\n        mask_mod: cutlass.Constexpr | None = None,\n        has_aux_tensors: cutlass.Constexpr = False,\n        subtile_factor: cutlass.Constexpr[int] = 1,\n    ):\n        # padding head_dim to a multiple of 16 as k_block_size\n        hdim_multiple_of = 16\n        self.tile_hdim = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of)\n        head_dim_v = head_dim_v if head_dim_v is not None else head_dim\n        self.same_hdim_kv = head_dim == head_dim_v\n        self.tile_hdimv = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of)\n        self.check_hdim_oob = head_dim != self.tile_hdim\n        self.check_hdim_v_oob = head_dim_v != self.tile_hdimv\n\n        self.tile_m = tile_m\n        self.tile_n = tile_n\n\n        assert self.tile_hdim <= 128 or (self.tile_hdim == 192 and self.tile_hdimv == 128)\n        assert self.tile_hdimv <= 128\n\n        self.use_2cta_instrs = bool(\n            use_2cta_instrs\n            and cluster_size == 2\n            and score_mod is None\n            and score_mod_bwd is None\n            and mask_mod is None\n        )\n        self.cta_group_size = 2 if self.use_2cta_instrs else 1\n\n        assert self.tile_hdim != 192 or self.use_2cta_instrs, \"Must use 2CTA for hdim 192\"\n\n        # CTA tiler\n        self.cta_tiler = (tile_n, tile_m, self.tile_hdim)\n        # S = K @ Q.T\n        self.mma_tiler_kq = (self.cta_group_size * tile_n, tile_m, self.tile_hdim)\n        # dP = V @ dO.T\n        self.mma_tiler_vdo = (self.cta_group_size * tile_n, tile_m, self.tile_hdimv)\n        # dV = P.T @ dO\n        self.mma_tiler_pdo = (self.cta_group_size * tile_n, self.tile_hdimv, tile_m)\n        # dK = dS.T @ Q\n        self.mma_tiler_dsq = (self.cta_group_size * tile_n, self.tile_hdim, tile_m)\n        # dQ = dS @ K\n        # 2-CTA: reduction dim is cluster-wide (tile_n * cta_group_size).\n        self.mma_tiler_dsk = (tile_m, self.tile_hdim, tile_n * self.cta_group_size)\n\n        self.acc_dtype = Float32\n\n        assert cluster_size in (1, 2), \"Only cluster_size=1 or 2 is supported\"\n        self.cluster_shape_mn = (cluster_size, 1)\n        self.is_persistent = is_persistent\n        self.is_causal = is_causal\n        self.is_local = is_local\n        self.qhead_per_kvhead = qhead_per_kvhead\n        self.pack_gqa = False\n        self.deterministic = deterministic\n\n        # Score mod and mask mod support\n        self.score_mod = score_mod\n        self.score_mod_bwd = score_mod_bwd\n        self.mask_mod = mask_mod\n        self.has_aux_tensors = has_aux_tensors\n        self.subtile_factor = subtile_factor\n        # For score_mod, use vec_size=1 (like forward) to handle per-element indices\n        if cutlass.const_expr(has_aux_tensors):\n            self.vec_size: cutlass.Constexpr = 1\n        else:\n            self.vec_size: cutlass.Constexpr = 4\n        self.qk_acc_dtype = Float32\n\n        # Speed optimizations, does not affect correctness\n        self.shuffle_LSE = False\n        self.shuffle_dPsum = False\n        # Generally slower to use store dS in smem for dK, and doesn't work for 2cta\n        self.use_smem_dS_for_mma_dK = False\n\n        self.reduce_warp_ids = (0, 1, 2, 3)\n        self.compute_warp_ids = (4, 5, 6, 7, 8, 9, 10, 11)\n        self.mma_warp_id = 12\n        self.load_warp_id = 13\n        self.relay_warp_id = 14\n        self.empty_warp_id = 15\n\n        # 16 warps -> 512 threads\n        self.threads_per_cta = cute.arch.WARP_SIZE * len(\n            (\n                *self.reduce_warp_ids,\n                *self.compute_warp_ids,\n                self.mma_warp_id,\n                self.load_warp_id,\n                self.relay_warp_id,\n                self.empty_warp_id,\n            )\n        )\n        # NamedBarrier\n        self.compute_sync_barrier = cutlass.pipeline.NamedBarrier(\n            barrier_id=int(NamedBarrierBwdSm100.Compute),\n            num_threads=len(self.compute_warp_ids) * cute.arch.WARP_SIZE,\n        )\n        # self.epilogue_sync_barrier = pipeline.NamedBarrier(\n        #     barrier_id=2,\n        #     num_threads=self.num_compute_warps * self.threads_per_warp,\n        # )\n        self.reduce_sync_barrier = cutlass.pipeline.NamedBarrier(\n            barrier_id=int(NamedBarrierBwdSm100.dQaccReduce),\n            num_threads=len(self.reduce_warp_ids) * cute.arch.WARP_SIZE,\n        )\n        # TMEM setup\n        self.tmem_alloc_cols = cute.arch.get_max_tmem_alloc_cols(\"sm_100\")\n        # self.tmem_dK_offset = 0\n        # self.tmem_dV_offset = self.tmem_dK_offset + self.tile_hdim\n        # self.tmem_dQ_offset = self.tmem_dV_offset + self.tile_hdimv\n        # self.tmem_dP_offset = self.tmem_dQ_offset  # overlap with dQ\n        # self.tmem_S_offset = self.tmem_dQ_offset + max(self.tile_m, self.tile_hdim)\n        # self.tmem_P_offset = self.tmem_S_offset  # overlap with S\n        # self.tmem_total = self.tmem_S_offset + self.tile_n\n        # assert self.tmem_total <= self.tmem_alloc_cols\n\n        if self.use_2cta_instrs and self.tile_hdim == 192 and self.tile_hdimv == 128:\n            assert self.tile_m == 128\n            assert self.tile_n == 128\n            self.tmem_dV_offset = 0\n            self.tmem_dK_offset = self.tmem_dV_offset + self.tile_hdimv\n            self.tmem_S_offset = self.tmem_dK_offset + self.tile_hdim\n            self.tmem_P_offset = self.tmem_S_offset  # overlap with S\n            self.tmem_dP_offset = 512 - self.tile_m\n            self.tmem_dS_offset = self.tmem_dP_offset  # overlaps with dP\n            self.tmem_dQ_offset = 512 - self.tile_hdim // 2\n        else:\n            self.tmem_S_offset = 0\n            self.tmem_P_offset = 0  # overlap with S\n            self.tmem_dV_offset = self.tmem_S_offset + self.tile_n\n            self.tmem_dP_offset = self.tmem_dV_offset + self.tile_hdimv\n            self.tmem_dQ_offset = (\n                (self.tmem_S_offset + (self.tile_hdim // 2))\n                if self.use_2cta_instrs\n                else self.tmem_dP_offset\n            )\n            self.tmem_dK_offset = self.tmem_dP_offset + self.tile_m\n            self.tmem_dS_offset = self.tmem_dP_offset  # overlap with dP\n\n        if (not is_causal and not is_local) or deterministic:\n            self.num_regs_reduce = 136 if self.use_2cta_instrs else 152\n            self.num_regs_compute = 136\n            self.num_regs_load = 104 if self.use_2cta_instrs else 96 - 8\n            self.num_regs_mma = 104 if self.use_2cta_instrs else self.num_regs_load\n        else:\n            self.num_regs_reduce = 136 if self.use_2cta_instrs else 136\n            self.num_regs_compute = 136 if self.use_2cta_instrs else 144\n            self.num_regs_load = 104 if self.use_2cta_instrs else 96 - 8\n            self.num_regs_mma = 104 if self.use_2cta_instrs else self.num_regs_load\n        self.num_regs_empty = 24\n\n        if const_expr(self.tile_hdim == 192):\n            if not is_causal and not is_local:\n                self.num_regs_reduce = 128 + 8\n                self.num_regs_compute = 128 + 8\n                self.num_regs_load = 128 - 24\n                self.num_regs_mma = self.num_regs_load\n            else:\n                self.num_regs_reduce = 128 + 8\n                self.num_regs_compute = 128 + 8\n                self.num_regs_load = 128 - 24\n                self.num_regs_mma = self.num_regs_load\n\n        assert (\n            self.num_regs_reduce\n            + self.num_regs_compute * 2\n            + max(self.num_regs_load, self.num_regs_mma)\n            <= 512\n        )\n        self.buffer_align_bytes = 1024\n\n    def _setup_attributes(self):\n        self.Q_stage = 1 if self.use_2cta_instrs else 2\n        self.dO_stage = 1\n        self.single_stage = 1\n        # LSE_stage = Q_stage and dPsum_stage = dO_stage\n        self.sdKVaccum_stage = 2\n        # number of tma reduce adds per dQacc mma\n        # todo: try 32/1 or 48/2 for 2cta d=192 dv=128\n        if self.use_2cta_instrs and self.tile_hdim == 192:\n            self.dQ_reduce_ncol_t2r = 32\n            self.dQ_reduce_ncol = 24 if not self.is_causal else 32\n            self.sdQaccum_stage = 2 if not self.is_causal else 1\n        else:\n            if self.use_2cta_instrs:\n                self.dQ_reduce_ncol = 16 if self.deterministic else 8\n                self.sdQaccum_stage = 2 if self.deterministic else 4\n                self.dQ_reduce_ncol_t2r = 32\n            else:\n                self.dQ_reduce_ncol = 32\n                self.sdQaccum_stage = 64 // self.dQ_reduce_ncol\n                self.dQ_reduce_ncol_t2r = self.dQ_reduce_ncol\n        assert (self.tile_hdim // self.cta_group_size) % self.dQ_reduce_ncol == 0\n        self.dQaccum_reduce_stage = self.tile_hdim // self.dQ_reduce_ncol\n        self.dQaccum_reduce_stage_t2r = self.tile_hdim // self.dQ_reduce_ncol_t2r\n        self.cluster_reduce_dQ = False and cute.size(self.cluster_shape_mn) > 1\n        # number of tma reduce adds for dKacc and dVacc epilogue (must divide hdim_per_wg)\n        self.dK_reduce_ncol = math.gcd(32, self.tile_hdim // 2)\n        # CTA group for MMA operations\n        self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE\n\n    def _get_tiled_mma(self):\n        # S.T = K @ Q.T\n        tiled_mma_S = sm100_utils_basic.make_trivial_tiled_mma(\n            self.q_dtype,\n            tcgen05.OperandMajorMode.K,\n            tcgen05.OperandMajorMode.K,\n            self.acc_dtype,\n            self.cta_group,\n            self.mma_tiler_kq[:2],\n        )\n        # dP.T = V @ dO.T\n        tiled_mma_dP = sm100_utils_basic.make_trivial_tiled_mma(\n            self.do_dtype,\n            tcgen05.OperandMajorMode.K,\n            tcgen05.OperandMajorMode.K,\n            self.acc_dtype,\n            self.cta_group,\n            self.mma_tiler_vdo[:2],\n        )\n        # dV += P.T @ dO --> (K, MN) major\n        tiled_mma_dV = sm100_utils_basic.make_trivial_tiled_mma(\n            self.do_dtype,\n            tcgen05.OperandMajorMode.K,  # P_major_mode\n            tcgen05.OperandMajorMode.MN,  # dO_major_mode\n            self.acc_dtype,\n            self.cta_group,\n            self.mma_tiler_pdo[:2],\n            a_source=tcgen05.OperandSource.TMEM,\n        )\n        # dK += dS.T @ Q\n        if const_expr(self.use_smem_dS_for_mma_dK):\n            mma_dK_a_src = tcgen05.OperandSource.SMEM\n        else:\n            mma_dK_a_src = tcgen05.OperandSource.TMEM\n        tiled_mma_dK = sm100_utils_basic.make_trivial_tiled_mma(\n            self.do_dtype,\n            tcgen05.OperandMajorMode.K,  # dS_major_mode\n            tcgen05.OperandMajorMode.MN,  # Q_major_mode\n            self.acc_dtype,\n            self.cta_group,\n            self.mma_tiler_dsq[:2],\n            a_source=mma_dK_a_src,\n        )\n        # dQ = dS @ K\n        tiled_mma_dQ = sm100_utils_basic.make_trivial_tiled_mma(\n            self.k_dtype,\n            tcgen05.OperandMajorMode.MN,  # dS_major_mode\n            tcgen05.OperandMajorMode.MN,  # Kt_major_mode\n            self.acc_dtype,\n            self.cta_group,\n            self.mma_tiler_dsk[:2],\n        )\n        return tiled_mma_S, tiled_mma_dP, tiled_mma_dK, tiled_mma_dV, tiled_mma_dQ\n\n    def _setup_smem_layout(self):\n        # S.T = K @ Q.T\n        sK_layout = sm100_utils_basic.make_smem_layout_a(\n            self.tiled_mma_S,\n            self.mma_tiler_kq,\n            self.k_dtype,\n            1,\n        )\n        self.sK_layout = cute.slice_(sK_layout, (None, None, None, 0))\n        self.sQ_layout = sm100_utils_basic.make_smem_layout_b(\n            self.tiled_mma_S,\n            self.mma_tiler_kq,\n            self.q_dtype,\n            self.Q_stage,\n        )\n        # dP.T = V @ dO.T\n        sV_layout = sm100_utils_basic.make_smem_layout_a(\n            self.tiled_mma_dP,\n            self.mma_tiler_vdo,\n            self.v_dtype,\n            1,\n        )\n        self.sV_layout = cute.slice_(sV_layout, (None, None, None, 0))\n        self.sdOt_layout = sm100_utils_basic.make_smem_layout_b(\n            self.tiled_mma_dP,\n            self.mma_tiler_vdo,\n            self.do_dtype,\n            self.dO_stage,\n        )\n        # dV += P.T @ dO\n        tP_layout = sm100_utils_basic.make_smem_layout_a(\n            self.tiled_mma_dV,\n            self.mma_tiler_pdo,\n            self.do_dtype,\n            1,\n        )\n        self.tP_layout = cute.slice_(tP_layout, (None, None, None, 0))\n        self.sdO_layout = sm100_utils_basic.make_smem_layout_b(\n            self.tiled_mma_dV,\n            self.mma_tiler_pdo,\n            self.do_dtype,\n            self.dO_stage,\n        )\n        # dK += dS.T @ Q\n        sdSt_layout = sm100_utils_basic.make_smem_layout_a(\n            self.tiled_mma_dK,\n            self.mma_tiler_dsq,\n            self.ds_dtype,\n            1,\n        )\n        self.sdSt_layout = cute.slice_(sdSt_layout, (None, None, None, 0))\n        tdS_layout = sm100_utils_basic.make_smem_layout_a(\n            self.tiled_mma_dK,\n            self.mma_tiler_dsq,\n            self.ds_dtype,\n            1,\n        )\n        self.tdS_layout = cute.slice_(tdS_layout, (None, None, None, 0))\n        self.sQt_layout = sm100_utils_basic.make_smem_layout_b(\n            self.tiled_mma_dK,\n            self.mma_tiler_dsq,\n            self.q_dtype,\n            self.Q_stage,\n        )\n        # dQ = dS @ K\n        sdS_layout = sm100_utils_basic.make_smem_layout_a(\n            self.tiled_mma_dQ,\n            self.mma_tiler_dsk,\n            self.ds_dtype,\n            1,\n        )\n        self.sdS_layout = cute.slice_(sdS_layout, (None, None, None, 0))\n        sKt_layout = sm100_utils_basic.make_smem_layout_b(\n            self.tiled_mma_dQ,\n            self.mma_tiler_dsk,\n            self.k_dtype,\n            1,\n        )\n        self.sKt_layout = cute.slice_(sKt_layout, (None, None, None, 0))\n        self.sdS_xchg_layout = cute.make_layout(shape=(self.tile_n, self.tile_m // 2))\n\n        self.sdQaccum_layout = cute.make_layout(\n            (self.tile_m * self.dQ_reduce_ncol, self.sdQaccum_stage)\n        )\n        self.sLSE_layout = cute.make_layout(\n            shape=(self.tile_m, self.Q_stage), stride=(1, cute.round_up(self.tile_m, 64))\n        )\n        self.sdPsum_layout = cute.make_layout(\n            shape=(self.tile_m, self.dO_stage),\n            stride=(1, cute.round_up(self.tile_m, 64)),\n        )\n        self.sdK_epi_tile = (\n            self.tile_n,\n            math.gcd(128 // (self.dk_dtype.width // 8), self.tile_hdim // 2),  # 64 or 32\n        )  # subtiles mma_tiler_dsq[:2] = mma_tiler_pdo[:2]\n        self.sdV_epi_tile = (\n            self.tile_n,\n            math.gcd(128 // (self.dk_dtype.width // 8), self.tile_hdimv // 2),  # 64 or 32\n        )  # subtiles mma_tiler_dsq[:2] = mma_tiler_pdo[:2]\n        # headdim_64 gets 1 stage\n        self.num_epi_stages = max(1, (self.tile_hdim // 2) // self.sdK_epi_tile[1])\n        self.num_epi_stages_v = max(1, (self.tile_hdimv // 2) // self.sdV_epi_tile[1])\n        self.sdK_flat_epi_tile = self.tile_n * (self.tile_hdim // 2) // self.num_epi_stages\n        self.sdV_flat_epi_tile = self.tile_n * (self.tile_hdimv // 2) // self.num_epi_stages_v\n        if const_expr(not self.dKV_postprocess):\n            self.sdK_layout = sm100_utils_basic.make_smem_layout_epi(\n                self.dk_dtype,\n                LayoutEnum.ROW_MAJOR,\n                self.sdK_epi_tile,\n                2,  # num compute wgs\n            )\n            self.sdV_layout = sm100_utils_basic.make_smem_layout_epi(\n                self.dv_dtype,\n                LayoutEnum.ROW_MAJOR,\n                self.sdV_epi_tile,\n                2,  # num compute wgs\n            )\n        else:\n            self.sdK_layout = cute.make_layout((self.tile_n * self.dK_reduce_ncol, 2))\n            # self.dK_reduce_ncol same for dV\n            self.sdV_layout = cute.make_layout((self.tile_n * self.dK_reduce_ncol, 2))\n\n    @cute.jit\n    def __call__(\n        self,\n        mQ: cute.Tensor,\n        mK: cute.Tensor,\n        mV: cute.Tensor,\n        mdO: cute.Tensor,\n        mLSE: cute.Tensor,\n        mdPsum: cute.Tensor,\n        mdQaccum: cute.Tensor,\n        mdK: cute.Tensor,\n        mdV: cute.Tensor,\n        softmax_scale: Float32,\n        mCuSeqlensQ: Optional[cute.Tensor] = None,\n        mCuSeqlensK: Optional[cute.Tensor] = None,\n        mSeqUsedQ: Optional[cute.Tensor] = None,\n        mSeqUsedK: Optional[cute.Tensor] = None,\n        softcap: Float32 | float | None = None,\n        window_size_left: Int32 | int | None = None,\n        window_size_right: Int32 | int | None = None,\n        mdQ_semaphore: Optional[cute.Tensor] = None,\n        mdK_semaphore: Optional[cute.Tensor] = None,\n        mdV_semaphore: Optional[cute.Tensor] = None,\n        aux_tensors: Optional[list] = None,\n        # Block-sparse tensors (Q direction - for iterating m_blocks per n_block):\n        blocksparse_tensors: Optional[BlockSparseTensors] = None,\n        # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI).\n        stream: cuda.CUstream = None,\n    ):\n        self.q_dtype = mQ.element_type\n        self.k_dtype = mK.element_type\n        self.v_dtype = mV.element_type\n        self.do_dtype = mdO.element_type\n        self.lse_dtype = mLSE.element_type\n        self.dpsum_dtype = mdPsum.element_type\n        self.dqaccum_dtype = mdQaccum.element_type\n        self.dk_dtype = mdK.element_type\n        self.dv_dtype = mdV.element_type\n        self.ds_dtype = self.q_dtype\n\n        self.is_varlen_k = mCuSeqlensK is not None or mSeqUsedK is not None\n        self.is_varlen_q = mCuSeqlensQ is not None or mSeqUsedQ is not None\n        self.use_tma_store = not (self.qhead_per_kvhead == 1 and mCuSeqlensK is not None)\n        # self.use_tma_store = not self.qhead_per_kvhead == 1\n        self.dKV_postprocess = self.qhead_per_kvhead > 1\n\n        if const_expr(self.dKV_postprocess):\n            assert self.dk_dtype.width == 32, \"Must accumulate dK in float precision for GQA\"\n            assert self.dv_dtype.width == 32, \"Must accumulate dV in float precision for GQA\"\n\n        mdQaccum, mdK, mdV = [assume_tensor_aligned(t) for t in (mdQaccum, mdK, mdV)]\n\n        # (b, s, n, h) --> (s, h, n, b) or (t, n, h) -> (t, h, n)\n        QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1]\n        mQ, mdO = [layout_utils.select(t, mode=QO_layout_transpose) for t in (mQ, mdO)]\n\n        KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1]\n        mK, mV = [layout_utils.select(t, mode=KV_layout_transpose) for t in (mK, mV)]\n\n        # (b, n, s) --> (s, n, b) or (n, t) --> (t, n)\n        LSE_dPsum_dQaccum_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0]\n        mLSE, mdPsum, mdQaccum = [\n            layout_utils.select(t, mode=LSE_dPsum_dQaccum_transpose)\n            for t in (mLSE, mdPsum, mdQaccum)\n        ]\n\n        if const_expr(not self.dKV_postprocess):\n            layout_dKV_transpose = KV_layout_transpose\n        else:\n            layout_dKV_transpose = [2, 1, 0] if const_expr(mCuSeqlensK is None) else [1, 0]\n        mdK, mdV = [layout_utils.select(t, mode=layout_dKV_transpose) for t in (mdK, mdV)]\n        # (s, h, n, b) --> (h, s, n, b) or (t, h, n) -> (h, t, b)\n        dO_transpose = [1, 0, 2, 3] if const_expr(mCuSeqlensQ is None) else [1, 0, 2]\n        mdO = layout_utils.select(mdO, mode=dO_transpose)\n\n        # Transposes for 2-CTA K/Q paths (Q follows Q seqlens, K follows K seqlens)\n        transpose_sh_q = dO_transpose\n        transpose_sh_k = [1, 0, 2, 3] if const_expr(mCuSeqlensK is None) else [1, 0, 2]\n\n        # (b, n, block, stage) -> (block, stage, n, b)\n        semaphore_transpose = [2, 3, 1, 0]\n        if const_expr(self.deterministic):\n            assert mdQ_semaphore is not None\n            mdQ_semaphore = layout_utils.select(mdQ_semaphore, mode=semaphore_transpose)\n\n        if const_expr(self.deterministic and self.qhead_per_kvhead > 1):\n            assert mdK_semaphore is not None\n            assert mdV_semaphore is not None\n            mdK_semaphore, mdV_semaphore = [\n                layout_utils.select(t, mode=semaphore_transpose)\n                for t in (mdK_semaphore, mdV_semaphore)\n            ]\n        else:\n            mdK_semaphore = None\n            mdV_semaphore = None\n\n        self._setup_attributes()\n        (\n            self.tiled_mma_S,\n            self.tiled_mma_dP,\n            self.tiled_mma_dK,\n            self.tiled_mma_dV,\n            self.tiled_mma_dQ,\n        ) = self._get_tiled_mma()\n        self._setup_smem_layout()\n\n        self.cluster_shape_mnk = (*self.cluster_shape_mn, 1)\n        self.cluster_layout_vmnk = cute.tiled_divide(\n            cute.make_layout(self.cluster_shape_mnk),\n            (self.tiled_mma_S.thr_id.shape,),\n        )\n        self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1])\n        self.is_q_do_mcast = self.num_mcast_ctas_b > 1\n\n        if const_expr(not self.dKV_postprocess):\n            self.mdK_layout_enum = LayoutEnum.from_tensor(mdK)\n            self.mdV_layout_enum = LayoutEnum.from_tensor(mdV)\n            dK_major_mode = self.mdK_layout_enum.mma_major_mode()\n            dV_major_mode = self.mdV_layout_enum.mma_major_mode()\n            if const_expr(dK_major_mode != tcgen05.OperandMajorMode.K):\n                raise RuntimeError(\"The layout of mdK is wrong\")\n            if const_expr(dV_major_mode != tcgen05.OperandMajorMode.K):\n                raise RuntimeError(\"The layout of mdV is wrong\")\n\n        if const_expr(self.use_tma_store and not self.dKV_postprocess):\n            tma_copy_op_dKV = cpasync.CopyBulkTensorTileS2GOp()\n            tma_atom_dK, mdK_tma_tensor = cpasync.make_tiled_tma_atom(\n                tma_copy_op_dKV,\n                mdK,\n                cute.select(self.sdK_layout, mode=[0, 1]),\n                self.sdK_epi_tile,\n                1,  # no mcast\n            )\n            tma_atom_dV, mdV_tma_tensor = cpasync.make_tiled_tma_atom(\n                tma_copy_op_dKV,\n                mdV,\n                cute.select(self.sdV_layout, mode=[0, 1]),\n                self.sdV_epi_tile,\n                1,  # no mcast\n            )\n        else:\n            mdV_tma_tensor = mdV\n            mdK_tma_tensor = mdK\n            tma_atom_dV = None\n            tma_atom_dK = None\n\n        if const_expr(not self.dKV_postprocess):\n            thr_layout_r2s_dKV = cute.make_ordered_layout((128, 1), order=(1, 0))  # 128 threads\n            val_layout_r2s_dKV = cute.make_ordered_layout(\n                (1, 128 // self.dk_dtype.width), order=(1, 0)\n            )  # 4 or 8 vals for 16 byte store\n            copy_atom_r2s_dKV = cute.make_copy_atom(\n                cute.nvgpu.CopyUniversalOp(),\n                self.dk_dtype,\n                num_bits_per_copy=128,\n            )\n            tiled_copy_r2s_dKV = cute.make_tiled_copy_tv(\n                copy_atom_r2s_dKV, thr_layout_r2s_dKV, val_layout_r2s_dKV\n            )\n        else:\n            tiled_copy_r2s_dKV = copy_utils.tiled_copy_1d(\n                Float32, 128, num_copy_elems=128 // Float32.width\n            )\n\n        tma_load_op = cpasync.CopyBulkTensorTileG2SOp(self.cta_group)\n        # S.T = K @ Q.T\n        tma_atom_K, tma_tensor_K = cute.nvgpu.make_tiled_tma_atom_A(\n            tma_load_op,\n            mK,\n            cute.select(self.sK_layout, mode=[0, 1, 2]),\n            self.mma_tiler_kq,\n            self.tiled_mma_S,\n            self.cluster_layout_vmnk.shape,\n        )\n        Q_tma_op = sm100_utils_basic.cluster_shape_to_tma_atom_B(\n            self.cluster_shape_mnk, self.tiled_mma_S.thr_id\n        )\n        tma_atom_Q, tma_tensor_Q = cute.nvgpu.make_tiled_tma_atom_B(\n            Q_tma_op,\n            mQ,\n            cute.select(self.sQ_layout, mode=[0, 1, 2]),\n            self.mma_tiler_kq,\n            self.tiled_mma_S,\n            self.cluster_layout_vmnk.shape,\n        )\n        # dP.T = V @ dO.T\n        tma_atom_V, tma_tensor_V = cute.nvgpu.make_tiled_tma_atom_A(\n            tma_load_op,\n            mV,\n            cute.select(self.sV_layout, mode=[0, 1, 2]),\n            self.mma_tiler_vdo,\n            self.tiled_mma_dP,\n            self.cluster_layout_vmnk.shape,\n        )\n        # dV = P.T @ dO\n        dO_tma_op = sm100_utils_basic.cluster_shape_to_tma_atom_B(\n            self.cluster_shape_mnk, self.tiled_mma_dV.thr_id\n        )\n        tma_atom_dO, tma_tensor_dO = cute.nvgpu.make_tiled_tma_atom_B(\n            dO_tma_op,\n            mdO,\n            cute.select(self.sdO_layout, mode=[0, 1, 2]),\n            self.mma_tiler_pdo,\n            self.tiled_mma_dV,\n            self.cluster_layout_vmnk.shape,\n        )\n        # ------------------------------------------------------------\n        # 2-CTA\n        # ------------------------------------------------------------\n        tma_atom_dOt = tma_tensor_dOt = None\n        if const_expr(self.use_2cta_instrs):\n            tma_atom_dOt, tma_tensor_dOt = cute.nvgpu.make_tiled_tma_atom_B(\n                dO_tma_op,\n                layout_utils.select(mdO, mode=transpose_sh_q),\n                cute.select(self.sdOt_layout, mode=[0, 1, 2]),\n                self.mma_tiler_vdo,\n                self.tiled_mma_dP,\n                self.cluster_layout_vmnk.shape,\n            )\n        tma_atom_Qt = tma_tensor_Qt = None\n        if const_expr(self.use_2cta_instrs):\n            tma_atom_Qt, tma_tensor_Qt = cute.nvgpu.make_tiled_tma_atom_B(\n                Q_tma_op,\n                layout_utils.select(mQ, mode=transpose_sh_q),\n                cute.select(self.sQt_layout, mode=[0, 1, 2]),\n                self.mma_tiler_dsq,\n                self.tiled_mma_dK,\n                self.cluster_layout_vmnk.shape,\n            )\n        tma_atom_Kt = tma_tensor_Kt = None\n        if const_expr(self.use_2cta_instrs):\n            Kt_tma_op = sm100_utils_basic.cluster_shape_to_tma_atom_B(\n                self.cluster_shape_mnk, self.tiled_mma_dQ.thr_id\n            )\n            tma_atom_Kt, tma_tensor_Kt = cute.nvgpu.make_tiled_tma_atom_B(\n                Kt_tma_op,\n                layout_utils.select(mK, mode=transpose_sh_k),\n                cute.select(self.sKt_layout, mode=[0, 1, 2]),\n                self.mma_tiler_dsk,\n                self.tiled_mma_dQ,\n                self.cluster_layout_vmnk.shape,\n            )\n\n        self.tma_copy_bytes = {\n            name: self.cta_group_size\n            * cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1, 2]))\n            for name, mX, layout in [\n                (\"Q\", mQ, self.sQ_layout),\n                (\"K\", mK, self.sK_layout),\n                (\"V\", mV, self.sV_layout),\n                (\"dO\", mdO, self.sdO_layout),\n            ]\n        }\n        self.tma_copy_bytes[\"LSE\"] = self.tile_m * Float32.width // 8\n        self.tma_copy_bytes[\"dPsum\"] = self.tile_m * Float32.width // 8\n        self.tma_copy_bytes[\"dQ\"] = self.tile_m * self.dQ_reduce_ncol * Float32.width // 8\n        self.tma_copy_bytes[\"dKacc\"] = self.tile_n * self.dK_reduce_ncol * Float32.width // 8\n        self.tma_copy_bytes[\"dS\"] = cute.size_in_bytes(self.ds_dtype, self.sdS_layout)\n        self.tma_copy_bytes[\"sdS_xchg\"] = self.tma_copy_bytes[\"dS\"] // 2  # Half of dS for exchange\n\n        # TileScheduler = SingleTileScheduler\n        if const_expr(self.is_varlen_k):\n            TileScheduler = SingleTileVarlenScheduler\n        elif const_expr(self.deterministic):\n            TileScheduler = SingleTileLPTBwdScheduler\n        else:\n            TileScheduler = SingleTileScheduler\n        self.spt = (self.is_causal or self.is_local) and self.deterministic\n        tile_sched_args = TileSchedulerArguments(\n            cute.ceil_div(cute.size(mK.shape[0]), self.cta_tiler[0]),  # num_blocks\n            cute.size(mQ.shape[2]),  # num_heads = num_query_heads\n            cute.size(mK.shape[3])\n            if const_expr(mCuSeqlensK is None)\n            else cute.size(mCuSeqlensK.shape[0] - 1),  # num_batches\n            1,  # num_splits\n            cute.size(mQ.shape[0]),  # pass seqlen_q or total_q for seqlen_k\n            mQ.shape[1],  # headdim\n            mV.shape[1],  # headdim_v\n            total_q=cute.size(mK.shape[0])  # pass total_k for total_q\n            if const_expr(mCuSeqlensK is not None)\n            else cute.size(mK.shape[0]) * cute.size(mK.shape[3]),\n            tile_shape_mn=self.cta_tiler[:2],  # (tile_n, tile_m)\n            cluster_shape_mn=self.cluster_shape_mnk[:2],\n            mCuSeqlensQ=mCuSeqlensK,\n            mSeqUsedQ=mSeqUsedK,\n            qhead_per_kvhead_packgqa=1,  # pack_gqa disabled for bwd\n            element_size=self.k_dtype.width // 8,\n            is_persistent=self.is_persistent,  # persistent mode not tested\n            lpt=self.spt,\n            head_swizzle=self.deterministic,\n        )\n\n        tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)\n        self.tile_scheduler_cls = TileScheduler\n        grid_dim = TileScheduler.get_grid_shape(tile_sched_params)\n\n        # Compute allocation sizes for shared buffers that are reused\n        # sQ is reused for sdK, sdO is reused for sdV\n        sQ_alloc_bytes = max(\n            cute.size_in_bytes(self.q_dtype, self.sQ_layout),\n            cute.size_in_bytes(self.dk_dtype, self.sdK_layout),\n        )\n        sdO_alloc_bytes = max(\n            cute.size_in_bytes(self.dv_dtype, self.sdV_layout),\n            cute.size_in_bytes(self.do_dtype, self.sdO_layout),\n        )\n\n        sdK_bytes = cute.size_in_bytes(self.dk_dtype, self.sdK_layout)\n        sdV_bytes = cute.size_in_bytes(self.dv_dtype, self.sdV_layout)\n        assert sdV_bytes <= sdO_alloc_bytes, \"sdV doesn't fit in sdO storage allocation\"\n        assert sdK_bytes <= sQ_alloc_bytes, \"sdK doesn't fit in sQ storage allocation\"\n        # 2-CTA: sdV reuses sV, sdK reuses sK\n        sV_bytes = cute.size_in_bytes(self.v_dtype, self.sV_layout)\n        sK_bytes = cute.size_in_bytes(self.k_dtype, self.sK_layout)\n        if const_expr(self.use_2cta_instrs):\n            assert sdV_bytes <= sV_bytes, \"sdV doesn't fit in sV storage allocation (2-CTA)\"\n            assert sdK_bytes <= sK_bytes, \"sdK doesn't fit in sK storage allocation (2-CTA)\"\n\n        if const_expr(self.use_2cta_instrs):\n            sQt_size = cute.cosize(self.sQt_layout) if const_expr(self.tile_hdim <= 128) else 0\n            sdOt_size = cute.cosize(self.sdOt_layout) if const_expr(self.tile_hdim <= 128) else 0\n            sdS_xchg_size = (\n                cute.cosize(self.sdS_xchg_layout) if const_expr(self.tile_hdim <= 128) else 0\n            )\n\n            @cute.struct\n            class SharedStorage:\n                Q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.Q_stage]\n                dO_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dO_stage]\n                LSE_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.Q_stage]\n                dPsum_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dO_stage]\n                S_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.single_stage]\n                dP_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.single_stage]\n                dS_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.single_stage]\n                dKV_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.sdKVaccum_stage]\n                dQ_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2]\n                dQ_cluster_full_mbar_ptr: cute.struct.MemRange[\n                    cutlass.Int64, self.dQaccum_reduce_stage // 2\n                ]\n                dQ_cluster_empty_mbar_ptr: cute.struct.MemRange[\n                    cutlass.Int64, self.dQaccum_reduce_stage // 2\n                ]\n                tmem_holding_buf: Int32\n                tmem_dealloc_mbar_ptr: cutlass.Int64\n\n                # 2-CTA\n                Qt_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.Q_stage]\n                Kt_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.single_stage]\n                dS_cluster_empty_mbar_ptr: cutlass.Int64\n                dS_cluster_full_mbar_ptr: cutlass.Int64\n                dS_cluster_leader_mbar_ptr: cutlass.Int64\n                dQaccum_empty_mbar_ptr: cutlass.Int64\n\n                sQ: cute.struct.Align[\n                    cute.struct.MemRange[self.q_dtype, cute.cosize(self.sQ_layout)],\n                    self.buffer_align_bytes,\n                ]\n                sK: cute.struct.Align[\n                    cute.struct.MemRange[self.k_dtype, cute.cosize(self.sK_layout)],\n                    self.buffer_align_bytes,\n                ]\n                sV: cute.struct.Align[\n                    cute.struct.MemRange[self.v_dtype, cute.cosize(self.sV_layout)],\n                    self.buffer_align_bytes,\n                ]\n                sdO: cute.struct.Align[\n                    cute.struct.MemRange[self.do_dtype, cute.cosize(self.sdO_layout)],\n                    self.buffer_align_bytes,\n                ]\n                sQt: cute.struct.Align[\n                    cute.struct.MemRange[self.q_dtype, sQt_size],\n                    self.buffer_align_bytes,\n                ]\n                sdOt: cute.struct.Align[\n                    cute.struct.MemRange[self.do_dtype, sdOt_size],\n                    self.buffer_align_bytes,\n                ]\n                sdS_xchg: cute.struct.Align[\n                    cute.struct.MemRange[self.ds_dtype, sdS_xchg_size],\n                    self.buffer_align_bytes,\n                ]\n                sKt: cute.struct.Align[\n                    cute.struct.MemRange[self.k_dtype, cute.cosize(self.sKt_layout)],\n                    self.buffer_align_bytes,\n                ]\n                sdS: cute.struct.Align[\n                    cute.struct.MemRange[self.ds_dtype, cute.cosize(self.sdSt_layout)],\n                    self.buffer_align_bytes,\n                ]\n                sLSE: cute.struct.Align[\n                    cute.struct.MemRange[self.lse_dtype, cute.cosize(self.sLSE_layout)],\n                    128,\n                ]\n                sdPsum: cute.struct.Align[\n                    cute.struct.MemRange[self.dpsum_dtype, cute.cosize(self.sdPsum_layout)],\n                    128,\n                ]\n                sdQaccum: cute.struct.Align[\n                    cute.struct.MemRange[self.dqaccum_dtype, cute.cosize(self.sdQaccum_layout)],\n                    self.buffer_align_bytes if sdS_xchg_size == 0 else 128,\n                ]\n\n        else:\n\n            @cute.struct\n            class SharedStorage:\n                Q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.Q_stage]\n                dO_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dO_stage]\n                LSE_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.Q_stage]\n                dPsum_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dO_stage]\n                S_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.single_stage]\n                dP_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.single_stage]\n                dS_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.single_stage]\n                dKV_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.sdKVaccum_stage]\n                dQ_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2]\n                dQ_cluster_full_mbar_ptr: cute.struct.MemRange[\n                    cutlass.Int64, self.dQaccum_reduce_stage // 2\n                ]\n                dQ_cluster_empty_mbar_ptr: cute.struct.MemRange[\n                    cutlass.Int64, self.dQaccum_reduce_stage // 2\n                ]\n                tmem_holding_buf: Int32\n                tmem_dealloc_mbar_ptr: Int64\n\n                sQ: cute.struct.Align[\n                    cute.struct.MemRange[cute.Uint8, sQ_alloc_bytes],\n                    self.buffer_align_bytes,\n                ]\n                sK: cute.struct.Align[\n                    cute.struct.MemRange[self.k_dtype, cute.cosize(self.sK_layout)],\n                    self.buffer_align_bytes,\n                ]\n                sV: cute.struct.Align[\n                    cute.struct.MemRange[self.v_dtype, cute.cosize(self.sV_layout)],\n                    self.buffer_align_bytes,\n                ]\n                sdO: cute.struct.Align[\n                    cute.struct.MemRange[cute.Uint8, sdO_alloc_bytes],\n                    self.buffer_align_bytes,\n                ]\n                sdS: cute.struct.Align[\n                    cute.struct.MemRange[self.ds_dtype, cute.cosize(self.sdSt_layout)],\n                    128,\n                ]\n                sLSE: cute.struct.Align[\n                    cute.struct.MemRange[self.lse_dtype, cute.cosize(self.sLSE_layout)],\n                    128,\n                ]\n                sdPsum: cute.struct.Align[\n                    cute.struct.MemRange[self.dpsum_dtype, cute.cosize(self.sdPsum_layout)],\n                    128,\n                ]\n                sdQaccum: cute.struct.Align[\n                    cute.struct.MemRange[self.dqaccum_dtype, cute.cosize(self.sdQaccum_layout)],\n                    self.buffer_align_bytes,\n                ]\n\n        self.shared_storage = SharedStorage\n\n        LOG2_E = math.log2(math.e)\n        if const_expr(self.score_mod is None):\n            # Without score_mod: bake scale into log2\n            softmax_scale_log2 = softmax_scale * LOG2_E\n        else:\n            # With score_mod: score_mod applied to S * softmax_scale, then use LOG2_E only\n            softmax_scale_log2 = LOG2_E\n\n        if const_expr(window_size_left is not None):\n            window_size_left = Int32(window_size_left)\n        if const_expr(window_size_right is not None):\n            window_size_right = Int32(window_size_right)\n\n        fastdiv_mods = None\n        if const_expr(aux_tensors is not None):\n            seqlen_q = cute.size(mQ.shape[0]) // (\n                self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1\n            )\n            seqlen_k = cute.size(mK.shape[0])\n            seqlen_q_divmod = FastDivmodDivisor(seqlen_q)\n            seqlen_k_divmod = FastDivmodDivisor(seqlen_k)\n            fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod)\n        self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None)\n\n        if const_expr(self.use_2cta_instrs):\n            assert blocksparse_tensors is None, (\n                \"2-CTA mode does not support block sparsity. \"\n                \"Please create kernel with use_2cta_instrs=False for block sparse attention.\"\n            )\n        # 2-CTA: 231424 and 1-CTA: 232448\n        # print(\"SMEM: \", self.shared_storage.size_in_bytes())\n        if const_expr(self.use_block_sparsity or aux_tensors is not None):\n            assert all(x is None for x in (mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)), (\n                \"Variable sequence length is not supported yet for blocksparse or aux tensors in bwd\"\n            )\n\n        self.kernel(\n            tma_tensor_Q,\n            tma_tensor_Qt,\n            tma_tensor_K,\n            tma_tensor_Kt,\n            tma_tensor_V,\n            mLSE,\n            mdPsum,\n            tma_tensor_dO,\n            tma_tensor_dOt,\n            mdV,\n            mdK,\n            mdQaccum,\n            mdV_tma_tensor,\n            mdK_tma_tensor,\n            mdQ_semaphore,\n            mdK_semaphore,\n            mdV_semaphore,\n            mCuSeqlensQ,\n            mCuSeqlensK,\n            mSeqUsedQ,\n            mSeqUsedK,\n            tma_atom_Q,\n            tma_atom_Qt,\n            tma_atom_K,\n            tma_atom_Kt,\n            tma_atom_V,\n            tma_atom_dO,\n            tma_atom_dOt,\n            tma_atom_dV,\n            tma_atom_dK,\n            self.sQ_layout,\n            self.sQt_layout,\n            self.sK_layout,\n            self.sKt_layout,\n            self.sV_layout,\n            self.sLSE_layout,\n            self.sdPsum_layout,\n            self.sdO_layout,\n            self.sdOt_layout,\n            self.sdSt_layout,\n            self.sdS_layout,\n            self.sdS_xchg_layout,\n            self.sdQaccum_layout,\n            self.sdK_layout,\n            self.sdV_layout,\n            self.tP_layout,\n            self.tdS_layout,\n            self.tiled_mma_S,\n            self.tiled_mma_dP,\n            self.tiled_mma_dV,\n            self.tiled_mma_dK,\n            self.tiled_mma_dQ,\n            tiled_copy_r2s_dKV,\n            softmax_scale,\n            softmax_scale_log2,\n            window_size_left,\n            window_size_right,\n            tile_sched_params,\n            aux_tensors,\n            fastdiv_mods,\n            blocksparse_tensors,\n        ).launch(\n            grid=grid_dim,\n            block=[self.threads_per_cta, 1, 1],\n            cluster=self.cluster_shape_mnk if cute.size(self.cluster_shape_mnk) > 1 else None,\n            smem=self.shared_storage.size_in_bytes(),\n            stream=stream,\n            min_blocks_per_mp=1,\n        )\n\n    @cute.kernel\n    def kernel(\n        self,\n        mQ: cute.Tensor,\n        mQt: Optional[cute.Tensor],\n        mK: cute.Tensor,\n        mKt: Optional[cute.Tensor],\n        mV: cute.Tensor,\n        mLSE: cute.Tensor,\n        mdPsum: cute.Tensor,\n        mdO: cute.Tensor,\n        mdOt: Optional[cute.Tensor],\n        mdV: cute.Tensor,\n        mdK: cute.Tensor,\n        mdQaccum: cute.Tensor,\n        mdV_tma_tensor: Optional[cute.Tensor],\n        mdK_tma_tensor: Optional[cute.Tensor],\n        mdQ_semaphore: Optional[cute.Tensor],\n        mdK_semaphore: Optional[cute.Tensor],\n        mdV_semaphore: Optional[cute.Tensor],\n        mCuSeqlensQ: Optional[cute.Tensor],\n        mCuSeqlensK: Optional[cute.Tensor],\n        mSeqUsedQ: Optional[cute.Tensor],\n        mSeqUsedK: Optional[cute.Tensor],\n        tma_atom_Q: cute.CopyAtom,\n        tma_atom_Qt: Optional[cute.CopyAtom],\n        tma_atom_K: cute.CopyAtom,\n        tma_atom_Kt: Optional[cute.CopyAtom],\n        tma_atom_V: cute.CopyAtom,\n        tma_atom_dO: cute.CopyAtom,\n        tma_atom_dOt: Optional[cute.CopyAtom],\n        tma_atom_dV: Optional[cute.CopyAtom],\n        tma_atom_dK: Optional[cute.CopyAtom],\n        sQ_layout: cute.ComposedLayout,\n        sQt_layout: cute.ComposedLayout,\n        sK_layout: cute.ComposedLayout,\n        sKt_layout: cute.ComposedLayout,\n        sV_layout: cute.ComposedLayout,\n        sLSE_layout: cute.Layout,\n        sdPsum_layout: cute.Layout,\n        sdO_layout: cute.ComposedLayout,\n        sdOt_layout: cute.ComposedLayout,\n        sdSt_layout: cute.ComposedLayout,\n        sdS_layout: cute.ComposedLayout,\n        sdS_xchg_layout: cute.Layout,\n        sdQaccum_layout: cute.Layout,\n        sdK_layout: cute.ComposedLayout | cute.Layout,\n        sdV_layout: cute.ComposedLayout | cute.Layout,\n        tP_layout: cute.ComposedLayout,\n        tdS_layout: cute.ComposedLayout,\n        tiled_mma_S: cute.TiledMma,\n        tiled_mma_dP: cute.TiledMma,\n        tiled_mma_dV: cute.TiledMma,\n        tiled_mma_dK: cute.TiledMma,\n        tiled_mma_dQ: cute.TiledMma,\n        tiled_copy_r2s_dKV: cute.TiledCopy,\n        softmax_scale: cutlass.Float32,\n        softmax_scale_log2: cutlass.Float32,\n        window_size_left: Optional[Int32],\n        window_size_right: Optional[Int32],\n        tile_sched_params: ParamsBase,\n        aux_tensors: Optional[list] = None,\n        fastdiv_mods=(None, None),\n        blocksparse_tensors: Optional[BlockSparseTensors] = None,\n    ):\n        warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())\n        bidx, _, _ = cute.arch.block_idx()\n        mma_tile_coord_v = bidx % self.cta_group_size\n        is_leader_cta = mma_tile_coord_v == 0\n        cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())\n\n        # Prefetch tma descriptor\n        if warp_idx == self.load_warp_id:\n            with cute.arch.elect_one():\n                cpasync.prefetch_descriptor(tma_atom_Q)\n                if const_expr(tma_atom_Qt is not None):\n                    cpasync.prefetch_descriptor(tma_atom_Qt)\n                cpasync.prefetch_descriptor(tma_atom_K)\n                if const_expr(tma_atom_Kt is not None):\n                    cpasync.prefetch_descriptor(tma_atom_Kt)\n                cpasync.prefetch_descriptor(tma_atom_V)\n                if const_expr(tma_atom_dOt is not None):\n                    cpasync.prefetch_descriptor(tma_atom_dOt)\n                cpasync.prefetch_descriptor(tma_atom_dO)\n                if const_expr(tma_atom_dV is not None):\n                    cpasync.prefetch_descriptor(tma_atom_dV)\n                if const_expr(tma_atom_dK is not None):\n                    cpasync.prefetch_descriptor(tma_atom_dK)\n\n        cluster_layout_vmnk = cute.tiled_divide(\n            cute.make_layout(self.cluster_shape_mnk),\n            (tiled_mma_S.thr_id.shape,),\n        )\n\n        # Alloc\n        smem = cutlass.utils.SmemAllocator()\n        storage = smem.allocate(self.shared_storage)\n\n        dQ_cluster_full_mbar_ptr = storage.dQ_cluster_full_mbar_ptr.data_ptr()\n        dQ_cluster_empty_mbar_ptr = storage.dQ_cluster_empty_mbar_ptr.data_ptr()\n\n        if const_expr(self.use_2cta_instrs):\n            dS_cluster_full_mbar_ptr = storage.dS_cluster_full_mbar_ptr\n            dS_cluster_empty_mbar_ptr = storage.dS_cluster_empty_mbar_ptr\n            dS_cluster_leader_mbar_ptr = storage.dS_cluster_leader_mbar_ptr\n            dQaccum_empty_mbar_ptr = storage.dQaccum_empty_mbar_ptr\n        else:\n            dS_cluster_full_mbar_ptr = None\n            dS_cluster_empty_mbar_ptr = None\n            dS_cluster_leader_mbar_ptr = None\n            dQaccum_empty_mbar_ptr = None\n\n        # Barrier initialization\n        if const_expr(self.use_2cta_instrs):\n            if const_expr(self.tile_hdim == 192):\n                if warp_idx == 2:\n                    cute.arch.mbarrier_init(\n                        dQaccum_empty_mbar_ptr,\n                        len(self.reduce_warp_ids),\n                    )\n            if warp_idx == 4:\n                cute.arch.mbarrier_init(dS_cluster_full_mbar_ptr, 1)\n                cute.arch.mbarrier_init(dS_cluster_empty_mbar_ptr, 1)\n                cute.arch.mbarrier_init(dS_cluster_leader_mbar_ptr, 2)\n\n        if const_expr(self.cluster_reduce_dQ):\n            if warp_idx == 4:\n                for i in range(self.dQaccum_reduce_stage // 2):\n                    cute.arch.mbarrier_init(dQ_cluster_full_mbar_ptr + i, 1)\n                    cute.arch.mbarrier_init(dQ_cluster_empty_mbar_ptr + i, 1)\n\n        tmem_alloc_barrier = cutlass.pipeline.NamedBarrier(\n            barrier_id=int(NamedBarrierBwdSm100.TmemPtr),\n            num_threads=cute.arch.WARP_SIZE\n            * len((self.mma_warp_id, *self.compute_warp_ids, *self.reduce_warp_ids)),\n        )\n        tmem = cutlass.utils.TmemAllocator(\n            storage.tmem_holding_buf,\n            barrier_for_retrieve=tmem_alloc_barrier,\n            allocator_warp_id=self.mma_warp_id,\n            is_two_cta=self.use_2cta_instrs,\n            two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr,\n        )\n\n        # UMMA producers and AsyncThread consumers\n        pipeline_producer_group_MMA_AsyncThread = cutlass.pipeline.CooperativeGroup(\n            cutlass.pipeline.Agent.Thread, len([self.mma_warp_id])\n        )\n        pipeline_consumer_group_MMA_AsyncThread = cutlass.pipeline.CooperativeGroup(\n            cutlass.pipeline.Agent.Thread, len(self.compute_warp_ids) * self.cta_group_size\n        )\n        pipeline_S_P = cutlass.pipeline.PipelineUmmaAsync.create(\n            num_stages=1,\n            producer_group=pipeline_producer_group_MMA_AsyncThread,\n            consumer_group=pipeline_consumer_group_MMA_AsyncThread,\n            barrier_storage=storage.S_mbar_ptr.data_ptr(),\n            cta_layout_vmnk=cluster_layout_vmnk,\n        )\n        pipeline_dP = cutlass.pipeline.PipelineUmmaAsync.create(\n            num_stages=1,\n            producer_group=pipeline_producer_group_MMA_AsyncThread,\n            consumer_group=pipeline_consumer_group_MMA_AsyncThread,\n            barrier_storage=storage.dP_mbar_ptr.data_ptr(),\n            cta_layout_vmnk=cluster_layout_vmnk,\n        )\n        pipeline_dKV = cutlass.pipeline.PipelineUmmaAsync.create(\n            num_stages=2,\n            producer_group=pipeline_producer_group_MMA_AsyncThread,\n            consumer_group=pipeline_consumer_group_MMA_AsyncThread,\n            barrier_storage=storage.dKV_mbar_ptr.data_ptr(),\n            cta_layout_vmnk=cluster_layout_vmnk,\n        )\n        pipeline_consumer_group_MMA_AsyncThread_dQ = cutlass.pipeline.CooperativeGroup(\n            cutlass.pipeline.Agent.Thread,\n            len(self.reduce_warp_ids) * self.cta_group_size,\n        )  # Compute\n        pipeline_dQ = cutlass.pipeline.PipelineUmmaAsync.create(\n            num_stages=1,\n            producer_group=pipeline_producer_group_MMA_AsyncThread,\n            consumer_group=pipeline_consumer_group_MMA_AsyncThread_dQ,\n            barrier_storage=storage.dQ_mbar_ptr.data_ptr(),\n            cta_layout_vmnk=cluster_layout_vmnk,\n        )\n\n        # AsyncThread producers and UMMA consumers\n        # Only 1 thread per warp will signal\n        pipeline_PdS_producer_group = cutlass.pipeline.CooperativeGroup(\n            cutlass.pipeline.Agent.Thread,\n            len(self.compute_warp_ids) * self.cta_group_size,\n        )  # Compute\n        pipeline_PdS_consumer_group = cutlass.pipeline.CooperativeGroup(\n            cutlass.pipeline.Agent.Thread, len([self.mma_warp_id])\n        )  # MMA\n        pipeline_dS = cutlass.pipeline.PipelineAsyncUmma.create(\n            num_stages=1,\n            producer_group=pipeline_PdS_producer_group,\n            consumer_group=pipeline_PdS_consumer_group,\n            barrier_storage=storage.dS_mbar_ptr.data_ptr(),\n            cta_layout_vmnk=cluster_layout_vmnk,\n        )\n\n        # TMA producer and UMMA consumers\n        pipeline_producer_group = cutlass.pipeline.CooperativeGroup(\n            cutlass.pipeline.Agent.Thread, len([self.load_warp_id])\n        )\n        # The arrive count is the number of mcast size\n        pipeline_consumer_group = cutlass.pipeline.CooperativeGroup(\n            cutlass.pipeline.Agent.Thread, len([self.mma_warp_id]) * self.num_mcast_ctas_b\n        )\n        pipeline_consumer_group_compute = cutlass.pipeline.CooperativeGroup(\n            cutlass.pipeline.Agent.Thread,\n            len(self.compute_warp_ids) * 1,\n        )\n        pipeline_LSE = cutlass.pipeline.PipelineTmaAsync.create(\n            barrier_storage=storage.LSE_mbar_ptr.data_ptr(),\n            num_stages=self.Q_stage,\n            producer_group=pipeline_producer_group,\n            consumer_group=pipeline_consumer_group_compute,\n            tx_count=self.tma_copy_bytes[\"LSE\"],\n            # cta_layout_vmnk=cluster_layout_vmnk,\n            defer_sync=True,\n        )\n        pipeline_dPsum = cutlass.pipeline.PipelineTmaAsync.create(\n            barrier_storage=storage.dPsum_mbar_ptr.data_ptr(),\n            num_stages=self.dO_stage,\n            producer_group=pipeline_producer_group,\n            consumer_group=pipeline_consumer_group_compute,\n            tx_count=self.tma_copy_bytes[\"dPsum\"],\n            # cta_layout_vmnk=cluster_layout_vmnk,\n            defer_sync=True,\n        )\n        pipeline_Q = pipeline.PipelineTmaUmma.create(\n            barrier_storage=storage.Q_mbar_ptr.data_ptr(),\n            num_stages=self.Q_stage,\n            producer_group=pipeline_producer_group,\n            consumer_group=pipeline_consumer_group,\n            tx_count=self.tma_copy_bytes[\"Q\"],\n            cta_layout_vmnk=cluster_layout_vmnk,\n            defer_sync=True,\n        )\n\n        if const_expr(self.use_2cta_instrs):\n            if const_expr(self.tile_hdim == 192):\n                pipeline_Qt = pipeline_Q\n            else:\n                pipeline_Qt = pipeline.PipelineTmaUmma.create(\n                    barrier_storage=storage.Qt_mbar_ptr.data_ptr(),\n                    num_stages=self.Q_stage,\n                    producer_group=pipeline_producer_group,\n                    consumer_group=pipeline_consumer_group,\n                    tx_count=self.tma_copy_bytes[\"Q\"],\n                    cta_layout_vmnk=cluster_layout_vmnk,\n                    defer_sync=True,\n                )\n            pipeline_Kt = pipeline.PipelineTmaUmma.create(\n                barrier_storage=storage.Kt_mbar_ptr.data_ptr(),\n                num_stages=self.single_stage,\n                producer_group=pipeline_producer_group,\n                consumer_group=pipeline_consumer_group,\n                tx_count=self.tma_copy_bytes[\"K\"],\n                cta_layout_vmnk=cluster_layout_vmnk,\n                defer_sync=True,\n            )\n        else:\n            pipeline_Qt = pipeline_Kt = pipeline_Q\n\n        pipeline_dO = pipeline.PipelineTmaUmma.create(\n            barrier_storage=storage.dO_mbar_ptr.data_ptr(),\n            num_stages=self.dO_stage,\n            producer_group=pipeline_producer_group,\n            consumer_group=pipeline_consumer_group,\n            tx_count=self.tma_copy_bytes[\"dO\"],\n            cta_layout_vmnk=cluster_layout_vmnk,\n            defer_sync=False,\n        )\n\n        sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner, dtype=self.q_dtype)\n        if const_expr(self.use_2cta_instrs and self.tile_hdim <= 128):\n            sQt = storage.sQt.get_tensor(\n                sQt_layout.outer, swizzle=sQt_layout.inner, dtype=self.q_dtype\n            )\n        else:\n            sQt = cute.make_tensor(\n                cute.recast_ptr(sQ.iterator, sQt_layout.inner, dtype=self.q_dtype), sQt_layout.outer\n            )\n        sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner)\n        if const_expr(self.use_2cta_instrs):\n            sKt = storage.sKt.get_tensor(sKt_layout.outer, swizzle=sKt_layout.inner)\n        else:\n            sKt = cute.make_tensor(cute.recast_ptr(sK.iterator, sKt_layout.inner), sKt_layout.outer)\n        sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner)\n        sdSt = storage.sdS.get_tensor(sdSt_layout.outer, swizzle=sdSt_layout.inner)\n        sdS = cute.make_tensor(cute.recast_ptr(sdSt.iterator, sdS_layout.inner), sdS_layout.outer)\n        if const_expr(self.use_2cta_instrs):\n            if const_expr(self.tile_hdim <= 128):\n                sdS_xchg = storage.sdS_xchg.get_tensor(sdS_xchg_layout)\n            else:\n                sdS_xchg = storage.sdQaccum.get_tensor(sdS_xchg_layout, dtype=self.ds_dtype)\n        else:\n            sdS_xchg = None\n\n        sdO = storage.sdO.get_tensor(\n            sdO_layout.outer, swizzle=sdO_layout.inner, dtype=self.do_dtype\n        )\n        if const_expr(self.use_2cta_instrs and self.tile_hdim <= 128):\n            sdOt = storage.sdOt.get_tensor(\n                sdOt_layout.outer, swizzle=sdOt_layout.inner, dtype=self.do_dtype\n            )\n        else:\n            sdOt = cute.make_tensor(\n                cute.recast_ptr(sdO.iterator, sdOt_layout.inner, dtype=self.do_dtype),\n                sdOt_layout.outer,\n            )\n\n        sLSE = storage.sLSE.get_tensor(sLSE_layout)\n        sdPsum = storage.sdPsum.get_tensor(sdPsum_layout)\n        if const_expr(self.use_2cta_instrs):\n            if const_expr(not self.dKV_postprocess):\n                sdV = storage.sV.get_tensor(\n                    sdV_layout.outer, swizzle=sdV_layout.inner, dtype=self.dv_dtype\n                )\n                sdK = storage.sK.get_tensor(\n                    sdK_layout.outer, swizzle=sdK_layout.inner, dtype=self.dk_dtype\n                )\n            else:\n                sdV = storage.sV.get_tensor(sdV_layout, dtype=self.dv_dtype)\n                sdK = storage.sK.get_tensor(sdK_layout, dtype=self.dk_dtype)\n        elif const_expr(not self.dKV_postprocess):\n            sdV = storage.sdO.get_tensor(\n                sdV_layout.outer, swizzle=sdV_layout.inner, dtype=self.dv_dtype\n            )\n            sdK = storage.sQ.get_tensor(\n                sdK_layout.outer, swizzle=sdK_layout.inner, dtype=self.dk_dtype\n            )\n        else:\n            sdV = storage.sdO.get_tensor(sdV_layout, dtype=self.dv_dtype)\n            sdK = storage.sQ.get_tensor(sdK_layout, dtype=self.dk_dtype)\n\n        # Buffer sizing is guaranteed by max(...) in SharedStorage declarations\n        # for both sQ (reused as sdK) and sdO (reused as sdV)\n        sdQaccum = storage.sdQaccum.get_tensor(sdQaccum_layout)\n\n        # TMEM\n        # This is a fake tensor, by right need to retrieve tmem_ptr. But we know that we always\n        # request 512 columns of tmem, so we know that it starts at 0.\n        tmem_ptr = cute.make_ptr(Float32, 0, mem_space=cute.AddressSpace.tmem, assumed_align=16)\n        # S\n        thr_mma_S = tiled_mma_S.get_slice(mma_tile_coord_v)\n        Sacc_shape = thr_mma_S.partition_shape_C(self.mma_tiler_kq[:2])  # (M, N)\n        tStS = thr_mma_S.make_fragment_C(Sacc_shape)\n        # (MMA, MMA_M, MMA_N)\n        tStS = cute.make_tensor(tmem_ptr + self.tmem_S_offset, tStS.layout)\n        # dP\n        thr_mma_dP = tiled_mma_dP.get_slice(mma_tile_coord_v)\n        dPacc_shape = thr_mma_dP.partition_shape_C(self.mma_tiler_vdo[:2])\n        tdPtdP = thr_mma_dP.make_fragment_C(dPacc_shape)\n        tdPtdP = cute.make_tensor(tmem_ptr + self.tmem_dP_offset, tdPtdP.layout)\n        # dV\n        thr_mma_dV = tiled_mma_dV.get_slice(mma_tile_coord_v)\n        dvacc_shape = thr_mma_dV.partition_shape_C(self.mma_tiler_pdo[:2])\n        tdVtdV = thr_mma_dV.make_fragment_C(dvacc_shape)\n        tdVtdV = cute.make_tensor(tmem_ptr + self.tmem_dV_offset, tdVtdV.layout)\n        tP = cute.make_tensor(\n            cute.recast_ptr(tmem_ptr + self.tmem_P_offset, dtype=self.do_dtype), tP_layout.outer\n        )\n        # dK\n        thr_mma_dK = tiled_mma_dK.get_slice(mma_tile_coord_v)\n        dkacc_shape = thr_mma_dK.partition_shape_C(self.mma_tiler_dsq[:2])\n        tdKtdK = thr_mma_dK.make_fragment_C(dkacc_shape)\n        tdKtdK = cute.make_tensor(tmem_ptr + self.tmem_dK_offset, tdKtdK.layout)\n        tdS = cute.make_tensor(\n            cute.recast_ptr(tmem_ptr + self.tmem_dS_offset, dtype=self.ds_dtype), tdS_layout.outer\n        )\n        # dQ\n        thr_mma_dQ = tiled_mma_dQ.get_slice(mma_tile_coord_v)\n        dQacc_shape = thr_mma_dQ.partition_shape_C(self.mma_tiler_dsk[:2])\n        tdQtdQ = thr_mma_dQ.make_fragment_C(dQacc_shape)\n        tdQtdQ = cute.make_tensor(tmem_ptr + self.tmem_dQ_offset, tdQtdQ.layout)\n\n        block_info = BlockInfo(\n            self.tile_m,\n            # self.tile_n,\n            self.tile_n * self.cluster_shape_mnk[0],  # careful, this case is not very well-tested\n            self.is_causal,\n            self.is_local,\n            False,  # is_split_kv\n            window_size_left,\n            window_size_right,\n            qhead_per_kvhead_packgqa=1,\n        )\n        SeqlenInfoCls = partial(\n            SeqlenInfoQK.create,\n            seqlen_q_static=mQ.shape[0],\n            seqlen_k_static=mK.shape[0],\n            mCuSeqlensQ=mCuSeqlensQ,\n            mCuSeqlensK=mCuSeqlensK,\n            mSeqUsedQ=mSeqUsedQ,\n            mSeqUsedK=mSeqUsedK,\n            tile_m=self.tile_m,\n            tile_n=self.tile_n * self.cluster_shape_mnk[0],\n        )\n        TileSchedulerCls = partial(self.tile_scheduler_cls.create, tile_sched_params)\n\n        AttentionMaskCls = partial(\n            AttentionMask,\n            self.tile_m,\n            self.tile_n * self.cta_group_size,\n            swap_AB=True,\n            window_size_left=window_size_left,\n            window_size_right=window_size_right,\n        )\n        #  EMPTY\n        # (15)\n        if warp_idx == self.empty_warp_id:\n            cute.arch.setmaxregister_decrease(self.num_regs_empty)\n\n        #  RELAY\n        # (14)\n        if warp_idx == self.relay_warp_id:\n            cute.arch.setmaxregister_decrease(\n                self.num_regs_mma if self.use_2cta_instrs else self.num_regs_empty\n            )\n            if const_expr(self.use_2cta_instrs):\n                self.relay(\n                    dS_cluster_full_mbar_ptr,\n                    dS_cluster_empty_mbar_ptr,\n                    dS_cluster_leader_mbar_ptr,\n                    cluster_layout_vmnk,\n                    block_info,\n                    SeqlenInfoCls,\n                    TileSchedulerCls,\n                )\n\n        #  LOAD\n        # (13)\n        if warp_idx == self.load_warp_id:\n            cute.arch.setmaxregister_decrease(self.num_regs_load)\n            self.load(\n                thr_mma_S,\n                thr_mma_dP,\n                thr_mma_dV,\n                thr_mma_dK,\n                thr_mma_dQ,\n                mQ,\n                mK,\n                mKt,\n                mV,\n                mdO,\n                mQt,\n                mdOt,\n                mLSE,\n                mdPsum,\n                sQ,\n                sK,\n                sKt,\n                sV,\n                sdO,\n                sQt,\n                sdOt,\n                sLSE,\n                sdPsum,\n                tma_atom_Q,\n                tma_atom_K,\n                tma_atom_Kt,\n                tma_atom_V,\n                tma_atom_dO,\n                tma_atom_Qt,\n                tma_atom_dOt,\n                pipeline_Q,\n                pipeline_Qt,\n                pipeline_Kt,\n                pipeline_dO,\n                pipeline_LSE,\n                pipeline_dPsum,\n                cluster_layout_vmnk,\n                block_info,\n                SeqlenInfoCls,\n                TileSchedulerCls,\n                blocksparse_tensors,\n                should_load_Q=True,\n                should_load_dO=True,\n            )\n\n        #  MMA\n        # (12)\n        if warp_idx == self.mma_warp_id:\n            cute.arch.setmaxregister_decrease(self.num_regs_mma)\n\n            # Alloc tmem buffer\n            tmem.allocate(self.tmem_alloc_cols)\n            tmem.wait_for_alloc()\n            tmem_ptr = tmem.retrieve_ptr(Float32)\n\n            self.mma(\n                tiled_mma_S,\n                tiled_mma_dP,\n                tiled_mma_dV,\n                tiled_mma_dK,\n                tiled_mma_dQ,\n                sQ,\n                sQt,\n                sK,\n                sKt,\n                sV,\n                sdO,\n                sdOt,\n                tP,\n                sdSt,\n                sdS,\n                tdS,\n                tStS,\n                tdPtdP,\n                tdVtdV,\n                tdKtdK,\n                tdQtdQ,\n                dS_cluster_full_mbar_ptr,\n                dS_cluster_empty_mbar_ptr,\n                dS_cluster_leader_mbar_ptr,\n                pipeline_Q,\n                pipeline_Qt,\n                pipeline_Kt,\n                pipeline_dO,\n                pipeline_S_P,\n                pipeline_dS,\n                pipeline_dKV,\n                pipeline_dP,\n                pipeline_dQ,\n                block_info,\n                SeqlenInfoCls,\n                TileSchedulerCls,\n                is_leader_cta,\n                blocksparse_tensors,\n            )\n            # Dealloc the tensor memory buffer\n            tmem.relinquish_alloc_permit()\n            tmem_alloc_barrier.arrive_and_wait()\n            tmem.free(tmem_ptr)\n\n        # Compute\n        # (4, 5, 6, 7, 8, 9, 10, 11) --> 8 warps\n        if warp_idx >= self.compute_warp_ids[0] and warp_idx <= self.compute_warp_ids[-1]:\n            cute.arch.setmaxregister_increase(self.num_regs_compute)  # 8 warps\n            tmem.wait_for_alloc()\n            tmem_ptr = tmem.retrieve_ptr(Float32)\n            self.compute_loop(\n                thr_mma_S,\n                thr_mma_dP,\n                thr_mma_dV,\n                thr_mma_dK,\n                tStS,\n                tdPtdP,\n                tdVtdV,\n                tdKtdK,\n                sLSE,\n                sdPsum,\n                mdV,\n                mdK,\n                sdS,\n                sdS_xchg,\n                pipeline_LSE,\n                pipeline_dPsum,\n                pipeline_S_P,\n                pipeline_dS,\n                pipeline_dKV,\n                pipeline_dP,\n                dS_cluster_empty_mbar_ptr,\n                dS_cluster_full_mbar_ptr,\n                dQaccum_empty_mbar_ptr,\n                softmax_scale,\n                softmax_scale_log2,\n                block_info,\n                SeqlenInfoCls,\n                AttentionMaskCls,\n                TileSchedulerCls,\n                sdV,\n                sdK,\n                mdV_tma_tensor,\n                mdK_tma_tensor,\n                tma_atom_dV,\n                tma_atom_dK,\n                tiled_copy_r2s_dKV,\n                mdK_semaphore,\n                mdV_semaphore,\n                aux_tensors,\n                fastdiv_mods,\n                blocksparse_tensors,\n            )\n            tmem_alloc_barrier.arrive()\n\n        # Reduce\n        # (0, 1, 2, 3) - dQ\n        if warp_idx >= self.reduce_warp_ids[0] and warp_idx <= self.reduce_warp_ids[-1]:\n            cute.arch.setmaxregister_increase(self.num_regs_reduce)\n            tmem.wait_for_alloc()\n            tmem_ptr = tmem.retrieve_ptr(Float32)\n            self.dQacc_reduce(\n                mdQaccum,\n                sdQaccum,\n                thr_mma_dQ,\n                tdQtdQ,\n                pipeline_dQ,\n                dQaccum_empty_mbar_ptr,\n                block_info,\n                SeqlenInfoCls,\n                TileSchedulerCls,\n                mdQ_semaphore,\n                blocksparse_tensors,\n            )\n            tmem_alloc_barrier.arrive()\n\n        return\n\n    @cute.jit\n    def relay(\n        self,\n        dS_cluster_full_mbar_ptr: cute.Pointer,\n        dS_cluster_empty_mbar_ptr: cute.Pointer,\n        dS_cluster_leader_mbar_ptr: cute.Pointer,\n        cluster_layout_vmnk: cute.Layout,\n        block_info: BlockInfo,\n        SeqlenInfoCls: Callable,\n        TileSchedulerCls: Callable,\n    ):\n        cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())\n        dS_cluster_phase = Int32(0)\n\n        tile_scheduler = TileSchedulerCls()\n        work_tile = tile_scheduler.initial_work_tile_info()\n        while work_tile.is_valid_tile:\n            n_block, head_idx, batch_idx, _ = work_tile.tile_idx\n            seqlen = SeqlenInfoCls(batch_idx)\n            m_block_min, m_block_max = block_info.get_m_block_min_max(\n                seqlen, n_block // self.cluster_shape_mnk[0]\n            )\n            head_idx_kv = head_idx // self.qhead_per_kvhead\n\n            process_tile = (\n                const_expr(not self.is_local and not self.is_varlen_q) or m_block_min < m_block_max\n            )\n\n            if process_tile:\n                num_iters = m_block_max - m_block_min\n                for _ in cutlass.range(num_iters, unroll=1):\n                    # Wait for dS_xchg from peer CTA\n                    cute.arch.mbarrier_wait(dS_cluster_full_mbar_ptr, phase=dS_cluster_phase)\n\n                    # Arrive on MMA leader warp\n                    with cute.arch.elect_one():\n                        cute.arch.mbarrier_arrive(dS_cluster_leader_mbar_ptr, Int32(0))\n\n                    dS_cluster_phase ^= 1\n\n            tile_scheduler.prefetch_next_work()\n            tile_scheduler.advance_to_next_work()\n            work_tile = tile_scheduler.get_current_work()\n\n    @cute.jit\n    def load(\n        self,\n        thr_mma_S: cute.core.ThrMma,\n        thr_mma_dP: cute.core.ThrMma,\n        thr_mma_dV: cute.core.ThrMma,\n        thr_mma_dK: cute.core.ThrMma,\n        thr_mma_dQ: cute.core.ThrMma,\n        mQ: cute.Tensor,\n        mK: cute.Tensor,\n        mKt: Optional[cute.Tensor],\n        mV: cute.Tensor,\n        mdO: cute.Tensor,\n        mQt: Optional[cute.Tensor],\n        mdOt: Optional[cute.Tensor],\n        mLSE: cute.Tensor,\n        mdPsum: cute.Tensor,\n        sQ: cute.Tensor,\n        sK: cute.Tensor,\n        sKt: cute.Tensor,\n        sV: cute.Tensor,\n        sdO: cute.Tensor,\n        sQt: cute.Tensor,\n        sdOt: cute.Tensor,\n        sLSE: cute.Tensor,\n        sdPsum: cute.Tensor,\n        tma_atom_Q: cute.CopyAtom,\n        tma_atom_K: cute.CopyAtom,\n        tma_atom_Kt: Optional[cute.CopyAtom],\n        tma_atom_V: cute.CopyAtom,\n        tma_atom_dO: cute.CopyAtom,\n        tma_atom_Qt: Optional[cute.CopyAtom],\n        tma_atom_dOt: Optional[cute.CopyAtom],  # 2-CTA only\n        pipeline_Q: PipelineAsync,\n        pipeline_Qt: PipelineAsync,\n        pipeline_Kt: PipelineAsync,\n        pipeline_dO: PipelineAsync,\n        pipeline_LSE: PipelineAsync,\n        pipeline_dPsum: PipelineAsync,\n        cluster_layout_vmnk: cute.Layout,\n        block_info: BlockInfo,\n        SeqlenInfoCls: Callable,\n        TileSchedulerCls: Callable,\n        blocksparse_tensors: Optional[BlockSparseTensors] = None,\n        should_load_Q: bool = True,\n        should_load_dO: bool = True,\n    ):\n        producer_state_Q_LSE = cutlass.pipeline.make_pipeline_state(\n            cutlass.pipeline.PipelineUserType.Producer, self.Q_stage\n        )\n        producer_state_Qt = cutlass.pipeline.make_pipeline_state(\n            cutlass.pipeline.PipelineUserType.Producer, self.Q_stage\n        )\n        producer_state_Kt = cutlass.pipeline.make_pipeline_state(\n            cutlass.pipeline.PipelineUserType.Producer, self.single_stage\n        )\n        producer_state_dO_dPsum = cutlass.pipeline.make_pipeline_state(\n            cutlass.pipeline.PipelineUserType.Producer, self.dO_stage\n        )\n        producer_state_Q_Qt = cutlass.pipeline.make_pipeline_state(\n            cutlass.pipeline.PipelineUserType.Producer, self.Q_stage\n        )\n        producer_state_O_Ot = cutlass.pipeline.make_pipeline_state(\n            cutlass.pipeline.PipelineUserType.Producer, self.dO_stage\n        )\n        producer_state_LSE = cutlass.pipeline.make_pipeline_state(\n            cutlass.pipeline.PipelineUserType.Producer, self.Q_stage\n        )\n        producer_state_dPsum = cutlass.pipeline.make_pipeline_state(\n            cutlass.pipeline.PipelineUserType.Producer, self.dO_stage\n        )\n\n        # Compute multicast mask for Q & dO buffer full\n        cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())\n        block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster)\n        q_do_mcast_mask = None\n        if const_expr(self.is_q_do_mcast):\n            q_do_mcast_mask = cpasync.create_tma_multicast_mask(\n                cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1\n            )\n\n        tile_scheduler = TileSchedulerCls()\n        work_tile = tile_scheduler.initial_work_tile_info()\n        while work_tile.is_valid_tile:\n            n_block, head_idx, batch_idx, _ = work_tile.tile_idx\n            seqlen = SeqlenInfoCls(batch_idx)\n            m_block_min, m_block_max = block_info.get_m_block_min_max(\n                seqlen, n_block // self.cluster_shape_mnk[0]\n            )\n            head_idx_kv = head_idx // self.qhead_per_kvhead\n            n_block_cta_group = n_block // self.cta_group_size\n\n            # GMEM tensors (varlen-aware)\n            mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx]\n            mK_cur = seqlen.offset_batch_K(mK, batch_idx, dim=3)[None, None, head_idx_kv]\n            mV_cur = seqlen.offset_batch_K(mV, batch_idx, dim=3)[None, None, head_idx_kv]\n            if const_expr(not seqlen.has_cu_seqlens_q):\n                mdO_cur = mdO[None, None, head_idx, batch_idx]\n            else:\n                mdO_cur = cute.domain_offset((0, seqlen.offset_q), mdO[None, None, head_idx])\n            mLSE_cur = seqlen.offset_batch_Q(mLSE, batch_idx, dim=2, padded=True)[None, head_idx]\n            mdPsum_cur = seqlen.offset_batch_Q(mdPsum, batch_idx, dim=2, padded=True)[\n                None, head_idx\n            ]\n\n            if const_expr(self.use_2cta_instrs):\n                if const_expr(not seqlen.has_cu_seqlens_q):\n                    mQt_cur = mQt[None, None, head_idx, batch_idx]\n                    mdOt_cur = mdOt[None, None, head_idx, batch_idx]\n                else:\n                    mQt_cur = cute.domain_offset((0, seqlen.offset_q, 0), mQt)[None, None, head_idx]\n                    mdOt_cur = cute.domain_offset((seqlen.offset_q, 0, 0), mdOt)[\n                        None, None, head_idx\n                    ]\n                if const_expr(not seqlen.has_cu_seqlens_k):\n                    mKt_cur = mKt[None, None, head_idx_kv, batch_idx]\n                else:\n                    mKt_cur = cute.domain_offset((0, seqlen.offset_k, 0), mKt)[\n                        None, None, head_idx_kv\n                    ]\n\n            # (1) S.T = K @ Q.T\n            gK = cute.local_tile(\n                mK_cur, cute.select(self.mma_tiler_kq, mode=[0, 2]), (n_block_cta_group, 0)\n            )\n            tSgK = thr_mma_S.partition_A(gK)\n\n            gQ = cute.local_tile(mQ_cur, cute.select(self.mma_tiler_kq, mode=[1, 2]), (None, 0))\n            tSgQ = thr_mma_S.partition_B(gQ)\n            gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (None,))\n            gdPsum = cute.local_tile(mdPsum_cur, (self.tile_m,), (None,))\n            gdO = cute.local_tile(mdO_cur, cute.select(self.mma_tiler_pdo, mode=[1, 2]), (0, None))\n            tdPgdO = thr_mma_dV.partition_B(gdO)\n\n            a_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape)\n            load_K, _, _ = copy_utils.tma_get_copy_fn(\n                tma_atom_K,\n                block_in_cluster_coord_vmnk[2],\n                a_cta_layout,\n                tSgK,\n                sK,\n                single_stage=True,\n            )\n\n            b_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape)\n            load_Q, _, _ = copy_utils.tma_get_copy_fn(\n                tma_atom_Q,\n                cta_coord=block_in_cluster_coord_vmnk[1],\n                cta_layout=b_cta_layout,\n                src_tensor=tSgQ,\n                dst_tensor=sQ,\n                mcast_mask=q_do_mcast_mask,\n            )\n            load_Q = copy_utils.tma_producer_copy_fn(load_Q, pipeline_Q)\n\n            # (2) dP = V @ dO.T\n            gV = cute.local_tile(\n                mV_cur, cute.select(self.mma_tiler_vdo, mode=[0, 2]), (n_block_cta_group, 0)\n            )\n            tdPgV = thr_mma_dP.partition_A(gV)\n\n            load_V, _, _ = copy_utils.tma_get_copy_fn(\n                tma_atom_V,\n                0,\n                cute.make_layout(1),\n                tdPgV,\n                sV,\n                single_stage=True,\n            )\n\n            if const_expr(tma_atom_dOt is not None):\n                gdOt = cute.local_tile(\n                    mdOt_cur, cute.select(self.mma_tiler_vdo, mode=[1, 2]), (None, 0)\n                )\n                tdPgdO = thr_mma_dP.partition_B(gdOt)\n                load_dOt, _, _ = copy_utils.tma_get_copy_fn(\n                    tma_atom_dOt,\n                    cta_coord=block_in_cluster_coord_vmnk[1],\n                    cta_layout=b_cta_layout,\n                    src_tensor=tdPgdO,\n                    dst_tensor=sdOt,\n                    mcast_mask=q_do_mcast_mask,\n                )\n                load_dOt = copy_utils.tma_producer_copy_fn(load_dOt, pipeline_dO)\n\n            # (3) dV += P.T @ dO\n            gdO = cute.local_tile(mdO_cur, cute.select(self.mma_tiler_pdo, mode=[1, 2]), (0, None))\n            tdVgdO = thr_mma_dV.partition_B(gdO)\n            load_dO, _, _ = copy_utils.tma_get_copy_fn(\n                tma_atom_dO,\n                cta_coord=block_in_cluster_coord_vmnk[1],\n                cta_layout=b_cta_layout,\n                src_tensor=tdVgdO,\n                dst_tensor=sdO,\n                mcast_mask=q_do_mcast_mask,\n            )\n            load_dO = copy_utils.tma_producer_copy_fn(load_dO, pipeline_dO)\n\n            # (4) dK += dS.T @ Q (2-CTA: needs separate Qt load)\n            if const_expr(tma_atom_Qt is not None):\n                gQt = cute.local_tile(\n                    mQt_cur, cute.select(self.mma_tiler_dsq, mode=[1, 2]), (0, None)\n                )\n                tdKgQt = thr_mma_dK.partition_B(gQt)\n                load_Qt, _, _ = copy_utils.tma_get_copy_fn(\n                    tma_atom_Qt,\n                    cta_coord=block_in_cluster_coord_vmnk[1],\n                    cta_layout=b_cta_layout,\n                    src_tensor=tdKgQt,\n                    dst_tensor=sQt,\n                    mcast_mask=q_do_mcast_mask,\n                )\n                load_Qt = copy_utils.tma_producer_copy_fn(load_Qt, pipeline_Qt)\n\n            # (5) dQ = dS @ K\n            if const_expr(self.use_2cta_instrs):\n                gKt = cute.local_tile(\n                    mKt_cur, cute.select(self.mma_tiler_dsk, mode=[1, 2]), (0, n_block_cta_group)\n                )\n                tdQgK = thr_mma_dQ.partition_B(gKt)\n\n                load_Kt, _, _ = copy_utils.tma_get_copy_fn(\n                    tma_atom_Kt,\n                    block_in_cluster_coord_vmnk[1],\n                    b_cta_layout,\n                    tdQgK,\n                    sKt,\n                    single_stage=True,\n                )\n\n            copy_atom_stats = cute.make_copy_atom(cpasync.CopyBulkG2SOp(), Float32)\n            copy_stats = partial(cute.copy, copy_atom_stats)\n            # copy_atom_stats = cute.make_copy_atom(cpasync.CopyBulkG2SMulticastOp(), Float32)\n            # sLSE = cute.logical_divide(sLSE, (64,))[(None, block_in_cluster_coord_vmnk[1]), None]\n            # gLSE = cute.logical_divide(gLSE, (64,))[(None, block_in_cluster_coord_vmnk[1]), None]\n            # sdPsum = cute.logical_divide(sdPsum, (64,))[(None, block_in_cluster_coord_vmnk[1]), None]\n            # gdPsum = cute.logical_divide(gdPsum, (64,))[(None, block_in_cluster_coord_vmnk[1]), None]\n            # copy_stats = partial(cute.copy, copy_atom_stats, mcast_mask=q_do_mcast_mask)\n\n            # some tiles might be empty due to block sparsity\n            if const_expr(self.use_block_sparsity):\n                total_m_block_cnt = get_total_q_block_count_bwd(\n                    blocksparse_tensors,\n                    batch_idx,\n                    head_idx,\n                    n_block,\n                    subtile_factor=self.subtile_factor,\n                    m_block_max=m_block_max,\n                )\n                process_tile = total_m_block_cnt > Int32(0)\n            else:\n                process_tile = (\n                    const_expr(not self.is_local and not self.is_varlen_q)\n                    or m_block_min < m_block_max\n                )\n\n            if process_tile:\n                if const_expr(self.use_block_sparsity):\n                    producer_state_Q_LSE, producer_state_dO_dPsum = (\n                        produce_block_sparse_q_loads_bwd_sm100(\n                            blocksparse_tensors,\n                            batch_idx,\n                            head_idx,\n                            n_block,\n                            producer_state_Q_LSE,\n                            producer_state_dO_dPsum,\n                            pipeline_Q,\n                            pipeline_LSE,\n                            pipeline_dO,\n                            pipeline_dPsum,\n                            load_K,\n                            load_V,\n                            load_Q,\n                            load_dO,\n                            copy_stats,\n                            gLSE,\n                            sLSE,\n                            gdPsum,\n                            sdPsum,\n                            self.tma_copy_bytes[\"K\"],\n                            self.tma_copy_bytes[\"V\"],\n                            should_load_Q=should_load_Q,\n                            should_load_dO=should_load_dO,\n                            subtile_factor=self.subtile_factor,\n                            m_block_max=m_block_max,\n                        )\n                    )\n                else:\n                    first_m_block = m_block_min\n                    if const_expr(self.use_2cta_instrs and self.tile_hdim == 192):\n                        #### Prologue ####\n                        assert should_load_Q and should_load_dO\n                        # K & Q (for S)\n                        pipeline_Q.producer_acquire(\n                            producer_state_Q_Qt,\n                            extra_tx_count=self.tma_copy_bytes[\"K\"],\n                        )\n                        load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q_Qt))\n                        load_Q(first_m_block, producer_state=producer_state_Q_Qt)\n                        pipeline_Q.producer_commit(producer_state_Q_Qt)\n                        producer_state_Q_Qt.advance()\n                        # LSE\n                        pipeline_LSE.producer_acquire(producer_state_LSE)\n                        with cute.arch.elect_one():\n                            copy_stats(\n                                gLSE[None, first_m_block],\n                                sLSE[None, producer_state_LSE.index],\n                                mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_LSE),\n                            )\n                        producer_state_LSE.advance()\n\n                        # dOt + V, for dP.T = V @ dO.T\n                        pipeline_dO.producer_acquire(\n                            producer_state_O_Ot,\n                            extra_tx_count=self.tma_copy_bytes[\"V\"],\n                        )\n                        load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_O_Ot))\n                        load_dOt(first_m_block, producer_state=producer_state_O_Ot)\n                        pipeline_dO.producer_commit(producer_state_O_Ot)\n                        producer_state_O_Ot.advance()\n                        # dPsum\n                        pipeline_dPsum.producer_acquire(producer_state_dPsum)\n                        with cute.arch.elect_one():\n                            copy_stats(\n                                gdPsum[None, first_m_block],\n                                sdPsum[None, producer_state_dPsum.index],\n                                mbar_ptr=pipeline_dPsum.producer_get_barrier(producer_state_dPsum),\n                            )\n                        producer_state_dPsum.advance()\n\n                        # Qt, for dK = dS.T @ Q\n                        pipeline_Qt.producer_acquire(\n                            producer_state_Q_Qt,\n                            extra_tx_count=self.tma_copy_bytes[\"K\"],\n                        )\n                        load_Qt(first_m_block, producer_state=producer_state_Q_Qt)\n                        load_Kt(tma_bar_ptr=pipeline_Qt.producer_get_barrier(producer_state_Q_Qt))\n                        pipeline_Qt.producer_commit(producer_state_Q_Qt)\n                        producer_state_Q_Qt.advance()\n\n                        # dO, for dV = P.T @ dO\n                        pipeline_dO.producer_acquire(producer_state_O_Ot)\n                        load_dO(first_m_block, producer_state=producer_state_O_Ot)\n                        pipeline_dO.producer_commit(producer_state_O_Ot)\n                        producer_state_O_Ot.advance()\n\n                        #### Mainloop ####\n                        # 2CTA: [lse | Q | dOt | dPsum | Qt | dO]\n                        for m_block in cutlass.range(m_block_min + 1, m_block_max, unroll=1):\n                            # LSE\n                            pipeline_LSE.producer_acquire(producer_state_LSE)\n                            with cute.arch.elect_one():\n                                copy_stats(\n                                    gLSE[None, m_block],\n                                    sLSE[None, producer_state_LSE.index],\n                                    mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_LSE),\n                                )\n                            producer_state_LSE.advance()\n\n                            # Q\n                            pipeline_Q.producer_acquire(producer_state_Q_Qt)\n                            load_Q(m_block, producer_state=producer_state_Q_Qt)\n                            pipeline_Q.producer_commit(producer_state_Q_Qt)\n                            producer_state_Q_Qt.advance()\n\n                            # dPsum\n                            pipeline_dPsum.producer_acquire(producer_state_dPsum)\n                            with cute.arch.elect_one():\n                                copy_stats(\n                                    gdPsum[None, m_block],\n                                    sdPsum[None, producer_state_dPsum.index],\n                                    mbar_ptr=pipeline_dPsum.producer_get_barrier(\n                                        producer_state_dPsum\n                                    ),\n                                )\n                            producer_state_dPsum.advance()\n\n                            # dOt, for dP.T = V @ dO.T\n                            pipeline_dO.producer_acquire(producer_state_O_Ot)\n                            load_dOt(m_block, producer_state=producer_state_O_Ot)\n                            pipeline_dO.producer_commit(producer_state_O_Ot)\n                            producer_state_O_Ot.advance()\n\n                            # Qt, for dK = dS.T @ Q\n                            pipeline_Qt.producer_acquire(producer_state_Q_Qt)\n                            load_Qt(m_block, producer_state=producer_state_Q_Qt)\n                            pipeline_Qt.producer_commit(producer_state_Q_Qt)\n                            producer_state_Q_Qt.advance()\n\n                            # dO, for dV = P.T @ dO\n                            pipeline_dO.producer_acquire(producer_state_O_Ot)\n                            load_dO(m_block, producer_state=producer_state_O_Ot)\n                            pipeline_dO.producer_commit(producer_state_O_Ot)\n                            producer_state_O_Ot.advance()\n\n                    else:\n                        #### Prologue ####\n                        if const_expr(should_load_Q):\n                            # K & Q (for S)\n                            pipeline_Q.producer_acquire(\n                                producer_state_Q_LSE, extra_tx_count=self.tma_copy_bytes[\"K\"]\n                            )\n                            load_K(\n                                tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q_LSE)\n                            )\n                            load_Q(first_m_block, producer_state=producer_state_Q_LSE)\n                            pipeline_Q.producer_commit(producer_state_Q_LSE)\n\n                            # LSE\n                            pipeline_LSE.producer_acquire(producer_state_Q_LSE)\n                            with cute.arch.elect_one():\n                                copy_stats(\n                                    gLSE[None, first_m_block],\n                                    sLSE[None, producer_state_Q_LSE.index],\n                                    mbar_ptr=pipeline_LSE.producer_get_barrier(\n                                        producer_state_Q_LSE\n                                    ),\n                                )\n                            producer_state_Q_LSE.advance()\n\n                        if const_expr(should_load_dO):\n                            pipeline_dO.producer_acquire(\n                                producer_state_dO_dPsum,\n                                extra_tx_count=self.tma_copy_bytes[\"V\"] + self.tma_copy_bytes[\"dO\"]\n                                if const_expr(tma_atom_dOt is not None)\n                                else self.tma_copy_bytes[\"V\"],\n                            )\n                            load_V(\n                                tma_bar_ptr=pipeline_dO.producer_get_barrier(\n                                    producer_state_dO_dPsum\n                                )\n                            )\n                            load_dO(first_m_block, producer_state=producer_state_dO_dPsum)\n                            if const_expr(tma_atom_dOt is not None):\n                                load_dOt(first_m_block, producer_state=producer_state_dO_dPsum)\n                            pipeline_dO.producer_commit(producer_state_dO_dPsum)\n\n                            # dPsum\n                            pipeline_dPsum.producer_acquire(producer_state_dO_dPsum)\n                            with cute.arch.elect_one():\n                                copy_stats(\n                                    gdPsum[None, first_m_block],\n                                    sdPsum[None, producer_state_dO_dPsum.index],\n                                    mbar_ptr=pipeline_dPsum.producer_get_barrier(\n                                        producer_state_dO_dPsum\n                                    ),\n                                )\n                            producer_state_dO_dPsum.advance()\n\n                        if const_expr(self.use_2cta_instrs):\n                            pipeline_Kt.producer_acquire(producer_state_Kt)\n                            load_Kt(tma_bar_ptr=pipeline_Kt.producer_get_barrier(producer_state_Kt))\n                            pipeline_Kt.producer_commit(producer_state_Kt)\n                            producer_state_Kt.advance()\n                        #### Main Loop ####\n                        for m_block in cutlass.range(m_block_min + 1, m_block_max, unroll=1):\n                            if const_expr(should_load_Q):\n                                if const_expr(tma_atom_Qt is not None):\n                                    pipeline_Qt.producer_acquire(producer_state_Qt)\n                                    load_Qt(m_block - 1, producer_state=producer_state_Qt)\n                                    pipeline_Qt.producer_commit(producer_state_Qt)\n                                    producer_state_Qt.advance()\n\n                                # Q (for S)\n                                pipeline_Q.producer_acquire(producer_state_Q_LSE)\n                                load_Q(m_block, producer_state=producer_state_Q_LSE)\n                                pipeline_Q.producer_commit(producer_state_Q_LSE)\n\n                                # LSE\n                                pipeline_LSE.producer_acquire(producer_state_Q_LSE)\n                                with cute.arch.elect_one():\n                                    copy_stats(\n                                        gLSE[None, m_block],\n                                        sLSE[None, producer_state_Q_LSE.index],\n                                        mbar_ptr=pipeline_LSE.producer_get_barrier(\n                                            producer_state_Q_LSE\n                                        ),\n                                    )\n                                producer_state_Q_LSE.advance()\n\n                            if const_expr(should_load_dO):\n                                pipeline_dO.producer_acquire(\n                                    producer_state_dO_dPsum,\n                                    extra_tx_count=self.tma_copy_bytes[\"dO\"]\n                                    if const_expr(tma_atom_dOt is not None)\n                                    else 0,\n                                )\n                                load_dO(m_block, producer_state=producer_state_dO_dPsum)\n                                if const_expr(tma_atom_dOt is not None):\n                                    load_dOt(m_block, producer_state=producer_state_dO_dPsum)\n                                pipeline_dO.producer_commit(producer_state_dO_dPsum)\n\n                                # dPsum\n                                pipeline_dPsum.producer_acquire(producer_state_dO_dPsum)\n                                with cute.arch.elect_one():\n                                    copy_stats(\n                                        gdPsum[None, m_block],\n                                        sdPsum[None, producer_state_dO_dPsum.index],\n                                        mbar_ptr=pipeline_dPsum.producer_get_barrier(\n                                            producer_state_dO_dPsum\n                                        ),\n                                    )\n                                producer_state_dO_dPsum.advance()\n\n                        #### Tail ####\n                        if const_expr(should_load_Q):\n                            if const_expr(tma_atom_Qt is not None):\n                                pipeline_Qt.producer_acquire(producer_state_Qt)\n                                load_Qt(m_block_max - 1, producer_state=producer_state_Qt)\n                                pipeline_Qt.producer_commit(producer_state_Qt)\n                                producer_state_Qt.advance()\n\n                if const_expr(self.use_2cta_instrs and self.tile_hdim == 192):\n                    pipeline_Q.producer_tail(producer_state_Q_Qt)\n                    pipeline_LSE.producer_tail(producer_state_LSE)\n                    pipeline_dO.producer_tail(producer_state_O_Ot)\n                    pipeline_dPsum.producer_tail(producer_state_dPsum)\n                else:\n                    if const_expr(should_load_Q):\n                        pipeline_Q.producer_tail(producer_state_Q_LSE.clone())\n                        pipeline_LSE.producer_tail(producer_state_Q_LSE)\n                        if const_expr(tma_atom_Qt is not None):\n                            pipeline_Qt.producer_tail(producer_state_Qt)\n                    if const_expr(should_load_dO):\n                        pipeline_dO.producer_tail(producer_state_dO_dPsum.clone())\n                        pipeline_dPsum.producer_tail(producer_state_dO_dPsum)\n\n            tile_scheduler.prefetch_next_work()\n            tile_scheduler.advance_to_next_work()\n            work_tile = tile_scheduler.get_current_work()\n\n    @cute.jit\n    def mma(\n        self,\n        tiled_mma_S: cute.TiledMma,\n        tiled_mma_dP: cute.TiledMma,\n        tiled_mma_dV: cute.TiledMma,\n        tiled_mma_dK: cute.TiledMma,\n        tiled_mma_dQ: cute.TiledMma,\n        sQ: cute.Tensor,\n        sQt: cute.Tensor,\n        sK: cute.Tensor,\n        sKt: cute.Tensor,\n        sV: cute.Tensor,\n        sdO: cute.Tensor,\n        sdOt: cute.Tensor,\n        tP: cute.Tensor,\n        sdSt: cute.Tensor,\n        sdS: cute.Tensor,\n        tdS: cute.Tensor,\n        tStS: cute.Tensor,\n        tdPtdP: cute.Tensor,\n        tdVtdV: cute.Tensor,\n        tdKtdK: cute.Tensor,\n        tdQtdQ: cute.Tensor,\n        dS_cluster_full_mbar_ptr: cute.Pointer,\n        dS_cluster_empty_mbar_ptr: cute.Pointer,\n        dS_cluster_leader_mbar_ptr: cute.Pointer,\n        pipeline_Q: PipelineAsync,\n        pipeline_Qt: PipelineAsync,\n        pipeline_Kt: PipelineAsync,\n        pipeline_dO: PipelineAsync,\n        pipeline_S_P: PipelineAsync,\n        pipeline_dS: PipelineAsync,\n        pipeline_dKV: PipelineAsync,\n        pipeline_dP: PipelineAsync,\n        pipeline_dQ: PipelineAsync,\n        block_info: BlockInfo,\n        SeqlenInfoCls: Callable,\n        TileSchedulerCls: Callable,\n        is_leader_cta: cutlass.Boolean,\n        blocksparse_tensors: Optional[BlockSparseTensors] = None,\n    ):\n        # [2025-10-21] For reasons I don't understand, putting these partitioning in the main\n        # kernel (before warp specialization) is a lot slower tha putting them here.\n        # Partition smem / tmem tensors\n        # S = K @ Q.T\n        tSrK = tiled_mma_S.make_fragment_A(sK)\n        tSrQ = tiled_mma_S.make_fragment_B(sQ)\n        # dP = V @ dOt.T\n        tdPrV = tiled_mma_dP.make_fragment_A(sV)\n        tdPrdOt = tiled_mma_dP.make_fragment_B(sdOt)\n        # dK = dS.T @ Q\n        # For 2-CTA, dS (dK mma) MUST come from TMEM (cannot use SMEM)\n        if const_expr(self.use_smem_dS_for_mma_dK and not self.use_2cta_instrs):\n            tdKrdS = tiled_mma_dK.make_fragment_A(sdSt)  # From SMEM\n        else:\n            tdKrdS = tiled_mma_dK.make_fragment_A(tdS)  # From TMEM\n\n        tdKrQ = tiled_mma_dK.make_fragment_B(sQt)\n        # dQ = dS @ K\n        tdQrdS = tiled_mma_dQ.make_fragment_A(sdS)\n        tdQrK = tiled_mma_dQ.make_fragment_B(sKt)\n        # dV = P @ dO.T\n        tdVrdO = tiled_mma_dV.make_fragment_B(sdO)\n        tdVrP = tiled_mma_dV.make_fragment_A(tP)\n\n        # mma_qk_fn = partial(gemm_w_idx, tiled_mma_S, tStS, tSrK, tSrQ, zero_init=True)\n        mma_qk_fn = partial(\n            gemm_ptx_w_idx,\n            tiled_mma_S,\n            tStS,\n            tSrK,\n            tSrQ,\n            sA=sK,\n            sB=sQ,\n            zero_init=True,\n            cta_group=self.cta_group_size,\n        )\n        # mma_dov_fn = partial(gemm_w_idx, tiled_mma_dP, tdPtdP, tdPrV, tdPrdOt, zero_init=True)\n        mma_dov_fn = partial(\n            gemm_ptx_w_idx,\n            tiled_mma_dP,\n            tdPtdP,\n            tdPrV,\n            tdPrdOt,\n            sA=sV,\n            sB=sdOt,\n            zero_init=True,\n            cta_group=self.cta_group_size,\n        )\n        # mma_pdo_fn = partial(gemm_w_idx, tiled_mma_dV, tdVtdV, tdVrP, tdVrdO)\n        mma_pdo_fn = partial(\n            gemm_ptx_w_idx,\n            tiled_mma_dV,\n            tdVtdV,\n            tdVrP,\n            tdVrdO,\n            sA=None,\n            sB=sdO,\n            tA_addr=self.tmem_P_offset,\n            cta_group=self.cta_group_size,\n        )\n        num_unroll_groups = 2 if const_expr(self.use_2cta_instrs) else 1\n        mma_dsk_fn = partial(\n            gemm_w_idx,\n            tiled_mma_dQ,\n            tdQtdQ,\n            tdQrdS,\n            tdQrK,\n            zero_init=True,\n            num_unroll_groups=num_unroll_groups,\n        )\n        # mma_dsk_fn = partial(\n        #     gemm_ptx_w_idx, tiled_mma_dQ, tdQtdQ, tdQrdS, tdQrK, sA=sdS, sB=sKt, zero_init=True\n        # )\n        if const_expr(self.use_smem_dS_for_mma_dK and not self.use_2cta_instrs):\n            mma_dsq_fn = partial(gemm_w_idx, tiled_mma_dK, tdKtdK, tdKrdS, tdKrQ)\n        else:\n            # Need to explicitly pass in tA_addr for correctness\n            mma_dsq_fn = partial(\n                gemm_ptx_w_idx,\n                tiled_mma_dK,\n                tdKtdK,\n                tdKrdS,\n                tdKrQ,\n                sA=None,\n                sB=sQt,\n                tA_addr=self.tmem_dS_offset,\n                cta_group=self.cta_group_size,\n            )\n\n        pipeline_Q_consumer = pipeline_Q.make_consumer()\n\n        consumer_state_Qt = cutlass.pipeline.make_pipeline_state(\n            cutlass.pipeline.PipelineUserType.Consumer, self.Q_stage\n        )\n        consumer_state_Q = cutlass.pipeline.make_pipeline_state(\n            cutlass.pipeline.PipelineUserType.Consumer, self.Q_stage\n        )\n        consumer_state_Kt = cutlass.pipeline.make_pipeline_state(\n            cutlass.pipeline.PipelineUserType.Consumer, self.single_stage\n        )\n        consumer_state_dO = cutlass.pipeline.make_pipeline_state(\n            cutlass.pipeline.PipelineUserType.Consumer, self.dO_stage\n        )\n        producer_phase_acc = Int32(1)  # For S & P, dP, dQ\n        producer_phase_dQ = Int32(1)  # 2-CTA: separate phase for dQ pipeline\n        consumer_state_dS = cutlass.pipeline.make_pipeline_state(\n            cutlass.pipeline.PipelineUserType.Consumer, 1\n        )\n        producer_phase_dKV = Int32(1)\n        cta_group = pipeline_S_P.cta_group\n\n        cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())\n        dS_cluster_phase = Int32(0)\n\n        tile_scheduler = TileSchedulerCls()\n        work_tile = tile_scheduler.initial_work_tile_info()\n        while work_tile.is_valid_tile:\n            n_block, head_idx, batch_idx, _ = work_tile.tile_idx\n            seqlen = SeqlenInfoCls(batch_idx)  # must be seqlen_k\n            m_block_min, m_block_max = block_info.get_m_block_min_max(\n                seqlen, n_block // self.cluster_shape_mnk[0]\n            )\n\n            if const_expr(self.use_block_sparsity):\n                block_iter_count = get_total_q_block_count_bwd(\n                    blocksparse_tensors,\n                    batch_idx,\n                    head_idx,\n                    n_block,\n                    subtile_factor=self.subtile_factor,\n                    m_block_max=m_block_max,\n                )\n                process_tile = block_iter_count > Int32(0)\n            else:\n                block_iter_count = m_block_max - m_block_min\n                process_tile = (\n                    const_expr(not self.is_local and not self.is_varlen_q)\n                    or m_block_min < m_block_max\n                )\n\n            if const_expr(self.use_2cta_instrs and self.tile_hdim == 192):\n                if is_leader_cta and process_tile:\n                    accumulate_dK = False\n                    accumulate_dV = False\n\n                    # -----------------------------------------------------------\n                    ###### MAIN LOOP\n                    # -----------------------------------------------------------\n                    # 1. S.T  = K    @ Q.T\n                    # 2. dP.T = V    @ dO.T\n                    # 3. dK   = dS.T @ Q\n                    # 4. dV   = P.T  @ dO\n                    # 5. dQ   = dS   @ K\n\n                    main_loop_iters = m_block_max - m_block_min\n\n                    # empty waits\n                    # pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc)\n                    # pipeline_dP.sync_object_empty.wait(0, producer_phase_acc)\n\n                    for _ in cutlass.range(main_loop_iters, unroll=1):\n                        # 1) S.T = K @ Q.T\n                        pipeline_Q.consumer_wait(consumer_state_Q)\n                        pipeline_dQ.sync_object_empty.wait(\n                            0, producer_phase_acc\n                        )  # dQ tmem overlaps with S\n                        mma_qk_fn(B_idx=consumer_state_Q.index)\n                        pipeline_S_P.sync_object_full.arrive(\n                            0, pipeline_S_P.producer_mask, cta_group\n                        )\n                        pipeline_Q.consumer_release(consumer_state_Q)\n                        consumer_state_Q.advance()\n\n                        producer_phase_acc ^= 1\n\n                        # 2) dP.T = V @ dO.T\n                        pipeline_dO.consumer_wait(consumer_state_dO)\n                        pipeline_S_P.sync_object_empty.wait(\n                            0, producer_phase_acc\n                        )  # dP tmem overlaps with S\n                        mma_dov_fn(B_idx=consumer_state_dO.index)\n                        pipeline_dP.sync_object_full.arrive(0, pipeline_dP.producer_mask, cta_group)\n                        pipeline_dO.consumer_release(consumer_state_dO)\n                        consumer_state_dO.advance()\n\n                        # 3) dK = dS.T @ Q\n                        pipeline_Q.consumer_wait(consumer_state_Q)\n                        pipeline_dP.sync_object_empty.wait(0, producer_phase_acc)  # dP -> dS\n                        mma_dsq_fn(B_idx=consumer_state_Q.index, zero_init=not accumulate_dK)\n                        pipeline_Q.consumer_release(consumer_state_Q)\n                        consumer_state_Q.advance()\n                        accumulate_dK = True\n\n                        # 4) dV = P.T @ dO\n                        # Note: if dS is written to tmem, P must be written to tmem\n                        pipeline_dO.consumer_wait(consumer_state_dO)\n                        mma_pdo_fn(B_idx=consumer_state_dO.index, zero_init=not accumulate_dV)\n                        pipeline_dO.consumer_release(consumer_state_dO)\n                        consumer_state_dO.advance()\n                        accumulate_dV = True\n\n                        # 5) dQ = dS @ K\n                        pipeline_dS.consumer_wait(consumer_state_dS)\n                        cute.arch.mbarrier_wait(dS_cluster_leader_mbar_ptr, phase=dS_cluster_phase)\n                        mma_dsk_fn()\n                        pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group)\n                        pipeline_dS.consumer_release(consumer_state_dS)\n                        consumer_state_dS.advance()\n                        dS_cluster_phase ^= 1\n\n                    # signal to the epilogue that dV is ready\n                    pipeline_dKV.sync_object_empty.wait(0, producer_phase_dKV)\n                    pipeline_dKV.sync_object_full.arrive(0, pipeline_dKV.producer_mask, cta_group)\n                    # signal to the epilogue that dK is ready\n                    pipeline_dKV.sync_object_empty.wait(1, producer_phase_dKV)\n                    pipeline_dKV.sync_object_full.arrive(1, pipeline_dKV.producer_mask, cta_group)\n                    producer_phase_dKV ^= 1\n            elif const_expr(self.use_2cta_instrs):\n                if is_leader_cta and process_tile:\n                    accumulate_dK = False\n                    # -----------------------------------------------------------\n                    ###### Prologue\n                    # -----------------------------------------------------------\n                    # 1. S  = Q0 @ K.T\n                    # 2. dP = V @ dOt.T\n                    # 3. dV = P @ dO\n\n                    # 1) S = K @ Q\n                    pipeline_Q.consumer_wait(consumer_state_Q)\n                    pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc)\n                    mma_qk_fn(B_idx=consumer_state_Q.index)\n                    pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group)\n                    pipeline_Q.consumer_release(consumer_state_Q)\n                    consumer_state_Q.advance()\n\n                    # 2) dP = V @ dOt.T\n                    pipeline_dO.consumer_wait(consumer_state_dO)\n                    pipeline_dP.sync_object_empty.wait(0, producer_phase_acc)\n                    mma_dov_fn(B_idx=consumer_state_dO.index)\n                    pipeline_dP.sync_object_full.arrive(0, pipeline_dP.producer_mask, cta_group)\n\n                    # 3) dV = P.T @ dO\n                    producer_phase_acc ^= 1\n                    pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc)\n                    mma_pdo_fn(B_idx=consumer_state_dO.index, zero_init=True)\n                    pipeline_dO.consumer_release(consumer_state_dO)\n                    consumer_state_dO.advance()\n\n                    pipeline_Kt.consumer_wait(consumer_state_Kt)\n                    # -----------------------------------------------------------\n                    ###### MAIN LOOP\n                    # -----------------------------------------------------------\n                    # 1. S.T  = K    @ Q.T\n                    # 2. dK   = dS.T @ Q\n                    # 3. dP.T = V    @ dO.T\n                    # 4. dQ   = dS   @ K\n                    # 5. dV   = P.T  @ dO\n\n                    main_loop_iters = (\n                        block_iter_count - 1\n                        if const_expr(self.use_block_sparsity)\n                        else m_block_max - m_block_min - 1\n                    )\n\n                    for _ in cutlass.range(main_loop_iters, unroll=1):\n                        # (1) S.T = K @ Q.T (next)\n                        pipeline_Q.consumer_wait(consumer_state_Q)\n                        pipeline_dQ.sync_object_empty.wait(0, producer_phase_dQ)\n                        mma_qk_fn(B_idx=consumer_state_Q.index)\n                        pipeline_S_P.sync_object_full.arrive(\n                            0, pipeline_S_P.producer_mask, cta_group\n                        )\n                        pipeline_Q.consumer_release(consumer_state_Q)\n                        consumer_state_Q.advance()\n\n                        # pipeline_dS.consumer_wait(consumer_state_dS)\n                        # (2) dK += dS.T @ Q (cur)\n                        pipeline_Qt.consumer_wait(consumer_state_Qt)\n                        pipeline_dP.sync_object_empty.wait(0, producer_phase_acc)  # dP -> dS\n                        mma_dsq_fn(B_idx=consumer_state_Qt.index, zero_init=not accumulate_dK)\n                        accumulate_dK = True\n                        pipeline_Qt.consumer_release(consumer_state_Qt)\n                        consumer_state_Qt.advance()\n\n                        # (3) dP.T = V @ dO.T (next)\n                        pipeline_dO.consumer_wait(consumer_state_dO)\n                        mma_dov_fn(B_idx=consumer_state_dO.index)\n                        pipeline_dP.sync_object_full.arrive(0, pipeline_dP.producer_mask, cta_group)\n\n                        # (5) dQ = dS @ K (cur)\n                        pipeline_dS.consumer_wait(consumer_state_dS)\n                        cute.arch.mbarrier_wait(dS_cluster_leader_mbar_ptr, phase=dS_cluster_phase)\n                        mma_dsk_fn()\n                        pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group)\n                        pipeline_dS.consumer_release(consumer_state_dS)\n                        consumer_state_dS.advance()\n                        dS_cluster_phase ^= 1\n                        producer_phase_dQ ^= 1\n\n                        # (4) dV += P.T @ dO (next)\n                        producer_phase_acc ^= 1\n                        pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc)  # S -> P\n                        mma_pdo_fn(B_idx=consumer_state_dO.index, zero_init=False)\n                        pipeline_dO.consumer_release(consumer_state_dO)\n                        consumer_state_dO.advance()\n\n                    pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group)\n\n                    # signal to the epilogue that dV is ready\n                    pipeline_dKV.sync_object_empty.wait(0, producer_phase_dKV)\n                    pipeline_dKV.sync_object_full.arrive(0, pipeline_dKV.producer_mask, cta_group)\n                    pipeline_dKV.sync_object_empty.wait(1, producer_phase_dKV)\n\n                    # -----------------------------------------------------------\n                    # Tail: Remaining dK and dQ\n                    # -----------------------------------------------------------\n                    # pipeline_dS.consumer_wait(consumer_state_dS)\n                    # dK += dS.T @ Q\n                    pipeline_Qt.consumer_wait(consumer_state_Qt)\n                    pipeline_dP.sync_object_empty.wait(0, producer_phase_acc)  # dP -> dS\n                    mma_dsq_fn(B_idx=consumer_state_Qt.index, zero_init=not accumulate_dK)\n                    pipeline_Qt.consumer_release(consumer_state_Qt)\n                    consumer_state_Qt.advance()\n                    # signal to the epilogue that dK is ready\n                    pipeline_dKV.sync_object_full.arrive(1, pipeline_dKV.producer_mask, cta_group)\n                    producer_phase_dKV ^= 1\n\n                    # dQ = dS @ K\n                    pipeline_dS.consumer_wait(consumer_state_dS)\n                    cute.arch.mbarrier_wait(dS_cluster_leader_mbar_ptr, phase=dS_cluster_phase)\n                    pipeline_dQ.sync_object_empty.wait(0, producer_phase_dQ)\n                    mma_dsk_fn()\n                    pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group)\n                    pipeline_dS.consumer_release(consumer_state_dS)\n                    pipeline_Kt.consumer_release(consumer_state_Kt)\n                    consumer_state_dS.advance()\n                    consumer_state_Kt.advance()\n                    dS_cluster_phase ^= 1\n                    producer_phase_dQ ^= 1\n\n                    producer_phase_acc ^= 1\n            else:\n                if is_leader_cta and process_tile:\n                    accumulate_dK = False\n                    # -----------------------------------------------------------\n                    ###### Prologue\n                    # -----------------------------------------------------------\n                    # 1. S  = Q0 @ K.T\n                    # 2. dP = V @ dOt.T\n                    # 3. dV = P @ dO\n\n                    # 1) S = K @ Q\n                    handle_Q = pipeline_Q_consumer.wait_and_advance()\n                    pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc)\n                    mma_qk_fn(B_idx=handle_Q.index)\n                    pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group)\n\n                    # 2) dP = V @ dOt.T\n                    pipeline_dO.consumer_wait(consumer_state_dO)\n                    pipeline_dP.sync_object_empty.wait(0, producer_phase_acc)\n                    pipeline_dQ.sync_object_empty.wait(0, producer_phase_acc)\n                    mma_dov_fn(B_idx=consumer_state_dO.index)\n                    pipeline_dP.sync_object_full.arrive(0, pipeline_dP.producer_mask, cta_group)\n\n                    producer_phase_acc ^= 1\n                    # 3) dV = P.T @ dO\n                    pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc)\n                    mma_pdo_fn(B_idx=consumer_state_dO.index, zero_init=True)\n                    pipeline_dO.consumer_release(consumer_state_dO)\n                    consumer_state_dO.advance()\n\n                    # -----------------------------------------------------------\n                    ###### MAIN LOOP\n                    # -----------------------------------------------------------\n                    # 1. S  = K    @ Q.T\n                    # 2. dQ = dS   @ K\n                    # 3. dK = dS.T @ Q\n                    # 4. dP = V    @ dOt.T\n                    # 5. dV = P.T  @ dO\n\n                    # For block sparsity, we use block_iter_count; for dense, use m_block range\n                    # MMA doesn't need actual m_block indices, just the iteration count\n                    main_loop_iters = (\n                        block_iter_count - 1\n                        if const_expr(self.use_block_sparsity)\n                        else m_block_max - m_block_min - 1\n                    )\n\n                    handle_Q_next = handle_Q\n                    for _ in cutlass.range(main_loop_iters, unroll=1):\n                        # (1) S.T = K @ Q.T\n                        handle_Q_next = pipeline_Q_consumer.wait_and_advance()\n                        mma_qk_fn(B_idx=handle_Q_next.index)\n                        pipeline_S_P.sync_object_full.arrive(\n                            0, pipeline_S_P.producer_mask, cta_group\n                        )\n\n                        # (2) dK += dS.T @ Q\n                        pipeline_dS.consumer_wait(consumer_state_dS)\n                        mma_dsq_fn(B_idx=handle_Q.index, zero_init=not accumulate_dK)\n                        accumulate_dK = True\n                        handle_Q.release()\n\n                        # (3) dQ = dS @ K\n                        mma_dsk_fn()\n                        pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group)\n                        pipeline_dS.consumer_release(consumer_state_dS)\n                        consumer_state_dS.advance()\n\n                        # (4) dP = V @ dO.T\n                        pipeline_dO.consumer_wait(consumer_state_dO)\n                        pipeline_dQ.sync_object_empty.wait(0, producer_phase_acc)\n                        mma_dov_fn(B_idx=consumer_state_dO.index)\n                        pipeline_dP.sync_object_full.arrive(0, pipeline_dP.producer_mask, cta_group)\n\n                        # (5) dV += P.T @ dO\n                        producer_phase_acc ^= 1\n                        pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc)\n                        mma_pdo_fn(B_idx=consumer_state_dO.index, zero_init=False)\n                        pipeline_dO.consumer_release(consumer_state_dO)\n                        consumer_state_dO.advance()\n\n                        handle_Q = handle_Q_next\n\n                    pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group)\n\n                    # signal to the epilogue that dV is ready\n                    # pipeline_dKV.producer_acquire(producer_state_dKV)\n                    pipeline_dKV.sync_object_empty.wait(0, producer_phase_dKV)\n                    # pipeline_dKV.producer_commit(producer_state_dKV)\n                    pipeline_dKV.sync_object_full.arrive(0, pipeline_dKV.producer_mask, cta_group)\n                    # producer_state_dKV.advance()\n                    # pipeline_dKV.producer_acquire(producer_state_dKV)\n                    pipeline_dKV.sync_object_empty.wait(1, producer_phase_dKV)\n\n                    # -----------------------------------------------------------\n                    # Tail: Remaining dK and dQ\n                    # -----------------------------------------------------------\n                    # 1) dK += dS.T @ Q\n                    pipeline_dS.consumer_wait(consumer_state_dS)\n                    mma_dsq_fn(B_idx=handle_Q.index, zero_init=not accumulate_dK)\n                    # signal to the epilogue that dK is ready\n                    pipeline_dKV.sync_object_full.arrive(1, pipeline_dKV.producer_mask, cta_group)\n                    producer_phase_dKV ^= 1\n\n                    # 2) dQ = dS @ K\n                    mma_dsk_fn()\n                    pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group)\n                    handle_Q.release()\n                    pipeline_dS.consumer_release(consumer_state_dS)\n                    consumer_state_dS.advance()\n\n                    producer_phase_acc ^= 1\n            tile_scheduler.advance_to_next_work()\n            work_tile = tile_scheduler.get_current_work()\n        # Currently it hangs if we have this S_P.producer_tail, will need to understand why\n        # pipeline_S_P.producer_tail(producer_state_S_P)\n        # pipeline_dP.producer_tail(producer_state_dP)\n        # pipeline_dKV.producer_tail(producer_state_dKV)\n        # pipeline_dQ.producer_tail(producer_state_dQ)\n\n    @cute.jit\n    def split_wg(\n        self,\n        t: cute.Tensor,\n        wg_idx: cutlass.Int32,\n        num_wg: cutlass.Constexpr[int],\n    ):\n        reduced_shape = cute.product_each(t.shape)\n        rank = len(reduced_shape)\n        if const_expr(reduced_shape[1] > 1):\n            assert rank >= 2, \"Need rank >= 2 for t in split_wg\"\n            t = cute.logical_divide(t, (reduced_shape[0], reduced_shape[1] // num_wg))\n            coord = (None, (None, wg_idx)) + (None,) * (rank - 2)\n        else:\n            assert rank >= 3, \"Need rank >= 3 for t in split_wg\"\n            if const_expr(rank == 3):\n                t = cute.logical_divide(\n                    t, (reduced_shape[0], reduced_shape[1], reduced_shape[2] // num_wg)\n                )\n                coord = (\n                    None,\n                    None,\n                    (None, wg_idx),\n                ) + (None,) * (rank - 3)\n            else:\n                t = cute.logical_divide(\n                    t,\n                    (\n                        reduced_shape[0],\n                        reduced_shape[1],\n                        reduced_shape[2],\n                        reduced_shape[3] // num_wg,\n                    ),\n                )\n                coord = (\n                    None,\n                    None,\n                    None,\n                    (None, wg_idx),\n                ) + (None,) * (rank - 4)\n        return t[coord]\n\n    @cute.jit\n    def apply_score_mod(\n        self,\n        tSrS_t2r,\n        thr_copy_t2r,\n        thr_mma_S,\n        batch_idx,\n        head_idx,\n        m_block,\n        n_block,\n        softmax_scale,\n        seqlen_info,\n        aux_tensors=None,\n        fastdiv_mods=(None, None),\n    ):\n        \"\"\"Apply forward score modification for SM100 backward pass.\"\"\"\n        # In bwd, S is computed as K @ Q.T so dimensions are (tile_n, tile_m)\n        cS = cute.make_identity_tensor((self.tile_n, self.tile_m))\n        cS = cute.domain_offset((n_block * self.tile_n, m_block * self.tile_m), cS)\n        tScS = thr_mma_S.partition_C(cS)\n        tScS_idx = thr_copy_t2r.partition_D(tScS)\n\n        apply_score_mod_inner(\n            tSrS_t2r,\n            tScS_idx,\n            self.score_mod,\n            batch_idx,\n            head_idx,\n            softmax_scale,\n            self.vec_size,\n            self.qk_acc_dtype,\n            aux_tensors,\n            fastdiv_mods,\n            seqlen_info,\n            constant_q_idx=None,\n            qhead_per_kvhead=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,\n            transpose_indices=True,\n        )\n\n    @cute.jit\n    def apply_score_mod_bwd(\n        self,\n        grad_tensor,\n        score_tensor,\n        index_tensor,\n        batch_idx,\n        head_idx,\n        softmax_scale,\n        seqlen_info,\n        aux_tensors=None,\n        fastdiv_mods=(None, None),\n    ):\n        \"\"\"Apply backward score modification (joint graph) for SM100.\"\"\"\n        apply_score_mod_bwd_inner(\n            grad_tensor,\n            score_tensor,\n            index_tensor,\n            self.score_mod_bwd,\n            batch_idx,\n            head_idx,\n            softmax_scale,\n            self.vec_size,\n            self.qk_acc_dtype,\n            aux_tensors,\n            fastdiv_mods,\n            seqlen_info,\n            constant_q_idx=None,\n            qhead_per_kvhead=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,\n            transpose_indices=True,\n        )\n\n    @cute.jit\n    def compute_loop(\n        self,\n        thr_mma_S: cute.core.ThrMma,\n        thr_mma_dP: cute.core.ThrMma,\n        thr_mma_dV: cute.core.ThrMma,\n        thr_mma_dK: cute.core.ThrMma,\n        tStS: cute.Tensor,\n        tdPtdP: cute.Tensor,\n        tdVtdV: cute.Tensor,\n        tdKtdK: cute.Tensor,\n        sLSE: cute.Tensor,\n        sdPsum: cute.Tensor,\n        mdV: cute.Tensor,\n        mdK: cute.Tensor,\n        sdS: cute.Tensor,\n        sdS_xchg: cute.Tensor,\n        pipeline_LSE: PipelineAsync,\n        pipeline_dPsum: PipelineAsync,\n        pipeline_S_P: PipelineAsync,\n        pipeline_dS: PipelineAsync,\n        pipeline_dKV: PipelineAsync,\n        pipeline_dP: PipelineAsync,\n        dS_cluster_empty_mbar_ptr: cute.Pointer,\n        dS_cluster_full_mbar_ptr: cute.Pointer,\n        dQaccum_empty_mbar_ptr: cute.Pointer,\n        softmax_scale: cutlass.Float32,\n        softmax_scale_log2: cutlass.Float32,\n        block_info: BlockInfo,\n        SeqlenInfoCls: Callable,\n        AttentionMaskCls: Callable,\n        TileSchedulerCls: Callable,\n        sdV: Optional[cute.Tensor],\n        sdK: Optional[cute.Tensor],\n        mdV_tma_tensor: Optional[cute.Tensor],\n        mdK_tma_tensor: Optional[cute.Tensor],\n        tma_atom_dV: Optional[cute.CopyAtom],\n        tma_atom_dK: Optional[cute.CopyAtom],\n        tiled_copy_r2s_dKV: Optional[cute.TiledCopy],\n        mdK_semaphore: Optional[cute.Tensor],\n        mdV_semaphore: Optional[cute.Tensor],\n        aux_tensors: Optional[list] = None,\n        fastdiv_mods=(None, None),\n        blocksparse_tensors: Optional[BlockSparseTensors] = None,\n    ):\n        sLSE_2D = cute.make_tensor(\n            sLSE.iterator,\n            cute.make_layout(\n                (self.tile_m, self.tile_n, self.Q_stage),\n                stride=(1, 0, cute.round_up(self.tile_m, 64)),\n            ),\n        )\n        sdPsum_2D = cute.make_tensor(\n            sdPsum.iterator,\n            cute.make_layout(\n                (self.tile_m, self.tile_n, self.dO_stage),\n                stride=(1, 0, cute.round_up(self.tile_m, 64)),\n            ),\n        )\n        # if const_expr(self.SdP_swapAB):\n        if const_expr(True):\n            sLSE_2D = layout_utils.transpose_view(sLSE_2D)\n            sdPsum_2D = layout_utils.transpose_view(sdPsum_2D)\n\n        # tix: [128...384]  8 warps\n        warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())  # 4-11\n        tidx = cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.compute_warp_ids))\n        # tidx = cute.arch.thread_idx()[0] - (cute.arch.WARP_SIZE * self.compute_warp_ids[0])\n        dp_idx = tidx % 128\n        num_wg = len(self.compute_warp_ids) // 4  # 2\n        # wg_idx:\n        # 0: [256...384]\n        # 1: [128...256]\n\n        tileP_f32_like = self.cta_tiler[1] // 32 * self.v_dtype.width\n        # tStS has shape ((128, 128), 1, 1), tStP has shape ((128, 64), 1, 1)\n        # tP overlap with tS\n        tStP = cute.composition(tStS, (cute.make_layout((self.tile_n, tileP_f32_like)), 1, 1))\n        tStP = cute.make_tensor(tStS.iterator, tStP.layout)  # Otherwise the tmem address is wrong\n        tScS = thr_mma_S.partition_C(cute.make_identity_tensor(self.mma_tiler_kq[:2]))\n        tScP = cute.composition(tScS, (cute.make_layout((self.tile_n, tileP_f32_like)), 1, 1))\n        # tdS overlap with tdP\n        tdPtdS = cute.composition(tdPtdP, (cute.make_layout((self.tile_n, tileP_f32_like)), 1, 1))\n        tdPcdP = thr_mma_dP.partition_C(cute.make_identity_tensor(self.mma_tiler_vdo[:2]))\n        tdPcdS = cute.composition(tdPcdP, (cute.make_layout((self.tile_n, tileP_f32_like)), 1, 1))\n\n        # 2-CTA assumes: repetiton should always be 32 & 16\n        tmem_load_atom = cute.make_copy_atom(\n            tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32\n        )\n        tmem_store_atom = cute.make_copy_atom(\n            tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(16)), Float32\n        )\n\n        # tmem -> rmem\n        thr_copy_t2r = copy_utils.make_tmem_copy(tmem_load_atom, num_wg).get_slice(tidx)\n        tStS_t2r = thr_copy_t2r.partition_S(tStS)  # (((32, 32), 1), 2, 1, 1)\n        tdPtdP_t2r = thr_copy_t2r.partition_S(tdPtdP)\n        tScS_t2r = thr_copy_t2r.partition_D(tScS)  # ((32, 1), 2, 1, 1)\n        t0ScS_t2r = thr_copy_t2r.get_slice(0).partition_D(tScS)  # ((32, 1), 2, 1, 1)\n        # ((32, 1), 2, 1, 1, STAGE)\n        tSsLSE = thr_copy_t2r.partition_D(thr_mma_S.partition_C(sLSE_2D))\n        tSsdPsum = thr_copy_t2r.partition_D(thr_mma_dP.partition_C(sdPsum_2D))\n        # rmem -> tmem\n        thr_copy_r2t = copy_utils.make_tmem_copy(tmem_store_atom, num_wg).get_slice(tidx)\n        tScP_r2t = thr_copy_r2t.partition_S(tScP)\n        tStP_r2t = thr_copy_r2t.partition_D(tStP)\n        tdPcdS_r2t = thr_copy_r2t.partition_S(tdPcdS)\n        tdPtdS_r2t = thr_copy_r2t.partition_D(tdPtdS)\n        # rmem -> smem\n        # This part is a bit iffy, we might be making a lot of assumptions here\n        copy_atom_r2s = sm100_utils_basic.get_smem_store_op(\n            LayoutEnum.ROW_MAJOR, self.ds_dtype, Float32, thr_copy_t2r\n        )\n        thr_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, thr_copy_t2r).get_slice(tidx)\n\n        # We assume the swizzle (i.e. layout.inner) stays the same\n        sdS_epi_layout = sm100_utils_basic.make_smem_layout_epi(\n            self.ds_dtype, LayoutEnum.ROW_MAJOR, (self.tile_n, self.tile_m), 1\n        )\n        sdS_layout = cute.slice_(sdS_epi_layout.outer, (None, None, 0))  # ((8,16), (64,2))\n        # Need to group into 1 mode to be compatible w thr_copy_r2s\n        sdS_layout = cute.make_layout((sdS_layout.shape,), stride=(sdS_layout.stride,))\n        sdS_epi = cute.make_tensor(sdS.iterator, sdS_layout)\n        tRS_sdS = thr_copy_r2s.partition_D(sdS_epi)\n\n        if const_expr(self.use_2cta_instrs):\n            sdS_xchg_epi = cute.make_tensor(\n                cute.recast_ptr(sdS_xchg.iterator, sdS_epi_layout.inner), sdS_layout\n            )\n            tRS_sdS_xchg = thr_copy_r2s.partition_D(sdS_xchg_epi)\n\n        cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())\n        dS_cluster_empty_phase = Int32(1)\n        # 2-CTA: CTA 0 exchanges stage 1 (bottom half), CTA 1 exchanges stage 0 (top half)\n        exchange_stage = cta_rank_in_cluster ^ 1 if const_expr(self.use_2cta_instrs) else Int32(0)\n\n        consumer_state_S_P_dP = pipeline.make_pipeline_state(  # Our impl has shortcut for stage==1\n            cutlass.pipeline.PipelineUserType.Consumer, 1\n        )\n        # consumer_phase_S_P_dP = Int32(0)\n        producer_state_dS = pipeline.make_pipeline_state(  # Our impl has shortcut for stage==1\n            cutlass.pipeline.PipelineUserType.Producer, 1\n        )\n        consumer_state_dKV = cutlass.pipeline.make_pipeline_state(\n            cutlass.pipeline.PipelineUserType.Consumer, 2\n        )\n        consumer_state_LSE = cutlass.pipeline.make_pipeline_state(\n            cutlass.pipeline.PipelineUserType.Consumer, self.Q_stage\n        )\n        consumer_state_dPsum = pipeline.make_pipeline_state(\n            cutlass.pipeline.PipelineUserType.Consumer, self.dO_stage\n        )\n\n        tile_scheduler = TileSchedulerCls()\n        work_tile = tile_scheduler.initial_work_tile_info()\n        while work_tile.is_valid_tile:\n            n_block, head_idx, batch_idx, _ = work_tile.tile_idx\n            seqlen = SeqlenInfoCls(batch_idx)\n            m_block_min, m_block_max = block_info.get_m_block_min_max(\n                seqlen, n_block // self.cluster_shape_mnk[0]\n            )\n            mask = AttentionMaskCls(seqlen)\n            n_block_for_cluster = n_block // self.cta_group_size\n            # TODO: condition mask_seqlen\n            mask_fn = partial(\n                mask.apply_mask_sm100_transposed,\n                tScS_t2r=tScS_t2r,\n                t0ScS_t2r=t0ScS_t2r,\n                n_block=n_block_for_cluster,\n                mask_seqlen=True,\n                mask_causal=self.is_causal,\n                mask_local=self.is_local,\n                mask_mod=self.mask_mod,\n                batch_idx=batch_idx,\n                head_idx=head_idx,\n                aux_tensors=aux_tensors,\n                fastdiv_mods=fastdiv_mods,\n            )\n\n            # prefetch_LSE = not self.is_causal\n            prefetch_LSE = False\n            # some tiles might be empty due to block sparsity\n            if const_expr(self.use_block_sparsity):\n                (\n                    curr_q_cnt,\n                    curr_q_idx,\n                    curr_full_cnt,\n                    curr_full_idx,\n                    loop_count,\n                ) = get_block_sparse_iteration_info_bwd(\n                    blocksparse_tensors,\n                    batch_idx,\n                    head_idx,\n                    n_block,\n                    subtile_factor=self.subtile_factor,\n                    m_block_max=m_block_max,\n                )\n                process_tile = loop_count > Int32(0)\n            else:\n                process_tile = (\n                    const_expr(not self.is_local and not self.is_varlen_q)\n                    or m_block_min < m_block_max\n                )\n                loop_count = m_block_max - m_block_min\n\n            # Mainloop\n            # Block sparsity: iterate over sparse m_block count and derive actual m_block\n            # from Q_IDX/FULL_Q_IDX tensors. Dense: iterate m_block_min..m_block_max directly.\n            for iter_idx in cutlass.range(loop_count, unroll=1):\n                if const_expr(self.use_block_sparsity):\n                    m_block, is_full_block = get_m_block_from_iter_bwd(\n                        iter_idx,\n                        curr_q_cnt,\n                        curr_q_idx,\n                        curr_full_cnt,\n                        curr_full_idx,\n                        subtile_factor=self.subtile_factor,\n                        m_block_max=m_block_max,\n                    )\n                    m_block_oob = m_block >= m_block_max\n                else:\n                    m_block = m_block_min + iter_idx\n                    m_block_oob = False\n                    is_full_block = False\n                # Prefetch 1 stage of LSE\n                pipeline_LSE.consumer_wait(consumer_state_LSE)\n                tSrLSE_s2r = cute.make_fragment(tScS_t2r[None, 0, 0, 0].shape, Float32)\n                if const_expr(prefetch_LSE and not self.shuffle_LSE):\n                    cute.autovec_copy(tSsLSE[None, 0, 0, 0, consumer_state_LSE.index], tSrLSE_s2r)\n\n                pipeline_S_P.consumer_wait(consumer_state_S_P_dP)\n                # pipeline_S_P.sync_object_full.wait(0, consumer_phase_S_P_dP)\n                #### TMEM->RMEM (Load S from TMEM)\n                tSrS_t2r = cute.make_fragment(tScS_t2r.shape, Float32)\n                cute.copy(thr_copy_t2r, tStS_t2r, tSrS_t2r)\n\n                if const_expr(self.tile_hdim == 192):\n                    # Signal S tmem load completion using pipeline_S_P when hdim 192\n                    # dP is overlapped with S\n                    cute.arch.fence_view_async_tmem_load()\n                    with cute.arch.elect_one():\n                        pipeline_S_P.consumer_release(consumer_state_S_P_dP)\n                elif const_expr(self.use_2cta_instrs and self.tile_hdim <= 128):\n                    # Signal S tmem load completion using pipeline_dS when 2cta hdim 128\n                    # dQ is overlapped with S\n                    if iter_idx > 0:\n                        cute.arch.fence_view_async_tmem_load()\n                        with cute.arch.elect_one():\n                            pipeline_dS.producer_commit(producer_state_dS)\n                        producer_state_dS.advance()\n\n                if const_expr(self.score_mod_bwd is not None):\n                    tSrS_pre = cute.make_fragment_like(tSrS_t2r)\n                    cute.autovec_copy(tSrS_t2r, tSrS_pre)\n\n                if const_expr(self.score_mod is not None):\n                    # Apply score_mod FIRST -> matches forward\n                    self.apply_score_mod(\n                        tSrS_t2r,\n                        thr_copy_t2r,\n                        thr_mma_S,\n                        batch_idx,\n                        head_idx,\n                        m_block,\n                        n_block,\n                        softmax_scale,\n                        seqlen,\n                        aux_tensors,\n                        fastdiv_mods,\n                    )\n\n                #### APPLY MASK (after score_mod, matching forward pass order)\n                check_m_boundary = (m_block + 1) * self.tile_m > seqlen.seqlen_q\n                mask_fn(\n                    tSrS_t2r,\n                    m_block=m_block,\n                    is_full_block=is_full_block,\n                    check_m_boundary=check_m_boundary,\n                )\n                num_stages = cute.size(tScS_t2r, mode=[1])\n                # ---------------------------------------------\n                #### P = exp(S - LSE)\n                # ---------------------------------------------\n                lane_idx = cute.arch.lane_idx()\n                tSrP_r2t_f32 = cute.make_fragment(tScP_r2t.shape, Float32)  # 64\n                tSrP_r2t = cute.recast_tensor(tSrP_r2t_f32, self.q_dtype)\n                for stage in cutlass.range_constexpr(num_stages):\n                    tSrS_cur = tSrS_t2r[None, stage, 0, 0]\n                    tSsLSE_cur = tSsLSE[None, stage, 0, 0, consumer_state_LSE.index]\n                    if const_expr(not self.shuffle_LSE):\n                        if const_expr(stage > 0 or not prefetch_LSE):\n                            cute.autovec_copy(tSsLSE_cur, tSrLSE_s2r)\n                        tSrLSE = tSrLSE_s2r\n                    else:\n                        tSrLSE = tSsLSE_cur[lane_idx]\n                    for v in cutlass.range_constexpr(cute.size(tSrS_t2r, mode=[0]) // 2):\n                        if const_expr(not self.shuffle_LSE):\n                            lse_pair = (tSrLSE[2 * v], tSrLSE[2 * v + 1])\n                        else:\n                            lse_pair = (\n                                utils.shuffle_sync(tSrLSE, offset=2 * v),\n                                utils.shuffle_sync(tSrLSE, offset=2 * v + 1),\n                            )\n                        tSrS_cur[2 * v], tSrS_cur[2 * v + 1] = cute.arch.fma_packed_f32x2(\n                            ((tSrS_cur[2 * v], tSrS_cur[2 * v + 1])),\n                            (softmax_scale_log2, softmax_scale_log2),\n                            (-lse_pair[0], -lse_pair[1]),\n                        )\n                        tSrS_cur[2 * v] = cute.math.exp2(tSrS_cur[2 * v], fastmath=True)\n                        tSrS_cur[2 * v + 1] = cute.math.exp2(tSrS_cur[2 * v + 1], fastmath=True)\n                    utils.cvt_f16(tSrS_cur, tSrP_r2t[None, stage, 0, 0])\n                    if const_expr(stage == 0):\n                        cute.arch.fence_view_async_tmem_load()\n                        # Without this barrier, we could have 1 warp writing to P in tmem while\n                        # another warp is still reading S from tmem.\n                        self.compute_sync_barrier.arrive_and_wait()\n                    cute.copy(\n                        thr_copy_r2t,\n                        tSrP_r2t_f32[None, stage, None, None],\n                        tStP_r2t[None, stage, None, None],\n                    )\n\n                cute.arch.fence_view_async_tmem_store()\n                cute.arch.fence_view_async_shared()\n                self.compute_sync_barrier.arrive_and_wait()\n                if const_expr(not self.tile_hdim == 192):\n                    # Signal tmem store P completion with pipeline_S_P\n                    with cute.arch.elect_one():\n                        pipeline_S_P.consumer_release(consumer_state_S_P_dP)\n                        # pipeline_S_P.sync_object_empty.arrive(0, pipeline_S_P.consumer_mask)\n                # Normally we'd need syncwarp here since only 1 thread will signal in\n                # consumer_release, but we already have the self.compute_sync_barrier before this\n                pipeline_LSE.consumer_release(consumer_state_LSE)\n                consumer_state_LSE.advance()\n                # ---------------------------------------------\n                # dS.T = P.T * (dP.T - D)\n                # ---------------------------------------------\n                pipeline_dPsum.consumer_wait(consumer_state_dPsum)\n                pipeline_dP.consumer_wait(consumer_state_S_P_dP)\n                # pipeline_dP.sync_object_full.wait(0, consumer_phase_S_P_dP)\n                ### Now delayed to after loop\n                # consumer_state_S_P_dP.advance()\n                # consumer_phase_S_P_dP ^= 1\n\n                ##### dS.T = P.T * (dP.T - Psum)\n                for stage in cutlass.range_constexpr(num_stages):\n                    tdPrdP_t2r = cute.make_fragment(tScS_t2r[None, 0, None, None].shape, Float32)\n                    cute.copy(thr_copy_t2r, tdPtdP_t2r[None, stage, None, None], tdPrdP_t2r)\n                    cute.arch.fence_view_async_tmem_load()\n                    self.compute_sync_barrier.arrive_and_wait()\n                    tdPrdP_cur = tdPrdP_t2r[None, 0, 0]\n                    tSrS_cur = tSrS_t2r[None, stage, 0, 0]\n                    tSsdPsum_cur = tSsdPsum[None, stage, 0, 0, consumer_state_dPsum.index]\n                    if const_expr(not self.shuffle_dPsum):\n                        tSrdPsum = cute.make_fragment_like(tSsdPsum_cur, Float32)\n                        cute.autovec_copy(tSsdPsum_cur, tSrdPsum)\n                    else:\n                        tSrdPsum = tSsdPsum_cur[lane_idx]\n                    for v in cutlass.range_constexpr(cute.size(tdPrdP_t2r, mode=[0]) // 2):\n                        if const_expr(not self.shuffle_dPsum):\n                            dPsum_pair = (tSrdPsum[2 * v], tSrdPsum[2 * v + 1])\n                        else:\n                            dPsum_pair = (\n                                utils.shuffle_sync(tSrdPsum, offset=2 * v),\n                                utils.shuffle_sync(tSrdPsum, offset=2 * v + 1),\n                            )\n                        tdPrdP_cur[2 * v], tdPrdP_cur[2 * v + 1] = (\n                            quack.activation.sub_packed_f32x2(\n                                (tdPrdP_cur[2 * v], tdPrdP_cur[2 * v + 1]), dPsum_pair\n                            )\n                        )\n                        tdPrdP_cur[2 * v], tdPrdP_cur[2 * v + 1] = cute.arch.mul_packed_f32x2(\n                            (tSrS_cur[2 * v], tSrS_cur[2 * v + 1]),\n                            (tdPrdP_cur[2 * v], tdPrdP_cur[2 * v + 1]),\n                        )\n\n                    if const_expr(self.score_mod_bwd is not None):\n                        tSrS_pre_cur = tSrS_pre[None, stage, 0, 0]\n                        cS_bwd = cute.make_identity_tensor((self.tile_n, self.tile_m))\n                        cS_bwd = cute.domain_offset(\n                            (n_block * self.tile_n, m_block * self.tile_m), cS_bwd\n                        )\n                        tScS_bwd = thr_mma_S.partition_C(cS_bwd)\n                        tScS_idx_bwd = thr_copy_t2r.partition_D(tScS_bwd)\n                        tScS_idx_cur = tScS_idx_bwd[None, stage, 0, 0]\n                        self.apply_score_mod_bwd(\n                            tdPrdP_cur,\n                            tSrS_pre_cur,\n                            tScS_idx_cur,\n                            batch_idx,\n                            head_idx,\n                            softmax_scale,\n                            seqlen,\n                            aux_tensors,\n                            fastdiv_mods,\n                        )\n                        # Zero out OOB positions (kv_idx >= seqlen_k) after score_mod_bwd\n                        for i in cutlass.range(cute.size(tdPrdP_cur), unroll_full=True):\n                            kv_idx = tScS_idx_cur[i][0]\n                            tdPrdP_cur[i] = 0.0 if kv_idx >= seqlen.seqlen_k else tdPrdP_cur[i]\n\n                    tdPrdS_cvt = cute.make_fragment_like(tdPrdP_cur, self.ds_dtype)\n                    utils.cvt_f16(tdPrdP_cur, tdPrdS_cvt)\n                    if const_expr(stage == 0):\n                        pipeline_dS.producer_acquire(producer_state_dS)\n                        if const_expr(self.use_2cta_instrs):\n                            tdPrdS_xchg = cute.make_fragment_like(tdPrdS_cvt, self.ds_dtype)\n\n                    # RMEM->TMEM: always write to TMEM for MMA\n                    if const_expr(not self.use_smem_dS_for_mma_dK or self.use_2cta_instrs):\n                        tdPrdS_r2t_f32 = cute.recast_tensor(tdPrdS_cvt, Float32)\n                        cute.copy(thr_copy_r2t, tdPrdS_r2t_f32, tdPtdS_r2t[None, stage, 0, 0])\n\n                    # RMEM->SMEM: For 2-CTA, keep exchange stage in registers, write non-exchange to sdS\n                    if const_expr(self.use_2cta_instrs):\n                        if exchange_stage == stage:\n                            cute.autovec_copy(tdPrdS_cvt, tdPrdS_xchg)\n                        else:\n                            cute.autovec_copy(tdPrdS_cvt, tRS_sdS[None, stage])\n                    else:\n                        cute.autovec_copy(tdPrdS_cvt, tRS_sdS[None, stage])\n\n                if const_expr(not self.use_smem_dS_for_mma_dK):\n                    cute.arch.fence_view_async_tmem_store()\n\n                if const_expr(self.use_2cta_instrs):\n                    # use pipeline_dP to signal tmem store of dS\n                    with cute.arch.elect_one():\n                        pipeline_dP.consumer_release(consumer_state_S_P_dP)\n                consumer_state_S_P_dP.advance()\n\n                # After the loop: copy exchange registers to sdS_xchg buffer\n                if const_expr(self.use_2cta_instrs):\n                    # when hdim 192, sdQaccum overlapped with sdS_xchg\n                    if const_expr(self.tile_hdim == 192):\n                        cute.arch.mbarrier_wait(\n                            dQaccum_empty_mbar_ptr, phase=producer_state_dS.phase\n                        )\n                    cute.autovec_copy(tdPrdS_xchg, tRS_sdS_xchg[None, 0])\n\n                cute.arch.fence_view_async_shared()\n                self.compute_sync_barrier.arrive_and_wait()\n                # Normally we'd need syncwarp here since only 1 thread will signal in\n                # consumer_release, but we already have the self.compute_sync_barrier before this\n                pipeline_dPsum.consumer_release(consumer_state_dPsum)\n                consumer_state_dPsum.advance()\n                # when 2cta hdim 128, pipeline_dS also signals S tmem load completion so is deferred\n                if const_expr(not (self.use_2cta_instrs and self.tile_hdim == 128)):\n                    with cute.arch.elect_one():\n                        pipeline_dS.producer_commit(producer_state_dS)\n                    producer_state_dS.advance()\n\n                # 2-CTA: DSMEM copy from sdS_xchg to peer's sdS buffer\n                if const_expr(self.use_2cta_instrs):\n                    stage_copy_bytes = const_expr(self.tma_copy_bytes[\"dS\"] // 2)\n                    stage_copy_elems = const_expr(stage_copy_bytes // (self.ds_dtype.width // 8))\n                    if tidx == 0:\n                        peer_cta_rank_in_cluster = cta_rank_in_cluster ^ 1\n                        smem_src_ptr = sdS_xchg.iterator\n                        # Destination is peer's sdS at our CTA's offset (exchange_stage position)\n                        smem_dst_ptr = sdS.iterator + cta_rank_in_cluster * stage_copy_elems\n                        cute.arch.mbarrier_arrive_and_expect_tx(\n                            dS_cluster_full_mbar_ptr,\n                            stage_copy_bytes,\n                            peer_cta_rank_in_cluster=peer_cta_rank_in_cluster,\n                        )\n                        copy_utils.cpasync_bulk_s2cluster(\n                            smem_src_ptr,\n                            smem_dst_ptr,\n                            dS_cluster_full_mbar_ptr,\n                            stage_copy_bytes,\n                            peer_cta_rank_in_cluster=peer_cta_rank_in_cluster,\n                        )\n\n            # Final signal for dS smem store completion\n            if const_expr(self.use_2cta_instrs and self.tile_hdim == 128):\n                if process_tile:\n                    with cute.arch.elect_one():\n                        pipeline_dS.producer_commit(producer_state_dS)\n                    producer_state_dS.advance()\n\n            # Epilogue\n            # Run epilogue if we processed any m_blocks for this n_block\n            if process_tile:\n                if const_expr(not self.use_tma_store):\n                    consumer_state_dKV = self.epilogue_dKV(\n                        dp_idx,\n                        warp_idx,\n                        batch_idx,\n                        head_idx,\n                        n_block,\n                        seqlen,\n                        thr_mma_dV,\n                        thr_mma_dK,\n                        tdVtdV,\n                        tdKtdK,\n                        mdV,\n                        mdK,\n                        pipeline_dKV,\n                        consumer_state_dKV,\n                        softmax_scale,\n                    )\n                else:\n                    thr_copy_r2s_dKV = tiled_copy_r2s_dKV.get_slice(dp_idx)\n                    #### STORE dV\n                    consumer_state_dKV = self.epilogue_dK_or_dV_tma(\n                        dp_idx,\n                        batch_idx,\n                        head_idx,\n                        n_block,\n                        seqlen,\n                        thr_mma_dV,\n                        tdVtdV,\n                        mdV_tma_tensor,\n                        sdV,\n                        tma_atom_dV,\n                        thr_copy_r2s_dKV,\n                        pipeline_dKV,\n                        consumer_state_dKV,\n                        None,  # Don't scale\n                        int(NamedBarrierBwdSm100.EpilogueWG1),  # barrier_id\n                        mdV_semaphore,\n                        \"V\",\n                    )\n                    #### STORE dK\n                    consumer_state_dKV = self.epilogue_dK_or_dV_tma(\n                        dp_idx,\n                        batch_idx,\n                        head_idx,\n                        n_block,\n                        seqlen,\n                        thr_mma_dK,\n                        tdKtdK,\n                        mdK_tma_tensor,\n                        sdK,\n                        tma_atom_dK,\n                        thr_copy_r2s_dKV,\n                        pipeline_dKV,\n                        consumer_state_dKV,\n                        softmax_scale if const_expr(not self.dKV_postprocess) else None,\n                        int(NamedBarrierBwdSm100.EpilogueWG1),  # barrier_id\n                        mdK_semaphore,\n                        \"K\",\n                    )\n            # Zero dK/dV for empty tiles (local attention or block sparsity)\n            # When total_m_block_cnt == 0 for block sparsity, no Q tiles contribute to this KV tile\n            if const_expr(not self.dKV_postprocess):\n                should_zero_dKV = False\n                if const_expr(self.is_local or self.is_varlen_q):\n                    should_zero_dKV = m_block_min >= m_block_max\n                if const_expr(self.use_block_sparsity):\n                    # For block sparsity, zero when no m_blocks contribute to this n_block\n                    if not process_tile:\n                        should_zero_dKV = True\n\n                if should_zero_dKV:\n                    # For 2-CTA: use cluster-wide tile size (cta_group_size * tile_n)\n                    cluster_tile_n = self.tile_n * self.cta_group_size\n                    n_block_for_tile = n_block // self.cta_group_size\n                    gmem_tiled_copy_zero_dK = copy_utils.tiled_copy_2d(\n                        self.dk_dtype,\n                        math.gcd(64, self.tile_hdim),\n                        128,  # num_threads\n                    )\n                    gmem_tiled_copy_zero_dV = copy_utils.tiled_copy_2d(\n                        self.dv_dtype,\n                        math.gcd(64, self.tile_hdimv),\n                        128,  # num_threads\n                    )\n                    gmem_thr_copy_zero_dK = gmem_tiled_copy_zero_dK.get_slice(dp_idx)\n                    gmem_thr_copy_zero_dV = gmem_tiled_copy_zero_dV.get_slice(dp_idx)\n                    mdV_cur = seqlen.offset_batch_K(mdV, batch_idx, dim=3)[None, None, head_idx]\n                    mdK_cur = seqlen.offset_batch_K(mdK, batch_idx, dim=3)[None, None, head_idx]\n                    gdK = cute.local_tile(\n                        mdK_cur, (cluster_tile_n, self.tile_hdim), (n_block_for_tile, 0)\n                    )\n                    gdV = cute.local_tile(\n                        mdV_cur, (cluster_tile_n, self.tile_hdimv), (n_block_for_tile, 0)\n                    )\n                    tdKgdK = gmem_thr_copy_zero_dK.partition_D(gdK)\n                    tdVgdV = gmem_thr_copy_zero_dV.partition_D(gdV)\n                    cdK = cute.make_identity_tensor((cluster_tile_n, self.tile_hdim))\n                    cdV = cute.make_identity_tensor((cluster_tile_n, self.tile_hdimv))\n                    tdKcdK = gmem_thr_copy_zero_dK.partition_D(cdK)\n                    tdVcdV = gmem_thr_copy_zero_dV.partition_D(cdV)\n                    assert cute.size(tdKgdK[None, 0, 0]) == cute.size(tdVgdV[None, 0, 0])\n                    zero = cute.make_fragment_like(tdKgdK[None, 0, 0])\n                    zero.fill(0.0)\n                    if tidx < 128:\n                        for i in cutlass.range_constexpr(tdKgdK.shape[1]):\n                            row_idx = tdKcdK[0, i, 0][0]\n                            if row_idx < seqlen.seqlen_k - cluster_tile_n * n_block_for_tile:\n                                for j in cutlass.range_constexpr(tdKgdK.shape[2]):\n                                    cute.copy(gmem_tiled_copy_zero_dK, zero, tdKgdK[None, i, j])\n                    else:\n                        for i in cutlass.range_constexpr(tdVgdV.shape[1]):\n                            row_idx = tdVcdV[0, i, 0][0]\n                            if row_idx < seqlen.seqlen_k - cluster_tile_n * n_block_for_tile:\n                                for j in cutlass.range_constexpr(tdVgdV.shape[2]):\n                                    cute.copy(gmem_tiled_copy_zero_dV, zero, tdVgdV[None, i, j])\n\n            tile_scheduler.advance_to_next_work()\n            work_tile = tile_scheduler.get_current_work()\n\n    @cute.jit\n    def dQacc_reduce(\n        self,\n        mdQaccum: cute.Tensor,\n        sdQaccum: cute.Tensor,\n        thr_mma_dQ: cute.core.ThrMma,\n        tdQtdQ: cute.Tensor,\n        pipeline_dQ: PipelineAsync,\n        dQaccum_empty_mbar_ptr: Optional[cute.Pointer],\n        block_info: BlockInfo,\n        SeqlenInfoCls: Callable,\n        TileSchedulerCls: Callable,\n        mdQ_semaphore: Optional[cute.Tensor],\n        blocksparse_tensors: Optional[BlockSparseTensors] = None,\n    ):\n        num_reduce_threads = cute.arch.WARP_SIZE * len(self.reduce_warp_ids)\n        tidx = cute.arch.thread_idx()[0] % num_reduce_threads\n        warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx() % len(self.reduce_warp_ids))\n        is_tma_warp = warp_idx == 0\n        cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())\n        # TMEM -> RMEM\n        tmem_load_atom = cute.make_copy_atom(\n            tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(self.dQ_reduce_ncol_t2r)), Float32\n        )\n        thr_copy_t2r = tcgen05.make_tmem_copy(tmem_load_atom, tdQtdQ).get_slice(tidx)\n        tdQtdQ_t2r = thr_copy_t2r.partition_S(tdQtdQ)\n        tdQcdQ = thr_mma_dQ.partition_C(cute.make_identity_tensor(self.mma_tiler_dsk[:2]))\n        tdQrdQ_t2r_shape = thr_copy_t2r.partition_D(tdQcdQ).shape\n        # For 2-CTA: reduce_stage = dQaccum_reduce_stage_t2r / cta_group_size\n        expected_reduce_stages_t2r = self.dQaccum_reduce_stage_t2r // self.cta_group_size\n        assert cute.size(tdQrdQ_t2r_shape, mode=[1]) == expected_reduce_stages_t2r, (\n            \"dQaccum t2r reduce stage mismatch\"\n        )\n        expected_reduce_stages = self.dQaccum_reduce_stage // self.cta_group_size\n        # 2-CTA: CTA 0 -> (M/2, D) (stage 0, 1) & CTA 1 -> (M/2, D) (stage 2, 3)\n        stage_offset = (\n            expected_reduce_stages * cta_rank_in_cluster if const_expr(self.use_2cta_instrs) else 0\n        )\n\n        thr_copy_dQaccum_r2s = copy_utils.tiled_copy_1d(\n            self.dqaccum_dtype, num_reduce_threads, num_copy_elems=128 // self.dqaccum_dtype.width\n        ).get_slice(tidx)\n        tdQsdQ = thr_copy_dQaccum_r2s.partition_D(sdQaccum)\n\n        read_flag = const_expr(not self.deterministic)\n\n        tile_scheduler = TileSchedulerCls()\n        work_tile = tile_scheduler.initial_work_tile_info()\n        dQ_consumer_state = pipeline.make_pipeline_state(\n            cutlass.pipeline.PipelineUserType.Consumer, 1\n        )\n        dQ_tma_store_producer_state = pipeline.make_pipeline_state(\n            pipeline.PipelineUserType.Producer, self.sdQaccum_stage\n        )\n        while work_tile.is_valid_tile:\n            n_block, head_idx, batch_idx, _ = work_tile.tile_idx\n            n_block_cta_group = n_block // self.cta_group_size  # for 2cta\n            seqlen = SeqlenInfoCls(batch_idx)\n            m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block_cta_group)\n            if const_expr(not seqlen.has_cu_seqlens_q):\n                mdQaccum_cur = mdQaccum[None, head_idx, batch_idx]\n            else:\n                mdQaccum_cur = cute.domain_offset(\n                    (seqlen.padded_offset_q * self.tile_hdim,), mdQaccum[None, head_idx]\n                )\n            gdQaccum_ = cute.local_tile(mdQaccum_cur, (self.tile_m * self.tile_hdim,), (None,))\n            # (M * K / STAGE, STAGE, _)\n            gdQaccum = cute.flat_divide(\n                gdQaccum_, (self.tile_m * self.tile_hdim // self.dQaccum_reduce_stage,)\n            )\n\n            if const_expr(self.deterministic):\n                mdQ_semaphore_cur = mdQ_semaphore[None, None, head_idx, batch_idx]\n\n            # delay_semaphore_release = self.is_causal and not self.tile_hdim == 192\n            delay_semaphore_release = not self.tile_hdim == 192\n\n            # some tiles might be empty due to block sparsity\n            if const_expr(self.use_block_sparsity):\n                (\n                    curr_q_cnt,\n                    curr_q_idx,\n                    curr_full_cnt,\n                    curr_full_idx,\n                    loop_count,\n                ) = get_block_sparse_iteration_info_bwd(\n                    blocksparse_tensors,\n                    batch_idx,\n                    head_idx,\n                    n_block,\n                    subtile_factor=self.subtile_factor,\n                    m_block_max=m_block_max,\n                )\n                process_tile = loop_count > Int32(0)\n            else:\n                process_tile = (\n                    const_expr(not self.is_local and not self.is_varlen_q)\n                    or m_block_min < m_block_max\n                )\n                loop_count = m_block_max - m_block_min\n\n            # dQacc_reduce mainloop\n            # Block sparsity: iterate over sparse m_block count and derive actual m_block\n            # from Q_IDX/FULL_Q_IDX tensors. Dense: iterate m_block_min..m_block_max directly.\n            for iter_idx in cutlass.range(loop_count, unroll=1):\n                if const_expr(self.use_block_sparsity):\n                    m_block, _ = get_m_block_from_iter_bwd(\n                        iter_idx,\n                        curr_q_cnt,\n                        curr_q_idx,\n                        curr_full_cnt,\n                        curr_full_idx,\n                        subtile_factor=self.subtile_factor,\n                        m_block_max=m_block_max,\n                    )\n                    if m_block_max > 0:\n                        m_block = cutlass.min(m_block, m_block_max - 1)\n                else:\n                    m_block = m_block_min + iter_idx\n                pipeline_dQ.consumer_wait(dQ_consumer_state)\n                # TMEM -> RMEM\n                tdQrdQ_t2r = cute.make_fragment(tdQrdQ_t2r_shape, Float32)\n                cute.copy(thr_copy_t2r, tdQtdQ_t2r, tdQrdQ_t2r)\n                cute.arch.fence_view_async_tmem_load()\n                cute.arch.sync_warp()\n                with cute.arch.elect_one():\n                    pipeline_dQ.consumer_release(dQ_consumer_state)\n                dQ_consumer_state.advance()\n\n                gdQaccum_cur = gdQaccum[None, None, m_block]\n\n                tdQrdQ_shape = (\n                    self.dQ_reduce_ncol,\n                    self.tile_hdim // self.cta_group_size // self.dQ_reduce_ncol,\n                )\n                tdQrdQ = cute.make_tensor(tdQrdQ_t2r.iterator, tdQrdQ_shape)\n\n                for stage in cutlass.range_constexpr(cute.size(tdQrdQ, mode=[1])):\n                    smem_idx = dQ_tma_store_producer_state.index\n                    tdQsdQ_r2s = tdQsdQ[None, None, smem_idx]\n                    tdQrdQ_r2s = cute.make_tensor(tdQrdQ[None, stage].iterator, tdQsdQ_r2s.shape)\n                    cute.copy(thr_copy_dQaccum_r2s, tdQrdQ_r2s, tdQsdQ_r2s)\n                    # Fence and barrier to make sure shared memory store is visible to TMA store\n                    cute.arch.fence_view_async_shared()\n                    # semaphore acquire\n                    if const_expr(self.deterministic and stage == 0):\n                        if const_expr(self.spt):\n                            _, n_block_max_for_m_block = block_info.get_n_block_min_max(\n                                seqlen, m_block\n                            )\n                            lock_value = n_block_max_for_m_block - 1 - n_block_cta_group\n                        else:\n                            lock_value = n_block_cta_group\n                        barrier.wait_eq(\n                            mdQ_semaphore_cur[(m_block, None)].iterator,\n                            tidx,\n                            cta_rank_in_cluster,\n                            lock_value,\n                        )\n                    self.reduce_sync_barrier.arrive_and_wait()\n                    # Copy from shared memory to global memory\n                    if is_tma_warp:\n                        with cute.arch.elect_one():\n                            copy_utils.cpasync_reduce_bulk_add_f32(\n                                sdQaccum[None, smem_idx].iterator,\n                                gdQaccum_cur[None, stage + stage_offset].iterator,\n                                self.tma_copy_bytes[\"dQ\"] // 1,\n                            )\n                        cute.arch.cp_async_bulk_commit_group()\n                        cute.arch.cp_async_bulk_wait_group(self.sdQaccum_stage - 1, read=read_flag)\n                    self.reduce_sync_barrier.arrive_and_wait()\n                    dQ_tma_store_producer_state.advance()\n                    # Directly add to gmem, much slower\n                    # tdQgdQ = thr_copy_dQaccum_r2s.partition_D(gdQaccum[None, stage, m_block])\n                    # assert cute.size(tdQrdQ_r2s) == cute.size(tdQgdQ)\n                    # for i in cutlass.range(cute.size(tdQrdQ_r2s) // 4, unroll_full=True):\n                    #     copy_utils.atomic_add_fp32x4(\n                    #         tdQrdQ_r2s[4 * i],\n                    #         tdQrdQ_r2s[4 * i + 1],\n                    #         tdQrdQ_r2s[4 * i + 2],\n                    #         tdQrdQ_r2s[4 * i + 3],\n                    #         utils.elem_pointer(tdQgdQ, 4 * i),\n                    #     )\n                    # semaphore release for prior m_block\n                    if const_expr(self.deterministic and stage == 0 and delay_semaphore_release):\n                        if m_block > m_block_min:\n                            barrier.arrive_inc(\n                                mdQ_semaphore_cur[(m_block - 1, None)].iterator,\n                                tidx,\n                                cta_rank_in_cluster,\n                                1,\n                            )\n\n                if const_expr(self.tile_hdim == 192):\n                    if const_expr(self.sdQaccum_stage > 1):\n                        if is_tma_warp:\n                            cute.arch.cp_async_bulk_wait_group(0, read=read_flag)\n                        self.reduce_sync_barrier.arrive_and_wait()\n                    with cute.arch.elect_one():\n                        cute.arch.mbarrier_arrive(dQaccum_empty_mbar_ptr)\n\n                # semaphore release\n                # NOTE: arrive_inc calls red_release which issues membar\n                if const_expr(self.deterministic and not delay_semaphore_release):\n                    if const_expr(self.sdQaccum_stage > 1 and not self.tile_hdim == 192):\n                        if is_tma_warp:\n                            cute.arch.cp_async_bulk_wait_group(0, read=read_flag)\n                        self.reduce_sync_barrier.arrive_and_wait()\n                    barrier.arrive_inc(\n                        mdQ_semaphore_cur[m_block, None].iterator, tidx, cta_rank_in_cluster, 1\n                    )\n\n            if process_tile:\n                if is_tma_warp:\n                    cute.arch.cp_async_bulk_wait_group(0, read=read_flag)\n                self.reduce_sync_barrier.arrive_and_wait()\n                # final semaphore release\n                if const_expr(self.deterministic and delay_semaphore_release):\n                    barrier.arrive_inc(\n                        mdQ_semaphore_cur[(m_block_max - 1, None)].iterator,\n                        tidx,\n                        cta_rank_in_cluster,\n                        1,\n                    )\n\n            if const_expr(\n                self.deterministic and not self.spt and block_info.window_size_left is not None\n            ):\n                m_block_global_max = cute.ceil_div(seqlen.seqlen_q, self.tile_m)\n                for m_block in cutlass.range(m_block_max, m_block_global_max, unroll=1):\n                    barrier.arrive_inc(\n                        mdQ_semaphore_cur[(m_block, None)].iterator, tidx, cta_rank_in_cluster, 1\n                    )\n\n            tile_scheduler.advance_to_next_work()\n            work_tile = tile_scheduler.get_current_work()\n\n        if const_expr(not self.deterministic):\n            cute.arch.cp_async_bulk_wait_group(0, read=True)\n\n    @cute.jit\n    def epilogue_dKV(\n        self,\n        tidx: Int32,\n        warp_idx: Int32,\n        batch_idx: Int32,\n        head_idx: Int32,\n        n_block: Int32,\n        seqlen,\n        thr_mma_dV: cute.core.ThrMma,\n        thr_mma_dK: cute.core.ThrMma,\n        tdVtdV: cute.Tensor,\n        tdKtdK: cute.Tensor,\n        mdV: cute.Tensor,\n        mdK: cute.Tensor,\n        pipeline_dKV: PipelineAsync,\n        consumer_state_dKV: cutlass.pipeline.PipelineState,\n        softmax_scale: Float32,\n    ):\n        wg_idx = (\n            cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.compute_warp_ids))\n        ) // 128\n        num_wg = cute.arch.WARP_SIZE * len(self.compute_warp_ids) // 128\n\n        assert self.qhead_per_kvhead == 1, \"This epilogue path is only for MHA\"\n        mdV_cur = seqlen.offset_batch_K(mdV, batch_idx, dim=3)[None, None, head_idx]\n        mdK_cur = seqlen.offset_batch_K(mdK, batch_idx, dim=3)[None, None, head_idx]\n\n        tmem_load_atom = cute.make_copy_atom(\n            tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(16)), Float32\n        )\n        # dV\n        pipeline_dKV.consumer_wait(consumer_state_dKV)\n\n        tiled_tmem_ld_dV = tcgen05.make_tmem_copy(tmem_load_atom, tdVtdV)\n        thr_tmem_ld_dV = tiled_tmem_ld_dV.get_slice(tidx)\n\n        tdVtdV_t2r_p = thr_tmem_ld_dV.partition_S(tdVtdV)\n        tdVtdV_t2r = self.split_wg(tdVtdV_t2r_p, wg_idx, num_wg)\n\n        cdV = cute.make_identity_tensor((self.mma_tiler_pdo[0], self.mma_tiler_pdo[1]))\n        tdVcdV = thr_mma_dV.partition_C(cdV)\n        tdVcdV_tensor = cute.make_tensor(tdVcdV.iterator, tdVcdV.layout)\n\n        tdVcdV_t2r_p = thr_tmem_ld_dV.partition_D(tdVcdV_tensor)\n        tdVcdV_t2r = self.split_wg(tdVcdV_t2r_p, wg_idx, num_wg)\n        tdVrdV_t2r = cute.make_fragment(tdVcdV_t2r.shape, Float32)\n\n        cute.copy(thr_tmem_ld_dV, tdVtdV_t2r, tdVrdV_t2r)\n        cute.arch.fence_view_async_tmem_load()\n\n        universal_copy_bits = 128\n        atom_universal_copy = cute.make_copy_atom(\n            cute.nvgpu.CopyUniversalOp(),\n            self.dv_dtype,\n            num_bits_per_copy=universal_copy_bits,\n        )\n        tiled_gmem_store_dV = cute.make_tiled_copy(\n            atom_universal_copy,\n            layout_tv=tiled_tmem_ld_dV.layout_dst_tv_tiled,\n            tiler_mn=tiled_tmem_ld_dV.tiler_mn,\n        )\n\n        tdVrdV_r2s = cute.make_fragment(tdVrdV_t2r.shape, self.dv_dtype)\n        for i in cutlass.range_constexpr(cute.size(tdVrdV_t2r, mode=[1])):\n            dV_vec = tdVrdV_t2r[(None, i, 0, 0)].load()\n            tdVrdV_r2s[(None, i, 0, 0)].store(dV_vec.to(self.dv_dtype))\n\n        gdV = cute.local_tile(mdV_cur, (self.mma_tiler_pdo[0], self.tile_hdimv), (None, 0))\n        gdV_tile = gdV[None, None, n_block // self.cta_group_size]\n\n        tdVgdV = thr_mma_dV.partition_C(gdV_tile)\n        tdVgdV_r2g_p = thr_tmem_ld_dV.partition_D(tdVgdV)\n        tdVgdV_r2g = self.split_wg(tdVgdV_r2g_p, wg_idx, num_wg)\n\n        if tidx < seqlen.seqlen_k - self.tile_n * n_block:\n            cute.copy(tiled_gmem_store_dV, tdVrdV_r2s, tdVgdV_r2g)\n\n        cute.arch.sync_warp()\n        with cute.arch.elect_one():\n            pipeline_dKV.consumer_release(consumer_state_dKV)\n        consumer_state_dKV.advance()\n\n        # dK\n        pipeline_dKV.consumer_wait(consumer_state_dKV)\n\n        tiled_tmem_ld_dK = tcgen05.make_tmem_copy(tmem_load_atom, tdKtdK)\n        thr_tmem_ld_dK = tiled_tmem_ld_dK.get_slice(tidx)\n\n        tdKtdK_t2r_p = thr_tmem_ld_dK.partition_S(tdKtdK)\n        tdKtdK_t2r = self.split_wg(tdKtdK_t2r_p, wg_idx, num_wg)\n\n        cdK = cute.make_identity_tensor((self.mma_tiler_dsq[0], self.mma_tiler_dsq[1]))\n        tdKcdK = thr_mma_dK.partition_C(cdK)\n        tdKcdK_tensor = cute.make_tensor(tdKcdK.iterator, tdKcdK.layout)\n\n        tdKcdK_t2r_p = thr_tmem_ld_dK.partition_D(tdKcdK_tensor)\n        tdKcdK_t2r = self.split_wg(tdKcdK_t2r_p, wg_idx, num_wg)\n        tdKrdK_t2r = cute.make_fragment(tdKcdK_t2r.shape, Float32)\n\n        cute.copy(tiled_tmem_ld_dK, tdKtdK_t2r, tdKrdK_t2r)\n        cute.arch.fence_view_async_tmem_load()\n\n        universal_copy_bits = 128\n        atom_universal_copy = cute.make_copy_atom(\n            cute.nvgpu.CopyUniversalOp(),\n            self.dk_dtype,\n            num_bits_per_copy=universal_copy_bits,\n        )\n\n        tiled_gmem_store_dK = cute.make_tiled_copy(\n            atom_universal_copy,\n            layout_tv=tiled_tmem_ld_dK.layout_dst_tv_tiled,\n            tiler_mn=tiled_tmem_ld_dK.tiler_mn,\n        )\n\n        tdKrdK_r2s = cute.make_fragment(tdKrdK_t2r.shape, self.dk_dtype)\n\n        for i in cutlass.range_constexpr(cute.size(tdKrdK_t2r, mode=[1])):\n            dK_vec = tdKrdK_t2r[(None, i, 0, 0)].load() * softmax_scale\n            tdKrdK_r2s[(None, i, 0, 0)].store(dK_vec.to(self.dk_dtype))\n\n        gdK = cute.local_tile(mdK_cur, (self.mma_tiler_dsq[0], self.tile_hdim), (None, 0))\n        gdK_tile = gdK[None, None, n_block // self.cta_group_size]\n\n        tdKgdK = thr_mma_dK.partition_C(gdK_tile)\n        tdKgdK_r2g_p = thr_tmem_ld_dK.partition_D(tdKgdK)\n        tdKgdK_r2g = self.split_wg(tdKgdK_r2g_p, wg_idx, num_wg)\n\n        if tidx < seqlen.seqlen_k - self.tile_n * n_block:\n            cute.copy(tiled_gmem_store_dK, tdKrdK_r2s, tdKgdK_r2g)\n\n        cute.arch.sync_warp()\n        with cute.arch.elect_one():\n            pipeline_dKV.consumer_release(consumer_state_dKV)\n        return consumer_state_dKV\n\n    @cute.jit\n    def epilogue_dK_or_dV_tma(\n        self,\n        tidx: Int32,\n        batch_idx: Int32,\n        head_idx: Int32,\n        n_block: Int32,\n        seqlen,\n        thr_mma: cute.core.ThrMma,\n        tdKVtdKV: cute.Tensor,\n        mdKV: cute.Tensor,\n        sdKV: cute.Tensor,\n        tma_atom_dKV: cute.CopyAtom,\n        thr_copy_r2s_dKV: cute.TiledCopy,\n        pipeline_dKV: PipelineAsync,\n        consumer_state_dKV: cutlass.pipeline.PipelineState,\n        scale: Optional[Float32],\n        barrier_id: Int32,\n        mdKV_semaphore: Optional[cute.Tensor],\n        K_or_V: cutlass.Constexpr[str],\n    ) -> cutlass.pipeline.PipelineState:\n        assert K_or_V in (\"K\", \"V\")\n        tile_hdim = self.tile_hdim if const_expr(K_or_V == \"K\") else self.tile_hdimv\n        dtype = self.dk_dtype if const_expr(K_or_V == \"K\") else self.dv_dtype\n        epi_tile = self.sdK_epi_tile if const_expr(K_or_V == \"K\") else self.sdV_epi_tile\n        flat_epi_tile = (\n            self.sdK_flat_epi_tile if const_expr(K_or_V == \"K\") else self.sdV_flat_epi_tile\n        )\n        num_compute_threads = cute.arch.WARP_SIZE * len(self.compute_warp_ids)\n        wg_idx = (cute.arch.thread_idx()[0] % num_compute_threads) // 128\n        num_wg = num_compute_threads // 128\n        leader_warp = (cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4) == 0\n\n        cta_group_tile_n = const_expr(self.tile_n * self.cta_group_size)\n\n        if const_expr(not self.dKV_postprocess):\n            sdKV = sdKV[None, None, wg_idx]  # (tile_n, 64) for bf16\n        else:\n            sdKV = sdKV[None, wg_idx]  # (tile_n * 32) for fp32\n\n        # (8, tile_n / 128, 64 / 8) = (8, 1, 8) or (4, tile_n * 32 / (128 * 4)) = (4, 8)\n        tdKVsdKV_r2s = thr_copy_r2s_dKV.partition_D(sdKV)\n\n        head_idx_kv = head_idx // self.qhead_per_kvhead\n        if const_expr(not self.dKV_postprocess):\n            assert not seqlen.has_cu_seqlens_k, \"varlen uses non tma store path\"\n            mdKV_cur = mdKV[None, None, head_idx_kv, batch_idx]  # (seqlen, hdim)\n            gdKV_p = cute.local_tile(\n                mdKV_cur, (self.tile_n, tile_hdim), (n_block, 0)\n            )  # (tile_n, hdim) - per CTA\n            gdKV = self.split_wg(gdKV_p, wg_idx, num_wg)  # (tile_n, hdim / 2)\n            gdKV_epi = cute.local_tile(\n                gdKV, epi_tile, (0, None)\n            )  # (tile_n, 64, epi_stage = (hdim / 2) / 64)\n        else:\n            # n_block_group = n_block // self.cta_group_size\n            if const_expr(not seqlen.has_cu_seqlens_k):\n                mdKV_cur = mdKV[None, head_idx_kv, batch_idx]  # (seqlen * hdim)\n            else:\n                mdKV_cur = cute.domain_offset(\n                    (seqlen.padded_offset_k * tile_hdim,), mdKV[None, head_idx_kv]\n                )\n            gdKV_p = cute.local_tile(\n                mdKV_cur, (self.tile_n * tile_hdim,), (n_block,)\n            )  # (tile_n * hdim)\n            gdKV = cute.logical_divide(gdKV_p, (self.tile_n * tile_hdim // num_wg,))[\n                ((None, wg_idx),)\n            ]  # (tile_n * hdim / 2)\n            gdKV_epi = cute.flat_divide(\n                gdKV, (flat_epi_tile,)\n            )  # (tile_n * hdim / 2 / epi_stage, epi_stage)\n\n        deterministic_KV = self.deterministic and self.qhead_per_kvhead > 1\n        if const_expr(deterministic_KV):\n            mdKV_semaphore_cur = mdKV_semaphore[n_block, None, head_idx_kv, batch_idx]\n\n        if const_expr(not self.dKV_postprocess):\n            tdKVsdKV, tdKVgdKV = cpasync.tma_partition(\n                tma_atom_dKV,\n                0,  # no multicast\n                cute.make_layout(1),\n                cute.group_modes(sdKV, 0, 2),\n                cute.group_modes(gdKV_epi, 0, 2),\n            )  # (TMA) and (TMA, EPI_STAGE)\n            assert len(tdKVsdKV.shape) == 1, \"Wrong rank for SMEM fragment tdKVsdKV\"\n            assert len(tdKVgdKV.shape) == 2, \"Wrong rank for GMEM fragment tdKVgdKV\"\n            num_epi_stages = cute.size(tdKVgdKV.shape[1])\n            if const_expr(K_or_V == \"K\"):\n                assert num_epi_stages == self.num_epi_stages, \"Epi stage calculation is wrong (K)\"\n            else:\n                assert num_epi_stages == self.num_epi_stages_v, \"Epi stage calculation is wrong (V)\"\n        else:\n            num_epi_stages = (\n                self.num_epi_stages if const_expr(K_or_V == \"K\") else self.num_epi_stages_v\n            )\n\n        tmem_load_atom = cute.make_copy_atom(\n            tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(self.dK_reduce_ncol)), Float32\n        )\n\n        read_flag = const_expr(not deterministic_KV)\n\n        pipeline_dKV.consumer_wait(consumer_state_dKV)\n\n        # semaphore acquire\n        if const_expr(deterministic_KV):\n            barrier.wait_eq(\n                mdKV_semaphore_cur.iterator, tidx, wg_idx, head_idx % self.qhead_per_kvhead\n            )\n            cute.arch.barrier(barrier_id=barrier_id + wg_idx, number_of_threads=128)\n\n        for epi_stage in cutlass.range_constexpr(num_epi_stages):\n            # TMEM -> RMEM -- setup\n            thr_copy_t2r = tcgen05.make_tmem_copy(tmem_load_atom, tdKVtdKV).get_slice(tidx)\n            tdKVtdKV_t2r_p = thr_copy_t2r.partition_S(tdKVtdKV)\n            tdKVtdKV_t2r = self.split_wg(tdKVtdKV_t2r_p, wg_idx, num_wg)[None, None, 0, 0]\n            if const_expr(num_epi_stages > 1):\n                tdKVtdKV_t2r = tdKVtdKV_t2r[None, epi_stage]\n\n            cdKV = cute.make_identity_tensor((cta_group_tile_n, tile_hdim))\n            tdKVcdKV = thr_mma.partition_C(cdKV)\n            tdKVcdKV_t2r_p = thr_copy_t2r.partition_D(tdKVcdKV)\n            tdKVcdKV_t2r = self.split_wg(tdKVcdKV_t2r_p, wg_idx, num_wg)[None, None, 0, 0]\n            if const_expr(num_epi_stages > 1):\n                tdKVcdKV_t2r = tdKVcdKV_t2r[None, epi_stage]\n\n            tdKVrdKV_t2r = cute.make_fragment(tdKVcdKV_t2r.shape, Float32)\n\n            assert cute.size(tdKVrdKV_t2r) == cute.size(tdKVtdKV_t2r) // cute.arch.WARP_SIZE, (\n                \"RMEM<->TMEM fragment size mismatch\"\n            )\n\n            # TMEM -> RMEM -- copy and fence\n            cute.copy(thr_copy_t2r, tdKVtdKV_t2r, tdKVrdKV_t2r)\n            cute.arch.fence_view_async_tmem_load()\n\n            # RMEM -- scale and convert\n            if const_expr(scale is not None):\n                for i in cutlass.range(cute.size(tdKVrdKV_t2r.shape) // 2, unroll_full=True):\n                    tdKVrdKV_t2r[2 * i], tdKVrdKV_t2r[2 * i + 1] = cute.arch.mul_packed_f32x2(\n                        (tdKVrdKV_t2r[2 * i], tdKVrdKV_t2r[2 * i + 1]), (scale, scale)\n                    )\n            tdKVrdKV = cute.make_fragment(tdKVrdKV_t2r.shape, dtype)  # (32 columns)\n            tdKVrdKV.store(tdKVrdKV_t2r.load().to(dtype))\n\n            # RMEM -> SMEM -- copy, fence and barrier\n            tdKVrdKV_r2s = cute.make_tensor(tdKVrdKV.iterator, tdKVsdKV_r2s.shape)\n            cute.copy(thr_copy_r2s_dKV, tdKVrdKV_r2s, tdKVsdKV_r2s)\n            cute.arch.fence_view_async_shared()\n            cute.arch.barrier(barrier_id=barrier_id + wg_idx, number_of_threads=128)\n\n            # SMEM -> GMEM\n            if leader_warp:\n                if const_expr(not self.dKV_postprocess):\n                    cute.copy(tma_atom_dKV, tdKVsdKV, tdKVgdKV[None, epi_stage])\n                else:\n                    with cute.arch.elect_one():\n                        copy_utils.cpasync_reduce_bulk_add_f32(\n                            sdKV.iterator,\n                            gdKV_epi[None, epi_stage].iterator,\n                            self.tma_copy_bytes[\"dKacc\"],\n                        )\n                if const_expr(epi_stage < num_epi_stages - 1):\n                    cute.arch.cp_async_bulk_commit_group()\n                    cute.arch.cp_async_bulk_wait_group(0, read=read_flag)\n                cute.arch.barrier_arrive(\n                    barrier_id=barrier_id + wg_idx, number_of_threads=128 + cute.arch.WARP_SIZE\n                )\n\n            # Barrier since all warps need to wait for SMEM to be freed\n            cute.arch.fence_view_async_shared()\n            cute.arch.barrier(\n                barrier_id=barrier_id + wg_idx, number_of_threads=128 + cute.arch.WARP_SIZE\n            )\n\n        # semaphore release\n        # NOTE: arrive_inc calls red_release which issues membar\n        if const_expr(deterministic_KV):\n            if leader_warp:\n                cute.arch.cp_async_bulk_commit_group()\n                cute.arch.cp_async_bulk_wait_group(0, read=read_flag)\n            cute.arch.barrier(barrier_id=barrier_id + wg_idx, number_of_threads=128)\n            barrier.arrive_inc(mdKV_semaphore_cur.iterator, tidx, wg_idx, 1)\n\n        cute.arch.sync_warp()\n        with cute.arch.elect_one():\n            pipeline_dKV.consumer_release(consumer_state_dKV)\n        consumer_state_dKV.advance()\n        return consumer_state_dKV\n"
  },
  {
    "path": "flash_attn/cute/flash_bwd_sm120.py",
    "content": "# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n# SM120 (Blackwell GeForce / DGX Spark) backward pass.\n#\n# SM120 uses the same SM80-era MMA instructions (mma.sync.aligned.m16n8k16) but has\n# a smaller shared memory capacity (99 KB vs 163 KB on SM80). This module subclasses\n# FlashAttentionBackwardSm80 and overrides the SMEM capacity check accordingly.\n\nimport cutlass\nimport cutlass.utils as utils_basic\n\nfrom flash_attn.cute.flash_bwd import FlashAttentionBackwardSm80\n\n\nclass FlashAttentionBackwardSm120(FlashAttentionBackwardSm80):\n    @staticmethod\n    def can_implement(\n        dtype,\n        head_dim,\n        head_dim_v,\n        m_block_size,\n        n_block_size,\n        num_stages_Q,\n        num_stages_dO,\n        num_threads,\n        is_causal,\n        V_in_regs=False,\n    ) -> bool:\n        \"\"\"Check if the kernel can be implemented on SM120.\n\n        Same logic as SM80 but uses SM120's shared memory capacity (99 KB).\n        \"\"\"\n        if dtype not in [cutlass.Float16, cutlass.BFloat16]:\n            return False\n        if head_dim % 8 != 0:\n            return False\n        if head_dim_v % 8 != 0:\n            return False\n        if n_block_size % 16 != 0:\n            return False\n        if num_threads % 32 != 0:\n            return False\n        # Shared memory usage: Q tile + dO tile + K tile + V tile\n        smem_usage_Q = m_block_size * head_dim * num_stages_Q * 2\n        smem_usage_dO = m_block_size * head_dim_v * num_stages_dO * 2\n        smem_usage_K = n_block_size * head_dim * 2\n        smem_usage_V = n_block_size * head_dim_v * 2\n        smem_usage_QV = (\n            (smem_usage_Q + smem_usage_V) if not V_in_regs else max(smem_usage_Q, smem_usage_V)\n        )\n        smem_usage = smem_usage_QV + smem_usage_dO + smem_usage_K\n        # SM120 has 99 KB shared memory (vs 163 KB on SM80)\n        smem_capacity = utils_basic.get_smem_capacity_in_bytes(\"sm_120\")\n        if smem_usage > smem_capacity:\n            return False\n        return True\n"
  },
  {
    "path": "flash_attn/cute/flash_bwd_sm90.py",
    "content": "import math\nfrom typing import Callable, Optional, Type\nfrom functools import partial\n\nimport cuda.bindings.driver as cuda\n\nimport cutlass\nimport cutlass.cute as cute\nimport cutlass.utils.hopper_helpers as sm90_utils_basic\nfrom cutlass.cute.nvgpu import cpasync, warpgroup\nfrom cutlass.cute import FastDivmodDivisor\nfrom cutlass import Float32, Int32, Boolean, const_expr\nfrom cutlass.utils import LayoutEnum\n\nfrom quack import copy_utils\nfrom quack import layout_utils\nfrom quack import sm90_utils\nfrom quack.sm90_utils import gemm_zero_init, gemm_w_idx\n\nfrom flash_attn.cute.cute_dsl_utils import assume_tensor_aligned\nfrom flash_attn.cute import utils\nfrom flash_attn.cute.mask import AttentionMask\nfrom flash_attn.cute.seqlen_info import SeqlenInfoQK\nfrom flash_attn.cute.block_info import BlockInfo\nfrom flash_attn.cute import pipeline\nfrom quack.cute_dsl_utils import ParamsBase\nfrom flash_attn.cute.tile_scheduler import (\n    TileSchedulerArguments,\n    SingleTileScheduler,\n    SingleTileLPTBwdScheduler,\n    SingleTileVarlenScheduler,\n)\nfrom flash_attn.cute import barrier\nfrom flash_attn.cute.named_barrier import NamedBarrierBwd\nfrom flash_attn.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner\nfrom flash_attn.cute.block_sparsity import BlockSparseTensors\nfrom flash_attn.cute.block_sparse_utils import (\n    get_total_q_block_count_bwd,\n    produce_block_sparse_q_loads_bwd_sm90,\n    consume_block_sparse_mma_bwd_sm90,\n    dQaccum_store_block_sparse_bwd_sm90,\n)\n\n\nclass FlashAttentionBackwardSm90:\n    arch = 90\n\n    def __init__(\n        self,\n        dtype: Type[cutlass.Numeric],\n        head_dim: int,\n        head_dim_v: Optional[int] = None,\n        qhead_per_kvhead: int = 1,\n        is_causal: bool = False,\n        is_local: bool = False,\n        deterministic: bool = False,\n        tile_m: int = 64,\n        tile_n: int = 128,\n        Q_stage: int = 2,\n        dO_stage: int = 2,\n        PdS_stage: int = 2,\n        SdP_swapAB: bool = False,\n        dKV_swapAB: bool = False,\n        dQ_swapAB: bool = False,\n        AtomLayoutMSdP: int = 1,\n        AtomLayoutNdKV: int = 2,\n        AtomLayoutMdQ: int = 1,\n        num_threads: int = 384,\n        V_in_regs: bool = False,\n        score_mod: cutlass.Constexpr | None = None,\n        score_mod_bwd: cutlass.Constexpr | None = None,\n        mask_mod: cutlass.Constexpr | None = None,\n        has_aux_tensors: cutlass.Constexpr = False,\n        subtile_factor: cutlass.Constexpr[int] = 1,\n        dQ_single_wg: bool = False,\n    ):\n        self.dtype = dtype\n        # padding head_dim to a multiple of 16 as k_block_size\n        hdim_multiple_of = 16\n        self.tile_hdim = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of)\n        head_dim_v = head_dim_v if head_dim_v is not None else head_dim\n        self.same_hdim_kv = head_dim == head_dim_v\n        self.tile_hdimv = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of)\n        # Can save registers (and hence be faster) if we don't have to check hdim predication\n        self.check_hdim_oob = head_dim != self.tile_hdim\n        self.check_hdim_v_oob = head_dim_v != self.tile_hdimv\n        self.qhead_per_kvhead = qhead_per_kvhead\n        self.is_causal = is_causal\n        self.is_local = is_local\n        self.deterministic = deterministic\n        self.tile_m = tile_m\n        self.tile_n = tile_n\n        self.num_threads = num_threads\n        self.Q_stage = Q_stage\n        self.dO_stage = dO_stage\n        self.PdS_stage = PdS_stage\n        assert self.dO_stage in [1, self.Q_stage]\n        assert self.PdS_stage in [1, self.Q_stage]\n        self.SdP_swapAB = SdP_swapAB\n        self.dKV_swapAB = dKV_swapAB\n        self.dQ_swapAB = dQ_swapAB\n        self.AtomLayoutMSdP = AtomLayoutMSdP\n        self.AtomLayoutNdKV = AtomLayoutNdKV\n        self.AtomLayoutMdQ = AtomLayoutMdQ\n        self.num_wg_mma = (self.num_threads // 128) - 1\n        self.mma_dkv_is_rs = (\n            AtomLayoutMSdP == 1\n            and AtomLayoutNdKV == self.num_wg_mma\n            and SdP_swapAB\n            and not dKV_swapAB\n        )\n        self.V_in_regs = V_in_regs\n        # May be overridden in __call__ for varlen inputs.\n        if qhead_per_kvhead > 1:\n            assert self.same_hdim_kv, \"GQA backward requires head_dim == head_dim_v\"\n            assert self.num_wg_mma == 2, \"GQA backward assumes 2 warp groups\"\n        # These are tuned for speed\n        # Do we keep the LSE and dPsum in each thread, or split them across 8 threads that share\n        # them and then shuffle to get the value whenever we need? This can reduce register\n        # pressure when SdP_swapAB, where each thread needs to keep statistics for (kBlockM / 4)\n        # rows. If !SdP_swapAB, each thread only needs to keep statistics for 2 rows.\n        self.shuffle_LSE = self.SdP_swapAB and self.tile_hdim <= 64\n        self.shuffle_dPsum = self.SdP_swapAB and self.tile_hdim <= 64\n\n        self.buffer_align_bytes = 1024\n\n        self.score_mod = score_mod\n        self.score_mod_bwd = score_mod_bwd\n        self.mask_mod = mask_mod\n        self.has_aux_tensors = has_aux_tensors\n        self.subtile_factor = subtile_factor\n        if cutlass.const_expr(has_aux_tensors):\n            self.vec_size: cutlass.Constexpr = 1\n        else:\n            self.vec_size: cutlass.Constexpr = 4\n        self.qk_acc_dtype = Float32\n        # dQ_single_wg: WG0 computes the full dQ GEMM, WG1 skips it.\n        # Only valid for 2 MMA warp groups.\n        # Credit: Ben Spector\n        if dQ_single_wg:\n            assert self.num_wg_mma == 2, \"dQ_single_wg only supports 2 warp groups\"\n        self.num_wg_dQ = 1 if dQ_single_wg else self.num_wg_mma\n\n    @staticmethod\n    def can_implement(\n        dtype,\n        head_dim,\n        head_dim_v,\n        tile_m,\n        tile_n,\n        Q_stage,\n        num_threads,\n        V_in_regs=False,\n    ) -> bool:\n        if dtype not in [cutlass.Float16, cutlass.BFloat16]:\n            return False\n        if head_dim % 8 != 0:\n            return False\n        if head_dim_v % 8 != 0:\n            return False\n        if tile_n % 16 != 0:\n            return False\n        if num_threads % 32 != 0:\n            return False\n        if (tile_m * 2) % num_threads != 0:\n            return False\n        return True\n\n    def _check_type(\n        self,\n        mQ_type: Type[cutlass.Numeric],\n        mK_type: Type[cutlass.Numeric],\n        mV_type: Type[cutlass.Numeric],\n        mdO_type: Type[cutlass.Numeric],\n        mLSE_type: Type[cutlass.Numeric],\n        mdPsum_type: Type[cutlass.Numeric],\n        mdQaccum_type: Type[cutlass.Numeric],\n        mdK_type: Type[cutlass.Numeric],\n        mdV_type: Type[cutlass.Numeric],\n    ):\n        # Get the data type and check if it is fp16 or bf16\n        if const_expr(not (mQ_type == mK_type == mV_type == mdO_type)):\n            raise TypeError(\"All tensors must have the same data type\")\n        if const_expr(mQ_type not in [cutlass.Float16, cutlass.BFloat16]):\n            raise TypeError(\"Only Float16 or BFloat16 is supported\")\n        if const_expr(mLSE_type not in [Float32]):\n            raise TypeError(\"LSE tensor must be Float32\")\n        if const_expr(mdPsum_type not in [Float32]):\n            raise TypeError(\"dPsum tensor must be Float32\")\n        if const_expr(mdQaccum_type not in [Float32]):\n            raise TypeError(\"dQaccum tensor must be Float32\")\n        if const_expr(self.qhead_per_kvhead == 1):\n            if const_expr(not (mdK_type == mdV_type == mQ_type)):\n                raise TypeError(\"mdK and mdV tensors must have the same data type as mQ\")\n        else:\n            if const_expr(not (mdK_type == mdV_type == Float32)):\n                raise TypeError(\"mdKaccum and mdVaccum tensors must have the data type Float32\")\n        assert mQ_type == self.dtype\n\n    def _setup_attributes(self):\n        # We need to accommodate both Q and Q^T (and dO and dO^T) in shared memory.\n        # Q & dO are used in the SdP Mma and Q^T and dO^T are used in the dKV Mma.\n        # The M dimension (tile_m) doesn't matter for the layout, only the K dimension\n        wg_d_dKV = self.num_wg_mma // self.AtomLayoutNdKV\n        self.sQ_layout, self.sdO_layout = [\n            # Need to set major_mode_size (mms) to accommodate Q and Q.T\n            sm90_utils.make_smem_layout(self.dtype, LayoutEnum.ROW_MAJOR, shape, stage, mms)\n            for shape, stage, mms in [\n                ((self.tile_m, self.tile_hdim), self.Q_stage, self.tile_hdim // wg_d_dKV),\n                ((self.tile_m, self.tile_hdimv), self.dO_stage, self.tile_hdim // wg_d_dKV),\n            ]\n        ]\n        wg_d_dQ = self.num_wg_dQ // self.AtomLayoutMdQ\n        # Accomodate both K and K.T\n        self.sK_layout = sm90_utils.make_smem_layout(\n            self.dtype,\n            LayoutEnum.ROW_MAJOR,\n            (self.tile_n, self.tile_hdim),\n            stage=None,\n            major_mode_size=self.tile_hdim // wg_d_dQ,\n        )\n        # There's only V, no V.T, so layout is normal\n        self.sV_layout = sm90_utils.make_smem_layout(\n            self.dtype, LayoutEnum.ROW_MAJOR, (self.tile_n, self.tile_hdimv), None\n        )\n        # Accomodate both S and S.T\n        wg_n_SdP = self.num_wg_mma // self.AtomLayoutMSdP\n        wg_n_dKV = self.AtomLayoutNdKV\n        self.sPdS_layout = sm90_utils.make_smem_layout(\n            self.dtype,\n            LayoutEnum.ROW_MAJOR,\n            (self.tile_m, self.tile_n),\n            stage=self.PdS_stage,\n            major_mode_size=math.gcd(self.tile_n // wg_n_SdP, self.tile_n // wg_n_dKV),\n        )\n        self.sdQaccum_layout = cute.make_layout(\n            (self.tile_m * self.tile_hdim // self.num_wg_dQ, self.num_wg_dQ)\n        )\n        # dQaccum R->S\n        self.r2s_tiled_copy_dQaccum = cute.make_tiled_copy_tv(\n            cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32, num_bits_per_copy=128),\n            # thr_layout\n            cute.make_layout((self.num_threads_per_warp_group, self.num_wg_dQ)),\n            cute.make_layout(128 // Float32.width),  # val_layout\n        )\n        # dKVaccum for GQA epilogue - reuses sV+sK memory recast as f32\n        # TODO: assert that sVaccum and sKaccum don't overflow smem\n\n    def _get_tiled_mma(self):\n        maybe_swap_mn = lambda shape, swap: (shape[1], shape[0], *shape[2:]) if swap else shape\n        # S = Q @ K.T, dP = dO @ V.T\n        atom_layout_SdP = (self.AtomLayoutMSdP, self.num_wg_mma // self.AtomLayoutMSdP, 1)\n        tiler_mn_SdP = (self.tile_m // atom_layout_SdP[0], self.tile_n // atom_layout_SdP[1])\n        tiled_mma_SdP = sm90_utils_basic.make_trivial_tiled_mma(\n            self.dtype,\n            self.dtype,\n            warpgroup.OperandMajorMode.K,\n            warpgroup.OperandMajorMode.K,\n            Float32,\n            atom_layout_mnk=maybe_swap_mn(atom_layout_SdP, self.SdP_swapAB),\n            tiler_mn=(64, tiler_mn_SdP[1] if not self.SdP_swapAB else tiler_mn_SdP[0]),\n        )\n        # dV = P.T @ dO, dK = dS.T @ Q\n        atom_layout_dKV = (self.AtomLayoutNdKV, self.num_wg_mma // self.AtomLayoutNdKV, 1)\n        tiler_mn_dK = (self.tile_n // atom_layout_dKV[0], self.tile_hdim // atom_layout_dKV[1])\n        tiler_mn_dV = (self.tile_n // atom_layout_dKV[0], self.tile_hdimv // atom_layout_dKV[1])\n        tiled_mma_dK, tiled_mma_dV = [\n            sm90_utils_basic.make_trivial_tiled_mma(\n                self.dtype,\n                self.dtype,\n                warpgroup.OperandMajorMode.MN\n                if not self.mma_dkv_is_rs\n                else warpgroup.OperandMajorMode.K,\n                warpgroup.OperandMajorMode.MN,\n                Float32,\n                atom_layout_mnk=maybe_swap_mn(atom_layout_dKV, self.dKV_swapAB),\n                tiler_mn=(64, tiler_mn_d[1] if not self.dKV_swapAB else tiler_mn_d[0]),\n                a_source=warpgroup.OperandSource.RMEM\n                if self.mma_dkv_is_rs\n                else warpgroup.OperandSource.SMEM,\n            )\n            for tiler_mn_d in (tiler_mn_dK, tiler_mn_dV)\n        ]\n        # dQ = dS @ K\n        assert self.num_wg_dQ % self.AtomLayoutMdQ == 0\n        atom_layout_dQ = (self.AtomLayoutMdQ, self.num_wg_dQ // self.AtomLayoutMdQ, 1)\n        tiler_mn_dQ = (self.tile_m // atom_layout_dQ[0], self.tile_hdim // atom_layout_dQ[1])\n        tiled_mma_dQ = sm90_utils_basic.make_trivial_tiled_mma(\n            self.dtype,\n            self.dtype,\n            warpgroup.OperandMajorMode.K if not self.dQ_swapAB else warpgroup.OperandMajorMode.MN,\n            warpgroup.OperandMajorMode.MN if not self.dQ_swapAB else warpgroup.OperandMajorMode.K,\n            Float32,\n            atom_layout_mnk=maybe_swap_mn(atom_layout_dQ, self.dQ_swapAB),\n            tiler_mn=(64, tiler_mn_dQ[1] if not self.dQ_swapAB else tiler_mn_dQ[0]),\n        )\n        return tiled_mma_SdP, tiled_mma_dK, tiled_mma_dV, tiled_mma_dQ\n\n    def _get_shared_storage_cls(self):\n        sQ_struct, sK_struct, sV_struct, sdO_struct, sdQaccum_struct = [\n            cute.struct.Align[cute.struct.MemRange[t, cute.cosize(layout)], self.buffer_align_bytes]\n            for (layout, t) in [\n                (self.sQ_layout, self.dtype),\n                (self.sK_layout, self.dtype),\n                (self.sV_layout, self.dtype),\n                (self.sdO_layout, self.dtype),\n                (self.sdQaccum_layout, Float32),\n            ]\n        ]\n\n        cosize_sdS = cute.cosize(self.sPdS_layout)\n        cosize_sP = cute.cosize(self.sPdS_layout) if const_expr(not self.mma_dkv_is_rs) else 0\n        sLSE_struct = cute.struct.Align[\n            cute.struct.MemRange[Float32, cute.round_up(self.tile_m, 64) * self.Q_stage], 128\n        ]\n        sdPsum_struct = cute.struct.Align[\n            cute.struct.MemRange[Float32, cute.round_up(self.tile_m, 64) * self.dO_stage], 128\n        ]\n\n        @cute.struct\n        class SharedStorageQKV:\n            mbar_ptr_Q: cute.struct.MemRange[cutlass.Int64, self.Q_stage * 2]\n            mbar_ptr_dO: cute.struct.MemRange[cutlass.Int64, self.dO_stage * 2]\n            sLSE: sLSE_struct\n            sdPsum: sdPsum_struct\n            sQ: sQ_struct\n            sV: sV_struct\n            sK: sK_struct\n            sdO: sdO_struct\n            sP: cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sP], 1024]\n            sdS: cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sdS], 1024]\n            sdQaccum: sdQaccum_struct\n\n        return SharedStorageQKV\n\n    @cute.jit\n    def __call__(\n        self,\n        mQ: cute.Tensor,\n        mK: cute.Tensor,\n        mV: cute.Tensor,\n        mdO: cute.Tensor,\n        mLSE: cute.Tensor,\n        mdPsum: cute.Tensor,\n        mdQaccum: cute.Tensor,\n        mdK: cute.Tensor,\n        mdV: cute.Tensor,\n        softmax_scale: Float32,\n        mCuSeqlensQ: Optional[cute.Tensor] = None,\n        mCuSeqlensK: Optional[cute.Tensor] = None,\n        mSeqUsedQ: Optional[cute.Tensor] = None,\n        mSeqUsedK: Optional[cute.Tensor] = None,\n        softcap: Float32 | float | None = None,\n        window_size_left: Int32 | int | None = None,\n        window_size_right: Int32 | int | None = None,\n        mdQ_semaphore: Optional[cute.Tensor] = None,\n        mdK_semaphore: Optional[cute.Tensor] = None,\n        mdV_semaphore: Optional[cute.Tensor] = None,\n        aux_tensors: Optional[list] = None,\n        blocksparse_tensors: Optional[BlockSparseTensors] = None,\n        # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI).\n        stream: cuda.CUstream = None,\n    ):\n        # For GQA (qhead_per_kvhead > 1), multiple Q heads accumulate into the same dK/dV,\n        # so we need the float32 accum path + postprocess.\n        # For varlen_k with qhead_per_kvhead == 1, we use ragged TMA tensors.\n        self.varlen_k = mCuSeqlensK is not None or mSeqUsedK is not None\n\n        self._check_type(\n            *(\n                t.element_type if t is not None else None\n                for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV)\n            )\n        )\n\n        self.is_varlen_q = mCuSeqlensQ is not None or mSeqUsedQ is not None\n\n        mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV = [\n            assume_tensor_aligned(t) for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV)\n        ]\n\n        # Non-varlen inputs are (b, s, n, h), varlen inputs are (s, n, h).\n        # We convert both to a seqlen-major view with head-dim second.\n        # Each tensor may have different rank when Q is padded (seqused_q) but K/V are unpadded (cu_seqlens_k).\n        def _qkv_transpose(t):\n            return layout_utils.select(t, [1, 3, 2, 0] if cute.rank(t.shape) == 4 else [0, 2, 1])\n\n        mQ, mK, mV, mdO = [_qkv_transpose(t) for t in (mQ, mK, mV, mdO)]\n        if const_expr(self.qhead_per_kvhead == 1):\n            mdK, mdV = [_qkv_transpose(t) for t in (mdK, mdV)]\n        else:\n            # Accum tensors are (b, n, s*h) for non-varlen and (n, s*h) for varlen.\n            accum_transpose = [2, 1, 0] if cute.rank(mdK.shape) == 3 else [1, 0]\n            mdK, mdV = [layout_utils.select(t, accum_transpose) for t in (mdK, mdV)]\n        # Non-varlen stats are (b, n, s), varlen stats are (n, s).\n        LSE_dPsum_dQaccum_transpose = [2, 1, 0] if cute.rank(mLSE.shape) == 3 else [1, 0]\n        mLSE, mdPsum, mdQaccum = [\n            layout_utils.select(t, LSE_dPsum_dQaccum_transpose) for t in (mLSE, mdPsum, mdQaccum)\n        ]\n\n        tiled_mma_SdP, tiled_mma_dK, tiled_mma_dV, tiled_mma_dQ = self._get_tiled_mma()\n        # (batch, num_head, num_m_blocks, cluster_size) -> (num_m_blocks, cluster_size, num_head, batch)\n        if const_expr(self.deterministic):\n            assert mdQ_semaphore is not None\n            mdQ_semaphore = layout_utils.select(mdQ_semaphore, mode=[2, 3, 1, 0])\n\n        self.num_mma_threads = tiled_mma_SdP.size\n        assert self.num_mma_threads + 128 == self.num_threads\n\n        self.num_threads_per_warp_group = 128\n        self.num_producer_threads = 32\n\n        REG_LIMIT = 504 if self.num_wg_mma == 2 else 512\n        if const_expr(self.num_wg_mma == 2):\n            if const_expr(self.num_wg_dQ == 1):\n                self.num_mma_regs_wg0 = 256\n                self.num_mma_regs_wg1 = 224\n            else:\n                self.num_mma_regs_wg0 = 240\n                self.num_mma_regs_wg1 = 240\n            self.num_mma_regs = self.num_mma_regs_wg0  # for backward compat\n            self.num_producer_regs = 24\n            assert (\n                self.num_mma_regs_wg0 + self.num_mma_regs_wg1 + self.num_producer_regs <= REG_LIMIT\n            )\n        else:  # 3 warp groups\n            self.num_mma_regs_wg0 = 160\n            self.num_mma_regs_wg1 = 160\n            self.num_mma_regs = 160\n            self.num_producer_regs = 32\n            assert self.num_mma_regs_wg0 * self.num_wg_mma + self.num_producer_regs <= REG_LIMIT\n\n        self._setup_attributes()\n        SharedStorage = self._get_shared_storage_cls()\n\n        self.tma_copy_bytes = {\n            name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1]))\n            for name, mX, layout in [\n                (\"Q\", mQ, self.sQ_layout),\n                (\"K\", mK, self.sK_layout),\n                (\"V\", mV, self.sV_layout),\n                (\"dO\", mdO, self.sdO_layout),\n            ]\n        }\n        self.tma_copy_bytes[\"LSE\"] = self.tile_m * Float32.width // 8\n        self.tma_copy_bytes[\"dPsum\"] = self.tile_m * Float32.width // 8\n        self.tma_copy_bytes[\"dQ\"] = (\n            self.tile_m * self.tile_hdim * Float32.width // 8 // self.num_wg_dQ\n        )\n        self.tma_copy_bytes[\"dKacc\"] = self.tile_n * self.tile_hdim * Float32.width // 8\n        self.tma_copy_bytes[\"dVacc\"] = self.tile_n * self.tile_hdimv * Float32.width // 8\n\n        tma_atom_Q, tma_tensor_Q = cpasync.make_tiled_tma_atom(\n            cpasync.CopyBulkTensorTileG2SOp(),\n            mQ,\n            cute.select(self.sQ_layout, mode=[0, 1]),\n            (self.tile_m, self.tile_hdim),\n        )\n        tma_atom_K, tma_tensor_K = cpasync.make_tiled_tma_atom(\n            cpasync.CopyBulkTensorTileG2SOp(),\n            mK,\n            cute.select(self.sK_layout, mode=[0, 1]),\n            (self.tile_n, self.tile_hdim),\n        )\n        tma_atom_V, tma_tensor_V = cpasync.make_tiled_tma_atom(\n            cpasync.CopyBulkTensorTileG2SOp(),\n            mV,\n            cute.select(self.sV_layout, mode=[0, 1]),\n            (self.tile_n, self.tile_hdimv),\n        )\n        tma_atom_dO, tma_tensor_dO = cpasync.make_tiled_tma_atom(\n            cpasync.CopyBulkTensorTileG2SOp(),\n            mdO,\n            cute.select(self.sdO_layout, mode=[0, 1]),\n            (self.tile_m, self.tile_hdimv),\n        )\n        if const_expr(self.qhead_per_kvhead == 1):\n            mdK_tma = (\n                copy_utils.create_ragged_tensor_for_tma(mdK, ragged_dim=0, ptr_shift=True)\n                if self.varlen_k\n                else mdK\n            )\n            mdV_tma = (\n                copy_utils.create_ragged_tensor_for_tma(mdV, ragged_dim=0, ptr_shift=True)\n                if self.varlen_k\n                else mdV\n            )\n            tma_atom_dK, tma_tensor_dK = cpasync.make_tiled_tma_atom(\n                cpasync.CopyBulkTensorTileS2GOp(),\n                mdK_tma,\n                cute.select(self.sK_layout, mode=[0, 1]),\n                (self.tile_n, self.tile_hdim),\n            )\n            tma_atom_dV, tma_tensor_dV = cpasync.make_tiled_tma_atom(\n                cpasync.CopyBulkTensorTileS2GOp(),\n                mdV_tma,\n                cute.select(self.sV_layout, mode=[0, 1]),\n                (self.tile_n, self.tile_hdimv),\n            )\n        else:\n            tma_atom_dK = tma_atom_dV = tma_tensor_dK = tma_tensor_dV = None\n\n        if const_expr(mCuSeqlensK is not None or mSeqUsedK is not None):\n            TileScheduler = SingleTileVarlenScheduler\n        elif const_expr(self.deterministic):\n            TileScheduler = SingleTileLPTBwdScheduler\n        else:\n            TileScheduler = SingleTileScheduler\n        self.spt = (self.is_causal or self.is_local) and self.deterministic\n        tile_sched_args = TileSchedulerArguments(\n            cute.ceil_div(cute.size(mK.shape[0]), self.tile_n),\n            cute.size(mQ.shape[2]),\n            cute.size(mK.shape[3])\n            if const_expr(mCuSeqlensK is None)\n            else cute.size(mCuSeqlensK.shape[0] - 1),  # num_batch\n            1,  # num_splits\n            cute.size(mQ.shape[0]),  # pass seqlen_q or total_q for seqlen_k\n            mQ.shape[1],  # headdim\n            mV.shape[1],  # headdim_v\n            total_q=cute.size(mK.shape[0])\n            if const_expr(mCuSeqlensK is not None)\n            else cute.size(mK.shape[0]) * cute.size(mK.shape[3]),\n            tile_shape_mn=(self.tile_n, self.tile_m),  # Swapping the role of Q & K\n            mCuSeqlensQ=mCuSeqlensK,\n            mSeqUsedQ=mSeqUsedK,\n            qhead_per_kvhead_packgqa=1,\n            element_size=self.dtype.width // 8,\n            is_persistent=False,\n            lpt=self.spt,\n            head_swizzle=self.deterministic,\n        )\n\n        tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)\n        grid_dim = TileScheduler.get_grid_shape(tile_sched_params)\n\n        LOG2_E = math.log2(math.e)\n        if const_expr(self.score_mod is None):\n            softmax_scale_log2 = softmax_scale * LOG2_E\n        else:\n            softmax_scale_log2 = LOG2_E\n\n        fastdiv_mods = None\n        if const_expr(aux_tensors is not None):\n            seqlen_q = cute.size(mQ.shape[0])\n            seqlen_k = cute.size(mK.shape[0])\n            seqlen_q_divmod = FastDivmodDivisor(seqlen_q)\n            seqlen_k_divmod = FastDivmodDivisor(seqlen_k)\n            fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod)\n\n        qhead_per_kvhead_divmod = None\n        if const_expr(self.qhead_per_kvhead > 1):\n            qhead_per_kvhead_divmod = FastDivmodDivisor(self.qhead_per_kvhead)\n\n        self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None)\n\n        if const_expr(window_size_left is not None):\n            window_size_left = Int32(window_size_left)\n        if const_expr(window_size_right is not None):\n            window_size_right = Int32(window_size_right)\n\n        self.kernel(\n            tma_tensor_Q,\n            tma_tensor_K,\n            tma_tensor_V,\n            tma_tensor_dO,\n            tma_tensor_dK if const_expr(self.qhead_per_kvhead == 1) else mdK,\n            tma_tensor_dV if const_expr(self.qhead_per_kvhead == 1) else mdV,\n            tma_atom_Q,\n            tma_atom_K,\n            tma_atom_V,\n            tma_atom_dO,\n            tma_atom_dK,\n            tma_atom_dV,\n            mLSE,\n            mdPsum,\n            mdQaccum,\n            mCuSeqlensQ,\n            mCuSeqlensK,\n            mSeqUsedQ,\n            mSeqUsedK,\n            self.sQ_layout,\n            self.sK_layout,\n            self.sV_layout,\n            self.sPdS_layout,\n            self.sdO_layout,\n            self.sdQaccum_layout,\n            self.r2s_tiled_copy_dQaccum,\n            tiled_mma_SdP,\n            tiled_mma_dK,\n            tiled_mma_dV,\n            tiled_mma_dQ,\n            softmax_scale_log2,\n            softmax_scale,\n            tile_sched_params,\n            TileScheduler,\n            SharedStorage,\n            aux_tensors,\n            fastdiv_mods,\n            blocksparse_tensors,\n            qhead_per_kvhead_divmod,\n            mdQ_semaphore,\n            window_size_left,\n            window_size_right,\n        ).launch(\n            grid=grid_dim,\n            block=[self.num_threads, 1, 1],\n            stream=stream,\n            min_blocks_per_mp=1,\n            use_pdl=True,\n        )\n\n    @cute.kernel\n    def kernel(\n        self,\n        mQ: cute.Tensor,\n        mK: cute.Tensor,\n        mV: cute.Tensor,\n        mdO: cute.Tensor,\n        mdK: cute.Tensor,\n        mdV: cute.Tensor,\n        tma_atom_Q: cute.CopyAtom,\n        tma_atom_K: cute.CopyAtom,\n        tma_atom_V: cute.CopyAtom,\n        tma_atom_dO: cute.CopyAtom,\n        tma_atom_dK: cute.CopyAtom,\n        tma_atom_dV: cute.CopyAtom,\n        mLSE: cute.Tensor,\n        mdPsum: cute.Tensor,\n        mdQaccum: cute.Tensor,\n        mCuSeqlensQ: Optional[cute.Tensor],\n        mCuSeqlensK: Optional[cute.Tensor],\n        mSeqUsedQ: Optional[cute.Tensor],\n        mSeqUsedK: Optional[cute.Tensor],\n        sQ_layout: cute.ComposedLayout,\n        sK_layout: cute.ComposedLayout,\n        sV_layout: cute.ComposedLayout,\n        sPdS_layout: cute.ComposedLayout,\n        sdO_layout: cute.ComposedLayout,\n        sdQaccum_layout: cute.Layout,\n        r2s_tiled_copy_dQaccum: cute.TiledCopy,\n        tiled_mma_SdP: cute.TiledMma,\n        tiled_mma_dK: cute.TiledMma,\n        tiled_mma_dV: cute.TiledMma,\n        tiled_mma_dQ: cute.TiledMma,\n        softmax_scale_log2,\n        softmax_scale,\n        tile_sched_params: ParamsBase,\n        TileScheduler: cutlass.Constexpr[Callable],\n        SharedStorage: cutlass.Constexpr[Callable],\n        aux_tensors: Optional[list] = None,\n        fastdiv_mods=(None, None),\n        blocksparse_tensors: Optional[BlockSparseTensors] = None,\n        qhead_per_kvhead_divmod: Optional[FastDivmodDivisor] = None,\n        mdQ_semaphore: Optional[cute.Tensor] = None,\n        window_size_left: Optional[Int32] = None,\n        window_size_right: Optional[Int32] = None,\n    ):\n        warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())\n\n        # prefetch TMA descriptors\n        if warp_idx == 0:\n            for atom in [tma_atom_Q, tma_atom_K, tma_atom_V, tma_atom_dO, tma_atom_dK, tma_atom_dV]:\n                if const_expr(atom is not None):\n                    cpasync.prefetch_descriptor(atom)\n\n        smem = cutlass.utils.SmemAllocator()\n        storage = smem.allocate(SharedStorage)\n\n        pipeline_producer_group = cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread)\n        pipeline_consumer_group = cutlass.pipeline.CooperativeGroup(\n            cutlass.pipeline.Agent.Thread, self.num_mma_threads // cute.arch.WARP_SIZE\n        )\n        pipeline_Q = pipeline.PipelineTmaAsync.create(\n            barrier_storage=storage.mbar_ptr_Q.data_ptr(),\n            num_stages=self.Q_stage,\n            producer_group=pipeline_producer_group,\n            consumer_group=pipeline_consumer_group,\n            tx_count=self.tma_copy_bytes[\"Q\"] + self.tma_copy_bytes[\"LSE\"],\n            defer_sync=True,\n        )\n        pipeline_dO = pipeline.PipelineTmaAsync.create(\n            barrier_storage=storage.mbar_ptr_dO.data_ptr(),\n            num_stages=self.dO_stage,\n            producer_group=pipeline_producer_group,\n            consumer_group=pipeline_consumer_group,\n            tx_count=self.tma_copy_bytes[\"dO\"] + self.tma_copy_bytes[\"dPsum\"],\n            defer_sync=False,\n        )\n\n        sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner)\n        sdO = storage.sdO.get_tensor(sdO_layout.outer, swizzle=sdO_layout.inner)\n        sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner)\n        sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner)\n        sP = None\n        if const_expr(not self.mma_dkv_is_rs):\n            sP = storage.sP.get_tensor(sPdS_layout.outer, swizzle=sPdS_layout.inner)\n        sdS = storage.sdS.get_tensor(sPdS_layout.outer, swizzle=sPdS_layout.inner)\n        sLSE = storage.sLSE.get_tensor(\n            cute.make_layout(\n                (self.tile_m, self.Q_stage),\n                stride=(1, cute.round_up(self.tile_m, 64)),\n            )\n        )\n        sdPsum = storage.sdPsum.get_tensor(\n            cute.make_layout(\n                (self.tile_m, self.dO_stage),\n                stride=(1, cute.round_up(self.tile_m, 64)),\n            )\n        )\n        sdQaccum = storage.sdQaccum.get_tensor(sdQaccum_layout)\n\n        block_info = BlockInfo(\n            self.tile_m,\n            self.tile_n,\n            self.is_causal,\n            self.is_local,\n            False,  # is_split_kv\n            window_size_left,\n            window_size_right,\n            qhead_per_kvhead_packgqa=1,\n        )\n        SeqlenInfoCls = partial(\n            SeqlenInfoQK.create,\n            seqlen_q_static=mQ.shape[0],\n            seqlen_k_static=mK.shape[0],\n            mCuSeqlensQ=mCuSeqlensQ,\n            mCuSeqlensK=mCuSeqlensK,\n            mSeqUsedQ=mSeqUsedQ,\n            mSeqUsedK=mSeqUsedK,\n            tile_m=self.tile_m,\n            tile_n=self.tile_n,\n        )\n        AttentionMaskCls = partial(\n            AttentionMask,\n            self.tile_m,\n            self.tile_n,\n            window_size_left=window_size_left,\n            window_size_right=window_size_right,\n            swap_AB=self.SdP_swapAB,\n        )\n        TileSchedulerCls = partial(TileScheduler.create, tile_sched_params)\n\n        if warp_idx < 4:\n            cute.arch.setmaxregister_decrease(self.num_producer_regs)\n            if warp_idx == 0:\n                self.load(\n                    mQ,\n                    mK,\n                    mV,\n                    mdO,\n                    mLSE,\n                    mdPsum,\n                    sQ,\n                    sK,\n                    sV,\n                    sdO,\n                    sLSE,\n                    sdPsum,\n                    tma_atom_Q,\n                    tma_atom_K,\n                    tma_atom_V,\n                    tma_atom_dO,\n                    pipeline_Q,\n                    pipeline_dO,\n                    block_info,\n                    SeqlenInfoCls,\n                    TileSchedulerCls,\n                    blocksparse_tensors,\n                    qhead_per_kvhead_divmod,\n                )\n            if warp_idx == 1:\n                self.dQaccum_store(\n                    mdQaccum,\n                    sdQaccum,\n                    block_info,\n                    TileSchedulerCls,\n                    SeqlenInfoCls,\n                    blocksparse_tensors,\n                    mdQ_semaphore,\n                )\n        else:\n            tidx, _, _ = cute.arch.thread_idx()\n            tidx = tidx - 128\n            mma_args = (\n                tiled_mma_SdP,\n                tiled_mma_dK,\n                tiled_mma_dV,\n                tiled_mma_dQ,\n                mdK,\n                mdV,\n                mdQaccum,\n                sQ,\n                sK,\n                sV,\n                sdO,\n                sP,\n                sdS,\n                sLSE,\n                sdPsum,\n                sdQaccum,\n                pipeline_Q,\n                pipeline_dO,\n                tidx,\n                tma_atom_dK,\n                tma_atom_dV,\n                r2s_tiled_copy_dQaccum,\n                softmax_scale_log2,\n                softmax_scale,\n                block_info,\n                SeqlenInfoCls,\n                AttentionMaskCls,\n                TileSchedulerCls,\n                aux_tensors,\n                fastdiv_mods,\n                blocksparse_tensors,\n                qhead_per_kvhead_divmod,\n            )\n            if const_expr(self.num_wg_dQ == self.num_wg_mma):\n                # Both WGs compute dQ\n                cute.arch.setmaxregister_increase(self.num_mma_regs_wg0)\n                self.mma(*mma_args, is_dQ_wg=True)\n            else:\n                # WG0 computes dQ, WG1 skips it\n                warp_idx_in_mma = cute.arch.make_warp_uniform(cute.arch.warp_idx()) - 4\n                if warp_idx_in_mma < 4:\n                    cute.arch.setmaxregister_increase(self.num_mma_regs_wg0)\n                    self.mma(*mma_args, is_dQ_wg=True)\n                else:\n                    cute.arch.setmaxregister_increase(self.num_mma_regs_wg1)\n                    self.mma(*mma_args, is_dQ_wg=False)\n\n    @cute.jit\n    def load(\n        self,\n        mQ: cute.Tensor,\n        mK: cute.Tensor,\n        mV: cute.Tensor,\n        mdO: cute.Tensor,\n        mLSE: cute.Tensor,\n        mdPsum: cute.Tensor,\n        sQ: cute.Tensor,\n        sK: cute.Tensor,\n        sV: cute.Tensor,\n        sdO: cute.Tensor,\n        sLSE: cute.Tensor,\n        sdPsum: cute.Tensor,\n        tma_atom_Q: cute.CopyAtom,\n        tma_atom_K: cute.CopyAtom,\n        tma_atom_V: cute.CopyAtom,\n        tma_atom_dO: cute.CopyAtom,\n        pipeline_Q: cutlass.pipeline.PipelineAsync,\n        pipeline_dO: cutlass.pipeline.PipelineAsync,\n        block_info: BlockInfo,\n        SeqlenInfoCls: Callable,\n        TileSchedulerCls: Callable,\n        blocksparse_tensors: Optional[BlockSparseTensors] = None,\n        qhead_per_kvhead_divmod: Optional[FastDivmodDivisor] = None,\n    ):\n        warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4\n\n        if warp_idx_in_wg == 0:\n            producer_state_Q = cutlass.pipeline.make_pipeline_state(\n                cutlass.pipeline.PipelineUserType.Producer, self.Q_stage\n            )\n            producer_state_dO = cutlass.pipeline.make_pipeline_state(\n                cutlass.pipeline.PipelineUserType.Producer, self.dO_stage\n            )\n            tile_scheduler = TileSchedulerCls()\n            work_tile = tile_scheduler.initial_work_tile_info()\n            while work_tile.is_valid_tile:\n                n_block, head_idx, batch_idx, _ = work_tile.tile_idx\n                seqlen = SeqlenInfoCls(batch_idx)\n                head_idx_kv = (\n                    head_idx\n                    if const_expr(self.qhead_per_kvhead == 1)\n                    else head_idx // qhead_per_kvhead_divmod\n                )\n                mK_cur = seqlen.offset_batch_K(mK, batch_idx, dim=3)[None, None, head_idx_kv]\n                mV_cur = seqlen.offset_batch_K(mV, batch_idx, dim=3)[None, None, head_idx_kv]\n                gK = cute.local_tile(mK_cur, (self.tile_n, self.tile_hdim), (n_block, 0))\n                gV = cute.local_tile(mV_cur, (self.tile_n, self.tile_hdimv), (n_block, 0))\n\n                mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx]\n                mLSE_cur = seqlen.offset_batch_Q(mLSE, batch_idx, dim=2, padded=True)[\n                    None, head_idx\n                ]\n                mdO_cur = seqlen.offset_batch_Q(mdO, batch_idx, dim=3)[None, None, head_idx]\n                mdPsum_cur = seqlen.offset_batch_Q(mdPsum, batch_idx, dim=2, padded=True)[\n                    None, head_idx\n                ]\n                gQ = cute.local_tile(mQ_cur, (self.tile_m, self.tile_hdim), (None, 0))\n                gdO = cute.local_tile(mdO_cur, (self.tile_m, self.tile_hdimv), (None, 0))\n                gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (None,))\n                gdPsum = cute.local_tile(mdPsum_cur, (self.tile_m,), (None,))\n\n                load_K, _, _ = copy_utils.tma_get_copy_fn(\n                    tma_atom_K, 0, cute.make_layout(1), gK, sK, single_stage=True\n                )\n                load_V, _, _ = copy_utils.tma_get_copy_fn(\n                    tma_atom_V, 0, cute.make_layout(1), gV, sV, single_stage=True\n                )\n                load_Q, _, _ = copy_utils.tma_get_copy_fn(\n                    tma_atom_Q, 0, cute.make_layout(1), gQ, sQ\n                )\n                load_Q = copy_utils.tma_producer_copy_fn(load_Q, pipeline_Q)\n                load_dO, _, _ = copy_utils.tma_get_copy_fn(\n                    tma_atom_dO, 0, cute.make_layout(1), gdO, sdO\n                )\n                load_dO = copy_utils.tma_producer_copy_fn(load_dO, pipeline_dO)\n                load_LSE = copy_utils.cpasync_bulk_get_copy_fn(gLSE, sLSE)\n                load_LSE = copy_utils.tma_producer_copy_fn(load_LSE, pipeline_Q)\n                load_dPsum = copy_utils.cpasync_bulk_get_copy_fn(gdPsum, sdPsum)\n                load_dPsum = copy_utils.tma_producer_copy_fn(load_dPsum, pipeline_dO)\n\n                m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block)\n\n                if const_expr(not self.use_block_sparsity):\n                    total_m_block_cnt = m_block_max - m_block_min\n                    process_tile = (\n                        const_expr(not self.is_local and not self.is_varlen_q)\n                        or m_block_min < m_block_max\n                    )\n                else:\n                    total_m_block_cnt = get_total_q_block_count_bwd(\n                        blocksparse_tensors,\n                        batch_idx,\n                        head_idx,\n                        n_block,\n                        subtile_factor=self.subtile_factor,\n                        m_block_max=m_block_max,\n                    )\n                    process_tile = total_m_block_cnt > Int32(0)\n\n                if process_tile:\n                    if const_expr(not self.use_block_sparsity):\n                        first_m_block = m_block_min\n                        pipeline_Q.producer_acquire(\n                            producer_state_Q, extra_tx_count=self.tma_copy_bytes[\"K\"]\n                        )\n                        load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q))\n                        load_Q(first_m_block, producer_state=producer_state_Q)\n                        # Wait for bwd preprocess to finish writing LSE and dPsum\n                        cute.arch.griddepcontrol_wait()\n                        load_LSE(first_m_block, producer_state=producer_state_Q)\n                        producer_state_dO_cur = (\n                            producer_state_dO\n                            if const_expr(self.Q_stage != self.dO_stage)\n                            else producer_state_Q\n                        )\n                        pipeline_dO.producer_acquire(\n                            producer_state_dO_cur, extra_tx_count=self.tma_copy_bytes[\"V\"]\n                        )\n                        load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_cur))\n                        load_dO(first_m_block, producer_state=producer_state_dO_cur)\n                        load_dPsum(first_m_block, producer_state=producer_state_dO_cur)\n                        producer_state_Q.advance()\n                        producer_state_dO.advance()\n\n                        for m_block in cutlass.range(m_block_min + 1, m_block_max, unroll=1):\n                            pipeline_Q.producer_acquire(producer_state_Q)\n                            load_Q(m_block, producer_state=producer_state_Q)\n                            load_LSE(m_block, producer_state=producer_state_Q)\n                            producer_state_dO_cur = (\n                                producer_state_dO\n                                if const_expr(self.Q_stage != self.dO_stage)\n                                else producer_state_Q\n                            )\n                            pipeline_dO.producer_acquire(producer_state_dO_cur)\n                            load_dO(m_block, producer_state=producer_state_dO_cur)\n                            load_dPsum(m_block, producer_state=producer_state_dO_cur)\n                            producer_state_Q.advance()\n                            producer_state_dO.advance()\n                    else:\n                        producer_state_Q, producer_state_dO = produce_block_sparse_q_loads_bwd_sm90(\n                            blocksparse_tensors,\n                            batch_idx,\n                            head_idx,\n                            n_block,\n                            producer_state_Q,\n                            producer_state_dO,\n                            pipeline_Q,\n                            pipeline_dO,\n                            load_K,\n                            load_V,\n                            load_Q,\n                            load_dO,\n                            load_LSE,\n                            load_dPsum,\n                            self.tma_copy_bytes[\"K\"],\n                            self.tma_copy_bytes[\"V\"],\n                            Q_stage_eq_dO_stage=(self.Q_stage == self.dO_stage),\n                            subtile_factor=self.subtile_factor,\n                            m_block_max=m_block_max,\n                        )\n\n                tile_scheduler.prefetch_next_work()\n                tile_scheduler.advance_to_next_work()\n                work_tile = tile_scheduler.get_current_work()\n\n    @cute.jit\n    def apply_score_mod(\n        self,\n        acc_S: cute.Tensor,\n        thr_mma_SdP: cute.core.ThrMma,\n        batch_idx,\n        head_idx,\n        m_block,\n        n_block,\n        softmax_scale,\n        seqlen_info: SeqlenInfoQK,\n        aux_tensors=None,\n        fastdiv_mods=(None, None),\n    ):\n        # [NOTE] SdP_swapAB: swapAB transposes the tile, so use (n, m) indexing\n        cS = cute.make_identity_tensor(\n            (self.tile_n, self.tile_m) if self.SdP_swapAB else (self.tile_m, self.tile_n)\n        )\n        cS = cute.domain_offset(\n            (n_block * self.tile_n, m_block * self.tile_m)\n            if self.SdP_swapAB\n            else (m_block * self.tile_m, n_block * self.tile_n),\n            cS,\n        )\n        tScS = thr_mma_SdP.partition_C(cS)\n\n        apply_score_mod_inner(\n            acc_S,\n            tScS,\n            self.score_mod,\n            batch_idx,\n            head_idx,\n            softmax_scale,\n            self.vec_size,\n            self.qk_acc_dtype,\n            aux_tensors,\n            fastdiv_mods,\n            seqlen_info,\n            constant_q_idx=None,\n            qhead_per_kvhead=self.qhead_per_kvhead,\n            transpose_indices=self.SdP_swapAB,\n        )\n\n    @cute.jit\n    def apply_score_mod_bwd(\n        self,\n        grad_tensor: cute.Tensor,\n        score_tensor: cute.Tensor,\n        thr_mma_SdP: cute.core.ThrMma,\n        batch_idx,\n        head_idx,\n        m_block,\n        n_block,\n        softmax_scale,\n        seqlen_info: SeqlenInfoQK,\n        aux_tensors=None,\n        fastdiv_mods=(None, None),\n    ):\n        cS = cute.make_identity_tensor(\n            (self.tile_n, self.tile_m) if self.SdP_swapAB else (self.tile_m, self.tile_n)\n        )\n        cS = cute.domain_offset(\n            (n_block * self.tile_n, m_block * self.tile_m)\n            if self.SdP_swapAB\n            else (m_block * self.tile_m, n_block * self.tile_n),\n            cS,\n        )\n        tScS = thr_mma_SdP.partition_C(cS)\n\n        apply_score_mod_bwd_inner(\n            grad_tensor,\n            score_tensor,\n            tScS,\n            self.score_mod_bwd,\n            batch_idx,\n            head_idx,\n            softmax_scale,\n            self.vec_size,\n            self.qk_acc_dtype,\n            aux_tensors,\n            fastdiv_mods,\n            seqlen_info,\n            constant_q_idx=None,\n            qhead_per_kvhead=self.qhead_per_kvhead,\n            transpose_indices=self.SdP_swapAB,\n        )\n\n    @cute.jit\n    def mma(\n        self,\n        tiled_mma_SdP: cute.TiledMma,\n        tiled_mma_dK: cute.TiledMma,\n        tiled_mma_dV: cute.TiledMma,\n        tiled_mma_dQ: cute.TiledMma,\n        mdK: cute.Tensor,\n        mdV: cute.Tensor,\n        mdQaccum: cute.Tensor,\n        sQ: cute.Tensor,\n        sK: cute.Tensor,\n        sV: cute.Tensor,\n        sdO: cute.Tensor,\n        sP: Optional[cute.Tensor],\n        sdS: cute.Tensor,\n        sLSE: cute.Tensor,\n        sdPsum: cute.Tensor,\n        sdQaccum: cute.Tensor,\n        pipeline_Q: cutlass.pipeline.PipelineAsync,\n        pipeline_dO: cutlass.pipeline.PipelineAsync,\n        tidx: Int32,\n        tma_atom_dK: cute.CopyAtom,\n        tma_atom_dV: cute.CopyAtom,\n        r2s_tiled_copy_dQaccum: cute.TiledCopy,\n        softmax_scale_log2: Float32,\n        softmax_scale: Float32,\n        block_info: BlockInfo,\n        SeqlenInfoCls: Callable,\n        AttentionMaskCls: Callable,\n        TileSchedulerCls: Callable,\n        aux_tensors: Optional[list] = None,\n        fastdiv_mods=(None, None),\n        blocksparse_tensors: Optional[BlockSparseTensors] = None,\n        qhead_per_kvhead_divmod: Optional[FastDivmodDivisor] = None,\n        is_dQ_wg: cutlass.Constexpr[bool] = True,\n    ):\n        warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group)\n        warp_group_thread_layout = cute.make_layout(\n            self.num_wg_mma, stride=self.num_threads_per_warp_group\n        )\n        thr_mma_SdP = tiled_mma_SdP.get_slice(tidx)\n        wg_mma_SdP = tiled_mma_SdP.get_slice(warp_group_thread_layout(warp_group_idx))\n        wg_mma_dK = tiled_mma_dK.get_slice(warp_group_thread_layout(warp_group_idx))\n        wg_mma_dV = tiled_mma_dV.get_slice(warp_group_thread_layout(warp_group_idx))\n        wg_mma_dQ = None\n        if const_expr(is_dQ_wg):\n            wg_idx_dQ = warp_group_idx if const_expr(self.num_wg_dQ > 1) else 0\n            wg_mma_dQ = tiled_mma_dQ.get_slice(warp_group_thread_layout(wg_idx_dQ))\n        # S = Q @ K.T\n        shape_mnk_S = (self.tile_m, self.tile_n, self.tile_hdim)\n        _, tSrQ, tSrK = sm90_utils.partition_fragment_ABC(\n            wg_mma_SdP, shape_mnk_S, sQ, sK, swap_AB=self.SdP_swapAB\n        )\n        mma_qk_fn = partial(\n            gemm_zero_init, tiled_mma_SdP, shape_mnk_S[:2], tSrQ, tSrK, swap_AB=self.SdP_swapAB\n        )\n        # dP = dO @ V.T\n        shape_mnk_dP = (self.tile_m, self.tile_n, self.tile_hdimv)\n        _, tdPrdO, tdPrV = sm90_utils.partition_fragment_ABC(\n            wg_mma_SdP, shape_mnk_dP, sdO, sV, swap_AB=self.SdP_swapAB\n        )\n        mma_dov_fn = partial(\n            gemm_zero_init, tiled_mma_SdP, shape_mnk_dP[:2], tdPrdO, tdPrV, swap_AB=self.SdP_swapAB\n        )\n        # dV += P.T @ dO\n        sPt = layout_utils.transpose_view(sP) if sP is not None else None\n        sdOt = layout_utils.transpose_view(sdO)\n        shape_mnk_dV = (self.tile_n, self.tile_hdimv, self.tile_m)\n        acc_dV, tdVrPt, tdVrdOt = sm90_utils.partition_fragment_ABC(\n            wg_mma_dV, shape_mnk_dV, sPt, sdOt, swap_AB=self.dKV_swapAB\n        )\n        if const_expr(not self.mma_dkv_is_rs):\n            mma_pdo_fn = partial(\n                gemm_w_idx, tiled_mma_dV, acc_dV, tdVrPt, tdVrdOt, swap_AB=self.dKV_swapAB\n            )\n        else:\n            mma_pdo_fn = partial(gemm_w_idx, tiled_mma_dV, acc_dV, tCrB=tdVrdOt)\n        # dK += dS.T @ Q\n        sdSt = layout_utils.transpose_view(sdS)\n        sQt = layout_utils.transpose_view(sQ)\n        shape_mnk_dK = (self.tile_n, self.tile_hdim, self.tile_m)\n        acc_dK, tdKrdSt, tdKrQt = sm90_utils.partition_fragment_ABC(\n            wg_mma_dK, shape_mnk_dK, sdSt, sQt, swap_AB=self.dKV_swapAB\n        )\n        if const_expr(not self.mma_dkv_is_rs):\n            mma_dsq_fn = partial(\n                gemm_w_idx, tiled_mma_dK, acc_dK, tdKrdSt, tdKrQt, swap_AB=self.dKV_swapAB\n            )\n        else:\n            mma_dsq_fn = partial(gemm_w_idx, tiled_mma_dK, acc_dK, tCrB=tdKrQt)\n        # dQ = dS @ K\n        sKt = layout_utils.transpose_view(sK)\n        shape_mnk_dQ = (self.tile_m, self.tile_hdim, self.tile_n)\n        mma_dsk_fn = None\n        if const_expr(is_dQ_wg):\n            _, tdQrdS, tdQrKt = sm90_utils.partition_fragment_ABC(\n                wg_mma_dQ, shape_mnk_dQ, sdS, sKt, swap_AB=self.dQ_swapAB\n            )\n            mma_dsk_fn = partial(\n                gemm_zero_init,\n                tiled_mma_dQ,\n                shape_mnk_dQ[:2],\n                tdQrdS,\n                tdQrKt,\n                swap_AB=self.dQ_swapAB,\n            )\n\n        # Smem copy atom tiling for P/dS R2S\n        copy_P_r2s = None\n        mms_PdS = self.tile_n // (self.num_wg_mma // self.AtomLayoutMSdP)\n        if const_expr(sP is not None):\n            sP_cpy = sP if const_expr(not self.SdP_swapAB) else sPt\n            copy_P_r2s, _, _ = copy_utils.get_smem_store_C(\n                tiled_mma_SdP,\n                sP_cpy,\n                tidx,\n                self.arch,\n                transpose=self.SdP_swapAB,\n                position_independent=True,\n                major_mode_size=mms_PdS,\n            )\n        sdS_cpy = sdS if const_expr(not self.SdP_swapAB) else sdSt\n        copy_dS_r2s, _, _ = copy_utils.get_smem_store_C(\n            tiled_mma_SdP,\n            sdS_cpy,\n            tidx,\n            self.arch,\n            transpose=self.SdP_swapAB,\n            position_independent=True,\n            major_mode_size=mms_PdS,\n        )\n\n        tLSEsLSE = layout_utils.mma_partition_C_vec(\n            sLSE, thr_mma_SdP, expand_shape=self.tile_n, is_colvec=not self.SdP_swapAB\n        )\n        tLSEsdPsum = layout_utils.mma_partition_C_vec(\n            sdPsum, thr_mma_SdP, expand_shape=self.tile_n, is_colvec=not self.SdP_swapAB\n        )\n        # When shuffle=True, rows are distributed across 8 quads (4 threads each) within a warp.\n        # Each thread loads only ceil(num_rows/8) values;\n        shfl_copy = copy_utils.tiled_copy_1d(sLSE.element_type, num_threads=8, num_copy_elems=2)\n        if const_expr(self.shuffle_LSE):\n            tLSEsLSE = shfl_copy.get_slice(cute.arch.lane_idx() // 4).partition_S(tLSEsLSE)\n            # ((2, 1), 1, 2) -> (((2, 1), 1), 2)\n            tLSEsLSE = cute.group_modes(tLSEsLSE, 0, 2)\n        if const_expr(self.shuffle_dPsum):\n            tLSEsdPsum = shfl_copy.get_slice(cute.arch.lane_idx() // 4).partition_S(tLSEsdPsum)\n            tLSEsdPsum = cute.group_modes(tLSEsdPsum, 0, 2)\n\n        tdQsdQaccum = None\n        if const_expr(is_dQ_wg):\n            smem_thr_copy_dQaccum = r2s_tiled_copy_dQaccum.get_slice(tidx)\n            tdQsdQaccum = smem_thr_copy_dQaccum.partition_D(sdQaccum)\n\n        PdS_barrier = cutlass.pipeline.NamedBarrier(\n            barrier_id=int(NamedBarrierBwd.PdS), num_threads=self.num_mma_threads\n        )\n        score_mod_fn = partial(\n            self.apply_score_mod,\n            thr_mma_SdP=thr_mma_SdP,\n            softmax_scale=softmax_scale,\n            aux_tensors=aux_tensors,\n            fastdiv_mods=fastdiv_mods,\n        )\n        score_mod_bwd_fn = partial(\n            self.apply_score_mod_bwd,\n            thr_mma_SdP=thr_mma_SdP,\n            softmax_scale=softmax_scale,\n            aux_tensors=aux_tensors,\n            fastdiv_mods=fastdiv_mods,\n        )\n\n        mma_one_m_block_all = partial(\n            self.mma_one_m_block,\n            warp_group_idx=warp_group_idx,\n            mma_qk_fn=mma_qk_fn,\n            mma_dov_fn=mma_dov_fn,\n            mma_pdo_fn=mma_pdo_fn,\n            mma_dsq_fn=mma_dsq_fn,\n            mma_dsk_fn=mma_dsk_fn,\n            copy_P_r2s=copy_P_r2s,\n            copy_dS_r2s=copy_dS_r2s,\n            pipeline_Q=pipeline_Q,\n            pipeline_dO=pipeline_dO,\n            tLSEsLSE=tLSEsLSE,\n            tLSEsdPsum=tLSEsdPsum,\n            tdQsdQaccum=tdQsdQaccum,\n            softmax_scale_log2=softmax_scale_log2,\n            PdS_barrier=PdS_barrier,\n            # acc_dV=acc_dV,\n            # acc_dK=acc_dK,\n            is_dQ_wg=is_dQ_wg,\n        )\n\n        consumer_state_Q = cutlass.pipeline.make_pipeline_state(\n            cutlass.pipeline.PipelineUserType.Consumer, self.Q_stage\n        )\n        consumer_state_dO = cutlass.pipeline.make_pipeline_state(\n            cutlass.pipeline.PipelineUserType.Consumer, self.dO_stage\n        )\n        tile_scheduler = TileSchedulerCls()\n        work_tile = tile_scheduler.initial_work_tile_info()\n        while work_tile.is_valid_tile:\n            n_block, head_idx, batch_idx, _ = work_tile.tile_idx\n            seqlen = SeqlenInfoCls(batch_idx)\n            mask = AttentionMaskCls(seqlen)\n            score_mod_fn_cur = partial(\n                score_mod_fn,\n                batch_idx=batch_idx,\n                head_idx=head_idx,\n                n_block=n_block,\n                seqlen_info=seqlen,\n            )\n            score_mod_bwd_fn_cur = partial(\n                score_mod_bwd_fn,\n                batch_idx=batch_idx,\n                head_idx=head_idx,\n                n_block=n_block,\n                seqlen_info=seqlen,\n            )\n            m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block)\n\n            if const_expr(not self.use_block_sparsity):\n                process_tile = (\n                    const_expr(not self.is_local and not self.is_varlen_q)\n                    or m_block_min < m_block_max\n                )\n            else:\n                total_m_block_cnt = get_total_q_block_count_bwd(\n                    blocksparse_tensors,\n                    batch_idx,\n                    head_idx,\n                    n_block,\n                    subtile_factor=self.subtile_factor,\n                    m_block_max=m_block_max,\n                )\n                process_tile = total_m_block_cnt > Int32(0)\n\n            if process_tile:\n                if const_expr(not self.use_block_sparsity):\n                    mask_fn = partial(\n                        mask.apply_mask,\n                        batch_idx=batch_idx,\n                        head_idx=head_idx,\n                        n_block=n_block,\n                        thr_mma=thr_mma_SdP,\n                        mask_seqlen=True,\n                        mask_causal=self.is_causal,\n                        mask_local=self.is_local,\n                        mask_mod=self.mask_mod,\n                        aux_tensors=aux_tensors,\n                        fastdiv_mods=fastdiv_mods,\n                    )\n                    dKV_accumulate = False\n                    for m_block in cutlass.range(m_block_min, m_block_max, unroll=1):\n                        consumer_state_Q, consumer_state_dO = mma_one_m_block_all(\n                            m_block,\n                            consumer_state_Q,\n                            consumer_state_dO,\n                            mask_fn=mask_fn,\n                            score_mod_fn=score_mod_fn_cur,\n                            score_mod_bwd_fn=score_mod_bwd_fn_cur,\n                            dKV_accumulate=dKV_accumulate,\n                        )\n                        dKV_accumulate = True\n                else:\n                    consumer_state_Q, consumer_state_dO = consume_block_sparse_mma_bwd_sm90(\n                        blocksparse_tensors,\n                        batch_idx,\n                        head_idx,\n                        n_block,\n                        consumer_state_Q,\n                        consumer_state_dO,\n                        mma_one_m_block_all,\n                        mask,\n                        self.mask_mod,\n                        is_causal=self.is_causal,\n                        is_local=self.is_local,\n                        thr_mma_SdP=thr_mma_SdP,\n                        score_mod_fn=score_mod_fn_cur,\n                        score_mod_bwd_fn=score_mod_bwd_fn_cur,\n                        subtile_factor=self.subtile_factor,\n                        m_block_max=m_block_max,\n                        aux_tensors=aux_tensors,\n                        fastdiv_mods=fastdiv_mods,\n                    )\n\n                if const_expr(self.qhead_per_kvhead == 1):\n                    acc_dK.store(acc_dK.load() * softmax_scale)\n                self.epilogue_dKV(\n                    acc_dV,\n                    mdV,\n                    sV,\n                    acc_dK,\n                    mdK,\n                    sK,\n                    seqlen,\n                    tma_atom_dK,\n                    tma_atom_dV,\n                    tiled_mma_dK,\n                    tiled_mma_dV,\n                    tidx,\n                    n_block,\n                    head_idx,\n                    batch_idx,\n                    qhead_per_kvhead_divmod,\n                )\n            else:\n                # KV tile with zero Q blocks produces no dK/dV; write zeros.\n                if const_expr(self.use_block_sparsity or self.is_local or self.is_varlen_q):\n                    acc_dK.fill(0.0)\n                    acc_dV.fill(0.0)\n                    self.epilogue_dKV(\n                        acc_dV,\n                        mdV,\n                        sV,\n                        acc_dK,\n                        mdK,\n                        sK,\n                        seqlen,\n                        tma_atom_dK,\n                        tma_atom_dV,\n                        tiled_mma_dK,\n                        tiled_mma_dV,\n                        tidx,\n                        n_block,\n                        head_idx,\n                        batch_idx,\n                        qhead_per_kvhead_divmod,\n                    )\n\n            tile_scheduler.advance_to_next_work()\n            work_tile = tile_scheduler.get_current_work()\n\n        warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())\n        if warp_idx == 4:\n            cute.arch.cp_async_bulk_wait_group(0, read=True)\n\n    @staticmethod\n    @cute.jit\n    def _get_stat(tSrS: cute.Tensor, row: Int32, lane: Int32, shuffle: bool) -> Float32:\n        \"\"\"Retrieve the statistic for a given accumulator row.\n\n        When shuffle=False, direct register indexing.\n        When shuffle=True, warp shuffle from the thread group that holds the value.\n        \"\"\"\n        if const_expr(not shuffle):\n            return tSrS[row]\n        # tSrS: (((2, 1), 1), 1)), distributed across 8 threads in the warp\n        vecsize = cute.size(tSrS, mode=[0, 0])  # 2\n        idx0, off, idx1 = cute.idx2crd(row, (vecsize, 8, cute.shape(tSrS, mode=[0, 1])))\n        # register index: 0, 1, 0, 1, ..., 2, 3, 2, 3, ...\n        return utils.shuffle_sync(tSrS[idx0 + idx1 * vecsize], offset=off * 4 + (lane % 4))\n\n    @cute.jit\n    def mma_one_m_block(\n        self,\n        m_block: Int32,\n        consumer_state_Q: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple,\n        consumer_state_dO: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple,\n        warp_group_idx: Int32,\n        mma_qk_fn: Callable,\n        mma_dov_fn: Callable,\n        mma_pdo_fn: Callable,\n        mma_dsq_fn: Callable,\n        mma_dsk_fn: Callable,\n        copy_P_r2s: Optional[Callable],\n        copy_dS_r2s: Callable,\n        pipeline_Q: cutlass.pipeline.PipelineAsync,\n        pipeline_dO: cutlass.pipeline.PipelineAsync,\n        tLSEsLSE: cute.Tensor,\n        tLSEsdPsum: cute.Tensor,\n        tdQsdQaccum: Optional[cute.Tensor],\n        softmax_scale_log2: Float32,\n        PdS_barrier: cutlass.pipeline.NamedBarrier,\n        is_dQ_wg: cutlass.Constexpr[bool] = True,\n        mask_fn: Optional[Callable] = None,\n        score_mod_fn: Optional[Callable] = None,\n        score_mod_bwd_fn: Optional[Callable] = None,\n        dKV_accumulate: Boolean = True,\n    ):\n        consumer_state_dO_cur = (\n            consumer_state_Q if const_expr(self.Q_stage == self.dO_stage) else consumer_state_dO\n        )\n        smem_idx_Q = consumer_state_Q.index\n        smem_idx_dO = consumer_state_dO_cur.index if const_expr(self.dO_stage > 1) else 0\n        smem_idx_PdS = smem_idx_Q if const_expr(self.PdS_stage > 1) else 0\n        # (1) [GEMM 1] S = Q @ K^T\n        pipeline_Q.consumer_wait(consumer_state_Q, pipeline_Q.consumer_try_wait(consumer_state_Q))\n        acc_S = mma_qk_fn(A_idx=smem_idx_Q, wg_wait=-1)\n        # If shuffle_LSE, OOB reads are OK since sLSE is already padded\n        tLSErLSE = copy_utils.load_s2r(tLSEsLSE[None, smem_idx_Q])\n        # (2) [GEMM 2] dP = dO @ V.T\n        pipeline_dO.consumer_wait(\n            consumer_state_dO_cur, pipeline_dO.consumer_try_wait(consumer_state_dO_cur)\n        )\n        acc_dP = mma_dov_fn(A_idx=smem_idx_Q, wg_wait=1)\n\n        if const_expr(self.score_mod_bwd is not None):\n            acc_S_pre = cute.make_fragment_like(acc_S)\n            cute.autovec_copy(acc_S, acc_S_pre)\n\n        if const_expr(self.score_mod is not None):\n            score_mod_fn(acc_S, m_block=m_block)\n\n        # (3) [Pointwise 1] P = exp(S - LSE)\n        if cutlass.const_expr(mask_fn is not None):\n            mask_fn(acc_S, m_block=m_block)\n        acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S, transpose=self.SdP_swapAB)\n        lane_idx = cute.arch.lane_idx()\n        for r in cutlass.range_constexpr(cute.size(acc_S_mn, mode=[0])):\n            lse_val = self._get_stat(tLSErLSE, r, lane_idx, shuffle=self.shuffle_LSE)\n            for c in cutlass.range(cute.size(acc_S_mn, mode=[1]), unroll_full=True):\n                acc_S_mn[r, c] = cute.math.exp2(\n                    acc_S_mn[r, c] * softmax_scale_log2 - lse_val, fastmath=True\n                )\n        tLSErdPsum = copy_utils.load_s2r(tLSEsdPsum[None, smem_idx_dO])\n\n        # Convert P from f32 -> f16\n        tdVrP = utils.cvt_f16(layout_utils.reshape_acc_to_frgA(acc_S), self.dtype)\n        # R2S for P\n        if const_expr(not self.mma_dkv_is_rs):\n            # sync to ensure P has already been used in the previous iteration before overwriting\n            if const_expr(self.PdS_stage == 1):\n                PdS_barrier.arrive_and_wait()\n            copy_P_r2s(tdVrP, dst_idx=smem_idx_PdS)\n\n        # (4) [Pointwise 2] dS = P*(dP-dPsum)\n        warpgroup.wait_group(0)\n        acc_dP_mn = layout_utils.reshape_acc_to_mn(acc_dP, transpose=self.SdP_swapAB)\n        for r in cutlass.range_constexpr(cute.size(acc_dP_mn, mode=[0])):\n            dpsum_val = self._get_stat(tLSErdPsum, r, lane_idx, shuffle=self.shuffle_dPsum)\n            for c in cutlass.range(cute.size(acc_dP_mn, mode=[1]), unroll_full=True):\n                acc_dP_mn[r, c] = acc_S_mn[r, c] * (acc_dP_mn[r, c] - dpsum_val)\n\n        if const_expr(self.score_mod_bwd is not None):\n            score_mod_bwd_fn(acc_dP, acc_S_pre, m_block=m_block)\n\n        # Convert dS from f32 -> f16\n        tdKrdS = utils.cvt_f16(layout_utils.reshape_acc_to_frgA(acc_dP), self.dtype)\n\n        # If there's double buffering on dS, we don't need to sync here.\n        # Otherwise we might have WG1 writing to dS before WG2 is done reading from it during MmadQ.\n        # But because both WGs have to sync at the end of the loop and double buffering,\n        # this race condition is not possible.\n        # This sync is to ensure (1) P is written in case of !mma_dkv_is_rs and\n        # (2) dS is already read by the Mma in the previous iteration in case of mma_dkv_is_rs.\n        if const_expr(not self.mma_dkv_is_rs or (self.PdS_stage == 1 and self.mma_dkv_is_rs)):\n            cute.arch.fence_view_async_shared()\n            PdS_barrier.arrive_and_wait()\n\n        # R2S for dS\n        copy_dS_r2s(tdKrdS, dst_idx=smem_idx_PdS)\n\n        # (5) [GEMM 3] dV += P.T @ dO\n        if const_expr(not self.mma_dkv_is_rs):\n            mma_pdo_fn(\n                A_idx=smem_idx_PdS, B_idx=smem_idx_dO, zero_init=not dKV_accumulate, wg_wait=-1\n            )\n        else:\n            mma_pdo_fn(tCrA=tdVrP, B_idx=smem_idx_dO, zero_init=not dKV_accumulate, wg_wait=-1)\n\n        # smem fence to make sure sdS is written before it's read by WGMMA\n        cute.arch.fence_view_async_shared()\n        PdS_barrier.arrive_and_wait()\n\n        if const_expr(is_dQ_wg):\n            # (6) [GEMM 4] dQ = dS @ K\n            acc_dQ = mma_dsk_fn(A_idx=smem_idx_PdS, wg_wait=1)\n            pipeline_dO.consumer_release(consumer_state_dO_cur)  # release dO as dV mma is done\n\n            # (7) [GEMM 5] dK += dS.T @ Q\n            if const_expr(not self.mma_dkv_is_rs):\n                mma_dsq_fn(\n                    A_idx=smem_idx_PdS, B_idx=smem_idx_Q, zero_init=not dKV_accumulate, wg_wait=1\n                )\n            else:\n                mma_dsq_fn(tCrA=tdKrdS, B_idx=smem_idx_Q, zero_init=not dKV_accumulate, wg_wait=1)\n\n            # dQ R2S: wait for dQaccum_store to free the smem buffer, then write dQ to smem\n            # When dQ_single_wg, only WG0 enters here so warp_group_idx == 0\n            cute.arch.barrier(\n                barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx,\n                number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE,\n            )\n            tdQrdQaccum_flat = cute.make_tensor(\n                acc_dQ.iterator, cute.make_layout(tdQsdQaccum.shape)\n            )\n            cute.autovec_copy(tdQrdQaccum_flat, tdQsdQaccum)\n            cute.arch.fence_view_async_shared()\n            cute.arch.barrier_arrive(\n                barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx,\n                number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE,\n            )\n\n            warpgroup.wait_group(0)\n            pipeline_Q.consumer_release(consumer_state_Q)\n        else:\n            # dQ_single_wg: WG1 skips dQ, only does dV wait + dK\n            # (7) [GEMM 5] dK += dS.T @ Q\n            if const_expr(not self.mma_dkv_is_rs):\n                mma_dsq_fn(\n                    A_idx=smem_idx_PdS, B_idx=smem_idx_Q, zero_init=not dKV_accumulate, wg_wait=1\n                )\n            else:\n                mma_dsq_fn(tCrA=tdKrdS, B_idx=smem_idx_Q, zero_init=not dKV_accumulate, wg_wait=1)\n            pipeline_dO.consumer_release(consumer_state_dO_cur)\n            warpgroup.wait_group(0)\n            pipeline_Q.consumer_release(consumer_state_Q)\n\n        consumer_state_Q.advance()\n        consumer_state_dO.advance()\n        return consumer_state_Q, consumer_state_dO\n\n    @cute.jit\n    def epilogue_dKV(\n        self,\n        acc_dV: cute.Tensor,\n        mdV: cute.Tensor,\n        sV: cute.Tensor,\n        acc_dK: cute.Tensor,\n        mdK: cute.Tensor,\n        sK: cute.Tensor,\n        seqlen: SeqlenInfoQK,\n        tma_atom_dK: cute.CopyAtom,\n        tma_atom_dV: cute.CopyAtom,\n        tiled_mma_dK: cute.TiledMma,\n        tiled_mma_dV: cute.TiledMma,\n        tidx: Int32,\n        n_block: Int32,\n        head_idx: Int32,\n        batch_idx: Int32,\n        qhead_per_kvhead_divmod: Optional[FastDivmodDivisor] = None,\n    ):\n        epi_barrier = cutlass.pipeline.NamedBarrier(\n            barrier_id=int(NamedBarrierBwd.Epilogue), num_threads=self.num_mma_threads\n        )\n        warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())\n\n        if const_expr(self.qhead_per_kvhead == 1):\n            mdK_cur = seqlen.offset_batch_K(mdK, batch_idx, dim=3, ragged=self.varlen_k)[\n                None, None, head_idx\n            ]\n            mdV_cur = seqlen.offset_batch_K(mdV, batch_idx, dim=3, ragged=self.varlen_k)[\n                None, None, head_idx\n            ]\n            gdK = cute.local_tile(mdK_cur, (self.tile_n, self.tile_hdim), (n_block, 0))\n            gdV = cute.local_tile(mdV_cur, (self.tile_n, self.tile_hdimv), (n_block, 0))\n            store_dK, _, _ = copy_utils.tma_get_copy_fn(\n                tma_atom_dK, 0, cute.make_layout(1), sK, gdK, single_stage=True\n            )\n            store_dV, _, _ = copy_utils.tma_get_copy_fn(\n                tma_atom_dV, 0, cute.make_layout(1), sV, gdV, single_stage=True\n            )\n            sdV = sV if const_expr(not self.dKV_swapAB) else layout_utils.transpose_view(sV)\n            sdK = sK if const_expr(not self.dKV_swapAB) else layout_utils.transpose_view(sK)\n            copy_dV_r2s, _, _ = copy_utils.get_smem_store_C(\n                tiled_mma_dV,\n                sdV,\n                tidx,\n                self.arch,\n                transpose=self.dKV_swapAB,\n                position_independent=True,\n            )\n            copy_dK_r2s, _, _ = copy_utils.get_smem_store_C(\n                tiled_mma_dK,\n                sdK,\n                tidx,\n                self.arch,\n                transpose=self.dKV_swapAB,\n                position_independent=True,\n            )\n            cute.arch.cp_async_bulk_wait_group(1, read=True)\n            epi_barrier.arrive_and_wait()\n            copy_dV_r2s(acc_dV, dst_idx=None)\n            cute.arch.fence_view_async_shared()\n            epi_barrier.arrive_and_wait()\n            if warp_idx == 4:\n                store_dV()\n                cute.arch.cp_async_bulk_commit_group()\n            cute.arch.cp_async_bulk_wait_group(1, read=True)\n            epi_barrier.arrive_and_wait()\n            copy_dK_r2s(acc_dK, dst_idx=None)\n            cute.arch.fence_view_async_shared()\n            epi_barrier.arrive_and_wait()\n            if warp_idx == 4:\n                store_dK()\n                cute.arch.cp_async_bulk_commit_group()\n        else:\n            sdKaccum_shape0 = self.tile_n * self.tile_hdim // self.num_wg_mma\n            sdVaccum_shape0 = self.tile_n * self.tile_hdimv // self.num_wg_mma\n            sdKaccum_layout = cute.make_layout((sdKaccum_shape0, self.num_wg_mma))\n            sdVaccum_layout = cute.make_layout((sdVaccum_shape0, self.num_wg_mma))\n            head_idx_kv = head_idx // qhead_per_kvhead_divmod\n            mdKaccum_cur = seqlen.offset_batch_K(\n                mdK, batch_idx, dim=2, padded=True, multiple=self.tile_hdim\n            )[None, head_idx_kv]\n            mdVaccum_cur = seqlen.offset_batch_K(\n                mdV, batch_idx, dim=2, padded=True, multiple=self.tile_hdimv\n            )[None, head_idx_kv]\n            gdKaccum_ = cute.local_tile(mdKaccum_cur, (self.tile_n * self.tile_hdim,), (n_block,))\n            gdKaccum = cute.flat_divide(gdKaccum_, (sdKaccum_shape0,))\n            gdVaccum_ = cute.local_tile(mdVaccum_cur, (self.tile_n * self.tile_hdimv,), (n_block,))\n            gdVaccum = cute.flat_divide(gdVaccum_, (sdVaccum_shape0,))\n            # These two overlap each other\n            sVaccum_ptr = cute.recast_ptr(sV.iterator, dtype=Float32)\n            sdKaccum = cute.make_tensor(sVaccum_ptr, sdKaccum_layout)\n            sdVaccum = cute.make_tensor(sVaccum_ptr, sdVaccum_layout)\n            tiled_copy_dKVaccum_r2s = cute.make_tiled_copy_tv(\n                cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32, num_bits_per_copy=128),\n                cute.make_layout((self.num_threads_per_warp_group, self.num_wg_mma)),\n                cute.make_layout(128 // Float32.width),\n            )\n            thr_copy_dKVaccum_r2s = tiled_copy_dKVaccum_r2s.get_slice(tidx)\n            tdKsdKaccum = thr_copy_dKVaccum_r2s.partition_D(sdKaccum)\n            tdVsdVaccum = thr_copy_dKVaccum_r2s.partition_D(sdVaccum)\n\n            cute.arch.cp_async_bulk_wait_group(0, read=True)\n            epi_barrier.arrive_and_wait()\n            tdKrdKaccum_flat = cute.make_tensor(acc_dK.iterator, tdKsdKaccum.shape)\n            cute.autovec_copy(tdKrdKaccum_flat, tdKsdKaccum)\n            cute.arch.fence_view_async_shared()\n            epi_barrier.arrive_and_wait()\n            if warp_idx == 4:\n                with cute.arch.elect_one():\n                    for wg_idx in cutlass.range_constexpr(self.num_wg_mma):\n                        copy_utils.cpasync_reduce_bulk_add_f32(\n                            sdKaccum[None, wg_idx].iterator,\n                            gdKaccum[None, wg_idx].iterator,\n                            self.tma_copy_bytes[\"dKacc\"] // self.num_wg_mma,\n                        )\n                cute.arch.cp_async_bulk_commit_group()\n\n            cute.arch.cp_async_bulk_wait_group(0, read=True)\n            epi_barrier.arrive_and_wait()\n            tdVrdVaccum_flat = cute.make_tensor(acc_dV.iterator, tdVsdVaccum.shape)\n            cute.autovec_copy(tdVrdVaccum_flat, tdVsdVaccum)\n            cute.arch.fence_view_async_shared()\n            epi_barrier.arrive_and_wait()\n            if warp_idx == 4:\n                with cute.arch.elect_one():\n                    for wg_idx in cutlass.range_constexpr(self.num_wg_mma):\n                        copy_utils.cpasync_reduce_bulk_add_f32(\n                            sdVaccum[None, wg_idx].iterator,\n                            gdVaccum[None, wg_idx].iterator,\n                            self.tma_copy_bytes[\"dVacc\"] // self.num_wg_mma,\n                        )\n                cute.arch.cp_async_bulk_commit_group()\n\n    @cute.jit\n    def dQaccum_store(\n        self,\n        mdQaccum: cute.Tensor,\n        sdQaccum: cute.Tensor,\n        block_info: BlockInfo,\n        TileSchedulerCls: cutlass.Constexpr[Callable],\n        SeqlenInfoCls: cutlass.Constexpr[Callable],\n        blocksparse_tensors: Optional[BlockSparseTensors] = None,\n        mdQ_semaphore: Optional[cute.Tensor] = None,\n    ):\n        tidx, _, _ = cute.arch.thread_idx()\n        # warp-local thread index (dQaccum_store runs on warp 1, global tidx 32-63)\n        warp_local_tidx = tidx % cute.arch.WARP_SIZE\n        read_flag = const_expr(not self.deterministic)\n\n        tile_scheduler = TileSchedulerCls()\n        work_tile = tile_scheduler.initial_work_tile_info()\n        while work_tile.is_valid_tile:\n            n_block, head_idx, batch_idx, _ = work_tile.tile_idx\n            seqlen = SeqlenInfoCls(batch_idx)\n            if const_expr(not seqlen.has_cu_seqlens_q):\n                mdQaccum_cur = mdQaccum[None, head_idx, batch_idx]\n            else:\n                mdQaccum_cur = cute.domain_offset(\n                    (seqlen.padded_offset_q * self.tile_hdim,), mdQaccum[None, head_idx]\n                )\n            # ((M * K / num_wg_dQ, num_wg_dQ), num_m_blocks)\n            gdQaccum = cute.local_tile(\n                mdQaccum_cur,\n                (\n                    cute.make_layout(\n                        (self.tile_m * self.tile_hdim // self.num_wg_dQ, self.num_wg_dQ)\n                    ),\n                ),\n                (None,),\n            )\n\n            if const_expr(mdQ_semaphore is not None):\n                # mdQ_semaphore is (num_m_blocks, cluster_size, num_head, batch) after transpose\n                mdQ_semaphore_cur = mdQ_semaphore[None, None, head_idx, batch_idx]\n\n            m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block)\n            if const_expr(not self.use_block_sparsity):\n                process_tile = (\n                    const_expr(not self.is_local and not self.is_varlen_q)\n                    or m_block_min < m_block_max\n                )\n                loop_count = m_block_max - m_block_min\n            else:\n                total_block_cnt = get_total_q_block_count_bwd(\n                    blocksparse_tensors,\n                    batch_idx,\n                    head_idx,\n                    n_block,\n                    subtile_factor=self.subtile_factor,\n                    m_block_max=m_block_max,\n                )\n                process_tile = total_block_cnt > Int32(0)\n\n            if process_tile:\n                if const_expr(not self.use_block_sparsity):\n                    for iter_idx in cutlass.range(loop_count, unroll=1):\n                        m_block = m_block_min + iter_idx\n                        m_block_safe = m_block\n\n                        num_dQ_chunks = self.num_wg_dQ\n                        for warp_group_idx in cutlass.range_constexpr(num_dQ_chunks):\n                            if const_expr(not self.deterministic):\n                                # If deterministic, we already waited at the end of the prev iter\n                                cute.arch.cp_async_bulk_wait_group(\n                                    num_dQ_chunks - 1 - warp_group_idx, read=read_flag\n                                )\n                            cute.arch.barrier_arrive(\n                                barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx,\n                                number_of_threads=self.num_threads_per_warp_group\n                                + cute.arch.WARP_SIZE,\n                            )\n\n                        # Semaphore acquire: wait for prior n_blocks to finish writing this m_block\n                        if const_expr(self.deterministic):\n                            if const_expr(self.spt):\n                                _, n_block_max_for_m_block = block_info.get_n_block_min_max(\n                                    seqlen, m_block_safe\n                                )\n                                lock_value = n_block_max_for_m_block - 1 - n_block\n                            else:\n                                lock_value = n_block\n                            barrier.wait_eq(\n                                mdQ_semaphore_cur[(m_block_safe, None)].iterator,\n                                warp_local_tidx,\n                                0,  # flag_offset\n                                lock_value,\n                            )\n\n                        for warp_group_idx in cutlass.range_constexpr(num_dQ_chunks):\n                            cute.arch.barrier(\n                                barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx,\n                                number_of_threads=self.num_threads_per_warp_group\n                                + cute.arch.WARP_SIZE,\n                            )\n                            with cute.arch.elect_one():\n                                copy_utils.cpasync_reduce_bulk_add_f32(\n                                    sdQaccum[None, warp_group_idx].iterator,\n                                    gdQaccum[(None, warp_group_idx), m_block_safe].iterator,\n                                    self.tma_copy_bytes[\"dQ\"],\n                                )\n                            cute.arch.cp_async_bulk_commit_group()\n\n                        # Semaphore release: signal that this n_block is done with this m_block\n                        if const_expr(self.deterministic):\n                            cute.arch.cp_async_bulk_wait_group(0, read=read_flag)\n                            barrier.arrive_inc(\n                                mdQ_semaphore_cur[(m_block_safe, None)].iterator,\n                                warp_local_tidx,\n                                0,  # flag_offset\n                                1,\n                            )\n                else:\n                    assert not self.deterministic, (\n                        \"Deterministic not implemented for block-sparse backward\"\n                    )\n                    dQaccum_store_block_sparse_bwd_sm90(\n                        blocksparse_tensors,\n                        batch_idx,\n                        head_idx,\n                        n_block,\n                        sdQaccum,\n                        gdQaccum,\n                        subtile_factor=self.subtile_factor,\n                        m_block_max=m_block_max,\n                        num_mma_warp_groups=self.num_wg_mma,\n                        num_threads_per_warp_group=self.num_threads_per_warp_group,\n                        tma_copy_bytes_dQ=self.tma_copy_bytes[\"dQ\"],\n                    )\n\n            # For local masking + deterministic (non-spt): signal remaining m_blocks\n            # that this n_block won't visit, so they don't deadlock waiting.\n            if const_expr(\n                self.deterministic and not self.spt and block_info.window_size_left is not None\n            ):\n                m_block_global_max = cute.ceil_div(seqlen.seqlen_q, self.tile_m)\n                for m_block in cutlass.range(m_block_max, m_block_global_max, unroll=1):\n                    barrier.arrive_inc(\n                        mdQ_semaphore_cur[(m_block, None)].iterator,\n                        warp_local_tidx,\n                        0,  # flag_offset\n                        1,\n                    )\n\n            tile_scheduler.advance_to_next_work()\n            work_tile = tile_scheduler.get_current_work()\n\n        if const_expr(not self.deterministic):\n            cute.arch.cp_async_bulk_wait_group(0, read=True)\n"
  },
  {
    "path": "flash_attn/cute/flash_fwd.py",
    "content": "# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n# A reimplementation of\n# https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_fwd_kernel_sm80.h\n# and https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_fwd_kernel_sm90.h\n# from Cutlass C++ to Cute-DSL.\n# Built on Cute-DSL example: https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/ampere/flash_attention_v2.py\n\nimport math\nfrom types import SimpleNamespace\nfrom typing import Type, Callable, Optional, List\nfrom functools import partial\n\nimport cuda.bindings.driver as cuda\n\nimport cutlass\nimport cutlass.cute as cute\nfrom cutlass import Constexpr, Float32, Int32, const_expr, Boolean\nfrom cutlass.cute.nvgpu import cpasync, warp\nimport cutlass.utils as utils_basic\nfrom cutlass.base_dsl.arch import Arch\nfrom cutlass.cutlass_dsl import BaseDSL\n\nfrom quack import copy_utils\nfrom quack import layout_utils\n\nfrom flash_attn.cute import ampere_helpers as sm80_utils\nfrom flash_attn.cute.cute_dsl_utils import assume_tensor_aligned\nfrom flash_attn.cute import utils\nfrom flash_attn.cute.mask import AttentionMask\nfrom flash_attn.cute.softmax import Softmax\nfrom flash_attn.cute.seqlen_info import SeqlenInfoQK\nfrom flash_attn.cute.block_info import BlockInfo\nfrom flash_attn.cute.pack_gqa import PackGQA\nfrom flash_attn.cute.named_barrier import NamedBarrierFwd\nfrom flash_attn.cute.block_sparsity import BlockSparseTensors\nfrom flash_attn.cute.tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments\n\n\nclass FlashAttentionForwardBase:\n\n    def __init__(\n        self,\n        dtype: Type[cutlass.Numeric],\n        head_dim: int,\n        head_dim_v: Optional[int] = None,\n        qhead_per_kvhead: int = 1,\n        is_causal: bool = False,\n        is_local: bool = False,\n        pack_gqa: bool = True,\n        tile_m: int = 128,\n        tile_n: int = 128,\n        num_stages: int = 1,\n        num_threads: int = 128,\n        Q_in_regs: bool = False,\n        score_mod: Optional[cutlass.Constexpr] = None,\n        mask_mod: Optional[cutlass.Constexpr] = None,\n        has_aux_tensors: bool = False,\n        q_subtile_factor: int | None = None,\n    ):\n        \"\"\"Initializes the configuration for a flash attention kernel.\n\n        All contiguous dimensions must be at least 16 bytes aligned, which means that the head dimension\n        should be a multiple of 8.\n\n        :param head_dim: head dimension\n        :type head_dim: int\n        :param tile_m: m block size\n        :type tile_m: int\n        :param tile_n: n block size\n        :type tile_n: int\n        :param num_threads: number of threads\n        :type num_threads: int\n        :param is_causal: is causal\n        :param score_mod: A callable that takes the attention scores and applies a modification.\n            Callable signature: ``score_mod(scores, batch_idx, head_idx, q_idx, kv_idx, aux_tensors) -> Any``\n        :param mask_mod: A callable that takes the attention scores and returns a boolean representing whether that score should be masked.\n            Callable signature: ``mask_mod(batch_idx, head_idx, q_idx, kv_idx, aux_tensors) -> Boolean``\n        \"\"\"\n        self.dtype = dtype\n        # padding head_dim to a multiple of 16 as k_block_size\n        hdim_multiple_of = 16\n        self.tile_hdim = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of)\n        head_dim_v = head_dim_v if head_dim_v is not None else head_dim\n        self.same_hdim_kv = head_dim == head_dim_v\n        self.tile_hdimv = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of)\n        # Can save registers (and hence be faster) if we don't have to check hdim predication\n        self.check_hdim_oob = head_dim != self.tile_hdim\n        self.check_hdim_v_oob = head_dim_v != self.tile_hdimv\n        self.qhead_per_kvhead = qhead_per_kvhead\n        self.is_causal = is_causal\n        self.is_local = is_local\n        self.pack_gqa = pack_gqa\n        self.tile_m = tile_m\n        self.tile_n = tile_n\n        self.num_threads = num_threads\n        self.num_stages = num_stages\n        self.q_subtile_factor = q_subtile_factor\n        self.Q_in_regs = Q_in_regs\n        self.score_mod = score_mod\n        self.mask_mod = mask_mod\n        self.qk_acc_dtype = Float32\n        self.vec_size: cutlass.Constexpr = getattr(\n            score_mod, \"__vec_size__\", 1 if cutlass.const_expr(has_aux_tensors) else 2\n        )\n        if self.vec_size > 2:\n            raise ValueError(\n                f\"score_mod vec_size {self.vec_size} not supported on Sm80/90/120 \"\n                \"due to accumulator thread ownership pattern.\"\n            )\n        self.arch = BaseDSL._get_dsl().get_arch_enum()\n\n    @staticmethod\n    def can_implement(\n        dtype,\n        head_dim,\n        head_dim_v,\n        tile_m,\n        tile_n,\n        num_stages,\n        num_threads,\n        is_causal,\n        Q_in_regs=False,\n    ) -> bool:\n        \"\"\"Check if the kernel can be implemented with the given parameters.\n\n        :param dtype: data type\n        :type dtype: cutlass.Numeric\n        :param head_dim: head dimension\n        :type head_dim: int\n        :param tile_m: m block size\n        :type tile_m: int\n        :param tile_n: n block size\n        :type tile_n: int\n        :param num_threads: number of threads\n        :type num_threads: int\n        :param is_causal: is causal\n        :type is_causal: bool\n\n        :return: True if the kernel can be implemented, False otherwise\n        :rtype: bool\n        \"\"\"\n        if dtype not in [cutlass.Float16, cutlass.BFloat16]:\n            return False\n        if head_dim % 8 != 0:\n            return False\n        if head_dim_v % 8 != 0:\n            return False\n        if tile_n % 16 != 0:\n            return False\n        if num_threads % 32 != 0:\n            return False\n        # Check if block size setting is out of shared memory capacity\n        # Shared memory usage: Q tile + (K tile + V tile) where K and V use the same tile size\n        smem_usage_Q = tile_m * head_dim * 2\n        smem_usage_K = tile_n * head_dim * num_stages * 2\n        smem_usage_V = tile_n * head_dim_v * num_stages * 2\n        smem_usage_QV = (\n            (smem_usage_Q + smem_usage_V) if not Q_in_regs else max(smem_usage_Q, smem_usage_V)\n        )\n        smem_usage = smem_usage_QV + smem_usage_K\n        # TODO: sm86 and sm89\n        smem_capacity = utils_basic.get_smem_capacity_in_bytes(\"sm_80\")\n        if smem_usage > smem_capacity:\n            return False\n        # Check if twice the block size is divisible by the number of threads\n        if (tile_m * 2) % num_threads != 0:\n            return False\n        return True\n\n    def _check_type(\n        self,\n        mQ_type: Type[cutlass.Numeric],\n        mK_type: Type[cutlass.Numeric],\n        mV_type: Type[cutlass.Numeric],\n        mO_type: Type[cutlass.Numeric],\n        mLSE_type: Type[cutlass.Numeric] | None,\n        mCuSeqlensQ_type: Type[cutlass.Numeric] | None,\n        mCuSeqlensK_type: Type[cutlass.Numeric] | None,\n        mSeqUsedQ_type: Type[cutlass.Numeric] | None,\n        mSeqUsedK_type: Type[cutlass.Numeric] | None,\n    ):\n        # Get the data type and check if it is fp16 or bf16\n        if const_expr(not (mQ_type == mK_type == mV_type == mO_type)):\n            raise TypeError(\"All tensors must have the same data type\")\n        if const_expr(mQ_type not in [cutlass.Float16, cutlass.BFloat16]):\n            raise TypeError(\"Only Float16 or BFloat16 is supported\")\n        if const_expr(mLSE_type not in [None, Float32]):\n            raise TypeError(\"LSE tensor must be Float32\")\n        if const_expr(mCuSeqlensQ_type not in [None, Int32]):\n            raise TypeError(\"cu_seqlens_q tensor must be Int32\")\n        if const_expr(mCuSeqlensK_type not in [None, Int32]):\n            raise TypeError(\"cu_seqlens_k tensor must be Int32\")\n        if const_expr(mSeqUsedQ_type not in [None, Int32]):\n            raise TypeError(\"seqused_q tensor must be Int32\")\n        if const_expr(mSeqUsedK_type not in [None, Int32]):\n            raise TypeError(\"seqused_k tensor must be Int32\")\n        assert mQ_type == self.dtype\n\n    def _setup_attributes(self):\n        # ///////////////////////////////////////////////////////////////////////////////\n        # Shared memory layout: Q/K/V\n        # ///////////////////////////////////////////////////////////////////////////////\n        sQ_layout_atom, sK_layout_atom, sV_layout_atom, sO_layout_atom, sP_layout_atom = (\n            self._get_smem_layout_atom()\n        )\n        self.sQ_layout = cute.tile_to_shape(\n            sQ_layout_atom,\n            (self.tile_m, self.tile_hdim),\n            (0, 1),\n        )\n        self.sK_layout = cute.tile_to_shape(\n            sK_layout_atom,\n            (self.tile_n, self.tile_hdim, self.num_stages),\n            (0, 1, 2),\n        )\n        self.sV_layout = cute.tile_to_shape(\n            sV_layout_atom,\n            (self.tile_n, self.tile_hdimv, self.num_stages),\n            (0, 1, 2),\n        )\n        self.sO_layout = cute.tile_to_shape(\n            sO_layout_atom,\n            (self.tile_m, self.tile_hdimv),\n            (0, 1),\n        )\n        if const_expr(sP_layout_atom is not None):\n            self.sP_layout = cute.tile_to_shape(\n                sP_layout_atom,\n                (self.tile_m, self.tile_n),\n                (0, 1),\n            )\n        else:\n            self.sP_layout = None\n\n        # ///////////////////////////////////////////////////////////////////////////////\n        # GMEM Tiled copy:\n        # ///////////////////////////////////////////////////////////////////////////////\n        # Thread layouts for copies\n        universal_copy_bits = 128\n        async_copy_elems = universal_copy_bits // self.dtype.width\n        # atom_async_copy: async copy atom for QKV load\n        atom_async_copy = cute.make_copy_atom(\n            cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),\n            self.dtype,\n            num_bits_per_copy=universal_copy_bits,\n        )\n        # atom_universal_copy: universal copy atom for O store\n        atom_universal_copy = cute.make_copy_atom(\n            cute.nvgpu.CopyUniversalOp(),\n            self.dtype,\n            num_bits_per_copy=universal_copy_bits,\n        )\n        # tQ_layout and tK_layout: thread layout for QK load\n        tQK_shape_dim_1 = sQ_layout_atom.outer.shape[1] // async_copy_elems\n        assert self.num_Q_load_threads % tQK_shape_dim_1 == 0, (\n            \"num_threads must be divisible by tQK_shape_dim_1\"\n        )\n        assert self.num_producer_threads % tQK_shape_dim_1 == 0, (\n            \"num_threads must be divisible by tQK_shape_dim_1\"\n        )\n        tQ_layout = cute.make_ordered_layout(\n            (self.num_Q_load_threads // tQK_shape_dim_1, tQK_shape_dim_1),\n            order=(1, 0),\n        )\n        tK_layout = cute.make_ordered_layout(\n            (self.num_producer_threads // tQK_shape_dim_1, tQK_shape_dim_1),\n            order=(1, 0),\n        )\n        # So that we don't have to check if we overshoot kBlockM when we load Q\n        assert self.tile_m % tQ_layout.shape[0] == 0\n        tV_shape_dim_1 = sV_layout_atom.outer.shape[1] // async_copy_elems\n        tV_layout = cute.make_ordered_layout(\n            (self.num_producer_threads // tV_shape_dim_1, tV_shape_dim_1),\n            order=(1, 0),\n        )\n        # TODO: need a different layout for O if O dtype is not the same as V dtype\n        # tO_layout: thread layout for O store\n        tO_layout = cute.make_ordered_layout(\n            (self.num_epilogue_threads // tV_shape_dim_1, tV_shape_dim_1),\n            order=(1, 0),\n        )\n        # So that we don't have to check if we overshoot kBlockM when we store O\n        assert self.tile_m % tO_layout.shape[0] == 0\n\n        # Value layouts for copies\n        vQKV_layout = cute.make_layout((1, async_copy_elems))\n        vO_layout = vQKV_layout\n\n        self.gmem_tiled_copy_Q = cute.make_tiled_copy_tv(atom_async_copy, tQ_layout, vQKV_layout)\n        self.gmem_tiled_copy_K = cute.make_tiled_copy_tv(atom_async_copy, tK_layout, vQKV_layout)\n        self.gmem_tiled_copy_V = cute.make_tiled_copy_tv(atom_async_copy, tV_layout, vQKV_layout)\n        # gmem_tiled_copy_O: tiled copy for O store\n        self.gmem_tiled_copy_O = cute.make_tiled_copy_tv(atom_universal_copy, tO_layout, vO_layout)\n\n    def _get_smem_layout_atom(self):\n        raise NotImplementedError()\n\n    def _get_tiled_mma(self):\n        raise NotImplementedError()\n\n    def _get_shared_storage_cls(self):\n        raise NotImplementedError()\n\n    @cute.jit\n    def __call__(\n        self,\n        mQ: cute.Tensor,\n        mK: cute.Tensor,\n        mV: cute.Tensor,\n        mO: cute.Tensor,\n        mLSE: Optional[cute.Tensor],\n        softmax_scale: Float32,\n        # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI).\n        stream: cuda.CUstream = None,\n    ):\n        \"\"\"Configures and launches the flash attention kernel.\n\n        mQ/mK/mV/mO has same data types(supports fp16 and bf16) and same layout:\n        (batch_size, seqlen_q, num_head, head_dim):(_, _, _, 1)\n        \"\"\"\n        raise NotImplementedError()\n\n    @cute.jit\n    def epilogue(\n        self,\n        acc_O: cute.Tensor,\n        lse: cute.Tensor,\n        mO: cute.Tensor,\n        mLSE: Optional[cute.Tensor],\n        sO: cute.Tensor,\n        seqlen: SeqlenInfoQK,\n        gmem_tiled_copy_O: cute.TiledCopy,\n        tma_atom_O: Optional[cute.CopyAtom],\n        tiled_mma: cute.TiledMma,\n        tidx: Int32,\n        m_block: Int32,\n        head_idx: Int32,\n        batch_idx: Int32,\n    ):\n        # store acc_O\n        rO = cute.make_fragment_like(acc_O, self.dtype)\n        rO.store(acc_O.load().to(self.dtype))\n        # Make sure all threads have finished reading V\n        cute.arch.barrier(\n            barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads\n        )\n        smem_copy_atom_O = utils.get_smem_store_atom(self.arch.major * 10 + self.arch.minor, self.dtype)\n        smem_thr_copy_O = cute.make_tiled_copy_C(smem_copy_atom_O, tiled_mma).get_slice(tidx)\n        taccOrO = smem_thr_copy_O.retile(rO)\n        taccOsO = smem_thr_copy_O.partition_D(sO)\n        # taccOsO = copy_utils.partition_D_position_independent(smem_thr_copy_O, sO)\n        # copy acc O from rmem to smem with the smem copy atom\n        cute.copy(smem_copy_atom_O, taccOrO, taccOsO)\n\n        cO = cute.make_identity_tensor((self.tile_m, self.tile_hdimv))\n        pack_gqa = PackGQA(\n            self.tile_m, self.tile_hdimv, self.check_hdim_v_oob, self.qhead_per_kvhead\n        )\n\n        # Write LSE from rmem -> gmem\n        if const_expr(mLSE is not None):\n            if const_expr(not seqlen.has_cu_seqlens_q):\n                mLSE_cur = mLSE[None, head_idx, batch_idx]\n            else:\n                offset = seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q)\n                mLSE_cur = cute.domain_offset((offset,), mLSE[None, head_idx])\n            if const_expr(not self.pack_gqa):\n                gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (m_block,))\n                gLSE_expanded_layout = cute.append(\n                    gLSE.layout, cute.make_layout((self.tile_hdimv,), stride=(0,))\n                )\n                gLSE_expanded = cute.make_tensor(gLSE.iterator, gLSE_expanded_layout)\n                thr_mma = tiled_mma.get_slice(tidx)\n                taccOgLSE = layout_utils.reshape_acc_to_mn(thr_mma.partition_C(gLSE_expanded))\n                assert cute.size(taccOgLSE, mode=[0]) == cute.size(lse)\n                taccOcO = layout_utils.reshape_acc_to_mn(thr_mma.partition_C(cO))\n                t0accOcO = layout_utils.reshape_acc_to_mn(thr_mma.get_slice(0).partition_C(cO))\n                # Only the thread corresponding to column 0 writes out the lse to gmem\n                if taccOcO[0][1] == 0:\n                    for m in cutlass.range_constexpr(cute.size(taccOgLSE.shape[1])):\n                        if (\n                            t0accOcO[m, 0][0]\n                            < seqlen.seqlen_q - m_block * self.tile_m - taccOcO[0][0]\n                        ):\n                            taccOgLSE[m, 0] = lse[m]\n            else:\n                pack_gqa.store_LSE(mLSE_cur, lse, tiled_mma, tidx, m_block, seqlen.seqlen_q)\n\n        ragged = self.use_tma_O and (seqlen.has_cu_seqlens_q or seqlen.has_seqused_q)\n        mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3, ragged=ragged)[None, None, head_idx]\n        # thr_mma = tiled_mma.get_slice(tidx)\n        # taccOgO = thr_mma.partition_C(gO)\n        # cute.autovec_copy(rO, taccOgO)\n        # sync to make sure all smem stores are done\n        if const_expr(self.use_tma_O):\n            # ensure smem writes are visible to TMA\n            cute.arch.fence_view_async_shared()\n            cute.arch.barrier_arrive(\n                barrier_id=int(NamedBarrierFwd.Epilogue),\n                number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE,\n            )\n            gO = cute.local_tile(mO_cur, (self.tile_m, self.tile_hdimv), (m_block, 0))\n            store_O, _, _ = copy_utils.tma_get_copy_fn(\n                tma_atom_O, 0, cute.make_layout(1), sO, gO, single_stage=True\n            )\n            warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())\n            if warp_idx == 4:\n                cute.arch.barrier(\n                    barrier_id=int(NamedBarrierFwd.Epilogue),\n                    number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE,\n                )\n                store_O()\n                cute.arch.cp_async_bulk_commit_group()\n                cute.arch.cp_async_bulk_wait_group(0, read=True)\n        else:\n            cute.arch.barrier(\n                barrier_id=int(NamedBarrierFwd.Epilogue),\n                number_of_threads=self.num_epilogue_threads,\n            )\n            gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx)\n            tOsO = gmem_thr_copy_O.partition_S(sO)\n            tOrO = cute.make_fragment_like(tOsO, self.dtype)\n            # load acc O from smem to rmem for wider vectorization\n            cute.autovec_copy(tOsO, tOrO)\n            if const_expr(not self.pack_gqa):\n                gO = cute.local_tile(mO_cur, (self.tile_m, self.tile_hdimv), (m_block, 0))\n                tOgO = gmem_thr_copy_O.partition_D(gO)\n                tOcO = gmem_thr_copy_O.partition_S(cO)\n                t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO)\n                tOpO = utils.predicate_k(tOcO, limit=mO.shape[1])\n                # copy acc O from rmem to gmem\n                for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])):\n                    if (\n                        t0OcO[0, rest_m, 0][0]\n                        < seqlen.seqlen_q - m_block * self.tile_m - tOcO[0][0]\n                    ):\n                        cute.copy(\n                            gmem_tiled_copy_O,\n                            tOrO[None, rest_m, None],\n                            tOgO[None, rest_m, None],\n                            pred=tOpO[None, rest_m, None]\n                            if const_expr(self.check_hdim_v_oob)\n                            else None,\n                        )\n            else:\n                pack_gqa.store_O(mO_cur, tOrO, gmem_tiled_copy_O, tidx, m_block, seqlen.seqlen_q)\n\n    @cute.jit\n    def advance_pipeline(self, pipeline_index):\n        return pipeline_index + 1 if pipeline_index < self.num_stages - 1 else 0\n\n    @cute.jit\n    def load_Q(\n        self,\n        gmem_thr_copy: cute.TiledCopy,\n        gQ: cute.Tensor,\n        sQ: cute.Tensor,\n        block: Int32,\n        seqlen: Int32,\n        headdim: Int32,\n    ):\n        tQsQ, tQgQ = gmem_thr_copy.partition_D(sQ), gmem_thr_copy.partition_S(gQ)\n        cQ = cute.make_identity_tensor((self.tile_m, self.tile_hdim))\n        tQcQ = gmem_thr_copy.partition_S(cQ)\n        t0QcQ = gmem_thr_copy.get_slice(0).partition_S(cQ)\n        tQpQ = utils.predicate_k(tQcQ, limit=headdim)\n        for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])):\n            # Instead of using tQcQ, we using t0QcQ and subtract the offset from the limit\n            # (seqlen - block * kBlockM). This is because the entries of t0QcQ are known at compile time.\n            if t0QcQ[0, m, 0][0] < seqlen - block * self.tile_m - tQcQ[0][0]:\n                cute.copy(\n                    gmem_thr_copy,\n                    tQgQ[None, m, None],\n                    tQsQ[None, m, None],\n                    pred=tQpQ[None, m, None] if const_expr(self.check_hdim_oob) else None,\n                )\n            # We don't need to clear the sQ smem tiles since we'll only write out the valid outputs\n\n    @cute.jit\n    def load_K(\n        self,\n        gmem_tiled_copy: cute.TiledCopy,\n        tKgK: cute.Tensor,\n        tKsK: cute.Tensor,\n        tKcK: cute.Tensor,\n        t0KcK: cute.Tensor,\n        tKpK: cute.Tensor,\n        block: Int32,\n        smem_pipe_write: Int32,\n        seqlen: Int32,\n        need_predicates: cutlass.Constexpr,\n    ):\n        # Do we need to check if we overshoot kBlockN when we load K?\n        is_even_n_smem_k = self.tile_n % gmem_tiled_copy.tiler_mn[0].shape == 0\n        if const_expr(need_predicates or not is_even_n_smem_k):\n            # Instead of using tKcK, we using t0KcK and subtract the offset from the limit\n            # (seqlen - block * kBlockN). This is because the entries of t0KcK are known at compile time.\n            if const_expr(is_even_n_smem_k):\n                seqlen_limit = seqlen - block * self.tile_n\n            else:\n                if const_expr(not need_predicates):\n                    seqlen_limit = self.tile_n\n                else:\n                    seqlen_limit = cutlass.min(seqlen - block * self.tile_n, self.tile_n)\n            seqlen_limit -= tKcK[0][0]\n            for n in cutlass.range_constexpr(cute.size(tKsK.shape[1])):\n                if t0KcK[0, n, 0][0] < seqlen_limit:\n                    cute.copy(\n                        gmem_tiled_copy,\n                        tKgK[None, n, None, block],\n                        tKsK[\n                            None, n, None, smem_pipe_write if const_expr(self.num_stages > 1) else 0\n                        ],\n                        pred=tKpK[None, n, None] if const_expr(self.check_hdim_oob) else None,\n                    )\n                # We don't need to clear the sK smem tiles since we'll mask out the scores anyway.\n        else:\n            cute.copy(\n                gmem_tiled_copy,\n                tKgK[None, None, None, block],\n                tKsK[None, None, None, smem_pipe_write if const_expr(self.num_stages > 1) else 0],\n                pred=tKpK if const_expr(self.check_hdim_oob) else None,\n            )\n\n    @cute.jit\n    def load_V(\n        self,\n        gmem_tiled_copy: cute.TiledCopy,\n        tVgV: cute.Tensor,\n        tVsV: cute.Tensor,\n        tVcV: cute.Tensor,\n        t0VcV: cute.Tensor,\n        tVpV: cute.Tensor,\n        block: Int32,\n        smem_pipe_write: Int32,\n        seqlen: Int32,\n        need_predicates: cutlass.Constexpr,\n    ):\n        # Do we need to check if we overshoot kBlockN when we load V?\n        is_even_n_smem_v = self.tile_n % gmem_tiled_copy.tiler_mn[0].shape == 0\n        if const_expr(need_predicates or not is_even_n_smem_v):\n            for n in cutlass.range_constexpr(cute.size(tVsV.shape[1])):\n                # If kBlockN doesn't evenly divide the tiled copy, only the last `n` needs to be checked\n                if (\n                    is_even_n_smem_v\n                    or n < cute.size(tVsV.shape[1]) - 1\n                    or tVcV[0, n, 0][0] < self.tile_n\n                ):\n                    predicate = tVpV[None, n, None] if const_expr(self.check_hdim_v_oob) else None\n                    if const_expr(need_predicates):\n                        seqlen_limit = seqlen - block * self.tile_n - tVcV[0][0]\n                        predicate_n = t0VcV[0, n, 0][0] < seqlen_limit\n                        predicate = cute.make_fragment_like(tVpV[None, 0, None])\n                        for k in cutlass.range_constexpr(cute.size(predicate.shape[1])):\n                            for i in cutlass.range_constexpr(cute.size(predicate.shape[0])):\n                                predicate[i, k] = (\n                                    tVpV[i, n, k] if const_expr(self.check_hdim_v_oob) else True\n                                ) and predicate_n\n                    cute.copy(\n                        gmem_tiled_copy,\n                        tVgV[None, n, None, block],\n                        tVsV[\n                            None, n, None, smem_pipe_write if const_expr(self.num_stages > 1) else 0\n                        ],\n                        pred=predicate,\n                    )\n        else:\n            cute.copy(\n                gmem_tiled_copy,\n                tVgV[None, None, None, block],\n                tVsV[None, None, None, smem_pipe_write if const_expr(self.num_stages > 1) else 0],\n                pred=tVpV if const_expr(self.check_hdim_v_oob) else None,\n            )\n\n\nclass FlashAttentionForwardSm80(FlashAttentionForwardBase):\n    def _get_smem_layout_atom(self):\n        sQ_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.tile_hdim)\n        sK_layout_atom = sQ_layout_atom\n        sV_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.tile_hdimv)\n        sO_layout_atom = sV_layout_atom\n        sP_layout_atom = None\n        return sQ_layout_atom, sK_layout_atom, sV_layout_atom, sO_layout_atom, sP_layout_atom\n\n    def _get_tiled_mma(self):\n        tiled_mma_qk = cute.make_tiled_mma(\n            warp.MmaF16BF16Op(self.dtype, Float32, (16, 8, 16)),\n            (self.num_threads // 32, 1, 1),\n            permutation_mnk=(self.num_threads // 32 * 16, 16, 16),\n        )\n        tiled_mma_pv = cute.make_tiled_mma(\n            warp.MmaF16BF16Op(self.dtype, Float32, (16, 8, 16)),\n            (self.num_threads // 32, 1, 1),\n            permutation_mnk=(self.num_threads // 32 * 16, 16, 16),\n        )\n        return tiled_mma_qk, tiled_mma_pv\n\n    def _get_shared_storage_cls(self):\n        sQ_struct, sK_struct, sV_struct = [\n            cute.struct.Align[cute.struct.MemRange[self.dtype, cute.cosize(layout)], 1024]\n            for layout in (self.sQ_layout, self.sK_layout, self.sV_layout)\n        ]\n        cosize_sQV = max(cute.cosize(self.sQ_layout), cute.cosize(self.sV_layout))\n        sQV_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sQV], 1024]\n\n        @cute.struct\n        class SharedStorageQKV:\n            sV: sV_struct\n            sQ: sQ_struct\n            sK: sK_struct\n\n        @cute.struct\n        class SharedStorageSharedQV:\n            sQ: sQV_struct\n            sK: sK_struct\n\n        return SharedStorageQKV if const_expr(not self.Q_in_regs) else SharedStorageSharedQV\n\n    @cute.jit\n    def __call__(\n        self,\n        mQ: cute.Tensor,\n        mK: cute.Tensor,\n        mV: cute.Tensor,\n        mO: cute.Tensor,\n        mLSE: Optional[cute.Tensor],\n        softmax_scale: Float32,\n        mCuSeqlensQ: Optional[cute.Tensor] = None,\n        mCuSeqlensK: Optional[cute.Tensor] = None,\n        mSeqUsedQ: Optional[cute.Tensor] = None,\n        mSeqUsedK: Optional[cute.Tensor] = None,\n        mPageTable: Optional[cute.Tensor] = None,\n        window_size_left: Optional[Int32] = None,\n        window_size_right: Optional[Int32] = None,\n        learnable_sink: Optional[cute.Tensor] = None,\n        blocksparse_tensors: Optional[BlockSparseTensors] = None,\n        aux_tensors=None,\n        # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI).\n        stream: cuda.CUstream = None,\n    ):\n        \"\"\"Configures and launches the flash attention kernel.\n\n        mQ/mK/mV/mO has same data types(supports fp16 and bf16) and same layout:\n        (batch_size, seqlen_q, num_head, head_dim):(_, _, _, 1)\n        \"\"\"\n        assert learnable_sink is None, \"Learnable sink is not supported in this kernel\"\n        self._check_type(\n            *(t.element_type if t is not None else None for t in (mQ, mK, mV, mO, mLSE, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK))\n        )\n        tiled_mma_qk, tiled_mma_pv = self._get_tiled_mma()\n        self.num_mma_threads = tiled_mma_pv.size\n        self.num_producer_threads = self.num_threads\n        self.num_Q_load_threads = self.num_threads\n        self.num_epilogue_threads = self.num_threads\n        # self.use_tma_O = self.arch >= 90 and mCuSeqlensQ is None\n        self.use_tma_O = self.arch >= Arch.sm_90\n        self._setup_attributes()\n        SharedStorage = self._get_shared_storage_cls()\n        mQ, mK, mV, mO = [assume_tensor_aligned(t) for t in (mQ, mK, mV, mO)]\n        # Layout permutation: 4D non-varlen vs 3D varlen\n        QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1]\n        KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1]\n        mQ, mO = [\n            cute.make_tensor(t.iterator, cute.select(t.layout, mode=QO_layout_transpose))\n            for t in (mQ, mO)\n        ]\n        mK, mV = [\n            cute.make_tensor(t.iterator, cute.select(t.layout, mode=KV_layout_transpose))\n            for t in (mK, mV)\n        ]\n        if const_expr(mLSE is not None):\n            LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0]\n            mLSE = cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose))\n        # TileScheduler for varlen, simple grid for non-varlen\n        if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None):\n            TileScheduler = SingleTileVarlenScheduler\n        else:\n            TileScheduler = SingleTileScheduler\n        num_batch = (\n            mCuSeqlensQ.shape[0] - 1\n            if const_expr(mCuSeqlensQ is not None)\n            else mQ.shape[3]\n        )\n        tile_sched_args = TileSchedulerArguments(\n            num_block=cute.ceil_div(mQ.shape[0], self.tile_m),\n            num_head=cute.size(mQ.shape[2]),\n            num_batch=num_batch,\n            num_splits=1,\n            seqlen_k=0,\n            headdim=mQ.shape[1],\n            headdim_v=mV.shape[1],\n            total_q=cute.size(mQ.shape[0])\n            if const_expr(mCuSeqlensQ is not None)\n            else cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]),\n            tile_shape_mn=(self.tile_m, self.tile_n),\n            qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,\n            mCuSeqlensQ=mCuSeqlensQ,\n            mSeqUsedQ=mSeqUsedQ,\n        )\n        tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)\n        grid_dim = TileScheduler.get_grid_shape(tile_sched_params)\n        softmax_scale_log2, softmax_scale = utils.compute_softmax_scale_log2(softmax_scale, self.score_mod)\n        fastdiv_mods = utils.compute_fastdiv_mods(mQ, mK, self.qhead_per_kvhead, self.pack_gqa, aux_tensors)\n\n        self.kernel(\n            mQ,\n            mK,\n            mV,\n            mO,\n            mLSE,\n            mCuSeqlensQ,\n            mCuSeqlensK,\n            mSeqUsedQ,\n            mSeqUsedK,\n            softmax_scale_log2,\n            softmax_scale,\n            window_size_left,\n            window_size_right,\n            self.sQ_layout,\n            self.sK_layout,\n            self.sV_layout,\n            self.sO_layout,\n            self.sP_layout,\n            self.gmem_tiled_copy_Q,\n            self.gmem_tiled_copy_K,\n            self.gmem_tiled_copy_V,\n            self.gmem_tiled_copy_O,\n            tiled_mma_qk,\n            tiled_mma_pv,\n            SharedStorage,\n            tile_sched_params,\n            TileScheduler,\n            aux_tensors,\n            fastdiv_mods,\n        ).launch(\n            grid=grid_dim,\n            block=[self.num_threads, 1, 1],\n            smem=SharedStorage.size_in_bytes(),\n            stream=stream,\n        )\n\n    @cute.kernel\n    def kernel(\n        self,\n        mQ: cute.Tensor,\n        mK: cute.Tensor,\n        mV: cute.Tensor,\n        mO: cute.Tensor,\n        mLSE: Optional[cute.Tensor],\n        mCuSeqlensQ: Optional[cute.Tensor],\n        mCuSeqlensK: Optional[cute.Tensor],\n        mSeqUsedQ: Optional[cute.Tensor],\n        mSeqUsedK: Optional[cute.Tensor],\n        softmax_scale_log2: Float32,\n        softmax_scale: Optional[Float32],\n        window_size_left: Optional[Int32],\n        window_size_right: Optional[Int32],\n        sQ_layout: cute.ComposedLayout,\n        sK_layout: cute.ComposedLayout,\n        sV_layout: cute.ComposedLayout,\n        sO_layout: cute.ComposedLayout,\n        sP_layout: cute.ComposedLayout | None,\n        gmem_tiled_copy_Q: cute.TiledCopy,\n        gmem_tiled_copy_K: cute.TiledCopy,\n        gmem_tiled_copy_V: cute.TiledCopy,\n        gmem_tiled_copy_O: cute.TiledCopy,\n        tiled_mma_qk: cute.TiledMma,\n        tiled_mma_pv: cute.TiledMma,\n        SharedStorage: cutlass.Constexpr,\n        tile_sched_params,\n        TileScheduler: cutlass.Constexpr[Callable],\n        aux_tensors=None,\n        fastdiv_mods=None,\n    ):\n        # Thread index, block index\n        tidx, _, _ = cute.arch.thread_idx()\n\n        tile_scheduler = TileScheduler.create(tile_sched_params)\n        work_tile = tile_scheduler.initial_work_tile_info()\n        m_block, num_head, batch_size, _ = work_tile.tile_idx\n\n        block_info = BlockInfo(\n            self.tile_m,\n            self.tile_n,\n            self.is_causal,\n            self.is_local,\n            False,  # is_split_kv\n            window_size_left,\n            window_size_right,\n            qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,\n        )\n        seqlen = SeqlenInfoQK.create(\n            batch_idx=batch_size,\n            seqlen_q_static=mQ.shape[0],\n            seqlen_k_static=mK.shape[0],\n            mCuSeqlensQ=mCuSeqlensQ,\n            mCuSeqlensK=mCuSeqlensK,\n            mSeqUsedQ=mSeqUsedQ,\n            mSeqUsedK=mSeqUsedK,\n        )\n        n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block)\n        # For varlen, wasted grid tiles (where batch_idx >= num_batch) will have\n        # seqlen_q=seqlen_k=0 and n_block_max=0.  Clamp to 0 so we don't use a\n        # negative block index for K/V loads; the load/store predicates already\n        # guard all memory accesses when seqlen is 0.\n        n_block = cutlass.max(n_block_max - 1, 0)\n\n        # ///////////////////////////////////////////////////////////////////////////////\n        # Get the appropriate tiles for this thread block.\n        # ///////////////////////////////////////////////////////////////////////////////\n        blkQ_shape = (self.tile_m, self.tile_hdim)\n        blkK_shape = (self.tile_n, self.tile_hdim)\n        blkV_shape = (self.tile_n, self.tile_hdimv)\n        num_head_kv = num_head // self.qhead_per_kvhead\n        if const_expr(not seqlen.has_cu_seqlens_q):\n            mQ_cur = mQ[None, None, num_head, batch_size]\n        else:\n            mQ_cur = cute.domain_offset((seqlen.offset_q, 0), mQ[None, None, num_head])\n        if const_expr(not seqlen.has_cu_seqlens_k):\n            mK_cur = mK[None, None, num_head_kv, batch_size]\n            mV_cur = mV[None, None, num_head_kv, batch_size]\n        else:\n            mK_cur = cute.domain_offset((seqlen.offset_k, 0), mK[None, None, num_head_kv])\n            mV_cur = cute.domain_offset((seqlen.offset_k, 0), mV[None, None, num_head_kv])\n        gQ = cute.local_tile(mQ_cur, blkQ_shape, (m_block, 0))\n        gK = cute.local_tile(mK_cur, blkK_shape, (None, 0))\n        gV = cute.local_tile(mV_cur, blkV_shape, (None, 0))\n\n        # ///////////////////////////////////////////////////////////////////////////////\n        # Get shared memory buffer\n        # ///////////////////////////////////////////////////////////////////////////////\n        smem = cutlass.utils.SmemAllocator()\n        storage = smem.allocate(SharedStorage)\n        sQ = storage.sQ.get_tensor(sQ_layout)\n        sK = storage.sK.get_tensor(sK_layout)\n        if const_expr(not self.Q_in_regs):\n            sV = storage.sV.get_tensor(sV_layout)\n        else:\n            sV = cute.make_tensor(cute.recast_ptr(sQ.iterator, dtype=self.dtype), sV_layout)\n        # Transpose view of V to tensor with layout (head_dim_v, tile_n) for tiled mma\n        sVt = layout_utils.transpose_view(sV)\n\n        gmem_thr_copy_K = gmem_tiled_copy_K.get_slice(tidx)\n        gmem_thr_copy_V = gmem_tiled_copy_V.get_slice(tidx)\n        # (CPY_Atom, CPY_N, CPY_K, n_block)\n        tKsK, tKgK = gmem_thr_copy_K.partition_D(sK), gmem_thr_copy_K.partition_S(gK)\n        # (CPY_Atom, CPY_N, CPY_K, n_block)\n        tVsV, tVgV = gmem_thr_copy_V.partition_D(sV), gmem_thr_copy_V.partition_S(gV)\n\n        # ///////////////////////////////////////////////////////////////////////////////\n        # Tile MMA compute thread partitions and allocate accumulators\n        # ///////////////////////////////////////////////////////////////////////////////\n        thr_mma_qk = tiled_mma_qk.get_slice(tidx)\n        thr_mma_pv = tiled_mma_pv.get_slice(tidx)\n        tSrQ = thr_mma_qk.make_fragment_A(thr_mma_qk.partition_A(sQ))\n        tSrK = thr_mma_qk.make_fragment_B(thr_mma_qk.partition_B(sK[None, None, 0]))\n        tOrVt = thr_mma_pv.make_fragment_B(thr_mma_pv.partition_B(sVt[None, None, 0]))\n        acc_shape_O = thr_mma_pv.partition_shape_C((self.tile_m, self.tile_hdimv))\n        acc_O = cute.make_fragment(acc_shape_O, Float32)\n        acc_O.fill(0.0)\n\n        # ///////////////////////////////////////////////////////////////////////////////\n        # Smem copy atom tiling\n        # ///////////////////////////////////////////////////////////////////////////////\n        smem_copy_atom_QK = cute.make_copy_atom(\n            warp.LdMatrix8x8x16bOp(transpose=False, num_matrices=4),\n            self.dtype,\n        )\n        smem_copy_atom_V = cute.make_copy_atom(\n            warp.LdMatrix8x8x16bOp(transpose=True, num_matrices=4),\n            self.dtype,\n        )\n        smem_thr_copy_Q = utils.make_tiled_copy_A(smem_copy_atom_QK, tiled_mma_qk).get_slice(tidx)\n        smem_thr_copy_K = utils.make_tiled_copy_B(smem_copy_atom_QK, tiled_mma_qk).get_slice(tidx)\n        smem_thr_copy_V = utils.make_tiled_copy_B(smem_copy_atom_V, tiled_mma_pv).get_slice(tidx)\n\n        tSsQ = smem_thr_copy_Q.partition_S(sQ)\n        tSsK = smem_thr_copy_K.partition_S(sK)\n        tOsVt = smem_thr_copy_V.partition_S(sVt)\n\n        # ///////////////////////////////////////////////////////////////////////////////\n        # Predicate: Mark indices that need to copy when problem_shape isn't a multiple\n        # of tile_shape\n        # ///////////////////////////////////////////////////////////////////////////////\n        # Construct identity layout for KV\n        cK = cute.make_identity_tensor((self.tile_n, self.tile_hdim))\n        tKcK = gmem_thr_copy_K.partition_S(cK)\n        t0KcK = gmem_thr_copy_K.get_slice(0).partition_S(cK)\n        if const_expr(self.tile_hdim == self.tile_hdimv):\n            tVcV = tKcK\n            t0VcV = t0KcK\n        else:\n            cV = cute.make_identity_tensor((self.tile_n, self.tile_hdimv))\n            tVcV = gmem_thr_copy_V.partition_S(cV)\n            t0VcV = gmem_thr_copy_V.get_slice(0).partition_S(cV)\n        # Allocate predicate tensors for m and n, here we only allocate the tile of k, and\n        # use \"if\" on the mn dimension.\n        # This is to reduce register pressure and gets 2-3% performance gain.\n        tKpK = utils.predicate_k(tKcK, limit=mK.shape[1])\n        if const_expr(self.same_hdim_kv):\n            tVpV = tKpK\n        else:\n            tVpV = utils.predicate_k(tVcV, limit=mV.shape[1])\n\n        # shape: (atom_v_m * rest_m)\n        softmax = Softmax.create(\n            softmax_scale_log2,\n            num_rows=acc_O.shape[0][0] * acc_O.shape[1],\n            softmax_scale=softmax_scale,\n        )\n        softmax.reset()\n\n        # group parameters for compute_one_n_block\n        mma_params = SimpleNamespace(\n            thr_mma_qk=thr_mma_qk,\n            thr_mma_pv=thr_mma_pv,\n            tSrQ=tSrQ,\n            tSrK=tSrK,\n            tOrVt=tOrVt,\n            acc_O=acc_O,\n        )\n        smem_copy_params = SimpleNamespace(\n            smem_thr_copy_Q=smem_thr_copy_Q,\n            smem_thr_copy_K=smem_thr_copy_K,\n            smem_thr_copy_V=smem_thr_copy_V,\n            tSsQ=tSsQ,\n            tSsK=tSsK,\n            tOsVt=tOsVt,\n        )\n        load_K = partial(\n            self.load_K, gmem_tiled_copy_K, tKgK, tKsK, tKcK, t0KcK, tKpK, seqlen=seqlen.seqlen_k\n        )\n        load_V = partial(\n            self.load_V, gmem_tiled_copy_V, tVgV, tVsV, tVcV, t0VcV, tVpV, seqlen=seqlen.seqlen_k\n        )\n\n        compute_one_n_block = partial(\n            self.compute_one_n_block,\n            mma_params=mma_params,\n            smem_copy_params=smem_copy_params,\n            softmax=softmax,\n            load_K=load_K,\n            load_V=load_V,\n            score_mod=self.score_mod,\n            batch_idx=batch_size,\n            head_idx=num_head,\n            m_block=m_block,\n            aux_tensors=aux_tensors,\n            fastdiv_mods=fastdiv_mods,\n        )\n\n        # ///////////////////////////////////////////////////////////////////////////////\n        # Prologue\n        # ///////////////////////////////////////////////////////////////////////////////\n        # Start async loads of the last mn-tile, where we take care of the mn residue\n        gmem_thr_copy_Q = gmem_tiled_copy_Q.get_slice(tidx)\n        self.load_Q(gmem_thr_copy_Q, gQ, sQ, m_block, seqlen=seqlen.seqlen_q, headdim=mQ.shape[1])\n        cute.arch.cp_async_commit_group()\n\n        def preprocess_Q():\n            cute.arch.cp_async_wait_group(self.num_stages * 2 - 1)\n            if const_expr(self.Q_in_regs):\n                cute.arch.barrier()\n                tSrQ_copy_view = smem_thr_copy_Q.retile(tSrQ)\n                cute.copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view)\n\n        # If Q_in_regs, we load Q, then load 1 stage of K, then (optionally) rotate Q and\n        # read from smem_q to registers, then load V.\n        # If !Q_in_regs, we load Q, load all stages of K & V, then (optionally) rotate Q.\n        if const_expr(self.Q_in_regs):\n            load_K(n_block, smem_pipe_write=0, need_predicates=True)\n            cute.arch.cp_async_commit_group()\n            preprocess_Q()\n            cute.arch.barrier()  # Make sure all threads have read smem_q before loading V\n\n        for stage in cutlass.range_constexpr(self.num_stages):\n            if const_expr(not self.Q_in_regs or stage > 0):\n                if stage == 0 or n_block - stage >= 0:\n                    load_K(n_block - stage, smem_pipe_write=stage, need_predicates=stage == 0)\n                cute.arch.cp_async_commit_group()\n            if const_expr(stage < self.num_stages - 1):\n                if stage == 0 or n_block - stage >= 0:\n                    load_V(n_block - stage, smem_pipe_write=stage, need_predicates=stage == 0)\n                cute.arch.cp_async_commit_group()\n        if const_expr(not self.Q_in_regs):\n            preprocess_Q()\n\n        # ///////////////////////////////////////////////////////////////////////////////\n        # Mainloop\n        # ///////////////////////////////////////////////////////////////////////////////\n        # Start processing of the first n-block.\n        # For performance reason, we separate out two kinds of iterations:\n        # those that need masking on S, and those that don't.\n        # We need masking on S for the very last block when K and V has length not multiple of tile_n.\n        # We also need masking on S if it's causal, for the last several blocks.\n        mask = AttentionMask(\n            self.tile_m,\n            self.tile_n,\n            seqlen,\n            window_size_left,\n            window_size_right,\n            self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,\n        )\n        mask_fn = partial(\n            mask.apply_mask,\n            batch_idx=batch_size,\n            head_idx=num_head,\n            m_block=m_block,\n            thr_mma=thr_mma_qk,\n            mask_causal=self.is_causal,\n            mask_local=self.is_local,\n            aux_tensors=aux_tensors,\n            fastdiv_mods=fastdiv_mods if const_expr(self.mask_mod is not None) else None,\n        )\n\n        # First iteration with seqlen masking\n        smem_pipe_read = Int32(0)\n        smem_pipe_write = Int32(self.num_stages - 1)\n        compute_one_n_block(\n            n_block,\n            smem_pipe_read,\n            smem_pipe_write,\n            is_first_n_block=True,\n            seqlen=seqlen,\n            mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=True),\n        )\n        smem_pipe_read = self.advance_pipeline(smem_pipe_read)\n        smem_pipe_write = self.advance_pipeline(smem_pipe_write)\n        # Next couple of iterations with causal masking\n        if const_expr(self.is_causal or self.is_local):\n            n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask(\n                seqlen, m_block, n_block_min\n            )\n            for n_tile in cutlass.range(n_block_max - 1 - n_block_min_causal_local_mask, unroll=1):\n                n_block = n_block_max - 2 - n_tile\n                compute_one_n_block(\n                    n_block,\n                    smem_pipe_read,\n                    smem_pipe_write,\n                    seqlen=seqlen,\n                    mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=True),\n                )\n                smem_pipe_read = self.advance_pipeline(smem_pipe_read)\n                smem_pipe_write = self.advance_pipeline(smem_pipe_write)\n        # The remaining iterations have no masking\n        for n_tile in cutlass.range(n_block, unroll=1):\n            compute_one_n_block(\n                n_block - n_tile - 1, smem_pipe_read, smem_pipe_write,\n                seqlen=seqlen, is_first_n_block=False,\n                mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False)\n            )\n            smem_pipe_read = self.advance_pipeline(smem_pipe_read)\n            smem_pipe_write = self.advance_pipeline(smem_pipe_write)\n        # TODO: local\n\n        # normalize acc_O by row_sum and calculate the lse\n        row_scale = softmax.finalize()\n        softmax.rescale_O(acc_O, row_scale)\n\n        # ///////////////////////////////////////////////////////////////////////////////\n        # Epilogue\n        # ///////////////////////////////////////////////////////////////////////////////\n        # reuse sQ's data iterator\n        sO = cute.make_tensor(sQ.iterator, sO_layout)\n        self.epilogue(\n            acc_O,\n            softmax.row_sum,\n            mO,\n            mLSE,\n            sO,\n            seqlen,\n            gmem_tiled_copy_O,\n            None,\n            tiled_mma_pv,\n            tidx,\n            m_block,\n            num_head,\n            batch_size,\n        )\n\n    @cute.jit\n    def compute_one_n_block(\n        self,\n        n_block: Int32,\n        smem_pipe_read: Int32,\n        smem_pipe_write: Int32,\n        mma_params: SimpleNamespace,\n        smem_copy_params: SimpleNamespace,\n        softmax: Softmax,\n        load_K: Callable,\n        load_V: Callable,\n        score_mod: Callable | None,\n        batch_idx: cutlass.Int32,\n        head_idx: cutlass.Int32,\n        m_block: cutlass.Int32,\n        seqlen: SeqlenInfoQK,\n        aux_tensors=None,\n        fastdiv_mods=None,\n        mask_fn: Optional[Callable] = None,\n        is_first_n_block: cutlass.Constexpr = False,\n        check_inf: cutlass.Constexpr = True,\n    ):\n        \"\"\"Compute one n_block of S/O.\n\n        This function provides different variants for processing the first n block versus\n        subsequent blocks.\n        \"\"\"\n\n        def sync():\n            cute.arch.cp_async_wait_group(self.num_stages * 2 - 2)\n            cute.arch.barrier()\n\n        acc_shape_S = mma_params.thr_mma_qk.partition_shape_C((self.tile_m, self.tile_n))\n        acc_S = cute.make_fragment(acc_shape_S, Float32)\n        acc_S.fill(0.0)\n        # wait for smem tile QK before mma calculation for S\n        sync()\n\n        # need predicates for the first tile\n        def load_V_next():\n            if self.num_stages == 1 or n_block - self.num_stages + 1 >= 0:\n                load_V(\n                    n_block - self.num_stages + 1,\n                    smem_pipe_write,\n                    need_predicates=is_first_n_block and self.num_stages == 1,\n                )\n            cute.arch.cp_async_commit_group()\n\n        load_V_next()\n        sm80_utils.gemm(\n            mma_params.thr_mma_qk,\n            acc_S,\n            mma_params.tSrQ,\n            mma_params.tSrK,\n            smem_copy_params.tSsQ,\n            smem_copy_params.tSsK[\n                None, None, None, smem_pipe_read if const_expr(self.num_stages > 1) else 0\n            ],\n            smem_copy_params.smem_thr_copy_Q,\n            smem_copy_params.smem_thr_copy_K,\n            # hook_fn=load_V_next,\n            A_in_regs=self.Q_in_regs,\n        )\n        if const_expr(score_mod is not None):\n            self.apply_score_mod(\n                mma_params.thr_mma_qk,\n                batch_idx,\n                head_idx,\n                m_block,\n                acc_S,\n                n_block,\n                seqlen,\n                softmax_scale=softmax.softmax_scale,\n                aux_tensors=aux_tensors,\n                fastdiv_mods=fastdiv_mods,\n            )\n\n        smem_pipe_write = self.advance_pipeline(smem_pipe_write)\n\n        def load_K_next():\n            if n_block - self.num_stages >= 0:\n                load_K(n_block - self.num_stages, smem_pipe_write, need_predicates=False)\n            cute.arch.cp_async_commit_group()\n\n        # wait for smem tile V for O\n        if const_expr(self.num_stages == 1):\n            sync()\n            load_K_next()\n        if const_expr(mask_fn is not None):\n            mask_fn(acc_S, n_block=n_block)\n        row_scale = softmax.online_softmax(acc_S, is_first=is_first_n_block, check_inf=check_inf)\n        softmax.rescale_O(mma_params.acc_O, row_scale)\n        rP = cute.make_fragment_like(acc_S, self.dtype)\n        rP.store(acc_S.load().to(self.dtype))\n        tOrP = layout_utils.reshape_acc_to_frgA(rP)\n        if const_expr(self.num_stages > 1):\n            sync()\n            load_K_next()\n        sm80_utils.gemm_rs(\n            mma_params.thr_mma_pv,\n            mma_params.acc_O,\n            tOrP,\n            mma_params.tOrVt,\n            smem_copy_params.tOsVt[\n                None, None, None, smem_pipe_read if const_expr(self.num_stages > 1) else 0\n            ],\n            smem_copy_params.smem_thr_copy_V,\n            # hook_fn=load_K_next,\n        )\n        # if const_expr(self.num_stages > 1):\n        #     load_K_next()\n\n\n# SM90 forward pass moved to flash_fwd_sm90.py; re-export for backward compatibility\ndef __getattr__(name):\n    if name == \"FlashAttentionForwardSm90\":\n        from flash_attn.cute.flash_fwd_sm90 import FlashAttentionForwardSm90\n        return FlashAttentionForwardSm90\n    raise AttributeError(f\"module {__name__!r} has no attribute {name!r}\")\n"
  },
  {
    "path": "flash_attn/cute/flash_fwd_combine.py",
    "content": "# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n# A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_fwd_combine_kernel.h\n# from Cutlass C++ to Cute-DSL.\nimport math\nfrom typing import Type, Optional\nfrom functools import partial\n\nimport cuda.bindings.driver as cuda\n\nimport cutlass\nimport cutlass.cute as cute\nfrom cutlass.cute.nvgpu import cpasync\nfrom cutlass import Float32, Int32, Boolean, const_expr\n\nfrom flash_attn.cute import utils\nfrom flash_attn.cute.cute_dsl_utils import assume_tensor_aligned\nfrom flash_attn.cute.seqlen_info import SeqlenInfo\nfrom cutlass.cute import FastDivmodDivisor\n\n\nclass FlashAttentionForwardCombine:\n    def __init__(\n        self,\n        dtype: Type[cutlass.Numeric],\n        dtype_partial: Type[cutlass.Numeric],\n        head_dim: int,\n        tile_m: int = 8,\n        k_block_size: int = 64,\n        log_max_splits: int = 4,\n        num_threads: int = 256,\n        stages: int = 4,\n    ):\n        \"\"\"\n        Forward combine kernel for split attention computation.\n\n        :param dtype: output data type\n        :param dtype_partial: partial accumulation data type\n        :param head_dim: head dimension\n        :param tile_m: m block size\n        :param k_block_size: k block size\n        :param log_max_splits: log2 of maximum splits\n        :param num_threads: number of threads\n        :param varlen: whether using variable length sequences\n        :param stages: number of pipeline stages\n        \"\"\"\n        self.dtype = dtype\n        self.dtype_partial = dtype_partial\n        self.head_dim = head_dim\n        self.tile_m = tile_m\n        self.k_block_size = k_block_size\n        self.max_splits = 1 << log_max_splits\n        self.num_threads = num_threads\n        self.is_even_k = head_dim % k_block_size == 0\n        self.stages = stages\n\n    @staticmethod\n    def can_implement(\n        dtype,\n        dtype_partial,\n        head_dim,\n        tile_m,\n        k_block_size,\n        log_max_splits,\n        num_threads,\n    ) -> bool:\n        \"\"\"Check if the kernel can be implemented with the given parameters.\"\"\"\n        if dtype not in [cutlass.Float16, cutlass.BFloat16, cutlass.Float32]:\n            return False\n        if dtype_partial not in [cutlass.Float16, cutlass.BFloat16, Float32]:\n            return False\n        if head_dim % 8 != 0:\n            return False\n        if num_threads % 32 != 0:\n            return False\n        if tile_m % 8 != 0:\n            return False\n        max_splits = 1 << log_max_splits\n        if max_splits > 256:\n            return False\n        if (tile_m * max_splits) % num_threads != 0:\n            return False\n        return True\n\n    def _setup_attributes(self):\n        # GMEM copy setup for O partial\n        universal_copy_bits = 128\n        async_copy_elems = universal_copy_bits // self.dtype_partial.width\n        assert self.k_block_size % async_copy_elems == 0\n\n        k_block_gmem = (\n            128 if self.k_block_size % 128 == 0 else (64 if self.k_block_size % 64 == 0 else 32)\n        )\n        gmem_threads_per_row = k_block_gmem // async_copy_elems\n        assert self.num_threads % gmem_threads_per_row == 0\n\n        # Async copy atom for O partial load\n        atom_async_copy_partial = cute.make_copy_atom(\n            cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),\n            self.dtype_partial,\n            num_bits_per_copy=universal_copy_bits,\n        )\n        tOpartial_layout = cute.make_ordered_layout(\n            (self.num_threads // gmem_threads_per_row, gmem_threads_per_row),\n            order=(1, 0),\n        )\n        vOpartial_layout = cute.make_layout((1, async_copy_elems))  # 4 vals per load\n        self.gmem_tiled_copy_O_partial = cute.make_tiled_copy_tv(\n            atom_async_copy_partial, tOpartial_layout, vOpartial_layout\n        )\n\n        # GMEM copy setup for final O (use universal copy for store)\n        atom_universal_copy = cute.make_copy_atom(\n            cute.nvgpu.CopyUniversalOp(),\n            self.dtype,\n            num_bits_per_copy=async_copy_elems * self.dtype.width,\n        )\n        self.gmem_tiled_copy_O = cute.make_tiled_copy_tv(\n            atom_universal_copy,\n            tOpartial_layout,\n            vOpartial_layout,  # 4 vals per store\n        )\n\n        # LSE copy setup with async copy (alignment = 1)\n        lse_copy_bits = Float32.width  # 1 element per copy, width is in bits\n        m_block_smem = (\n            128\n            if self.tile_m % 128 == 0\n            else (\n                64\n                if self.tile_m % 64 == 0\n                else (32 if self.tile_m % 32 == 0 else (16 if self.tile_m % 16 == 0 else 8))\n            )\n        )\n        gmem_threads_per_row_lse = m_block_smem\n        assert self.num_threads % gmem_threads_per_row_lse == 0\n\n        # Async copy atom for LSE load\n        atom_async_copy_lse = cute.make_copy_atom(\n            cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.ALWAYS),\n            Float32,\n            num_bits_per_copy=lse_copy_bits,\n        )\n        tLSE_layout = cute.make_ordered_layout(\n            (self.num_threads // gmem_threads_per_row_lse, gmem_threads_per_row_lse),\n            order=(1, 0),\n        )\n        vLSE_layout = cute.make_layout(1)\n        self.gmem_tiled_copy_LSE = cute.make_tiled_copy_tv(\n            atom_async_copy_lse, tLSE_layout, vLSE_layout\n        )\n\n        # ///////////////////////////////////////////////////////////////////////////////\n        # Shared memory\n        # ///////////////////////////////////////////////////////////////////////////////\n\n        # Shared memory to register copy for LSE\n        self.smem_threads_per_col_lse = self.num_threads // m_block_smem\n        assert 32 % self.smem_threads_per_col_lse == 0  # Must divide warp size\n\n        s2r_layout_atom_lse = cute.make_ordered_layout(\n            (self.smem_threads_per_col_lse, self.num_threads // self.smem_threads_per_col_lse),\n            order=(0, 1),\n        )\n        self.s2r_tiled_copy_LSE = cute.make_tiled_copy_tv(\n            cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32),\n            s2r_layout_atom_lse,\n            cute.make_layout(1),\n        )\n\n        # LSE shared memory layout with swizzling to avoid bank conflicts\n        # This works for kBlockMSmem = 8, 16, 32, 64, 128, no bank conflicts\n        if const_expr(m_block_smem == 8):\n            smem_lse_swizzle = cute.make_swizzle(5, 0, 5)\n        elif const_expr(m_block_smem == 16):\n            smem_lse_swizzle = cute.make_swizzle(4, 0, 4)\n        else:\n            smem_lse_swizzle = cute.make_swizzle(3, 2, 3)\n        smem_layout_atom_lse = cute.make_composed_layout(\n            smem_lse_swizzle, 0, cute.make_ordered_layout((8, m_block_smem), order=(1, 0))\n        )\n        self.smem_layout_lse = cute.tile_to_shape(\n            smem_layout_atom_lse, (self.max_splits, self.tile_m), (0, 1)\n        )\n\n        # O partial shared memory layout (simple layout for pipeline stages)\n        self.smem_layout_o = cute.make_ordered_layout(\n            (self.tile_m, self.k_block_size, self.stages), order=(1, 0, 2)\n        )\n\n    @cute.jit\n    def __call__(\n        self,\n        mO_partial: cute.Tensor,\n        mLSE_partial: cute.Tensor,\n        mO: cute.Tensor,\n        mLSE: Optional[cute.Tensor] = None,\n        cu_seqlens: Optional[cute.Tensor] = None,\n        seqused: Optional[cute.Tensor] = None,\n        num_splits_dynamic_ptr: Optional[cute.Tensor] = None,\n        varlen_batch_idx: Optional[cute.Tensor] = None,\n        semaphore_to_reset: Optional[cute.Tensor] = None,\n        # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI).\n        stream: cuda.CUstream = None,\n    ):\n        # Type checking\n        if const_expr(not (mO_partial.element_type == self.dtype_partial)):\n            raise TypeError(\"O partial tensor must match dtype_partial\")\n        if const_expr(not (mO.element_type == self.dtype)):\n            raise TypeError(\"O tensor must match dtype\")\n        if const_expr(mLSE_partial.element_type not in [Float32]):\n            raise TypeError(\"LSE partial tensor must be Float32\")\n        if const_expr(mLSE is not None and mLSE.element_type not in [Float32]):\n            raise TypeError(\"LSE tensor must be Float32\")\n\n        # Shape validation - input tensors are in user format, need to be converted to kernel format\n        if const_expr(len(mO_partial.shape) not in [4, 5]):\n            raise ValueError(\n                \"O partial tensor must have 4 or 5 dimensions: (num_splits, batch, seqlen, nheads, headdim) or (num_splits, total_q, nheads, headdim)\"\n            )\n        if const_expr(len(mLSE_partial.shape) not in [3, 4]):\n            raise ValueError(\n                \"LSE partial tensor must have 3 or 4 dimensions: (num_splits, batch, seqlen, nheads) or (num_splits, total_q, nheads)\"\n            )\n        if const_expr(len(mO.shape) not in [3, 4]):\n            raise ValueError(\n                \"O tensor must have 3 or 4 dimensions: (batch, seqlen, nheads, headdim) or (total_q, nheads, headdim)\"\n            )\n        if const_expr(mLSE is not None and len(mLSE.shape) not in [2, 3]):\n            raise ValueError(\n                \"LSE tensor must have 2 or 3 dimensions: (batch, seqlen, nheads) or (total_q, nheads)\"\n            )\n\n        mO_partial, mO = [assume_tensor_aligned(t) for t in (mO_partial, mO)]\n        # (num_splits, b, seqlen, h, d) -> (seqlen, d, num_splits, h, b)\n        # or (num_splits, total_q, h, d) -> (total_q, d, num_splits, h)\n        O_partial_layout_transpose = (\n            [2, 4, 0, 3, 1] if const_expr(cu_seqlens is None) else [1, 3, 0, 2]\n        )\n        # (b, seqlen, h, d) -> (seqlen, d, h, b) or (total_q, h, d) -> (total_q, d, h)\n        mO_partial = cute.make_tensor(\n            mO_partial.iterator, cute.select(mO_partial.layout, mode=O_partial_layout_transpose)\n        )\n        O_layout_transpose = [1, 3, 2, 0] if const_expr(cu_seqlens is None) else [0, 2, 1]\n        mO = cute.make_tensor(mO.iterator, cute.select(mO.layout, mode=O_layout_transpose))\n        # (num_splits, b, seqlen, h) -> (seqlen, num_splits, h, b)\n        # or (num_splits, total_q, h) -> (total_q, num_splits, h)\n        LSE_partial_layout_transpose = [2, 0, 3, 1] if const_expr(cu_seqlens is None) else [1, 0, 2]\n        mLSE_partial = cute.make_tensor(\n            mLSE_partial.iterator,\n            cute.select(mLSE_partial.layout, mode=LSE_partial_layout_transpose),\n        )\n        # (b, seqlen, h) -> (seqlen, h, b) or (total_q, h) -> (total_q, h)\n        LSE_layout_transpose = [1, 2, 0] if const_expr(cu_seqlens is None) else [0, 1]\n        mLSE = (\n            cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose))\n            if mLSE is not None\n            else None\n        )\n\n        # Determine if we have variable length sequences\n        varlen = const_expr(cu_seqlens is not None or seqused is not None)\n\n        self._setup_attributes()\n\n        @cute.struct\n        class SharedStorage:\n            sLSE: cute.struct.Align[\n                cute.struct.MemRange[Float32, cute.cosize(self.smem_layout_lse)], 128\n            ]\n            sMaxValidSplit: cute.struct.Align[cute.struct.MemRange[Int32, self.tile_m], 128]\n            sO: cute.struct.Align[\n                cute.struct.MemRange[self.dtype_partial, cute.cosize(self.smem_layout_o)], 128\n            ]\n\n        smem_size = SharedStorage.size_in_bytes()\n\n        # Grid dimensions: (ceil_div(seqlen, m_block), ceil_div(head_dim, k_block), num_head * batch)\n        seqlen = mO_partial.shape[0]\n        num_head = mO_partial.shape[3]\n        batch_size = (\n            mO_partial.shape[4]\n            if const_expr(cu_seqlens is None)\n            else Int32(cu_seqlens.shape[0] - 1)\n        )\n\n        # Create FastDivmodDivisor objects for efficient division\n        seqlen_divmod = FastDivmodDivisor(seqlen)\n        head_divmod = FastDivmodDivisor(num_head)\n\n        grid_dim = (\n            cute.ceil_div(seqlen * num_head, self.tile_m),\n            cute.ceil_div(self.head_dim, self.k_block_size),\n            batch_size,\n        )\n\n        self.kernel(\n            mO_partial,\n            mLSE_partial,\n            mO,\n            mLSE,\n            cu_seqlens,\n            seqused,\n            num_splits_dynamic_ptr,\n            varlen_batch_idx,\n            semaphore_to_reset,\n            SharedStorage,\n            self.smem_layout_lse,\n            self.smem_layout_o,\n            self.gmem_tiled_copy_O_partial,\n            self.gmem_tiled_copy_O,\n            self.gmem_tiled_copy_LSE,\n            self.s2r_tiled_copy_LSE,\n            seqlen_divmod,\n            head_divmod,\n            varlen,\n        ).launch(\n            grid=grid_dim,\n            block=[self.num_threads, 1, 1],\n            smem=smem_size,\n            stream=stream,\n        )\n\n    @cute.kernel\n    def kernel(\n        self,\n        mO_partial: cute.Tensor,\n        mLSE_partial: cute.Tensor,\n        mO: cute.Tensor,\n        mLSE: Optional[cute.Tensor],\n        cu_seqlens: Optional[cute.Tensor],\n        seqused: Optional[cute.Tensor],\n        num_splits_dynamic_ptr: Optional[cute.Tensor],\n        varlen_batch_idx: Optional[cute.Tensor],\n        semaphore_to_reset: Optional[cute.Tensor],\n        SharedStorage: cutlass.Constexpr,\n        smem_layout_lse: cute.Layout | cute.ComposedLayout,\n        smem_layout_o: cute.Layout,\n        gmem_tiled_copy_O_partial: cute.TiledCopy,\n        gmem_tiled_copy_O: cute.TiledCopy,\n        gmem_tiled_copy_LSE: cute.TiledCopy,\n        s2r_tiled_copy_LSE: cute.TiledCopy,\n        seqlen_divmod: FastDivmodDivisor,\n        head_divmod: FastDivmodDivisor,\n        varlen: cutlass.Constexpr[bool],\n    ):\n        # Thread and block indices\n        tidx, _, _ = cute.arch.thread_idx()\n        m_block, k_block, maybe_virtual_batch = cute.arch.block_idx()\n\n        # Map virtual batch index to real batch index (for persistent tile schedulers)\n        batch_idx = (\n            varlen_batch_idx[maybe_virtual_batch]\n            if const_expr(varlen_batch_idx is not None)\n            else maybe_virtual_batch\n        )\n\n        # ///////////////////////////////////////////////////////////////////////////////\n        # Get shared memory buffer\n        # ///////////////////////////////////////////////////////////////////////////////\n        smem = cutlass.utils.SmemAllocator()\n        storage = smem.allocate(SharedStorage)\n        sLSE = storage.sLSE.get_tensor(smem_layout_lse)\n        sMaxValidSplit = storage.sMaxValidSplit.get_tensor((self.tile_m,))\n        sO = storage.sO.get_tensor(smem_layout_o)\n\n        # Handle semaphore reset — wait for dependent grids first\n        if const_expr(semaphore_to_reset is not None):\n            if (\n                tidx == 0\n                and m_block == cute.arch.grid_dim()[0] - 1\n                and k_block == cute.arch.grid_dim()[1] - 1\n                and maybe_virtual_batch == cute.arch.grid_dim()[2] - 1\n            ):\n                cute.arch.griddepcontrol_wait()\n                semaphore_to_reset[0] = 0\n\n        # Get number of splits (use maybe_virtual_batch for per-batch-slot splits)\n        num_splits = (\n            num_splits_dynamic_ptr[maybe_virtual_batch]\n            if const_expr(num_splits_dynamic_ptr is not None)\n            else mLSE_partial.shape[1]\n        )\n        # Handle variable length sequences using SeqlenInfo\n        seqlen_info = SeqlenInfo.create(\n            batch_idx=batch_idx,\n            seqlen_static=mO_partial.shape[0],\n            cu_seqlens=cu_seqlens,\n            seqused=seqused,\n            # Don't need to pass in tile size since we won't use offset_padded\n        )\n        seqlen, offset = seqlen_info.seqlen, seqlen_info.offset\n\n        # Extract number of heads (head index will be determined dynamically)\n        num_head = mO_partial.shape[3]\n        max_idx = seqlen * num_head\n\n        # Early exit for single split if dynamic\n        if (const_expr(num_splits_dynamic_ptr is None) or num_splits > 1) and (\n            const_expr(not varlen) or m_block * self.tile_m < max_idx\n        ):\n            # Wait for dependent grids (e.g., the main attention kernel that produces O_partial/LSE_partial)\n            cute.arch.griddepcontrol_wait()\n\n            # ===============================\n            # Step 1: Load LSE_partial from gmem to shared memory\n            # ===============================\n\n            mLSE_partial_cur = seqlen_info.offset_batch(mLSE_partial, batch_idx, dim=3)\n            mLSE_partial_copy = cute.tiled_divide(mLSE_partial_cur, (1,))\n            gmem_thr_copy_LSE = gmem_tiled_copy_LSE.get_slice(tidx)\n            tLSEsLSE = gmem_thr_copy_LSE.partition_D(sLSE)\n            # Create identity tensor for coordinate tracking\n            cLSE = cute.make_identity_tensor((self.max_splits, self.tile_m))\n            tLSEcLSE = gmem_thr_copy_LSE.partition_S(cLSE)\n\n            # Load LSE partial values\n            for m in cutlass.range(cute.size(tLSEcLSE, mode=[2]), unroll_full=True):\n                mi = tLSEcLSE[0, 0, m][1]  # Get m coordinate\n                idx = m_block * self.tile_m + mi\n                if idx < max_idx:\n                    # Calculate actual sequence position and head using FastDivmodDivisor\n                    if const_expr(not varlen):\n                        head_idx, m_idx = divmod(idx, seqlen_divmod)\n                    else:\n                        head_idx = idx // seqlen\n                        m_idx = idx - head_idx * seqlen\n                    mLSE_partial_cur_copy = mLSE_partial_copy[None, m_idx, None, head_idx]\n                    for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True):\n                        si = tLSEcLSE[0, s, 0][0]  # Get split coordinate\n                        if si < num_splits:\n                            cute.copy(\n                                gmem_thr_copy_LSE,\n                                mLSE_partial_cur_copy[None, si],\n                                tLSEsLSE[None, s, m],\n                            )\n                        else:\n                            tLSEsLSE[None, s, m].fill(-Float32.inf)\n                # Don't need to zero out the rest of the LSEs, as we will not write the output to gmem\n            cute.arch.cp_async_commit_group()\n\n            # ===============================\n            # Step 2: Load O_partial for pipeline stages\n            # ===============================\n\n            gmem_thr_copy_O_partial = gmem_tiled_copy_O_partial.get_slice(tidx)\n            cO = cute.make_identity_tensor((self.tile_m, self.k_block_size))\n            tOcO = gmem_thr_copy_O_partial.partition_D(cO)\n            tOsO_partial = gmem_thr_copy_O_partial.partition_D(sO)\n            mO_partial_cur = seqlen_info.offset_batch(mO_partial, batch_idx, dim=4)\n\n            # Precompute these values to avoid recomputing them in the loop\n            num_rows = const_expr(cute.size(tOcO, mode=[1]))\n            tOmidx = cute.make_rmem_tensor(num_rows, cutlass.Int32)\n            tOhidx = cute.make_rmem_tensor(num_rows, cutlass.Int32)\n            tOrOptr = cute.make_rmem_tensor(num_rows, cutlass.Int64)\n            for m in cutlass.range(num_rows, unroll_full=True):\n                mi = tOcO[0, m, 0][0]  # m coordinate\n                idx = m_block * self.tile_m + mi\n                if const_expr(not varlen):\n                    tOhidx[m], tOmidx[m] = divmod(idx, seqlen_divmod)\n                else:\n                    tOhidx[m] = idx // seqlen\n                    tOmidx[m] = idx - tOhidx[m] * seqlen\n                tOrOptr[m] = utils.elem_pointer(\n                    mO_partial_cur, (tOmidx[m], k_block * self.k_block_size, 0, tOhidx[m])\n                ).toint()\n                if idx >= max_idx:\n                    tOhidx[m] = -1\n\n            tOpO = None\n            if const_expr(not self.is_even_k):\n                tOpO = cute.make_rmem_tensor(cute.size(tOcO, mode=[2]), Boolean)\n                for k in cutlass.range(cute.size(tOpO), unroll_full=True):\n                    tOpO[k] = tOcO[0, 0, k][1] < mO_partial.shape[1] - k_block * self.k_block_size\n                # if cute.arch.thread_idx()[0] == 0 and k_block == 1: cute.print_tensor(tOpO)\n\n            load_O_partial = partial(\n                self.load_O_partial,\n                gmem_tiled_copy_O_partial,\n                tOrOptr,\n                tOsO_partial,\n                tOhidx,\n                tOpO,\n                tOcO,\n                mO_partial_cur.layout,\n            )\n\n            # Load first few stages of O_partial\n            for stage in cutlass.range(self.stages - 1, unroll_full=True):\n                if stage < num_splits:\n                    load_O_partial(stage, stage)\n                cute.arch.cp_async_commit_group()\n\n            # ===============================\n            # Step 3: Load and transpose LSE from smem to registers\n            # ===============================\n\n            # Wait for LSE and initial O partial stages to complete\n            cute.arch.cp_async_wait_group(self.stages - 1)\n            cute.arch.sync_threads()\n            # if cute.arch.thread_idx()[0] == 0:\n            #     # cute.print_tensor(sLSE)\n            #     for i in range(64):\n            #         cute.printf(\"sLSE[%d, 0] = %f\", i, sLSE[i, 0])\n            # cute.arch.sync_threads()\n\n            s2r_thr_copy_LSE = s2r_tiled_copy_LSE.get_slice(tidx)\n            ts2rsLSE = s2r_thr_copy_LSE.partition_S(sLSE)\n            ts2rrLSE = cute.make_rmem_tensor_like(ts2rsLSE)\n            cute.copy(s2r_tiled_copy_LSE, ts2rsLSE, ts2rrLSE)\n\n            # ===============================\n            # Step 4: Compute final LSE along split dimension\n            # ===============================\n\n            lse_sum = cute.make_rmem_tensor(cute.size(ts2rrLSE, mode=[2]), Float32)\n            ts2rcLSE = s2r_thr_copy_LSE.partition_D(cLSE)\n            # We compute the max valid split for each row to short-circuit the computation later\n            max_valid_split = cute.make_rmem_tensor(cute.size(ts2rrLSE, mode=[2]), Int32)\n            assert cute.size(ts2rrLSE, mode=[0]) == 1\n            # Compute max, scales, and final LSE for each row\n            for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True):\n                # Find max LSE value across splits\n                threads_per_col = const_expr(self.smem_threads_per_col_lse)\n                lse_max = cute.arch.warp_reduction_max(\n                    ts2rrLSE[None, None, m]\n                    .load()\n                    .reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0),\n                    threads_in_group=threads_per_col,\n                )\n                # if cute.arch.thread_idx()[0] == 0: cute.printf(lse_max)\n                # Find max valid split index\n                max_valid_idx = -1\n                for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True):\n                    if ts2rrLSE[0, s, m] != -Float32.inf:\n                        max_valid_idx = ts2rcLSE[0, s, 0][0]  # Get split coordinate\n                # if cute.arch.thread_idx()[0] < 32: cute.printf(max_valid_idx)\n                max_valid_split[m] = cute.arch.warp_reduction_max(\n                    max_valid_idx, threads_in_group=threads_per_col\n                )\n                # Compute exp scales and sum\n                lse_max_cur = (\n                    0.0 if lse_max == -Float32.inf else lse_max\n                )  # In case all local LSEs are -inf\n                LOG2_E = math.log2(math.e)\n                lse_sum_cur = 0.0\n                for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True):\n                    scale = cute.math.exp2(\n                        ts2rrLSE[0, s, m] * LOG2_E - (lse_max_cur * LOG2_E), fastmath=True\n                    )\n                    lse_sum_cur += scale\n                    ts2rrLSE[0, s, m] = scale  # Store scale for later use\n                lse_sum_cur = cute.arch.warp_reduction_sum(\n                    lse_sum_cur, threads_in_group=threads_per_col\n                )\n                lse_sum[m] = cute.math.log(lse_sum_cur, fastmath=True) + lse_max\n                # Normalize scales\n                inv_sum = (\n                    0.0 if (lse_sum_cur == 0.0 or lse_sum_cur != lse_sum_cur) else 1.0 / lse_sum_cur\n                )\n                ts2rrLSE[None, None, m].store(ts2rrLSE[None, None, m].load() * inv_sum)\n            # Store the scales exp(lse - lse_logsum) back to smem\n            cute.copy(s2r_tiled_copy_LSE, ts2rrLSE, ts2rsLSE)\n\n            # Store max valid split to smem\n            for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True):\n                if ts2rcLSE[0, 0, m][0] == 0:  # Only thread responsible for s=0 writes\n                    mi = ts2rcLSE[0, 0, m][1]\n                    if mi < self.tile_m:\n                        sMaxValidSplit[mi] = max_valid_split[m]\n\n            # ===============================\n            # Step 5: Store final LSE to gmem\n            # ===============================\n\n            if const_expr(mLSE is not None):\n                if const_expr(cu_seqlens is None):\n                    mLSE_cur = mLSE[None, None, batch_idx]\n                else:\n                    mLSE_cur = cute.domain_offset((offset, 0), mLSE)\n                if k_block == 0:  # Only first k_block writes LSE when mLSE is provided\n                    for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True):\n                        if ts2rcLSE[0, 0, m][0] == 0:  # Only thread responsible for s=0 writes\n                            mi = ts2rcLSE[0, 0, m][1]\n                            idx = m_block * self.tile_m + mi\n                            if idx < max_idx:\n                                if const_expr(not varlen):\n                                    head_idx, m_idx = divmod(idx, seqlen_divmod)\n                                else:\n                                    head_idx = idx // seqlen\n                                    m_idx = idx - head_idx * seqlen\n                                mLSE_cur[m_idx, head_idx] = lse_sum[m]\n\n            # ===============================\n            # Step 6: Read O_partial and accumulate final O\n            # ===============================\n\n            cute.arch.sync_threads()\n\n            # Get max valid split for this thread\n            thr_max_valid_split = sMaxValidSplit[tOcO[0, 0, 0][0]]\n            for m in cutlass.range(1, cute.size(tOcO, mode=[1]), unroll_full=True):\n                thr_max_valid_split = max(thr_max_valid_split, sMaxValidSplit[tOcO[0, m, 0][0]])\n\n            tOrO_partial = cute.make_rmem_tensor_like(tOsO_partial[None, None, None, 0])\n            tOrO = cute.make_rmem_tensor_like(tOrO_partial, Float32)\n            tOrO.fill(0.0)\n\n            stage_load = self.stages - 1\n            stage_compute = 0\n\n            # Main accumulation loop\n            for s in cutlass.range(thr_max_valid_split + 1, unroll=4):\n                # Get scales for this split\n                scale = cute.make_rmem_tensor(num_rows, Float32)\n                for m in cutlass.range(num_rows, unroll_full=True):\n                    scale[m] = sLSE[s, tOcO[0, m, 0][0]]  # Get scale from smem\n\n                # Load next stage if needed\n                split_to_load = s + self.stages - 1\n                if split_to_load <= thr_max_valid_split:\n                    load_O_partial(split_to_load, stage_load)\n                cute.arch.cp_async_commit_group()\n                stage_load = 0 if stage_load == self.stages - 1 else stage_load + 1\n\n                # Wait for the current stage to be ready\n                cute.arch.cp_async_wait_group(self.stages - 1)\n                # We don't need __syncthreads() because each thread is just reading its own data from smem\n                # Copy from smem to registers\n                cute.autovec_copy(tOsO_partial[None, None, None, stage_compute], tOrO_partial)\n                stage_compute = 0 if stage_compute == self.stages - 1 else stage_compute + 1\n\n                # Accumulate scaled partial results\n                for m in cutlass.range(num_rows, unroll_full=True):\n                    if tOhidx[m] >= 0 and scale[m] > 0.0:\n                        tOrO[None, m, None].store(\n                            tOrO[None, m, None].load()\n                            + scale[m] * tOrO_partial[None, m, None].load().to(Float32)\n                        )\n\n            # ===============================\n            # Step 7: Write final O to gmem\n            # ===============================\n\n            rO = cute.make_rmem_tensor_like(tOrO, self.dtype)\n            rO.store(tOrO.load().to(self.dtype))\n            mO_cur = seqlen_info.offset_batch(mO, batch_idx, dim=3)\n            if const_expr(cu_seqlens is None):\n                mO_cur = mO[None, None, None, batch_idx]\n            else:\n                mO_cur = cute.domain_offset((offset, 0, 0), mO)\n            mO_cur = utils.domain_offset_aligned((0, k_block * self.k_block_size, 0), mO_cur)\n            elems_per_store = const_expr(cute.size(gmem_tiled_copy_O.layout_tv_tiled[1]))\n            # mO_cur_copy = cute.tiled_divide(mO_cur, (1, elems_per_store,))\n            gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx)\n            # Write final results\n            for m in cutlass.range(num_rows, unroll_full=True):\n                if tOhidx[m] >= 0:\n                    mO_cur_copy = cute.tiled_divide(\n                        mO_cur[tOmidx[m], None, tOhidx[m]], (elems_per_store,)\n                    )\n                    for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True):\n                        k_idx = tOcO[0, 0, k][1] // elems_per_store\n                        if const_expr(self.is_even_k) or tOpO[k]:\n                            cute.copy(gmem_thr_copy_O, rO[None, m, k], mO_cur_copy[None, k_idx])\n\n    @cute.jit\n    def load_O_partial(\n        self,\n        gmem_tiled_copy_O_partial: cute.TiledCopy,\n        tOrOptr: cute.Tensor,\n        tOsO_partial: cute.Tensor,\n        tOhidx: cute.Tensor,\n        tOpO: Optional[cute.Tensor],\n        tOcO: cute.Tensor,\n        mO_cur_partial_layout: cute.Layout,\n        split: Int32,\n        stage: Int32,\n    ) -> None:\n        elems_per_load = const_expr(cute.size(gmem_tiled_copy_O_partial.layout_tv_tiled[1]))\n        tOsO_partial_cur = tOsO_partial[None, None, None, stage]\n        for m in cutlass.range(cute.size(tOcO, [1]), unroll_full=True):\n            if tOhidx[m] >= 0:\n                o_gmem_ptr = cute.make_ptr(\n                    tOsO_partial.element_type, tOrOptr[m], cute.AddressSpace.gmem, assumed_align=16\n                )\n                mO_partial_cur = cute.make_tensor(\n                    o_gmem_ptr, cute.slice_(mO_cur_partial_layout, (0, None, None, 0))\n                )\n                mO_partial_cur_copy = cute.tiled_divide(mO_partial_cur, (elems_per_load,))\n                for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True):\n                    k_idx = tOcO[0, 0, k][1] // elems_per_load\n                    if const_expr(tOpO is None) or tOpO[k]:\n                        cute.copy(\n                            gmem_tiled_copy_O_partial,\n                            mO_partial_cur_copy[None, k_idx, split],\n                            tOsO_partial_cur[None, m, k],\n                        )\n"
  },
  {
    "path": "flash_attn/cute/flash_fwd_sm100.py",
    "content": "# Supported features:\n# - BF16 & FP16 dtype\n# - noncausal & causal attention\n# - MHA, GQA, MQA\n# - hdim 64, 96, 128, (192, 128).\n# - varlen\n# - sliding window\n# - split-kv\n# Unsupported features that will be added later:\n# - page size != 128\n# - more hdim (192, 256)\n# Based on the cutlass example and cute-dsl example:\n# https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha\n# https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/fmha.py\n\nimport math\nfrom typing import Type, Tuple, Callable, Optional, Literal\nfrom functools import partial\n\nimport cuda.bindings.driver as cuda\n\nimport cutlass\nimport cutlass.cute as cute\nfrom cutlass import Float32, Int32, Int64, Boolean, const_expr\nfrom cutlass.cute.nvgpu import cpasync\nimport cutlass.cute.nvgpu.tcgen05 as tcgen05\nimport cutlass.utils.blackwell_helpers as sm100_utils_basic\nfrom cutlass import pipeline\nfrom cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait\nfrom cutlass.base_dsl.arch import Arch\nfrom cutlass.cutlass_dsl import BaseDSL\n\nfrom quack import copy_utils, layout_utils\n\nfrom flash_attn.cute.paged_kv import PagedKVManager\nfrom flash_attn.cute.cute_dsl_utils import assume_tensor_aligned\nfrom flash_attn.cute import utils\nimport flash_attn.cute.pipeline as pipeline_custom\nfrom flash_attn.cute.mask import AttentionMask\nfrom flash_attn.cute.softmax import SoftmaxSm100, apply_score_mod_inner\nfrom flash_attn.cute.seqlen_info import SeqlenInfoQK\nfrom flash_attn.cute.block_info import BlockInfo\nfrom flash_attn.cute.block_sparsity import BlockSparseTensors\nfrom flash_attn.cute.block_sparse_utils import (\n    get_total_block_count,\n    produce_block_sparse_loads_sm100,\n    softmax_block_sparse_sm100,\n    handle_block_sparse_empty_tile_correction_sm100,\n)\nfrom flash_attn.cute.pack_gqa import PackGQA, pack_gqa_layout\nfrom flash_attn.cute import mma_sm100_desc as sm100_desc\nfrom flash_attn.cute import blackwell_helpers as sm100_utils\nfrom flash_attn.cute.named_barrier import NamedBarrierFwdSm100\nfrom cutlass.cute import FastDivmodDivisor\nfrom quack.cute_dsl_utils import ParamsBase\nfrom flash_attn.cute.tile_scheduler import (\n    TileSchedulerArguments,\n    SingleTileScheduler,\n    StaticPersistentTileScheduler,\n    SingleTileLPTScheduler,\n    SingleTileVarlenScheduler,\n)\n\nclass FlashAttentionForwardSm100:\n\n    def __init__(\n        self,\n        # dtype: Type[cutlass.Numeric],\n        head_dim: int,\n        head_dim_v: Optional[int] = None,\n        qhead_per_kvhead: cutlass.Constexpr[int] = 1,\n        is_causal: bool = False,\n        is_local: bool = False,\n        is_split_kv: bool = False,\n        pack_gqa: bool = False,\n        q_subtile_factor: int | None = None,\n        m_block_size: int = 128,\n        n_block_size: int = 128,\n        q_stage: cutlass.Constexpr[int] = 2,\n        is_persistent: bool = True,\n        score_mod: cutlass.Constexpr | None = None,\n        mask_mod: cutlass.Constexpr | None = None,\n        has_aux_tensors: cutlass.Constexpr = False,\n        paged_kv_non_tma: bool = False,\n        is_varlen_q: bool = False,\n        use_2cta_instrs: bool = False,\n    ):\n        self.use_tma_KV = not paged_kv_non_tma\n        # self.dtype = dtype\n        # padding head_dim to a multiple of 16 as k_block_size\n        hdim_multiple_of = 16\n        self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of)\n        head_dim_v = head_dim_v if head_dim_v is not None else head_dim\n        self.same_hdim_kv = head_dim == head_dim_v\n        self.head_dim_v_padded = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of)\n        self.same_hdim_kv_padded = self.head_dim_padded == self.head_dim_v_padded\n        self.check_hdim_oob = head_dim != self.head_dim_padded\n        self.check_hdim_v_oob = head_dim_v != self.head_dim_v_padded\n        self.m_block_size = m_block_size\n        self.n_block_size = n_block_size\n        self.q_stage = q_stage\n        assert self.q_stage in [1, 2]\n        self.use_2cta_instrs = use_2cta_instrs\n        # If split_P_arrive, the softmax warps write some columns of P first, signal to the MMA warp\n        # to being the P @ V MMA, then write the rest of P and signal again. This allows some overlap\n        # between compute the last couple columns of P and the P @ V MMA.\n        self.split_P_arrive = n_block_size // 4 * 3\n        self.split_P_arrive = int(self.split_P_arrive / 32) * 32  # multiple of 32\n        assert self.split_P_arrive % 32 == 0\n        assert self.split_P_arrive < self.n_block_size\n        self.arch = BaseDSL._get_dsl().get_arch_enum()\n        assert self.arch >= Arch.sm_100 and self.arch <= Arch.sm_110f, \"Only SM 10.x and 11.x are supported\"\n\n        self.cta_group_size = 2 if self.use_2cta_instrs else 1\n        # cta_tiler M includes only 1 CTA, the scheduler will take into account the cluster shape\n        self.cta_tiler = (self.q_stage * m_block_size, n_block_size, self.head_dim_padded)\n        # With 2CTA, the MMA tiler M covers both CTAs, so it's cta_group_size * m_block_size.\n        # Each CTA owns m_block_size rows; the 2CTA MMA instruction spans both.\n        self.mma_tiler_qk = (self.cta_group_size * m_block_size, n_block_size, self.head_dim_padded)\n        self.mma_tiler_pv = (self.cta_group_size * m_block_size, self.head_dim_v_padded, n_block_size)\n        self.qk_acc_dtype = Float32\n        self.pv_acc_dtype = Float32\n        self.cluster_shape_mn = (2, 1) if self.use_2cta_instrs else (1, 1)\n        self.is_persistent = is_persistent\n        self.is_causal = is_causal\n        self.is_local = is_local\n        self.is_varlen_q = is_varlen_q\n        self.use_correction_warps_for_epi = is_varlen_q\n        self.qhead_per_kvhead = qhead_per_kvhead\n        self.is_split_kv = is_split_kv\n        self.pack_gqa = pack_gqa\n        self.q_subtile_factor = q_subtile_factor\n        if pack_gqa:\n            assert m_block_size % self.qhead_per_kvhead == 0, (\n                \"For PackGQA, m_block_size must be divisible by qhead_per_kvhead\"\n            )\n        assert not (self.is_split_kv and self.head_dim_v_padded >= 192), (\n            \"SplitKV is not supported for hdim >= 192\"\n        )\n        self.score_mod = score_mod\n        self.mask_mod = mask_mod\n        self.vec_size: cutlass.Constexpr = getattr(\n            score_mod, \"__vec_size__\", 1 if cutlass.const_expr(has_aux_tensors) else 2\n        )\n        # Does S1 need to wait for S0 to finish\n        # self.s0_s1_barrier = self.head_dim_padded in [64, 96] and (not self.is_causal and not self.is_local)\n        is_sm103 = self.arch >= Arch.sm_103 and self.arch <= Arch.sm_103f\n        # self.enable_ex2_emu = self.head_dim_padded <= 128 and not is_sm103\n        self.enable_ex2_emu = (self.head_dim_padded <= 128 or (self.head_dim_padded == 192 and self.use_2cta_instrs and not self.is_causal and not self.is_local)) and not is_sm103\n        self.s0_s1_barrier = False\n        self.overlap_sO_sQ = (\n            (self.head_dim_padded == 192 and self.head_dim_v_padded >= 64) or\n            (self.head_dim_v_padded >= 128 and self.is_split_kv)\n        )\n        if self.overlap_sO_sQ:\n            self.is_persistent = False\n\n        assert self.use_tma_KV or not (self.check_hdim_oob or self.check_hdim_v_oob), (\n            \"Paged KV does not support irregular head dim\"\n        )\n\n        self.softmax0_warp_ids = (0, 1, 2, 3)\n        self.softmax1_warp_ids = (4, 5, 6, 7)\n        self.correction_warp_ids = (8, 9, 10, 11)\n        self.mma_warp_id = 12\n        self.epilogue_warp_ids = (13,)\n        self.load_warp_ids = (14,)\n        self.empty_warp_ids = (15,)\n        self.tmem_alloc_cols = cute.arch.get_max_tmem_alloc_cols(\"sm_100\")\n\n        self.threads_per_cta = cute.arch.WARP_SIZE * len(\n            (\n                *self.softmax0_warp_ids,\n                *self.softmax1_warp_ids,\n                *self.correction_warp_ids,\n                self.mma_warp_id,\n                *self.load_warp_ids,\n                *self.epilogue_warp_ids,\n                *self.empty_warp_ids,\n            )\n        )\n\n        if self.q_stage == 1:\n            if not self.use_tma_KV:\n                self.empty_warp_ids = self.empty_warp_ids + self.load_warp_ids\n                self.load_warp_ids = self.softmax1_warp_ids\n            else:\n                self.empty_warp_ids = self.empty_warp_ids + self.softmax1_warp_ids\n            self.softmax1_warp_ids = ()\n        elif not self.use_tma_KV:\n            self.load_warp_ids = (14, 15)\n            self.empty_warp_ids = ()\n\n        if self.use_correction_warps_for_epi:\n            self.empty_warp_ids = self.empty_warp_ids + self.epilogue_warp_ids\n            self.epilogue_warp_ids = self.correction_warp_ids\n        elif self.is_varlen_q: # fallback\n            self.epilogue_warp_ids = (13, 14)\n\n        self.tmem_s_offset = [0, self.n_block_size]  # e.g., 0, 128\n        self.tmem_o_offset = [\n            self.tmem_s_offset[-1] + self.n_block_size + i * self.head_dim_v_padded\n            for i in range(self.q_stage)\n        ]  # e.g., 256, 384\n        self.tmem_total = self.tmem_o_offset[-1] + self.head_dim_v_padded\n        assert self.tmem_total <= self.tmem_alloc_cols\n        self.tmem_s_to_p_offset = self.n_block_size // 2\n        self.tmem_p_offset = [\n            self.tmem_s_offset[i] + self.tmem_s_to_p_offset for i in range(2)\n        ]  # 0, 128\n\n        # vec buffer for row_max & row_sum\n        self.tmem_vec_offset = self.tmem_s_offset\n\n        if self.head_dim_padded < 96:\n            self.num_regs_softmax = 200 if not paged_kv_non_tma else 184\n            self.num_regs_correction = 64\n            self.num_regs_other = 48 if not paged_kv_non_tma else 80\n        else:\n            # self.num_regs_softmax = 192 if self.is_causal or self.is_local else 184\n            if not self.enable_ex2_emu:\n                self.num_regs_softmax = 192 if not paged_kv_non_tma else 184\n            else:\n                # self.num_regs_softmax = 200 if not paged_kv_non_tma else 184\n                self.num_regs_softmax = 192 if not paged_kv_non_tma else 184\n            # self.num_regs_softmax = 176\n            # self.num_regs_correction = 96\n            # self.num_regs_correction = 64 if self.is_causal or self.is_local else 80\n            if not self.enable_ex2_emu:\n                self.num_regs_correction = 80 if not paged_kv_non_tma else 64\n            else:\n                # self.num_regs_correction = 64\n                self.num_regs_correction = 80 if not paged_kv_non_tma else 64\n            # self.num_regs_other = 32\n            # self.num_regs_other = 64\n            # self.num_regs_other = 80\n            self.num_regs_other = 48 if not paged_kv_non_tma else 80\n            # self.num_regs_other = 96 if self.is_causal or self.is_local else 80\n            # self.num_regs_other = 64 if self.is_causal or self.is_local else 80\n\n        self.buffer_align_bytes = 1024\n\n    def _setup_attributes(self):\n        \"\"\"Set up configurations and parameters for the FMHA kernel operation.\n\n        This method initializes and configures various attributes required for the\n        execution of the fused multi-head attention kernel, mainly about the pipeline stages:\n\n        - Sets up staging parameters for Q, K, V inputs and accumulator data\n        - Configures pipeline stages for softmax, correction, and epilogue operations\n        \"\"\"\n\n        smem_size_q = self.q_stage * self.m_block_size * self.head_dim_padded * self.q_dtype.width // 8\n        smem_size_o = self.q_stage * self.m_block_size * self.head_dim_v_padded * self.o_dtype.width // 8\n        smem_size_q_o = smem_size_q + smem_size_o if not self.overlap_sO_sQ else max(smem_size_q, smem_size_o)\n        smem_size_k_per_stage = self.n_block_size * self.head_dim_padded * self.k_dtype.width // 8\n        smem_size_v_per_stage = self.n_block_size * self.head_dim_v_padded * self.v_dtype.width // 8\n        smem_size_kv_per_stage = max(smem_size_k_per_stage, smem_size_v_per_stage) // self.cta_group_size\n        kv_stage = (224 * 1024 - smem_size_q_o) // smem_size_kv_per_stage\n        if self.head_dim_padded == 192 and self.head_dim_v_padded == 128 and kv_stage == 2:\n            # For hdim 192,128, we can fit 3 stages if we use uneven_kv_smem\n             kv_stage = 3\n        self.kv_stage = kv_stage\n        # print(\"kv_stage\", self.kv_stage)\n        self.s_stage = 2\n        assert self.s_stage >= self.q_stage\n        # For hdim 192,128 1CTA, we don't have enough smem to store all 3 stages of KV:\n        # 128 x 192 x 2 bytes x 3 stages = 144KB, and we need 96KB for Q.\n        # Instead we store smem as [smem_large, smem_small, smem_large], where smem_large is\n        # 128 x 192 and smem_small is 128 x 128. We set the stride between the stages to be\n        # 128 * 160, so that indexing the 0th and 2nd stages will get the right address,\n        # but for the 1st stage we need to add or subtract (depending on phase) 128 x 64.\n        self.uneven_kv_smem = (\n            self.head_dim_padded == 192 and self.head_dim_v_padded == 128 and self.kv_stage == 3\n        )\n        self.uneven_kv_smem_offset = (\n            self.n_block_size * (self.head_dim_padded - self.head_dim_v_padded) // 2\n            if self.uneven_kv_smem\n            else 0\n        )\n        assert self.uneven_kv_smem_offset % 1024 == 0\n\n    @cute.jit\n    def __call__(\n        self,\n        mQ: cute.Tensor,  # (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q\n        mK: cute.Tensor,  # (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table\n        mV: cute.Tensor,  # (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table\n        mO: cute.Tensor,  # (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q\n        mLSE: Optional[cute.Tensor],\n        softmax_scale: Float32,\n        mCuSeqlensQ: Optional[cute.Tensor] = None,\n        mCuSeqlensK: Optional[cute.Tensor] = None,\n        mSeqUsedQ: Optional[cute.Tensor] = None,\n        mSeqUsedK: Optional[cute.Tensor] = None,\n        mPageTable: Optional[cute.Tensor] = None,  # (b_k, max_num_pages_per_seq)\n        window_size_left: Int32 | int | None = None,\n        window_size_right: Int32 | int | None = None,\n        learnable_sink: Optional[cute.Tensor] = None,\n        blocksparse_tensors: Optional[BlockSparseTensors] = None,\n        aux_tensors: Optional[list] = None,\n        # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI).\n        stream: cuda.CUstream = None,\n    ):\n        \"\"\"Execute the Fused Multi-Head Attention operation on the provided tensors.\n\n        This method prepares the input tensors for processing, validates their shapes and types,\n        configures the computation parameters, and launches the CUDA kernel.\n\n        The method handles:\n        1. Tensor layout transformations for specific memory access patterns\n        2. Validation of tensor shapes and data types\n        3. Initialization of hardware-specific parameters and memory layouts\n        4. Configuration of TMA (Tensor Memory Access) operations\n        5. Grid and work scheduling computation\n        6. Kernel launch with appropriate parameters\n        \"\"\"\n        # setup static attributes before smem/grid/tma computation\n        self.q_dtype = mQ.element_type\n        self.k_dtype = mK.element_type\n        self.v_dtype = mV.element_type\n        self.o_dtype = mO.element_type\n        mQ, mK, mV, mO = [assume_tensor_aligned(t) for t in (mQ, mK, mV, mO)]\n        Q_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1]\n        mQ = cute.make_tensor(mQ.iterator, cute.select(mQ.layout, mode=Q_layout_transpose))\n        # (s_k, d, h_k, b_k) or (total_k, d, h_k) if there's cu_seqlens_k or (page_size, d, h_k, num_pages) if there's page_table\n        KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1]\n        mK, mV = [\n            cute.make_tensor(t.iterator, cute.select(t.layout, mode=KV_layout_transpose))\n            for t in (mK, mV)\n        ]\n        if const_expr(self.is_split_kv):\n            O_layout_transpose = [2, 4, 3, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 3, 2, 0]\n            LSE_layout_transpose = [3, 2, 1, 0] if const_expr(mCuSeqlensQ is None) else [2, 1, 0]\n            num_splits = mO.shape[0]\n        else:\n            O_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1]\n            LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0]\n            num_splits = Int32(1)\n        mO = cute.make_tensor(mO.iterator, cute.select(mO.layout, mode=O_layout_transpose))\n        mLSE = (\n            cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose))\n            if const_expr(mLSE is not None)\n            else None\n        )\n        # (s, d, h, b) -> (d, s, h, b)\n        V_layout_transpose = [1, 0, 2, 3] if const_expr(mCuSeqlensK is None) else [1, 0, 2]\n        mV = cute.make_tensor(mV.iterator, cute.select(mV.layout, mode=V_layout_transpose))\n\n        # check type consistency\n        if const_expr(self.q_dtype != self.k_dtype):\n            raise TypeError(f\"Type mismatch: {self.q_dtype} != {self.k_dtype}\")\n        if const_expr(self.q_dtype != self.v_dtype):\n            raise TypeError(f\"Type mismatch: {self.q_dtype} != {self.v_dtype}\")\n        self._setup_attributes()\n        self.use_tma_O = self.arch >= Arch.sm_90 and mCuSeqlensQ is None and mSeqUsedQ is None\n        # This can be tuned\n        # This is currently very ad-hoc, we should tune it systematically\n        self.ex2_emu_freq = 0\n        # self.ex2_emu_start_frg = 1 if self.is_causal else 0\n        self.ex2_emu_start_frg = 1\n        if const_expr(self.enable_ex2_emu):\n            self.ex2_emu_freq = 16\n            if const_expr(self.head_dim_padded == 128 and self.use_2cta_instrs):\n                self.ex2_emu_freq = 12\n            if const_expr(\n                self.pack_gqa and self.head_dim_padded > 64 and not self.is_causal and not self.is_local\n            ):\n                self.ex2_emu_freq = 32 if mCuSeqlensQ is not None or mSeqUsedQ is not None else 10\n            if const_expr(self.head_dim_padded > 64 and self.is_causal):\n                self.ex2_emu_freq = 10\n\n        cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE\n        q_major_mode = tcgen05.OperandMajorMode.K\n        k_major_mode = tcgen05.OperandMajorMode.K\n        v_major_mode = tcgen05.OperandMajorMode.MN\n        self.o_layout = cutlass.utils.LayoutEnum.from_tensor(mO)\n        # the intermediate tensor p is from tmem & mK-major\n        p_source = tcgen05.OperandSource.TMEM\n        p_major_mode = tcgen05.OperandMajorMode.K\n        tiled_mma_qk = sm100_utils_basic.make_trivial_tiled_mma(\n            self.q_dtype,\n            q_major_mode,\n            k_major_mode,\n            self.qk_acc_dtype,\n            cta_group,\n            self.mma_tiler_qk[:2],\n        )\n        tiled_mma_pv = sm100_utils_basic.make_trivial_tiled_mma(\n            self.v_dtype,\n            p_major_mode,\n            v_major_mode,\n            self.pv_acc_dtype,\n            cta_group,\n            self.mma_tiler_pv[:2],\n            p_source,\n        )\n\n        self.cluster_shape_mnk = (*self.cluster_shape_mn, 1)\n        cta_layout_vmnk = cute.tiled_divide(\n            cute.make_layout(self.cluster_shape_mnk), (tiled_mma_qk.thr_id.shape,)\n        )\n\n        # epi_tile is per-CTA (not full 2CTA) since each CTA writes its own O portion\n        self.epi_tile = (self.m_block_size, self.head_dim_v_padded)\n\n        sQ_layout = sm100_utils_basic.make_smem_layout_a(\n            tiled_mma_qk, self.mma_tiler_qk, self.q_dtype, self.q_stage\n        )\n        sK_layout = sm100_utils_basic.make_smem_layout_b(\n            tiled_mma_qk, self.mma_tiler_qk, self.k_dtype, self.kv_stage\n        )\n        tP_layout = sm100_utils_basic.make_smem_layout_a(\n            tiled_mma_pv, self.mma_tiler_pv, self.q_dtype, self.s_stage\n        )\n        sV_layout = sm100_utils_basic.make_smem_layout_b(\n            tiled_mma_pv, self.mma_tiler_pv, self.v_dtype, self.kv_stage\n        )\n        sO_layout = sm100_utils_basic.make_smem_layout_epi(\n            self.o_dtype, self.o_layout, self.epi_tile, self.q_stage\n        )\n        if const_expr(not self.same_hdim_kv_padded):\n            # sK and sV are using the same physical smem so we need to adjust the stride so that they line up\n            stride_sK = const_expr(\n                max(sK_layout.outer.stride[-1], 0)\n            )  # take max to turn tuple to Int32\n            stride_sV = const_expr(max(sV_layout.outer.stride[-1], 0))\n            stage_stride = const_expr(\n                max(stride_sK, stride_sV)\n                if not self.uneven_kv_smem\n                else (stride_sK + stride_sV) // 2\n            )\n            sK_layout = cute.make_composed_layout(\n                sK_layout.inner,\n                0,\n                cute.make_layout(\n                    (*sK_layout.outer.shape[:-1], self.kv_stage),\n                    stride=(*sK_layout.outer.stride[:-1], stage_stride),\n                ),\n            )\n            sV_layout = cute.make_composed_layout(\n                sV_layout.inner,\n                0,\n                cute.make_layout(\n                    (*sV_layout.outer.shape[:-1], self.kv_stage),\n                    stride=(*sV_layout.outer.stride[:-1], stage_stride),\n                ),\n            )\n\n        if const_expr(self.pack_gqa):\n            nheads_kv = mK.shape[2]\n            mQ = pack_gqa_layout(mQ, self.qhead_per_kvhead, nheads_kv, head_idx=2)\n            mO = pack_gqa_layout(mO, self.qhead_per_kvhead, nheads_kv, head_idx=2)\n            if const_expr(mLSE is not None):\n                mLSE = pack_gqa_layout(mLSE, self.qhead_per_kvhead, nheads_kv, head_idx=1)\n\n        self.tma_copy_bytes = {\n            name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1, 2]))\n            for name, mX, layout in [\n                (\"Q\", mQ, sQ_layout),\n                (\"K\", mK, sK_layout),\n                (\"V\", mV, sV_layout),\n            ]\n        }\n        for name in (\"Q\", \"K\", \"V\"):\n            self.tma_copy_bytes[name] *= self.cta_group_size\n\n        # TMA load for Q\n        tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group)\n        tma_store_op = cpasync.CopyBulkTensorTileS2GOp()\n\n        tma_atom_Q, mQ = cute.nvgpu.make_tiled_tma_atom_A(\n            tma_load_op,\n            mQ,\n            cute.select(sQ_layout, mode=[0, 1, 2]),\n            self.mma_tiler_qk,\n            tiled_mma_qk,\n            cta_layout_vmnk.shape,\n        )\n\n        tma_atom_K = None\n        tma_atom_V = None\n        if const_expr(self.use_tma_KV):\n            # TMA load for K\n            tma_atom_K, mK = cute.nvgpu.make_tiled_tma_atom_B(\n                tma_load_op,\n                mK,\n                cute.select(sK_layout, mode=[0, 1, 2]),\n                self.mma_tiler_qk,\n                tiled_mma_qk,\n                cta_layout_vmnk.shape,\n            )\n            # TMA load for V\n            tma_atom_V, mV = cute.nvgpu.make_tiled_tma_atom_B(\n                tma_load_op,\n                mV,\n                cute.select(sV_layout, mode=[0, 1, 2]),\n                self.mma_tiler_pv,\n                tiled_mma_pv,\n                cta_layout_vmnk.shape,\n            )\n\n        self.num_epilogue_threads = cute.arch.WARP_SIZE * len(self.epilogue_warp_ids)\n        if const_expr(self.use_tma_O):\n            tma_atom_O, mO = cpasync.make_tiled_tma_atom(\n                tma_store_op, mO, cute.select(sO_layout, mode=[0, 1]), self.epi_tile\n            )\n            gmem_tiled_copy_O = None\n        else:\n            tma_atom_O = None\n            universal_copy_bits = 128\n            async_copy_elems = universal_copy_bits // self.o_dtype.width\n            atom_universal_copy = cute.make_copy_atom(\n                cute.nvgpu.CopyUniversalOp(),\n                self.o_dtype,\n                num_bits_per_copy=universal_copy_bits,\n            )\n            tO_shape_dim_1 = sO_layout.outer.shape[1][0] // async_copy_elems\n            tO_layout = cute.make_ordered_layout(\n                (self.num_epilogue_threads // tO_shape_dim_1, tO_shape_dim_1),\n                order=(1, 0),\n            )\n            # So that we don't have to check if we overshoot kBlockM when we store O\n            assert self.m_block_size % tO_layout.shape[0] == 0\n            vO_layout = cute.make_layout((1, async_copy_elems))\n            gmem_tiled_copy_O = cute.make_tiled_copy_tv(atom_universal_copy, tO_layout, vO_layout)\n\n        if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None):\n            TileScheduler = SingleTileVarlenScheduler\n        else:\n            if const_expr(self.is_causal or self.is_local):\n                TileScheduler = SingleTileLPTScheduler\n            else:\n                TileScheduler = (\n                    SingleTileScheduler\n                    if const_expr(not self.is_persistent)\n                    else StaticPersistentTileScheduler\n                )\n        tile_sched_args = TileSchedulerArguments(\n            cute.ceil_div(cute.size(mQ.shape[0]), self.cta_tiler[0]),\n            cute.size(mQ.shape[2]),\n            cute.size(mQ.shape[3])\n            if const_expr(mCuSeqlensQ is None)\n            else cute.size(mCuSeqlensQ.shape[0] - 1),\n            num_splits,\n            cute.size(mK.shape[0])\n            if const_expr(mPageTable is None)\n            else mK.shape[0] * mPageTable.shape[1],\n            mQ.shape[1],\n            mV.shape[0],  # Note that this is different from Sm90 since we transpose mV in Sm100\n            total_q=cute.size(mQ.shape[0])\n            if const_expr(mCuSeqlensQ is not None)\n            else cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]),\n            tile_shape_mn=self.cta_tiler[:2],\n            mCuSeqlensQ=mCuSeqlensQ,\n            mSeqUsedQ=mSeqUsedQ,\n            qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,\n            element_size=self.k_dtype.width // 8,\n            is_persistent=self.is_persistent,\n            lpt=self.is_causal or self.is_local,\n            is_split_kv=self.is_split_kv,\n            cluster_shape_mn=self.cluster_shape_mn,\n        )\n        tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)\n        self.tile_scheduler_cls = TileScheduler\n        grid_dim = TileScheduler.get_grid_shape(tile_sched_params)\n\n        sO_size = cute.cosize(sO_layout) if const_expr(not self.overlap_sO_sQ) else 0\n        sQ_size = (\n            cute.cosize(sQ_layout) if const_expr(not self.overlap_sO_sQ) else\n            cutlass.max(cute.cosize(sQ_layout), cute.cosize(sO_layout) * self.o_dtype.width // self.q_dtype.width)\n        )\n\n        @cute.struct\n        class SharedStorage:\n            # m_barriers for pipelines\n            mbar_load_Q: cute.struct.MemRange[Int64, self.q_stage * 2]\n            mbar_load_KV: cute.struct.MemRange[Int64, self.kv_stage * 2]\n            mbar_S_full_P_full_O_rescaled: cute.struct.MemRange[Int64, self.q_stage * 2]\n            mbar_P_full_lastsplit: cute.struct.MemRange[Int64, self.q_stage * 2]\n            mbar_O_full: cute.struct.MemRange[Int64, self.q_stage * 2]\n            mbar_softmax_stats: cute.struct.MemRange[Int64, self.q_stage * 2]\n            # mbar_softmax_stats: cute.struct.MemRange[Int64, self.q_stage * 4 * 2]\n            mbar_O_epi: cute.struct.MemRange[Int64, self.q_stage * 2]\n            mbar_s0_s1_sequence: cute.struct.MemRange[Int64, 2 * 2]\n            # Tmem dealloc cluster barrier\n            tmem_dealloc_mbar_ptr: Int64\n            # Tmem holding buffer\n            tmem_holding_buf: Int32\n            # Smem tensors\n            # store row max and row sum\n            sScale: cute.struct.MemRange[Float32, self.q_stage * self.m_block_size * 2]\n            sO: cute.struct.Align[\n                cute.struct.MemRange[self.o_dtype, sO_size], self.buffer_align_bytes\n            ]\n            sQ: cute.struct.Align[\n                cute.struct.MemRange[self.q_dtype, sQ_size], self.buffer_align_bytes\n            ]\n            sK: cute.struct.Align[\n                # cute.cosize(sK_layout) is correct even in the case of self.uneven_kv_smem\n                cute.struct.MemRange[self.k_dtype, cute.cosize(sK_layout)],\n                self.buffer_align_bytes,\n            ]\n\n        self.shared_storage = SharedStorage\n\n        softmax_scale_log2, softmax_scale = utils.compute_softmax_scale_log2(softmax_scale, self.score_mod)\n        window_size_left = Int32(window_size_left) if window_size_left is not None else None\n        window_size_right = Int32(window_size_right) if window_size_right is not None else None\n        fastdiv_mods = utils.compute_fastdiv_mods(mQ, mK, self.qhead_per_kvhead, self.pack_gqa, aux_tensors, mPageTable)\n\n        head_divmod = None\n        if cutlass.const_expr(self.pack_gqa):\n            head_divmod = FastDivmodDivisor(self.qhead_per_kvhead)\n\n        self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None)\n        if cutlass.const_expr(self.use_block_sparsity and mPageTable is not None):\n            raise NotImplementedError(\"Block sparsity + paged KV not supported on SM100\")\n\n        # Launch the kernel synchronously\n        self.kernel(\n            mQ,\n            mK,\n            mV,\n            mO,\n            mLSE,\n            mCuSeqlensQ,\n            mCuSeqlensK,\n            mSeqUsedQ,\n            mSeqUsedK,\n            mPageTable,\n            tma_atom_Q,\n            tma_atom_K,\n            tma_atom_V,\n            tma_atom_O,\n            softmax_scale_log2,\n            softmax_scale,\n            window_size_left,\n            window_size_right,\n            learnable_sink,\n            blocksparse_tensors,\n            sQ_layout,\n            sK_layout,\n            tP_layout,\n            sV_layout,\n            sO_layout,\n            gmem_tiled_copy_O,\n            tiled_mma_qk,\n            tiled_mma_pv,\n            tile_sched_params,\n            num_splits,\n            aux_tensors,\n            fastdiv_mods,\n            head_divmod,\n        ).launch(\n            grid=grid_dim,\n            block=[self.threads_per_cta, 1, 1],\n            cluster=self.cluster_shape_mnk if cute.size(self.cluster_shape_mnk) > 1 else None,\n            stream=stream,\n            min_blocks_per_mp=1,\n        )\n\n    #  GPU device kernel\n    @cute.kernel\n    def kernel(\n        self,\n        mQ: cute.Tensor,  # (s_q, d, h, b) or (total_q, d, h) if there is cu_seqlens_q\n        mK: cute.Tensor,  # (s_k, d, h_k, b_k) or (total_k, d, h_k) if there is cu_seqlens_k or (page_size, d, h_k, num_pages) if there is page_table\n        mV: cute.Tensor,  # (d, s_k, h_k, b_k) or (d, total_k, h_k) if there is cu_seqlens_k or (d, page_size, h_k, num_pages) if there is page_table\n        mO: cute.Tensor,\n        mLSE: Optional[cute.Tensor],\n        mCuSeqlensQ: Optional[cute.Tensor],\n        mCuSeqlensK: Optional[cute.Tensor],\n        mSeqUsedQ: Optional[cute.Tensor],\n        mSeqUsedK: Optional[cute.Tensor],\n        mPageTable: Optional[cute.Tensor],\n        tma_atom_Q: cute.CopyAtom,\n        tma_atom_K: Optional[cute.CopyAtom],\n        tma_atom_V: Optional[cute.CopyAtom],\n        tma_atom_O: Optional[cute.CopyAtom],\n        softmax_scale_log2: Float32,\n        softmax_scale: Float32 | None,\n        window_size_left: Optional[Int32],\n        window_size_right: Optional[Int32],\n        learnable_sink: Optional[cute.Tensor],\n        blocksparse_tensors: Optional[BlockSparseTensors],\n        sQ_layout: cute.ComposedLayout,\n        sK_layout: cute.ComposedLayout,\n        tP_layout: cute.ComposedLayout,\n        sV_layout: cute.ComposedLayout,\n        sO_layout: cute.ComposedLayout,\n        gmem_tiled_copy_O: Optional[cute.TiledCopy],\n        tiled_mma_qk: cute.TiledMma,\n        tiled_mma_pv: cute.TiledMma,\n        tile_sched_params: ParamsBase,\n        num_splits: Int32,\n        aux_tensors: Optional[list] = None,\n        fastdiv_mods=(None, None),\n        head_divmod=None,\n    ):\n        \"\"\"The device kernel implementation of the Fused Multi-Head Attention.\n\n        This kernel coordinates multiple specialized warps to perform different phases of the FMHA computation:\n        1. Load warp: Loads Q, K, V data from global memory to shared memory using TMA\n        2. MMA warp: Performs matrix multiplications (Q*K^T and P*V)\n        3. Softmax warps: Compute softmax normalization on attention scores\n        4. Correction warps: Apply adjustments to intermediate results\n        5. Epilogue warp: Handles final output transformation and storage\n\n        The kernel implements a complex pipeline with overlapping computation and memory operations,\n        using tensor memory access (TMA) for efficient data loading, warp specialization for different\n        computation phases, and optional attention masking.\n        \"\"\"\n\n        warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())\n\n        # Prefetch tma descriptor\n        if warp_idx == 0:\n            for tma_atom in (tma_atom_Q, tma_atom_K, tma_atom_V, tma_atom_O):\n                if const_expr(tma_atom is not None):\n                    cpasync.prefetch_descriptor(tma_atom)\n\n        cta_layout_vmnk = cute.tiled_divide(\n            cute.make_layout(self.cluster_shape_mnk), (tiled_mma_qk.thr_id.shape,)\n        )\n        # Setup cta/thread coordinates\n        bidx, _, _ = cute.arch.block_idx()\n        if const_expr(cute.size(tiled_mma_qk.thr_id.shape) == 1):\n            mma_tile_coord_v = 0\n        else:\n            mma_tile_coord_v = bidx % cute.size(tiled_mma_qk.thr_id.shape)\n        is_leader_cta = mma_tile_coord_v == 0\n\n        # Alloc\n        smem = cutlass.utils.SmemAllocator()\n        storage = smem.allocate(self.shared_storage)\n\n        tmem_alloc_barrier = pipeline.NamedBarrier(\n            barrier_id=int(NamedBarrierFwdSm100.TmemPtr),\n            num_threads=cute.arch.WARP_SIZE * len(\n                (self.mma_warp_id,\n                 *self.softmax0_warp_ids,\n                 *self.softmax1_warp_ids,\n                 *self.correction_warp_ids)\n            ),\n        )\n        # Tensor memory dealloc barrier init\n        tmem = cutlass.utils.TmemAllocator(\n            storage.tmem_holding_buf,\n            barrier_for_retrieve=tmem_alloc_barrier,\n            allocator_warp_id=self.mma_warp_id,\n            is_two_cta=self.use_2cta_instrs,\n            two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr,\n        )\n\n        ThreadCooperativeGroup = partial(pipeline.CooperativeGroup, pipeline.Agent.Thread)\n        mma_warp = ThreadCooperativeGroup(len([self.mma_warp_id]))\n        load_warps = ThreadCooperativeGroup(len(self.load_warp_ids))\n        tma_warp = ThreadCooperativeGroup(1)\n        softmax_warps = ThreadCooperativeGroup(len(self.softmax0_warp_ids))\n        softmax_threads = ThreadCooperativeGroup(cute.arch.WARP_SIZE * len(self.softmax0_warp_ids))\n        # softmax_threads = ThreadCooperativeGroup(cute.arch.WARP_SIZE)\n        correction_threads = ThreadCooperativeGroup(\n            cute.arch.WARP_SIZE * len(self.correction_warp_ids)\n        )\n        # correction_threads = ThreadCooperativeGroup(cute.arch.WARP_SIZE)\n        softmax_correction_threads = ThreadCooperativeGroup(\n            cute.arch.WARP_SIZE * len(self.softmax0_warp_ids + self.correction_warp_ids)\n        )\n        epilogue_threads = ThreadCooperativeGroup(cute.arch.WARP_SIZE * len(self.epilogue_warp_ids))\n        # For UMMA-bridging pipelines: the non-MMA side spans both CTAs in the cluster,\n        # so the thread count must include warps from both CTAs.\n        softmax_warps_cluster = ThreadCooperativeGroup(\n            len(self.softmax0_warp_ids) * self.cta_group_size\n        )\n        correction_threads_cluster = ThreadCooperativeGroup(\n            cute.arch.WARP_SIZE * len(self.correction_warp_ids) * self.cta_group_size\n        )\n        softmax_correction_threads_cluster = ThreadCooperativeGroup(\n            cute.arch.WARP_SIZE * len(self.softmax0_warp_ids + self.correction_warp_ids) * self.cta_group_size\n        )\n        pipeline_q = pipeline_custom.PipelineTmaUmma.create(\n            barrier_storage=storage.mbar_load_Q.data_ptr(),\n            num_stages=self.q_stage,\n            producer_group=tma_warp,\n            consumer_group=mma_warp,\n            tx_count=self.tma_copy_bytes[\"Q\"],\n            cta_layout_vmnk=cta_layout_vmnk,\n            defer_sync=True,\n        )\n        if const_expr(self.use_tma_KV):\n            pipeline_kv = pipeline_custom.PipelineTmaUmma.create(\n                barrier_storage=storage.mbar_load_KV.data_ptr(),\n                num_stages=self.kv_stage,\n                producer_group=tma_warp,\n                consumer_group=mma_warp,\n                tx_count=self.tma_copy_bytes[\"K\"],\n                cta_layout_vmnk=cta_layout_vmnk,\n                defer_sync=True,\n            )\n        else:\n            cpasync_producer_group = pipeline.CooperativeGroup(\n                pipeline.Agent.Thread, len(self.load_warp_ids) * cute.arch.WARP_SIZE\n            )\n            pipeline_kv = pipeline.PipelineAsyncUmma.create(\n                barrier_storage=storage.mbar_load_KV.data_ptr(),\n                num_stages=self.kv_stage,\n                producer_group=cpasync_producer_group,\n                consumer_group=mma_warp,\n                cta_layout_vmnk=cta_layout_vmnk,\n                defer_sync=True,\n            )\n        # This pipeline is not the typical producer-consumer pipeline. The \"producer\" mma warp\n        # uses it to signal that S is ready, and the softmax threads wait for S to be ready.\n        # When softmax threads write P to tmem and the correction threads have rescaled O, they\n        # signal as \"consumer\". The mma warp then waits for that signal to do the P @ V gemm.\n        pipeline_s_p_o = pipeline_custom.PipelineUmmaAsync.create(\n            barrier_storage=storage.mbar_S_full_P_full_O_rescaled.data_ptr(),\n            num_stages=self.q_stage,\n            producer_group=mma_warp,\n            consumer_group=softmax_correction_threads_cluster,\n            cta_layout_vmnk=cta_layout_vmnk,\n            defer_sync=True,\n        )\n        pipeline_p_lastsplit = pipeline_custom.PipelineAsyncUmma.create(\n            barrier_storage=storage.mbar_P_full_lastsplit.data_ptr(),\n            num_stages=self.q_stage,\n            producer_group=softmax_warps_cluster,\n            consumer_group=mma_warp,\n            cta_layout_vmnk=cta_layout_vmnk,\n            defer_sync=True,\n        )\n        # MMA warp uses this to signal to the correction warps that O is ready.\n        pipeline_o_acc = pipeline_custom.PipelineUmmaAsync.create(\n            barrier_storage=storage.mbar_O_full.data_ptr(),\n            num_stages=self.q_stage,\n            producer_group=mma_warp,\n            consumer_group=correction_threads_cluster,\n            cta_layout_vmnk=cta_layout_vmnk,\n            defer_sync=True,\n        )\n        pipeline_s0_s1_sequence = None\n        if const_expr(self.s0_s1_barrier and self.q_stage > 1):\n            # This is not a typical producer-consumer pipeline. We will directly use\n            # pipeline_s0_s1_sequence.sync_object_full and will not use\n            # pipeline_s0_s1_sequence.sync_object_empty.\n            pipeline_s0_s1_sequence = pipeline_custom.PipelineAsync.create(\n                barrier_storage=storage.mbar_s0_s1_sequence.data_ptr(),\n                num_stages=2,\n                producer_group=softmax_threads,\n                consumer_group=softmax_threads,\n                defer_sync=True,\n            )\n        pipeline_sm_stats = pipeline_custom.PipelineAsync.create(\n            barrier_storage=storage.mbar_softmax_stats.data_ptr(),\n            num_stages=self.q_stage,\n            producer_group=softmax_threads,\n            consumer_group=correction_threads,\n            defer_sync=True,\n        )\n        # Should put the NamedBarrier inside the pipeline class so we'll just have pipeline_sm_stats\n        sm_stats_barrier = pipeline_custom.NamedBarrier(\n            barrier_id=int(NamedBarrierFwdSm100.SoftmaxStatsW0), num_threads=cute.arch.WARP_SIZE * 2\n        )\n        pipeline_o_epi = None\n        if const_expr(not self.use_correction_warps_for_epi):\n            pipeline_o_epi = pipeline_custom.PipelineAsync.create(\n                barrier_storage=storage.mbar_O_epi.data_ptr(),\n                num_stages=self.q_stage,\n                producer_group=correction_threads,\n                consumer_group=epilogue_threads,\n                defer_sync=True,\n            )\n\n        # Cluster arrive after barrier init\n        pipeline_init_arrive(cluster_shape_mn=cta_layout_vmnk, is_relaxed=True)\n\n        #  Generate smem tensor Q/K/V/O\n        # (MMA, MMA_Q, MMA_D, PIPE)\n        sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner)\n        # (MMA, MMA_K, MMA_D, PIPE)\n        sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner)\n        # (MMA, MMA_K, MMA_D, PIPE)\n        # Strip swizzle info to reuse smem\n        sV = cute.make_tensor(cute.recast_ptr(sK.iterator, sV_layout.inner), sV_layout.outer)\n        if const_expr(not self.overlap_sO_sQ):\n            sO = storage.sO.get_tensor(sO_layout.outer, swizzle=sO_layout.inner)\n        else:\n            sO = cute.make_tensor(cute.recast_ptr(sQ.iterator, sO_layout.inner, self.o_dtype), sO_layout.outer)\n\n        sScale = storage.sScale.get_tensor(cute.make_layout(self.q_stage * self.m_block_size * 2))\n\n        thr_mma_qk = tiled_mma_qk.get_slice(mma_tile_coord_v)\n        thr_mma_pv = tiled_mma_pv.get_slice(mma_tile_coord_v)\n\n        qk_acc_shape = thr_mma_qk.partition_shape_C(self.mma_tiler_qk[:2])\n        # This is a fake tensor, by right we need to retrieve tmem_ptr. But we know that we always\n        # request 512 columns of tmem, so we know that it starts at 0.\n        tStS = thr_mma_qk.make_fragment_C(cute.append(qk_acc_shape, self.s_stage))\n        pv_acc_shape = thr_mma_pv.partition_shape_C(self.mma_tiler_pv[:2])\n        tOtO = thr_mma_pv.make_fragment_C(cute.append(pv_acc_shape, self.q_stage))\n        tOtO = cute.make_tensor(tOtO.iterator + self.tmem_o_offset[0], tOtO.layout)\n        tP = cute.make_tensor(tStS.iterator, tP_layout.outer)\n        tOrP = thr_mma_pv.make_fragment_A(tP)[None, None, None, 0]\n        # Need to multiply by width ratio bc tP is in v_dtype but tmem offsets are in FP32\n        tP_width_ratio = Float32.width // self.v_dtype.width\n        # Need to adjust the stage stride manually since the two stages aren't contiguous in tmem\n        tP_stage_stride = (self.tmem_p_offset[1] - self.tmem_p_offset[0]) * tP_width_ratio\n        tOrP = cute.make_tensor(\n            tOrP.iterator + self.tmem_p_offset[0] * tP_width_ratio,\n            cute.append(tOrP.layout, cute.make_layout((self.s_stage,), stride=(tP_stage_stride,)))\n        )\n\n        block_info = BlockInfo(\n            # This is cta_tiler, not mma_tiler_qk, since we move by block by (2 * mma_tiler[0], mma_tiler[1])\n            self.cta_tiler[0],\n            self.cta_tiler[1],\n            self.is_causal,\n            self.is_local,\n            self.is_split_kv,\n            window_size_left,\n            window_size_right,\n            qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,\n        )\n        SeqlenInfoCls = partial(\n            SeqlenInfoQK.create,\n            seqlen_q_static=mQ.shape[0] if const_expr(not self.pack_gqa) else mQ.shape[0][1],\n            seqlen_k_static=mK.shape[0]\n            if const_expr(mPageTable is None)\n            else mK.shape[0] * mPageTable.shape[1],\n            mCuSeqlensQ=mCuSeqlensQ,\n            mCuSeqlensK=mCuSeqlensK,\n            mSeqUsedQ=mSeqUsedQ,\n            mSeqUsedK=mSeqUsedK,\n        )\n        AttentionMaskCls = partial(\n            AttentionMask,\n            self.m_block_size,\n            self.n_block_size,\n            window_size_left=window_size_left,\n            window_size_right=window_size_right,\n            qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,\n        )\n        TileSchedulerCls = partial(self.tile_scheduler_cls.create, tile_sched_params)\n\n        # Cluster wait before tensor memory alloc\n        pipeline_init_wait(cluster_shape_mn=cta_layout_vmnk)\n\n        # ///////////////////////////////////////////////////////////////////////////////\n        #  EMPTY\n        # ///////////////////////////////////////////////////////////////////////////////\n        for i in cutlass.range_constexpr(len(self.empty_warp_ids)):\n            if warp_idx == self.empty_warp_ids[i]:\n                cute.arch.setmaxregister_decrease(self.num_regs_other)\n\n        # ///////////////////////////////////////////////////////////////////////////////\n        #  LOAD\n        # ///////////////////////////////////////////////////////////////////////////////\n        if warp_idx >= self.load_warp_ids[0] and warp_idx <= self.load_warp_ids[-1]:\n            cute.arch.setmaxregister_decrease(self.num_regs_other)\n            self.load(\n                thr_mma_qk,\n                thr_mma_pv,\n                mQ,\n                mK,\n                mV,\n                sQ,\n                sK,\n                sV,\n                mPageTable,\n                tma_atom_Q,\n                tma_atom_K,\n                tma_atom_V,\n                pipeline_q,\n                pipeline_kv,\n                block_info,\n                num_splits,\n                SeqlenInfoCls,\n                TileSchedulerCls,\n                blocksparse_tensors,\n            )\n\n        # ///////////////////////////////////////////////////////////////////////////////\n        #  MMA\n        # ///////////////////////////////////////////////////////////////////////////////\n        if warp_idx == self.mma_warp_id:\n            cute.arch.setmaxregister_decrease(self.num_regs_other)\n            # Alloc tensor memory buffer\n            tmem.allocate(cute.arch.get_max_tmem_alloc_cols(\"sm_100\"))\n            tmem.wait_for_alloc()\n            tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)\n            self.mma(\n                tiled_mma_qk,\n                tiled_mma_pv,\n                sQ,\n                sK,\n                sV,\n                tStS,\n                tOtO,\n                tOrP,\n                pipeline_q,\n                pipeline_kv,\n                pipeline_s_p_o,\n                pipeline_p_lastsplit,\n                pipeline_o_acc,\n                is_leader_cta,\n                block_info,\n                num_splits,\n                SeqlenInfoCls,\n                TileSchedulerCls,\n                blocksparse_tensors,\n            )\n            # Dealloc the tensor memory buffer\n            tmem.relinquish_alloc_permit()\n            tmem_alloc_barrier.arrive_and_wait()\n            tmem.free(tmem_ptr)\n\n        # ///////////////////////////////////////////////////////////////////////////////\n        #  Epilogue\n        # ///////////////////////////////////////////////////////////////////////////////\n        if const_expr(not self.use_correction_warps_for_epi):\n            if warp_idx >= self.epilogue_warp_ids[0] and warp_idx <= self.epilogue_warp_ids[-1]:\n                cute.arch.setmaxregister_decrease(self.num_regs_other)\n                self.epilogue_s2g(\n                    mO,\n                    sO,\n                    gmem_tiled_copy_O,\n                    tma_atom_O,\n                    pipeline_o_epi,\n                    block_info,\n                    num_splits,\n                    SeqlenInfoCls,\n                    TileSchedulerCls,\n                    mma_tile_coord_v,\n                )\n\n        # ///////////////////////////////////////////////////////////////////////////////\n        #  Softmax\n        # ///////////////////////////////////////////////////////////////////////////////\n        if (\n            (const_expr(self.q_stage == 2) and warp_idx <= self.softmax1_warp_ids[-1]) or\n            (const_expr(self.q_stage == 1) and warp_idx <= self.softmax0_warp_ids[-1])\n        ):\n            # increase register after decreasing\n            cute.arch.setmaxregister_increase(self.num_regs_softmax)\n            # sync with mma warp before retrieving tmem ptr\n            tmem.wait_for_alloc()\n            tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)\n            softmax_loop = partial(\n                self.softmax_loop,\n                softmax_scale_log2=softmax_scale_log2,\n                softmax_scale=softmax_scale,\n                thr_mma_qk=thr_mma_qk,\n                sScale=sScale,\n                mLSE=mLSE,\n                pipeline_s_p_o=pipeline_s_p_o,\n                pipeline_p_lastsplit=pipeline_p_lastsplit,\n                pipeline_sm_stats=pipeline_sm_stats,\n                sm_stats_barrier=sm_stats_barrier,\n                pipeline_s0_s1_sequence=pipeline_s0_s1_sequence,\n                learnable_sink=learnable_sink,\n                block_info=block_info,\n                num_splits=num_splits,\n                SeqlenInfoCls=SeqlenInfoCls,\n                AttentionMaskCls=AttentionMaskCls,\n                TileSchedulerCls=TileSchedulerCls,\n                aux_tensors=aux_tensors,\n                fastdiv_mods=fastdiv_mods,\n                head_divmod=head_divmod,\n                blocksparse_tensors=blocksparse_tensors,\n            )\n\n            if const_expr(not self.s0_s1_barrier):\n                stage = Int32(0 if const_expr(self.q_stage == 1) or warp_idx < self.softmax1_warp_ids[0] else 1)\n                softmax_loop(stage=stage, tStS=tStS)\n            else:\n                # If there's s0_s1_barrier, it's faster to have 2 WGs having different code\n                if warp_idx < self.softmax1_warp_ids[0]:\n                    softmax_loop(stage=0, tStS=tStS)\n                if warp_idx < self.correction_warp_ids[0] and warp_idx >= self.softmax1_warp_ids[0]:\n                    softmax_loop(stage=1, tStS=tStS)\n\n            tmem_alloc_barrier.arrive()\n\n        # ///////////////////////////////////////////////////////////////////////////////\n        #  Correction\n        # ///////////////////////////////////////////////////////////////////////////////\n        if warp_idx >= self.correction_warp_ids[0] and warp_idx < self.mma_warp_id:\n            cute.arch.setmaxregister_decrease(self.num_regs_correction)\n            # sync with mma warp before retrieving tmem ptr\n            tmem.wait_for_alloc()\n            tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)\n            self.correction_loop(\n                thr_mma_qk,\n                thr_mma_pv,\n                tStS,\n                tOtO,\n                sScale,\n                mO,\n                mLSE,\n                sO,\n                pipeline_s_p_o,\n                pipeline_o_acc,\n                pipeline_sm_stats,\n                sm_stats_barrier,\n                pipeline_o_epi,\n                learnable_sink,\n                gmem_tiled_copy_O,\n                tma_atom_O,\n                softmax_scale_log2,\n                block_info,\n                num_splits,\n                SeqlenInfoCls,\n                TileSchedulerCls,\n                blocksparse_tensors,\n            )\n            tmem_alloc_barrier.arrive()\n\n        return\n\n    @cute.jit\n    def load(\n        self,\n        thr_mma_qk: cute.core.ThrMma,\n        thr_mma_pv: cute.core.ThrMma,\n        mQ: cute.Tensor,\n        mK: cute.Tensor,\n        mV: cute.Tensor,\n        sQ: cute.Tensor,\n        sK: cute.Tensor,\n        sV: cute.Tensor,\n        mPageTable: Optional[cute.Tensor],\n        tma_atom_Q: cute.CopyAtom,\n        tma_atom_K: Optional[cute.CopyAtom],\n        tma_atom_V: Optional[cute.CopyAtom],\n        pipeline_q: pipeline.PipelineAsync,\n        pipeline_kv: pipeline.PipelineAsync,\n        block_info: BlockInfo,\n        num_splits: Int32,\n        SeqlenInfoCls: Callable,\n        TileSchedulerCls: Callable,\n        blocksparse_tensors: Optional[BlockSparseTensors],\n    ):\n        num_load_threads = len(self.load_warp_ids) * cute.arch.WARP_SIZE\n        tidx = cute.arch.thread_idx()[0] % num_load_threads\n        warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())\n        q_producer_phase = Int32(1)\n        kv_producer_state = pipeline.make_pipeline_state(\n            pipeline.PipelineUserType.Producer, self.kv_stage\n        )\n        tile_scheduler = TileSchedulerCls()\n        work_tile = tile_scheduler.initial_work_tile_info()\n        while work_tile.is_valid_tile:\n            m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx\n            seqlen = SeqlenInfoCls(batch_idx)\n            mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx]\n            tiler_gQ = ((self.mma_tiler_qk[0] * self.q_stage), self.head_dim_padded)\n            gQ = cute.local_tile(mQ_cur, tiler_gQ, (m_block, 0))  # (128 * 2, 128)\n            gQ = layout_utils.select(\n                cute.flat_divide(gQ, (self.mma_tiler_qk[0],)), mode=[0, 2, 1]\n            )  # (128, 128, 2)\n\n            head_idx_kv = (\n                head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx\n            )\n            if const_expr(mPageTable is None):\n                if const_expr(not seqlen.has_cu_seqlens_k):\n                    mK_cur, mV_cur = [t[None, None, head_idx_kv, batch_idx] for t in (mK, mV)]\n                else:\n                    mK_cur = cute.domain_offset((seqlen.offset_k, 0), mK[None, None, head_idx_kv])\n                    mV_cur = cute.domain_offset((0, seqlen.offset_k), mV[None, None, head_idx_kv])\n                gK = cute.local_tile(mK_cur, cute.select(self.mma_tiler_qk, mode=[1, 2]), (None, 0))\n                gV = cute.local_tile(mV_cur, cute.select(self.mma_tiler_pv, mode=[1, 2]), (0, None))\n            else:\n                # Need to keep batch coord None since we'll index into it with page idx\n                mK_cur, mV_cur = [t[None, None, head_idx_kv, None] for t in (mK, mV)]\n                gK = cute.local_tile(\n                    mK_cur, cute.select(self.mma_tiler_qk, mode=[1, 2]), (None, 0, None)\n                )\n                gV = cute.local_tile(\n                    mV_cur, cute.select(self.mma_tiler_pv, mode=[1, 2]), (0, None, None)\n                )\n            tSgQ = thr_mma_qk.partition_A(gQ)\n            tSgK = thr_mma_qk.partition_B(gK)\n            tOgV = thr_mma_pv.partition_B(gV)\n            load_Q_fn, _, _ = copy_utils.tma_get_copy_fn(\n                tma_atom_Q, 0, cute.make_layout(1), tSgQ, sQ\n            )\n\n            if const_expr(self.use_tma_KV):\n                tKsK, tKgK = cpasync.tma_partition(\n                    tma_atom_K,\n                    0,  # no multicast\n                    cute.make_layout(1),\n                    cute.group_modes(sK, 0, 3),\n                    cute.group_modes(tSgK, 0, 3),\n                )\n                tVsV, tVgV = cpasync.tma_partition(\n                    tma_atom_V,\n                    0,  # no multicast\n                    cute.make_layout(1),\n                    cute.group_modes(sV, 0, 3),\n                    cute.group_modes(tOgV, 0, 3),\n                )\n                paged_kv_manager = None\n            else:\n                page_size = mK.shape[0]\n                paged_kv_manager = PagedKVManager.create(\n                    mPageTable,\n                    mK,\n                    mV,\n                    FastDivmodDivisor(page_size),\n                    batch_idx,\n                    head_idx_kv,\n                    tidx,\n                    seqlen.seqlen_k,\n                    0,  # leftpad_k\n                    self.n_block_size,\n                    self.head_dim_padded,\n                    self.head_dim_v_padded,\n                    num_load_threads,\n                    mK.element_type,\n                )\n                tKsK, tKgK = None, None\n                tVsV, tVgV = None, None\n\n            load_Q = partial(self.load_Q, load_Q_fn, pipeline_q=pipeline_q, phase=q_producer_phase)\n            load_K = partial(\n                self.load_KV,\n                tma_atom_K,\n                tKgK,\n                tKsK,\n                paged_kv_manager,\n                sK,\n                pipeline_kv=pipeline_kv,\n                K_or_V=\"K\",\n            )\n            load_V = partial(\n                self.load_KV,\n                tma_atom_V,\n                tVgV,\n                tVsV,\n                paged_kv_manager,\n                sV,\n                pipeline_kv=pipeline_kv,\n                K_or_V=\"V\",\n            )\n\n            if const_expr(not self.use_block_sparsity):\n                n_block_min, n_block_max = block_info.get_n_block_min_max(\n                    seqlen, m_block, split_idx, num_splits\n                )\n                if const_expr(not self.is_split_kv) or n_block_min < n_block_max:\n                    n_block_first = n_block_max - 1 if n_block_max > 0 else 0\n                    page_idx = (\n                        mPageTable[batch_idx, n_block_first]\n                        if const_expr(mPageTable is not None and self.use_tma_KV)\n                        else None\n                    )\n                    if const_expr(not self.use_tma_KV):\n                        paged_kv_manager.load_page_table(n_block_first)\n                    load_K(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx)  # K0\n                    # load_K(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx, extra_tx_count=self.tma_copy_bytes[\"Q\"])  # K0\n                    if const_expr(len(self.load_warp_ids) == 1) or warp_idx == self.load_warp_ids[0]:\n                        # load_Q(block=0, stage=0)  # Q0\n                        pipeline_q.producer_acquire_w_index_phase(0, q_producer_phase)\n                        # pipeline_q.sync_object_empty.wait(0, q_producer_phase)\n                        tma_bar_ptr = pipeline_q.sync_object_full.get_barrier(0)\n                        # tma_bar_ptr = pipeline_kv.producer_get_barrier(kv_producer_state)\n                        load_Q_fn(src_idx=0, dst_idx=0, tma_bar_ptr=tma_bar_ptr)\n                    kv_producer_state.advance()\n                    if const_expr(self.q_stage == 2) and (const_expr(len(self.load_warp_ids) == 1) or warp_idx == self.load_warp_ids[0]):\n                        # load_Q(block=1, stage=1)  # Q1\n                        pipeline_q.producer_acquire_w_index_phase(1, q_producer_phase)\n                        tma_bar_ptr = pipeline_q.sync_object_full.get_barrier(1)\n                        load_Q_fn(src_idx=1, dst_idx=1, tma_bar_ptr=tma_bar_ptr)\n                    q_producer_phase ^= 1\n                    load_V(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx)  # V0\n                    kv_producer_state.advance()\n                    for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1):\n                        n_block = n_block_max - 2 - i\n                        page_idx = (\n                            mPageTable[batch_idx, n_block]\n                            if const_expr(mPageTable is not None and self.use_tma_KV)\n                            else None\n                        )\n                        if const_expr(not self.use_tma_KV):\n                            paged_kv_manager.load_page_table(n_block)\n                    # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf(\"n_block = {}, page_idx = {}\", n_block, page_idx)\n                        load_K(block=n_block, producer_state=kv_producer_state, page_idx=page_idx)  # Ki\n                        kv_producer_state.advance()\n                        load_V(block=n_block, producer_state=kv_producer_state, page_idx=page_idx)  # Vi\n                        kv_producer_state.advance()\n\n            else:\n                kv_producer_state, q_producer_phase = produce_block_sparse_loads_sm100(\n                    blocksparse_tensors,\n                    batch_idx,\n                    head_idx,\n                    m_block,\n                    kv_producer_state,\n                    load_Q,\n                    load_K,\n                    load_V,\n                    pipeline_kv,\n                    self.q_stage,\n                    q_producer_phase,\n                    self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,\n                    self.q_subtile_factor if self.q_subtile_factor is not None else 1,\n                )\n\n            tile_scheduler.prefetch_next_work()\n            tile_scheduler.advance_to_next_work()\n            work_tile = tile_scheduler.get_current_work()\n            # End of persistent scheduler loop\n\n        pipeline_kv.producer_tail(kv_producer_state)\n        # This is equivalent to pipeline_q.producer_tail\n        if const_expr(len(self.load_warp_ids) == 1) or warp_idx == self.load_warp_ids[0]:\n            pipeline_q.producer_acquire_w_index_phase(self.q_stage - 1, q_producer_phase)\n\n    @cute.jit\n    def mma(\n        self,\n        tiled_mma_qk: cute.core.ThrMma,\n        tiled_mma_pv: cute.core.ThrMma,\n        sQ: cute.Tensor,\n        sK: cute.Tensor,\n        sV: cute.Tensor,\n        tStS: cute.Tensor,\n        tOtO: cute.Tensor,\n        tOrP: cute.Tensor,\n        pipeline_q: pipeline.PipelineAsync,\n        pipeline_kv: pipeline.PipelineAsync,\n        pipeline_s_p_o: pipeline.PipelineAsync,\n        pipeline_p_lastsplit: pipeline.PipelineAsync,\n        pipeline_o_acc: pipeline.PipelineAsync,\n        is_leader_cta: Boolean,\n        block_info: BlockInfo,\n        num_splits: Int32,\n        SeqlenInfoCls: Callable,\n        TileSchedulerCls: Callable,\n        blocksparse_tensors: Optional[BlockSparseTensors],\n    ):\n        tSrQ = tiled_mma_qk.make_fragment_A(sQ)\n        tSrK = tiled_mma_qk.make_fragment_B(sK)\n        tOrV = tiled_mma_pv.make_fragment_B(sV)\n        if const_expr(self.q_stage == 2):\n            tSrQs = (tSrQ[None, None, None, 0], tSrQ[None, None, None, 1])\n        else:\n            tSrQs = (tSrQ[None, None, None, 0],)\n\n        qk_mma_op, pv_mma_op = tiled_mma_qk.op, tiled_mma_pv.op\n        qk_mma_idesc, pv_mma_idesc = sm100_desc.mma_op_to_idesc(qk_mma_op), sm100_desc.mma_op_to_idesc(pv_mma_op)\n        q_smem_base = sm100_desc.smem_desc_base_from_tensor(sQ, sm100_desc.Major.K)\n        k_smem_base = sm100_desc.smem_desc_base_from_tensor(sK, sm100_desc.Major.K)\n        v_smem_base = sm100_desc.smem_desc_base_from_tensor(sV, sm100_desc.Major.MN)\n        q_smem_start = [sm100_desc.make_smem_desc_start_addr(sQ[None, None, None, stage].iterator) for stage in range(self.q_stage)]\n\n        sm100_utils.declare_ptx_smem_desc(q_smem_start[self.q_stage - 1], q_smem_base, tSrQ[None, None, None, 0].layout, var_name_prefix=\"fa_fwd_q_smem_desc\")\n        sm100_utils.declare_ptx_idesc(qk_mma_op, var_name=\"fa_fwd_qk_mma_idesc\")\n        sm100_utils.declare_ptx_idesc(pv_mma_op, var_name=\"fa_fwd_pv_mma_idesc\")\n\n        sQ_stage_stride = (sQ.layout.stride[-1] * sQ.element_type.width // 8) >> 4\n        if const_expr(self.q_stage == 1):\n            sQ_stage_stride = 0\n        gemm_Si = [\n            partial(\n                # sm100_utils.gemm_ptx_precomputed,\n                # self.tmem_s_offset[stage],\n                # smem_desc_start_a=q_smem_start[stage],\n                # idesc=qk_mma_idesc,\n                # smem_desc_base_a=q_smem_base,\n                # smem_desc_base_b=k_smem_base,\n                # tCrA_layout=tSrQ[None, None, None, 0].layout,\n                sm100_utils.gemm_ptx_precomputed_varname,\n                self.tmem_s_offset[stage],\n                # idesc=qk_mma_idesc,\n                smem_desc_base_b=k_smem_base,\n                tCrB_layout=tSrK[None, None, None, 0].layout,\n                smem_var_name_prefix=f\"fa_fwd_q_smem_desc\",\n                idesc_var_name=f\"fa_fwd_qk_mma_idesc\",\n                smem_offset=-sQ_stage_stride if stage == 0 else sQ_stage_stride,\n                zero_init=True,\n                cta_group=self.cta_group_size,\n            )\n            for stage in range(self.q_stage)\n        ]\n        # gemm_Si = [\n        #     partial(\n        #         sm100_utils.gemm,\n        #         tiled_mma_qk,\n        #         tStS[None, None, None, stage],\n        #         tCrA=tSrQ[None, None, None, stage],\n        #         zero_init=True,\n        #     )\n        #     for stage in range(self.q_stage)\n        # ]\n        gemm_Pi = [\n            partial(\n                # sm100_utils.gemm_ptx_precomputed,\n                sm100_utils.gemm_ptx_partial,\n                pv_mma_op,\n                self.tmem_o_offset[stage],\n                tOrP[None, None, None, stage],\n                sA=None,\n                split_arrive=self.split_P_arrive if self.split_P_arrive > 0 else None,\n                # smem_desc_start_a=tOrP[None, None, None, stage].iterator.toint(),\n                # smem_desc_start_a=self.tmem_p_offset[stage],\n                # idesc=pv_mma_idesc,\n                # smem_desc_base_a=None,\n                # smem_desc_base_b=v_smem_base,\n                # tCrA_layout=tOrP[None, None, None, 0].layout,\n                # tCrB_layout=tOrV[None, None, None, 0].layout\n                cta_group=self.cta_group_size,\n            )\n            for stage in range(self.q_stage)\n        ]\n        # gemm_Pi = [\n        #     partial(\n        #         sm100_utils.gemm, tOtO[None, None, None, stage], tCrA=tOrP[None, None, None, stage]\n        #     )\n        #     for stage in range(self.q_stage)\n        # ]\n\n        mma_q_consumer_phase = Int32(0)\n        mma_kv_consumer_state = pipeline.make_pipeline_state(\n            pipeline.PipelineUserType.Consumer, self.kv_stage\n        )\n        P_full_O_rescaled_phase = Int32(0)\n\n        tile_scheduler = TileSchedulerCls()\n        work_tile = tile_scheduler.initial_work_tile_info()\n        while work_tile.is_valid_tile:\n            m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx\n            seqlen = SeqlenInfoCls(batch_idx)\n\n            block_iter_count = Int32(0)\n            process_tile = False\n\n            if const_expr(self.use_block_sparsity):\n                block_iter_count = get_total_block_count(\n                    blocksparse_tensors,\n                    batch_idx,\n                    head_idx,\n                    m_block,\n                    self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,\n                    self.q_subtile_factor if self.q_subtile_factor is not None else 1,\n                )\n                process_tile = block_iter_count > Int32(0)\n            else:\n                n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits)\n                block_iter_count = n_block_max - n_block_min\n                if const_expr(not self.is_split_kv):\n                    process_tile = True\n                else:\n                    process_tile = n_block_min < n_block_max\n\n            if process_tile and is_leader_cta:\n                for stage in cutlass.range_constexpr(self.q_stage):\n                    # GEMM_QK00 (Q0 * K0 -> S0) or GEMM_QK01 (Q1 * K0 -> S1)\n                    # 1. wait for Q0 / Q1\n                    pipeline_q.consumer_wait_w_index_phase(stage, mma_q_consumer_phase)\n                    # 2. wait for K0\n                    if const_expr(stage == 0):\n                        pipeline_kv.consumer_wait(mma_kv_consumer_state)\n                    Ki_index, Ki_phase = mma_kv_consumer_state.index, mma_kv_consumer_state.phase\n                    tSrKi = tSrK[None, None, None, Ki_index]\n                    # We don't need to acquire empty S0 / S1.\n                    # For the first iteration, we don't need to wait as we're guaranteed S0 / S1\n                    # are empty. For subsequent iterations, the wait happened at the end\n                    # of the while loop.\n                    # 3. gemm\n                    # sm100_utils.gemm(tiled_mma_qk, tStS[None, None, None, stage], tSrQ[None, None, None, stage], tSrKi, zero_init=True)\n                    sK_cur = sK[None, None, None, Ki_index]\n                    if const_expr(self.uneven_kv_smem):\n                        sK_cur = self.offset_kv_smem(sK_cur, Ki_index, Ki_phase)\n                    # gemm_Si[stage](tCrB=tSrKi, sB=sK_cur)\n                    gemm_Si[stage](\n                        smem_desc_start_b=sm100_desc.make_smem_desc_start_addr(sK_cur.iterator)\n                    )\n                    # gemm_Si[stage](tCrB=tSrKi)\n                    # 4. release S0 / S1\n                    pipeline_s_p_o.producer_commit_w_index(stage)\n                mma_q_consumer_phase ^= 1\n                # 5. release K0\n                pipeline_kv.consumer_release(mma_kv_consumer_state)\n                mma_kv_consumer_state.advance()\n                # End of GEMM (Q1 * K0 -> S1)\n                # Note: Q0 & Q1 are still needed in the seqlen_kv loop\n                # so we need to release them after the seqlen_kv loop\n\n                # O hasn't been accumulated yet, its first MMA calculation doesn't need to accumulate\n                block_loop_count = block_iter_count - 1\n                O_should_accumulate = False\n                for i in cutlass.range(block_loop_count, unroll=1):\n                    # GEMM_PV00 (P0 * V0 -> O0_partial), O0 needs to be accumulated in the seqlen_kv loop\n                    # 1. wait for V0\n                    pipeline_kv.consumer_wait(mma_kv_consumer_state)\n                    mma_kv_release_state = mma_kv_consumer_state.clone()\n                    Vi_index, Vi_phase = mma_kv_consumer_state.index, mma_kv_consumer_state.phase\n                    tOrVi = tOrV[None, None, None, Vi_index]\n                    for stage in cutlass.range_constexpr(self.q_stage):\n                        # 2. acquire corrected O0/O1_partial and P0 / P1\n                        # For the first iteration in this work tile, waiting for O0/O1_partial\n                        # means that the correction warps has finished reading tO during\n                        # the last iteration of the previous work tile.\n                        pipeline_s_p_o.producer_acquire_w_index_phase(stage, P_full_O_rescaled_phase)\n                        # 3. gemm\n                        # sm100_utils.gemm(tiled_mma_pv, tOtO0, tOrP0, tOrVi, zero_init=True)\n                        # gemm_Pi[stage](tCrB=tOrVi, sB=sV[None, None, None, Vi_index], zero_init=not O_should_accumulate)\n                        sV_cur = sV[None, None, None, Vi_index]\n                        if const_expr(self.uneven_kv_smem):\n                            sV_cur = self.offset_kv_smem(sV_cur, Vi_index, Vi_phase)\n                        gemm_Pi[stage](\n                            tCrB=tOrVi,\n                            sB=sV_cur,\n                            # smem_desc_start_b=sm100_desc.make_smem_desc_start_addr(sV_cur.iterator),\n                            zero_init=not O_should_accumulate,\n                            mbar_ptr=pipeline_p_lastsplit.sync_object_full.get_barrier(stage) if self.split_P_arrive > 0 else None,\n                            mbar_phase=P_full_O_rescaled_phase,\n                        )\n                        # Don't need to signal O_full to the correction warps since the\n                        # correction warps wait for the softmax warps anyway. By the time the softmax\n                        # warps finished, S_i for the next iteration must have been done, so O_i-1\n                        # must have been done as well.\n                        # pipeline_o_acc.producer_commit_w_index(stage)\n                        # 4. release V(i-1)\n                        if const_expr(stage == self.q_stage - 1):\n                            pipeline_kv.consumer_release(mma_kv_release_state)\n                            mma_kv_release_state.advance()\n                        # End of GEMM_PV00 (P0 * V0 -> O0_partial)\n\n                        # GEMM_QK0i (Q0 * Ki -> S0)\n                        # 1. wait for Ki\n                        if const_expr(stage == 0):\n                            mma_kv_consumer_state.advance()\n                            pipeline_kv.consumer_wait(mma_kv_consumer_state)\n                        Ki_index, Ki_phase = mma_kv_consumer_state.index, mma_kv_consumer_state.phase\n                        # 2. gemm\n                        # Don't need to wait for the softmax warp to have finished reading the previous\n                        # Si, since this gemm is scheduled after the PV gemm, which guaranteed that Si\n                        # has been read and Pi has been written.\n                        # sm100_utils.gemm(tiled_mma_qk, tStS[None, None, None, stage], tSrQ[None, None, None, stage], tSrK[None, None, None, Ki_index], zero_init=True)\n                        sK_cur = sK[None, None, None, Ki_index]\n                        if const_expr(self.uneven_kv_smem):\n                            sK_cur = self.offset_kv_smem(sK_cur, Ki_index, Ki_phase)\n                        # gemm_Si[stage](tCrB=tSrK[None, None, None, Ki_index], sB=sK_cur)\n                        gemm_Si[stage](\n                            smem_desc_start_b=sm100_desc.make_smem_desc_start_addr(sK_cur.iterator)\n                        )\n                        # gemm_Si[stage](tCrB=tSrK[None, None, None, Ki_index])\n                        # 3. release S0 / S1\n                        pipeline_s_p_o.producer_commit_w_index(stage)\n                        # End of GEMM_QK0i (Q0 * Ki -> S0)\n                    # 4. release Ki\n                    pipeline_kv.consumer_release(mma_kv_consumer_state)\n                    mma_kv_consumer_state.advance()\n                    P_full_O_rescaled_phase ^= 1\n                    O_should_accumulate = True\n                # End of seqlen_kv loop\n\n                # release Q0 & Q1\n                for stage in cutlass.range(self.q_stage):\n                    pipeline_q.consumer_release_w_index(stage)\n\n                # GEMM_PV00 (P0 * V0 -> O0_partial), O0 needs to be accumulated in the seqlen_kv loop\n                # 1. wait for V0\n                pipeline_kv.consumer_wait(mma_kv_consumer_state)\n                Vi_index, Vi_phase = mma_kv_consumer_state.index, mma_kv_consumer_state.phase\n                tOrVi = tOrV[None, None, None, Vi_index]\n                for stage in cutlass.range_constexpr(self.q_stage):\n                    # 2. acquire corrected Oi_partial and Pi\n                    pipeline_s_p_o.producer_acquire_w_index_phase(stage, P_full_O_rescaled_phase)\n                    # 3. gemm\n                    # sm100_utils.gemm(tiled_mma_pv, tOtO0, tOrP0, tOrVi, zero_init=True)\n                    # gemm_Pi[stage](tCrB=tOrVi, sB=sV[None, None, None, Vi_index], zero_init=not O_should_accumulate)\n                    sV_cur = sV[None, None, None, Vi_index]\n                    if const_expr(self.uneven_kv_smem):\n                        sV_cur = self.offset_kv_smem(sV_cur, Vi_index, Vi_phase)\n                    gemm_Pi[stage](\n                        tCrB=tOrVi,\n                        sB=sV_cur,\n                        # smem_desc_start_b=sm100_desc.make_smem_desc_start_addr(sV_cur.iterator),\n                        zero_init=not O_should_accumulate,\n                        mbar_ptr=pipeline_p_lastsplit.sync_object_full.get_barrier(stage) if self.split_P_arrive > 0 else None,\n                        mbar_phase=P_full_O_rescaled_phase,\n                    )\n                    # 4. release accumulated O0_partial\n                    # We do need O_full here since for the last tile, by the time the softmax warp\n                    # has signaled to the correction warps, the softmax warp has just finished\n                    # computing the row sum of the current tile. It does not guarantee that the 1st\n                    # tile of the next work tile has been computed yet.\n                    pipeline_o_acc.producer_commit_w_index(stage)\n                    # End of GEMM_PV00 (P0 * V0 -> O0_partial)\n                P_full_O_rescaled_phase ^= 1\n                # 5. release Vi_end\n                pipeline_kv.consumer_release(mma_kv_consumer_state)\n                mma_kv_consumer_state.advance()\n                # End of GEMM_PV1(i_end) (P1 * Vi_end -> O1)\n\n            # Advance to next tile\n            tile_scheduler.advance_to_next_work()\n            work_tile = tile_scheduler.get_current_work()\n        # End of persistent scheduler loop\n\n        # We don't need pipeline_s_p_o.producer_tail() since there's no dangling mbarrier at the end\n        # pipeline_s_p_o.producer_acquire_w_index_phase(self.q_stage - 1, P_full_O_rescaled_phase)\n        # We don't need pipeline_o_acc.producer_tail() since we don't call\n        # pipeline_o_acc.producer_acquire() inside the loop.\n\n    # for both softmax0 and softmax1 warp group\n    @cute.jit\n    def softmax_loop(\n        self,\n        stage: int | Int32,\n        softmax_scale_log2: Float32,\n        softmax_scale: Float32,\n        thr_mma_qk: cute.core.ThrMma,\n        tStS: cute.Tensor,  # ((TILE_M, TILE_N), 1, 1, q_stage)\n        sScale: cute.Tensor,\n        mLSE: Optional[cute.Tensor],\n        pipeline_s_p_o: pipeline.PipelineAsync,\n        pipeline_p_lastsplit: pipeline.PipelineAsync,\n        pipeline_sm_stats: pipeline.PipelineAsync,\n        sm_stats_barrier: pipeline.NamedBarrier,\n        pipeline_s0_s1_sequence: Optional[pipeline.PipelineAsync],\n        learnable_sink: Optional[cute.Tensor],\n        block_info: BlockInfo,\n        num_splits: Int32,\n        SeqlenInfoCls: Callable,\n        AttentionMaskCls: Callable,\n        TileSchedulerCls: Callable,\n        aux_tensors: Optional[list] = None,\n        fastdiv_mods=(None, None),\n        head_divmod=None,\n        blocksparse_tensors: Optional[BlockSparseTensors] = None,\n    ):\n        \"\"\"Compute softmax on attention scores from QK matrix multiplication.\n\n        This method handles the softmax computation for either the first or second half of the\n        attention matrix, depending on the 'stage' parameter. It calculates row-wise maximum\n        and sum values needed for stable softmax computation, applies optional masking, and\n        transforms raw attention scores into probability distributions.\n\n        The implementation uses specialized memory access patterns and efficient math operations\n        for computing exp(x) using exp2 functions. It also coordinates pipeline\n        synchronization between MMA, correction, and sequence processing stages.\n        \"\"\"\n        tidx = cute.arch.thread_idx()[0] % (\n            cute.arch.WARP_SIZE\n            # * (len(self.softmax0_warp_ids) if stage == 0 else len(self.softmax1_warp_ids)\n            * (len(self.softmax0_warp_ids))\n        )\n        warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4\n\n        cta_qk_tiler = (self.mma_tiler_qk[0] // thr_mma_qk.thr_id.shape, self.mma_tiler_qk[1])\n        tSAcc = tStS[(None, None), 0, 0, stage]  # (128, 128)\n        tStScale = cute.composition(tSAcc, cute.make_layout((self.m_block_size, 1)))\n        tScS = thr_mma_qk.partition_C(cute.make_identity_tensor(self.mma_tiler_qk[:2]))\n        tScS = tScS[(None, None), 0, 0]  # (128, 128)\n        tScScale = cute.composition(tScS, cute.make_layout((self.m_block_size, 1)))\n\n        tilePlikeFP32 = self.mma_tiler_qk[1] // Float32.width * self.v_dtype.width\n        tStP_layout = cute.composition(\n            tSAcc.layout, cute.make_layout((self.m_block_size, tilePlikeFP32))\n        )\n        tStP = cute.make_tensor(tSAcc.iterator + self.tmem_s_to_p_offset, tStP_layout)\n\n        tmem_load_atom = cute.make_copy_atom(\n            tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype\n        )\n        thr_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tSAcc).get_slice(tidx)\n        tStS_t2r = thr_tmem_load.partition_S(tSAcc)  # (((32,32),1),1,4)\n\n        tmem_store_scale_atom = cute.make_copy_atom(\n            tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(1)), Float32\n        )\n        thr_tmem_store_scale = tcgen05.make_tmem_copy(tmem_store_scale_atom, tStScale).get_slice(\n            tidx\n        )\n        tStScale_r2t = thr_tmem_store_scale.partition_D(tStScale)\n        tmem_store_atom = cute.make_copy_atom(\n            tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(16)), Float32\n        )\n        thr_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP).get_slice(tidx)\n        tStP_r2t = thr_tmem_store.partition_D(tStP)  # (((16,32),1),1,4)\n\n        mma_si_consumer_phase = Int32(0)\n        sm_stats_producer_phase = Int32(1)\n        s0_s1_sequence_phase = Int32(1 if stage == 0 else 0)\n\n        # self.warp_scheduler_barrier_init()\n\n        warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4\n\n        tile_scheduler = TileSchedulerCls()\n        work_tile = tile_scheduler.initial_work_tile_info()\n        while work_tile.is_valid_tile:\n            m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx\n            seqlen = SeqlenInfoCls(batch_idx)\n            n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits)\n\n            mask = AttentionMaskCls(seqlen)\n            shared_mask_kwargs = dict(\n                m_block=(self.q_stage * m_block + stage) * self.cta_group_size,\n                thr_mma=thr_mma_qk,\n                thr_tmem_load=thr_tmem_load,\n                mask_causal=self.is_causal,\n                mask_local=self.is_local,\n                batch_idx=batch_idx,\n                head_idx=head_idx,\n                aux_tensors=aux_tensors,\n            )\n\n            # Recompute fastdiv_mods if necessary\n            recompute_fastdiv_mods_q = cutlass.const_expr(\n                aux_tensors is not None and (seqlen.has_cu_seqlens_q or seqlen.has_seqused_q)\n            )\n            recompute_fastdiv_mods_k = cutlass.const_expr(\n                aux_tensors is not None and (seqlen.has_cu_seqlens_k or seqlen.has_seqused_k)\n            )\n\n            if cutlass.const_expr(fastdiv_mods is not None):\n                seqlen_q_divmod, seqlen_k_divmod = fastdiv_mods\n                fastdiv_mods = (\n                    seqlen_q_divmod\n                    if not recompute_fastdiv_mods_q\n                    else FastDivmodDivisor(seqlen.seqlen_q),\n                    seqlen_k_divmod\n                    if not recompute_fastdiv_mods_k\n                    else FastDivmodDivisor(seqlen.seqlen_k),\n                )\n\n            mask_mod = self.mask_mod if const_expr(self.mask_mod is not None) else None\n            mask_fn = partial(\n                mask.apply_mask_sm100,\n                mask_mod=mask_mod,\n                fastdiv_mods=fastdiv_mods,\n                head_divmod=head_divmod,\n                **shared_mask_kwargs,\n            )\n            if const_expr(self.use_block_sparsity):\n                #  Full blocks dont need mask_mod\n                mask_fn_none = partial(\n                    mask.apply_mask_sm100,\n                    mask_mod=None,\n                    fastdiv_mods=fastdiv_mods,\n                    head_divmod=head_divmod,\n                    **shared_mask_kwargs,\n                )\n            else:\n                mask_fn_none = None\n\n            softmax = SoftmaxSm100.create(\n                softmax_scale_log2,\n                rescale_threshold=8.0 if const_expr(self.q_dtype.width == 16) else 0.0,\n                softmax_scale=softmax_scale,\n            )\n            softmax.reset()\n\n            if const_expr(self.use_block_sparsity):\n                tile_block_count = get_total_block_count(\n                    blocksparse_tensors,\n                    batch_idx,\n                    head_idx,\n                    m_block,\n                    self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,\n                    self.q_subtile_factor if self.q_subtile_factor is not None else 1,\n                )\n                has_work = tile_block_count > Int32(0)\n            else:\n                tile_block_count = n_block_max - n_block_min\n                has_work = const_expr(not self.is_split_kv) or tile_block_count > Int32(0)\n\n            softmax_step = partial(\n                self.softmax_step,\n                softmax=softmax,\n                thr_mma_qk=thr_mma_qk,\n                pipeline_s_p_o=pipeline_s_p_o,\n                pipeline_p_lastsplit=pipeline_p_lastsplit,\n                pipeline_sm_stats=pipeline_sm_stats,\n                sm_stats_barrier=sm_stats_barrier,\n                pipeline_s0_s1_sequence=pipeline_s0_s1_sequence,\n                thr_tmem_load=thr_tmem_load,\n                thr_tmem_store=thr_tmem_store,\n                thr_tmem_store_scale=thr_tmem_store_scale,\n                tStS_t2r=tStS_t2r,\n                tStScale_r2t=tStScale_r2t,\n                tStP_r2t=tStP_r2t,\n                sScale=sScale,\n                stage=stage,\n                batch_idx=batch_idx,\n                head_idx=head_idx,\n                m_block=(self.q_stage * m_block + stage) * self.cta_group_size,\n                seqlen=seqlen,\n                aux_tensors=aux_tensors,\n                fastdiv_mods=fastdiv_mods,\n                head_divmod=head_divmod,\n            )\n\n            if const_expr(self.use_block_sparsity) or has_work:\n                # See block_sparse_utils.py NOTE [SM100 block-sparse empty tiles: mbarrier contract].\n                pipeline_sm_stats.producer_acquire_w_index_phase(stage, sm_stats_producer_phase)\n                sm_stats_producer_phase ^= 1\n\n            # Block sparse or dense iteration\n            if const_expr(self.use_block_sparsity):\n                # When aux_tensors exist, Q indices beyond seqlen_q must be wrapped to avoid\n                # OOB aux_tensor access. Only edge tiles (where m_tile_end > seqlen_q) need this.\n                if const_expr(aux_tensors is not None):\n                    m_tile_end = ((self.q_stage * m_block + stage + 1) * self.cta_group_size) * self.m_block_size\n                    check_m_boundary = m_tile_end > seqlen.seqlen_q\n                else:\n                    check_m_boundary = False\n                (\n                    mma_si_consumer_phase,\n                    sm_stats_producer_phase,\n                    s0_s1_sequence_phase,\n                    empty_tile,\n                ) = softmax_block_sparse_sm100(\n                    blocksparse_tensors,\n                    batch_idx,\n                    head_idx,\n                    m_block,\n                    softmax_step,\n                    mask_fn,\n                    mask_fn_none,\n                    mma_si_consumer_phase,\n                    sm_stats_producer_phase,\n                    s0_s1_sequence_phase,\n                    pipeline_sm_stats,\n                    sm_stats_barrier,\n                    self.q_stage,\n                    Int32(stage),\n                    check_m_boundary,\n                    self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,\n                    self.q_subtile_factor if self.q_subtile_factor is not None else 1,\n                )\n                if not empty_tile:\n                    sScale[tidx + stage * self.m_block_size] = softmax.row_sum[0]\n                    if const_expr(mLSE is not None or learnable_sink is not None):\n                        sScale[\n                            tidx + stage * self.m_block_size + self.q_stage * self.m_block_size\n                        ] = softmax.row_max[0]\n                    # if tidx == 0:\n                    #     cute.printf(\"softmax row sum stage %d: %f, row_max = %f\\n\", stage, softmax.row_sum[0], softmax.row_max[0])\n                    # See block_sparse_utils.py NOTE [SM100 block-sparse empty tiles: mbarrier contract].\n                    # pipeline_sm_stats.producer_commit_w_index(stage)\n                    sm_stats_barrier.arrive_w_index(index=stage * 4 + warp_idx)\n                    # if tidx == 0: cute.printf(\"softmax row sum stage %d: %f\\n\", stage, softmax.row_sum[0])\n            else:\n                if const_expr(not self.is_split_kv) or tile_block_count > Int32(0):\n                    mma_si_consumer_phase, sm_stats_producer_phase, s0_s1_sequence_phase = softmax_step(\n                        mma_si_consumer_phase,\n                        sm_stats_producer_phase,\n                        s0_s1_sequence_phase,\n                        n_block_max - 1,\n                        is_first=True,\n                        mask_fn=partial(mask_fn, mask_seqlen=True),\n                    )\n                    n_block_max -= 1\n                    # Next couple of iterations with causal masking\n                    if const_expr(self.is_causal or self.is_local):\n                        n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask(\n                            seqlen, m_block, n_block_min\n                        )\n                        for n_tile in cutlass.range(n_block_max - n_block_min_causal_local_mask, unroll=1):\n                            n_block = n_block_max - 1 - n_tile\n                            mma_si_consumer_phase, sm_stats_producer_phase, s0_s1_sequence_phase = (\n                                softmax_step(\n                                    mma_si_consumer_phase,\n                                    sm_stats_producer_phase,\n                                    s0_s1_sequence_phase,\n                                    n_block,\n                                    mask_fn=partial(mask_fn, mask_seqlen=False),\n                                )\n                            )\n                        n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask)\n                    # The remaining iterations have no masking (but may still need mask_mod)\n                    n_block_min_before_local_mask = block_info.get_n_block_min_before_local_mask(\n                        seqlen, m_block, n_block_min\n                    )\n                    for n_tile in cutlass.range(n_block_max - n_block_min_before_local_mask, unroll=1):\n                        n_block = n_block_max - n_tile - 1\n                        if const_expr(self.mask_mod is not None):\n                            mma_si_consumer_phase, sm_stats_producer_phase, s0_s1_sequence_phase = softmax_step(\n                                mma_si_consumer_phase, sm_stats_producer_phase, s0_s1_sequence_phase, n_block,\n                                mask_fn=partial(mask_fn, mask_seqlen=False),\n                            )\n                        else:\n                            mma_si_consumer_phase, sm_stats_producer_phase, s0_s1_sequence_phase = softmax_step(\n                                mma_si_consumer_phase, sm_stats_producer_phase, s0_s1_sequence_phase, n_block,\n                            )\n                    # Separate iterations with local masking on the left\n                    if const_expr(self.is_local and block_info.window_size_left is not None):\n                        n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask)\n                        for n_tile in cutlass.range(0, n_block_max - n_block_min, unroll=1):\n                            n_block = n_block_max - 1 - n_tile\n                            mma_si_consumer_phase, sm_stats_producer_phase, s0_s1_sequence_phase = (\n                                softmax_step(\n                                    mma_si_consumer_phase,\n                                    sm_stats_producer_phase,\n                                    s0_s1_sequence_phase,\n                                    n_block,\n                                    mask_fn=partial(mask_fn, mask_seqlen=False),\n                                )\n                            )\n                            # Now that we no longer already have the 1st iteration, need mask_seqlen=True here\n\n                    # Dense path always writes scale / signals\n                    sScale[tidx + stage * self.m_block_size] = softmax.row_sum[0]\n                    if const_expr(mLSE is not None or learnable_sink is not None):\n                        sScale[\n                            tidx + stage * self.m_block_size + self.q_stage * self.m_block_size\n                        ] = softmax.row_max[0]\n                    # pipeline_sm_stats.producer_commit_w_index(stage)\n                    sm_stats_barrier.arrive_w_index(index=stage * 4 + warp_idx)\n\n            # # Write LSE to gmem\n            # if const_expr(mLSE is not None):\n            #     acc_O_mn_row_is_zero_or_nan = softmax.row_sum[0] == 0.0 or softmax.row_sum[0] != softmax.row_sum[0]\n            #     scale = (\n            #         cute.arch.rcp_approx(softmax.row_sum[0] if not acc_O_mn_row_is_zero_or_nan else 1.0)\n            #     )\n            #     LN2 = math.log(2.0)\n            #     lse = (\n            #         (softmax.row_max[0] * softmax.scale_log2 + cute.math.log2(softmax.row_sum[0], fastmath=True)) * LN2\n            #         if not acc_O_mn_row_is_zero_or_nan else -Float32.inf\n            #     )\n            #     if const_expr(not seqlen.has_cu_seqlens_q):\n            #         mLSE_cur = mLSE[None, head_idx, batch_idx]\n            #     else:\n            #         mLSE_cur = cute.domain_offset((seqlen.offset_q,), mLSE[None, head_idx])\n            #     gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (m_block * 2 + stage,))\n            #     if tidx < seqlen.seqlen_q - (m_block * 2 + stage) * self.m_block_size:\n            #         gLSE[tidx] = lse\n\n            # Advance to next tile\n            tile_scheduler.advance_to_next_work()\n            work_tile = tile_scheduler.get_current_work()\n        # End of persistent scheduler loop\n\n        # This is equivalent to pipeline_sm_stats.producer_tail\n        pipeline_sm_stats.producer_acquire_w_index_phase(stage, sm_stats_producer_phase)\n        # This is equivalent to pipeline_s0_s1.producer_tail\n        if const_expr(self.s0_s1_barrier):\n            if stage == 0:\n                pipeline_s0_s1_sequence.sync_object_full.wait(stage, s0_s1_sequence_phase)\n\n    @cute.jit\n    def softmax_step(\n        self,\n        mma_si_consumer_phase: Int32,\n        sm_stats_producer_phase: Int32,\n        s0_s1_sequence_phase: Int32,\n        n_block: Int32,\n        softmax: SoftmaxSm100,\n        thr_mma_qk: cute.core.ThrMma,\n        pipeline_s_p_o: pipeline.PipelineAsync,\n        pipeline_p_lastsplit: pipeline.PipelineAsync,\n        pipeline_sm_stats: pipeline.PipelineAsync,\n        sm_stats_barrier: pipeline.NamedBarrier,\n        pipeline_s0_s1_sequence: Optional[pipeline.PipelineAsync],\n        thr_tmem_load: cute.CopyAtom,\n        thr_tmem_store: cute.CopyAtom,\n        thr_tmem_store_scale: cute.CopyAtom,\n        tStS_t2r: cute.Tensor,\n        tStScale_r2t: cute.Tensor,\n        tStP_r2t: cute.Tensor,\n        sScale: cute.Tensor,\n        stage: int | Int32,\n        batch_idx: Int32,\n        head_idx: Int32,\n        m_block: Int32,\n        seqlen,\n        aux_tensors: Optional[list] = None,\n        fastdiv_mods=(None, None),\n        head_divmod=None,\n        mask_fn: Optional[Callable] = None,\n        is_first: bool = False,\n    ) -> Tuple[cute.Int32, cute.Int32, cute.Int32]:\n        \"\"\"Perform a single step of the softmax computation on a block of attention scores.\n\n        This method processes one block of the attention matrix, computing numerically stable\n        softmax by first finding the row maximum, subtracting it from all elements, applying\n        exponential function, and then normalizing by the sum of exponentials. It also handles\n        optional masking of attention scores.\n\n        The method involves several key operations:\n        1. Loading attention scores from tensor memory\n        2. Applying optional masking based on position\n        3. Computing row-wise maximum values for numerical stability\n        4. Transforming scores using exp2(x*scale - max*scale)\n        5. Computing row sums for normalization\n        6. Coordinating pipeline synchronization between different processing stages\n        \"\"\"\n        warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4\n        tilePlikeFP32 = self.mma_tiler_qk[1] // Float32.width * self.v_dtype.width\n        tScS = thr_mma_qk.partition_C(cute.make_identity_tensor(self.mma_tiler_qk[:2]))\n        tScS = tScS[(None, None), 0, 0]  # (128, 128)\n        # tScScale = cute.composition(tScS, cute.make_layout((self.m_block_size, 1)))\n        cta_qk_tiler = (self.mma_tiler_qk[0] // thr_mma_qk.thr_id.shape, self.mma_tiler_qk[1])\n        tScS_shape = cta_qk_tiler  # (128, 128)\n        tScP_shape = (tScS_shape[0], tilePlikeFP32)  # (128, 64)\n\n        # Wait for Si\n        pipeline_s_p_o.consumer_wait_w_index_phase(stage, mma_si_consumer_phase)\n        tSrS_t2r = cute.make_fragment(thr_tmem_load.partition_D(tScS).shape, self.qk_acc_dtype)\n        cute.copy(thr_tmem_load, tStS_t2r, tSrS_t2r)\n        # tSrS_t2r = copy_utils.load_t2r(thr_tmem_load, tScS_shape, tStS_t2r)\n        if cutlass.const_expr(self.score_mod is not None):\n            self.apply_score_mod(\n                tSrS_t2r,\n                thr_tmem_load,\n                thr_mma_qk,\n                batch_idx,\n                head_idx,\n                m_block,\n                n_block,\n                softmax,\n                seqlen,\n                aux_tensors,\n                fastdiv_mods,\n                head_divmod,\n            )\n\n        if const_expr(mask_fn is not None):\n            mask_fn(tSrS_t2r, n_block=n_block)\n        row_max, acc_scale = softmax.update_row_max(tSrS_t2r.load(), is_first)\n\n        if const_expr(not is_first):\n            # tSrScale_r2t = cute.make_fragment(thr_tmem_store_scale.partition_S(tScScale).shape, Float32)\n            # tSrScale_r2t[0] = acc_scale\n            # cute.copy(thr_tmem_store_scale, tSrScale_r2t, tStScale_r2t)\n            # cute.arch.fence_view_async_tmem_store()\n            thread_idx = thr_tmem_load.thr_idx\n            sScale[thread_idx + stage * self.m_block_size] = acc_scale\n            # if thread_idx == 0: cute.printf(\"softmax acc_scale stage %d: %f, row_max = %f\\n\", stage, acc_scale, row_max)\n        # Notify correction wg that row_max is ready\n        # pipeline_sm_stats.producer_commit_w_index(stage)\n        sm_stats_barrier.arrive_w_index(index=stage * 4 + warp_idx)\n\n        # if thread_idx == 0 and stage == 0: cute.print_tensor(tSrS_t2r)\n        softmax.scale_subtract_rowmax(tSrS_t2r, row_max)\n        # Sequence barrier wait\n        if const_expr(self.s0_s1_barrier):\n            pipeline_s0_s1_sequence.sync_object_full.wait(stage, s0_s1_sequence_phase)\n        tSrP_r2t_f32 = cute.make_fragment(\n            thr_tmem_store.partition_S(cute.make_identity_tensor(tScP_shape)).shape, Float32\n        )\n        tSrP_r2t = cute.make_tensor(\n            cute.recast_ptr(tSrP_r2t_f32.iterator, dtype=self.q_dtype), tSrS_t2r.layout\n        )\n        # softmax.scale_apply_exp2_convert(tSrS_t2r, row_max, tSrP_r2t)\n        softmax.apply_exp2_convert(\n            tSrS_t2r,\n            tSrP_r2t,\n            ex2_emu_freq=self.ex2_emu_freq if const_expr(mask_fn is None) else 0,\n            ex2_emu_start_frg=self.ex2_emu_start_frg,\n        )\n        # Sequence barrier arrive\n        if const_expr(self.s0_s1_barrier):\n            pipeline_s0_s1_sequence.sync_object_full.arrive(1 - stage, dst=None)\n        # print(tSrP_r2t_f32, tStP_r2t)\n        # cute.copy(thr_tmem_store, tSrP_r2t_f32, tStP_r2t)\n        for i in cutlass.range_constexpr(cute.size(tStP_r2t.shape[2])):\n            cute.copy(thr_tmem_store, tSrP_r2t_f32[None, None, i], tStP_r2t[None, None, i])\n            if const_expr(self.split_P_arrive > 0):\n                split_P_arrive_idx = cute.size(tStP_r2t.shape[2]) * self.split_P_arrive // self.n_block_size\n                if const_expr(i + 1 == split_P_arrive_idx):\n                    # Notify mma warp that the 1st half of P is ready\n                    cute.arch.fence_view_async_tmem_store()\n                    pipeline_s_p_o.consumer_release_w_index(stage)\n        # Notify mma warp that the 2nd half of P is ready\n        cute.arch.fence_view_async_tmem_store()\n        if const_expr(self.split_P_arrive > 0):\n            cute.arch.sync_warp()\n            with cute.arch.elect_one():\n                pipeline_p_lastsplit.producer_commit_w_index(stage)\n        else:\n            pipeline_s_p_o.consumer_release_w_index(stage)\n        pipeline_sm_stats.producer_acquire_w_index_phase(stage, sm_stats_producer_phase)\n        softmax.update_row_sum(tSrS_t2r.load(), acc_scale, is_first)\n        # acc_scale = cute.math.exp2(acc_scale_, fastmath=True)\n        return mma_si_consumer_phase ^ 1, sm_stats_producer_phase ^ 1, s0_s1_sequence_phase ^ 1\n\n    @cute.jit\n    def correction_loop(\n        self,\n        thr_mma_qk: cute.core.ThrMma,\n        thr_mma_pv: cute.core.ThrMma,\n        tStS: cute.Tensor,\n        tOtO: cute.Tensor,\n        sScale: cute.Tensor,\n        mO: cute.Tensor,\n        mLSE: cute.Tensor,\n        sO: cute.Tensor,\n        pipeline_s_p_o: pipeline.PipelineAsync,\n        pipeline_o_acc: pipeline.PipelineAsync,\n        pipeline_sm_stats: pipeline.PipelineAsync,\n        sm_stats_barrier: pipeline.NamedBarrier,\n        pipeline_o_epi: pipeline.PipelineAsync,\n        learnable_sink: Optional[cute.Tensor],\n        gmem_tiled_copy_O: cute.TiledCopy,\n        tma_atom_O: cute.CopyAtom,\n        softmax_scale_log2: Float32,\n        block_info: BlockInfo,\n        num_splits: Int32,\n        SeqlenInfoCls: Callable,\n        TileSchedulerCls: Callable,\n        blocksparse_tensors: Optional[BlockSparseTensors] = None,\n    ):\n        tidx = cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.correction_warp_ids))\n        warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4\n        mma_tile_coord_v = thr_mma_qk.thr_idx\n\n        tScS = thr_mma_qk.partition_C(cute.make_identity_tensor(self.mma_tiler_qk[:2]))\n        tStScale_layout = cute.composition(tStS.layout, cute.make_layout((self.m_block_size, 1)))\n        tStScales = tuple(\n            cute.make_tensor(tStS.iterator + self.tmem_vec_offset[stage], tStScale_layout)\n            for stage in range(self.q_stage)\n        )\n        tScScale = cute.composition(tScS, cute.make_layout((self.m_block_size, 1)))\n        tmem_load_v_atom = cute.make_copy_atom(\n            tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(1)), self.qk_acc_dtype\n        )\n        thr_tmem_load_vec = tcgen05.make_tmem_copy(tmem_load_v_atom, tStScales[0]).get_slice(tidx)\n\n        tStScales_t2r = [thr_tmem_load_vec.partition_S(tStScales[stage]) for stage in range(self.q_stage)]\n        tSrScale_t2r_shape = thr_tmem_load_vec.partition_D(tScScale).shape\n\n        # First iter: no correction is required\n        # Notify mma warp that O has been rescaled\n        for stage in cutlass.range(self.q_stage):\n            pipeline_s_p_o.consumer_release_w_index(stage)\n\n        sm_stats_consumer_phase = Int32(0)\n        o_corr_consumer_phase = Int32(0)\n        corr_epi_producer_phase = Int32(1)\n\n        tile_scheduler = TileSchedulerCls()\n        work_tile = tile_scheduler.initial_work_tile_info()\n        while work_tile.is_valid_tile:\n            m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx\n            seqlen = SeqlenInfoCls(batch_idx)\n            n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits)\n\n            if const_expr(self.is_split_kv):\n                mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx, split_idx]\n            else:\n                mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx]\n            tiler_gO = ((self.mma_tiler_pv[0] * self.q_stage), self.head_dim_v_padded)\n            gO = cute.local_tile(mO_cur, tiler_gO, (m_block, 0))  # (128 * 2, 128)\n            gO = layout_utils.select(\n                cute.flat_divide(gO, (self.mma_tiler_pv[0],)), mode=[0, 2, 1]\n            )  # (128, 128, 2)\n            gO = cute.flat_divide(gO, (self.mma_tiler_pv[0] // self.cta_group_size,))[None, mma_tile_coord_v, None, None]\n\n            # Default LSE to -inf for invalid split_idx tiles\n            stats = [(0.0, -Float32.inf if const_expr(mLSE is not None or learnable_sink is not None) else None, True)] * self.q_stage\n\n            if const_expr(self.use_block_sparsity):\n                total_block_count = get_total_block_count(\n                    blocksparse_tensors,\n                    batch_idx,\n                    head_idx,\n                    m_block,\n                    self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,\n                    self.q_subtile_factor if self.q_subtile_factor is not None else 1,\n                )\n                has_work = total_block_count > Int32(0)\n            else:\n                total_block_count = n_block_max - n_block_min\n                has_work = const_expr(not self.is_split_kv) or total_block_count > Int32(0)\n\n            if has_work:\n                # Ignore first signal from softmax as no correction is required\n                # pipeline_sm_stats.consumer_wait_w_index_phase(0, sm_stats_consumer_phase)\n                sm_stats_barrier.arrive_and_wait_w_index(index=0 * 4 + warp_idx)\n                pipeline_sm_stats.consumer_release_w_index(0)\n                if const_expr(self.q_stage == 2):\n                    # pipeline_sm_stats.consumer_wait_w_index_phase(1, sm_stats_consumer_phase)\n                    sm_stats_barrier.arrive_and_wait_w_index(index=1 * 4 + warp_idx)\n                sm_stats_consumer_phase ^= 1\n\n                tSrScale_t2r = cute.make_fragment(tSrScale_t2r_shape, Float32)\n                for i in cutlass.range(total_block_count - 1, unroll=1):\n                    for stage in cutlass.range_constexpr(self.q_stage):\n                        # wait for S0 / S1\n                        # pipeline_sm_stats.consumer_wait_w_index_phase(stage, sm_stats_consumer_phase)\n                        sm_stats_barrier.arrive_and_wait_w_index(index=stage * 4 + warp_idx)\n                        # cute.copy(tiled_tmem_load_vec, tStScales_t2r[stage], tSrScale_t2r)\n                        # cute.arch.fence_view_async_tmem_load()\n                        # scale = tSrScale_t2r[0]\n                        scale = sScale[tidx + stage * self.m_block_size]\n                        should_rescale = cute.arch.vote_ballot_sync(scale < 1.0) != 0\n                        # should_rescale = True\n                        # if tidx == 0: cute.printf(\"Correction scale i = %d, for stage %d: %f, should_rescale = %d\\n\", i, stage, scale, should_rescale)\n                        # Don't need O_full anymore, since by the time softmax has signaled the correction\n                        # warps, S_i must have been done, so O_i-1 must have been done as well.\n                        # pipeline_o_acc.consumer_wait_w_index_phase(stage, o_corr_consumer_phase)\n                        if should_rescale:\n                            self.correction_rescale(thr_mma_pv, tOtO[None, None, None, stage], tidx, scale)\n                        # Notify mma warp that O has been rescaled\n                        pipeline_s_p_o.consumer_release_w_index(stage)\n                        pipeline_sm_stats.consumer_release_w_index(self.q_stage - 1 - stage)\n                    sm_stats_consumer_phase ^= 1\n                    # o_corr_consumer_phase ^= 1\n                if const_expr(self.q_stage == 2):\n                    pipeline_sm_stats.consumer_release_w_index(1)\n                # End of seqlen_corr_loop_steps\n\n                # Even in the case of self.overlap_sO_sQ, we can write to stage 0 of sO without\n                # additional sync because the MMA in the top half must have been done.\n                # Similarly we can write to stage 1 of sO without additional sync.\n                learnable_sink_val = [None] * self.q_stage\n                if const_expr(learnable_sink is not None):\n                    if const_expr(not self.pack_gqa):\n                        sink_val = Float32(learnable_sink[head_idx])\n                        learnable_sink_val = [sink_val] * self.q_stage\n                    else:  # Each thread might have a different sink value due to different q_head\n                        for stage in cutlass.range_constexpr(self.q_stage):\n                            q_head_idx = (\n                                ((m_block * self.q_stage + stage) * self.cta_group_size + mma_tile_coord_v) * self.m_block_size + tidx\n                            ) % self.qhead_per_kvhead + head_idx * self.qhead_per_kvhead\n                            learnable_sink_val[stage] = Float32(learnable_sink[q_head_idx])\n                for stage in cutlass.range_constexpr(self.q_stage):\n                    # pipeline_sm_stats.consumer_wait_w_index_phase(stage, sm_stats_consumer_phase)\n                    sm_stats_barrier.arrive_and_wait_w_index(index=stage * 4 + warp_idx)\n                    # cute.copy(tiled_tmem_load_vec, tStScales_t2r[stage], tSrScale_t2r)\n                    # cute.arch.fence_view_async_tmem_load()\n                    # scale = tSrScale_t2r[0]\n                    row_sum = sScale[tidx + stage * self.m_block_size]\n                    if const_expr(mLSE is not None or learnable_sink is not None):\n                        row_max = sScale[tidx + stage * self.m_block_size + self.q_stage * self.m_block_size]\n                    else:\n                        row_max = None\n                    pipeline_sm_stats.consumer_release_w_index(stage)\n                    if const_expr(learnable_sink is not None):\n                        LOG2_E = math.log2(math.e)\n                        sink_val = learnable_sink_val[stage]\n                        if const_expr(not self.is_split_kv) or split_idx == 0:\n                            if row_max == -Float32.inf:\n                                # It's possible to have an empty row with splitKV.\n                                row_max = sink_val * (LOG2_E / softmax_scale_log2)\n                                row_sum = Float32(1.0)\n                            else:\n                                row_sum += cute.math.exp2(\n                                    sink_val * LOG2_E - row_max * softmax_scale_log2, fastmath=True\n                                )\n                    acc_O_mn_row_is_zero_or_nan = row_sum == 0.0 or row_sum != row_sum\n                    stats[stage] = (row_sum, row_max, acc_O_mn_row_is_zero_or_nan)\n                    scale = cute.arch.rcp_approx(row_sum if not acc_O_mn_row_is_zero_or_nan else 1.0)\n                    # Wait for the last O to be ready from the MMA warp\n                    pipeline_o_acc.consumer_wait_w_index_phase(stage, o_corr_consumer_phase)\n                    if const_expr(not self.use_correction_warps_for_epi):\n                        pipeline_o_epi.producer_acquire_w_index_phase(stage, corr_epi_producer_phase)\n                    self.correction_epilogue(\n                        thr_mma_pv,\n                        tOtO[None, None, None, stage],\n                        tidx,\n                        stage,\n                        m_block,\n                        seqlen.seqlen_q,\n                        scale,\n                        sO[None, None, stage],\n                        mO_cur,\n                        gO[None, None, stage],\n                        gmem_tiled_copy_O,\n                    )\n                    # Signal for the next work tile that O buffers in tmem are already read, so\n                    # mma warp can write to them\n                    pipeline_s_p_o.consumer_release_w_index(stage)\n                    if const_expr(not self.use_correction_warps_for_epi):\n                        pipeline_o_epi.producer_commit_w_index(stage)\n                    # if tidx == 0: cute.printf(\"Correction final scale for stage %d: %f\\n\", stage, scale)\n\n                o_corr_consumer_phase ^= 1\n                sm_stats_consumer_phase ^= 1\n                corr_epi_producer_phase ^= 1\n            else:\n                gmem_tiled_copy_O_for_empty_tile = None\n                if const_expr(self.use_correction_warps_for_epi):\n                    gmem_tiled_copy_O_for_empty_tile = gmem_tiled_copy_O\n                if const_expr(self.use_block_sparsity):\n                    (\n                        sm_stats_consumer_phase,\n                        o_corr_consumer_phase,\n                        corr_epi_producer_phase,\n                    ) = handle_block_sparse_empty_tile_correction_sm100(\n                        tidx,\n                        self.q_stage,\n                        self.m_block_size,\n                        self.qhead_per_kvhead,\n                        self.pack_gqa,\n                        self.is_split_kv,\n                        learnable_sink,\n                        mLSE,\n                        seqlen,\n                        m_block,\n                        head_idx,\n                        batch_idx,\n                        split_idx,\n                        sScale,\n                        stats,\n                        self.correction_epilogue,\n                        thr_mma_pv,\n                        tOtO,\n                        sO,\n                        pipeline_sm_stats,\n                        sm_stats_barrier,\n                        pipeline_o_epi,\n                        sm_stats_consumer_phase,\n                        o_corr_consumer_phase,\n                        corr_epi_producer_phase,\n                        softmax_scale_log2,\n                        mO_cur,\n                        gO,\n                        gmem_tiled_copy_O_for_empty_tile,\n                    )\n\n            if const_expr(mLSE is not None):\n                if const_expr(not seqlen.has_cu_seqlens_q):\n                    if const_expr(self.is_split_kv):\n                        mLSE_cur = mLSE[None, head_idx, batch_idx, split_idx]\n                    else:\n                        mLSE_cur = mLSE[None, head_idx, batch_idx]\n                else:\n                    offset = (\n                        seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q)\n                    )\n                    if const_expr(self.is_split_kv):\n                        mLSE_cur = cute.domain_offset((offset,), mLSE[None, head_idx, split_idx])\n                    else:\n                        mLSE_cur = cute.domain_offset((offset,), mLSE[None, head_idx])\n                for stage in cutlass.range_constexpr(self.q_stage):\n                    m_tile_idx = (m_block * self.q_stage + stage) * self.cta_group_size + mma_tile_coord_v\n                    gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (m_tile_idx,))\n                    row_sum, row_max, acc_O_mn_row_is_zero_or_nan = stats[stage]\n                    # if tidx == 0 and stage <= 1:\n                    #     cute.printf(\"row_sum = {}, row_max = {}, acc_O_mn_row_is_zero_or_nan = {}\\n\", row_sum, row_max, acc_O_mn_row_is_zero_or_nan)\n                    LN2 = math.log(2.0)\n                    lse = (\n                        (row_max * softmax_scale_log2 + cute.math.log2(row_sum, fastmath=True)) * LN2\n                        if not acc_O_mn_row_is_zero_or_nan\n                        else -Float32.inf\n                    )\n                    seqlen_q = (\n                        seqlen.seqlen_q\n                        if const_expr(not self.pack_gqa)\n                        else seqlen.seqlen_q * self.qhead_per_kvhead\n                    )\n                    if tidx < seqlen_q - m_tile_idx * self.m_block_size:\n                        # This actually just works with PackGQA too\n                        gLSE[tidx] = lse\n\n            # Advance to next tile\n            tile_scheduler.advance_to_next_work()\n            work_tile = tile_scheduler.get_current_work()\n        # End of persistent scheduler loop\n\n        # This is equivalent to pipeline_o_epi.consumer_tail() for the correction warps\n        if const_expr(not self.use_correction_warps_for_epi):\n            pipeline_o_epi.producer_acquire_w_index_phase(self.q_stage - 1, corr_epi_producer_phase)\n\n    @cute.jit\n    def correction_rescale(\n        self,\n        thr_mma: cute.core.ThrMma,\n        tOtO: cute.Tensor,\n        tidx: Int32,\n        scale: Float32,\n    ):\n        \"\"\"Rescale intermediate attention results based on softmax normalization factor.\n\n        This method performs a crucial correction step in the attention computation pipeline.\n        When processing attention in blocks, the softmax normalization factors may change\n        as new blocks are processed. This method rescales previously computed partial\n        output values to account for updated normalization factors.\n\n        The implementation uses efficient tensor memory operations to:\n        1. Load existing partial attention output from tensor memory\n        2. Apply the scaling factor to all elements\n        3. Store the rescaled results back to tensor memory\n        \"\"\"\n        tOcO = thr_mma.partition_C(cute.make_identity_tensor(self.mma_tiler_pv[:2]))\n        corr_tile_size = 16  # tuneable parameter\n        tmem_load_atom = cute.make_copy_atom(\n            tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), self.pv_acc_dtype\n        )\n        tmem_store_atom = cute.make_copy_atom(\n            tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)),\n            self.pv_acc_dtype,\n        )\n        tOtO_i = cute.composition(tOtO, cute.make_layout((self.m_block_size, corr_tile_size)))\n        tOcO_i = cute.composition(tOcO, cute.make_layout((self.m_block_size, corr_tile_size)))\n        thr_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tOtO_i).get_slice(tidx)\n        thr_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tOtO_i).get_slice(tidx)\n        tOtO_t2r = thr_tmem_load.partition_S(tOtO_i)\n        tOrO_t2r_shape = thr_tmem_load.partition_D(tOcO_i).shape\n        tOtO_r2t = thr_tmem_store.partition_D(tOtO_i)\n\n        frg_count = self.head_dim_v_padded // corr_tile_size\n        tOrO_frg = cute.make_fragment((tOrO_t2r_shape, frg_count), self.pv_acc_dtype)\n        for i in cutlass.range_constexpr(frg_count):\n            tOrO_frg = cute.make_fragment(tOrO_t2r_shape, self.pv_acc_dtype)\n            tOtO_t2r_i = cute.make_tensor(tOtO_t2r.iterator + i * corr_tile_size, tOtO_t2r.layout)\n            cute.copy(thr_tmem_load, tOtO_t2r_i, tOrO_frg)\n            for j in cutlass.range(0, cute.size(tOrO_frg), 2, unroll_full=True):\n                tOrO_frg[j], tOrO_frg[j + 1] = cute.arch.mul_packed_f32x2(\n                    (tOrO_frg[j], tOrO_frg[j + 1]), (scale, scale)\n                )\n            tOtO_r2t_i = cute.make_tensor(tOtO_r2t.iterator + i * corr_tile_size, tOtO_r2t.layout)\n            cute.copy(thr_tmem_store, tOrO_frg, tOtO_r2t_i)\n        cute.arch.fence_view_async_tmem_store()\n\n    @cute.jit\n    def correction_epilogue(\n        self,\n        thr_mma: cute.core.ThrMma,\n        tOtO: cute.Tensor,\n        tidx: Int32,\n        stage: Int32,\n        m_block: Int32,\n        seqlen_q: Int32,\n        scale: Float32,\n        sO: cute.Tensor,\n        mO_cur: Optional[cute.Tensor] = None,\n        gO: Optional[cute.Tensor] = None,\n        gmem_tiled_copy_O: Optional[cute.TiledCopy] = None,\n    ):\n        \"\"\"Apply final scaling and transformation to attention output before writing to global memory.\n\n        This correction_epilogue function handles the final processing step for attention output values.\n        It applies a scaling factor to the accumulated attention results and prepares the\n        data for efficient transfer back to global memory.\n\n        The method performs:\n        1. Loading of accumulated attention results from tensor memory\n        2. Application of the final output scaling factor\n        3. Type conversion if necessary (typically from higher precision accumulator to output precision)\n        4. Reorganization of data for optimal memory access patterns\n        5. Preparation for efficient TMA store operations\n\n        :param thr_mma: Thread MMA operation for the computation\n        :type thr_mma: cute.core.ThrMma\n        :param tOtO: Tensor containing accumulated attention output\n        :type tOtO: cute.Tensor\n        :param scale: Final scaling factor to apply to the output\n        :type scale: Float32\n        :param sO: Shared memory tensor for the final output\n        :type sO: cute.Tensor\n        \"\"\"\n\n        corr_tile_size = 8 * 32 // self.o_dtype.width\n        # Use CTA 0 mapping for smem partitioning since sO is per-CTA sized\n        tOsO = thr_mma.get_slice(0).partition_C(sO)\n        tOcO = thr_mma.partition_C(cute.make_identity_tensor(self.mma_tiler_pv[:2]))\n\n        tOtO_i = cute.logical_divide(tOtO, cute.make_layout((self.m_block_size, corr_tile_size)))\n        tOcO_i = cute.logical_divide(tOcO, cute.make_layout((self.m_block_size, corr_tile_size)))\n        tOsO_i = cute.logical_divide(tOsO, cute.make_layout((self.m_block_size, corr_tile_size)))\n\n        epi_subtile = (self.epi_tile[0], corr_tile_size)\n        tmem_copy_atom = sm100_utils_basic.get_tmem_load_op(\n            self.mma_tiler_pv,\n            self.o_layout,\n            self.o_dtype,\n            self.pv_acc_dtype,\n            epi_subtile,\n            use_2cta_instrs=self.use_2cta_instrs,\n        )\n        tiled_tmem_load = tcgen05.make_tmem_copy(tmem_copy_atom, tOtO_i[(None, None), 0])\n        thr_tmem_load = tiled_tmem_load.get_slice(tidx)\n        smem_copy_atom = sm100_utils_basic.get_smem_store_op(\n            self.o_layout, self.o_dtype, self.pv_acc_dtype, tiled_tmem_load\n        )\n        tiled_smem_store = cute.make_tiled_copy_D(smem_copy_atom, tiled_tmem_load)\n\n        tOtO_t2r = thr_tmem_load.partition_S(tOtO_i[(None, None), None])\n        tOsO_s2r = copy_utils.partition_D_position_independent(thr_tmem_load, tOsO_i[(None, None), None])\n        tOcO_t2r = thr_tmem_load.partition_D(tOcO_i[(None, None), None])\n        for i in cutlass.range(self.head_dim_v_padded // corr_tile_size, unroll_full=True):\n            tOtO_t2r_i = tOtO_t2r[None, 0, 0, i]\n            tOsO_r2s_i = tOsO_s2r[None, 0, 0, i]\n            tOrO_frg = cute.make_fragment(tOcO_t2r[None, 0, 0, i].shape, self.pv_acc_dtype)\n            cute.copy(tiled_tmem_load, tOtO_t2r_i, tOrO_frg)\n            for j in cutlass.range(0, cute.size(tOrO_frg), 2, unroll_full=True):\n                tOrO_frg[j], tOrO_frg[j + 1] = cute.arch.mul_packed_f32x2(\n                    (tOrO_frg[j], tOrO_frg[j + 1]), (scale, scale)\n                )\n            copy_utils.cvt_copy(tiled_smem_store, tOrO_frg, tOsO_r2s_i)\n        cute.arch.fence_view_async_shared()\n\n        if const_expr(self.use_correction_warps_for_epi):\n            assert(not self.use_tma_O)\n            assert(gmem_tiled_copy_O is not None)\n            cute.arch.barrier(barrier_id=int(NamedBarrierFwdSm100.Epilogue),\n                              number_of_threads=len(self.epilogue_warp_ids) * cute.arch.WARP_SIZE)\n            mma_tile_coord_v = thr_mma.thr_idx\n            m_tile_idx = (m_block * self.q_stage + stage) * self.cta_group_size + mma_tile_coord_v\n            self._store_O_to_gmem(\n                sO, gO, mO_cur, gmem_tiled_copy_O, tidx, seqlen_q, m_tile_idx\n            )\n\n    @cute.jit\n    def _store_O_to_gmem(\n        self,\n        sO_stage: cute.Tensor,\n        gO: cute.Tensor,\n        mO_cur: cute.Tensor,\n        gmem_tiled_copy_O: cute.TiledCopy,\n        tidx: Int32,\n        seqlen_q: Int32,\n        m_tile_idx: Int32,\n    ):\n        \"\"\"Copy a single stage of O from smem to gmem via registers.\"\"\"\n        gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx)\n        tOsO = gmem_thr_copy_O.partition_S(sO_stage)\n        cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded))\n        tOgO = gmem_thr_copy_O.partition_D(gO)\n        tOcO = gmem_thr_copy_O.partition_S(cO)\n        t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO)\n        tOpO = copy_utils.predicate_k(tOcO, limit=mO_cur.shape[1])\n        pack_gqa = PackGQA(\n            self.m_block_size,\n            self.head_dim_v_padded,\n            self.check_hdim_v_oob,\n            self.qhead_per_kvhead,\n        )\n\n        # load acc O from smem to rmem for wider vectorization\n        tOrO = cute.make_fragment_like(tOsO, self.o_dtype)\n        cute.autovec_copy(tOsO, tOrO)\n        # copy acc O from rmem to gmem\n        if const_expr(not self.pack_gqa):\n            for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])):\n                if (\n                    t0OcO[0, rest_m, 0][0] < seqlen_q - m_tile_idx * self.m_block_size - tOcO[0][0]\n                ):\n                    cute.copy(\n                        gmem_tiled_copy_O,\n                        tOrO[None, rest_m, None],\n                        tOgO[None, rest_m, None],\n                        pred=tOpO[None, rest_m, None]\n                        if const_expr(self.check_hdim_v_oob)\n                        else None,\n                    )\n        else:\n            pack_gqa.store_O(\n                mO_cur, tOrO, gmem_tiled_copy_O, tidx, m_tile_idx, seqlen_q\n            )\n\n    @cute.jit\n    def epilogue_s2g(\n        self,\n        mO: cute.Tensor,\n        sO: cute.Tensor,\n        gmem_tiled_copy_O: cute.TiledCopy,\n        tma_atom_O: Optional[cute.CopyAtom],\n        pipeline_o_epi: pipeline.PipelineAsync,\n        block_info: BlockInfo,\n        num_splits: int,\n        SeqlenInfoCls: Callable,\n        TileSchedulerCls: Callable,\n        mma_tile_coord_v: Int32 = 0,\n    ):\n        epi_consumer_phase = Int32(0)\n        tile_scheduler = TileSchedulerCls()\n        work_tile = tile_scheduler.initial_work_tile_info()\n        while work_tile.is_valid_tile:\n            m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx\n            seqlen = SeqlenInfoCls(batch_idx)\n            n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits)\n\n            if const_expr(not self.is_split_kv) or n_block_min < n_block_max:\n                if const_expr(self.is_split_kv):\n                    mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx, split_idx]\n                else:\n                    mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx]\n                tiler_gO = ((self.mma_tiler_pv[0] * self.q_stage), self.head_dim_v_padded)\n                gO = cute.local_tile(mO_cur, tiler_gO, (m_block, 0))  # (128 * 2, 128)\n                gO = layout_utils.select(\n                    cute.flat_divide(gO, (self.mma_tiler_pv[0],)), mode=[0, 2, 1]\n                )  # (128, 128, 2)\n                gO = cute.flat_divide(gO, (self.mma_tiler_pv[0] // self.cta_group_size,))[None, mma_tile_coord_v, None, None]\n\n                if const_expr(self.use_tma_O):\n                    store_O, _, _ = copy_utils.tma_get_copy_fn(\n                        tma_atom_O, 0, cute.make_layout(1), sO, gO\n                    )\n                    for stage in cutlass.range(self.q_stage, unroll_full=True):\n                        # wait from corr, issue tma store on smem\n                        # 1. wait for O0 / O1 final\n                        pipeline_o_epi.consumer_wait_w_index_phase(stage, epi_consumer_phase)\n                        # 2. copy O0 / O1 to gmem\n                        store_O(src_idx=stage, dst_idx=stage)\n                        cute.arch.cp_async_bulk_commit_group()\n                    for stage in cutlass.range_constexpr(self.q_stage):\n                        # Ensure O0 / O1 buffer is ready to be released\n                        cute.arch.cp_async_bulk_wait_group(self.q_stage - 1 - stage, read=True)\n                        pipeline_o_epi.consumer_release_w_index(stage)\n                else:\n                    tidx = cute.arch.thread_idx()[0] % (\n                        cute.arch.WARP_SIZE * len(self.epilogue_warp_ids)\n                    )\n                    for stage in cutlass.range_constexpr(self.q_stage):\n                        # wait from corr, issue tma store on smem\n                        # 1. wait for O0 / O1 final\n                        pipeline_o_epi.consumer_wait_w_index_phase(stage, epi_consumer_phase)\n                        # 2. copy O0 / O1 to gmem\n                        m_tile_idx = (m_block * self.q_stage + stage) * self.cta_group_size + mma_tile_coord_v\n                        self._store_O_to_gmem(\n                            sO[None, None, stage], gO[None, None, stage], mO_cur, gmem_tiled_copy_O,\n                            tidx, seqlen.seqlen_q, m_tile_idx,\n                        )\n                        pipeline_o_epi.consumer_release_w_index(stage)\n\n                epi_consumer_phase ^= 1\n\n            # Advance to next tile\n            tile_scheduler.advance_to_next_work()\n            work_tile = tile_scheduler.get_current_work()\n\n    def load_Q(\n        self,\n        load_Q_fn: Callable,\n        pipeline_q: pipeline.PipelineAsync,\n        block: Int32,\n        stage: int,\n        phase: Int32,\n    ):\n        pipeline_q.producer_acquire_w_index_phase(stage, phase)\n        load_Q_fn(src_idx=block, dst_idx=stage, tma_bar_ptr=pipeline_q.sync_object_full.get_barrier(stage))\n\n    @cute.jit\n    def load_KV(\n        self,\n        tma_atom: Optional[cute.CopyAtom],\n        tXgX: Optional[cute.Tensor],\n        tXsX: Optional[cute.Tensor],\n        paged_kv_manager: Optional[PagedKVManager],\n        sX: cute.Tensor,\n        block: Int32,\n        pipeline_kv: pipeline.PipelineAsync,\n        producer_state: pipeline.PipelineState,\n        K_or_V: Literal[\"K\", \"V\"],\n        page_idx: Optional[Int32] = None,\n        extra_tx_count: Optional[Int32] = None,\n    ):\n        assert K_or_V in (\"K\", \"V\")\n        stage, phase = producer_state.index, producer_state.phase\n        extra_tx_count_kv = self.tma_copy_bytes[K_or_V] - self.tma_copy_bytes[\"K\"]\n        extra_tx_count = (\n            extra_tx_count_kv + (extra_tx_count if extra_tx_count is not None else 0) if const_expr(self.use_tma_KV)\n            else None\n        )\n        extra_kwargs = {\"extra_tx_count\": extra_tx_count} if const_expr(self.use_tma_KV) else {}\n        pipeline_kv.producer_acquire(producer_state, **extra_kwargs)\n        if const_expr(K_or_V == \"K\" and self.uneven_kv_smem):\n            # Before this round, the smem location was occupied by V, which is smaller than\n            # K. So we need to wait for the stage after that (stage 1) to be empty as well.\n            if stage == 0:\n                pipeline_kv.sync_object_empty.wait(1, phase)\n\n        if const_expr(self.use_tma_KV):\n            assert tXgX is not None and tXsX is not None and tma_atom is not None\n            tXsX_cur = tXsX[None, stage]\n            if const_expr(self.uneven_kv_smem):\n                # Since this is the producer_state, the phase starts at 1, so we have to invert it\n                tXsX_cur = self.offset_kv_smem(tXsX_cur, stage, phase ^ 1)\n            # Currently we assume that page_size == n_block_size so we index into tXgX with block = 0\n            tXgX_cur = tXgX[None, block] if const_expr(page_idx is None) else tXgX[None, 0, page_idx]\n            cute.copy(tma_atom, tXgX_cur, tXsX_cur, tma_bar_ptr=pipeline_kv.producer_get_barrier(producer_state))\n        else:\n            assert paged_kv_manager is not None\n            assert extra_tx_count is None\n            sX_cur = sX[None, None, None, stage]\n            if const_expr(self.uneven_kv_smem):\n                sX_cur = self.offset_kv_smem(sX_cur, stage, phase ^ 1)\n            paged_kv_manager.load_KV(block, sX_cur, K_or_V)\n            cute.arch.cp_async_commit_group()\n            pipeline_kv.sync_object_full.arrive_cp_async_mbarrier(stage)\n\n    @cute.jit\n    def offset_kv_smem(self, sX: cute.Tensor, stage: Int32, phase: Int32):\n        if const_expr(self.uneven_kv_smem):\n            # smem layout is [smem_large, smem_small, smem_large], and the current stride is\n            # (smem_large + smem_small) // 2. So for stage == 1, move right by offset if\n            # phase == 0, or left by offset if phase == 1.\n            offset = 0 if stage != 1 else self.uneven_kv_smem_offset * (1 - 2 * phase)\n            # Hint that the offset is 128-bit aligned so that\n            # ptr + offset preserves the alignment needed by cp.async.\n            offset = cute.assume(offset, divby=128 // self.k_dtype.width)\n            return cute.make_tensor(sX.iterator + offset, sX.layout)\n        else:\n            return sX\n\n    # @cute.jit\n    # def warp_scheduler_barrier_init(self):\n    #     warp_group_idx = utils.canonical_warp_group_idx(sync=False)\n    #     if warp_group_idx == 0:\n    #         cute.arch.barrier_arrive(\n    #             barrier_id=int(NamedBarrierFwdSm100.WarpSchedulerWG1), number_of_threads=2 * 128,\n    #         )\n\n    # def warp_scheduler_barrier_sync(self):\n    #     cute.arch.barrier(\n    #         barrier_id=int(NamedBarrierFwdSm100.WarpSchedulerWG1) + utils.canonical_warp_group_idx(sync=False),\n    #         number_of_threads=2 * 128\n    #     )\n\n    # def warp_scheduler_barrier_arrive(self):\n    #     cur_wg = utils.canonical_warp_group_idx(sync=False)\n    #     next_wg = 1 - cur_wg\n    #     cute.arch.barrier_arrive(\n    #         barrier_id=int(NamedBarrierFwdSm100.WarpSchedulerWG1) + next_wg, number_of_threads=2 * 128,\n    #     )\n\n    @cute.jit\n    def apply_score_mod(\n        self,\n        tSrS_t2r,\n        thr_tmem_load,\n        thr_mma_qk,\n        batch_idx,\n        head_idx,\n        m_block,\n        n_block,\n        softmax,\n        seqlen: SeqlenInfoQK,\n        aux_tensors=None,\n        fastdiv_mods=(None, None),\n        head_divmod=None,\n    ):\n        \"\"\"Apply score modification for SM100 (constant q_idx).\"\"\"\n        # Prepare index tensor with extra partition\n        cS = cute.make_identity_tensor((self.m_block_size, self.n_block_size))\n        cS = cute.domain_offset((m_block * self.m_block_size, n_block * self.n_block_size), cS)\n        tScS = thr_mma_qk.partition_C(cS)\n        tScS = tScS[(None, None), 0, 0]\n        tScS_t2r = thr_tmem_load.partition_D(tScS)\n\n        # Shared q_idx for all scores\n        q_idx_logical = tScS_t2r[0][0]\n\n        # For Pack-GQA, compute the logical head index for this tile\n        if cutlass.const_expr(self.pack_gqa):\n            assert head_divmod is not None\n            # Building up the logical q_head idx: final_q_head = kv_head * qhead_per_kvhead + (q_physical % qhead_per_kvhead)\n            q_physical = q_idx_logical\n            q_idx_logical, head_offset = divmod(q_physical, head_divmod)\n            head_idx = head_idx * self.qhead_per_kvhead + head_offset\n\n        if cutlass.const_expr(aux_tensors is not None):\n            seqlen_q_divmod, _ = fastdiv_mods\n            _, q_idx_logical = divmod(q_idx_logical, seqlen_q_divmod)\n\n        apply_score_mod_inner(\n            tSrS_t2r,\n            tScS_t2r,\n            self.score_mod,\n            batch_idx,\n            head_idx,\n            softmax.softmax_scale,\n            self.vec_size,\n            self.qk_acc_dtype,\n            aux_tensors,\n            fastdiv_mods,\n            seqlen_info=seqlen,\n            constant_q_idx=q_idx_logical,\n            qhead_per_kvhead=self.qhead_per_kvhead if cutlass.const_expr(self.pack_gqa) else 1,\n        )\n"
  },
  {
    "path": "flash_attn/cute/flash_fwd_sm120.py",
    "content": "# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n# SM120 (Blackwell GeForce / DGX Spark) forward pass.\n#\n# SM120 uses the same SM80-era MMA instructions (mma.sync.aligned.m16n8k16) but has\n# a smaller shared memory capacity (99 KB vs 163 KB on SM80). This module subclasses\n# FlashAttentionForwardSm80 and overrides the SMEM capacity check accordingly.\n\nimport cutlass\nimport cutlass.utils as utils_basic\n\nfrom flash_attn.cute.flash_fwd import FlashAttentionForwardSm80\n\n\nclass FlashAttentionForwardSm120(FlashAttentionForwardSm80):\n    # Keep arch = 80 to use CpAsync code paths (no TMA for output).\n    # The compilation target is determined by the GPU at compile time, not this field.\n    arch = 80\n\n    @staticmethod\n    def can_implement(\n        dtype,\n        head_dim,\n        head_dim_v,\n        tile_m,\n        tile_n,\n        num_stages,\n        num_threads,\n        is_causal,\n        Q_in_regs=False,\n    ) -> bool:\n        \"\"\"Check if the kernel can be implemented on SM120.\n\n        Same logic as SM80 but uses SM120's shared memory capacity (99 KB).\n        \"\"\"\n        if dtype not in [cutlass.Float16, cutlass.BFloat16]:\n            return False\n        if head_dim % 8 != 0:\n            return False\n        if head_dim_v % 8 != 0:\n            return False\n        if tile_n % 16 != 0:\n            return False\n        if num_threads % 32 != 0:\n            return False\n        # Shared memory usage: Q tile + (K tile + V tile)\n        smem_usage_Q = tile_m * head_dim * 2\n        smem_usage_K = tile_n * head_dim * num_stages * 2\n        smem_usage_V = tile_n * head_dim_v * num_stages * 2\n        smem_usage_QV = (\n            (smem_usage_Q + smem_usage_V) if not Q_in_regs else max(smem_usage_Q, smem_usage_V)\n        )\n        smem_usage = smem_usage_QV + smem_usage_K\n        # SM120 has 99 KB shared memory (vs 163 KB on SM80)\n        smem_capacity = utils_basic.get_smem_capacity_in_bytes(\"sm_120\")\n        if smem_usage > smem_capacity:\n            return False\n        if (tile_m * 2) % num_threads != 0:\n            return False\n        return True\n"
  },
  {
    "path": "flash_attn/cute/flash_fwd_sm90.py",
    "content": "# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n# SM90 (Hopper) forward pass for flash attention, extracted from flash_fwd.py.\n\nfrom types import SimpleNamespace\nfrom typing import Callable, Literal, Optional\nfrom functools import partial\n\nimport cuda.bindings.driver as cuda\n\nimport cutlass\nimport cutlass.cute as cute\nfrom cutlass import Float32, Int32, const_expr\nfrom cutlass.cute.nvgpu import cpasync, warpgroup\nfrom cutlass.utils import LayoutEnum\nimport cutlass.utils.hopper_helpers as sm90_utils_basic\nfrom cutlass import pipeline\nfrom cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait\nfrom cutlass.base_dsl.arch import Arch\n\nfrom quack import copy_utils\nfrom quack import layout_utils\nfrom quack import sm90_utils\n\nfrom flash_attn.cute.cute_dsl_utils import assume_tensor_aligned\nfrom flash_attn.cute import utils\nfrom flash_attn.cute.mask import AttentionMask\nfrom flash_attn.cute.softmax import Softmax, apply_score_mod_inner\nfrom flash_attn.cute.seqlen_info import SeqlenInfoQK\nfrom flash_attn.cute.block_info import BlockInfo\nfrom flash_attn.cute.block_sparsity import BlockSparseTensors\nfrom flash_attn.cute.block_sparse_utils import (\n    produce_block_sparse_loads,\n    consume_block_sparse_loads,\n)\nfrom flash_attn.cute import pipeline as pipeline_custom\nfrom flash_attn.cute.pack_gqa import PackGQA, pack_gqa_layout, make_packgqa_tiled_tma_atom\nfrom flash_attn.cute.paged_kv import PagedKVManager\nfrom flash_attn.cute.named_barrier import NamedBarrierFwd\nfrom quack.cute_dsl_utils import ParamsBase\nfrom flash_attn.cute.tile_scheduler import (\n    TileSchedulerArguments,\n    SingleTileScheduler,\n    SingleTileLPTScheduler,\n    SingleTileVarlenScheduler,\n)\nfrom cutlass.cute import FastDivmodDivisor\n\nfrom flash_attn.cute.flash_fwd import FlashAttentionForwardBase\n\n\nclass FlashAttentionForwardSm90(FlashAttentionForwardBase):\n    def __init__(\n        self,\n        *args,\n        intra_wg_overlap: bool = True,\n        mma_pv_is_rs: bool = True,\n        paged_kv_non_tma: bool = False,\n        **kwargs,\n    ):\n        super().__init__(*args, **kwargs)\n        self.intra_wg_overlap = intra_wg_overlap\n        self.mma_pv_is_rs = mma_pv_is_rs\n        self.buffer_align_bytes = 1024\n        self.use_tma_KV = not paged_kv_non_tma\n        assert self.use_tma_KV or not (self.check_hdim_oob or self.check_hdim_v_oob), (\n            \"Paged KV does not support irregular head dim\"\n        )\n        self.cluster_shape_mn = (1, 1)\n        assert self.arch >= Arch.sm_90 and self.arch <= Arch.sm_90a, \"Only SM 9.x is supported\"\n\n    def _get_smem_layout_atom(self):\n        sQ_layout_atom = warpgroup.make_smem_layout_atom(\n            sm90_utils_basic.get_smem_layout_atom(LayoutEnum.ROW_MAJOR, self.dtype, self.tile_hdim),\n            self.dtype,\n        )\n        sK_layout_atom = sQ_layout_atom\n        sV_layout_atom = warpgroup.make_smem_layout_atom(\n            sm90_utils_basic.get_smem_layout_atom(\n                LayoutEnum.ROW_MAJOR, self.dtype, self.tile_hdimv\n            ),\n            self.dtype,\n        )\n        sO_layout_atom = sV_layout_atom\n        if not self.mma_pv_is_rs:\n            sP_layout_atom = warpgroup.make_smem_layout_atom(\n                sm90_utils_basic.get_smem_layout_atom(\n                    LayoutEnum.ROW_MAJOR, self.dtype, self.tile_n\n                ),\n                self.dtype,\n            )\n        else:\n            sP_layout_atom = None\n        return sQ_layout_atom, sK_layout_atom, sV_layout_atom, sO_layout_atom, sP_layout_atom\n\n    def _get_tiled_mma(self):\n        tiled_mma_qk = sm90_utils_basic.make_trivial_tiled_mma(\n            self.dtype,\n            self.dtype,\n            warpgroup.OperandMajorMode.K,\n            warpgroup.OperandMajorMode.K,\n            Float32,\n            atom_layout_mnk=(self.tile_m // 64, 1, 1),\n            tiler_mn=(64, self.tile_n),\n        )\n        tiled_mma_pv = sm90_utils_basic.make_trivial_tiled_mma(\n            self.dtype,\n            self.dtype,\n            warpgroup.OperandMajorMode.K,\n            warpgroup.OperandMajorMode.MN,\n            Float32,\n            atom_layout_mnk=(self.tile_m // 64, 1, 1),  # Might need (1, 2, 1) for hdim 512\n            tiler_mn=(64, self.tile_hdimv),\n            a_source=warpgroup.OperandSource.RMEM\n            if self.mma_pv_is_rs\n            else warpgroup.OperandSource.SMEM,\n        )\n        return tiled_mma_qk, tiled_mma_pv\n\n    def _get_shared_storage_cls(self):\n        sQ_struct, sK_struct, sV_struct = [\n            cute.struct.Align[\n                cute.struct.MemRange[self.dtype, cute.cosize(layout)], self.buffer_align_bytes\n            ]\n            for layout in (self.sQ_layout, self.sK_layout, self.sV_layout)\n        ]\n        cosize_sQV = max(cute.cosize(self.sQ_layout), cute.cosize(self.sV_layout))\n        sQV_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sQV], 1024]\n        cosize_sP = cute.cosize(self.sP_layout) if const_expr(self.sP_layout is not None) else 0\n        sP_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sP], 1024]\n        # 1 stage * 2 for Q pipeline (full + empty), self.num_stages*2 for K, self.num_stages*2 for V,\n        mbar_ptr_Q_struct = cute.struct.MemRange[cutlass.Int64, 1 * 2]\n        mbar_ptr_K_struct = cute.struct.MemRange[cutlass.Int64, self.num_stages * 2]\n        mbar_ptr_V_struct = cute.struct.MemRange[cutlass.Int64, self.num_stages * 2]\n\n        @cute.struct\n        class SharedStorageQKV:\n            mbar_ptr_Q: mbar_ptr_Q_struct\n            mbar_ptr_K: mbar_ptr_K_struct\n            mbar_ptr_V: mbar_ptr_V_struct\n            sV: sV_struct\n            sQ: sQ_struct\n            sK: sK_struct\n            sP: sP_struct\n\n        @cute.struct\n        class SharedStorageSharedQV:\n            mbar_ptr_Q: mbar_ptr_Q_struct\n            mbar_ptr_K: mbar_ptr_K_struct\n            mbar_ptr_V: mbar_ptr_V_struct\n            sQ: sQV_struct\n            sK: sK_struct\n            sP: sP_struct\n\n        return SharedStorageQKV if const_expr(not self.Q_in_regs) else SharedStorageSharedQV\n\n    @cute.jit\n    def __call__(\n        self,\n        mQ: cute.Tensor,  # (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q\n        mK: cute.Tensor,  # (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table\n        mV: cute.Tensor,  # (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table\n        mO: cute.Tensor,  # (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q\n        mLSE: Optional[cute.Tensor],\n        softmax_scale: Float32,\n        mCuSeqlensQ: Optional[cute.Tensor] = None,\n        mCuSeqlensK: Optional[cute.Tensor] = None,\n        mSeqUsedQ: Optional[cute.Tensor] = None,\n        mSeqUsedK: Optional[cute.Tensor] = None,\n        mPageTable: Optional[cute.Tensor] = None,  # (b_k, max_num_pages_per_seq)\n        window_size_left: Int32 | int | None = None,\n        window_size_right: Int32 | int | None = None,\n        learnable_sink: Optional[cute.Tensor] = None,\n        blocksparse_tensors: Optional[BlockSparseTensors] = None,\n        aux_tensors: Optional[list] = None,\n        # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI).\n        stream: cuda.CUstream = None,\n    ):\n        \"\"\"Configures and launches the flash attention kernel.\n\n        mQ/mK/mV/mO has same data types(supports fp16 and bf16) and same layout:\n        (batch_size, seqlen_q, num_head, head_dim):(_, _, _, 1)\n        \"\"\"\n\n        self._check_type(\n            *(\n                t.element_type if t is not None else None\n                for t in (mQ, mK, mV, mO, mLSE, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)\n            )\n        )\n\n        self.varlen_q = mCuSeqlensQ is not None or mSeqUsedQ is not None\n\n        mQ, mK, mV, mO = [assume_tensor_aligned(t) for t in (mQ, mK, mV, mO)]\n        QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1]\n        mQ, mO = [layout_utils.select(t, QO_layout_transpose) for t in (mQ, mO)]\n        KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1]\n        mK, mV = [layout_utils.select(t, KV_layout_transpose) for t in (mK, mV)]\n        LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0]\n        mLSE = (\n            layout_utils.select(mLSE, LSE_layout_transpose)\n            if const_expr(mLSE is not None)\n            else None\n        )\n\n        tiled_mma_qk, tiled_mma_pv = self._get_tiled_mma()\n        self.num_mma_threads = tiled_mma_qk.size\n        self.num_threads_per_warp_group = 128\n        self.num_wg_mma = self.num_mma_threads // self.num_threads_per_warp_group\n        assert self.num_wg_mma in [1, 2, 3]\n        self.num_threads = self.num_threads_per_warp_group * (self.num_wg_mma + 1)\n        self.num_producer_threads = 32\n        self.num_Q_load_threads = self.num_mma_threads  # If not TMA_Q, MMA threads load Q\n        self.num_epilogue_threads = self.num_mma_threads\n        self.num_mma_regs, self.num_producer_regs = {1: (256, 56), 2: (240, 24), 3: (160, 32)}[\n            self.num_wg_mma\n        ]\n        self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None)\n\n        self.use_scheduler_barrier = (\n            (self.num_wg_mma >= 2 and self.tile_hdim <= 128)\n            if const_expr(self.intra_wg_overlap)\n            else (self.num_wg_mma == 2)\n        )\n        self.use_tma_Q = self.arch >= Arch.sm_90 and not (\n            self.pack_gqa and self.tile_m % self.qhead_per_kvhead != 0\n        )\n        self.use_tma_O = self.use_tma_Q\n        self.rescale_O_before_gemm = self.tile_hdimv > 128 and self.intra_wg_overlap\n        self._setup_attributes()\n        # TODO: we prob don't need most of what's in _setup_attributes\n        self.sQ_layout, self.sK_layout, self.sV_layout, self.sO_layout = [\n            sm90_utils.make_smem_layout(mX.element_type, LayoutEnum.ROW_MAJOR, shape, stage)\n            for mX, shape, stage in [\n                (mQ, (self.tile_m, self.tile_hdim), None),\n                (mK, (self.tile_n, self.tile_hdim), self.num_stages),\n                (mV, (self.tile_n, self.tile_hdimv), self.num_stages),\n                (mO, (self.tile_m, self.tile_hdimv), None),\n            ]\n        ]\n        self.sP_layout = None\n        if const_expr(not self.mma_pv_is_rs):\n            self.sP_layout = sm90_utils.make_smem_layout(\n                mV.element_type, LayoutEnum.ROW_MAJOR, (self.tile_m, self.tile_n)\n            )\n\n        SharedStorage = self._get_shared_storage_cls()\n\n        mQ_og, mO_og = mQ, mO\n        if const_expr(self.pack_gqa):\n            nheads_kv = mK.shape[2]\n            mQ = pack_gqa_layout(mQ, self.qhead_per_kvhead, nheads_kv, head_idx=2)\n            mO = pack_gqa_layout(mO, self.qhead_per_kvhead, nheads_kv, head_idx=2)\n            if const_expr(mLSE is not None):\n                mLSE = pack_gqa_layout(mLSE, self.qhead_per_kvhead, nheads_kv, head_idx=1)\n\n        # TMA\n        gmem_tiled_copy_Q = cpasync.CopyBulkTensorTileG2SOp()\n        gmem_tiled_copy_KV = cpasync.CopyBulkTensorTileG2SOp()  # Might multicast\n        gmem_tiled_copy_O = cpasync.CopyBulkTensorTileS2GOp()\n        self.tma_copy_bytes = {\n            name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1]))\n            for name, mX, layout in [\n                (\"Q\", mQ, self.sQ_layout),\n                (\"K\", mK, self.sK_layout),\n                (\"V\", mV, self.sV_layout),\n            ]\n        }\n        make_tiled_tma_atom_fn = (\n            partial(make_packgqa_tiled_tma_atom, qhead_per_kvhead=self.qhead_per_kvhead, head_idx=2)\n            if const_expr(self.pack_gqa)\n            else cpasync.make_tiled_tma_atom\n        )\n        tma_atom_Q, tma_tensor_Q = None, None\n        if const_expr(self.use_tma_Q):\n            tma_atom_Q, tma_tensor_Q = make_tiled_tma_atom_fn(\n                gmem_tiled_copy_Q,\n                mQ_og if const_expr(self.pack_gqa) else mQ,\n                self.sQ_layout,\n                (self.tile_m, self.tile_hdim),  # No mcast\n            )\n        tma_atom_K, tma_tensor_K = None, None\n        tma_atom_V, tma_tensor_V = None, None\n        if const_expr(self.use_tma_KV):\n            tma_atom_K, tma_tensor_K = cpasync.make_tiled_tma_atom(\n                gmem_tiled_copy_KV,\n                mK,\n                cute.select(self.sK_layout, mode=[0, 1]),\n                (self.tile_n, self.tile_hdim),\n                1,  # No mcast for now\n            )\n            tma_atom_V, tma_tensor_V = cpasync.make_tiled_tma_atom(\n                gmem_tiled_copy_KV,\n                mV,\n                cute.select(self.sV_layout, mode=[0, 1]),\n                (self.tile_n, self.tile_hdimv),\n                1,  # No mcast for now\n            )\n        tma_atom_O, tma_tensor_O = None, None\n        if const_expr(self.use_tma_O):\n            mO_tma = mO_og if const_expr(self.pack_gqa) else mO\n            if const_expr(self.varlen_q):\n                mO_tma = copy_utils.create_ragged_tensor_for_tma(\n                    mO_tma, ragged_dim=0, ptr_shift=True\n                )\n            tma_atom_O, tma_tensor_O = make_tiled_tma_atom_fn(\n                gmem_tiled_copy_O,\n                mO_tma,\n                self.sO_layout,\n                (self.tile_m, self.tile_hdimv),  # No mcast\n            )\n        if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None):\n            TileScheduler = SingleTileVarlenScheduler\n        else:\n            TileScheduler = (\n                SingleTileScheduler\n                if const_expr(not self.is_causal or self.is_local)\n                else SingleTileLPTScheduler\n            )\n        tile_sched_args = TileSchedulerArguments(\n            cute.ceil_div(cute.size(mQ.shape[0]), self.tile_m),\n            cute.size(mQ.shape[2]),\n            cute.size(mQ.shape[3])\n            if const_expr(mCuSeqlensQ is None)\n            else cute.size(mCuSeqlensQ.shape[0] - 1),\n            1,  # num_splits\n            cute.size(mK.shape[0])\n            if const_expr(mPageTable is None)\n            else mK.shape[0] * mPageTable.shape[1],\n            mQ.shape[1],\n            mV.shape[1],\n            total_q=cute.size(mQ.shape[0])\n            if const_expr(mCuSeqlensQ is not None)\n            else cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]),\n            tile_shape_mn=(self.tile_m, self.tile_n),\n            mCuSeqlensQ=mCuSeqlensQ,\n            mSeqUsedQ=mSeqUsedQ,\n            qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,\n            element_size=self.dtype.width // 8,\n            is_persistent=False,\n            lpt=self.is_causal or self.is_local,\n        )\n        tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)\n        grid_dim = TileScheduler.get_grid_shape(tile_sched_params)\n        softmax_scale_log2, softmax_scale = utils.compute_softmax_scale_log2(\n            softmax_scale, self.score_mod\n        )\n        window_size_left = Int32(window_size_left) if window_size_left is not None else None\n        window_size_right = Int32(window_size_right) if window_size_right is not None else None\n        fastdiv_mods = utils.compute_fastdiv_mods(\n            mQ, mK, self.qhead_per_kvhead, self.pack_gqa, aux_tensors, mPageTable\n        )\n\n        self.kernel(\n            tma_tensor_Q if const_expr(self.use_tma_Q) else mQ,\n            tma_tensor_K if const_expr(self.use_tma_KV) else mK,\n            tma_tensor_V if const_expr(self.use_tma_KV) else mV,\n            tma_tensor_O if const_expr(self.use_tma_O) else mO,\n            mLSE,\n            mCuSeqlensQ,\n            mCuSeqlensK,\n            mSeqUsedQ,\n            mSeqUsedK,\n            mPageTable,\n            tma_atom_Q,\n            tma_atom_K,\n            tma_atom_V,\n            tma_atom_O,\n            softmax_scale_log2,\n            softmax_scale,\n            window_size_left,\n            window_size_right,\n            learnable_sink,\n            blocksparse_tensors,\n            self.sQ_layout,\n            self.sK_layout,\n            self.sV_layout,\n            self.sO_layout,\n            self.sP_layout,\n            self.gmem_tiled_copy_Q,\n            self.gmem_tiled_copy_K,\n            self.gmem_tiled_copy_V,\n            self.gmem_tiled_copy_O,\n            tiled_mma_qk,\n            tiled_mma_pv,\n            tile_sched_params,\n            TileScheduler,\n            SharedStorage,\n            aux_tensors,\n            fastdiv_mods,\n        ).launch(\n            grid=grid_dim,\n            block=[self.num_threads, 1, 1],\n            stream=stream,\n            min_blocks_per_mp=1,\n        )\n\n    @cute.kernel\n    def kernel(\n        self,\n        mQ: cute.Tensor,\n        mK: cute.Tensor,\n        mV: cute.Tensor,\n        mO: cute.Tensor,\n        mLSE: Optional[cute.Tensor],\n        mCuSeqlensQ: Optional[cute.Tensor],\n        mCuSeqlensK: Optional[cute.Tensor],\n        mSeqUsedQ: Optional[cute.Tensor],\n        mSeqUsedK: Optional[cute.Tensor],\n        mPageTable: Optional[cute.Tensor],\n        tma_atom_Q: Optional[cute.CopyAtom],\n        tma_atom_K: Optional[cute.CopyAtom],\n        tma_atom_V: Optional[cute.CopyAtom],\n        tma_atom_O: Optional[cute.CopyAtom],\n        softmax_scale_log2: Float32,\n        softmax_scale: Optional[Float32],\n        window_size_left: Optional[Int32],\n        window_size_right: Optional[Int32],\n        learnable_sink: Optional[cute.Tensor],\n        blocksparse_tensors: Optional[BlockSparseTensors],\n        sQ_layout: cute.ComposedLayout,\n        sK_layout: cute.ComposedLayout,\n        sV_layout: cute.ComposedLayout,\n        sO_layout: cute.ComposedLayout,\n        sP_layout: cute.ComposedLayout | None,\n        gmem_tiled_copy_Q: cute.TiledCopy,\n        gmem_tiled_copy_K: cute.TiledCopy,\n        gmem_tiled_copy_V: cute.TiledCopy,\n        gmem_tiled_copy_O: cute.TiledCopy,\n        tiled_mma_qk: cute.TiledMma,\n        tiled_mma_pv: cute.TiledMma,\n        tile_sched_params: ParamsBase,\n        TileScheduler: cutlass.Constexpr[Callable],\n        SharedStorage: cutlass.Constexpr[Callable],\n        aux_tensors=Optional[list[cute.Tensor]],\n        fastdiv_mods=None,\n    ):\n        warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())\n        # Prefetch tma descriptor\n        if warp_idx == 0:\n            for tma_atom in (tma_atom_Q, tma_atom_K, tma_atom_V, tma_atom_O):\n                if const_expr(tma_atom is not None):\n                    cpasync.prefetch_descriptor(tma_atom)\n\n        smem = cutlass.utils.SmemAllocator()\n        storage = smem.allocate(SharedStorage)\n\n        # Mbarrier / pipeline init\n        mbar_ptr_Q = storage.mbar_ptr_Q.data_ptr()\n        if const_expr(not self.use_tma_Q):\n            if warp_idx == 1:\n                cute.arch.mbarrier_init(mbar_ptr_Q, self.num_Q_load_threads)\n\n        ThreadCooperativeGroup = partial(pipeline.CooperativeGroup, pipeline.Agent.Thread)\n        tma_warp = ThreadCooperativeGroup(1)\n        load_threads = ThreadCooperativeGroup(self.num_threads_per_warp_group)\n        mma_warps = ThreadCooperativeGroup(self.num_mma_threads // cute.arch.WARP_SIZE)\n        mma_threads = ThreadCooperativeGroup(self.num_mma_threads)\n        pipeline_q = None\n        if const_expr(self.use_tma_Q):\n            pipeline_q = pipeline_custom.PipelineTmaAsync.create(\n                barrier_storage=mbar_ptr_Q,\n                num_stages=1,\n                producer_group=tma_warp,\n                consumer_group=mma_warps,\n                tx_count=self.tma_copy_bytes[\"Q\"],\n                defer_sync=True,\n            )\n\n        # We rely on pipeline_k and pipeline_v to initialize the mbarrier fence and sync\n        if const_expr(self.use_tma_KV):\n            # PipelineTmaAsync: consumer_release has internal per-warp gating\n            # (is_signalling_thread), so arrive count = num_consumer_warps\n            pipeline_k = pipeline_custom.PipelineTmaAsync.create(\n                barrier_storage=storage.mbar_ptr_K.data_ptr(),\n                num_stages=self.num_stages,\n                producer_group=tma_warp,\n                consumer_group=mma_warps,\n                tx_count=self.tma_copy_bytes[\"K\"],\n                defer_sync=True,\n            )\n            pipeline_v = pipeline_custom.PipelineTmaAsync.create(\n                barrier_storage=storage.mbar_ptr_V.data_ptr(),\n                num_stages=self.num_stages,\n                producer_group=tma_warp,\n                consumer_group=mma_warps,\n                tx_count=self.tma_copy_bytes[\"V\"],\n                defer_sync=True,\n            )\n        else:\n            # PipelineAsync: no thread gating in producer_commit/consumer_release,\n            # so arrive counts must equal actual thread counts\n            pipeline_k = pipeline.PipelineAsync.create(\n                num_stages=self.num_stages,\n                producer_group=load_threads,\n                consumer_group=mma_threads,\n                barrier_storage=storage.mbar_ptr_K.data_ptr(),\n                defer_sync=True,\n            )\n            pipeline_v = pipeline.PipelineAsync.create(\n                num_stages=self.num_stages,\n                producer_group=load_threads,\n                consumer_group=mma_threads,\n                barrier_storage=storage.mbar_ptr_V.data_ptr(),\n                defer_sync=True,\n            )\n\n        # Cluster arrive after barrier init\n        pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mn, is_relaxed=True)\n\n        # ///////////////////////////////////////////////////////////////////////////////\n        # Get shared memory buffer\n        # ///////////////////////////////////////////////////////////////////////////////\n        sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner)\n        sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner)\n        if const_expr(not self.Q_in_regs):\n            sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner)\n        else:\n            sV = storage.sQ.get_tensor(\n                sV_layout.outer, swizzle=sV_layout.inner, dtype=mV.element_type\n            )\n        # Transpose view of V to tensor with layout (head_dim_v, tile_n) for tiled mma\n        sVt = layout_utils.transpose_view(sV)\n        sP = None\n        if const_expr(sP_layout is not None):\n            sP = storage.sP.get_tensor(sP_layout.outer, swizzle=sP_layout.inner)\n        # reuse sQ's data iterator\n        sO = storage.sQ.get_tensor(sO_layout.outer, swizzle=sO_layout.inner, dtype=self.dtype)\n\n        block_info = BlockInfo(\n            self.tile_m,\n            self.tile_n,\n            self.is_causal,\n            self.is_local,\n            False,  # is_split_kv\n            window_size_left,\n            window_size_right,\n            qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,\n        )\n        SeqlenInfoCls = partial(\n            SeqlenInfoQK.create,\n            seqlen_q_static=mQ.shape[0] if const_expr(not self.pack_gqa) else mQ.shape[0][1],\n            seqlen_k_static=mK.shape[0]\n            if const_expr(mPageTable is None)\n            else mK.shape[0] * mPageTable.shape[1],\n            mCuSeqlensQ=mCuSeqlensQ,\n            mCuSeqlensK=mCuSeqlensK,\n            mSeqUsedQ=mSeqUsedQ,\n            mSeqUsedK=mSeqUsedK,\n            # Don't need to pass in tile_mn because we won't access offset_padded\n        )\n        AttentionMaskCls = partial(\n            AttentionMask,\n            self.tile_m,\n            self.tile_n,\n            window_size_left=window_size_left,\n            window_size_right=window_size_right,\n            qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,\n        )\n        TileSchedulerCls = partial(TileScheduler.create, tile_sched_params)\n\n        # Cluster wait before starting\n        pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mn)\n\n        if warp_idx < 4:  # Producer\n            cute.arch.setmaxregister_decrease(self.num_producer_regs)\n            self.load(\n                mQ,\n                mK,\n                mV,\n                sQ,\n                sK,\n                sV,\n                tma_atom_Q,\n                tma_atom_K,\n                tma_atom_V,\n                pipeline_k,\n                pipeline_v,\n                pipeline_q,\n                mPageTable,\n                blocksparse_tensors,\n                block_info,\n                SeqlenInfoCls,\n                TileSchedulerCls,\n            )\n\n        else:  # Consumer\n            cute.arch.setmaxregister_increase(self.num_mma_regs)\n            # ///////////////////////////////////////////////////////////////////////////////\n            # Tile MMA compute thread partitions and allocate accumulators\n            # ///////////////////////////////////////////////////////////////////////////////\n            tidx, _, _ = cute.arch.thread_idx()\n            tidx = tidx - 128\n            self.mma(\n                tiled_mma_qk,\n                tiled_mma_pv,\n                mQ,\n                mO,\n                mLSE,\n                sQ,\n                sK,\n                sVt,\n                sP,\n                sO,\n                learnable_sink,\n                pipeline_k,\n                pipeline_v,\n                pipeline_q,\n                mbar_ptr_Q,\n                gmem_tiled_copy_Q,\n                gmem_tiled_copy_O,\n                tma_atom_O,\n                tidx,\n                softmax_scale_log2,\n                softmax_scale,\n                block_info,\n                SeqlenInfoCls,\n                AttentionMaskCls,\n                TileSchedulerCls,\n                blocksparse_tensors,\n                aux_tensors,\n                fastdiv_mods,\n            )\n\n    @cute.jit\n    def load(\n        self,\n        mQ: cute.Tensor,\n        mK: cute.Tensor,\n        mV: cute.Tensor,\n        sQ: cute.Tensor,\n        sK: cute.Tensor,\n        sV: cute.Tensor,\n        tma_atom_Q: Optional[cute.CopyAtom],\n        tma_atom_K: Optional[cute.CopyAtom],\n        tma_atom_V: Optional[cute.CopyAtom],\n        pipeline_k: pipeline.PipelineAsync,\n        pipeline_v: pipeline.PipelineAsync,\n        pipeline_q: Optional[pipeline.PipelineAsync],\n        mPageTable: Optional[cute.Tensor],\n        blocksparse_tensors: Optional[BlockSparseTensors],\n        block_info: BlockInfo,\n        SeqlenInfoCls: Callable,\n        TileSchedulerCls: Callable,\n    ):\n        warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4\n        tidx, _, _ = cute.arch.thread_idx()\n\n        # TMA: only warp 0 loads. cp_async: all warps load\n        is_load_warp = warp_idx_in_wg == 0 or const_expr(not self.use_tma_KV)\n\n        if is_load_warp:\n            q_producer_phase = Int32(1)\n            kv_producer_state = pipeline.make_pipeline_state(\n                pipeline.PipelineUserType.Producer, self.num_stages\n            )\n            tile_scheduler = TileSchedulerCls()\n            work_tile = tile_scheduler.initial_work_tile_info()\n            while work_tile.is_valid_tile:\n                # if work_tile.is_valid_tile:\n                m_block, head_idx, batch_idx, _ = work_tile.tile_idx\n                seqlen = SeqlenInfoCls(batch_idx)\n                mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx]\n                head_idx_kv = (\n                    head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx\n                )\n\n                load_Q = None\n                if const_expr(self.use_tma_Q):\n                    gQ = cute.local_tile(mQ_cur, (self.tile_m, self.tile_hdim), (m_block, 0))\n                    load_Q, _, _ = copy_utils.tma_get_copy_fn(\n                        tma_atom_Q, 0, cute.make_layout(1), gQ, sQ, single_stage=True\n                    )\n\n                paged_kv_manager = None\n                tma_load_K_fn = None\n                tma_load_V_fn = None\n                if const_expr(self.use_tma_KV):\n                    # === TMA path (non-paged and paged with page_size == n_block_size) ===\n                    if const_expr(mPageTable is not None):\n                        # Paged TMA: keep page dimension indexable\n                        mK_cur = mK[None, None, head_idx_kv, None]\n                        mV_cur = mV[None, None, head_idx_kv, None]\n                        gK = cute.local_tile(mK_cur, (self.tile_n, self.tile_hdim), (0, 0, None))\n                        gV = cute.local_tile(mV_cur, (self.tile_n, self.tile_hdimv), (0, 0, None))\n                    else:\n                        # Non-paged TMA\n                        mK_cur = seqlen.offset_batch_K(mK, batch_idx, dim=3)[\n                            None, None, head_idx_kv\n                        ]\n                        mV_cur = seqlen.offset_batch_K(mV, batch_idx, dim=3)[\n                            None, None, head_idx_kv\n                        ]\n                        gK = cute.local_tile(mK_cur, (self.tile_n, self.tile_hdim), (None, 0))\n                        gV = cute.local_tile(mV_cur, (self.tile_n, self.tile_hdimv), (None, 0))\n                    # TODO: mcast\n                    tma_load_K_fn, _, _ = copy_utils.tma_get_copy_fn(\n                        tma_atom_K, 0, cute.make_layout(1), gK, sK\n                    )\n                    tma_load_K_fn = copy_utils.tma_producer_copy_fn(tma_load_K_fn, pipeline_k)\n                    tma_load_V_fn, _, _ = copy_utils.tma_get_copy_fn(\n                        tma_atom_V, 0, cute.make_layout(1), gV, sV\n                    )\n                    tma_load_V_fn = copy_utils.tma_producer_copy_fn(tma_load_V_fn, pipeline_v)\n                else:\n                    # === cp_async path (paged KV with page_size != n_block_size) ===\n                    paged_kv_manager = PagedKVManager.create(\n                        mPageTable,\n                        mK,\n                        mV,\n                        FastDivmodDivisor(mK.shape[0]),\n                        batch_idx,\n                        head_idx_kv,\n                        tidx,\n                        seqlen.seqlen_k,\n                        0,  # leftpad_k\n                        self.tile_n,\n                        self.tile_hdim,\n                        self.tile_hdimv,\n                        self.num_threads_per_warp_group,\n                        mK.element_type,\n                        arch=self.arch.major * 10 + self.arch.minor,\n                    )\n\n                load_K = partial(\n                    self.load_KV,\n                    tma_load_K_fn,\n                    paged_kv_manager,\n                    sK,\n                    pipeline_kv=pipeline_k,\n                    K_or_V=\"K\",\n                )\n                load_V = partial(\n                    self.load_KV,\n                    tma_load_V_fn,\n                    paged_kv_manager,\n                    sV,\n                    pipeline_kv=pipeline_v,\n                    K_or_V=\"V\",\n                )\n\n                if const_expr(not self.use_block_sparsity):\n                    n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block)\n                    # if cute.arch.thread_idx()[0] == 0:\n                    #     cute.printf(\"m_block = %d, n_block_min: %d, n_block_max: %d\", m_block, n_block_min, n_block_max)\n                    # Clamp n_block to 0 when n_block_max == 0 (can happen with causal\n                    # + pack_gqa when seqlen_k < tile_n). TMA handles n_block=-1\n                    # gracefully (fills zeros), but cp.async would crash on\n                    # out-of-bounds page table access.\n                    n_block = (\n                        n_block_max - 1\n                        if const_expr(self.use_tma_KV)\n                        else cutlass.max(n_block_max - 1, 0)\n                    )\n                    page_idx = (\n                        mPageTable[batch_idx, n_block]\n                        if const_expr(mPageTable is not None and self.use_tma_KV)\n                        else None\n                    )\n\n                    # First iteration: load Q on pipeline_q, K on pipeline_k\n                    pipeline_k.producer_acquire(kv_producer_state)\n                    if const_expr(not self.use_tma_KV):\n                        paged_kv_manager.load_page_table(n_block)\n                    load_K(block=n_block, producer_state=kv_producer_state, page_idx=page_idx)\n                    if const_expr(self.use_tma_Q):\n                        if warp_idx_in_wg == 0:\n                            pipeline_q.producer_acquire_w_index_phase(0, q_producer_phase)\n                            load_Q(tma_bar_ptr=pipeline_q.sync_object_full.get_barrier(0))\n                            q_producer_phase ^= 1\n\n                    if const_expr(not self.intra_wg_overlap or not self.use_tma_KV):\n                        pipeline_v.producer_acquire(kv_producer_state)\n                        load_V(block=n_block, producer_state=kv_producer_state, page_idx=page_idx)\n                        kv_producer_state.advance()\n                        for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1):\n                            n_block = n_block_max - 1 - i - 1\n                            page_idx = (\n                                mPageTable[batch_idx, n_block]\n                                if const_expr(mPageTable is not None and self.use_tma_KV)\n                                else None\n                            )\n                            if const_expr(not self.use_tma_KV):\n                                paged_kv_manager.load_page_table(n_block)\n                            pipeline_k.producer_acquire(kv_producer_state)\n                            load_K(\n                                block=n_block, producer_state=kv_producer_state, page_idx=page_idx\n                            )\n                            pipeline_v.producer_acquire(kv_producer_state)\n                            load_V(\n                                block=n_block, producer_state=kv_producer_state, page_idx=page_idx\n                            )\n                            kv_producer_state.advance()\n                    else:\n                        for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1):\n                            n_block_prev = n_block_max - i - 1\n                            n_block = n_block_prev - 1\n                            page_idx = (\n                                mPageTable[batch_idx, n_block]\n                                if const_expr(mPageTable is not None)\n                                else None\n                            )\n                            page_idx_prev = (\n                                mPageTable[batch_idx, n_block_prev]\n                                if const_expr(mPageTable is not None)\n                                else None\n                            )\n                            kv_producer_state_prev = kv_producer_state.clone()\n                            kv_producer_state.advance()\n                            pipeline_k.producer_acquire(kv_producer_state)\n                            load_K(\n                                block=n_block, producer_state=kv_producer_state, page_idx=page_idx\n                            )\n                            pipeline_v.producer_acquire(kv_producer_state_prev)\n                            load_V(\n                                block=n_block_prev,\n                                producer_state=kv_producer_state_prev,\n                                page_idx=page_idx_prev,\n                            )\n                        n_block = n_block_min\n                        page_idx = (\n                            mPageTable[batch_idx, n_block]\n                            if const_expr(mPageTable is not None)\n                            else None\n                        )\n                        pipeline_v.producer_acquire(kv_producer_state)\n                        load_V(block=n_block, producer_state=kv_producer_state, page_idx=page_idx)\n                        kv_producer_state.advance()\n                else:\n                    # Block sparsity: use TMA closures directly (not paged)\n                    # Load Q on pipeline_q, separate from K/V pipeline\n                    if const_expr(self.use_tma_Q):\n                        if warp_idx_in_wg == 0:\n                            pipeline_q.producer_acquire_w_index_phase(0, q_producer_phase)\n                            load_Q(tma_bar_ptr=pipeline_q.sync_object_full.get_barrier(0))\n                            q_producer_phase ^= 1\n                    kv_producer_state = produce_block_sparse_loads(\n                        blocksparse_tensors,\n                        batch_idx,\n                        head_idx,\n                        m_block,\n                        kv_producer_state,\n                        tma_load_K_fn,\n                        tma_load_V_fn,\n                        pipeline_k,\n                        pipeline_v,\n                        self.intra_wg_overlap,\n                        self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,\n                        self.q_subtile_factor if self.q_subtile_factor is not None else 1,\n                    )\n\n                tile_scheduler.prefetch_next_work()\n                tile_scheduler.advance_to_next_work()\n                work_tile = tile_scheduler.get_current_work()\n                # End of persistent scheduler loop\n\n            # Producer tail is only useful for cluster to avoid early exit of blocks.\n            # We only need producer_tail on V since that's the last that's loaded, we don't\n            # need it for Q (no cluster) and K.\n            pipeline_v.producer_tail(kv_producer_state)\n\n    @cute.jit\n    def load_KV(\n        self,\n        tma_load_fn: Optional[Callable],\n        paged_kv_manager: Optional[PagedKVManager],\n        sX: cute.Tensor,\n        block: Int32,\n        pipeline_kv: pipeline.PipelineAsync,\n        producer_state: pipeline.PipelineState,\n        K_or_V: Literal[\"K\", \"V\"],\n        page_idx: Optional[Int32] = None,\n    ):\n        if const_expr(self.use_tma_KV):\n            src_idx = block if const_expr(page_idx is None) else page_idx\n            tma_load_fn(src_idx=src_idx, producer_state=producer_state)\n        else:\n            paged_kv_manager.load_KV(block, sX[None, None, producer_state.index], K_or_V)\n            cute.arch.cp_async_commit_group()\n            cute.arch.cp_async_wait_group(0)\n            pipeline_kv.producer_commit(producer_state)\n\n    @cute.jit\n    def mma(\n        self,\n        tiled_mma_qk: cute.TiledMma,\n        tiled_mma_pv: cute.TiledMma,\n        # softmax: Softmax,\n        # acc_O: cute.Tensor,\n        mQ: cute.Tensor,\n        mO: cute.Tensor,\n        mLSE: Optional[cute.Tensor],\n        sQ: cute.Tensor,\n        sK: cute.Tensor,\n        sVt: cute.Tensor,\n        sP: Optional[cute.Tensor],\n        sO: cute.Tensor,\n        learnable_sink: Optional[cute.Tensor],\n        pipeline_k: pipeline.PipelineAsync,\n        pipeline_v: pipeline.PipelineAsync,\n        pipeline_q: Optional[pipeline.PipelineAsync],\n        mbar_ptr_Q: cutlass.Pointer,\n        gmem_tiled_copy_Q: cute.TiledCopy,\n        gmem_tiled_copy_O: cute.TiledCopy,\n        tma_atom_O: Optional[cute.CopyAtom],\n        tidx: Int32,\n        softmax_scale_log2: Float32,\n        softmax_scale: Optional[Float32],\n        block_info: BlockInfo,\n        SeqlenInfoCls: Callable,\n        AttentionMaskCls: Callable,\n        TileSchedulerCls: Callable,\n        blocksparse_tensors: Optional[BlockSparseTensors],\n        aux_tensors: Optional[list],\n        fastdiv_mods=None,\n    ):\n        warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group)\n        warp_group_thread_layout = cute.make_layout(\n            self.num_wg_mma, stride=self.num_threads_per_warp_group\n        )\n        thr_mma_qk = tiled_mma_qk.get_slice(tidx)\n        wg_mma_qk = tiled_mma_qk.get_slice(warp_group_thread_layout(warp_group_idx))\n        wg_mma_pv = tiled_mma_pv.get_slice(warp_group_thread_layout(warp_group_idx))\n        _, tSrQ, tSrK = sm90_utils.partition_fragment_ABC(\n            wg_mma_qk, (self.tile_m, self.tile_n, self.tile_hdim), sQ, sK\n        )\n        mma_qk_fn = partial(\n            sm90_utils.gemm_zero_init, tiled_mma_qk, (self.tile_m, self.tile_n), tSrQ, tSrK\n        )\n        acc_O, tOrP, tOrVt = sm90_utils.partition_fragment_ABC(\n            wg_mma_pv, (self.tile_m, self.tile_hdimv, self.tile_n), sP, sVt\n        )\n        mma_pv_fn = partial(sm90_utils.gemm_w_idx, tiled_mma_pv, acc_O, tOrP, tOrVt)\n\n        # ///////////////////////////////////////////////////////////////////////////////\n        # Smem copy atom tiling\n        # ///////////////////////////////////////////////////////////////////////////////\n        smem_copy_atom_P = utils.get_smem_store_atom(\n            self.arch.major * 10 + self.arch.minor, self.dtype\n        )\n        smem_thr_copy_P = cute.make_tiled_copy_C(smem_copy_atom_P, tiled_mma_qk).get_slice(tidx)\n        tPsP = smem_thr_copy_P.partition_D(sP) if const_expr(sP is not None) else None\n        smem_copy_params = SimpleNamespace(smem_thr_copy_P=smem_thr_copy_P, tPsP=tPsP)\n\n        self.mma_init()\n\n        mma_q_consumer_phase = Int32(0)\n        q_consumer_phase = Int32(0)\n        kv_consumer_state = pipeline.make_pipeline_state(\n            pipeline.PipelineUserType.Consumer, self.num_stages\n        )\n\n        tile_scheduler = TileSchedulerCls()\n        work_tile = tile_scheduler.initial_work_tile_info()\n        softmax = Softmax.create(\n            softmax_scale_log2,\n            num_rows=acc_O.shape[0][0] * acc_O.shape[1],\n            softmax_scale=softmax_scale,\n        )\n\n        # For RescaleOBeforeGemm: persistent scores_scale across iterations\n        scores_scale = None\n        if const_expr(self.rescale_O_before_gemm):\n            scores_scale = cute.make_rmem_tensor_like(softmax.row_max, Float32)\n\n        mma_one_n_block_all = partial(\n            self.mma_one_n_block_intrawg_overlap\n            if const_expr(self.intra_wg_overlap)\n            else self.mma_one_n_block,\n            mma_qk_fn=mma_qk_fn,\n            pipeline_k=pipeline_k,\n            pipeline_v=pipeline_v,\n            acc_O=acc_O,\n            tOrP=tOrP,\n            smem_copy_params=smem_copy_params,\n            check_inf=True,\n            scores_scale=scores_scale,\n        )\n\n        process_first_half_block = partial(\n            self.first_half_block_overlap,\n            mma_qk_fn=mma_qk_fn,\n            pipeline_k=pipeline_k,\n            tOrP=tOrP,\n            smem_copy_params=smem_copy_params,\n            scores_scale=scores_scale,\n            softmax=softmax,\n            acc_O=acc_O,\n        )\n        process_last_half_block = partial(\n            self.last_half_block_overlap,\n            pipeline_v=pipeline_v,\n            mma_pv_fn=mma_pv_fn,\n            scores_scale=scores_scale,\n            softmax=softmax,\n            acc_O=acc_O,\n        )\n        while work_tile.is_valid_tile:\n            # if work_tile.is_valid_tile:\n\n            # shape: (atom_v_m * rest_m)\n            m_block, head_idx, batch_idx, _ = work_tile.tile_idx\n            seqlen = SeqlenInfoCls(batch_idx)\n\n            # Recompute fastdiv_mods if necessary for varlen with aux_tensors\n            recompute_fastdiv_mods_q = cutlass.const_expr(\n                aux_tensors is not None and (seqlen.has_cu_seqlens_q or seqlen.has_seqused_q)\n            )\n            recompute_fastdiv_mods_k = cutlass.const_expr(\n                aux_tensors is not None and (seqlen.has_cu_seqlens_k or seqlen.has_seqused_k)\n            )\n            if cutlass.const_expr(fastdiv_mods is not None):\n                seqlen_q_divmod, seqlen_k_divmod = fastdiv_mods\n                fastdiv_mods = (\n                    seqlen_q_divmod\n                    if not recompute_fastdiv_mods_q\n                    else FastDivmodDivisor(seqlen.seqlen_q),\n                    seqlen_k_divmod\n                    if not recompute_fastdiv_mods_k\n                    else FastDivmodDivisor(seqlen.seqlen_k),\n                )\n\n            mask = AttentionMaskCls(seqlen)\n            mask_fn = partial(\n                mask.apply_mask,\n                batch_idx=batch_idx,\n                head_idx=head_idx,\n                m_block=m_block,\n                thr_mma=thr_mma_qk,\n                mask_causal=self.is_causal,\n                mask_local=self.is_local,\n                aux_tensors=aux_tensors,\n                fastdiv_mods=fastdiv_mods,\n            )\n            score_mod_fn = None\n            if const_expr(self.score_mod is not None):\n                score_mod_fn = partial(\n                    self.apply_score_mod,\n                    thr_mma_qk,\n                    batch_idx,\n                    head_idx,\n                    m_block,\n                    softmax_scale=softmax_scale,\n                    aux_tensors=aux_tensors,\n                    fastdiv_mods=fastdiv_mods,\n                )\n            mma_one_n_block = partial(\n                mma_one_n_block_all, seqlen=seqlen, softmax=softmax, score_mod_fn=score_mod_fn\n            )\n            # Load Q if not TMA_Q\n            if const_expr(not self.use_tma_Q):\n                pack_gqa = PackGQA(\n                    self.tile_m, self.tile_hdim, self.check_hdim_oob, self.qhead_per_kvhead\n                )\n                mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx]\n                # gmem_thr_copy_Q = gmem_tiled_copy_Q.get_slice(tidx)\n                # gQ = cute.local_tile(mQ_cur, (self.tile_m, self.tile_hdim), (m_block, 0))\n                # self.load_Q(gmem_thr_copy_Q, gQ, sQ, m_block, seqlen=seqlen.seqlen_q,\n                #             headdim=mQ.shape[1])\n                pack_gqa.load_Q(mQ_cur, sQ, gmem_tiled_copy_Q, tidx, m_block, seqlen.seqlen_q)\n                cute.arch.cp_async_mbarrier_arrive_noinc(mbar_ptr_Q)\n\n            n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block)\n            if const_expr(self.use_tma_Q):\n                pipeline_q.consumer_wait_w_index_phase(0, mma_q_consumer_phase)\n            else:\n                cute.arch.mbarrier_wait(mbar_ptr_Q, phase=q_consumer_phase)\n                q_consumer_phase ^= 1\n            # For performance reason, we separate out two kinds of iterations:\n            # those that need masking on S, and those that don't.\n            # We need masking on S for the very last block when K and V has length not multiple of tile_n.\n            # We also need masking on S if it's causal, for the last several blocks.\n            # softmax.reset()  # Don't need reset as we explicitly call softmax w is_first=True\n            O_should_accumulate = False\n\n            # ==========================================\n            # MAINLOOP\n            # ==========================================\n            if const_expr(not self.use_block_sparsity):\n                # ==========================================\n                # No block-sparsity (original path)\n                # ==========================================\n                # First iteration with seqlen masking\n                if const_expr(self.intra_wg_overlap):\n                    kv_consumer_state = process_first_half_block(\n                        n_block=n_block_max - 1,\n                        seqlen=seqlen,\n                        kv_consumer_state=kv_consumer_state,\n                        mask_fn=partial(mask_fn, mask_mod=self.mask_mod),\n                        score_mod_fn=score_mod_fn,\n                        is_first_block=True,\n                    )\n                else:\n                    self.warp_scheduler_barrier_sync()\n                    kv_consumer_state = mma_one_n_block(\n                        kv_consumer_state,\n                        n_block=n_block_max - 1,\n                        seqlen=seqlen,\n                        mma_pv_fn=partial(mma_pv_fn, zero_init=True),\n                        is_first_n_block=True,\n                        mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=True),\n                    )\n                    O_should_accumulate = True\n                # if cute.arch.thread_idx()[0] == 128: cute.printf(\"m_block = {}, n_block_max = {}, n_block_min = {}\", m_block, n_block_max, n_block_min)\n                n_block_max -= 1\n                # Next couple of iterations with causal masking\n                if const_expr(self.is_causal or self.is_local):\n                    n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask(\n                        seqlen, m_block, n_block_min\n                    )\n                    # if cute.arch.thread_idx()[0] == 128: cute.printf(\"n_block_min_causal_local_mask = {}\", n_block_min_causal_local_mask)\n                    for n_tile in cutlass.range(\n                        n_block_max - n_block_min_causal_local_mask, unroll=1\n                    ):\n                        kv_consumer_state = mma_one_n_block(\n                            kv_consumer_state,\n                            n_block=n_block_max - 1 - n_tile,\n                            seqlen=seqlen,\n                            mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),\n                            mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False),\n                        )\n                        O_should_accumulate = True\n                    n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask)\n                # The remaining iterations have no masking\n                n_block_min_before_local_mask = block_info.get_n_block_min_before_local_mask(\n                    seqlen, m_block, n_block_min\n                )\n                # if cute.arch.thread_idx()[0] == 128: cute.printf(\"n_block_min_before_local_mask = {}, n_block_min = {}\", n_block_min_before_local_mask, n_block_min)\n                for n_tile in cutlass.range(n_block_max - n_block_min_before_local_mask, unroll=1):\n                    kv_consumer_state = mma_one_n_block(\n                        kv_consumer_state,\n                        n_block=n_block_max - 1 - n_tile,\n                        seqlen=seqlen,\n                        mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),\n                        mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False),\n                    )\n                    O_should_accumulate = True\n                # Separate iterations with local masking on the left\n                if const_expr(self.is_local and block_info.window_size_left is not None):\n                    n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask)\n                    for n_tile in cutlass.range(n_block_max - n_block_min, unroll=1):\n                        kv_consumer_state = mma_one_n_block(\n                            kv_consumer_state,\n                            n_block=n_block_max - 1 - n_tile,\n                            seqlen=seqlen,\n                            mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),\n                            mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False),\n                        )\n                        O_should_accumulate = True\n                # Last \"half\" iteration\n                if const_expr(self.intra_wg_overlap):\n                    kv_consumer_state = process_last_half_block(\n                        kv_consumer_state=kv_consumer_state,\n                        zero_init=not O_should_accumulate,\n                    )\n                    O_should_accumulate = True\n                else:\n                    self.warp_scheduler_barrier_arrive()\n\n            else:\n                # ==========================================\n                # Block sparsity\n                # ==========================================\n                kv_consumer_state, O_should_accumulate, processed_any = consume_block_sparse_loads(\n                    blocksparse_tensors,\n                    batch_idx,\n                    head_idx,\n                    m_block,\n                    seqlen,\n                    kv_consumer_state,\n                    mma_pv_fn,\n                    mma_one_n_block,\n                    process_first_half_block,\n                    process_last_half_block,\n                    mask_fn,\n                    score_mod_fn,\n                    O_should_accumulate,\n                    self.mask_mod,\n                    fastdiv_mods,\n                    self.intra_wg_overlap,\n                    self.warp_scheduler_barrier_sync,\n                    self.warp_scheduler_barrier_arrive,\n                    self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,\n                    self.q_subtile_factor if self.q_subtile_factor is not None else 1,\n                )\n\n                # Handle empty case (when no blocks to process)\n                if not processed_any:\n                    softmax.reset()\n                    acc_O.fill(0.0)\n\n            sink_val = None\n            if const_expr(learnable_sink is not None):\n                if const_expr(not self.pack_gqa):\n                    sink_val = Float32(learnable_sink[head_idx])\n                else:  # Each thread might have a different sink value due to different q_head\n                    sink_val = cute.make_rmem_tensor_like(softmax.row_max, Float32)\n                    cS = cute.make_identity_tensor((self.tile_m, self.tile_n))\n                    tScS_mn = layout_utils.reshape_acc_to_mn(thr_mma_qk.partition_C(cS))\n                    for r in cutlass.range(cute.size(sink_val), unroll_full=True):\n                        row = m_block * self.tile_m + tScS_mn[r][0]\n                        q_head_idx = row % self.qhead_per_kvhead + head_idx * self.qhead_per_kvhead\n                        sink_val[r] = Float32(learnable_sink[q_head_idx])\n\n            # normalize acc_O by row_sum and calculate the lse\n            row_scale = softmax.finalize(sink_val=sink_val)\n            softmax.rescale_O(acc_O, row_scale)\n\n            # ///////////////////////////////////////////////////////////////////////////////\n            # Epilogue\n            # ///////////////////////////////////////////////////////////////////////////////\n            self.epilogue(\n                acc_O,\n                softmax.row_sum,\n                mO,\n                mLSE,\n                sO,\n                seqlen,\n                gmem_tiled_copy_O,\n                tma_atom_O,\n                tiled_mma_pv,\n                tidx,\n                m_block,\n                head_idx,\n                batch_idx,\n            )\n\n            # Release Q pipeline so the producer can load the next tile's Q\n            if const_expr(self.use_tma_Q):\n                pipeline_q.consumer_release_w_index(0)\n                mma_q_consumer_phase ^= 1\n\n            tile_scheduler.advance_to_next_work()\n            work_tile = tile_scheduler.get_current_work()\n\n    @cute.jit\n    def first_half_block_overlap(\n        self,\n        n_block: Int32,\n        mma_qk_fn: Callable,\n        kv_consumer_state,\n        pipeline_k,\n        tOrP: cute.Tensor,\n        smem_copy_params: SimpleNamespace,\n        softmax: Softmax,\n        seqlen: SeqlenInfoQK,\n        scores_scale: Optional[cute.Tensor] = None,\n        acc_O: Optional[cute.Tensor] = None,\n        mask_fn: Callable = None,\n        score_mod_fn: Optional[Callable] = None,\n        is_first_block: bool = False,\n    ):\n        \"\"\"Processes the first half block when using intra-warpgroup-overlap\"\"\"\n\n        pipeline_k.consumer_wait(kv_consumer_state, pipeline_k.consumer_try_wait(kv_consumer_state))\n        acc_S = mma_qk_fn(B_idx=kv_consumer_state.index, wg_wait=0)\n        pipeline_k.consumer_release(kv_consumer_state)\n\n        # Apply score modification if present\n        if const_expr(score_mod_fn is not None):\n            score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen)\n\n        # Apply mask; mask_seqlen always True for first block\n        # Caveat: if full block further right than mask block, seqlen masking is redundant;\n        # however, masking is being applied anyway, so essentially no perf hit\n        mask_fn(acc_S, n_block=n_block, mask_seqlen=True)\n\n        row_scale = softmax.online_softmax(acc_S, is_first=is_first_block)\n\n        tOrP_acc = layout_utils.reshape_acc_to_frgA(acc_S)\n        tOrP_cur = (\n            tOrP\n            if const_expr(self.mma_pv_is_rs)\n            else cute.make_rmem_tensor_like(tOrP_acc, self.dtype)\n        )\n        tOrP_cur.store(tOrP_acc.load().to(self.dtype))\n\n        if const_expr(not self.mma_pv_is_rs):\n            tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur)\n            cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP)\n            # Fence and barrier to make smem store visible to WGMMA\n            cute.arch.fence_view_async_shared()\n            cute.arch.sync_warp()\n\n        # For RescaleOBeforeGemm: initialize acc_O\n        if const_expr(self.rescale_O_before_gemm):\n            acc_O.fill(0.0)\n            scores_scale.store(row_scale.load())\n\n        return kv_consumer_state\n\n    @cute.jit\n    def last_half_block_overlap(\n        self,\n        kv_consumer_state,\n        pipeline_v,\n        mma_pv_fn: Callable,\n        zero_init: bool,\n        scores_scale: Optional[cute.Tensor] = None,\n        softmax: Optional[Softmax] = None,\n        acc_O: Optional[cute.Tensor] = None,\n    ):\n        \"\"\"Processes the final PV GEMM when using intra-warpgroup-overlap\"\"\"\n\n        # For RescaleOBeforeGemm: rescale O before the final PV GEMM\n        if const_expr(self.rescale_O_before_gemm):\n            softmax.rescale_O(acc_O, scores_scale)\n\n        pipeline_v.consumer_wait(kv_consumer_state, pipeline_v.consumer_try_wait(kv_consumer_state))\n        mma_pv_fn(B_idx=kv_consumer_state.index, zero_init=zero_init, wg_wait=0)\n        pipeline_v.consumer_release(kv_consumer_state)\n        kv_consumer_state.advance()\n        return kv_consumer_state\n\n    @cute.jit\n    def mma_one_n_block(\n        self,\n        smem_pipe_read: pipeline.PipelineState | pipeline_custom.PipelineStateSimple,\n        n_block: Int32,\n        mma_qk_fn: Callable,\n        mma_pv_fn: Callable,\n        pipeline_k: pipeline.PipelineAsync,\n        pipeline_v: pipeline.PipelineAsync,\n        acc_O: cute.Tensor,\n        tOrP: cute.Tensor,\n        smem_copy_params: SimpleNamespace,\n        softmax: Softmax,\n        seqlen: SeqlenInfoQK,\n        scores_scale: Optional[cute.Tensor] = None,  # not used\n        score_mod_fn: Optional[Callable] = None,\n        mask_fn: Optional[Callable] = None,\n        is_first_n_block: cutlass.Constexpr = False,\n        check_inf: cutlass.Constexpr = True,\n    ):\n        pipeline_k.consumer_wait(smem_pipe_read, pipeline_k.consumer_try_wait(smem_pipe_read))\n        # S = Q @ K.T\n        acc_S = mma_qk_fn(B_idx=smem_pipe_read.index, wg_wait=-1)\n        self.warp_scheduler_barrier_arrive()\n        warpgroup.wait_group(0)\n        pipeline_k.consumer_release(smem_pipe_read)\n\n        # handle score mods and masking\n        if const_expr(score_mod_fn is not None):\n            score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen)\n        if const_expr(mask_fn is not None):\n            mask_fn(acc_S=acc_S, n_block=n_block)\n\n        row_scale = softmax.online_softmax(acc_S, is_first=is_first_n_block, check_inf=check_inf)\n        # if cute.arch.thread_idx()[0] == 0: cute.print_tensor(layout_utils.reshape_acc_to_mn(acc_S))\n        tOrP_acc = layout_utils.reshape_acc_to_frgA(acc_S)\n        tOrP_cur = (\n            tOrP\n            if const_expr(self.mma_pv_is_rs)\n            else cute.make_rmem_tensor_like(tOrP_acc, self.dtype)\n        )\n        # tOrP.store(tOrP_acc.load().to(self.dtype))\n        # the \"to(self.dtype)\" conversion fails to vectorize for block sizes other\n        # than 128 x 128, i.e. it calls convert on 1 fp32 element at a time instead of\n        # 2 elements. So we just call ptx directly.\n        utils.cvt_f16(tOrP_acc, tOrP_cur)\n        if const_expr(not self.mma_pv_is_rs):\n            tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur)\n            cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP)\n        softmax.rescale_O(acc_O, row_scale)\n        if const_expr(not self.mma_pv_is_rs):\n            # Fence and barrier to make sure smem store is visible to WGMMA\n            cute.arch.fence_view_async_shared()\n            cute.arch.sync_warp()  # Only need syncwarp since each warp is using its own P values for MmaPV\n        pipeline_v.consumer_wait(smem_pipe_read, pipeline_v.consumer_try_wait(smem_pipe_read))\n        self.warp_scheduler_barrier_sync()\n        # O += P @ V\n        mma_pv_fn(B_idx=smem_pipe_read.index, wg_wait=0)\n        pipeline_v.consumer_release(smem_pipe_read)\n        smem_pipe_read.advance()\n        return smem_pipe_read\n\n    @cute.jit\n    def mma_one_n_block_intrawg_overlap(\n        self,\n        smem_pipe_read: pipeline.PipelineState | pipeline_custom.PipelineStateSimple,\n        n_block: Int32,\n        mma_qk_fn: Callable,\n        mma_pv_fn: Callable,\n        pipeline_k: pipeline.PipelineAsync,\n        pipeline_v: pipeline.PipelineAsync,\n        acc_O: cute.Tensor,\n        tOrP: cute.Tensor,\n        smem_copy_params: SimpleNamespace,\n        softmax: Softmax,\n        seqlen: SeqlenInfoQK,\n        scores_scale: Optional[cute.Tensor] = None,\n        score_mod_fn: Optional[Callable] = None,\n        mask_fn: Optional[Callable] = None,\n        check_inf: cutlass.Constexpr = True,\n    ):\n        smem_pipe_read_v = smem_pipe_read.clone()\n        smem_pipe_read.advance()\n        pipeline_k.consumer_wait(smem_pipe_read, pipeline_k.consumer_try_wait(smem_pipe_read))\n        self.warp_scheduler_barrier_sync()\n        # S = Q @ K.T\n        acc_S = mma_qk_fn(B_idx=smem_pipe_read.index, wg_wait=-1)\n        # RescaleOBeforeGemm: rescale O while QK GEMM is in flight, before PV GEMM\n        if const_expr(self.rescale_O_before_gemm):\n            softmax.rescale_O(acc_O, scores_scale)\n        pipeline_v.consumer_wait(smem_pipe_read_v, pipeline_v.consumer_try_wait(smem_pipe_read_v))\n        # O += P @ V\n        mma_pv_fn(B_idx=smem_pipe_read_v.index, wg_wait=-1)\n        self.warp_scheduler_barrier_arrive()\n        warpgroup.wait_group(1)\n        pipeline_k.consumer_release(smem_pipe_read)\n\n        # handle score mods and masking\n        if const_expr(score_mod_fn is not None):\n            score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen)\n        if const_expr(mask_fn is not None):\n            mask_fn(acc_S=acc_S, n_block=n_block)\n        # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(layout_utils.reshape_acc_to_mn(acc_S))\n\n        row_scale = softmax.online_softmax(acc_S, check_inf=check_inf)\n        warpgroup.wait_group(0)\n        pipeline_v.consumer_release(smem_pipe_read_v)\n        tOrP_acc = layout_utils.reshape_acc_to_frgA(acc_S)\n        tOrP_cur = (\n            tOrP\n            if const_expr(self.mma_pv_is_rs)\n            else cute.make_rmem_tensor_like(tOrP_acc, self.dtype)\n        )\n        # tOrP_cur.store(tOrP_acc.load().to(self.dtype))\n        # the \"to(self.dtype)\" conversion fails to vectorize for block sizes other\n        # than 128 x 128, i.e. it calls convert on 1 fp32 element at a time instead of\n        # 2 elements. So we just call ptx directly.\n        utils.cvt_f16(tOrP_acc, tOrP_cur)\n        if const_expr(not self.mma_pv_is_rs):\n            tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur)\n            cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP)\n        if const_expr(not self.rescale_O_before_gemm):\n            softmax.rescale_O(acc_O, row_scale)\n        if const_expr(self.rescale_O_before_gemm):\n            scores_scale.store(row_scale.load())\n        if const_expr(not self.mma_pv_is_rs):\n            # Fence and barrier to make sure smem store is visible to WGMMA\n            cute.arch.fence_view_async_shared()\n            cute.arch.sync_warp()  # Only need syncwarp since each warp is using its own P values for MmaPV\n        return smem_pipe_read\n\n    @cute.jit\n    def mma_init(self):\n        warp_group_idx = utils.canonical_warp_group_idx(sync=False)\n        if const_expr(self.use_scheduler_barrier):\n            if warp_group_idx == 1:\n                cute.arch.barrier_arrive(\n                    barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1),\n                    number_of_threads=2 * self.num_threads_per_warp_group,\n                )\n\n    @cute.jit\n    def apply_score_mod(\n        self,\n        thr_mma_qk,\n        batch_idx,\n        head_idx,\n        m_block,\n        acc_S,\n        n_block,\n        softmax_scale,\n        seqlen,\n        aux_tensors: Optional[list] = None,\n        fastdiv_mods=None,\n    ):\n        # Prepare index tensor\n        cS = cute.make_identity_tensor((self.tile_m, self.tile_n))\n        cS = cute.domain_offset((m_block * self.tile_m, n_block * self.tile_n), cS)\n        tScS = thr_mma_qk.partition_C(cS)\n\n        apply_score_mod_inner(\n            acc_S,\n            tScS,\n            self.score_mod,\n            batch_idx,\n            head_idx,\n            softmax_scale,\n            self.vec_size,\n            self.qk_acc_dtype,\n            aux_tensors,\n            fastdiv_mods,\n            seqlen_info=seqlen,\n            constant_q_idx=None,\n            qhead_per_kvhead=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,\n        )\n\n    def warp_scheduler_barrier_sync(self):\n        if const_expr(self.use_scheduler_barrier):\n            cute.arch.barrier(\n                barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1)\n                - 1\n                + utils.canonical_warp_group_idx(sync=False),\n                number_of_threads=2 * self.num_threads_per_warp_group,\n            )\n\n    def warp_scheduler_barrier_arrive(self):\n        if const_expr(self.use_scheduler_barrier):\n            assert self.num_wg_mma in [2, 3]\n            cur_wg = utils.canonical_warp_group_idx(sync=False) - 1\n            if const_expr(self.num_wg_mma == 2):\n                next_wg = 1 - cur_wg\n            else:\n                t = cur_wg + 1\n                next_wg = t % self.num_wg_mma\n            cute.arch.barrier_arrive(\n                barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + next_wg,\n                number_of_threads=2 * self.num_threads_per_warp_group,\n            )\n"
  },
  {
    "path": "flash_attn/cute/interface.py",
    "content": "# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n# [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'll need install nvidia-cutlass-dsl==4.2.0.\n\n# Supported features:\n# - BF16 & FP16 dtype\n# - noncausal & causal attention\n# - MHA, GQA, MQA\n# - hdim 64, 96, 128.\n# - (hdim_qk, hdim_v) = (192, 128) for Blackwell (i.e. DeepSeek shape)\n# - varlen\n# - sliding window\n# - bwd pass for Ampere (will also run on Hopper/Blackwell, but will be slow)\n\n# Features not supported yet:\n# - split (i.e. FlashDecoding)\n# - tuned block sizes\n# - paged KV\n# - append KV to existing KV cache\n# - FP8\n# - bwd pass optimized for Hopper/Blackwell\n\nimport os\nimport math\nfrom dataclasses import dataclass\nfrom functools import lru_cache\nfrom typing import Optional, Tuple, Callable\n\nimport torch\n\n\nimport cuda.bindings.driver as cuda\n\nimport cutlass\nimport cutlass.cute as cute\nfrom cutlass import Int32, Float32\nfrom quack.compile_utils import make_fake_tensor as fake_tensor\nfrom flash_attn.cute.cache_utils import get_jit_cache\nfrom flash_attn.cute.testing import is_fake_mode\n\n\nif os.environ.get(\"CUTE_DSL_PTXAS_PATH\", None) is not None:\n    from flash_attn.cute import cute_dsl_ptxas  # noqa: F401\n\n    # Patch to dump ptx and then use system ptxas to compile to cubin\n    cute_dsl_ptxas.patch()\n\n\nfrom flash_attn.cute import utils\nfrom flash_attn.cute import fa_logging\nfrom flash_attn.cute.cute_dsl_utils import (\n    to_cute_tensor, to_cute_aux_tensor, get_aux_tensor_metadata, get_broadcast_dims,\n)\nfrom flash_attn.cute.flash_fwd import FlashAttentionForwardSm80\nfrom flash_attn.cute.flash_fwd_sm90 import FlashAttentionForwardSm90\nfrom flash_attn.cute.flash_fwd_sm100 import FlashAttentionForwardSm100\nfrom flash_attn.cute.flash_fwd_sm120 import FlashAttentionForwardSm120\nfrom flash_attn.cute.flash_bwd_preprocess import FlashAttentionBackwardPreprocess\nfrom flash_attn.cute.flash_bwd import FlashAttentionBackwardSm80\nfrom flash_attn.cute.flash_bwd_sm90 import FlashAttentionBackwardSm90\nfrom flash_attn.cute.flash_bwd_sm100 import FlashAttentionBackwardSm100\nfrom flash_attn.cute.flash_bwd_sm120 import FlashAttentionBackwardSm120\nfrom flash_attn.cute.flash_bwd_postprocess import FlashAttentionBackwardPostprocess\nfrom flash_attn.cute.flash_fwd_combine import FlashAttentionForwardCombine\n\nfrom flash_attn.cute.block_sparsity import (\n    BlockSparseTensorsTorch,\n    to_cute_block_sparse_tensors,\n    normalize_block_sparse_config,\n    normalize_block_sparse_config_bwd,\n)\n\ndef _parse_arch_str(arch_str):\n    \"\"\"Parse arch string (e.g. 'sm_80', 'sm_90a', '80', '100') to int (e.g. 80, 90, 100).\"\"\"\n    import re\n    match = re.match(r\"^(?:sm_?|SM_?)?(\\d+)(\\d)([af]?)$\", arch_str)\n    if not match:\n        raise ValueError(f\"Invalid arch format: {arch_str}\")\n    major, minor, _ = match.groups()\n    return int(major) * 10 + int(minor)\n\n\n@lru_cache(maxsize=None)\ndef _get_device_arch():\n    \"\"\"Cached device arch check.\n\n    Override with FLASH_ATTENTION_ARCH (e.g. 'sm_80' or '80') to select which\n    kernel path to use (SM80/SM90/SM100/SM120) independently of the compilation\n    target (CUTE_DSL_ARCH).\n\n    For CPU-only compilation (no GPU), set both:\n      FLASH_ATTENTION_ARCH=sm_80  (kernel selection)\n      CUTE_DSL_ARCH=sm_80         (compilation target)\n    \"\"\"\n    arch_override = os.environ.get(\"FLASH_ATTENTION_ARCH\", None)\n    if arch_override is not None:\n        return _parse_arch_str(arch_override)\n    major, minor = torch.cuda.get_device_capability()\n    return major * 10 + int(minor)\n\n\ndef _validate_head_dims(head_dim: int, head_dim_v: int, compute_capability: int, alignment: int) -> None:\n    \"\"\"Validate head dimension constraints based on compute capability.\"\"\"\n    is_deepseek_shape = head_dim == 192 and head_dim_v == 128\n    is_standard_range = 8 <= head_dim <= 128 and 8 <= head_dim_v <= 128\n\n    is_sm90_range = 8 <= head_dim <= 256 and 8 <= head_dim_v <= 256\n    if compute_capability == 9:\n        assert is_sm90_range and head_dim % alignment == 0 and head_dim_v % alignment == 0, (\n            f\"(head_dim, head_dim_v)=({head_dim}, {head_dim_v}) is not supported on SM90. \"\n            f\"head_dim and head_dim_v must be between 8 and 256 and divisible by {alignment}.\"\n        )\n    elif compute_capability in [10, 11]:\n        assert (is_standard_range or is_deepseek_shape) and head_dim % alignment == 0 and head_dim_v % alignment == 0, (\n            f\"(head_dim, head_dim_v)=({head_dim}, {head_dim_v}) is not supported on SM100/SM110. \"\n            f\"head_dim and head_dim_v must be between 8 and 128 and divisible by {alignment}, or (192, 128) for DeepSeek.\"\n        )\n\n\n@dataclass(frozen=True)\nclass FwdConfig:\n    m_block_size: int\n    n_block_size: int\n    mma_pv_is_rs: bool\n    intra_wg_overlap: bool\n\n\ndef _tile_size_fwd_sm90(head_dim, head_dim_v, is_causal, is_local, use_block_sparsity):\n    \"\"\"Return FwdConfig for SM90 forward.\n\n    Tile sizes and flags based on tile_size_fwd_sm90 in hopper/tile_size.h, adjusted\n    for the Python kernel's different register/smem tradeoffs (benchmarked on H100 SXM).\n    \"\"\"\n    if head_dim <= 64:\n        # C++: 192×192 non-causal, 192×128 causal/local.\n        # Python: 192×128 RS+OL is consistently best across seqlens.\n        return FwdConfig(192, 128, True, True)\n    elif head_dim <= 96:\n        # C++: 192×144 noRS+OL for all cases.\n        # Python: RS is catastrophic with 192× tiles (~300 vs ~600 TFLOPS).\n        # noRS+OL is always required. Causal: 192×128 slightly better short seqlen.\n        if is_causal or is_local:\n            return FwdConfig(192, 128, False, True)\n        else:\n            return FwdConfig(192, 144, False, True)\n    elif head_dim <= 128:\n        return FwdConfig(128, 128, True, True)\n    elif head_dim <= 192:\n        tile_n = 96 if is_local else (128 if head_dim_v <= 128 else 112)\n        return FwdConfig(128, tile_n, True, True)\n    else:  # hdim 256\n        tile_n = 64 if is_local else 80\n        return FwdConfig(128, tile_n, True, True)\n\n\n@dataclass(frozen=True)\nclass BwdConfig:\n    m_block_size: int\n    n_block_size: int\n    num_stages_Q: int\n    num_stages_dO: int\n    num_stages_PdS: int\n    SdP_swapAB: bool\n    dKV_swapAB: bool\n    dQ_swapAB: bool\n    AtomLayoutMSdP: int\n    AtomLayoutNdKV: int\n    AtomLayoutMdQ: int\n    num_wg: int = 2  # MMA warp groups (total threads = (num_wg + 1) * 128)\n    dQ_single_wg: bool = False\n\n\ndef _tile_size_bwd_sm90(head_dim, head_dim_v, causal, local):\n    \"\"\"Return BwdConfig for SM90.\n\n    Configs based on C++ FA3 hopper/flash_bwd_launch_template.h,\n    benchmarked on H100 SXM.\n    \"\"\"\n    if head_dim <= 64:\n        # C++ FA3: 128, 128, 64, ..., 2, 2, true, false, false, 2, 1, 2, 2\n        return BwdConfig(\n            m_block_size=128, n_block_size=128,\n            num_stages_Q=2, num_stages_dO=2, num_stages_PdS=2,\n            SdP_swapAB=True, dKV_swapAB=False, dQ_swapAB=False,\n            AtomLayoutMSdP=1, AtomLayoutNdKV=2, AtomLayoutMdQ=2,\n        )\n    elif head_dim <= 96:\n        # C++ FA3: 64, 128, 96, dQ_swapAB=False\n        return BwdConfig(\n            m_block_size=64, n_block_size=128,\n            num_stages_Q=2, num_stages_dO=2, num_stages_PdS=2,\n            SdP_swapAB=True, dKV_swapAB=False, dQ_swapAB=False,\n            AtomLayoutMSdP=1, AtomLayoutNdKV=2, AtomLayoutMdQ=1,\n            dQ_single_wg=True,\n        )\n    elif head_dim <= 128:\n        # C++ FA3: causal/local: 64, 128; non-causal: 80, 128 with dQ_swapAB\n        is_causal_or_local = causal or local\n        m_block_size = 64 if is_causal_or_local else 80\n        return BwdConfig(\n            m_block_size=m_block_size,\n            n_block_size=128,\n            num_stages_Q=2, num_stages_dO=2, num_stages_PdS=2,\n            SdP_swapAB=True, dKV_swapAB=False,\n            dQ_swapAB=m_block_size % 64 != 0,\n            AtomLayoutMSdP=1, AtomLayoutNdKV=2, AtomLayoutMdQ=1,\n        )\n    elif head_dim <= 192:\n        hdimv128 = head_dim_v <= 128\n        if hdimv128:\n            return BwdConfig(\n                m_block_size=64, n_block_size=96,\n                num_stages_Q=2, num_stages_dO=2, num_stages_PdS=1,\n                SdP_swapAB=False, dKV_swapAB=True, dQ_swapAB=False,\n                AtomLayoutMSdP=1, AtomLayoutNdKV=2, AtomLayoutMdQ=1,\n                num_wg=2,\n            )\n        else:\n            return BwdConfig(\n                m_block_size=64, n_block_size=96,\n                num_stages_Q=2, num_stages_dO=1, num_stages_PdS=1,\n                SdP_swapAB=False, dKV_swapAB=True, dQ_swapAB=False,\n                AtomLayoutMSdP=1, AtomLayoutNdKV=2, AtomLayoutMdQ=1,\n                num_wg=2,\n            )\n    else:\n        # hdim 256\n        return BwdConfig(\n            m_block_size=64, n_block_size=64,\n            num_stages_Q=1, num_stages_dO=1, num_stages_PdS=1,\n            SdP_swapAB=False, dKV_swapAB=False, dQ_swapAB=False,\n            AtomLayoutMSdP=1, AtomLayoutNdKV=1, AtomLayoutMdQ=1,\n        )\n\n\n\ndef maybe_contiguous(x):\n    return x.contiguous() if x is not None and x.stride(-1) != 1 else x\n\n\ndef _validate_tensor(t, name, expected_shape, expected_dtype, expected_device):\n    assert t.shape == expected_shape, f\"{name} shape {t.shape} != expected {expected_shape}\"\n    assert t.dtype == expected_dtype, f\"{name} dtype {t.dtype} != expected {expected_dtype}\"\n    assert t.device == expected_device, f\"{name} device {t.device} != expected {expected_device}\"\n    if not is_fake_mode():\n        assert t.is_cuda, f\"{name} must be on CUDA\"\n\n\ntorch2cute_dtype_map = {\n    torch.float16: cutlass.Float16,\n    torch.bfloat16: cutlass.BFloat16,\n    torch.float32: cutlass.Float32,\n}\n\n\ndef num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, max_splits):\n    # If num_n_blocks is too small, use 1 split. For example, we never split for hdim = 128 and seqlen_k = 512.\n    if num_n_blocks <= 4:\n        return 1\n\n    # NOTE: We should revisit this heuristic after persistence is supported for split KV.\n    # Sometimes, it's ideal to over-schedule splits for better efficiency.\n    return min(num_SMs // total_mblocks, max_splits, num_n_blocks)\n\n\ndef _resolve_causal_local_window(causal, window_size_left, window_size_right, mask_mod=None):\n    \"\"\"Resolve causal/local/window settings into canonical form.\n\n    Returns (causal, local, window_size_left, window_size_right).\n    \"\"\"\n    if mask_mod is not None:\n        return False, False, window_size_left, window_size_right\n    if causal:\n        window_size_right = 0\n    if window_size_left is not None and window_size_right is not None and window_size_left + window_size_right < 0:\n        window_size_left = None\n        window_size_right = None\n    if window_size_left is not None or window_size_right is not None:\n        if window_size_left is None and window_size_right == 0:\n            causal, local = True, False\n            window_size_right = None\n        else:\n            causal, local = False, True\n    else:\n        local = False\n    return causal, local, window_size_left, window_size_right\n\n\ndef _flash_attn_fwd(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    cu_seqlens_q: Optional[torch.Tensor] = None,\n    cu_seqlens_k: Optional[torch.Tensor] = None,\n    seqused_q: Optional[torch.Tensor] = None,\n    seqused_k: Optional[torch.Tensor] = None,\n    max_seqlen_q: Optional[int] = None,\n    max_seqlen_k: Optional[int] = None,\n    page_table: Optional[torch.Tensor] = None,\n    softmax_scale: Optional[float] = None,\n    causal: bool = False,\n    softcap: Optional[float] = None,\n    window_size_left: Optional[int] = None,\n    window_size_right: Optional[int] = None,\n    learnable_sink: Optional[torch.Tensor] = None,\n    tile_mn: Optional[Tuple[int, int]] = None,\n    mma_pv_is_rs: Optional[bool] = None,\n    intra_wg_overlap: Optional[bool] = None,\n    num_threads: int = 384,\n    num_splits: int = 1,\n    pack_gqa: Optional[bool] = None,\n    _arch: Optional[int] = None,\n    score_mod: Optional[Callable] = None,\n    mask_mod: Optional[Callable] = None,\n    block_sparse_tensors: Optional[BlockSparseTensorsTorch] = None,\n    return_lse: bool = False,\n    out: Optional[torch.Tensor] = None,\n    lse: Optional[torch.Tensor] = None,\n    aux_tensors: Optional[list[torch.Tensor]] = None,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"Forward pass for FlashAttention.\n\n    Args:\n        ...\n        score_mod: A callable that takes the attention scores and applies a modification.\n        mask_mod: A callable that takes token position information and selectively masks\n        block_sparse_tensors: A tuple of tensors used for block sparsity.\n        return_lse: Whether to return the log softmax of the attention scores. If set to True will always calculate\n            The returned LSE supports taking gradient.\n        out: Optional pre-allocated output tensor. If None, will be allocated internally.\n        lse: Optional pre-allocated log-sum-exp tensor. If None, will be allocated when needed.\n        aux_tensors: Some score_mods will want to read from global aux_tensors. This is how we thread them through to the inner kernel.\n    \"\"\"\n    q, k, v = [maybe_contiguous(t) for t in (q, k, v)]\n    num_head, head_dim = q.shape[-2:]\n    if cu_seqlens_q is None:\n        batch_size, seqlen_q = q.shape[:2]\n        total_q = batch_size * seqlen_q\n    else:\n        batch_size = cu_seqlens_q.shape[0] - 1\n        seqlen_q = None\n        total_q = q.shape[0]\n    if page_table is not None:\n        assert cu_seqlens_k is None, \"page_table is not supported with cu_seqlens_k\"\n        assert page_table.dtype == torch.int32, \"page_table must be int32\"\n        assert page_table.stride(-1) == 1, \"page_table must be contiguous in the last dimension\"\n        max_num_pages_per_seq = page_table.shape[1]\n        assert page_table.shape == (batch_size, max_num_pages_per_seq)\n        num_pages, page_size = k.shape[:2]\n        seqlen_k = num_pages * page_size\n    else:\n        num_pages, page_size = None, None\n        seqlen_k = k.shape[-3]\n    num_head_kv = k.shape[-2]\n    head_dim_v = v.shape[-1]\n    if cu_seqlens_k is None:\n        if page_table is None:\n            assert k.shape == (batch_size, seqlen_k, num_head_kv, head_dim)\n            assert v.shape == (batch_size, seqlen_k, num_head_kv, head_dim_v)\n        else:\n            assert k.shape == (num_pages, page_size, num_head_kv, head_dim)\n            assert v.shape == (num_pages, page_size, num_head_kv, head_dim_v)\n    else:\n        assert k.shape == (seqlen_k, num_head_kv, head_dim)\n        assert v.shape == (seqlen_k, num_head_kv, head_dim_v)\n        assert cu_seqlens_k.shape == (batch_size + 1,), (\n            \"cu_seqlens_k must have shape (batch_size + 1,)\"\n        )\n\n    if cu_seqlens_q is not None:\n        assert cu_seqlens_q.shape == (batch_size + 1,), (\n            \"cu_seqlens_q must have shape (batch_size + 1,)\"\n        )\n    assert seqused_q is None or seqused_q.shape == (batch_size,), (\n        \"seqused_q must have shape (batch_size,)\"\n    )\n    assert seqused_k is None or seqused_k.shape == (batch_size,), (\n        \"seqused_k must have shape (batch_size,)\"\n    )\n    assert q.dtype in [torch.float16, torch.bfloat16], \"inputs must be float16 or bfloat16\"\n    assert q.dtype == k.dtype == v.dtype, \"inputs must have the same dtype\"\n    for t in [cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k]:\n        if t is not None:\n            assert t.dtype == torch.int32, (\n                \"cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be int32\"\n            )\n            assert t.stride(0) == 1, (\n                \"cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be contiguous\"\n            )\n    if learnable_sink is not None:\n        assert learnable_sink.shape == (num_head,)\n        assert learnable_sink.dtype == torch.bfloat16, \"learnable_sink must be bfloat16\"\n\n    if not is_fake_mode():\n        assert all(\n            t is None or t.is_cuda\n            for t in (\n                q,\n                k,\n                v,\n                cu_seqlens_q,\n                cu_seqlens_k,\n                seqused_q,\n                seqused_k,\n                page_table,\n                learnable_sink,\n            )\n        ), \"inputs must be on CUDA device\"\n    arch = _get_device_arch() if _arch is None else _arch\n    assert arch // 10 in [8, 9, 10, 11, 12], \"Unsupported compute capability. Supported: 8.x, 9.x, 10.x, 11.x, 12.x\"\n    assert num_head % num_head_kv == 0, \"num_head must be divisible by num_head_kv\"\n    alignment = 16 // q.element_size()\n    if arch // 10 not in [8, 12]:\n        _validate_head_dims(head_dim, head_dim_v, arch // 10, alignment)\n    if softmax_scale is None:\n        softmax_scale = 1.0 / math.sqrt(head_dim)\n    if softcap == 0.0:\n        softcap = None\n    qhead_per_kvhead = num_head // num_head_kv\n    if pack_gqa is None:\n        pack_gqa = qhead_per_kvhead > 1\n\n    out_torch_dtype = q.dtype\n    device = q.device\n    q_batch_seqlen_shape = (batch_size, seqlen_q) if cu_seqlens_q is None else (total_q,)\n    lse_shape = (batch_size, num_head, seqlen_q) if cu_seqlens_q is None else (num_head, total_q)\n    requires_grad = q.requires_grad or k.requires_grad or v.requires_grad\n\n    if out is None:\n        out = torch.empty(\n            *q_batch_seqlen_shape, num_head, head_dim_v, dtype=out_torch_dtype, device=device\n        )\n    else:\n        _validate_tensor(out, \"out\", (*q_batch_seqlen_shape, num_head, head_dim_v), out_torch_dtype, device)\n\n    if lse is None:\n        lse = (\n            torch.empty(lse_shape, dtype=torch.float32, device=device)\n            if requires_grad or return_lse\n            else None\n        )\n    elif lse is not None:\n        _validate_tensor(lse, \"lse\", lse_shape, torch.float32, device)\n\n    dtype = torch2cute_dtype_map[q.dtype]\n    use_block_sparsity = block_sparse_tensors is not None\n\n    causal, local, window_size_left, window_size_right = _resolve_causal_local_window(\n        causal, window_size_left, window_size_right, mask_mod\n    )\n\n    # In fake mode (CPU-only compilation), use a fake stream placeholder.\n    current_stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True)\n\n    # SM80/SM120: uses SM80 MMA, 128 threads (4 warps)\n    if arch // 10 in [8, 12]:\n        num_threads = 128\n\n    fwd_cfg = FwdConfig(128, 128, True, True)  # default\n    if tile_mn is None:\n        if arch // 10 == 12:\n            # SM120 tile sizes tuned for 99 KB SMEM capacity:\n            # D<=64:  128x128 → 48 KB (good occupancy)\n            # D>64:   128x64  → 64 KB (128x128 would use 96 KB, hurting occupancy)\n            if head_dim <= 64:\n                fwd_cfg = FwdConfig(128, 128, True, True)\n            else:\n                fwd_cfg = FwdConfig(128, 64, True, True)\n        elif arch // 10 == 8:\n            fwd_cfg = FwdConfig(128, 64, True, True)  # SM80, should tune\n        elif arch // 10 == 9:\n            fwd_cfg = _tile_size_fwd_sm90(head_dim, head_dim_v, causal, local, use_block_sparsity)\n    else:\n        fwd_cfg = FwdConfig(tile_mn[0], tile_mn[1], fwd_cfg.mma_pv_is_rs, fwd_cfg.intra_wg_overlap)\n    tile_m, tile_n = fwd_cfg.m_block_size, fwd_cfg.n_block_size\n    if mma_pv_is_rs is None:\n        mma_pv_is_rs = fwd_cfg.mma_pv_is_rs\n    if intra_wg_overlap is None:\n        intra_wg_overlap = fwd_cfg.intra_wg_overlap\n\n    # TODO: fix GQA + SplitKV + non-varlen\n    if pack_gqa and num_splits != 1 and cu_seqlens_q is None:\n        pack_gqa = False\n\n    if arch // 10 in [10, 11]:\n        if pack_gqa and (128 % qhead_per_kvhead != 0):\n            pack_gqa = False\n\n    if max_seqlen_q is None:\n        max_seqlen_q = seqlen_q if cu_seqlens_q is None else total_q\n    if max_seqlen_k is None:\n        max_seqlen_k = seqlen_k\n    seqlen_q_packgqa = max_seqlen_q * qhead_per_kvhead\n    if arch // 10 == 10:\n        q_stage = 2 if seqlen_q_packgqa > tile_m else 1\n    else:\n        q_stage = 1\n\n    m_block_size_effective = q_stage * tile_m\n    seqlen_k_loaded = max_seqlen_k if not local else max(0, min(max_seqlen_k, (window_size_right or max_seqlen_k) + (window_size_left or max_seqlen_k) + 1 + tile_m))\n    num_m_blocks = (seqlen_q_packgqa + m_block_size_effective - 1) // m_block_size_effective\n    total_mblocks = batch_size * num_head_kv * num_m_blocks\n    num_n_blocks = (seqlen_k_loaded + tile_n - 1) // tile_n\n    num_SMs = 132 if is_fake_mode() else torch.cuda.get_device_properties(device).multi_processor_count\n    if num_splits < 1:\n        num_splits = num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, 128)\n\n    # SplitKV uses float32 partial output, which doubles the O buffer size\n    # in shared memory, causing OOM for diff-headdim (192, 128)\n    if arch // 10 in [10, 11] and head_dim != head_dim_v and num_splits > 1:\n        if num_n_blocks >= 64:\n            tile_n = 64\n            num_n_blocks = (seqlen_k_loaded + tile_n - 1) // tile_n\n            num_splits = num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, 128)\n        else:\n            num_splits = 1\n\n    is_split_kv = num_splits > 1\n    if is_split_kv:\n        out_partial = torch.empty(num_splits, *q_batch_seqlen_shape, num_head, head_dim_v, dtype=torch.float32, device=device)\n        lse_partial = torch.empty(num_splits, *lse_shape, dtype=torch.float32, device=device)\n\n    use_2cta_instrs = (\n        arch // 10 in [10, 11]\n        and not causal\n        and not local\n        and not is_split_kv\n        and cu_seqlens_q is None\n        and seqused_q is None\n        and not use_block_sparsity\n        and page_size in [None, 128]\n        and int(math.ceil(head_dim / 16) * 16) == 128\n        and int(math.ceil(head_dim_v / 16) * 16) == 128\n        and seqlen_q_packgqa > 2 * tile_m\n    )\n\n    # hash score and mask mods for compile cache\n    score_mod_hash = utils.hash_callable(score_mod) if score_mod is not None else False\n    mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod is not None else False\n\n    if softcap is not None:\n        assert score_mod is None, \"softcap and score_mod cannot be used together\"\n        score_mod = utils.create_softcap_scoremod(softcap)\n\n    is_varlen = (\n        cu_seqlens_q is not None\n        or cu_seqlens_k is not None\n        or seqused_q is not None\n        or seqused_k is not None\n    )\n\n    if mask_mod is not None:\n        if is_varlen:\n            raise NotImplementedError(\n                \"mask_mod with aux_tensors is not yet supported for varlen sequences. This will be fixed in a future PR.\"\n            )\n\n    if use_block_sparsity:\n        if is_varlen:\n            raise NotImplementedError(\n                \"Block sparsity is not yet supported for varlen sequences. This will be fixed in a future PR.\"\n            )\n        # NB: pack_gqa requires block sparse head dim == 1 (broadcasted)\n        if pack_gqa and block_sparse_tensors.mask_block_cnt.shape[1] != 1:\n            pack_gqa = False\n        if is_split_kv:\n            raise NotImplementedError(\n                \"Block sparsity is not yet supported with SplitKV. TODO: partition sparse block lists per split.\"\n            )\n\n    # See get_broadcast_dims for why this is needed in compile key\n    block_sparse_broadcast_pattern = None\n    normalized_block_sparse_tensors = None\n    q_subtile_factor = None\n    if block_sparse_tensors is not None:\n        if seqlen_q is None:\n            raise ValueError(\"Block sparsity requires fixed-length sequences (seqlen_q must be known).\")\n        (\n            normalized_block_sparse_tensors,\n            block_sparse_broadcast_pattern,\n            q_subtile_factor,\n        ) = normalize_block_sparse_config(\n            block_sparse_tensors,\n            batch_size=batch_size,\n            num_head=num_head,\n            seqlen_q=seqlen_q,\n            seqlen_k=seqlen_k,\n            block_size=(tile_m, tile_n),\n            q_stage=q_stage,\n        )\n    if aux_tensors is not None:\n        aux_tensor_metadata = get_aux_tensor_metadata(aux_tensors)\n    else:\n        aux_tensor_metadata = None\n\n    compile_key = (\n        dtype,\n        head_dim,\n        head_dim_v,\n        qhead_per_kvhead,\n        causal,\n        score_mod_hash,\n        mask_mod_hash,\n        use_block_sparsity,\n        block_sparse_broadcast_pattern,\n        aux_tensor_metadata,\n        lse is None,\n        cu_seqlens_q is None,\n        cu_seqlens_k is None,\n        seqused_q is None,\n        seqused_k is None,\n        page_table is not None,\n        window_size_left is not None,\n        window_size_right is not None,\n        learnable_sink is not None,\n        tile_m,\n        tile_n,\n        q_stage,\n        num_threads,\n        is_split_kv,\n        pack_gqa,\n        arch,\n        page_size not in [None, tile_n],  # paged KV non-TMA\n        use_2cta_instrs,\n        q_subtile_factor,\n        mma_pv_is_rs,\n        intra_wg_overlap,\n        fa_logging.get_fa_log_level(),\n    )\n    if compile_key not in _flash_attn_fwd.compile_cache:\n        (\n            cu_seqlens_q_tensor,\n            cu_seqlens_k_tensor,\n            seqused_q_tensor,\n            seqused_k_tensor,\n            learnable_sink_tensor,\n        ) = [\n            to_cute_tensor(t, assumed_align=4, leading_dim=0)\n            if t is not None\n            else None\n            for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, learnable_sink)\n        ]\n        page_table_tensor = (\n            to_cute_tensor(page_table, assumed_align=4, leading_dim=1)\n            if page_table is not None\n            else None\n        )\n        q_tensor, k_tensor, v_tensor, o_tensor = [\n            to_cute_tensor(t) for t in (q, k, v, out if not is_split_kv else out_partial)\n        ]\n        if is_split_kv:\n            lse_tensor = to_cute_tensor(lse_partial, assumed_align=4)\n        elif lse is not None:\n            lse_tensor = to_cute_tensor(lse, assumed_align=4)\n        else:\n            lse_tensor = None\n\n        sparse_tensors = None\n        if normalized_block_sparse_tensors is not None:\n            sparse_tensors = to_cute_block_sparse_tensors(normalized_block_sparse_tensors)\n\n        cute_aux_tensors = None\n        aux_tensor_metadata = None\n        if aux_tensors is not None:\n            cute_aux_tensors = [to_cute_aux_tensor(buf) for buf in aux_tensors]\n\n        if arch // 10 == 8:\n            assert page_table is None, \"paged KV not supported on SM 8.0\"\n            assert not is_split_kv, \"SplitKV not supported on SM 8.0\"\n            fa_fwd = FlashAttentionForwardSm80(\n                dtype,\n                head_dim,\n                head_dim_v,\n                qhead_per_kvhead,\n                is_causal=causal,\n                is_local=local,\n                pack_gqa=pack_gqa,\n                tile_m=tile_m,\n                tile_n=tile_n,\n                num_stages=1,\n                num_threads=num_threads,\n                Q_in_regs=False,\n                score_mod=score_mod,\n                mask_mod=mask_mod,\n                has_aux_tensors=aux_tensors is not None,\n            )\n        elif arch // 10 == 9:\n            assert not is_split_kv, \"SplitKV not supported on SM 9.0\"\n            fa_fwd = FlashAttentionForwardSm90(\n                dtype,\n                head_dim,\n                head_dim_v,\n                qhead_per_kvhead,\n                is_causal=causal,\n                is_local=local,\n                pack_gqa=pack_gqa,\n                tile_m=tile_m,\n                tile_n=tile_n,\n                # num_stages=1,\n                num_stages=2,\n                num_threads=num_threads,\n                Q_in_regs=False,\n                intra_wg_overlap=intra_wg_overlap,\n                mma_pv_is_rs=mma_pv_is_rs,\n                mask_mod=mask_mod,\n                score_mod=score_mod,\n                has_aux_tensors=aux_tensors is not None,\n                q_subtile_factor=q_subtile_factor,\n                paged_kv_non_tma=page_size not in [None, tile_n],\n            )\n        elif arch // 10 in [10, 11]:\n            fa_fwd = FlashAttentionForwardSm100(\n                head_dim,\n                head_dim_v,\n                qhead_per_kvhead=qhead_per_kvhead,\n                is_causal=causal,\n                is_local=local,\n                is_split_kv=is_split_kv,\n                pack_gqa=pack_gqa,\n                m_block_size=tile_m,\n                n_block_size=tile_n,\n                q_stage=q_stage,\n                is_persistent=not causal\n                    and not local\n                    and cu_seqlens_q is None\n                    and seqused_q is None\n                    and not is_split_kv,\n                score_mod=score_mod,\n                mask_mod=mask_mod,\n                has_aux_tensors=aux_tensors is not None,\n                paged_kv_non_tma=page_size not in [None, tile_n],\n                is_varlen_q=cu_seqlens_q is not None or seqused_q is not None,\n                q_subtile_factor=q_subtile_factor,\n                use_2cta_instrs=use_2cta_instrs,\n            )\n        elif arch // 10 == 12:\n            # SM120 (Blackwell GeForce / DGX Spark): uses SM80 MMA with SM120 SMEM capacity\n            assert not use_block_sparsity, \"Block sparsity not supported on SM 12.0\"\n            assert page_table is None, \"Paged KV not supported on SM 12.0 in this PR\"\n            assert not is_split_kv, \"SplitKV not supported on SM 12.0 in this PR\"\n            fa_fwd = FlashAttentionForwardSm120(\n                dtype,\n                head_dim,\n                head_dim_v,\n                qhead_per_kvhead,\n                is_causal=causal,\n                is_local=local,\n                pack_gqa=pack_gqa,\n                tile_m=tile_m,\n                tile_n=tile_n,\n                num_stages=1,\n                num_threads=num_threads,\n                Q_in_regs=False,\n                score_mod=score_mod,\n                mask_mod=mask_mod,\n                has_aux_tensors=aux_tensors is not None,\n            )\n        else:\n            raise ValueError(\n                f\"Unsupported compute capability: {arch}. Supported: 8.x, 9.x, 10.x, 11.x, 12.x\"\n            )\n        # TODO: check @can_implement\n        _flash_attn_fwd.compile_cache[compile_key] = cute.compile(\n            fa_fwd,\n            q_tensor,\n            k_tensor,\n            v_tensor,\n            o_tensor,\n            lse_tensor,\n            softmax_scale,\n            cu_seqlens_q_tensor,\n            cu_seqlens_k_tensor,\n            seqused_q_tensor,\n            seqused_k_tensor,\n            page_table_tensor,\n            window_size_left,\n            window_size_right,\n            learnable_sink_tensor,\n            sparse_tensors,\n            cute_aux_tensors,\n            current_stream,\n            options=\"--enable-tvm-ffi\",\n        )\n\n    # In \"fake mode\", we will take torch fake tensors as input and the expected behaviors are:\n    # - Use those fake metadata to populate compilation cache\n    # - Return \"fake\" output tensors, which could be needed in follow-up fake operations\n    # Thus, we skip the actual kernel invocation here.\n    if not is_fake_mode():\n        _flash_attn_fwd.compile_cache[compile_key](\n            q.detach(),\n            k.detach(),\n            v.detach(),\n            out.detach() if not is_split_kv else out_partial,\n            lse_partial if is_split_kv else lse,\n            softmax_scale,\n            cu_seqlens_q,\n            cu_seqlens_k,\n            seqused_q,\n            seqused_k,\n            page_table,\n            window_size_left,\n            window_size_right,\n            learnable_sink,\n            normalized_block_sparse_tensors[:4] if normalized_block_sparse_tensors is not None else None,\n            aux_tensors,\n        )\n    if is_split_kv:\n        _flash_attn_fwd_combine(\n            out_partial,\n            lse_partial.transpose(-1, -2),\n            out,\n            lse.transpose(-1, -2) if lse is not None else None,\n            cu_seqlens_q,\n            seqused_q,\n        )\n    return out, lse\n\n\n_flash_attn_fwd.compile_cache = get_jit_cache(\"fwd\")\n\n\ndef make_fake_bwd_tensors(dtype, has_gqa, varlen_q, varlen_k):\n    sym = cute.sym_int\n    # divisibility in elements: assumed_align_bytes = divisibility * dtype.width // 8\n    # For 16-byte align: fp16/bf16 → divisibility=8, float32 → divisibility=4\n    div = 128 // dtype.width  # 8 for fp16/bf16\n    # Shared sym_ints for dimensions that must match across tensors\n    b, seqlen_q, seqlen_k, h_q, d, d_v = sym(), sym(), sym(), sym(), sym(), sym()\n    h_kv = h_q if not has_gqa else sym()\n    seqlen_q_rounded, seqlen_k_rounded = sym(), sym()\n    seqlen_q_d_rounded, seqlen_k_d_rounded, seqlen_k_dv_rounded = sym(), sym(), sym()\n    total_q, total_k, total_q_rounded, total_k_rounded = sym(), sym(), sym(), sym()\n    total_q_d_rounded, total_k_d_rounded, total_k_dv_rounded = sym(), sym(), sym()\n    b_seqlenq = (b, seqlen_q) if not varlen_q else (total_q,)\n    b_seqlenk = (b, seqlen_k) if not varlen_k else (total_k,)\n    mQ = fake_tensor(dtype, (*b_seqlenq, h_q, d), divisibility=div)\n    mO = fake_tensor(dtype, (*b_seqlenq, h_q, d_v), divisibility=div)\n    mdO = fake_tensor(dtype, (*b_seqlenq, h_q, d_v), divisibility=div)\n    mK = fake_tensor(dtype, (*b_seqlenk, h_kv, d), divisibility=div)\n    mV = fake_tensor(dtype, (*b_seqlenk, h_kv, d_v), divisibility=div)\n    mdQ = fake_tensor(dtype, (*b_seqlenq, h_q, d), divisibility=div)\n    mdK = fake_tensor(dtype, (*b_seqlenk, h_kv, d), divisibility=div)\n    mdV = fake_tensor(dtype, (*b_seqlenk, h_kv, d_v), divisibility=div)\n    if not varlen_q:\n        mLSE = fake_tensor(Float32, (b, h_q, seqlen_q), divisibility=1)\n        mLSElog2 = fake_tensor(Float32, (b, h_q, seqlen_q_rounded), divisibility=4)\n        mPdPsum = fake_tensor(Float32, (b, h_q, seqlen_q_rounded), divisibility=4)\n        dQaccum = fake_tensor(Float32, (b, h_q, seqlen_q_d_rounded), divisibility=4)\n    else:\n        mLSE = fake_tensor(Float32, (h_q, total_q), divisibility=1)\n        mLSElog2 = fake_tensor(Float32, (h_q, total_q_rounded), divisibility=4)\n        mPdPsum = fake_tensor(Float32, (h_q, total_q_rounded), divisibility=4)\n        dQaccum = fake_tensor(Float32, (h_q, total_q_d_rounded), divisibility=4)\n    if not has_gqa:\n        mdKaccum, mdVaccum = None, None\n    else:\n        if not varlen_k:\n            mdKaccum = fake_tensor(Float32, (b, h_kv, seqlen_k_rounded), divisibility=4)\n            mdVaccum = fake_tensor(Float32, (b, h_kv, seqlen_k_dv_rounded), divisibility=4)\n        else:\n            mdKaccum = fake_tensor(Float32, (h_kv, total_k_rounded), divisibility=4)\n            mdVaccum = fake_tensor(Float32, (h_kv, total_k_dv_rounded), divisibility=4)\n    return mQ, mK, mV, mO, mdO, mdQ, mdK, mdV, mLSE, mLSElog2, mPdPsum, dQaccum, mdKaccum, mdVaccum\n\n\ndef _compile_bwd_preprocess(\n    dtype, head_dim, head_dim_v, m_block_size, has_cuseqlens_q, has_seqused_q, has_dlse,\n):\n    \"\"\"Compile bwd preprocess kernel using cute fake tensors (no real GPU tensors needed).\"\"\"\n    mQ, mK, mV, mO, mdO, mdQ, mdK, mdV, mLSE, mLSElog2, mPdPsum, mdQaccum, mdKaccum, mdVaccum = make_fake_bwd_tensors(\n        dtype, has_gqa=True, varlen_q=has_cuseqlens_q, varlen_k=False\n    )\n    batch = mQ.shape[0] if not has_cuseqlens_q else cute.sym_int()\n    batchp1 = cute.sym_int()\n    mCuSeqlensQ = fake_tensor(Int32, (batchp1,), divisibility=1) if has_cuseqlens_q else None\n    mSequsedQ = fake_tensor(Int32, (batch,), divisibility=1) if has_seqused_q else None\n    mdLSE = fake_tensor(Float32, mLSE.shape, divisibility=1) if has_dlse else None\n    fa_bwd_pre = FlashAttentionBackwardPreprocess(dtype, head_dim, head_dim_v, m_block_size)\n    return cute.compile(\n        fa_bwd_pre, mO, mdO, mPdPsum, mLSE, mLSElog2, mdQaccum, mCuSeqlensQ, mSequsedQ, mdLSE,\n        cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True),\n        options=\"--enable-tvm-ffi\",\n    )\n\n\ndef _bwd_preprocess(\n    out, dout, dpsum, lse, lse_log2, dq_accum,\n    cu_seqlens_q, seqused_q, dlse,\n    dtype, head_dim, head_dim_v, m_block_size,\n):\n    \"\"\"Backward preprocess: compute (o * dout).sum(dim=-1) - dLSE, lse * log2_e, and zero out dq_accum.\"\"\"\n    is_varlen = cu_seqlens_q is not None\n    compile_key = (\n        dtype, head_dim, head_dim_v, m_block_size, is_varlen, seqused_q is not None, dlse is not None,\n    )\n    if compile_key not in _bwd_preprocess.compile_cache:\n        _bwd_preprocess.compile_cache[compile_key] = _compile_bwd_preprocess(*compile_key)\n    if not is_fake_mode():\n        _bwd_preprocess.compile_cache[compile_key](\n            out, dout, dpsum, lse, lse_log2, dq_accum, cu_seqlens_q, seqused_q, dlse\n        )\n\n\n_bwd_preprocess.compile_cache = get_jit_cache(\"bwd_pre\")\n\n\ndef _compile_bwd_postprocess(\n    dtype, hdim, block_size, num_threads, atom_layout, swap_ab,\n    has_cuseqlens_q, has_seqused_q,\n    use_2cta_instrs, cluster_size, arch,\n):\n    \"\"\"Compile bwd postprocess kernel using cute fake tensors.\"\"\"\n    mQ, mK, mV, mO, mdO, mdQ, mdK, mdV, mLSE, mLSElog2, mPdPsum, mdQaccum, mdKaccum, mdVaccum = make_fake_bwd_tensors(\n        dtype, has_gqa=True, varlen_q=has_cuseqlens_q, varlen_k=False\n    )\n    batch = mQ.shape[0] if not has_cuseqlens_q else cute.sym_int()\n    batchp1 = cute.sym_int()\n    mCuSeqlensQ = fake_tensor(Int32, (batchp1,), divisibility=1) if has_cuseqlens_q else None\n    mSeqUsedQ = fake_tensor(Int32, (batch,), divisibility=1) if has_seqused_q else None\n    fa_bwd_post = FlashAttentionBackwardPostprocess(\n        dtype, hdim, arch, block_size, num_threads, atom_layout, swap_ab,\n        use_2cta_instrs=use_2cta_instrs,\n        cluster_size=cluster_size,\n    )\n    return cute.compile(\n        fa_bwd_post, mdQaccum, mdQ, Float32(0.0), mCuSeqlensQ, mSeqUsedQ,\n        cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True),\n        options=\"--enable-tvm-ffi\",\n    )\n\n\ndef _bwd_postprocess_convert(\n    accum, output, scale,\n    cu_seqlens, seqused,\n    arch, dtype, hdim, block_size, num_threads,\n    atom_layout, swap_ab,\n    use_2cta_instrs=False, cluster_size=1,\n):\n    \"\"\"Backward postprocess: convert float32 accumulator to bf16/fp16 output.\"\"\"\n    compile_key = (\n        dtype, hdim, block_size, num_threads, atom_layout, swap_ab,\n        cu_seqlens is not None, seqused is not None,\n        use_2cta_instrs, cluster_size, arch,\n    )\n    if compile_key not in _bwd_postprocess_convert.compile_cache:\n        _bwd_postprocess_convert.compile_cache[compile_key] = _compile_bwd_postprocess(*compile_key)\n    if not is_fake_mode():\n        _bwd_postprocess_convert.compile_cache[compile_key](\n            accum, output, scale, cu_seqlens, seqused,\n        )\n\n\n_bwd_postprocess_convert.compile_cache = get_jit_cache(\"bwd_post\")\n\n\ndef _flash_attn_bwd(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    out: torch.Tensor,\n    dout: torch.Tensor,\n    lse: torch.Tensor,\n    softmax_scale: Optional[float] = None,\n    causal: bool = False,\n    softcap: float = 0.0,\n    window_size_left: Optional[int] = None,\n    window_size_right: Optional[int] = None,\n    m_block_size: int = 64,\n    n_block_size: int = 128,\n    num_threads: int = 256,\n    pack_gqa: bool = False,\n    num_stages_Q: int = 2,\n    num_stages_dO: int = 2,\n    SdP_swapAB: bool = False,\n    dKV_swapAB: bool = False,\n    dQ_swapAB: bool = False,\n    AtomLayoutMSdP: int = 2,\n    AtomLayoutNdKV: int = 2,\n    AtomLayoutMdQ: int = 2,\n    V_in_regs: bool = False,\n    cu_seqlens_q: Optional[torch.Tensor] = None,\n    cu_seqlens_k: Optional[torch.Tensor] = None,\n    seqused_q: Optional[torch.Tensor] = None,\n    seqused_k: Optional[torch.Tensor] = None,\n    max_seqlen_q: Optional[int] = None,\n    max_seqlen_k: Optional[int] = None,\n    deterministic: bool = False,\n    dq: Optional[torch.Tensor] = None,\n    dk: Optional[torch.Tensor] = None,\n    dv: Optional[torch.Tensor] = None,\n    score_mod: Optional[Callable] = None,\n    score_mod_bwd: Optional[Callable] = None,\n    mask_mod: Optional[Callable] = None,\n    aux_tensors: Optional[list[torch.Tensor]] = None,\n    block_sparse_tensors: Optional[BlockSparseTensorsTorch] = None,\n    dlse: Optional[torch.Tensor] = None,\n) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n    arch = _get_device_arch()\n    assert arch // 10 in [9, 10, 11, 12], \"Unsupported compute capability. Supported: 9.x, 10.x, 11.x, 12.x\"\n\n    num_head, head_dim = q.shape[-2:]\n    head_dim_v = v.shape[-1]\n\n    causal, local, window_size_left, window_size_right = _resolve_causal_local_window(\n        causal, window_size_left, window_size_right\n    )\n\n    if arch // 10 == 12:\n        # SM120: uses SM80 MMA with 99 KB SMEM, 128 threads (4 warps).\n        m_block_size = 64\n        n_block_size = 64\n        if head_dim <= 64:\n            num_stages_Q = 2\n            num_stages_dO = 2\n        else:\n            num_stages_Q = 1\n            num_stages_dO = 1\n        SdP_swapAB = False\n        dKV_swapAB = False\n        dQ_swapAB = False\n        AtomLayoutMSdP = 4\n        AtomLayoutNdKV = 4\n        AtomLayoutMdQ = 4\n        V_in_regs = False\n        cluster_size = 1\n        use_2cta_instrs = False\n        num_threads = 128\n        assert not (block_sparse_tensors is not None), \"Block sparsity backward not supported on SM 12.0\"\n        assert score_mod is None and score_mod_bwd is None, \"score_mod backward not supported on SM 12.0\"\n        assert mask_mod is None, \"mask_mod backward not supported on SM 12.0\"\n        assert deterministic is False, \"deterministic backward not supported on SM 12.0\"\n    elif arch // 10 == 9:\n        cfg = _tile_size_bwd_sm90(head_dim, head_dim_v, causal, local)\n        m_block_size = cfg.m_block_size\n        n_block_size = cfg.n_block_size\n        num_stages_Q = cfg.num_stages_Q\n        num_stages_dO = cfg.num_stages_dO\n        num_stages_PdS = cfg.num_stages_PdS\n        SdP_swapAB = cfg.SdP_swapAB\n        dKV_swapAB = cfg.dKV_swapAB\n        dQ_swapAB = cfg.dQ_swapAB\n        AtomLayoutMSdP = cfg.AtomLayoutMSdP\n        AtomLayoutNdKV = cfg.AtomLayoutNdKV\n        AtomLayoutMdQ = cfg.AtomLayoutMdQ\n        num_threads = (cfg.num_wg + 1) * 128\n        dQ_single_wg = cfg.dQ_single_wg\n        cluster_size = 1\n        use_2cta_instrs = False\n        is_varlen = (\n            cu_seqlens_q is not None\n            or cu_seqlens_k is not None\n            or seqused_q is not None\n            or seqused_k is not None\n        )\n    else:\n        m_block_size = 128\n        n_block_size = 128\n        dQ_swapAB = False\n        dKV_swapAB = False\n        AtomLayoutMdQ = 1\n        AtomLayoutNdKV = 1\n        disable_2cta = (\n            score_mod is not None\n            or score_mod_bwd is not None\n            or mask_mod is not None\n        )\n        cluster_size = 2 if head_dim >= 128 and not disable_2cta else 1\n        use_2cta_instrs = cluster_size==2\n\n    q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = [\n        maybe_contiguous(t)\n        for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)\n    ]\n    if cu_seqlens_q is None:\n        batch_size, seqlen_q = q.shape[:2]\n        total_q = batch_size * seqlen_q\n    else:\n        batch_size = cu_seqlens_q.shape[0] - 1\n        total_q = q.shape[0]\n        seqlen_q = max_seqlen_q if max_seqlen_q is not None else total_q\n\n    if cu_seqlens_k is None:\n        batch_size, seqlen_k = k.shape[:2]\n        total_k = batch_size * seqlen_k\n    else:\n        batch_size = cu_seqlens_k.shape[0] - 1\n        total_k = k.shape[0]\n        seqlen_k = max_seqlen_k if max_seqlen_k is not None else total_k\n\n    num_head_kv = k.shape[-2]\n\n    use_block_sparsity = block_sparse_tensors is not None\n\n    # SM90 block-sparse backward: tile_m=64 is the GCD between a m_block_size that fits,\n    # the base block_m of 128 from forward, and block-sparse size for subtiling.\n    if arch // 10 == 9 and use_block_sparsity:\n        m_block_size = 64\n        # dQ_swapAB tuning: use False when m_block_size=64 (same as causal case)\n        dQ_swapAB = False\n\n    # NB: this could be derived from the block_sparse_tensors but for now we hardcode it to 2\n    subtile_factor = 2\n    seqlen_q_rounded = (seqlen_q + m_block_size - 1) // m_block_size * m_block_size\n    seqlen_k_rounded = (seqlen_k + n_block_size - 1) // n_block_size * n_block_size\n    num_n_blocks = seqlen_k_rounded // n_block_size\n    if cluster_size == 2 and num_n_blocks % cluster_size != 0:\n        seqlen_k_rounded = seqlen_k_rounded + n_block_size\n\n    if cu_seqlens_k is None:\n        assert k.shape == (batch_size, seqlen_k, num_head_kv, head_dim)\n        assert v.shape == (batch_size, seqlen_k, num_head_kv, head_dim_v)\n    else:\n        assert k.shape == (total_k, num_head_kv, head_dim)\n        assert v.shape == (total_k, num_head_kv, head_dim_v)\n        assert cu_seqlens_k.shape == (batch_size + 1,), (\n            \"cu_seqlens_k must have shape (batch_size + 1,)\"\n        )\n\n    if cu_seqlens_q is not None:\n        assert cu_seqlens_q.shape == (batch_size + 1,), (\n            \"cu_seqlens_q must have shape (batch_size + 1,)\"\n        )\n\n        assert out.shape == (total_q, num_head, head_dim_v)\n        assert dout.shape == (total_q, num_head, head_dim_v)\n        assert lse.shape == (num_head, total_q), \"lse must have shape (num_head, total_q)\"\n    else:\n        assert out.shape == (batch_size, seqlen_q, num_head, head_dim_v)\n        assert dout.shape == (batch_size, seqlen_q, num_head, head_dim_v)\n        assert lse.shape == (batch_size, num_head, seqlen_q), (\n            \"lse must have shape (batch_size, num_head, seqlen_q)\"\n        )\n\n    assert q.dtype in [torch.float16, torch.bfloat16], \"inputs must be float16 or bfloat16\"\n    assert q.dtype == k.dtype == v.dtype == out.dtype == dout.dtype, (\n        \"inputs must have the same dtype\"\n    )\n    for t in [cu_seqlens_q, cu_seqlens_k]:\n        if t is not None:\n            assert t.dtype == torch.int32, \"cu_seqlens_q, cu_seqlens_k must be int32\"\n    assert lse.dtype == torch.float32, \"lse must be float32\"\n    if dlse is not None:\n        dlse = maybe_contiguous(dlse)\n    if not is_fake_mode():\n        assert all(\n            t is None or t.is_cuda for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k)\n        ), \"inputs must be on CUDA device\"\n    assert num_head % num_head_kv == 0, \"num_head must be divisible by num_head_kv\"\n    alignment = 16 // q.element_size()\n    if arch // 10 != 12:\n        _validate_head_dims(head_dim, head_dim_v, arch // 10, alignment)\n    if softmax_scale is None:\n        softmax_scale = 1.0 / math.sqrt(head_dim)\n    qhead_per_kvhead = num_head // num_head_kv\n    if pack_gqa is None:\n        pack_gqa = qhead_per_kvhead > 1\n    # pack_gqa backward not yet supported in bwd\n    pack_gqa = False\n    if score_mod is not None:\n        assert score_mod_bwd is not None, \"score_mod_bwd is required when score_mod is provided\"\n        assert softcap == 0.0, \"softcap and score_mod are mutually exclusive (different log2 scaling)\"\n        assert cu_seqlens_q is None and cu_seqlens_k is None, (\n            \"varlen + score_mod not supported in bwd yet\"\n        )\n\n    device = q.device\n    out_torch_dtype = q.dtype\n\n    if dq is None:\n        dq = torch.empty_like(q)\n    else:\n        _validate_tensor(dq, \"dq\", q.shape, out_torch_dtype, device)\n\n    if dk is None:\n        dk = torch.empty_like(k)\n    else:\n        _validate_tensor(dk, \"dk\", k.shape, out_torch_dtype, device)\n\n    if dv is None:\n        dv = torch.empty_like(v)\n    else:\n        _validate_tensor(dv, \"dv\", v.shape, out_torch_dtype, device)\n\n    head_dim_rounded = (head_dim + 32 - 1) // 32 * 32\n\n    if cu_seqlens_q is None:\n        dq_accum = torch.empty(\n            batch_size,\n            num_head,\n            seqlen_q_rounded * head_dim_rounded,\n            dtype=torch.float32,\n            device=device,\n        )\n        dpsum = torch.empty(\n            batch_size, num_head, seqlen_q_rounded, dtype=torch.float32, device=device\n        )\n        lse_log2 = torch.empty(\n            batch_size, num_head, seqlen_q_rounded, dtype=torch.float32, device=device\n        )\n    else:\n        total_q_rounded_padded = (\n            (total_q + cu_seqlens_q.shape[0] * m_block_size - 1) // m_block_size * m_block_size\n        )\n        dq_accum = torch.empty(\n            num_head, total_q_rounded_padded * head_dim_rounded, dtype=torch.float32, device=device\n        )\n        dpsum = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device)\n        lse_log2 = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device)\n\n    # GQA (qhead_per_kvhead > 1) needs dK/dV accum+postprocess since multiple Q heads\n    # accumulate into the same dK/dV. SM90 varlen_k with qhead_per_kvhead==1 now uses\n    # ragged TMA tensors for direct store, so no longer needs accum+postprocess.\n    dKV_postprocess = qhead_per_kvhead > 1\n    if dKV_postprocess:\n        head_dim_v_rounded = (head_dim_v + 32 - 1) // 32 * 32\n        if cu_seqlens_k is None:\n            dk_accum = torch.zeros(\n                batch_size,\n                num_head_kv,\n                seqlen_k_rounded * head_dim_rounded,\n                dtype=torch.float32,\n                device=device,\n            )\n            dv_accum = torch.zeros(\n                batch_size,\n                num_head_kv,\n                seqlen_k_rounded * head_dim_v_rounded,\n                dtype=torch.float32,\n                device=device,\n            )\n        else:\n            cluster_tile_n = cluster_size * n_block_size\n            total_k_rounded_padded = (\n                (total_k + cu_seqlens_k.shape[0] * cluster_tile_n - 1) // cluster_tile_n * cluster_tile_n\n            )\n            dk_accum = torch.zeros(\n                num_head_kv,\n                total_k_rounded_padded * head_dim_rounded,\n                dtype=torch.float32,\n                device=device,\n            )\n            dv_accum = torch.zeros(\n                num_head_kv,\n                total_k_rounded_padded * head_dim_v_rounded,\n                dtype=torch.float32,\n                device=device,\n            )\n\n    dtype = torch2cute_dtype_map[q.dtype]\n    current_stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True)\n\n    if deterministic:\n        dQ_semaphore = torch.zeros(batch_size, num_head, seqlen_q_rounded // m_block_size, cluster_size, dtype=torch.int32, device=device)\n    else:\n        dQ_semaphore = None\n\n    if deterministic and qhead_per_kvhead > 1:\n        dK_semaphore = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, dtype=torch.int32, device=device)\n        dV_semaphore = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, dtype=torch.int32, device=device)\n    else:\n        dK_semaphore = None\n        dV_semaphore = None\n\n    # Preprocess kernel: compute (o * dout).sum(dim=-1) - dLSE, lse * log2_e, and zero out dq_accum.\n    _bwd_preprocess(\n        out, dout, dpsum, lse, lse_log2, dq_accum,\n        cu_seqlens_q, seqused_q, dlse,\n        dtype, head_dim, head_dim_v, m_block_size,\n    )\n    # num_threads: SM90 derives from BwdConfig.num_wg, SM120 is set to 128 above,\n    # SM100/SM110 uses default from function signature (384).\n    if arch // 10 not in [9, 12]:\n        num_threads = 384\n\n    # Backward kernel: compute dk, dv, dq_accum.\n    score_mod_hash = utils.hash_callable(score_mod) if score_mod else False\n    score_mod_bwd_hash = utils.hash_callable(score_mod_bwd) if score_mod_bwd else False\n    mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod else False\n    num_aux_tensors = len(aux_tensors) if aux_tensors else 0\n    cute_aux_tensors = None\n    if aux_tensors is not None:\n        cute_aux_tensors = [to_cute_tensor(buf, assumed_align=None, fully_dynamic=True) for buf in aux_tensors]\n\n    block_sparse_broadcast_pattern = None\n    normalized_block_sparse_tensors = None\n    if block_sparse_tensors is not None:\n        (\n            normalized_block_sparse_tensors,\n            block_sparse_broadcast_pattern,\n        ) = normalize_block_sparse_config_bwd(\n            block_sparse_tensors,\n            batch_size=batch_size,\n            num_head=num_head,\n            seqlen_q=seqlen_q,\n            seqlen_k=seqlen_k,\n            block_size=(m_block_size, n_block_size),\n            subtile_factor=subtile_factor,\n        )\n\n    if arch // 10 in [8, 9, 12]:\n        compile_key = (\n            arch,\n            dtype,\n            head_dim,\n            head_dim_v,\n            qhead_per_kvhead,\n            causal,\n            window_size_left is not None,\n            window_size_right is not None,\n            softcap != 0.0,\n            m_block_size,\n            n_block_size,\n            num_threads,\n            pack_gqa,\n            num_stages_Q,\n            num_stages_dO,\n            SdP_swapAB,\n            dKV_swapAB,\n            dQ_swapAB,\n            AtomLayoutMSdP,\n            AtomLayoutNdKV,\n            AtomLayoutMdQ,\n            V_in_regs,\n            dQ_single_wg,\n            deterministic,\n            cu_seqlens_q is None,\n            cu_seqlens_k is None,\n            seqused_q is None,\n            seqused_k is None,\n            score_mod_hash,\n            score_mod_bwd_hash,\n            mask_mod_hash,\n            num_aux_tensors,\n            use_block_sparsity,\n            block_sparse_broadcast_pattern,\n            get_broadcast_dims(q),\n            get_broadcast_dims(k),\n            get_broadcast_dims(v),\n            get_broadcast_dims(dout),\n        )\n    else:\n        compile_key = (\n            arch,\n            dtype,\n            head_dim,\n            head_dim_v,\n            qhead_per_kvhead,\n            causal,\n            window_size_left is not None,\n            window_size_right is not None,\n            softcap != 0.0,\n            m_block_size,\n            n_block_size,\n            num_threads,\n            pack_gqa,\n            cluster_size,\n            use_2cta_instrs,\n            deterministic,\n            score_mod_hash,\n            score_mod_bwd_hash,\n            mask_mod_hash,\n            num_aux_tensors,\n            use_block_sparsity,\n            block_sparse_broadcast_pattern,\n            cu_seqlens_q is None,\n            cu_seqlens_k is None,\n            seqused_q is None,\n            seqused_k is None,\n            get_broadcast_dims(q),\n            get_broadcast_dims(k),\n            get_broadcast_dims(v),\n            get_broadcast_dims(dout),\n        )\n    if compile_key not in _flash_attn_bwd.compile_cache:\n        q_tensor, k_tensor, v_tensor, do_tensor, dq_tensor, dk_tensor, dv_tensor = [\n            to_cute_tensor(t) for t in (q, k, v, dout, dq, dk, dv)\n        ]\n        dq_accum_tensor, dpsum_tensor, lse_log2_tensor = [\n            to_cute_tensor(t) for t in (dq_accum, dpsum, lse_log2)\n        ]\n        if dKV_postprocess:\n            dk_accum_tensor, dv_accum_tensor = [\n                to_cute_tensor(t) for t in (dk_accum, dv_accum)\n            ]\n        cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor = [\n            to_cute_tensor(t, assumed_align=4) if t is not None else None\n            for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)\n        ]\n        dQ_semaphore_tensor, dK_semaphore_tensor, dV_semaphore_tensor = [\n            utils.convert_from_dlpack_leading_static(t.detach(), leading_dim=3, alignment=4, stride_order=t.dim_order())\n            if t is not None else None\n            for t in (dQ_semaphore, dK_semaphore, dV_semaphore)\n        ]\n        if arch // 10 in [8, 12]:\n            flash_bwd_obj_cls = FlashAttentionBackwardSm120 if arch // 10 == 12 else FlashAttentionBackwardSm80\n            fa_bwd_obj = flash_bwd_obj_cls(\n                dtype,\n                head_dim,\n                head_dim_v,\n                qhead_per_kvhead,\n                m_block_size,\n                n_block_size,\n                num_stages_Q,\n                num_stages_dO,\n                num_threads,\n                pack_gqa,\n                causal,\n                SdP_swapAB,\n                dKV_swapAB,\n                dQ_swapAB,\n                AtomLayoutMSdP,\n                AtomLayoutNdKV,\n                AtomLayoutMdQ,\n                V_in_regs=V_in_regs,\n            )\n        elif arch // 10 == 9:\n            fa_bwd_obj = FlashAttentionBackwardSm90(\n                dtype,\n                head_dim,\n                head_dim_v,\n                qhead_per_kvhead,\n                causal,\n                is_local=local,\n                deterministic=deterministic,\n                tile_m=m_block_size,\n                tile_n=n_block_size,\n                Q_stage=num_stages_Q,\n                dO_stage=num_stages_dO,\n                PdS_stage=num_stages_PdS,\n                SdP_swapAB=SdP_swapAB,\n                dKV_swapAB=dKV_swapAB,\n                dQ_swapAB=dQ_swapAB,\n                AtomLayoutMSdP=AtomLayoutMSdP,\n                AtomLayoutNdKV=AtomLayoutNdKV,\n                AtomLayoutMdQ=AtomLayoutMdQ,\n                num_threads=num_threads,\n                V_in_regs=V_in_regs,\n                score_mod=score_mod,\n                score_mod_bwd=score_mod_bwd,\n                mask_mod=mask_mod,\n                has_aux_tensors=aux_tensors is not None,\n                subtile_factor=subtile_factor,\n                dQ_single_wg=dQ_single_wg,\n            )\n        else:\n            fa_bwd_obj = FlashAttentionBackwardSm100(\n                head_dim,\n                head_dim_v,\n                is_causal=causal,\n                is_local=local,\n                qhead_per_kvhead=qhead_per_kvhead,\n                tile_m=m_block_size,\n                tile_n=n_block_size,\n                cluster_size=cluster_size,\n                use_2cta_instrs=use_2cta_instrs,\n                deterministic=deterministic,\n                score_mod=score_mod,\n                score_mod_bwd=score_mod_bwd,\n                mask_mod=mask_mod,\n                has_aux_tensors=aux_tensors is not None,\n                subtile_factor=subtile_factor,\n            )\n\n        # Block sparse tensors for backward use Q-direction indexing (transposed from forward).\n        sparse_tensors_compile = None\n        if normalized_block_sparse_tensors is not None:\n            sparse_tensors_compile = to_cute_block_sparse_tensors(normalized_block_sparse_tensors)\n\n        # TODO: check @can_implement\n        _flash_attn_bwd.compile_cache[compile_key] = cute.compile(\n            fa_bwd_obj,\n            q_tensor,\n            k_tensor,\n            v_tensor,\n            do_tensor,\n            lse_log2_tensor,\n            dpsum_tensor,\n            dq_accum_tensor,\n            dk_tensor if not dKV_postprocess else dk_accum_tensor,\n            dv_tensor if not dKV_postprocess else dv_accum_tensor,\n            softmax_scale,\n            cu_seqlens_q_tensor,\n            cu_seqlens_k_tensor,\n            seqused_q_tensor,\n            seqused_k_tensor,\n            None,  # softcap - not yet supported in backward\n            window_size_left,\n            window_size_right,\n            dQ_semaphore_tensor,\n            dK_semaphore_tensor,\n            dV_semaphore_tensor,\n            cute_aux_tensors,\n            sparse_tensors_compile,\n            current_stream,\n            options=\"--enable-tvm-ffi\",\n        )\n    if not is_fake_mode():\n        _flash_attn_bwd.compile_cache[compile_key](\n            q.detach(),\n            k.detach(),\n            v.detach(),\n            dout,\n            lse_log2,\n            dpsum,\n            dq_accum,\n            dk if not dKV_postprocess else dk_accum,\n            dv if not dKV_postprocess else dv_accum,\n            softmax_scale,\n            cu_seqlens_q,\n            cu_seqlens_k,\n            seqused_q,\n            seqused_k,\n            None,  # softcap - not yet supported in backward\n            window_size_left,\n            window_size_right,\n            dQ_semaphore,\n            dK_semaphore,\n            dV_semaphore,\n            aux_tensors,\n            normalized_block_sparse_tensors[:4] if normalized_block_sparse_tensors is not None else None,\n        )\n\n    if arch // 10 == 9:\n        # dQ postprocess: match main kernel's MMA WG count, unless dQ_single_wg\n        num_threads_post_dQ = 128 if dQ_single_wg else cfg.num_wg * 128\n        num_threads_post_dKV = cfg.num_wg * 128\n    else:\n        num_threads_post_dQ = 128\n        num_threads_post_dKV = 128\n\n    # Postprocess: convert dq_accum from float32 to dq in bf16/fp16\n    _bwd_postprocess_convert(\n        dq_accum, dq, softmax_scale,\n        cu_seqlens_q, seqused_q,\n        arch, dtype, head_dim, m_block_size, num_threads_post_dQ,\n        AtomLayoutMdQ, dQ_swapAB,\n        use_2cta_instrs=use_2cta_instrs, cluster_size=1,\n    )\n\n    if dKV_postprocess:\n        # Postprocess: convert dk_accum from float32 to dk in bf16/fp16\n        _bwd_postprocess_convert(\n            dk_accum, dk, softmax_scale,\n            cu_seqlens_k, seqused_k,\n            arch, dtype, head_dim, n_block_size, num_threads_post_dKV,\n            AtomLayoutNdKV, dKV_swapAB,\n            cluster_size=cluster_size,\n        )\n        # Postprocess: convert dv_accum from float32 to dv in bf16/fp16\n        _bwd_postprocess_convert(\n            dv_accum, dv, 1.0,\n            cu_seqlens_k, seqused_k,\n            arch, dtype, head_dim_v, n_block_size, num_threads_post_dKV,\n            AtomLayoutNdKV, dKV_swapAB,\n            cluster_size=cluster_size,\n        )\n\n    return dq, dk, dv\n\n\n_flash_attn_bwd.compile_cache = get_jit_cache(\"bwd\")\n\n\nclass FlashAttnFunc(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        softmax_scale: Optional[float] = None,\n        causal: bool = False,\n        window_size: Tuple[Optional[int], Optional[int]] = (None, None),\n        learnable_sink: Optional[torch.Tensor] = None,\n        softcap: float = 0.0,\n        num_splits: int = 1,\n        pack_gqa: Optional[bool] = None,\n        deterministic: bool = False,\n        mask_mod: Optional[Callable] = None,\n        full_block_cnt: Optional[torch.Tensor] = None,\n        full_block_idx: Optional[torch.Tensor] = None,\n        mask_block_cnt: Optional[torch.Tensor] = None,\n        mask_block_idx: Optional[torch.Tensor] = None,\n        block_size: Optional[Tuple[int, int]] = None,\n        return_lse: bool = False,\n    ):\n        # Only create block sparse tensors if at least one block sparse parameter is provided\n        block_sparse_tensors = None\n        if any(t is not None for t in [full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx]):\n            block_sparse_tensors = BlockSparseTensorsTorch(\n                full_block_cnt=full_block_cnt,\n                full_block_idx=full_block_idx,\n                mask_block_cnt=mask_block_cnt,\n                mask_block_idx=mask_block_idx,\n                block_size=block_size,\n            )\n        out, lse = _flash_attn_fwd(\n            q,\n            k,\n            v,\n            softmax_scale=softmax_scale,\n            causal=causal,\n            window_size_left=window_size[0],\n            window_size_right=window_size[1],\n            learnable_sink=learnable_sink,\n            softcap=softcap,\n            num_splits=num_splits,\n            pack_gqa=pack_gqa,\n            mask_mod=mask_mod,\n            block_sparse_tensors=block_sparse_tensors,\n            return_lse=return_lse,\n        )\n        ctx.save_for_backward(q, k, v, out, lse)\n        ctx.softmax_scale = softmax_scale\n        ctx.causal = causal\n        ctx.window_size = window_size\n        ctx.softcap = softcap\n        ctx.deterministic = deterministic\n        ctx.return_lse = return_lse\n        ctx.set_materialize_grads(False)\n        return out, lse\n\n    @staticmethod\n    def backward(ctx, dout, dlse):\n        q, k, v, out, lse = ctx.saved_tensors\n        if not ctx.return_lse:\n            dlse = None\n        if dout is None:\n            dout = torch.zeros_like(out)\n        dq, dk, dv = _flash_attn_bwd(\n            q,\n            k,\n            v,\n            out,\n            dout,\n            lse,\n            ctx.softmax_scale,\n            ctx.causal,\n            ctx.softcap,\n            window_size_left=ctx.window_size[0],\n            window_size_right=ctx.window_size[1],\n            deterministic=ctx.deterministic,\n            dlse=dlse,\n        )\n        return dq, dk, dv, *((None,) * 20)  # Extra Nones is fine\n\n\nclass FlashAttnVarlenFunc(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        cu_seqlens_q: Optional[torch.Tensor],\n        cu_seqlens_k: Optional[torch.Tensor],\n        seqused_q: Optional[torch.Tensor] = None,\n        seqused_k: Optional[torch.Tensor] = None,\n        max_seqlen_q: Optional[int] = None,\n        max_seqlen_k: Optional[int] = None,\n        page_table: Optional[torch.Tensor] = None,\n        softmax_scale: Optional[float] = None,\n        causal: bool = False,\n        window_size: Tuple[Optional[int], Optional[int]] = (None, None),\n        learnable_sink: Optional[torch.Tensor] = None,\n        softcap: float = 0.0,\n        num_splits: int = 1,\n        pack_gqa: Optional[bool] = None,\n        deterministic: bool = False,\n        score_mod: Optional[Callable] = None,\n        aux_tensors: Optional[list] = None,\n        return_lse: bool = False,\n    ):\n        out, lse = _flash_attn_fwd(\n            q,\n            k,\n            v,\n            cu_seqlens_q,\n            cu_seqlens_k,\n            seqused_q,\n            seqused_k,\n            max_seqlen_q=max_seqlen_q,\n            max_seqlen_k=max_seqlen_k,\n            page_table=page_table,\n            softmax_scale=softmax_scale,\n            causal=causal,\n            window_size_left=window_size[0],\n            window_size_right=window_size[1],\n            learnable_sink=learnable_sink,\n            softcap=softcap,\n            num_splits=num_splits,\n            pack_gqa=pack_gqa,\n            score_mod=score_mod,\n            aux_tensors=aux_tensors,\n            return_lse=return_lse,\n        )\n        ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)\n        ctx.softmax_scale = softmax_scale\n        ctx.causal = causal\n        ctx.window_size = window_size\n        ctx.softcap = softcap\n        ctx.deterministic = deterministic\n        ctx.max_seqlen_q = max_seqlen_q\n        ctx.max_seqlen_k = max_seqlen_k\n        ctx.return_lse = return_lse\n        ctx.set_materialize_grads(False)\n        return out, lse\n\n    @staticmethod\n    def backward(ctx, dout, dlse):\n        q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors\n        assert ctx.softcap == 0.0\n        if not ctx.return_lse:\n            dlse = None\n        if dout is None:\n            dout = torch.zeros_like(out)\n        dq, dk, dv = _flash_attn_bwd(\n            q,\n            k,\n            v,\n            out,\n            dout,\n            lse,\n            ctx.softmax_scale,\n            ctx.causal,\n            ctx.softcap,\n            window_size_left=ctx.window_size[0],\n            window_size_right=ctx.window_size[1],\n            cu_seqlens_q=cu_seqlens_q,\n            cu_seqlens_k=cu_seqlens_k,\n            seqused_q=seqused_q,\n            seqused_k=seqused_k,\n            max_seqlen_q=ctx.max_seqlen_q,\n            max_seqlen_k=ctx.max_seqlen_k,\n            deterministic=ctx.deterministic,\n            dlse=dlse,\n        )\n\n        return dq, dk, dv, *((None,) * 20)\n\n\ndef flash_attn_func(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    softmax_scale: Optional[float] = None,\n    causal: bool = False,\n    window_size: Tuple[Optional[int], Optional[int]] = (None, None),\n    learnable_sink: Optional[torch.Tensor] = None,\n    softcap: float = 0.0,\n    num_splits: int = 1,\n    pack_gqa: Optional[bool] = None,\n    deterministic: bool = False,\n    mask_mod: Optional[Callable] = None,\n    full_block_cnt: Optional[torch.Tensor] = None,\n    full_block_idx: Optional[torch.Tensor] = None,\n    mask_block_cnt: Optional[torch.Tensor] = None,\n    mask_block_idx: Optional[torch.Tensor] = None,\n    block_size: Optional[Tuple[int, int]] = None,\n    return_lse: bool = False,\n):\n    return FlashAttnFunc.apply(\n        q,\n        k,\n        v,\n        softmax_scale,\n        causal,\n        window_size,\n        learnable_sink,\n        softcap,\n        num_splits,\n        pack_gqa,\n        deterministic,\n        mask_mod,\n        full_block_cnt,\n        full_block_idx,\n        mask_block_cnt,\n        mask_block_idx,\n        block_size,\n        return_lse,\n    )\n\n\ndef flash_attn_varlen_func(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    cu_seqlens_q: Optional[torch.Tensor] = None,\n    cu_seqlens_k: Optional[torch.Tensor] = None,\n    max_seqlen_q: Optional[int] = None,\n    max_seqlen_k: Optional[int] = None,\n    seqused_q: Optional[torch.Tensor] = None,\n    seqused_k: Optional[torch.Tensor] = None,\n    page_table: Optional[torch.Tensor] = None,\n    softmax_scale: Optional[float] = None,\n    causal: bool = False,\n    window_size: Tuple[Optional[int], Optional[int]] = (None, None),\n    learnable_sink: Optional[torch.Tensor] = None,\n    softcap: float = 0.0,\n    num_splits: int = 1,\n    pack_gqa: Optional[bool] = None,\n    deterministic: bool = False,\n    score_mod: Optional[Callable] = None,\n    aux_tensors: Optional[list] = None,\n    return_lse: bool = False,\n):\n    return FlashAttnVarlenFunc.apply(\n        q,\n        k,\n        v,\n        cu_seqlens_q,\n        cu_seqlens_k,\n        seqused_q,\n        seqused_k,\n        max_seqlen_q,\n        max_seqlen_k,\n        page_table,\n        softmax_scale,\n        causal,\n        window_size,\n        learnable_sink,\n        softcap,\n        num_splits,\n        pack_gqa,\n        deterministic,\n        score_mod,\n        aux_tensors,\n        return_lse,\n    )\n\n\ndef _compile_fwd_combine(\n    dtype, dtype_partial, head_dim, tile_m, k_block_size, log_max_splits,\n    has_cu_seqlens, has_seqused, has_lse, has_varlen_batch_idx,\n):\n    \"\"\"Compile fwd combine kernel using cute fake tensors (no real GPU tensors needed).\"\"\"\n    sym = cute.sym_int\n    div = 128 // dtype_partial.width  # 16-byte alignment in elements\n\n    fa_combine = FlashAttentionForwardCombine(\n        dtype=dtype,\n        dtype_partial=dtype_partial,\n        head_dim=head_dim,\n        tile_m=tile_m,\n        k_block_size=k_block_size,\n        log_max_splits=log_max_splits,\n    )\n    if not fa_combine.can_implement(\n        dtype, dtype_partial, head_dim, tile_m, k_block_size, log_max_splits,\n        num_threads=256,\n    ):\n        raise RuntimeError(\n            \"FlashAttention combine kernel cannot be implemented with given parameters\"\n        )\n\n    if has_cu_seqlens:\n        # Varlen: (num_splits, total_q, nheads, headdim)\n        num_splits, total_q, nheads = sym(), sym(), sym()\n        mO_partial = fake_tensor(dtype_partial, (num_splits, total_q, nheads, head_dim), divisibility=div)\n        mLSE_partial = fake_tensor(Float32, (num_splits, total_q, nheads), divisibility=1, leading_dim=1)\n        mO = fake_tensor(dtype, (total_q, nheads, head_dim), divisibility=div)\n        mLSE = fake_tensor(Float32, (total_q, nheads), divisibility=1, leading_dim=0) if has_lse else None\n    else:\n        # Batched: (num_splits, batch, seqlen, nheads, headdim)\n        num_splits, batch, seqlen, nheads = sym(), sym(), sym(), sym()\n        mO_partial = fake_tensor(dtype_partial, (num_splits, batch, seqlen, nheads, head_dim), divisibility=div)\n        mLSE_partial = fake_tensor(Float32, (num_splits, batch, seqlen, nheads), divisibility=1, leading_dim=2)\n        mO = fake_tensor(dtype, (batch, seqlen, nheads, head_dim), divisibility=div)\n        mLSE = fake_tensor(Float32, (batch, seqlen, nheads), divisibility=1, leading_dim=1) if has_lse else None\n        batch = mO_partial.shape[1]\n\n    batch_for_1d = batch if not has_cu_seqlens else sym()\n    batchp1 = sym()\n    mCuSeqlens = fake_tensor(Int32, (batchp1,), divisibility=1) if has_cu_seqlens else None\n    mSeqused = fake_tensor(Int32, (batch_for_1d,), divisibility=1) if has_seqused else None\n    mNumSplitsDynamic = None  # Not parametrized in compile_key\n    mVarlenBatchIdx = fake_tensor(Int32, (batch_for_1d,), divisibility=1) if has_varlen_batch_idx else None\n    mSemaphore = None  # Not parametrized in compile_key\n\n    return cute.compile(\n        fa_combine,\n        mO_partial, mLSE_partial, mO, mLSE,\n        mCuSeqlens, mSeqused, mNumSplitsDynamic, mVarlenBatchIdx, mSemaphore,\n        cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True),\n        options=\"--enable-tvm-ffi\",\n    )\n\n\ndef _flash_attn_fwd_combine(\n    out_partial: torch.Tensor,\n    lse_partial: torch.Tensor,\n    out: torch.Tensor,\n    lse: Optional[torch.Tensor] = None,\n    cu_seqlens: Optional[torch.Tensor] = None,\n    seqused: Optional[torch.Tensor] = None,\n    num_splits_dynamic_ptr: Optional[torch.Tensor] = None,\n    varlen_batch_idx: Optional[torch.Tensor] = None,\n    semaphore_to_reset: Optional[torch.Tensor] = None,\n) -> None:\n    \"\"\"Forward combine kernel for split attention computation.\n\n    Combines partial outputs and log-sum-exp values from multiple splits\n    of attention computation into final outputs.\n\n    Args:\n        out_partial: Partial outputs tensor (num_splits, batch, seqlen, nheads, headdim) or\n                                            (num_splits, total_q, nheads, headdim) if there's cu_seqlens\n        lse_partial: Partial LSE tensor (num_splits, batch, seqlen, nheads) or\n                                       (num_splits, total_q, nheads) if there's cu_seqlens\n        out: Output tensor (batch, seqlen, nheads, headdim) or (total_q, nheads, headdim) if there's cu_seqlens\n        lse: Output LSE tensor (batch, seqlen, nheads) or (total_q, nheads) if there's cu_seqlens.\n        cu_seqlens: Cumulative sequence lengths for variable length sequences\n        seqused: Used sequence lengths for each batch\n        num_splits_dynamic_ptr: Dynamic number of splits per batch\n        semaphore_to_reset: Semaphore for synchronization\n        k_block_size: Block size for head dimension\n\n    Returns:\n        None\n    \"\"\"\n    assert out_partial.dtype in [torch.float16, torch.bfloat16, torch.float32], (\n        \"out_partial must be fp16, bf16, or fp32\"\n    )\n    if not is_fake_mode():\n        assert out_partial.is_cuda and lse_partial.is_cuda, \"tensors must be on CUDA device\"\n    # Determine if this is variable length based on dimensions\n    is_varlen = out_partial.dim() == 4\n    # Validate optional tensors\n    for t, name in [\n        (cu_seqlens, \"cu_seqlens\"),\n        (seqused, \"seqused\"),\n        (num_splits_dynamic_ptr, \"num_splits_dynamic_ptr\"),\n    ]:\n        if t is not None:\n            if not is_fake_mode():\n                assert t.is_cuda, f\"{name} must be on CUDA device\"\n            assert t.is_contiguous(), f\"{name} must be contiguous\"\n    head_dim = out_partial.shape[-1]\n    num_splits = out_partial.shape[0]\n    assert num_splits <= 256\n    # If hdim is 96 or 192, it's faster to round them to 128 or 256 respectively\n    # so that kBlockM is smaller and we have more parallelism.\n    k_block_size = 64 if head_dim <= 64 else 128\n    # We want kBlockM to be as small as possible to maximize parallelism.\n    # E.g., if hdim is 64, we want kBlockM to be 16 so that we can use 256 threads, each reading 4 elements (floats).\n    tile_m = 8 if k_block_size % 128 == 0 else (16 if k_block_size % 64 == 0 else 32)\n    log_max_splits = max(math.ceil(math.log2(num_splits)), 4)\n    if tile_m == 8:\n        # If kBlockM == 8 then the minimum number of splits is 32.\n        # TODO: we can deal w this by using 128 threads instead\n        log_max_splits = max(log_max_splits, 5)\n\n    # Create combine kernel configuration\n    dtype = torch2cute_dtype_map[out.dtype]\n    dtype_partial = torch2cute_dtype_map[out_partial.dtype]\n    compile_key = (\n        dtype,\n        dtype_partial,\n        head_dim,\n        tile_m,\n        k_block_size,\n        log_max_splits,\n        cu_seqlens is not None,\n        seqused is not None,\n        lse is not None,\n        varlen_batch_idx is not None,\n    )\n    if compile_key not in _flash_attn_fwd_combine.compile_cache:\n        _flash_attn_fwd_combine.compile_cache[compile_key] = _compile_fwd_combine(\n            *compile_key\n        )\n    if not is_fake_mode():\n        _flash_attn_fwd_combine.compile_cache[compile_key](\n            out_partial, lse_partial, out, lse,\n            cu_seqlens, seqused, num_splits_dynamic_ptr, varlen_batch_idx,\n            semaphore_to_reset,\n        )\n\n\n_flash_attn_fwd_combine.compile_cache = get_jit_cache(\"fwd_combine\")\n\n\ndef flash_attn_combine(\n    out_partial: torch.Tensor,\n    lse_partial: torch.Tensor,\n    out: Optional[torch.Tensor] = None,\n    out_dtype: Optional[torch.dtype] = None,\n    cu_seqlens: Optional[torch.Tensor] = None,\n    seqused: Optional[torch.Tensor] = None,\n    varlen_batch_idx: Optional[torch.Tensor] = None,\n    return_lse: bool = True,\n) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n    \"\"\"Flash Attention combine function for split attention computation.\n\n    Combines partial outputs and log-sum-exp values from multiple splits\n    of attention computation into final outputs. This is the main user-facing\n    interface for the combine kernel.\n\n    Args:\n        out_partial: Partial outputs tensor with shape:\n            - (num_splits, batch_size, seqlen, num_heads, head_size) for regular batched input\n            - (num_splits, total_q, num_heads, head_size) for variable length input\n        lse_partial: Partial LSE tensor with shape:\n            - (num_splits, batch_size, seqlen, num_heads) for regular batched input\n            - (num_splits, total_q, num_heads) for variable length input\n        out: Optional output tensor. If None, will be created automatically.\n        out_dtype: Optional output dtype. If None, will use fp16/bf16 based on input.\n        cu_seqlens: Cumulative sequence lengths for variable length sequences\n        seqused: Used sequence lengths for each batch\n        varlen_batch_idx: Optional mapping from virtual batch index to real batch index\n            (int32 tensor of shape (batch_size,)). Used by persistent tile schedulers\n            that reorder batch processing for load balancing.\n        return_lse: Whether to return the combined LSE tensor. Default is True.\n\n    Returns:\n        Tuple of (out, lse) where:\n        - out: Combined output tensor with shape (batch_size, seqlen, num_heads, head_size)\n              or (total_q, num_heads, head_size) for varlen\n        - lse: Combined log-sum-exp tensor with shape (batch_size, seqlen, num_heads)\n              or (total_q, num_heads) for varlen. None if return_lse=False\n\n    Note:\n        This function expects the input tensors to be in the format produced by\n        split attention computation, where the first dimension is num_splits.\n        The permuting from user format to kernel format is now done inside the kernel.\n    \"\"\"\n    # Input validation\n    assert out_partial.dim() in [4, 5], \"out_partial must have 4 or 5 dimensions\"\n    # Determine if this is variable length based on dimensions\n    is_varlen = out_partial.dim() == 4\n    if is_varlen:\n        # Variable length: (num_splits, total_q, num_heads, head_size)\n        num_splits, total_q, num_heads, head_size = out_partial.shape\n        batch_size = 1  # Treat as single batch for varlen\n        seqlen = total_q\n    else:\n        # Regular batched: (num_splits, batch_size, seqlen, num_heads, head_size)\n        num_splits, batch_size, seqlen, num_heads, head_size = out_partial.shape\n    # Determine output dtype\n    if out_dtype is None:\n        out_dtype = out_partial.dtype\n    # Create output if not provided\n    device = out_partial.device\n    if out is None:\n        if is_varlen:\n            out = torch.empty(total_q, num_heads, head_size, dtype=out_dtype, device=device)\n        else:\n            out = torch.empty(\n                batch_size, seqlen, num_heads, head_size, dtype=out_dtype, device=device\n            )\n    # Create lse output only if requested\n    if return_lse:\n        if is_varlen:\n            lse = torch.empty(num_heads, total_q, dtype=torch.float32, device=device)\n        else:\n            lse = torch.empty(batch_size, num_heads, seqlen, dtype=torch.float32, device=device)\n        lse = lse.transpose(-1, -2)\n    else:\n        lse = None\n    _flash_attn_fwd_combine(\n        out_partial,\n        lse_partial,\n        out,\n        lse,\n        cu_seqlens,\n        seqused,\n        varlen_batch_idx=varlen_batch_idx,\n    )\n    return out, lse\n"
  },
  {
    "path": "flash_attn/cute/mask.py",
    "content": "# Copyright (c) 2025, Tri Dao.\n\nfrom typing import Optional, Callable, TypeAlias\nfrom dataclasses import dataclass\n\nimport cutlass\nimport cutlass.cute as cute\nfrom cutlass import Float32, Int32, Uint32, const_expr\n\nfrom quack import layout_utils\nimport flash_attn.cute.utils as utils\nfrom flash_attn.cute.seqlen_info import SeqlenInfoQK\n\nMaskGenFn: TypeAlias = Callable[[int], Uint32]\nMASK_R2P_CHUNK_SIZE: int = 32\n\n\n@cute.jit\ndef r2p_bitmask_below(limit: Int32, s: int) -> Uint32:\n    \"\"\"32-bit R2P bitmask keeping positions < limit (exclusive upper bound).\n\n    Positions 0..limit-1 in chunk `s` get bit=1 (keep), the rest bit=0 (mask).\n    Uses inline PTX to avoid shift-by-type-width UB.\n    \"\"\"\n    m = max((s + 1) * MASK_R2P_CHUNK_SIZE - limit, 0)\n    return utils.shr_u32(Uint32(0xFFFFFFFF), Uint32(m))\n\n\n@cute.jit\ndef r2p_bitmask_above(limit: Int32, s: int) -> Uint32:\n    \"\"\"32-bit R2P bitmask keeping positions >= limit (inclusive lower bound).\n\n    Positions limit..31 in chunk `s` get bit=1 (keep), the rest bit=0 (mask).\n    Uses inline PTX to avoid shift-by-type-width UB.\n    \"\"\"\n    n = max(limit - s * MASK_R2P_CHUNK_SIZE, 0)\n    return utils.shl_u32(Uint32(0xFFFFFFFF), Uint32(n))\n\n\n@cute.jit\ndef mask_r2p_lambda(\n    X: cute.Tensor,\n    mask_gen_fn: cutlass.Constexpr[MaskGenFn],\n    rank1: bool = False,\n) -> None:\n    \"\"\"Apply R2P masking with a custom bitmask generator.\n\n    mask_gen_fn(chunk_idx: constexpr int) -> Uint32:\n        Returns a 32-bit bitmask for the chunk. Bit i set means column\n        chunk_idx * chunk_size + i is KEPT; bit i clear means masked to -inf.\n    \"\"\"\n    ncol = const_expr(cute.size(X.shape[cute.rank(X) - 1]) if not rank1 else cute.size(X.shape))\n    # 32-column chunks. The mask_gen_fn returns a Uint32 bitmask (1=keep).\n    CHUNK_SIZE = MASK_R2P_CHUNK_SIZE\n    for s in cutlass.range_constexpr(cute.ceil_div(ncol, CHUNK_SIZE)):\n        mask = mask_gen_fn(s)\n        # This needs to be range_constexpr, o/w the compiler can't generate the R2P instruction\n        for i in cutlass.range_constexpr(min(CHUNK_SIZE, ncol - s * CHUNK_SIZE)):\n            in_bound = cutlass.Boolean(mask & (Uint32(1) << i))\n            c = s * CHUNK_SIZE + i\n            if const_expr(rank1):\n                X[c] = X[c] if in_bound else -Float32.inf\n            else:\n                for r in cutlass.range_constexpr(cute.size(X.shape[0])):\n                    X[r, c] = X[r, c] if in_bound else -Float32.inf\n\n\n@cute.jit\ndef sm90_col_to_r2p_idx(col_limit: Int32) -> Int32:\n    \"\"\"Transform SM90 MMA column coordinate to R2P element index.\n\n    SM90 MMA accumulator column indices are non-contiguous: 0, 1, 8, 9, 16, 17, ...\n    Element indices are contiguous: 0, 1, 2, 3, 4, 5, ...\n    This converts a column-space threshold to element-space for r2p_bitmask_below/above.\n    \"\"\"\n    return col_limit // 8 * 2 + min(col_limit % 8, 2)\n\n\n@cute.jit\ndef row_to_r2p_idx(x: Int32, num_rep: int, num_wg: int) -> Int32:\n    \"\"\"Convert a row coordinate to an R2P element index in the warp-group interleaved layout.\n\n    In the SM100 backward pass, 2 warp groups share TMEM. The TMEM load atom\n    distributes rows in an interleaved pattern: elements 0..num_rep-1 map to\n    rows 0..num_rep-1 (warp group 0), elements num_rep..2*num_rep-1 map to\n    rows num_rep*num_wg..num_rep*num_wg+num_rep-1 (warp group 1), and so on.\n    Row-coordinate thresholds (causal limits, window bounds, uih_len) must be\n    converted to element indices before use with r2p_bitmask_above/below.\n\n    Rows not owned by this thread (in the gap between warp groups) are clamped\n    to the boundary element index, which is safe because R2P thresholds are\n    monotonic.\n\n    Example with num_rep=16, num_wg=2:\n        row  0 -> elem  0,  row 15 -> elem 15,\n        row 16 -> elem 16 (clamped), row 31 -> elem 16 (clamped),\n        row 32 -> elem 16, row 33 -> elem 17, row 47 -> elem 31.\n    \"\"\"\n    return x // (num_rep * num_wg) * num_rep + min(x % (num_rep * num_wg), num_rep)\n\n\n@dataclass(frozen=True)\nclass AttentionMask:\n    tile_m: cutlass.Constexpr[int]\n    tile_n: cutlass.Constexpr[int]\n    seqlen_info: SeqlenInfoQK\n    window_size_left: Optional[Int32] = None\n    window_size_right: Optional[Int32] = None\n    qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1  # only pass in if we're doing PackGQA\n    swap_AB: cutlass.Constexpr[bool] = False\n\n    @property\n    def seqlen_q(self) -> Int32:\n        return self.seqlen_info.seqlen_q\n\n    @property\n    def seqlen_k(self) -> Int32:\n        return self.seqlen_info.seqlen_k\n\n    @cute.jit\n    def apply_mask(\n        self,\n        acc_S: cute.Tensor,\n        batch_idx: cutlass.Int32,\n        head_idx: cutlass.Int32,\n        m_block: cutlass.Int32,\n        n_block: cutlass.Int32,\n        thr_mma: cute.TiledMma,\n        mask_seqlen: cutlass.Constexpr[bool],\n        mask_causal: cutlass.Constexpr[bool],\n        mask_local: cutlass.Constexpr[bool] = False,\n        mask_mod: cutlass.Constexpr[Optional[Callable]] = None,\n        aux_tensors: Optional[list] = None,\n        fastdiv_mods=(None, None),\n    ) -> None:\n        assert not (mask_causal and mask_local), \"mask_causal and mask_local cannot be both True\"\n        acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S, transpose=self.swap_AB)\n        acc_shape = (self.tile_m, self.tile_n)\n        cS = cute.make_identity_tensor(acc_shape if not self.swap_AB else acc_shape[::-1])\n        tScS_mn = layout_utils.reshape_acc_to_mn(thr_mma.partition_C(cS), transpose=self.swap_AB)\n        # We use t0ScS as these indices are known at compile time. We then must subtract the\n        # column limit by the thread column offset.\n        t0ScS_mn = layout_utils.reshape_acc_to_mn(\n            thr_mma.get_slice(0).partition_C(cS), transpose=self.swap_AB\n        )\n        ROW = 0 if const_expr(not self.swap_AB) else 1\n        COL = 1 if const_expr(not self.swap_AB) else 0\n        thr_col_offset = tScS_mn[0][COL]\n        # To handle edge cases of completely masked out rows where n_block_max = 0,\n        # we treat negative n_blocks as 0th n_block\n        # TODO: find more transparent solution\n        if n_block < 0:\n            n_block = 0\n        seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset\n        if const_expr(not mask_causal and not mask_local and mask_mod is None):\n            if const_expr(mask_seqlen):\n                r2p = const_expr(not self.swap_AB)\n                if const_expr(not r2p):\n                    # traverse column index.\n                    for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True):\n                        oob = t0ScS_mn[0, c][COL] >= seqlenk_col_limit\n                        for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True):\n                            acc_S_mn[r, c] = -Float32.inf if oob else acc_S_mn[r, c]\n                else:\n                    seqlenk_col_limit_r2p = sm90_col_to_r2p_idx(seqlenk_col_limit)\n                    mask_r2p_lambda(acc_S_mn, lambda s: r2p_bitmask_below(seqlenk_col_limit_r2p, s))\n\n        elif const_expr(\n            not mask_causal and not mask_local and mask_mod is not None\n        ):  # FlexAttention mask mod\n            nrow = const_expr(cute.size(tScS_mn.shape[0]))\n            ncol = const_expr(cute.size(tScS_mn.shape[1]))\n            has_fastdiv = const_expr(\n                fastdiv_mods is not None\n                and fastdiv_mods[0] is not None\n                and fastdiv_mods[1] is not None\n            )\n            wrap_aux_indices = const_expr(\n                has_fastdiv and mask_seqlen and const_expr(aux_tensors is not None)\n            )\n\n            for r in cutlass.range_constexpr(nrow):\n                # Respect swap_AB: ROW/COL determine which coordinate component corresponds to Q/KV.\n                local_row = tScS_mn[r, 0][ROW]\n                global_row_idx = local_row + m_block * self.tile_m\n                row_for_mod = global_row_idx\n                head_idx_for_mod = head_idx\n                if const_expr(self.qhead_per_kvhead_packgqa != 1):\n                    head_offset = global_row_idx % self.qhead_per_kvhead_packgqa\n                    head_idx_for_mod = head_idx * self.qhead_per_kvhead_packgqa + head_offset\n                    row_for_mod = global_row_idx // self.qhead_per_kvhead_packgqa\n                row_for_seqlen = row_for_mod\n                if const_expr(wrap_aux_indices):\n                    _, row_for_mod = divmod(row_for_mod, fastdiv_mods[0])\n\n                for col in cutlass.range_constexpr(ncol):\n                    col_idx_local = t0ScS_mn[0, col][COL]\n                    # Convert to absolute column index\n                    global_col_idx = thr_col_offset + col_idx_local + n_block * self.tile_n\n                    col_for_mod = global_col_idx\n                    if const_expr(wrap_aux_indices):\n                        _, col_for_mod = divmod(global_col_idx, fastdiv_mods[1])\n\n                    batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32)\n                    head_idx_ssa = utils.scalar_to_ssa(head_idx_for_mod, cutlass.Int32)\n                    q_idx_ssa = utils.scalar_to_ssa(row_for_mod, cutlass.Int32)\n                    kv_idx_ssa = utils.scalar_to_ssa(col_for_mod, cutlass.Int32)\n                    mask_value = mask_mod(\n                        batch_idx_ssa,\n                        head_idx_ssa,\n                        q_idx_ssa,\n                        kv_idx_ssa,\n                        self.seqlen_info,\n                        aux_tensors,\n                    )\n                    cond = cutlass.Boolean(utils.ssa_to_scalar(mask_value))\n                    if const_expr(mask_seqlen):\n                        out_of_bounds = (row_for_seqlen >= self.seqlen_q) or (\n                            global_col_idx >= self.seqlen_k\n                        )\n                        if out_of_bounds:\n                            acc_S_mn[r, col] = -cutlass.Float32.inf\n                        else:\n                            acc_S_mn[r, col] = acc_S_mn[r, col] if cond else -cutlass.Float32.inf\n                    else:\n                        acc_S_mn[r, col] = acc_S_mn[r, col] if cond else -cutlass.Float32.inf\n\n        else:  # Causal or local\n            if const_expr(not self.swap_AB):\n                # If PackGQA, we split the work of compute divmod among threads in the same row\n                threads_per_row = thr_mma.tv_layout_C.shape[0][0]\n                mma_m_idx = None\n                if const_expr(self.qhead_per_kvhead_packgqa != 1):\n                    assert not self.swap_AB, \"swap_AB with PackGQA not supported yet\"\n                    assert cute.arch.WARP_SIZE % threads_per_row == 0, (\n                        \"threads_per_row must divide WARP_SIZE\"\n                    )\n                    assert cute.size(acc_S_mn.shape[0]) <= threads_per_row\n                    tidx = thr_mma.thr_idx\n                    mma_m_idx = (\n                        m_block * self.tile_m + tScS_mn[tidx % threads_per_row, 0][0]\n                    ) // self.qhead_per_kvhead_packgqa\n                causal_row_offset = (\n                    1 + self.seqlen_k - n_block * self.tile_n - self.seqlen_q - thr_col_offset\n                )\n                if const_expr(mask_causal):\n                    r2p = const_expr(not self.swap_AB)  # R2P trick, see apply_mask_sm100\n                    for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True):\n                        # get the column index limit based on current row. Only consider the row index, so the column index sets to 0.\n                        if const_expr(self.qhead_per_kvhead_packgqa == 1):\n                            row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m\n                        else:\n                            row_idx = utils.shuffle_sync(\n                                mma_m_idx, r % threads_per_row, width=threads_per_row\n                            )\n                        col_limit_right = row_idx + causal_row_offset\n                        if const_expr(mask_seqlen):\n                            col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit)\n                        if const_expr(not r2p):\n                            # traverse column index.\n                            for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True):\n                                acc_S_mn[r, c] = (\n                                    -Float32.inf\n                                    if t0ScS_mn[0, c][1] >= col_limit_right\n                                    else acc_S_mn[r, c]\n                                )\n                        else:\n                            col_limit_r2p = sm90_col_to_r2p_idx(col_limit_right)\n                            mask_r2p_lambda(\n                                acc_S_mn[r, None],\n                                lambda s: r2p_bitmask_below(col_limit_r2p, s),\n                                rank1=True,\n                            )\n                else:  # Local\n                    local_row_offset_right = (\n                        causal_row_offset + self.window_size_right\n                        if const_expr(self.window_size_right is not None)\n                        else None\n                    )\n                    local_row_offset_left = (\n                        causal_row_offset - 1 - self.window_size_left\n                        if const_expr(self.window_size_left is not None)\n                        else None\n                    )\n                    r2p_local = const_expr(not self.swap_AB)\n                    for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True):\n                        if const_expr(self.qhead_per_kvhead_packgqa == 1):\n                            row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m\n                        else:\n                            row_idx = utils.shuffle_sync(\n                                mma_m_idx, r % threads_per_row, width=threads_per_row\n                            )\n                        if const_expr(self.window_size_right is not None):\n                            col_limit_right = row_idx + local_row_offset_right\n                        else:\n                            col_limit_right = self.tile_n\n                        if const_expr(mask_seqlen):\n                            col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit)\n                        col_limit_left = (\n                            row_idx + local_row_offset_left\n                            if const_expr(self.window_size_left is not None)\n                            else 0\n                        )\n                        if const_expr(not r2p_local):\n                            # traverse column index.\n                            for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True):\n                                col_idx = t0ScS_mn[0, c][1]\n                                if col_idx >= col_limit_right or col_idx < col_limit_left:\n                                    acc_S_mn[r, c] = -Float32.inf\n                        else:\n                            col_limit_right_r2p = sm90_col_to_r2p_idx(col_limit_right)\n                            col_limit_left_r2p = sm90_col_to_r2p_idx(col_limit_left)\n\n                            def mask_gen_fn(s: int) -> Uint32:\n                                return r2p_bitmask_below(\n                                    col_limit_right_r2p, s\n                                ) & r2p_bitmask_above(col_limit_left_r2p, s)\n\n                            mask_r2p_lambda(acc_S_mn[r, None], mask_gen_fn, rank1=True)\n            else:  # swap_AB\n                assert self.qhead_per_kvhead_packgqa == 1\n                thr_row_offset = tScS_mn[0][ROW]\n                causal_row_offset = (\n                    seqlenk_col_limit - self.seqlen_q + m_block * self.tile_m + thr_row_offset\n                )\n                if const_expr(mask_causal):\n                    for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True):\n                        col0 = t0ScS_mn[0, c][COL]\n                        # If col0 is beyond the column limit, we want to mask out the entire\n                        # column, by setting row limit to be self.tile_m.\n                        row_limit_top = (\n                            self.tile_m\n                            if col0 >= seqlenk_col_limit and mask_seqlen\n                            else col0 - causal_row_offset\n                        )\n                        for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True):\n                            acc_S_mn[r, c] = (\n                                -Float32.inf\n                                if t0ScS_mn[r, 0][ROW] < row_limit_top\n                                else acc_S_mn[r, c]\n                            )\n                else:\n                    for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True):\n                        col0 = t0ScS_mn[0, c][COL]\n                        # If col0 is beyond the column limit, we want to mask out the entire\n                        # column, by setting row limit to be self.tile_m.\n                        row_limit_top = (\n                            self.tile_m\n                            if col0 >= seqlenk_col_limit and mask_seqlen\n                            else (\n                                col0 - causal_row_offset - self.window_size_right\n                                if const_expr(self.window_size_right is not None)\n                                else 0\n                            )\n                        )\n                        row_limit_bot = (\n                            col0 - causal_row_offset + self.window_size_left\n                            if const_expr(self.window_size_left is not None)\n                            else self.tile_m\n                        )\n                        for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True):\n                            row_idx = t0ScS_mn[r, 0][ROW]\n                            acc_S_mn[r, c] = (\n                                -Float32.inf\n                                if row_idx < row_limit_top or row_idx > row_limit_bot\n                                else acc_S_mn[r, c]\n                            )\n\n    @cute.jit\n    def apply_mask_sm100(\n        self,\n        acc_S: cute.Tensor,\n        m_block: Int32,\n        n_block: Int32,\n        thr_mma: cute.TiledMma,\n        thr_tmem_load: cute.TiledCopy,\n        mask_seqlen: cutlass.Constexpr[bool],\n        mask_causal: cutlass.Constexpr[bool],\n        mask_local: cutlass.Constexpr[bool] = False,\n        mask_mod: cutlass.Constexpr[Optional[Callable]] = None,\n        batch_idx: Int32 = None,\n        head_idx: Int32 = None,\n        aux_tensors: Optional[list] = None,\n        fastdiv_mods=(None, None),\n        head_divmod=None,\n        check_q_boundary: bool = False,\n    ) -> None:\n        assert not (mask_causal and mask_local), \"mask_causal and mask_local cannot be both True\"\n        acc_shape = (self.tile_m, self.tile_n)\n        cS = cute.make_identity_tensor(acc_shape if not self.swap_AB else acc_shape[::-1])\n        tScS = thr_mma.partition_C(cS)\n        tScS = tScS[(None, None), 0, 0]\n        tScS_t2r = thr_tmem_load.partition_D(tScS)\n        # To handle edge cases of completely masked out rows where n_block_max = 0,\n        # we treat negative n_blocks as 0th n_block\n        # TODO: find more transparent solution\n        if n_block < 0:\n            n_block = 0\n        seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n\n        r2p = True\n        if const_expr(not mask_causal and not mask_local and mask_mod is None):\n            if const_expr(mask_seqlen):\n                if const_expr(not r2p):\n                    for i in cutlass.range(cute.size(tScS_t2r.shape), unroll_full=True):\n                        # if tScS_t2r[i][1] >= seqlenk_col_limit:\n                        #     acc_S[i] = -Float32.inf\n                        # For some reason the 2 lines above generate really bad SASS\n                        acc_S[i] = -Float32.inf if tScS_t2r[i][1] >= seqlenk_col_limit else acc_S[i]\n                else:\n                    mask_r2p_lambda(\n                        acc_S,\n                        lambda s: r2p_bitmask_below(seqlenk_col_limit, s),\n                        rank1=True,\n                    )\n\n        elif const_expr(not mask_causal and not mask_local and mask_mod is not None):\n            # Block sparse case w/ mask_mod\n            has_fastdiv = const_expr(\n                fastdiv_mods is not None\n                and fastdiv_mods[0] is not None\n                and fastdiv_mods[1] is not None\n            )\n            batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32)\n\n            ncol = const_expr(cute.size(tScS_t2r.shape))\n            for i in cutlass.range_constexpr(ncol):\n                row_coord = tScS_t2r[i][0] if not self.swap_AB else tScS_t2r[i][1]\n                col_coord = tScS_t2r[i][1] if not self.swap_AB else tScS_t2r[i][0]\n                global_row = row_coord + m_block * self.tile_m\n                global_col = col_coord + n_block * self.tile_n\n\n                if const_expr(self.qhead_per_kvhead_packgqa != 1):\n                    assert head_divmod is not None\n                    mask_row, head_offset = divmod(global_row, head_divmod)\n                    head_idx_for_mod = head_idx * self.qhead_per_kvhead_packgqa + head_offset\n                else:\n                    head_idx_for_mod = head_idx\n                    mask_row = global_row\n\n                mask_row_for_mod = mask_row\n                if const_expr(has_fastdiv and aux_tensors is not None):\n                    if check_q_boundary:\n                        _, mask_row_for_mod = divmod(mask_row, fastdiv_mods[0])\n                global_col_for_mod = global_col\n                if const_expr(has_fastdiv and mask_seqlen and aux_tensors is not None):\n                    _, global_col_for_mod = divmod(global_col, fastdiv_mods[1])\n\n                head_idx_ssa = utils.scalar_to_ssa(head_idx_for_mod, cutlass.Int32)\n                mask_row_ssa = utils.scalar_to_ssa(mask_row_for_mod, cutlass.Int32)\n                kv_idx_ssa = utils.scalar_to_ssa(global_col_for_mod, cutlass.Int32)\n                mask_value = mask_mod(\n                    batch_idx_ssa,\n                    head_idx_ssa,\n                    mask_row_ssa,\n                    kv_idx_ssa,\n                    self.seqlen_info,\n                    aux_tensors,\n                )\n                cond = cutlass.Boolean(utils.ssa_to_scalar(mask_value))\n                acc_S[i] = acc_S[i] if cond else -Float32.inf\n                if const_expr(mask_seqlen):\n                    acc_S[i] = -Float32.inf if global_col >= self.seqlen_k else acc_S[i]\n                if check_q_boundary:\n                    acc_S[i] = -Float32.inf if mask_row >= self.seqlen_q else acc_S[i]\n\n        else:  # Causal or local\n            causal_row_offset = self.seqlen_k - n_block * self.tile_n - self.seqlen_q\n            row_idx = tScS_t2r[0][0] + m_block * self.tile_m\n            if const_expr(self.qhead_per_kvhead_packgqa != 1):\n                row_idx = row_idx // self.qhead_per_kvhead_packgqa\n            if const_expr(mask_causal):\n                col_limit_right = row_idx + causal_row_offset + 1\n                if const_expr(mask_seqlen):\n                    col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit)\n                # if cute.arch.thread_idx()[0] % 32 == 0:\n                #     cute.printf(\"tidx = %d, tidx tmem = %d, row_idx = %d, col_limit_right = %d, causal_row_offset = %d\\n\", cute.arch.thread_idx()[0], thr_tmem_load.thr_idx, row_idx, col_limit_right, causal_row_offset)\n                ncol = const_expr(cute.size(tScS_t2r.shape))\n                if const_expr(not r2p):\n                    for i in cutlass.range(ncol, unroll_full=True):\n                        acc_S[i] = -Float32.inf if tScS_t2r[i][1] >= col_limit_right else acc_S[i]\n                else:\n                    mask_r2p_lambda(\n                        acc_S,\n                        lambda s: r2p_bitmask_below(col_limit_right, s),\n                        rank1=True,\n                    )\n            else:\n                local_row_offset_right = (\n                    causal_row_offset + 1 + self.window_size_right\n                    if const_expr(self.window_size_right is not None)\n                    else None\n                )\n                local_row_offset_left = (\n                    causal_row_offset - self.window_size_left\n                    if const_expr(self.window_size_left is not None)\n                    else None\n                )\n                if const_expr(self.window_size_right is not None):\n                    col_limit_right = row_idx + local_row_offset_right\n                else:\n                    col_limit_right = self.tile_n\n                if const_expr(mask_seqlen):\n                    col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit)\n                col_limit_left = (\n                    row_idx + local_row_offset_left\n                    if const_expr(self.window_size_left is not None)\n                    else 0\n                )\n                if const_expr(not r2p):\n                    # if cute.arch.thread_idx()[0] == 0 or cute.arch.thread_idx()[0] == 128: cute.printf(\"m_block = {}, n_block = {}, row_idx = {}, causal_row_offset = {}, col_limit_right = {}, col_limit_left = {}\", m_block, n_block, row_idx, causal_row_offset, col_limit_right, col_limit_left)\n                    for i in cutlass.range(cute.size(tScS_t2r.shape), unroll_full=True):\n                        col_idx = tScS_t2r[i][1]\n                        acc_S[i] = (\n                            -Float32.inf\n                            if col_idx >= col_limit_right or col_idx < col_limit_left\n                            else acc_S[i]\n                        )\n                else:\n                    # Dual-bound R2P masking for SM100.\n                    # Masks elements where: NOT (col_limit_left <= col < col_limit_right)\n\n                    def mask_gen_fn(s: int) -> Uint32:\n                        return r2p_bitmask_below(col_limit_right, s) & r2p_bitmask_above(\n                            col_limit_left, s\n                        )\n\n                    mask_r2p_lambda(acc_S, mask_gen_fn, rank1=True)\n\n    @cute.jit\n    def apply_mask_sm100_transposed(\n        self,\n        acc_S: cute.Tensor,\n        tScS_t2r: cute.Tensor,\n        t0ScS_t2r: cute.Tensor,\n        m_block: cutlass.Int32,\n        n_block: cutlass.Int32,\n        mask_seqlen: cutlass.Constexpr,\n        mask_causal: cutlass.Constexpr,\n        mask_local: cutlass.Constexpr,\n        mask_mod: cutlass.Constexpr[Optional[Callable]] = None,\n        batch_idx: Int32 = None,\n        head_idx: Int32 = None,\n        aux_tensors: Optional[list] = None,\n        fastdiv_mods=(None, None),\n        is_full_block: bool = False,\n        check_m_boundary: bool = True,\n    ) -> None:\n        \"\"\"\n        Backward pass: mask S = K @ Q.T where n_block tiles seqlen_k and m_block tiles seqlen_q.\n\n        Coordinate conventio:\n        - ROW corresponds to Q (m_block)\n        - COL corresponds to KV (n_block)\n\n        is_full_block: If True, skip mask_mod (all elements valid). Only apply seqlen masking.\n        check_m_boundary: If False, skip seqlen_q boundary check (optimization for non-boundary m_blocks).\n                          When iterating m_blocks in forward order, only the last m_block may be partial.\n        \"\"\"\n        assert not (mask_causal and mask_local), \"mask_causal and mask_local cannot be both True\"\n        ROW = 0 if const_expr(not self.swap_AB) else 1\n        COL = 1 if const_expr(not self.swap_AB) else 0\n        # assert t0ScS_t2r[0][COL] == 0, \"col0 == 0\" # tmp comment for 2-cta bwd\n        thr_col_offset = tScS_t2r[0][COL]\n        seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset\n\n        if const_expr(not mask_causal and not mask_local and mask_mod is not None):\n            # Block sparse case with mask_mod (backward)\n            #\n            # Coordinate convention: ROW → Q (m_block), COL → KV (n_block).\n            # These already account for swap_AB.\n            #\n            # FULL blocks: mask_mod returns True for all elements, so skip it.\n            #   Still need seqlen bounds check (elements may be OOB on last m_block).\n            # PARTIAL blocks: apply mask_mod element-wise, then seqlen bounds.\n            if is_full_block:\n                if const_expr(mask_seqlen):\n                    if seqlenk_col_limit <= 0:\n                        # Entire tile is OOB for K\n                        for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True):\n                            acc_S[i] = -cutlass.Float32.inf\n                    elif check_m_boundary:\n                        # Last m_block: check Q and K boundaries\n                        ncol = const_expr(cute.size(tScS_t2r.shape))\n                        for i in cutlass.range_constexpr(ncol):\n                            row_coord = tScS_t2r[i][ROW]\n                            col_coord = tScS_t2r[i][COL]\n                            global_q = row_coord + m_block * self.tile_m\n                            global_kv = col_coord + n_block * self.tile_n\n                            q_out_of_bounds = global_q >= self.seqlen_q\n                            kv_out_of_bounds = global_kv >= self.seqlen_k\n                            out_of_bounds = q_out_of_bounds or kv_out_of_bounds\n                            acc_S[i] = -cutlass.Float32.inf if out_of_bounds else acc_S[i]\n            else:\n                # Partial block\n                has_fastdiv = const_expr(\n                    fastdiv_mods is not None\n                    and fastdiv_mods[0] is not None\n                    and fastdiv_mods[1] is not None\n                )\n                wrap_aux_indices = const_expr(\n                    has_fastdiv and mask_seqlen and const_expr(aux_tensors is not None)\n                )\n                batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32)\n                head_idx_ssa = utils.scalar_to_ssa(head_idx, cutlass.Int32)\n\n                ncol = const_expr(cute.size(tScS_t2r.shape))\n                for i in cutlass.range_constexpr(ncol):\n                    row_coord = tScS_t2r[i][ROW]\n                    col_coord = tScS_t2r[i][COL]\n                    global_q = row_coord + m_block * self.tile_m\n                    global_kv = col_coord + n_block * self.tile_n\n\n                    q_idx_for_mod = global_q\n                    kv_idx_for_mod = global_kv\n                    if const_expr(wrap_aux_indices):\n                        _, q_idx_for_mod = divmod(global_q, fastdiv_mods[0])\n                        _, kv_idx_for_mod = divmod(global_kv, fastdiv_mods[1])\n\n                    q_idx_ssa = utils.scalar_to_ssa(q_idx_for_mod, cutlass.Int32)\n                    kv_idx_ssa = utils.scalar_to_ssa(kv_idx_for_mod, cutlass.Int32)\n\n                    mask_value = mask_mod(\n                        batch_idx_ssa,\n                        head_idx_ssa,\n                        q_idx_ssa,\n                        kv_idx_ssa,\n                        self.seqlen_info,\n                        aux_tensors,\n                    )\n                    cond = cutlass.Boolean(utils.ssa_to_scalar(mask_value))\n                    acc_S[i] = acc_S[i] if cond else -cutlass.Float32.inf\n\n                    if const_expr(mask_seqlen):\n                        # check_m_boundary=False skips q check for non-boundary m_blocks\n                        q_out_of_bounds = check_m_boundary and (global_q >= self.seqlen_q)\n                        kv_out_of_bounds = global_kv >= self.seqlen_k\n                        out_of_bounds = q_out_of_bounds or kv_out_of_bounds\n                        acc_S[i] = -cutlass.Float32.inf if out_of_bounds else acc_S[i]\n\n        elif const_expr(not mask_causal and not mask_local):\n            if const_expr(mask_seqlen):\n                if seqlenk_col_limit <= 0:\n                    for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True):\n                        acc_S[i] = -cutlass.Float32.inf\n        else:  # Causal or local\n            thr_row_offset = tScS_t2r[0][ROW]\n            seqlenq_row_limit = self.seqlen_q - m_block * self.tile_m - thr_row_offset\n            causal_offset = seqlenq_row_limit - seqlenk_col_limit\n            if const_expr(mask_causal):\n                # tidx = cute.arch.thread_idx()[0] % 256\n                # if tidx < 32:\n                #     cute.printf(\"tidx = {}, {} {}, {} {}\", tidx, tScS_t2r[0][0], tScS_t2r[0][1], tScS_t2r[1][0], tScS_t2r[1][1])\n                row_limit_top = causal_offset\n                if const_expr(mask_seqlen):\n                    # If col is beyond the column limit, we want to mask out the entire\n                    # column, by setting row limit to be self.tile_m.\n                    if seqlenk_col_limit <= 0:\n                        row_limit_top = self.tile_m\n                r2p = True\n                if const_expr(not r2p):\n                    for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True):\n                        acc_S[i] = (\n                            -cutlass.Float32.inf if t0ScS_t2r[i][ROW] < row_limit_top else acc_S[i]\n                        )\n                else:\n                    num_rep = cute.size(tScS_t2r, mode=[0])  # 16 or 32\n                    num_wg = 2\n                    row_limit = row_to_r2p_idx(row_limit_top, num_rep, num_wg)\n                    mask_r2p_lambda(\n                        acc_S,\n                        lambda s: r2p_bitmask_above(row_limit, s),\n                        rank1=True,\n                    )\n            else:\n                if const_expr(self.window_size_right is not None):\n                    row_limit_top = causal_offset - self.window_size_right\n                else:\n                    row_limit_top = 0\n                if const_expr(self.window_size_left is not None):\n                    row_limit_bot = causal_offset + self.window_size_left\n                if const_expr(mask_seqlen):\n                    if seqlenk_col_limit <= 0:\n                        row_limit_top = self.tile_m\n                r2p = True\n                if const_expr(not r2p):\n                    for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True):\n                        row_idx = t0ScS_t2r[i][ROW]\n                        local_mask = row_idx < row_limit_top\n                        if const_expr(self.window_size_left is not None):\n                            local_mask |= row_idx > row_limit_bot\n                        acc_S[i] = -cutlass.Float32.inf if local_mask else acc_S[i]\n                else:\n\n                    def mask_gen_fn(s: int) -> Uint32:\n                        num_rep = cute.size(tScS_t2r, mode=[0])\n                        num_wg = 2\n\n                        row_limit = row_to_r2p_idx(row_limit_top, num_rep, num_wg)\n                        mask = r2p_bitmask_above(row_limit, s)\n\n                        if const_expr(self.window_size_left is not None):\n                            row_limit_bottom = row_to_r2p_idx(row_limit_bot + 1, num_rep, num_wg)\n                            mask = mask & r2p_bitmask_below(row_limit_bottom, s)\n\n                        return mask\n\n                    mask_r2p_lambda(\n                        acc_S,\n                        mask_gen_fn,\n                        rank1=True,\n                    )\n"
  },
  {
    "path": "flash_attn/cute/mma_sm100_desc.py",
    "content": "# Copyright (c) 2025, Tri Dao.\n# Ported Cutlass code from C++ to Python:\n# https://github.com/NVIDIA/cutlass/blob/main/include/cute/arch/mma_sm100_desc.hpp\n# https://github.com/NVIDIA/cutlass/blob/main/include/cute/atom/mma_traits_sm100.hpp\n\nfrom enum import IntEnum\n\nimport cutlass\nimport cutlass.cute as cute\n\n# ---------------------------------------------------------------------------\n# Enumerations that match the HW encodings (values MUST stay identical)\n# ---------------------------------------------------------------------------\n\n\nclass Major(IntEnum):  # matrix “layout” in the ISA docs\n    K = 0\n    MN = 1\n\n\nclass ScaleIn(IntEnum):  # negate flags\n    One = 0\n    Neg = 1\n\n\nclass Saturate(IntEnum):\n    False_ = 0\n    True_ = 1\n\n\nclass CFormat(IntEnum):  # 2-bit field (bits 4-5)\n    F16 = 0\n    F32 = 1\n    S32 = 2\n\n\nclass F16F32Format(IntEnum):  # 3-bit field (A/B element type)\n    F16 = 0\n    BF16 = 1\n    TF32 = 2\n\n\nclass S8Format(IntEnum):\n    UINT8 = 0\n    INT8 = 1\n\n\nclass MXF8F6F4Format(IntEnum):\n    E4M3 = 0\n    E5M2 = 1\n    E2M3 = 3\n    E3M2 = 4\n    E2M1 = 5\n\n\nclass MaxShift(IntEnum):\n    NoShift = 0\n    MaxShift8 = 1\n    MaxShift16 = 2\n    MaxShift32 = 3\n\n\n# ---------------------------------------------------------------------------\n# CUTLASS-type → encoding helpers\n# ---------------------------------------------------------------------------\n\n\ndef to_UMMA_format(cutlass_type) -> int:\n    \"\"\"\n    Map a CUTLASS scalar class to the 3-bit encoding for Matrix A/B.\n    \"\"\"\n    if cutlass_type is cutlass.Int8:\n        return S8Format.INT8\n    # Unsigned 8-bit (if available in your CUTLASS build)\n    if cutlass_type is cutlass.Uint8:\n        return S8Format.UINT8\n    # FP-16 / BF-16\n    if cutlass_type is cutlass.Float16:\n        return F16F32Format.F16\n    if cutlass_type is cutlass.BFloat16:\n        return F16F32Format.BF16\n    # TensorFloat-32 (8-bit exponent, 10-bit mantissa packed in 19 bits)\n    if cutlass_type is cutlass.TFloat32:\n        return F16F32Format.TF32\n    # Float-8 / Float-6 / Float-4 – add whenever CUTLASS exposes them\n    if cutlass_type is cutlass.FloatE4M3FN:\n        return MXF8F6F4Format.E4M3\n    if cutlass_type is cutlass.FloatE5M2:\n        return MXF8F6F4Format.E5M2\n    raise TypeError(f\"Unsupported CUTLASS scalar type for A/B: {cutlass_type!r}\")\n\n\ndef to_C_format(cutlass_type) -> int:\n    \"\"\"\n    Map a CUTLASS scalar class to the 2-bit accumulator encoding.\n    \"\"\"\n    if cutlass_type is cutlass.Float16:\n        return CFormat.F16\n    if cutlass_type is cutlass.Float32:\n        return CFormat.F32\n    if cutlass_type is cutlass.Int32:\n        return CFormat.S32\n    raise TypeError(f\"Unsupported CUTLASS scalar type for accumulator: {cutlass_type!r}\")\n\n\n# ---------------------------------------------------------------------------\n# The constructor – accepts only CUTLASS scalar classes\n# ---------------------------------------------------------------------------\n\n\ndef make_instr_desc(\n    a_type,  # CUTLASS scalar class, e.g. cutlass.Int8\n    b_type,\n    c_type,\n    M: int,  # 64, 128 or 256\n    N: int,  # 8 … 256 (multiple of 8)\n    a_major: Major,\n    b_major: Major,\n    a_neg: ScaleIn = ScaleIn.One,\n    b_neg: ScaleIn = ScaleIn.One,\n    c_sat: Saturate = Saturate.False_,\n    is_sparse: bool = False,\n    max_shift: MaxShift = MaxShift.NoShift,\n) -> int:\n    \"\"\"\n    Build the 32-bit instruction descriptor for Blackwell MMA.\n    All matrix/accumulator **types must be CUTLASS scalar classes** –\n    passing integers is forbidden.\n    \"\"\"\n    # --- encode element formats -------------------------------------------------\n    a_fmt = int(to_UMMA_format(a_type))\n    b_fmt = int(to_UMMA_format(b_type))\n    c_fmt = int(to_C_format(c_type))\n\n    # --- range checks on M/N -----------------------------------------------------\n    if M not in (64, 128, 256):\n        raise ValueError(\"M must be 64, 128 or 256\")\n    if N < 8 or N > 256 or (N & 7):\n        raise ValueError(\"N must be a multiple of 8 in the range 8…256\")\n\n    m_dim = M >> 4  # 5-bit field\n    n_dim = N >> 3  # 6-bit field\n\n    # fmt: off\n    # --- pack the bit-fields -----------------------------------------------------\n    desc = 0\n    desc |= (0                 & 0x3) << 0        # sparse_id2 (always 0 here)\n    desc |= (int(is_sparse)    & 0x1) << 2        # sparse_flag\n    desc |= (int(c_sat)        & 0x1) << 3        # saturate\n    desc |= (c_fmt             & 0x3) << 4        # c_format\n    desc |= (a_fmt             & 0x7) << 7        # a_format\n    desc |= (b_fmt             & 0x7) << 10       # b_format\n    desc |= (int(a_neg)        & 0x1) << 13       # a_negate\n    desc |= (int(b_neg)        & 0x1) << 14       # b_negate\n    desc |= (int(a_major)      & 0x1) << 15       # a_major\n    desc |= (int(b_major)      & 0x1) << 16       # b_major\n    desc |= (n_dim             & 0x3F) << 17      # n_dim (6 bits)\n    desc |= (m_dim             & 0x1F) << 24      # m_dim (5 bits)\n    desc |= (int(max_shift)    & 0x3) << 30       # max_shift (2 bits)\n    # fmt: on\n\n    return desc & 0xFFFF_FFFF  # ensure 32-bit result\n\n\ndef mma_op_to_idesc(op: cute.nvgpu.tcgen05.mma.MmaOp):\n    return make_instr_desc(\n        op.a_dtype,\n        op.b_dtype,\n        op.acc_dtype,\n        op.shape_mnk[0],\n        op.shape_mnk[1],\n        Major.K if op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else Major.MN,\n        Major.K if op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else Major.MN,\n    )\n\n\nclass LayoutType(IntEnum):  # occupies the top-3 bits [61:64)\n    SWIZZLE_NONE = 0  # (a.k.a. “INTERLEAVE” in older docs)\n    SWIZZLE_128B_BASE32B = 1\n    SWIZZLE_128B = 2\n    SWIZZLE_64B = 4\n    SWIZZLE_32B = 6\n    # values 3,5,7 are reserved / illegal for UMMA\n\n\n# ---------------------------------------------------------------------------\n#  Helpers – figure out the SWIZZLE_* family from the tensor layout\n# ---------------------------------------------------------------------------\n\n\ndef _layout_type(swizzle: cute.Swizzle) -> LayoutType:\n    B, M, S = swizzle.num_bits, swizzle.num_base, swizzle.num_shift\n\n    if M == 4:  # Swizzle<*,4,3>\n        if S != 3:\n            raise ValueError(\"Unexpected swizzle shift – want S==3 for M==4\")\n        return {\n            0: LayoutType.SWIZZLE_NONE,\n            1: LayoutType.SWIZZLE_32B,\n            2: LayoutType.SWIZZLE_64B,\n            3: LayoutType.SWIZZLE_128B,\n        }[B]  # KeyError ⇒ invalid B→ raise\n    if M == 5:  # Swizzle<2,5,2> (the only legal triple for M==5)\n        if (B, S) != (2, 2):\n            raise ValueError(\"Only Swizzle<2,5,2> supported for 128B_BASE32B\")\n        return LayoutType.SWIZZLE_128B_BASE32B\n\n    # Any other (M,B,S) triple is not a UMMA-legal shared-memory layout\n    raise ValueError(\"Unsupported swizzle triple for UMMA smem descriptor\")\n\n\ndef make_smem_desc_base(layout: cute.Layout, swizzle: cute.Swizzle, major: Major) -> int:\n    \"\"\"\n    Convert a 2-D *shared-memory* Cute layout into the Blackwell 64-bit\n    smem-descriptor, without the smem start address.\n    layout must correspond to layout of an uint128 tensor.\n    \"\"\"\n    # ------------------------------------------------------------------ meta\n    layout_type = _layout_type(swizzle)  # resolve SWIZZLE_* family\n\n    VERSION = 1  # bits 46–47\n    LBO_MODE = 0  # bit  52\n    BASE_OFFSET = 0  # bits 49–51   (CUTLASS always 0)\n\n    # ---------------------------------------------------------- strides  (units: uint128_t = 16 B)\n    swizzle_atom_mn_size = {\n        LayoutType.SWIZZLE_NONE: 1,\n        LayoutType.SWIZZLE_32B: 2,\n        LayoutType.SWIZZLE_64B: 4,\n        LayoutType.SWIZZLE_128B: 8,\n        LayoutType.SWIZZLE_128B_BASE32B: 8,\n    }[layout_type]\n\n    if major is Major.MN:\n        swizzle_atom_k_size = 4 if layout_type is LayoutType.SWIZZLE_128B_BASE32B else 8\n        canonical_layout = cute.logical_divide(layout, (swizzle_atom_mn_size, swizzle_atom_k_size))\n        if not cute.is_congruent(canonical_layout, ((1, 1), (1, 1))):\n            raise ValueError(\"Not a canonical UMMA_MN Layout: Expected profile failure.\")\n        stride_00 = canonical_layout.stride[0][0]\n        if layout_type is not LayoutType.SWIZZLE_NONE and stride_00 != 1:\n            raise ValueError(\"Not a canonical UMMA_MN Layout: Expected stride failure.\")\n        stride_10 = canonical_layout.stride[1][0]\n        if stride_10 != swizzle_atom_mn_size:\n            raise ValueError(\"Not a canonical UMMA_MN Layout: Expected stride failure.\")\n        stride_01, stride_11 = canonical_layout.stride[0][1], canonical_layout.stride[1][1]\n        if layout_type is LayoutType.SWIZZLE_NONE:\n            stride_byte_offset, leading_byte_offset = stride_01, stride_11\n        else:\n            stride_byte_offset, leading_byte_offset = stride_11, stride_01\n    else:\n        if layout_type == LayoutType.SWIZZLE_128B_BASE32B:\n            raise ValueError(\"SWIZZLE_128B_BASE32B is invalid for Major-K\")\n        if not cute.size(layout.shape[0]) % 8 == 0:\n            raise ValueError(\"Not a canonical UMMA_K Layout: Expected MN-size multiple of 8.\")\n        canonical_layout = cute.logical_divide(layout, (8, 2))\n        if not cute.is_congruent(canonical_layout, ((1, 1), (1, 1))):\n            raise ValueError(\"Not a canonical UMMA_K Layout: Expected profile failure.\")\n        stride_00 = canonical_layout.stride[0][0]\n        if stride_00 != swizzle_atom_mn_size:\n            raise ValueError(\"Not a canonical UMMA_K Layout: Expected stride failure.\")\n        stride_10 = canonical_layout.stride[1][0]\n        if layout_type is not LayoutType.SWIZZLE_NONE and stride_10 != 1:\n            raise ValueError(\"Not a canonical UMMA_K Layout: Expected stride failure.\")\n        stride_01 = canonical_layout.stride[0][1]\n        stride_byte_offset, leading_byte_offset = stride_01, stride_10\n\n    # ------------------------------------------------------------------ pack\n    desc = 0\n    # leading_byte_offset_  [16:30)\n    desc |= (leading_byte_offset & 0x3FFF) << 16\n    # stride_byte_offset_   [32:46)\n    desc |= (stride_byte_offset & 0x3FFF) << 32\n    # version_             [46:48)\n    desc |= (VERSION & 0x3) << 46\n    # base_offset_         [49:52)\n    desc |= (BASE_OFFSET & 0x7) << 49\n    # lbo_mode_            [52:53)\n    desc |= (LBO_MODE & 0x1) << 52\n    # layout_type_         [61:64)\n    desc |= (int(layout_type) & 0x7) << 61\n\n    return desc & 0xFFFF_FFFF_FFFF_FFFF  # force 64-bit width\n\n\ndef make_smem_desc_start_addr(start_addr: cute.Pointer) -> cutlass.Int32:\n    # 14 bits, remove 4 LSB (bits 0-13 in desc)\n    return (start_addr.toint() & 0x3FFFF) >> 4\n\n\ndef smem_desc_base_from_tensor(sA: cute.Tensor, major: Major) -> int:\n    sA_swizzle = sA.iterator.type.swizzle_type\n    return make_smem_desc_base(\n        cute.recast_layout(128, sA.element_type.width, sA.layout[0]),\n        sA_swizzle,\n        major,\n    )\n"
  },
  {
    "path": "flash_attn/cute/named_barrier.py",
    "content": "# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n\nimport enum\n\n\nclass NamedBarrierFwd(enum.IntEnum):\n    Epilogue = enum.auto()  # starts from 1 as barrier 0 is reserved for sync_threads()\n    WarpSchedulerWG1 = enum.auto()\n    WarpSchedulerWG2 = enum.auto()\n    WarpSchedulerWG3 = enum.auto()\n    PFull = enum.auto()\n    PEmpty = enum.auto()\n\n\nclass NamedBarrierFwdSm100(enum.IntEnum):\n    Epilogue = enum.auto()  # starts from 1 as barrier 0 is reserved for sync_threads()\n    TmemPtr = enum.auto()\n    SoftmaxStatsW0 = enum.auto()\n    SoftmaxStatsW1 = enum.auto()\n    SoftmaxStatsW2 = enum.auto()\n    SoftmaxStatsW3 = enum.auto()\n    SoftmaxStatsW4 = enum.auto()\n    SoftmaxStatsW5 = enum.auto()\n    SoftmaxStatsW6 = enum.auto()\n    SoftmaxStatsW7 = enum.auto()\n\n\nclass NamedBarrierBwd(enum.IntEnum):\n    Epilogue = enum.auto()\n    WarpSchedulerWG1 = enum.auto()\n    WarpSchedulerWG2 = enum.auto()\n    WarpSchedulerWG3 = enum.auto()\n    PdS = enum.auto()\n    dQFullWG0 = enum.auto()\n    dQFullWG1 = enum.auto()\n    dQFullWG2 = enum.auto()\n    dQEmptyWG0 = enum.auto()\n    dQEmptyWG1 = enum.auto()\n    dQEmptyWG2 = enum.auto()\n\n\nclass NamedBarrierBwdSm100(enum.IntEnum):\n    EpilogueWG1 = enum.auto()\n    EpilogueWG2 = enum.auto()\n    Compute = enum.auto()\n    dQaccReduce = enum.auto()\n    TmemPtr = enum.auto()\n"
  },
  {
    "path": "flash_attn/cute/pack_gqa.py",
    "content": "# Copyright (c) 2025, Tri Dao.\n\nfrom typing import Union, Tuple\n\nimport cutlass\nimport cutlass.cute as cute\nfrom cutlass.cute.nvgpu import cpasync\n\n\nfrom quack import layout_utils\nimport flash_attn.cute.utils as utils\n\n\ndef pack_gqa_layout(T, qhead_per_kvhead, nheads_kv, head_idx):\n    \"\"\"Reshape a tensor to fold qhead_per_kvhead into the seqlen dimension (mode 0).\n\n    The head dimension is at mode ``head_idx``.  Modes before it (1..head_idx-1)\n    are kept as-is (e.g. headdim for Q/O tensors), and modes after it are kept\n    as-is (e.g. batch).\n\n    For Q/O tensors (head_idx=2):\n        (seqlen_q, headdim, nheads, batch, ...) -> ((qhead_per_kvhead, seqlen_q), headdim, nheads_kv, batch, ...)\n    For LSE tensors (head_idx=1):\n        (seqlen_q, nheads, batch, ...) -> ((qhead_per_kvhead, seqlen_q), nheads_kv, batch, ...)\n    \"\"\"\n    head_stride = T.stride[head_idx]\n    shape_packed = (\n        (qhead_per_kvhead, T.shape[0]),\n        *[T.shape[i] for i in range(1, head_idx)],\n        nheads_kv,\n        *[T.shape[i] for i in range(head_idx + 1, len(T.shape))],\n    )\n    stride_packed = (\n        (head_stride, T.stride[0]),\n        *[T.stride[i] for i in range(1, head_idx)],\n        head_stride * qhead_per_kvhead,\n        *[T.stride[i] for i in range(head_idx + 1, len(T.shape))],\n    )\n    return cute.make_tensor(T.iterator, cute.make_layout(shape_packed, stride=stride_packed))\n\n\ndef make_packgqa_tiled_tma_atom(\n    op: cute.atom.CopyOp,\n    gmem_tensor: cute.Tensor,\n    smem_layout: Union[cute.Layout, cute.ComposedLayout],\n    cta_tiler: Tuple[int, int],\n    qhead_per_kvhead: int,\n    head_idx: int,\n):\n    # This packing and unpacking of the layout is so that we keep the same TMA dimension as usual.\n    # e.g. for (seqlen, d, nheads, b) layout, we still have 4D TMA after packing to\n    # ((nheads, seqlen), d, b).\n    # If we instead pack directly to ((qhead_per_kvhead, seqlen), d, nheads_kv, b) we'd have 5D TMA.\n    # Pack headdim and seqlen dim into 1: (seqlen, d, nheads, b) -> ((nheads, seqlen), d, b)\n    gmem_tensor = layout_utils.select(\n        gmem_tensor, [head_idx, *range(head_idx), *range(head_idx + 1, cute.rank(gmem_tensor))]\n    )\n    gmem_tensor = cute.group_modes(gmem_tensor, 0, 2)\n    assert cta_tiler[0] % qhead_per_kvhead == 0, (\n        \"CTA tile size in the seqlen dimension must be divisible by qhead_per_kvhead\"\n    )\n    tma_atom, tma_tensor = cpasync.make_tiled_tma_atom(\n        op,\n        gmem_tensor,\n        smem_layout,\n        ((qhead_per_kvhead, cta_tiler[0] // qhead_per_kvhead), cta_tiler[1]),  # No mcast\n    )\n    # Unpack from ((nheads, seqlen), d, b) -> ((qhead_per_kvhead, seqlen), d, nheads_kv, b)\n    T = tma_tensor\n    shape_packed = (\n        (qhead_per_kvhead, T.shape[0][1]),\n        *[T.shape[i] for i in range(1, head_idx)],\n        T.shape[0][0] // qhead_per_kvhead,\n        *[T.shape[i] for i in range(head_idx, len(T.shape))],\n    )\n    stride_packed = (\n        *[T.stride[i] for i in range(head_idx)],\n        T.stride[0][0] * qhead_per_kvhead,\n        *[T.stride[i] for i in range(head_idx, len(T.shape))],\n    )\n    tma_tensor = cute.make_tensor(T.iterator, cute.make_layout(shape_packed, stride=stride_packed))\n    return tma_atom, tma_tensor\n\n\ndef unpack_gqa_layout(T, qhead_per_kvhead, head_idx):\n    \"\"\"Reverse of pack_gqa_layout: unfold qhead_per_kvhead from the seqlen dimension (mode 0).\n\n    The head dimension is at mode ``head_idx``.  Modes before it (1..head_idx-1)\n    are kept as-is (e.g. headdim for Q/O tensors), and modes after it are kept\n    as-is (e.g. batch).\n\n    For Q/O tensors (head_idx=2):\n        ((qhead_per_kvhead, seqlen_q), headdim, nheads_kv, batch, ...) -> (seqlen_q, headdim, nheads, batch, ...)\n    For LSE tensors (head_idx=1):\n        ((qhead_per_kvhead, seqlen_q), nheads_kv, batch, ...) -> (seqlen_q, nheads, batch, ...)\n    \"\"\"\n    seqlen_stride = T.stride[0][1]\n    head_stride = T.stride[0][0]\n    shape_unpacked = (\n        T.shape[0][1],\n        *[T.shape[i] for i in range(1, head_idx)],\n        T.shape[head_idx] * qhead_per_kvhead,\n        *[T.shape[i] for i in range(head_idx + 1, len(T.shape))],\n    )\n    stride_unpacked = (\n        seqlen_stride,\n        *[T.stride[i] for i in range(1, head_idx)],\n        head_stride,\n        *[T.stride[i] for i in range(head_idx + 1, len(T.shape))],\n    )\n    return cute.make_tensor(T.iterator, cute.make_layout(shape_unpacked, stride=stride_unpacked))\n\n\nclass PackGQA:\n    def __init__(\n        self,\n        m_block_size: cutlass.Constexpr[int],\n        head_dim_padded: cutlass.Constexpr[int],\n        check_hdim_oob: cutlass.Constexpr[bool],\n        qhead_per_kvhead: cutlass.Constexpr[bool],\n    ):\n        self.m_block_size = m_block_size\n        self.head_dim_padded = head_dim_padded\n        self.check_hdim_oob = check_hdim_oob\n        self.qhead_per_kvhead = qhead_per_kvhead\n\n    @cute.jit\n    def compute_ptr(\n        self,\n        tensor: cute.Tensor,\n        cRows: cute.Tensor,\n        tidx: cutlass.Int32,\n        block: cutlass.Int32,\n        threads_per_row: cutlass.Constexpr[int],\n        num_threads: cutlass.Constexpr[int],\n    ):\n        num_ptr_per_thread = cute.ceil_div(cute.size(cRows), threads_per_row)\n        tPrPtr = cute.make_fragment(num_ptr_per_thread, cutlass.Int64)\n        for i in cutlass.range_constexpr(num_ptr_per_thread):\n            row = i * num_threads + cRows[tidx % threads_per_row][0]\n            idx = block * self.m_block_size + row\n            m_idx = idx // self.qhead_per_kvhead\n            h_idx = idx - m_idx * self.qhead_per_kvhead\n            tPrPtr[i] = utils.elem_pointer(tensor, ((h_idx, m_idx),)).toint()\n        return tPrPtr\n\n    @cute.jit\n    def load_Q(\n        self,\n        mQ: cute.Tensor,  # ((qhead_per_kvhead, seqlen_q), headdim)\n        sQ: cute.Tensor,  # (m_block_size, head_dim_padded)\n        gmem_tiled_copy: cute.TiledCopy,\n        tidx: cutlass.Int32,\n        block: cutlass.Int32,\n        seqlen: cutlass.Int32,\n    ):\n        gmem_thr_copy = gmem_tiled_copy.get_slice(tidx)\n        cQ = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded))\n        tQsQ = gmem_thr_copy.partition_D(sQ)\n        tQcQ = gmem_thr_copy.partition_S(cQ)\n        t0QcQ = gmem_thr_copy.get_slice(0).partition_S(cQ)\n        tQpQ = utils.predicate_k(tQcQ, limit=mQ.shape[1])\n        tQcQ_row = tQcQ[0, None, 0]\n        threads_per_row = gmem_tiled_copy.layout_tv_tiled.shape[0][0]\n        assert cute.arch.WARP_SIZE % threads_per_row == 0, \"threads_per_row must divide WARP_SIZE\"\n        num_threads = gmem_tiled_copy.size\n        tPrQPtr = self.compute_ptr(mQ[None, 0], tQcQ_row, tidx, block, threads_per_row, num_threads)\n        for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])):\n            q_ptr_i64 = utils.shuffle_sync(\n                tPrQPtr[m // threads_per_row], m % threads_per_row, width=threads_per_row\n            )\n            q_gmem_ptr = cute.make_ptr(\n                mQ.element_type, q_ptr_i64, cute.AddressSpace.gmem, assumed_align=16\n            )\n            if (\n                t0QcQ[0, m, 0][0]\n                < seqlen * self.qhead_per_kvhead - block * self.m_block_size - tQcQ_row[0][0]\n            ):\n                mQ_cur = cute.make_tensor(q_gmem_ptr, (self.head_dim_padded,))\n                elems_per_load = cute.size(tQsQ.shape[0][0])\n                mQ_cur_copy = cute.tiled_divide(mQ_cur, (elems_per_load,))\n                for k in cutlass.range_constexpr(cute.size(tQsQ.shape[2])):\n                    ki = tQcQ[0, 0, k][1] // elems_per_load\n                    cute.copy(\n                        gmem_thr_copy,\n                        mQ_cur_copy[None, ki],\n                        tQsQ[None, m, k],\n                        pred=tQpQ[None, m, k] if cutlass.const_expr(self.check_hdim_oob) else None,\n                    )\n            # We don't need to clear the sQ smem tiles since we'll only write out the valid outputs\n\n    @cute.jit\n    def store_LSE(\n        self,\n        mLSE: cute.Tensor,  # (qhead_per_kvhead, seqlen_q)\n        tLSErLSE: cute.Tensor,  # (m_block_size, head_dim_padded)\n        tiled_mma: cute.TiledMma,\n        tidx: cutlass.Int32,\n        block: cutlass.Int32,\n        seqlen: cutlass.Int32,\n    ):\n        thr_mma = tiled_mma.get_slice(tidx)\n        caccO = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded))\n        taccOcO = thr_mma.partition_C(caccO)\n        taccOcO_row = layout_utils.reshape_acc_to_mn(taccOcO)[None, 0]\n        assert cute.size(tLSErLSE) == cute.size(taccOcO_row)\n        threads_per_row = tiled_mma.tv_layout_C.shape[0][0]\n        assert cute.arch.WARP_SIZE % threads_per_row == 0, \"threads_per_row must divide WARP_SIZE\"\n        assert cute.size(tLSErLSE) <= threads_per_row\n        num_threads = tiled_mma.size\n        tPrLSEPtr = self.compute_ptr(mLSE, taccOcO_row, tidx, block, threads_per_row, num_threads)\n        for m in cutlass.range_constexpr(cute.size(tLSErLSE)):\n            lse_ptr_i64 = utils.shuffle_sync(\n                tPrLSEPtr[m // threads_per_row],\n                m % threads_per_row,\n                width=threads_per_row,\n            )\n            lse_gmem_ptr = cute.make_ptr(\n                mLSE.element_type, lse_ptr_i64, cute.AddressSpace.gmem, assumed_align=4\n            )\n            row = block * self.m_block_size + taccOcO_row[m][0]\n            # Only the thread corresponding to column 0 writes out the lse to gmem\n            if taccOcO[0][1] == 0 and row < seqlen * self.qhead_per_kvhead:\n                mLSE_copy = cute.make_tensor(lse_gmem_ptr, (1,))\n                mLSE_copy[0] = tLSErLSE[m]\n\n    @cute.jit\n    def store_O(\n        self,\n        mO: cute.Tensor,  # ((qhead_per_kvhead, seqlen_q), headdim)\n        tOrO: cute.Tensor,  # (m_block_size, head_dim_padded) split across threads according to gmem_tiled_copy\n        gmem_tiled_copy: cute.TiledCopy,\n        tidx: cutlass.Int32,\n        block: cutlass.Int32,\n        seqlen: cutlass.Int32,\n    ):\n        gmem_thr_copy = gmem_tiled_copy.get_slice(tidx)\n        cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded))\n        tOcO = gmem_thr_copy.partition_S(cO)\n        t0OcO = gmem_thr_copy.get_slice(0).partition_S(cO)\n        tOpO = utils.predicate_k(tOcO, limit=mO.shape[1])\n        tOcO_row = tOcO[0, None, 0]\n        threads_per_row = gmem_tiled_copy.layout_tv_tiled.shape[0][0]\n        assert cute.arch.WARP_SIZE % threads_per_row == 0, \"threads_per_row must divide WARP_SIZE\"\n        num_threads = gmem_tiled_copy.size\n        tPrOPtr = self.compute_ptr(mO[None, 0], tOcO_row, tidx, block, threads_per_row, num_threads)\n        for m in cutlass.range_constexpr(cute.size(tOrO.shape[1])):\n            o_ptr_i64 = utils.shuffle_sync(\n                tPrOPtr[m // threads_per_row], m % threads_per_row, width=threads_per_row\n            )\n            o_gmem_ptr = cute.make_ptr(\n                mO.element_type, o_ptr_i64, cute.AddressSpace.gmem, assumed_align=16\n            )\n            if (\n                t0OcO[0, m, 0][0]\n                < seqlen * self.qhead_per_kvhead - block * self.m_block_size - tOcO_row[0][0]\n            ):\n                mO_cur = cute.make_tensor(o_gmem_ptr, (self.head_dim_padded,))\n                elems_per_load = cute.size(tOrO.shape[0][0])\n                mO_cur_copy = cute.tiled_divide(mO_cur, (elems_per_load,))\n                for k in cutlass.range_constexpr(cute.size(tOrO.shape[2])):\n                    ki = tOcO[0, 0, k][1] // elems_per_load\n                    cute.copy(\n                        gmem_thr_copy,\n                        tOrO[None, m, k],\n                        mO_cur_copy[None, ki],\n                        pred=tOpO[None, m, k] if cutlass.const_expr(self.check_hdim_oob) else None,\n                    )\n"
  },
  {
    "path": "flash_attn/cute/paged_kv.py",
    "content": "from typing import Type\nfrom dataclasses import dataclass\n\nimport cutlass\nimport cutlass.cute as cute\nfrom cutlass.cute.nvgpu import cpasync\nfrom cutlass import Int32, const_expr\n\nfrom flash_attn.cute import utils\nfrom quack.cute_dsl_utils import ParamsBase\nfrom cutlass.cute import FastDivmodDivisor\n\nimport math\n\n\n@dataclass\nclass PagedKVManager(ParamsBase):\n    mPageTable: cute.Tensor\n    mK_paged: cute.Tensor\n    mV_paged: cute.Tensor\n    thread_idx: Int32\n\n    page_size_divmod: FastDivmodDivisor\n    seqlen_k: Int32\n    leftpad_k: Int32\n    n_block_size: Int32\n    num_threads: cutlass.Constexpr[Int32]\n    head_dim_padded: cutlass.Constexpr[Int32]\n    head_dim_v_padded: cutlass.Constexpr[Int32]\n\n    arch: cutlass.Constexpr[Int32]\n    v_gmem_transposed: cutlass.Constexpr[bool]\n\n    gmem_threads_per_row: cutlass.Constexpr[Int32]\n    page_entry_per_thread: Int32\n    async_copy_elems: Int32\n\n    gmem_tiled_copy_KV: cute.TiledCopy\n    gmem_thr_copy_KV: cute.TiledCopy\n    tPrPage: cute.Tensor\n    tPrPageOffset: cute.Tensor\n    tKpK: cute.Tensor\n    tVpV: cute.Tensor\n\n    @staticmethod\n    def create(\n        mPageTable: cute.Tensor,\n        mK_paged: cute.Tensor,\n        mV_paged: cute.Tensor,\n        page_size_divmod: FastDivmodDivisor,\n        bidb: Int32,\n        bidh: Int32,\n        thread_idx: Int32,\n        seqlen_k: Int32,\n        leftpad_k: Int32,\n        n_block_size: cutlass.Constexpr[Int32],\n        head_dim_padded: cutlass.Constexpr[Int32],\n        head_dim_v_padded: cutlass.Constexpr[Int32],\n        num_threads: cutlass.Constexpr[Int32],\n        dtype: Type[cutlass.Numeric],\n        arch: cutlass.Constexpr[int] = 100,\n    ):\n        # SM100 transposes V in gmem to (dv, page_size, num_pages);\n        # SM90 keeps V as (page_size, dv, num_pages), same layout as K.\n        v_gmem_transposed = arch != 90\n        universal_copy_bits = 128\n        async_copy_elems = universal_copy_bits // dtype.width\n        dtype_bytes = dtype.width // 8\n        gmem_k_block_size = math.gcd(\n            head_dim_padded,\n            head_dim_v_padded,\n            128 // dtype_bytes,\n        )\n        assert gmem_k_block_size % async_copy_elems == 0\n        gmem_threads_per_row = gmem_k_block_size // async_copy_elems\n        assert cute.arch.WARP_SIZE % gmem_threads_per_row == 0\n        atom_async_copy = cute.make_copy_atom(\n            cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),\n            dtype,\n            num_bits_per_copy=universal_copy_bits,\n        )\n        thr_layout = cute.make_ordered_layout(\n            (num_threads // gmem_threads_per_row, gmem_threads_per_row),\n            order=(1, 0),\n        )\n        val_layout = cute.make_layout((1, async_copy_elems))\n        gmem_tiled_copy_KV = cute.make_tiled_copy_tv(atom_async_copy, thr_layout, val_layout)\n        gmem_thr_copy_KV = gmem_tiled_copy_KV.get_slice(thread_idx)\n        page_entry_per_thread = n_block_size // num_threads\n\n        tPrPage = cute.make_rmem_tensor((page_entry_per_thread,), Int32)\n        tPrPageOffset = cute.make_rmem_tensor((page_entry_per_thread,), Int32)\n\n        mPageTable = mPageTable[bidb, None]\n        mK_paged = mK_paged[None, None, bidh, None]\n        mV_paged = mV_paged[None, None, bidh, None]\n\n        cK = cute.make_identity_tensor((n_block_size, head_dim_padded))\n        tKcK = gmem_thr_copy_KV.partition_S(cK)\n        tKpK = utils.predicate_k(tKcK, limit=mK_paged.shape[1])\n\n        if const_expr(head_dim_padded == head_dim_v_padded):\n            tVpV = tKpK\n        else:\n            cV = cute.make_identity_tensor((n_block_size, head_dim_v_padded))\n            tVcV = gmem_thr_copy_KV.partition_S(cV)\n            # When V is transposed in gmem, dv is shape[0]; otherwise dv is shape[1] (same as K)\n            tVpV = utils.predicate_k(tVcV, limit=mV_paged.shape[0 if v_gmem_transposed else 1])\n\n        return PagedKVManager(\n            mPageTable,\n            mK_paged,\n            mV_paged,\n            thread_idx,\n            page_size_divmod,\n            seqlen_k,\n            leftpad_k,\n            n_block_size,\n            num_threads,\n            head_dim_padded,\n            head_dim_v_padded,\n            arch,\n            v_gmem_transposed,\n            gmem_threads_per_row,\n            page_entry_per_thread,\n            async_copy_elems,\n            gmem_tiled_copy_KV,\n            gmem_thr_copy_KV,\n            tPrPage,\n            tPrPageOffset,\n            tKpK,\n            tVpV,\n        )\n\n    @cute.jit\n    def load_page_table(self, n_block: Int32):\n        for i in cutlass.range(self.page_entry_per_thread, unroll=1):\n            row = (\n                i * self.num_threads\n                + (self.thread_idx % self.gmem_threads_per_row)\n                * (self.num_threads // self.gmem_threads_per_row)\n                + (self.thread_idx // self.gmem_threads_per_row)\n            )\n            row_idx = n_block * self.n_block_size + row\n\n            page_idx, page_offset = divmod(row_idx + self.leftpad_k, self.page_size_divmod)\n\n            is_valid = (\n                (i + 1) * self.num_threads <= self.n_block_size or row < self.n_block_size\n            ) and row_idx < self.seqlen_k\n            page = self.mPageTable[page_idx] if is_valid else 0\n\n            self.tPrPage[i] = page\n            self.tPrPageOffset[i] = page_offset\n\n    @cute.jit\n    def compute_X_ptr(self, K_or_V: str):\n        tPrXPtr = cute.make_rmem_tensor((self.page_entry_per_thread,), cutlass.Int64)\n        mX = self.mK_paged if const_expr(K_or_V == \"K\") else self.mV_paged\n        # K is always (page_size, d, num_pages). V matches K when not transposed,\n        # but is (dv, page_size, num_pages) when transposed (SM100).\n        transposed = const_expr(K_or_V == \"V\" and self.v_gmem_transposed)\n        for i in cutlass.range(self.page_entry_per_thread, unroll=1):\n            page = self.tPrPage[i]\n            page_offset = self.tPrPageOffset[i]\n            if const_expr(transposed):\n                tPrXPtr[i] = utils.elem_pointer(mX, (0, page_offset, page)).toint()\n            else:\n                tPrXPtr[i] = utils.elem_pointer(mX, (page_offset, 0, page)).toint()\n        return tPrXPtr\n\n    @cute.jit\n    def load_KV(self, n_block: Int32, sX: cute.Tensor, K_or_V: str):\n        assert K_or_V in (\"K\", \"V\")\n\n        tPrXPtr = self.compute_X_ptr(K_or_V)\n\n        if const_expr(self.arch == 90):\n            # SM90: sX is already stage-sliced by caller (sK[None, None, stage]).\n            # Flatten hierarchical modes to get (n_block_size, head_dim).\n            sX_pi = cute.group_modes(sX, 0, 1)\n            # SM90 does NOT transpose V here (it's transposed via utils.transpose_view before MMA)\n        else:\n            # SM100: Finesse sX layout to be (M, N).\n            sX_pi = cute.make_tensor(\n                sX.iterator,\n                cute.make_layout(\n                    (sX.shape[0][0], (sX.shape[0][1], sX.shape[2])),\n                    stride=(sX.stride[0][0], (sX.stride[0][1], sX.stride[2])),\n                ),\n            )\n\n            if const_expr(K_or_V == \"V\"):\n                # Transpose smem V to match transposed gmem layout\n                sX_pi = cute.make_tensor(sX_pi.iterator, cute.select(sX_pi.layout, mode=[1, 0]))\n\n        head_dim = self.head_dim_v_padded if const_expr(K_or_V == \"V\") else self.head_dim_padded\n        cX = cute.make_identity_tensor((self.n_block_size, head_dim))\n        tXsX = self.gmem_thr_copy_KV.partition_D(sX_pi)\n        tXcX = self.gmem_thr_copy_KV.partition_S(cX)\n        tXc0X = self.gmem_thr_copy_KV.get_slice(0).partition_S(cX)\n\n        seqlenk_row_limit = (\n            self.seqlen_k - n_block * self.n_block_size - tXcX[0][0] if n_block >= 0 else 0\n        )\n        for m in cutlass.range_constexpr(cute.size(tXsX, mode=[1])):\n            row_valid = tXc0X[0, m, 0][0] < seqlenk_row_limit\n            should_load = cute.make_fragment_like(tXsX[(0, None), m, 0], cute.Boolean)\n            should_load.fill(row_valid)\n\n            x_ptr_i64 = utils.shuffle_sync(\n                tPrXPtr[m // self.gmem_threads_per_row],\n                m % self.gmem_threads_per_row,\n                width=self.gmem_threads_per_row,\n            )\n            x_gmem_ptr = cute.make_ptr(\n                self.mK_paged.element_type, x_ptr_i64, cute.AddressSpace.gmem, assumed_align=16\n            )\n            mX_paged_cur = cute.make_tensor(x_gmem_ptr, cute.make_layout((head_dim,)))\n            mX_paged_cur_copy = cute.tiled_divide(mX_paged_cur, (self.async_copy_elems,))\n\n            for k in cutlass.range_constexpr(cute.size(tXsX, mode=[2])):\n                ki = tXcX[0, 0, k][1] // self.async_copy_elems\n                mX_paged_cur_copy_ki = mX_paged_cur_copy[None, ki]\n                tXsX_k = tXsX[None, m, k]\n                mX_paged_cur_copy_ki = cute.make_tensor(\n                    mX_paged_cur_copy_ki.iterator, tXsX_k.layout\n                )\n                cute.copy(\n                    self.gmem_tiled_copy_KV,\n                    mX_paged_cur_copy_ki,\n                    tXsX_k,\n                    pred=should_load,\n                )\n"
  },
  {
    "path": "flash_attn/cute/pipeline.py",
    "content": "# Copyright (c) 2025, Tri Dao.\n\n# import math\nfrom typing import Optional\nfrom dataclasses import dataclass\n\nimport cutlass.cute as cute\nfrom cutlass import Boolean, Int32, const_expr\nfrom cutlass.cutlass_dsl import if_generate, dsl_user_op\nfrom cutlass.pipeline import PipelineState\nfrom cutlass.pipeline import PipelineUserType\nfrom cutlass.pipeline import NamedBarrier as NamedBarrierOg\nfrom cutlass.pipeline import PipelineAsync as PipelineAsyncOg\nfrom cutlass.pipeline import PipelineTmaAsync as PipelineTmaAsyncOg\nfrom cutlass.pipeline import PipelineTmaUmma as PipelineTmaUmmaOg\nfrom cutlass.pipeline import PipelineUmmaAsync as PipelineUmmaAsyncOg\nfrom cutlass.pipeline import PipelineAsyncUmma as PipelineAsyncUmmaOg\n\n\nclass PipelineStateSimple:\n    \"\"\"\n    Pipeline state contains an index and phase bit corresponding to the current position in the circular buffer.\n    Use a single Int32 to store both the index and phase bit, then we use divmod to get the\n    index and phase. If stages is a power of 2, divmod turns into bit twiddling.\n    \"\"\"\n\n    def __init__(self, stages: int, phase_index: Int32):\n        # assert stages < 2**16\n        # self._log_stages = int(math.log2(stages))\n        # assert 1 << self._log_stages == stages, \"Number of stages must be a power of 2.\"\n        self._stages = stages\n        self._phase_index = phase_index\n\n    def clone(self) -> \"PipelineStateSimple\":\n        return PipelineStateSimple(self.stages, self._phase_index)\n\n    @property\n    def stages(self) -> int:\n        # return 1 << self._log_stages\n        return self._stages\n\n    @property\n    def index(self) -> Int32:\n        # return self._phase_index & 0xFFFF\n        # return self._phase_index & ((1 << self._log_stages) - 1)\n        if const_expr(self._stages == 1):\n            return Int32(0)\n        else:\n            return self._phase_index % self._stages\n\n    @property\n    def phase(self) -> Int32:\n        # return self._phase_index >> 16\n        # PTX docs say that the phase parity needs to be 0 or 1, so by right we need to\n        # take modulo 2. But in practice just passing the phase in without modulo works fine.\n        # return (self._phase_index >> self._log_stages) % 2\n        # return self._phase_index >> self._log_stages\n        if const_expr(self._stages == 1):\n            return self._phase_index\n        else:\n            return self._phase_index // self._stages\n\n    def advance(self):\n        if const_expr(self._stages == 1):\n            self._phase_index ^= 1\n        else:\n            self._phase_index += 1\n\n        # def then_body(phase_index):\n        #     # XOR the phase bit and set the index to 0\n        #     return (phase_index & 0xFFFF0000) ^ (1 << 16)\n\n        # def else_body(phase_index):\n        #     return phase_index\n\n        # self._phase_index = if_generate(\n        #     (self._phase_index & 0xFFFF) == self.stages,\n        #     then_body,\n        #     else_body,\n        #     [self._phase_index],\n        #     [Int32],\n        # )\n\n    def __extract_mlir_values__(self):\n        phase_index = self._phase_index\n        return [phase_index.ir_value()]\n\n    def __new_from_mlir_values__(self, values):\n        return PipelineStateSimple(self.stages, Int32(values[0]))\n\n\ndef make_pipeline_state(type: PipelineUserType, stages: int):\n    \"\"\"\n    Creates a pipeline state. Producers are assumed to start with an empty buffer and have a flipped phase bit of 1.\n    \"\"\"\n    if type is PipelineUserType.Producer:\n        # return PipelineStateSimple(stages, Int32(1 << 16))\n        return PipelineStateSimple(stages, Int32(stages))\n    elif type is PipelineUserType.Consumer:\n        return PipelineStateSimple(stages, Int32(0))\n    else:\n        assert False, \"Error: invalid PipelineUserType specified for make_pipeline_state.\"\n\n\n@dataclass(frozen=True)\nclass NamedBarrier(NamedBarrierOg):\n    @staticmethod\n    def create(*args, **kwargs):\n        obj = NamedBarrierOg.create(*args, **kwargs)\n        # Can't assign to __class__ directly since the dataclass is frozen\n        object.__setattr__(obj, \"__class__\", NamedBarrier)\n        return obj\n\n    @dsl_user_op\n    def arrive_w_index(self, index: Int32, *, loc=None, ip=None) -> None:\n        \"\"\"\n        The aligned flavor of arrive is used when all threads in the CTA will execute the\n        same instruction. See PTX documentation.\n        \"\"\"\n        cute.arch.barrier_arrive(\n            barrier_id=self.barrier_id + index,\n            number_of_threads=self.num_threads,\n            loc=loc,\n            ip=ip,\n        )\n\n    @dsl_user_op\n    def arrive_and_wait_w_index(self, index: Int32, *, loc=None, ip=None) -> None:\n        cute.arch.barrier(\n            barrier_id=self.barrier_id + index,\n            number_of_threads=self.num_threads,\n            loc=loc,\n            ip=ip,\n        )\n\n\n@dataclass(frozen=True)\nclass PipelineAsync(PipelineAsyncOg):\n    @staticmethod\n    def create(*args, **kwargs):\n        obj = PipelineAsyncOg.create(*args, **kwargs)\n        # Can't assign to __class__ directly since the dataclass is frozen\n        # obj.__class__ = PipelineAsync\n        object.__setattr__(obj, \"__class__\", PipelineAsync)\n        return obj\n\n    @dsl_user_op\n    def producer_acquire_w_index_phase(\n        self,\n        index: Int32,\n        phase: Int32,\n        try_acquire_token: Optional[Boolean] = None,\n        *,\n        loc=None,\n        ip=None,\n    ):\n        if_generate(\n            try_acquire_token is None or try_acquire_token == 0,\n            lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip),\n            loc=loc,\n            ip=ip,\n        )\n\n    @dsl_user_op\n    def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None):\n        self.sync_object_full.arrive(index, self.producer_mask, loc=loc, ip=ip)\n\n    @dsl_user_op\n    def consumer_wait_w_index_phase(\n        self,\n        index: Int32,\n        phase: Int32,\n        try_wait_token: Optional[Boolean] = None,\n        *,\n        loc=None,\n        ip=None,\n    ):\n        if_generate(\n            try_wait_token is None or try_wait_token == 0,\n            lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip),\n            loc=loc,\n            ip=ip,\n        )\n\n    @dsl_user_op\n    def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None):\n        self.sync_object_empty.arrive(index, self.consumer_mask, loc=loc, ip=ip)\n\n\n@dataclass(frozen=True)\nclass PipelineTmaAsync(PipelineTmaAsyncOg):\n    \"\"\"\n    Override producer_acquire to take in extra_tx_count parameter.\n    \"\"\"\n\n    @staticmethod\n    def create(*args, **kwargs):\n        obj = PipelineTmaAsyncOg.create(*args, **kwargs)\n        # Can't assign to __class__ directly since the dataclass is frozen\n        object.__setattr__(obj, \"__class__\", PipelineTmaAsync)\n        return obj\n\n    @dsl_user_op\n    def producer_acquire(\n        self,\n        state: PipelineState,\n        try_acquire_token: Optional[Boolean] = None,\n        extra_tx_count: int = 0,\n        *,\n        loc=None,\n        ip=None,\n    ):\n        \"\"\"\n        TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks.\n        \"\"\"\n        if_generate(\n            try_acquire_token is None or try_acquire_token == 0,\n            lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip),\n            loc=loc,\n            ip=ip,\n        )\n        if const_expr(extra_tx_count == 0):\n            self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip)\n        else:\n            tx_count = self.sync_object_full.tx_count + extra_tx_count\n            self.sync_object_full.arrive_and_expect_tx(state.index, tx_count, loc=loc, ip=ip)\n\n    @dsl_user_op\n    def producer_acquire_w_index_phase(\n        self,\n        index: Int32,\n        phase: Int32,\n        try_acquire_token: Optional[Boolean] = None,\n        *,\n        loc=None,\n        ip=None,\n    ):\n        if_generate(\n            try_acquire_token is None or try_acquire_token == 0,\n            lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip),\n            loc=loc,\n            ip=ip,\n        )\n        self.sync_object_full.arrive(index, self.producer_mask, loc=loc, ip=ip)\n\n    @dsl_user_op\n    def consumer_wait_w_index_phase(\n        self,\n        index: Int32,\n        phase: Int32,\n        try_wait_token: Optional[Boolean] = None,\n        *,\n        loc=None,\n        ip=None,\n    ):\n        if_generate(\n            try_wait_token is None or try_wait_token == 0,\n            lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip),\n            loc=loc,\n            ip=ip,\n        )\n\n    @dsl_user_op\n    def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None):\n        \"\"\"\n        TMA consumer release conditionally signals the empty buffer to the producer.\n        \"\"\"\n        if_generate(\n            self.is_signalling_thread,\n            lambda: self.sync_object_empty.arrive(index, self.consumer_mask, loc=loc, ip=ip),\n        )\n\n\n@dataclass(frozen=True)\nclass PipelineTmaUmma(PipelineTmaUmmaOg):\n    \"\"\"\n    Override producer_acquire to take in extra_tx_count parameter.\n    \"\"\"\n\n    @staticmethod\n    def create(*args, **kwargs):\n        obj = PipelineTmaUmmaOg.create(*args, **kwargs)\n        # Can't assign to __class__ directly since the dataclass is frozen\n        # obj.__class__ = PipelineTmaUmma\n        object.__setattr__(obj, \"__class__\", PipelineTmaUmma)\n        return obj\n\n    @dsl_user_op\n    def producer_acquire(\n        self,\n        state: PipelineState,\n        try_acquire_token: Optional[Boolean] = None,\n        extra_tx_count: int = 0,\n        *,\n        loc=None,\n        ip=None,\n    ):\n        \"\"\"\n        TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks.\n        \"\"\"\n        if_generate(\n            try_acquire_token is None or try_acquire_token == 0,\n            lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip),\n            loc=loc,\n            ip=ip,\n        )\n        if const_expr(extra_tx_count == 0):\n            if_generate(\n                self.is_leader_cta,\n                lambda: self.sync_object_full.arrive(\n                    state.index, self.producer_mask, loc=loc, ip=ip\n                ),\n                loc=loc,\n                ip=ip,\n            )\n        else:\n            tx_count = self.sync_object_full.tx_count + extra_tx_count\n            if_generate(\n                self.is_leader_cta,\n                lambda: self.sync_object_full.arrive_and_expect_tx(\n                    state.index, tx_count, loc=loc, ip=ip\n                ),\n                loc=loc,\n                ip=ip,\n            )\n\n    @dsl_user_op\n    def producer_acquire_w_index_phase(\n        self,\n        index: Int32,\n        phase: Int32,\n        try_acquire_token: Optional[Boolean] = None,\n        *,\n        loc=None,\n        ip=None,\n    ):\n        \"\"\"\n        TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks.\n        \"\"\"\n        if_generate(\n            try_acquire_token is None or try_acquire_token == 0,\n            lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip),\n            loc=loc,\n            ip=ip,\n        )\n        if_generate(\n            self.is_leader_cta,\n            lambda: self.sync_object_full.arrive(index, self.producer_mask, loc=loc, ip=ip),\n            loc=loc,\n            ip=ip,\n        )\n\n    @dsl_user_op\n    def consumer_wait_w_index_phase(\n        self,\n        index: Int32,\n        phase: Int32,\n        try_wait_token: Optional[Boolean] = None,\n        *,\n        loc=None,\n        ip=None,\n    ):\n        if_generate(\n            try_wait_token is None or try_wait_token == 0,\n            lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip),\n            loc=loc,\n            ip=ip,\n        )\n\n    @dsl_user_op\n    def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None):\n        \"\"\"\n        UMMA consumer release buffer empty, cta_group needs to be provided.\n        \"\"\"\n        self.sync_object_empty.arrive(index, self.consumer_mask, self.cta_group, loc=loc, ip=ip)\n\n\n@dataclass(frozen=True)\nclass PipelineUmmaAsync(PipelineUmmaAsyncOg):\n    @staticmethod\n    def create(*args, **kwargs):\n        obj = PipelineUmmaAsyncOg.create(*args, **kwargs)\n        # Can't assign to __class__ directly since the dataclass is frozen\n        object.__setattr__(obj, \"__class__\", PipelineUmmaAsync)\n        return obj\n\n    @dsl_user_op\n    def producer_acquire_w_index_phase(\n        self,\n        index: Int32,\n        phase: Int32,\n        try_acquire_token: Optional[Boolean] = None,\n        *,\n        loc=None,\n        ip=None,\n    ):\n        if_generate(\n            try_acquire_token is None or try_acquire_token == 0,\n            lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip),\n            loc=loc,\n            ip=ip,\n        )\n\n    @dsl_user_op\n    def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None):\n        \"\"\"\n        UMMA producer commit buffer full, cta_group needs to be provided.\n        \"\"\"\n        self.sync_object_full.arrive(index, self.producer_mask, self.cta_group, loc=loc, ip=ip)\n\n    @dsl_user_op\n    def consumer_wait_w_index_phase(\n        self,\n        index: Int32,\n        phase: Int32,\n        try_wait_token: Optional[Boolean] = None,\n        *,\n        loc=None,\n        ip=None,\n    ):\n        if_generate(\n            try_wait_token is None or try_wait_token == 0,\n            lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip),\n            loc=loc,\n            ip=ip,\n        )\n\n    @dsl_user_op\n    def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None):\n        self.sync_object_empty.arrive(index, self.consumer_mask, loc=loc, ip=ip)\n\n\n@dataclass(frozen=True)\nclass PipelineAsyncUmma(PipelineAsyncUmmaOg):\n    @staticmethod\n    def create(*args, **kwargs):\n        obj = PipelineAsyncUmmaOg.create(*args, **kwargs)\n        # Can't assign to __class__ directly since the dataclass is frozen\n        object.__setattr__(obj, \"__class__\", PipelineAsyncUmma)\n        return obj\n\n    @dsl_user_op\n    def producer_acquire_w_index_phase(\n        self,\n        index: Int32,\n        phase: Int32,\n        try_acquire_token: Optional[Boolean] = None,\n        *,\n        loc=None,\n        ip=None,\n    ):\n        if_generate(\n            try_acquire_token is None or try_acquire_token == 0,\n            lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip),\n            loc=loc,\n            ip=ip,\n        )\n\n    @dsl_user_op\n    def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None):\n        self.sync_object_full.arrive(index, self.producer_mask, loc=loc, ip=ip)\n\n    @dsl_user_op\n    def consumer_wait_w_index_phase(\n        self,\n        index: Int32,\n        phase: Int32,\n        try_wait_token: Optional[Boolean] = None,\n        *,\n        loc=None,\n        ip=None,\n    ):\n        if_generate(\n            try_wait_token is None or try_wait_token == 0,\n            lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip),\n            loc=loc,\n            ip=ip,\n        )\n\n    @dsl_user_op\n    def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None):\n        \"\"\"\n        UMMA consumer release buffer empty, cta_group needs to be provided.\n        \"\"\"\n        self.sync_object_empty.arrive(index, self.consumer_mask, self.cta_group, loc=loc, ip=ip)\n"
  },
  {
    "path": "flash_attn/cute/pyproject.toml",
    "content": "[build-system]\nrequires = [\"setuptools>=75\", \"setuptools-scm>=8\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"flash-attn-4\"\ndynamic = [\"version\"]\ndescription = \"Flash Attention CUTE (CUDA Template Engine) implementation\"\nreadme = \"README.md\"\nrequires-python = \">=3.10\"\nlicense = {text = \"BSD 3-Clause License\"}\nauthors = [\n    {name = \"Tri Dao\"},\n]\nclassifiers = [\n    \"Development Status :: 3 - Alpha\",\n    \"License :: OSI Approved :: BSD License\",\n    \"Programming Language :: Python :: 3\",\n    \"Programming Language :: Python :: 3.10\",\n    \"Programming Language :: Python :: 3.11\",\n    \"Programming Language :: Python :: 3.12\",\n]\n\ndependencies = [\n    \"nvidia-cutlass-dsl>=4.4.2\",\n    \"torch\",\n    \"einops\",\n    \"typing_extensions\",\n    \"apache-tvm-ffi>=0.1.5,<0.2\",\n    \"torch-c-dlpack-ext\",\n    \"quack-kernels>=0.3.3\",\n]\n\n[project.optional-dependencies]\ndev = [\n    \"pytest\",\n    \"ruff\",\n]\n\n[project.urls]\nHomepage = \"https://github.com/Dao-AILab/flash-attention\"\nRepository = \"https://github.com/Dao-AILab/flash-attention\"\n\n[tool.setuptools]\npackages = [\"flash_attn.cute\"]\npackage-dir = {\"flash_attn.cute\" = \".\"}\n\n[tool.setuptools_scm]\nroot = \"../..\"\ntag_regex = \"^fa4-v(?P<version>.+)$\"\ngit_describe_command = \"git describe --dirty --tags --long --match 'fa4-v*'\"\nfallback_version = \"0.0.0\"\n\n[tool.ruff]\nline-length = 100\n\n[tool.ruff.lint]\nignore = [\n    \"E731\",  # do not assign a lambda expression, use a def\n    \"E741\",  # Do not use variables named 'I', 'O', or 'l'\n    \"F841\",  # local variable is assigned to but never used\n    \"D102\",  # Missing docstring in public methods\n]\n"
  },
  {
    "path": "flash_attn/cute/seqlen_info.py",
    "content": "from typing import Optional\nfrom dataclasses import dataclass\n\nimport cutlass\nimport cutlass.cute as cute\nfrom cutlass import Int32, const_expr\n\nfrom quack import copy_utils\n\n\"\"\"\nThis consolidates all the info related to sequence length. This is so that we can do all\nthe gmem reads once at the beginning of each tile, rather than having to repeat these reads\nto compute various things like n_block_min, n_block_max, etc.\n\"\"\"\n\n\n@dataclass(frozen=True)\nclass SeqlenInfo:\n    offset: Int32\n    offset_padded: Int32\n    seqlen: Int32\n    has_cu_seqlens: cutlass.Constexpr[bool] = False\n\n    @staticmethod\n    def create(\n        batch_idx: Int32,\n        seqlen_static: Int32,\n        cu_seqlens: Optional[cute.Tensor] = None,\n        seqused: Optional[cute.Tensor] = None,\n        tile: cutlass.Constexpr[int] = 128,\n    ):\n        offset = 0 if const_expr(cu_seqlens is None) else cu_seqlens[batch_idx]\n        offset_padded = (\n            0\n            if const_expr(cu_seqlens is None)\n            # Add divby so that the compiler knows the alignment when moving by offset_padded\n            else cute.assume((offset + batch_idx * tile) // tile * tile, divby=tile)\n        )\n        if const_expr(seqused is not None):\n            seqlen = seqused[batch_idx]\n        elif const_expr(cu_seqlens is not None):\n            seqlen = cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx]\n        else:\n            seqlen = seqlen_static\n        return SeqlenInfo(offset, offset_padded, seqlen, has_cu_seqlens=cu_seqlens is not None)\n\n    def offset_batch(\n        self,\n        mT: cute.Tensor,\n        batch_idx: Int32,\n        dim: int,\n        padded: cutlass.Constexpr[bool] = False,\n        multiple: int = 1,\n    ) -> cute.Tensor:\n        \"\"\"Offset a tensor by batch index. batch dim is at position `dim`, seqlen is at dim=0.\"\"\"\n        if const_expr(not self.has_cu_seqlens):\n            idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mT) - 1 - dim)\n            return mT[idx]\n        else:\n            off = multiple * (self.offset if const_expr(not padded) else self.offset_padded)\n            offset = off if const_expr(cute.rank(mT.shape[0]) == 1) else (0, off)\n            idx = (offset,) + (None,) * (cute.rank(mT) - 1)\n            return cute.domain_offset(idx, mT)\n\n\n@dataclass(frozen=True)\nclass SeqlenInfoQK:\n    offset_q: Int32\n    offset_k: Int32\n    padded_offset_q: Int32\n    padded_offset_k: Int32\n    seqlen_q: Int32\n    seqlen_k: Int32\n    has_cu_seqlens_q: cutlass.Constexpr[bool]\n    has_cu_seqlens_k: cutlass.Constexpr[bool]\n    has_seqused_q: cutlass.Constexpr[bool]\n    has_seqused_k: cutlass.Constexpr[bool]\n\n    @staticmethod\n    def create(\n        batch_idx: Int32,\n        seqlen_q_static: Int32,\n        seqlen_k_static: Int32,\n        mCuSeqlensQ: Optional[cute.Tensor] = None,\n        mCuSeqlensK: Optional[cute.Tensor] = None,\n        mSeqUsedQ: Optional[cute.Tensor] = None,\n        mSeqUsedK: Optional[cute.Tensor] = None,\n        tile_m: cutlass.Constexpr[Int32] = 128,\n        tile_n: cutlass.Constexpr[Int32] = 128,\n    ):\n        offset_q = 0 if const_expr(mCuSeqlensQ is None) else mCuSeqlensQ[batch_idx]\n        offset_k = 0 if const_expr(mCuSeqlensK is None) else mCuSeqlensK[batch_idx]\n        padded_offset_q = (\n            0\n            if const_expr(mCuSeqlensQ is None)\n            else cute.assume((offset_q + batch_idx * tile_m) // tile_m * tile_m, divby=tile_m)\n        )\n        padded_offset_k = (\n            0\n            if const_expr(mCuSeqlensK is None)\n            else cute.assume((offset_k + batch_idx * tile_n) // tile_n * tile_n, divby=tile_n)\n        )\n        if const_expr(mSeqUsedQ is not None):\n            seqlen_q = mSeqUsedQ[batch_idx]\n        else:\n            seqlen_q = (\n                seqlen_q_static\n                if const_expr(mCuSeqlensQ is None)\n                else mCuSeqlensQ[batch_idx + 1] - offset_q\n            )\n        if const_expr(mSeqUsedK is not None):\n            seqlen_k = mSeqUsedK[batch_idx]\n        else:\n            seqlen_k = (\n                seqlen_k_static\n                if const_expr(mCuSeqlensK is None)\n                else mCuSeqlensK[batch_idx + 1] - offset_k\n            )\n        return SeqlenInfoQK(\n            offset_q,\n            offset_k,\n            padded_offset_q,\n            padded_offset_k,\n            seqlen_q,\n            seqlen_k,\n            has_cu_seqlens_q=mCuSeqlensQ is not None,\n            has_cu_seqlens_k=mCuSeqlensK is not None,\n            has_seqused_q=mSeqUsedQ is not None,\n            has_seqused_k=mSeqUsedK is not None,\n        )\n\n    def offset_batch_Q(\n        self,\n        mQ: cute.Tensor,\n        batch_idx: Int32,\n        dim: int,\n        padded: cutlass.Constexpr[bool] = False,\n        ragged: cutlass.Constexpr[bool] = False,\n    ) -> cute.Tensor:\n        \"\"\"Seqlen must be the first dimension of mQ\"\"\"\n        if const_expr(not ragged):\n            if const_expr(not self.has_cu_seqlens_q):\n                idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mQ) - 1 - dim)\n                return mQ[idx]\n            else:\n                offset_q = self.offset_q if const_expr(not padded) else self.padded_offset_q\n                offset_q = offset_q if const_expr(cute.rank(mQ.shape[0]) == 1) else (None, offset_q)\n                idx = (offset_q,) + (None,) * (cute.rank(mQ) - 1)\n                return cute.domain_offset(idx, mQ)\n        else:\n            if const_expr(not self.has_cu_seqlens_q):\n                offset_q = 0\n                idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mQ) - 1 - dim)\n                mQ = mQ[idx]\n            else:\n                offset_q = self.offset_q if const_expr(not padded) else self.padded_offset_q\n            if const_expr(cute.rank(mQ.shape[0]) == 1):\n                return copy_utils.offset_ragged_tensor(\n                    mQ, offset_q, self.seqlen_q, ragged_dim=0, ptr_shift=True\n                )\n            else:  # PackGQA\n                assert cute.rank(mQ.shape[0]) == 2\n                # Unpack before calling offset_ragged_tensor, then pack\n                idx = ((None, None),) + (None,) * (cute.rank(mQ) - 1)\n                mQ = mQ[idx]\n                mQ = copy_utils.offset_ragged_tensor(\n                    mQ, offset_q, self.seqlen_q, ragged_dim=1, ptr_shift=True\n                )\n                return cute.group_modes(mQ, 0, 2)\n\n    def offset_batch_K(\n        self,\n        mK: cute.Tensor,\n        batch_idx: Int32,\n        dim: int,\n        padded: cutlass.Constexpr[bool] = False,\n        ragged: cutlass.Constexpr[bool] = False,\n        multiple: int = 1,\n    ) -> cute.Tensor:\n        \"\"\"Seqlen must be the first dimension of mK\"\"\"\n        if const_expr(not ragged):\n            if const_expr(not self.has_cu_seqlens_k):\n                idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mK) - 1 - dim)\n                return mK[idx]\n            else:\n                offset_k = self.offset_k if const_expr(not padded) else self.padded_offset_k\n                offset_k *= multiple\n                idx = (offset_k,) + (None,) * (cute.rank(mK) - 1)\n                return cute.domain_offset(idx, mK)\n        else:\n            if const_expr(not self.has_cu_seqlens_k):\n                offset_k = 0\n                idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mK) - 1 - dim)\n                mK = mK[idx]\n            else:\n                offset_k = self.offset_k if const_expr(not padded) else self.padded_offset_k\n                offset_k *= multiple\n            return copy_utils.offset_ragged_tensor(\n                mK, offset_k, self.seqlen_k, ragged_dim=0, ptr_shift=True\n            )\n\n\n@dataclass(frozen=True)\nclass SeqlenInfoQKNewK:\n    \"\"\"Sequence length info for append-KV with left-padding and new K support.\n\n    Extends SeqlenInfoQK with:\n    - leftpad_k: left padding for K (tokens to skip at the start of the KV cache)\n    - offset_k_new: offset into the new K tensor\n    - seqlen_k_og: original K length (before appending new K), excluding leftpad\n    - seqlen_k_new: length of new K to append\n    - seqlen_k: total K length (seqlen_k_og + seqlen_k_new)\n    - seqlen_rotary: position for rotary embedding computation\n    \"\"\"\n\n    leftpad_k: Int32\n    offset_q: Int32\n    offset_k: Int32\n    offset_k_new: Int32\n    seqlen_q: Int32\n    seqlen_k_og: Int32\n    seqlen_k_new: Int32\n    seqlen_k: Int32\n    seqlen_rotary: Int32\n\n    @staticmethod\n    def create(\n        batch_idx: Int32,\n        seqlen_q_static: Int32,\n        seqlen_k_static: Int32,\n        shape_K_new_0: Int32,\n        mCuSeqlensQ: Optional[cute.Tensor] = None,\n        mCuSeqlensK: Optional[cute.Tensor] = None,\n        mCuSeqlensKNew: Optional[cute.Tensor] = None,\n        mSeqUsedQ: Optional[cute.Tensor] = None,\n        mSeqUsedK: Optional[cute.Tensor] = None,\n        mLeftpadK: Optional[cute.Tensor] = None,\n        mSeqlensRotary: Optional[cute.Tensor] = None,\n    ):\n        leftpad_k = 0 if const_expr(mLeftpadK is None) else mLeftpadK[batch_idx]\n        offset_q = 0 if const_expr(mCuSeqlensQ is None) else mCuSeqlensQ[batch_idx]\n        if const_expr(mCuSeqlensK is not None):\n            offset_k = mCuSeqlensK[batch_idx] + leftpad_k\n        else:\n            offset_k = leftpad_k if const_expr(mCuSeqlensQ is not None) else 0\n        offset_k_new = 0 if const_expr(mCuSeqlensKNew is None) else mCuSeqlensKNew[batch_idx]\n        # seqlen_q\n        if const_expr(mSeqUsedQ is not None):\n            seqlen_q = mSeqUsedQ[batch_idx]\n        elif const_expr(mCuSeqlensQ is not None):\n            seqlen_q = mCuSeqlensQ[batch_idx + 1] - mCuSeqlensQ[batch_idx]\n        else:\n            seqlen_q = seqlen_q_static\n        # seqlen_k_og: original K length (excluding leftpad)\n        if const_expr(mSeqUsedK is not None):\n            seqlen_k_og = mSeqUsedK[batch_idx] - leftpad_k\n        elif const_expr(mCuSeqlensK is not None):\n            seqlen_k_og = mCuSeqlensK[batch_idx + 1] - mCuSeqlensK[batch_idx] - leftpad_k\n        else:\n            seqlen_k_og = (\n                seqlen_k_static - leftpad_k\n                if const_expr(mCuSeqlensQ is not None)\n                else seqlen_k_static\n            )\n        # seqlen_k_new\n        if const_expr(mCuSeqlensKNew is None):\n            seqlen_k_new = 0 if const_expr(mCuSeqlensQ is None) else shape_K_new_0\n        else:\n            seqlen_k_new = mCuSeqlensKNew[batch_idx + 1] - mCuSeqlensKNew[batch_idx]\n        seqlen_k = seqlen_k_og if const_expr(mCuSeqlensQ is None) else seqlen_k_og + seqlen_k_new\n\n        # seqlen_rotary: defaults to seqlen_k_og + leftpad_k unless explicitly provided\n        if const_expr(mSeqlensRotary is not None):\n            seqlen_rotary = mSeqlensRotary[batch_idx]\n        else:\n            seqlen_rotary = seqlen_k_og + leftpad_k\n        return SeqlenInfoQKNewK(\n            leftpad_k,\n            offset_q,\n            offset_k,\n            offset_k_new,\n            seqlen_q,\n            seqlen_k_og,\n            seqlen_k_new,\n            seqlen_k,\n            seqlen_rotary,\n        )\n"
  },
  {
    "path": "flash_attn/cute/sm90_config_search.py",
    "content": "\"\"\"Search feasible SM90 fwd/bwd attention configs for given (head_dim, head_dim_v).\n\nEnumerates tile sizes, swap modes, atom layouts, and staging options.\nChecks GMMA divisibility, register budget, and shared memory budget.\n\nUsage:\n    python flash_attn/cute/sm90_config_search.py --headdim 128\n    python flash_attn/cute/sm90_config_search.py --mode fwd --headdim 192-128\n    python flash_attn/cute/sm90_config_search.py --mode bwd --headdim 192 --tile-n 64,96\n\"\"\"\n\nimport math\n\n# H100 hardware limits\nSMEM_LIMIT = 224 * 1024  # 228 KB minus ~3 KB for LSE, dPsum, mbarriers\nREG_LIMITS = {2: 216, 3: 128}  # per-WG budget: 2WG=240-24, 3WG=160-32\nTHREADS_PER_WG = 128\n\n\ndef _divisors(n):\n    return [d for d in range(1, n + 1) if n % d == 0]\n\n\ndef _acc_regs(M, N, num_wg):\n    \"\"\"Accumulator registers per thread per WG.\"\"\"\n    return M * N // (num_wg * THREADS_PER_WG)\n\n\ndef _check_mma(M, N, num_wg, atom_layout_m, swap_AB):\n    \"\"\"Check MMA feasibility. Returns regs per WG, or None if infeasible.\n\n    GMMA atom M=64. Swap exchanges (M, N) and atom layout.\n    Requires: M divisible by (atom_layout_m * 64), N by (atom_layout_n * 8).\n    \"\"\"\n    if swap_AB:\n        M, N = N, M\n        atom_layout_m = num_wg // atom_layout_m\n    atom_layout_n = num_wg // atom_layout_m\n    if M % (atom_layout_m * 64) != 0 or N % (atom_layout_n * 8) != 0:\n        return None\n    return _acc_regs(M, N, num_wg)\n\n\ndef _mma_traffic(M_eff, N_eff, K_red, num_wg, wg_n, is_rs=False):\n    \"\"\"Total SMEM read traffic for one MMA (all WGs combined).\n\n    num_instr = (M_eff / 64) * wg_n instructions total.\n    Each reads A(64, K_red) and B(N_eff/wg_n, K_red) from smem (bf16).\n    \"\"\"\n    num_instr = (M_eff // 64) * wg_n\n    A_per = 64 * K_red * 2 if not is_rs else 0\n    B_per = (N_eff // wg_n) * K_red * 2\n    return num_instr * (A_per + B_per)\n\n\n# ============================================================================\n# Backward\n# ============================================================================\n\n\ndef _check_bwd_config(\n    hdim,\n    hdimv,\n    tile_m,\n    tile_n,\n    num_wg,\n    SdP_swapAB,\n    dKV_swapAB,\n    dQ_swapAB,\n    AtomLayoutMSdP,\n    AtomLayoutNdKV,\n    AtomLayoutMdQ,\n):\n    reg_limit = REG_LIMITS[num_wg]\n\n    # MMA feasibility\n    regs_SdP = _check_mma(tile_m, tile_n, num_wg, AtomLayoutMSdP, SdP_swapAB)\n    regs_dK = _check_mma(tile_n, hdim, num_wg, AtomLayoutNdKV, dKV_swapAB)\n    regs_dV = _check_mma(tile_n, hdimv, num_wg, AtomLayoutNdKV, dKV_swapAB)\n    regs_dQ = _check_mma(tile_m, hdim, num_wg, AtomLayoutMdQ, dQ_swapAB)\n    if any(r is None for r in (regs_SdP, regs_dK, regs_dV, regs_dQ)):\n        return None\n\n    # Peak regs: max(S+dP, dQ) + dK + dV\n    total_regs = max(2 * regs_SdP, regs_dQ) + regs_dK + regs_dV\n    if total_regs > reg_limit:\n        return None\n\n    # SMEM\n    mma_dkv_is_rs = (\n        AtomLayoutMSdP == 1 and AtomLayoutNdKV == num_wg and SdP_swapAB and not dKV_swapAB\n    )\n    Q_stage, PdS_stage = 2, 1\n\n    for dO_stage in (2, 1):\n        sQ = tile_m * hdim * 2 * Q_stage\n        sK = tile_n * hdim * 2\n        sV = tile_n * hdimv * 2\n        sdO = tile_m * hdimv * 2 * dO_stage\n        sPdS = tile_m * tile_n * 2 * PdS_stage\n        sP = sPdS if not mma_dkv_is_rs else 0\n        sdQaccum = tile_m * hdim * 4\n        smem = sQ + sK + sV + sdO + sP + sPdS + sdQaccum\n        if smem <= SMEM_LIMIT:\n            break\n    else:\n        return None\n\n    # SMEM traffic\n    def _swap(a, b, s):\n        return (b, a) if s else (a, b)\n\n    def _wg_n(al_m, s):\n        return al_m if s else num_wg // al_m\n\n    M_s, N_s = _swap(tile_m, tile_n, SdP_swapAB)\n    wn_SdP = _wg_n(AtomLayoutMSdP, SdP_swapAB)\n    traffic_S = _mma_traffic(M_s, N_s, hdim, num_wg, wn_SdP)\n    traffic_dP = _mma_traffic(M_s, N_s, hdimv, num_wg, wn_SdP)\n\n    wn_dKV = _wg_n(AtomLayoutNdKV, dKV_swapAB)\n    M_dv, N_dv = _swap(tile_n, hdimv, dKV_swapAB)\n    traffic_dV = _mma_traffic(M_dv, N_dv, tile_m, num_wg, wn_dKV, is_rs=mma_dkv_is_rs)\n    M_dk, N_dk = _swap(tile_n, hdim, dKV_swapAB)\n    traffic_dK = _mma_traffic(M_dk, N_dk, tile_m, num_wg, wn_dKV, is_rs=mma_dkv_is_rs)\n\n    M_dq, N_dq = _swap(tile_m, hdim, dQ_swapAB)\n    wn_dQ = _wg_n(AtomLayoutMdQ, dQ_swapAB)\n    traffic_dQ = _mma_traffic(M_dq, N_dq, tile_n, num_wg, wn_dQ)\n\n    traffic_P_store = tile_m * tile_n * 2 if not mma_dkv_is_rs else 0\n    traffic_dS_store = tile_m * tile_n * 2\n    traffic_dQ_smem = tile_m * hdim * 4 * 2  # store + TMA load\n\n    smem_traffic = (\n        traffic_S\n        + traffic_dP\n        + traffic_dV\n        + traffic_dK\n        + traffic_dQ\n        + traffic_P_store\n        + traffic_dS_store\n        + traffic_dQ_smem\n    )\n\n    return dict(\n        tile_m=tile_m,\n        tile_n=tile_n,\n        num_wg=num_wg,\n        Q_stage=Q_stage,\n        dO_stage=dO_stage,\n        PdS_stage=PdS_stage,\n        SdP_swapAB=SdP_swapAB,\n        dKV_swapAB=dKV_swapAB,\n        dQ_swapAB=dQ_swapAB,\n        AtomLayoutMSdP=AtomLayoutMSdP,\n        AtomLayoutNdKV=AtomLayoutNdKV,\n        AtomLayoutMdQ=AtomLayoutMdQ,\n        mma_dkv_is_rs=mma_dkv_is_rs,\n        regs_SdP=regs_SdP,\n        regs_dK=regs_dK,\n        regs_dV=regs_dV,\n        regs_dQ=regs_dQ,\n        total_regs=total_regs,\n        reg_limit=reg_limit,\n        smem_bytes=smem,\n        smem_kb=smem / 1024,\n        smem_traffic=smem_traffic,\n        smem_traffic_kb=smem_traffic / 1024,\n        smem_traffic_per_block=smem_traffic / (tile_m * tile_n),\n    )\n\n\ndef find_feasible_bwd_configs(\n    head_dim,\n    head_dim_v=None,\n    tile_m_choices=(64, 80, 96, 112, 128),\n    tile_n_choices=(64, 80, 96, 112, 128),\n):\n    if head_dim_v is None:\n        head_dim_v = head_dim\n    hdim = int(math.ceil(head_dim / 32) * 32)\n    hdimv = int(math.ceil(head_dim_v / 32) * 32)\n\n    results = []\n    for num_wg in (2, 3):\n        divs = _divisors(num_wg)\n        for tile_m in tile_m_choices:\n            for tile_n in tile_n_choices:\n                for SdP_swap in (False, True):\n                    if (tile_n if SdP_swap else tile_m) % 64 != 0:\n                        continue\n                    for dKV_swap in (False, True):\n                        if not dKV_swap and tile_n % 64 != 0:\n                            continue\n                        if dKV_swap and (hdim % 64 != 0 or hdimv % 64 != 0):\n                            continue\n                        for dQ_swap in (False, True):\n                            if (hdim if dQ_swap else tile_m) % 64 != 0:\n                                continue\n                            for a1 in divs:\n                                for a2 in divs:\n                                    for a3 in divs:\n                                        cfg = _check_bwd_config(\n                                            hdim,\n                                            hdimv,\n                                            tile_m,\n                                            tile_n,\n                                            num_wg,\n                                            SdP_swap,\n                                            dKV_swap,\n                                            dQ_swap,\n                                            a1,\n                                            a2,\n                                            a3,\n                                        )\n                                        if cfg is not None:\n                                            results.append(cfg)\n\n    results.sort(key=lambda c: (-c[\"tile_n\"], -c[\"tile_m\"], c[\"smem_traffic_per_block\"]))\n    return results\n\n\ndef print_bwd_configs(configs, max_results=20):\n    if not configs:\n        print(\"No feasible configs found!\")\n        return\n    n = min(len(configs), max_results)\n    print(f\"Found {len(configs)} feasible configs (showing top {n}):\\n\")\n    hdr = (\n        f\"{'wg':>2} {'tm':>3} {'tn':>3}  \"\n        f\"{'SdP':>3} {'dKV':>3} {'dQ':>3}  \"\n        f\"{'aSdP':>4} {'adKV':>4} {'adQ':>4}  \"\n        f\"{'Qs':>2} {'dOs':>3}  \"\n        f\"{'rS':>3} {'rdK':>3} {'rdV':>3} {'rdQ':>3} {'tot':>4}/{'':<3}  \"\n        f\"{'smem':>5}  {'traffic':>7}  {'tr/blk':>6}\"\n    )\n    print(hdr)\n    print(\"-\" * len(hdr))\n    B = lambda b: \"T\" if b else \"F\"\n    for c in configs[:max_results]:\n        print(\n            f\"{c['num_wg']:>2} {c['tile_m']:>3} {c['tile_n']:>3}  \"\n            f\"{B(c['SdP_swapAB']):>3} {B(c['dKV_swapAB']):>3} {B(c['dQ_swapAB']):>3}  \"\n            f\"{c['AtomLayoutMSdP']:>4} {c['AtomLayoutNdKV']:>4} {c['AtomLayoutMdQ']:>4}  \"\n            f\"{c['Q_stage']:>2} {c['dO_stage']:>3}  \"\n            f\"{c['regs_SdP']:>3} {c['regs_dK']:>3} {c['regs_dV']:>3} {c['regs_dQ']:>3} \"\n            f\"{c['total_regs']:>4}/{c['reg_limit']:<3}  \"\n            f\"{c['smem_kb']:>4.0f}K  \"\n            f\"{c['smem_traffic_kb']:>6.0f}K  \"\n            f\"{c['smem_traffic_per_block']:>6.1f}\"\n        )\n\n\n# ============================================================================\n# Forward\n# ============================================================================\n\n\ndef _check_fwd_config(hdim, hdimv, tile_n, num_wg, pv_is_rs, overlap_wg):\n    reg_limit = REG_LIMITS[num_wg]\n    tile_m = num_wg * 64\n\n    if tile_n % 8 != 0:\n        return None\n\n    regs_S = _acc_regs(tile_m, tile_n, num_wg)\n    regs_O = _acc_regs(tile_m, hdimv, num_wg)\n    regs_P = regs_S // 2  # bf16 = half of f32\n\n    if overlap_wg:\n        total_regs = regs_S + regs_P + regs_O\n    else:\n        total_regs = regs_S + regs_O\n\n    if total_regs > reg_limit:\n        return None\n\n    # SMEM: 1 stage Q, 2 stages K/V, O overlaps Q, sP if not RS\n    sQ = tile_m * hdim * 2\n    sK = tile_n * hdim * 2 * 2\n    sV = tile_n * hdimv * 2 * 2\n    sO = tile_m * hdimv * 2\n    sP = tile_m * tile_n * 2 if not pv_is_rs else 0\n    smem = max(sQ, sO) + sK + sV + sP\n    if smem > SMEM_LIMIT:\n        return None\n\n    # SMEM traffic: num_instr = num_wg (all WGs in M, wg_n=1)\n    traffic_S = num_wg * (64 * hdim * 2 + tile_n * hdim * 2)\n    A_pv = 64 * tile_n * 2 if not pv_is_rs else 0\n    traffic_O = num_wg * (A_pv + hdimv * tile_n * 2)\n    traffic_P_store = tile_m * tile_n * 2 if not pv_is_rs else 0\n    smem_traffic = traffic_S + traffic_O + traffic_P_store\n\n    return dict(\n        tile_m=tile_m,\n        tile_n=tile_n,\n        num_wg=num_wg,\n        pv_is_rs=pv_is_rs,\n        overlap_wg=overlap_wg,\n        regs_S=regs_S,\n        regs_O=regs_O,\n        regs_P=regs_P,\n        total_regs=total_regs,\n        reg_limit=reg_limit,\n        smem_bytes=smem,\n        smem_kb=smem / 1024,\n        smem_traffic=smem_traffic,\n        smem_traffic_kb=smem_traffic / 1024,\n        smem_traffic_per_block=smem_traffic / (tile_m * tile_n),\n    )\n\n\ndef find_feasible_fwd_configs(\n    head_dim, head_dim_v=None, tile_n_choices=(64, 80, 96, 112, 128, 144, 160, 176, 192)\n):\n    if head_dim_v is None:\n        head_dim_v = head_dim\n    hdim = int(math.ceil(head_dim / 32) * 32)\n    hdimv = int(math.ceil(head_dim_v / 32) * 32)\n\n    results = []\n    for num_wg in (2, 3):\n        for tile_n in tile_n_choices:\n            for pv_is_rs in (True, False):\n                for overlap_wg in (True, False):\n                    cfg = _check_fwd_config(hdim, hdimv, tile_n, num_wg, pv_is_rs, overlap_wg)\n                    if cfg is not None:\n                        results.append(cfg)\n\n    results.sort(key=lambda c: (-c[\"tile_n\"], c[\"smem_traffic_per_block\"]))\n    return results\n\n\ndef print_fwd_configs(configs, max_results=20):\n    if not configs:\n        print(\"No feasible configs found!\")\n        return\n    n = min(len(configs), max_results)\n    print(f\"Found {len(configs)} feasible configs (showing top {n}):\\n\")\n    hdr = (\n        f\"{'wg':>2} {'tm':>3} {'tn':>3}  \"\n        f\"{'RS':>2} {'olap':>4}  \"\n        f\"{'rS':>3} {'rP':>3} {'rO':>3} {'tot':>4}/{'':<3}  \"\n        f\"{'smem':>5}  {'traffic':>7}  {'tr/blk':>6}\"\n    )\n    print(hdr)\n    print(\"-\" * len(hdr))\n    B = lambda b: \"T\" if b else \"F\"\n    for c in configs[:max_results]:\n        print(\n            f\"{c['num_wg']:>2} {c['tile_m']:>3} {c['tile_n']:>3}  \"\n            f\"{B(c['pv_is_rs']):>2} {B(c['overlap_wg']):>4}  \"\n            f\"{c['regs_S']:>3} {c['regs_P']:>3} {c['regs_O']:>3} \"\n            f\"{c['total_regs']:>4}/{c['reg_limit']:<3}  \"\n            f\"{c['smem_kb']:>4.0f}K  \"\n            f\"{c['smem_traffic_kb']:>6.0f}K  \"\n            f\"{c['smem_traffic_per_block']:>6.1f}\"\n        )\n\n\n# ============================================================================\n# CLI\n# ============================================================================\n\nif __name__ == \"__main__\":\n    import argparse\n\n    parser = argparse.ArgumentParser(description=\"Search feasible SM90 MMA configs\")\n    parser.add_argument(\"--mode\", choices=[\"fwd\", \"bwd\", \"both\"], default=\"both\")\n    parser.add_argument(\n        \"--headdim\", type=str, default=\"128\", help=\"Head dim, or hdim-hdimv (e.g. 192-128)\"\n    )\n    parser.add_argument(\"--tile-m\", type=str, default=\"64,80,96,112,128\", help=\"Bwd tile_m choices\")\n    parser.add_argument(\n        \"--tile-n\",\n        type=str,\n        default=None,\n        help=\"tile_n choices (default: fwd up to 192, bwd up to 128)\",\n    )\n    parser.add_argument(\"-n\", \"--num-results\", type=int, default=30)\n    args = parser.parse_args()\n\n    parts = args.headdim.split(\"-\")\n    hdim = int(parts[0])\n    hdimv = int(parts[1]) if len(parts) > 1 else hdim\n\n    TN_FWD = \"64,80,96,112,128,144,160,176,192\"\n    TN_BWD = \"64,80,96,112,128\"\n\n    if args.mode in (\"fwd\", \"both\"):\n        tn = tuple(int(x) for x in (args.tile_n or TN_FWD).split(\",\"))\n        print(f\"=== FWD configs: hdim={hdim}, hdimv={hdimv} ===\\n\")\n        print_fwd_configs(find_feasible_fwd_configs(hdim, hdimv, tn), args.num_results)\n        print()\n\n    if args.mode in (\"bwd\", \"both\"):\n        tm = tuple(int(x) for x in args.tile_m.split(\",\"))\n        tn = tuple(int(x) for x in (args.tile_n or TN_BWD).split(\",\"))\n        print(f\"=== BWD configs: hdim={hdim}, hdimv={hdimv} ===\\n\")\n        print_bwd_configs(find_feasible_bwd_configs(hdim, hdimv, tm, tn), args.num_results)\n"
  },
  {
    "path": "flash_attn/cute/softmax.py",
    "content": "# Copyright (c) 2025, Tri Dao.\n\nimport math\nimport operator\nfrom typing import Tuple\nfrom dataclasses import dataclass\n\nimport cutlass\nimport cutlass.cute as cute\nfrom cutlass import Float32\n\nfrom quack import layout_utils\nimport flash_attn.cute.utils as utils\nfrom quack.cute_dsl_utils import ParamsBase\nfrom flash_attn.cute.seqlen_info import SeqlenInfoQK\n\n\n@dataclass\nclass Softmax(ParamsBase):\n    scale_log2: Float32\n    num_rows: cutlass.Constexpr[int]\n    row_max: cute.Tensor\n    row_sum: cute.Tensor\n    arch: cutlass.Constexpr[int] = 80\n    softmax_scale: Float32 | None = None\n\n    @staticmethod\n    def create(\n        scale_log2: Float32,\n        num_rows: cutlass.Constexpr[int],\n        arch: cutlass.Constexpr[int] = 80,\n        softmax_scale: Float32 | None = None,\n    ):\n        row_max = cute.make_rmem_tensor(num_rows, Float32)\n        row_sum = cute.make_rmem_tensor(num_rows, Float32)\n        return Softmax(scale_log2, num_rows, row_max, row_sum, arch, softmax_scale)\n\n    def reset(self) -> None:\n        self.row_max.fill(-Float32.inf)\n        self.row_sum.fill(0.0)\n\n    def _compute_row_max(\n        self, acc_S_row: cute.TensorSSA, init_val: float | Float32 | None = None\n    ) -> Float32:\n        return utils.fmax_reduce(acc_S_row, init_val, arch=self.arch)\n\n    def _compute_row_sum(\n        self, acc_S_row_exp: cute.TensorSSA, init_val: float | Float32 | None = None\n    ) -> Float32:\n        return utils.fadd_reduce(acc_S_row_exp, init_val, arch=self.arch)\n\n    @cute.jit\n    def online_softmax(\n        self,\n        acc_S: cute.Tensor,\n        is_first: cutlass.Constexpr[bool] = False,\n        check_inf: cutlass.Constexpr[bool] = True,\n    ) -> cute.Tensor:\n        \"\"\"Apply online softmax and return the row_scale to rescale O.\n\n        :param acc_S: acc_S tensor\n        :type acc_S: cute.Tensor\n        :param is_first: is first n_block\n        :type is_first: cutlass.Constexpr\n        \"\"\"\n        # Change acc_S to M,N layout view.\n        acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S)\n        row_scale = cute.make_fragment_like(self.row_max, Float32)\n\n        row_max = self.row_max\n        row_sum = self.row_sum\n        scale_log2 = self.scale_log2\n        arch = self.arch\n\n        # Each iteration processes one row of acc_S\n        for r in cutlass.range(cute.size(row_max), unroll_full=True):\n            acc_S_row = acc_S_mn[r, None].load()  # (n_block_size)\n\n            row_max_cur = utils.fmax_reduce(\n                acc_S_row,\n                init_val=row_max[r] if cutlass.const_expr(not is_first) else None,\n                arch=arch,\n            )\n\n            row_max_cur = cute.arch.warp_reduction_max(row_max_cur, threads_in_group=4)\n            # Update row_max before changing row_max_cur to safe value for -inf\n            row_max_prev = row_max[r]\n            row_max[r] = row_max_cur\n\n            if cutlass.const_expr(check_inf):\n                row_max_cur = 0.0 if row_max_cur == -Float32.inf else row_max_cur\n\n            if cutlass.const_expr(is_first):\n                row_max_cur_scaled = row_max_cur * scale_log2\n                acc_S_row_exp = cute.math.exp2(\n                    acc_S_row * scale_log2 - row_max_cur_scaled, fastmath=True\n                )\n                acc_S_row_sum = utils.fadd_reduce(acc_S_row_exp, init_val=None, arch=arch)\n                row_scale[r] = 1.0\n            else:\n                row_max_cur_scaled = row_max_cur * scale_log2\n                acc_S_row_exp = cute.math.exp2(\n                    acc_S_row * scale_log2 - row_max_cur_scaled, fastmath=True\n                )\n                # row_scale[r] = cute.math.exp2(row_max_prev * self.scale_log2 - row_max_cur_scaled)\n                row_scale[r] = cute.math.exp2(\n                    (row_max_prev - row_max_cur) * scale_log2, fastmath=True\n                )\n                acc_S_row_sum = utils.fadd_reduce(\n                    acc_S_row_exp, init_val=row_sum[r] * row_scale[r], arch=arch\n                )\n\n            row_sum[r] = acc_S_row_sum\n            acc_S_mn[r, None].store(acc_S_row_exp)\n\n        return row_scale\n\n    @cute.jit\n    def finalize(\n        self, final_scale: Float32 = 1.0, sink_val: Float32 | cute.Tensor | None = None\n    ) -> cute.Tensor:\n        \"\"\"Finalize the online softmax by computing the scale and logsumexp.\"\"\"\n        if cutlass.const_expr(sink_val is not None and isinstance(sink_val, cute.Tensor)):\n            assert cute.size(sink_val) == cute.size(self.row_sum)\n        row_sum = self.row_sum\n        row_max = self.row_max\n        scale_log2 = self.scale_log2\n\n        # quad reduction for row_sum as we didn't do it during each iteration of online softmax\n        row_sum.store(utils.warp_reduce(row_sum.load(), operator.add, width=4))\n        row_scale = cute.make_fragment_like(row_max, Float32)\n\n        for r in cutlass.range(cute.size(row_sum), unroll_full=True):\n            if cutlass.const_expr(sink_val is not None):\n                sink_val_cur = sink_val if not isinstance(sink_val, cute.Tensor) else sink_val[r]\n                LOG2_E = math.log2(math.e)\n                row_sum[r] += cute.math.exp2(\n                    sink_val_cur * LOG2_E - row_max[r] * scale_log2, fastmath=True\n                )\n\n            # if row_sum is zero or nan, set acc_O_mn_row to 1.0\n            acc_O_mn_row_is_zero_or_nan = row_sum[r] == 0.0 or row_sum[r] != row_sum[r]\n            row_scale[r] = (\n                cute.arch.rcp_approx(row_sum[r] if not acc_O_mn_row_is_zero_or_nan else 1.0)\n            ) * final_scale\n            row_sum_cur = row_sum[r]\n            LN2 = math.log(2.0)\n            row_sum[r] = (\n                (row_max[r] * scale_log2 + cute.math.log2(row_sum_cur, fastmath=True)) * LN2\n                if not acc_O_mn_row_is_zero_or_nan\n                else -Float32.inf\n            )\n        return row_scale\n\n    @cute.jit\n    def rescale_O(self, acc_O: cute.Tensor, row_scale: cute.Tensor) -> None:\n        \"\"\"Scale each row of acc_O by the given scale tensor.\n        :param acc_O: input tensor\n        :type acc_O: cute.Tensor\n        :param row_scale: row_scale tensor\n        :type row_scale: cute.Tensor\n        \"\"\"\n        acc_O_mn = layout_utils.reshape_acc_to_mn(acc_O)\n        assert cute.size(row_scale) == cute.size(acc_O_mn, mode=[0])\n        for r in cutlass.range(cute.size(row_scale), unroll_full=True):\n            acc_O_mn[r, None].store(acc_O_mn[r, None].load() * row_scale[r])\n\n\n@dataclass\nclass SoftmaxSm100(Softmax):\n    rescale_threshold: cutlass.Constexpr[float] = 0.0\n\n    @staticmethod\n    def create(\n        scale_log2: Float32,\n        rescale_threshold: cutlass.Constexpr[float] = 0.0,\n        softmax_scale: Float32 | None = None,\n    ):\n        num_rows = 1\n        arch = 100\n        row_max = cute.make_rmem_tensor(num_rows, Float32)\n        row_sum = cute.make_rmem_tensor(num_rows, Float32)\n        return SoftmaxSm100(\n            scale_log2,\n            num_rows,\n            row_max,\n            row_sum,\n            arch,\n            softmax_scale,\n            rescale_threshold=rescale_threshold,\n        )\n\n    @cute.jit\n    def update_row_max(self, acc_S_row: cute.TensorSSA, is_first: int) -> Tuple[Float32, Float32]:\n        if cutlass.const_expr(is_first):\n            row_max_new = self._compute_row_max(acc_S_row)\n            row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0\n            acc_scale = 0.0\n        else:\n            row_max_old = self.row_max[0]\n            row_max_new = self._compute_row_max(acc_S_row, init_val=row_max_old)\n            row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0\n            acc_scale_ = (row_max_old - row_max_safe) * self.scale_log2\n            acc_scale = cute.math.exp2(acc_scale_, fastmath=True)\n            if cutlass.const_expr(self.rescale_threshold > 0.0):\n                if acc_scale_ >= -self.rescale_threshold:\n                    row_max_new = row_max_old\n                    row_max_safe = row_max_old\n                    acc_scale = 1.0\n        self.row_max[0] = row_max_new\n        return row_max_safe, acc_scale\n\n    def update_row_sum(\n        self, acc_S_row_exp: cute.TensorSSA, row_scale: Float32, is_first: int = False\n    ) -> None:\n        init_val = self.row_sum[0] * row_scale if cutlass.const_expr(not is_first) else None\n        # self.row_sum[0] = self._compute_row_sum(acc_S_row_exp, init_val=self.row_sum[0] * row_scale)\n        self.row_sum[0] = self._compute_row_sum(acc_S_row_exp, init_val=init_val)\n        # tmp = self._compute_row_sum(acc_S_row_exp)\n        # self.row_sum[0] = self.row_sum[0] * row_scale + tmp\n\n    @cute.jit\n    def scale_subtract_rowmax(\n        self,\n        acc_S_row: cute.Tensor,\n        row_max: Float32,\n    ):\n        assert cute.size(acc_S_row.shape) % 2 == 0, \"acc_S_row must have an even number of elements\"\n        row_max_scaled = row_max * self.scale_log2\n        for i in cutlass.range(0, cute.size(acc_S_row.shape), 2, unroll_full=True):\n            acc_S_row[i], acc_S_row[i + 1] = cute.arch.fma_packed_f32x2(\n                (acc_S_row[i], acc_S_row[i + 1]),\n                (self.scale_log2, self.scale_log2),\n                (-row_max_scaled, -row_max_scaled),\n            )\n\n    @cute.jit\n    def apply_exp2_convert(\n        self,\n        acc_S_row: cute.Tensor,\n        acc_S_row_converted: cute.Tensor,\n        ex2_emu_freq: cutlass.Constexpr[int] = 0,\n        ex2_emu_res: cutlass.Constexpr[int] = 4,\n        ex2_emu_start_frg: cutlass.Constexpr[int] = 0,\n    ):\n        assert cute.size(acc_S_row.shape) % 2 == 0, \"acc_S_row must have an even number of elements\"\n        frg_tile = 32\n        assert frg_tile % 2 == 0\n        frg_cnt = cute.size(acc_S_row) // frg_tile\n        assert cute.size(acc_S_row) % frg_tile == 0\n        acc_S_row_frg = cute.logical_divide(acc_S_row, cute.make_layout(frg_tile))\n        acc_S_row_converted_frg = cute.logical_divide(\n            acc_S_row_converted, cute.make_layout(frg_tile)\n        )\n        for j in cutlass.range_constexpr(frg_cnt):\n            for k in cutlass.range_constexpr(0, cute.size(acc_S_row_frg, mode=[0]), 2):\n                # acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True)\n                # acc_S_row_frg[k + 1, j] = cute.math.exp2(acc_S_row_frg[k + 1, j], fastmath=True)\n                if cutlass.const_expr(ex2_emu_freq == 0):\n                    acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True)\n                    acc_S_row_frg[k + 1, j] = cute.math.exp2(acc_S_row_frg[k + 1, j], fastmath=True)\n                else:\n                    if cutlass.const_expr(\n                        k % ex2_emu_freq < ex2_emu_freq - ex2_emu_res\n                        or j >= frg_cnt - 1\n                        or j < ex2_emu_start_frg\n                    ):\n                        acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True)\n                        acc_S_row_frg[k + 1, j] = cute.math.exp2(\n                            acc_S_row_frg[k + 1, j], fastmath=True\n                        )\n                    else:\n                        # acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = utils.e2e_asm2(acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j])\n                        acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = utils.ex2_emulation_2(\n                            acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j]\n                        )\n            acc_S_row_converted_frg[None, j].store(\n                acc_S_row_frg[None, j].load().to(acc_S_row_converted.element_type)\n            )\n\n    @cute.jit\n    def scale_apply_exp2_convert(\n        self,\n        acc_S_row: cute.Tensor,\n        row_max: Float32,\n        acc_S_row_converted: cute.Tensor,\n    ):\n        assert cute.size(acc_S_row.shape) % 2 == 0, \"acc_S_row must have an even number of elements\"\n        minus_row_max_scaled = -row_max * self.scale_log2\n        for i in cutlass.range_constexpr(0, cute.size(acc_S_row.shape), 2):\n            acc_S_row[i], acc_S_row[i + 1] = cute.arch.fma_packed_f32x2(\n                (acc_S_row[i], acc_S_row[i + 1]),\n                (self.scale_log2, self.scale_log2),\n                (minus_row_max_scaled, minus_row_max_scaled),\n            )\n\n        # for i in cutlass.range_constexpr(0, cute.size(acc_S_row.shape), 2):\n        #     acc_S_row[i], acc_S_row[i + 1] = cute.arch.fma_packed_f32x2(\n        #         (acc_S_row[i], acc_S_row[i + 1]),\n        #         (self.scale_log2, self.scale_log2),\n        #         (minus_row_max_scaled, minus_row_max_scaled),\n        #     )\n        #     acc_S_row[i] = cute.math.exp2(acc_S_row[i], fastmath=True)\n        #     acc_S_row[i + 1] = cute.math.exp2(acc_S_row[i + 1], fastmath=True)\n\n        frg_tile = 32\n        assert frg_tile % 2 == 0\n        frg_cnt = cute.size(acc_S_row) // frg_tile\n        assert cute.size(acc_S_row) % frg_tile == 0\n        acc_S_row_frg = cute.logical_divide(acc_S_row, cute.make_layout(frg_tile))\n        acc_S_row_converted_frg = cute.logical_divide(\n            acc_S_row_converted, cute.make_layout(frg_tile)\n        )\n        for j in cutlass.range_constexpr(frg_cnt):\n            for k in cutlass.range_constexpr(0, cute.size(acc_S_row_frg, mode=[0]), 2):\n                # acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = (\n                #     cute.arch.fma_packed_f32x2(\n                #         (acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j]),\n                #         (self.scale_log2, self.scale_log2),\n                #         (minus_row_max_scaled, minus_row_max_scaled),\n                #     )\n                # )\n                # acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True)\n                # acc_S_row_frg[k + 1, j] = cute.math.exp2(acc_S_row_frg[k + 1, j], fastmath=True)\n                acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True)\n                acc_S_row_frg[k + 1, j] = cute.math.exp2(acc_S_row_frg[k + 1, j], fastmath=True)\n            acc_S_row_converted_frg[None, j].store(\n                acc_S_row_frg[None, j].load().to(acc_S_row_converted.element_type)\n            )\n\n\n@cute.jit\ndef floor_if_packed(\n    q_idx,\n    qhead_per_kvhead: cutlass.Constexpr[int],\n) -> cute.Tensor:\n    \"\"\"Convert q_idx to packed format for Pack-GQA.\"\"\"\n    if cutlass.const_expr(qhead_per_kvhead == 1):\n        return q_idx\n    return q_idx // qhead_per_kvhead\n\n\n@cute.jit\ndef apply_score_mod_inner(\n    score_tensor,\n    index_tensor,\n    score_mod: cutlass.Constexpr,\n    batch_idx,\n    head_idx,\n    softmax_scale,\n    vec_size: cutlass.Constexpr,\n    qk_acc_dtype: cutlass.Constexpr,\n    aux_tensors,\n    fastdiv_mods,\n    seqlen_info: SeqlenInfoQK,\n    constant_q_idx: cutlass.Constexpr,\n    qhead_per_kvhead: cutlass.Constexpr[int] = 1,\n    transpose_indices: cutlass.Constexpr[bool] = False,\n):\n    \"\"\"Shared implementation for applying score modification.\n\n    Args:\n        score_tensor: The scores to modify (acc_S for flash_fwd, tSrS_t2r for sm100)\n        index_tensor: Index positions (tScS for flash_fwd, tScS_t2r for sm100)\n        score_mod: The score modification function to apply\n        batch_idx: Batch index\n        head_idx: Head index\n        softmax_scale: Scale to apply\n        vec_size: Vector size for processing elements\n        qk_acc_dtype: Data type for accumulator\n        aux_tensors: Optional aux_tensors for FlexAttention\n        fastdiv_mods: Tuple of (seqlen_q_divmod, seqlen_k_divmod) for wrapping\n        seqlen_info: Sequence length info\n        constant_q_idx: If provided, use this constant for all q_idx values\n                        If None, compute q_idx per-element\n        qhead_per_kvhead_packgqa: Pack-GQA replication factor. Divide q_idx by this\n                                  when greater than 1 so score mods see logical heads.\n        transpose_indices: If True, swap q_idx/kv_idx in index_tensor (for bwd kernel where S is transposed)\n    \"\"\"\n    # Index positions in the index_tensor tuple\n    # Forward: index_tensor[...][0] = q_idx, index_tensor[...][1] = kv_idx\n    # Backward (transposed): index_tensor[...][0] = kv_idx, index_tensor[...][1] = q_idx\n    if cutlass.const_expr(transpose_indices):\n        q_idx_pos = cutlass.const_expr(1)\n        kv_idx_pos = cutlass.const_expr(0)\n    else:\n        q_idx_pos = cutlass.const_expr(0)\n        kv_idx_pos = cutlass.const_expr(1)\n\n    n_vals = cutlass.const_expr(cute.size(score_tensor.shape))\n    score_vec = cute.make_rmem_tensor(vec_size, qk_acc_dtype)\n    kv_idx_vec = cute.make_rmem_tensor(vec_size, cutlass.Int32)\n\n    # SSA values for batch (constant across all elements)\n    batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32).broadcast_to((vec_size,))\n\n    # Handle q_idx based on whether it's constant\n    q_idx_vec = cute.make_rmem_tensor(vec_size, cutlass.Int32)\n\n    # For Pack-GQA with non-constant q_idx, we need per-element head indices\n    # since a thread my process multiple query head indices\n    if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None):\n        head_idx_vec = cute.make_rmem_tensor(vec_size, cutlass.Int32)\n\n    for i in cutlass.range(0, n_vals, vec_size, unroll_full=True):\n        for j in cutlass.range(vec_size, unroll_full=True):\n            score_vec[j] = score_tensor[i + j] * softmax_scale\n\n            # Extract head offset from packed q_idx for Pack-GQA\n            if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None):\n                q_idx_packed = index_tensor[i + j][q_idx_pos]\n                # Building up the logical q_head idx: final_q_head = kv_head * qhead_per_kvhead + (q_physical % qhead_per_kvhead)\n                q_idx_logical = q_idx_packed // qhead_per_kvhead\n                head_offset = q_idx_packed - q_idx_logical * qhead_per_kvhead\n                head_idx_vec[j] = head_idx * qhead_per_kvhead + head_offset\n\n            # If we will do loads we mod, in order to not read OOB\n            if cutlass.const_expr(aux_tensors is not None and fastdiv_mods is not None):\n                if cutlass.const_expr(constant_q_idx is None):\n                    seqlen_q_divmod, seqlen_k_divmod = fastdiv_mods\n                    q_idx_floored = floor_if_packed(\n                        index_tensor[i + j][q_idx_pos], qhead_per_kvhead\n                    )\n                    _, q_idx_wrapped = divmod(q_idx_floored, seqlen_q_divmod)\n                    q_idx_vec[j] = q_idx_wrapped\n                else:\n                    _, seqlen_k_divmod = fastdiv_mods\n\n                _, kv_idx_wrapped = divmod(index_tensor[i + j][kv_idx_pos], seqlen_k_divmod)\n                kv_idx_vec[j] = kv_idx_wrapped\n            else:\n                # No bounds checking - direct indexing\n                if constant_q_idx is None:\n                    q_idx_vec[j] = floor_if_packed(index_tensor[i + j][q_idx_pos], qhead_per_kvhead)\n                kv_idx_vec[j] = index_tensor[i + j][kv_idx_pos]\n\n        # Convert to SSA for score_mod call\n        score_ssa = score_vec.load()\n        kv_idx_ssa = kv_idx_vec.load()\n        if cutlass.const_expr(constant_q_idx is None):\n            q_idx_ssa = q_idx_vec.load()\n        else:\n            # NB we do not apply Pack-GQA division here, as constant_q_idx is assumed to already be logical\n            q_idx_const = constant_q_idx\n            q_idx_ssa = utils.scalar_to_ssa(q_idx_const, cutlass.Int32).broadcast_to((vec_size,))\n\n        # Compute head_idx_ssa: per-element for Pack-GQA with non-constant q_idx, constant otherwise\n        if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None):\n            head_idx_ssa = head_idx_vec.load()\n        else:\n            head_idx_ssa = utils.scalar_to_ssa(head_idx, cutlass.Int32).broadcast_to((vec_size,))\n\n        aux_args = []\n        if cutlass.const_expr(aux_tensors is not None):\n            aux_args = aux_tensors\n\n        post_mod_scores = score_mod(\n            score_ssa,\n            batch_idx_ssa,\n            head_idx_ssa,\n            q_idx=q_idx_ssa,\n            kv_idx=kv_idx_ssa,\n            seqlen_info=seqlen_info,\n            aux_tensors=aux_args,\n        )\n\n        # Write back modified scores\n        score_vec.store(post_mod_scores)\n        for j in cutlass.range(vec_size, unroll_full=True):\n            score_tensor[i + j] = score_vec[j]\n\n\n@cute.jit\ndef apply_score_mod_bwd_inner(\n    grad_tensor,\n    score_tensor,\n    index_tensor,\n    score_mod_bwd: cutlass.Constexpr,\n    batch_idx,\n    head_idx,\n    softmax_scale,\n    vec_size: cutlass.Constexpr,\n    qk_acc_dtype: cutlass.Constexpr,\n    aux_tensors,\n    fastdiv_mods,\n    seqlen_info,\n    constant_q_idx: cutlass.Constexpr,\n    qhead_per_kvhead: cutlass.Constexpr[int] = 1,\n    transpose_indices: cutlass.Constexpr[bool] = False,\n):\n    \"\"\"Apply backward score modification (joint graph).\n\n    Args:\n        grad_tensor: in/out: dlogits rewritten in-place with d(scaled_scores)\n        score_tensor: pre-mod scores (unscaled QK tile), scaled by softmax_scale internally\n        index_tensor: Index positions (same as forward)\n        score_mod_bwd: The backward score modification function (joint graph)\n        batch_idx: Batch index\n        head_idx: Head index\n        softmax_scale: Scale to apply to score_tensor\n        vec_size: Vector size for processing elements\n        qk_acc_dtype: Data type for accumulator\n        aux_tensors: Optional aux_tensors for FlexAttention\n        fastdiv_mods: Tuple of (seqlen_q_divmod, seqlen_k_divmod) for wrapping\n        seqlen_info: Sequence length info\n        constant_q_idx: If provided, use this constant for all q_idx values\n        qhead_per_kvhead: Pack-GQA replication factor\n        transpose_indices: If True, swap q_idx/kv_idx in index_tensor\n    \"\"\"\n    # Index positions in the index_tensor tuple\n    # Forward: index_tensor[...][0] = q_idx, index_tensor[...][1] = kv_idx\n    # Backward (transposed): index_tensor[...][0] = kv_idx, index_tensor[...][1] = q_idx\n    if cutlass.const_expr(transpose_indices):\n        q_idx_pos = cutlass.const_expr(1)\n        kv_idx_pos = cutlass.const_expr(0)\n    else:\n        q_idx_pos = cutlass.const_expr(0)\n        kv_idx_pos = cutlass.const_expr(1)\n    n_vals = cutlass.const_expr(cute.size(grad_tensor.shape))\n    grad_vec = cute.make_fragment(vec_size, qk_acc_dtype)\n    score_vec = cute.make_fragment(vec_size, qk_acc_dtype)\n    kv_idx_vec = cute.make_fragment(vec_size, cutlass.Int32)\n    batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32).broadcast_to((vec_size,))\n    q_idx_vec = cute.make_fragment(vec_size, cutlass.Int32)\n\n    # For Pack-GQA with non-constant q_idx, we need per-element head indices\n    if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None):\n        head_idx_vec = cute.make_fragment(vec_size, cutlass.Int32)\n\n    for i in cutlass.range(0, n_vals, vec_size, unroll_full=True):\n        for j in cutlass.range(vec_size, unroll_full=True):\n            grad_vec[j] = grad_tensor[i + j]\n            # Scale score so joint graph sees same value as forward score_mod\n            score_vec[j] = score_tensor[i + j] * softmax_scale\n\n            if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None):\n                q_idx_packed = index_tensor[i + j][q_idx_pos]\n                q_idx_logical = q_idx_packed // qhead_per_kvhead\n                head_offset = q_idx_packed - q_idx_logical * qhead_per_kvhead\n                head_idx_vec[j] = head_idx * qhead_per_kvhead + head_offset\n\n            if cutlass.const_expr(aux_tensors is not None and fastdiv_mods is not None):\n                if cutlass.const_expr(constant_q_idx is None):\n                    seqlen_q_divmod, seqlen_k_divmod = fastdiv_mods\n                    q_idx_floored = floor_if_packed(\n                        index_tensor[i + j][q_idx_pos], qhead_per_kvhead\n                    )\n                    _, q_idx_wrapped = divmod(q_idx_floored, seqlen_q_divmod)\n                    q_idx_vec[j] = q_idx_wrapped\n                else:\n                    _, seqlen_k_divmod = fastdiv_mods\n\n                _, kv_idx_wrapped = divmod(index_tensor[i + j][kv_idx_pos], seqlen_k_divmod)\n                kv_idx_vec[j] = kv_idx_wrapped\n            else:\n                # No bounds checking - direct indexing\n                if constant_q_idx is None:\n                    q_idx_vec[j] = floor_if_packed(index_tensor[i + j][q_idx_pos], qhead_per_kvhead)\n                kv_idx_vec[j] = index_tensor[i + j][kv_idx_pos]\n\n        grad_ssa = grad_vec.load()\n        score_ssa = score_vec.load()\n        kv_idx_ssa = kv_idx_vec.load()\n\n        if cutlass.const_expr(constant_q_idx is None):\n            q_idx_ssa = q_idx_vec.load()\n        else:\n            q_idx_ssa = utils.scalar_to_ssa(constant_q_idx, cutlass.Int32).broadcast_to((vec_size,))\n\n        if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None):\n            head_idx_ssa = head_idx_vec.load()\n        else:\n            head_idx_ssa = utils.scalar_to_ssa(head_idx, cutlass.Int32).broadcast_to((vec_size,))\n\n        aux_args = []\n        if cutlass.const_expr(aux_tensors is not None):\n            aux_args = aux_tensors\n\n        grad_out_ssa = score_mod_bwd(\n            grad_ssa,\n            score_ssa,\n            batch_idx_ssa,\n            head_idx_ssa,\n            q_idx=q_idx_ssa,\n            kv_idx=kv_idx_ssa,\n            seqlen_info=seqlen_info,\n            aux_tensors=aux_args,\n        )\n\n        grad_vec.store(grad_out_ssa)\n        for j in cutlass.range(vec_size, unroll_full=True):\n            grad_tensor[i + j] = grad_vec[j]\n"
  },
  {
    "path": "flash_attn/cute/testing.py",
    "content": "import math\nfrom contextlib import nullcontext\nfrom functools import wraps\nfrom typing import Optional\n\nimport torch\nimport torch.nn.functional as F\nfrom einops import rearrange, repeat\nfrom torch._guards import active_fake_mode\nfrom torch._subclasses.fake_tensor import FakeTensorMode\n\n\nclass IndexFirstAxis(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, input, indices):\n        ctx.save_for_backward(indices)\n        assert input.ndim >= 2\n        ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]\n        second_dim = other_shape.numel()\n        return torch.gather(\n            rearrange(input, \"b ... -> b (...)\"),\n            0,\n            repeat(indices, \"z -> z d\", d=second_dim),\n        ).reshape(-1, *other_shape)\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        (indices,) = ctx.saved_tensors\n        assert grad_output.ndim >= 2\n        other_shape = grad_output.shape[1:]\n        grad_output = rearrange(grad_output, \"b ... -> b (...)\")\n        grad_input = torch.zeros(\n            [ctx.first_axis_dim, grad_output.shape[1]],\n            device=grad_output.device,\n            dtype=grad_output.dtype,\n        )\n        grad_input.scatter_(0, repeat(indices, \"z -> z d\", d=grad_output.shape[1]), grad_output)\n        return grad_input.reshape(ctx.first_axis_dim, *other_shape), None\n\n\nindex_first_axis = IndexFirstAxis.apply\n\n\nclass IndexPutFirstAxis(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, values, indices, first_axis_dim):\n        ctx.save_for_backward(indices)\n        assert indices.ndim == 1\n        assert values.ndim >= 2\n        output = torch.zeros(\n            first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype\n        )\n        output[indices] = values\n        return output\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        (indices,) = ctx.saved_tensors\n        grad_values = grad_output[indices]\n        return grad_values, None, None\n\n\nindex_put_first_axis = IndexPutFirstAxis.apply\n\n\ndef unpad_input(hidden_states, attention_mask, unused_mask=None):\n    all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask\n    seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)\n    used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)\n    in_fake_mode = active_fake_mode() is not None\n    if not in_fake_mode:\n        indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()\n        max_seqlen_in_batch = seqlens_in_batch.max().item()\n    else:\n        # torch.nonzero and .item() are not supported in FakeTensorMode\n        batch_size, seqlen = attention_mask.shape\n        indices = torch.arange(batch_size * seqlen, device=hidden_states.device)\n        max_seqlen_in_batch = seqlen\n    cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))\n    return (\n        index_first_axis(rearrange(hidden_states, \"b s ... -> (b s) ...\"), indices),\n        indices,\n        cu_seqlens,\n        max_seqlen_in_batch,\n        used_seqlens_in_batch,\n    )\n\n\ndef pad_input(hidden_states, indices, batch, seqlen):\n    output = index_put_first_axis(hidden_states, indices, batch * seqlen)\n    return rearrange(output, \"(b s) ... -> b s ...\", b=batch)\n\n\ndef generate_random_padding_mask(max_seqlen, batch_size, device, mode=\"random\", zero_lengths=False):\n    assert mode in [\"full\", \"random\", \"third\"]\n    if mode == \"full\":\n        lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32)\n    elif mode == \"random\":\n        lengths = torch.randint(\n            max(0 if zero_lengths else 1, max_seqlen - 20),\n            max_seqlen + 1,\n            (batch_size, 1),\n            device=device,\n        )\n    else:\n        lengths = torch.randint(\n            max(0 if zero_lengths else 1, max_seqlen // 3),\n            max_seqlen + 1,\n            (batch_size, 1),\n            device=device,\n        )\n\n    if zero_lengths:\n        for i in range(batch_size):\n            if i % 5 == 0:\n                lengths[i] = 0\n        lengths[-1] = 0\n    padding_mask = (\n        repeat(torch.arange(max_seqlen, device=device), \"s -> b s\", b=batch_size) < lengths\n    )\n    return padding_mask\n\n\ndef generate_qkv(\n    q,\n    k,\n    v,\n    query_padding_mask=None,\n    key_padding_mask=None,\n    qv=None,\n    kvpacked=False,\n    qkvpacked=False,\n    query_unused_mask=None,\n    key_unused_mask=None,\n):\n    assert not (kvpacked and qkvpacked)\n    batch_size, seqlen_q, nheads, d = q.shape\n    d_v = v.shape[-1]\n    _, seqlen_k, nheads_k, _ = k.shape\n    assert k.shape == (batch_size, seqlen_k, nheads_k, d)\n    assert v.shape == (batch_size, seqlen_k, nheads_k, d_v)\n    if query_unused_mask is not None or key_unused_mask is not None:\n        assert not kvpacked\n        assert not qkvpacked\n\n    if query_padding_mask is not None:\n        q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q = unpad_input(\n            q, query_padding_mask, query_unused_mask\n        )\n        output_pad_fn = lambda output_unpad: pad_input(\n            output_unpad, indices_q, batch_size, seqlen_q\n        )\n        qv_unpad = rearrange(qv, \"b s ... -> (b s) ...\")[indices_q] if qv is not None else None\n    else:\n        q_unpad = rearrange(q, \"b s h d -> (b s) h d\")\n        cu_seqlens_q = torch.arange(\n            0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device\n        )\n        seqused_q = None\n        max_seqlen_q = seqlen_q\n        output_pad_fn = lambda output_unpad: rearrange(\n            output_unpad, \"(b s) h d -> b s h d\", b=batch_size\n        )\n        qv_unpad = rearrange(qv, \"b s ... -> (b s) ...\") if qv is not None else None\n\n    if key_padding_mask is not None:\n        k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k = unpad_input(\n            k, key_padding_mask, key_unused_mask\n        )\n        v_unpad, *_ = unpad_input(v, key_padding_mask, key_unused_mask)\n    else:\n        k_unpad = rearrange(k, \"b s h d -> (b s) h d\")\n        v_unpad = rearrange(v, \"b s h d -> (b s) h d\")\n        cu_seqlens_k = torch.arange(\n            0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device\n        )\n        seqused_k = None\n        max_seqlen_k = seqlen_k\n\n    if qkvpacked:\n        assert (query_padding_mask == key_padding_mask).all()\n        assert nheads == nheads_k\n        qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)\n        qkv = torch.stack([q, k, v], dim=2)\n        if query_padding_mask is not None:\n            dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q)\n        else:\n            dqkv_pad_fn = lambda dqkv_unpad: rearrange(\n                dqkv_unpad, \"(b s) t h d -> b s t h d\", b=batch_size\n            )\n        return (\n            qkv_unpad.detach().requires_grad_(),\n            cu_seqlens_q,\n            max_seqlen_q,\n            qkv.detach().requires_grad_(),\n            output_pad_fn,\n            dqkv_pad_fn,\n        )\n    elif kvpacked:\n        kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)\n        kv = torch.stack([k, v], dim=2)\n        dq_pad_fn = output_pad_fn\n        if key_padding_mask is not None:\n            dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k)\n        else:\n            dkv_pad_fn = lambda dkv_unpad: rearrange(\n                dkv_unpad, \"(b s) t h d -> b s t h d\", b=batch_size\n            )\n        return (\n            q_unpad.detach().requires_grad_(),\n            kv_unpad.detach().requires_grad_(),\n            cu_seqlens_q,\n            cu_seqlens_k,\n            max_seqlen_q,\n            max_seqlen_k,\n            q.detach().requires_grad_(),\n            kv.detach().requires_grad_(),\n            output_pad_fn,\n            dq_pad_fn,\n            dkv_pad_fn,\n        )\n    else:\n        dq_pad_fn = output_pad_fn\n        if key_padding_mask is not None:\n            dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k)\n        else:\n            dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, \"(b s) h d -> b s h d\", b=batch_size)\n        return (\n            q_unpad.detach().requires_grad_(),\n            k_unpad.detach().requires_grad_(),\n            v_unpad.detach().requires_grad_(),\n            qv_unpad.detach() if qv is not None else None,\n            cu_seqlens_q,\n            cu_seqlens_k,\n            seqused_q,\n            seqused_k,\n            max_seqlen_q,\n            max_seqlen_k,\n            q.detach().requires_grad_(),\n            k.detach().requires_grad_(),\n            v.detach().requires_grad_(),\n            qv.detach() if qv is not None else None,\n            output_pad_fn,\n            dq_pad_fn,\n            dk_pad_fn,\n        )\n\n\ndef construct_local_mask(\n    seqlen_q,\n    seqlen_k,\n    window_size=(None, None),\n    sink_token_length=0,\n    query_padding_mask=None,\n    key_padding_mask=None,\n    key_leftpad=None,\n    device=None,\n):\n    row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), \"s -> s 1\")\n    col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)\n    if key_leftpad is not None:\n        key_leftpad = rearrange(key_leftpad, \"b -> b 1 1 1\")\n        col_idx = repeat(col_idx, \"s -> b 1 1 s\", b=key_leftpad.shape[0])\n        col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)\n    sk = (\n        seqlen_k\n        if key_padding_mask is None\n        else rearrange(key_padding_mask.sum(-1), \"b -> b 1 1 1\")\n    )\n    sq = (\n        seqlen_q\n        if query_padding_mask is None\n        else rearrange(query_padding_mask.sum(-1), \"b -> b 1 1 1\")\n    )\n    if window_size[0] is None:\n        return col_idx > row_idx + sk - sq + window_size[1]\n    else:\n        sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk\n        if window_size[1] is None:\n            local_mask_left = col_idx > sk\n        else:\n            local_mask_left = col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk)\n        return torch.logical_or(\n            local_mask_left,\n            torch.logical_and(\n                col_idx < row_idx + sk - sq - window_size[0], col_idx >= sink_token_length\n            ),\n        )\n\n\ndef construct_chunk_mask(\n    seqlen_q,\n    seqlen_k,\n    attention_chunk,\n    query_padding_mask=None,\n    key_padding_mask=None,\n    key_leftpad=None,\n    device=None,\n):\n    row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), \"s -> s 1\")\n    col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)\n    if key_leftpad is not None:\n        key_leftpad = rearrange(key_leftpad, \"b -> b 1 1 1\")\n        col_idx = repeat(col_idx, \"s -> b 1 1 s\", b=key_leftpad.shape[0])\n        col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)\n    sk = (\n        seqlen_k\n        if key_padding_mask is None\n        else rearrange(key_padding_mask.sum(-1), \"b -> b 1 1 1\")\n    )\n    sq = (\n        seqlen_q\n        if query_padding_mask is None\n        else rearrange(query_padding_mask.sum(-1), \"b -> b 1 1 1\")\n    )\n    sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk\n    col_limit_left_chunk = row_idx + sk - sq - (row_idx + sk - sq) % attention_chunk\n    return torch.logical_or(\n        col_idx < col_limit_left_chunk, col_idx >= col_limit_left_chunk + attention_chunk\n    )\n\n\ndef attention_ref(\n    q,\n    k,\n    v,\n    query_padding_mask=None,\n    key_padding_mask=None,\n    key_leftpad=None,\n    attn_bias=None,\n    dropout_p=0.0,\n    dropout_mask=None,\n    causal=False,\n    qv=None,\n    q_descale=None,\n    k_descale=None,\n    v_descale=None,\n    window_size=(None, None),\n    attention_chunk=0,\n    sink_token_length=0,\n    learnable_sink: Optional[torch.Tensor] = None,\n    softcap=0.0,\n    upcast=True,\n    reorder_ops=False,\n    intermediate_dtype=None,\n):\n    if causal:\n        window_size = (window_size[0], 0)\n    dtype_og = q.dtype\n    if upcast:\n        q, k, v = q.float(), k.float(), v.float()\n        qv = qv.float() if qv is not None else None\n    if q_descale is not None:\n        q_descale = repeat(q_descale, \"b h -> b 1 (h g) 1\", g=q.shape[2] // k.shape[2])\n        q = (q.float() * q_descale).to(q.dtype)\n        qv = (qv.float() * q_descale).to(qv.dtype) if qv is not None else None\n    if k_descale is not None:\n        k = (k.float() * rearrange(k_descale, \"b h -> b 1 h 1\")).to(dtype=k.dtype)\n    if v_descale is not None:\n        v = (v.float() * rearrange(v_descale, \"b h -> b 1 h 1\")).to(dtype=v.dtype)\n    seqlen_q, seqlen_k = q.shape[1], k.shape[1]\n    k = repeat(k, \"b s h d -> b s (h g) d\", g=q.shape[2] // k.shape[2])\n    v = repeat(v, \"b s h d -> b s (h g) d\", g=q.shape[2] // v.shape[2])\n    d = q.shape[-1]\n    dv = v.shape[-1]\n    softmax_scale = 1.0 / math.sqrt(d if qv is None else d + dv)\n    if not reorder_ops:\n        scores = torch.einsum(\"bthd,bshd->bhts\", q * softmax_scale, k)\n    else:\n        scores = torch.einsum(\"bthd,bshd->bhts\", q, k * softmax_scale)\n    if qv is not None:\n        scores = scores + torch.einsum(\"bthd,bshd->bhts\", qv * softmax_scale, v)\n    if softcap > 0:\n        scores = torch.tanh(scores / softcap) * softcap\n    if key_padding_mask is not None:\n        scores.masked_fill_(rearrange(~key_padding_mask, \"b s -> b 1 1 s\"), float(\"-inf\"))\n    local_mask = None\n    if window_size[0] is not None or window_size[1] is not None:\n        local_mask = construct_local_mask(\n            seqlen_q,\n            seqlen_k,\n            window_size,\n            sink_token_length,\n            query_padding_mask,\n            key_padding_mask,\n            key_leftpad=key_leftpad,\n            device=q.device,\n        )\n    if attention_chunk > 0:\n        chunk_mask = construct_chunk_mask(\n            seqlen_q,\n            seqlen_k,\n            attention_chunk,\n            query_padding_mask,\n            key_padding_mask,\n            key_leftpad=key_leftpad,\n            device=q.device,\n        )\n        local_mask = (\n            torch.logical_or(local_mask, chunk_mask) if local_mask is not None else chunk_mask\n        )\n    if local_mask is not None:\n        scores.masked_fill_(local_mask, float(\"-inf\"))\n    if attn_bias is not None:\n        scores = scores + attn_bias\n    if learnable_sink is None:\n        attention = torch.softmax(scores, dim=-1).to(v.dtype)\n    else:\n        scores_fp32 = scores.to(torch.float32)\n        logits_max = torch.amax(scores_fp32, dim=-1, keepdim=True)\n        learnable_sink = rearrange(learnable_sink, \"h -> h 1 1\")\n        logits_or_sinks_max = torch.maximum(learnable_sink, logits_max)\n        unnormalized_scores = torch.exp(scores_fp32 - logits_or_sinks_max)\n        normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + torch.exp(\n            learnable_sink - logits_or_sinks_max\n        )\n        attention = (unnormalized_scores / normalizer).to(v.dtype)\n    if query_padding_mask is not None:\n        attention = attention.masked_fill(rearrange(~query_padding_mask, \"b s -> b 1 s 1\"), 0.0)\n    if key_padding_mask is not None:\n        attention = attention.masked_fill(rearrange(~key_padding_mask, \"b s -> b 1 1 s\"), 0.0)\n    if local_mask is not None:\n        attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0)\n    dropout_scaling = 1.0 / (1 - dropout_p)\n    if dropout_mask is not None:\n        attention_drop = attention.masked_fill(~dropout_mask, 0.0)\n    else:\n        attention_drop = attention\n    if intermediate_dtype is not None:\n        attention_drop = attention_drop.to(intermediate_dtype).to(attention_drop.dtype)\n    output = torch.einsum(\"bhts,bshd->bthd\", attention_drop, v * dropout_scaling)\n    if query_padding_mask is not None:\n        output.masked_fill_(rearrange(~query_padding_mask, \"b s -> b s 1 1\"), 0.0)\n    return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)\n\n\ndef maybe_fake_tensor_mode(fake: bool = True):\n    \"\"\"\n    One way to populate/pre-compile cache is to use torch fake tensor mode,\n    which does not allocate actual GPU tensors but retains tensor shape/dtype\n    metadata for cute.compile.\n    \"\"\"\n\n    def decorator(fn):\n        @wraps(fn)\n        def wrapper(*args, **kwargs):\n            with FakeTensorMode() if fake else nullcontext():\n                return fn(*args, **kwargs)\n\n        return wrapper\n\n    return decorator\n\n\ndef is_fake_mode() -> bool:\n    return active_fake_mode() is not None\n"
  },
  {
    "path": "flash_attn/cute/tile_scheduler.py",
    "content": "# Copyright (c) 2025, Tri Dao.\n\nfrom typing import Optional, Tuple\nfrom dataclasses import dataclass\n\ntry:\n    from typing import override\nexcept ImportError:  # Python < 3.12\n    from typing_extensions import override\n\nimport cutlass\nfrom cutlass._mlir import ir\nimport cutlass.cute as cute\nfrom cutlass import Int32, const_expr\nfrom cutlass.cute import FastDivmodDivisor\n\nfrom quack.cute_dsl_utils import ParamsBase\n\nimport flash_attn.cute.utils as utils\nfrom flash_attn.cute.fast_math import clz\n\n\nclass WorkTileInfo(cutlass.utils.WorkTileInfo):\n    \"\"\"Altered WorkTileInfo which includes four axes: (block, head, batch, split)\"\"\"\n\n    @override\n    def __new_from_mlir_values__(self, values: list[ir.Value]) -> \"WorkTileInfo\":\n        assert len(values) == 5\n        new_tile_idx = cutlass.new_from_mlir_values(self._tile_idx, values[:-1])\n        new_is_valid_tile = cutlass.new_from_mlir_values(self._is_valid_tile, [values[-1]])\n        return WorkTileInfo(new_tile_idx, new_is_valid_tile)\n\n\n@dataclass\nclass TileSchedulerArguments(ParamsBase):\n    num_block: Int32\n    num_head: Int32\n    num_batch: Int32\n    num_splits: Int32\n    seqlen_k: Int32\n    headdim: Int32\n    headdim_v: Int32\n    total_q: Int32\n    tile_shape_mn: cutlass.Constexpr[Tuple[int, int]]\n    cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1)\n    mCuSeqlensQ: Optional[cute.Tensor] = None\n    mSeqUsedQ: Optional[cute.Tensor] = None\n    qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1\n    element_size: cutlass.Constexpr[int] = 2\n    is_persistent: cutlass.Constexpr[bool] = False\n    lpt: cutlass.Constexpr[bool] = False\n    is_split_kv: cutlass.Constexpr[bool] = False\n    head_swizzle: cutlass.Constexpr[bool] = False\n\n\nclass SingleTileScheduler:\n    @dataclass\n    class Params(ParamsBase):\n        num_block: Int32\n        num_head: Int32\n        num_batch: Int32\n        num_splits: Int32\n        num_splits_divmod: FastDivmodDivisor\n        is_split_kv: cutlass.Constexpr[bool] = False\n        cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1)\n\n        @staticmethod\n        def create(\n            args: TileSchedulerArguments, *, loc=None, ip=None\n        ) -> \"SingleTileScheduler.Params\":\n            return SingleTileScheduler.Params(\n                args.num_block,\n                args.num_head,\n                args.num_batch,\n                args.num_splits,\n                FastDivmodDivisor(args.num_splits),\n                args.is_split_kv,\n                args.cluster_shape_mn,\n            )\n\n    def __init__(self, params: Params, blk_coord: cute.Coord, *, loc=None, ip=None):\n        self.params = params\n        self._blk_coord = blk_coord\n        self._is_first_block = True\n        self._loc = loc\n        self._ip = ip\n\n    @staticmethod\n    def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params:\n        return SingleTileScheduler.Params.create(args, loc=loc, ip=ip)\n\n    @staticmethod\n    def create(params: Params, *, loc=None, ip=None) -> \"SingleTileScheduler\":\n        # if const_expr(cute.size(params.cluster_shape_mn) == 1):\n        #     blk_coord = cute.arch.block_idx()\n        # else:\n        #     # All CTAs in a cluster must get the same block coordinate\n        #     blk_coord = cute.arch.cluster_idx()\n        # Temporary set to block_idx until we sort out the best way to handle cluster\n        blk_coord = cute.arch.block_idx()\n        return SingleTileScheduler(params, blk_coord, loc=loc, ip=ip)\n\n    # called by host\n    @staticmethod\n    def get_grid_shape(\n        params: Params,\n        *,\n        loc=None,\n        ip=None,\n    ) -> Tuple[Int32, Int32, Int32]:\n        # TODO: this hard-codes the fact that we only use cluster = (1, 1) or (2, 1)\n        assert params.cluster_shape_mn[1] == 1, \"Only cluster_shape_mn[1] == 1 is supported\"\n        return (\n            cute.round_up(params.num_block, params.cluster_shape_mn[0]),\n            params.num_head * params.num_splits,\n            params.num_batch,\n        )\n\n    def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:\n        block_idx, head_idx, batch_idx = self._blk_coord\n        if const_expr(self.params.is_split_kv):\n            head_idx, split_idx = divmod(head_idx, self.params.num_splits_divmod)\n        else:\n            split_idx = Int32(0)\n        return WorkTileInfo(\n            (block_idx, head_idx, batch_idx, split_idx),\n            self._is_first_block,\n        )\n\n    def initial_work_tile_info(self, *, loc=None, ip=None):\n        return self.get_current_work(loc=loc, ip=ip)\n\n    def prefetch_next_work(self, *, loc=None, ip=None):\n        pass\n\n    def advance_to_next_work(self, *, loc=None, ip=None):\n        self._is_first_block = False\n\n    def __extract_mlir_values__(self):\n        values, self._values_pos = [], []\n        for obj in [self.params, self._blk_coord]:\n            obj_values = cutlass.extract_mlir_values(obj)\n            values += obj_values\n            self._values_pos.append(len(obj_values))\n        return values\n\n    def __new_from_mlir_values__(self, values):\n        obj_list = []\n        for obj, n_items in zip([self.params, self._blk_coord], self._values_pos):\n            obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))\n            values = values[n_items:]\n        return SingleTileScheduler(*(tuple(obj_list)), loc=self._loc)\n\n\nclass StaticPersistentTileScheduler:\n    @dataclass\n    class Params(ParamsBase):\n        num_block_cluster_divmod: FastDivmodDivisor\n        num_head_divmod: FastDivmodDivisor\n        total_blocks_cluster: Int32\n        cluster_shape_m: cutlass.Constexpr[int] = 1\n\n        @staticmethod\n        def create(\n            args: TileSchedulerArguments, *, loc=None, ip=None\n        ) -> \"StaticPersistentTileScheduler.Params\":\n            num_block_cluster = cute.ceil_div(args.num_block, cute.size(args.cluster_shape_mn))\n            total_blocks_cluster = num_block_cluster * args.num_head * args.num_batch\n            return StaticPersistentTileScheduler.Params(\n                FastDivmodDivisor(num_block_cluster),\n                FastDivmodDivisor(args.num_head),\n                total_blocks_cluster,\n                cluster_shape_m=args.cluster_shape_mn[0],\n            )\n\n    def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=None):\n        self.params = params\n        self._tile_idx = tile_idx\n        self._loc = loc\n        self._ip = ip\n\n    @staticmethod\n    def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params:\n        return StaticPersistentTileScheduler.Params.create(args, loc=loc, ip=ip)\n\n    @staticmethod\n    def create(params: Params, *, loc=None, ip=None) -> \"StaticPersistentTileScheduler\":\n        if const_expr(cute.size(params.cluster_shape_m) == 1):\n            tile_idx = cute.arch.block_idx()[0]\n        else:\n            tile_idx = cute.arch.cluster_idx()[0]\n        return StaticPersistentTileScheduler(params, tile_idx, loc=loc, ip=ip)\n\n    # called by host\n    @staticmethod\n    def get_grid_shape(\n        params: Params,\n        *,\n        loc=None,\n        ip=None,\n    ) -> Tuple[Int32, Int32, Int32]:\n        hardware_info = cutlass.utils.HardwareInfo()\n        sm_count = hardware_info.get_device_multiprocessor_count()\n        # Grid must be a multiple of cluster_shape_m for CUDA cluster launch.\n        max_ctas = (sm_count // params.cluster_shape_m) * params.cluster_shape_m\n        grid_x = cutlass.min(max_ctas, params.total_blocks_cluster * params.cluster_shape_m)\n        return (grid_x, Int32(1), Int32(1))\n\n    # @cute.jit\n    def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:\n        hn_idx, block_idx = divmod(self._tile_idx, self.params.num_block_cluster_divmod)\n        batch_idx, head_idx = divmod(hn_idx, self.params.num_head_divmod)\n        is_valid = self._tile_idx < self.params.total_blocks_cluster\n        # if cute.arch.thread_idx()[0] == 0:\n        #     cute.printf(\"TileScheduler: tile_idx=%d, hn_idx=%d, block_idx=%d, batch_idx=%d, head_idx=%d, is_valid=%d\", self._tile_idx, hn_idx, block_idx, batch_idx, head_idx, is_valid)\n        return WorkTileInfo(\n            (Int32(block_idx), Int32(head_idx), Int32(batch_idx), Int32(0)), is_valid\n        )\n\n    def initial_work_tile_info(self, *, loc=None, ip=None):\n        return self.get_current_work(loc=loc, ip=ip)\n\n    def prefetch_next_work(self, *, loc=None, ip=None):\n        pass\n\n    def advance_to_next_work(self, *, loc=None, ip=None):\n        if const_expr(self.params.cluster_shape_m == 1):\n            self._tile_idx += cute.arch.grid_dim()[0]\n        else:\n            self._tile_idx += cute.arch.cluster_dim()[0]\n\n    def __extract_mlir_values__(self):\n        values, self._values_pos = [], []\n        for obj in [self.params, self._tile_idx]:\n            obj_values = cutlass.extract_mlir_values(obj)\n            values += obj_values\n            self._values_pos.append(len(obj_values))\n        return values\n\n    def __new_from_mlir_values__(self, values):\n        obj_list = []\n        for obj, n_items in zip(\n            [self.params, self._tile_idx],\n            self._values_pos,\n        ):\n            obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))\n            values = values[n_items:]\n        return StaticPersistentTileScheduler(*(tuple(obj_list)), loc=self._loc)\n\n\nclass SingleTileLPTScheduler:\n    @dataclass\n    class Params(ParamsBase):\n        total_blocks: Int32\n        num_splits: Int32\n        num_block: Int32\n        l2_minor: Int32\n        num_block_divmod: FastDivmodDivisor\n        num_head_divmod: FastDivmodDivisor\n        l2_minor_divmod: FastDivmodDivisor\n        l2_major_divmod: FastDivmodDivisor\n        l2_minor_residual_divmod: FastDivmodDivisor\n        num_hb_quotient: Int32\n        is_split_kv: cutlass.Constexpr[bool] = False\n\n        @staticmethod\n        @cute.jit\n        def create(\n            args: TileSchedulerArguments, *, loc=None, ip=None\n        ) -> \"SingleTileLPTScheduler.Params\":\n            # cute.printf(args.num_block, args.num_head, args.num_batch, args.seqlen_k, args.headdim, args.headdim_v, args.total_q, args.tile_shape_mn, args.qhead_per_kvhead_packgqa, args.element_size)\n            size_one_kv_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size\n            size_one_head = size_one_kv_head\n            size_l2 = 50 * 1024 * 1024  # 40 MB for K & V\n            # Swizzle is the size of each \"section\". Round swizzle to a power of 2\n            # Need to be careful about the case where only one head will fit\n            # swizzle is how many heads can fit in L2\n            # swizzle = 1 if size_l2 < size_one_head else (size_l2 // size_one_head)\n            # Seems faster if swizzle if a power of 2\n            log2_floor = lambda n: 31 - clz(n)\n            swizzle = 1 if size_l2 < size_one_head else (1 << log2_floor(size_l2 // size_one_head))\n            # swizzle = 1 if size_l2 < size_one_head else (size_l2 // size_one_head)\n            # If we're in the last section (called residual), we don't want to divide by\n            # swizzle. Instead we want to divide by the remainder.\n            num_hb_quotient = (args.num_head * args.num_batch) // swizzle\n            num_hb_remainder = (args.num_head * args.num_batch) % swizzle\n            return SingleTileLPTScheduler.Params(\n                total_blocks=args.num_block * args.num_head * args.num_batch,\n                num_block=args.num_block,\n                l2_minor=Int32(swizzle),\n                num_block_divmod=FastDivmodDivisor(args.num_block),\n                num_head_divmod=FastDivmodDivisor(args.num_head),\n                l2_minor_divmod=FastDivmodDivisor(swizzle),\n                l2_major_divmod=FastDivmodDivisor(swizzle * args.num_block),\n                l2_minor_residual_divmod=FastDivmodDivisor(\n                    max(num_hb_remainder, 1)\n                ),  # don't divide by 0\n                num_hb_quotient=Int32(num_hb_quotient),\n                num_splits=args.num_splits,\n                is_split_kv=args.is_split_kv,\n            )\n\n    def __init__(self, params: Params, tile_idx: Int32, split_idx: Int32, *, loc=None, ip=None):\n        self.params = params\n        self._tile_idx = tile_idx\n        self._split_idx = split_idx\n        self._loc = loc\n        self._ip = ip\n\n    @staticmethod\n    def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params:\n        return SingleTileLPTScheduler.Params.create(args, loc=loc, ip=ip)\n\n    @staticmethod\n    @cute.jit\n    def create(params: Params, *, loc=None, ip=None) -> \"SingleTileLPTScheduler\":\n        tile_idx, split_idx, _ = cute.arch.block_idx()\n        return SingleTileLPTScheduler(params, tile_idx, split_idx, loc=loc, ip=ip)\n\n    # called by host\n    @staticmethod\n    def get_grid_shape(\n        params: Params,\n        *,\n        loc=None,\n        ip=None,\n    ) -> Tuple[Int32, Int32, Int32]:\n        return (params.total_blocks, params.num_splits, Int32(1))\n\n    @cute.jit\n    def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:\n        params = self.params\n        # Implement LPT scheduling coordinate calculation\n        bidhb, l2_mod = divmod(self._tile_idx, params.l2_major_divmod)\n        # If we're in the last section (called residual), we don't want to divide by\n        # swizzle. Instead we want to divide by the remainder.\n        block, bidhb_residual = 0, 0\n        if bidhb < params.num_hb_quotient:\n            block, bidhb_residual = divmod(l2_mod, params.l2_minor_divmod)\n        else:\n            block, bidhb_residual = divmod(l2_mod, params.l2_minor_residual_divmod)\n        bidhb_actual = bidhb * params.l2_minor + bidhb_residual\n        batch_idx, head_idx = divmod(bidhb_actual, params.num_head_divmod)\n        # Longest-processing-time-first\n        block = params.num_block - 1 - block\n        is_valid = self._tile_idx < params.total_blocks\n        return WorkTileInfo(\n            (Int32(block), Int32(head_idx), Int32(batch_idx), Int32(self._split_idx)), is_valid\n        )\n\n    def initial_work_tile_info(self, *, loc=None, ip=None):\n        return self.get_current_work(loc=loc, ip=ip)\n\n    def prefetch_next_work(self, *, loc=None, ip=None):\n        pass\n\n    def advance_to_next_work(self, *, loc=None, ip=None):\n        # Single tile scheduler - set to invalid tile_idx to indicate no more work\n        self._tile_idx = self.params.total_blocks\n\n    def __extract_mlir_values__(self):\n        values, self._values_pos = [], []\n        for obj in [self.params, self._tile_idx, self._split_idx]:\n            obj_values = cutlass.extract_mlir_values(obj)\n            values += obj_values\n            self._values_pos.append(len(obj_values))\n        return values\n\n    def __new_from_mlir_values__(self, values):\n        obj_list = []\n        for obj, n_items in zip([self.params, self._tile_idx, self._split_idx], self._values_pos):\n            obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))\n            values = values[n_items:]\n        return self.__class__(*(tuple(obj_list)), loc=self._loc)\n\n\nclass SingleTileLPTBwdScheduler:\n    @dataclass\n    class Params(ParamsBase):\n        total_blocks: Int32\n        num_block: Int32\n        l2_minor: Int32\n        num_head_divmod: FastDivmodDivisor\n        l2_minor_divmod: FastDivmodDivisor\n        l2_major_divmod: FastDivmodDivisor\n        l2_minor_residual_divmod: FastDivmodDivisor\n        num_hb_quotient: Int32\n        cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1)\n        spt: cutlass.Constexpr[bool] = True\n\n        @staticmethod\n        @cute.jit\n        def create(\n            args: TileSchedulerArguments, *, loc=None, ip=None\n        ) -> \"SingleTileLPTBwdScheduler.Params\":\n            size_l2 = 50 * 1024 * 1024\n            size_one_qdo_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size\n            # size_one_dqaccum_head = args.seqlen_k * (args.headdim) * 4\n            size_one_dqaccum_head = 0\n            size_one_head = size_one_qdo_head + size_one_dqaccum_head\n            log2_floor = lambda n: 31 - clz(n)\n            swizzle = 1 if size_l2 < size_one_head else (1 << log2_floor(size_l2 // size_one_head))\n            # swizzle = 8\n            # If we're in the last section (called residual), we don't want to divide by\n            # swizzle. Instead we want to divide by the remainder.\n            num_hb_quotient = (args.num_head * args.num_batch) // swizzle\n            num_hb_remainder = (args.num_head * args.num_batch) % swizzle\n            num_block = cute.ceil_div(args.num_block, args.cluster_shape_mn[0])\n            return SingleTileLPTBwdScheduler.Params(\n                total_blocks=(num_block * args.cluster_shape_mn[0])\n                * args.num_head\n                * args.num_batch,\n                num_block=num_block,\n                l2_minor=Int32(swizzle),\n                num_head_divmod=FastDivmodDivisor(args.num_head),\n                l2_minor_divmod=FastDivmodDivisor(swizzle),\n                l2_major_divmod=FastDivmodDivisor(swizzle * num_block),\n                l2_minor_residual_divmod=FastDivmodDivisor(\n                    max(num_hb_remainder, 1)\n                ),  # don't divide by 0\n                num_hb_quotient=Int32(num_hb_quotient),\n                cluster_shape_mn=args.cluster_shape_mn,\n                spt=args.lpt,\n            )\n\n    def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=None):\n        self.params = params\n        self._tile_idx = tile_idx\n        self._loc = loc\n        self._ip = ip\n\n    @staticmethod\n    def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params:\n        return SingleTileLPTBwdScheduler.Params.create(args, loc=loc, ip=ip)\n\n    @staticmethod\n    @cute.jit\n    def create(params: Params, *, loc=None, ip=None) -> \"SingleTileLPTBwdScheduler\":\n        tile_idx = cute.arch.block_idx()[0]\n        return SingleTileLPTBwdScheduler(params, tile_idx, loc=loc, ip=ip)\n\n    # called by host\n    @staticmethod\n    def get_grid_shape(\n        params: Params,\n        *,\n        loc=None,\n        ip=None,\n    ) -> Tuple[Int32, Int32, Int32]:\n        return (params.total_blocks, Int32(1), Int32(1))\n\n    @cute.jit\n    def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo:\n        cluster_idx = self._tile_idx // self.params.cluster_shape_mn[0]\n        params = self.params\n        # Implement LPT scheduling coordinate calculation\n        bidhb, l2_mod = divmod(cluster_idx, params.l2_major_divmod)\n        # If we're in the last section (called residual), we don't want to divide by\n        # swizzle. Instead we want to divide by the remainder.\n        block, bidhb_residual = 0, 0\n        if bidhb < params.num_hb_quotient:\n            block, bidhb_residual = divmod(l2_mod, params.l2_minor_divmod)\n        else:\n            block, bidhb_residual = divmod(l2_mod, params.l2_minor_residual_divmod)\n        bidhb_actual = bidhb * params.l2_minor + bidhb_residual\n        batch_idx, head_idx = divmod(bidhb_actual, params.num_head_divmod)\n        if cutlass.const_expr(params.spt):\n            block = params.num_block - 1 - block\n        if cutlass.const_expr(params.cluster_shape_mn[0] > 1):\n            bidx_in_cluster = cute.arch.block_in_cluster_idx()\n            block = block * params.cluster_shape_mn[0] + bidx_in_cluster[0]\n        is_valid = self._tile_idx < params.total_blocks\n        return WorkTileInfo((Int32(block), Int32(head_idx), Int32(batch_idx), Int32(0)), is_valid)\n\n    def initial_work_tile_info(self, *, loc=None, ip=None):\n        return self.get_current_work(loc=loc, ip=ip)\n\n    def prefetch_next_work(self, *, loc=None, ip=None):\n        pass\n\n    def advance_to_next_work(self, *, loc=None, ip=None):\n        # Single tile scheduler - set to invalid tile_idx to indicate no more work\n        self._tile_idx = self.params.total_blocks\n\n    def __extract_mlir_values__(self):\n        values, self._values_pos = [], []\n        for obj in [self.params, self._tile_idx]:\n            obj_values = cutlass.extract_mlir_values(obj)\n            values += obj_values\n            self._values_pos.append(len(obj_values))\n        return values\n\n    def __new_from_mlir_values__(self, values):\n        obj_list = []\n        for obj, n_items in zip([self.params, self._tile_idx], self._values_pos):\n            obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))\n            values = values[n_items:]\n        return self.__class__(*(tuple(obj_list)), loc=self._loc)\n\n\nclass SingleTileVarlenScheduler:\n    @dataclass\n    class Params(ParamsBase):\n        num_head: Int32\n        num_batch: Int32\n        total_q: Int32\n        num_splits: Int32\n        max_kvblock_in_l2: Int32\n        tile_shape_mn: cutlass.Constexpr[Tuple[int, int]]\n        mCuSeqlensQ: Optional[cute.Tensor] = None\n        mSeqUsedQ: Optional[cute.Tensor] = None\n        qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1\n        lpt: cutlass.Constexpr[bool] = False\n        is_split_kv: cutlass.Constexpr[bool] = False\n        head_swizzle: cutlass.Constexpr[bool] = False\n        cluster_shape_m: cutlass.Constexpr[int] = 1\n\n        @staticmethod\n        @cute.jit\n        def create(\n            args: TileSchedulerArguments, *, loc=None, ip=None\n        ) -> \"SingleTileVarlenScheduler.Params\":\n            size_l2 = 50 * 1024 * 1024  # 50 MB for K & V\n            max_kvblock_in_l2 = size_l2 // (\n                (args.headdim + args.headdim_v) * args.element_size * args.tile_shape_mn[1]\n            )\n            assert args.mCuSeqlensQ is not None or args.mSeqUsedQ is not None, (\n                \"At least one of mCuSeqlensQ or mSeqUsedQ must be provided\"\n            )\n            assert args.cluster_shape_mn[1] == 1, \"Only cluster_shape_mn[1] == 1 is supported\"\n            return SingleTileVarlenScheduler.Params(\n                num_head=args.num_head,\n                num_batch=args.num_batch,\n                total_q=args.total_q,\n                num_splits=args.num_splits,\n                max_kvblock_in_l2=max_kvblock_in_l2,\n                tile_shape_mn=args.tile_shape_mn,\n                mCuSeqlensQ=args.mCuSeqlensQ,\n                mSeqUsedQ=args.mSeqUsedQ,\n                qhead_per_kvhead_packgqa=args.qhead_per_kvhead_packgqa,\n                lpt=args.lpt,\n                is_split_kv=args.is_split_kv,\n                head_swizzle=args.head_swizzle,\n                cluster_shape_m=args.cluster_shape_mn[0],\n            )\n\n    def __init__(self, params: Params, tile_idx: Int32, split_idx: Int32, *, loc=None, ip=None):\n        self.params = params\n        self._tile_idx = tile_idx\n        self._split_idx = split_idx\n        self._is_first_block = True\n        self._loc = loc\n        self._ip = ip\n\n    @staticmethod\n    def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params:\n        return SingleTileVarlenScheduler.Params.create(args, loc=loc, ip=ip)\n\n    @staticmethod\n    def create(params: Params, *, loc=None, ip=None) -> \"SingleTileVarlenScheduler\":\n        tile_idx, split_idx, _ = cute.arch.block_idx()\n        return SingleTileVarlenScheduler(params, tile_idx, split_idx, loc=loc, ip=ip)\n\n    # called by host\n    @staticmethod\n    def get_grid_shape(\n        params: Params,\n        *,\n        loc=None,\n        ip=None,\n    ) -> Tuple[Int32, Int32, Int32]:\n        total_blocks_max = (\n            params.total_q\n            + params.num_batch * (params.cluster_shape_m * params.tile_shape_mn[0] - 1)\n        ) // params.tile_shape_mn[0]\n        # round down to nearest multiple of cluster since odd excess is always padding\n        total_blocks_max = total_blocks_max // params.cluster_shape_m * params.cluster_shape_m\n        return (total_blocks_max * params.num_head, params.num_splits, Int32(1))\n\n    @cute.jit\n    def _get_num_m_blocks(self, lane: Int32, bidb_start: Int32) -> Int32:\n        params = self.params\n        batch_idx = lane + bidb_start\n        if cutlass.const_expr(params.mSeqUsedQ is not None):\n            seqlen = Int32(0)\n            if batch_idx < params.num_batch:\n                seqlen = params.mSeqUsedQ[batch_idx]\n        else:\n            assert params.mCuSeqlensQ is not None\n            cur_cu_seqlen = Int32(0)\n            if batch_idx <= params.num_batch:\n                cur_cu_seqlen = params.mCuSeqlensQ[batch_idx]\n            next_cu_seqlen = cute.arch.shuffle_sync_down(cur_cu_seqlen, offset=1)\n            seqlen = next_cu_seqlen - cur_cu_seqlen\n        if cutlass.const_expr(params.qhead_per_kvhead_packgqa > 1):\n            seqlen *= params.qhead_per_kvhead_packgqa\n        return (\n            cute.ceil_div(cute.ceil_div(seqlen, params.tile_shape_mn[0]), params.cluster_shape_m)\n            if batch_idx < params.num_batch and lane < cute.arch.WARP_SIZE - 1\n            else Int32(0)\n        )\n\n    @cute.jit\n    def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:\n        params = self.params\n        lane_idx = cute.arch.lane_idx()\n        num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=0)\n        num_m_blocks_cumulative = utils.warp_prefix_sum(num_m_blocks, lane_idx)\n        # Total number of blocks for the next 31 batches\n        m_blocks_in_group = cute.arch.shuffle_sync(num_m_blocks_cumulative, cute.arch.WARP_SIZE - 1)\n        # Same for all lanes\n        group_end_tile = m_blocks_in_group * params.num_head\n        # if cute.arch.thread_idx()[0] == 128 + 31: cute.printf(\"SingleTileVarlenScheduler: tile_idx=%d, group_end_tile = %d, num_m_blocks=%d, num_m_blocks_cumulative = %d, m_blocks_in_group = %d\", self._tile_idx, group_end_tile, num_m_blocks, num_m_blocks_cumulative, m_blocks_in_group)\n        block, head_idx, batch_idx = Int32(0), Int32(0), Int32(0)\n        next_tile_idx = self._tile_idx // params.cluster_shape_m\n        while group_end_tile <= next_tile_idx:\n            batch_idx += cute.arch.WARP_SIZE - 1\n            if batch_idx >= params.num_batch:\n                batch_idx = Int32(params.num_batch)\n                group_end_tile = next_tile_idx + 1\n            else:\n                num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=batch_idx)\n                num_m_blocks_cumulative = utils.warp_prefix_sum(num_m_blocks, lane_idx)\n                m_blocks_in_group = cute.arch.shuffle_sync(\n                    num_m_blocks_cumulative, cute.arch.WARP_SIZE - 1\n                )\n                group_end_tile += m_blocks_in_group * params.num_head\n        is_valid = False\n        if batch_idx >= params.num_batch:\n            block, head_idx, batch_idx = Int32(0), Int32(0), Int32(params.num_batch)\n        else:\n            group_start_tile = group_end_tile - m_blocks_in_group * params.num_head\n            # if cute.arch.thread_idx()[0] == 128 + 31: cute.printf(\"SingleTileVarlenScheduler: tile_idx=%d, group_end_tile = %d, num_m_blocks=%d, batch_idx = %d\", self._tile_idx, group_end_tile, num_m_blocks, batch_idx)\n            # The next problem to process is the first one that does not have ending tile position\n            # that is greater than or equal to tile index.\n            batch_idx_in_group = cute.arch.popc(\n                cute.arch.vote_ballot_sync(\n                    group_start_tile + num_m_blocks_cumulative * params.num_head <= next_tile_idx\n                )\n            )\n            batch_idx += batch_idx_in_group\n            num_m_blocks_prev_lane = (\n                0\n                if batch_idx_in_group == 0\n                else cute.arch.shuffle_sync(num_m_blocks_cumulative, batch_idx_in_group - 1)\n            )\n            num_m_blocks = cute.arch.shuffle_sync(num_m_blocks, batch_idx_in_group)\n            mh_block = next_tile_idx - group_start_tile - num_m_blocks_prev_lane * params.num_head\n            if cutlass.const_expr(params.lpt or params.head_swizzle):\n                # This is a version of the SingleTileLPTScheduler, complicated by the fact that\n                # the seqlen can vary per batch.\n                # TODO: is there any case where num_m_blocks is 0?\n                # TODO: by right we should read the seqlen_kv but we're assuming seqlen_q == seqlen_k here\n                num_n_blocks = (\n                    num_m_blocks\n                    * params.tile_shape_mn[0]\n                    // params.qhead_per_kvhead_packgqa\n                    // params.tile_shape_mn[1]\n                )\n                # nheads_in_l2 = min(max(self.max_kvblock_in_l2 // num_n_blocks, 1), self.num_head)\n                # Seems faster to have this be a power of 2\n                nheads_in_l2 = (\n                    16\n                    if num_n_blocks * 16 <= params.max_kvblock_in_l2\n                    else (\n                        8\n                        if num_n_blocks * 8 <= params.max_kvblock_in_l2\n                        else (\n                            4\n                            if num_n_blocks * 4 <= params.max_kvblock_in_l2\n                            else (2 if num_n_blocks * 2 <= params.max_kvblock_in_l2 else 1)\n                        )\n                    )\n                )\n                nheads_in_l2 = min(nheads_in_l2, params.num_head)\n                mh_in_l2 = nheads_in_l2 * num_m_blocks\n                section_idx = mh_block // mh_in_l2\n                l2_mod = mh_block - section_idx * mh_in_l2\n                # Deal with tail section\n                nheads_in_this_section = (\n                    nheads_in_l2\n                    if nheads_in_l2 * (section_idx + 1) <= params.num_head\n                    else params.num_head - section_idx * nheads_in_l2\n                )\n                block = l2_mod // nheads_in_this_section\n                head_idx_residual = l2_mod - block * nheads_in_this_section\n                head_idx = section_idx * nheads_in_l2 + head_idx_residual\n                if cutlass.const_expr(params.lpt):\n                    block = num_m_blocks - 1 - block\n            else:\n                head_idx = mh_block // num_m_blocks\n                block = mh_block - head_idx * num_m_blocks\n            is_valid = self._is_first_block and batch_idx < params.num_batch\n            if cutlass.const_expr(params.cluster_shape_m > 1):\n                bidx_in_cluster = cute.arch.block_in_cluster_idx()\n                block = block * params.cluster_shape_m + bidx_in_cluster[0]\n        # if cute.arch.thread_idx()[0] == 128: cute.printf(\"SingleTileVarlenScheduler: tile_idx=%d, batch_idx=%d, head_idx=%d, block=%d, is_valid = %d\", self._tile_idx, batch_idx, head_idx, block, is_valid)\n        split_idx = self._split_idx if const_expr(params.is_split_kv) else Int32(0)\n        return WorkTileInfo((Int32(block), Int32(head_idx), Int32(batch_idx), split_idx), is_valid)\n\n    def initial_work_tile_info(self, *, loc=None, ip=None):\n        return self.get_current_work(loc=loc, ip=ip)\n\n    def prefetch_next_work(self, *, loc=None, ip=None):\n        pass\n\n    def advance_to_next_work(self, *, loc=None, ip=None):\n        # Single tile scheduler - set to invalid tile_idx to indicate no more work\n        self._is_first_block = False\n\n    def __extract_mlir_values__(self):\n        values, self._values_pos = [], []\n        for obj in [self.params, self._tile_idx, self._split_idx]:\n            obj_values = cutlass.extract_mlir_values(obj)\n            values += obj_values\n            self._values_pos.append(len(obj_values))\n        return values\n\n    def __new_from_mlir_values__(self, values):\n        obj_list = []\n        for obj, n_items in zip(\n            [self.params, self._tile_idx, self._split_idx],\n            self._values_pos,\n        ):\n            obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))\n            values = values[n_items:]\n        return SingleTileVarlenScheduler(*(tuple(obj_list)), loc=self._loc)\n"
  },
  {
    "path": "flash_attn/cute/utils.py",
    "content": "# Copyright (c) 2025, Tri Dao.\n\nimport math\nimport hashlib\nimport inspect\nfrom typing import Type, Callable, Optional, Tuple, overload\n\nimport cutlass\nimport cutlass.cute as cute\n\nfrom cutlass import Float32, const_expr\nfrom cutlass.cute import FastDivmodDivisor\nfrom cutlass.cutlass_dsl import T, dsl_user_op\nfrom cutlass._mlir.dialects import nvvm, llvm\nfrom cutlass.cute.runtime import from_dlpack\n\n\nimport quack.activation\n\n_MIXER_ATTRS = (\"__vec_size__\",)\n\n# Obtained from sollya:\n# fpminimax(exp(x * log(2.0)), 1, [|1,24...|],[0;1],relative);\nPOLY_EX2 = {\n    0: (1.0),\n    1: (\n        1.0,\n        0.922497093677520751953125,\n    ),\n    2: (\n        1.0,\n        0.6657850742340087890625,\n        0.330107033252716064453125,\n    ),\n    3: (\n        1.0,\n        0.695146143436431884765625,\n        0.227564394474029541015625,\n        0.077119089663028717041015625,\n    ),\n    4: (\n        1.0,\n        0.693042695522308349609375,\n        0.2412912547588348388671875,\n        5.2225358784198760986328125e-2,\n        1.3434938155114650726318359375e-2,\n    ),\n    5: (\n        1.0,\n        0.693151414394378662109375,\n        0.24016360938549041748046875,\n        5.5802188813686370849609375e-2,\n        9.01452265679836273193359375e-3,\n        1.86810153536498546600341796875e-3,\n    ),\n}\n\n\ndef _compute_base_hash(func: Callable) -> str:\n    \"\"\"Compute hash from source code or bytecode and closure values.\"\"\"\n    try:\n        data = inspect.getsource(func).encode()\n    except (OSError, TypeError):\n        if hasattr(func, \"__code__\") and func.__code__ is not None:\n            data = func.__code__.co_code\n        else:\n            data = repr(func).encode()\n\n    hasher = hashlib.sha256(data)\n\n    if hasattr(func, \"__closure__\") and func.__closure__ is not None:\n        for cell in func.__closure__:\n            hasher.update(repr(cell.cell_contents).encode())\n\n    return hasher.hexdigest()\n\n\ndef hash_callable(\n    func: Callable, mixer_attrs: Tuple[str] = _MIXER_ATTRS, set_cute_hash: bool = True\n) -> str:\n    \"\"\"Hash a callable based on the source code or bytecode and closure values.\n    Fast-path: if the callable (or its __wrapped__ base) has a ``__cute_hash__``\n    attribute, that value is returned immediately as the base hash, then\n    metadata dunders are mixed in to produce the final dict-key hash.\n    set_cute_hash: whether or not to set func.__cute_hash__\n    \"\"\"\n    # Resolve base hash\n    if hasattr(func, \"__cute_hash__\"):\n        base_hash = func.__cute_hash__\n    else:\n        # Unwrap decorated functions (e.g., cute.jit wrappers).\n        base_func = getattr(func, \"__wrapped__\", func)\n\n        if hasattr(base_func, \"__cute_hash__\"):\n            base_hash = base_func.__cute_hash__\n        else:\n            base_hash = _compute_base_hash(base_func)\n\n            if set_cute_hash:\n                base_func.__cute_hash__ = base_hash\n\n    # Mix in mutable metadata dunders\n    mixer_values = tuple(getattr(func, attr, None) for attr in mixer_attrs)\n\n    if all(v is None for v in mixer_values):\n        return base_hash\n\n    hasher = hashlib.sha256(base_hash.encode())\n\n    for attr, val in zip(_MIXER_ATTRS, mixer_values):\n        hasher.update(f\"{attr}={val!r}\".encode())\n\n    return hasher.hexdigest()\n\n\ndef create_softcap_scoremod(softcap_val):\n    inv_softcap = 1.0 / softcap_val\n\n    @cute.jit\n    def scoremod_premask_fn(acc_S_SSA, batch_idx, head_idx, q_idx, kv_idx, aux_tensors):\n        scores = acc_S_SSA * inv_softcap\n        return scores * cute.math.tanh(scores, fastmath=True)\n\n    return scoremod_premask_fn\n\n\nLOG2_E = math.log2(math.e)\n\n\ndef compute_softmax_scale_log2(softmax_scale, score_mod):\n    \"\"\"Compute softmax_scale_log2 and adjusted softmax_scale based on whether score_mod is used.\n\n    When score_mod is None, fold the log2(e) factor into softmax_scale_log2 and set softmax_scale\n    to None. When score_mod is present, keep softmax_scale separate so it can be applied before\n    the score_mod, and set softmax_scale_log2 to just the change-of-base constant.\n\n    Returns (softmax_scale_log2, softmax_scale).\n    \"\"\"\n    if const_expr(score_mod is None):\n        return softmax_scale * LOG2_E, None\n    else:\n        return LOG2_E, softmax_scale\n\n\ndef compute_fastdiv_mods(mQ, mK, qhead_per_kvhead, pack_gqa, aux_tensors, mPageTable=None):\n    \"\"\"Compute FastDivmodDivisor pairs for aux_tensors index computation.\n\n    Returns a (seqlen_q_divmod, seqlen_k_divmod) tuple, or None if aux_tensors is None.\n    \"\"\"\n    if const_expr(aux_tensors is None):\n        return None\n    seqlen_q = cute.size(mQ.shape[0]) // (qhead_per_kvhead if const_expr(pack_gqa) else 1)\n    seqlen_k = (\n        cute.size(mK.shape[0])\n        if const_expr(mPageTable is None)\n        else mK.shape[0] * mPageTable.shape[1]\n    )\n    return (FastDivmodDivisor(seqlen_q), FastDivmodDivisor(seqlen_k))\n\n\ndef convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Tensor:\n    return (\n        from_dlpack(x, assumed_align=alignment)\n        .mark_layout_dynamic(leading_dim=leading_dim)\n        .mark_compact_shape_dynamic(\n            mode=leading_dim, stride_order=x.dim_order(), divisibility=divisibility\n        )\n    )\n\n\ndef convert_from_dlpack_leading_static(\n    x, leading_dim, alignment=16, static_modes=None, stride_order=None\n) -> cute.Tensor:\n    if stride_order is None:\n        stride_order = x.dim_order()\n    x_ = from_dlpack(x, assumed_align=alignment)\n    for i in range(x.ndim):\n        if i != leading_dim and (static_modes is None or i not in static_modes):\n            x_ = x_.mark_compact_shape_dynamic(mode=i, stride_order=stride_order)\n    return x_\n\n\ndef make_tiled_copy_A(\n    copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False\n) -> cute.TiledCopy:\n    if const_expr(swapAB):\n        return cute.make_tiled_copy_B(copy_atom, tiled_mma)\n    else:\n        return cute.make_tiled_copy_A(copy_atom, tiled_mma)\n\n\ndef make_tiled_copy_B(\n    copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False\n) -> cute.TiledCopy:\n    if const_expr(swapAB):\n        return cute.make_tiled_copy_A(copy_atom, tiled_mma)\n    else:\n        return cute.make_tiled_copy_B(copy_atom, tiled_mma)\n\n\ndef mma_make_fragment_A(\n    smem: cute.Tensor, thr_mma: cute.core.ThrMma, swapAB: cutlass.Constexpr[bool] = False\n) -> cute.Tensor:\n    if const_expr(swapAB):\n        return mma_make_fragment_B(smem, thr_mma)\n    else:\n        return thr_mma.make_fragment_A(thr_mma.partition_A(smem))\n\n\ndef mma_make_fragment_B(\n    smem: cute.Tensor, thr_mma: cute.core.ThrMma, swapAB: cutlass.Constexpr[bool] = False\n) -> cute.Tensor:\n    if const_expr(swapAB):\n        return mma_make_fragment_A(smem, thr_mma)\n    else:\n        return thr_mma.make_fragment_B(thr_mma.partition_B(smem))\n\n\ndef get_smem_store_atom(\n    arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False\n) -> cute.CopyAtom:\n    if const_expr(arch < 90 or element_type.width != 16):\n        return cute.make_copy_atom(\n            cute.nvgpu.CopyUniversalOp(),\n            element_type,\n            num_bits_per_copy=2 * element_type.width,\n        )\n    else:\n        return cute.make_copy_atom(\n            cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=transpose, num_matrices=4),\n            element_type,\n        )\n\n\n@cute.jit\ndef warp_reduce(\n    val: cute.TensorSSA | cute.Numeric,\n    op: Callable,\n    width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,\n) -> cute.TensorSSA | cute.Numeric:\n    if const_expr(isinstance(val, cute.TensorSSA)):\n        res = cute.make_fragment(val.shape, val.dtype)\n        res.store(val)\n        for i in cutlass.range_constexpr(cute.size(val.shape)):\n            res[i] = warp_reduce(res[i], op, width)\n        return res.load()\n    else:\n        for i in cutlass.range_constexpr(int(math.log2(width))):\n            val = op(val, cute.arch.shuffle_sync_bfly(val, offset=1 << i))\n    return val\n\n\n@dsl_user_op\ndef fmax(\n    a: float | Float32, b: float | Float32, c: float | Float32 | None = None, *, loc=None, ip=None\n) -> Float32:\n    from cutlass import CUDA_VERSION\n\n    # * NVVM call based on nvvm version\n    if CUDA_VERSION.major == 12 and CUDA_VERSION.minor == 9:\n        # Old API: requires explicit result type as first positional argument\n        return Float32(\n            nvvm.fmax(\n                T.f32(),\n                Float32(a).ir_value(loc=loc, ip=ip),\n                Float32(b).ir_value(loc=loc, ip=ip),\n                c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None,\n                loc=loc,\n                ip=ip,\n            )\n        )\n    else:\n        # New API: infers result type automatically\n        return Float32(\n            nvvm.fmax(\n                Float32(a).ir_value(loc=loc, ip=ip),\n                Float32(b).ir_value(loc=loc, ip=ip),\n                c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None,\n                loc=loc,\n                ip=ip,\n            )\n        )\n\n\n@cute.jit\ndef fmax_reduce(\n    x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80\n) -> Float32:\n    if const_expr(arch < 100 or cute.size(x.shape) % 8 != 0):\n        # if const_expr(init_val is None):\n        #     init_val = -cutlass.Float32.if\n        # return x.reduce(cute.ReductionOp.MAX, init_val, 0)\n        res = cute.make_fragment(x.shape, Float32)\n        res.store(x)\n        # local_max = [res[0], res[1]]\n        # for i in cutlass.range_constexpr(2, cute.size(x.shape), 2):\n        #     local_max[0] = fmax(local_max[0], res[i + 0])\n        #     local_max[1] = fmax(local_max[1], res[i + 1])\n        # local_max[0] = fmax(local_max[0], local_max[1])\n        # return local_max[0] if const_expr(init_val is None) else fmax(local_max[0], init_val)\n        local_max = [res[0], res[1], res[2], res[3]]\n        for i in cutlass.range_constexpr(4, cute.size(x.shape), 4):\n            local_max[0] = fmax(local_max[0], res[i + 0])\n            local_max[1] = fmax(local_max[1], res[i + 1])\n            local_max[2] = fmax(local_max[2], res[i + 2])\n            local_max[3] = fmax(local_max[3], res[i + 3])\n        local_max[0] = fmax(local_max[0], local_max[1])\n        local_max[2] = fmax(local_max[2], local_max[3])\n        local_max[0] = fmax(local_max[0], local_max[2])\n        return local_max[0] if const_expr(init_val is None) else fmax(local_max[0], init_val)\n    else:\n        # [2025-06-15] x.reduce only seems to use 50% 3-input max and 50% 2-input max\n        # We instead force the 3-input max.\n        res = cute.make_fragment(x.shape, Float32)\n        res.store(x)\n        local_max_0 = (\n            fmax(init_val, res[0], res[1])\n            if const_expr(init_val is not None)\n            else fmax(res[0], res[1])\n        )\n        local_max = [\n            local_max_0,\n            fmax(res[2], res[3]),\n            fmax(res[4], res[5]),\n            fmax(res[6], res[7]),\n        ]\n        for i in cutlass.range_constexpr(8, cute.size(x.shape), 8):\n            local_max[0] = fmax(local_max[0], res[i], res[i + 1])\n            local_max[1] = fmax(local_max[1], res[i + 2], res[i + 3])\n            local_max[2] = fmax(local_max[2], res[i + 4], res[i + 5])\n            local_max[3] = fmax(local_max[3], res[i + 6], res[i + 7])\n        local_max[0] = fmax(local_max[0], local_max[1])\n        return fmax(local_max[0], local_max[2], local_max[3])\n\n\n@cute.jit\ndef fadd_reduce(\n    x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80\n) -> Float32:\n    if const_expr(arch < 100 or cute.size(x.shape) % 8 != 0):\n        if const_expr(init_val is None):\n            init_val = Float32.zero\n        return x.reduce(cute.ReductionOp.ADD, init_val, 0)\n        # res = cute.make_fragment(x.shape, Float32)\n        # res.store(x)\n        # local_sum = [res[0], res[1], res[2], res[3]]\n        # for i in cutlass.range_constexpr(4, cute.size(x.shape), 4):\n        #     local_sum[0] += res[i + 0]\n        #     local_sum[1] += res[i + 1]\n        #     local_sum[2] += res[i + 2]\n        #     local_sum[3] += res[i + 3]\n        # local_sum[0] += local_sum[1]\n        # local_sum[2] += local_sum[3]\n        # local_sum[0] += local_sum[2]\n        # return local_sum[0] if const_expr(init_val is None) else local_sum[0] + init_val\n    else:\n        res = cute.make_fragment(x.shape, Float32)\n        res.store(x)\n        local_sum_0 = (\n            cute.arch.add_packed_f32x2((init_val, 0.0), (res[0], res[1]))\n            # cute.arch.add_packed_f32x2((init_val / 2, init_val / 2), (res[0], res[1]))\n            if const_expr(init_val is not None)\n            else (res[0], res[1])\n        )\n        local_sum = [local_sum_0, (res[2], res[3]), (res[4], res[5]), (res[6], res[7])]\n        for i in cutlass.range_constexpr(8, cute.size(x.shape), 8):\n            local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], (res[i + 0], res[i + 1]))\n            local_sum[1] = cute.arch.add_packed_f32x2(local_sum[1], (res[i + 2], res[i + 3]))\n            local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], (res[i + 4], res[i + 5]))\n            local_sum[3] = cute.arch.add_packed_f32x2(local_sum[3], (res[i + 6], res[i + 7]))\n        local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[1])\n        local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], local_sum[3])\n        local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[2])\n        return local_sum[0][0] + local_sum[0][1]\n\n\n@dsl_user_op\ndef atomic_add_fp32(a: float | Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> None:\n    # gmem_ptr_i64 = gmem_ptr.toint(loc=loc, ip=ip).ir_value()\n    # # cache_hint = cutlass.Int64(0x12F0000000000000)\n    # llvm.inline_asm(\n    #     None,\n    #     [gmem_ptr_i64, Float32(a).ir_value(loc=loc, ip=ip)],\n    #     # [gmem_ptr_i64, Float32(a).ir_value(loc=loc, ip=ip), cache_hint.ir_value()],\n    #     \"red.global.add.f32 [$0], $1;\",\n    #     # \"red.global.add.L2::cache_hint.f32 [$0], $1, 0x12F0000000000000;\",\n    #     # \"red.global.add.L2::cache_hint.f32 [$0], $1, $2;\",\n    #     \"l,f\",\n    #     # \"l,f,l\",\n    #     has_side_effects=True,\n    #     is_align_stack=False,\n    #     asm_dialect=llvm.AsmDialect.AD_ATT,\n    # )\n    nvvm.atomicrmw(\n        res=T.f32(), op=nvvm.AtomicOpKind.FADD, ptr=gmem_ptr.llvm_ptr, a=Float32(a).ir_value()\n    )\n\n\n@dsl_user_op\ndef elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer:\n    return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip)\n\n\n@cute.jit\ndef predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor:\n    # Only compute predicates for the \"k\" dimension. For the mn dimension, we will use \"if\"\n    tApA = cute.make_fragment(\n        cute.make_layout(\n            (cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])),\n            stride=(cute.size(tAcA, mode=[2]), 0, 1),\n        ),\n        cutlass.Boolean,\n    )\n    for rest_v in cutlass.range_constexpr(tApA.shape[0]):\n        for rest_k in cutlass.range_constexpr(tApA.shape[2]):\n            tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit)\n    return tApA\n\n\ndef canonical_warp_group_idx(sync: bool = True) -> cutlass.Int32:\n    warp_group_idx = cute.arch.thread_idx()[0] // 128\n    if const_expr(sync):\n        warp_group_idx = cute.arch.make_warp_uniform(warp_group_idx)\n    return warp_group_idx\n\n\n# @dsl_user_op\n# def warp_vote_any_lt(a: float | Float32, b: float | Float32, *, loc=None, ip=None) -> cutlass.Boolean:\n#     mask = cutlass.Int32(-1)\n#     return cutlass.Boolean(\n#         llvm.inline_asm(\n#             T.i32(),\n#             [Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip), mask.ir_value(loc=loc, ip=ip)],\n#             \".pred p1, p2;\\n\"\n#             \"setp.lt.f32 p1, $1, $2;\\n\"\n#             \"vote.sync.any.pred p2, p1, $3;\\n\"\n#             \"selp.u32 $0, 1, 0, p2;\",\n#             # \"selp.u32 $0, 1, 0, p1;\",\n#             \"=r,f,f,r\",\n#             has_side_effects=False,\n#             is_align_stack=False,\n#             asm_dialect=llvm.AsmDialect.AD_ATT,\n#         )\n#     )\n\n\n@cute.jit\ndef shuffle_sync(\n    value: cute.Numeric,\n    offset: cute.typing.Int,\n    width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,\n) -> cute.Numeric:\n    assert value.width % 32 == 0, \"value type must be a multiple of 32 bits\"\n    # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000\n    mask = cute.arch.WARP_SIZE - width\n    clamp = cute.arch.WARP_SIZE - 1\n    mask_and_clamp = mask << 8 | clamp\n    # important: need stride 1 and not 0 for recast_tensor to work\n    val = cute.make_rmem_tensor(cute.make_layout((1,), stride=(1,)), type(value))\n    val[0] = value\n    val_i32 = cute.recast_tensor(val, cutlass.Int32)\n    for i in cutlass.range_constexpr(cute.size(val_i32)):\n        val_i32[i] = cute.arch.shuffle_sync(val_i32[i], offset, mask_and_clamp=mask_and_clamp)\n    return val[0]\n\n\n@dsl_user_op\ndef shl_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) -> cutlass.Uint32:\n    \"\"\"\n    Left-shift val by shift bits using PTX shl.b32 (sign-agnostic).\n\n    Named ``shl_u32`` (not ``shl_b32``) because python type annotations\n    distinguish signed/unsigned.\n\n    PTX semantics (§9.7.8.8): \"Shift amounts greater than the register width N\n    are clamped to N.\"  So ``shl.b32 d, a, 32`` is well-defined and yields 0.\n\n    This differs from C/C++ and LLVM IR, where shifting by >= the type width is\n    undefined behavior.  CuTeDSL compiles through MLIR -> LLVM IR, so a plain\n    Python-level ``Uint32(x) << Uint32(n)`` inherits LLVM's UB: the optimizer\n    may treat the result as poison and eliminate dependent code.  Inline PTX\n    bypasses the LLVM IR shift entirely — the instruction is emitted verbatim\n    into PTX where clamping makes it safe for all shift amounts.\n    \"\"\"\n    return cutlass.Uint32(\n        llvm.inline_asm(\n            T.i32(),\n            [\n                cutlass.Uint32(val).ir_value(loc=loc, ip=ip),\n                cutlass.Uint32(shift).ir_value(loc=loc, ip=ip),\n            ],\n            \"shl.b32 $0, $1, $2;\",\n            \"=r,r,r\",\n            has_side_effects=False,\n            is_align_stack=False,\n            asm_dialect=llvm.AsmDialect.AD_ATT,\n        )\n    )\n\n\n@dsl_user_op\ndef shr_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) -> cutlass.Uint32:\n    \"\"\"\n    Unsigned right-shift val by shift bits using PTX shr.u32 (zero-fills).\n\n    See ``shl_u32`` docstring for why inline PTX is used instead of plain\n    CuTeDSL shift operators (LLVM shift-by-type-width UB).\n    \"\"\"\n    return cutlass.Uint32(\n        llvm.inline_asm(\n            T.i32(),\n            [\n                cutlass.Uint32(val).ir_value(loc=loc, ip=ip),\n                cutlass.Uint32(shift).ir_value(loc=loc, ip=ip),\n            ],\n            \"shr.u32 $0, $1, $2;\",\n            \"=r,r,r\",\n            has_side_effects=False,\n            is_align_stack=False,\n            asm_dialect=llvm.AsmDialect.AD_ATT,\n        )\n    )\n\n\n@cute.jit\ndef warp_prefix_sum(val: cutlass.Int32, lane: Optional[cutlass.Int32] = None) -> cutlass.Int32:\n    if const_expr(lane is None):\n        lane = cute.arch.lane_idx()\n    # if cute.arch.thread_idx()[0] >= 128 and cute.arch.thread_idx()[0] < 128 + 32 and cute.arch.block_idx()[0] == 0: cute.printf(\"tidx = %d, val = %d\", cute.arch.thread_idx()[0] % 32, val)\n    for i in cutlass.range_constexpr(int(math.log2(cute.arch.WARP_SIZE))):\n        offset = 1 << i\n        # Very important that we set mask_and_clamp to 0\n        partial_sum = cute.arch.shuffle_sync_up(val, offset=offset, mask_and_clamp=0)\n        if lane >= offset:\n            val += partial_sum\n        # if cute.arch.thread_idx()[0] >= 128 and cute.arch.thread_idx()[0] < 128 + 32 and cute.arch.block_idx()[0] == 0: cute.printf(\"tidx = %d, partial_sum = %d, val = %d\", cute.arch.thread_idx()[0] % 32, partial_sum, val)\n    return val\n\n\n@dsl_user_op\ndef cvt_f16x2_f32(\n    a: float | Float32, b: float | Float32, to_dtype: Type, *, loc=None, ip=None\n) -> cutlass.Int32:\n    assert to_dtype in [cutlass.BFloat16, cutlass.Float16], \"to_dtype must be BFloat16 or Float16\"\n    return cutlass.Int32(\n        llvm.inline_asm(\n            T.i32(),\n            [Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip)],\n            f\"cvt.rn.{'bf16x2' if to_dtype is cutlass.BFloat16 else 'f16x2'}.f32 $0, $2, $1;\",\n            \"=r,f,f\",\n            has_side_effects=False,\n            is_align_stack=False,\n            asm_dialect=llvm.AsmDialect.AD_ATT,\n        )\n    )\n\n\n@overload\ndef cvt_f16(src: cute.Tensor, dst: cute.Tensor) -> None: ...\n\n\n@overload\ndef cvt_f16(src: cute.Tensor, dtype: Type[cute.Numeric]) -> cute.Tensor: ...\n\n\n@cute.jit\ndef cvt_f16(src: cute.Tensor, dst_or_dtype):\n    \"\"\"Convert Float32 tensor to Float16/BFloat16.\n\n    Args:\n        src: Source tensor with Float32 element type\n        dst_or_dtype: Either a destination tensor or a dtype (Float16/BFloat16)\n\n    Returns:\n        None if dst is a tensor, or a new tensor if dtype is provided\n    \"\"\"\n    if const_expr(isinstance(dst_or_dtype, type)):\n        # dtype variant: create new tensor and call the tensor variant\n        dtype = dst_or_dtype\n        dst = cute.make_fragment(src.shape, dtype)\n        cvt_f16(src, dst)\n        return dst\n    else:\n        # tensor variant: write to dst\n        dst = dst_or_dtype\n        assert cute.size(dst.shape) == cute.size(src.shape), \"dst and src must have the same size\"\n        assert cute.size(src.shape) % 2 == 0, \"src must have an even number of elements\"\n        assert dst.element_type in [cutlass.BFloat16, cutlass.Float16], (\n            \"dst must be BFloat16 or Float16\"\n        )\n        assert src.element_type is Float32, \"src must be Float32\"\n        dst_i32 = cute.recast_tensor(dst, cutlass.Int32)\n        assert cute.size(dst_i32.shape) * 2 == cute.size(src.shape)\n        for i in cutlass.range_constexpr(cute.size(dst_i32)):\n            dst_i32[i] = cvt_f16x2_f32(src[2 * i], src[2 * i + 1], dst.element_type)\n\n\n@dsl_user_op\n@cute.jit\ndef evaluate_polynomial(x: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None) -> Float32:\n    deg = len(poly) - 1\n    out = poly[deg]\n    for i in cutlass.range_constexpr(deg - 1, -1, -1):\n        out = out * x + poly[i]\n    return out\n\n\n@dsl_user_op\n@cute.jit\ndef evaluate_polynomial_2(\n    x: Float32, y: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None\n) -> Tuple[Float32, Float32]:\n    deg = len(poly) - 1\n    out = (poly[deg], poly[deg])\n    for i in cutlass.range_constexpr(deg - 1, -1, -1):\n        out = cute.arch.fma_packed_f32x2(out, (x, y), (poly[i], poly[i]))\n    return out\n\n\n@dsl_user_op\ndef add_round_down(x: float | Float32, y: float | Float32, *, loc=None, ip=None) -> Float32:\n    # There's probably a way to call llvm or nvvm to do this instead of ptx\n    return cutlass.Float32(\n        llvm.inline_asm(\n            T.f32(),\n            [Float32(x).ir_value(loc=loc, ip=ip), Float32(y).ir_value(loc=loc, ip=ip)],\n            \"add.rm.ftz.f32 $0, $1, $2;\",\n            \"=f,f,f\",\n            has_side_effects=False,\n            is_align_stack=False,\n            asm_dialect=llvm.AsmDialect.AD_ATT,\n        )\n    )\n\n\n@dsl_user_op\ndef combine_int_frac_ex2(x_rounded: Float32, frac_ex2: Float32, *, loc=None, ip=None) -> Float32:\n    return cutlass.Float32(\n        llvm.inline_asm(\n            T.f32(),\n            [\n                Float32(x_rounded).ir_value(loc=loc, ip=ip),\n                Float32(frac_ex2).ir_value(loc=loc, ip=ip),\n            ],\n            \"{\\n\\t\"\n            \".reg .s32 x_rounded_i, frac_ex_i, x_rounded_e, out_i;\\n\\t\"\n            \"mov.b32 x_rounded_i, $1;\\n\\t\"\n            \"mov.b32 frac_ex_i, $2;\\n\\t\"\n            \"shl.b32 x_rounded_e, x_rounded_i, 23;\\n\\t\"\n            # add.u32 generates IMAD instruction and add.s32 generates LEA instruction\n            # IMAD uses the FMA pipeline and LEA uses the ALU pipeline, afaik\n            \"add.s32 out_i, x_rounded_e, frac_ex_i;\\n\\t\"\n            \"mov.b32 $0, out_i;\\n\\t\"\n            \"}\\n\",\n            \"=f,f,f\",\n            has_side_effects=False,\n            is_align_stack=False,\n            asm_dialect=llvm.AsmDialect.AD_ATT,\n        )\n    )\n\n\n@dsl_user_op\ndef ex2_emulation(x: Float32, *, poly_degree: int = 3, loc=None, ip=None) -> Float32:\n    assert poly_degree in POLY_EX2, f\"Polynomial degree {poly_degree} not supported\"\n    # We assume x <= 127.0\n    fp32_round_int = float(2**23 + 2**22)\n    x_clamped = cute.arch.fmax(x, -127.0)\n    # We want to round down here, so that the fractional part is in [0, 1)\n    x_rounded = add_round_down(x_clamped, fp32_round_int, loc=loc, ip=ip)\n    # The integer floor of x is now in the last 8 bits of x_rounded\n    # We assume the next 2 ops round to nearest even. The rounding mode is important.\n    x_rounded_back = x_rounded - fp32_round_int\n    x_frac = x_clamped - x_rounded_back\n    x_frac_ex2 = evaluate_polynomial(x_frac, POLY_EX2[poly_degree], loc=loc, ip=ip)\n    return combine_int_frac_ex2(x_rounded, x_frac_ex2, loc=loc, ip=ip)\n\n\n# TODO: check that the ex2_emulation_2 produces the same SASS as the ptx version\n@dsl_user_op\ndef ex2_emulation_2(\n    x: Float32, y: Float32, *, poly_degree: int = 3, loc=None, ip=None\n) -> Tuple[Float32, Float32]:\n    # We assume x <= 127.0 and y <= 127.0\n    fp32_round_int = float(2**23 + 2**22)\n    xy_clamped = (cute.arch.fmax(x, -127.0), cute.arch.fmax(y, -127.0))\n    # We want to round down here, so that the fractional part is in [0, 1)\n    xy_rounded = cute.arch.add_packed_f32x2(xy_clamped, (fp32_round_int, fp32_round_int), rnd=\"rm\")\n    # The integer floor of x & y are now in the last 8 bits of xy_rounded\n    # We want the next 2 ops to round to nearest even. The rounding mode is important.\n    xy_rounded_back = quack.activation.sub_packed_f32x2(\n        xy_rounded, (fp32_round_int, fp32_round_int)\n    )\n    xy_frac = quack.activation.sub_packed_f32x2(xy_clamped, xy_rounded_back)\n    xy_frac_ex2 = evaluate_polynomial_2(*xy_frac, POLY_EX2[poly_degree], loc=loc, ip=ip)\n    x_out = combine_int_frac_ex2(xy_rounded[0], xy_frac_ex2[0], loc=loc, ip=ip)\n    y_out = combine_int_frac_ex2(xy_rounded[1], xy_frac_ex2[1], loc=loc, ip=ip)\n    return x_out, y_out\n\n\n@dsl_user_op\ndef e2e_asm2(x: Float32, y: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]:\n    out_f32x2 = llvm.inline_asm(\n        llvm.StructType.get_literal([T.f32(), T.f32()]),\n        [Float32(x).ir_value(loc=loc, ip=ip), Float32(y, loc=loc, ip=ip).ir_value()],\n        \"{\\n\\t\"\n        \".reg .f32 f1, f2, f3, f4, f5, f6, f7;\\n\\t\"\n        \".reg .b64 l1, l2, l3, l4, l5, l6, l7, l8, l9, l10;\\n\\t\"\n        \".reg .s32 r1, r2, r3, r4, r5, r6, r7, r8;\\n\\t\"\n        \"max.ftz.f32 f1, $2, 0fC2FE0000;\\n\\t\"\n        \"max.ftz.f32 f2, $3, 0fC2FE0000;\\n\\t\"\n        \"mov.b64 l1, {f1, f2};\\n\\t\"\n        \"mov.f32 f3, 0f4B400000;\\n\\t\"\n        \"mov.b64 l2, {f3, f3};\\n\\t\"\n        \"add.rm.ftz.f32x2 l7, l1, l2;\\n\\t\"\n        \"sub.rn.ftz.f32x2 l8, l7, l2;\\n\\t\"\n        \"sub.rn.ftz.f32x2 l9, l1, l8;\\n\\t\"\n        \"mov.f32 f7, 0f3D9DF09D;\\n\\t\"\n        \"mov.b64 l6, {f7, f7};\\n\\t\"\n        \"mov.f32 f6, 0f3E6906A4;\\n\\t\"\n        \"mov.b64 l5, {f6, f6};\\n\\t\"\n        \"mov.f32 f5, 0f3F31F519;\\n\\t\"\n        \"mov.b64 l4, {f5, f5};\\n\\t\"\n        \"mov.f32 f4, 0f3F800000;\\n\\t\"\n        \"mov.b64 l3, {f4, f4};\\n\\t\"\n        \"fma.rn.ftz.f32x2 l10, l9, l6, l5;\\n\\t\"\n        \"fma.rn.ftz.f32x2 l10, l10, l9, l4;\\n\\t\"\n        \"fma.rn.ftz.f32x2 l10, l10, l9, l3;\\n\\t\"\n        \"mov.b64 {r1, r2}, l7;\\n\\t\"\n        \"mov.b64 {r3, r4}, l10;\\n\\t\"\n        \"shl.b32 r5, r1, 23;\\n\\t\"\n        \"add.s32 r7, r5, r3;\\n\\t\"\n        \"shl.b32 r6, r2, 23;\\n\\t\"\n        \"add.s32 r8, r6, r4;\\n\\t\"\n        \"mov.b32 $0, r7;\\n\\t\"\n        \"mov.b32 $1, r8;\\n\\t\"\n        \"}\\n\",\n        \"=r,=r,f,f\",\n        has_side_effects=False,\n        is_align_stack=False,\n        asm_dialect=llvm.AsmDialect.AD_ATT,\n    )\n    out0 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [0], loc=loc, ip=ip))\n    out1 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [1], loc=loc, ip=ip))\n    return out0, out1\n\n\n@dsl_user_op\ndef domain_offset_aligned(\n    coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None\n) -> cute.Tensor:\n    assert isinstance(tensor.iterator, cute.Pointer)\n    # We assume that applying the offset does not change the pointer alignment\n    new_ptr = cute.make_ptr(\n        tensor.element_type,\n        elem_pointer(tensor, coord).toint(),\n        tensor.memspace,\n        assumed_align=tensor.iterator.alignment,\n    )\n    return cute.make_tensor(new_ptr, tensor.layout)\n\n\n@cute.jit\ndef scalar_to_ssa(a: cute.Numeric, dtype) -> cute.TensorSSA:\n    \"\"\"Convert a scalar to a cute TensorSSA of shape (1,) and given dtype\"\"\"\n    vec = cute.make_fragment(1, dtype)\n    vec[0] = a\n    return vec.load()\n\n\ndef ssa_to_scalar(val):\n    \"\"\"Could inline but nice for reflecting the above api\"\"\"\n    return val[0]\n"
  },
  {
    "path": "flash_attn/flash_attn_interface.py",
    "content": "# Copyright (c) 2023, Tri Dao.\n\nfrom typing import Optional, Sequence, Tuple, Union\n\nimport torch\nimport torch.nn as nn\nimport os\nimport warnings\n\n# isort: off\n# We need to import the CUDA kernels after importing torch\nUSE_TRITON_ROCM = os.getenv(\"FLASH_ATTENTION_TRITON_AMD_ENABLE\", \"FALSE\") == \"TRUE\"\nif not USE_TRITON_ROCM and getattr(torch.version, 'hip', None) is not None:\n    try:\n        import flash_attn_2_cuda\n    except ImportError:\n        warnings.warn(\"flash_attn_2_cuda (which has ROCm/HIP kernels) not found, falling back to Triton implementation\")\n        USE_TRITON_ROCM = True\n\nif USE_TRITON_ROCM:\n    from aiter.ops.triton._triton_kernels.flash_attn_triton_amd import flash_attn_2 as flash_attn_gpu\nelse:\n    import flash_attn_2_cuda as flash_attn_gpu\n\n# isort: on\n\ndef maybe_contiguous(x):\n    return x.contiguous() if x is not None and x.stride(-1) != 1 else x\n\n\ndef _get_block_size_n(device, head_dim, is_dropout, is_causal):\n    # This should match the block sizes in the CUDA kernel\n    assert head_dim <= 256\n    major, minor = torch.cuda.get_device_capability(device)\n    is_sm8x = major == 8 and minor > 0  # Only include sm86 and sm89, exclude sm80 (A100)\n    is_sm80 = major == 8 and minor == 0\n    is_sm90 = major == 9 and minor == 0\n    if head_dim <= 32:\n        return 128\n    if head_dim <= 64:\n        return 128 if not is_dropout else 64\n    elif head_dim <= 96:\n        return 64\n    elif head_dim <= 128:\n        if is_sm8x:\n            return 64 if (not is_dropout and is_causal) else 32\n        else:\n            return 64 if not is_dropout else 32\n    elif head_dim <= 192:\n        return 64\n    elif head_dim <= 224:\n        return 64\n    elif head_dim <= 256:\n        return 64\n\n\ndef round_multiple(x, m):\n    return (x + m - 1) // m * m\n\n\n# torch.compile() support is only enabled for pytorch >= 2.4\n# The reason for this is that we are using the new custom_op and register_fake\n# APIs, which support inplace modification of inputs in the function itself\nif torch.__version__ >= \"2.4.0\":\n    _torch_custom_op_wrapper = torch.library.custom_op\n    _torch_register_fake_wrapper = torch.library.register_fake\nelse:\n    def noop_custom_op_wrapper(name, fn=None, /, *, mutates_args, device_types=None, schema=None):\n        def wrap(func):\n            return func\n        if fn is None:\n            return wrap\n        return fn\n    def noop_register_fake_wrapper(op, fn=None, /, *, lib=None, _stacklevel=1):\n        def wrap(func):\n            return func\n        if fn is None:\n            return wrap\n        return fn\n    _torch_custom_op_wrapper = noop_custom_op_wrapper\n    _torch_register_fake_wrapper = noop_register_fake_wrapper\n\n\n@_torch_custom_op_wrapper(\"flash_attn::_flash_attn_forward\", mutates_args=(), device_types=\"cuda\")\ndef _flash_attn_forward(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    dropout_p: float,\n    softmax_scale: float,\n    causal: bool,\n    window_size_left: int,\n    window_size_right: int,\n    softcap: float,\n    alibi_slopes: Optional[torch.Tensor],\n    return_softmax: bool\n) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]\n    out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.fwd(\n        q,\n        k,\n        v,\n        None,\n        alibi_slopes,\n        dropout_p,\n        softmax_scale,\n        causal,\n        window_size_left,\n        window_size_right,\n        softcap,\n        return_softmax,\n        None,\n    )\n    return out, softmax_lse, S_dmask, rng_state\n\n\n@_torch_register_fake_wrapper(\"flash_attn::_flash_attn_forward\")\ndef _flash_attn_forward_fake(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    dropout_p: float,\n    softmax_scale: float,\n    causal: bool,\n    window_size_left: int,\n    window_size_right: int,\n    softcap: float,\n    alibi_slopes: Optional[torch.Tensor],\n    return_softmax: bool\n) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]\n    batch_size, seqlen_q, num_heads, head_size = q.shape\n    seqlen_k = k.shape[1]\n    out = torch.empty_like(q)\n    softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device, layout=q.layout)\n    p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout)\n    if return_softmax:\n        if torch.cuda.is_available() and torch.version.hip:\n            p = torch.empty((batch_size, num_heads, seqlen_q, seqlen_k), dtype=q.dtype, device=q.device, layout=q.layout)\n        else:\n            p = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128), round_multiple(seqlen_k, 128)), dtype=q.dtype, device=q.device, layout=q.layout)\n    rng_state = torch.empty((2,), dtype=torch.int64, device=q.device)\n\n    return out, softmax_lse, p, rng_state\n\n\nif torch.__version__ >= \"2.4.0\":\n    _wrapped_flash_attn_forward = torch.ops.flash_attn._flash_attn_forward\nelse:\n    _wrapped_flash_attn_forward = _flash_attn_forward\n\n\n@_torch_custom_op_wrapper(\"flash_attn::_flash_attn_varlen_forward\", mutates_args=(), device_types=\"cuda\")\ndef _flash_attn_varlen_forward(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    cu_seqlens_q: torch.Tensor,\n    cu_seqlens_k: torch.Tensor,\n    max_seqlen_q: int,\n    max_seqlen_k: int,\n    dropout_p: float,\n    softmax_scale: float,\n    causal: bool,\n    window_size_left: int = -1,\n    window_size_right: int = -1,\n    softcap: float = 0.0,\n    alibi_slopes: Optional[torch.Tensor] = None,\n    return_softmax: bool = False,\n    block_table: Optional[torch.Tensor] = None,\n    leftpad_k: Optional[torch.Tensor] = None,\n    seqused_k: Optional[torch.Tensor] = None,\n    zero_tensors: bool = False,\n) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]\n    out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd(\n        q,\n        k,\n        v,\n        None,\n        cu_seqlens_q,\n        cu_seqlens_k,\n        seqused_k,\n        leftpad_k,\n        block_table,\n        alibi_slopes,\n        max_seqlen_q,\n        max_seqlen_k,\n        dropout_p,\n        softmax_scale,\n        zero_tensors,\n        causal,\n        window_size_left,\n        window_size_right,\n        softcap,\n        return_softmax,\n        None,\n    )\n    # if out.isnan().any() or softmax_lse.isnan().any():\n    #     breakpoint()\n    return out, softmax_lse, S_dmask, rng_state\n\n\n@_torch_register_fake_wrapper(\"flash_attn::_flash_attn_varlen_forward\")\ndef _flash_attn_varlen_forward_fake(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    cu_seqlens_q: torch.Tensor,\n    cu_seqlens_k: torch.Tensor,\n    max_seqlen_q: int,\n    max_seqlen_k: int,\n    dropout_p: float,\n    softmax_scale: float,\n    causal: bool,\n    window_size_left: int = -1,\n    window_size_right: int = -1,\n    softcap: float = 0.0,\n    alibi_slopes: Optional[torch.Tensor] = None,\n    return_softmax: bool = False,\n    block_table: Optional[torch.Tensor] = None,\n    leftpad_k: Optional[torch.Tensor] = None,\n    seqused_k: Optional[torch.Tensor] = None,\n    zero_tensors: bool = False,\n) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]\n    paged_kv = block_table is not None\n    batch_size = cu_seqlens_q.numel() - 1\n    total_q, num_heads, _ = q.shape\n    \n    out = torch.empty_like(q)\n    softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device, layout=q.layout)\n    p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout)\n    if return_softmax:\n        if torch.cuda.is_available() and torch.version.hip:\n            p = torch.empty((batch_size, num_heads, max_seqlen_q, max_seqlen_k), dtype=q.dtype, device=q.device, layout=q.layout)\n        else:\n            p = torch.empty((batch_size, num_heads, round_multiple(max_seqlen_q, 128), round_multiple(max_seqlen_k, 128)), dtype=q.dtype, device=q.device, layout=q.layout)\n    rng_state = torch.empty((2,), dtype=torch.int64, device=q.device)\n    return out, softmax_lse, p, rng_state\n\n\nif torch.__version__ >= \"2.4.0\":\n    _wrapped_flash_attn_varlen_forward = torch.ops.flash_attn._flash_attn_varlen_forward\nelse:\n    _wrapped_flash_attn_varlen_forward = _flash_attn_varlen_forward\n\n\n@_torch_custom_op_wrapper(\"flash_attn::_flash_attn_backward\", mutates_args=(\"dq\", \"dk\", \"dv\"), device_types=\"cuda\")\ndef _flash_attn_backward(\n    dout: torch.Tensor,\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    out: torch.Tensor,\n    softmax_lse: torch.Tensor,\n    dq: Optional[torch.Tensor],\n    dk: Optional[torch.Tensor],\n    dv: Optional[torch.Tensor],\n    dropout_p: float,\n    softmax_scale: float,\n    causal: bool,\n    window_size_left: int,\n    window_size_right: int,\n    softcap: float,\n    alibi_slopes: Optional[torch.Tensor],\n    deterministic: bool,\n    rng_state: Optional[torch.Tensor] = None,\n) -> torch.Tensor:\n    # dq, dk, dv are allocated by us so they should already be contiguous\n    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]\n    (\n        dq,\n        dk,\n        dv,\n        softmax_d,\n    ) = flash_attn_gpu.bwd(\n        dout,\n        q,\n        k,\n        v,\n        out,\n        softmax_lse,\n        dq,\n        dk,\n        dv,\n        alibi_slopes,\n        dropout_p,\n        softmax_scale,\n        causal,\n        window_size_left,\n        window_size_right,\n        softcap,\n        deterministic,\n        None,\n        rng_state,\n    )\n    return softmax_d\n\n\n@_torch_register_fake_wrapper(\"flash_attn::_flash_attn_backward\")\ndef _flash_attn_backward_fake(\n    dout: torch.Tensor,\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    out: torch.Tensor,\n    softmax_lse: torch.Tensor,\n    dq: Optional[torch.Tensor],\n    dk: Optional[torch.Tensor],\n    dv: Optional[torch.Tensor],\n    dropout_p: float,\n    softmax_scale: float,\n    causal: bool,\n    window_size_left: int,\n    window_size_right: int,\n    softcap: float,\n    alibi_slopes: Optional[torch.Tensor],\n    deterministic: bool,\n    rng_state: Optional[torch.Tensor] = None,\n) -> torch.Tensor:\n    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]\n    if dq is None:\n        dq = torch.empty_like(q)\n    if dk is None:\n        dk = torch.empty_like(k)\n    if dv is None:\n        dv = torch.empty_like(v)\n    batch_size, seqlen_q, num_heads, _ = q.shape\n    if torch.cuda.is_available() and torch.version.hip:\n        softmax_d = torch.empty((batch_size, num_heads, seqlen_q), device=q.device, dtype=torch.float32)\n    else:\n        softmax_d = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128)), device=q.device, dtype=torch.float32)\n    \n    return softmax_d\n\n\nif torch.__version__ >= \"2.4.0\":\n    _wrapped_flash_attn_backward = torch.ops.flash_attn._flash_attn_backward\nelse:\n    _wrapped_flash_attn_backward = _flash_attn_backward\n\n\n@_torch_custom_op_wrapper(\"flash_attn::_flash_attn_varlen_backward\", mutates_args=(\"dq\", \"dk\", \"dv\"), device_types=\"cuda\")\ndef _flash_attn_varlen_backward(\n    dout: torch.Tensor,\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    out: torch.Tensor,\n    softmax_lse: torch.Tensor,\n    dq: Optional[torch.Tensor],\n    dk: Optional[torch.Tensor],\n    dv: Optional[torch.Tensor],\n    cu_seqlens_q: torch.Tensor,\n    cu_seqlens_k: torch.Tensor,\n    max_seqlen_q: int,\n    max_seqlen_k: int,\n    dropout_p: float,\n    softmax_scale: float,\n    causal: bool,\n    window_size_left: int,\n    window_size_right: int,\n    softcap: float,\n    alibi_slopes: Optional[torch.Tensor],\n    deterministic: bool,\n    rng_state: Optional[torch.Tensor] = None,\n    zero_tensors: bool = False,\n) -> torch.Tensor:\n    # dq, dk, dv are allocated by us so they should already be contiguous\n    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]\n    (\n        dq,\n        dk,\n        dv,\n        softmax_d,\n    ) = flash_attn_gpu.varlen_bwd(\n        dout,\n        q,\n        k,\n        v,\n        out,\n        softmax_lse,\n        dq,\n        dk,\n        dv,\n        cu_seqlens_q,\n        cu_seqlens_k,\n        alibi_slopes,\n        max_seqlen_q,\n        max_seqlen_k,\n        dropout_p,\n        softmax_scale,\n        zero_tensors,\n        causal,\n        window_size_left,\n        window_size_right,\n        softcap,\n        deterministic,\n        None,\n        rng_state,\n    )\n    # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():\n    #     breakpoint()\n    return softmax_d\n\n\n@_torch_register_fake_wrapper(\"flash_attn::_flash_attn_varlen_backward\")\ndef _flash_attn_varlen_backward_fake(\n    dout: torch.Tensor,\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    out: torch.Tensor,\n    softmax_lse: torch.Tensor,\n    dq: Optional[torch.Tensor],\n    dk: Optional[torch.Tensor],\n    dv: Optional[torch.Tensor],\n    cu_seqlens_q: torch.Tensor,\n    cu_seqlens_k: torch.Tensor,\n    max_seqlen_q: int,\n    max_seqlen_k: int,\n    dropout_p: float,\n    softmax_scale: float,\n    causal: bool,\n    window_size_left: int,\n    window_size_right: int,\n    softcap: float,\n    alibi_slopes: Optional[torch.Tensor],\n    deterministic: bool,\n    rng_state: Optional[torch.Tensor] = None,\n    zero_tensors: bool = False,\n) -> torch.Tensor:\n    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]\n    batch_size = cu_seqlens_q.numel() - 1\n    total_q, num_heads, _ = q.shape\n\n    if dq is None:\n        dq = torch.empty_like(q)\n    if dk is None:\n        dk = torch.empty_like(k)\n    if dv is None:\n        dv = torch.empty_like(v)\n    if torch.cuda.is_available() and torch.version.hip:\n        softmax_d = torch.empty((num_heads, total_q), device=q.device, dtype=torch.float32)\n    else:\n        softmax_d = torch.empty((num_heads, total_q + 128 * batch_size), device=q.device, dtype=torch.float32)\n    \n    return softmax_d\n\n\nif torch.__version__ >= \"2.4.0\":\n    _wrapped_flash_attn_varlen_backward = torch.ops.flash_attn._flash_attn_varlen_backward\nelse:\n    _wrapped_flash_attn_varlen_backward = _flash_attn_varlen_backward\n\n\nclass FlashAttnQKVPackedFunc(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        qkv,\n        dropout_p,\n        softmax_scale,\n        causal,\n        window_size,\n        softcap,\n        alibi_slopes,\n        deterministic,\n        return_softmax,\n        is_grad_enabled,\n    ):\n        is_grad = is_grad_enabled and qkv.requires_grad\n        if softmax_scale is None:\n            softmax_scale = qkv.shape[-1] ** (-0.5)\n        q, k, v = qkv[:, :, 0].detach(), qkv[:, :, 1].detach(), qkv[:, :, 2].detach()\n        head_size_og = q.size(3)\n        if head_size_og % 8 != 0:\n            q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])\n            k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])\n            v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])\n        out_padded, softmax_lse, S_dmask, rng_state =  _wrapped_flash_attn_forward(\n            q,\n            k,\n            v,\n            dropout_p,\n            softmax_scale,\n            causal=causal,\n            window_size_left=window_size[0],\n            window_size_right=window_size[1],\n            softcap=softcap,\n            alibi_slopes=alibi_slopes,\n            return_softmax=return_softmax and dropout_p > 0,\n        )\n        if is_grad:\n            ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)\n            ctx.dropout_p = dropout_p\n            ctx.softmax_scale = softmax_scale\n            ctx.causal = causal\n            ctx.window_size = window_size\n            ctx.softcap = softcap\n            ctx.alibi_slopes = alibi_slopes\n            ctx.deterministic = deterministic\n        out = out_padded[..., :head_size_og]\n        return out if not return_softmax else (out, softmax_lse, S_dmask)\n\n    @staticmethod\n    def backward(ctx, dout, *args):\n        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors\n        qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])\n        dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)\n        head_size_og = dout.size(3)\n        dout_padded = dout\n        if head_size_og % 8 != 0:\n            dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])\n        _wrapped_flash_attn_backward(\n            dout_padded,\n            q,\n            k,\n            v,\n            out,\n            softmax_lse,\n            dqkv[:, :, 0],\n            dqkv[:, :, 1],\n            dqkv[:, :, 2],\n            ctx.dropout_p,\n            ctx.softmax_scale,\n            ctx.causal,\n            ctx.window_size[0],\n            ctx.window_size[1],\n            ctx.softcap,\n            ctx.alibi_slopes,\n            ctx.deterministic,\n            rng_state=rng_state,\n        )\n        dqkv = dqkv[..., : dout.shape[-1]]  # We could have padded the head dimension\n        return dqkv, None, None, None, None, None, None, None, None, None\n\n\nclass FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        qkv,\n        cu_seqlens,\n        max_seqlen,\n        dropout_p,\n        softmax_scale,\n        causal,\n        window_size,\n        softcap,\n        alibi_slopes,\n        deterministic,\n        return_softmax,\n        is_grad_enabled,\n    ):\n        is_grad = is_grad_enabled and qkv.requires_grad\n        if softmax_scale is None:\n            softmax_scale = qkv.shape[-1] ** (-0.5)\n        q, k, v = qkv[:, 0].detach(), qkv[:, 1].detach(), qkv[:, 2].detach()\n        head_size_og = q.size(2)\n        if head_size_og % 8 != 0:\n            q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])\n            k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])\n            v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])\n        out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward(\n            q,\n            k,\n            v,\n            cu_seqlens,\n            cu_seqlens,\n            max_seqlen,\n            max_seqlen,\n            dropout_p,\n            softmax_scale,\n            causal=causal,\n            window_size_left=window_size[0],\n            window_size_right=window_size[1],\n            softcap=softcap,\n            alibi_slopes=alibi_slopes,\n            return_softmax=return_softmax and dropout_p > 0,\n            block_table=None,\n        )\n        if is_grad:\n            ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)\n            ctx.dropout_p = dropout_p\n            ctx.max_seqlen = max_seqlen\n            ctx.softmax_scale = softmax_scale\n            ctx.causal = causal\n            ctx.window_size = window_size\n            ctx.softcap = softcap\n            ctx.alibi_slopes = alibi_slopes\n            ctx.deterministic = deterministic\n        out = out_padded[..., :head_size_og]\n        return out if not return_softmax else (out, softmax_lse, S_dmask)\n\n    @staticmethod\n    def backward(ctx, dout, *args):\n        q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors\n        qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])\n        dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)\n        head_size_og = dout.size(2)\n        dout_padded = dout\n        if head_size_og % 8 != 0:\n            dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])\n        _wrapped_flash_attn_varlen_backward(\n            dout_padded,\n            q,\n            k,\n            v,\n            out,\n            softmax_lse,\n            dqkv[:, 0],\n            dqkv[:, 1],\n            dqkv[:, 2],\n            cu_seqlens,\n            cu_seqlens,\n            ctx.max_seqlen,\n            ctx.max_seqlen,\n            ctx.dropout_p,\n            ctx.softmax_scale,\n            ctx.causal,\n            ctx.window_size[0],\n            ctx.window_size[1],\n            ctx.softcap,\n            ctx.alibi_slopes,\n            ctx.deterministic,\n            rng_state=rng_state,\n        )\n        dqkv = dqkv[..., : dout.shape[-1]]  # We could have padded the head dimension\n        return dqkv, None, None, None, None, None, None, None, None, None, None, None\n\n\nclass FlashAttnKVPackedFunc(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        q,\n        kv,\n        dropout_p,\n        softmax_scale,\n        causal,\n        window_size,\n        softcap,\n        alibi_slopes,\n        deterministic,\n        return_softmax,\n        is_grad_enabled,\n    ):\n        is_grad = is_grad_enabled and any(\n            x.requires_grad for x in [q, kv]\n        )\n        if softmax_scale is None:\n            softmax_scale = q.shape[-1] ** (-0.5)\n        k, v = kv[:, :, 0].detach(), kv[:, :, 1].detach()\n        head_size_og = q.size(3)\n        if head_size_og % 8 != 0:\n            q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])\n            k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])\n            v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])\n        out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_forward(\n            q,\n            k,\n            v,\n            dropout_p,\n            softmax_scale,\n            causal=causal,\n            window_size_left=window_size[0],\n            window_size_right=window_size[1],\n            softcap=softcap,\n            alibi_slopes=alibi_slopes,\n            return_softmax=return_softmax and dropout_p > 0,\n        )\n        if is_grad:\n            ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)\n            ctx.dropout_p = dropout_p\n            ctx.softmax_scale = softmax_scale\n            ctx.causal = causal\n            ctx.window_size = window_size\n            ctx.softcap = softcap\n            ctx.alibi_slopes = alibi_slopes\n            ctx.deterministic = deterministic\n        out = out_padded[..., :head_size_og]\n        return out if not return_softmax else (out, softmax_lse, S_dmask)\n\n    @staticmethod\n    def backward(ctx, dout, *args):\n        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors\n        dq = torch.empty_like(q)\n        kv_shape = k.shape[:-2] + (2, *k.shape[-2:])\n        dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)\n        head_size_og = dout.size(3)\n        dout_padded = dout\n        if head_size_og % 8 != 0:\n            dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])\n        _wrapped_flash_attn_backward(\n            dout_padded,\n            q,\n            k,\n            v,\n            out,\n            softmax_lse,\n            dq,\n            dkv[:, :, 0],\n            dkv[:, :, 1],\n            ctx.dropout_p,\n            ctx.softmax_scale,\n            ctx.causal,\n            ctx.window_size[0],\n            ctx.window_size[1],\n            ctx.softcap,\n            ctx.alibi_slopes,\n            ctx.deterministic,\n            rng_state=rng_state,\n        )\n        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension\n        dkv = dkv[..., : dout.shape[-1]]\n        return dq, dkv, None, None, None, None, None, None, None, None, None\n\n\nclass FlashAttnVarlenKVPackedFunc(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        q,\n        kv,\n        cu_seqlens_q,\n        cu_seqlens_k,\n        max_seqlen_q,\n        max_seqlen_k,\n        dropout_p,\n        softmax_scale,\n        causal,\n        window_size,\n        softcap,\n        alibi_slopes,\n        deterministic,\n        return_softmax,\n        is_grad_enabled,\n    ):\n        is_grad = is_grad_enabled and any(\n            x.requires_grad for x in [q, kv]\n        )\n        if softmax_scale is None:\n            softmax_scale = q.shape[-1] ** (-0.5)\n        k, v = kv[:, 0].detach(), kv[:, 1].detach()\n        head_size_og = q.size(2)\n        if head_size_og % 8 != 0:\n            q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])\n            k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])\n            v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])\n        out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward(\n            q,\n            k,\n            v,\n            cu_seqlens_q,\n            cu_seqlens_k,\n            max_seqlen_q,\n            max_seqlen_k,\n            dropout_p,\n            softmax_scale,\n            causal=causal,\n            window_size_left=window_size[0],\n            window_size_right=window_size[1],\n            softcap=softcap,\n            alibi_slopes=alibi_slopes,\n            return_softmax=return_softmax and dropout_p > 0,\n            block_table=None,\n        )\n        if is_grad:\n            ctx.save_for_backward(\n                q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state\n            )\n            ctx.dropout_p = dropout_p\n            ctx.max_seqlen_q = max_seqlen_q\n            ctx.max_seqlen_k = max_seqlen_k\n            ctx.softmax_scale = softmax_scale\n            ctx.causal = causal\n            ctx.window_size = window_size\n            ctx.softcap = softcap\n            ctx.alibi_slopes = alibi_slopes\n            ctx.deterministic = deterministic\n        out = out_padded[..., :head_size_og]\n        return out if not return_softmax else (out, softmax_lse, S_dmask)\n\n    @staticmethod\n    def backward(ctx, dout, *args):\n        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors\n        dq = torch.empty_like(q)\n        kv_shape = k.shape[:-2] + (2, *k.shape[-2:])\n        dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)\n        head_size_og = dout.size(2)\n        dout_padded = dout\n        if head_size_og % 8 != 0:\n            dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])\n        _wrapped_flash_attn_varlen_backward(\n            dout_padded,\n            q,\n            k,\n            v,\n            out,\n            softmax_lse,\n            dq,\n            dkv[:, 0],\n            dkv[:, 1],\n            cu_seqlens_q,\n            cu_seqlens_k,\n            ctx.max_seqlen_q,\n            ctx.max_seqlen_k,\n            ctx.dropout_p,\n            ctx.softmax_scale,\n            ctx.causal,\n            ctx.window_size[0],\n            ctx.window_size[1],\n            ctx.softcap,\n            ctx.alibi_slopes,\n            ctx.deterministic,\n            rng_state=rng_state,\n        )\n        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension\n        dkv = dkv[..., : dout.shape[-1]]\n        return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None, None\n\n\nclass FlashAttnFunc(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        q,\n        k,\n        v,\n        dropout_p,\n        softmax_scale,\n        causal,\n        window_size,\n        softcap,\n        alibi_slopes,\n        deterministic,\n        return_softmax,\n        is_grad_enabled,\n    ):\n        is_grad = is_grad_enabled and any(\n            x.requires_grad for x in [q, k, v]\n        )\n        if softmax_scale is None:\n            softmax_scale = q.shape[-1] ** (-0.5)\n        head_size_og = q.size(3)\n        if head_size_og % 8 != 0:\n            q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])\n            k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])\n            v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])\n        out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_forward(\n            q,\n            k,\n            v,\n            dropout_p,\n            softmax_scale,\n            causal=causal,\n            window_size_left=window_size[0],\n            window_size_right=window_size[1],\n            softcap=softcap,\n            alibi_slopes=alibi_slopes,\n            return_softmax=return_softmax and dropout_p > 0,\n        )\n        if is_grad:\n            ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)\n            ctx.dropout_p = dropout_p\n            ctx.softmax_scale = softmax_scale\n            ctx.causal = causal\n            ctx.window_size = window_size\n            ctx.softcap = softcap\n            ctx.alibi_slopes = alibi_slopes\n            ctx.deterministic = deterministic\n        out = out_padded[..., :head_size_og]\n        return out if not return_softmax else (out, softmax_lse, S_dmask)\n\n    @staticmethod\n    def backward(ctx, dout, *args):\n        q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors\n        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)\n        head_size_og = dout.size(3)\n        dout_padded = dout\n        if head_size_og % 8 != 0:\n            dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])\n        _wrapped_flash_attn_backward(\n            dout_padded,\n            q,\n            k,\n            v,\n            out,\n            softmax_lse,\n            dq,\n            dk,\n            dv,\n            ctx.dropout_p,\n            ctx.softmax_scale,\n            ctx.causal,\n            ctx.window_size[0],\n            ctx.window_size[1],\n            ctx.softcap,\n            ctx.alibi_slopes,\n            ctx.deterministic,\n            rng_state=rng_state,\n        )\n        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension\n        dk = dk[..., : dout.shape[-1]]\n        dv = dv[..., : dout.shape[-1]]\n        return dq, dk, dv, None, None, None, None, None, None, None, None, None\n\n\nclass FlashAttnVarlenFunc(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        q,\n        k,\n        v,\n        cu_seqlens_q,\n        cu_seqlens_k,\n        max_seqlen_q,\n        max_seqlen_k,\n        dropout_p,\n        softmax_scale,\n        causal,\n        window_size,\n        softcap,\n        alibi_slopes,\n        deterministic,\n        return_softmax,\n        block_table,\n        is_grad_enabled,\n    ):\n        is_grad = is_grad_enabled and any(\n            x.requires_grad for x in [q, k, v]\n        )\n        if softmax_scale is None:\n            softmax_scale = q.shape[-1] ** (-0.5)\n        head_size_og = q.size(2)\n        if head_size_og % 8 != 0:\n            q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])\n            k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])\n            v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])\n        out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward(\n            q,\n            k,\n            v,\n            cu_seqlens_q,\n            cu_seqlens_k,\n            max_seqlen_q,\n            max_seqlen_k,\n            dropout_p,\n            softmax_scale,\n            causal=causal,\n            window_size_left=window_size[0],\n            window_size_right=window_size[1],\n            softcap=softcap,\n            alibi_slopes=alibi_slopes,\n            return_softmax=return_softmax and dropout_p > 0,\n            block_table=block_table,\n        )\n        if is_grad:\n            ctx.save_for_backward(\n                q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state\n            )\n            ctx.dropout_p = dropout_p\n            ctx.max_seqlen_q = max_seqlen_q\n            ctx.max_seqlen_k = max_seqlen_k\n            ctx.softmax_scale = softmax_scale\n            ctx.causal = causal\n            ctx.window_size = window_size\n            ctx.softcap = softcap\n            ctx.alibi_slopes = alibi_slopes\n            ctx.deterministic = deterministic\n\n        out = out_padded[..., :head_size_og]\n        return out if not return_softmax else (out, softmax_lse, S_dmask)\n\n    @staticmethod\n    def backward(ctx, dout, *args):\n        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors\n        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)\n        head_size_og = dout.size(2)\n        dout_padded = dout\n        if head_size_og % 8 != 0:\n            dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])\n        _wrapped_flash_attn_varlen_backward(\n            dout_padded,\n            q,\n            k,\n            v,\n            out,\n            softmax_lse,\n            dq,\n            dk,\n            dv,\n            cu_seqlens_q,\n            cu_seqlens_k,\n            ctx.max_seqlen_q,\n            ctx.max_seqlen_k,\n            ctx.dropout_p,\n            ctx.softmax_scale,\n            ctx.causal,\n            ctx.window_size[0],\n            ctx.window_size[1],\n            ctx.softcap,\n            ctx.alibi_slopes,\n            ctx.deterministic,\n            rng_state=rng_state,\n        )\n        dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension\n        dk = dk[..., : dout.shape[-1]]\n        dv = dv[..., : dout.shape[-1]]\n        return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None\n\n\ndef flash_attn_qkvpacked_func(\n    qkv,\n    dropout_p=0.0,\n    softmax_scale=None,\n    causal=False,\n    window_size=(-1, -1),  # -1 means infinite context window\n    softcap=0.0,  # <=0.0 means deactivate\n    alibi_slopes=None,\n    deterministic=False,\n    return_attn_probs=False,\n):\n    \"\"\"dropout_p should be set to 0.0 during evaluation\n    If Q, K, V are already stacked into 1 tensor, this function will be faster than\n    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation\n    of the gradients of Q, K, V.\n    For multi-query and grouped-query attention (MQA/GQA), please see\n    flash_attn_kvpacked_func and flash_attn_func.\n\n    If window_size != (-1, -1), implements sliding window local attention. Query at position i\n    will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.\n\n    Arguments:\n        qkv: (batch_size, seqlen, 3, nheads, headdim)\n        dropout_p: float. Dropout probability.\n        softmax_scale: float. The scaling of QK^T before applying softmax.\n            Default to 1 / sqrt(headdim).\n        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).\n        window_size: (left, right). If not (-1, -1), implements sliding window local attention.\n        softcap: float. Anything > 0 activates softcapping attention.\n        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to\n            the attention score of query i and key j.\n        deterministic: bool. Whether to use the deterministic implementation of the backward pass,\n            which is slightly slower and uses more memory. The forward pass is always deterministic.\n        return_attn_probs: bool. Whether to return the attention probabilities. This option is for\n           testing only. The returned probabilities are not guaranteed to be correct\n           (they might not have the right scaling).\n    Return:\n        out: (batch_size, seqlen, nheads, headdim).\n        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The\n            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax\n            normalization factor).\n        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).\n            The output of softmax (possibly with different scaling). It also encodes the dropout\n            pattern (negative means that location was dropped, nonnegative means it was kept).\n    \"\"\"\n    return FlashAttnQKVPackedFunc.apply(\n        qkv,\n        dropout_p,\n        softmax_scale,\n        causal,\n        window_size,\n        softcap,\n        alibi_slopes,\n        deterministic,\n        return_attn_probs,\n        torch.is_grad_enabled(),\n    )\n\n\ndef flash_attn_kvpacked_func(\n    q,\n    kv,\n    dropout_p=0.0,\n    softmax_scale=None,\n    causal=False,\n    window_size=(-1, -1),  # -1 means infinite context window\n    softcap=0.0,  # 0.0 means deactivated\n    alibi_slopes=None,\n    deterministic=False,\n    return_attn_probs=False,\n):\n    \"\"\"dropout_p should be set to 0.0 during evaluation\n    If K, V are already stacked into 1 tensor, this function will be faster than\n    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation\n    of the gradients of K, V.\n    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads\n    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.\n    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head\n    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.\n\n    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.\n    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:\n        1 1 1 1 0\n        1 1 1 1 1\n    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:\n        0 0\n        0 0\n        0 0\n        1 0\n        1 1\n    If the row of the mask is all zero, the output will be zero.\n\n    If window_size != (-1, -1), implements sliding window local attention. Query at position i\n    will only attend to keys between\n    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.\n\n    Arguments:\n        q: (batch_size, seqlen, nheads, headdim)\n        kv: (batch_size, seqlen, 2, nheads_k, headdim)\n        dropout_p: float. Dropout probability.\n        softmax_scale: float. The scaling of QK^T before applying softmax.\n            Default to 1 / sqrt(headdim).\n        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).\n        window_size: (left, right). If not (-1, -1), implements sliding window local attention.\n        softcap: float. Anything > 0 activates softcapping attention.\n        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of\n            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)\n            is added to the attention score of query i and key j.\n        deterministic: bool. Whether to use the deterministic implementation of the backward pass,\n            which is slightly slower and uses more memory. The forward pass is always deterministic.\n        return_attn_probs: bool. Whether to return the attention probabilities. This option is for\n           testing only. The returned probabilities are not guaranteed to be correct\n           (they might not have the right scaling).\n    Return:\n        out: (batch_size, seqlen, nheads, headdim).\n        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The\n            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax\n            normalization factor).\n        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).\n            The output of softmax (possibly with different scaling). It also encodes the dropout\n            pattern (negative means that location was dropped, nonnegative means it was kept).\n    \"\"\"\n    return FlashAttnKVPackedFunc.apply(\n        q,\n        kv,\n        dropout_p,\n        softmax_scale,\n        causal,\n        window_size,\n        softcap,\n        alibi_slopes,\n        deterministic,\n        return_attn_probs,\n        torch.is_grad_enabled(),\n    )\n\n\ndef flash_attn_func(\n    q,\n    k,\n    v,\n    dropout_p=0.0,\n    softmax_scale=None,\n    causal=False,\n    window_size=(-1, -1),  # -1 means infinite context window\n    softcap=0.0, # 0.0 means deactivated\n    alibi_slopes=None,\n    deterministic=False,\n    return_attn_probs=False,\n):\n    \"\"\"dropout_p should be set to 0.0 during evaluation\n    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads\n    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.\n    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head\n    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.\n\n    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.\n    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:\n        1 1 1 1 0\n        1 1 1 1 1\n    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:\n        0 0\n        0 0\n        0 0\n        1 0\n        1 1\n    If the row of the mask is all zero, the output will be zero.\n\n    If window_size != (-1, -1), implements sliding window local attention. Query at position i\n    will only attend to keys between\n    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.\n\n    Arguments:\n        q: (batch_size, seqlen, nheads, headdim)\n        k: (batch_size, seqlen, nheads_k, headdim)\n        v: (batch_size, seqlen, nheads_k, headdim)\n        dropout_p: float. Dropout probability.\n        softmax_scale: float. The scaling of QK^T before applying softmax.\n            Default to 1 / sqrt(headdim).\n        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).\n        window_size: (left, right). If not (-1, -1), implements sliding window local attention.\n        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of\n            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)\n            is added to the attention score of query i and key j.\n        deterministic: bool. Whether to use the deterministic implementation of the backward pass,\n            which is slightly slower and uses more memory. The forward pass is always deterministic.\n        return_attn_probs: bool. Whether to return the attention probabilities. This option is for\n           testing only. The returned probabilities are not guaranteed to be correct\n           (they might not have the right scaling).\n    Return:\n        out: (batch_size, seqlen, nheads, headdim).\n        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The\n            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax\n            normalization factor).\n        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).\n            The output of softmax (possibly with different scaling). It also encodes the dropout\n            pattern (negative means that location was dropped, nonnegative means it was kept).\n    \"\"\"\n    return FlashAttnFunc.apply(\n        q,\n        k,\n        v,\n        dropout_p,\n        softmax_scale,\n        causal,\n        window_size,\n        softcap,\n        alibi_slopes,\n        deterministic,\n        return_attn_probs,\n        torch.is_grad_enabled(),\n    )\n\n\ndef flash_attn_varlen_qkvpacked_func(\n    qkv,\n    cu_seqlens,\n    max_seqlen,\n    dropout_p=0.0,\n    softmax_scale=None,\n    causal=False,\n    window_size=(-1, -1),  # -1 means infinite context window\n    softcap=0.0, # 0.0 means deactivated\n    alibi_slopes=None,\n    deterministic=False,\n    return_attn_probs=False,\n):\n    \"\"\"dropout_p should be set to 0.0 during evaluation\n    If Q, K, V are already stacked into 1 tensor, this function will be faster than\n    calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation\n    of the gradients of Q, K, V.\n    For multi-query and grouped-query attention (MQA/GQA), please see\n    flash_attn_varlen_kvpacked_func and flash_attn_varlen_func.\n\n    If window_size != (-1, -1), implements sliding window local attention. Query at position i\n    will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.\n\n    Arguments:\n        qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.\n        cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths\n           of the sequences in the batch, used to index into qkv.\n        max_seqlen: int. Maximum sequence length in the batch.\n        dropout_p: float. Dropout probability.\n        softmax_scale: float. The scaling of QK^T before applying softmax.\n            Default to 1 / sqrt(headdim).\n        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).\n        window_size: (left, right). If not (-1, -1), implements sliding window local attention.\n        softcap: float. Anything > 0 activates softcapping attention.\n        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|)\n            is added to the attention score of query i and key j.\n        deterministic: bool. Whether to use the deterministic implementation of the backward pass,\n            which is slightly slower and uses more memory. The forward pass is always deterministic.\n        return_attn_probs: bool. Whether to return the attention probabilities. This option is for\n           testing only. The returned probabilities are not guaranteed to be correct\n           (they might not have the right scaling).\n    Return:\n        out: (total, nheads, headdim).\n        softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The\n            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax\n            normalization factor).\n        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).\n            The output of softmax (possibly with different scaling). It also encodes the dropout\n            pattern (negative means that location was dropped, nonnegative means it was kept).\n    \"\"\"\n    return FlashAttnVarlenQKVPackedFunc.apply(\n        qkv,\n        cu_seqlens,\n        max_seqlen,\n        dropout_p,\n        softmax_scale,\n        causal,\n        window_size,\n        softcap,\n        alibi_slopes,\n        deterministic,\n        return_attn_probs,\n        torch.is_grad_enabled(),\n    )\n\n\ndef flash_attn_varlen_kvpacked_func(\n    q,\n    kv,\n    cu_seqlens_q,\n    cu_seqlens_k,\n    max_seqlen_q,\n    max_seqlen_k,\n    dropout_p=0.0,\n    softmax_scale=None,\n    causal=False,\n    window_size=(-1, -1),  # -1 means infinite context window\n    softcap=0.0, # 0.0 means deactivated\n    alibi_slopes=None,\n    deterministic=False,\n    return_attn_probs=False,\n):\n    \"\"\"dropout_p should be set to 0.0 during evaluation\n    If K, V are already stacked into 1 tensor, this function will be faster than\n    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation\n    of the gradients of K, V.\n    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads\n    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.\n    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head\n    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.\n\n    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.\n    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:\n        1 1 1 1 0\n        1 1 1 1 1\n    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:\n        0 0\n        0 0\n        0 0\n        1 0\n        1 1\n    If the row of the mask is all zero, the output will be zero.\n\n    If window_size != (-1, -1), implements sliding window local attention. Query at position i\n    will only attend to keys between\n    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.\n\n    Arguments:\n        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.\n        kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.\n        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths\n           of the sequences in the batch, used to index into q.\n        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths\n           of the sequences in the batch, used to index into kv.\n        max_seqlen_q: int. Maximum query sequence length in the batch.\n        max_seqlen_k: int. Maximum key sequence length in the batch.\n        dropout_p: float. Dropout probability.\n        softmax_scale: float. The scaling of QK^T before applying softmax.\n            Default to 1 / sqrt(headdim).\n        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).\n        window_size: (left, right). If not (-1, -1), implements sliding window local attention.\n        softcap: float. Anything > 0 activates softcapping attention.\n        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of\n            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)\n            is added to the attention score of query i and key j.\n        deterministic: bool. Whether to use the deterministic implementation of the backward pass,\n            which is slightly slower and uses more memory. The forward pass is always deterministic.\n        return_attn_probs: bool. Whether to return the attention probabilities. This option is for\n           testing only. The returned probabilities are not guaranteed to be correct\n           (they might not have the right scaling).\n    Return:\n        out: (total, nheads, headdim).\n        softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The\n            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax\n            normalization factor).\n        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).\n            The output of softmax (possibly with different scaling). It also encodes the dropout\n            pattern (negative means that location was dropped, nonnegative means it was kept).\n    \"\"\"\n    return FlashAttnVarlenKVPackedFunc.apply(\n        q,\n        kv,\n        cu_seqlens_q,\n        cu_seqlens_k,\n        max_seqlen_q,\n        max_seqlen_k,\n        dropout_p,\n        softmax_scale,\n        causal,\n        window_size,\n        softcap,\n        alibi_slopes,\n        deterministic,\n        return_attn_probs,\n        torch.is_grad_enabled(),\n    )\n\n\ndef flash_attn_varlen_func(\n    q,\n    k,\n    v,\n    cu_seqlens_q,\n    cu_seqlens_k,\n    max_seqlen_q,\n    max_seqlen_k,\n    dropout_p=0.0,\n    softmax_scale=None,\n    causal=False,\n    window_size=(-1, -1),  # -1 means infinite context window\n    softcap=0.0, # 0.0 means deactivated\n    alibi_slopes=None,\n    deterministic=False,\n    return_attn_probs=False,\n    block_table=None,\n):\n    \"\"\"dropout_p should be set to 0.0 during evaluation\n    Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads\n    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.\n    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head\n    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.\n\n    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.\n    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:\n        1 1 1 1 0\n        1 1 1 1 1\n    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:\n        0 0\n        0 0\n        0 0\n        1 0\n        1 1\n    If the row of the mask is all zero, the output will be zero.\n\n    If window_size != (-1, -1), implements sliding window local attention. Query at position i\n    will only attend to keys between\n    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.\n\n    Arguments:\n        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.\n        k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.\n        v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.\n        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths\n           of the sequences in the batch, used to index into q.\n        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths\n           of the sequences in the batch, used to index into kv.\n        max_seqlen_q: int. Maximum query sequence length in the batch.\n        max_seqlen_k: int. Maximum key sequence length in the batch.\n        dropout_p: float. Dropout probability.\n        softmax_scale: float. The scaling of QK^T before applying softmax.\n            Default to 1 / sqrt(headdim).\n        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).\n        window_size: (left, right). If not (-1, -1), implements sliding window local attention.\n        softcap: float. Anything > 0 activates softcapping attention.\n        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of\n            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)\n            is added to the attention score of query i and key j.\n        deterministic: bool. Whether to use the deterministic implementation of the backward pass,\n            which is slightly slower and uses more memory. The forward pass is always deterministic.\n        return_attn_probs: bool. Whether to return the attention probabilities. This option is for\n           testing only. The returned probabilities are not guaranteed to be correct\n           (they might not have the right scaling).\n    Return:\n        out: (total, nheads, headdim).\n        softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The\n            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax\n            normalization factor).\n        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).\n            The output of softmax (possibly with different scaling). It also encodes the dropout\n            pattern (negative means that location was dropped, nonnegative means it was kept).\n    \"\"\"\n    return FlashAttnVarlenFunc.apply(\n        q,\n        k,\n        v,\n        cu_seqlens_q,\n        cu_seqlens_k,\n        max_seqlen_q,\n        max_seqlen_k,\n        dropout_p,\n        softmax_scale,\n        causal,\n        window_size,\n        softcap,\n        alibi_slopes,\n        deterministic,\n        return_attn_probs,\n        block_table,\n        torch.is_grad_enabled(),\n    )\n\n\ndef flash_attn_with_kvcache(\n    q,\n    k_cache,\n    v_cache,\n    k=None,\n    v=None,\n    rotary_cos=None,\n    rotary_sin=None,\n    cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,\n    cache_batch_idx: Optional[torch.Tensor] = None,\n    cache_leftpad: Optional[torch.Tensor] = None,\n    block_table: Optional[torch.Tensor] = None,\n    softmax_scale=None,\n    causal=False,\n    window_size=(-1, -1),  # -1 means infinite context window\n    softcap=0.0, # 0.0 means deactivated\n    rotary_interleaved=True,\n    alibi_slopes=None,\n    num_splits=0,\n    return_softmax_lse=False,\n):\n    \"\"\"\n    If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from\n    k and v. This is useful for incremental decoding: you can pass in the cached keys/values from\n    the previous step, and update them with the new keys/values from the current step, and do\n    attention with the updated cache, all in 1 kernel.\n\n    If you pass in k / v, you must make sure that the cache is large enough to hold the new values.\n    For example, the KV cache could be pre-allocated with the max sequence length, and you can use\n    cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.\n\n    Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be\n    rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.\n    If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos\n    and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.\n    If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at\n    indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).\n\n    See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.\n\n    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads\n    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.\n    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head\n    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.\n\n    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.\n    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:\n        1 1 1 1 0\n        1 1 1 1 1\n    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:\n        0 0\n        0 0\n        0 0\n        1 0\n        1 1\n    If the row of the mask is all zero, the output will be zero.\n\n    If window_size != (-1, -1), implements sliding window local attention. Query at position i\n    will only attend to keys between\n    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.\n\n    Note: Does not support backward pass.\n\n    Arguments:\n        q: (batch_size, seqlen, nheads, headdim)\n        k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,\n            or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)\n            page_block_size must be a multiple of 256.\n        v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,\n            or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)\n        k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate\n            k with k_cache, starting at the indices specified by cache_seqlens.\n        v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.\n        rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding\n            to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.\n        rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.\n        cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the\n            KV cache.\n        cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.\n            If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].\n            If the indices are not distinct, and k and v are provided, the values updated in the cache\n                 might come from any of the duplicate indices.\n        cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.\n        block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.\n        softmax_scale: float. The scaling of QK^T before applying softmax.\n            Default to 1 / sqrt(headdim).\n        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).\n        window_size: (left, right). If not (-1, -1), implements sliding window local attention.\n        softcap: float. Anything > 0 activates softcapping attention.\n        rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.\n            If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,\n            rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1\n            (i.e. GPT-NeoX style).\n        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of\n            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)\n            is added to the attention score of query i and key j.\n        num_splits: int. If > 1, split the key/value into this many chunks along the sequence.\n           If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic\n           to automatically determine the number of splits.\n           Don't change this unless you know what you are doing.\n        return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.\n\n    Return:\n        out: (batch_size, seqlen, nheads, headdim).\n        softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The\n            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax\n            normalization factor).\n    \"\"\"\n    assert k_cache.stride(-1) == 1, \"k_cache must have contiguous last dimension\"\n    assert v_cache.stride(-1) == 1, \"v_cache must have contiguous last dimension\"\n    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]\n    if softmax_scale is None:\n        softmax_scale = q.shape[-1] ** (-0.5)\n    if cache_seqlens is not None and isinstance(cache_seqlens, int):\n        cache_seqlens = torch.full(\n            (q.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device\n        )\n        cache_seqlens = maybe_contiguous(cache_seqlens)\n    cache_batch_idx = maybe_contiguous(cache_batch_idx)\n    block_table = maybe_contiguous(block_table)\n    out, softmax_lse = flash_attn_gpu.fwd_kvcache(\n        q,\n        k_cache,\n        v_cache,\n        k,\n        v,\n        cache_seqlens,\n        rotary_cos,\n        rotary_sin,\n        cache_batch_idx,\n        cache_leftpad,\n        block_table,\n        alibi_slopes,\n        None,\n        softmax_scale,\n        causal,\n        window_size[0],\n        window_size[1],\n        softcap,\n        rotary_interleaved,\n        num_splits,\n    )\n    return (out, softmax_lse) if return_softmax_lse else out\n"
  },
  {
    "path": "flash_attn/flash_attn_triton.py",
    "content": "\"\"\"\n*Experimental* implementation of FlashAttention in Triton.\nTested with triton==2.0.0.dev20221202.\nTriton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions\nother than 64:\nhttps://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207\nWe'll update this implementation with the new Triton backend once this is fixed.\n\nWe use the FlashAttention implementation from Phil Tillet a starting point.\nhttps://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py\n\nChanges:\n- Implement both causal and non-causal attention.\n- Implement both self-attention and cross-attention.\n- Support arbitrary seqlens (not just multiples of 128), for both forward and backward.\n- Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward.\n- Support attention bias.\n- Speed up the forward pass a bit, and only store the LSE instead of m and l.\n- Make the backward for d=128 much faster by reducing register spilling.\n- Optionally parallelize the backward pass across seqlen_k, to deal with the case of\nsmall batch size * nheads.\n\nCaution:\n- This is an *experimental* implementation. The forward pass should be quite robust but\nI'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler).\n- This implementation has only been tested on A100.\n- If you plan to use headdim other than 64 and 128, you should test for race conditions\n(due to the Triton compiler), as done in tests/test_flash_attn.py\n\"test_flash_attn_triton_race_condition\". I've tested and fixed many race conditions\nfor different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident\nthat there are none left for other head dimensions.\n\nDifferences between this Triton version and the CUDA version:\n- Triton version doesn't support dropout.\n- Triton forward is generally faster than CUDA forward, while Triton backward is\ngenerally slower than CUDA backward. Overall Triton forward + backward is slightly slower\nthan CUDA forward + backward.\n- Triton version doesn't support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).\n- Triton version supports attention bias, while CUDA version doesn't.\n\"\"\"\n\nimport math\n\nimport torch\nimport triton\nimport triton.language as tl\n\n\n# Disabling autotune for now, set num_warps=4 if headdim=64 and num_warps=8 if headdim=128\n# @triton.autotune(\n#     configs=[\n#         triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 128}, num_warps=4, num_stages=1),\n#         # This config has a race condition when EVEN_M == False, disabling it for now.\n#         # triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 64}, num_warps=4, num_stages=1),\n#     ],\n#     key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM']\n# )\n@triton.heuristics(\n    {\n        \"EVEN_M\": lambda args: args[\"seqlen_q\"] % args[\"BLOCK_M\"] == 0,\n        \"EVEN_N\": lambda args: args[\"seqlen_k\"] % args[\"BLOCK_N\"] == 0,\n        \"EVEN_HEADDIM\": lambda args: args[\"headdim\"] == args[\"BLOCK_HEADDIM\"],\n    }\n)\n@triton.jit\ndef _fwd_kernel(\n    Q,\n    K,\n    V,\n    Bias,\n    Out,\n    Lse,\n    TMP,  # NOTE: TMP is a scratchpad buffer to workaround a compiler bug\n    softmax_scale,\n    stride_qb,\n    stride_qh,\n    stride_qm,\n    stride_kb,\n    stride_kh,\n    stride_kn,\n    stride_vb,\n    stride_vh,\n    stride_vn,\n    stride_bb,\n    stride_bh,\n    stride_bm,\n    stride_ob,\n    stride_oh,\n    stride_om,\n    nheads,\n    seqlen_q,\n    seqlen_k,\n    seqlen_q_rounded,\n    headdim,\n    CACHE_KEY_SEQLEN_Q,\n    CACHE_KEY_SEQLEN_K,\n    BIAS_TYPE: tl.constexpr,\n    IS_CAUSAL: tl.constexpr,\n    BLOCK_HEADDIM: tl.constexpr,\n    EVEN_M: tl.constexpr,\n    EVEN_N: tl.constexpr,\n    EVEN_HEADDIM: tl.constexpr,\n    BLOCK_M: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n):\n    start_m = tl.program_id(0)\n    off_hb = tl.program_id(1)\n    off_b = off_hb // nheads\n    off_h = off_hb % nheads\n    # off_b = tl.program_id(1)\n    # off_h = tl.program_id(2)\n    # off_hb = off_b * nheads + off_h\n    # initialize offsets\n    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n    offs_n = tl.arange(0, BLOCK_N)\n    offs_d = tl.arange(0, BLOCK_HEADDIM)\n    # Initialize pointers to Q, K, V\n    # Adding parenthesis around indexing might use int32 math instead of int64 math?\n    # https://github.com/openai/triton/issues/741\n    # I'm seeing a tiny bit of difference (5-7us)\n    q_ptrs = (\n        Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])\n    )\n    k_ptrs = (\n        K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])\n    )\n    v_ptrs = (\n        V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])\n    )\n    if BIAS_TYPE == \"vector\":\n        b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n\n    elif BIAS_TYPE == \"matrix\":\n        b_ptrs = (\n            Bias\n            + off_b * stride_bb\n            + off_h * stride_bh\n            + (offs_m[:, None] * stride_bm + offs_n[None, :])\n        )\n    # initialize pointer to m and l\n    t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m\n    lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n    acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)\n    # load q: it will stay in SRAM throughout\n    # [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call\n    # tl.load(q_ptrs), we get the wrong output!\n    if EVEN_M & EVEN_N:\n        if EVEN_HEADDIM:\n            q = tl.load(q_ptrs)\n        else:\n            q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)\n    else:\n        if EVEN_HEADDIM:\n            q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)\n        else:\n            q = tl.load(\n                q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0\n            )\n    # loop over k, v and update accumulator\n    end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)\n    for start_n in range(0, end_n, BLOCK_N):\n        start_n = tl.multiple_of(start_n, BLOCK_N)\n        # -- compute qk ----\n        if EVEN_N & EVEN_M:  # If we just do \"if EVEN_N\", there seems to be some race condition\n            if EVEN_HEADDIM:\n                k = tl.load(k_ptrs + start_n * stride_kn)\n            else:\n                k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)\n        else:\n            if EVEN_HEADDIM:\n                k = tl.load(\n                    k_ptrs + start_n * stride_kn,\n                    mask=(start_n + offs_n)[:, None] < seqlen_k,\n                    other=0.0,\n                )\n            else:\n                k = tl.load(\n                    k_ptrs + start_n * stride_kn,\n                    mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),\n                    other=0.0,\n                )\n        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n        qk += tl.dot(q, k, trans_b=True)\n        # Trying to combine the two masks seem to make the result wrong\n        if not EVEN_N:  # Need to mask out otherwise the softmax is wrong\n            qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float(\"-inf\"))\n        if IS_CAUSAL:\n            qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float(\"-inf\"))\n        if BIAS_TYPE != \"none\":\n            if BIAS_TYPE == \"vector\":\n                if EVEN_N:\n                    bias = tl.load(b_ptrs + start_n).to(tl.float32)\n                else:\n                    bias = tl.load(\n                        b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0\n                    ).to(tl.float32)\n                bias = bias[None, :]\n            elif BIAS_TYPE == \"matrix\":\n                if EVEN_M & EVEN_N:\n                    bias = tl.load(b_ptrs + start_n).to(tl.float32)\n                else:\n                    bias = tl.load(\n                        b_ptrs + start_n,\n                        mask=(offs_m[:, None] < seqlen_q)\n                        & ((start_n + offs_n)[None, :] < seqlen_k),\n                        other=0.0,\n                    ).to(tl.float32)\n            # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler\n            # can then fuse the mult and add into an fma instruction. But if we have bias we need to\n            # to multiply with softmax_scale here.\n            qk = qk * softmax_scale + bias\n            m_ij = tl.maximum(tl.max(qk, 1), lse_i)\n            p = tl.exp(qk - m_ij[:, None])\n        else:\n            m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)\n            p = tl.exp(qk * softmax_scale - m_ij[:, None])\n        l_ij = tl.sum(p, 1)\n\n        # scale acc_o\n        acc_o_scale = tl.exp(m_i - m_ij)\n\n        # # -- update output accumulator --\n        # BUG: have to store and immediately load\n        tl.store(t_ptrs, acc_o_scale)\n        acc_o_scale = tl.load(t_ptrs)\n        acc_o = acc_o * acc_o_scale[:, None]\n        # update acc_o\n        if EVEN_N & EVEN_M:  # If we just do \"if EVEN_N\", there seems to be some race condition\n            if EVEN_HEADDIM:\n                v = tl.load(v_ptrs + start_n * stride_vn)\n            else:\n                v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)\n        else:\n            if EVEN_HEADDIM:\n                v = tl.load(\n                    v_ptrs + start_n * stride_vn,\n                    mask=(start_n + offs_n)[:, None] < seqlen_k,\n                    other=0.0,\n                )\n            else:\n                v = tl.load(\n                    v_ptrs + start_n * stride_vn,\n                    mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),\n                    other=0.0,\n                )\n        p = p.to(v.dtype)\n        acc_o += tl.dot(p, v)\n\n        # -- update statistics\n        m_i = m_ij\n        l_i_new = tl.exp(lse_i - m_ij) + l_ij\n        lse_i = m_ij + tl.log(l_i_new)\n\n    o_scale = tl.exp(m_i - lse_i)\n    # BUG: have to store and immediately load\n    tl.store(t_ptrs, o_scale)\n    o_scale = tl.load(t_ptrs)\n    acc_o = acc_o * o_scale[:, None]\n    # rematerialize offsets to save registers\n    start_m = tl.program_id(0)\n    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n    # write back l and m\n    lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m\n    tl.store(lse_ptrs, lse_i)\n    # initialize pointers to output\n    offs_d = tl.arange(0, BLOCK_HEADDIM)\n    out_ptrs = (\n        Out\n        + off_b * stride_ob\n        + off_h * stride_oh\n        + (offs_m[:, None] * stride_om + offs_d[None, :])\n    )\n    if EVEN_M:\n        if EVEN_HEADDIM:\n            tl.store(out_ptrs, acc_o)\n        else:\n            tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)\n    else:\n        if EVEN_HEADDIM:\n            tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)\n        else:\n            tl.store(\n                out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)\n            )\n\n\n@triton.jit\ndef _bwd_preprocess_do_o_dot(\n    Out,\n    DO,\n    Delta,\n    stride_ob,\n    stride_oh,\n    stride_om,\n    stride_dob,\n    stride_doh,\n    stride_dom,\n    nheads,\n    seqlen_q,\n    seqlen_q_rounded,\n    headdim,\n    BLOCK_M: tl.constexpr,\n    BLOCK_HEADDIM: tl.constexpr,\n):\n    start_m = tl.program_id(0)\n    off_hb = tl.program_id(1)\n    off_b = off_hb // nheads\n    off_h = off_hb % nheads\n    # initialize offsets\n    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n    offs_d = tl.arange(0, BLOCK_HEADDIM)\n    # load\n    o = tl.load(\n        Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :],\n        mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),\n        other=0.0,\n    ).to(tl.float32)\n    do = tl.load(\n        DO\n        + off_b * stride_dob\n        + off_h * stride_doh\n        + offs_m[:, None] * stride_dom\n        + offs_d[None, :],\n        mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),\n        other=0.0,\n    ).to(tl.float32)\n    delta = tl.sum(o * do, axis=1)\n    # write-back\n    tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta)\n\n\n@triton.jit\ndef _bwd_store_dk_dv(\n    dk_ptrs,\n    dv_ptrs,\n    dk,\n    dv,\n    offs_n,\n    offs_d,\n    seqlen_k,\n    headdim,\n    EVEN_M: tl.constexpr,\n    EVEN_N: tl.constexpr,\n    EVEN_HEADDIM: tl.constexpr,\n):\n    # [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False,\n    # if we just call tl.store(dv_ptrs), there's a race condition\n    if EVEN_N & EVEN_M:\n        if EVEN_HEADDIM:\n            tl.store(dv_ptrs, dv)\n            tl.store(dk_ptrs, dk)\n        else:\n            tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim)\n            tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim)\n    else:\n        if EVEN_HEADDIM:\n            tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k)\n            tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k)\n        else:\n            tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))\n            tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))\n\n\n@triton.jit\ndef _bwd_kernel_one_col_block(\n    start_n,\n    Q,\n    K,\n    V,\n    Bias,\n    DO,\n    DQ,\n    DK,\n    DV,\n    LSE,\n    D,\n    softmax_scale,\n    stride_qm,\n    stride_kn,\n    stride_vn,\n    stride_bm,\n    stride_dom,\n    stride_dqm,\n    stride_dkn,\n    stride_dvn,\n    seqlen_q,\n    seqlen_k,\n    headdim,\n    ATOMIC_ADD: tl.constexpr,\n    BIAS_TYPE: tl.constexpr,\n    IS_CAUSAL: tl.constexpr,\n    BLOCK_HEADDIM: tl.constexpr,\n    EVEN_M: tl.constexpr,\n    EVEN_N: tl.constexpr,\n    EVEN_HEADDIM: tl.constexpr,\n    BLOCK_M: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n):\n    # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N)\n    begin_m = 0 if not IS_CAUSAL else ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M\n    # initialize row/col offsets\n    offs_qm = begin_m + tl.arange(0, BLOCK_M)\n    offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)\n    offs_m = tl.arange(0, BLOCK_M)\n    offs_d = tl.arange(0, BLOCK_HEADDIM)\n    # initialize pointers to value-like data\n    q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :])\n    k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :])\n    v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :])\n    do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :])\n    dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :])\n    if BIAS_TYPE == \"vector\":\n        b_ptrs = Bias + offs_n\n    elif BIAS_TYPE == \"matrix\":\n        b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :])\n    # initialize dv and dk\n    dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)\n    dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)\n    # There seems to be some problem with Triton pipelining that makes results wrong for\n    # headdim=64, seqlen=(113, 255), bias_type='matrix'. In this case the for loop\n    # may have zero step, and pipelining with the bias matrix could screw it up.\n    # So we just exit early.\n    if begin_m >= seqlen_q:\n        dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])\n        dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])\n        _bwd_store_dk_dv(\n            dk_ptrs,\n            dv_ptrs,\n            dk,\n            dv,\n            offs_n,\n            offs_d,\n            seqlen_k,\n            headdim,\n            EVEN_M=EVEN_M,\n            EVEN_N=EVEN_N,\n            EVEN_HEADDIM=EVEN_HEADDIM,\n        )\n        return\n    # k and v stay in SRAM throughout\n    # [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False,\n    # if we just call tl.load(k_ptrs), we get the wrong output!\n    if EVEN_N & EVEN_M:\n        if EVEN_HEADDIM:\n            k = tl.load(k_ptrs)\n            v = tl.load(v_ptrs)\n        else:\n            k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)\n            v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0)\n    else:\n        if EVEN_HEADDIM:\n            k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)\n            v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)\n        else:\n            k = tl.load(\n                k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0\n            )\n            v = tl.load(\n                v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0\n            )\n    # loop over rows\n    num_block_m = tl.cdiv(seqlen_q, BLOCK_M)\n    for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M):\n        start_m = tl.multiple_of(start_m, BLOCK_M)\n        offs_m_curr = start_m + offs_m\n        # load q, k, v, do on-chip\n        # Same bug as below. Otherwise gives wrong result for headdim=40, seqlen=(128, 117)\n        if EVEN_M & EVEN_HEADDIM:\n            q = tl.load(q_ptrs)\n        else:\n            if EVEN_HEADDIM:\n                q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)\n            else:\n                q = tl.load(\n                    q_ptrs,\n                    mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),\n                    other=0.0,\n                )\n        # recompute p = softmax(qk, dim=-1).T\n        qk = tl.dot(q, k, trans_b=True)\n        # Trying to combine the two masks seem to make the result wrong\n        if not EVEN_N:  # Need to mask out otherwise the softmax is wrong\n            qk = tl.where(offs_n[None, :] < seqlen_k, qk, float(\"-inf\"))\n        if IS_CAUSAL:\n            qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float(\"-inf\"))\n        if BIAS_TYPE != \"none\":\n            tl.debug_barrier()  # Race condition otherwise\n            if BIAS_TYPE == \"vector\":\n                if EVEN_N:\n                    bias = tl.load(b_ptrs).to(tl.float32)\n                else:\n                    bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(tl.float32)\n                bias = bias[None, :]\n            elif BIAS_TYPE == \"matrix\":\n                if EVEN_M & EVEN_N:\n                    bias = tl.load(b_ptrs).to(tl.float32)\n                else:\n                    bias = tl.load(\n                        b_ptrs,\n                        mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k),\n                        other=0.0,\n                    ).to(tl.float32)\n            qk = qk * softmax_scale + bias\n        # There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong.\n        # Also wrong for headdim=64.\n        if not (EVEN_M & EVEN_HEADDIM):\n            tl.debug_barrier()\n        lse_i = tl.load(LSE + offs_m_curr)\n        if BIAS_TYPE == \"none\":\n            p = tl.exp(qk * softmax_scale - lse_i[:, None])\n        else:\n            p = tl.exp(qk - lse_i[:, None])\n        # compute dv\n        # [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call\n        # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0), we get wrong outputs\n        # in the case of headdim=48/96, seqlen_q & seqlen_k >= 512. If headdim=40 or seqlen < 512,\n        # the output is correct.\n        if EVEN_M & EVEN_HEADDIM:\n            do = tl.load(do_ptrs)\n        else:\n            # [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask.\n            do = tl.load(\n                do_ptrs,\n                mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),\n                other=0.0,\n            )\n        # if EVEN_M:\n        #     if EVEN_HEADDIM:\n        #         do = tl.load(do_ptrs)\n        #     else:\n        #         do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0)\n        # else:\n        #     if EVEN_HEADDIM:\n        #         do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)\n        #     else:\n        #         do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)\n        #                                    & (offs_d[None, :] < headdim), other=0.0)\n        dv += tl.dot(p.to(do.dtype), do, trans_a=True)\n        # compute dp = dot(v, do)\n        # There seems to be a race condition when headdim=48/96, and dq, dk are wrong.\n        # Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True\n        # Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False\n        if not (EVEN_M & EVEN_HEADDIM):\n            tl.debug_barrier()\n        dp = tl.dot(do, v, trans_b=True)\n        # There's a race condition for headdim=48\n        if not EVEN_HEADDIM:\n            tl.debug_barrier()\n        # compute ds = p * (dp - delta[:, None])\n        # Putting the subtraction after the dp matmul (instead of before) is slightly faster\n        Di = tl.load(D + offs_m_curr)\n        # Converting ds to q.dtype here reduces register pressure and makes it much faster\n        # for BLOCK_HEADDIM=128\n        ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype)\n        # compute dk = dot(ds.T, q)\n        dk += tl.dot(ds, q, trans_a=True)\n        # compute dq\n        if not (\n            EVEN_M & EVEN_HEADDIM\n        ):  # Otherewise there's a race condition when BIAS_TYPE='matrix'\n            tl.debug_barrier()\n        if not ATOMIC_ADD:\n            if EVEN_M & EVEN_HEADDIM:  # Race condition if we just do EVEN_M\n                dq = tl.load(dq_ptrs, eviction_policy=\"evict_last\")\n                dq += tl.dot(ds, k)\n                tl.store(dq_ptrs, dq, eviction_policy=\"evict_last\")\n            else:\n                if EVEN_HEADDIM:\n                    dq = tl.load(\n                        dq_ptrs,\n                        mask=offs_m_curr[:, None] < seqlen_q,\n                        other=0.0,\n                        eviction_policy=\"evict_last\",\n                    )\n                    dq += tl.dot(ds, k)\n                    tl.store(\n                        dq_ptrs,\n                        dq,\n                        mask=offs_m_curr[:, None] < seqlen_q,\n                        eviction_policy=\"evict_last\",\n                    )\n                else:\n                    dq = tl.load(\n                        dq_ptrs,\n                        mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),\n                        other=0.0,\n                        eviction_policy=\"evict_last\",\n                    )\n                    dq += tl.dot(ds, k)\n                    tl.store(\n                        dq_ptrs,\n                        dq,\n                        mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),\n                        eviction_policy=\"evict_last\",\n                    )\n        else:  # If we're parallelizing across the seqlen_k dimension\n            dq = tl.dot(ds, k)\n            if EVEN_M & EVEN_HEADDIM:  # Race condition if we just do EVEN_M\n                tl.atomic_add(dq_ptrs, dq)\n            else:\n                if EVEN_HEADDIM:\n                    tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q)\n                else:\n                    tl.atomic_add(\n                        dq_ptrs,\n                        dq,\n                        mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),\n                    )\n        # increment pointers\n        dq_ptrs += BLOCK_M * stride_dqm\n        q_ptrs += BLOCK_M * stride_qm\n        do_ptrs += BLOCK_M * stride_dom\n        if BIAS_TYPE == \"matrix\":\n            b_ptrs += BLOCK_M * stride_bm\n    # write-back\n    dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])\n    dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])\n    _bwd_store_dk_dv(\n        dk_ptrs,\n        dv_ptrs,\n        dk,\n        dv,\n        offs_n,\n        offs_d,\n        seqlen_k,\n        headdim,\n        EVEN_M=EVEN_M,\n        EVEN_N=EVEN_N,\n        EVEN_HEADDIM=EVEN_HEADDIM,\n    )\n\n\ndef init_to_zero(name):\n    return lambda nargs: nargs[name].zero_()\n\n\n@triton.autotune(\n    configs=[\n        triton.Config(\n            {\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"SEQUENCE_PARALLEL\": False},\n            num_warps=8,\n            num_stages=1,\n            pre_hook=init_to_zero(\"DQ\"),\n        ),\n        triton.Config(\n            {\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"SEQUENCE_PARALLEL\": True},\n            num_warps=8,\n            num_stages=1,\n            pre_hook=init_to_zero(\"DQ\"),\n        ),\n        # Other configs seem to give wrong results when seqlen_q % 128 != 0, disabling them for now\n        # # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4*\n        # triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 64, \"SEQUENCE_PARALLEL\": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),\n        # triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 64, \"SEQUENCE_PARALLEL\": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),\n        # triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 64, \"SEQUENCE_PARALLEL\": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),\n        # triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 64, \"SEQUENCE_PARALLEL\": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),\n    ],\n    key=[\"CACHE_KEY_SEQLEN_Q\", \"CACHE_KEY_SEQLEN_K\", \"BIAS_TYPE\", \"IS_CAUSAL\", \"BLOCK_HEADDIM\"],\n)\n@triton.heuristics(\n    {\n        \"EVEN_M\": lambda args: args[\"seqlen_q\"] % args[\"BLOCK_M\"] == 0,\n        \"EVEN_N\": lambda args: args[\"seqlen_k\"] % args[\"BLOCK_N\"] == 0,\n        \"EVEN_HEADDIM\": lambda args: args[\"headdim\"] == args[\"BLOCK_HEADDIM\"],\n    }\n)\n@triton.jit\ndef _bwd_kernel(\n    Q,\n    K,\n    V,\n    Bias,\n    DO,\n    DQ,\n    DK,\n    DV,\n    LSE,\n    D,\n    softmax_scale,\n    stride_qb,\n    stride_qh,\n    stride_qm,\n    stride_kb,\n    stride_kh,\n    stride_kn,\n    stride_vb,\n    stride_vh,\n    stride_vn,\n    stride_bb,\n    stride_bh,\n    stride_bm,\n    stride_dob,\n    stride_doh,\n    stride_dom,\n    stride_dqb,\n    stride_dqh,\n    stride_dqm,\n    stride_dkb,\n    stride_dkh,\n    stride_dkn,\n    stride_dvb,\n    stride_dvh,\n    stride_dvn,\n    nheads,\n    seqlen_q,\n    seqlen_k,\n    seqlen_q_rounded,\n    headdim,\n    CACHE_KEY_SEQLEN_Q,\n    CACHE_KEY_SEQLEN_K,\n    BIAS_TYPE: tl.constexpr,\n    IS_CAUSAL: tl.constexpr,\n    BLOCK_HEADDIM: tl.constexpr,\n    SEQUENCE_PARALLEL: tl.constexpr,\n    EVEN_M: tl.constexpr,\n    EVEN_N: tl.constexpr,\n    EVEN_HEADDIM: tl.constexpr,\n    BLOCK_M: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n):\n    off_hb = tl.program_id(1)\n    off_b = off_hb // nheads\n    off_h = off_hb % nheads\n    # offset pointers for batch/head\n    Q += off_b * stride_qb + off_h * stride_qh\n    K += off_b * stride_kb + off_h * stride_kh\n    V += off_b * stride_vb + off_h * stride_vh\n    DO += off_b * stride_dob + off_h * stride_doh\n    DQ += off_b * stride_dqb + off_h * stride_dqh\n    DK += off_b * stride_dkb + off_h * stride_dkh\n    DV += off_b * stride_dvb + off_h * stride_dvh\n    if BIAS_TYPE != \"none\":\n        Bias += off_b * stride_bb + off_h * stride_bh\n    # pointer to row-wise quantities in value-like data\n    D += off_hb * seqlen_q_rounded\n    LSE += off_hb * seqlen_q_rounded\n    if not SEQUENCE_PARALLEL:\n        num_block_n = tl.cdiv(seqlen_k, BLOCK_N)\n        for start_n in range(0, num_block_n):\n            _bwd_kernel_one_col_block(\n                start_n,\n                Q,\n                K,\n                V,\n                Bias,\n                DO,\n                DQ,\n                DK,\n                DV,\n                LSE,\n                D,\n                softmax_scale,\n                stride_qm,\n                stride_kn,\n                stride_vn,\n                stride_bm,\n                stride_dom,\n                stride_dqm,\n                stride_dkn,\n                stride_dvn,\n                seqlen_q,\n                seqlen_k,\n                headdim,\n                ATOMIC_ADD=False,\n                BIAS_TYPE=BIAS_TYPE,\n                IS_CAUSAL=IS_CAUSAL,\n                BLOCK_HEADDIM=BLOCK_HEADDIM,\n                EVEN_M=EVEN_M,\n                EVEN_N=EVEN_N,\n                EVEN_HEADDIM=EVEN_HEADDIM,\n                BLOCK_M=BLOCK_M,\n                BLOCK_N=BLOCK_N,\n            )\n    else:\n        start_n = tl.program_id(0)\n        _bwd_kernel_one_col_block(\n            start_n,\n            Q,\n            K,\n            V,\n            Bias,\n            DO,\n            DQ,\n            DK,\n            DV,\n            LSE,\n            D,\n            softmax_scale,\n            stride_qm,\n            stride_kn,\n            stride_vn,\n            stride_bm,\n            stride_dom,\n            stride_dqm,\n            stride_dkn,\n            stride_dvn,\n            seqlen_q,\n            seqlen_k,\n            headdim,\n            ATOMIC_ADD=True,\n            BIAS_TYPE=BIAS_TYPE,\n            IS_CAUSAL=IS_CAUSAL,\n            BLOCK_HEADDIM=BLOCK_HEADDIM,\n            EVEN_M=EVEN_M,\n            EVEN_N=EVEN_N,\n            EVEN_HEADDIM=EVEN_HEADDIM,\n            BLOCK_M=BLOCK_M,\n            BLOCK_N=BLOCK_N,\n        )\n\n\ndef _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):\n    # shape constraints\n    batch, seqlen_q, nheads, d = q.shape\n    _, seqlen_k, _, _ = k.shape\n    assert k.shape == (batch, seqlen_k, nheads, d)\n    assert v.shape == (batch, seqlen_k, nheads, d)\n    assert d <= 128, \"FlashAttention only support head dimensions up to 128\"\n    assert q.dtype == k.dtype == v.dtype, \"All tensors must have the same type\"\n    assert q.dtype in [torch.float16, torch.bfloat16], \"Only support fp16 and bf16\"\n    assert q.is_cuda and k.is_cuda and v.is_cuda\n    softmax_scale = softmax_scale or 1.0 / math.sqrt(d)\n\n    has_bias = bias is not None\n    bias_type = \"none\"\n    if has_bias:\n        assert bias.dtype in [q.dtype, torch.float]\n        assert bias.is_cuda\n        assert bias.dim() == 4\n        if bias.stride(-1) != 1:\n            bias = bias.contiguous()\n        if bias.shape[2:] == (1, seqlen_k):\n            bias_type = \"vector\"\n        elif bias.shape[2:] == (seqlen_q, seqlen_k):\n            bias_type = \"matrix\"\n        else:\n            raise RuntimeError(\n                \"Last 2 dimensions of bias must be (1, seqlen_k)\" \" or (seqlen_q, seqlen_k)\"\n            )\n        bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)\n    bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)\n\n    seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128\n    lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)\n    tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)\n    o = torch.empty_like(q)\n\n    BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)\n    BLOCK = 128\n    num_warps = 4 if d <= 64 else 8\n    grid = lambda META: (triton.cdiv(seqlen_q, META[\"BLOCK_M\"]), batch * nheads)\n    _fwd_kernel[grid](\n        q,\n        k,\n        v,\n        bias,\n        o,\n        lse,\n        tmp,\n        softmax_scale,\n        q.stride(0),\n        q.stride(2),\n        q.stride(1),\n        k.stride(0),\n        k.stride(2),\n        k.stride(1),\n        v.stride(0),\n        v.stride(2),\n        v.stride(1),\n        *bias_strides,\n        o.stride(0),\n        o.stride(2),\n        o.stride(1),\n        nheads,\n        seqlen_q,\n        seqlen_k,\n        seqlen_q_rounded,\n        d,\n        seqlen_q // 32,\n        seqlen_k // 32,  # key for triton cache (limit number of compilations)\n        # Can't use kwargs here because triton autotune expects key to be args, not kwargs\n        # IS_CAUSAL=causal, BLOCK_HEADDIM=d,\n        bias_type,\n        causal,\n        BLOCK_HEADDIM,\n        BLOCK_M=BLOCK,\n        BLOCK_N=BLOCK,\n        num_warps=num_warps,\n        num_stages=1,\n    )\n    return o, lse, softmax_scale  # softmax_scale could have been updated\n\n\ndef _flash_attn_backward(\n    do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None\n):\n    # Make sure that the last dimension is contiguous\n    if do.stride(-1) != 1:\n        do = do.contiguous()\n    batch, seqlen_q, nheads, d = q.shape\n    _, seqlen_k, _, _ = k.shape\n    # assert d in {16, 32, 64, 128}\n    assert d <= 128\n    seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128\n    assert lse.shape == (batch, nheads, seqlen_q_rounded)\n    assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1\n    assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1\n    softmax_scale = softmax_scale or 1.0 / math.sqrt(d)\n    # dq_accum = torch.zeros_like(q, dtype=torch.float32)\n    dq_accum = torch.empty_like(q, dtype=torch.float32)\n    delta = torch.empty_like(lse)\n    # delta = torch.zeros_like(lse)\n\n    BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)\n    grid = lambda META: (triton.cdiv(seqlen_q, META[\"BLOCK_M\"]), batch * nheads)\n    _bwd_preprocess_do_o_dot[grid](\n        o,\n        do,\n        delta,\n        o.stride(0),\n        o.stride(2),\n        o.stride(1),\n        do.stride(0),\n        do.stride(2),\n        do.stride(1),\n        nheads,\n        seqlen_q,\n        seqlen_q_rounded,\n        d,\n        BLOCK_M=128,\n        BLOCK_HEADDIM=BLOCK_HEADDIM,\n    )\n\n    has_bias = bias is not None\n    bias_type = \"none\"\n    if has_bias:\n        assert bias.dtype in [q.dtype, torch.float]\n        assert bias.is_cuda\n        assert bias.dim() == 4\n        assert bias.stride(-1) == 1\n        if bias.shape[2:] == (1, seqlen_k):\n            bias_type = \"vector\"\n        elif bias.shape[2:] == (seqlen_q, seqlen_k):\n            bias_type = \"matrix\"\n        else:\n            raise RuntimeError(\n                \"Last 2 dimensions of bias must be (1, seqlen_k)\" \" or (seqlen_q, seqlen_k)\"\n            )\n        bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)\n    bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)\n\n    # BLOCK_M = 128\n    # BLOCK_N = 64\n    # num_warps = 4\n    grid = lambda META: (\n        triton.cdiv(seqlen_k, META[\"BLOCK_N\"]) if META[\"SEQUENCE_PARALLEL\"] else 1,\n        batch * nheads,\n    )\n    _bwd_kernel[grid](\n        q,\n        k,\n        v,\n        bias,\n        do,\n        dq_accum,\n        dk,\n        dv,\n        lse,\n        delta,\n        softmax_scale,\n        q.stride(0),\n        q.stride(2),\n        q.stride(1),\n        k.stride(0),\n        k.stride(2),\n        k.stride(1),\n        v.stride(0),\n        v.stride(2),\n        v.stride(1),\n        *bias_strides,\n        do.stride(0),\n        do.stride(2),\n        do.stride(1),\n        dq_accum.stride(0),\n        dq_accum.stride(2),\n        dq_accum.stride(1),\n        dk.stride(0),\n        dk.stride(2),\n        dk.stride(1),\n        dv.stride(0),\n        dv.stride(2),\n        dv.stride(1),\n        nheads,\n        seqlen_q,\n        seqlen_k,\n        seqlen_q_rounded,\n        d,\n        seqlen_q // 32,\n        seqlen_k // 32,  # key for triton cache (limit number of compilations)\n        # Can't use kwargs here because triton autotune expects key to be args, not kwargs\n        # IS_CAUSAL=causal, BLOCK_HEADDIM=d,\n        bias_type,\n        causal,\n        BLOCK_HEADDIM,\n        # SEQUENCE_PARALLEL=False,\n        # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,\n        # num_warps=num_warps,\n        # num_stages=1,\n    )\n    dq.copy_(dq_accum)\n\n\nclass FlashAttnQKVPackedFunc(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None):\n        \"\"\"\n        qkv: (batch, seqlen, 3, nheads, headdim)\n        bias: optional, shape broadcastible to (batch, nheads, seqlen, seqlen).\n            For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen).\n            ALiBi mask for non-causal would have shape (1, nheads, seqlen, seqlen)\n        \"\"\"\n        # Make sure that the last dimension is contiguous\n        if qkv.stride(-1) != 1:\n            qkv = qkv.contiguous()\n        o, lse, ctx.softmax_scale = _flash_attn_forward(\n            qkv[:, :, 0],\n            qkv[:, :, 1],\n            qkv[:, :, 2],\n            bias=bias,\n            causal=causal,\n            softmax_scale=softmax_scale,\n        )\n        ctx.save_for_backward(qkv, o, lse, bias)\n        ctx.causal = causal\n        return o\n\n    @staticmethod\n    def backward(ctx, do):\n        qkv, o, lse, bias = ctx.saved_tensors\n        assert not ctx.needs_input_grad[1], \"FlashAttention does not support bias gradient yet\"\n        # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd\n        # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.\n        with torch.inference_mode():\n            dqkv = torch.empty_like(qkv)\n            _flash_attn_backward(\n                do,\n                qkv[:, :, 0],\n                qkv[:, :, 1],\n                qkv[:, :, 2],\n                o,\n                lse,\n                dqkv[:, :, 0],\n                dqkv[:, :, 1],\n                dqkv[:, :, 2],\n                bias=bias,\n                causal=ctx.causal,\n                softmax_scale=ctx.softmax_scale,\n            )\n        return dqkv, None, None, None\n\n\nflash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply\n\n\nclass FlashAttnKVPackedFunc(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None):\n        \"\"\"\n        q: (batch, seqlen_q, nheads, headdim)\n        kv: (batch, seqlen_k, 2, nheads, headdim)\n        bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).\n            For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).\n            ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)\n        \"\"\"\n        # Make sure that the last dimension is contiguous\n        q, kv = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, kv]]\n        o, lse, ctx.softmax_scale = _flash_attn_forward(\n            q, kv[:, :, 0], kv[:, :, 1], bias=bias, causal=causal, softmax_scale=softmax_scale\n        )\n        ctx.save_for_backward(q, kv, o, lse, bias)\n        ctx.causal = causal\n        return o\n\n    @staticmethod\n    def backward(ctx, do):\n        q, kv, o, lse, bias = ctx.saved_tensors\n        if len(ctx.needs_input_grad) >= 3:\n            assert not ctx.needs_input_grad[2], \"FlashAttention does not support bias gradient yet\"\n        # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd\n        # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.\n        with torch.inference_mode():\n            dq = torch.empty_like(q)\n            dkv = torch.empty_like(kv)\n            _flash_attn_backward(\n                do,\n                q,\n                kv[:, :, 0],\n                kv[:, :, 1],\n                o,\n                lse,\n                dq,\n                dkv[:, :, 0],\n                dkv[:, :, 1],\n                bias=bias,\n                causal=ctx.causal,\n                softmax_scale=ctx.softmax_scale,\n            )\n        return dq, dkv, None, None, None\n\n\nflash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply\n\n\nclass FlashAttnFunc(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None):\n        \"\"\"\n        q: (batch_size, seqlen_q, nheads, headdim)\n        k, v: (batch_size, seqlen_k, nheads, headdim)\n        bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).\n            For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).\n            ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)\n        \"\"\"\n        # Make sure that the last dimension is contiguous\n        q, k, v = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]]\n        o, lse, ctx.softmax_scale = _flash_attn_forward(\n            q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale\n        )\n        ctx.save_for_backward(q, k, v, o, lse, bias)\n        ctx.causal = causal\n        return o\n\n    @staticmethod\n    def backward(ctx, do):\n        q, k, v, o, lse, bias = ctx.saved_tensors\n        assert not ctx.needs_input_grad[3], \"FlashAttention does not support bias gradient yet\"\n        # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd\n        # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.\n        with torch.inference_mode():\n            dq = torch.empty_like(q)\n            dk = torch.empty_like(k)\n            dv = torch.empty_like(v)\n            _flash_attn_backward(\n                do,\n                q,\n                k,\n                v,\n                o,\n                lse,\n                dq,\n                dk,\n                dv,\n                bias=bias,\n                causal=ctx.causal,\n                softmax_scale=ctx.softmax_scale,\n            )\n        return dq, dk, dv, None, None, None\n\n\nflash_attn_func = FlashAttnFunc.apply\n"
  },
  {
    "path": "flash_attn/flash_attn_triton_og.py",
    "content": "# [2022-10-23] Downloaded from https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py\n# for benchmarking.\n# We fixed a few dtype cast to make it work for bf16\n\n\"\"\"\nFused Attention\n===============\nThis is a Triton implementation of the Flash Attention algorithm\n(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf)\n\"\"\"\n\nimport pytest\nimport torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _fwd_kernel(\n    Q,\n    K,\n    V,\n    sm_scale,\n    TMP,\n    L,\n    M,  # NOTE: TMP is a scratchpad buffer to workaround a compiler bug\n    Out,\n    stride_qz,\n    stride_qh,\n    stride_qm,\n    stride_qk,\n    stride_kz,\n    stride_kh,\n    stride_kn,\n    stride_kk,\n    stride_vz,\n    stride_vh,\n    stride_vk,\n    stride_vn,\n    stride_oz,\n    stride_oh,\n    stride_om,\n    stride_on,\n    Z,\n    H,\n    N_CTX,\n    BLOCK_M: tl.constexpr,\n    BLOCK_DMODEL: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n):\n    start_m = tl.program_id(0)\n    off_hz = tl.program_id(1)\n    # initialize offsets\n    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n    offs_n = tl.arange(0, BLOCK_N)\n    offs_d = tl.arange(0, BLOCK_DMODEL)\n    off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk\n    off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk\n    off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk\n    # Initialize pointers to Q, K, V\n    q_ptrs = Q + off_q\n    k_ptrs = K + off_k\n    v_ptrs = V + off_v\n    # initialize pointer to m and l\n    t_ptrs = TMP + off_hz * N_CTX + offs_m\n    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n    l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n    acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n    # load q: it will stay in SRAM throughout\n    q = tl.load(q_ptrs)\n    # loop over k, v and update accumulator\n    for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):\n        start_n = tl.multiple_of(start_n, BLOCK_N)\n        # -- compute qk ----\n        k = tl.load(k_ptrs + start_n * stride_kn)\n        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n        qk += tl.dot(q, k, trans_b=True)\n        qk *= sm_scale\n        qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float(\"-inf\"))\n        # -- compute m_ij, p, l_ij\n        m_ij = tl.max(qk, 1)\n        p = tl.exp(qk - m_ij[:, None])\n        l_ij = tl.sum(p, 1)\n        # -- update m_i and l_i\n        m_i_new = tl.maximum(m_i, m_ij)\n        alpha = tl.exp(m_i - m_i_new)\n        beta = tl.exp(m_ij - m_i_new)\n        l_i_new = alpha * l_i + beta * l_ij\n        # -- update output accumulator --\n        # scale p\n        p_scale = beta / l_i_new\n        p = p * p_scale[:, None]\n        # scale acc\n        acc_scale = l_i / l_i_new * alpha\n        tl.store(t_ptrs, acc_scale)\n        acc_scale = tl.load(t_ptrs)  # BUG: have to store and immediately load\n        acc = acc * acc_scale[:, None]\n        # update acc\n        v = tl.load(v_ptrs + start_n * stride_vk)\n        p = p.to(v.dtype)\n        acc += tl.dot(p, v)\n        # update m_i and l_i\n        l_i = l_i_new\n        m_i = m_i_new\n    # rematerialize offsets to save registers\n    start_m = tl.program_id(0)\n    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n    # write back l and m\n    l_ptrs = L + off_hz * N_CTX + offs_m\n    m_ptrs = M + off_hz * N_CTX + offs_m\n    tl.store(l_ptrs, l_i)\n    tl.store(m_ptrs, m_i)\n    # initialize pointers to output\n    offs_n = tl.arange(0, BLOCK_DMODEL)\n    off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on\n    out_ptrs = Out + off_o\n    tl.store(out_ptrs, acc)\n\n\n@triton.jit\ndef _bwd_preprocess(\n    Out,\n    DO,\n    L,\n    NewDO,\n    Delta,\n    BLOCK_M: tl.constexpr,\n    D_HEAD: tl.constexpr,\n):\n    off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)\n    off_n = tl.arange(0, D_HEAD)\n    # load\n    o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)\n    do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)\n    denom = tl.load(L + off_m).to(tl.float32)\n    # compute\n    do = do / denom[:, None]\n    delta = tl.sum(o * do, axis=1)\n    # write-back\n    tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)\n    tl.store(Delta + off_m, delta)\n\n\n@triton.jit\ndef _bwd_kernel(\n    Q,\n    K,\n    V,\n    sm_scale,\n    Out,\n    DO,\n    DQ,\n    DK,\n    DV,\n    L,\n    M,\n    D,\n    stride_qz,\n    stride_qh,\n    stride_qm,\n    stride_qk,\n    stride_kz,\n    stride_kh,\n    stride_kn,\n    stride_kk,\n    stride_vz,\n    stride_vh,\n    stride_vk,\n    stride_vn,\n    Z,\n    H,\n    N_CTX,\n    num_block,\n    BLOCK_M: tl.constexpr,\n    BLOCK_DMODEL: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n):\n    off_hz = tl.program_id(0)\n    off_z = off_hz // H\n    off_h = off_hz % H\n    # offset pointers for batch/head\n    Q += off_z * stride_qz + off_h * stride_qh\n    K += off_z * stride_qz + off_h * stride_qh\n    V += off_z * stride_qz + off_h * stride_qh\n    DO += off_z * stride_qz + off_h * stride_qh\n    DQ += off_z * stride_qz + off_h * stride_qh\n    DK += off_z * stride_qz + off_h * stride_qh\n    DV += off_z * stride_qz + off_h * stride_qh\n    for start_n in range(0, num_block):\n        lo = start_n * BLOCK_M\n        # initialize row/col offsets\n        offs_qm = lo + tl.arange(0, BLOCK_M)\n        offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)\n        offs_m = tl.arange(0, BLOCK_N)\n        offs_k = tl.arange(0, BLOCK_DMODEL)\n        # initialize pointers to value-like data\n        q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n        k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)\n        v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n        do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n        dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n        # pointer to row-wise quantities in value-like data\n        D_ptrs = D + off_hz * N_CTX\n        m_ptrs = M + off_hz * N_CTX\n        # initialize dv amd dk\n        dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n        dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n        # k and v stay in SRAM throughout\n        k = tl.load(k_ptrs)\n        v = tl.load(v_ptrs)\n        # loop over rows\n        for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):\n            offs_m_curr = start_m + offs_m\n            # load q, k, v, do on-chip\n            q = tl.load(q_ptrs)\n            # recompute p = softmax(qk, dim=-1).T\n            # NOTE: `do` is pre-divided by `l`; no normalization here\n            qk = tl.dot(q, k, trans_b=True)\n            qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float(\"-inf\"))\n            m = tl.load(m_ptrs + offs_m_curr)\n            p = tl.exp(qk * sm_scale - m[:, None])\n            # compute dv\n            do = tl.load(do_ptrs)\n            dv += tl.dot(p.to(do.dtype), do, trans_a=True)\n            # compute dp = dot(v, do)\n            Di = tl.load(D_ptrs + offs_m_curr)\n            dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]\n            dp += tl.dot(do, v, trans_b=True)\n            # compute ds = p * (dp - delta[:, None])\n            ds = p * dp * sm_scale\n            # compute dk = dot(ds.T, q)\n            dk += tl.dot(ds.to(q.dtype), q, trans_a=True)\n            # # compute dq\n            dq = tl.load(dq_ptrs, eviction_policy=\"evict_last\")\n            dq += tl.dot(ds.to(k.dtype), k)\n            tl.store(dq_ptrs, dq, eviction_policy=\"evict_last\")\n            # # increment pointers\n            dq_ptrs += BLOCK_M * stride_qm\n            q_ptrs += BLOCK_M * stride_qm\n            do_ptrs += BLOCK_M * stride_qm\n        # write-back\n        dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n        dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)\n        tl.store(dv_ptrs, dv)\n        tl.store(dk_ptrs, dk)\n\n\nclass _attention(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, q, k, v, sm_scale):\n        BLOCK = 128\n        # shape constraints\n        Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n        assert Lq == Lk and Lk == Lv\n        assert Lk in {16, 32, 64, 128}\n        o = torch.empty_like(q)\n        grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])\n        tmp = torch.empty(\n            (q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32\n        )\n        L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n        m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n        num_warps = 4 if Lk <= 64 else 8\n\n        _fwd_kernel[grid](\n            q,\n            k,\n            v,\n            sm_scale,\n            tmp,\n            L,\n            m,\n            o,\n            q.stride(0),\n            q.stride(1),\n            q.stride(2),\n            q.stride(3),\n            k.stride(0),\n            k.stride(1),\n            k.stride(2),\n            k.stride(3),\n            v.stride(0),\n            v.stride(1),\n            v.stride(2),\n            v.stride(3),\n            o.stride(0),\n            o.stride(1),\n            o.stride(2),\n            o.stride(3),\n            q.shape[0],\n            q.shape[1],\n            q.shape[2],\n            BLOCK_M=BLOCK,\n            BLOCK_N=BLOCK,\n            BLOCK_DMODEL=Lk,\n            num_warps=num_warps,\n            num_stages=1,\n        )\n        ctx.save_for_backward(q, k, v, o, L, m)\n        ctx.BLOCK = BLOCK\n        ctx.grid = grid\n        ctx.sm_scale = sm_scale\n        ctx.BLOCK_DMODEL = Lk\n        return o\n\n    @staticmethod\n    def backward(ctx, do):\n        q, k, v, o, l, m = ctx.saved_tensors\n        do = do.contiguous()\n        dq = torch.zeros_like(q, dtype=torch.float32)\n        dk = torch.empty_like(k)\n        dv = torch.empty_like(v)\n        do_scaled = torch.empty_like(do)\n        delta = torch.empty_like(l)\n        _bwd_preprocess[(ctx.grid[0] * ctx.grid[1],)](\n            o,\n            do,\n            l,\n            do_scaled,\n            delta,\n            BLOCK_M=ctx.BLOCK,\n            D_HEAD=ctx.BLOCK_DMODEL,\n        )\n\n        # NOTE: kernel currently buggy for other values of `num_warps`\n        num_warps = 8\n        _bwd_kernel[(ctx.grid[1],)](\n            q,\n            k,\n            v,\n            ctx.sm_scale,\n            o,\n            do_scaled,\n            dq,\n            dk,\n            dv,\n            l,\n            m,\n            delta,\n            q.stride(0),\n            q.stride(1),\n            q.stride(2),\n            q.stride(3),\n            k.stride(0),\n            k.stride(1),\n            k.stride(2),\n            k.stride(3),\n            v.stride(0),\n            v.stride(1),\n            v.stride(2),\n            v.stride(3),\n            q.shape[0],\n            q.shape[1],\n            q.shape[2],\n            ctx.grid[0],\n            BLOCK_M=ctx.BLOCK,\n            BLOCK_N=ctx.BLOCK,\n            BLOCK_DMODEL=ctx.BLOCK_DMODEL,\n            num_warps=num_warps,\n            num_stages=1,\n        )\n        return dq.to(q.dtype), dk, dv, None\n\n\nattention = _attention.apply\n"
  },
  {
    "path": "flash_attn/flash_blocksparse_attention.py",
    "content": "import math\n\nimport hydra\nimport torch\nimport torch.nn as nn\nfrom einops import rearrange\n\nfrom flash_attn.bert_padding import index_first_axis, pad_input, unpad_input\nfrom flash_attn.flash_blocksparse_attn_interface import (\n    convert_blockmask,\n    flash_blocksparse_attn_func,\n)\n\n\nclass FlashBlocksparseAttention(nn.Module):\n    \"\"\"Implement the scaled dot product attention with softmax.\n    Arguments\n    ---------\n        softmax_temp: The temperature to use for the softmax attention.\n                      (default: 1/sqrt(d_keys) where d_keys is computed at\n                      runtime)\n        attention_dropout: The dropout rate to apply to the attention\n                           (default: 0.1)\n    \"\"\"\n\n    def __init__(\n        self,\n        sparsity_config,\n        softmax_temp=None,\n        attention_dropout=0.0,\n        max_seq_length=2048,\n        device=None,\n        dtype=None,\n    ):\n        super().__init__()\n        self.sparsity_config = hydra.utils.instantiate(sparsity_config)\n        self.softmax_temp = softmax_temp\n        self.dropout_p = attention_dropout\n\n        # initialize sparse layout and register as buffer\n        max_seq_length = ((max_seq_length + 256 - 1) // 256) * 256\n        layout = self.sparsity_config.make_layout(max_seq_length)\n        self.register_buffer(\"layout\", layout)\n        blockmask_converted = convert_blockmask(self.layout, causal=False)\n        self.register_buffer(\"blockmask_converted\", blockmask_converted)\n        # logger.info(f'Attention class {self.__class__}: saving={self.layout.float().mean()}')\n\n    def forward(\n        self,\n        qkv,\n        attn_mask=None,\n        key_padding_mask=None,\n        causal=False,\n        cu_seqlens=None,\n        max_s=None,\n        need_weights=False,\n        convert_mask=True,\n    ):\n        \"\"\"Implements the multihead softmax attention.\n        Arguments\n        ---------\n            qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None\n            attn_mask: An implementation of BaseMask that encodes where each\n                       query can attend to\n            key_padding_mask: An implementation of BaseMask that encodes how\n                         many query each sequence in the batch consists of\n        \"\"\"\n        assert not need_weights\n        assert attn_mask is None\n        assert qkv.dtype == torch.float16\n        assert qkv.is_cuda\n\n        if cu_seqlens is None:\n            batch_size = qkv.shape[0]\n            seqlen = qkv.shape[1]\n            # Convert mask to take a subset\n            seqlen_rounded = ((seqlen + 256 - 1) // 256) * 256\n            assert seqlen_rounded // 16 <= self.layout.shape[0], (\n                seqlen_rounded // 256 <= self.layout.shape[1]\n            )\n            blockmask = self.layout[: seqlen_rounded // 16, : seqlen_rounded // 256]\n            if key_padding_mask is None:\n                qkv = rearrange(qkv, \"b s ... -> (b s) ...\")\n                max_s = seqlen\n                cu_seqlens = torch.arange(\n                    0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, device=qkv.device\n                )\n                output = flash_blocksparse_attn_func(\n                    qkv,\n                    cu_seqlens,\n                    blockmask,\n                    self.dropout_p if self.training else 0.0,\n                    max_s,\n                    softmax_scale=self.softmax_temp,\n                    causal=causal,\n                )\n                output = rearrange(output, \"(b s) ... -> b s ...\", b=batch_size)\n            else:\n                key_padding_mask_bool = key_padding_mask.bool_matrix\n                nheads = qkv.shape[-2]\n                x = rearrange(qkv, \"b s three h d -> b s (three h d)\")\n                x_unpad, indices, cu_seqlens, max_s, _ = unpad_input(x, key_padding_mask_bool)\n                x_unpad = rearrange(x_unpad, \"nnz (three h d) -> nnz three h d\", three=3, h=nheads)\n                output_unpad = flash_blocksparse_attn_func(\n                    x_unpad,\n                    cu_seqlens,\n                    blockmask,\n                    self.dropout_p if self.training else 0.0,\n                    max_s,\n                    softmax_scale=self.softmax_temp,\n                    causal=causal,\n                )\n                output = rearrange(\n                    pad_input(\n                        rearrange(output_unpad, \"nnz h d -> nnz (h d)\"), indices, batch_size, seqlen\n                    ),\n                    \"b s (h d) -> b s h d\",\n                    h=nheads,\n                )\n        else:\n            assert max_s is not None\n            seqlen = max_s\n            # Convert mask to take a subset\n            seqlen_rounded = ((seqlen + 256 - 1) // 256) * 256\n            assert seqlen_rounded // 16 <= self.layout.shape[0], (\n                seqlen_rounded // 256 <= self.layout.shape[1]\n            )\n            blockmask = self.layout[: seqlen_rounded // 16, : seqlen_rounded // 256]\n            if convert_mask:\n                output = flash_blocksparse_attn_func(\n                    qkv,\n                    cu_seqlens,\n                    blockmask,\n                    self.dropout_p if self.training else 0.0,\n                    max_s,\n                    softmax_scale=self.softmax_temp,\n                    causal=causal,\n                )\n            else:\n                output = flash_blocksparse_attn_func(\n                    qkv,\n                    cu_seqlens,\n                    self.blockmask_converted,\n                    self.dropout_p if self.training else 0.0,\n                    max_s,\n                    softmax_scale=self.softmax_temp,\n                    causal=causal,\n                    convert_mask=False,\n                )\n\n        return output, None\n\n\nclass FlashBlocksparseMHA(nn.Module):\n    def __init__(\n        self,\n        embed_dim,\n        num_heads,\n        sparsity_config,\n        bias=True,\n        batch_first=True,\n        attention_dropout=0.0,\n        causal=False,\n        max_seq_length=2048,\n        device=None,\n        dtype=None,\n        **kwargs,\n    ) -> None:\n        assert batch_first\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        self.embed_dim = embed_dim\n        self.causal = causal\n\n        self.num_heads = num_heads\n        assert self.embed_dim % num_heads == 0, \"self.kdim must be divisible by num_heads\"\n        self.head_dim = self.embed_dim // num_heads\n        assert self.head_dim in [16, 32, 64], \"Only support head_dim == 16, 32, or 64\"\n\n        self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)\n        self.inner_attn = FlashBlocksparseAttention(\n            sparsity_config,\n            attention_dropout=attention_dropout,\n            max_seq_length=max_seq_length,\n            **factory_kwargs,\n        )\n        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)\n\n    def forward(\n        self, x, x_ignored_, x_ignored_1_, attn_mask=None, key_padding_mask=None, need_weights=False\n    ):\n        qkv = self.Wqkv(x)\n        qkv = rearrange(qkv, \"b s (three h d) -> b s three h d\", three=3, h=self.num_heads)\n        context, attn_weights = self.inner_attn(\n            qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=self.causal\n        )\n        return self.out_proj(rearrange(context, \"b s h d -> b s (h d)\")), attn_weights\n"
  },
  {
    "path": "flash_attn/flash_blocksparse_attn_interface.py",
    "content": "# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/fmha.py\nimport flash_attn_cuda\nimport torch\nimport torch.nn as nn\n\n\ndef convert_blockmask(blockmask, causal):\n    \"\"\"Convert from the 0-1 format to the format used by the CUDA code.\n    0 means the block is skipped.\n    nonzero means the block is not skipped.\n    Argument:\n        blockmask: (row, col): a 0-1 tensor\n    Return:\n        blockmask_converted: (col, row), dtype torch.int32: for each column, it contains the row\n            indices of the nonzero blocks, padded with -1 to reach length @row.\n            The indices are multiplied by 4, with the smallest bit used to encode whether\n            it is the first nonzero in its row, and the 2nd smallest bit to encode whether it is\n            the last nonzero in its row..\n    \"\"\"\n    assert not causal\n    # TD [2022-05-13]: The indexing and sorting is very tricky\n    nrow, ncol = blockmask.shape\n    # Sort does not support bool on CUDA\n    blockmask = blockmask.to(dtype=torch.uint8)\n    nonzero_val, nonzero_sorted_rowidx = blockmask.sort(dim=0, stable=True, descending=True)\n    nonzero_unsorted_rowidx = nonzero_sorted_rowidx.argsort(dim=0)\n    last_nonzero_col_per_row = blockmask.sort(dim=-1, stable=True).indices[:, -1]\n    last_nonzero_col_per_row_after_sort = nonzero_unsorted_rowidx[\n        torch.arange(nrow, device=blockmask.device), last_nonzero_col_per_row\n    ]\n    first_nonzero_col_per_row = blockmask.sort(dim=-1, stable=True, descending=True).indices[:, 0]\n    first_nonzero_col_per_row_after_sort = nonzero_unsorted_rowidx[\n        torch.arange(nrow, device=blockmask.device), first_nonzero_col_per_row\n    ]\n    nonzero_idx = nonzero_sorted_rowidx * 4\n    nonzero_idx[last_nonzero_col_per_row_after_sort, last_nonzero_col_per_row] += 2\n    nonzero_idx[first_nonzero_col_per_row_after_sort, first_nonzero_col_per_row] += 1\n    nonzero_idx[nonzero_val == 0] = -1\n    return nonzero_idx.T.contiguous().to(dtype=torch.int32)\n\n\ndef _flash_blocksparse_attn_forward(\n    qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal, return_softmax\n):\n    context, softmax_lse, *rest = flash_attn_cuda.fwd_block(\n        qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal, return_softmax, None\n    )\n    # if context.isnan().any() or softmax_lse.isnan().any():\n    #     breakpoint()\n    S_dmask = rest[0] if return_softmax else None\n    return context, softmax_lse, S_dmask\n\n\ndef _flash_blocksparse_attn_backward(\n    dout,\n    qkv,\n    out,\n    S_dmask,\n    softmax_lse,\n    cu_seqlens,\n    blockmask,\n    dropout_p,\n    max_s,\n    softmax_scale,\n    causal,\n):\n    dqkv, dp, softmax_d = flash_attn_cuda.bwd_block(\n        dout,\n        qkv,\n        out,\n        S_dmask,\n        softmax_lse,\n        cu_seqlens,\n        blockmask,\n        dropout_p,\n        softmax_scale,\n        max_s,\n        causal,\n        None,\n    )\n    # if dqkv.isnan().any() or softmax_d.isnan().any():\n    #     breakpoint()\n    return dqkv\n\n\nclass FlashBlocksparseAttnFun(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal):\n        # Save rng_state because the backward pass will regenerate the dropout mask\n        rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None\n        if softmax_scale is None:\n            softmax_scale = qkv.shape[-1] ** (-0.5)\n        context, softmax_lse, S_dmask = _flash_blocksparse_attn_forward(\n            qkv,\n            cu_seqlens,\n            blockmask,\n            dropout_p,\n            max_s,\n            softmax_scale,\n            causal=causal,\n            return_softmax=False,\n        )\n        ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state)\n        ctx.dropout_p = dropout_p\n        ctx.max_s = max_s\n        ctx.softmax_scale = softmax_scale\n        ctx.causal = causal\n        return context\n\n    @staticmethod\n    def backward(ctx, dout):\n        qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state = ctx.saved_tensors\n        if rng_state is not None:\n            cur_rng_state = torch.cuda.get_rng_state()\n            torch.cuda.set_rng_state(rng_state)\n        # S_dmask is None, temporarily use another tensor just to get it running\n        dqkv = _flash_blocksparse_attn_backward(\n            dout,\n            qkv,\n            context,\n            context,\n            softmax_lse,\n            cu_seqlens,\n            blockmask,\n            ctx.dropout_p,\n            ctx.max_s,\n            ctx.softmax_scale,\n            ctx.causal,\n        )\n        if rng_state is not None:\n            torch.cuda.set_rng_state(cur_rng_state)\n        return dqkv, None, None, None, None, None, None, None\n\n\n# We duplicate code to return both the output and the softmax for testing\n# Returning both makes backward a bit slower, so we want to keep using the other version for speed.\nclass FlashBlocksparseAttnFunWithS(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal):\n        # Save rng_state because the backward pass is gonna regenerate the dropout mask\n        rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None\n        if softmax_scale is None:\n            softmax_scale = qkv.shape[-1] ** (-0.5)\n        context, softmax_lse, S_dmask = _flash_blocksparse_attn_forward(\n            qkv,\n            cu_seqlens,\n            blockmask,\n            dropout_p,\n            max_s,\n            softmax_scale,\n            causal=causal,\n            return_softmax=True,\n        )\n        ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state)\n        ctx.dropout_p = dropout_p\n        ctx.max_s = max_s\n        ctx.softmax_scale = softmax_scale\n        ctx.causal = causal\n        return context, S_dmask, softmax_lse\n\n    @staticmethod\n    def backward(ctx, dout, _dS_dmask_ignored, _dsoftmax_sum_ignored):\n        qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state = ctx.saved_tensors\n        if rng_state is not None:\n            cur_rng_state = torch.cuda.get_rng_state()\n            torch.cuda.set_rng_state(rng_state)\n        dqkv = _flash_blocksparse_attn_backward(\n            dout,\n            qkv,\n            context,\n            S_dmask,\n            softmax_lse,\n            cu_seqlens,\n            blockmask,\n            ctx.dropout_p,\n            ctx.max_s,\n            ctx.softmax_scale,\n            ctx.causal,\n        )\n        if rng_state is not None:\n            torch.cuda.set_rng_state(cur_rng_state)\n        return dqkv, None, None, None, None, None, None\n\n\ndef flash_blocksparse_attn_func(\n    qkv,\n    cu_seqlens,\n    blockmask,\n    dropout_p,\n    max_s,\n    softmax_scale=None,\n    causal=False,\n    return_attn_probs=False,\n    convert_mask=True,\n):\n    \"\"\"dropout_p should be set to 0.0 during evaluation\"\"\"\n    func = FlashBlocksparseAttnFun if not return_attn_probs else FlashBlocksparseAttnFunWithS\n    if convert_mask:\n        blockmask = convert_blockmask(blockmask, causal=causal)\n    return func.apply(qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal)\n"
  },
  {
    "path": "flash_attn/layers/__init__.py",
    "content": ""
  },
  {
    "path": "flash_attn/layers/patch_embed.py",
    "content": "# We use the same API as https://github.com/rwightman/pytorch-image-models/blob/v0.6.11/timm/models/layers/patch_embed.py\n# But we use nn.Linear instead of Conv2d and it's about 8x faster.\n\nfrom functools import partial\n\nimport torch.nn as nn\nfrom einops import rearrange\nfrom torch import _assert\nfrom torch.nn.modules.utils import _pair\n\ntry:\n    from flash_attn.ops.fused_dense import FusedDense\nexcept ImportError:\n    FusedDense = None\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\"2D Image to Patch Embedding\"\"\"\n\n    def __init__(\n        self,\n        img_size=224,\n        patch_size=16,\n        in_chans=3,\n        embed_dim=768,\n        norm_layer=None,\n        flatten=True,\n        bias=True,\n        fused_bias_fc=False,\n    ):\n        super().__init__()\n        img_size = _pair(img_size)\n        patch_size = _pair(patch_size)\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])\n        self.num_patches = self.grid_size[0] * self.grid_size[1]\n        self.flatten = flatten\n        if fused_bias_fc and FusedDense is None:\n            raise ImportError(\"fused_dense is not installed\")\n\n        linear_cls = nn.Linear if not fused_bias_fc or not bias else FusedDense\n        self.proj = linear_cls(in_chans * patch_size[0] * patch_size[1], embed_dim, bias=bias)\n        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()\n\n    def forward(self, x):\n        _, _, H, W = x.shape\n        _assert(\n            H == self.img_size[0],\n            f\"Input image height ({H}) doesn't match model ({self.img_size[0]}).\",\n        )\n        _assert(\n            W == self.img_size[1],\n            f\"Input image width ({W}) doesn't match model ({self.img_size[1]}).\",\n        )\n        x = self.proj(\n            rearrange(\n                x,\n                \"b c (h p1) (w p2) -> b h w (c p1 p2)\",\n                p1=self.patch_size[0],\n                p2=self.patch_size[1],\n            )\n        )\n        if self.flatten:\n            x = rearrange(x, \"b h w c -> b (h w) c\")\n        x = self.norm(x)\n        return x\n"
  },
  {
    "path": "flash_attn/layers/rotary.py",
    "content": "# Copyright (c) 2025, Tri Dao\n\nimport math\nfrom functools import partial\nfrom typing import Optional, Tuple, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom einops import rearrange, repeat\nfrom flash_attn.ops.triton.rotary import apply_rotary\n\n\ndef rotate_half(x, interleaved=False):\n    if not interleaved:\n        x1, x2 = x.chunk(2, dim=-1)\n        return torch.cat((-x2, x1), dim=-1)\n    else:\n        x1, x2 = x[..., ::2], x[..., 1::2]\n        return rearrange(torch.stack((-x2, x1), dim=-1), \"... d two -> ... (d two)\", two=2)\n\n\ndef apply_rotary_emb_torch(x, cos, sin, interleaved=False):\n    \"\"\"\n    x: (batch_size, seqlen, nheads, headdim)\n    cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)\n    \"\"\"\n    ro_dim = cos.shape[-1] * 2\n    assert ro_dim <= x.shape[-1]\n    cos = repeat(cos, \"... d -> ... 1 (2 d)\" if not interleaved else \"... d -> ... 1 (d 2)\")\n    sin = repeat(sin, \"... d -> ... 1 (2 d)\" if not interleaved else \"... d -> ... 1 (d 2)\")\n    return torch.cat(\n        [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]],\n        dim=-1,\n    )\n\n\nclass ApplyRotaryEmb(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        x,\n        cos,\n        sin,\n        interleaved=False,\n        inplace=False,\n        seqlen_offsets: Union[int, Tensor] = 0,\n        cu_seqlens: Optional[Tensor] = None,\n        max_seqlen: Optional[int] = None,\n    ):\n        out = apply_rotary(\n            x,\n            cos,\n            sin,\n            seqlen_offsets=seqlen_offsets,\n            cu_seqlens=cu_seqlens,\n            max_seqlen=max_seqlen,\n            interleaved=interleaved,\n            inplace=inplace,\n        )\n        if isinstance(seqlen_offsets, int):\n            ctx.save_for_backward(cos, sin, cu_seqlens)  # Can't save int with save_for_backward\n            ctx.seqlen_offsets = seqlen_offsets\n        else:\n            ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)\n            ctx.seqlen_offsets = None\n        ctx.interleaved = interleaved\n        ctx.inplace = inplace\n        ctx.max_seqlen = max_seqlen\n        return out if not inplace else x\n\n    @staticmethod\n    def backward(ctx, do):\n        seqlen_offsets = ctx.seqlen_offsets\n        if seqlen_offsets is None:\n            cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors\n        else:\n            cos, sin, cu_seqlens = ctx.saved_tensors\n        dx = apply_rotary(\n            do,\n            cos,\n            sin,\n            seqlen_offsets=seqlen_offsets,\n            cu_seqlens=cu_seqlens,\n            max_seqlen=ctx.max_seqlen,\n            interleaved=ctx.interleaved,\n            inplace=ctx.inplace,\n            conjugate=True,\n        )\n        return dx, None, None, None, None, None, None, None\n\n\ndef apply_rotary_emb(\n    x,\n    cos,\n    sin,\n    interleaved=False,\n    inplace=False,\n    seqlen_offsets: Union[int, Tensor] = 0,\n    cu_seqlens: Optional[Tensor] = None,\n    max_seqlen: Optional[int] = None,\n):\n    \"\"\"\n    Arguments:\n        x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None\n            else (total_seqlen, nheads, headdim)\n        cos, sin: (seqlen_rotary, rotary_dim / 2)\n        interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead\n            of 1st half and 2nd half (GPT-NeoX style).\n        inplace: if True, apply rotary embedding in-place.\n        seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.\n            Most commonly used in inference when we have KV cache.\n        cu_seqlens: (batch + 1,) or None\n        max_seqlen: int\n    Return:\n        out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None\n            else (total_seqlen, nheads, headdim)\n    rotary_dim must be <= headdim\n    Apply rotary embedding to the first rotary_dim of x.\n    \"\"\"\n    return ApplyRotaryEmb.apply(\n        x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen\n    )\n\n\n# For backward compatibility\napply_rotary_emb_func = apply_rotary_emb\n\n\ndef _apply_rotary_emb_qkv(\n    qkv,\n    cos,\n    sin,\n    cos_k=None,\n    sin_k=None,\n    interleaved=False,\n    inplace=False,\n    conjugate=False,\n    seqlen_offsets: Union[int, Tensor] = 0,\n    num_heads_q: Optional[int] = None,\n):\n    apply_rotary_fn = partial(\n        apply_rotary,\n        interleaved=interleaved,\n        inplace=inplace,\n        conjugate=conjugate,\n        seqlen_offsets=seqlen_offsets\n    )\n    if cos_k is None and sin_k is None and qkv.is_contiguous():\n        # Call 1 kernel instead of 2 kernels\n        # We need qkv to be contiguous so that when we reshape to combine (3, nheads)\n        # dimensions, we get the same tensor\n        if qkv.dim() == 5:\n            batch, seqlen, three, nheads, headdim = qkv.shape\n            assert three == 3\n            # qk = rearrange(qkv[:, :, :2], \"b s t h d -> b s (t h) d\")\n            qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim)\n            qk = apply_rotary_fn(qk, cos, sin)\n        else:\n            assert qkv.dim() == 4\n            assert num_heads_q is not None\n            num_heads_k = (qkv.shape[2] - num_heads_q) // 2\n            assert qkv.shape[2] == num_heads_q + 2 * num_heads_k\n            qk = qkv[:, :, :num_heads_q + num_heads_k]\n            qk = apply_rotary_fn(qk, cos, sin)\n        if not inplace:\n            if qkv.dim() == 5:\n                qkv = torch.cat([rearrange(qk, \"b s (t h) d -> b s t h d\", t=2), qkv[:, :, 2:]], dim=2)\n            else:\n                qkv = torch.cat([qk, qkv[:, :, num_heads_q + num_heads_k :]], dim=2)\n    else:\n        cos_k = cos if cos_k is None else cos_k\n        sin_k = sin if sin_k is None else sin_k\n        if qkv.dim() == 5:\n            batch, seqlen, three, nheads, headdim = qkv.shape\n            assert three == 3\n            q, k = qkv[:, :, 0], qkv[:, :, 1]\n        else:\n            assert qkv.dim() == 4\n            assert num_heads_q is not None\n            num_heads_k = (qkv.shape[2] - num_heads_q) // 2\n            assert qkv.shape[2] == num_heads_q + 2 * num_heads_k\n            q, k = qkv[:, :, :num_heads_q], qkv[:, :, num_heads_q : num_heads_q + num_heads_k]\n        q = apply_rotary_fn(q, cos, sin)\n        k = apply_rotary_fn(k, cos_k, sin_k)\n        if not inplace:\n            if qkv.dim() == 5:\n                qkv = torch.stack([q, k, qkv[:, :, 2]], dim=2)\n            else:\n                qkv = torch.cat([q, k, qkv[:, :, num_heads_q + num_heads_k:]], dim=2)\n    return qkv\n\n\nclass ApplyRotaryEmbQKV_(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        qkv,\n        cos,\n        sin,\n        cos_k=None,\n        sin_k=None,\n        interleaved=False,\n        seqlen_offsets: Union[int, torch.Tensor] = 0,\n        num_heads_q: Optional[int] = None,\n    ):\n        # apply_rotary_emb_qkv_inplace(\n        qkv = _apply_rotary_emb_qkv(\n            qkv, cos, sin, cos_k, sin_k, interleaved=interleaved, inplace=True,\n            seqlen_offsets=seqlen_offsets, num_heads_q=num_heads_q,\n        )\n        if isinstance(seqlen_offsets, int):\n            ctx.save_for_backward(cos, sin, cos_k, sin_k)\n            ctx.seqlen_offsets = seqlen_offsets\n        else:\n            ctx.save_for_backward(cos, sin, cos_k, sin_k, seqlen_offsets)\n            ctx.seqlen_offsets = None\n        ctx.interleaved = interleaved\n        ctx.num_heads_q = num_heads_q\n        return qkv\n\n    @staticmethod\n    def backward(ctx, dqkv):\n        seqlen_offsets = ctx.seqlen_offsets\n        if seqlen_offsets is None:\n            cos, sin, cos_k, sin_k, seqlen_offsets = ctx.saved_tensors\n        else:\n            cos, sin, cos_k, sin_k = ctx.saved_tensors\n        dqkv = _apply_rotary_emb_qkv(\n            dqkv, cos, sin, cos_k, sin_k, interleaved=ctx.interleaved, inplace=True,\n            seqlen_offsets=seqlen_offsets, num_heads_q=ctx.num_heads_q, conjugate=True,\n        )\n        return dqkv, None, None, None, None, None, None, None\n\n\ndef apply_rotary_emb_qkv_(\n    qkv,\n    cos,\n    sin,\n    cos_k=None,\n    sin_k=None,\n    interleaved=False,\n    seqlen_offsets: Union[int, torch.Tensor] = 0,\n    num_heads_q: Optional[int] = None,\n):\n    \"\"\"\n    Arguments:\n        qkv: (batch_size, seqlen, 3, nheads, headdim) or (batch_size, seqlen, num_heads_q + 2 * num_heads_k, headdim).\n            If qkv has shape (batch_size, seqlen, num_heads_q + 2 * num_heads_k, headdim) (e.g. MQA / GQA),\n            then num_heads_q must be provided.\n        cos, sin: (seqlen, rotary_dim / 2)\n        cos_k, sin_k: (seqlen, rotary_dim / 2), optional\n        interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of\n            1st half and 2nd half (GPT-NeoX style).\n        seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.\n            Most commonly used in inference when we have KV cache.\n    Return:\n        qkv: (batch_size, seqlen, 3, nheads, headdim) or (batch_size, seqlen, num_heads_q + 2 * num_heads_k, headdim)\n    rotary_dim must be <= headdim\n    Apply rotary embedding *inplace* to the first rotary_dim of Q and K.\n    \"\"\"\n    return ApplyRotaryEmbQKV_.apply(\n        qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets, num_heads_q\n    )\n\n\nclass ApplyRotaryEmbKV_(torch.autograd.Function):\n\n    @staticmethod\n    def forward(ctx, kv, cos, sin, interleaved=False, seqlen_offsets: Union[int, torch.Tensor] = 0):\n        batch, seqlen, two, nheads, headdim = kv.shape\n        assert two == 2\n        k = kv[:, :, 0]\n        apply_rotary(\n            k, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True\n        )\n        if isinstance(seqlen_offsets, int):\n            ctx.save_for_backward(cos, sin)  # Can't save int with save_for_backward\n            ctx.seqlen_offsets = seqlen_offsets\n        else:\n            ctx.save_for_backward(cos, sin, seqlen_offsets)\n            ctx.seqlen_offsets = None\n        ctx.interleaved = interleaved\n        return kv\n\n    @staticmethod\n    def backward(ctx, dkv):\n        seqlen_offsets = ctx.seqlen_offsets\n        if seqlen_offsets is None:\n            cos, sin, seqlen_offsets = ctx.saved_tensors\n        else:\n            cos, sin = ctx.saved_tensors\n        apply_rotary(\n            dkv[:, :, 0],\n            cos,\n            sin,\n            seqlen_offsets=seqlen_offsets,\n            interleaved=ctx.interleaved,\n            inplace=True,\n            conjugate=True,\n        )\n        return dkv, None, None, None, None\n\n\napply_rotary_emb_kv_ = ApplyRotaryEmbKV_.apply\n\n\ndef apply_rotary_emb_kv_(\n    kv,\n    cos,\n    sin,\n    interleaved=False,\n    seqlen_offsets: Union[int, torch.Tensor] = 0,\n):\n    \"\"\"\n    Arguments:\n        kv: (batch_size, seqlen, 2, nheads, headdim)\n        cos, sin: (seqlen, rotary_dim / 2)\n        interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of\n            1st half and 2nd half (GPT-NeoX style).\n        seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.\n            Most commonly used in inference when we have KV cache.\n    Return:\n        kv: (batch_size, seqlen, 2, nheads, headdim)\n    rotary_dim must be <= headdim\n    Apply rotary embedding *inplace* to the first rotary_dim of K.\n    \"\"\"\n    return ApplyRotaryEmbKV_.apply(kv, cos, sin, interleaved, seqlen_offsets)\n\n\nclass RotaryEmbedding(torch.nn.Module):\n    \"\"\"\n    The rotary position embeddings from RoFormer_ (Su et. al).\n    A crucial insight from the method is that the query and keys are\n    transformed by rotation matrices which depend on the relative positions.\n\n    Other implementations are available in the Rotary Transformer repo_ and in\n    GPT-NeoX_, GPT-NeoX was an inspiration\n\n    .. _RoFormer: https://arxiv.org/abs/2104.09864\n    .. _repo: https://github.com/ZhuiyiTechnology/roformer\n    .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox\n\n    If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).\n    A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96\n    Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py\n    \"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        base=10000.0,\n        interleaved=False,\n        scale_base=None,\n        device=None,\n    ):\n        \"\"\"\n        interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead\n            of 1st half and 2nd half (GPT-NeoX style).\n        \"\"\"\n        super().__init__()\n        self.dim = dim\n        self.base = float(base)\n        # Generate and save the inverse frequency buffer (non trainable)\n        inv_freq = self._compute_inv_freq(device)\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n        self.interleaved = interleaved\n        self.scale_base = scale_base\n        scale = (\n            (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)\n            if scale_base is not None\n            else None\n        )\n        self.register_buffer(\"scale\", scale, persistent=False)\n\n        self._seq_len_cached = 0\n        self._cos_cached = None\n        self._sin_cached = None\n        self._cos_k_cached = None\n        self._sin_k_cached = None\n\n    def _compute_inv_freq(self, device=None):\n        return 1.0 / (\n            self.base\n            ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)\n        )\n\n    def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):\n        # Reset the tables if the sequence length has changed,\n        # if we're on a new device (possibly due to tracing for instance),\n        # or if we're switching from inference mode to training\n        if (\n            seqlen > self._seq_len_cached\n            or self._cos_cached is None\n            or self._cos_cached.device != device\n            or self._cos_cached.dtype != dtype\n            or (self.training and self._cos_cached.is_inference())\n        ):\n            self._seq_len_cached = seqlen\n            # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16\n            # And the output of arange can be quite large, so bf16 would lose a lot of precision.\n            t = torch.arange(seqlen, device=device, dtype=torch.float32)\n            # We want fp32 here as well since inv_freq will be multiplied with t, and the output\n            # will be large. Having it in bf16 will lose a lot of precision and cause the\n            # cos & sin output to change significantly.\n            # We want to recompute self.inv_freq if it was not loaded in fp32\n            if self.inv_freq.dtype != torch.float32:\n                inv_freq = self._compute_inv_freq(device=device)\n            else:\n                inv_freq = self.inv_freq\n            # Don't do einsum, it converts fp32 to bf16 under AMP\n            # freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n            freqs = torch.outer(t, inv_freq)\n            if self.scale is None:\n                self._cos_cached = torch.cos(freqs).to(dtype)\n                self._sin_cached = torch.sin(freqs).to(dtype)\n            else:\n                power = (\n                    torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)\n                    - seqlen // 2\n                ) / self.scale_base\n                scale = self.scale.to(device=power.device) ** rearrange(power, \"s -> s 1\")\n                # We want the multiplication by scale to happen in fp32\n                self._cos_cached = (torch.cos(freqs) * scale).to(dtype)\n                self._sin_cached = (torch.sin(freqs) * scale).to(dtype)\n                self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)\n                self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)\n\n    def forward(\n        self,\n        qkv: torch.Tensor,\n        kv: Optional[torch.Tensor] = None,\n        seqlen_offset: Union[int, torch.Tensor] = 0,\n        max_seqlen: Optional[int] = None,\n        num_heads_q: Optional[int] = None,\n    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        \"\"\"\n        qkv: (batch, seqlen, 3, nheads, headdim) or (batch, seqlen, num_heads_q + 2 * num_heads_k, headdim)\n            if kv is none, else it's just q of shape (batch, seqlen, nheads, headdim).\n            If qkv has shape (batch, seqlen, num_heads_q + 2 * num_heads_k, headdim) (e.g. MQA / GQA),\n            then num_heads_q must be provided.\n        kv: (batch, seqlen, 2, nheads, headdim)\n        seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.\n            Most commonly used in inference when we have KV cache.\n            If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one\n            should pass in max_seqlen, which will update the cos / sin cache up to that length.\n        Apply rotary embedding *inplace* to qkv and / or kv.\n        \"\"\"\n        seqlen = qkv.shape[1]\n        if max_seqlen is not None:\n            self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)\n        elif isinstance(seqlen_offset, int):\n            self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)\n        if kv is None:\n            return apply_rotary_emb_qkv_(\n                qkv,\n                self._cos_cached,\n                self._sin_cached,\n                self._cos_k_cached if self.scale is not None else None,\n                self._sin_k_cached if self.scale is not None else None,\n                interleaved=self.interleaved,\n                seqlen_offsets=seqlen_offset,\n                num_heads_q=num_heads_q,\n            )\n        else:\n            q = qkv\n            q = apply_rotary_emb_func(\n                q,\n                self._cos_cached,\n                self._sin_cached,\n                interleaved=self.interleaved,\n                inplace=True,\n                seqlen_offsets=seqlen_offset,\n            )\n            kv = apply_rotary_emb_kv_(\n                kv,\n                self._cos_cached if self.scale is None else self._cos_k_cached,\n                self._sin_cached if self.scale is None else self._sin_k_cached,\n                interleaved=self.interleaved,\n                seqlen_offsets=seqlen_offset,\n            )\n            return q, kv\n"
  },
  {
    "path": "flash_attn/losses/__init__.py",
    "content": ""
  },
  {
    "path": "flash_attn/losses/cross_entropy.py",
    "content": "# Copyright (c) 2024, Tri Dao.\n\nimport torch\nimport torch.nn as nn\n\nfrom flash_attn.ops.triton.cross_entropy import cross_entropy_loss\n\n\nclass CrossEntropyLoss(nn.Module):\n    def __init__(\n        self,\n        ignore_index=-100,\n        reduction=\"mean\",\n        label_smoothing=0.0,\n        logit_scale=1.0,\n        lse_square_scale=0.0,\n        inplace_backward=False,\n        process_group=None,\n        return_z_loss=False,\n    ):\n        \"\"\"\n        Arguments:\n            ignore_index: int. If labels == ignore_index, the loss is set to 0.0.\n            label_smoothing: float\n            lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.\n                This is also referred to as \"z-loss\".\n            inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.\n                This saves memory.\n            process_group: if not None, we're doing Tensor Parallel: each process is responsible for\n                one part of the vocab. The loss will be aggregated across processes.\n            return_z_loss: bool. If True, we return the component of the loss contributed by\n                the lse_square_scale value. This value is only for logging and does not support\n                backprop.\n        \"\"\"\n        super().__init__()\n        if reduction not in [\"mean\", \"none\", \"sum\"]:\n            raise NotImplementedError(\"Only support reduction = 'mean' or 'none' or 'sum'\")\n        self.ignore_index = ignore_index\n        self.reduction = reduction\n        self.label_smoothing = label_smoothing\n        self.logit_scale = logit_scale\n        self.lse_square_scale = lse_square_scale\n        self.inplace_backward = inplace_backward\n        self.process_group = process_group\n        self.return_z_loss = return_z_loss\n\n    def forward(self, input, target, precomputed_lse=None):\n        \"\"\"\n        Arguments:\n            input: (batch, vocab_size)\n            target: (batch,)\n        Returns:\n            losses: (batch,) if reduction is 'none', else (1,), dtype float\n            z_loss: (batch,) if reduction is 'none', else (1,), dtype float (if self.return_z_loss)\n        \"\"\"\n        assert input.is_cuda and target.is_cuda, \"Only support CUDA tensors\"\n        loss, z_loss = cross_entropy_loss(\n            input,\n            target,\n            precomputed_lse=precomputed_lse,\n            label_smoothing=self.label_smoothing,\n            logit_scale=self.logit_scale,\n            lse_square_scale=self.lse_square_scale,\n            ignore_index=self.ignore_index,\n            inplace_backward=self.inplace_backward,\n            process_group=self.process_group,\n        )\n        if self.reduction == \"mean\":\n            loss = loss.sum() / (target != self.ignore_index).sum()\n        elif self.reduction == \"sum\":\n            loss = loss.sum()\n        else:\n            loss = loss\n\n        if not self.return_z_loss:\n            return loss\n\n        if self.reduction == \"mean\":\n            z_loss = z_loss.sum() / (target != self.ignore_index).sum()\n        elif self.reduction == \"sum\":\n            z_loss = z_loss.sum()\n        else:\n            z_loss = z_loss\n\n        return loss, z_loss\n"
  },
  {
    "path": "flash_attn/models/__init__.py",
    "content": ""
  },
  {
    "path": "flash_attn/models/baichuan.py",
    "content": "# Copyright (c) 2023, GGGGGGXY, Tri Dao.\n\nimport math\nimport json\nimport re\nfrom pathlib import Path\n\nfrom collections import OrderedDict\n\nimport torch\nimport torch.nn.functional as F\n\nfrom einops import rearrange\nfrom transformers import GPT2Config, AutoConfig, PretrainedConfig\n\n\ndef remap_state_dict_hf_baichuan(state_dict, config):\n    def key_mapping_layers(key):\n        return re.sub(r\"^model.\", \"transformer.\", key)\n\n    state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())\n\n    # Word embedding\n    def key_mapping_emb(key):\n        return re.sub(\n            r\"^transformer.embed_tokens.\",\n            \"transformer.embeddings.word_embeddings.\",\n            key,\n        )\n\n    state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())\n    word_embeddings = state_dict.pop(\"transformer.embeddings.word_embeddings.weight\")\n    # It's possible that vocab_size is padded to be a multiple of 8, for example.\n    pad_vocab_size_multiple = getattr(config, \"pad_vocab_size_multiple\", 1)\n    vocab_size = (\n        math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple)\n        * pad_vocab_size_multiple\n    )\n    state_dict[\"transformer.embeddings.word_embeddings.weight\"] = F.pad(\n        word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])\n    )\n    if getattr(config, \"tie_word_embeddings\"):\n        state_dict[\"lm_head.weight\"] = state_dict[\n            \"transformer.embeddings.word_embeddings.weight\"\n        ]\n    else:\n        output_embeddings = state_dict.pop(\"lm_head.weight\")\n        # Need to recompute vocab_size since Baichuan shards the word embeddings and output embeddings\n        # differently.\n        vocab_size = (\n            math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)\n            * pad_vocab_size_multiple\n        )\n        # It's possible that vocab_size is padded to be a multiple of 8, for example.\n        state_dict[\"lm_head.weight\"] = F.pad(\n            output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])\n        )\n\n    # LayerNorm\n    def key_mapping_ln(key):\n        key = re.sub(r\"^transformer.norm.\", r\"transformer.ln_f.\", key)\n        key = re.sub(\n            r\"^transformer.layers.(\\d+).input_layernorm.\",\n            r\"transformer.layers.\\1.norm1.\",\n            key,\n        )\n        key = re.sub(\n            r\"^transformer.layers.(\\d+).post_attention_layernorm.\",\n            r\"transformer.layers.\\1.norm2.\",\n            key,\n        )\n        return key\n\n    state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())\n\n    # MLP\n    for l in range(config.n_layer):\n        w1 = state_dict.pop(f\"transformer.layers.{l}.mlp.gate_proj.weight\")\n        w3 = state_dict.pop(f\"transformer.layers.{l}.mlp.up_proj.weight\")\n        # Our ordering is different\n        state_dict[f\"transformer.layers.{l}.mlp.fc1.weight\"] = torch.cat(\n            [w3, w1], dim=0\n        )\n\n    def key_mapping_mlp(key):\n        return re.sub(\n            r\"^transformer.layers.(\\d+).mlp.down_proj.\",\n            r\"transformer.layers.\\1.mlp.fc2.\",\n            key,\n        )\n\n    state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())\n\n    # Attention\n    def key_mapping_attn(key):\n        key = re.sub(\n            r\"^transformer.layers.(\\d+).self_attn.W_pack.\",\n            r\"transformer.layers.\\1.mixer.Wqkv.\",\n            key,\n        )\n        key = re.sub(\n            r\"^transformer.layers.(\\d+).self_attn.o_proj.\",\n            r\"transformer.layers.\\1.mixer.out_proj.\",\n            key,\n        )\n        return key\n\n    state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())\n    for l in range(config.n_layer):\n        # pop rotary_emb.inv_freq from state dict\n        state_dict.pop(f\"transformer.layers.{l}.self_attn.rotary_emb.inv_freq\", None)\n    return state_dict\n\n\ndef baichuan_config_to_gpt2_config(baichuan_config: PretrainedConfig) -> GPT2Config:\n    # HACK: the config doesn't have say whether it's rotary or alibi.\n    # So we have to infer from the hidden size (7B -> rotary, 13B -> alibi).\n    # HACK: the config doesn't have say whether it uses norm head.\n    # So we have to infer from the vocab size\n    # (v1, vocab size 64k, no norm head; v2, vocab size 128k, norm head).\n    use_rotary = baichuan_config.hidden_size < 5000\n    return GPT2Config(\n        vocab_size=baichuan_config.vocab_size,\n        n_positions=0,  # No absolute position embedding\n        n_embd=baichuan_config.hidden_size,\n        n_layer=baichuan_config.num_hidden_layers,\n        n_head=baichuan_config.num_attention_heads,\n        n_inner=baichuan_config.intermediate_size,\n        activation_function=\"swiglu\",  # Hardcode since HF calls it 'silu'\n        # baichuan doesn't have dropout, idk if it's because they only release the inference code\n        resid_pdrop=0.0,\n        embd_pdrop=0.0,\n        attn_pdrop=0.0,\n        layer_norm_epsilon=baichuan_config.rms_norm_eps,\n        initializer_range=baichuan_config.initializer_range,\n        bos_token_id=baichuan_config.bos_token_id,\n        eos_token_id=baichuan_config.eos_token_id,\n        # These are new arguments not in the original GPT2Config\n        pad_token_id=baichuan_config.pad_token_id,  # Idk if this does anything\n        rms_norm=True,\n        rotary_emb_fraction=1.0 if use_rotary else 0.0,\n        rotary_emb_interleaved=False,\n        use_alibi=not use_rotary,\n        use_flash_attn=not use_rotary,  # Alibi code path requires flash_attn\n        tie_word_embeddings=False,\n        norm_head=baichuan_config.vocab_size > 70000,\n        qkv_proj_bias=False,\n        out_proj_bias=False,\n        mlp_fc1_bias=False,\n        mlp_fc2_bias=False,\n    )\n"
  },
  {
    "path": "flash_attn/models/bert.py",
    "content": "# Copyright (c) 2022, Tri Dao.\n# This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.\n# https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py\n# https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py\n\n# Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py\n\nimport logging\nimport re\nfrom collections import OrderedDict\nfrom collections.abc import Sequence\nfrom functools import partial\nfrom typing import Any, Mapping\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom einops import rearrange\nfrom transformers import BertConfig, PretrainedConfig\nfrom transformers.models.bert.modeling_bert import (\n    BaseModelOutputWithPoolingAndCrossAttentions,\n    BertForPreTrainingOutput,\n)\n\nfrom flash_attn.bert_padding import (\n    index_first_axis,\n    index_first_axis_residual,\n    pad_input,\n    unpad_input,\n)\nfrom flash_attn.modules.block import Block\nfrom flash_attn.modules.embedding import BertEmbeddings\nfrom flash_attn.modules.mha import MHA\nfrom flash_attn.modules.mlp import FusedMLP, Mlp\nfrom flash_attn.utils.pretrained import state_dict_from_pretrained\n\ntry:\n    from flash_attn.ops.fused_dense import FusedDense\nexcept ImportError:\n    FusedDense = None\n\ntry:\n    from flash_attn.ops.triton.layer_norm import layer_norm_fn\nexcept ImportError:\n    layer_norm_fn = None\n\n\ntry:\n    from flash_attn.losses.cross_entropy import CrossEntropyLoss\nexcept ImportError:\n    CrossEntropyLoss = None\n\n\nlogger = logging.getLogger(__name__)\n\n\ndef create_mixer_cls(config, cross_attn=False, return_residual=False):\n    use_flash_attn = getattr(config, \"use_flash_attn\", False)\n    fused_bias_fc = getattr(config, \"fused_bias_fc\", False)\n    rotary_kwargs = {}\n    if config.position_embedding_type == \"rotary\":\n        rotary_kwargs[\"rotary_emb_dim\"] = getattr(config, \"rotary_emb_dim\", config.hidden_size)\n        rotary_kwargs[\"rotary_emb_base\"] = getattr(config, \"rotary_emb_base\", 10000.0)\n        rotary_kwargs[\"rotary_emb_scale_base\"] = getattr(config, \"rotary_emb_scale_base\", None)\n        rotary_kwargs[\"rotary_emb_interleaved\"] = getattr(config, \"rotary_emb_interleaved\", False)\n    mixer_cls = partial(\n        MHA,\n        num_heads=config.num_attention_heads,\n        cross_attn=cross_attn,\n        dropout=config.attention_probs_dropout_prob,\n        causal=False,\n        fused_bias_fc=fused_bias_fc,\n        use_flash_attn=use_flash_attn,\n        return_residual=return_residual,\n        **rotary_kwargs,\n    )\n    return mixer_cls\n\n\ndef create_mlp_cls(config, layer_idx=None, return_residual=False):\n    inner_dim = config.intermediate_size\n    fused_mlp = getattr(config, \"fused_mlp\", False)\n    if fused_mlp:\n        assert config.hidden_act in [\"gelu_new\", \"gelu_fast\", \"gelu_pytorch_tanh\"], (\n            \"fused_mlp only \" \"supports approximate gelu\"\n        )\n    if not fused_mlp:\n        approximate = (\n            \"tanh\"\n            if config.hidden_act in [\"gelu_new\", \"gelu_fast\", \"gelu_pytorch_tanh\"]\n            else \"none\"\n        )\n        mlp_cls = partial(\n            Mlp,\n            hidden_features=inner_dim,\n            activation=partial(F.gelu, approximate=approximate),\n            return_residual=return_residual,\n        )\n    else:\n        if FusedMLP is None:\n            raise ImportError(\"fused_dense is not installed\")\n        mlp_checkpoint_lvl = getattr(config, \"mlp_checkpoint_lvl\", 0)\n        # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer\n        if isinstance(mlp_checkpoint_lvl, Sequence):\n            assert layer_idx is not None\n            mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]\n        mlp_cls = partial(\n            FusedMLP,\n            hidden_features=inner_dim,\n            checkpoint_lvl=mlp_checkpoint_lvl,\n            return_residual=return_residual,\n        )\n    return mlp_cls\n\n\ndef create_block(config, layer_idx=None):\n    last_layer_subset = getattr(config, \"last_layer_subset\", False)\n    cross_attn = last_layer_subset and layer_idx == config.num_hidden_layers - 1\n    # TD [2022-12-19]: For cross attention (last layer), we actually want to return the\n    # residual x_kv, not residual x. But it's annoying to change the API (and it only affects\n    # one layer) so we just choose not to return residual in this case.\n    return_residual = not cross_attn\n    mixer_cls = create_mixer_cls(config, cross_attn, return_residual=return_residual)\n    mlp_cls = create_mlp_cls(config, layer_idx, return_residual=return_residual)\n    norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_eps)\n    block = Block(\n        config.hidden_size,\n        mixer_cls,\n        mlp_cls,\n        norm_cls=norm_cls,\n        prenorm=False,\n        resid_dropout1=config.hidden_dropout_prob,\n        resid_dropout2=config.hidden_dropout_prob,\n        fused_dropout_add_ln=getattr(config, \"fused_dropout_add_ln\", False),\n        return_residual=return_residual,\n    )\n    return block\n\n\n# https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748\ndef _init_weights(module, initializer_range=0.02):\n    if isinstance(module, nn.Linear):\n        nn.init.normal_(module.weight, std=initializer_range)\n        if module.bias is not None:\n            nn.init.zeros_(module.bias)\n    elif isinstance(module, nn.Embedding):\n        nn.init.normal_(module.weight, std=initializer_range)\n        if module.padding_idx is not None:\n            nn.init.zeros_(module.weight[module.padding_idx])\n\n\nclass BertEncoder(nn.Module):\n    def __init__(self, config: BertConfig):\n        super().__init__()\n        self.use_flash_attn = getattr(config, \"use_flash_attn\", False)\n        self.layers = nn.ModuleList(\n            [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]\n        )\n\n    def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):\n        \"\"\"If subset_mask is not None, we only want output for the subset of the sequence.\n        This means that we only compute the last layer output for these tokens.\n        subset_mask: (batch, seqlen), dtype=torch.bool\n        \"\"\"\n        if key_padding_mask is None or not self.use_flash_attn:\n            mixer_kwargs = (\n                {\"key_padding_mask\": key_padding_mask} if key_padding_mask is not None else None\n            )\n            for layer in self.layers:\n                hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)\n            if subset_mask is not None:\n                hidden_states = hidden_states[subset_mask]\n        else:\n            batch, seqlen = hidden_states.shape[:2]\n            hidden_states, indices, cu_seqlens, max_seqlen_in_batch, _ = unpad_input(\n                hidden_states, key_padding_mask\n            )\n            mixer_kwargs = {\"cu_seqlens\": cu_seqlens, \"max_seqlen\": max_seqlen_in_batch}\n            if subset_mask is None:\n                for layer in self.layers:\n                    hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)\n                hidden_states = pad_input(hidden_states, indices, batch, seqlen)\n            else:\n                for layer in self.layers[:-1]:\n                    hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)\n                if key_padding_mask is not None:\n                    subset_idx = torch.nonzero(\n                        subset_mask[key_padding_mask], as_tuple=False\n                    ).flatten()\n                    subset_seqlens = (subset_mask & key_padding_mask).sum(dim=-1, dtype=torch.int32)\n                    subset_cu_seqlens = F.pad(\n                        torch.cumsum(subset_seqlens, dim=0, dtype=torch.int32), (1, 0)\n                    )\n                else:\n                    subset_idx = torch.nonzero(subset_mask, as_tuple=False).flatten()\n                    subset_seqlens = subset_mask.sum(dim=-1, dtype=torch.int32)\n                    subset_cu_seqlens = F.pad(\n                        torch.cumsum(subset_seqlens, dim=0, dtype=torch.int32), (1, 0)\n                    )\n                hidden_states_subset, hidden_states = index_first_axis_residual(\n                    hidden_states, subset_idx\n                )\n                # It's ok to set max_seqlen_q to be much larger\n                mixer_kwargs = {\n                    \"x_kv\": hidden_states,\n                    \"cu_seqlens\": subset_cu_seqlens,\n                    \"max_seqlen\": max_seqlen_in_batch,\n                    \"cu_seqlens_k\": cu_seqlens,\n                    \"max_seqlen_k\": max_seqlen_in_batch,\n                }\n                hidden_states = self.layers[-1](hidden_states_subset, mixer_kwargs=mixer_kwargs)\n        return hidden_states\n\n\nclass BertPooler(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        fused_bias_fc = getattr(config, \"fused_bias_fc\", False)\n        if fused_bias_fc and FusedDense is None:\n            raise ImportError(\"fused_dense is not installed\")\n        linear_cls = nn.Linear if not fused_bias_fc else FusedDense\n        self.dense = linear_cls(config.hidden_size, config.hidden_size)\n        self.activation = nn.Tanh()\n\n    def forward(self, hidden_states, pool=True):\n        # We \"pool\" the model by simply taking the hidden state corresponding\n        # to the first token.\n        first_token_tensor = hidden_states[:, 0] if pool else hidden_states\n        pooled_output = self.dense(first_token_tensor)\n        pooled_output = self.activation(pooled_output)\n        return pooled_output\n\n\nclass BertPredictionHeadTransform(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        fused_bias_fc = getattr(config, \"fused_bias_fc\", False)\n        if fused_bias_fc and FusedDense is None:\n            raise ImportError(\"fused_dense is not installed\")\n        self.fused_dropout_add_ln = getattr(config, \"fused_dropout_add_ln\", False)\n        if self.fused_dropout_add_ln and layer_norm_fn is None:\n            raise ImportError(\"Triton is not installed\")\n        linear_cls = nn.Linear if not fused_bias_fc else FusedDense\n        self.dense = linear_cls(config.hidden_size, config.hidden_size)\n        approximate = (\n            \"tanh\"\n            if config.hidden_act in [\"gelu_new\", \"gelu_fast\", \"gelu_pytorch_tanh\"]\n            else \"none\"\n        )\n        self.transform_act_fn = nn.GELU(approximate=approximate)\n        self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        hidden_states = self.dense(hidden_states)\n        hidden_states = self.transform_act_fn(hidden_states)\n        if not self.fused_dropout_add_ln:\n            hidden_states = self.layer_norm(hidden_states)\n        else:\n            hidden_states = layer_norm_fn(\n                hidden_states, self.layer_norm.weight, self.layer_norm.bias, eps=self.layer_norm.eps\n            )\n        return hidden_states\n\n\nclass BertLMPredictionHead(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        fused_bias_fc = getattr(config, \"fused_bias_fc\", False)\n        if fused_bias_fc and FusedDense is None:\n            raise ImportError(\"fused_dense is not installed\")\n        linear_cls = nn.Linear if not fused_bias_fc else FusedDense\n\n        self.transform = BertPredictionHeadTransform(config)\n\n        # The output weights are the same as the input embeddings, but there is\n        # an output-only bias for each token.\n        self.decoder = linear_cls(config.hidden_size, config.vocab_size, bias=True)\n\n    def forward(self, hidden_states):\n        hidden_states = self.transform(hidden_states)\n        hidden_states = self.decoder(hidden_states)\n        return hidden_states\n\n\nclass BertPreTrainingHeads(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.predictions = BertLMPredictionHead(config)\n        self.seq_relationship = nn.Linear(config.hidden_size, 2)\n\n    def forward(self, sequence_output, pooled_output):\n        prediction_scores = self.predictions(sequence_output)\n        seq_relationship_score = self.seq_relationship(pooled_output)\n        return prediction_scores, seq_relationship_score\n\n\nclass BertPreTrainedModel(nn.Module):\n    \"\"\"An abstract class to handle weights initialization and\n    a simple interface for dowloading and loading pretrained models.\n    \"\"\"\n\n    def __init__(self, config, *inputs, **kwargs):\n        super().__init__()\n        if not isinstance(config, BertConfig):\n            raise ValueError(\n                \"Parameter config in `{}(config)` should be an instance of class `BertConfig`. \"\n                \"To create a model from a Google pretrained model use \"\n                \"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`\".format(\n                    self.__class__.__name__, self.__class__.__name__\n                )\n            )\n        self.config = config\n\n    @classmethod\n    def from_pretrained(cls, model_name, config, *inputs, **kwargs):\n        \"\"\"\n        Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.\n        Download and cache the pre-trained model file if needed.\n\n        Params:\n            pretrained_model_name_or_path: either:\n                - a path or url to a pretrained model archive containing:\n                    . `bert_config.json` a configuration file for the model\n                    . `pytorch_model.bin` a PyTorch dump of a BertForPretraining instance\n                - a path or url to a pretrained model archive containing:\n                    . `bert_config.json` a configuration file for the model\n                    . `model.chkpt` a TensorFlow checkpoint\n            *inputs, **kwargs: additional input for the specific Bert class\n                (ex: num_labels for BertForSequenceClassification)\n        \"\"\"\n        # Instantiate model.\n        model = cls(config, *inputs, **kwargs)\n        load_return = model.load_state_dict(\n            remap_state_dict(state_dict_from_pretrained(model_name), config), strict=False\n        )\n        logger.info(load_return)\n        return model\n\n\nclass BertModel(BertPreTrainedModel):\n    def __init__(self, config: BertConfig, add_pooling_layer=True):\n        super().__init__(config)\n        self.pad_vocab_size_multiple = getattr(config, \"pad_vocab_size_multiple\", 1)\n        if config.vocab_size % self.pad_vocab_size_multiple != 0:\n            config.vocab_size += self.pad_vocab_size_multiple - (\n                config.vocab_size % self.pad_vocab_size_multiple\n            )\n        self.fused_dropout_add_ln = getattr(config, \"fused_dropout_add_ln\", False)\n        if self.fused_dropout_add_ln and layer_norm_fn is None:\n            raise ImportError(\"Triton is not installed\")\n        assert config.hidden_act in [\"gelu\", \"gelu_new\", \"gelu_fast\", \"gelu_pytorch_tanh\"]\n\n        self.embeddings = BertEmbeddings(\n            config.hidden_size,\n            config.vocab_size,\n            config.max_position_embeddings,\n            config.type_vocab_size,\n            padding_idx=config.pad_token_id,\n        )\n        self.emb_drop = nn.Dropout(config.hidden_dropout_prob)\n        self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n        self.encoder = BertEncoder(config)\n        self.pooler = BertPooler(config) if add_pooling_layer else None\n\n        self.apply(partial(_init_weights, initializer_range=config.initializer_range))\n\n    def forward(\n        self,\n        input_ids,\n        position_ids=None,\n        token_type_ids=None,\n        attention_mask=None,\n        masked_tokens_mask=None,\n    ):\n        \"\"\"If masked_tokens_mask is not None (i.e. last_layer_subset == True in BertForPreTraining),\n        we only want the output for the masked tokens. This means that we only compute the last\n        layer output for these tokens.\n        masked_tokens_mask: (batch, seqlen), dtype=torch.bool\n        \"\"\"\n        hidden_states = self.embeddings(\n            input_ids, position_ids=position_ids, token_type_ids=token_type_ids\n        )\n        # TD [2022-12:18]: Don't need to force residual in fp32\n        # BERT puts embedding LayerNorm before embedding dropout.\n        if not self.fused_dropout_add_ln:\n            hidden_states = self.emb_ln(hidden_states)\n        else:\n            hidden_states = layer_norm_fn(\n                hidden_states, self.emb_ln.weight, self.emb_ln.bias, eps=self.emb_ln.eps\n            )\n        hidden_states = self.emb_drop(hidden_states)\n\n        if masked_tokens_mask is not None:\n            batch_size, seqlen = input_ids.shape[:2]\n            # We also need the first column for the CLS token\n            first_col_mask = torch.zeros(\n                batch_size, seqlen, dtype=torch.bool, device=input_ids.device\n            )\n            first_col_mask[:, 0] = True\n            subset_mask = masked_tokens_mask | first_col_mask\n        else:\n            subset_mask = None\n\n        sequence_output = self.encoder(\n            hidden_states, key_padding_mask=attention_mask, subset_mask=subset_mask\n        )\n\n        if masked_tokens_mask is None:\n            pooled_output = self.pooler(sequence_output) if self.pooler is not None else None\n        else:\n            # TD [2022-03-01]: the indexing here is very tricky.\n            if attention_mask is not None:\n                subset_idx = subset_mask[attention_mask]\n                pool_input = sequence_output[first_col_mask[attention_mask][subset_idx]]\n                sequence_output = sequence_output[masked_tokens_mask[attention_mask][subset_idx]]\n            else:\n                pool_input = sequence_output[first_col_mask[subset_mask]]\n                sequence_output = sequence_output[masked_tokens_mask[subset_mask]]\n            pooled_output = self.pooler(pool_input, pool=False) if self.pooler is not None else None\n\n        return BaseModelOutputWithPoolingAndCrossAttentions(\n            last_hidden_state=sequence_output,\n            pooler_output=pooled_output,\n        )\n\n\nclass BertForPreTraining(BertPreTrainedModel):\n    def __init__(self, config: BertConfig):\n        super().__init__(config)\n        # If dense_seq_output, we only need to pass the hidden states for the masked out tokens\n        # (around 15%) to the classifier heads.\n        self.dense_seq_output = getattr(config, \"dense_seq_output\", False)\n        # If last_layer_subset, we only need the compute the last layer for a subset of tokens\n        # (e.g., the tokens we need to compute the masked LM loss and the next-sentence prediction).\n        self.last_layer_subset = getattr(config, \"last_layer_subset\", False)\n        if self.last_layer_subset:\n            assert self.dense_seq_output, \"last_layer_subset requires dense_seq_output\"\n        use_xentropy = getattr(config, \"use_xentropy\", False)\n        if use_xentropy and CrossEntropyLoss is None:\n            raise ImportError(\"xentropy_cuda is not installed\")\n        loss_cls = (\n            nn.CrossEntropyLoss\n            if not use_xentropy\n            else partial(CrossEntropyLoss, inplace_backward=True)\n        )\n\n        self.bert = BertModel(config)\n        self.cls = BertPreTrainingHeads(config)\n        self.mlm_loss = loss_cls(ignore_index=0)\n        self.nsp_loss = loss_cls(ignore_index=-1)\n\n        # Initialize weights and apply final processing\n        self.apply(partial(_init_weights, initializer_range=config.initializer_range))\n        self.tie_weights()\n\n    def tie_weights(self):\n        self.cls.predictions.decoder.weight = self.bert.embeddings.word_embeddings.weight\n\n    def forward(\n        self,\n        input_ids,\n        position_ids=None,\n        token_type_ids=None,\n        attention_mask=None,\n        labels=None,\n        next_sentence_label=None,\n    ):\n        \"\"\"\n        If labels are provided, they must be 0 for masked out tokens (as specified in the attention\n        mask).\n        Outputs:\n            if `labels` and `next_sentence_label` are not `None`:\n                Outputs the total_loss which is the sum of the masked language modeling loss and the next\n                sentence classification loss.\n            if `labels` or `next_sentence_label` is `None`:\n                Outputs a tuple comprising\n                - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and\n                - the next sentence classification logits of shape [batch_size, 2].\n\n        \"\"\"\n        masked_tokens_mask = labels > 0 if (self.last_layer_subset and labels is not None) else None\n        outputs = self.bert(\n            input_ids,\n            position_ids=position_ids,\n            token_type_ids=token_type_ids,\n            attention_mask=attention_mask.bool() if attention_mask is not None else None,\n            masked_tokens_mask=masked_tokens_mask,\n        )\n        sequence_output, pooled_output = outputs.last_hidden_state, outputs.pooler_output\n        if self.dense_seq_output and labels is not None:\n            masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten()\n            if not self.last_layer_subset:\n                sequence_output = index_first_axis(\n                    rearrange(sequence_output, \"b s d -> (b s) d\"), masked_token_idx\n                )\n        prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)\n\n        total_loss = None\n        if labels is not None and next_sentence_label is not None:\n            if (\n                self.dense_seq_output and labels is not None\n            ):  # prediction_scores are already flattened\n                masked_lm_loss = self.mlm_loss(\n                    prediction_scores, labels.flatten()[masked_token_idx]\n                )\n            else:\n                masked_lm_loss = self.mlm_loss(\n                    rearrange(prediction_scores, \"... v -> (...) v\"),\n                    rearrange(labels, \"... -> (...)\"),\n                )\n            next_sentence_loss = self.nsp_loss(\n                rearrange(seq_relationship_score, \"... t -> (...) t\"),\n                rearrange(next_sentence_label, \"... -> (...)\"),\n            )\n            total_loss = masked_lm_loss.float() + next_sentence_loss.float()\n\n        return BertForPreTrainingOutput(\n            loss=total_loss,\n            prediction_logits=prediction_scores,\n            seq_relationship_logits=seq_relationship_score,\n        )\n\n\ndef remap_state_dict(state_dict, config: PretrainedConfig):\n    \"\"\"\n    Map the state_dict of a Huggingface BERT model to be flash_attn compatible.\n    \"\"\"\n\n    # LayerNorm\n    def key_mapping_ln_gamma_beta(key):\n        key = re.sub(r\"LayerNorm.gamma$\", \"LayerNorm.weight\", key)\n        key = re.sub(r\"LayerNorm.beta$\", \"LayerNorm.bias\", key)\n        return key\n\n    state_dict = OrderedDict((key_mapping_ln_gamma_beta(k), v) for k, v in state_dict.items())\n\n    # Layers\n    def key_mapping_layers(key):\n        return re.sub(r\"^bert.encoder.layer.\", \"bert.encoder.layers.\", key)\n\n    state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())\n\n    # LayerNorm\n    def key_mapping_ln(key):\n        key = re.sub(r\"^bert.embeddings.LayerNorm.\", \"bert.emb_ln.\", key)\n        key = re.sub(\n            r\"^bert.encoder.layers.(\\d+).attention.output.LayerNorm.(weight|bias)\",\n            r\"bert.encoder.layers.\\1.norm1.\\2\",\n            key,\n        )\n        key = re.sub(\n            r\"^bert.encoder.layers.(\\d+).output.LayerNorm.(weight|bias)\",\n            r\"bert.encoder.layers.\\1.norm2.\\2\",\n            key,\n        )\n        key = re.sub(\n            r\"^cls.predictions.transform.LayerNorm.(weight|bias)\",\n            r\"cls.predictions.transform.layer_norm.\\1\",\n            key,\n        )\n        return key\n\n    state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())\n\n    # MLP\n    def key_mapping_mlp(key):\n        key = re.sub(\n            r\"^bert.encoder.layers.(\\d+).intermediate.dense.(weight|bias)\",\n            r\"bert.encoder.layers.\\1.mlp.fc1.\\2\",\n            key,\n        )\n        key = re.sub(\n            r\"^bert.encoder.layers.(\\d+).output.dense.(weight|bias)\",\n            r\"bert.encoder.layers.\\1.mlp.fc2.\\2\",\n            key,\n        )\n        return key\n\n    state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())\n\n    # Attention\n    last_layer_subset = getattr(config, \"last_layer_subset\", False)\n    for d in range(config.num_hidden_layers):\n        Wq = state_dict.pop(f\"bert.encoder.layers.{d}.attention.self.query.weight\")\n        Wk = state_dict.pop(f\"bert.encoder.layers.{d}.attention.self.key.weight\")\n        Wv = state_dict.pop(f\"bert.encoder.layers.{d}.attention.self.value.weight\")\n        bq = state_dict.pop(f\"bert.encoder.layers.{d}.attention.self.query.bias\")\n        bk = state_dict.pop(f\"bert.encoder.layers.{d}.attention.self.key.bias\")\n        bv = state_dict.pop(f\"bert.encoder.layers.{d}.attention.self.value.bias\")\n        if not (last_layer_subset and d == config.num_hidden_layers - 1):\n            state_dict[f\"bert.encoder.layers.{d}.mixer.Wqkv.weight\"] = torch.cat(\n                [Wq, Wk, Wv], dim=0\n            )\n            state_dict[f\"bert.encoder.layers.{d}.mixer.Wqkv.bias\"] = torch.cat([bq, bk, bv], dim=0)\n        else:\n            state_dict[f\"bert.encoder.layers.{d}.mixer.Wq.weight\"] = Wq\n            state_dict[f\"bert.encoder.layers.{d}.mixer.Wkv.weight\"] = torch.cat([Wk, Wv], dim=0)\n            state_dict[f\"bert.encoder.layers.{d}.mixer.Wq.bias\"] = bq\n            state_dict[f\"bert.encoder.layers.{d}.mixer.Wkv.bias\"] = torch.cat([bk, bv], dim=0)\n\n    def key_mapping_attn(key):\n        return re.sub(\n            r\"^bert.encoder.layers.(\\d+).attention.output.dense.(weight|bias)\",\n            r\"bert.encoder.layers.\\1.mixer.out_proj.\\2\",\n            key,\n        )\n\n    state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())\n\n    def key_mapping_decoder_bias(key):\n        return re.sub(r\"^cls.predictions.bias\", \"cls.predictions.decoder.bias\", key)\n\n    state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items())\n\n    # Word embedding\n    pad_vocab_size_multiple = getattr(config, \"pad_vocab_size_multiple\", 1)\n    if pad_vocab_size_multiple > 1:\n        word_embeddings = state_dict[\"bert.embeddings.word_embeddings.weight\"]\n        state_dict[\"bert.embeddings.word_embeddings.weight\"] = F.pad(\n            word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0])\n        )\n        decoder_weight = state_dict[\"cls.predictions.decoder.weight\"]\n        state_dict[\"cls.predictions.decoder.weight\"] = F.pad(\n            decoder_weight, (0, 0, 0, config.vocab_size - decoder_weight.shape[0])\n        )\n        # If the vocab was padded, we want to set the decoder bias for those padded indices to be\n        # strongly negative (i.e. the decoder shouldn't predict those indices).\n        # TD [2022-05-09]: I don't think it affects the MLPerf training.\n        decoder_bias = state_dict[\"cls.predictions.decoder.bias\"]\n        state_dict[\"cls.predictions.decoder.bias\"] = F.pad(\n            decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0\n        )\n\n    return state_dict\n\n\ndef inv_remap_state_dict(state_dict, config: PretrainedConfig):\n    \"\"\"\n    Map the state_dict of a flash_attn model to be Huggingface BERT compatible.\n\n    This function is meant to be the inverse of remap_state_dict.\n    \"\"\"\n    # Word embedding\n    pad_vocab_size_multiple = getattr(config, \"pad_vocab_size_multiple\", 1)\n    if pad_vocab_size_multiple > 1:\n        word_embeddings = state_dict[\"bert.embeddings.word_embeddings.weight\"]\n        decoder_weight = state_dict[\"cls.predictions.decoder.weight\"]\n        decoder_bias = state_dict[\"cls.predictions.decoder.bias\"]\n        # unpad embeddings\n        state_dict[\"bert.embeddings.word_embeddings.weight\"] = word_embeddings[\n            : config.orig_vocab_size, :\n        ]\n        state_dict[\"cls.predictions.decoder.weight\"] = decoder_weight[: config.orig_vocab_size, :]\n        state_dict[\"cls.predictions.decoder.bias\"] = decoder_bias[: config.orig_vocab_size]\n\n    for d in range(config.num_hidden_layers):\n        last_layer_subset = getattr(config, \"last_layer_subset\", False)\n        if not last_layer_subset or d != (config.num_hidden_layers - 1):\n            Wqkv_weights = state_dict.pop(f\"bert.encoder.layers.{d}.mixer.Wqkv.weight\")\n            Wqkv_biases = state_dict.pop(f\"bert.encoder.layers.{d}.mixer.Wqkv.bias\")\n            state_dict[f\"bert.encoder.layers.{d}.attention.self.query.weight\"] = Wqkv_weights[\n                : Wqkv_weights.shape[0] // 3, :\n            ]\n            state_dict[f\"bert.encoder.layers.{d}.attention.self.key.weight\"] = Wqkv_weights[\n                Wqkv_weights.shape[0] // 3 : 2 * Wqkv_weights.shape[0] // 3, :\n            ]\n            state_dict[f\"bert.encoder.layers.{d}.attention.self.value.weight\"] = Wqkv_weights[\n                2 * Wqkv_weights.shape[0] // 3 :, :\n            ]\n            state_dict[f\"bert.encoder.layers.{d}.attention.self.query.bias\"] = Wqkv_biases[\n                : Wqkv_biases.shape[0] // 3\n            ]\n            state_dict[f\"bert.encoder.layers.{d}.attention.self.key.bias\"] = Wqkv_biases[\n                Wqkv_biases.shape[0] // 3 : 2 * Wqkv_biases.shape[0] // 3\n            ]\n            state_dict[f\"bert.encoder.layers.{d}.attention.self.value.bias\"] = Wqkv_biases[\n                2 * Wqkv_biases.shape[0] // 3 :\n            ]\n        else:\n            Wq_weight = state_dict.pop(f\"bert.encoder.layers.{d}.mixer.Wq.weight\")\n            Wkv_weights = state_dict.pop(f\"bert.encoder.layers.{d}.mixer.Wkv.weight\")\n            Wq_bias = state_dict.pop(f\"bert.encoder.layers.{d}.mixer.Wq.bias\")\n            Wkv_biases = state_dict.pop(f\"bert.encoder.layers.{d}.mixer.Wkv.bias\")\n            state_dict[f\"bert.encoder.layers.{d}.attention.self.query.weight\"] = Wq_weight\n            state_dict[f\"bert.encoder.layers.{d}.attention.self.key.weight\"] = Wkv_weights[\n                : Wkv_weights.shape[0] // 2, :\n            ]\n            state_dict[f\"bert.encoder.layers.{d}.attention.self.value.weight\"] = Wkv_weights[\n                Wkv_weights.shape[0] // 2 :, :\n            ]\n            state_dict[f\"bert.encoder.layers.{d}.attention.self.query.bias\"] = Wq_bias\n            state_dict[f\"bert.encoder.layers.{d}.attention.self.key.bias\"] = Wkv_biases[\n                : Wkv_biases.shape[0] // 2\n            ]\n            state_dict[f\"bert.encoder.layers.{d}.attention.self.value.bias\"] = Wkv_biases[\n                Wkv_biases.shape[0] // 2 :\n            ]\n\n    def inv_key_mapping_ln(key):\n        key = re.sub(r\"bert.emb_ln.\", \"bert.embeddings.LayerNorm.\", key)\n        key = re.sub(\n            r\"bert.encoder.layers.(\\d+).norm1.(weight|bias)\",\n            r\"bert.encoder.layers.\\1.attention.output.LayerNorm.\\2\",\n            key,\n        )\n        key = re.sub(\n            r\"bert.encoder.layers.(\\d+).norm2.(weight|bias)\",\n            r\"bert.encoder.layers.\\1.output.LayerNorm.\\2\",\n            key,\n        )\n        key = re.sub(\n            r\"cls.predictions.transform.layer_norm.(weight|bias)\",\n            r\"cls.predictions.transform.LayerNorm.\\1\",\n            key,\n        )\n        return key\n\n    def inv_key_mapping_ln_gamma_beta(key):\n        key = re.sub(r\"LayerNorm.weight$\", \"LayerNorm.gamma\", key)\n        key = re.sub(r\"LayerNorm.bias$\", \"LayerNorm.beta\", key)\n        return key\n\n    def inv_key_mapping_layers(key):\n        return re.sub(r\"bert.encoder.layers.\", \"bert.encoder.layer.\", key)\n\n    def inv_key_mapping_mlp(key):\n        key = re.sub(\n            r\"bert.encoder.layer.(\\d+).mlp.fc1.(weight|bias)\",\n            r\"bert.encoder.layer.\\1.intermediate.dense.\\2\",\n            key,\n        )\n        key = re.sub(\n            r\"bert.encoder.layer.(\\d+).mlp.fc2.(weight|bias)\",\n            r\"bert.encoder.layer.\\1.output.dense.\\2\",\n            key,\n        )\n        return key\n\n    def inv_key_mapping_attn(key):\n        return re.sub(\n            r\"bert.encoder.layer.(\\d+).mixer.out_proj.(weight|bias)\",\n            r\"bert.encoder.layer.\\1.attention.output.dense.\\2\",\n            key,\n        )\n\n    def inv_key_mapping_decoder_bias(key):\n        return re.sub(r\"cls.predictions.decoder.bias\", \"cls.predictions.bias\", key)\n\n    state_dict = OrderedDict((inv_key_mapping_ln(key), value) for key, value in state_dict.items())\n    state_dict = OrderedDict(\n        (inv_key_mapping_ln_gamma_beta(key), value) for key, value in state_dict.items()\n    )\n    state_dict = OrderedDict(\n        (inv_key_mapping_layers(key), value) for key, value in state_dict.items()\n    )\n    state_dict = OrderedDict((inv_key_mapping_mlp(key), value) for key, value in state_dict.items())\n    state_dict = OrderedDict(\n        (inv_key_mapping_attn(key), value) for key, value in state_dict.items()\n    )\n    state_dict = OrderedDict(\n        (inv_key_mapping_decoder_bias(key), value) for key, value in state_dict.items()\n    )\n\n    return state_dict\n"
  },
  {
    "path": "flash_attn/models/bigcode.py",
    "content": "import math\nimport re\nfrom collections import OrderedDict\n\nimport torch\nimport torch.nn.functional as F\nfrom transformers import GPT2Config, GPTBigCodeConfig, PretrainedConfig\n\n\ndef remap_state_dict_hf_bigcode(state_dict, config: PretrainedConfig):\n    \"\"\"\n    Map the state_dict of a Huggingface BigCode model to be flash_attn compatible.\n    \"\"\"\n\n    # Word embedding and position embedding\n    def key_mapping_pos_emb(key):\n        return re.sub(r\"^transformer.wpe.\", \"transformer.embeddings.position_embeddings.\", key)\n\n    state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())\n    word_embeddings = state_dict.pop(\"transformer.wte.weight\")\n    # It's possible that vocab_size is padded to be a multiple of 8, for example.\n    pad_vocab_size_multiple = getattr(config, \"pad_vocab_size_multiple\", 1)\n    vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple\n    state_dict[\"transformer.embeddings.word_embeddings.weight\"] = F.pad(\n        word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])\n    )\n    state_dict[\"lm_head.weight\"] = state_dict[\"transformer.embeddings.word_embeddings.weight\"]\n\n    # LayerNorm\n    def key_mapping_ln(key):\n        key = re.sub(r\"^transformer.ln_f.(weight|bias)\", r\"transformer.ln_f.\\1\", key)\n        key = re.sub(\n            r\"^transformer.h.(\\d+).ln_(1|2).(weight|bias)\",\n            r\"transformer.layers.\\1.norm\\2.\\3\",\n            key,\n        )\n        return key\n\n    state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())\n\n    def key_mapping_mlp(key):\n        key = re.sub(\n            r\"^transformer.h.(\\d+).mlp.c_fc.weight\",\n            r\"transformer.layers.\\1.mlp.fc1.weight\",\n            key,\n        )\n        key = re.sub(\n            r\"^transformer.h.(\\d+).mlp.c_proj.weight\",\n            r\"transformer.layers.\\1.mlp.fc2.weight\",\n            key,\n        )\n        key = re.sub(\n            r\"^transformer.h.(\\d+).mlp.c_fc.bias\",\n            r\"transformer.layers.\\1.mlp.fc1.bias\",\n            key,\n        )\n        key = re.sub(\n            r\"^transformer.h.(\\d+).mlp.c_proj.bias\",\n            r\"transformer.layers.\\1.mlp.fc2.bias\",\n            key,\n        )\n        return key\n\n    state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())\n\n    # TODO: add support for multi-head attention\n    assert config.multi_query, \"Only multi-query attention is supported\"\n\n    # Attention\n    for d in range(config.num_hidden_layers):\n        embed_dim = config.n_embd\n        head_dim = embed_dim // config.n_head\n\n        c_attn_weight = state_dict.pop(f\"transformer.h.{d}.attn.c_attn.weight\")\n        # with multi-query attention, the weights have shape (embed_dim, embed_dim + head_dim + head_dim)\n        # see https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py#L112\n        # see also https://github.com/ggerganov/ggml/blob/dd1d575956e54c5bdc07632f25506b3b1884dbd2/examples/starcoder/convert-hf-to-ggml.py#L183\n        # ((n_head + 2) * head_dim, embed_dim) -> (3 * n_heads * head_dim, hidden_dim)\n        q, k, v = torch.split(c_attn_weight, [embed_dim, head_dim, head_dim], dim=0)\n        # duplicate k, v along the first axis (head_dim, hidden_dim) -> (n_heads * head_dim, hidden_dim)\n        k = torch.tile(k, (config.n_head, 1))\n        v = torch.tile(v, (config.n_head, 1))\n        state_dict[f\"transformer.layers.{d}.mixer.Wqkv.weight\"] = torch.cat((q, k, v), dim=0)\n\n        # same deal with the bias\n        c_attn_bias = state_dict.pop(f\"transformer.h.{d}.attn.c_attn.bias\")\n        # ((n_head + 2) * head_dim, embed_dim) -> (3 * n_heads * head_dim, hidden_dim)\n        q, k, v = torch.split(c_attn_bias, [embed_dim, head_dim, head_dim], dim=0)\n        # duplicate k, v along the first axis (head_dim, hidden_dim) -> (n_heads * head_dim, hidden_dim)\n        k = torch.tile(k, (config.n_head,))\n        v = torch.tile(v, (config.n_head,))\n        state_dict[f\"transformer.layers.{d}.mixer.Wqkv.bias\"] = torch.cat((q, k, v), dim=0)\n\n    def key_mapping_attn(key):\n        key = re.sub(\n            r\"^transformer.h.(\\d+).attn.c_proj.weight\",\n            r\"transformer.layers.\\1.mixer.out_proj.weight\",\n            key,\n        )\n        key = re.sub(\n            r\"^transformer.h.(\\d+).attn.c_proj.bias\",\n            r\"transformer.layers.\\1.mixer.out_proj.bias\",\n            key,\n        )\n        return key\n\n    state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())\n\n    return state_dict\n\n\ndef inv_remap_state_dict_hf_bigcode(state_dict, config: PretrainedConfig):\n    \"\"\"\n    Map the state_dict of a flash_attn model to be Huggingface BigCode compatible.\n\n    This function is meant to be the inverse of remap_state_dict_hf_bigcode.\n    \"\"\"\n\n    # Word embedding and position embeddings\n    def inv_key_mapping_pos_emb(key):\n        return re.sub(r\"^transformer.embeddings.position_embeddings.\", \"transformer.wpe.\", key)\n\n    state_dict = OrderedDict((inv_key_mapping_pos_emb(k), v) for k, v in state_dict.items())\n    word_embeddings = state_dict.pop(\"transformer.embeddings.word_embeddings.weight\")\n\n    word_embeddings = word_embeddings[:, : config.vocab_size]\n    state_dict[\"transformer.wte.weight\"] = word_embeddings\n    state_dict[\"lm_head.weight\"] = word_embeddings\n\n    # LayerNorm\n    def inv_key_mapping_ln(key):\n        key = re.sub(r\"^transformer.ln_f.(weight|bias)\", r\"transformer.ln_f.\\1\", key)\n        key = re.sub(\n            r\"^transformer.layers.(\\d+).norm(1|2).(weight|bias)\",\n            r\"transformer.h.\\1.ln_\\2.\\3\",\n            key,\n        )\n        return key\n\n    state_dict = OrderedDict((inv_key_mapping_ln(k), v) for k, v in state_dict.items())\n\n    # MLPs\n    def inv_key_mapping_mlp(key):\n        key = re.sub(\n            r\"^transformer.layers.(\\d+).mlp.fc1.weight\",\n            r\"transformer.h.\\1.mlp.c_fc.weight\",\n            key,\n        )\n        key = re.sub(\n            r\"^transformer.layers.(\\d+).mlp.fc2.weight\",\n            r\"transformer.h.\\1.mlp.c_proj.weight\",\n            key,\n        )\n        key = re.sub(\n            r\"^transformer.layers.(\\d+).mlp.fc1.bias\",\n            r\"transformer.h.\\1.mlp.c_fc.bias\",\n            key,\n        )\n        key = re.sub(\n            r\"^transformer.layers.(\\d+).mlp.fc2.bias\",\n            r\"transformer.h.\\1.mlp.c_proj.bias\",\n            key,\n        )\n        return key\n\n    state_dict = OrderedDict((inv_key_mapping_mlp(k), v) for k, v in state_dict.items())\n\n    # Attention\n    for d in range(config.num_hidden_layers):\n        embed_dim = config.n_embd\n        head_dim = embed_dim // config.n_head\n\n        Wqkv_weight = state_dict.pop(f\"transformer.layers.{d}.mixer.Wqkv.weight\")\n        q, k, v = torch.split(\n            Wqkv_weight, [embed_dim, head_dim * config.n_head, head_dim * config.n_head], dim=0\n        )\n        c_attn_weight = torch.cat((q, k[:head_dim], v[:head_dim]), dim=0)\n        state_dict[f\"transformer.h.{d}.attn.c_attn.weight\"] = c_attn_weight\n\n        # Same deal with the bias\n        Wqkv_bias = state_dict.pop(f\"transformer.layers.{d}.mixer.Wqkv.bias\")\n        q, k, v = torch.split(\n            Wqkv_bias, [embed_dim, head_dim * config.n_head, head_dim * config.n_head], dim=0\n        )\n        c_attn_bias = torch.cat((q, k[:head_dim], v[:head_dim]), dim=0)\n        state_dict[f\"transformer.h.{d}.attn.c_attn.bias\"] = c_attn_bias\n\n    def inv_key_mapping_attn(key):\n        key = re.sub(\n            r\"^transformer.layers.(\\d+).mixer.out_proj.weight\",\n            r\"transformer.h.\\1.attn.c_proj.weight\",\n            key,\n        )\n        key = re.sub(\n            r\"^transformer.layers.(\\d+).mixer.out_proj.bias\",\n            r\"transformer.h.\\1.attn.c_proj.bias\",\n            key,\n        )\n        return key\n\n    state_dict = OrderedDict((inv_key_mapping_attn(k), v) for k, v in state_dict.items())\n\n    return state_dict\n\n\ndef bigcode_config_to_gpt2_config(bigcode_config: GPTBigCodeConfig) -> GPT2Config:\n    return GPT2Config(\n        activation_function=bigcode_config.activation_function,\n        attn_pdrop=bigcode_config.attn_pdrop,\n        bos_token_id=bigcode_config.bos_token_id,\n        embd_pdrop=bigcode_config.embd_pdrop,\n        eos_token_id=bigcode_config.eos_token_id,\n        initializer_range=bigcode_config.initializer_range,\n        layer_norm_epsilon=bigcode_config.layer_norm_epsilon,\n        max_batch_size=bigcode_config.max_batch_size,\n        max_sequence_length=bigcode_config.max_sequence_length,\n        model_type=bigcode_config.model_type,\n        multi_query=bigcode_config.multi_query,\n        n_embd=bigcode_config.n_embd,\n        n_head=bigcode_config.n_head,\n        n_inner=bigcode_config.n_inner,\n        n_layer=bigcode_config.n_layer,\n        n_positions=bigcode_config.n_positions,\n        resid_pdrop=bigcode_config.resid_pdrop,\n        scale_attn_weights=bigcode_config.scale_attn_weights,\n        summary_activation=bigcode_config.summary_activation,\n        summary_first_dropout=bigcode_config.summary_first_dropout,\n        summary_proj_to_labels=bigcode_config.summary_proj_to_labels,\n        summary_type=bigcode_config.summary_type,\n        summary_use_proj=bigcode_config.summary_use_proj,\n        use_cache=bigcode_config.use_cache,\n        vocab_size=bigcode_config.vocab_size,\n    )\n"
  },
  {
    "path": "flash_attn/models/btlm.py",
    "content": "# Copyright (c) 2023, Tri Dao.\n\nimport math\nimport json\nimport re\nfrom pathlib import Path\n\nfrom collections import OrderedDict\n\nimport torch\nimport torch.nn.functional as F\n\nfrom einops import rearrange\nfrom transformers import GPT2Config, AutoConfig, PretrainedConfig\n\n\ndef remap_state_dict_hf_btlm(state_dict, config):\n    # Word embedding and position embedding\n    def key_mapping_pos_emb(key):\n        return re.sub(r\"^transformer.wpe.\", \"transformer.embeddings.position_embeddings.\", key)\n\n    if \"transformer.wpe.weight\" in state_dict:\n        state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())\n    word_embeddings = state_dict.pop(\"transformer.wte.weight\")\n    # It's possible that vocab_size is padded to be a multiple of 8, for example.\n    pad_vocab_size_multiple = getattr(config, \"pad_vocab_size_multiple\", 1)\n    vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple\n    state_dict[\"transformer.embeddings.word_embeddings.weight\"] = F.pad(\n        word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])\n    )\n    state_dict[\"lm_head.weight\"] = state_dict[\"transformer.embeddings.word_embeddings.weight\"]\n\n    # LayerNorm\n    def key_mapping_ln(key):\n        key = re.sub(r\"^transformer.ln_f.(weight|bias)\", r\"transformer.ln_f.\\1\", key)\n        key = re.sub(r\"^transformer.h.(\\d+).ln_(1|2).(weight|bias)\", r\"transformer.layers.\\1.norm\\2.\\3\", key)\n        return key\n\n    state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())\n\n    # MLP\n    for d in range(config.num_hidden_layers):\n        W1 = state_dict.pop(f\"transformer.h.{d}.mlp.c_fc.weight\")\n        W3 = state_dict.pop(f\"transformer.h.{d}.mlp.c_fc2.weight\")\n        state_dict[f\"transformer.layers.{d}.mlp.fc1.weight\"] = torch.cat([W1.t(), W3.t()], dim=0)\n        b1 = state_dict.pop(f\"transformer.h.{d}.mlp.c_fc.bias\")\n        b3 = state_dict.pop(f\"transformer.h.{d}.mlp.c_fc2.bias\")\n        state_dict[f\"transformer.layers.{d}.mlp.fc1.bias\"] = torch.cat([b1, b3], dim=0)\n        W2 = state_dict.pop(f\"transformer.h.{d}.mlp.c_proj.weight\")\n        state_dict[f\"transformer.layers.{d}.mlp.fc2.weight\"] = W2.t()\n\n    def key_mapping_mlp(key):\n        key = re.sub(r\"^transformer.h.(\\d+).mlp.c_proj.bias\", r\"transformer.layers.\\1.mlp.fc2.bias\", key)\n        return key\n\n    state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())\n\n    # Attention\n    for d in range(config.num_hidden_layers):\n        Wqkv = state_dict.pop(f\"transformer.h.{d}.attn.c_attn.weight\")\n        state_dict[f\"transformer.layers.{d}.mixer.Wqkv.weight\"] = Wqkv.t()\n        Wout = state_dict.pop(f\"transformer.h.{d}.attn.c_proj.weight\")\n        state_dict[f\"transformer.layers.{d}.mixer.out_proj.weight\"] = Wout.t()\n    state_dict.pop(f\"transformer.relative_pe.slopes\")  # We don't store the Alibi slopes\n\n    def key_mapping_attn(key):\n        key = re.sub(r\"^transformer.h.(\\d+).attn.c_attn.bias\", r\"transformer.layers.\\1.mixer.Wqkv.bias\", key)\n        key = re.sub(\n            r\"^transformer.h.(\\d+).attn.c_proj.bias\", r\"transformer.layers.\\1.mixer.out_proj.bias\", key\n        )\n        return key\n\n    state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())\n\n    return state_dict\n\n\ndef btlm_config_to_gpt2_config(btlm_config: PretrainedConfig) -> GPT2Config:\n    return GPT2Config(\n        vocab_size=btlm_config.vocab_size,\n        n_positions=0 if btlm_config.position_embedding_type == \"alibi\" else btlm_config.n_positions,\n        n_embd=btlm_config.hidden_size,\n        n_layer=btlm_config.num_hidden_layers,\n        n_head=btlm_config.num_attention_heads,\n        n_inner=btlm_config.n_inner,\n        activation_function=btlm_config.activation_function,\n        resid_pdrop=btlm_config.resid_pdrop,\n        embd_pdrop=btlm_config.embd_pdrop,\n        attn_pdrop=btlm_config.attn_pdrop,\n        layer_norm_epsilon=btlm_config.layer_norm_epsilon,\n        initializer_range=btlm_config.initializer_range,\n        bos_token_id=btlm_config.bos_token_id,\n        eos_token_id=btlm_config.eos_token_id,\n        # These are new arguments not in the original GPT2Config\n        use_alibi=btlm_config.position_embedding_type == \"alibi\",\n        use_flash_attn=btlm_config.position_embedding_type == \"alibi\",  # Alibi code path requires flash_attn\n        mup_width_scale=btlm_config.mup_width_scale,\n        mup_embeddings_multiplier=btlm_config.mup_embeddings_scale,\n        mup_output_multiplier=btlm_config.mup_output_alpha,\n        mup_scale_qk_dot_by_d=btlm_config.mup_scale_qk_dot_by_d,\n        mlp_multiple_of=1,\n    )\n"
  },
  {
    "path": "flash_attn/models/falcon.py",
    "content": "# Copyright (c) 2023, Tri Dao.\n\nimport math\nimport re\nfrom collections import OrderedDict\n\nimport torch\nimport torch.nn.functional as F\nfrom einops import rearrange\nfrom transformers import FalconConfig, GPT2Config\n\n\ndef remap_state_dict_hf_falcon(state_dict, config):\n    def key_mapping_layers(key):\n        return re.sub(r\"^transformer.h.\", \"transformer.layers.\", key)\n\n    state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())\n    # Word embedding\n    def key_mapping_emb(key):\n        return re.sub(\n            r\"^transformer.word_embeddings.\", \"transformer.embeddings.word_embeddings.\", key\n        )\n\n    state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())\n    word_embeddings = state_dict.pop(\"transformer.embeddings.word_embeddings.weight\")\n    # It's possible that vocab_size is padded to be a multiple of 8, for example.\n    pad_vocab_size_multiple = getattr(config, \"pad_vocab_size_multiple\", 1)\n    vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple\n    state_dict[\"transformer.embeddings.word_embeddings.weight\"] = F.pad(\n        word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])\n    )\n    if getattr(config, \"tie_word_embeddings\"):\n        state_dict[\"lm_head.weight\"] = state_dict[\"transformer.embeddings.word_embeddings.weight\"]\n    else:\n        output_embeddings = state_dict.pop(\"lm_head.weight\")\n        # It's possible that vocab_size is padded to be a multiple of 8, for example.\n        state_dict[\"lm_head.weight\"] = F.pad(\n            output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])\n        )\n        output_embeddings_bias = state_dict.pop(\"lm_head.bias\")\n        state_dict[\"lm_head.bias\"] = F.pad(\n            output_embeddings_bias, (0, vocab_size - output_embeddings_bias.shape[0])\n        )\n\n    # LayerNorm\n    def key_mapping_ln(key):\n        key = re.sub(\n            r\"^transformer.layers.(\\d+).input_layernorm.\", r\"transformer.layers.\\1.norm1.\", key\n        )\n        key = re.sub(\n            r\"^transformer.layers.(\\d+).post_attention_layernorm.\",\n            r\"transformer.layers.\\1.norm2.\",\n            key,\n        )\n        key = re.sub(r\"^transformer.layers.(\\d+).ln_attn.\", r\"transformer.layers.\\1.norm1.\", key)\n        key = re.sub(r\"^transformer.layers.(\\d+).ln_mlp.\", r\"transformer.layers.\\1.norm2.\", key)\n        return key\n\n    state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())\n\n    # MLP\n    def key_mapping_mlp(key):\n        key = re.sub(\n            r\"^transformer.layers.(\\d+).mlp.dense_h_to_4h.\", r\"transformer.layers.\\1.mlp.fc1.\", key\n        )\n        key = re.sub(\n            r\"^transformer.layers.(\\d+).mlp.dense_4h_to_h.\", r\"transformer.layers.\\1.mlp.fc2.\", key\n        )\n        return key\n\n    state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())\n\n    def key_mapping_attn(key):\n        key = re.sub(\n            r\"^transformer.layers.(\\d+).self_attention.query_key_value.\",\n            r\"transformer.layers.\\1.mixer.Wqkv.\",\n            key,\n        )\n        key = re.sub(\n            r\"^transformer.layers.(\\d+).self_attention.dense.\",\n            r\"transformer.layers.\\1.mixer.out_proj.\",\n            key,\n        )\n        return key\n\n    state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())\n    n_head = config.n_head\n    n_head_kv = getattr(config, \"n_head_kv\", 1)\n    headdim = config.hidden_size // n_head\n    for l in range(config.n_layer):\n        # The weights are stored in a different layout compared to our implementation\n        Wqkv = rearrange(\n            state_dict.pop(f\"transformer.layers.{l}.mixer.Wqkv.weight\"),\n            \"(group ratio headdim) ... -> group ratio headdim ...\",\n            ratio=n_head // n_head_kv + 2,\n            headdim=headdim,\n        )\n        Wq = rearrange(Wqkv[:, :-2], \"group ratio headdim ... -> (group ratio headdim) ...\")\n        Wk = rearrange(Wqkv[:, [-2]], \"group ratio headdim ... -> (group ratio headdim) ...\")\n        Wv = rearrange(Wqkv[:, [-1]], \"group ratio headdim ... -> (group ratio headdim) ...\")\n        state_dict[f\"transformer.layers.{l}.mixer.Wqkv.weight\"] = torch.cat([Wq, Wk, Wv], dim=0)\n\n    return state_dict\n\n\ndef falcon_config_to_gpt2_config(falcon_config: FalconConfig) -> GPT2Config:\n    # The 40b config uses \"n_head_kv\" instead of \"num_kv_heads\"\n    n_head_kv = getattr(\n        falcon_config,\n        \"n_head_kv\",\n        1 if getattr(falcon_config, \"multi_query\", False) else falcon_config.n_head,\n    )\n    # HACK: the 40b config has 2 LN per layer instead of 1, but that's not reflected in the config.\n    # So we have to infer it from the number of heads in the key/value block\n    parallel_block_tied_norm = n_head_kv == 1\n    return GPT2Config(\n        vocab_size=falcon_config.vocab_size,\n        n_positions=0,  # No absolute position embedding\n        n_embd=falcon_config.hidden_size,\n        n_layer=falcon_config.n_layer,\n        n_head=falcon_config.n_head,\n        n_inner=falcon_config.hidden_size * 4,\n        activation_function=\"gelu\",\n        resid_pdrop=falcon_config.hidden_dropout,\n        embd_pdrop=0.0,  # There doesn't seem to be any embedding dropout\n        attn_pdrop=falcon_config.attention_dropout,\n        layer_norm_epsilon=falcon_config.layer_norm_epsilon,\n        initializer_range=falcon_config.initializer_range,\n        bos_token_id=falcon_config.bos_token_id,\n        eos_token_id=falcon_config.eos_token_id,\n        # These are new arguments not in the original GPT2Config\n        parallel_block=falcon_config.parallel_attn,\n        n_head_kv=n_head_kv,\n        parallel_block_tied_norm=parallel_block_tied_norm,\n        rotary_emb_fraction=1.0,\n        rotary_emb_interleaved=False,\n        tie_word_embeddings=True,\n        qkv_proj_bias=falcon_config.bias,\n        out_proj_bias=falcon_config.bias,\n        mlp_fc1_bias=falcon_config.bias,\n        mlp_fc2_bias=falcon_config.bias,\n        lm_head_bias=False,\n    )\n"
  },
  {
    "path": "flash_attn/models/gpt.py",
    "content": "# Copyright (c) 2024, Tri Dao.\n\nimport logging\nimport math\nimport re\nfrom collections import OrderedDict, namedtuple\nfrom collections.abc import Sequence\nfrom functools import partial\nfrom typing import Dict, List\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom einops import rearrange\nfrom transformers import GPT2Config\n\nfrom flash_attn.models.bigcode import remap_state_dict_hf_bigcode\nfrom flash_attn.models.falcon import remap_state_dict_hf_falcon\nfrom flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox\nfrom flash_attn.models.gptj import remap_state_dict_hf_gptj\nfrom flash_attn.models.llama import remap_state_dict_hf_llama\nfrom flash_attn.models.opt import remap_state_dict_hf_opt\nfrom flash_attn.modules.block import Block, ParallelBlock\nfrom flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings\nfrom flash_attn.modules.mha import MHA, ParallelMHA\nfrom flash_attn.modules.mlp import (\n    FusedMLP,\n    GatedMlp,\n    Mlp,\n    ParallelFusedMLP,\n    ParallelGatedMlp,\n    ParallelMLP,\n)\nfrom flash_attn.ops.activations import sqrelu_fwd\nfrom flash_attn.utils.distributed import (\n    all_gather,\n    all_gather_raw,\n    get_dim_for_local_rank,\n    sync_shared_params,\n)\nfrom flash_attn.utils.generation import GenerationMixin\nfrom flash_attn.utils.pretrained import state_dict_from_pretrained\n\ntry:\n    from flash_attn.ops.fused_dense import ColumnParallelLinear\nexcept ImportError:\n    ColumnParallelLinear = None\n\ntry:\n    from flash_attn.ops.triton.mlp import FusedDenseSqreluDense\nexcept ImportError:\n    FusedDenseSqreluDense = None\n\ntry:\n    from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm\nexcept ImportError:\n    layer_norm_fn, RMSNorm = None, None\n\nlogger = logging.getLogger(__name__)\n\n\ndef create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):\n    factory_kwargs = {\"device\": device, \"dtype\": dtype}\n    head_dim = getattr(config, \"head_dim\", config.hidden_size // config.num_attention_heads)\n    attn_scale_power = 0.5 if not getattr(config, \"mup_scale_qk_dot_by_d\", False) else 1.0\n    softmax_scale = 1.0 if not config.scale_attn_weights else (head_dim ** (-attn_scale_power))\n    softmax_scale *= getattr(config, \"mup_attn_multiplier\", 1.0)\n    if config.scale_attn_by_inverse_layer_idx:\n        assert layer_idx is not None\n        softmax_scale /= float(layer_idx + 1)\n    dwconv = getattr(config, \"attn_dwconv\", False)\n    if dwconv:\n        assert process_group is None, \"TensorParallel MHA does not support dwconv yet\"\n    qkv_proj_bias = getattr(config, \"qkv_proj_bias\", True)\n    out_proj_bias = getattr(config, \"out_proj_bias\", True)\n    rotary_emb_dim = int(getattr(config, \"rotary_emb_fraction\", 0.0) * head_dim)\n    rotary_emb_base = getattr(config, \"rotary_emb_base\", 10000.0)\n    rotary_emb_scale_base = getattr(config, \"rotary_emb_scale_base\", None)\n    rotary_emb_interleaved = getattr(config, \"rotary_emb_interleaved\", False)\n    use_alibi = getattr(config, \"use_alibi\", False)\n    window_size = getattr(config, \"window_size\", (-1, -1))\n    use_flash_attn = getattr(config, \"use_flash_attn\", False)\n    fused_bias_fc = getattr(config, \"fused_bias_fc\", False)\n    if not fused_bias_fc:\n        assert process_group is None, \"TensorParallel MHA requires fused_bias_fc\"\n    mha_cls = MHA if process_group is None else ParallelMHA\n    serial_kwargs = (\n        {\"fused_bias_fc\": fused_bias_fc, \"dwconv\": dwconv} if process_group is None else {}\n    )\n    parallel_kwargs = (\n        {\n            \"process_group\": process_group,\n            \"sequence_parallel\": getattr(config, \"sequence_parallel\", True),\n        }\n        if process_group is not None\n        else {}\n    )\n    num_heads_kv = getattr(config, \"n_head_kv\", None)\n    mixer_cls = partial(\n        mha_cls,\n        num_heads=config.num_attention_heads,\n        num_heads_kv=num_heads_kv,\n        qkv_proj_bias=qkv_proj_bias,\n        out_proj_bias=out_proj_bias,\n        dropout=config.attn_pdrop,\n        softmax_scale=softmax_scale,\n        causal=True,\n        layer_idx=layer_idx,\n        rotary_emb_dim=rotary_emb_dim,\n        rotary_emb_base=rotary_emb_base,\n        rotary_emb_scale_base=rotary_emb_scale_base,\n        rotary_emb_interleaved=rotary_emb_interleaved,\n        use_alibi=use_alibi,\n        window_size=window_size,\n        use_flash_attn=use_flash_attn,\n        **serial_kwargs,\n        **parallel_kwargs,\n        **factory_kwargs,\n    )\n    return mixer_cls\n\n\ndef create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):\n    factory_kwargs = {\"device\": device, \"dtype\": dtype}\n    mlp_fc1_bias = getattr(config, \"mlp_fc1_bias\", True)\n    mlp_fc2_bias = getattr(config, \"mlp_fc2_bias\", True)\n    fused_mlp = getattr(config, \"fused_mlp\", False)\n    if fused_mlp:\n        assert config.activation_function in [\n            \"gelu_new\",\n            \"gelu_fast\",\n            \"gelu_approx\",\n            \"gelu_pytorch_tanh\",\n            \"relu\",\n            \"sqrelu\",\n        ]\n    fused_dense_sqrelu_dense = getattr(config, \"fused_dense_sqrelu_dense\", False)\n    if fused_dense_sqrelu_dense:\n        assert config.activation_function == \"sqrelu\", (\n            \"fused_dense_sqrelu_dense only \" \"supports approximate activation_function sqrelu\"\n        )\n    assert not (fused_dense_sqrelu_dense and fused_mlp)\n    if not fused_mlp and not fused_dense_sqrelu_dense:\n        assert config.activation_function in [\n            \"gelu\",\n            \"gelu_new\",\n            \"gelu_fast\",\n            \"gelu_approx\",\n            \"gelu_pytorch_tanh\",\n            \"relu\",\n            \"sqrelu\",\n            \"glu\",\n            \"swiglu\",\n            \"geglu\",\n        ]\n        if config.activation_function in [\"glu\", \"swiglu\", \"geglu\"]:\n            activation = (\n                F.sigmoid\n                if config.activation_function == \"glu\"\n                else (F.silu if config.activation_function == \"swiglu\" else F.gelu)\n            )\n            mlp_cls = GatedMlp if process_group is None else ParallelGatedMlp\n            parallel_kwargs = (\n                {\n                    \"process_group\": process_group,\n                    \"sequence_parallel\": getattr(config, \"sequence_parallel\", True),\n                }\n                if process_group is not None\n                else {}\n            )\n            mlp_multiple_of = getattr(config, \"mlp_multiple_of\", 128)\n            mlp_cls = partial(\n                mlp_cls,\n                hidden_features=config.n_inner,\n                activation=activation,\n                bias1=mlp_fc1_bias,\n                bias2=mlp_fc2_bias,\n                multiple_of=mlp_multiple_of,\n                **parallel_kwargs,\n                **factory_kwargs,\n            )\n        else:\n            if config.activation_function == \"relu\":\n                activation = partial(F.relu, inplace=True)\n            elif config.activation_function == \"sqrelu\":\n                activation = sqrelu_fwd\n            else:\n                approximate = (\n                    \"tanh\"\n                    if config.activation_function\n                    in [\"gelu_new\", \"gelu_fast\", \"gelu_approx\", \"gelu_pytorch_tanh\"]\n                    else \"none\"\n                )\n                activation = partial(F.gelu, approximate=approximate)\n            mlp_cls = Mlp if process_group is None else ParallelMLP\n            parallel_kwargs = (\n                {\n                    \"process_group\": process_group,\n                    \"sequence_parallel\": getattr(config, \"sequence_parallel\", True),\n                }\n                if process_group is not None\n                else {}\n            )\n            mlp_cls = partial(\n                mlp_cls,\n                hidden_features=config.n_inner,\n                activation=activation,\n                bias1=mlp_fc1_bias,\n                bias2=mlp_fc2_bias,\n                **parallel_kwargs,\n                **factory_kwargs,\n            )\n    else:\n        mlp_checkpoint_lvl = getattr(config, \"mlp_checkpoint_lvl\", 0)\n        # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer\n        if isinstance(mlp_checkpoint_lvl, Sequence):\n            assert layer_idx is not None\n            mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]\n        if fused_mlp:\n            if FusedMLP is None:\n                raise ImportError(\"fused_dense is not installed\")\n            activation = (\n                \"gelu_approx\"\n                if config.activation_function\n                in [\"gelu_new\", \"gelu_fast\", \"gelu_approx\", \"gelu_pytorch_tanh\"]\n                else config.activation_function\n            )\n            mlp_cls = FusedMLP if process_group is None else ParallelFusedMLP\n            parallel_kwargs = (\n                {\n                    \"process_group\": process_group,\n                    \"sequence_parallel\": getattr(config, \"sequence_parallel\", True),\n                }\n                if process_group is not None\n                else {}\n            )\n            mlp_cls = partial(\n                mlp_cls,\n                hidden_features=config.n_inner,\n                activation=activation,\n                checkpoint_lvl=mlp_checkpoint_lvl,\n                bias1=mlp_fc1_bias,\n                bias2=mlp_fc2_bias,\n                **parallel_kwargs,\n                **factory_kwargs,\n            )\n        elif fused_dense_sqrelu_dense:\n            if process_group is not None:\n                assert fused_mlp, \"Tensor Parallel is not implemented for FusedDenseSqreluDense\"\n            assert FusedDenseSqreluDense is not None\n            mlp_cls = partial(\n                FusedDenseSqreluDense,\n                hidden_features=config.n_inner,\n                checkpoint_lvl=mlp_checkpoint_lvl,\n                **factory_kwargs,\n            )\n        else:\n            raise RuntimeError(\"MLP type not supported\")\n    return mlp_cls\n\n\ndef create_block(config, layer_idx=None, process_group=None, device=None, dtype=None):\n    factory_kwargs = {\"device\": device, \"dtype\": dtype}\n    sequence_parallel = getattr(config, \"sequence_parallel\", True)\n    mixer_cls = create_mixer_cls(config, layer_idx, process_group=process_group, **factory_kwargs)\n    mlp_cls = create_mlp_cls(config, layer_idx, process_group=process_group, **factory_kwargs)\n    use_rms_norm = getattr(config, \"rms_norm\", False)\n    norm_cls = partial(\n        nn.LayerNorm if not use_rms_norm else RMSNorm,\n        eps=config.layer_norm_epsilon,\n        **factory_kwargs,\n    )\n    # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable\n    residual_in_fp32 = getattr(config, \"residual_in_fp32\", False)\n    resid_dropout1 = config.resid_pdrop if layer_idx is None or layer_idx > 0 else config.embd_pdrop\n    prenorm = getattr(config, \"prenorm\", True)\n    parallel_block = getattr(config, \"parallel_block\", False)\n    if not parallel_block:\n        block = Block(\n            config.hidden_size,\n            mixer_cls,\n            mlp_cls,\n            norm_cls=norm_cls,\n            prenorm=prenorm,\n            resid_dropout1=resid_dropout1,\n            resid_dropout2=config.resid_pdrop,\n            fused_dropout_add_ln=getattr(config, \"fused_dropout_add_ln\", False),\n            residual_in_fp32=residual_in_fp32,\n            sequence_parallel=sequence_parallel and process_group is not None,\n            mark_shared_params=process_group is not None,\n        )\n    else:\n        assert prenorm\n        block = ParallelBlock(\n            config.hidden_size,\n            mixer_cls,\n            mlp_cls,\n            norm_cls=norm_cls,\n            resid_dropout1=resid_dropout1,\n            resid_dropout2=config.resid_pdrop,\n            tied_norm=getattr(config, \"parallel_block_tied_norm\", False),\n            fused_dropout_add_ln=getattr(config, \"fused_dropout_add_ln\", False),\n            residual_in_fp32=residual_in_fp32,\n            sequence_parallel=sequence_parallel and process_group is not None,\n            mark_shared_params=process_group is not None,\n        )\n    block.layer_idx = layer_idx\n    return block\n\n\nclass GPTPreTrainedModel(nn.Module):\n    \"\"\"An abstract class to handle weights initialization and\n    a simple interface for dowloading and loading pretrained models.\n    \"\"\"\n\n    def __init__(self, config, *inputs, **kwargs):\n        super().__init__()\n        if not isinstance(config, GPT2Config):\n            raise ValueError(\n                \"Parameter config in `{}(config)` should be an instance of class `GPT2Config`. \"\n                \"To create a model from a Google pretrained model use \"\n                \"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`\".format(\n                    self.__class__.__name__, self.__class__.__name__\n                )\n            )\n        self.config = config\n\n    @classmethod\n    def from_pretrained(\n        cls,\n        model_name,\n        config,\n        *args,\n        strict=True,\n        device=None,\n        dtype=None,\n        world_size=1,\n        rank=0,\n        **kwargs,\n    ):\n        \"\"\"\n        Instantiate a GPTPreTrainedModel from a pre-trained model file or a pytorch state dict.\n        Download and cache the pre-trained model file if needed.\n        \"\"\"\n        # Instantiate model.\n        model = cls(config, *args, device=device, dtype=dtype, **kwargs)\n        # Load state_dict in cpu because we already initialized the model in GPU, and we don't\n        # want extra stuff taking up more GPU memory\n        state_dict = state_dict_from_pretrained(model_name, device=\"cpu\", dtype=dtype)\n        if model_name.startswith(\"gpt2\"):\n            state_dict = remap_state_dict_hf_gpt2(state_dict, config)\n        elif model_name.startswith(\"facebook/opt\"):\n            state_dict = remap_state_dict_hf_opt(state_dict, config)\n        elif model_name.startswith(\"EleutherAI/gpt-j-\") or model_name.startswith(\n            \"togethercomputer/GPT-JT-\"\n        ):\n            state_dict = remap_state_dict_hf_gptj(state_dict, config)\n        elif (\n            model_name.startswith(\"EleutherAI/gpt-neox-\")\n            or model_name.startswith(\"EleutherAI/pythia-\")\n            or model_name.startswith(\"togethercomputer/RedPajama-INCITE-\")\n        ):\n            state_dict = remap_state_dict_hf_gpt_neox(state_dict, config)\n        elif model_name.startswith(\"tiiuae/falcon-\"):\n            state_dict = remap_state_dict_hf_falcon(state_dict, config)\n        elif model_name.startswith(\"meta-llama/Llama-\"):\n            state_dict = remap_state_dict_hf_llama(state_dict, config)\n        elif model_name.startswith(\"bigcode/\") or model_name.startswith(\"WizardLM/\"):\n            state_dict = remap_state_dict_hf_bigcode(state_dict, config)\n        else:\n            raise NotImplementedError(f\"Model {model_name} not supported\")\n        if world_size > 1:\n            state_dict = shard_state_dict_tp(state_dict, config, world_size, rank)\n        load_return = model.load_state_dict(state_dict, strict=strict)\n        logger.info(load_return)\n        return model\n\n\n# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454\ndef _init_weights(\n    module, n_layer, initializer_range=0.02, mup_width_scale=1.0, rescale_prenorm_residual=True\n):\n    mup_init_scale = math.sqrt(mup_width_scale)\n    if isinstance(module, nn.Linear):\n        nn.init.normal_(module.weight, std=initializer_range * mup_init_scale)\n        optim_cfg = getattr(module.weight, \"_optim\", {})\n        optim_cfg.update({\"lr_multiplier\": mup_width_scale})\n        setattr(module.weight, \"_optim\", optim_cfg)\n        if module.bias is not None:\n            nn.init.zeros_(module.bias)\n    elif isinstance(module, nn.Embedding):\n        nn.init.normal_(module.weight, std=initializer_range)\n\n    if rescale_prenorm_residual:\n        # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:\n        #   > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale\n        #   > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.\n        #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/\n        #\n        # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py\n        for name, p in module.named_parameters():\n            if name in [\"out_proj.weight\", \"fc2.weight\"]:\n                # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block\n                nn.init.normal_(\n                    p, mean=0.0, std=initializer_range * mup_init_scale / math.sqrt(2 * n_layer)\n                )\n\n\nclass GPTModel(GPTPreTrainedModel):\n    def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None):\n        super().__init__(config)\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        self.process_group = process_group\n        self.sequence_parallel = getattr(config, \"sequence_parallel\", True)\n        assert config.activation_function in [\n            \"gelu\",\n            \"gelu_new\",\n            \"gelu_fast\",\n            \"gelu_approx\",\n            \"gelu_pytorch_tanh\",\n            \"relu\",\n            \"sqrelu\",\n            \"glu\",\n            \"swiglu\",\n            \"geglu\",\n        ]\n        pad_vocab_size_multiple = getattr(config, \"pad_vocab_size_multiple\", 1)\n        vocab_size = (\n            math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple\n        )\n        self.embeddings_multiplier = getattr(config, \"mup_embeddings_multiplier\", 1.0)\n        # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable\n        self.residual_in_fp32 = getattr(config, \"residual_in_fp32\", False)\n        # These 2 options are for OPT-350m\n        self.prenorm = getattr(config, \"prenorm\", True)\n        use_rms_norm = getattr(config, \"rms_norm\", False)\n        word_embed_proj_dim = getattr(config, \"word_embed_proj_dim\", None)\n        # For GPT-J, GPT-NeoX\n        self.parallel_block = getattr(config, \"parallel_block\", False)\n\n        if process_group is None:\n            self.embeddings = GPT2Embeddings(\n                config.hidden_size,\n                vocab_size,\n                config.max_position_embeddings,\n                word_embed_proj_dim=word_embed_proj_dim,\n                **factory_kwargs,\n            )\n        else:\n            self.embeddings = ParallelGPT2Embeddings(\n                config.hidden_size,\n                vocab_size,\n                config.max_position_embeddings,\n                process_group=process_group,\n                sequence_parallel=self.sequence_parallel,\n                **factory_kwargs,\n            )\n\n        # We change the order of dropout, residual and layer norm:\n        # Instead of LN -> Attn / MLP -> Dropout -> Add, we do:\n        # Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and\n        # the main branch (output of MLP). The model definition is unchanged, but the mapping of the\n        # nn.Dropout probabilities are changed.\n        # This is for performance reason: we can fuse dropout + add + layer_norm.\n        self.layers = nn.ModuleList(\n            [\n                create_block(config, layer_idx=i, process_group=process_group, **factory_kwargs)\n                for i in range(config.num_hidden_layers)\n            ]\n        )\n        rotary_emb_fraction = getattr(config, \"rotary_emb_fraction\", 0.0)\n        if rotary_emb_fraction > 0.0:  # Tie all the RotaryEmbedding modules to share the same cos/sin cache\n            for layer in self.layers[1:]:\n                layer.mixer.rotary_emb = self.layers[0].mixer.rotary_emb\n\n        self.fused_dropout_add_ln = getattr(config, \"fused_dropout_add_ln\", False)\n        if self.fused_dropout_add_ln:\n            if layer_norm_fn is None:\n                raise ImportError(\"Triton is not installed\")\n        if self.prenorm:\n            self.drop_f = nn.Dropout(config.resid_pdrop)\n            norm_cls = nn.LayerNorm if not use_rms_norm else RMSNorm\n            self.ln_f = norm_cls(\n                config.hidden_size, eps=config.layer_norm_epsilon, **factory_kwargs\n            )\n        if process_group is not None:\n            for p in self.ln_f.parameters():\n                # Mark the norm parameters as \"shared_params\" so that we sync their values at init.\n                p._shared_params = True\n                # Mark the norm params as \"sequence_parallel\" so we run all-reduce on their grads.\n                if self.sequence_parallel:\n                    p._sequence_parallel = True\n\n        self.apply(\n            partial(\n                _init_weights,\n                n_layer=config.num_hidden_layers,\n                initializer_range=config.initializer_range,\n                mup_width_scale=getattr(config, \"mup_width_scale\", 1.0),\n            )\n        )\n        self.tie_weights()\n\n    def tie_weights(self):\n        if self.process_group is not None:\n            sync_shared_params(self, self.process_group)\n\n    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):\n        return {\n            i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)\n            for i, layer in enumerate(self.layers)\n        }\n\n    def forward(self, input_ids, position_ids=None, inference_params=None):\n        # If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen\n        # dimensions so that we can split on it easily, in case of small batch size.\n        # Only the attention layers need to know the seqlen.\n        embedding_kwargs = (\n            {\"combine_batch_seqlen_dim\": True}\n            if self.process_group is not None and self.sequence_parallel\n            else {}\n        )\n        hidden_states = self.embeddings(input_ids, position_ids=position_ids, **embedding_kwargs)\n        if self.embeddings_multiplier != 1.0:\n            hidden_states = hidden_states * self.embeddings_multiplier\n        if self.parallel_block:\n            hidden_states2 = None\n        residual = None\n        mixer_kwargs = (\n            {\"seqlen\": input_ids.shape[1]}\n            if self.process_group is not None and self.sequence_parallel\n            else {}\n        )\n        if inference_params is not None:\n            mixer_kwargs[\"inference_params\"] = inference_params\n        for layer in self.layers:\n            if self.prenorm:\n                if not self.parallel_block:\n                    hidden_states, residual = layer(\n                        hidden_states, residual, mixer_kwargs=mixer_kwargs\n                    )\n                else:\n                    hidden_states, hidden_states2, residual = layer(\n                        hidden_states, hidden_states2, residual, mixer_kwargs=mixer_kwargs\n                    )\n            else:\n                hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)\n        if self.prenorm:\n            if not self.fused_dropout_add_ln:\n                dropped = self.drop_f(hidden_states)\n                if not self.parallel_block:\n                    residual = (dropped + residual) if residual is not None else dropped\n                else:\n                    dropped2 = self.drop_f(hidden_states2)\n                    residual = (\n                        (residual + dropped + dropped2)\n                        if residual is not None\n                        else dropped + dropped2\n                    )\n                hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype))\n            else:\n                # Set prenorm=False here since we don't need the residual\n                hidden_states = layer_norm_fn(\n                    hidden_states,\n                    self.ln_f.weight,\n                    self.ln_f.bias,\n                    residual=residual,\n                    x1=None if not self.parallel_block else hidden_states2,\n                    eps=self.ln_f.eps,\n                    dropout_p=self.drop_f.p if self.training else 0.0,\n                    prenorm=False,\n                    is_rms_norm=isinstance(self.ln_f, RMSNorm)\n                )\n        return hidden_states\n\n\nclass GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):\n    def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None):\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__(config)\n        self.process_group = process_group\n        self.transformer = GPTModel(config, process_group=process_group, **factory_kwargs)\n        self.tie_word_embeddings = getattr(config, \"tie_word_embeddings\", True)\n        lm_head_bias = getattr(config, \"lm_head_bias\", False)\n        pad_vocab_size_multiple = getattr(config, \"pad_vocab_size_multiple\", 1)\n        vocab_size = (\n            math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple\n        )\n        # This option is for OPT-350m\n        word_embed_proj_dim = getattr(config, \"word_embed_proj_dim\", None)\n        embed_dim = config.n_embd if word_embed_proj_dim is None else word_embed_proj_dim\n        if word_embed_proj_dim is not None:\n            self.project_out = nn.Linear(config.n_embd, embed_dim, bias=False, **factory_kwargs)\n        else:\n            self.project_out = None\n        mup_width_scale = getattr(config, \"mup_width_scale\", 1.0)\n        mup_output_multiplier = getattr(config, \"mup_output_multiplier\", 1.0)\n        self.output_scale = mup_output_multiplier * mup_width_scale\n        if process_group is None:\n            self.lm_head = nn.Linear(embed_dim, vocab_size, bias=lm_head_bias, **factory_kwargs)\n        else:\n            if ColumnParallelLinear is None:\n                raise ImportError(\"fused_dense_lib is not installed\")\n            self.lm_head = ColumnParallelLinear(\n                embed_dim,\n                vocab_size,\n                process_group,\n                bias=lm_head_bias,\n                sequence_parallel=getattr(config, \"sequence_parallel\", True),\n                **factory_kwargs,\n            )\n        self.norm_head = getattr(config, \"norm_head\", False)\n        # Initialize weights and apply final processing\n        self.apply(\n            partial(\n                _init_weights,\n                n_layer=config.num_hidden_layers,\n                initializer_range=config.initializer_range,\n                mup_width_scale=mup_width_scale,\n            )\n        )\n        self.tie_weights()\n\n    def tie_weights(self):\n        if self.tie_word_embeddings:\n            self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight\n        if self.process_group is not None:\n            sync_shared_params(self, self.process_group)\n\n    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):\n        return self.transformer.allocate_inference_cache(\n            batch_size, max_seqlen, dtype=dtype, **kwargs\n        )\n\n    def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0):\n        \"\"\"\n        input_ids: (batch, seqlen) int tensor\n        inference_params: for generation. Adapted from Megatron-LM (and Apex)\n        https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470\n        num_last_tokens: if > 0, only return the logits for the last n tokens\n        \"\"\"\n        assert (\n            input_ids.ndim == 2\n        ), f\"Expected `input_ids` to have shape [b, slen], but got shape {input_ids.shape}\"\n        b, slen = input_ids.shape\n        hidden_states = self.transformer(\n            input_ids, position_ids=position_ids, inference_params=inference_params\n        )\n        if inference_params is not None:\n            assert hidden_states.ndim == 3, \"sequence_parallel is not supported in generation mode\"\n        if num_last_tokens > 0:\n            hidden_states = hidden_states[:, -num_last_tokens:]\n        if self.project_out is not None:\n            hidden_states = self.project_out(hidden_states)\n        if self.output_scale != 1.0:\n            hidden_states = hidden_states * self.output_scale\n        if not self.norm_head:\n            lm_logits = self.lm_head(hidden_states)\n        else:\n            lm_head_weight = F.normalize(self.lm_head.weight)\n            if isinstance(self.lm_head, ColumnParallelLinear) and self.lm_head.sequence_parallel:\n                hidden_states = all_gather(hidden_states, self.lm_head.process_group)\n            lm_logits = F.linear(hidden_states, lm_head_weight, bias=self.lm_head.bias)\n        # During inference, we want the full logit for sampling\n        if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None:\n            lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group)\n            lm_logits = rearrange(lm_logits, \"(n b) ... d -> b ... (n d)\", b=b)\n        CausalLMOutput = namedtuple(\"CausalLMOutput\", [\"logits\"])\n        return CausalLMOutput(logits=lm_logits)\n\n    def load_state_dict(self, state_dict, strict=True):\n        # Remapping from our checkpoints that used a different ordering of layers in the block\n        # Previous: Attn / MLP -> Dropout -> Add -> LN\n        # Current: Dropout -> Add -> LN -> Attn / MLP\n        if \"transformer.ln_0.weight\" in state_dict:\n            n_layers = len(self.transformer.layers)\n            ln_weight = state_dict.pop(f\"transformer.layers.{n_layers - 1}.norm2.weight\")\n            ln_bias = state_dict.pop(f\"transformer.layers.{n_layers - 1}.norm2.bias\")\n            state_dict[\"transformer.ln_f.weight\"] = ln_weight\n            state_dict[\"transformer.ln_f.bias\"] = ln_bias\n            for l in reversed(range(n_layers)):\n                ln_weight = state_dict.pop(f\"transformer.layers.{l}.norm1.weight\")\n                ln_bias = state_dict.pop(f\"transformer.layers.{l}.norm1.bias\")\n                state_dict[f\"transformer.layers.{l}.norm2.weight\"] = ln_weight\n                state_dict[f\"transformer.layers.{l}.norm2.bias\"] = ln_bias\n                if l > 0:\n                    ln_weight = state_dict.pop(f\"transformer.layers.{l - 1}.norm2.weight\")\n                    ln_bias = state_dict.pop(f\"transformer.layers.{l - 1}.norm2.bias\")\n                    state_dict[f\"transformer.layers.{l}.norm1.weight\"] = ln_weight\n                    state_dict[f\"transformer.layers.{l}.norm1.bias\"] = ln_bias\n            ln_weight = state_dict.pop(\"transformer.ln_0.weight\")\n            ln_bias = state_dict.pop(\"transformer.ln_0.bias\")\n            state_dict[f\"transformer.layers.0.norm1.weight\"] = ln_weight\n            state_dict[f\"transformer.layers.0.norm1.bias\"] = ln_bias\n        return super().load_state_dict(state_dict, strict=strict)\n\n\ndef shard_state_dict_tp(state_dict, config, world_size, rank):\n    \"\"\"Convert the state_dict of a standard GPT model to the state_dict of a GPT model\n    with tensor parallel.\n\n    This function modifies state_dict in place.\n    \"\"\"\n    pad_vocab_size_multiple = getattr(config, \"pad_vocab_size_multiple\", 1)\n    vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple\n    assert vocab_size % world_size == 0\n    assert config.hidden_size % world_size == 0\n    inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size\n    assert inner_dim % world_size == 0\n\n    n_head = config.n_head\n    n_head_kv = getattr(config, \"n_head_kv\", n_head)\n\n    embed_dim = config.hidden_size\n    head_dim = embed_dim // n_head\n\n    def shard_first_dim(state_dict, key):\n        if key in state_dict:\n            x = state_dict[key]\n            dim = x.shape[0] // world_size\n            state_dict[key] = x[rank * dim : (rank + 1) * dim]\n\n    def shard_last_dim(state_dict, key, multiple_of=1):\n        if key in state_dict:\n            x = state_dict[key]\n            dim_each_rank = [\n                get_dim_for_local_rank(x.size(-1), world_size, local_rank, multiple_of)\n                for local_rank in range(world_size)\n            ]\n            beg, end = tuple(sum(dim_each_rank[:pos]) for pos in (rank, rank + 1))\n            state_dict[key] = x[..., beg:end]\n\n    def shard_gatedmlp_fc1_dim(state_dict, key):\n        if key in state_dict:\n            x = state_dict[key]\n            dim = x.shape[0] // world_size // 2\n            state_dict[key] = rearrange(\n                rearrange(x, \"(two o) ... -> two o ...\", two=2)[:, rank * dim : (rank + 1) * dim],\n                \"two o ... -> (two o) ...\",\n            )\n\n    def shard_qkv_headdim(state_dict, key):\n        if key in state_dict:\n            n_head_each_rank = [\n                get_dim_for_local_rank(n_head, world_size, local_rank)\n                for local_rank in range(world_size)\n            ]\n            n_head_kv_each_rank = [\n                get_dim_for_local_rank(n_head_kv, world_size, local_rank)\n                for local_rank in range(world_size)\n            ]\n\n            beg_n_head = sum(n_head_each_rank[:rank])\n            end_n_head = sum(n_head_each_rank[: rank + 1])\n\n            beg_n_head_kv = sum(n_head_kv_each_rank[:rank])\n            end_n_head_kv = sum(n_head_kv_each_rank[: rank + 1])\n\n            if n_head_kv == n_head:\n                x = rearrange(state_dict[key], \"(three d) ... -> three d ...\", three=3)\n                state_dict[key] = rearrange(\n                    x[:, beg_n_head * head_dim : end_n_head * head_dim],\n                    \"three d ... -> (three d) ...\",\n                )\n            else:\n                x = rearrange(\n                    state_dict[key],\n                    \"(nheadqkv headdim) ... -> nheadqkv headdim ...\",\n                    nheadqkv=n_head + 2 * n_head_kv,\n                )\n                state_dict[key] = rearrange(\n                    torch.cat(\n                        [\n                            x[beg_n_head:end_n_head],\n                            x[n_head + beg_n_head_kv : n_head + end_n_head_kv],\n                            x[\n                                n_head\n                                + n_head_kv\n                                + beg_n_head_kv : n_head\n                                + n_head_kv\n                                + end_n_head_kv\n                            ],\n                        ],\n                        dim=0,\n                    ),\n                    \"nheadqkv headdim ... -> (nheadqkv headdim) ...\",\n                )\n\n    shard_first_dim(state_dict, \"transformer.embeddings.word_embeddings.weight\")\n    if \"lm_head.weight\" in state_dict:\n        shard_first_dim(state_dict, \"lm_head.weight\")\n    if \"transformer.embeddings.position_embeddings.weight\" in state_dict:\n        shard_last_dim(state_dict, \"transformer.embeddings.position_embeddings.weight\")\n    for i in range(config.num_hidden_layers):\n        shard_qkv_headdim(state_dict, f\"transformer.layers.{i}.mixer.Wqkv.weight\")\n        shard_qkv_headdim(state_dict, f\"transformer.layers.{i}.mixer.Wqkv.bias\")\n        shard_last_dim(\n            state_dict, f\"transformer.layers.{i}.mixer.out_proj.weight\", multiple_of=head_dim\n        )\n        if rank != 0:\n            state_dict.pop(f\"transformer.layers.{i}.mixer.out_proj.bias\", None)\n        if config.activation_function in [\"glu\", \"swiglu\", \"geglu\"]:\n            shard_gatedmlp_fc1_dim(state_dict, f\"transformer.layers.{i}.mlp.fc1.weight\")\n            shard_gatedmlp_fc1_dim(state_dict, f\"transformer.layers.{i}.mlp.fc1.bias\")\n        else:\n            shard_first_dim(state_dict, f\"transformer.layers.{i}.mlp.fc1.weight\")\n            shard_first_dim(state_dict, f\"transformer.layers.{i}.mlp.fc1.bias\")\n        shard_last_dim(state_dict, f\"transformer.layers.{i}.mlp.fc2.weight\")\n        if rank != 0:\n            state_dict.pop(f\"transformer.layers.{i}.mlp.fc2.bias\", None)\n    return state_dict\n\n\ndef combine_state_dicts_tp(state_dicts: List[Dict[str, torch.Tensor]], config: GPT2Config):\n    \"\"\"Convert the list of sharded state_dict of a GPT model with tensor parallel to\n    the state_dict of a standard GPT model.\n\n    This function is meant to be the \"reverse\" of shard_state_dict_tp.\n\n    Precondition:\n        - state_dicts should be ordered in the same way as the shards were created.\n    \"\"\"\n    world_size = len(state_dicts)\n    keys = state_dicts[0].keys()\n    pad_vocab_size_multiple = getattr(config, \"pad_vocab_size_multiple\", 1)\n    vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple\n    assert vocab_size % world_size == 0\n    assert config.hidden_size % world_size == 0\n    inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size\n    assert inner_dim % world_size == 0\n    assert config.hidden_size % config.n_head == 0\n    headdim = config.hidden_size // config.n_head\n\n    # Sometimes the word embeddings are sharded on the 0th dim, sometimes on the 1st dim.\n    # vocab_size // world_size coordinates are nonzero.\n    def combine_word_embeddings(state_dicts, state_dict, key):\n        dim = 0 if state_dicts[0][key].shape[0] == vocab_size // world_size else 1\n        state_dict[key] = torch.cat([s[key] for s in state_dicts], dim=dim)\n\n    def combine_dim(state_dicts, state_dict, key, dim=-1):\n        if key in state_dict:\n            state_dict[key] = torch.cat([s[key] for s in state_dicts], dim=dim)\n\n    def combine_qkv_headdim(state_dicts, state_dict, key):\n        n_head = config.n_head\n        n_head_kv = getattr(config, \"n_head_kv\", n_head)\n        if key in state_dict:\n            if n_head_kv == n_head:\n                xs = [\n                    rearrange(s[key], \"(three d) ... -> three d ...\", three=3) for s in state_dicts\n                ]\n                state_dict[key] = rearrange(torch.cat(xs, dim=1), \"three d ... -> (three d) ...\")\n            else:\n                n_head_each_rank = [\n                    get_dim_for_local_rank(n_head, world_size, local_rank)\n                    for local_rank in range(world_size)\n                ]\n                n_head_kv_each_rank = [\n                    get_dim_for_local_rank(n_head_kv, world_size, local_rank)\n                    for local_rank in range(world_size)\n                ]\n                xs = [\n                    rearrange(\n                        s[key],\n                        \"(nheadqkv headdim) ... -> nheadqkv headdim ...\",\n                        nheadqkv=rank_n_head + 2 * rank_n_head_kv,\n                        headdim=headdim,\n                    )\n                    for s, rank_n_head, rank_n_head_kv in zip(\n                        state_dicts, n_head_each_rank, n_head_kv_each_rank\n                    )\n                ]\n                wq = torch.cat([x[: n_head_each_rank[rank]] for rank, x in enumerate(xs)], dim=0)\n                wk = torch.cat(\n                    [\n                        x[\n                            n_head_each_rank[rank] : n_head_each_rank[rank]\n                            + n_head_kv_each_rank[rank]\n                        ]\n                        for rank, x in enumerate(xs)\n                    ],\n                    dim=0,\n                )\n                wv = torch.cat(\n                    [\n                        x[n_head_each_rank[rank] + n_head_kv_each_rank[rank] :]\n                        for rank, x in enumerate(xs)\n                    ],\n                    dim=0,\n                )\n                wqkv = torch.cat(\n                    [wq, wk, wv],\n                    dim=0,\n                )\n                state_dict[key] = rearrange(\n                    wqkv,\n                    \"nheadqkv headdim ... -> (nheadqkv headdim) ...\",\n                )\n\n    def combine_gated_mlp(state_dicts, state_dict, key):\n        if key in state_dict:\n            xs = [rearrange(s[key], \"(two d) ... -> two d ...\", two=2) for s in state_dicts]\n            state_dict[key] = rearrange(torch.cat(xs, dim=1), \"two d ... -> (two d) ...\")\n\n    state_dict = state_dicts[0].copy()  # don't modify state_dict[0] inplace\n    combine_word_embeddings(\n        state_dicts, state_dict, \"transformer.embeddings.word_embeddings.weight\"\n    )\n    if \"lm_head.weight\" in state_dict:\n        combine_word_embeddings(state_dicts, state_dict, \"lm_head.weight\")\n    if \"transformer.embeddings.position_embeddings.weight\" in state_dict:\n        combine_dim(\n            state_dicts, state_dict, \"transformer.embeddings.position_embeddings.weight\", -1\n        )\n    mlp_combine_fn = (\n        combine_gated_mlp\n        if config.activation_function in [\"glu\", \"swiglu\", \"geglu\"]\n        else partial(combine_dim, dim=0)\n    )\n    for i in range(config.num_hidden_layers):\n        combine_qkv_headdim(state_dicts, state_dict, f\"transformer.layers.{i}.mixer.Wqkv.weight\")\n        combine_qkv_headdim(state_dicts, state_dict, f\"transformer.layers.{i}.mixer.Wqkv.bias\")\n        combine_dim(state_dicts, state_dict, f\"transformer.layers.{i}.mixer.out_proj.weight\", -1)\n        mlp_combine_fn(state_dicts, state_dict, f\"transformer.layers.{i}.mlp.fc1.weight\")\n        combine_dim(state_dicts, state_dict, f\"transformer.layers.{i}.mlp.fc1.bias\", 0)\n        combine_dim(state_dicts, state_dict, f\"transformer.layers.{i}.mlp.fc2.weight\", -1)\n    return state_dict\n\n\ndef remap_state_dict_hf_gpt2(state_dict, config):\n    # Word embedding and position embedding\n    def key_mapping_pos_emb(key):\n        return re.sub(r\"^wpe.\", \"transformer.embeddings.position_embeddings.\", key)\n\n    state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())\n    word_embeddings = state_dict.pop(\"wte.weight\")\n    # It's possible that vocab_size is padded to be a multiple of 8, for example.\n    pad_vocab_size_multiple = getattr(config, \"pad_vocab_size_multiple\", 1)\n    vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple\n    state_dict[\"transformer.embeddings.word_embeddings.weight\"] = F.pad(\n        word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])\n    )\n    state_dict[\"lm_head.weight\"] = state_dict[\"transformer.embeddings.word_embeddings.weight\"]\n\n    # LayerNorm\n    def key_mapping_ln(key):\n        key = re.sub(r\"^ln_f.(weight|bias)\", r\"transformer.ln_f.\\1\", key)\n        key = re.sub(r\"^h.(\\d+).ln_(1|2).(weight|bias)\", r\"transformer.layers.\\1.norm\\2.\\3\", key)\n        return key\n\n    state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())\n\n    # MLP\n    for d in range(config.num_hidden_layers):\n        W1 = state_dict.pop(f\"h.{d}.mlp.c_fc.weight\")\n        state_dict[f\"transformer.layers.{d}.mlp.fc1.weight\"] = W1.t()\n        W2 = state_dict.pop(f\"h.{d}.mlp.c_proj.weight\")\n        state_dict[f\"transformer.layers.{d}.mlp.fc2.weight\"] = W2.t()\n\n    def key_mapping_mlp(key):\n        key = re.sub(r\"^h.(\\d+).mlp.c_fc.bias\", r\"transformer.layers.\\1.mlp.fc1.bias\", key)\n        key = re.sub(r\"^h.(\\d+).mlp.c_proj.bias\", r\"transformer.layers.\\1.mlp.fc2.bias\", key)\n        return key\n\n    state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())\n\n    # Attention\n    for d in range(config.num_hidden_layers):\n        state_dict.pop(f\"h.{d}.attn.bias\", None)  # We don't store this bias\n        Wqkv = state_dict.pop(f\"h.{d}.attn.c_attn.weight\")\n        state_dict[f\"transformer.layers.{d}.mixer.Wqkv.weight\"] = Wqkv.t()\n        Wout = state_dict.pop(f\"h.{d}.attn.c_proj.weight\")\n        state_dict[f\"transformer.layers.{d}.mixer.out_proj.weight\"] = Wout.t()\n\n    def key_mapping_attn(key):\n        key = re.sub(r\"^h.(\\d+).attn.c_attn.bias\", r\"transformer.layers.\\1.mixer.Wqkv.bias\", key)\n        key = re.sub(\n            r\"^h.(\\d+).attn.c_proj.bias\", r\"transformer.layers.\\1.mixer.out_proj.bias\", key\n        )\n        return key\n\n    state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())\n\n    return state_dict\n\n\ndef remap_state_dict_megatron(state_dict, config):\n    def key_mapping_transformer(key):\n        key = re.sub(r\"^language_model.encoder.\", \"transformer.\", key)\n        key = re.sub(r\"^language_model.\", \"transformer.\", key)\n        return key\n\n    state_dict = OrderedDict((key_mapping_transformer(k), v) for k, v in state_dict.items())\n\n    # Word embedding and position embedding\n    def key_mapping_pos_emb(key):\n        return re.sub(r\"^wpe.\", \"transformer.embeddings.position_embeddings.\", key)\n\n    state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())\n    word_embeddings = state_dict.pop(\"transformer.embedding.word_embeddings.weight\")\n    # It's possible that vocab_size is padded to be a multiple of 8, for example.\n    pad_vocab_size_multiple = getattr(config, \"pad_vocab_size_multiple\", 1)\n    vocab_size = (\n        math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple\n    )\n    state_dict[\"transformer.embeddings.word_embeddings.weight\"] = F.pad(\n        word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])\n    )\n    state_dict[\"lm_head.weight\"] = state_dict[\"transformer.embeddings.word_embeddings.weight\"]\n\n    # LayerNorm\n    def key_mapping_ln(key):\n        key = re.sub(r\"^transformer.final_layernorm.(weight|bias)\", r\"transformer.ln_f.\\1\", key)\n        key = re.sub(\n            r\"^transformer.layers.(\\d+).input_layernorm.(weight|bias)\",\n            r\"transformer.layers.\\1.norm1.\\2\",\n            key,\n        )\n        key = re.sub(\n            r\"^transformer.layers.(\\d+).post_attention_layernorm.(weight|bias)\",\n            r\"transformer.layers.\\1.norm2.\\2\",\n            key,\n        )\n        return key\n\n    state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())\n\n    # MLP\n    def key_mapping_mlp(key):\n        key = re.sub(\n            r\"^transformer.layers.(\\d+).mlp.dense_h_to_4h.(weight|bias)\",\n            r\"transformer.layers.\\1.mlp.fc1.\\2\",\n            key,\n        )\n        key = re.sub(\n            r\"^transformer.layers.(\\d+).mlp.dense_4h_to_h.(weight|bias)\",\n            r\"transformer.layers.\\1.mlp.fc2.\\2\",\n            key,\n        )\n        return key\n\n    state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())\n\n    # Attention\n    def key_mapping_attn(key):\n        key = re.sub(\n            r\"^transformer.layers.(\\d+).self_attention.rotary_emb.inv_freq\",\n            r\"transformer.layers.\\1.mixer.rotary_emb.inv_freq\",\n            key,\n        )\n        key = re.sub(\n            r\"^transformer.layers.(\\d+).self_attention.query_key_value.(weight|bias)\",\n            r\"transformer.layers.\\1.mixer.Wqkv.\\2\",\n            key,\n        )\n        key = re.sub(\n            r\"^transformer.layers.(\\d+).self_attention.dense.(weight|bias)\",\n            r\"transformer.layers.\\1.mixer.out_proj.\\2\",\n            key,\n        )\n        return key\n\n    state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())\n    # Megatron stores Wqkv as ((nheads 3 headdim), hidden_dim)\n    # while we store Wqkv as ((3 nheads headdim), hidden_dim)\n    headdim = config.hidden_size // config.num_attention_heads\n    for d in range(config.num_hidden_layers):\n        Wqkv = state_dict.pop(f\"transformer.layers.{d}.mixer.Wqkv.weight\")\n        state_dict[f\"transformer.layers.{d}.mixer.Wqkv.weight\"] = rearrange(\n            Wqkv,\n            \"(nheads three headdim) ... -> (three nheads headdim) ...\",\n            three=3,\n            headdim=headdim,\n        )\n        bqkv = state_dict.pop(f\"transformer.layers.{d}.mixer.Wqkv.bias\")\n        state_dict[f\"transformer.layers.{d}.mixer.Wqkv.bias\"] = rearrange(\n            bqkv, \"(nheads three headdim) -> (three nheads headdim)\", three=3, headdim=headdim\n        )\n\n    return state_dict\n"
  },
  {
    "path": "flash_attn/models/gpt_neox.py",
    "content": "# Copyright (c) 2023, Tri Dao.\n\nimport math\nimport re\nfrom collections import OrderedDict\n\nimport torch\nimport torch.nn.functional as F\nfrom einops import rearrange\nfrom transformers import GPT2Config, GPTNeoXConfig\n\n\ndef remap_state_dict_hf_gpt_neox(state_dict, config):\n    def key_mapping_layers(key):\n        return re.sub(r\"^gpt_neox.\", \"transformer.\", key)\n\n    state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())\n    # Word embedding\n    def key_mapping_emb(key):\n        return re.sub(r\"^transformer.embed_in.\", \"transformer.embeddings.word_embeddings.\", key)\n\n    state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())\n    word_embeddings = state_dict.pop(\"transformer.embeddings.word_embeddings.weight\")\n    # It's possible that vocab_size is padded to be a multiple of 8, for example.\n    pad_vocab_size_multiple = getattr(config, \"pad_vocab_size_multiple\", 1)\n    vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple\n    state_dict[\"transformer.embeddings.word_embeddings.weight\"] = F.pad(\n        word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])\n    )\n    if getattr(config, \"tie_word_embeddings\", False):\n        state_dict[\"lm_head.weight\"] = state_dict[\"transformer.embeddings.word_embeddings.weight\"]\n    else:\n        output_embeddings = state_dict.pop(\"embed_out.weight\")\n        # It's possible that vocab_size is padded to be a multiple of 8, for example.\n        state_dict[\"lm_head.weight\"] = F.pad(\n            output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])\n        )\n\n    # LayerNorm\n    def key_mapping_ln(key):\n        key = re.sub(r\"^transformer.final_layer_norm.\", r\"transformer.ln_f.\", key)\n        key = re.sub(\n            r\"^transformer.layers.(\\d+).input_layernorm.\", r\"transformer.layers.\\1.norm1.\", key\n        )\n        key = re.sub(\n            r\"^transformer.layers.(\\d+).post_attention_layernorm.\",\n            r\"transformer.layers.\\1.norm2.\",\n            key,\n        )\n        return key\n\n    state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())\n\n    # MLP\n    def key_mapping_mlp(key):\n        key = re.sub(\n            r\"^transformer.layers.(\\d+).mlp.dense_h_to_4h.\", r\"transformer.layers.\\1.mlp.fc1.\", key\n        )\n        key = re.sub(\n            r\"^transformer.layers.(\\d+).mlp.dense_4h_to_h.\", r\"transformer.layers.\\1.mlp.fc2.\", key\n        )\n        return key\n\n    state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())\n\n    # Attention\n    for l in range(config.n_layer):\n        # We don't store these biases\n        state_dict.pop(f\"transformer.layers.{l}.attention.bias\")\n        state_dict.pop(f\"transformer.layers.{l}.attention.masked_bias\")\n        # We don't store these\n        state_dict.pop(f\"transformer.layers.{l}.attention.rotary_emb.inv_freq\", None)\n        # GPT-NeoX stores Wqkv as ((nheads 3 headdim), hidden_dim)\n        # while we store Wqkv as ((3 nheads headdim), hidden_dim)\n        headdim = config.hidden_size // config.num_attention_heads\n        Wqkv = state_dict.pop(f\"transformer.layers.{l}.attention.query_key_value.weight\")\n        state_dict[f\"transformer.layers.{l}.mixer.Wqkv.weight\"] = rearrange(\n            Wqkv,\n            \"(nheads three headdim) ... -> (three nheads headdim) ...\",\n            three=3,\n            headdim=headdim,\n        )\n        bqkv = state_dict.pop(f\"transformer.layers.{l}.attention.query_key_value.bias\")\n        state_dict[f\"transformer.layers.{l}.mixer.Wqkv.bias\"] = rearrange(\n            bqkv, \"(nheads three headdim) -> (three nheads headdim)\", three=3, headdim=headdim\n        )\n\n    def key_mapping_attn(key):\n        key = re.sub(\n            r\"^transformer.layers.(\\d+).attention.dense.\",\n            r\"transformer.layers.\\1.mixer.out_proj.\",\n            key,\n        )\n        return key\n\n    state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())\n\n    return state_dict\n\n\ndef gpt_neox_config_to_gpt2_config(gpt_neox_config: GPTNeoXConfig) -> GPT2Config:\n    assert gpt_neox_config.rotary_emb_base == 10000\n    return GPT2Config(\n        vocab_size=gpt_neox_config.vocab_size,\n        n_positions=0,  # No absolute position embedding\n        n_embd=gpt_neox_config.hidden_size,\n        n_layer=gpt_neox_config.num_hidden_layers,\n        n_head=gpt_neox_config.num_attention_heads,\n        n_inner=gpt_neox_config.intermediate_size,\n        activation_function=gpt_neox_config.hidden_act,\n        resid_pdrop=0.0,  # No dropout\n        embd_pdrop=0.0,\n        attn_pdrop=0.0,\n        layer_norm_epsilon=gpt_neox_config.layer_norm_eps,\n        initializer_range=gpt_neox_config.initializer_range,\n        bos_token_id=gpt_neox_config.bos_token_id,\n        eos_token_id=gpt_neox_config.eos_token_id,\n        # These are new arguments not in the original GPT2Config\n        prenorm=True,\n        parallel_block=gpt_neox_config.use_parallel_residual,\n        parallel_block_tied_norm=False,\n        rotary_emb_fraction=gpt_neox_config.rotary_pct,\n        tie_word_embeddings=gpt_neox_config.tie_word_embeddings,\n    )\n"
  },
  {
    "path": "flash_attn/models/gptj.py",
    "content": "# Copyright (c) 2023, Tri Dao.\n\nimport math\nimport re\nfrom collections import OrderedDict\n\nimport torch\nimport torch.nn.functional as F\nfrom transformers import GPT2Config, GPTJConfig\n\n\ndef remap_state_dict_hf_gptj(state_dict, config):\n    def key_mapping_layers(key):\n        return re.sub(r\"^transformer.h.\", \"transformer.layers.\", key)\n\n    state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())\n    # Word embedding\n    def key_mapping_emb(key):\n        return re.sub(r\"^transformer.wte.\", \"transformer.embeddings.word_embeddings.\", key)\n\n    state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())\n    word_embeddings = state_dict.pop(\"transformer.embeddings.word_embeddings.weight\")\n    # It's possible that vocab_size is padded to be a multiple of 8, for example.\n    pad_vocab_size_multiple = getattr(config, \"pad_vocab_size_multiple\", 1)\n    vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple\n    state_dict[\"transformer.embeddings.word_embeddings.weight\"] = F.pad(\n        word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])\n    )\n    if getattr(config, \"tie_word_embeddings\"):\n        state_dict[\"lm_head.weight\"] = state_dict[\"transformer.embeddings.word_embeddings.weight\"]\n    else:\n        output_embeddings = state_dict.pop(\"lm_head.weight\")\n        # It's possible that vocab_size is padded to be a multiple of 8, for example.\n        state_dict[\"lm_head.weight\"] = F.pad(\n            output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])\n        )\n        output_embeddings_bias = state_dict.pop(\"lm_head.bias\")\n        state_dict[\"lm_head.bias\"] = F.pad(\n            output_embeddings_bias, (0, vocab_size - output_embeddings_bias.shape[0])\n        )\n\n    # LayerNorm\n    def key_mapping_ln(key):\n        return re.sub(r\"^transformer.layers.(\\d+).ln_1.\", r\"transformer.layers.\\1.norm1.\", key)\n\n    state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())\n\n    # MLP\n    def key_mapping_mlp(key):\n        key = re.sub(\n            r\"^transformer.layers.(\\d+).mlp.fc_in.\", r\"transformer.layers.\\1.mlp.fc1.\", key\n        )\n        key = re.sub(\n            r\"^transformer.layers.(\\d+).mlp.fc_out.\", r\"transformer.layers.\\1.mlp.fc2.\", key\n        )\n        return key\n\n    state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())\n\n    # Attention\n    for l in range(config.n_layer):\n        Wq = state_dict.pop(f\"transformer.layers.{l}.attn.q_proj.weight\")\n        Wk = state_dict.pop(f\"transformer.layers.{l}.attn.k_proj.weight\")\n        Wv = state_dict.pop(f\"transformer.layers.{l}.attn.v_proj.weight\")\n        state_dict[f\"transformer.layers.{l}.mixer.Wqkv.weight\"] = torch.cat([Wq, Wk, Wv], dim=0)\n        # We don't store these biases\n        state_dict.pop(f\"transformer.layers.{l}.attn.bias\")\n        state_dict.pop(f\"transformer.layers.{l}.attn.masked_bias\")\n\n    def key_mapping_attn(key):\n        return re.sub(\n            r\"^transformer.layers.(\\d+).attn.out_proj.\",\n            r\"transformer.layers.\\1.mixer.out_proj.\",\n            key,\n        )\n\n    state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())\n\n    return state_dict\n\n\ndef gptj_config_to_gpt2_config(gptj_config: GPTJConfig) -> GPT2Config:\n    headdim = gptj_config.n_embd // gptj_config.n_head\n    return GPT2Config(\n        vocab_size=gptj_config.vocab_size,\n        n_positions=0,  # No absolute position embedding\n        n_embd=gptj_config.n_embd,\n        n_layer=gptj_config.n_layer,\n        n_head=gptj_config.n_head,\n        n_inner=gptj_config.n_inner,\n        activation_function=gptj_config.activation_function,\n        resid_pdrop=gptj_config.resid_pdrop,\n        embd_pdrop=gptj_config.embd_pdrop,\n        attn_pdrop=gptj_config.attn_pdrop,\n        layer_norm_epsilon=gptj_config.layer_norm_epsilon,\n        initializer_range=gptj_config.initializer_range,\n        bos_token_id=gptj_config.bos_token_id,\n        eos_token_id=gptj_config.eos_token_id,\n        # These are new arguments not in the original GPT2Config\n        prenorm=True,\n        parallel_block=True,\n        parallel_block_tied_norm=True,\n        rotary_emb_fraction=gptj_config.rotary_dim / headdim,\n        rotary_emb_interleaved=True,\n        tie_word_embeddings=False,\n        qkv_proj_bias=False,\n        out_proj_bias=False,\n        lm_head_bias=True,\n    )\n"
  },
  {
    "path": "flash_attn/models/llama.py",
    "content": "# Copyright (c) 2023, Tri Dao.\n\nimport json\nimport math\nimport os\nimport re\nfrom collections import OrderedDict\nfrom pathlib import Path\nfrom typing import Dict, List, Union\n\nimport torch\nimport torch.nn.functional as F\nfrom sentencepiece import SentencePieceProcessor\nfrom transformers import GPT2Config, LlamaConfig\n\nfrom einops import rearrange\n\n\ndef remap_state_dict_meta_llama(\n    state_dict: Dict[str, torch.Tensor], config: GPT2Config\n) -> Dict[str, torch.Tensor]:\n    \"\"\"Convert the state_dict in Meta format to standard GPT format.\n\n    This function modifies state_dict in place.\n    \"\"\"\n\n    def key_mapping_layers(key):\n        return f\"transformer.{key}\" if not key.startswith(\"output.\") else key\n\n    state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())\n\n    # Word embedding\n    def key_mapping_emb(key):\n        return re.sub(\n            r\"^transformer.tok_embeddings.\", \"transformer.embeddings.word_embeddings.\", key\n        )\n\n    state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())\n    word_embeddings = state_dict.pop(\"transformer.embeddings.word_embeddings.weight\")\n    # It's possible that vocab_size is padded to be a multiple of 8, for example.\n    pad_vocab_size_multiple = getattr(config, \"pad_vocab_size_multiple\", 1)\n    vocab_size = (\n        math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple\n    )\n    state_dict[\"transformer.embeddings.word_embeddings.weight\"] = F.pad(\n        word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])\n    )\n    if getattr(config, \"tie_word_embeddings\"):\n        state_dict[\"lm_head.weight\"] = state_dict[\"transformer.embeddings.word_embeddings.weight\"]\n    else:\n        output_embeddings = state_dict.pop(\"output.weight\")\n        # Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings\n        # differently.\n        vocab_size = (\n            math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)\n            * pad_vocab_size_multiple\n        )\n        # It's possible that vocab_size is padded to be a multiple of 8, for example.\n        state_dict[\"lm_head.weight\"] = F.pad(\n            output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])\n        )\n\n    # LayerNorm\n    def key_mapping_ln(key):\n        key = re.sub(r\"^transformer.norm.\", r\"transformer.ln_f.\", key)\n        key = re.sub(\n            r\"^transformer.layers.(\\d+).attention_norm.\",\n            r\"transformer.layers.\\1.norm1.\",\n            key,\n        )\n        key = re.sub(r\"^transformer.layers.(\\d+).ffn_norm.\", r\"transformer.layers.\\1.norm2.\", key)\n        return key\n\n    state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())\n\n    # MLP\n    for l in range(config.n_layer):\n        w1 = state_dict.pop(f\"transformer.layers.{l}.feed_forward.w1.weight\")\n        w3 = state_dict.pop(f\"transformer.layers.{l}.feed_forward.w3.weight\")\n        # Our ordering is different\n        state_dict[f\"transformer.layers.{l}.mlp.fc1.weight\"] = torch.cat([w3, w1], dim=0)\n\n    def key_mapping_mlp(key):\n        return re.sub(\n            r\"^transformer.layers.(\\d+).feed_forward.w2.\",\n            r\"transformer.layers.\\1.mlp.fc2.\",\n            key,\n        )\n\n    state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())\n\n    # Attention\n    for l in range(config.n_layer):\n        Wq = state_dict.pop(f\"transformer.layers.{l}.attention.wq.weight\")\n        Wk = state_dict.pop(f\"transformer.layers.{l}.attention.wk.weight\")\n        Wv = state_dict.pop(f\"transformer.layers.{l}.attention.wv.weight\")\n        state_dict[f\"transformer.layers.{l}.mixer.Wqkv.weight\"] = torch.cat([Wq, Wk, Wv], dim=0)\n        # We don't store these\n        state_dict.pop(f\"transformer.layers.{l}.attention.inner_attention.rope.freqs\", None)\n\n    def key_mapping_attn(key):\n        return re.sub(\n            r\"^transformer.layers.(\\d+).attention.wo.\",\n            r\"transformer.layers.\\1.mixer.out_proj.\",\n            key,\n        )\n\n    state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())\n\n    state_dict.pop(\"transformer.rope.freqs\", None)\n\n    return state_dict\n\n\ndef remap_state_dict_hf_llama(\n    state_dict: Dict[str, torch.Tensor], config: GPT2Config\n) -> Dict[str, torch.Tensor]:\n    \"\"\"Convert the state_dict in Hugging Face format to standard GPT format.\n\n    This function modifies state_dict in place.\n    \"\"\"\n\n    # Embedding\n    def key_mapping_emb(key):\n        return re.sub(r\"^model.embed_tokens.\", \"transformer.embeddings.word_embeddings.\", key)\n\n    state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())\n    word_embeddings = state_dict.pop(\"transformer.embeddings.word_embeddings.weight\")\n    # It's possible that vocab_size is padded to be a multiple of 8, for example.\n    pad_vocab_size_multiple = getattr(config, \"pad_vocab_size_multiple\", 1)\n    vocab_size = (\n        math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple\n    )\n    state_dict[\"transformer.embeddings.word_embeddings.weight\"] = F.pad(\n        word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])\n    )\n\n    # LM head\n    if getattr(config, \"tie_word_embeddings\"):\n        state_dict[\"lm_head.weight\"] = state_dict[\"transformer.embeddings.word_embeddings.weight\"]\n    else:\n        output_embeddings = state_dict.pop(\"lm_head.weight\")\n        # Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings\n        # differently.\n        vocab_size = (\n            math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)\n            * pad_vocab_size_multiple\n        )\n        # It's possible that vocab_size is padded to be a multiple of 8, for example.\n        state_dict[\"lm_head.weight\"] = F.pad(\n            output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])\n        )\n\n    # MLP\n    for l in range(config.n_layer):\n        # Fusing weights this way based on difference in the following:\n        # https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/modeling_llama.py#L220\n        # https://github.com/Dao-AILab/flash-attention/blob/c60851a8253257eb970e06a022c82517a8033e8c/flash_attn/modules/mlp.py#L115\n        w1 = state_dict.pop(f\"model.layers.{l}.mlp.gate_proj.weight\")\n        w3 = state_dict.pop(f\"model.layers.{l}.mlp.up_proj.weight\")\n        state_dict[f\"transformer.layers.{l}.mlp.fc1.weight\"] = torch.cat([w3, w1], dim=0)\n\n    def key_mapping_mlp(key):\n        return re.sub(\n            r\"^model.layers.(\\d+).mlp.down_proj.\",\n            r\"transformer.layers.\\1.mlp.fc2.\",\n            key,\n        )\n\n    state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())\n\n    # LayerNorm\n    def key_mapping_ln(key):\n        key = re.sub(r\"^model.norm.\", r\"transformer.ln_f.\", key)\n        key = re.sub(\n            r\"^model.layers.(\\d+).input_layernorm.\",\n            r\"transformer.layers.\\1.norm1.\",\n            key,\n        )\n        key = re.sub(\n            r\"^model.layers.(\\d+).post_attention_layernorm.\",\n            r\"transformer.layers.\\1.norm2.\",\n            key,\n        )\n        return key\n\n    state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())\n\n    def inv_permute(w):\n        # Inverse of permute implemented in:\n        # https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/convert_llama_weights_to_hf.py#L114\n        return rearrange(\n            w, \"(h two d) n -> (h d two) n\", d=config.n_embd // config.n_head // 2, two=2\n        )\n\n    # Attention\n    for l in range(config.n_layer):\n        Wq = state_dict.pop(f\"model.layers.{l}.self_attn.q_proj.weight\")\n        Wk = state_dict.pop(f\"model.layers.{l}.self_attn.k_proj.weight\")\n        Wv = state_dict.pop(f\"model.layers.{l}.self_attn.v_proj.weight\")\n\n        state_dict[f\"transformer.layers.{l}.mixer.Wqkv.weight\"] = torch.cat(\n            [inv_permute(Wq), inv_permute(Wk), Wv], dim=0\n        )\n        # We don't store these\n        state_dict.pop(f\"model.layers.{l}.self_attn.rotary_emb.inv_freq\", None)\n\n    def key_mapping_attn(key):\n        return re.sub(\n            r\"^model.layers.(\\d+).self_attn.o_proj.\",\n            r\"transformer.layers.\\1.mixer.out_proj.\",\n            key,\n        )\n\n    state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())\n    return state_dict\n\n\ndef inv_remap_state_dict_hf_llama(\n    state_dict: Dict[str, torch.Tensor], config: GPT2Config\n) -> Dict[str, torch.Tensor]:\n    \"\"\"Convert the state_dict in standard GPT format to Hugging Face format.\n\n    This function is meant to be the inverse of remap_state_dict_hf_llama, up to a\n    multiplier pad in the embedding and lm_head. That is if the original embedding\n    isn't a multiple of pad_vocab_size_multiple, then\n    inv_remap_state_dict_hf_llama(remap_state_dict_hf_llama(state_dict)) != state_dict.\n\n    This function modifies state_dict in place.\n    \"\"\"\n\n    # Embedding\n    def key_mapping_emb(key):\n        return re.sub(r\"^transformer.embeddings.word_embeddings.\", \"model.embed_tokens.\", key)\n\n    state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())\n    word_embeddings = state_dict.pop(\"model.embed_tokens.weight\")\n    pad_vocab_size_multiple = getattr(config, \"pad_vocab_size_multiple\", 1)\n    vocab_size = (\n        math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple\n    )\n    state_dict[\"model.embed_tokens.weight\"] = F.pad(\n        word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])\n    )\n\n    # LM head\n    if getattr(config, \"tie_word_embeddings\"):\n        state_dict[\"lm_head.weight\"] = state_dict[\"model.embed_tokens.weight\"]\n    else:\n        output_embeddings = state_dict.pop(\"lm_head.weight\")\n        vocab_size = (\n            math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)\n            * pad_vocab_size_multiple\n        )\n        state_dict[\"lm_head.weight\"] = F.pad(\n            output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])\n        )\n\n    # MLP\n    for l in range(config.n_layer):\n        w3, w1 = torch.chunk(\n            state_dict.pop(f\"transformer.layers.{l}.mlp.fc1.weight\"), chunks=2, dim=0\n        )\n        state_dict[f\"model.layers.{l}.mlp.gate_proj.weight\"] = w1\n        state_dict[f\"model.layers.{l}.mlp.up_proj.weight\"] = w3\n\n    def key_mapping_mlp(key):\n        return re.sub(\n            r\"^transformer.layers.(\\d+).mlp.fc2.\",\n            r\"model.layers.\\1.mlp.down_proj.\",\n            key,\n        )\n\n    state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())\n\n    # LayerNorm\n    def key_mapping_ln(key):\n        key = re.sub(r\"^transformer.ln_f.\", r\"model.norm.\", key)\n        key = re.sub(\n            r\"^transformer.layers.(\\d+).norm1.\",\n            r\"model.layers.\\1.input_layernorm.\",\n            key,\n        )\n        key = re.sub(\n            r\"^transformer.layers.(\\d+).norm2.\",\n            r\"model.layers.\\1.post_attention_layernorm.\",\n            key,\n        )\n        return key\n\n    state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())\n\n    def permute(w):\n        return rearrange(\n            w, \"(h d two) n -> (h two d) n\", d=config.n_embd // config.n_head // 2, two=2\n        )\n\n    n_head = config.n_head\n    n_head_kv = getattr(config, \"n_head_kv\", n_head)\n\n    embed_dim = config.hidden_size\n    head_dim = embed_dim // n_head\n\n    q_dim = n_head * head_dim\n    k_dim = v_dim = n_head_kv * head_dim\n\n    # Attention\n    for l in range(config.n_layer):\n        Wqkv = state_dict.pop(f\"transformer.layers.{l}.mixer.Wqkv.weight\")\n        Wq = Wqkv[:q_dim]\n        Wk = Wqkv[q_dim : q_dim + k_dim]\n        Wv = Wqkv[q_dim + k_dim : q_dim + k_dim + v_dim]\n        state_dict[f\"model.layers.{l}.self_attn.q_proj.weight\"] = permute(Wq)\n        state_dict[f\"model.layers.{l}.self_attn.k_proj.weight\"] = permute(Wk)\n        state_dict[f\"model.layers.{l}.self_attn.v_proj.weight\"] = Wv\n        state_dict.pop(f\"transformer.layers.{l}.attention.inner_attention.rope.freqs\", None)\n\n    def key_mapping_attn(key):\n        return re.sub(\n            r\"^transformer.layers.(\\d+).mixer.out_proj.\",\n            r\"model.layers.\\1.self_attn.o_proj.\",\n            key,\n        )\n\n    state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())\n    return state_dict\n\n\ndef config_from_meta_checkpoint(\n    checkpoint_path: Union[str, os.PathLike], model_name: str\n) -> LlamaConfig:\n    \"\"\"Load a LlamaConfig from a checkpoint path.\"\"\"\n    with open(Path(checkpoint_path) / model_name / \"params.json\") as f:\n        params = json.load(f)\n    config = LlamaConfig(\n        hidden_size=params[\"dim\"],\n        intermediate_size=None,\n        num_attention_heads=params[\"n_heads\"],\n        num_hidden_layers=params[\"n_layers\"],\n        rms_norm_eps=params[\"norm_eps\"],\n        num_key_value_heads=params.get(\"n_kv_heads\", None),\n    )\n    multiple_of = params.get(\"multiple_of\", 1)\n    ffn_dim_multiplier = params.get(\"ffn_dim_multiplier\", None)\n\n    # Compute the hidden dimension of the MLP\n    # https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/model.py#L224\n    intermediate_size = 4 * config.hidden_size\n    # https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/model.py#L195-L199\n    intermediate_size = int(2 * intermediate_size / 3)\n    # custom dim factor multiplier\n    if ffn_dim_multiplier is not None:\n        intermediate_size = int(ffn_dim_multiplier * intermediate_size)\n    intermediate_size = multiple_of * ((intermediate_size + multiple_of - 1) // multiple_of)\n\n    config.intermediate_size = intermediate_size\n    if \"rope_theta\" in params:\n        config.rotary_emb_base = params[\"rope_theta\"]\n    config.vocab_size = 32000\n    # some CodeLLaMa have vocab_size 32000, some 32016\n    # Sadly it's not specified in the `params.json` file :(\n    tokenizer = Path(checkpoint_path) / model_name / \"tokenizer.model\"\n    if tokenizer.is_file():\n        config.vocab_size = SentencePieceProcessor(str(tokenizer)).vocab_size()\n    return config\n\n\ndef config_from_hf_checkpoint(\n    checkpoint_path: Union[str, os.PathLike], model_name: str\n) -> LlamaConfig:\n    return LlamaConfig.from_pretrained(Path(checkpoint_path) / f\"{model_name}-hf\" / \"config.json\")\n\n\ndef config_from_checkpoint(\n    checkpoint_path: Union[str, os.PathLike], model_name: str, checkpoint_format=\"meta\"\n) -> LlamaConfig:\n    if checkpoint_format == \"meta\":\n        return config_from_meta_checkpoint(checkpoint_path, model_name)\n    else:\n        return config_from_hf_checkpoint(checkpoint_path, model_name)\n\n\ndef state_dicts_from_checkpoint(\n    checkpoint_path: Union[str, os.PathLike], model_name: str\n) -> List[dict]:\n    # Need to sort, otherwise we mess up the ordering and the weights are wrong\n    return [\n        torch.load(path, map_location=\"cpu\")\n        for path in sorted((Path(checkpoint_path) / model_name).glob(\"consolidated.*.pth\"))\n    ]\n\n\ndef llama_config_to_gpt2_config(llama_config: LlamaConfig) -> GPT2Config:\n    return GPT2Config(\n        vocab_size=llama_config.vocab_size,\n        n_positions=0,  # No absolute position embedding\n        n_embd=llama_config.hidden_size,\n        n_layer=llama_config.num_hidden_layers,\n        n_head=llama_config.num_attention_heads,\n        n_inner=llama_config.intermediate_size,\n        activation_function=\"swiglu\",  # Hardcode since HF calls it 'silu'\n        # Llama doesn't have dropout, idk if it's because they only release the inference code\n        resid_pdrop=0.0,\n        embd_pdrop=0.0,\n        attn_pdrop=0.0,\n        layer_norm_epsilon=llama_config.rms_norm_eps,\n        initializer_range=llama_config.initializer_range,\n        bos_token_id=llama_config.bos_token_id,\n        eos_token_id=llama_config.eos_token_id,\n        # These are new arguments not in the original GPT2Config\n        pad_token_id=llama_config.pad_token_id,  # Idk if this does anything\n        rms_norm=True,\n        rotary_emb_fraction=1.0,\n        rotary_emb_interleaved=True,\n        tie_word_embeddings=False,\n        qkv_proj_bias=False,\n        out_proj_bias=False,\n        mlp_fc1_bias=False,\n        mlp_fc2_bias=False,\n        rotary_emb_base=getattr(llama_config, \"rotary_emb_base\", 10000.0),\n        n_head_kv=llama_config.num_key_value_heads,\n    )\n"
  },
  {
    "path": "flash_attn/models/opt.py",
    "content": "# Copyright (c) 2023, Tri Dao.\n\nimport math\nimport re\nfrom collections import OrderedDict\n\nimport torch\nimport torch.nn.functional as F\nfrom transformers import GPT2Config, OPTConfig\n\n\ndef remap_state_dict_hf_opt(state_dict, config):\n    def key_mapping_model(key):\n        key = re.sub(r\"^model.decoder.\", \"transformer.\", key)\n        # The OPT-350m model uses '^decoder' instead of '^model.decoder'\n        key = re.sub(r\"^decoder.\", \"transformer.\", key)\n        return key\n\n    state_dict = OrderedDict((key_mapping_model(k), v) for k, v in state_dict.items())\n    # Word embedding and position embedding\n    def key_mapping_emb(key):\n        key = re.sub(r\"^transformer.embed_tokens.\", \"transformer.embeddings.word_embeddings.\", key)\n        # The OPT-350m model uses has project_in and project_out\n        key = re.sub(r\"^transformer.project_in.\", \"transformer.embeddings.project_in.\", key)\n        key = re.sub(r\"^transformer.project_out.\", \"project_out.\", key)\n        key = re.sub(\n            r\"^transformer.embed_positions.\", \"transformer.embeddings.position_embeddings.\", key\n        )\n        return key\n\n    state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())\n    # OPT uses the first 2 indices of pos_emb for padding tokens\n    pos_embeddings = state_dict.pop(\"transformer.embeddings.position_embeddings.weight\")\n    state_dict[\"transformer.embeddings.position_embeddings.weight\"] = pos_embeddings[2:]\n    word_embeddings = state_dict.pop(\"transformer.embeddings.word_embeddings.weight\")\n    # It's possible that vocab_size is padded to be a multiple of 8, for example.\n    pad_vocab_size_multiple = getattr(config, \"pad_vocab_size_multiple\", 1)\n    vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple\n    state_dict[\"transformer.embeddings.word_embeddings.weight\"] = F.pad(\n        word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])\n    )\n    state_dict[\"lm_head.weight\"] = state_dict[\"transformer.embeddings.word_embeddings.weight\"]\n\n    # LayerNorm\n    def key_mapping_ln(key):\n        key = re.sub(r\"^transformer.final_layer_norm.\", r\"transformer.ln_f.\", key)\n        # The OPT-175B checkpoint calls this 'decoder.layer_norm' instead of 'decoder.final_layer_norm'\n        key = re.sub(r\"^transformer.layer_norm.\", r\"transformer.ln_f.\", key)\n        key = re.sub(\n            r\"^transformer.layers.(\\d+).self_attn_layer_norm.\", r\"transformer.layers.\\1.norm1.\", key\n        )\n        key = re.sub(\n            r\"^transformer.layers.(\\d+).final_layer_norm.\", r\"transformer.layers.\\1.norm2.\", key\n        )\n        return key\n\n    state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())\n\n    # MLP\n    def key_mapping_mlp(key):\n        return re.sub(\n            r\"^transformer.layers.(\\d+).fc(1|2).\", r\"transformer.layers.\\1.mlp.fc\\2.\", key\n        )\n\n    state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())\n\n    # Attention\n    for l in range(config.n_layer):\n        Wq = state_dict.pop(f\"transformer.layers.{l}.self_attn.q_proj.weight\")\n        Wk = state_dict.pop(f\"transformer.layers.{l}.self_attn.k_proj.weight\")\n        Wv = state_dict.pop(f\"transformer.layers.{l}.self_attn.v_proj.weight\")\n        bq = state_dict.pop(f\"transformer.layers.{l}.self_attn.q_proj.bias\")\n        bk = state_dict.pop(f\"transformer.layers.{l}.self_attn.k_proj.bias\")\n        bv = state_dict.pop(f\"transformer.layers.{l}.self_attn.v_proj.bias\")\n        state_dict[f\"transformer.layers.{l}.mixer.Wqkv.weight\"] = torch.cat([Wq, Wk, Wv], dim=0)\n        state_dict[f\"transformer.layers.{l}.mixer.Wqkv.bias\"] = torch.cat([bq, bk, bv], dim=0)\n\n    def key_mapping_attn(key):\n        return re.sub(\n            r\"^transformer.layers.(\\d+).self_attn.out_proj.\",\n            r\"transformer.layers.\\1.mixer.out_proj.\",\n            key,\n        )\n\n    state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())\n\n    return state_dict\n\n\ndef opt_config_to_gpt2_config(opt_config: OPTConfig) -> GPT2Config:\n    assert opt_config.layerdrop == 0.0\n    assert opt_config.layer_norm_elementwise_affine\n    word_embed_proj_dim = (\n        None\n        if opt_config.word_embed_proj_dim == opt_config.hidden_size\n        else opt_config.word_embed_proj_dim\n    )\n    return GPT2Config(\n        vocab_size=opt_config.vocab_size,\n        n_positions=opt_config.max_position_embeddings,\n        n_embd=opt_config.hidden_size,\n        n_layer=opt_config.num_hidden_layers,\n        n_head=opt_config.num_attention_heads,\n        n_inner=opt_config.ffn_dim,\n        activation_function=opt_config.activation_function,\n        resid_pdrop=opt_config.dropout,\n        # HF's implementation of OPT doesn't seem to have embedding dropout\n        embd_pdrop=opt_config.dropout,\n        attn_pdrop=opt_config.attention_dropout,\n        initializer_range=opt_config.init_std,\n        bos_token_id=opt_config.bos_token_id,\n        eos_token_id=opt_config.eos_token_id,\n        # These are new arguments not in the original GPT2Config\n        prenorm=opt_config.do_layer_norm_before,\n        word_embed_proj_dim=word_embed_proj_dim,\n    )\n"
  },
  {
    "path": "flash_attn/models/vit.py",
    "content": "# Copyright (c) 2022, Tri Dao.\n# Inspired by / adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py\nimport math\nimport re\nfrom collections import OrderedDict\nfrom copy import deepcopy\nfrom functools import partial\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom einops import rearrange\nfrom timm.models.helpers import named_apply\nfrom torch.nn.init import trunc_normal_\nfrom torchvision.ops import StochasticDepth\n\nfrom flash_attn.layers.patch_embed import PatchEmbed\nfrom flash_attn.modules.block import Block\nfrom flash_attn.modules.mha import MHA\nfrom flash_attn.modules.mlp import FusedMLP, Mlp\n\ntry:\n    from flash_attn.ops.triton.layer_norm import layer_norm_fn\nexcept ImportError:\n    layer_norm_fn = None\n\n\ndef create_mixer_cls(\n    num_heads, qkv_bias, attn_drop, use_flash_attn, fused_bias_fc, cross_attn=False\n):\n    mixer_cls = partial(\n        MHA,\n        num_heads=num_heads,\n        cross_attn=cross_attn,\n        qkv_proj_bias=qkv_bias,\n        dropout=attn_drop,\n        fused_bias_fc=fused_bias_fc,\n        use_flash_attn=use_flash_attn,\n    )\n    return mixer_cls\n\n\ndef create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_mlp):\n    inner_dim = int(embed_dim * mlp_ratio)\n    if not fused_mlp:\n        mlp_cls = partial(Mlp, hidden_features=inner_dim, activation=act_layer())\n    else:\n        mlp_cls = partial(FusedMLP, hidden_features=inner_dim)\n    return mlp_cls\n\n\ndef create_block(\n    embed_dim,\n    num_heads,\n    mlp_ratio,\n    qkv_bias,\n    drop_rate,\n    attn_drop_rate,\n    drop_path1,\n    drop_path2,\n    norm_layer,\n    act_layer,\n    use_flash_attn,\n    fused_bias_fc,\n    fused_mlp,\n    fused_dropout_add_ln,\n    layer_idx=None,\n    n_layer=None,\n    last_layer_subset=False,\n):\n    mixer_cls = create_mixer_cls(\n        num_heads,\n        qkv_bias,\n        attn_drop_rate,\n        use_flash_attn,\n        fused_bias_fc,\n        cross_attn=(last_layer_subset and layer_idx == n_layer - 1),\n    )\n    mlp_cls = create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_mlp)\n    # TD [2022-10-15]: Force residual in fp32 in case of DeepSpeed\n    block = Block(\n        embed_dim,\n        mixer_cls,\n        mlp_cls,\n        norm_cls=norm_layer,\n        prenorm=True,\n        resid_dropout1=drop_rate,\n        resid_dropout2=drop_rate,\n        drop_path1=drop_path1,\n        drop_path2=drop_path2,\n        fused_dropout_add_ln=fused_dropout_add_ln,\n        residual_in_fp32=True,\n    )\n    return block\n\n\nclass VisionTransformer(nn.Module):\n    \"\"\"Vision Transformer\n    A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`\n        - https://arxiv.org/abs/2010.11929\n    \"\"\"\n\n    def __init__(\n        self,\n        img_size=224,\n        patch_size=16,\n        in_chans=3,\n        num_classes=1000,\n        global_pool=\"token\",\n        embed_dim=768,\n        depth=12,\n        num_heads=12,\n        mlp_ratio=4.0,\n        qkv_bias=True,\n        init_values=None,\n        class_token=True,\n        no_embed_class=False,\n        pre_norm=False,\n        fc_norm=None,\n        drop_rate=0.0,\n        attn_drop_rate=0.0,\n        drop_path_rate=0.0,\n        weight_init=\"\",\n        embed_layer=PatchEmbed,\n        norm_layer=None,\n        act_layer=None,\n        use_flash_attn=False,\n        fused_bias_fc=False,\n        fused_mlp=False,\n        fused_dropout_add_ln=False,\n    ):\n        \"\"\"\n        Args:\n            img_size (int, tuple): input image size\n            patch_size (int, tuple): patch size\n            in_chans (int): number of input channels\n            num_classes (int): number of classes for classification head\n            global_pool (str): type of global pooling for final sequence (default: 'token')\n            embed_dim (int): embedding dimension\n            depth (int): depth of transformer\n            num_heads (int): number of attention heads\n            mlp_ratio (int): ratio of mlp hidden dim to embedding dim\n            qkv_bias (bool): enable bias for qkv if True\n            init_values: (float): layer-scale init values\n            class_token (bool): use class token\n            fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None)\n            drop_rate (float): dropout rate\n            attn_drop_rate (float): attention dropout rate\n            drop_path_rate (float): stochastic depth rate\n            weight_init (str): weight init scheme\n            embed_layer (nn.Module): patch embedding layer\n            norm_layer: (nn.Module): normalization layer\n            act_layer: (nn.Module): MLP activation layer\n        \"\"\"\n        super().__init__()\n        assert global_pool == \"token\", \"Only support pooling with CLS token\"\n        assert class_token\n        assert init_values is None, \"LayerScale is not supported yet\"\n        assert weight_init == \"\"\n        assert fc_norm is None\n        # pre_norm seems redundant, as there's a LayerNorm right at the start of each block, idk\n        assert not pre_norm\n        use_fc_norm = global_pool == \"avg\" if fc_norm is None else fc_norm\n        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)\n        act_layer = act_layer or nn.GELU\n\n        self.num_classes = num_classes\n        self.global_pool = global_pool\n        self.num_features = (\n            self.embed_dim\n        ) = embed_dim  # num_features for consistency with other models\n        self.num_prefix_tokens = 1 if class_token else 0\n        self.no_embed_class = no_embed_class\n\n        patch_embed_extra_kwargs = (\n            {\"fused_bias_fc\": fused_bias_fc} if embed_layer is PatchEmbed else {}\n        )\n        self.patch_embed = embed_layer(\n            img_size=img_size,\n            patch_size=patch_size,\n            in_chans=in_chans,\n            embed_dim=embed_dim,\n            bias=not pre_norm,  # disable bias if pre-norm is used (e.g. CLIP)\n            **patch_embed_extra_kwargs,\n        )\n        num_patches = self.patch_embed.num_patches\n\n        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None\n        embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens\n        self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02)\n\n        dpr = [\n            x.item() for x in torch.linspace(0, drop_path_rate, depth)\n        ]  # stochastic depth decay rule\n\n        # We change the order of dropout, residual and layer norm:\n        # Instead of LN -> Attn / MLP -> Dropout -> Add, we do:\n        # Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and\n        # the main branch (output of MLP). The model definition is unchanged, but the mapping of the\n        # nn.Dropout probabilities are changed.\n        # This is for performance reason: we can fuse dropout + add + layer_norm.\n        self.blocks = nn.ModuleList(\n            [\n                create_block(\n                    embed_dim,\n                    num_heads,\n                    mlp_ratio,\n                    qkv_bias,\n                    drop_rate,\n                    attn_drop_rate,\n                    drop_path1=dpr[i - 1] if i > 0 else 0.0,\n                    drop_path2=dpr[i],\n                    norm_layer=norm_layer,\n                    act_layer=act_layer,\n                    use_flash_attn=use_flash_attn,\n                    fused_bias_fc=fused_bias_fc,\n                    fused_mlp=fused_mlp,\n                    fused_dropout_add_ln=fused_dropout_add_ln,\n                    layer_idx=i,\n                    n_layer=depth,\n                    last_layer_subset=(global_pool == \"token\"),\n                )\n                for i in range(depth)\n            ]\n        )\n\n        self.dropout = nn.Dropout(p=drop_rate)\n        self.drop_path = StochasticDepth(p=dpr[-1], mode=\"row\")\n        self.norm = norm_layer(embed_dim)\n\n        self.fused_dropout_add_ln = fused_dropout_add_ln\n        if self.fused_dropout_add_ln and layer_norm_fn is None:\n            raise ImportError(\"Triton is not installed\")\n\n        # Classifier Head\n        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()\n\n        self.init_weights(weight_init)\n\n    def init_weights(self, mode=\"\"):\n        assert mode == \"\"\n        trunc_normal_(self.pos_embed, std=0.02)\n        if self.cls_token is not None:\n            nn.init.normal_(self.cls_token, std=1e-6)\n        named_apply(init_weights_vit_timm, self)\n\n    def _init_weights(self, m):\n        # this fn left here for compat with downstream users\n        init_weights_vit_timm(m)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {\"pos_embed\", \"cls_token\"}\n\n    def _pos_embed(self, x):\n        if self.no_embed_class:\n            # deit-3, updated JAX (big vision)\n            # position embedding does not overlap with class token, add then concat\n            x = x + self.pos_embed\n            if self.cls_token is not None:\n                x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)\n        else:\n            # original timm, JAX, and deit vit impl\n            # pos_embed has entry for class token, concat then add\n            if self.cls_token is not None:\n                x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)\n            x = x + self.pos_embed\n        return x\n\n    def forward_features(self, x, all_tokens=True):\n        \"\"\"\n        If all_tokens==False and self.global_pool == 'token', we only return the features for the\n        cls token.\n        \"\"\"\n        x = self.patch_embed(x)\n        hidden_states = self._pos_embed(x)\n        residual = None\n        if self.global_pool != \"token\" or all_tokens:\n            # if True:\n            for block in self.blocks:\n                hidden_states, residual = block(hidden_states, residual)\n        else:\n            for block in self.blocks[:-1]:\n                hidden_states, residual = block(hidden_states, residual)\n            # For the last layer, we only want the 1st token of the output. So we do cross-attention\n            # where the query is the 1st token and the key/value is the whole sequence.\n            hidden_states, residual = self.blocks[-1](\n                hidden_states, residual, mixer_subset=slice(0, 1)\n            )\n        if not self.fused_dropout_add_ln:\n            residual = self.drop_path(self.dropout(hidden_states)) + residual\n            hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))\n        else:\n            if self.drop_path.p == 0 or not self.training:\n                rowscale = None\n            else:\n                rowscale = self.drop_path(\n                    torch.ones(\n                        hidden_states.shape[:-1],\n                        device=hidden_states.device,\n                        dtype=hidden_states.dtype,\n                    )\n                )\n            # Set prenorm=False here since we don't need to the residual\n            hidden_states = layer_norm_fn(\n                hidden_states,\n                self.norm.weight,\n                self.norm.bias,\n                residual=residual,\n                eps=self.norm.eps,\n                dropout_p=self.dropout.p if self.training else 0.0,\n                rowscale=rowscale,\n                prenorm=False,\n            )\n        return hidden_states\n\n    def forward_head(self, x, pre_logits: bool = False):\n        if self.global_pool:\n            x = x[:, self.num_prefix_tokens :].mean(dim=1) if self.global_pool == \"avg\" else x[:, 0]\n        return x if pre_logits else self.head(x)\n\n    def forward(self, x):\n        x = self.forward_features(x, all_tokens=False)\n        x = self.forward_head(x)\n        return x\n\n    def load_state_dict(self, state_dict, strict=True):\n        patch_embed_weight = state_dict[\"patch_embed.proj.weight\"]\n        if patch_embed_weight.dim() == 4:\n            # convert from Conv2d to Linear\n            state_dict[\"patch_embed.proj.weight\"] = rearrange(\n                patch_embed_weight, \"o c h w -> o (c h w)\"\n            )\n\n        def key_mapping_attn(key):\n            key = re.sub(r\"^blocks.(\\d+).attn.qkv.\", r\"blocks.\\1.mixer.Wqkv.\", key)\n            key = re.sub(r\"^blocks.(\\d+).attn.proj.\", r\"blocks.\\1.mixer.out_proj.\", key)\n            return key\n\n        state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())\n        n_layer = len(self.blocks)\n        # Convert from Wqkv to Wq and Wkv for cross attention (last layer)\n        if (\n            self.blocks[-1].mixer.cross_attn\n            and f\"blocks.{n_layer - 1}.mixer.Wqkv.weight\" in state_dict\n        ):\n            Wqkv = state_dict.pop(f\"blocks.{n_layer - 1}.mixer.Wqkv.weight\")\n            bqkv = state_dict.pop(f\"blocks.{n_layer - 1}.mixer.Wqkv.bias\")\n            state_dict[f\"blocks.{n_layer - 1}.mixer.Wq.weight\"] = Wqkv[: self.embed_dim]\n            state_dict[f\"blocks.{n_layer - 1}.mixer.Wkv.weight\"] = Wqkv[self.embed_dim :]\n            state_dict[f\"blocks.{n_layer - 1}.mixer.Wq.bias\"] = bqkv[: self.embed_dim]\n            state_dict[f\"blocks.{n_layer - 1}.mixer.Wkv.bias\"] = bqkv[self.embed_dim :]\n        return super().load_state_dict(state_dict, strict=strict)\n\n\ndef init_weights_vit_timm(module: nn.Module, name: str = \"\"):\n    \"\"\"ViT weight initialization, original timm impl (for reproducibility)\"\"\"\n    if isinstance(module, nn.Linear):\n        trunc_normal_(module.weight, std=0.02)\n        if module.bias is not None:\n            nn.init.zeros_(module.bias)\n    elif hasattr(module, \"init_weights\"):\n        module.init_weights()\n\n\ndef vit_base_patch16_224(pretrained=False, **kwargs):\n    \"\"\"ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).\n    ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.\n    \"\"\"\n    assert not pretrained\n    model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)\n    model = VisionTransformer(**model_kwargs)\n    return model\n"
  },
  {
    "path": "flash_attn/modules/__init__.py",
    "content": ""
  },
  {
    "path": "flash_attn/modules/block.py",
    "content": "# Copyright (c) 2024, Tri Dao.\n\nfrom functools import partial\nfrom typing import Optional\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom torchvision.ops import StochasticDepth\n\nfrom flash_attn.modules.mha import MHA\nfrom flash_attn.modules.mlp import Mlp\n\ntry:\n    from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm\nexcept ImportError:\n    layer_norm_fn, RMSNorm = None, None\n\n\nclass Block(nn.Module):\n    def __init__(\n        self,\n        dim,\n        mixer_cls=None,\n        mlp_cls=None,\n        norm_cls=nn.LayerNorm,\n        dropout_cls=nn.Dropout,\n        prenorm=True,\n        resid_dropout1=0.0,\n        resid_dropout2=0.0,\n        drop_path1=0.0,\n        drop_path2=0.0,\n        fused_dropout_add_ln=False,\n        return_residual=False,\n        residual_in_fp32=False,\n        sequence_parallel=False,\n        mark_shared_params=False,\n    ):\n        \"\"\"\n        For prenorm=True, this Block has a slightly different structure compared to a regular\n        prenorm Transformer block.\n        The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.\n        [Ref: https://arxiv.org/abs/2002.04745]\n        Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both\n        the hidden_states (output of the MLP) and the residual.\n        This is for performance reasons, as we can fuse the dropout, add and LayerNorm.\n        The residual needs to be provided (except for the very first block).\n\n        For prenorm=False, this Block has the same structure as a regular postnorm Transformer\n        block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN.\n\n        return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.\n        This is for performance reason: for post-norm architecture, returning the input allows us\n        to fuse the backward of nn.Linear with the residual connection.\n        \"\"\"\n        super().__init__()\n        self.prenorm = prenorm\n        self.fused_dropout_add_ln = fused_dropout_add_ln\n        self.return_residual = return_residual\n        self.residual_in_fp32 = residual_in_fp32\n        if self.residual_in_fp32:\n            assert self.prenorm, \"residual_in_fp32 is only compatible with prenorm=True\"\n        if mixer_cls is None:\n            mixer_cls = partial(MHA, num_heads=dim // 64)\n        if mlp_cls is None:\n            mlp_cls = partial(Mlp, hidden_features=4 * dim)\n        self.mixer = mixer_cls(dim)\n        self.dropout1 = dropout_cls(resid_dropout1)\n        self.drop_path1 = StochasticDepth(drop_path1, mode=\"row\")\n        self.norm1 = norm_cls(dim)\n        self.mlp = mlp_cls(dim)\n        if not isinstance(self.mlp, nn.Identity):\n            self.dropout2 = dropout_cls(resid_dropout2)\n            self.drop_path2 = StochasticDepth(drop_path2, mode=\"row\")\n            self.norm2 = norm_cls(dim)\n\n        if self.fused_dropout_add_ln:\n            assert layer_norm_fn is not None, \"Triton is not installed\"\n            assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(\n                self.dropout1, nn.Dropout\n            )\n\n        # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,\n        # then the input to each worker in the tensor parallel group will be different.\n        # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.\n        # For now this is not an issue because we always use sequence_parallel=True during training\n        # and only use sequence_parallel=False during inference.\n\n        # Mark the norm parameters as \"sequence_parallel\" so that we run all-reduce on their grads.\n        if sequence_parallel:\n            for p in self.norm1.parameters():\n                p._sequence_parallel = True\n            if hasattr(self, \"norm2\"):\n                for p in self.norm2.parameters():\n                    p._sequence_parallel = True\n        # Mark the norm parameters as \"shared_params\" so that we sync their values at init.\n        if mark_shared_params:\n            for p in self.norm1.parameters():\n                p._shared_params = True\n            if hasattr(self, \"norm2\"):\n                for p in self.norm2.parameters():\n                    p._shared_params = True\n\n    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):\n        return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)\n\n    def forward(\n        self,\n        hidden_states: Tensor,\n        residual: Optional[Tensor] = None,\n        mixer_subset=None,\n        mixer_kwargs=None,\n    ):\n        r\"\"\"Pass the input through the encoder layer.\n\n        Args:\n            hidden_states: the sequence to the encoder layer (required).\n            residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))\n            mixer_subset: for cross-attention only. If not None, will take a subset of x\n                before applying the query projection. Useful for e.g., ViT where we only care\n                about the CLS token in the last layer.\n        \"\"\"\n        if self.prenorm:\n            if not self.fused_dropout_add_ln:\n                dropped = self.drop_path1(self.dropout1(hidden_states))\n                residual = (dropped + residual) if residual is not None else dropped\n                hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))\n                if self.residual_in_fp32:\n                    residual = residual.to(torch.float32)\n            else:\n                if self.drop_path1.p == 0 or not self.training:\n                    rowscale1 = None\n                else:\n                    rowscale1 = self.drop_path1(\n                        torch.ones(\n                            hidden_states.shape[:-1],\n                            device=hidden_states.device,\n                            dtype=hidden_states.dtype,\n                        )\n                    )\n                hidden_states, residual = layer_norm_fn(\n                    hidden_states,\n                    self.norm1.weight,\n                    self.norm1.bias,\n                    residual=residual,\n                    eps=self.norm1.eps,\n                    dropout_p=self.dropout1.p if self.training else 0.0,\n                    rowscale=rowscale1,\n                    prenorm=True,\n                    residual_in_fp32=self.residual_in_fp32,\n                    is_rms_norm=isinstance(self.norm1, RMSNorm)\n                )\n            if mixer_kwargs is None:\n                mixer_kwargs = {}\n            if mixer_subset is not None:\n                mixer_kwargs[\"mixer_subset\"] = mixer_subset\n            hidden_states = self.mixer(hidden_states, **mixer_kwargs)\n            if mixer_subset is not None:\n                residual = residual[:, mixer_subset]\n            if not isinstance(self.mlp, nn.Identity):\n                if not self.fused_dropout_add_ln:\n                    dropped = self.drop_path2(self.dropout2(hidden_states))\n                    residual = (dropped + residual) if residual is not None else dropped\n                    hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))\n                    if self.residual_in_fp32:\n                        residual = residual.to(torch.float32)\n                else:\n                    if self.drop_path2.p == 0 or not self.training:\n                        rowscale2 = None\n                    else:\n                        rowscale2 = self.drop_path2(\n                            torch.ones(\n                                hidden_states.shape[:-1],\n                                device=hidden_states.device,\n                                dtype=hidden_states.dtype,\n                            )\n                        )\n                    hidden_states, residual = layer_norm_fn(\n                        hidden_states,\n                        self.norm2.weight,\n                        self.norm2.bias,\n                        residual=residual,\n                        eps=self.norm2.eps,\n                        dropout_p=self.dropout2.p if self.training else 0.0,\n                        rowscale=rowscale2,\n                        prenorm=True,\n                        residual_in_fp32=self.residual_in_fp32,\n                        is_rms_norm=isinstance(self.norm2, RMSNorm)\n                    )\n                hidden_states = self.mlp(hidden_states)\n            return hidden_states, residual\n        else:\n            assert residual is None\n            mixer_out = self.mixer(\n                hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {})\n            )\n            if self.return_residual:  # mixer out is actually a pair here\n                mixer_out, hidden_states = mixer_out\n            if not self.fused_dropout_add_ln:\n                hidden_states = self.norm1(\n                    (self.drop_path1(self.dropout1(mixer_out)) + hidden_states).to(\n                        dtype=self.norm1.weight.dtype\n                    )\n                )\n            else:\n                if self.drop_path1.p == 0 or not self.training:\n                    rowscale1 = None\n                else:\n                    rowscale1 = self.drop_path1(\n                        torch.ones(\n                            mixer_out.shape[:-1], device=mixer_out.device, dtype=mixer_out.dtype\n                        )\n                    )\n                hidden_states = layer_norm_fn(\n                    mixer_out,\n                    self.norm1.weight,\n                    self.norm1.bias,\n                    residual=hidden_states,\n                    eps=self.norm1.eps,\n                    dropout_p=self.dropout1.p if self.training else 0.0,\n                    rowscale=rowscale1,\n                    prenorm=False,\n                    is_rms_norm=isinstance(self.norm1, RMSNorm)\n                )\n            if not isinstance(self.mlp, nn.Identity):\n                mlp_out = self.mlp(hidden_states)\n                if self.return_residual:  # mlp out is actually a pair here\n                    mlp_out, hidden_states = mlp_out\n                if not self.fused_dropout_add_ln:\n                    hidden_states = self.norm2(\n                        (self.drop_path2(self.dropout2(mlp_out)) + hidden_states).to(\n                            dtype=self.norm2.weight.dtype\n                        )\n                    )\n                else:\n                    if self.drop_path2.p == 0 or not self.training:\n                        rowscale2 = None\n                    else:\n                        rowscale2 = self.drop_path2(\n                            torch.ones(\n                                mlp_out.shape[:-1], device=mlp_out.device, dtype=mlp_out.dtype\n                            )\n                        )\n                    hidden_states = layer_norm_fn(\n                        mlp_out,\n                        self.norm2.weight,\n                        self.norm2.bias,\n                        residual=hidden_states,\n                        eps=self.norm2.eps,\n                        dropout_p=self.dropout2.p if self.training else 0.0,\n                        rowscale=rowscale2,\n                        prenorm=False,\n                        is_rms_norm=isinstance(self.norm2, RMSNorm)\n                    )\n            return hidden_states\n\n\nclass ParallelBlock(nn.Module):\n    \"\"\"The attention (mixer) and MLP blocks are done in parallel, similar to GPT-J, GPT-NeoX,\n    and PaLM.\n    \"\"\"\n\n    def __init__(\n        self,\n        dim,\n        mixer_cls=None,\n        mlp_cls=None,\n        norm_cls=nn.LayerNorm,\n        dropout_cls=nn.Dropout,\n        resid_dropout1=0.0,\n        resid_dropout2=0.0,\n        tied_norm=False,\n        fused_dropout_add_ln=False,\n        residual_in_fp32=False,\n        sequence_parallel=False,\n        mark_shared_params=False,\n    ):\n        \"\"\"\n        This Block has a slightly different structure compared to a regular\n        prenorm Transformer block.\n        The standard block is: LN -> MHA / MLP -> Dropout -> Add.\n        [Ref: https://arxiv.org/abs/2002.04745]\n        Here we have: Dropout -> Add -> LN -> MHA / MLP, returning both\n        the hidden_states (output1 of the MHA / MLP) and the residual.\n        This is for performance reasons, as we can fuse the dropout, add and LayerNorm.\n        The residual needs to be provided (except for the very first block).\n        \"\"\"\n        super().__init__()\n        self.tied_norm = tied_norm\n        self.fused_dropout_add_ln = fused_dropout_add_ln\n        self.residual_in_fp32 = residual_in_fp32\n        if mixer_cls is None:\n            mixer_cls = partial(MHA, num_heads=dim // 64)\n        if mlp_cls is None:\n            mlp_cls = partial(Mlp, hidden_features=4 * dim)\n        self.mixer = mixer_cls(dim)\n        self.dropout1 = dropout_cls(resid_dropout1)\n        self.norm1 = norm_cls(dim)\n        self.mlp = mlp_cls(dim)\n        self.dropout2 = dropout_cls(resid_dropout2)\n        if not self.tied_norm:\n            self.norm2 = norm_cls(dim)\n\n        if self.fused_dropout_add_ln:\n            assert layer_norm_fn is not None, \"Triton is not installed\"\n            assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(\n                self.dropout1, nn.Dropout\n            )\n\n        # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,\n        # then the input to each worker in the tensor parallel group will be different.\n        # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.\n        # For now this is not an issue because we always use sequence_parallel=True during training\n        # and only use sequence_parallel=False during inference.\n\n        # Mark the norm parameters as \"sequence_parallel\" so that we run all-reduce on their grads.\n        if sequence_parallel:\n            for p in self.norm1.parameters():\n                p._sequence_parallel = True\n            if hasattr(self, \"norm2\"):\n                for p in self.norm2.parameters():\n                    p._sequence_parallel = True\n        # Mark the norm parameters as \"shared_params\" so that we sync their values at init.\n        if mark_shared_params:\n            for p in self.norm1.parameters():\n                p._shared_params = True\n            if hasattr(self, \"norm2\"):\n                for p in self.norm2.parameters():\n                    p._shared_params = True\n\n    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):\n        return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)\n\n    def forward(\n        self,\n        hidden_states1: Tensor,\n        hidden_states2: Optional[Tensor] = None,\n        residual: Optional[Tensor] = None,\n        mixer_kwargs=None,\n    ):\n        r\"\"\"Pass the input through the encoder layer.\n\n        Args:\n            hidden_states1: the output of the previous attention (mixer) or embedding layer.\n            hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).\n            residual.\n        \"\"\"\n        # TODO: Ideally we should only do the allgather / allreduce once for\n        # the Linear to MLP & Attention\n        if not self.fused_dropout_add_ln:\n            dropped1 = self.dropout1(hidden_states1)\n            # For the very 1st block, we only want 1 dropout, not two different dropouts\n            if hidden_states2 is not None:\n                dropped2 = self.dropout2(hidden_states2)\n                residual = (\n                    (residual + dropped1 + dropped2)\n                    if residual is not None\n                    else dropped1 + dropped2\n                )\n            else:\n                residual = (residual + dropped1) if residual is not None else dropped1\n            hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype))\n            hidden_states2 = (\n                self.norm2(residual.to(dtype=self.norm2.weight.dtype))\n                if not self.tied_norm\n                else hidden_states1\n            )\n            if self.residual_in_fp32:\n                residual = residual.to(torch.float32)\n        else:\n            weight2, bias2 = (\n                (self.norm2.weight, self.norm2.bias) if not self.tied_norm else (None, None)\n            )\n            hidden_states1, *rest, residual = layer_norm_fn(\n                hidden_states1,\n                self.norm1.weight,\n                self.norm1.bias,\n                residual=residual,\n                x1=hidden_states2,\n                weight1=weight2,\n                bias1=bias2,\n                eps=self.norm1.eps,\n                dropout_p=self.dropout1.p if self.training else 0.0,\n                prenorm=True,\n                residual_in_fp32=self.residual_in_fp32,\n                is_rms_norm=isinstance(self.norm1, RMSNorm)\n            )\n            if self.tied_norm:\n                hidden_states2 = hidden_states1\n            else:\n                hidden_states2, = rest\n        if mixer_kwargs is None:\n            mixer_kwargs = {}\n        hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs)\n        hidden_states2 = self.mlp(hidden_states2)\n        return hidden_states1, hidden_states2, residual\n"
  },
  {
    "path": "flash_attn/modules/embedding.py",
    "content": "# Copyright (c) 2022, Tri Dao.\n\nimport torch\nimport torch.nn as nn\nfrom einops import rearrange\nfrom torch import Tensor\n\nfrom flash_attn.utils.distributed import all_reduce, reduce_scatter\n\n\nclass GPT2Embeddings(nn.Module):\n    def __init__(\n        self,\n        embed_dim,\n        vocab_size,\n        max_position_embeddings,\n        padding_idx=None,\n        word_embed_proj_dim=None,\n        device=None,\n        dtype=None,\n    ):\n        \"\"\"\n        If max_position_embeddings <= 0, there's no position embeddings\n        If word_embe_proj_dim is not None (e.g., OPT-350m), we embed to that dimension\n            the project up to embed_dim\n        \"\"\"\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        if word_embed_proj_dim is None:\n            self.word_embeddings = nn.Embedding(\n                vocab_size, embed_dim, padding_idx=padding_idx, **factory_kwargs\n            )\n            self.project_in = None\n        else:\n            self.word_embeddings = nn.Embedding(\n                vocab_size, word_embed_proj_dim, padding_idx=padding_idx, **factory_kwargs\n            )\n            self.project_in = nn.Linear(\n                word_embed_proj_dim, embed_dim, bias=False, **factory_kwargs\n            )\n        self.max_position_embeddings = max_position_embeddings\n        if self.max_position_embeddings > 0:\n            self.position_embeddings = nn.Embedding(\n                max_position_embeddings, embed_dim, **factory_kwargs\n            )\n\n    def forward(self, input_ids, position_ids=None):\n        \"\"\"\n        input_ids: (batch, seqlen)\n        position_ids: (batch, seqlen)\n        \"\"\"\n        batch_size, seqlen = input_ids.shape\n        embeddings = self.word_embeddings(input_ids)\n        if self.project_in is not None:\n            embeddings = self.project_in(embeddings)\n        if self.max_position_embeddings > 0:\n            if position_ids is None:\n                position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)\n            position_embeddings = self.position_embeddings(position_ids)\n            embeddings = embeddings + position_embeddings\n        return embeddings\n\n\nclass BertEmbeddings(nn.Module):\n    def __init__(\n        self,\n        embed_dim,\n        vocab_size,\n        max_position_embeddings,\n        type_vocab_size,\n        padding_idx=None,\n        device=None,\n        dtype=None,\n    ):\n        \"\"\"\n        If max_position_embeddings <= 0, there's no position embeddings\n        If type_vocab_size <= 0, there's no token type embeddings\n        \"\"\"\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        self.word_embeddings = nn.Embedding(\n            vocab_size, embed_dim, padding_idx=padding_idx, **factory_kwargs\n        )\n        self.max_position_embeddings = max_position_embeddings\n        self.type_vocab_size = type_vocab_size\n        if self.max_position_embeddings > 0:\n            self.position_embeddings = nn.Embedding(\n                max_position_embeddings, embed_dim, **factory_kwargs\n            )\n        if self.type_vocab_size > 0:\n            self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs)\n\n    def forward(self, input_ids, position_ids=None, token_type_ids=None):\n        \"\"\"\n        input_ids: (batch, seqlen)\n        position_ids: (batch, seqlen)\n        token_type_ids: (batch, seqlen)\n        \"\"\"\n        batch_size, seqlen = input_ids.shape\n        embeddings = self.word_embeddings(input_ids)\n        if self.max_position_embeddings > 0:\n            if position_ids is None:\n                position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)\n            position_embeddings = self.position_embeddings(position_ids)\n            embeddings = embeddings + position_embeddings\n        if self.type_vocab_size > 0:\n            if token_type_ids is None:\n                token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)\n            token_type_embeddings = self.token_type_embeddings(token_type_ids)\n            embeddings = embeddings + token_type_embeddings\n        return embeddings\n\n\nclass VocabParallelEmbedding(nn.Embedding):\n    def __init__(self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs):\n        self.process_group = process_group\n        if process_group is not None:\n            world_size = torch.distributed.get_world_size(process_group)\n            if num_embeddings % world_size != 0:\n                raise ValueError(\n                    f\"num_embeddings ({num_embeddings}) must be divisible by \"\n                    f\"world_size ({world_size})\"\n                )\n            if world_size > 1 and padding_idx is not None:\n                raise RuntimeError(\"ParallelEmbedding does not support padding_idx\")\n        else:\n            world_size = 1\n        super().__init__(num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs)\n\n    def forward(self, input: Tensor) -> Tensor:\n        if self.process_group is None:\n            return super().forward(input)\n        else:\n            rank = torch.distributed.get_rank(self.process_group)\n            vocab_size = self.num_embeddings\n            vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size\n            # Create a mask of valid vocab ids (1 means it needs to be masked).\n            input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index)\n            input = input - vocab_start_index\n            input[input_ids_mask] = 0\n            embeddings = super().forward(input)\n            embeddings[input_ids_mask] = 0.0\n            return embeddings\n\n\nclass ColumnParallelEmbedding(nn.Embedding):\n    def __init__(self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs):\n        self.process_group = process_group\n        if process_group is not None:\n            world_size = torch.distributed.get_world_size(process_group)\n            if embedding_dim % world_size != 0:\n                raise ValueError(\n                    f\"embedding_dim ({embedding_dim}) must be divisible by \"\n                    f\"world_size ({world_size})\"\n                )\n        else:\n            world_size = 1\n        super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs)\n\n\nclass ParallelGPT2Embeddings(nn.Module):\n    def __init__(\n        self,\n        embed_dim,\n        vocab_size,\n        max_position_embeddings,\n        process_group,\n        padding_idx=None,\n        sequence_parallel=True,\n        device=None,\n        dtype=None,\n    ):\n        \"\"\"\n        If max_position_embeddings <= 0, there's no position embeddings\n        \"\"\"\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        self.process_group = process_group\n        self.sequence_parallel = sequence_parallel\n        self.word_embeddings = VocabParallelEmbedding(\n            vocab_size,\n            embed_dim,\n            padding_idx=padding_idx,\n            process_group=process_group,\n            **factory_kwargs,\n        )\n        self.max_position_embeddings = max_position_embeddings\n        if self.max_position_embeddings > 0:\n            self.position_embeddings = ColumnParallelEmbedding(\n                max_position_embeddings, embed_dim, process_group=process_group, **factory_kwargs\n            )\n\n    def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False):\n        \"\"\"\n        input_ids: (batch, seqlen)\n        position_ids: (batch, seqlen)\n        \"\"\"\n        batch_size, seqlen = input_ids.shape\n        world_size = torch.distributed.get_world_size(self.process_group)\n        embeddings = self.word_embeddings(input_ids)\n        if self.max_position_embeddings > 0:\n            if position_ids is None:\n                position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)\n            position_embeddings = self.position_embeddings(position_ids)\n            if world_size <= 1:\n                embeddings = embeddings + position_embeddings\n            else:\n                partition_dim = self.position_embeddings.embedding_dim\n                rank = torch.distributed.get_rank(self.process_group)\n                embeddings[\n                    ..., rank * partition_dim : (rank + 1) * partition_dim\n                ] += position_embeddings\n        if combine_batch_seqlen_dim:\n            embeddings = rearrange(embeddings, \"b s d -> (b s) d\")\n        reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce\n        return embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group)\n"
  },
  {
    "path": "flash_attn/modules/mha.py",
    "content": "# Copyright (c) 2023, Tri Dao.\n\nimport math\nfrom functools import partial\n\nimport torch\nimport torch.nn as nn\nfrom einops import rearrange, repeat\n\nfrom flash_attn.utils.distributed import get_dim_for_local_rank\n\ntry:\n    from flash_attn import (\n        flash_attn_kvpacked_func,\n        flash_attn_qkvpacked_func,\n        flash_attn_varlen_kvpacked_func,\n        flash_attn_varlen_qkvpacked_func,\n        flash_attn_with_kvcache,\n    )\nexcept ImportError:\n    flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None\n    flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None\n    flash_attn_with_kvcache = None\n\ntry:\n    from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear\nexcept ImportError:\n    ColumnParallelLinear, RowParallelLinear = None, None\n\ntry:\n    from flash_attn.layers.rotary import RotaryEmbedding\nexcept ImportError:\n    RotaryEmbedding = None\n\n\n# From https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742\ndef get_alibi_slopes(nheads):\n    def get_slopes_power_of_2(nheads):\n        start = 2 ** (-(2 ** -(math.log2(nheads) - 3)))\n        ratio = start\n        return [start * ratio**i for i in range(nheads)]\n\n    if math.log2(nheads).is_integer():\n        return get_slopes_power_of_2(nheads)\n    else:\n        closest_power_of_2 = 2 ** math.floor(math.log2(nheads))\n        return (\n            get_slopes_power_of_2(closest_power_of_2)\n            + get_alibi_slopes(2 * closest_power_of_2)[0::2][: nheads - closest_power_of_2]\n        )\n\n\nclass FlashSelfAttention(nn.Module):\n    \"\"\"Implement the scaled dot product attention with softmax.\n    Arguments\n    ---------\n        softmax_scale: The temperature to use for the softmax attention.\n                      (default: 1/sqrt(d_keys) where d_keys is computed at\n                      runtime)\n        attention_dropout: The dropout rate to apply to the attention\n                           (default: 0.0)\n    \"\"\"\n\n    def __init__(\n        self,\n        causal=False,\n        softmax_scale=None,\n        attention_dropout=0.0,\n        window_size=(-1, -1),\n        alibi_slopes=None,\n        deterministic=False,\n    ):\n        super().__init__()\n        assert flash_attn_varlen_qkvpacked_func is not None, \"FlashAttention is not installed\"\n        assert flash_attn_qkvpacked_func is not None, \"FlashAttention is not installed\"\n        self.causal = causal\n        self.softmax_scale = softmax_scale\n        self.drop = nn.Dropout(attention_dropout)\n        self.register_buffer(\"alibi_slopes\", alibi_slopes, persistent=False)\n        self.window_size = window_size\n        self.deterministic = deterministic\n\n    def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):\n        \"\"\"Implements the multihead softmax attention.\n        Arguments\n        ---------\n            qkv: The tensor containing the query, key, and value.\n                If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).\n                If cu_seqlens is not None and max_seqlen is not None, then qkv has shape\n                (total, 3, H, D), where total is the sum of the sequence lengths in the batch.\n            causal: if passed, will override self.causal\n            cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths\n                of the sequences in the batch, used to index into qkv.\n            max_seqlen: int. Maximum sequence length in the batch.\n        Returns:\n        --------\n            out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,\n                else (B, S, H, D).\n        \"\"\"\n        assert qkv.dtype in [torch.float16, torch.bfloat16]\n        assert qkv.is_cuda\n        causal = self.causal if causal is None else causal\n        unpadded = cu_seqlens is not None\n        if self.alibi_slopes is not None:\n            self.alibi_slopes = self.alibi_slopes.to(torch.float32)\n        if unpadded:\n            assert cu_seqlens.dtype == torch.int32\n            assert max_seqlen is not None\n            assert isinstance(max_seqlen, int)\n            return flash_attn_varlen_qkvpacked_func(\n                qkv,\n                cu_seqlens,\n                max_seqlen,\n                self.drop.p if self.training else 0.0,\n                softmax_scale=self.softmax_scale,\n                causal=causal,\n                alibi_slopes=self.alibi_slopes,\n                window_size=self.window_size,\n                deterministic=self.deterministic,\n            )\n        else:\n            return flash_attn_qkvpacked_func(\n                qkv,\n                self.drop.p if self.training else 0.0,\n                softmax_scale=self.softmax_scale,\n                causal=causal,\n                alibi_slopes=self.alibi_slopes,\n                window_size=self.window_size,\n                deterministic=self.deterministic,\n            )\n\n\nclass FlashCrossAttention(nn.Module):\n    \"\"\"Implement the scaled dot product attention with softmax.\n    Arguments\n    ---------\n        softmax_scale: The temperature to use for the softmax attention.\n                      (default: 1/sqrt(d_keys) where d_keys is computed at\n                      runtime)\n        attention_dropout: The dropout rate to apply to the attention\n                           (default: 0.0)\n    \"\"\"\n\n    def __init__(\n        self,\n        causal=False,\n        softmax_scale=None,\n        attention_dropout=0.0,\n        alibi_slopes=None,\n        window_size=(-1, -1),\n        deterministic=False,\n    ):\n        super().__init__()\n        assert flash_attn_varlen_kvpacked_func is not None, \"FlashAttention is not installed\"\n        assert flash_attn_kvpacked_func is not None, \"FlashAttention is not installed\"\n        self.causal = causal\n        self.softmax_scale = softmax_scale\n        self.drop = nn.Dropout(attention_dropout)\n        self.register_buffer(\"alibi_slopes\", alibi_slopes, persistent=False)\n        self.window_size = window_size\n        self.deterministic = deterministic\n\n    def forward(\n        self,\n        q,\n        kv,\n        causal=None,\n        cu_seqlens=None,\n        max_seqlen=None,\n        cu_seqlens_k=None,\n        max_seqlen_k=None,\n    ):\n        \"\"\"Implements the multihead softmax attention.\n        Arguments\n        ---------\n            q: The tensor containing the query. (B, Sq, H, D)\n            kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)\n            causal: if passed, will override self.causal\n            cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths\n                of the sequences in the batch, used to index into q.\n            max_seqlen: int. Maximum sequence length in the batch of q.\n            cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths\n                of the sequences in the batch, used to index into kv.\n            max_seqlen_k: int. Maximum sequence length in the batch of k and v.\n        \"\"\"\n        assert q.dtype in [torch.float16, torch.bfloat16]\n        assert q.is_cuda and kv.is_cuda\n        causal = self.causal if causal is None else causal\n        unpadded = cu_seqlens is not None\n        if self.alibi_slopes is not None:\n            self.alibi_slopes = self.alibi_slopes.to(torch.float32)\n        if unpadded:\n            assert cu_seqlens.dtype == torch.int32\n            assert max_seqlen is not None\n            assert isinstance(max_seqlen, int)\n            assert cu_seqlens_k is not None\n            assert cu_seqlens_k.dtype == torch.int32\n            assert max_seqlen_k is not None\n            assert isinstance(max_seqlen_k, int)\n            return flash_attn_varlen_kvpacked_func(\n                q,\n                kv,\n                cu_seqlens,\n                cu_seqlens_k,\n                max_seqlen,\n                max_seqlen_k,\n                self.drop.p if self.training else 0.0,\n                softmax_scale=self.softmax_scale,\n                causal=causal,\n                alibi_slopes=self.alibi_slopes,\n                window_size=self.window_size,\n                deterministic=self.deterministic,\n            )\n        else:\n            batch_size, seqlen_q = q.shape[0], q.shape[1]\n            seqlen_k = kv.shape[1]\n            assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]\n            return flash_attn_kvpacked_func(\n                q,\n                kv,\n                self.drop.p if self.training else 0.0,\n                causal=causal,\n                softmax_scale=self.softmax_scale,\n                alibi_slopes=self.alibi_slopes,\n                window_size=self.window_size,\n                deterministic=self.deterministic,\n            )\n\n\nclass SelfAttention(nn.Module):\n    \"\"\"Implement the scaled dot product attention with softmax.\n    Arguments\n    ---------\n        softmax_scale: The temperature to use for the softmax attention.\n                      (default: 1/sqrt(d_keys) where d_keys is computed at\n                      runtime)\n        attention_dropout: The dropout rate to apply to the attention\n                           (default: 0.0)\n    \"\"\"\n\n    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):\n        super().__init__()\n        self.causal = causal\n        self.softmax_scale = softmax_scale\n        self.drop = nn.Dropout(attention_dropout)\n\n    def forward(self, qkv, causal=None, key_padding_mask=None):\n        \"\"\"Implements the multihead softmax attention.\n        Arguments\n        ---------\n            qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)\n            causal: if passed, will override self.causal\n            key_padding_mask: boolean mask to apply to the attention weights. True means to keep,\n                False means to mask out. (B, S)\n        \"\"\"\n        batch_size, seqlen = qkv.shape[0], qkv.shape[1]\n        causal = self.causal if causal is None else causal\n        q, k, v = qkv.unbind(dim=2)\n        softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])\n        scores = torch.einsum(\"bthd,bshd->bhts\", q, k * softmax_scale)\n        if key_padding_mask is not None:\n            padding_mask = torch.full(\n                (batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device\n            )\n            padding_mask.masked_fill_(key_padding_mask, 0.0)\n            # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)\n            scores = scores + rearrange(padding_mask, \"b s -> b 1 1 s\")\n        if causal:\n            # \"triu_tril_cuda_template\" not implemented for 'BFloat16'\n            # So we have to construct the mask in float\n            causal_mask = torch.triu(\n                torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1\n            )\n            # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)\n            scores = scores + causal_mask.to(dtype=scores.dtype)\n        attention = torch.softmax(scores, dim=-1, dtype=v.dtype)\n        attention_drop = self.drop(attention)\n        output = torch.einsum(\"bhts,bshd->bthd\", attention_drop, v)\n        return output\n\n\nclass CrossAttention(nn.Module):\n    \"\"\"Implement the scaled dot product attention with softmax.\n    Arguments\n    ---------\n        softmax_scale: The temperature to use for the softmax attention.\n                      (default: 1/sqrt(d_keys) where d_keys is computed at\n                      runtime)\n        attention_dropout: The dropout rate to apply to the attention\n                           (default: 0.0)\n    \"\"\"\n\n    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):\n        super().__init__()\n        self.causal = causal\n        self.softmax_scale = softmax_scale\n        self.drop = nn.Dropout(attention_dropout)\n\n    def forward(self, q, kv, causal=None, key_padding_mask=None):\n        \"\"\"Implements the multihead softmax attention.\n        Arguments\n        ---------\n            q: The tensor containing the query. (B, Sq, H, D)\n            kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)\n            causal: if passed, will override self.causal\n            key_padding_mask: boolean mask to apply to the attention weights. True means to keep,\n                False means to mask out. (B, Sk)\n        \"\"\"\n        batch_size, seqlen_q = q.shape[0], q.shape[1]\n        causal = self.causal if causal is None else causal\n        seqlen_k = kv.shape[1]\n        assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]\n        if kv.shape[3] != q.shape[2]:  # MQA/GQA\n            kv = repeat(kv, \"... hkv d -> ... (hkv g) d\", g=q.shape[2] // kv.shape[3])\n        k, v = kv.unbind(dim=2)\n        softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])\n        scores = torch.einsum(\"bthd,bshd->bhts\", q, k * softmax_scale)\n        if key_padding_mask is not None:\n            padding_mask = torch.full(\n                (batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device\n            )\n            padding_mask.masked_fill_(key_padding_mask, 0.0)\n            # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)\n            scores = scores + rearrange(padding_mask, \"b s -> b 1 1 s\")\n        if causal:\n            # causal mask needs to take into account the difference between seqlen_q and seqlen_k\n            row_idx = rearrange(\n                torch.arange(seqlen_q, device=q.device, dtype=torch.long), \"s -> s 1\"\n            )\n            col_idx = torch.arange(seqlen_k, device=kv.device, dtype=torch.long)\n            sk = (\n                seqlen_k\n                if key_padding_mask is None\n                else rearrange(key_padding_mask.sum(-1), \"b -> b 1 1 1\")\n            )\n            causal_mask = col_idx > row_idx + sk - seqlen_q\n            scores = scores.masked_fill(causal_mask, -10000.0)\n        attention = torch.softmax(scores, dim=-1, dtype=v.dtype)\n        attention_drop = self.drop(attention)\n        output = torch.einsum(\"bhts,bshd->bthd\", attention_drop, v)\n        return output\n\n\ndef _update_kv_cache(kv, inference_params, layer_idx):\n    \"\"\"kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)\"\"\"\n    # Pre-allocate memory for key-values for inference.\n    num_heads, head_dim = kv.shape[-2:]\n    if layer_idx not in inference_params.key_value_memory_dict:\n        kv_cache = torch.empty(\n            inference_params.max_batch_size,\n            inference_params.max_seqlen,\n            2,\n            num_heads,\n            head_dim,\n            dtype=kv.dtype,\n            device=kv.device,\n        )\n        inference_params.key_value_memory_dict[layer_idx] = kv_cache\n    else:\n        kv_cache = inference_params.key_value_memory_dict[layer_idx]\n    # Adjust key and value for inference\n    batch_start = inference_params.batch_size_offset\n    batch_end = batch_start + kv.shape[0]\n    sequence_start = inference_params.seqlen_offset\n    sequence_end = sequence_start + kv.shape[1]\n    assert batch_end <= kv_cache.shape[0]\n    assert sequence_end <= kv_cache.shape[1]\n    assert kv_cache is not None\n    kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv\n    return kv_cache[batch_start:batch_end, :sequence_end, ...]\n\n\nclass MHA(nn.Module):\n    \"\"\"Multi-head self-attention and cross-attention\"\"\"\n\n    def __init__(\n        self,\n        embed_dim,\n        num_heads,\n        num_heads_kv=None,\n        cross_attn=False,\n        qkv_proj_bias=True,\n        out_proj_bias=True,\n        dropout=0.0,\n        softmax_scale=None,\n        causal=False,\n        layer_idx=None,\n        dwconv=False,\n        rotary_emb_dim=0,\n        rotary_emb_base=10000.0,\n        rotary_emb_scale_base=None,\n        rotary_emb_interleaved=False,\n        use_alibi=False,\n        window_size=(-1, -1),\n        fused_bias_fc=False,\n        use_flash_attn=False,\n        return_residual=False,\n        checkpointing=False,\n        device=None,\n        dtype=None,\n    ) -> None:\n        \"\"\"\n        num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.\n        return_residual: whether to return the input x along with the output. This is for\n            performance reason: for post-norm architecture, returning the input allows us\n            to fuse the backward of nn.Linear with the residual connection.\n        \"\"\"\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        self.embed_dim = embed_dim\n        self.cross_attn = cross_attn\n        self.causal = causal\n        self.layer_idx = layer_idx\n        self.dwconv = dwconv\n        self.rotary_emb_dim = rotary_emb_dim\n        self.use_flash_attn = use_flash_attn\n        self.return_residual = return_residual\n        self.checkpointing = checkpointing\n        if use_alibi:\n            assert use_flash_attn, \"ALiBi code path requires flash_attn\"\n            alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device)\n        else:\n            alibi_slopes = None\n        if window_size != (-1, -1):\n            assert use_flash_attn, \"Local (sliding window) attention code path requires flash_attn\"\n\n        self.num_heads = num_heads\n        self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads\n        assert (\n            self.num_heads % self.num_heads_kv == 0\n        ), \"num_heads must be divisible by num_heads_kv\"\n        assert self.embed_dim % num_heads == 0, \"embed_dim must be divisible by num_heads\"\n        self.head_dim = self.embed_dim // num_heads\n        qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)\n        kv_dim = 2 * self.head_dim * self.num_heads_kv\n\n        if self.rotary_emb_dim > 0:\n            assert not cross_attn, \"MHA with rotary embedding does not support cross-attention yet\"\n            assert RotaryEmbedding is not None, \"rotary_emb is not installed\"\n            self.rotary_emb = RotaryEmbedding(\n                self.rotary_emb_dim,\n                base=rotary_emb_base,\n                scale_base=rotary_emb_scale_base,\n                interleaved=rotary_emb_interleaved,\n                device=device,\n            )\n\n        inner_attn_cls = (\n            partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)\n            if use_flash_attn\n            else SelfAttention\n        )\n        inner_cross_attn_cls = (\n            partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size)\n            if use_flash_attn\n            else CrossAttention\n        )\n        if not self.cross_attn:\n            self.Wqkv = nn.Linear(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs)\n        else:\n            self.Wq = nn.Linear(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs)\n            self.Wkv = nn.Linear(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)\n        if self.dwconv:\n            if self.num_heads_kv == self.num_heads:\n                self.dwconv_qkv = nn.Conv1d(\n                    qkv_dim, qkv_dim, kernel_size=3, padding=2, groups=qkv_dim\n                )\n            else:\n                self.dwconv_q = nn.Conv1d(\n                    embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim\n                )\n                self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim)\n        self.inner_attn = inner_attn_cls(\n            causal=causal,\n            softmax_scale=softmax_scale,\n            attention_dropout=dropout,\n        )\n        self.inner_cross_attn = inner_cross_attn_cls(\n            causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout\n        )\n        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)\n\n    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):\n        dtype = self.out_proj.weight.dtype if dtype is None else dtype\n        device = self.out_proj.weight.device\n        return torch.empty(\n            batch_size,\n            max_seqlen,\n            2,\n            self.num_heads_kv,\n            self.head_dim,\n            dtype=dtype,\n            device=device,\n        )\n\n    def _update_kv_cache(self, kv, inference_params):\n        \"\"\"kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)\"\"\"\n        assert not self.dwconv, \"Generation does not support dwconv yet\"\n        assert self.layer_idx is not None, \"Generation requires layer_idx in the constructor\"\n        return _update_kv_cache(kv, inference_params, self.layer_idx)\n\n    def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):\n        \"\"\"\n        Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.\n        q: (batch_size, seqlen_q, nheads, head_dim)\n        kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)\n        \"\"\"\n        assert inference_params is not None and inference_params.seqlen_offset > 0\n        assert self.use_flash_attn\n        if self.rotary_emb_dim > 0:\n            assert self.rotary_emb.scale is None, \"This code path does not support xPos\"\n            self.rotary_emb._update_cos_sin_cache(\n                inference_params.max_seqlen, device=q.device, dtype=q.dtype\n            )\n            rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached\n        else:\n            rotary_cos, rotary_sin = None, None\n        batch = q.shape[0]\n        kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]\n        cache_seqlens = (\n            inference_params.lengths_per_sample[:batch]\n            if inference_params.lengths_per_sample is not None\n            else inference_params.seqlen_offset\n        )\n        alibi_slopes = getattr(self.inner_cross_attn, \"alibi_slopes\", None)\n        context = flash_attn_with_kvcache(\n            q,\n            kv_cache[:, :, 0],\n            kv_cache[:, :, 1],\n            kv[:, :, 0],\n            kv[:, :, 1],\n            rotary_cos=rotary_cos,\n            rotary_sin=rotary_sin,\n            cache_seqlens=cache_seqlens,\n            softmax_scale=self.inner_cross_attn.softmax_scale,\n            causal=self.inner_cross_attn.causal,\n            rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,\n            alibi_slopes=alibi_slopes,\n        )\n        return context\n\n    def _update_kvcache_attention(self, q, kv, inference_params):\n        \"\"\"Write kv to inference_params, then do attention\"\"\"\n        if (\n            inference_params.seqlen_offset == 0\n            or flash_attn_with_kvcache is None\n            or not self.use_flash_attn\n        ):\n            # TODO: this only uses seqlen_offset and not lengths_per_sample.\n            kv = self._update_kv_cache(kv, inference_params)\n            return self.inner_cross_attn(q, kv)\n        else:\n            batch = q.shape[0]\n            kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]\n            cache_seqlens = (\n                inference_params.lengths_per_sample[:batch]\n                if inference_params.lengths_per_sample is not None\n                else inference_params.seqlen_offset\n            )\n            alibi_slopes = getattr(self.inner_cross_attn, \"alibi_slopes\", None)\n            return flash_attn_with_kvcache(\n                q,\n                kv_cache[:, :, 0],\n                kv_cache[:, :, 1],\n                kv[:, :, 0],\n                kv[:, :, 1],\n                cache_seqlens=cache_seqlens,\n                softmax_scale=self.inner_cross_attn.softmax_scale,\n                causal=self.inner_cross_attn.causal,\n                alibi_slopes=alibi_slopes,\n            )\n\n    def forward(\n        self,\n        x,\n        x_kv=None,\n        key_padding_mask=None,\n        cu_seqlens=None,\n        max_seqlen=None,\n        mixer_subset=None,\n        inference_params=None,\n        **kwargs,\n    ):\n        \"\"\"\n        Arguments:\n            x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if\n                cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total\n                is the is the sum of the sequence lengths in the batch.\n            x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.\n            cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths\n                of the sequences in the batch, used to index into x. Only applicable when using\n                FlashAttention.\n            max_seqlen: int. Maximum sequence length in the batch.\n            key_padding_mask: boolean mask, True means to keep, False means to mask out.\n                (batch, seqlen). Only applicable when not using FlashAttention.\n            mixer_subset: for cross-attention only. If not None, will take a subset of x\n                before applying the query projection. Useful for e.g., ViT where we only care\n                about the CLS token in the last layer.\n            inference_params: for generation. Adapted from Megatron-LM (and Apex)\n            https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470\n        \"\"\"\n        if cu_seqlens is not None:\n            assert max_seqlen is not None\n            assert key_padding_mask is None\n            assert self.use_flash_attn\n            assert not self.dwconv\n            assert self.rotary_emb_dim == 0\n        if key_padding_mask is not None:\n            assert cu_seqlens is None\n            assert max_seqlen is None\n            assert not self.use_flash_attn\n        if inference_params is not None:\n            assert key_padding_mask is None\n            assert cu_seqlens is None and max_seqlen is None\n            assert not self.dwconv\n\n        kwargs = (\n            {\"cu_seqlens\": cu_seqlens, \"max_seqlen\": max_seqlen, **kwargs}\n            if self.use_flash_attn\n            else {\"key_padding_mask\": key_padding_mask, **kwargs}\n        )\n        seqlen_offset = (\n            0\n            if inference_params is None\n            else (\n                inference_params.lengths_per_sample\n                if inference_params.lengths_per_sample is not None\n                else inference_params.seqlen_offset\n            )\n        )\n        rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None\n        batch, seqlen = x.shape[:2]\n        if not self.cross_attn and self.num_heads_kv == self.num_heads:\n            assert x_kv is None and mixer_subset is None\n            qkv = self.Wqkv(x)\n            if self.dwconv:\n                qkv = rearrange(\n                    self.dwconv_qkv(rearrange(qkv, \"b s d -> b d s\"))[..., :-2], \"b d s -> b s d\"\n                ).contiguous()\n            qkv = rearrange(qkv, \"... (three h d) -> ... three h d\", three=3, d=self.head_dim)\n            if (\n                inference_params is None\n                or inference_params.seqlen_offset == 0\n                or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)\n                or not self.use_flash_attn\n            ):\n                if self.rotary_emb_dim > 0:\n                    qkv = self.rotary_emb(\n                        qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen\n                    )\n                if inference_params is None:\n                    if not self.checkpointing:\n                        context = self.inner_attn(qkv, **kwargs)\n                    else:\n                        context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)\n                else:\n                    context = self._update_kvcache_attention(\n                        qkv[:, :, 0], qkv[:, :, 1:], inference_params\n                    )\n            else:\n                context = self._apply_rotary_update_kvcache_attention(\n                    qkv[:, :, 0], qkv[:, :, 1:], inference_params\n                )\n        else:\n            if self.cross_attn:\n                q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])\n                kv = self.Wkv(x_kv if x_kv is not None else x)\n            else:\n                assert self.num_heads_kv != self.num_heads\n                qkv = self.Wqkv(x)\n                q = qkv[..., : self.num_heads * self.head_dim]\n                kv = qkv[..., self.num_heads * self.head_dim :]\n            q = rearrange(q, \"... (h d) -> ... h d\", d=self.head_dim)\n            kv = rearrange(kv, \"... (two hkv d) -> ... two hkv d\", two=2, d=self.head_dim)\n            if self.dwconv:\n                q = rearrange(\n                    self.dwconv_q(rearrange(q, \"b s d -> b d s\"))[..., :-2], \"b d s -> b s d\"\n                ).contiguous()\n                kv = rearrange(\n                    self.dwconv_kv(rearrange(kv, \"b s d -> b d s\"))[..., :-2], \"b d s -> b s d\"\n                ).contiguous()\n            if (\n                inference_params is None\n                or inference_params.seqlen_offset == 0\n                or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)\n                or not self.use_flash_attn\n            ):\n                if self.rotary_emb_dim > 0:\n                    q, kv = self.rotary_emb(\n                        q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen\n                    )\n                if inference_params is None:\n                    if not self.checkpointing:\n                        context = self.inner_cross_attn(q, kv, **kwargs)\n                    else:\n                        context = torch.utils.checkpoint.checkpoint(\n                            self.inner_cross_attn, q, kv, **kwargs\n                        )\n                else:\n                    context = self._update_kvcache_attention(q, kv, inference_params)\n            else:\n                context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)\n        out = self.out_proj(rearrange(context, \"... h d -> ... (h d)\"))\n        return out if not self.return_residual else (out, x)\n\n\nclass ParallelMHA(nn.Module):\n    \"\"\"Multi-head self-attention and cross-attention\"\"\"\n\n    def __init__(\n        self,\n        embed_dim,\n        num_heads,\n        process_group,\n        num_heads_kv=None,\n        qkv_proj_bias=True,\n        out_proj_bias=True,\n        dropout=0.0,\n        softmax_scale=None,\n        causal=False,\n        layer_idx=None,\n        rotary_emb_dim=0,\n        rotary_emb_base=10000.0,\n        rotary_emb_scale_base=None,\n        rotary_emb_interleaved=False,\n        use_alibi=False,\n        window_size=(-1, -1),\n        use_flash_attn=False,\n        checkpointing=False,\n        sequence_parallel=True,\n        device=None,\n        dtype=None,\n    ) -> None:\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        self.embed_dim = embed_dim\n        self.causal = causal\n        self.layer_idx = layer_idx\n        self.rotary_emb_dim = rotary_emb_dim\n        self.use_flash_attn = use_flash_attn\n        self.checkpointing = checkpointing\n        self.process_group = process_group\n        self.world_size = process_group.size()\n        self.local_rank = torch.distributed.get_rank(process_group)\n\n        self.num_heads = num_heads\n        assert self.embed_dim % self.num_heads == 0, \"embed_dim must be divisible by num_heads\"\n\n        self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads\n        assert (\n            self.num_heads % self.num_heads_kv == 0\n        ), \"num_heads must be divisible by num_heads_kv\"\n\n        self.num_heads_per_rank = get_dim_for_local_rank(\n            self.num_heads, self.world_size, self.local_rank\n        )\n        self.num_heads_kv_per_rank = get_dim_for_local_rank(\n            self.num_heads_kv, self.world_size, self.local_rank\n        )\n        self.head_dim = self.embed_dim // num_heads\n        qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)\n\n        if use_alibi:\n            assert use_flash_attn, \"ALiBi code path requires flash_attn\"\n            num_heads_local = math.ceil(self.num_heads / self.world_size)\n            alibi_slopes = torch.tensor(\n                get_alibi_slopes(num_heads)[\n                    self.local_rank * num_heads_local : (self.local_rank + 1) * num_heads_local\n                ],\n                device=device,\n            )\n        else:\n            alibi_slopes = None\n        if window_size != (-1, -1):\n            assert use_flash_attn, \"Local (sliding window) attention code path requires flash_attn\"\n\n        if self.rotary_emb_dim > 0:\n            assert RotaryEmbedding is not None, \"rotary_emb is not installed\"\n            self.rotary_emb = RotaryEmbedding(\n                self.rotary_emb_dim,\n                base=rotary_emb_base,\n                scale_base=rotary_emb_scale_base,\n                interleaved=rotary_emb_interleaved,\n                device=device,\n            )\n\n        if ColumnParallelLinear is None or RowParallelLinear is None:\n            raise ImportError(\"fused_dense is not installed\")\n        self.Wqkv = ColumnParallelLinear(\n            embed_dim,\n            qkv_dim,\n            process_group,\n            bias=qkv_proj_bias,\n            sequence_parallel=sequence_parallel,\n            multiple_of=self.head_dim * (self.num_heads // self.num_heads_kv + 2),\n            **factory_kwargs,\n        )\n        inner_attn_cls = (\n            partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)\n            if use_flash_attn\n            else SelfAttention\n        )\n        inner_cross_attn_cls = (\n            partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size)\n            if use_flash_attn\n            else CrossAttention\n        )\n        self.inner_attn = inner_attn_cls(\n            causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout\n        )\n        self.inner_cross_attn = inner_cross_attn_cls(\n            causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout\n        )\n        self.out_proj = RowParallelLinear(\n            embed_dim,\n            embed_dim,\n            process_group,\n            bias=out_proj_bias,\n            sequence_parallel=sequence_parallel,\n            multiple_of=self.head_dim,\n            **factory_kwargs,\n        )\n\n    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):\n        dtype = self.out_proj.weight.dtype if dtype is None else dtype\n        device = self.out_proj.weight.device\n        return torch.empty(\n            batch_size,\n            max_seqlen,\n            2,\n            self.num_heads_kv_per_rank,\n            self.head_dim,\n            dtype=dtype,\n            device=device,\n        )\n\n    def _update_kv_cache(self, kv, inference_params):\n        \"\"\"kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)\"\"\"\n        assert self.layer_idx is not None, \"Generation requires layer_idx in the constructor\"\n        return _update_kv_cache(kv, inference_params, self.layer_idx)\n\n    def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):\n        \"\"\"\n        Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.\n        q: (batch_size, seqlen_q, nheads, head_dim)\n        kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)\n        \"\"\"\n        assert inference_params is not None and inference_params.seqlen_offset > 0\n        assert self.use_flash_attn\n        if self.rotary_emb_dim > 0:\n            assert self.rotary_emb.scale is None, \"This code path does not support xPos\"\n            self.rotary_emb._update_cos_sin_cache(\n                inference_params.max_seqlen, device=q.device, dtype=q.dtype\n            )\n            rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached\n        else:\n            rotary_cos, rotary_sin = None, None\n        batch = q.shape[0]\n        kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]\n        cache_seqlens = (\n            inference_params.lengths_per_sample[:batch]\n            if inference_params.lengths_per_sample is not None\n            else inference_params.seqlen_offset\n        )\n        alibi_slopes = getattr(self.inner_cross_attn, \"alibi_slopes\", None)\n        context = flash_attn_with_kvcache(\n            q,\n            kv_cache[:, :, 0],\n            kv_cache[:, :, 1],\n            kv[:, :, 0],\n            kv[:, :, 1],\n            rotary_cos=rotary_cos,\n            rotary_sin=rotary_sin,\n            cache_seqlens=cache_seqlens,\n            softmax_scale=self.inner_cross_attn.softmax_scale,\n            causal=self.inner_cross_attn.causal,\n            rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,\n            alibi_slopes=alibi_slopes,\n        )\n        return context\n\n    def _update_kvcache_attention(self, q, kv, inference_params):\n        \"\"\"Write kv to inference_params, then do attention\"\"\"\n        if inference_params.seqlen_offset == 0 or not self.use_flash_attn:\n            # TODO: this only uses seqlen_offset and not lengths_per_sample.\n            kv = self._update_kv_cache(kv, inference_params)\n            return self.inner_cross_attn(q, kv)\n        else:\n            batch = q.shape[0]\n            kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]\n            cache_seqlens = (\n                inference_params.lengths_per_sample[:batch]\n                if inference_params.lengths_per_sample is not None\n                else inference_params.seqlen_offset\n            )\n            alibi_slopes = getattr(self.inner_cross_attn, \"alibi_slopes\", None)\n            context = flash_attn_with_kvcache(\n                q,\n                kv_cache[:, :, 0],\n                kv_cache[:, :, 1],\n                kv[:, :, 0],\n                kv[:, :, 1],\n                cache_seqlens=cache_seqlens,\n                softmax_scale=self.inner_cross_attn.softmax_scale,\n                causal=self.inner_cross_attn.causal,\n                alibi_slopes=alibi_slopes,\n            )\n            return context\n\n    def forward(self, x, seqlen=None, inference_params=None, **kwargs):\n        \"\"\"\n        Arguments:\n            x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.\n                If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we\n                split x during sequence parallel, we split the batch * seqlen dimension\n                (in case batch is small).\n        \"\"\"\n        qkv = self.Wqkv(x)\n        if seqlen is not None:\n            qkv = rearrange(qkv, \"(b s) ... -> b s ...\", s=seqlen)\n        seqlen_offset = (\n            0\n            if inference_params is None\n            else (\n                inference_params.lengths_per_sample\n                if inference_params.lengths_per_sample is not None\n                else inference_params.seqlen_offset\n            )\n        )\n        rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None\n        if self.num_heads_kv == self.num_heads:\n            qkv = rearrange(qkv, \"b s (three h d) -> b s three h d\", three=3, d=self.head_dim)\n            if (\n                inference_params is None\n                or inference_params.seqlen_offset == 0\n                or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)\n                or not self.use_flash_attn\n            ):\n                if self.rotary_emb_dim > 0:\n                    qkv = self.rotary_emb(\n                        qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen\n                    )\n                if inference_params is None:\n                    if not self.checkpointing:\n                        context = self.inner_attn(qkv, **kwargs)\n                    else:\n                        context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)\n                else:\n                    context = self._update_kvcache_attention(\n                        qkv[:, :, 0], qkv[:, :, 1:], inference_params\n                    )\n            else:\n                context = self._apply_rotary_update_kvcache_attention(\n                    qkv[:, :, 0], qkv[:, :, 1:], inference_params\n                )\n        else:\n            q = rearrange(\n                qkv[..., : self.num_heads_per_rank * self.head_dim],\n                \"... (h d) -> ... h d\",\n                d=self.head_dim,\n            )\n            kv = rearrange(\n                qkv[..., self.num_heads_per_rank * self.head_dim :],\n                \"... (two hkv d) -> ... two hkv d\",\n                two=2,\n                d=self.head_dim,\n            )\n            if (\n                inference_params is None\n                or inference_params.seqlen_offset == 0\n                or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)\n                or not self.use_flash_attn\n            ):\n                if self.rotary_emb_dim > 0:\n                    q, kv = self.rotary_emb(\n                        q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen\n                    )\n                if inference_params is None:\n                    if not self.checkpointing:\n                        context = self.inner_cross_attn(q, kv, **kwargs)\n                    else:\n                        context = torch.utils.checkpoint.checkpoint(\n                            self.inner_cross_attn, q, kv, **kwargs\n                        )\n                else:\n                    context = self._update_kvcache_attention(q, kv, inference_params)\n            else:\n                context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)\n        context = rearrange(context, \"b s h d -> b s (h d)\")\n        if seqlen is not None:\n            context = rearrange(context, \"b s d -> (b s) d\")\n        out = self.out_proj(context)\n        return out\n"
  },
  {
    "path": "flash_attn/modules/mlp.py",
    "content": "# Copyright (c) 2023, Tri Dao.\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.distributed import ProcessGroup\n\n\ntry:\n    from flash_attn.ops.activations import swiglu\nexcept ImportError:\n    swiglu = None\n\ntry:\n    from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear\nexcept ImportError:\n    ColumnParallelLinear, RowParallelLinear = None, None\n\ntry:\n    from flash_attn.ops.fused_dense import FusedMLP, ParallelFusedMLP\nexcept ImportError:\n    FusedMLP, ParallelFusedMLP = None, None\n\n\nclass Mlp(nn.Module):\n    def __init__(\n        self,\n        in_features,\n        hidden_features=None,\n        out_features=None,\n        activation=F.gelu,\n        bias1=True,\n        bias2=True,\n        return_residual=False,\n        device=None,\n        dtype=None,\n    ):\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        out_features = out_features if out_features is not None else in_features\n        hidden_features = hidden_features if hidden_features is not None else in_features * 4\n        self.return_residual = return_residual\n        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)\n        self.activation = activation\n        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)\n\n    def forward(self, x):\n        y = self.fc1(x)\n        y = self.activation(y)\n        y = self.fc2(y)\n        return y if not self.return_residual else (y, x)\n\n\nclass ParallelMLP(nn.Module):\n    def __init__(\n        self,\n        in_features,\n        hidden_features=None,\n        out_features=None,\n        activation=F.gelu,\n        process_group: ProcessGroup = None,\n        sequence_parallel=True,\n        bias1=True,\n        bias2=True,\n        device=None,\n        dtype=None,\n    ):\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        assert ColumnParallelLinear is not None, \"Need to install fused_dense\"\n        assert RowParallelLinear is not None, \"Need to install fused_dense\"\n        out_features = out_features if out_features is not None else in_features\n        hidden_features = hidden_features if hidden_features is not None else in_features * 4\n        self.fc1 = ColumnParallelLinear(\n            in_features,\n            hidden_features,\n            process_group,\n            bias=bias1,\n            sequence_parallel=sequence_parallel,\n            **factory_kwargs,\n        )\n        self.activation = activation\n        self.fc2 = RowParallelLinear(\n            hidden_features,\n            out_features,\n            process_group,\n            bias=bias2,\n            sequence_parallel=sequence_parallel,\n            **factory_kwargs,\n        )\n\n    def forward(self, x):\n        y = self.fc1(x)\n        y = self.activation(y)\n        y = self.fc2(y)\n        return y\n\n\nclass GatedMlp(nn.Module):\n    def __init__(\n        self,\n        in_features,\n        hidden_features=None,\n        out_features=None,\n        activation=F.sigmoid,\n        bias1=True,\n        bias2=True,\n        multiple_of=128,\n        return_residual=False,\n        device=None,\n        dtype=None,\n    ):\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        out_features = out_features if out_features is not None else in_features\n        hidden_features = (\n            hidden_features if hidden_features is not None else int(8 * in_features / 3)\n        )\n        hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of\n        self.return_residual = return_residual\n        self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias1, **factory_kwargs)\n        self.activation = activation\n        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)\n\n    def forward(self, x):\n        y = self.fc1(x)\n        if self.activation == F.sigmoid:  # Special case for GLU\n            y = F.glu(y, dim=-1)\n        elif self.activation == F.silu and swiglu is not None:  # Special case for SwiGLU\n            y, gate = y.chunk(2, dim=-1)\n            y = swiglu(gate, y)\n        else:\n            y, gate = y.chunk(2, dim=-1)\n            y = y * self.activation(gate)\n        y = self.fc2(y)\n        return y if not self.return_residual else (y, x)\n\n\nclass ParallelGatedMlp(nn.Module):\n    \"\"\"Parallel GatedMlp\"\"\"\n\n    def __init__(\n        self,\n        in_features,\n        process_group,\n        hidden_features=None,\n        out_features=None,\n        activation=F.sigmoid,\n        bias1=True,\n        bias2=True,\n        multiple_of=128,\n        sequence_parallel=True,\n        device=None,\n        dtype=None,\n    ):\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        out_features = out_features if out_features is not None else in_features\n        hidden_features = (\n            hidden_features if hidden_features is not None else int(8 * in_features / 3)\n        )\n        hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of\n        if ColumnParallelLinear is None or RowParallelLinear is None:\n            raise ImportError(\"fused_dense is not installed\")\n        self.fc1 = ColumnParallelLinear(\n            in_features,\n            2 * hidden_features,\n            process_group,\n            bias=bias1,\n            sequence_parallel=sequence_parallel,\n            **factory_kwargs,\n        )\n        self.activation = activation\n        self.fc2 = RowParallelLinear(\n            hidden_features,\n            out_features,\n            process_group,\n            bias=bias2,\n            sequence_parallel=sequence_parallel,\n            **factory_kwargs,\n        )\n\n    def forward(self, x):\n        y = self.fc1(x)\n        if self.activation == F.sigmoid:  # Special case for GLU\n            y = F.glu(y, dim=-1)\n        else:\n            y, gate = y.chunk(2, dim=-1)\n            y = y * self.activation(gate)\n        y = self.fc2(y)\n        return y\n"
  },
  {
    "path": "flash_attn/ops/__init__.py",
    "content": ""
  },
  {
    "path": "flash_attn/ops/activations.py",
    "content": "# Copied from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/model/layers/activations.py\nimport math\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n# 1/sqrt(2*pi)-> 0.3989423\n# 1/sqrt(2)   -> 0.70710678\n# sqrt(2/pi)  -> 0.79788456\n\n# this function is tanh approximation of gelu\n# actual gelu is:\n# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))\n@torch.jit.script\ndef bias_gelu(y, bias):\n    x = bias + y\n    return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=y.dtype)\n\n\n# gradient of tanh approximation of gelu\n# gradient of actual gelu is:\n# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)\n@torch.jit.script\ndef bias_gelu_back(g, y, bias):\n    \"\"\"Assume that y has shape (B, D) and bias has shape (D)\"\"\"\n    x = bias + y\n    tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))\n    # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243\n    ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (\n        1 + tanh_out\n    )\n    grad_y = ff * g\n    return grad_y.to(dtype=y.dtype), grad_y.sum(dim=(0), dtype=bias.dtype)\n\n\nclass GeLUFunction(torch.autograd.Function):\n    @staticmethod\n    # bias is an optional argument\n    def forward(ctx, input, bias):\n        ctx.save_for_backward(input, bias)\n        return bias_gelu(input, bias)\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        input, bias = ctx.saved_tensors\n        tmp = bias_gelu_back(grad_output, input, bias)\n        return tmp, tmp\n\n\nbias_gelu_impl = GeLUFunction.apply\n\n# this function is tanh approximation of gelu\n# actual gelu is:\n# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))\n@torch.jit.script\ndef gelu_fwd(x):\n    return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=x.dtype)\n\n\n# gradient of tanh approximation of gelu\n# gradient of actual gelu is:\n# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)\n@torch.jit.script\ndef gelu_bwd(g, x):\n    tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))\n    # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243\n    ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (\n        1 + tanh_out\n    )\n    return (ff * g).to(dtype=x.dtype)\n\n\nclass FastGeLUFunction(torch.autograd.Function):\n    @staticmethod\n    # bias is an optional argument\n    def forward(ctx, input):\n        ctx.save_for_backward(input)\n        return gelu_fwd(input)\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        (input,) = ctx.saved_tensors\n        tmp = gelu_bwd(grad_output, input)\n        return tmp\n\n\nfast_gelu_impl = FastGeLUFunction.apply\n\n\n@torch.jit.script\ndef relu_bwd(g, x):\n    return torch.where(x >= 0, g, 0.0).to(dtype=x.dtype)\n\n\n@torch.jit.script\ndef sqrelu_fwd(x):\n    r = F.relu(x)\n    return (r * r).to(dtype=x.dtype)\n\n\n@torch.jit.script\ndef sqrelu_bwd(g, x):\n    return (2.0 * g * F.relu(x)).to(dtype=x.dtype)\n\n\nswiglu_fwd_codestring = \"\"\"\ntemplate <typename T> T swiglu_fwd(T x, T y) {\n    return float(x) * float(y) / (1.0f + ::exp(-float(x)));\n}\n\"\"\"\nswiglu_bwd_codestring = \"\"\"\ntemplate <typename T> void swiglu_bwd(T x, T y, T g, T& dx, T& dy) {\n    float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));\n    dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y);\n    dy = float(x) * x_sigmoid * float(g);\n}\n\"\"\"\nswiglu_fwd = torch.cuda.jiterator._create_jit_fn(swiglu_fwd_codestring)\nswiglu_bwd = torch.cuda.jiterator._create_multi_output_jit_fn(swiglu_bwd_codestring, num_outputs=2)\n\n\nclass SwiGLUFunction(torch.autograd.Function):\n\n    @staticmethod\n    def forward(ctx, x, y):\n        ctx.save_for_backward(x, y)\n        return swiglu_fwd(x, y)\n\n    @staticmethod\n    def backward(ctx, dout):\n        x, y = ctx.saved_tensors\n        return swiglu_bwd(x, y, dout)\n\nswiglu = SwiGLUFunction.apply\n"
  },
  {
    "path": "flash_attn/ops/fused_dense.py",
    "content": "# Copyright (c) 2023, Tri Dao.\n# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/fused_dense/fused_dense.py\n# We make it work with pytorch amp and with bfloat16.\n# The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py\nfrom functools import partial\nfrom typing import Optional\n\n# import fused_dense_cuda  # from apex\nimport fused_dense_lib as fused_dense_cuda\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom torch.distributed import ProcessGroup\n\nfrom flash_attn.utils.torch import custom_fwd, custom_bwd\nfrom flash_attn.ops.activations import gelu_bwd, relu_bwd, sqrelu_bwd, sqrelu_fwd\nfrom flash_attn.utils.distributed import (\n    all_gather_raw,\n    all_reduce,\n    all_reduce_raw,\n    reduce_scatter,\n    reduce_scatter_raw,\n)\n\n\nclass FusedDenseFunc(torch.autograd.Function):\n    @staticmethod\n    @custom_fwd\n    def forward(\n        ctx, x, weight, bias, return_residual=False, process_group=None, sequence_parallel=True\n    ):\n        \"\"\"\n        If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel\n        with sequence parallelism: we do an all_gather_raw of x before doing the matmul.\n        \"\"\"\n        ctx.compute_weight_gradient = weight.requires_grad\n        ctx.return_residual = return_residual\n        ctx.process_group = process_group\n        ctx.sequence_parallel = sequence_parallel\n\n        if torch.is_autocast_enabled():\n            x = x.to(dtype=torch.get_autocast_gpu_dtype())\n        x = x.contiguous()\n        if process_group is not None and sequence_parallel:\n            # We want to kick off the all_gather early, before weight dtype conversion\n            total_x, handle_x = all_gather_raw(x, process_group, async_op=True)\n        else:\n            total_x = x\n\n        if torch.is_autocast_enabled():\n            weight = weight.to(dtype=torch.get_autocast_gpu_dtype())\n            bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None\n        weight = weight.contiguous()\n        if process_group is not None and sequence_parallel:\n            handle_x.wait()\n        batch_shape, n = total_x.shape[:-1], total_x.shape[-1]\n        batch_dim = batch_shape.numel()\n        # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174\n        if min(batch_dim, n, *weight.shape) > 65535 * 32:\n            raise RuntimeError(\"fused_dense only supports matrix dims <= 2M\")\n        output = F.linear(total_x, weight, bias)\n        if ctx.compute_weight_gradient:\n            ctx.save_for_backward(x, weight)\n        else:\n            ctx.save_for_backward(weight)\n        return output if not return_residual else (output, x)\n\n    @staticmethod\n    @custom_bwd\n    def backward(ctx, grad_output, *args):\n        grad_output = grad_output.contiguous()\n        if ctx.return_residual:\n            (grad_input,) = args\n            grad_input = grad_input.contiguous()\n        process_group = ctx.process_group\n        sequence_parallel = ctx.sequence_parallel\n        if ctx.compute_weight_gradient:\n            x, weight = ctx.saved_tensors\n            if process_group is not None and sequence_parallel:\n                total_x, handle_x = all_gather_raw(x, process_group, async_op=True)\n            else:\n                total_x = x\n        else:\n            (weight,) = ctx.saved_tensors\n            total_x = None\n        batch_shape = grad_output.shape[:-1]\n        batch_dim = batch_shape.numel()\n        grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])\n        if ctx.needs_input_grad[0]:\n            if not ctx.return_residual:\n                grad_input = F.linear(grad_output, weight.t())\n            else:\n                grad_input = torch.addmm(\n                    grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, weight\n                )\n            grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])\n            if process_group is not None:\n                reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw\n                grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)\n        else:\n            grad_input = None\n        if ctx.needs_input_grad[1]:\n            assert ctx.compute_weight_gradient\n            if process_group is not None and sequence_parallel:\n                handle_x.wait()\n            grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad(\n                total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]\n            )\n        else:\n            grad_weight = None\n            grad_bias = grad_output if ctx.needs_input_grad[2] else None\n        if process_group is not None and ctx.needs_input_grad[0]:\n            handle_grad_input.wait()\n        return grad_input, grad_weight, grad_bias, None, None, None\n\n\ndef fused_dense_func(\n    x: Tensor,\n    weight: Tensor,\n    bias: Optional[Tensor] = None,\n    return_residual: bool = False,\n    process_group: Optional[ProcessGroup] = None,\n    sequence_parallel: bool = True,\n):\n    dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or (\n        x.dtype == torch.float32 and torch.is_autocast_enabled()\n    )\n    if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:\n        return FusedDenseFunc.apply(\n            x, weight, bias, return_residual, process_group, sequence_parallel\n        )\n    else:\n        assert process_group is None\n        out = F.linear(x, weight, bias)\n        return out if not return_residual else (out, x)\n\n\nclass FusedDense(nn.Linear):\n    def __init__(\n        self,\n        in_features: int,\n        out_features: int,\n        bias: bool = True,\n        return_residual: bool = False,\n        device=None,\n        dtype=None,\n    ) -> None:\n        super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)\n        self.return_residual = return_residual\n\n    def forward(self, x, process_group=None):\n        \"\"\"\n        If process_group is not None, we're doing Tensor Parallel with sequence parallelism:\n        we do an all_gather of x before doing the matmul.\n        \"\"\"\n        return fused_dense_func(\n            x,\n            self.weight,\n            self.bias,\n            return_residual=self.return_residual,\n            process_group=process_group,\n        )\n\n\nclass ColumnParallelLinear(nn.Linear):\n    def __init__(\n        self,\n        in_features: int,\n        out_features: int,\n        process_group: ProcessGroup,\n        bias: bool = True,\n        sequence_parallel=True,\n        multiple_of=1,\n        device=None,\n        dtype=None,\n    ) -> None:\n        world_size = torch.distributed.get_world_size(process_group)\n        if out_features % multiple_of:\n            raise ValueError(f\"out_features ({out_features}) must be a multiple of {multiple_of}\")\n        multiple = out_features // multiple_of\n        # We want to split @multiple across world_size, but it could be an uneven split\n        div = multiple // world_size\n        mod = multiple % world_size\n        # The first @mod ranks get @div + 1 copies, the rest get @div copies\n        local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)\n        super().__init__(\n            in_features, local_multiple * multiple_of, bias=bias, device=device, dtype=dtype\n        )\n        self.process_group = process_group\n        self.sequence_parallel = sequence_parallel\n\n    def forward(self, x):\n        # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:\n        # we do an all_gather of x before doing the matmul.\n        # If not, then the input is already gathered.\n        return fused_dense_func(\n            x,\n            self.weight,\n            self.bias,\n            process_group=self.process_group,\n            sequence_parallel=self.sequence_parallel,\n        )\n\n\nclass RowParallelLinear(nn.Linear):\n    def __init__(\n        self,\n        in_features: int,\n        out_features: int,\n        process_group: ProcessGroup,\n        bias: bool = True,\n        sequence_parallel=True,\n        multiple_of=1,\n        device=None,\n        dtype=None,\n    ) -> None:\n        world_size = torch.distributed.get_world_size(process_group)\n        rank = torch.distributed.get_rank(process_group)\n        if in_features % multiple_of:\n            raise ValueError(f\"in_features ({in_features}) must be a multiple of {multiple_of}\")\n        multiple = in_features // multiple_of\n        # We want to split @multiple across world_size, but it could be an uneven split\n        div = multiple // world_size\n        mod = multiple % world_size\n        # The first @mod ranks get @div + 1 copies, the rest get @div copies\n        local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)\n        # Only rank 0 will have bias\n        super().__init__(\n            local_multiple * multiple_of,\n            out_features,\n            bias=bias and rank == 0,\n            device=device,\n            dtype=dtype,\n        )\n        self.process_group = process_group\n        self.sequence_parallel = sequence_parallel\n\n    def forward(self, x):\n        \"\"\"\n        We're doing Tensor Parallel with sequence parallelism: we do the matmul and then\n        a reduce_scatter of the result.\n        \"\"\"\n        out = fused_dense_func(x, self.weight, self.bias)\n        reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce\n        return reduce_fn(out, self.process_group)\n\n\nclass FusedMLPFunc(torch.autograd.Function):\n    @staticmethod\n    @custom_fwd\n    def forward(\n        ctx,\n        x,\n        weight1,\n        bias1,\n        weight2,\n        bias2,\n        activation=\"gelu_approx\",\n        save_pre_act=True,\n        return_residual=False,\n        checkpoint_lvl=0,\n        heuristic=0,\n        process_group=None,\n        sequence_parallel=True,\n    ):\n        \"\"\"\n        If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel\n        with sequence parallelism: we do an all_gather of x before doing the matmul.\n        If sequence_parallel=False, then the input is already gathered.\n\n        checkpoint_lvl:\n        0: no recomputation in the bwd\n        1: recompute gelu_out / relu_out in the bwd\n        2: recompute pre_act and gelu_out / relu_out in the bwd\n        \"\"\"\n        assert -1 <= heuristic <= 4\n        assert activation in [\"gelu_approx\", \"relu\", \"sqrelu\"]\n        if activation == \"sqrelu\":\n            assert heuristic == -1\n        if not save_pre_act:\n            checkpoint_lvl = 2\n        assert checkpoint_lvl in [0, 1, 2]\n        ctx.return_residual = return_residual\n        ctx.process_group = process_group\n        ctx.sequence_parallel = sequence_parallel\n        ctx.checkpoint_lvl = checkpoint_lvl\n        ctx.activation = activation\n        ctx.heuristic = heuristic\n\n        if torch.is_autocast_enabled():\n            x = x.to(dtype=torch.get_autocast_gpu_dtype())\n        x = x.contiguous()\n        if process_group is not None and sequence_parallel:\n            # We want to kick off the all_gather early, before weight dtype conversion\n            total_x, handle_x = all_gather_raw(x, process_group, async_op=True)\n        else:\n            total_x = x\n\n        if torch.is_autocast_enabled():\n            dtype = torch.get_autocast_gpu_dtype()\n            weight1, weight2 = [a.to(dtype=dtype) for a in [weight1, weight2]]\n            bias1 = bias1.to(dtype=dtype) if bias1 is not None else None\n            bias2 = bias2.to(dtype=dtype) if bias2 is not None else None\n        weight1 = weight1.contiguous()\n        bias1 = bias1.contiguous() if bias1 is not None else None\n        weight2 = weight2.contiguous()\n        bias2 = bias2.contiguous() if bias2 is not None else None\n        if process_group is not None and sequence_parallel:\n            handle_x.wait()\n        batch_shape, n = total_x.shape[:-1], total_x.shape[-1]\n        batch_dim = batch_shape.numel()\n        # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174\n        if min(batch_dim, n, *weight1.shape, *weight2.shape) > 65535 * 32:\n            raise RuntimeError(\"fused_dense only supports matrix dims <= 2M\")\n        if heuristic == -1:\n            pre_act = F.linear(total_x, weight1, bias1)\n            activation_fn = (\n                partial(F.gelu, approximate=\"tanh\")\n                if activation == \"gelu_approx\"\n                else (sqrelu_fwd if activation == \"sqrelu\" else F.relu)\n            )\n            with torch.jit.fuser(\"fuser2\"):\n                output1 = activation_fn(pre_act)\n            # This is before adding bias1\n            # pre_act = F.linear(total_x.reshape(batch_dim, n), weight1)\n            # with torch.jit.fuser('fuser2'):\n            #     output1 = bias_gelu(pre_act, bias1)\n        else:\n            is_gelu = activation == \"gelu_approx\"\n            output1, *rest = fused_dense_cuda.linear_act_forward(\n                total_x.reshape(batch_dim, n), weight1, bias1, is_gelu, save_pre_act, heuristic\n            )\n            if save_pre_act:\n                pre_act = rest[0]\n        output2 = F.linear(output1, weight2, bias2)\n        if checkpoint_lvl == 0 or (checkpoint_lvl == 1 and activation == \"relu\"):\n            # For RELU the pre_act is very small (just a bit-mask) so we just save it\n            ctx.save_for_backward(x, weight1, weight2, pre_act, output1)\n        elif checkpoint_lvl == 1:\n            ctx.save_for_backward(x, weight1, weight2, pre_act)\n        elif checkpoint_lvl == 2:\n            ctx.save_for_backward(x, weight1, weight2, bias1)\n        output2 = output2.reshape(*batch_shape, output2.shape[-1])\n        return output2 if not return_residual else (output2, x)\n\n    @staticmethod\n    @custom_bwd\n    def backward(ctx, grad_output, *args):\n        grad_output = grad_output.contiguous()\n        checkpoint_lvl = ctx.checkpoint_lvl\n        activation = ctx.activation\n        activation_fn = (\n            partial(F.gelu, approximate=\"tanh\")\n            if activation == \"gelu_approx\"\n            else (sqrelu_fwd if activation == \"sqrelu\" else F.relu)\n        )\n        if ctx.return_residual:\n            (grad_input,) = args\n            grad_input = grad_input.contiguous()\n        process_group = ctx.process_group\n        sequence_parallel = ctx.sequence_parallel\n        x, weight1, weight2, *rest = ctx.saved_tensors\n        if process_group is None or not sequence_parallel:\n            total_x = x\n        batch_shape = grad_output.shape[:-1]\n        batch_dim = batch_shape.numel()\n        if checkpoint_lvl in [0, 1]:\n            if process_group is not None and sequence_parallel:\n                total_x, handle_x = all_gather_raw(x, process_group, async_op=True)\n            if checkpoint_lvl == 0 or (checkpoint_lvl == 1 and activation == \"relu\"):\n                pre_act, output1 = rest\n            elif checkpoint_lvl == 1:\n                (pre_act,) = rest\n                with torch.jit.fuser(\"fuser2\"):\n                    output1 = activation_fn(pre_act)\n        elif checkpoint_lvl == 2:\n            (bias1,) = rest\n            if process_group is not None and sequence_parallel:\n                total_x, _ = all_gather_raw(x, process_group)\n            if ctx.heuristic == -1:\n                pre_act = F.linear(total_x, weight1, bias1)\n                with torch.jit.fuser(\"fuser2\"):\n                    output1 = activation_fn(pre_act)\n            else:\n                output1, pre_act = fused_dense_cuda.linear_act_forward(\n                    total_x.reshape(batch_dim, total_x.shape[-1]),\n                    weight1,\n                    bias1,\n                    activation == \"gelu_approx\",\n                    True,\n                    ctx.heuristic,\n                )\n\n        grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])\n        output1 = output1.reshape(batch_dim, output1.shape[-1])\n        pre_act = pre_act.reshape(batch_dim, pre_act.shape[-1])\n        if ctx.needs_input_grad[3]:\n            grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(\n                output1, grad_output, ctx.needs_input_grad[4]\n            )\n        else:\n            grad_weight2 = None\n            grad_bias2 = grad_output if ctx.needs_input_grad[4] else None\n        if ctx.heuristic == -1:\n            # grad_pre_act = matmul_dgelu(grad_output, weight2, pre_act)\n            grad_output1 = F.linear(grad_output, weight2.t())\n            activation_grad_fn = (\n                gelu_bwd\n                if activation == \"gelu_approx\"\n                else (sqrelu_bwd if activation == \"sqrelu\" else relu_bwd)\n            )\n            with torch.jit.fuser(\"fuser2\"):\n                grad_pre_act = activation_grad_fn(grad_output1, pre_act)\n        else:\n            # The cublasLt epilogue has to compute both gelu/relu grad and bias grad, we can't\n            # just compute gelu/relu grad\n            grad_pre_act, grad_bias1 = fused_dense_cuda.bias_act_linear_dgrad_bgrad(\n                weight2, grad_output, pre_act, activation == \"gelu_approx\", ctx.heuristic\n            )\n            if not ctx.needs_input_grad[2]:\n                grad_bias1 = None\n        if ctx.needs_input_grad[0]:\n            if not ctx.return_residual:\n                grad_input = F.linear(grad_pre_act, weight1.t())\n            else:\n                grad_input = torch.addmm(\n                    grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_pre_act, weight1\n                )\n            grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])\n            if process_group is not None:\n                reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw\n                grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)\n        else:\n            grad_input = None\n        if ctx.heuristic == -1:\n            if ctx.needs_input_grad[1]:\n                if process_group is not None and sequence_parallel and checkpoint_lvl != 2:\n                    handle_x.wait()\n                grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_wgrad(\n                    total_x.reshape(batch_dim, total_x.shape[-1]),\n                    grad_pre_act,\n                    ctx.needs_input_grad[2],\n                )\n            else:\n                grad_weight1 = None\n                grad_bias1 = grad_pre_act if ctx.needs_input_grad[2] else None\n        else:\n            if ctx.needs_input_grad[1]:\n                if process_group is not None and sequence_parallel and checkpoint_lvl != 2:\n                    handle_x.wait()\n                grad_weight1 = F.linear(\n                    grad_pre_act.t(), total_x.reshape(batch_dim, total_x.shape[-1]).t()\n                )\n            else:\n                grad_weight1 = None\n        if process_group is not None and ctx.needs_input_grad[0]:\n            handle_grad_input.wait()\n        return (\n            grad_input,\n            grad_weight1,\n            grad_bias1,\n            grad_weight2,\n            grad_bias2,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n        )\n\n\ndef fused_mlp_func(\n    x: Tensor,\n    weight1: Tensor,\n    weight2: Tensor,\n    bias1: Optional[Tensor] = None,\n    bias2: Optional[Tensor] = None,\n    activation: str = \"gelu_approx\",\n    save_pre_act: bool = True,\n    return_residual: bool = False,\n    checkpoint_lvl: int = 0,\n    heuristic: int = 0,\n    process_group: Optional[ProcessGroup] = None,\n    sequence_parallel: bool = True,\n):\n    assert activation in [\"gelu_approx\", \"relu\", \"sqrelu\"]\n    dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or (\n        x.dtype == torch.float32 and torch.is_autocast_enabled()\n    )\n    # If we save pre-activation, dimension must be divisible by 128 (relu) or 8 (gelu)\n    dim_eligible = not save_pre_act or (x.shape[-1] % (128 if activation == \"relu\" else 8) == 0)\n    if (\n        x.is_cuda\n        and weight1.is_cuda\n        and weight2.is_cuda\n        and (bias1 is None or bias1.is_cuda)\n        and (bias2 is None or bias2.is_cuda)\n        and dtype_eligible\n        and dim_eligible\n    ):\n        return FusedMLPFunc.apply(\n            x,\n            weight1,\n            bias1,\n            weight2,\n            bias2,\n            activation,\n            save_pre_act,\n            return_residual,\n            checkpoint_lvl,\n            heuristic,\n            process_group,\n            sequence_parallel,\n        )\n    else:\n        assert process_group is None\n        pre_act = F.linear(x, weight1, bias1)\n        activation_fn = (\n            partial(F.gelu, approximate=\"tanh\")\n            if activation == \"gelu_approx\"\n            else partial(F.relu, inplace=True)\n        )\n        output1 = activation_fn(pre_act)\n        output2 = F.linear(output1, weight2, bias2)\n        return output2 if not return_residual else (output2, x)\n\n\nclass FusedMLP(nn.Module):\n    def __init__(\n        self,\n        in_features,\n        hidden_features=None,\n        out_features=None,\n        bias1=True,\n        bias2=True,\n        activation=\"gelu_approx\",\n        return_residual=False,\n        checkpoint_lvl=0,\n        heuristic=\"auto\",\n        device=None,\n        dtype=None,\n    ):\n        \"\"\"\n        If process_group is not None, we're doing Tensor Parallel with sequence parallelism:\n        we do an all_gather of x before doing the matmul, gelu, then matmul.\n        Finally we do a reduce_scatter of the output.\n\n        checkpoint_lvl (increasing lvl means slower but more memory saving):\n            0: no recomputation in the bwd\n            1: recompute gelu_out in the bwd\n            2: recompute pre_act and gelu_out in the bwd\n        heuristic:\n            -1: don't fuse gemm + gelu (separate kernel)\n            0..4: use this heuristic for the algo section in the fused gemm + gelu\n            'auto': heuristic will be picked automatically:\n                For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf.\n                For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16.\n                For H100, we set heuristic=-1 for both fp16 and bf16 as the fused cuBlasLt implementation\n                is slower than the unfused version.\n        return_residual: whether to return the input x along with the output. This is for\n            performance reason: for post-norm architecture, returning the input allows us\n            to fuse the backward of nn.Linear with the residual connection.\n        \"\"\"\n        assert checkpoint_lvl in [0, 1, 2]\n        assert activation in [\"gelu_approx\", \"relu\", \"sqrelu\"]\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features * 4\n        self.activation = activation\n        self.return_residual = return_residual\n        self.checkpoint_lvl = checkpoint_lvl\n        self.heuristic = heuristic if activation != \"sqrelu\" else -1\n        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)\n        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)\n\n    def forward(self, x, process_group=None):\n        dtype = x.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype()\n        if self.heuristic == \"auto\":\n            if self.activation == \"gelu_approx\":\n                if torch.cuda.get_device_capability(\"cuda\") == (9, 0):\n                    heuristic = -1\n                else:\n                    cuda_ver = tuple(map(int, torch.version.cuda.split(\".\")))\n                    heuristic = 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1)\n            else:\n                heuristic = 0\n        else:\n            heuristic = self.heuristic\n        out = fused_mlp_func(\n            x,\n            self.fc1.weight,\n            self.fc2.weight,\n            self.fc1.bias,\n            self.fc2.bias,\n            activation=self.activation,\n            save_pre_act=self.training,\n            return_residual=self.return_residual,\n            checkpoint_lvl=self.checkpoint_lvl,\n            heuristic=heuristic,\n            process_group=process_group,\n        )\n        if self.return_residual:\n            out, x = out\n        if process_group is not None:\n            out = reduce_scatter(out, process_group)\n        return out if not self.return_residual else (out, x)\n\n\nclass ParallelFusedMLP(nn.Module):\n    def __init__(\n        self,\n        in_features,\n        hidden_features=None,\n        out_features=None,\n        activation=\"gelu_approx\",\n        process_group: ProcessGroup = None,\n        bias1=True,\n        bias2=True,\n        sequence_parallel=True,\n        checkpoint_lvl=0,\n        heuristic=\"auto\",\n        device=None,\n        dtype=None,\n    ):\n        \"\"\"\n        process_group is required. We're doing Tensor Parallel with sequence parallelism:\n        we do an all_gather of x before doing the matmul, gelu, then matmul.\n        Finally we do a reduce_scatter of the output.\n\n        checkpoint_lvl (increasing lvl means slower but more memory saving):\n            0: no recomputation in the bwd\n            1: recompute gelu_out in the bwd\n            2: recompute pre_act and gelu_out in the bwd\n        heuristic:\n            -1: don't fuse gemm + gelu (separate kernel)\n            0..4: use this heuristic for the algo section in the fused gemm + gelu\n            'auto': heuristic will be picked automatically:\n                For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf.\n                For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16.\n        \"\"\"\n        assert checkpoint_lvl in [0, 1, 2]\n        assert activation in [\"gelu_approx\", \"relu\", \"sqrelu\"]\n        assert process_group is not None\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features * 4\n        self.activation = activation\n        self.process_group = process_group\n        self.sequence_parallel = sequence_parallel\n        self.checkpoint_lvl = checkpoint_lvl\n        self.heuristic = heuristic if activation != \"sqrelu\" else -1\n        self.fc1 = ColumnParallelLinear(\n            in_features, hidden_features, process_group, bias=bias1, **factory_kwargs\n        )\n        self.fc2 = RowParallelLinear(\n            hidden_features, out_features, process_group, bias=bias2, **factory_kwargs\n        )\n\n    def forward(self, x):\n        dtype = x.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype()\n        if self.heuristic == \"auto\":\n            if self.activation == \"gelu_approx\":\n                cuda_ver = tuple(map(int, torch.version.cuda.split(\".\")))\n                heuristic = 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1)\n            else:\n                heuristic = 0\n        else:\n            heuristic = self.heuristic\n        out = fused_mlp_func(\n            x,\n            self.fc1.weight,\n            self.fc2.weight,\n            self.fc1.bias,\n            self.fc2.bias,\n            activation=self.activation,\n            save_pre_act=self.training,\n            checkpoint_lvl=self.checkpoint_lvl,\n            heuristic=heuristic,\n            process_group=self.process_group,\n            sequence_parallel=self.sequence_parallel,\n        )\n        reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce\n        return reduce_fn(out, self.process_group)\n"
  },
  {
    "path": "flash_attn/ops/layer_norm.py",
    "content": "# Copyright (c) 2022, Tri Dao.\n# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py\n\nimport dropout_layer_norm\nimport torch\nfrom torch.nn import init\n\n\ndef maybe_align(x, alignment_in_bytes=16):\n    \"\"\"Assume that x already has last dim divisible by alignment_in_bytes\"\"\"\n    # TD [2023-07-04] I'm not 100% sure that clone will align the memory\n    # https://discuss.pytorch.org/t/how-to-ensure-that-tensor-data-ptr-is-aligned-to-16-bytes/183440\n    return x if x.data_ptr() % alignment_in_bytes == 0 else x.clone()\n\n\ndef _dropout_add_layer_norm_forward(\n    x0,\n    residual,\n    gamma,\n    beta,\n    rowscale,\n    colscale,\n    dropout_p,\n    epsilon,\n    residual_in_fp32=False,\n    is_rms_norm=False,\n):\n    \"\"\"Assume that arguments are contiguous and aligned to 16 bytes\"\"\"\n    hidden_size = gamma.numel()\n    x0mat = x0.view((-1, hidden_size))\n    residualmat = residual.view((-1, hidden_size)) if residual is not None else None\n    rowscale = rowscale.view(-1) if rowscale is not None else None\n    zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(\n        x0mat,\n        residualmat,\n        gamma,\n        beta,\n        rowscale,\n        colscale,\n        None,\n        None,\n        dropout_p,\n        epsilon,\n        1.0,\n        0,\n        None,\n        residual_in_fp32,\n        is_rms_norm,\n    )\n    # dmask is None if dropout_p == 0.0\n    # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype\n    return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma\n\n\ndef _dropout_add_layer_norm_backward(\n    dz,\n    dx,\n    x,\n    x0,\n    dmask,\n    mu,\n    rsigma,\n    gamma,\n    rowscale,\n    colscale,\n    dropout_p,\n    has_residual,\n    is_rms_norm=False,\n):\n    \"\"\"Assume that arguments are contiguous and aligned to 16 bytes\n    dx == None means that it was a post-norm architecture\n    (x = drop(x0) + residual was not returned in the fwd).\n    x0 must not be None if we have colscale.\n    \"\"\"\n    hidden_size = gamma.numel()\n    xmat = x.view((-1, hidden_size))\n    dzmat = dz.view(xmat.shape)\n    dxmat = dx.view(xmat.shape) if dx is not None else None\n    x0mat = x0.view((-1, hidden_size)) if x0 is not None else None\n    rowscale = rowscale.view(-1) if rowscale is not None else None\n    if colscale is not None:\n        assert x0 is not None, \"x0 is required to compute the gradient of colscale\"\n    dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(\n        dzmat,\n        dxmat,\n        xmat,\n        x0mat,\n        dmask,\n        mu,\n        rsigma,\n        gamma,\n        rowscale,\n        colscale,\n        None,\n        None,\n        dropout_p,\n        1.0,\n        0,\n        has_residual,\n        is_rms_norm,\n    )\n    # dresidualmat is None if not has_residual\n    if colscale is None:\n        return dx0mat, dresidualmat, dgamma, dbeta\n    else:\n        dcolscale = rest[0]\n        return dx0mat, dresidualmat, dgamma, dbeta, dcolscale\n\n\ndef _dropout_add_layer_norm_subset_forward(\n    x0,\n    residual,\n    gamma,\n    beta,\n    colscale,\n    x0_subset,\n    out_subset,\n    dropout_p,\n    epsilon,\n    rowscale_const,\n    out_numrows,\n    residual_in_fp32=False,\n    is_rms_norm=False,\n):\n    \"\"\"Assume that arguments are contiguous and aligned to 16 bytes\"\"\"\n    hidden_size = gamma.numel()\n    x0mat = x0.view((-1, hidden_size))\n    residualmat = residual.view((-1, hidden_size)) if residual is not None else None\n    x0_subset = x0_subset.view(-1) if x0_subset is not None else None\n    out_subset = out_subset.view(-1) if out_subset is not None else None\n    zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(\n        x0mat,\n        residualmat,\n        gamma,\n        beta,\n        None,\n        colscale,\n        x0_subset,\n        out_subset,\n        dropout_p,\n        epsilon,\n        rowscale_const,\n        out_numrows,\n        None,\n        residual_in_fp32,\n        is_rms_norm,\n    )\n    # dmask is None if dropout_p == 0.0\n    # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype\n    return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma\n\n\ndef _dropout_add_layer_norm_subset_backward(\n    dz,\n    dx,\n    x,\n    x0,\n    dmask,\n    mu,\n    rsigma,\n    gamma,\n    colscale,\n    x0_subset,\n    out_subset,\n    dropout_p,\n    rowscale_const,\n    x0_numrows,\n    has_residual,\n    is_rms_norm=False,\n):\n    \"\"\"Assume that arguments are contiguous and aligned to 16 bytes\n    dx == None means that it was a post-norm architecture\n    (x = drop(x0) + residual was not returned in the fwd).\n    x0 must not be None if we have colscale.\n    \"\"\"\n    hidden_size = gamma.numel()\n    xmat = x.view((-1, hidden_size))\n    dzmat = dz.view(-1, hidden_size)\n    dxmat = dx.view(xmat.shape) if dx is not None else None\n    x0mat = x0.view((-1, hidden_size)) if x0 is not None else None\n    x0_subset = x0_subset.view(-1) if x0_subset is not None else None\n    out_subset = out_subset.view(-1) if out_subset is not None else None\n    if colscale is not None:\n        assert x0 is not None, \"x0 is required to compute the gradient of colscale\"\n    dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(\n        dzmat,\n        dxmat,\n        xmat,\n        x0mat,\n        dmask,\n        mu,\n        rsigma,\n        gamma,\n        None,\n        colscale,\n        x0_subset,\n        out_subset,\n        dropout_p,\n        rowscale_const,\n        x0_numrows,\n        has_residual,\n        is_rms_norm,\n    )\n    # dresidualmat is None if not has_residual\n    if colscale is None:\n        return dx0mat, dresidualmat, dgamma, dbeta\n    else:\n        dcolscale = rest[0]\n        return dx0mat, dresidualmat, dgamma, dbeta, dcolscale\n\n\ndef _dropout_add_layer_norm_parallel_residual_forward(\n    x0,\n    x1,\n    residual,\n    gamma0,\n    beta0,\n    gamma1,\n    beta1,\n    dropout_p,\n    epsilon,\n    residual_in_fp32=False,\n    is_rms_norm=False,\n):\n    \"\"\"Assume that arguments are contiguous and aligned to 16 bytes\"\"\"\n    hidden_size = gamma0.numel()\n    x0mat = x0.view((-1, hidden_size))\n    x1mat = x1.view((-1, hidden_size)) if x1 is not None else None\n    residualmat = residual.view((-1, hidden_size)) if residual is not None else None\n    (\n        z0mat,\n        z1mat,\n        xmat,\n        dmask0,\n        dmask1,\n        mu,\n        rsigma,\n    ) = dropout_layer_norm.dropout_add_ln_parallel_residual_fwd(\n        x0mat,\n        x1mat,\n        residualmat,\n        gamma0,\n        beta0,\n        gamma1,\n        beta1,\n        dropout_p,\n        epsilon,\n        None,\n        residual_in_fp32,\n        is_rms_norm,\n    )\n    # dmask0 and dmask1 are None if dropout_p == 0.0\n    # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype\n    return z0mat, z1mat, xmat if xmat is not None else x0mat, dmask0, dmask1, mu, rsigma\n\n\ndef _dropout_add_layer_norm_parallel_residual_backward(\n    dz0,\n    dz1,\n    dx,\n    x,\n    dmask0,\n    dmask1,\n    mu,\n    rsigma,\n    gamma0,\n    gamma1,\n    dropout_p,\n    has_x1,\n    has_residual,\n    is_rms_norm=False,\n):\n    \"\"\"Assume that arguments are contiguous and aligned to 16 bytes\n    dx == None means that it was a post-norm architecture\n    (x = drop(x0) + residual was not returned in the fwd).\n    \"\"\"\n    hidden_size = gamma0.numel()\n    xmat = x.view((-1, hidden_size))\n    dz0mat = dz0.view(xmat.shape)\n    dz1mat = dz1.view(xmat.shape) if dz1 is not None else None\n    dxmat = dx.view(xmat.shape) if dx is not None else None\n    (\n        dx0mat,\n        dx1mat,\n        dresidualmat,\n        dgamma0,\n        dbeta0,\n        dgamma1,\n        dbeta1,\n        *rest,\n    ) = dropout_layer_norm.dropout_add_ln_parallel_residual_bwd(\n        dz0mat,\n        dz1mat,\n        dxmat,\n        xmat,\n        dmask0,\n        dmask1,\n        mu,\n        rsigma,\n        gamma0,\n        gamma1,\n        dropout_p,\n        has_x1,\n        has_residual,\n        is_rms_norm,\n    )\n    # dresidualmat is None if not has_residual\n    return dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1\n\n\nclass DropoutAddLayerNormFn(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        x0,\n        residual,\n        gamma,\n        beta,\n        rowscale,\n        colscale,\n        dropout_p,\n        epsilon,\n        residual_in_fp32=False,\n        prenorm=False,\n        is_rms_norm=False,\n        return_dmask=False,\n    ):\n        x0 = maybe_align(x0.contiguous(), 16)\n        residual = maybe_align(residual.contiguous(), 16) if residual is not None else None\n        gamma = maybe_align(gamma.contiguous(), 16)\n        beta = maybe_align(beta.contiguous(), 16) if beta is not None else None\n        rowscale = maybe_align(rowscale.contiguous(), 16) if rowscale is not None else None\n        colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None\n        zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(\n            x0,\n            residual,\n            gamma,\n            beta,\n            rowscale,\n            colscale,\n            dropout_p,\n            epsilon,\n            residual_in_fp32,\n            is_rms_norm,\n        )\n        # Only need to save x0 if we need to compute gradient wrt colscale\n        x0_saved = x0 if colscale is not None else None\n        ctx.save_for_backward(\n            xmat.view(x0.shape), x0_saved, dmask, gamma, mu, rsigma, rowscale, colscale\n        )\n        ctx.prenorm = prenorm\n        ctx.dropout_p = dropout_p\n        ctx.has_residual = residual is not None\n        ctx.is_rms_norm = is_rms_norm\n        ctx.has_beta = beta is not None\n        if not return_dmask:\n            return (\n                zmat.view(x0.shape) if not prenorm else (zmat.view(x0.shape), xmat.view(x0.shape))\n            )\n        else:\n            dmask = (\n                dmask.view(x0.shape)\n                if dropout_p > 0.0\n                else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)\n            )\n            ctx.mark_non_differentiable(dmask)\n            return (\n                (zmat.view(x0.shape), dmask)\n                if not prenorm\n                else (zmat.view(x0.shape), xmat.view(x0.shape), dmask)\n            )\n\n    @staticmethod\n    def backward(ctx, dz, *args):\n        # assert dz.is_contiguous()\n        dz = maybe_align(dz.contiguous(), 16)  # this happens!\n        dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None\n        x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors\n        # x0 is None if colscale is None\n        dropout_p = ctx.dropout_p\n        has_residual = ctx.has_residual\n        dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward(\n            dz,\n            dx,\n            x,\n            x0,\n            dmask,\n            mu,\n            rsigma,\n            gamma,\n            rowscale,\n            colscale,\n            dropout_p,\n            has_residual,\n            ctx.is_rms_norm,\n        )\n        dx0 = dx0mat.view(x.shape)\n        dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None\n        dcolscale = rest[0] if colscale is not None else None\n        return (\n            dx0,\n            dresidual,\n            dgamma,\n            dbeta if ctx.has_beta else None,\n            None,\n            dcolscale,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n        )\n\n\nclass DropoutAddLayerNormSubsetFn(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        x0,\n        residual,\n        gamma,\n        beta,\n        colscale,\n        x0_subset,\n        out_subset,\n        dropout_p,\n        epsilon,\n        rowscale_const,\n        out_numrows,\n        residual_in_fp32=False,\n        prenorm=False,\n        is_rms_norm=False,\n        return_dmask=False,\n    ):\n        x0 = maybe_align(x0.contiguous(), 16)\n        residual = maybe_align(residual.contiguous(), 16) if residual is not None else None\n        gamma = maybe_align(gamma.contiguous(), 16)\n        beta = maybe_align(beta.contiguous(), 16) if beta is not None else None\n        colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None\n        zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward(\n            x0,\n            residual,\n            gamma,\n            beta,\n            colscale,\n            x0_subset,\n            out_subset,\n            dropout_p,\n            epsilon,\n            rowscale_const,\n            out_numrows,\n            residual_in_fp32,\n            is_rms_norm,\n        )\n        # Only need to save x0 if we need to compute gradient wrt colscale\n        x0_saved = x0 if colscale is not None else None\n        x_shape = (-1, *x0.shape[1:])\n        ctx.save_for_backward(\n            xmat.view(x_shape), x0_saved, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset\n        )\n        ctx.prenorm = prenorm\n        ctx.dropout_p = dropout_p\n        ctx.rowscale_const = rowscale_const\n        ctx.x0_numrows = x0.shape[:-1].numel()\n        ctx.has_residual = residual is not None\n        ctx.is_rms_norm = is_rms_norm\n        ctx.has_beta = beta is not None\n        z_shape = (-1, *x0.shape[1:])\n        if not return_dmask:\n            return zmat.view(z_shape) if not prenorm else (zmat.view(z_shape), xmat.view(x0.shape))\n        else:\n            z = zmat.view(z_shape)\n            dmask = (\n                dmask.view(x0.shape)\n                if dropout_p > 0.0\n                else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)\n            )\n            ctx.mark_non_differentiable(dmask)\n            return (z, dmask) if not prenorm else (z, xmat.view(x_shape), dmask)\n\n    @staticmethod\n    def backward(ctx, dz, *args):\n        # assert dz.is_contiguous()\n        dz = maybe_align(dz.contiguous(), 16)  # this happens!\n        dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None\n        x, x0, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset = ctx.saved_tensors\n        # x0 is None if colscale is None\n        dropout_p = ctx.dropout_p\n        has_residual = ctx.has_residual\n        dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward(\n            dz,\n            dx,\n            x,\n            x0,\n            dmask,\n            mu,\n            rsigma,\n            gamma,\n            colscale,\n            x0_subset,\n            out_subset,\n            dropout_p,\n            ctx.rowscale_const,\n            ctx.x0_numrows,\n            has_residual,\n            ctx.is_rms_norm,\n        )\n        dx0 = dx0mat.view(-1, *x.shape[1:])\n        dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None\n        dcolscale = rest[0] if colscale is not None else None\n        return (\n            dx0,\n            dresidual,\n            dgamma,\n            dbeta if ctx.has_beta else None,\n            dcolscale,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n        )\n\n\nclass DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        x0,\n        x1,\n        residual,\n        gamma0,\n        beta0,\n        gamma1,\n        beta1,\n        dropout_p,\n        epsilon,\n        residual_in_fp32=False,\n        prenorm=False,\n        is_rms_norm=False,\n        return_dmask=False,\n    ):\n        x0 = maybe_align(x0.contiguous(), 16)\n        x1 = maybe_align(x1.contiguous(), 16) if x1 is not None else None\n        residual = maybe_align(residual.contiguous(), 16) if residual is not None else None\n        gamma0 = maybe_align(gamma0.contiguous(), 16)\n        beta0 = maybe_align(beta0.contiguous(), 16) if beta0 is not None else None\n        gamma1 = maybe_align(gamma1.contiguous(), 16) if gamma1 is not None else None\n        beta1 = maybe_align(beta1.contiguous(), 16) if beta1 is not None else None\n        (\n            z0mat,\n            z1mat,\n            xmat,\n            dmask0,\n            dmask1,\n            mu,\n            rsigma,\n        ) = _dropout_add_layer_norm_parallel_residual_forward(\n            x0,\n            x1,\n            residual,\n            gamma0,\n            beta0,\n            gamma1,\n            beta1,\n            dropout_p,\n            epsilon,\n            residual_in_fp32,\n            is_rms_norm,\n        )\n        ctx.save_for_backward(xmat.view(x0.shape), dmask0, dmask1, gamma0, gamma1, mu, rsigma)\n        ctx.prenorm = prenorm\n        ctx.dropout_p = dropout_p\n        ctx.has_x1 = x1 is not None\n        ctx.has_residual = residual is not None\n        ctx.is_rms_norm = is_rms_norm\n        ctx.has_beta = beta0 is not None\n        z = (z0mat.view(x0.shape), z1mat.view(x0.shape) if z1mat is not None else None)\n        if not return_dmask:\n            return z if not prenorm else (*z, xmat.view(x0.shape))\n        else:\n            dmask0 = (\n                dmask0.view(x0.shape)\n                if dropout_p > 0.0\n                else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)\n            )\n            dmask1 = (\n                dmask1.view(x0.shape)\n                if dropout_p > 0.0 and x1 is not None\n                else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)\n            )\n            ctx.mark_non_differentiable(dmask0)\n            ctx.mark_non_differentiable(dmask1)\n            return (\n                (*z, dmask0, dmask1) if not prenorm else (*z, xmat.view(x0.shape), dmask0, dmask1)\n            )\n\n    @staticmethod\n    def backward(ctx, dz0, dz1, *args):\n        dz0 = maybe_align(dz0.contiguous(), 16)  # this happens!\n        dz1 = maybe_align(dz1.contiguous(), 16) if dz1 is not None else None\n        dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None\n        x, dmask0, dmask1, gamma0, gamma1, mu, rsigma = ctx.saved_tensors\n        dropout_p = ctx.dropout_p\n        has_x1 = ctx.has_x1\n        has_residual = ctx.has_residual\n        (\n            dx0mat,\n            dx1mat,\n            dresidualmat,\n            dgamma0,\n            dbeta0,\n            dgamma1,\n            dbeta1,\n        ) = _dropout_add_layer_norm_parallel_residual_backward(\n            dz0,\n            dz1,\n            dx,\n            x,\n            dmask0,\n            dmask1,\n            mu,\n            rsigma,\n            gamma0,\n            gamma1,\n            dropout_p,\n            has_x1,\n            has_residual,\n            ctx.is_rms_norm,\n        )\n        dx0 = dx0mat.view(x.shape)\n        dx1 = dx1mat.view(x.shape) if dx1mat is not None else None\n        dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None\n        return (\n            dx0,\n            dx1,\n            dresidual,\n            dgamma0,\n            dbeta0 if ctx.has_beta else None,\n            dgamma1,\n            dbeta1 if ctx.has_beta else None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n        )\n\n\ndef layer_norm(x, weight, bias, epsilon):\n    return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False)\n\n\ndef dropout_add_layer_norm(\n    x0,\n    residual,\n    weight,\n    bias,\n    dropout_p,\n    epsilon,\n    rowscale=None,\n    layerscale=None,\n    prenorm=False,\n    residual_in_fp32=False,\n    return_dropout_mask=False,\n):\n    \"\"\"residual_in_fp32 only has an effect if residual is None.\n    Otherwise residual dtype is residual.dtype.\n    \"\"\"\n    return DropoutAddLayerNormFn.apply(\n        x0,\n        residual,\n        weight,\n        bias,\n        rowscale,\n        layerscale,\n        dropout_p,\n        epsilon,\n        residual_in_fp32,\n        prenorm,\n        False,\n        return_dropout_mask,\n    )\n\n\ndef dropout_add_layer_norm_subset(\n    x0,\n    residual,\n    weight,\n    bias,\n    dropout_p,\n    epsilon,\n    layerscale=None,\n    x0_subset=None,\n    out_subset=None,\n    rowscale_const=1.0,\n    out_numrows=0,\n    prenorm=False,\n    residual_in_fp32=False,\n    return_dropout_mask=False,\n):\n    \"\"\"residual_in_fp32 only has an effect if residual is None.\n    Otherwise residual dtype is residual.dtype.\n    \"\"\"\n    return DropoutAddLayerNormSubsetFn.apply(\n        x0,\n        residual,\n        weight,\n        bias,\n        layerscale,\n        x0_subset,\n        out_subset,\n        dropout_p,\n        epsilon,\n        rowscale_const,\n        out_numrows,\n        residual_in_fp32,\n        prenorm,\n        False,\n        return_dropout_mask,\n    )\n\n\ndef dropout_add_layer_norm_parallel_residual(\n    x0,\n    x1,\n    residual,\n    weight0,\n    bias0,\n    weight1,\n    bias1,\n    dropout_p,\n    epsilon,\n    prenorm=False,\n    residual_in_fp32=False,\n    return_dropout_mask=False,\n):\n    \"\"\"residual_in_fp32 only has an effect if residual is None.\n    Otherwise residual dtype is residual.dtype.\n    \"\"\"\n    return DropoutAddLayerNormParallelResidualFn.apply(\n        x0,\n        x1,\n        residual,\n        weight0,\n        bias0,\n        weight1,\n        bias1,\n        dropout_p,\n        epsilon,\n        residual_in_fp32,\n        prenorm,\n        False,\n        return_dropout_mask,\n    )\n\n\nclass DropoutAddLayerNorm(torch.nn.Module):\n    def __init__(\n        self,\n        hidden_size,\n        prenorm=False,\n        p=0.0,\n        eps=1e-5,\n        residual_in_fp32=False,\n        device=None,\n        dtype=None,\n    ):\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        self.prenorm = prenorm\n        self.p = p\n        self.eps = eps\n        self.residual_in_fp32 = residual_in_fp32\n        self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))\n        self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        init.ones_(self.weight)\n        init.zeros_(self.bias)\n\n    def forward(self, x0, residual=None):\n        return dropout_add_layer_norm(\n            x0,\n            residual,\n            self.weight,\n            self.bias,\n            self.p if self.training else 0.0,\n            self.eps,\n            prenorm=self.prenorm,\n            residual_in_fp32=self.residual_in_fp32,\n        )\n"
  },
  {
    "path": "flash_attn/ops/rms_norm.py",
    "content": "# Copyright (c) 2022, Tri Dao.\n# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py\n\nimport torch\nfrom torch.nn import init\n\nfrom flash_attn.ops.layer_norm import (\n    DropoutAddLayerNormFn,\n    DropoutAddLayerNormParallelResidualFn,\n    DropoutAddLayerNormSubsetFn,\n)\n\n\ndef rms_norm(x, weight, epsilon):\n    return DropoutAddLayerNormFn.apply(\n        x, None, weight, None, None, None, 0.0, epsilon, False, False, True\n    )\n\n\ndef dropout_add_rms_norm(\n    x0,\n    residual,\n    weight,\n    bias,\n    dropout_p,\n    epsilon,\n    rowscale=None,\n    layerscale=None,\n    prenorm=False,\n    residual_in_fp32=False,\n    return_dropout_mask=False,\n):\n    \"\"\"residual_in_fp32 only has an effect if residual is None.\n    Otherwise residual dtype is residual.dtype.\n    \"\"\"\n    return DropoutAddLayerNormFn.apply(\n        x0,\n        residual,\n        weight,\n        bias,\n        rowscale,\n        layerscale,\n        dropout_p,\n        epsilon,\n        residual_in_fp32,\n        prenorm,\n        True,\n        return_dropout_mask,\n    )\n\n\ndef dropout_add_rms_norm_subset(\n    x0,\n    residual,\n    weight,\n    bias,\n    dropout_p,\n    epsilon,\n    layerscale=None,\n    x0_subset=None,\n    out_subset=None,\n    rowscale_const=1.0,\n    out_numrows=0,\n    prenorm=False,\n    residual_in_fp32=False,\n    return_dropout_mask=False,\n):\n    \"\"\"residual_in_fp32 only has an effect if residual is None.\n    Otherwise residual dtype is residual.dtype.\n    \"\"\"\n    return DropoutAddLayerNormSubsetFn.apply(\n        x0,\n        residual,\n        weight,\n        bias,\n        layerscale,\n        x0_subset,\n        out_subset,\n        dropout_p,\n        epsilon,\n        rowscale_const,\n        out_numrows,\n        residual_in_fp32,\n        prenorm,\n        True,\n        return_dropout_mask,\n    )\n\n\ndef dropout_add_rms_norm_parallel_residual(\n    x0,\n    x1,\n    residual,\n    weight0,\n    bias0,\n    weight1,\n    bias1,\n    dropout_p,\n    epsilon,\n    prenorm=False,\n    residual_in_fp32=False,\n    return_dropout_mask=False,\n):\n    \"\"\"residual_in_fp32 only has an effect if residual is None.\n    Otherwise residual dtype is residual.dtype.\n    \"\"\"\n    return DropoutAddLayerNormParallelResidualFn.apply(\n        x0,\n        x1,\n        residual,\n        weight0,\n        bias0,\n        weight1,\n        bias1,\n        dropout_p,\n        epsilon,\n        residual_in_fp32,\n        prenorm,\n        True,\n        return_dropout_mask,\n    )\n\n\nclass RMSNorm(torch.nn.Module):\n    def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        self.eps = eps\n        self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))\n        self.register_parameter(\"bias\", None)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        init.ones_(self.weight)\n\n    def forward(self, x):\n        return rms_norm(x, self.weight, self.eps)\n\n\nclass DropoutAddRMSNorm(torch.nn.Module):\n    def __init__(\n        self,\n        hidden_size,\n        prenorm=False,\n        p=0.0,\n        eps=1e-5,\n        residual_in_fp32=False,\n        device=None,\n        dtype=None,\n    ):\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        self.prenorm = prenorm\n        self.p = p\n        self.eps = eps\n        self.residual_in_fp32 = residual_in_fp32\n        self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))\n        self.register_parameter(\"bias\", None)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        init.ones_(self.weight)\n\n    def forward(self, x0, residual=None):\n        return dropout_add_rms_norm(\n            x0,\n            residual,\n            self.weight,\n            None,\n            self.p if self.training else 0.0,\n            self.eps,\n            prenorm=self.prenorm,\n            residual_in_fp32=self.residual_in_fp32,\n        )\n"
  },
  {
    "path": "flash_attn/ops/triton/__init__.py",
    "content": "\n"
  },
  {
    "path": "flash_attn/ops/triton/cross_entropy.py",
    "content": "# Copyright (c) 2023, Tri Dao.\n\nfrom typing import Tuple, Optional, Union\n\nimport torch\nimport torch.nn.functional as F\n\nimport triton\nimport triton.language as tl\n\n# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for\n# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent\n# version of PyTorch. The following 2 lines are for backward compatibility with\n# older PyTorch.\nif \"all_gather_into_tensor\" not in dir(torch.distributed):\n    torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base\n\n\n@triton.heuristics(\n    {\n        \"HAS_SMOOTHING\": lambda args: args[\"smoothing\"] > 0.0,\n    }\n)\n@triton.jit\ndef cross_entropy_fwd_kernel(\n    loss_ptr,  # data ptrs\n    lse_ptr,\n    z_loss_ptr,\n    logits_ptr,\n    labels_ptr,\n    smoothing,\n    logit_scale,\n    lse_square_scale,\n    ignore_index,\n    total_classes,\n    class_start_idx,  # Useful for tensor parallel when each rank only has a subset of classes\n    n_cols,  # shapes\n    logits_row_stride,  # strides\n    BLOCK_SIZE: tl.constexpr,\n    HAS_SMOOTHING: tl.constexpr,\n    # if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE\n    SPLIT: tl.constexpr,\n    PRECOMPUTED_LSE: tl.constexpr,  # If LSE is already computed (also no smoothing and logit_scale == 1.0)\n):\n    row_idx = tl.program_id(0)\n    logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)\n    sum_logits = 0.0  # For smoothing\n    if not PRECOMPUTED_LSE:\n        # Statistics for online softmax\n        m_i = -float(\"inf\")\n        l_i = 0.0\n        for col_offset in range(0, n_cols, BLOCK_SIZE):\n            cols = col_offset + tl.arange(0, BLOCK_SIZE)\n            logits = tl.load(logits_ptr + cols, mask=cols < n_cols, other=-float(\"inf\")).to(\n                tl.float32\n            ) * logit_scale\n            if HAS_SMOOTHING:\n                sum_logits += tl.sum(tl.where(cols < n_cols, logits, 0.0))\n            m_i_new = tl.maximum(m_i, tl.max(logits))\n            l_i = tl.exp(m_i - m_i_new) * l_i + tl.sum(tl.exp(logits - m_i_new))\n            m_i = m_i_new\n        lse = tl.log(l_i) + m_i\n        tl.store(lse_ptr + row_idx, lse)\n    else:\n        lse = tl.load(lse_ptr + row_idx)\n    label_idx = tl.load(labels_ptr + row_idx)\n    if label_idx == ignore_index:\n        loss = 0.0\n        z_loss = 0.0\n    else:\n        label_idx -= class_start_idx\n        if label_idx >= 0 and label_idx < n_cols:\n            logits_label = tl.load(logits_ptr + label_idx) * logit_scale\n            if HAS_SMOOTHING:\n                loss = (\n                    (lse if not SPLIT else 0.0)\n                    - smoothing * sum_logits / total_classes\n                    - (1 - smoothing) * logits_label\n                )\n            else:\n                loss = (lse if not SPLIT else 0.0) - logits_label\n        else:\n            # If label is out of bounds, we set the CE loss to 0.0. But we still want the smoothing loss\n            if HAS_SMOOTHING:\n                loss = smoothing * ((lse if not SPLIT else 0.0) - sum_logits / total_classes)\n            else:\n                loss = 0.0\n        if not SPLIT:\n            z_loss = lse_square_scale * lse * lse\n            loss += z_loss\n        else:\n            z_loss = 0.0\n    tl.store(loss_ptr + row_idx, loss)\n    if not SPLIT:\n        tl.store(z_loss_ptr + row_idx, z_loss)\n\n\n@triton.heuristics(\n    {\n        \"HAS_SMOOTHING\": lambda args: args[\"smoothing\"] > 0.0,\n    }\n)\n@triton.jit\ndef cross_entropy_bwd_kernel(\n    dlogits_ptr,  # data ptrs\n    dloss_ptr,\n    logits_ptr,\n    lse_ptr,\n    labels_ptr,\n    smoothing,\n    logit_scale,\n    lse_square_scale,\n    ignore_index,\n    total_classes,\n    class_start_idx,  # Useful for tensor parallel when each rank only has a subset of classes\n    n_cols,  # shapes\n    logits_row_stride,  # strides\n    dlogits_row_stride,\n    dloss_row_stride,\n    BLOCK_SIZE: tl.constexpr,\n    HAS_SMOOTHING: tl.constexpr,\n):\n    row_idx = tl.program_id(0)\n    col_block_idx = tl.program_id(1)\n    logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)\n    dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64)\n    col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n    label_idx = tl.load(labels_ptr + row_idx)\n    if label_idx != ignore_index:\n        dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride)\n    else:\n        dloss = 0.0\n    logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float(\"inf\")).to(\n        tl.float32\n    ) * logit_scale\n    lse = tl.load(lse_ptr + row_idx)\n    probs = tl.exp(logits - lse)\n    probs += 2.0 * lse_square_scale * lse * probs\n    label_idx -= class_start_idx\n    if HAS_SMOOTHING:\n        smooth_positive = 1.0 - smoothing\n        smooth_negative = smoothing / total_classes\n        probs = tl.where(col_offsets == label_idx, probs - smooth_positive, probs) - smooth_negative\n    else:\n        probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)\n    tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols)\n\n\nclass CrossEntropyLoss(torch.autograd.Function):\n\n    @staticmethod\n    def forward(\n        ctx,\n        logits,\n        labels,\n        precomputed_lse=None,\n        smoothing=0.0,\n        logit_scale=1.0,\n        lse_square_scale=0.0,\n        ignore_index=-100,\n        inplace_backward=False,\n        process_group=None,\n    ):\n        # For some reason Triton generates wrong code when labels has dtype long and its address\n        # is not aligned to 16 bytes. The ld.global.b64 seems to load the wrong label index.\n        if labels.dtype == torch.long and labels.data_ptr() % 16 != 0:\n            labels = F.pad(labels, (0, 1))[..., :-1]\n            assert labels.data_ptr() % 16 == 0\n        assert logit_scale > 0.0\n        n_rows, n_cols = logits.shape\n        assert labels.shape == (n_rows,)\n        world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)\n        total_classes = world_size * n_cols\n        rank = 0 if process_group is None else torch.distributed.get_rank(process_group)\n        class_start_idx = rank * n_cols\n        use_precomputed_lse = precomputed_lse is not None and logit_scale == 1.0 and smoothing == 0.0\n\n        if logits.stride(-1) != 1:\n            logits = logits.contiguous()\n        MAX_BLOCK_SIZE = 16 * 1024\n        BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE)\n        num_warps = (\n            4\n            if BLOCK_SIZE < 2048\n            else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32))\n        )\n        losses = torch.empty(n_rows, dtype=torch.float, device=logits.device)\n        if use_precomputed_lse:\n            assert precomputed_lse.shape == (n_rows,)\n            lse = precomputed_lse.contiguous()\n        else:\n            lse = torch.empty(n_rows, dtype=torch.float, device=logits.device)\n        z_losses = torch.empty(n_rows, dtype=torch.float, device=logits.device)\n        # Need this, otherwise Triton tries to launch from cuda:0 and we get\n        # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)\n        with torch.cuda.device(logits.device.index):\n            cross_entropy_fwd_kernel[(n_rows,)](\n                losses,  # data ptrs\n                lse,\n                z_losses,\n                logits,\n                labels,\n                smoothing,\n                logit_scale,\n                lse_square_scale,\n                ignore_index,\n                total_classes,\n                class_start_idx,\n                n_cols,  # shapes\n                logits.stride(0),  # strides\n                BLOCK_SIZE=BLOCK_SIZE,  # constants\n                SPLIT=world_size > 1,\n                PRECOMPUTED_LSE=use_precomputed_lse,\n                num_warps=num_warps,\n            )\n\n        if world_size > 1:\n            # If there's no smoothing, if labels are in the vocab of this partition, losses contains\n            # - predicted logit, and 0 otherwise.\n            # If there's smoothing=0.1, for labels in the vocab of this partition, losses contains\n            # -0.9 * predicted logit - 0.1 * sum logit / total_classes.\n            # For labels not in the vocab of this partition, losses contains\n            # -0.1 * sum logit / total_classes.\n            if world_size > 1:\n                lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device)\n                torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group)\n                handle_losses = torch.distributed.all_reduce(\n                    losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True\n                )\n                lse = torch.logsumexp(lse_allgather, dim=0)\n                handle_losses.wait()\n            # After the allreduce, if there's no smoothing, the total losses are - predicted_logit,\n            # we just have to add the (global) lse.\n            # If there's smoothing=0.1, the total losses are\n            # -0.9 * predicted_logit - 0.1 * sum logit / total_classes.\n            # Again, we just have to add the (global) lse.\n            losses += lse\n            if lse_square_scale != 0.0:\n                z_losses = lse_square_scale * lse.square()\n                z_losses.masked_fill_(labels == ignore_index, 0.0)\n                losses += z_losses\n            else:\n                z_losses = torch.zeros_like(losses)\n            losses.masked_fill_(labels == ignore_index, 0.0)\n\n        ctx.save_for_backward(logits, lse, labels)\n        ctx.mark_non_differentiable(z_losses)\n        ctx.smoothing = smoothing\n        ctx.logit_scale = logit_scale\n        ctx.lse_square_scale = lse_square_scale\n        ctx.ignore_index = ignore_index\n        ctx.total_classes = total_classes\n        ctx.class_start_idx = class_start_idx\n        ctx.inplace_backward = inplace_backward\n        return losses, z_losses\n\n    @staticmethod\n    def backward(ctx, grad_losses, grad_z_losses):\n        del grad_z_losses  # z_losses are only for logging.\n\n        logits, lse, labels = ctx.saved_tensors\n        dlogits = logits if ctx.inplace_backward else torch.empty_like(logits)\n        n_rows, n_cols = logits.shape\n        BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 4 * 1024)\n        num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else 16)\n        grid = lambda META: (n_rows, triton.cdiv(n_cols, META[\"BLOCK_SIZE\"]))  # noqa\n        # Need this, otherwise Triton tries to launch from cuda:0 and we get\n        # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)\n        with torch.cuda.device(logits.device.index):\n            cross_entropy_bwd_kernel[grid](\n                dlogits,  # data ptrs\n                grad_losses,\n                logits,\n                lse,\n                labels,\n                ctx.smoothing,\n                ctx.logit_scale,\n                ctx.lse_square_scale,\n                ctx.ignore_index,\n                ctx.total_classes,\n                ctx.class_start_idx,\n                n_cols,  # shapes\n                logits.stride(0),  # strides\n                dlogits.stride(0),\n                grad_losses.stride(0),\n                BLOCK_SIZE=BLOCK_SIZE,  # constants\n                num_warps=num_warps,\n            )\n        return dlogits, None, None, None, None, None, None, None, None, None\n\n\ndef cross_entropy_loss(\n    logits: torch.Tensor,\n    labels: torch.Tensor,\n    precomputed_lse: Optional[torch.Tensor] = None,\n    label_smoothing: float = 0.0,\n    logit_scale: float = 1.0,\n    lse_square_scale: float = 0.0,\n    ignore_index=-100,\n    inplace_backward: bool = False,\n    process_group=None,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Arguments:\n        logits: (batch, vocab_size)\n        labels: (batch,)\n        label_smoothing: float\n        logit_scale: float. Multiply logits by this scale before calculating the loss.\n        lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.\n            This is also referred to as \"z-loss\".\n        ignore_index: int. If labels == ignore_index, the loss is set to 0.0.\n        inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.\n            This saves memory.\n        process_group: if not None, we're doing Tensor Parallel: each process is responsible for\n            one part of the vocab. The loss will be aggregated across processes.\n    Returns:\n        losses: (batch,), float\n        z_losses: (batch,), float\n    \"\"\"\n    return CrossEntropyLoss.apply(\n        logits,\n        labels,\n        precomputed_lse,\n        label_smoothing,\n        logit_scale,\n        lse_square_scale,\n        ignore_index,\n        inplace_backward,\n        process_group,\n    )\n"
  },
  {
    "path": "flash_attn/ops/triton/k_activations.py",
    "content": "# Adapted from https://github.com/facebookresearch/xformers/blob/main/xformers/triton/k_activations.py\n# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.\n#\n# This source code is licensed under the BSD license found in the\n# LICENSE file in the root directory of this source tree.\n\nimport math\nfrom enum import Enum\nfrom typing import Optional\n\nimport triton\nimport triton.language as tl\n\n_sqrt2pi = math.sqrt(2.0 / math.pi)\n_sqrt1_2 = math.sqrt(1.0 / 2)\n_gaussian_pdf_normalization = 1.0 / math.sqrt(2 * math.pi)\n\n\nclass Activation(str, Enum):\n    SquaredReLU = \"squared_relu\"\n    GeLU = \"gelu\"\n    GeLUApprox = \"gelu_approx\"\n    LeakyReLU = \"leaky_relu\"\n    ReLU = \"relu\"\n\n\ndef get_triton_activation_kernel(activation: Optional[Activation]):\n    return (\n        {\n            Activation.ReLU: relu,\n            Activation.LeakyReLU: leaky_relu,\n            Activation.GeLU: gelu,\n            Activation.GeLUApprox: gelu_approx,\n            Activation.SquaredReLU: squared_relu,\n        }[activation]\n        if activation\n        else None\n    )\n\n\ndef get_triton_activation_bwd_kernel(activation: Optional[Activation]):\n    return (\n        {\n            Activation.ReLU: relu_grad,\n            Activation.LeakyReLU: leaky_relu_grad,\n            Activation.GeLU: gelu_grad,\n            Activation.GeLUApprox: gelu_approx_grad,\n            Activation.SquaredReLU: squared_relu_grad,\n        }[activation]\n        if activation\n        else None\n    )\n\n\n@triton.jit\ndef tanh(x):\n    # Tanh is just a scaled sigmoid\n    return 2 * tl.sigmoid(2 * x) - 1\n\n\n@triton.jit\ndef cosh(x):\n    exp_x = tl.exp(x)\n    return (exp_x + 1.0 / exp_x) * 0.5\n\n\n# a Triton implementation of the most used activations\n# See for instance http://arxiv.org/abs/1606.08415 for an overview\n\n# ReLU\n@triton.jit\ndef relu(x):\n    \"\"\"\n    ReLU_ activation function\n\n    .. _ReLU: https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html\n    \"\"\"\n    zero = 0.0\n    return tl.where(x >= 0, x, zero.to(x.dtype))\n\n\n@triton.jit\ndef relu_grad(x):\n    # ReLU is different from other activations\n    # in that it does not require the input to retrospectively compute its gradient\n    # here the input is the downstream gradient, and we return the upstream gradient directly\n    zero = 0.0\n    one = 1.0\n    return tl.where(x >= 0, one.to(x.dtype), zero.to(x.dtype))\n\n\n@triton.jit\ndef squared_relu(x):\n    \"\"\"\n    Squared ReLU activation, as proposed in the Primer_ paper.\n\n    .. _Primer: https://arxiv.org/abs/2109.08668\n    \"\"\"\n    x_ = relu(x)\n    return (x_ * x_).to(x.dtype)\n\n\n@triton.jit\ndef squared_relu_grad(x):\n    return tl.where(x >= 0, 2.0 * x, 0.0)\n\n\n# Leaky ReLU\n@triton.jit\ndef leaky_relu(x):\n    \"\"\"\n    LeakyReLU_ activation\n\n    .. _LeakyReLU: https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html\n    \"\"\"\n    scale = 0.01 + 0.0\n    scale = scale.to(x.dtype)\n    return tl.where(x >= 0, x, scale * x)\n\n\n@triton.jit\ndef leaky_relu_grad(x):\n    min_grad = 0.01\n    max_grad = 1\n\n    min_grad = min_grad.to(x.dtype)\n    max_grad = max_grad.to(x.dtype)\n\n    return tl.where(x >= 0, max_grad, min_grad)\n\n\n@triton.jit\ndef gelu(x):\n    \"\"\"Gaussian Error Linear Unit (GELU)\"\"\"\n    return x * 0.5 * (1.0 + tl.libdevice.erf(x * _sqrt1_2))\n\n\n@triton.jit\ndef gelu_grad(x):\n    cdf = 0.5 * (1.0 + tl.libdevice.erf(x * _sqrt1_2))\n    pdf = tl.exp(-0.5 * x * x) * _gaussian_pdf_normalization\n    return cdf + x * pdf\n\n\n@triton.jit\ndef gelu_approx(x):\n    \"\"\"\n    GeLU_ activation - Gaussian error linear unit, with tanh approximation\n\n    .. _GeLU: https://arxiv.org/pdf/1606.08415.pdf\n    \"\"\"\n    return 0.5 * x * (1.0 + tanh(_sqrt2pi * x * (1.0 + 0.044715 * x * x)))\n\n\n@triton.jit\ndef gelu_approx_grad(x):\n    # CREDITS: Fast implementation proposed in\n    # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30\n    tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * x * x))\n    return 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (\n        1 + tanh_out\n    )\n"
  },
  {
    "path": "flash_attn/ops/triton/layer_norm.py",
    "content": "# Copyright (c) 2024, Tri Dao.\n# Implement dropout + residual + layer_norm / rms_norm.\n\n# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html\n# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.\n# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.\n# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.\n\nimport math\nfrom typing import Optional, List\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\n\nimport triton\nimport triton.language as tl\n\nfrom flash_attn.utils.torch import custom_fwd, custom_bwd\nfrom flash_attn.utils.library import triton_op\n\n\ndef maybe_contiguous_lastdim(x):\n    return x.contiguous() if x is not None and x.stride(-1) != 1 else x\n\n\ndef maybe_contiguous(x):\n    return x.contiguous() if x is not None else None\n\n\ndef triton_autotune_configs():\n    # Return configs with a valid warp count for the current device\n    configs = []\n    # Maximum threads per block is architecture-dependent in theory, but in reality all are 1024\n    max_threads_per_block = 1024\n    # Default to warp size 32 if not defined by device\n    warp_size = getattr(torch.cuda.get_device_properties(torch.cuda.current_device()), \"warp_size\", 32)\n    # Autotune for warp counts which are powers of 2 and do not exceed thread per block limit\n    return [triton.Config({}, num_warps=warp_count) for warp_count in [1, 2, 4, 8, 16, 32]\n            if warp_count * warp_size <= max_threads_per_block]\n    # return [triton.Config({}, num_warps=8)]\n\n\ndef layer_norm_ref(\n    x,\n    weight,\n    bias,\n    residual=None,\n    x1=None,\n    weight1=None,\n    bias1=None,\n    eps=1e-6,\n    dropout_p=0.0,\n    rowscale=None,\n    prenorm=False,\n    zero_centered_weight=False,\n    dropout_mask=None,\n    dropout_mask1=None,\n    upcast=False,\n):\n    dtype = x.dtype\n    if upcast:\n        x = x.float()\n        weight = weight.float()\n        bias = bias.float() if bias is not None else None\n        residual = residual.float() if residual is not None else residual\n        x1 = x1.float() if x1 is not None else None\n        weight1 = weight1.float() if weight1 is not None else None\n        bias1 = bias1.float() if bias1 is not None else None\n    if zero_centered_weight:\n        weight = weight + 1.0\n        if weight1 is not None:\n            weight1 = weight1 + 1.0\n    if x1 is not None:\n        assert rowscale is None, \"rowscale is not supported with parallel LayerNorm\"\n    if rowscale is not None:\n        x = x * rowscale[..., None]\n    if dropout_p > 0.0:\n        if dropout_mask is not None:\n            x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)\n        else:\n            x = F.dropout(x, p=dropout_p)\n        if x1 is not None:\n            if dropout_mask1 is not None:\n                x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)\n            else:\n                x1 = F.dropout(x1, p=dropout_p)\n    if x1 is not None:\n        x = x + x1\n    if residual is not None:\n        x = (x + residual).to(x.dtype)\n    out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(\n        dtype\n    )\n    if weight1 is None:\n        return out if not prenorm else (out, x)\n    else:\n        out1 = F.layer_norm(\n            x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps\n        ).to(dtype)\n        return (out, out1) if not prenorm else (out, out1, x)\n\n\ndef rms_norm_ref(\n    x,\n    weight,\n    bias,\n    residual=None,\n    x1=None,\n    weight1=None,\n    bias1=None,\n    eps=1e-6,\n    dropout_p=0.0,\n    rowscale=None,\n    prenorm=False,\n    zero_centered_weight=False,\n    dropout_mask=None,\n    dropout_mask1=None,\n    upcast=False,\n):\n    dtype = x.dtype\n    if upcast:\n        x = x.float()\n        weight = weight.float()\n        bias = bias.float() if bias is not None else None\n        residual = residual.float() if residual is not None else residual\n        x1 = x1.float() if x1 is not None else None\n        weight1 = weight1.float() if weight1 is not None else None\n        bias1 = bias1.float() if bias1 is not None else None\n    if zero_centered_weight:\n        weight = weight + 1.0\n        if weight1 is not None:\n            weight1 = weight1 + 1.0\n    if x1 is not None:\n        assert rowscale is None, \"rowscale is not supported with parallel LayerNorm\"\n    if rowscale is not None:\n        x = x * rowscale[..., None]\n    if dropout_p > 0.0:\n        if dropout_mask is not None:\n            x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)\n        else:\n            x = F.dropout(x, p=dropout_p)\n        if x1 is not None:\n            if dropout_mask1 is not None:\n                x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)\n            else:\n                x1 = F.dropout(x1, p=dropout_p)\n    if x1 is not None:\n        x = x + x1\n    if residual is not None:\n        x = (x + residual).to(x.dtype)\n    rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)\n    out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(dtype)\n    if weight1 is None:\n        return out if not prenorm else (out, x)\n    else:\n        out1 = ((x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)).to(\n            dtype\n        )\n        return (out, out1) if not prenorm else (out, out1, x)\n\n\n@triton.autotune(\n    configs=triton_autotune_configs(),\n    key=[\"N\", \"HAS_RESIDUAL\", \"STORE_RESIDUAL_OUT\", \"IS_RMS_NORM\", \"HAS_BIAS\", \"HAS_X1\", \"HAS_W1\", \"HAS_B1\"],\n)\n# torch compile doesn't like triton.heuristics, so we set these manually when calling the kernel\n# @triton.heuristics({\"HAS_BIAS\": lambda args: args[\"B\"] is not None})\n# @triton.heuristics({\"HAS_RESIDUAL\": lambda args: args[\"RESIDUAL\"] is not None})\n# @triton.heuristics({\"HAS_X1\": lambda args: args[\"X1\"] is not None})\n# @triton.heuristics({\"HAS_W1\": lambda args: args[\"W1\"] is not None})\n# @triton.heuristics({\"HAS_B1\": lambda args: args[\"B1\"] is not None})\n@triton.jit\ndef _layer_norm_fwd_1pass_kernel(\n    X,  # pointer to the input\n    Y,  # pointer to the output\n    W,  # pointer to the weights\n    B,  # pointer to the biases\n    RESIDUAL,  # pointer to the residual\n    X1,\n    W1,\n    B1,\n    Y1,\n    RESIDUAL_OUT,  # pointer to the residual\n    ROWSCALE,\n    SEEDS,  # Dropout seeds for each row\n    DROPOUT_MASK,\n    DROPOUT_MASK1,\n    Mean,  # pointer to the mean\n    Rstd,  # pointer to the 1/std\n    stride_x_row,  # how much to increase the pointer when moving by 1 row\n    stride_y_row,\n    stride_res_row,\n    stride_res_out_row,\n    stride_x1_row,\n    stride_y1_row,\n    M,  # number of rows in X\n    N,  # number of columns in X\n    eps,  # epsilon to avoid division by zero\n    dropout_p,  # Dropout probability\n    zero_centered_weight,  # If true, add 1.0 to the weight\n    IS_RMS_NORM: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n    HAS_RESIDUAL: tl.constexpr,\n    STORE_RESIDUAL_OUT: tl.constexpr,\n    HAS_BIAS: tl.constexpr,\n    HAS_DROPOUT: tl.constexpr,\n    STORE_DROPOUT_MASK: tl.constexpr,\n    HAS_ROWSCALE: tl.constexpr,\n    HAS_X1: tl.constexpr,\n    HAS_W1: tl.constexpr,\n    HAS_B1: tl.constexpr,\n):\n    # Map the program id to the row of X and Y it should compute.\n    row = tl.program_id(0)\n    X += row * stride_x_row\n    Y += row * stride_y_row\n    if HAS_RESIDUAL:\n        RESIDUAL += row * stride_res_row\n    if STORE_RESIDUAL_OUT:\n        RESIDUAL_OUT += row * stride_res_out_row\n    if HAS_X1:\n        X1 += row * stride_x1_row\n    if HAS_W1:\n        Y1 += row * stride_y1_row\n    # Compute mean and variance\n    cols = tl.arange(0, BLOCK_N)\n    x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n    if HAS_ROWSCALE:\n        rowscale = tl.load(ROWSCALE + row).to(tl.float32)\n        x *= rowscale\n    if HAS_DROPOUT:\n        # Compute dropout mask\n        # 7 rounds is good enough, and reduces register pressure\n        keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p\n        x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)\n        if STORE_DROPOUT_MASK:\n            tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)\n    if HAS_X1:\n        x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)\n        if HAS_ROWSCALE:\n            rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)\n            x1 *= rowscale\n        if HAS_DROPOUT:\n            # Compute dropout mask\n            # 7 rounds is good enough, and reduces register pressure\n            keep_mask = (\n                tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p\n            )\n            x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)\n            if STORE_DROPOUT_MASK:\n                tl.store(DROPOUT_MASK1 + row * N + cols, keep_mask, mask=cols < N)\n        x += x1\n    if HAS_RESIDUAL:\n        residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)\n        x += residual\n    if STORE_RESIDUAL_OUT:\n        tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)\n    if not IS_RMS_NORM:\n        mean = tl.sum(x, axis=0) / N\n        tl.store(Mean + row, mean)\n        xbar = tl.where(cols < N, x - mean, 0.0)\n        var = tl.sum(xbar * xbar, axis=0) / N\n    else:\n        xbar = tl.where(cols < N, x, 0.0)\n        var = tl.sum(xbar * xbar, axis=0) / N\n    rstd = 1 / tl.sqrt(var + eps)\n    tl.store(Rstd + row, rstd)\n    # Normalize and apply linear transformation\n    mask = cols < N\n    w = tl.load(W + cols, mask=mask).to(tl.float32)\n    if zero_centered_weight:\n        w += 1.0\n    if HAS_BIAS:\n        b = tl.load(B + cols, mask=mask).to(tl.float32)\n    x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n    y = x_hat * w + b if HAS_BIAS else x_hat * w\n    # Write output\n    tl.store(Y + cols, y, mask=mask)\n    if HAS_W1:\n        w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)\n        if zero_centered_weight:\n            w1 += 1.0\n        if HAS_B1:\n            b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)\n        y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1\n        tl.store(Y1 + cols, y1, mask=mask)\n\n\ndef _layer_norm_fwd(\n    x: Tensor,\n    weight: Tensor,\n    bias: Tensor,\n    eps: float,\n    residual: Optional[Tensor] = None,\n    x1: Optional[Tensor] = None,\n    weight1: Optional[Tensor] = None,\n    bias1: Optional[Tensor] = None,\n    dropout_p: float = 0.0,\n    rowscale: Optional[Tensor] = None,\n    out_dtype: Optional[torch.dtype] = None,\n    residual_dtype: Optional[torch.dtype] = None,\n    zero_centered_weight: bool = False,\n    is_rms_norm: bool = False,\n    return_dropout_mask: bool = False,\n    out: Optional[Tensor] = None,\n    residual_out: Optional[Tensor] = None\n) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor):\n    # Need to wrap to handle the case where residual_out is a alias of x, which makes torch.library\n    # and torch.compile unhappy. Also allocate memory for out and residual_out if they are None\n    # so that _layer_norm_fwd_impl doesn't have to return them.\n    if out is None:\n        out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)\n    if residual is not None:\n        residual_dtype = residual.dtype\n    if residual_out is None and (\n        residual is not None\n        or (residual_dtype is not None and residual_dtype != x.dtype)\n        or dropout_p > 0.0\n        or rowscale is not None\n        or x1 is not None\n    ):\n        residual_out = torch.empty_like(\n            x, dtype=residual_dtype if residual_dtype is not None else x.dtype\n        )\n    else:\n        residual_out = None\n    y1, mean, rstd, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd_impl(\n        x,\n        weight,\n        bias,\n        eps,\n        out,\n        residual=residual,\n        x1=x1,\n        weight1=weight1,\n        bias1=bias1,\n        dropout_p=dropout_p,\n        rowscale=rowscale,\n        zero_centered_weight=zero_centered_weight,\n        is_rms_norm=is_rms_norm,\n        return_dropout_mask=return_dropout_mask,\n        residual_out=residual_out,\n    )\n    # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0\n    if residual_out is None:\n        residual_out = x\n    return out, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1\n\n\n# [2025-04-28] torch.library.triton_op ignores the schema argument, but here we need the schema\n# since we're returning a tuple of tensors\n@triton_op(\"flash_attn::layer_norm_fwd_impl\", mutates_args={\"out\", \"residual_out\"},\n           schema=\"(Tensor x, Tensor weight, Tensor bias, float eps, Tensor(a!) out, Tensor? residual, Tensor? x1, Tensor? weight1, Tensor? bias1, float dropout_p, Tensor? rowscale, bool zero_centered_weight, bool is_rms_norm, bool return_dropout_mask, Tensor(a!)? residual_out) -> (Tensor y1, Tensor mean, Tensor rstd, Tensor seeds, Tensor dropout_mask, Tensor dropout_mask1)\")\ndef _layer_norm_fwd_impl(\n    x: Tensor,\n    weight: Tensor,\n    bias: Tensor,\n    eps: float,\n    out: Tensor,\n    residual: Optional[Tensor] = None,\n    x1: Optional[Tensor] = None,\n    weight1: Optional[Tensor] = None,\n    bias1: Optional[Tensor] = None,\n    dropout_p: float = 0.0,\n    rowscale: Optional[Tensor] = None,\n    zero_centered_weight: bool = False,\n    is_rms_norm: bool = False,\n    return_dropout_mask: bool = False,\n    residual_out: Optional[Tensor] = None\n) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor):\n    M, N = x.shape\n    assert x.stride(-1) == 1\n    if residual is not None:\n        assert residual.stride(-1) == 1\n        assert residual.shape == (M, N)\n    assert weight.shape == (N,)\n    assert weight.stride(-1) == 1\n    if bias is not None:\n        assert bias.stride(-1) == 1\n        assert bias.shape == (N,)\n    if x1 is not None:\n        assert x1.shape == x.shape\n        assert rowscale is None\n        assert x1.stride(-1) == 1\n    if weight1 is not None:\n        assert weight1.shape == (N,)\n        assert weight1.stride(-1) == 1\n    if bias1 is not None:\n        assert bias1.shape == (N,)\n        assert bias1.stride(-1) == 1\n    if rowscale is not None:\n        assert rowscale.is_contiguous()\n        assert rowscale.shape == (M,)\n    assert out.shape == x.shape\n    assert out.stride(-1) == 1\n    if residual_out is not None:\n        assert residual_out.shape == x.shape\n        assert residual_out.stride(-1) == 1\n    if weight1 is not None:\n        y1 = torch.empty_like(out)\n        assert y1.stride(-1) == 1\n    else:\n        y1 = None\n    mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None\n    rstd = torch.empty((M,), dtype=torch.float32, device=x.device)\n    if dropout_p > 0.0:\n        seeds = torch.randint(\n            2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64\n        )\n    else:\n        seeds = None\n    if return_dropout_mask and dropout_p > 0.0:\n        dropout_mask = torch.empty(M, N, device=x.device, dtype=torch.bool)\n        if x1 is not None:\n            dropout_mask1 = torch.empty(M, N, device=x.device, dtype=torch.bool)\n        else:\n            dropout_mask1 = None\n    else:\n        dropout_mask, dropout_mask1 = None, None\n    # Less than 64KB per feature: enqueue fused kernel\n    MAX_FUSED_SIZE = 65536 // x.element_size()\n    BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n    if N > BLOCK_N:\n        raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n    with torch.cuda.device(x.device.index):\n        torch.library.wrap_triton(_layer_norm_fwd_1pass_kernel)[(M,)](\n            x,\n            out,\n            weight,\n            bias,\n            residual,\n            x1,\n            weight1,\n            bias1,\n            y1,\n            residual_out,\n            rowscale,\n            seeds,\n            dropout_mask,\n            dropout_mask1,\n            mean,\n            rstd,\n            x.stride(0),\n            out.stride(0),\n            residual.stride(0) if residual is not None else 0,\n            residual_out.stride(0) if residual_out is not None else 0,\n            x1.stride(0) if x1 is not None else 0,\n            y1.stride(0) if y1 is not None else 0,\n            M,\n            N,\n            eps,\n            dropout_p,\n            # Passing bool make torch inductor very unhappy since it then tries to compare to int_max\n            int(zero_centered_weight),\n            is_rms_norm,\n            BLOCK_N,\n            residual is not None,\n            residual_out is not None,\n            bias is not None,\n            dropout_p > 0.0,\n            dropout_mask is not None,\n            rowscale is not None,\n            HAS_X1=x1 is not None,\n            HAS_W1=weight1 is not None,\n            HAS_B1=bias1 is not None,\n        )\n    return y1, mean, rstd, seeds, dropout_mask, dropout_mask1\n\n\n@triton.autotune(\n    configs=triton_autotune_configs(),\n    key=[\"N\", \"HAS_DRESIDUAL\", \"STORE_DRESIDUAL\", \"IS_RMS_NORM\", \"HAS_BIAS\", \"HAS_DROPOUT\"],\n)\n# torch compile doesn't like triton.heuristics, so we set these manually when calling the kernel\n# @triton.heuristics({\"HAS_BIAS\": lambda args: args[\"B\"] is not None})\n# @triton.heuristics({\"HAS_DRESIDUAL\": lambda args: args[\"DRESIDUAL\"] is not None})\n# @triton.heuristics({\"STORE_DRESIDUAL\": lambda args: args[\"DRESIDUAL_IN\"] is not None})\n# @triton.heuristics({\"HAS_ROWSCALE\": lambda args: args[\"ROWSCALE\"] is not None})\n# @triton.heuristics({\"HAS_DY1\": lambda args: args[\"DY1\"] is not None})\n# @triton.heuristics({\"HAS_DX1\": lambda args: args[\"DX1\"] is not None})\n# @triton.heuristics({\"HAS_B1\": lambda args: args[\"DB1\"] is not None})\n# @triton.heuristics({\"RECOMPUTE_OUTPUT\": lambda args: args[\"Y\"] is not None})\n@triton.jit\ndef _layer_norm_bwd_kernel(\n    X,  # pointer to the input\n    W,  # pointer to the weights\n    B,  # pointer to the biases\n    Y,  # pointer to the output to be recomputed\n    DY,  # pointer to the output gradient\n    DX,  # pointer to the input gradient\n    DW,  # pointer to the partial sum of weights gradient\n    DB,  # pointer to the partial sum of biases gradient\n    DRESIDUAL,\n    W1,\n    DY1,\n    DX1,\n    DW1,\n    DB1,\n    DRESIDUAL_IN,\n    ROWSCALE,\n    SEEDS,\n    Mean,  # pointer to the mean\n    Rstd,  # pointer to the 1/std\n    stride_x_row,  # how much to increase the pointer when moving by 1 row\n    stride_y_row,\n    stride_dy_row,\n    stride_dx_row,\n    stride_dres_row,\n    stride_dy1_row,\n    stride_dx1_row,\n    stride_dres_in_row,\n    M,  # number of rows in X\n    N,  # number of columns in X\n    eps,  # epsilon to avoid division by zero\n    dropout_p,\n    zero_centered_weight,\n    rows_per_program,\n    IS_RMS_NORM: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n    HAS_DRESIDUAL: tl.constexpr,\n    STORE_DRESIDUAL: tl.constexpr,\n    HAS_BIAS: tl.constexpr,\n    HAS_DROPOUT: tl.constexpr,\n    HAS_ROWSCALE: tl.constexpr,\n    HAS_DY1: tl.constexpr,\n    HAS_DX1: tl.constexpr,\n    HAS_B1: tl.constexpr,\n    RECOMPUTE_OUTPUT: tl.constexpr,\n):\n    # Map the program id to the elements of X, DX, and DY it should compute.\n    row_block_id = tl.program_id(0)\n    row_start = row_block_id * rows_per_program\n    # Do not early exit if row_start >= M, because we need to write DW and DB\n    cols = tl.arange(0, BLOCK_N)\n    mask = cols < N\n    X += row_start * stride_x_row\n    if HAS_DRESIDUAL:\n        DRESIDUAL += row_start * stride_dres_row\n    if STORE_DRESIDUAL:\n        DRESIDUAL_IN += row_start * stride_dres_in_row\n    DY += row_start * stride_dy_row\n    DX += row_start * stride_dx_row\n    if HAS_DY1:\n        DY1 += row_start * stride_dy1_row\n    if HAS_DX1:\n        DX1 += row_start * stride_dx1_row\n    if RECOMPUTE_OUTPUT:\n        Y += row_start * stride_y_row\n    w = tl.load(W + cols, mask=mask).to(tl.float32)\n    if zero_centered_weight:\n        w += 1.0\n    if RECOMPUTE_OUTPUT and HAS_BIAS:\n        b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)\n    if HAS_DY1:\n        w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)\n        if zero_centered_weight:\n            w1 += 1.0\n    dw = tl.zeros((BLOCK_N,), dtype=tl.float32)\n    if HAS_BIAS:\n        db = tl.zeros((BLOCK_N,), dtype=tl.float32)\n    if HAS_DY1:\n        dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32)\n        if HAS_B1:\n            db1 = tl.zeros((BLOCK_N,), dtype=tl.float32)\n    row_end = min((row_block_id + 1) * rows_per_program, M)\n    for row in range(row_start, row_end):\n        # Load data to SRAM\n        x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)\n        dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)\n        if HAS_DY1:\n            dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32)\n        if not IS_RMS_NORM:\n            mean = tl.load(Mean + row)\n        rstd = tl.load(Rstd + row)\n        # Compute dx\n        xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n        xhat = tl.where(mask, xhat, 0.0)\n        if RECOMPUTE_OUTPUT:\n            y = xhat * w + b if HAS_BIAS else xhat * w\n            tl.store(Y + cols, y, mask=mask)\n        wdy = w * dy\n        dw += dy * xhat\n        if HAS_BIAS:\n            db += dy\n        if HAS_DY1:\n            wdy += w1 * dy1\n            dw1 += dy1 * xhat\n            if HAS_B1:\n                db1 += dy1\n        if not IS_RMS_NORM:\n            c1 = tl.sum(xhat * wdy, axis=0) / N\n            c2 = tl.sum(wdy, axis=0) / N\n            dx = (wdy - (xhat * c1 + c2)) * rstd\n        else:\n            c1 = tl.sum(xhat * wdy, axis=0) / N\n            dx = (wdy - xhat * c1) * rstd\n        if HAS_DRESIDUAL:\n            dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)\n            dx += dres\n        # Write dx\n        if STORE_DRESIDUAL:\n            tl.store(DRESIDUAL_IN + cols, dx, mask=mask)\n        if HAS_DX1:\n            if HAS_DROPOUT:\n                keep_mask = (\n                    tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p\n                )\n                dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)\n            else:\n                dx1 = dx\n            tl.store(DX1 + cols, dx1, mask=mask)\n        if HAS_DROPOUT:\n            keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p\n            dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)\n        if HAS_ROWSCALE:\n            rowscale = tl.load(ROWSCALE + row).to(tl.float32)\n            dx *= rowscale\n        tl.store(DX + cols, dx, mask=mask)\n\n        X += stride_x_row\n        if HAS_DRESIDUAL:\n            DRESIDUAL += stride_dres_row\n        if STORE_DRESIDUAL:\n            DRESIDUAL_IN += stride_dres_in_row\n        if RECOMPUTE_OUTPUT:\n            Y += stride_y_row\n        DY += stride_dy_row\n        DX += stride_dx_row\n        if HAS_DY1:\n            DY1 += stride_dy1_row\n        if HAS_DX1:\n            DX1 += stride_dx1_row\n    tl.store(DW + row_block_id * N + cols, dw, mask=mask)\n    if HAS_BIAS:\n        tl.store(DB + row_block_id * N + cols, db, mask=mask)\n    if HAS_DY1:\n        tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask)\n        if HAS_B1:\n            tl.store(DB1 + row_block_id * N + cols, db1, mask=mask)\n\n\ndef _layer_norm_bwd(\n    dy: Tensor,\n    x: Tensor,\n    weight: Tensor,\n    bias: Tensor,\n    eps: float,\n    mean: Tensor,\n    rstd: Tensor,\n    dresidual: Optional[Tensor] = None,\n    dy1: Optional[Tensor] = None,\n    weight1: Optional[Tensor] = None,\n    bias1: Optional[Tensor] = None,\n    seeds: Optional[Tensor] = None,\n    dropout_p: float = 0.0,\n    rowscale: Optional[Tensor] = None,\n    has_residual: bool = False,\n    has_x1: bool = False,\n    zero_centered_weight: bool = False,\n    is_rms_norm: bool = False,\n    x_dtype: Optional[torch.dtype] = None,\n    recompute_output: bool = False,\n) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor):\n    # Need to wrap to handle the case where dresidual_in or dx1 are aliases of x,\n    # which makes torch.library unhappy\n    dx, dw, db, dresidual_in, dx1, dw1, db1, y = _layer_norm_bwd_impl(\n        dy,\n        x,\n        weight,\n        bias,\n        eps,\n        mean,\n        rstd,\n        dresidual,\n        dy1,\n        weight1,\n        bias1,\n        seeds,\n        dropout_p,\n        rowscale,\n        has_residual,\n        has_x1,\n        zero_centered_weight,\n        is_rms_norm,\n        x_dtype=x_dtype,\n        recompute_output=recompute_output,\n    )\n    # Don't need to compute dresidual_in separately in this case\n    if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None:\n        dresidual_in = dx\n    if has_x1 and dropout_p == 0.0:\n        dx1 = dx\n    return dx, dw, db, dresidual_in, dx1, dw1, db1, y\n\n\n\n@triton_op(\"flash_attn::layer_norm_bwd_impl\", mutates_args={},\n           schema=\"(Tensor dy, Tensor x, Tensor weight, Tensor bias, float eps, Tensor mean, Tensor rstd, Tensor? dresidual, Tensor? dy1, Tensor? weight1, Tensor? bias1, Tensor? seeds, float dropout_p, Tensor? rowscale, bool has_residual, bool has_x1, bool zero_centered_weight, bool is_rms_norm, ScalarType? x_dtype, bool recompute_output) -> (Tensor dx, Tensor dw, Tensor db, Tensor dresidual_in, Tensor dx1, Tensor dw1, Tensor db1, Tensor y)\",\n           allow_decomposition=False,  # Don't let torch.compile trace inside\n           )\ndef _layer_norm_bwd_impl(\n    dy: Tensor,\n    x: Tensor,\n    weight: Tensor,\n    bias: Tensor,\n    eps: float,\n    mean: Tensor,\n    rstd: Tensor,\n    dresidual: Optional[Tensor] = None,\n    dy1: Optional[Tensor] = None,\n    weight1: Optional[Tensor] = None,\n    bias1: Optional[Tensor] = None,\n    seeds: Optional[Tensor] = None,\n    dropout_p: float = 0.0,\n    rowscale: Optional[Tensor] = None,\n    has_residual: bool = False,\n    has_x1: bool = False,\n    zero_centered_weight: bool = False,\n    is_rms_norm: bool = False,\n    x_dtype: Optional[torch.dtype] = None,\n    recompute_output: bool = False,\n) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor):\n    M, N = x.shape\n    assert x.stride(-1) == 1\n    dy = maybe_contiguous_lastdim(dy)\n    assert dy.stride(-1) == 1\n    assert dy.shape == (M, N)\n    if dresidual is not None:\n        dresidual = maybe_contiguous_lastdim(dresidual)\n        assert dresidual.stride(-1) == 1\n        assert dresidual.shape == (M, N)\n    assert weight.shape == (N,)\n    assert weight.stride(-1) == 1\n    if bias is not None:\n        assert bias.stride(-1) == 1\n        assert bias.shape == (N,)\n    if dy1 is not None:\n        dy1 = maybe_contiguous_lastdim(dy1)\n        assert weight1 is not None\n        assert dy1.shape == dy.shape\n        assert dy1.stride(-1) == 1\n    if weight1 is not None:\n        assert weight1.shape == (N,)\n        assert weight1.stride(-1) == 1\n    if bias1 is not None:\n        assert bias1.shape == (N,)\n        assert bias1.stride(-1) == 1\n    if seeds is not None:\n        assert seeds.is_contiguous()\n        assert seeds.shape == (M if not has_x1 else M * 2,)\n    if rowscale is not None:\n        assert rowscale.is_contiguous()\n        assert rowscale.shape == (M,)\n    # allocate output\n    dx = (\n        torch.empty_like(x)\n        if x_dtype is None\n        else torch.empty(M, N, dtype=x_dtype, device=x.device)\n    )\n    dresidual_in = (\n        torch.empty_like(x)\n        if has_residual\n        and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1)\n        else None\n    )\n    dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None\n    y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None\n    if recompute_output:\n        assert weight1 is None, \"recompute_output is not supported with parallel LayerNorm\"\n\n    # Less than 64KB per feature: enqueue fused kernel\n    MAX_FUSED_SIZE = 65536 // x.element_size()\n    BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n    if N > BLOCK_N:\n        raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n    # Increasing the multiple (e.g. 8) will allow more thread blocks to be launched and hide the\n    # latency of the gmem reads/writes, but will increase the time of summing up dw / db.\n    sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count * 8\n    _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)\n    _db = (\n        torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)\n        if bias is not None\n        else None\n    )\n    _dw1 = torch.empty_like(_dw) if weight1 is not None else None\n    _db1 = torch.empty_like(_db) if bias1 is not None else None\n    rows_per_program = math.ceil(M / sm_count)\n    grid = (sm_count,)\n    with torch.cuda.device(x.device.index):\n        torch.library.wrap_triton(_layer_norm_bwd_kernel)[grid](\n            x,\n            weight,\n            bias,\n            y,\n            dy,\n            dx,\n            _dw,\n            _db,\n            dresidual,\n            weight1,\n            dy1,\n            dx1,\n            _dw1,\n            _db1,\n            dresidual_in,\n            rowscale,\n            seeds,\n            mean,\n            rstd,\n            x.stride(0),\n            0 if not recompute_output else y.stride(0),\n            dy.stride(0),\n            dx.stride(0),\n            dresidual.stride(0) if dresidual is not None else 0,\n            dy1.stride(0) if dy1 is not None else 0,\n            dx1.stride(0) if dx1 is not None else 0,\n            dresidual_in.stride(0) if dresidual_in is not None else 0,\n            M,\n            N,\n            eps,\n            dropout_p,\n            # Passing bool make torch inductor very unhappy since it then tries to compare to int_max\n            int(zero_centered_weight),\n            rows_per_program,\n            is_rms_norm,\n            BLOCK_N,\n            dresidual is not None,\n            dresidual_in is not None,\n            bias is not None,\n            dropout_p > 0.0,\n            HAS_ROWSCALE=rowscale is not None,\n            HAS_DY1=dy1 is not None,\n            HAS_DX1=dx1 is not None,\n            HAS_B1=bias1 is not None,\n            RECOMPUTE_OUTPUT=y is not None,\n        )\n    dw = _dw.sum(0).to(weight.dtype)\n    db = _db.sum(0).to(bias.dtype) if bias is not None else None\n    dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None\n    db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None\n    # dresidual_in and dx1 could be None, the wrapper will handle assigning them from dx\n    return dx, dw, db, dresidual_in, dx1, dw1, db1, y\n\n\nclass LayerNormFn(torch.autograd.Function):\n\n    @staticmethod\n    def forward(\n        ctx,\n        x,\n        weight,\n        bias,\n        residual=None,\n        x1=None,\n        weight1=None,\n        bias1=None,\n        eps=1e-6,\n        dropout_p=0.0,\n        rowscale=None,\n        prenorm=False,\n        residual_in_fp32=False,\n        zero_centered_weight=False,\n        is_rms_norm=False,\n        return_dropout_mask=False,\n        out_dtype=None,\n        out=None,\n        residual_out=None\n    ):\n        x_shape_og = x.shape\n        # reshape input data into 2D tensor\n        x = maybe_contiguous_lastdim(x.reshape(-1, x.shape[-1]))\n        if residual is not None:\n            assert residual.shape == x_shape_og\n            residual = maybe_contiguous_lastdim(residual.reshape(-1, residual.shape[-1]))\n        if x1 is not None:\n            assert x1.shape == x_shape_og\n            assert rowscale is None, \"rowscale is not supported with parallel LayerNorm\"\n            x1 = maybe_contiguous_lastdim(x1.reshape(-1, x1.shape[-1]))\n        weight = weight.contiguous()\n        bias = maybe_contiguous(bias)\n        weight1 = maybe_contiguous(weight1)\n        bias1 = maybe_contiguous(bias1)\n        if rowscale is not None:\n            rowscale = rowscale.reshape(-1).contiguous()\n        residual_dtype = (\n            residual.dtype\n            if residual is not None\n            else (torch.float32 if residual_in_fp32 else None)\n        )\n        if out is not None:\n            out = out.reshape(-1, out.shape[-1])\n        if residual_out is not None:\n            residual_out = residual_out.reshape(-1, residual_out.shape[-1])\n        y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd(\n            x,\n            weight,\n            bias,\n            eps,\n            residual,\n            x1,\n            weight1,\n            bias1,\n            dropout_p=dropout_p,\n            rowscale=rowscale,\n            out_dtype=out_dtype,\n            residual_dtype=residual_dtype,\n            zero_centered_weight=zero_centered_weight,\n            is_rms_norm=is_rms_norm,\n            return_dropout_mask=return_dropout_mask,\n            out=out,\n            residual_out=residual_out,\n        )\n        ctx.save_for_backward(\n            residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd\n        )\n        ctx.x_shape_og = x_shape_og\n        ctx.eps = eps\n        ctx.dropout_p = dropout_p\n        ctx.is_rms_norm = is_rms_norm\n        ctx.has_residual = residual is not None\n        ctx.has_x1 = x1 is not None\n        ctx.prenorm = prenorm\n        ctx.x_dtype = x.dtype\n        ctx.zero_centered_weight = zero_centered_weight\n        y = y.reshape(x_shape_og)\n        y1 = y1.reshape(x_shape_og) if y1 is not None else None\n        residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None\n        dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None\n        dropout_mask1 = dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None\n        if not return_dropout_mask:\n            if weight1 is None:\n                return y if not prenorm else (y, residual_out)\n            else:\n                return (y, y1) if not prenorm else (y, y1, residual_out)\n        else:\n            if weight1 is None:\n                return (\n                    (y, dropout_mask, dropout_mask1)\n                    if not prenorm\n                    else (y, residual_out, dropout_mask, dropout_mask1)\n                )\n            else:\n                return (\n                    (y, y1, dropout_mask, dropout_mask1)\n                    if not prenorm\n                    else (y, y1, residual_out, dropout_mask, dropout_mask1)\n                )\n\n    @staticmethod\n    def backward(ctx, dy, *args):\n        x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors\n        dy = dy.reshape(-1, dy.shape[-1])\n        if weight1 is not None:\n            dy1, args = args[0], args[1:]\n            dy1 = dy1.reshape(-1, dy1.shape[-1])\n            assert dy1.shape == x.shape\n        else:\n            dy1 = None\n        if ctx.prenorm:\n            dresidual = args[0]\n            dresidual = dresidual.reshape(-1, dresidual.shape[-1])\n            assert dresidual.shape == x.shape\n        else:\n            dresidual = None\n        dx, dw, db, dresidual_in, dx1, dw1, db1, _ = _layer_norm_bwd(\n            dy,\n            x,\n            weight,\n            bias,\n            ctx.eps,\n            mean,\n            rstd,\n            dresidual,\n            dy1,\n            weight1,\n            bias1,\n            seeds,\n            ctx.dropout_p,\n            rowscale,\n            ctx.has_residual,\n            ctx.has_x1,\n            ctx.zero_centered_weight,\n            ctx.is_rms_norm,\n            x_dtype=ctx.x_dtype,\n            recompute_output=False,\n        )\n        return (\n            dx.reshape(ctx.x_shape_og),\n            dw,\n            db,\n            dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,\n            dx1.reshape(ctx.x_shape_og) if dx1 is not None else None,\n            dw1,\n            db1,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n        )\n\n\ndef layer_norm_fn(\n    x,\n    weight,\n    bias,\n    residual=None,\n    x1=None,\n    weight1=None,\n    bias1=None,\n    eps=1e-6,\n    dropout_p=0.0,\n    rowscale=None,\n    prenorm=False,\n    residual_in_fp32=False,\n    zero_centered_weight=False,\n    is_rms_norm=False,\n    return_dropout_mask=False,\n    out_dtype=None,\n    out=None,\n    residual_out=None\n):\n    return LayerNormFn.apply(\n        x,\n        weight,\n        bias,\n        residual,\n        x1,\n        weight1,\n        bias1,\n        eps,\n        dropout_p,\n        rowscale,\n        prenorm,\n        residual_in_fp32,\n        zero_centered_weight,\n        is_rms_norm,\n        return_dropout_mask,\n        out_dtype,\n        out,\n        residual_out\n    )\n\n\ndef rms_norm_fn(\n    x,\n    weight,\n    bias,\n    residual=None,\n    x1=None,\n    weight1=None,\n    bias1=None,\n    eps=1e-6,\n    dropout_p=0.0,\n    rowscale=None,\n    prenorm=False,\n    residual_in_fp32=False,\n    zero_centered_weight=False,\n    return_dropout_mask=False,\n    out_dtype=None,\n    out=None,\n    residual_out=None\n):\n    return LayerNormFn.apply(\n        x,\n        weight,\n        bias,\n        residual,\n        x1,\n        weight1,\n        bias1,\n        eps,\n        dropout_p,\n        rowscale,\n        prenorm,\n        residual_in_fp32,\n        zero_centered_weight,\n        True,\n        return_dropout_mask,\n        out_dtype,\n        out,\n        residual_out\n    )\n\n\nclass RMSNorm(torch.nn.Module):\n\n    def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, zero_centered_weight=False,\n                 device=None, dtype=None):\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        self.eps = eps\n        if dropout_p > 0.0:\n            self.drop = torch.nn.Dropout(dropout_p)\n        else:\n            self.drop = None\n        self.zero_centered_weight = zero_centered_weight\n        self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))\n        self.register_parameter(\"bias\", None)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        if not self.zero_centered_weight:\n            torch.nn.init.ones_(self.weight)\n        else:\n            torch.nn.init.zeros_(self.weight)\n\n    def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):\n        return rms_norm_fn(\n            x,\n            self.weight,\n            self.bias,\n            residual=residual,\n            eps=self.eps,\n            dropout_p=self.drop.p if self.drop is not None and self.training else 0.0,\n            prenorm=prenorm,\n            residual_in_fp32=residual_in_fp32,\n            zero_centered_weight=self.zero_centered_weight,\n        )\n\n\nclass LayerNormLinearFn(torch.autograd.Function):\n\n    @staticmethod\n    @custom_fwd\n    def forward(\n        ctx,\n        x,\n        norm_weight,\n        norm_bias,\n        linear_weight,\n        linear_bias,\n        residual=None,\n        eps=1e-6,\n        prenorm=False,\n        residual_in_fp32=False,\n        is_rms_norm=False,\n    ):\n        x_shape_og = x.shape\n        # reshape input data into 2D tensor\n        x = maybe_contiguous_lastdim(x.reshape(-1, x.shape[-1]))\n        if residual is not None:\n            assert residual.shape == x_shape_og\n            residual = maybe_contiguous_lastdim(residual.reshape(-1, residual.shape[-1]))\n        norm_weight = norm_weight.contiguous()\n        norm_bias = maybe_contiguous(norm_bias)\n        residual_dtype = (\n            residual.dtype\n            if residual is not None\n            else (torch.float32 if residual_in_fp32 else None)\n        )\n        y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd(\n            x,\n            norm_weight,\n            norm_bias,\n            eps,\n            residual,\n            out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_dtype(\"cuda\"),\n            residual_dtype=residual_dtype,\n            is_rms_norm=is_rms_norm,\n        )\n        y = y.reshape(x_shape_og)\n        dtype = torch.get_autocast_dtype(\"cuda\") if torch.is_autocast_enabled() else y.dtype\n        linear_weight = linear_weight.to(dtype)\n        linear_bias = linear_bias.to(dtype) if linear_bias is not None else None\n        out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)\n        # We don't store y, will be recomputed in the backward pass to save memory\n        ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd)\n        ctx.x_shape_og = x_shape_og\n        ctx.eps = eps\n        ctx.is_rms_norm = is_rms_norm\n        ctx.has_residual = residual is not None\n        ctx.prenorm = prenorm\n        ctx.x_dtype = x.dtype\n        ctx.linear_bias_is_none = linear_bias is None\n        return out if not prenorm else (out, residual_out.reshape(x_shape_og))\n\n    @staticmethod\n    @custom_bwd\n    def backward(ctx, dout, *args):\n        x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors\n        dout = dout.reshape(-1, dout.shape[-1])\n        dy = F.linear(dout, linear_weight.t())\n        dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)\n        dy = maybe_contiguous_lastdim(dy)\n        assert dy.shape == x.shape\n        if ctx.prenorm:\n            dresidual = args[0]\n            dresidual = maybe_contiguous_lastdim(dresidual.reshape(-1, dresidual.shape[-1]))\n            assert dresidual.shape == x.shape\n        else:\n            dresidual = None\n        dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd(\n            dy,\n            x,\n            norm_weight,\n            norm_bias,\n            ctx.eps,\n            mean,\n            rstd,\n            dresidual=dresidual,\n            has_residual=ctx.has_residual,\n            is_rms_norm=ctx.is_rms_norm,\n            x_dtype=ctx.x_dtype,\n            recompute_output=True,\n        )\n        dlinear_weight = torch.einsum(\"bo,bi->oi\", dout, y)\n        return (\n            dx.reshape(ctx.x_shape_og),\n            dnorm_weight,\n            dnorm_bias,\n            dlinear_weight,\n            dlinear_bias,\n            dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,\n            None,\n            None,\n            None,\n            None,\n        )\n\n\ndef layer_norm_linear_fn(\n    x,\n    norm_weight,\n    norm_bias,\n    linear_weight,\n    linear_bias,\n    residual=None,\n    eps=1e-6,\n    prenorm=False,\n    residual_in_fp32=False,\n    is_rms_norm=False,\n):\n    return LayerNormLinearFn.apply(\n        x,\n        norm_weight,\n        norm_bias,\n        linear_weight,\n        linear_bias,\n        residual,\n        eps,\n        prenorm,\n        residual_in_fp32,\n        is_rms_norm,\n    )\n"
  },
  {
    "path": "flash_attn/ops/triton/linear.py",
    "content": "# Adapted from https://github.com/ELS-RD/kernl/blob/main/src/kernl/implementations/linear_layer.py\n# and https://github.com/openai/triton/blob/master/python/triton/ops/matmul.py\nfrom typing import Optional\n\nimport torch\nimport triton\nimport triton.language as tl\nfrom triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time\n\nfrom flash_attn.ops.triton.k_activations import (\n    gelu,\n    gelu_approx,\n    gelu_approx_grad,\n    gelu_grad,\n    squared_relu,\n    squared_relu_grad,\n)\n\n# CREDITS: Initially inspired by the Triton tutorial on matrix multiplications\n\n\ndef init_to_zero(name):\n    return lambda nargs: nargs[name].zero_()\n\n\ndef get_configs_io_bound():\n    configs = []\n    for num_stages in [2, 3, 4, 5, 6]:\n        for block_m in [16, 32]:\n            for block_k in [32, 64]:\n                for block_n in [32, 64, 128, 256]:\n                    num_warps = 2 if block_n <= 64 else 4\n                    configs.append(\n                        triton.Config(\n                            {\n                                \"BLOCK_M\": block_m,\n                                \"BLOCK_N\": block_n,\n                                \"BLOCK_K\": block_k,\n                                \"SPLIT_K\": 1,\n                            },\n                            num_stages=num_stages,\n                            num_warps=num_warps,\n                        )\n                    )\n                    # split_k not used\n                    # for split_k in [2, 4, 8, 16]:\n                    #     configs.append(triton.Config(\n                    #         {'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},\n                    #         num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))\n    return configs\n\n\n@triton.autotune(\n    configs=[\n        triton.Config(\n            {\"BLOCK_M\": 128, \"BLOCK_N\": 256, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=3, num_warps=8\n        ),\n        triton.Config(\n            {\"BLOCK_M\": 256, \"BLOCK_N\": 128, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=3, num_warps=8\n        ),\n        triton.Config(\n            {\"BLOCK_M\": 256, \"BLOCK_N\": 64, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n        ),\n        triton.Config(\n            {\"BLOCK_M\": 64, \"BLOCK_N\": 256, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n        ),\n        triton.Config(\n            {\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n        ),\n        triton.Config(\n            {\"BLOCK_M\": 128, \"BLOCK_N\": 64, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n        ),\n        triton.Config(\n            {\"BLOCK_M\": 64, \"BLOCK_N\": 128, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n        ),\n        triton.Config(\n            {\"BLOCK_M\": 128, \"BLOCK_N\": 32, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n        ),\n        triton.Config(\n            {\"BLOCK_M\": 64, \"BLOCK_N\": 32, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=5, num_warps=2\n        ),\n        # good for int8\n        triton.Config(\n            {\"BLOCK_M\": 128, \"BLOCK_N\": 256, \"BLOCK_K\": 128, \"SPLIT_K\": 1},\n            num_stages=3,\n            num_warps=8,\n        ),\n        triton.Config(\n            {\"BLOCK_M\": 256, \"BLOCK_N\": 128, \"BLOCK_K\": 128, \"SPLIT_K\": 1},\n            num_stages=3,\n            num_warps=8,\n        ),\n        triton.Config(\n            {\"BLOCK_M\": 256, \"BLOCK_N\": 64, \"BLOCK_K\": 128, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n        ),\n        triton.Config(\n            {\"BLOCK_M\": 64, \"BLOCK_N\": 256, \"BLOCK_K\": 128, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n        ),\n        triton.Config(\n            {\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"BLOCK_K\": 128, \"SPLIT_K\": 1},\n            num_stages=4,\n            num_warps=4,\n        ),\n        triton.Config(\n            {\"BLOCK_M\": 128, \"BLOCK_N\": 64, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n        ),\n        triton.Config(\n            {\"BLOCK_M\": 64, \"BLOCK_N\": 128, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n        ),\n        triton.Config(\n            {\"BLOCK_M\": 128, \"BLOCK_N\": 32, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n        ),\n        triton.Config(\n            {\"BLOCK_M\": 64, \"BLOCK_N\": 32, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=5, num_warps=2\n        ),\n    ]\n    + get_configs_io_bound(),\n    key=[\"CACHE_KEY_M\", \"CACHE_KEY_N\", \"CACHE_KEY_K\"],\n    prune_configs_by={\n        \"early_config_prune\": early_config_prune,\n        \"perf_model\": estimate_matmul_time,\n        \"top_k\": 10,\n    },\n)\n@triton.heuristics(\n    {\n        \"EVEN_K\": lambda args: args[\"K\"] % (args[\"BLOCK_K\"] * args[\"SPLIT_K\"]) == 0,\n    }\n)\n@triton.jit\ndef kernel_fwd(\n    C,  # Pointers to matrices\n    ACT_INPUT,\n    A,\n    B,\n    bias,\n    # Matrix dimensions\n    M,\n    N,\n    K,\n    CACHE_KEY_M,\n    CACHE_KEY_N,\n    CACHE_KEY_K,\n    # The stride variables represent how much to increase the ptr by when moving by 1\n    # element in a particular dimension. E.g. stride_am is how much to increase a_ptr\n    # by to get the element one row down (A has M rows)\n    stride_cm,\n    # stride_cn,  # Assume that stride_cn == 1\n    stride_am,\n    stride_ak,\n    stride_bn,\n    stride_bk,\n    # Meta-parameters\n    BLOCK_M: tl.constexpr,\n    GROUP_M: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n    BLOCK_K: tl.constexpr,\n    # split k not used, not performant with activation, kept because early_config_prune is expecting it\n    SPLIT_K: tl.constexpr,\n    EVEN_K: tl.constexpr,\n    A_ROWMAJOR: tl.constexpr,\n    B_COLMAJOR: tl.constexpr,\n    BIAS: tl.constexpr,\n    SAVE_ACT_INPUT: tl.constexpr,\n    ACTIVATION: tl.constexpr,\n):\n\n    \"\"\"\n    Kernel for computing Out = activation(A x W + C)\n    - Input has shape (M, K)\n    - Weight has shape (K, N)\n    - Bias has shape (N,)\n    - Output has shape (M, N)\n    - ActInputs (optional) has shape (M, N)\n    'ActInputs' optionally saves the A x W + C intermediate for backward computations\n    This kernel will consolidate over K\n    \"\"\"\n\n    pid = tl.program_id(axis=0)\n\n    grid_m = (M + BLOCK_M - 1) // BLOCK_M\n    grid_n = (N + BLOCK_N - 1) // BLOCK_N\n    # re-order program ID for better L2 performance\n    width = GROUP_M * grid_n\n    group_id = pid // width\n    group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n    pid_m = group_id * GROUP_M + (pid % group_size)\n    pid_n = (pid % width) // (group_size)\n\n    # now compute the block that each program will go through\n    # rm (resp. rn) denotes a range of indices\n    # for rows (resp. col) of C\n    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n    # trick to avoid masking on M and N axis\n    ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n    rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n    rk = tl.arange(0, BLOCK_K)\n\n    if A_ROWMAJOR:\n        A = A + (ram[:, None] * stride_am + rk[None, :])\n    else:\n        A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)\n    if B_COLMAJOR:\n        B = B + (rk[:, None] + rbn[None, :] * stride_bn)\n    else:\n        B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)\n\n    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n\n    for k in range(K, 0, -BLOCK_K):\n        if EVEN_K:\n            a = tl.load(A)\n            b = tl.load(B)\n        else:\n            a = tl.load(A, mask=rk[None, :] < k, other=0.0)\n            b = tl.load(B, mask=rk[:, None] < k, other=0.0)\n        acc += tl.dot(a, b)\n\n        if A_ROWMAJOR:\n            A += BLOCK_K\n        else:\n            A += BLOCK_K * stride_ak\n        if B_COLMAJOR:\n            B += BLOCK_K\n        else:\n            B += BLOCK_K * stride_bk\n\n    # Putting bias after the matmul (instead of before) is faster, idk why\n    if BIAS:\n        bias = tl.load(bias + rn, mask=rn < N, other=0.0).to(tl.float32)\n        acc += bias[None, :]\n\n    # optional: save the activation inputs\n    if SAVE_ACT_INPUT:\n        # act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :] * stride_cn\n        act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :]\n        tl.store(act_in_ptrs, acc)\n\n    # optional: fused activation (while the data is in shared memory)\n    if ACTIVATION == \"gelu\":\n        acc = gelu(acc)\n    elif ACTIVATION == \"gelu_approx\":\n        acc = gelu_approx(acc)\n    elif ACTIVATION == \"squared_relu\":\n        acc = squared_relu(acc)\n    # rematerialize rm and rn to save registers\n    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n\n    # write back result\n    # C = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn\n    C = C + rm[:, None] * stride_cm + rn[None, :]\n    mask = (rm < M)[:, None] & (rn < N)[None, :]\n    tl.store(C, acc)\n\n\ndef triton_linear_act(\n    x: torch.Tensor,\n    weight: torch.Tensor,\n    bias: Optional[torch.Tensor] = None,\n    activation: str = \"id\",\n    save_act_input: bool = False,\n) -> torch.Tensor:\n    \"\"\"\n    Compute e = activation(x @ weight.T + bias).\n    This wrapper kicks the `kernel_fwd` Triton kernel\n    :param x: input tensor\n    :param weight: weight matrix\n    :param bias: an optional bias tensor\n    :param activation: Activation name. Needs to be a Triton kernel.\n    :param act_input: an optional tensor to save the activation inputs (for backward)\n    :return: result tensor\n    \"\"\"\n    # if torch.is_autocast_enabled():\n    #     dtype = torch.get_autocast_gpu_dtype()\n    #     x, weight, bias = [a.to(dtype=dtype) for a in [x, weight, bias]]\n\n    assert activation in [\"id\", \"gelu\", \"gelu_approx\", \"squared_relu\"]\n\n    batch_shape, n = x.shape[:-1], x.shape[-1]\n    batch_dim = batch_shape.numel()\n    x_reshaped = x.reshape(batch_dim, n)\n\n    if x_reshaped.stride(0) > 1 and x_reshaped.stride(1) > 1:\n        x_reshaped = x_reshaped.contiguous()\n    if weight.stride(0) > 1 and weight.stride(1) > 1:\n        weight = weight.contiguous()\n    bias = bias.contiguous() if bias is not None else None\n\n    assert (\n        x.dtype == weight.dtype\n    ), f\"Input and weight must have the same dtype, got {x.dtype} and {weight.dtype}\"\n    if bias is not None:\n        assert (\n            x.dtype == bias.dtype\n        ), f\"Input and bias must have the same dtype, got {x.dtype} and {bias.dtype}\"\n    assert (\n        x_reshaped.shape[1] == weight.shape[1]\n    ), f\"Incompatible dimensions: {x_reshaped.shape} - {weight.shape}\"\n\n    assert (\n        bias is None or bias.shape[0] == weight.shape[0]\n    ), \"Incompatible dimensions in between weight and bias\"\n\n    M, K = x_reshaped.shape\n    N, K = weight.shape\n\n    output = torch.empty((M, N), device=x.device, dtype=x.dtype)\n    act_input = torch.empty_like(output) if save_act_input else None\n\n    # 1D launch kernel where each block gets its own program.\n    grid = lambda META: (triton.cdiv(M, META[\"BLOCK_M\"]) * triton.cdiv(N, META[\"BLOCK_N\"]),)  # noqa\n\n    kernel_fwd[grid](\n        output,\n        act_input,\n        x_reshaped,\n        weight,  # data ptrs\n        bias if bias is not None else x,  # auto skip bias if not present\n        M,  # shapes\n        N,\n        K,\n        M // 32,  # key for triton cache (limit number of compilations)\n        N // 32,\n        K // 32,\n        stride_cm=output.stride(0),  # strides\n        # stride_cn=output.stride(1),\n        stride_am=x_reshaped.stride(0),\n        stride_ak=x_reshaped.stride(1),\n        stride_bk=weight.stride(1),\n        stride_bn=weight.stride(0),\n        BIAS=bias is not None,  # optional fused bias\n        SAVE_ACT_INPUT=save_act_input,  # optional save activation inputs\n        ACTIVATION=activation,  # optional fused activation\n        A_ROWMAJOR=x_reshaped.stride(1) == 1,\n        B_COLMAJOR=weight.stride(1) == 1,\n        GROUP_M=8,  # speed optimization: group the programs\n    )\n\n    if not save_act_input:\n        return output.reshape(*batch_shape, output.shape[-1])\n    else:\n        return (\n            output.reshape(*batch_shape, output.shape[-1]),\n            act_input.reshape(*batch_shape, act_input.shape[-1]),\n        )\n\n\n@triton.autotune(\n    configs=[\n        triton.Config(\n            {\"BLOCK_M\": 128, \"BLOCK_N\": 256, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=3, num_warps=8\n        ),\n        triton.Config(\n            {\"BLOCK_M\": 256, \"BLOCK_N\": 128, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=3, num_warps=8\n        ),\n        triton.Config(\n            {\"BLOCK_M\": 256, \"BLOCK_N\": 64, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n        ),\n        triton.Config(\n            {\"BLOCK_M\": 64, \"BLOCK_N\": 256, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n        ),\n        triton.Config(\n            {\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n        ),\n        triton.Config(\n            {\"BLOCK_M\": 128, \"BLOCK_N\": 64, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n        ),\n        triton.Config(\n            {\"BLOCK_M\": 64, \"BLOCK_N\": 128, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n        ),\n        triton.Config(\n            {\"BLOCK_M\": 128, \"BLOCK_N\": 32, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n        ),\n        triton.Config(\n            {\"BLOCK_M\": 64, \"BLOCK_N\": 32, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=5, num_warps=2\n        ),\n        # good for int8\n        triton.Config(\n            {\"BLOCK_M\": 128, \"BLOCK_N\": 256, \"BLOCK_K\": 128, \"SPLIT_K\": 1},\n            num_stages=3,\n            num_warps=8,\n        ),\n        triton.Config(\n            {\"BLOCK_M\": 256, \"BLOCK_N\": 128, \"BLOCK_K\": 128, \"SPLIT_K\": 1},\n            num_stages=3,\n            num_warps=8,\n        ),\n        triton.Config(\n            {\"BLOCK_M\": 256, \"BLOCK_N\": 64, \"BLOCK_K\": 128, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n        ),\n        triton.Config(\n            {\"BLOCK_M\": 64, \"BLOCK_N\": 256, \"BLOCK_K\": 128, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n        ),\n        triton.Config(\n            {\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"BLOCK_K\": 128, \"SPLIT_K\": 1},\n            num_stages=4,\n            num_warps=4,\n        ),\n        triton.Config(\n            {\"BLOCK_M\": 128, \"BLOCK_N\": 64, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n        ),\n        triton.Config(\n            {\"BLOCK_M\": 64, \"BLOCK_N\": 128, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n        ),\n        triton.Config(\n            {\"BLOCK_M\": 128, \"BLOCK_N\": 32, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=4, num_warps=4\n        ),\n        triton.Config(\n            {\"BLOCK_M\": 64, \"BLOCK_N\": 32, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=5, num_warps=2\n        ),\n    ]\n    + get_configs_io_bound(),\n    key=[\"CACHE_KEY_M\", \"CACHE_KEY_N\", \"CACHE_KEY_K\"],\n    prune_configs_by={\n        \"early_config_prune\": early_config_prune,\n        \"perf_model\": estimate_matmul_time,\n        \"top_k\": 10,\n    },\n)\n@triton.heuristics(\n    {\n        \"EVEN_K\": lambda args: args[\"K\"] % (args[\"BLOCK_K\"] * args[\"SPLIT_K\"]) == 0,\n    }\n)\n@triton.jit\ndef kernel_bwd(\n    C,  # Pointers to matrices\n    ACT_INPUT,\n    A,\n    B,\n    # Matrix dimensions\n    M,\n    N,\n    K,\n    CACHE_KEY_M,\n    CACHE_KEY_N,\n    CACHE_KEY_K,\n    # The stride variables represent how much to increase the ptr by when moving by 1\n    # element in a particular dimension. E.g. stride_am is how much to increase a_ptr\n    # by to get the element one row down (A has M rows)\n    stride_cm,\n    # stride_cn,  # Assume that stride_cn == 1\n    stride_am,\n    stride_ak,\n    stride_bk,\n    stride_bn,\n    # Meta-parameters\n    BLOCK_M: tl.constexpr,\n    GROUP_M: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n    BLOCK_K: tl.constexpr,\n    # split k not used, not performant with activation, kept because early_config_prune is expecting it\n    SPLIT_K: tl.constexpr,\n    EVEN_K: tl.constexpr,\n    ACTIVATION: tl.constexpr,\n):\n\n    \"\"\"\n    Kernel for computing Out = activation(A x W + C)\n    - Input has shape (M, K)\n    - Weight has shape (K, N)\n    - Output has shape (M, N)\n    - ActInputs (optional) has shape (M, N)\n    'ActInputs' optionally saves the A x W + C intermediate for backward computations\n    This kernel will consolidate over K\n    \"\"\"\n\n    pid = tl.program_id(axis=0)\n\n    grid_m = (M + BLOCK_M - 1) // BLOCK_M\n    grid_n = (N + BLOCK_N - 1) // BLOCK_N\n    # re-order program ID for better L2 performance\n    width = GROUP_M * grid_n\n    group_id = pid // width\n    group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n    pid_m = group_id * GROUP_M + (pid % group_size)\n    pid_n = (pid % width) // (group_size)\n\n    # now compute the block that each program will go through\n    # rm (resp. rn) denotes a range of indices\n    # for rows (resp. col) of C\n    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n    # trick to avoid masking on M and N axis\n    ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n    rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n    rk = tl.arange(0, BLOCK_K)\n\n    A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)\n    B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)\n\n    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n\n    for k in range(K, 0, -BLOCK_K):\n        if EVEN_K:\n            a = tl.load(A)\n            b = tl.load(B)\n        else:\n            a = tl.load(A, mask=rk[None, :] < k, other=0.0)\n            b = tl.load(B, mask=rk[:, None] < k, other=0.0)\n        acc += tl.dot(a, b)\n\n        A += BLOCK_K * stride_ak\n        B += BLOCK_K * stride_bk\n\n    # optional: fused activation (while the data is in shared memory)\n    if ACTIVATION != \"id\":\n        act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :]\n        act_input = tl.load(act_in_ptrs).to(acc.dtype)\n    if ACTIVATION == \"gelu\":\n        acc *= gelu_grad(act_input)\n    elif ACTIVATION == \"gelu_approx\":\n        acc *= gelu_approx_grad(act_input)\n    elif ACTIVATION == \"squared_relu\":\n        acc *= squared_relu_grad(act_input)\n\n    # rematerialize rm and rn to save registers\n    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n\n    # write back result\n    C = C + rm[:, None] * stride_cm + rn[None, :]\n    mask = (rm < M)[:, None] & (rn < N)[None, :]\n    tl.store(C, acc, mask=mask)\n\n\ndef triton_dgrad_act(\n    grad_output: torch.Tensor,\n    weight: torch.Tensor,\n    activation: str = \"id\",\n    act_input: Optional[torch.Tensor] = None,\n) -> torch.Tensor:\n    \"\"\"\n    Compute e = activation(grad_output @ weight + bias).\n    This wrapper kicks the `kernel_fwd` Triton kernel\n    :param grad_output: input tensor\n    :param weight: weight matrix\n    :param activation: Activation name. Needs to be a Triton kernel.\n    :param act_input: an optional tensor to save the activation inputs (for backward)\n    :return: result tensor\n    \"\"\"\n    assert activation in [\"id\", \"gelu\", \"gelu_approx\", \"squared_relu\"]\n\n    batch_shape, n = grad_output.shape[:-1], grad_output.shape[-1]\n    batch_dim = batch_shape.numel()\n    grad_output_reshaped = grad_output.reshape(batch_dim, n)\n\n    if grad_output_reshaped.stride(0) > 1 and grad_output_reshaped.stride(1) > 1:\n        grad_output_reshaped = grad_output_reshaped.contiguous()\n    if weight.stride(0) > 1 and weight.stride(1) > 1:\n        weight = weight.contiguous()\n\n    assert (\n        grad_output.dtype == weight.dtype\n    ), f\"grad_output and weight must have the same dtype, got {grad_output.dtype} and {weight.dtype}\"\n    assert (\n        grad_output_reshaped.shape[1] == weight.shape[0]\n    ), f\"Incompatible dimensions: {grad_output_reshaped.shape} - {weight.shape}\"\n    if activation != \"id\":\n        assert act_input is not None, f\"act_input is required for activation {activation}\"\n\n    # M, N, K in bwd are different from M, N, K in fwd\n    M, K = grad_output_reshaped.shape\n    K, N = weight.shape\n\n    grad_input = torch.empty((M, N), device=grad_output.device, dtype=grad_output.dtype)\n\n    # 1D launch kernel where each block gets its own program.\n    grid = lambda META: (triton.cdiv(M, META[\"BLOCK_M\"]) * triton.cdiv(N, META[\"BLOCK_N\"]),)  # noqa\n\n    kernel_bwd[grid](\n        grad_input,\n        act_input,\n        grad_output_reshaped,\n        weight,  # data ptrs\n        M,  # shapes\n        N,\n        K,\n        M // 32,  # key for triton cache (limit number of compilations)\n        N // 32,\n        K // 32,\n        stride_cm=grad_input.stride(0),  # strides\n        # stride_cn=grad_input.stride(1),\n        stride_am=grad_output_reshaped.stride(0),\n        stride_ak=grad_output_reshaped.stride(1),\n        stride_bk=weight.stride(0),\n        stride_bn=weight.stride(1),\n        ACTIVATION=activation,  # optional fused activation\n        GROUP_M=8,  # speed optimization: group the programs\n    )\n\n    return grad_input.reshape(*batch_shape, grad_input.shape[-1])\n"
  },
  {
    "path": "flash_attn/ops/triton/mlp.py",
    "content": "# The triton fused matmul + sqrelu is faster for fp16 but slower for bf16, compared\n# to naive implementation.\nimport fused_dense_lib as fused_dense_cuda\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom flash_attn.utils.torch import custom_fwd, custom_bwd\nfrom flash_attn.ops.activations import sqrelu_bwd, sqrelu_fwd\nfrom flash_attn.ops.triton.linear import triton_dgrad_act, triton_linear_act\n\n\nclass FusedDenseSqreluDenseFunc(torch.autograd.Function):\n    @staticmethod\n    @custom_fwd\n    def forward(ctx, x, weight1, bias1, weight2, bias2, checkpoint_lvl=0):\n        \"\"\"checkpoint_lvl:\n        0: no recomputation in the bwd\n        1: recompute gelu_out in the bwd\n        2: recompute act_input and gelu_out in the bwd\n        \"\"\"\n        if torch.is_autocast_enabled():\n            dtype = torch.get_autocast_gpu_dtype()\n            x, weight1, bias1, weight2, bias2 = [\n                a.to(dtype=dtype) for a in [x, weight1, bias1, weight2, bias2]\n            ]\n        is_bf16 = x.dtype == torch.bfloat16\n        assert checkpoint_lvl in [0, 1, 2]\n        x = x.contiguous()\n        weight1 = weight1.contiguous()\n        bias1 = bias1.contiguous()\n        weight2 = weight2.contiguous()\n        bias2 = bias2.contiguous()\n        batch_shape, n = x.shape[:-1], x.shape[-1]\n        batch_dim = batch_shape.numel()\n        if is_bf16:\n            act_input = fused_dense_cuda.linear_bias_forward(\n                x.reshape(batch_dim, n), weight1, bias1\n            )\n            output1 = sqrelu_fwd(act_input)\n        else:\n            save_act_input = checkpoint_lvl != 2\n            result = triton_linear_act(\n                x.reshape(batch_dim, n),\n                weight1,\n                bias1,\n                activation=\"squared_relu\",\n                save_act_input=save_act_input,\n            )\n            if save_act_input:\n                output1, act_input = result\n            else:\n                output1 = result\n        output2 = fused_dense_cuda.linear_bias_forward(output1, weight2, bias2)\n        ctx.checkpoint_lvl = checkpoint_lvl\n        if checkpoint_lvl == 0:\n            ctx.save_for_backward(x, weight1, bias1, weight2, act_input, output1)\n        elif checkpoint_lvl == 1:\n            ctx.save_for_backward(x, weight1, bias1, weight2, act_input)\n        elif checkpoint_lvl == 2:\n            ctx.save_for_backward(x, weight1, bias1, weight2)\n        return output2.reshape(*batch_shape, output2.shape[-1])\n\n    @staticmethod\n    @custom_bwd\n    def backward(ctx, grad_output):\n        grad_output = grad_output.contiguous()\n        checkpoint_lvl = ctx.checkpoint_lvl\n        x, weight1, bias1, weight2, *rest = ctx.saved_tensors\n        batch_shape, n = x.shape[:-1], x.shape[-1]\n        batch_dim = batch_shape.numel()\n        is_bf16 = x.dtype == torch.bfloat16\n        if checkpoint_lvl == 0:\n            act_input, output1 = rest\n        elif checkpoint_lvl == 1:\n            (act_input,) = rest\n            output1 = sqrelu_fwd(act_input)\n        elif checkpoint_lvl == 2:\n            if is_bf16:\n                act_input = fused_dense_cuda.linear_bias_forward(\n                    x.reshape(batch_dim, n), weight1, bias1\n                )\n                output1 = sqrelu_fwd(act_input)\n            else:\n                output1, act_input = triton_linear_act(\n                    x.reshape(batch_dim, n),\n                    weight1,\n                    bias1,\n                    activation=\"squared_relu\",\n                    save_act_input=True,\n                )\n\n        if is_bf16:\n            grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])\n            grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output)\n            grad_output1 = grad_output @ weight2\n            grad_act_input = sqrelu_bwd(grad_output1, act_input)\n            grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward(\n                x.reshape(batch_dim, n), weight1, grad_act_input\n            )\n        else:\n            grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])\n            grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output)\n            grad_act_input = triton_dgrad_act(\n                grad_output, weight2, activation=\"squared_relu\", act_input=act_input\n            )\n            grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward(\n                x.reshape(batch_dim, n), weight1, grad_act_input\n            )\n        return grad_input.reshape_as(x), grad_weight1, grad_bias1, grad_weight2, grad_bias2, None\n\n\nfused_dense_sqrelu_dense_function = FusedDenseSqreluDenseFunc.apply\n\n\nclass FusedDenseSqreluDense(nn.Module):\n    def __init__(\n        self,\n        in_features,\n        hidden_features=None,\n        out_features=None,\n        bias1=True,\n        bias2=True,\n        checkpoint_lvl=0,\n        device=None,\n        dtype=None,\n    ):\n        \"\"\"\n        checkpoint_lvl (increasing lvl means slower but more memory saving):\n            0: no recomputation in the bwd\n            1: recompute gelu_out in the bwd\n            2: recompute gelu_in and gelu_out in the bwd\n        \"\"\"\n        assert checkpoint_lvl in [0, 1, 2]\n        factory_kwargs = {\"device\": device, \"dtype\": dtype}\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features * 4\n        assert bias1 == True, \"DenseSqreluDense module without bias is currently not supported\"\n        assert bias2 == True, \"DenseSqreluDense module without bias is currently not supported\"\n        self.checkpoint_lvl = checkpoint_lvl\n        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)\n        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)\n\n    def forward(self, x):\n        assert x.is_cuda\n        return fused_dense_sqrelu_dense_function(\n            x, self.fc1.weight, self.fc1.bias, self.fc2.weight, self.fc2.bias, self.checkpoint_lvl\n        )\n"
  },
  {
    "path": "flash_attn/ops/triton/rotary.py",
    "content": "# Copyright (c) 2025, Tri Dao.\n# As of 2025-04-23, we require triton >= 3.0\n\nfrom typing import Optional, Union\n\nimport torch\n\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef rotary_kernel(\n    OUT,  # Pointers to matrices\n    X,\n    COS,\n    SIN,\n    CU_SEQLENS,\n    SEQLEN_OFFSETS,  # this could be int or a pointer\n    # Matrix dimensions\n    seqlen,\n    nheads,\n    seqlen_ro,\n    # strides\n    stride_out_batch,\n    stride_out_seqlen,\n    stride_out_nheads,\n    stride_out_headdim,\n    stride_x_batch,\n    stride_x_seqlen,\n    stride_x_nheads,\n    stride_x_headdim,\n    # Meta-parameters\n    # We want ROTARY_DIM to be constexpr, otherwise the triton compiler doesn't know that\n    # the mask is constant every 8 elements, and it will generate LDG.16 instead of LDG.128\n    ROTARY_DIM: tl.constexpr,\n    IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,\n    IS_VARLEN: tl.constexpr,\n    INTERLEAVED: tl.constexpr,\n    CONJUGATE: tl.constexpr,\n    BLOCK_H: tl.constexpr,\n    BLOCK_M: tl.constexpr,\n):\n    BLOCK_K: tl.constexpr = triton.next_power_of_2(ROTARY_DIM)\n    ROTARY_DIM_HALF = ROTARY_DIM // 2\n    pid_head = tl.program_id(axis=0)\n    pid_m = tl.program_id(axis=1)\n    pid_batch = tl.program_id(axis=2)\n\n    if not IS_VARLEN:\n        X = X + pid_batch * stride_x_batch\n        OUT = OUT + pid_batch * stride_out_batch\n    else:\n        start_idx = tl.load(CU_SEQLENS + pid_batch)\n        seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx\n        X = X + start_idx * stride_x_seqlen\n        OUT = OUT + start_idx * stride_out_seqlen\n\n    if pid_m * BLOCK_M >= seqlen:\n        return\n\n    rh = pid_head * BLOCK_H + tl.arange(0, BLOCK_H)\n    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n    if not IS_SEQLEN_OFFSETS_TENSOR:\n        rm_cs = rm + SEQLEN_OFFSETS\n    else:\n        rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch)\n\n    rk_half = tl.arange(0, BLOCK_K // 2)\n    COS = COS + (rm_cs[:, None] * ROTARY_DIM_HALF + rk_half[None, :])\n    SIN = SIN + (rm_cs[:, None] * ROTARY_DIM_HALF + rk_half[None, :])\n    mask_cs = (rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < ROTARY_DIM_HALF)\n    cos = tl.load(COS, mask=mask_cs, other=1.0).to(tl.float32)\n    sin = tl.load(SIN, mask=mask_cs, other=0.0).to(tl.float32)\n    if CONJUGATE:\n        sin = -sin\n\n    if not INTERLEAVED:\n        # Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT\n        X = X + (rh[:, None, None] * stride_x_nheads + rm[None, :, None] * stride_x_seqlen + rk_half[None, None, :] * stride_x_headdim)\n        OUT = OUT + (rh[:, None, None] * stride_out_nheads + rm[None, :, None] * stride_out_seqlen + rk_half[None, None, :] * stride_out_headdim)\n        mask = (rh[:, None, None] < nheads) & (rm[None, :, None] < seqlen) & (rk_half[None, None, :] < ROTARY_DIM_HALF)\n        x0 = tl.load(X, mask=mask, other=0.0).to(tl.float32)\n        x1 = tl.load(X + ROTARY_DIM_HALF * stride_x_headdim, mask=mask, other=0.0,).to(tl.float32)\n        o0 = x0 * cos - x1 * sin\n        o1 = x0 * sin + x1 * cos\n        tl.store(OUT, o0, mask=mask)\n        tl.store(OUT + ROTARY_DIM_HALF * stride_out_headdim, o1, mask=mask)\n    else:\n        rk = tl.arange(0, BLOCK_K)\n        X = X + (rh[:, None, None] * stride_x_nheads + rm[None, :, None] * stride_x_seqlen + rk[None, None, :] * stride_x_headdim)\n        OUT = OUT + (rh[:, None, None] * stride_out_nheads + rm[None, :, None] * stride_out_seqlen + rk[None, None, :] * stride_out_headdim)\n        mask = (rh[:, None, None] < nheads) & (rm[None, :, None] < seqlen) & (rk[None, None, :] < ROTARY_DIM)\n        x = tl.load(X, mask=mask, other=0.0).to(tl.float32)\n        x0, x1 = tl.split(tl.reshape(x, [BLOCK_H, BLOCK_M, BLOCK_K // 2, 2]))\n        o0 = x0 * cos - x1 * sin\n        o1 = x0 * sin + x1 * cos\n        o = tl.reshape(tl.join(o0, o1), [BLOCK_H, BLOCK_M, BLOCK_K])\n        tl.store(OUT, o, mask=mask)\n\n\ndef apply_rotary(\n    x: torch.Tensor,\n    cos: torch.Tensor,\n    sin: torch.Tensor,\n    seqlen_offsets: Union[int, torch.Tensor] = 0,\n    cu_seqlens: Optional[torch.Tensor] = None,\n    max_seqlen: Optional[int] = None,\n    interleaved=False,\n    inplace=False,\n    conjugate=False,\n) -> torch.Tensor:\n    \"\"\"\n    Arguments:\n        x: (batch, seqlen, nheads, headdim) if cu_seqlens is None\n            else (total_seqlen, nheads, headdim).\n        cos: (seqlen_ro, rotary_dim / 2)\n        sin: (seqlen_ro, rotary_dim / 2)\n        seqlen_offsets: integer or integer tensor of size (batch,)\n        cu_seqlens: (batch + 1,) or None\n        max_seqlen: int\n    Returns:\n        y: (batch, seqlen, nheads, headdim)\n    \"\"\"\n    is_varlen = cu_seqlens is not None\n    if not is_varlen:\n        batch, seqlen, nheads, headdim = x.shape\n    else:\n        assert max_seqlen is not None, \"If cu_seqlens is passed in, then max_seqlen must be passed\"\n        total_seqlen, nheads, headdim = x.shape\n        batch_p_1 = cu_seqlens.shape[0]\n        batch = batch_p_1 - 1\n        seqlen = max_seqlen\n    seqlen_ro, rotary_dim = cos.shape\n    assert sin.shape == cos.shape\n    rotary_dim *= 2\n    assert rotary_dim <= headdim, \"rotary_dim must be <= headdim\"\n    assert headdim <= 256, \"Only support headdim <= 256\"\n    assert seqlen_ro >= seqlen, \"seqlen_ro must be >= seqlen\"\n\n    cos, sin = cos.contiguous(), sin.contiguous()\n    if isinstance(seqlen_offsets, torch.Tensor):\n        assert seqlen_offsets.shape == (batch,)\n        assert seqlen_offsets.dtype in [torch.int32, torch.int64]\n        seqlen_offsets = seqlen_offsets.contiguous()\n    else:\n        assert seqlen_offsets + seqlen <= seqlen_ro\n\n    output = torch.empty_like(x) if not inplace else x\n    if rotary_dim < headdim and not inplace:\n        output[..., rotary_dim:].copy_(x[..., rotary_dim:])\n\n    grid = lambda META: (triton.cdiv(nheads, META[\"BLOCK_H\"]), triton.cdiv(seqlen, META[\"BLOCK_M\"]), batch)  # noqa\n    BLOCK_M = 8 if rotary_dim <= 128 else 4\n\n    # Need this, otherwise Triton tries to launch from cuda:0 and we get\n    # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)\n    with torch.cuda.device(x.device.index):\n        torch.library.wrap_triton(rotary_kernel)[grid](\n            output,  # data ptrs\n            x,\n            cos,\n            sin,\n            cu_seqlens,\n            seqlen_offsets,\n            seqlen,  # shapes\n            nheads,\n            seqlen_ro,\n            output.stride(0) if not is_varlen else 0,  # batch_strides if not varlen else 0\n            output.stride(-3),  # seqlen_stride or total_seqlen_stride\n            output.stride(-2),  # nheads_stride\n            output.stride(-1),  # headdim_stride\n            x.stride(0) if not is_varlen else 0,  # batch_strides if not varlen else 0\n            x.stride(-3),  # seqlen stride or total_seqlen_stride\n            x.stride(-2),  # nheads stride\n            x.stride(-1),  # headdim stride\n            rotary_dim,\n            isinstance(seqlen_offsets, torch.Tensor),\n            is_varlen,\n            interleaved,\n            conjugate,\n            BLOCK_M=BLOCK_M,\n            BLOCK_H=2,\n        )\n    return output\n"
  },
  {
    "path": "flash_attn/pyproject.toml",
    "content": "[tool.black]\nline-length = 100\ntarget-version = 'py39'\n[tool.ruff]\nline-length = 100\ntarget-version = 'py39'"
  },
  {
    "path": "flash_attn/utils/__init__.py",
    "content": ""
  },
  {
    "path": "flash_attn/utils/benchmark.py",
    "content": "# Copyright (c) 2023, Tri Dao.\n\"\"\" Useful functions for writing test code. \"\"\"\n\nimport torch\nimport torch.utils.benchmark as benchmark\n\n\ndef benchmark_forward(\n    fn, *inputs, repeats=10, desc=\"\", verbose=True, amp=False, amp_dtype=torch.float16, **kwinputs\n):\n    \"\"\"Use Pytorch Benchmark on the forward pass of an arbitrary function.\"\"\"\n    if verbose:\n        print(desc, \"- Forward pass\")\n\n    def amp_wrapper(*inputs, **kwinputs):\n        with torch.autocast(device_type=\"cuda\", dtype=amp_dtype, enabled=amp):\n            fn(*inputs, **kwinputs)\n\n    t = benchmark.Timer(\n        stmt=\"fn_amp(*inputs, **kwinputs)\",\n        globals={\"fn_amp\": amp_wrapper, \"inputs\": inputs, \"kwinputs\": kwinputs},\n        num_threads=torch.get_num_threads(),\n    )\n    m = t.timeit(repeats)\n    if verbose:\n        print(m)\n    return t, m\n\n\ndef benchmark_backward(\n    fn,\n    *inputs,\n    grad=None,\n    repeats=10,\n    desc=\"\",\n    verbose=True,\n    amp=False,\n    amp_dtype=torch.float16,\n    **kwinputs,\n):\n    \"\"\"Use Pytorch Benchmark on the backward pass of an arbitrary function.\"\"\"\n    if verbose:\n        print(desc, \"- Backward pass\")\n    with torch.autocast(device_type=\"cuda\", dtype=amp_dtype, enabled=amp):\n        y = fn(*inputs, **kwinputs)\n        if type(y) is tuple:\n            y = y[0]\n    if grad is None:\n        grad = torch.randn_like(y)\n    else:\n        if grad.shape != y.shape:\n            raise RuntimeError(\"Grad shape does not match output shape\")\n\n    def f(*inputs, y, grad):\n        # Set .grad to None to avoid extra operation of gradient accumulation\n        for x in inputs:\n            if isinstance(x, torch.Tensor):\n                x.grad = None\n        y.backward(grad, retain_graph=True)\n\n    t = benchmark.Timer(\n        stmt=\"f(*inputs, y=y, grad=grad)\",\n        globals={\"f\": f, \"inputs\": inputs, \"y\": y, \"grad\": grad},\n        num_threads=torch.get_num_threads(),\n    )\n    m = t.timeit(repeats)\n    if verbose:\n        print(m)\n    return t, m\n\n\ndef benchmark_combined(\n    fn,\n    *inputs,\n    grad=None,\n    repeats=10,\n    desc=\"\",\n    verbose=True,\n    amp=False,\n    amp_dtype=torch.float16,\n    **kwinputs,\n):\n    \"\"\"Use Pytorch Benchmark on the forward+backward pass of an arbitrary function.\"\"\"\n    if verbose:\n        print(desc, \"- Forward + Backward pass\")\n    with torch.autocast(device_type=\"cuda\", dtype=amp_dtype, enabled=amp):\n        y = fn(*inputs, **kwinputs)\n        if type(y) is tuple:\n            y = y[0]\n    if grad is None:\n        grad = torch.randn_like(y)\n    else:\n        if grad.shape != y.shape:\n            raise RuntimeError(\"Grad shape does not match output shape\")\n\n    def f(grad, *inputs, **kwinputs):\n        for x in inputs:\n            if isinstance(x, torch.Tensor):\n                x.grad = None\n        with torch.autocast(device_type=\"cuda\", dtype=amp_dtype, enabled=amp):\n            y = fn(*inputs, **kwinputs)\n            if type(y) is tuple:\n                y = y[0]\n        y.backward(grad, retain_graph=True)\n\n    t = benchmark.Timer(\n        stmt=\"f(grad, *inputs, **kwinputs)\",\n        globals={\"f\": f, \"fn\": fn, \"inputs\": inputs, \"grad\": grad, \"kwinputs\": kwinputs},\n        num_threads=torch.get_num_threads(),\n    )\n    m = t.timeit(repeats)\n    if verbose:\n        print(m)\n    return t, m\n\n\ndef benchmark_fwd_bwd(\n    fn,\n    *inputs,\n    grad=None,\n    repeats=10,\n    desc=\"\",\n    verbose=True,\n    amp=False,\n    amp_dtype=torch.float16,\n    **kwinputs,\n):\n    \"\"\"Use Pytorch Benchmark on the forward+backward pass of an arbitrary function.\"\"\"\n    return (\n        benchmark_forward(\n            fn,\n            *inputs,\n            repeats=repeats,\n            desc=desc,\n            verbose=verbose,\n            amp=amp,\n            amp_dtype=amp_dtype,\n            **kwinputs,\n        ),\n        benchmark_backward(\n            fn,\n            *inputs,\n            grad=grad,\n            repeats=repeats,\n            desc=desc,\n            verbose=verbose,\n            amp=amp,\n            amp_dtype=amp_dtype,\n            **kwinputs,\n        ),\n    )\n\n\ndef benchmark_all(\n    fn,\n    *inputs,\n    grad=None,\n    repeats=10,\n    desc=\"\",\n    verbose=True,\n    amp=False,\n    amp_dtype=torch.float16,\n    **kwinputs,\n):\n    \"\"\"Use Pytorch Benchmark on the forward+backward pass of an arbitrary function.\"\"\"\n    return (\n        benchmark_forward(\n            fn,\n            *inputs,\n            repeats=repeats,\n            desc=desc,\n            verbose=verbose,\n            amp=amp,\n            amp_dtype=amp_dtype,\n            **kwinputs,\n        ),\n        benchmark_backward(\n            fn,\n            *inputs,\n            grad=grad,\n            repeats=repeats,\n            desc=desc,\n            verbose=verbose,\n            amp=amp,\n            amp_dtype=amp_dtype,\n            **kwinputs,\n        ),\n        benchmark_combined(\n            fn,\n            *inputs,\n            grad=grad,\n            repeats=repeats,\n            desc=desc,\n            verbose=verbose,\n            amp=amp,\n            amp_dtype=amp_dtype,\n            **kwinputs,\n        ),\n    )\n\n\ndef pytorch_profiler(\n    fn,\n    *inputs,\n    trace_filename=None,\n    backward=False,\n    amp=False,\n    amp_dtype=torch.float16,\n    cpu=False,\n    verbose=True,\n    **kwinputs,\n):\n    \"\"\"Wrap benchmark functions in Pytorch profiler to see CUDA information.\"\"\"\n    if backward:\n        with torch.autocast(device_type=\"cuda\", dtype=amp_dtype, enabled=amp):\n            out = fn(*inputs, **kwinputs)\n            if type(out) is tuple:\n                out = out[0]\n            g = torch.randn_like(out)\n    for _ in range(30):  # Warm up\n        if backward:\n            for x in inputs:\n                if isinstance(x, torch.Tensor):\n                    x.grad = None\n        with torch.autocast(device_type=\"cuda\", dtype=amp_dtype, enabled=amp):\n            out = fn(*inputs, **kwinputs)\n            if type(out) is tuple:\n                out = out[0]\n        # Backward should be done outside autocast\n        if backward:\n            out.backward(g, retain_graph=True)\n    activities = ([torch.profiler.ProfilerActivity.CPU] if cpu else []) + [\n        torch.profiler.ProfilerActivity.CUDA\n    ]\n    with torch.profiler.profile(\n        activities=activities,\n        record_shapes=True,\n        # profile_memory=True,\n        with_stack=True,\n    ) as prof:\n        if backward:\n            for x in inputs:\n                if isinstance(x, torch.Tensor):\n                    x.grad = None\n        with torch.autocast(device_type=\"cuda\", dtype=amp_dtype, enabled=amp):\n            out = fn(*inputs, **kwinputs)\n            if type(out) is tuple:\n                out = out[0]\n        if backward:\n            out.backward(g, retain_graph=True)\n    if verbose:\n        # print(prof.key_averages().table(sort_by=\"self_cuda_time_total\", row_limit=50))\n        print(prof.key_averages().table(row_limit=50))\n    if trace_filename is not None:\n        prof.export_chrome_trace(trace_filename)\n\n\ndef benchmark_memory(fn, *inputs, desc=\"\", verbose=True, **kwinputs):\n    torch.cuda.empty_cache()\n    torch.cuda.reset_peak_memory_stats()\n    torch.cuda.synchronize()\n    fn(*inputs, **kwinputs)\n    torch.cuda.synchronize()\n    mem = torch.cuda.max_memory_allocated() / ((2**20) * 1000)\n    if verbose:\n        print(f\"{desc} max memory: {mem}GB\")\n    torch.cuda.empty_cache()\n    return mem\n"
  },
  {
    "path": "flash_attn/utils/distributed.py",
    "content": "from typing import Optional\n\nimport torch\nfrom torch import Tensor\nfrom torch.distributed import ProcessGroup\n\n# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for\n# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent\n# version of PyTorch. The following 4 lines are for backward compatibility with\n# older PyTorch.\nif \"all_gather_into_tensor\" not in dir(torch.distributed):\n    torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base\nif \"reduce_scatter_tensor\" not in dir(torch.distributed):\n    torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base\n\n\n# Raw operation, does not support autograd, but does support async\ndef all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):\n    world_size = torch.distributed.get_world_size(process_group)\n    output = torch.empty(\n        world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device\n    )\n    handle = torch.distributed.all_gather_into_tensor(\n        output, input_.contiguous(), group=process_group, async_op=async_op\n    )\n    return output, handle\n\n\n# Raw operation, does not support autograd, but does support async\ndef reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):\n    world_size = torch.distributed.get_world_size(process_group)\n    assert input_.shape[0] % world_size == 0\n    output = torch.empty(\n        input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device\n    )\n    handle = torch.distributed.reduce_scatter_tensor(\n        output, input_.contiguous(), group=process_group, async_op=async_op\n    )\n    return output, handle\n\n\n# Raw operation, does not support autograd, but does support async\ndef all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):\n    input_ = input_.contiguous()\n    handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op)\n    return input_, handle\n\n\nclass AllGatherFunc(torch.autograd.Function):\n    \"\"\"Gather the input from sequence parallel region and concatenate.\"\"\"\n\n    @staticmethod\n    def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:\n        ctx.process_group = process_group\n        output, _ = all_gather_raw(input_, process_group)\n        return output\n\n    @staticmethod\n    def backward(ctx, grad_output: Tensor):\n        grad_input, _ = reduce_scatter_raw(grad_output, ctx.process_group)\n        return grad_input, None\n\n\n# Supports autograd, but does not support async\nall_gather = AllGatherFunc.apply\n\n\nclass ReduceScatterFunc(torch.autograd.Function):\n    \"\"\"Reduce scatter the input from the sequence parallel region and concatenate.\"\"\"\n\n    @staticmethod\n    def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:\n        ctx.process_group = process_group\n        output, _ = reduce_scatter_raw(input_, process_group)\n        return output\n\n    @staticmethod\n    def backward(ctx, grad_output: Tensor):\n        grad_input, _ = all_gather_raw(grad_output, ctx.process_group)\n        return grad_input, None\n\n\n# Supports autograd, but does not support async\nreduce_scatter = ReduceScatterFunc.apply\n\n\nclass AllReduceFunc(torch.autograd.Function):\n    \"\"\"Gather the input from sequence parallel region and concatenate.\"\"\"\n\n    @staticmethod\n    def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:\n        ctx.process_group = process_group\n        output, _ = all_reduce_raw(input_, process_group)\n        return output\n\n    @staticmethod\n    def backward(ctx, grad_output: Tensor):\n        return grad_output, None\n\n\n# Supports autograd, but does not support async\nall_reduce = AllReduceFunc.apply\n\n\ndef sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup):\n    # We want to iterate over parameters with _shared_params=True in the same order,\n    # as different ranks might have different number of parameters (e.g., only rank 0 has bias).\n    pamams_shared = {\n        name: p for name, p in model.named_parameters() if getattr(p, \"_shared_params\", False)\n    }\n    for _, p in sorted(pamams_shared.items()):\n        with torch.no_grad():\n            # Broadcast needs src to be global rank, not group rank\n            torch.distributed.broadcast(\n                p, src=torch.distributed.get_global_rank(process_group, 0), group=process_group\n            )\n\n\n# Ref: https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/optimizer/optimizer.py#L256\ndef allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup):\n    # We want to iterate over parameters with _sequence_parallel=True in the same order,\n    # as different ranks might have different number of parameters (e.g., only rank 0 has bias).\n    params_seqparallel = {\n        name: p for name, p in model.named_parameters() if getattr(p, \"_sequence_parallel\", False)\n    }\n    grads = [p.grad for _, p in sorted(params_seqparallel.items())]\n    if grads:\n        with torch.no_grad():\n            coalesced = torch._utils._flatten_dense_tensors(grads)\n            torch.distributed.all_reduce(coalesced, group=process_group)\n            for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)):\n                buf.copy_(synced)\n\n\ndef get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, multiple_of: int = 1) -> int:\n    \"\"\"Get the dim for the local rank derived from splitting dim on world_size processes.\n\n    The split may not be even across the world_size processes.\n    \"\"\"\n    multiple = dim // multiple_of\n    div = multiple // world_size\n    mod = multiple % world_size\n    local_multiple = div + int(local_rank < mod)\n    return local_multiple * multiple_of\n"
  },
  {
    "path": "flash_attn/utils/generation.py",
    "content": "# Copyright (c) 2023, Tri Dao.\n# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31\nimport gc\nimport time\nfrom collections import namedtuple\nfrom dataclasses import dataclass, field\nfrom functools import partial\nfrom typing import Callable, Optional, Sequence, Union\n\nimport torch\nimport torch.nn.functional as F\nfrom einops import rearrange, repeat\nfrom torch import Tensor\nfrom torch.profiler import ProfilerActivity, profile, record_function\n\ntry:\n    from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput\nexcept ImportError:\n    GreedySearchDecoderOnlyOutput = namedtuple(\"GreedySearchDecoderOnlyOutput\", [\"sequences\", \"scores\"])\n    SampleDecoderOnlyOutput = namedtuple(\"SampleDecoderOnlyOutput\", [\"sequences\", \"scores\"])\n\n\n@dataclass\nclass InferenceParams:\n    \"\"\"Inference parameters that are passed to the main model in order\n    to efficienly calculate and store the context during inference.\"\"\"\n\n    max_seqlen: int\n    max_batch_size: int\n    seqlen_offset: int = 0\n    batch_size_offset: int = 0\n    key_value_memory_dict: dict = field(default_factory=dict)\n    lengths_per_sample: Optional[Tensor] = None\n\n    def reset(self, max_seqlen, max_batch_size):\n        self.max_seqlen = max_seqlen\n        self.max_batch_size = max_batch_size\n        self.seqlen_offset = 0\n        if self.lengths_per_sample is not None:\n            self.lengths_per_sample.zero_()\n\n\n# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py\n# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231\ndef modify_logits_for_top_k_filtering(logits, top_k):\n    \"\"\"Set the logits for none top-k values to -inf. Done in-place.\"\"\"\n    indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]\n    logits.masked_fill_(indices_to_remove, float(\"-Inf\"))\n\n\n# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py\n# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170\ndef modify_logits_for_top_p_filtering(logits, top_p):\n    \"\"\"Set the logits for none top-p values to -inf. Done in-place.\"\"\"\n    if top_p <= 0.0 or top_p >= 1.0:\n        return\n    # First sort and calculate cumulative sum of probabilities.\n    sorted_logits, sorted_indices = torch.sort(logits, descending=False)\n    cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)\n    # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)\n    sorted_indices_to_remove = cumulative_probs <= (1 - top_p)\n    # scatter sorted tensors to original indexing\n    indices_to_remove = sorted_indices_to_remove.scatter(\n        1, sorted_indices, sorted_indices_to_remove\n    )\n    logits.masked_fill_(indices_to_remove, float(\"-inf\"))\n\n\ndef sample(logits, top_k=1, top_p=0.0, temperature=1.0):\n    \"\"\"Sample from top-k logits.\n    Arguments:\n        logits: Tensor of shape (batch_size, vocab_size)\n    \"\"\"\n    if top_k == 1:  # Short-circuit for greedy decoding\n        return logits.argmax(dim=-1)\n    else:\n        if top_p > 0.0:\n            assert top_p <= 1.0, \"top-p should be in (0, 1].\"\n        if top_k > 0:\n            top_k = min(top_k, logits.size(-1))  # Safety check\n            logits_top, indices = torch.topk(logits, top_k, dim=-1)\n            if temperature != 1.0:\n                logits_top /= temperature\n            modify_logits_for_top_p_filtering(logits_top, top_p)\n            return indices[\n                torch.arange(indices.shape[0], device=indices.device),\n                torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1),\n            ]\n        else:\n            # Clone so that when we modify for top_p we don't change the original logits\n            logits_top = logits / temperature if temperature != 1.0 else logits.clone()\n            modify_logits_for_top_p_filtering(logits_top, top_p)\n            return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(\n                dim=-1\n            )\n\n\n@torch.inference_mode()\ndef decode(\n    input_ids,\n    model,\n    max_length,\n    top_k=1,\n    top_p=0.0,\n    temperature=1.0,\n    eos_token_id=None,\n    teacher_outputs=None,\n    vocab_size=None,\n    tensor_parallel=1,\n    cg=False,\n    enable_timing=False,\n):\n    \"\"\"Decoding, either greedy or with top-k or top-p sampling.\n    If top-k = 0, don't limit the number of candidates (pure sampling).\n    Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,\n    then top-p.\n    We assume that all sequences in the same batch have the same length.\n\n    Arguments:\n        input_ids: (batch, seq_len)\n        max_length: int\n        teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the\n            logits, the next token is taken from the teacher_outputs. Useful for testing.\n    Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:\n        sequences: (batch, max_length)\n        scores: tuples of (batch, vocab_size)\n    \"\"\"\n    batch_size, seqlen_og = input_ids.shape\n    teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0\n    if cg:\n        if not hasattr(model, \"_decoding_cache\"):\n            model._decoding_cache = None\n        model._decoding_cache = update_graph_cache(\n            model,\n            model._decoding_cache,\n            batch_size,\n            seqlen_og,\n            max_length,\n            tensor_parallel=tensor_parallel,\n        )\n        inference_params = model._decoding_cache.inference_params\n        inference_params.reset(max_length, batch_size)\n    else:\n        inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)\n\n    def get_logits(input_ids, inference_params):\n        decoding = inference_params.seqlen_offset > 0\n        if decoding:\n            position_ids = torch.full(\n                (batch_size, 1),\n                inference_params.seqlen_offset,\n                dtype=torch.long,\n                device=input_ids.device,\n            )\n        else:\n            position_ids = None\n        if not cg or not decoding:\n            logits = model(\n                input_ids,\n                position_ids=position_ids,\n                inference_params=inference_params,\n                num_last_tokens=1,\n            ).logits.squeeze(dim=1)\n        else:\n            logits = model._decoding_cache.run(\n                input_ids, position_ids, inference_params.seqlen_offset\n            ).squeeze(dim=1)\n        return logits[..., :vocab_size] if vocab_size is not None else logits\n\n    def sample_tokens(logits, inference_params):\n        if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset:\n            token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)\n        else:\n            token = teacher_outputs[:, inference_params.seqlen_offset]\n        # return rearrange(token, \"b -> b 1\")\n        return token.unsqueeze(1)\n\n    def should_stop(current_token, inference_params):\n        if inference_params.seqlen_offset == 0:\n            return False\n        if eos_token_id is not None and (current_token == eos_token_id).all():\n            return True\n        if inference_params.seqlen_offset >= max_length - 1:\n            return True\n        return False\n\n    start = torch.cuda.Event(enable_timing=enable_timing)\n    end = torch.cuda.Event(enable_timing=enable_timing)\n\n    if enable_timing:\n        if tensor_parallel > 1:\n            torch.distributed.barrier()\n        start.record()\n    scores, sequences = [], [input_ids]\n    while not should_stop(sequences[-1], inference_params):\n        scores.append(get_logits(sequences[-1], inference_params))\n        inference_params.seqlen_offset += sequences[-1].shape[1]\n        sequences.append(sample_tokens(scores[-1], inference_params))\n    if enable_timing:\n        end.record()\n        if tensor_parallel > 1:\n            torch.distributed.barrier()\n        torch.cuda.synchronize()\n        print(f\"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms\")\n    output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput\n    return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))\n\n\ndef sample_speculative(logits, logits_draft, tokens_draft, top_k=1, top_p=0.0, temperature=1.0):\n    \"\"\"Algorithm 1 from [1]\n    [1] Fast Inference from Transformers via Speculative Decoding\n    Yaniv Leviathan, Matan Kalman, Yossi Matias\n    https://arxiv.org/abs/2211.17192\n\n    Arguments:\n        logits: Tensor of shape (batch_size, seqlen + 1, vocab_size)\n        logits_draft: Tensor of shape (batch_size, seqlen, vocab_size)\n        tokens_draft: Tensor of shape (batch_size, seqlen)\n    Return:\n        tokens: Tensor of shape (batch_size, seqlen + 1)\n        num_generated_tokens: Tensor of shape (batch_size), with value in [1, seqlen + 1].\n            For each sequence in the batch, the number of valid tokens that were sampled by\n            speculative sampling.\n    \"\"\"\n    batch, seqlen_p_1, vocab_size = logits.shape\n    seqlen = seqlen_p_1 - 1\n    assert logits_draft.shape == (batch, seqlen, vocab_size)\n    assert tokens_draft.shape == (batch, seqlen)\n    assert tokens_draft.dtype in [torch.int64, torch.int32]\n    # TODO: if top_k = 1 we can simplify things and only work with indices\n    if top_p > 0.0:\n        assert top_p <= 1.0, \"top-p should be in (0, 1].\"\n    # Clone so that when we modify for top_p we don't change the original logits\n    logits = logits / temperature if temperature != 1.0 else logits.clone()\n    logits_draft = logits_draft / temperature if temperature != 1.0 else logits_draft.clone()\n    if top_k > 0:\n        top_k = min(top_k, logits.size(-1))  # Safety check\n        modify_logits_for_top_k_filtering(logits, top_k)\n        modify_logits_for_top_k_filtering(logits_draft, top_k)\n    modify_logits_for_top_p_filtering(logits, top_p)\n    modify_logits_for_top_p_filtering(logits_draft, top_p)\n    probs = torch.softmax(logits, dim=-1)\n    probs_draft = torch.softmax(logits_draft, dim=-1)\n    gather = lambda probs, tokens: rearrange(\n        probs.gather(dim=-1, index=rearrange(tokens, \"... -> ... 1\")), \"... 1 -> ...\"\n    )\n    # (batch, seqlen)\n    accepted = torch.rand(batch, seqlen, device=probs.device) * gather(\n        probs_draft, tokens_draft\n    ) <= gather(probs[:, :-1], tokens_draft)\n    accepted_all = accepted.all(dim=-1)\n    # (batch,)\n    first_rejected_idx = torch.where(accepted_all, seqlen, accepted.int().argmin(dim=-1))\n    probs_diff = torch.clamp(probs[:, :-1] - probs_draft, min=0.0)\n    # torch.multinomial can deal with unnormalized probabilities\n    # probs_diff /= probs_diff.sum(dim=-1, keepdim=True)\n    resample_probs = torch.cat([probs_diff, probs[:, -1:]], dim=1)\n    resample_probs = rearrange(\n        resample_probs.gather(dim=1, index=repeat(first_rejected_idx, \"b -> b 1 d\", d=vocab_size)),\n        \"b 1 d -> b d\",\n    )\n    resample = torch.multinomial(resample_probs, num_samples=1).squeeze(dim=-1)  # (batch,)\n    tokens = F.pad(tokens_draft, (0, 1))\n    tokens[:, first_rejected_idx] = resample\n    return tokens, first_rejected_idx + 1\n\n\n@torch.inference_mode()\ndef decode_speculative(\n    input_ids,\n    model,\n    model_draft,\n    max_length,\n    speculative_lookahead=3,\n    top_k=1,\n    top_p=0.0,\n    temperature=1.0,\n    eos_token_id=None,\n    vocab_size=None,\n    tensor_parallel=1,\n    cg=False,\n    enable_timing=False,\n    debug=False,\n):\n    \"\"\"\n    TD: WIP, for my own understanding, lightly tested. Only support batch_size == 1 for now.\n\n    Speculative decoding, either greedy or with top-k or top-p sampling.\n    If top-k = 0, don't limit the number of candidates (pure sampling).\n    Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,\n    then top-p.\n    We assume that all sequences in the same batch have the same length.\n\n    Arguments:\n        input_ids: (batch, seq_len)\n        max_length: int\n    Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:\n        sequences: (batch, max_length)\n        scores: tuples of (batch, vocab_size)\n    \"\"\"\n    batch_size, seqlen_og = input_ids.shape\n    assert batch_size == 1, \"Speculative decoding implementation only supports batch_size=1\"\n    assert eos_token_id is None, \"Speculative decoding implementation doesn't support eos_token_id\"\n    if cg:\n        if not hasattr(model_draft, \"_decoding_cache\"):\n            model_draft._decoding_cache = None\n        model_draft._decoding_cache = update_graph_cache(\n            model_draft,\n            model_draft._decoding_cache,\n            batch_size,\n            seqlen_og,\n            max_length,\n            # draft model needs to process either 1 or 2 tokens at a time\n            decoding_seqlens=(1, 2),\n            tensor_parallel=tensor_parallel,\n        )\n        inference_params_draft = model_draft._decoding_cache.inference_params\n        inference_params_draft.reset(max_length, batch_size)\n        if not hasattr(model, \"_decoding_cache\"):\n            model._decoding_cache = None\n        model._decoding_cache = update_graph_cache(\n            model,\n            model._decoding_cache,\n            batch_size,\n            seqlen_og,\n            max_length,\n            decoding_seqlens=range(1, speculative_lookahead + 2),\n            tensor_parallel=tensor_parallel,\n        )\n        inference_params = model._decoding_cache.inference_params\n        inference_params.reset(max_length, batch_size)\n    else:\n        inference_params_draft = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)\n        inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)\n\n    def get_logits(input_ids, inference_params, model, num_last_tokens=1, cg=False):\n        decoding = inference_params.seqlen_offset > 0\n        if decoding:\n            seqlen = input_ids.shape[1]\n            # if inference_params.lengths_per_sample is None:\n            # TODO: in the case of batched decoding where each sequence has a different length,\n            # we need to compute the position_ids for each sequence using lengths_per_sample\n            if True:\n                cache_seqlens = torch.full(\n                    (input_ids.shape[0],),\n                    inference_params.seqlen_offset,\n                    dtype=torch.int32,\n                    device=input_ids.device,\n                )\n            else:\n                cache_seqlens = inference_params.lengths_per_sample\n            position_ids = cache_seqlens[:, None] + torch.arange(\n                seqlen, dtype=torch.long, device=input_ids.device\n            )\n        else:\n            position_ids = None\n        if not cg or not decoding:\n            logits = model(\n                input_ids,\n                position_ids=position_ids,\n                inference_params=inference_params,\n                num_last_tokens=num_last_tokens,\n            ).logits\n        else:\n            # NOTE: careful, CUDA graph is set to have num_last_tokens=input_ids.shape[1].\n            # This might not be compatible the num_last_tokens used here.\n            assert num_last_tokens <= input_ids.shape[1]\n            logits = model._decoding_cache.run(\n                input_ids, position_ids, inference_params.seqlen_offset\n            )[:, -num_last_tokens:]\n        return logits[..., :vocab_size] if vocab_size is not None else logits\n\n    def sample_tokens(input_ids, get_logits_fn, inference_params, sample_fn, num_tokens=1):\n        \"\"\"Sample `num_tokens` tokens from the model, given the previous logits.\n        Also return the logits of the sampled tokens.\n        Arguments:\n            input_ids: (batch, seqlen)\n        Return:\n            tokens: (batch, num_tokens)\n            scores: (batch, num_tokens), which contains @previous_logits and the logits of the next\n                (num_tokens - 1) tokens. The logits of the last token isn't computed.\n        \"\"\"\n        assert num_tokens >= 1\n        sequences, scores = [input_ids], []\n        for i in range(num_tokens):\n            scores.append(get_logits_fn(sequences[-1], inference_params)[:, -1])\n            inference_params.seqlen_offset += sequences[-1].shape[1]\n            sequences.append(sample_fn(scores[-1]).unsqueeze(1))\n        return torch.cat(sequences[1:], dim=1), torch.stack(scores, dim=1)\n\n    sampling_kwargs = dict(top_k=top_k, top_p=top_p, temperature=temperature)\n    sample_fn = partial(sample, **sampling_kwargs)\n    get_logits_main = partial(get_logits, model=model, cg=cg)\n    get_logits_draft = partial(get_logits, model=model_draft, cg=cg)\n    sample_tokens_main = partial(\n        sample_tokens,\n        get_logits_fn=get_logits_main,\n        sample_fn=sample_fn,\n        inference_params=inference_params,\n    )\n    sample_tokens_draft = partial(\n        sample_tokens,\n        get_logits_fn=get_logits_draft,\n        sample_fn=sample_fn,\n        inference_params=inference_params_draft,\n    )\n\n    if debug:\n        from transformers import AutoTokenizer\n\n        tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")\n    if enable_timing:\n        if tensor_parallel > 1:\n            torch.distributed.barrier()\n        torch.cuda.synchronize()\n        start = time.time()\n\n    sequences, scores = [input_ids], []\n    num_main_model_calls = 0\n    num_draft_tokens = 0\n    num_accepted_tokens_history = []\n    if seqlen_og >= max_length - 1:\n        # Don't do speculative sampling, just sample 1 token from the model\n        tokens, scores_new = sample_tokens_main(input_ids, num_tokens=1)\n        sequences.append(tokens)\n        scores.append(scores_new)\n    else:\n        # Sample from draft model, which produces @n_spec_tokens, and @model\n        # will then use to produce between 1 and 1 + @n_spec_tokens tokens.\n        # We want seqlen_og + 1 + @n_spec_tokens to be <= @max_length.\n        n_spec_tokens = min(speculative_lookahead, max_length - seqlen_og - 1)\n        tokens_draft, scores_draft = sample_tokens_draft(input_ids, num_tokens=n_spec_tokens)\n        num_draft_tokens += n_spec_tokens\n        if debug:\n            scores_draft_ref = model_draft(\n                torch.cat([input_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1\n            ).logits\n            print((scores_draft - scores_draft_ref[:, :-1]).abs().max())\n\n        # Evaluate the draft tokens with the model\n        logits = get_logits_main(\n            torch.cat([input_ids, tokens_draft], dim=1),\n            inference_params,\n            num_last_tokens=n_spec_tokens + 1,\n        )\n        num_main_model_calls += 1\n        if debug:\n            logits_ref = model(\n                torch.cat([input_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1\n            ).logits\n            print((logits - logits_ref).abs().max())\n            # breakpoint()\n        tokens, num_generated_tokens = sample_speculative(\n            logits, scores_draft, tokens_draft, **sampling_kwargs\n        )\n        num_accepted_tokens_history.append(num_generated_tokens - 1)\n        if debug:\n            print(tokens)\n            print(num_generated_tokens)\n            # breakpoint()\n        # TODO: we're using the fact that batch_size == 1\n        # TODO: check eos_token_id\n        sequences.append(tokens[:1, : num_generated_tokens[0]])\n        scores.append(logits[:1, : num_generated_tokens[0]])\n        # Note that @model has not evaluated the last sampled token yet, so we'll need to pass\n        # that in the next time we call @model.\n        num_generated = num_generated_tokens[0].item()\n        inference_params.seqlen_offset = seqlen_og + num_generated - 1\n        inference_params_draft.seqlen_offset = (\n            inference_params.seqlen_offset - 1\n            if num_generated > 1\n            else inference_params.seqlen_offset\n        )\n        if debug:\n            cur_ids = torch.cat([input_ids, sequences[-1]], dim=1)\n            scores_ref = model(cur_ids, num_last_tokens=num_generated_tokens[0].item() + 1).logits\n            print((scores[-1] - scores_ref[:, :-1]).abs().max())\n            # breakpoint()\n\n    while True:\n        # seqlen_offset is total length generated - 1\n        if inference_params.seqlen_offset >= max_length - 1:\n            break\n        if inference_params.seqlen_offset >= max_length - 2:\n            # Don't do speculative sampling, just sample 1 token from the model\n            tokens, scores_new = sample_tokens_main(sequences[-1][:, -1:], num_tokens=1)\n            sequences.append(tokens)\n            scores.append(scores_new)\n            break\n        # Sample from draft model\n        n_spec_tokens = min(\n            speculative_lookahead, max_length - inference_params_draft.seqlen_offset - 2\n        )\n        # If the main model accepts all the draft tokens, plus it samples one new token,\n        # then at the next iteration the draft model need to evaluate the logits of the last draft\n        # token and the logits of the newly sampled token. So here we pass in the last 2 tokens\n        # of sequences[-1].\n        # This exception is when the main model rejects all the draft tokens, in which case we\n        # will only have 1 token to pass in.\n        tokens_draft, scores_draft = sample_tokens_draft(\n            sequences[-1][:, -2:], num_tokens=n_spec_tokens\n        )\n        num_draft_tokens += n_spec_tokens\n        if debug:\n            scores_draft_ref = model_draft(\n                torch.cat([cur_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1\n            ).logits\n            print((scores_draft - scores_draft_ref[:, :-1]).abs().max())\n            # breakpoint()\n        # Evaluate the draft tokens with the model\n        logits = get_logits_main(\n            torch.cat([sequences[-1][:, -1:], tokens_draft], dim=1),\n            inference_params,\n            num_last_tokens=n_spec_tokens + 1,\n        )  # (batch, n_spec_tokens + 1, vocab_size)\n        num_main_model_calls += 1\n        if debug:\n            logits_ref = model(\n                torch.cat([cur_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1\n            ).logits\n            print((logits - logits_ref).abs().max())\n            # breakpoint()\n        tokens, num_generated_tokens = sample_speculative(\n            logits, scores_draft, tokens_draft, **sampling_kwargs\n        )\n        num_accepted_tokens_history.append(num_generated_tokens - 1)\n        if debug:\n            print(tokens)\n            print(num_generated_tokens)\n            # breakpoint()\n        sequences.append(tokens[:1, : num_generated_tokens[0]])\n        scores.append(logits[:1, : num_generated_tokens[0]])\n        # We've evaluated 1 token from sequences[-1][:, -1:] above, plus\n        # num_generated_tokens[0].item() - 1 tokens from the draft model.\n        num_generated = num_generated_tokens[0].item()\n        inference_params.seqlen_offset += num_generated\n        inference_params_draft.seqlen_offset = (\n            inference_params.seqlen_offset - 1\n            if num_generated > 1\n            else inference_params.seqlen_offset\n        )\n        if debug:\n            cur_ids = torch.cat([cur_ids, sequences[-1]], dim=1)\n            scores_ref = model(cur_ids, num_last_tokens=num_generated_tokens[0].item() + 1).logits\n            print((scores[-1] - scores_ref[:, :-1]).abs().max())\n            # breakpoint()\n\n    if enable_timing:\n        if tensor_parallel > 1:\n            torch.distributed.barrier()\n        torch.cuda.synchronize()\n        print(f\"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms\")\n        print(f\"Number of calls to main model: {num_main_model_calls}\")\n        print(\n            f\"Acceptance rate: {torch.cat(num_accepted_tokens_history).sum().item() / num_draft_tokens * 100:.2f}%\"\n        )\n    sequences = torch.cat(sequences, dim=1)\n    scores = torch.cat(scores, dim=1)\n    if debug:\n        scores_ref = model(sequences).logits\n        print((scores - scores_ref[:, seqlen_og - 1 : -1]).abs().max())\n    output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput\n    return output_cls(sequences=sequences, scores=scores)\n\n\nclass GenerationMixin:\n    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):\n        raise NotImplementedError\n\n    def generate(\n        self,\n        input_ids,\n        max_length,\n        top_k=1,\n        top_p=0.0,\n        temperature=1.0,\n        return_dict_in_generate=False,\n        output_scores=False,\n        **kwargs,\n    ):\n        output = decode(\n            input_ids, self, max_length, top_k=top_k, top_p=top_p, temperature=temperature, **kwargs\n        )\n        if not output_scores:\n            output.scores = None\n        return output if return_dict_in_generate else output.sequences\n\n\ndef allocate_inference_cache(\n    max_batch_size,\n    max_seqlen,\n    nheads,\n    headdim,\n    layers: Union[int, Sequence],\n    device,\n    dtype=torch.float16,\n):\n    assert dtype in [torch.float16, torch.bfloat16, torch.float32]\n    kv_cache_shape = (max_batch_size, max_seqlen, 2, nheads, headdim)\n    if isinstance(layers, int):\n        layers = range(layers)\n    return {i: torch.empty(kv_cache_shape, device=device, dtype=dtype) for i in layers}\n\n\n@dataclass\nclass DecodingCGCache:\n    max_batch_size: int = 0\n    max_seqlen: int = 0\n    device = None\n    dtype = None\n    callables: dict = field(default_factory=dict)\n    mempool = None\n    inference_params: Optional[InferenceParams] = None\n    run: Optional[Callable] = None\n\n\n@torch.inference_mode()\ndef update_graph_cache(\n    model,\n    cache,\n    batch_size,\n    seqlen_og,\n    max_seqlen,\n    decoding_seqlens=(1,),\n    tensor_parallel=1,\n    dtype=None,\n    n_warmups=2,\n):\n    if cache is None:\n        cache = DecodingCGCache()\n    param_example = next(iter(model.parameters()))\n    device = param_example.device\n    if dtype is None:\n        dtype = param_example.dtype\n    if (\n        (device, dtype) != (cache.device, cache.dtype)\n        or batch_size > cache.max_batch_size\n        or max_seqlen > cache.max_seqlen\n    ):  # Invalidate the cache\n        cache.callables = {}\n        cache.mempool = None\n        cache.inference_params = None\n        gc.collect()\n        cache.device, cache.dtype = device, dtype\n        cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen\n        if hasattr(model, \"allocate_inference_cache\"):\n            inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)\n        else:\n            headdim = getattr(\n                model.config,\n                \"head_dim\",\n                model.config.hidden_size // model.config.num_attention_heads,\n            )\n            inf_cache = allocate_inference_cache(\n                batch_size,\n                max_seqlen,\n                model.config.num_attention_heads // tensor_parallel,\n                headdim,\n                model.config.num_hidden_layers,\n                device,\n                dtype,\n            )\n        lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)\n        cache.inference_params = InferenceParams(\n            max_seqlen=max_seqlen,\n            max_batch_size=batch_size,\n            seqlen_offset=seqlen_og,\n            key_value_memory_dict=inf_cache,\n            lengths_per_sample=lengths_per_sample,\n        )\n        cache.mempool = torch.cuda.graphs.graph_pool_handle()\n    for decoding_seqlen in decoding_seqlens:\n        if (batch_size, decoding_seqlen) not in cache.callables:\n            cache.callables[batch_size, decoding_seqlen] = capture_graph(\n                model,\n                cache.inference_params,\n                batch_size,\n                max_seqlen,\n                decoding_seqlen=decoding_seqlen,\n                mempool=cache.mempool,\n                n_warmups=n_warmups,\n            )\n\n    def dispatch(input_ids, position_ids, seqlen):\n        batch_size, decoding_seqlen = input_ids.shape[:2]\n        return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen)\n\n    cache.run = dispatch\n    cache.inference_params.seqlen_offset = 0  # Reset so it's not confusing\n    return cache\n\n\ndef capture_graph(\n    model, inference_params, batch_size, max_seqlen, decoding_seqlen=1, mempool=None, n_warmups=2\n):\n    device = next(iter(model.parameters())).device\n    input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)\n    position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)\n    seqlen_offset_og = inference_params.seqlen_offset\n    inference_params.seqlen_offset = max_seqlen - decoding_seqlen\n    inference_params.lengths_per_sample[:] = inference_params.seqlen_offset\n\n    # Warmup before capture\n    s = torch.cuda.Stream()\n    s.wait_stream(torch.cuda.current_stream())\n    with torch.cuda.stream(s):\n        for _ in range(n_warmups):\n            logits = model(\n                input_ids,\n                position_ids=position_ids,\n                inference_params=inference_params,\n                num_last_tokens=decoding_seqlen,\n            ).logits\n        s.synchronize()\n        # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,\n        # which requires that graph launch and non-captured launch to not overlap (I think,\n        # that's how I interpret the documentation). I'm not sure if this is required.\n        if torch.distributed.is_initialized():\n            torch.distributed.barrier()\n    torch.cuda.current_stream().wait_stream(s)\n    # Captures the graph\n    # To allow capture, automatically sets a side stream as the current stream in the context\n    graph = torch.cuda.CUDAGraph()\n    with torch.cuda.graph(graph, pool=mempool):\n        logits = model(\n            input_ids,\n            position_ids=position_ids,\n            inference_params=inference_params,\n            num_last_tokens=decoding_seqlen,\n        ).logits\n\n    def run(new_input_ids, new_position_ids, seqlen):\n        inference_params.lengths_per_sample[:] = seqlen\n        input_ids.copy_(new_input_ids)\n        position_ids.copy_(new_position_ids)\n        graph.replay()\n        return logits.clone()\n\n    inference_params.seqlen_offset = seqlen_offset_og\n    return run\n"
  },
  {
    "path": "flash_attn/utils/library.py",
    "content": "# Adapted from https://github.com/pytorch/pytorch/blob/v2.7.0/torch/_library/triton.py\n# The PyTorch implementation simply ignores the schema argument, we simply modify it to use schema.\n\nfrom typing import Optional, Callable, Iterable, Union\n\nfrom torch.library import custom_op, CustomOpDef\nfrom torch._library.triton import set_wrap_triton_enabled\n\n\ndef triton_op(\n    name: str,\n    fn: Optional[Callable] = None,\n    /,\n    *,\n    mutates_args: Union[str, Iterable[str]],\n    schema: Optional[str] = None,\n    # If allow_decomposition=True, this matches torch.library.triton_op behavior. If set to False,\n    # then it behaves like torch.library.custom_op instead, which doesn't decompose the operator\n    # and so inductor can't trace inside.\n    allow_decomposition=True,\n) -> Callable:\n    def dec(fn: Callable[..., object]) -> CustomOpDef:\n        def backend_fn(*args, **kwargs):  # type: ignore[no-untyped-def]\n            # Optimization: we're passing regular Tensors into the triton kernel, so\n            # no need to go through HOP dispatch\n            with set_wrap_triton_enabled(False):\n                return fn(*args, **kwargs)\n\n        result = custom_op(\n            name,\n            backend_fn,\n            mutates_args=mutates_args,\n            # This is the only difference with the PyTorch implementation\n            schema=schema,\n        )\n        from torch._subclasses.functional_tensor import FunctionalTensorMode\n\n        # We require that the user pass us a function that is make_fx traceable,\n        # so we can just register it as the Fake/meta kernel.\n        result.register_fake(fn)\n\n        if allow_decomposition:\n            # We decompose the operator when FunctionalTensorMode is active.\n            # The goal is to decompose the operator in AOTDispatcher.\n            # - With torch.compile, this means that the backend (usually Inductor)\n            #   can see a call to the triton kernel(s) and so it can directly optimize\n            #   them by inlining them into the lowering process.\n            def functional_decomp(  # type: ignore[no-untyped-def]\n                mode, op, types, args, kwargs\n            ):\n                from torch.export._trace import custom_triton_ops_decomposition_disabled\n\n                if custom_triton_ops_decomposition_disabled():\n                    return mode.__torch_dispatch__(op, types, args, kwargs)\n                else:\n                    with mode:\n                        return fn(*args, **kwargs)\n\n            result.register_torch_dispatch(FunctionalTensorMode, functional_decomp)\n\n        return result\n\n    if fn is None:\n        return dec\n    else:\n        return dec(fn)\n"
  },
  {
    "path": "flash_attn/utils/pretrained.py",
    "content": "import os\nfrom functools import partial\n\nimport torch\nfrom safetensors.torch import load_file as safe_load_file\nfrom transformers.utils import (\n    SAFE_WEIGHTS_INDEX_NAME,\n    SAFE_WEIGHTS_NAME,\n    WEIGHTS_INDEX_NAME,\n    WEIGHTS_NAME,\n)\nfrom transformers.utils.hub import cached_file, get_checkpoint_shard_files\n\n\ndef state_dict_from_pretrained(model_name, device=None, dtype=None):\n    # If not fp32, then we don't want to load directly to the GPU\n    mapped_device = \"cpu\" if dtype not in [torch.float32, None] else device\n    is_sharded = False\n    load_safe = False\n    resolved_archive_file = None\n\n    weights_path = os.path.join(model_name, WEIGHTS_NAME)\n    weights_index_path = os.path.join(model_name, WEIGHTS_INDEX_NAME)\n    safe_weights_path = os.path.join(model_name, SAFE_WEIGHTS_NAME)\n    safe_weights_index_path = os.path.join(model_name, SAFE_WEIGHTS_INDEX_NAME)\n\n    if os.path.isfile(weights_path):\n        resolved_archive_file = cached_file(\n            model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False\n        )\n    elif os.path.isfile(weights_index_path):\n        resolved_archive_file = cached_file(\n            model_name, WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False\n        )\n        is_sharded = True\n    elif os.path.isfile(safe_weights_path):\n        resolved_archive_file = cached_file(\n            model_name, SAFE_WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False\n        )\n        load_safe = True\n    elif os.path.isfile(safe_weights_index_path):\n        resolved_archive_file = cached_file(\n            model_name, SAFE_WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False\n        )\n        is_sharded = True\n        load_safe = True\n    else:  # Try loading from HF hub instead of from local files\n        resolved_archive_file = cached_file(model_name, WEIGHTS_NAME,\n                                            _raise_exceptions_for_missing_entries=False)\n        if resolved_archive_file is None:\n            resolved_archive_file = cached_file(model_name, WEIGHTS_INDEX_NAME,\n                                                _raise_exceptions_for_missing_entries=False)\n            if resolved_archive_file is not None:\n                is_sharded = True\n\n    if resolved_archive_file is None:\n        raise EnvironmentError(f\"Model name {model_name} was not found.\")\n\n    if load_safe:\n        loader = partial(safe_load_file, device=mapped_device)\n    else:\n        loader = partial(torch.load, map_location=mapped_device)\n\n    if is_sharded:\n        # resolved_archive_file becomes a list of files that point to the different\n        # checkpoint shards in this case.\n        resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(\n            model_name, resolved_archive_file\n        )\n        state_dict = {}\n        for sharded_file in resolved_archive_file:\n            state_dict.update(loader(sharded_file))\n    else:\n        state_dict = loader(resolved_archive_file)\n    # Convert dtype before moving to GPU to save memory\n    if dtype is not None:\n        state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}\n    state_dict = {k: v.to(device=device) for k, v in state_dict.items()}\n    return state_dict\n"
  },
  {
    "path": "flash_attn/utils/testing.py",
    "content": "# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\nimport math\nfrom typing import Optional\n\nimport torch\nfrom einops import rearrange, repeat\n\nfrom flash_attn.bert_padding import pad_input, unpad_input\n\n\ndef generate_random_padding_mask(max_seqlen, batch_size, device, mode=\"random\", zero_lengths=False):\n    assert mode in [\"full\", \"random\", \"third\"]\n    if mode == \"full\":\n        lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32)\n    elif mode == \"random\":\n        lengths = torch.randint(\n            max(0 if zero_lengths else 1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device\n        )\n    elif mode == \"third\":\n        lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device)\n\n    if zero_lengths:\n        # Generate zero-lengths every 5 batches and the last batch.\n        for i in range(batch_size):\n            if i % 5 == 0:\n                lengths[i] = 0\n        lengths[-1] = 0\n    padding_mask = (\n        repeat(torch.arange(max_seqlen, device=device), \"s -> b s\", b=batch_size) < lengths\n    )\n    return padding_mask\n\n\ndef generate_qkv(\n    q, k, v, query_padding_mask=None, key_padding_mask=None, qv=None, kvpacked=False, qkvpacked=False,\n    query_unused_mask=None, key_unused_mask=None,\n):\n    \"\"\"\n    Arguments:\n        q: (batch_size, seqlen_q, nheads, d)\n        k: (batch_size, seqlen_k, nheads_k, d)\n        v: (batch_size, seqlen_k, nheads_k, d_v)\n        query_padding_mask: (batch_size, seqlen), bool\n        key_padding_mask: (batch_size, seqlen), bool\n    \"\"\"\n    assert not (kvpacked and qkvpacked)\n    batch_size, seqlen_q, nheads, d = q.shape\n    d_v = v.shape[-1]\n    _, seqlen_k, nheads_k, _ = k.shape\n    assert k.shape == (batch_size, seqlen_k, nheads_k, d)\n    assert v.shape == (batch_size, seqlen_k, nheads_k, d_v)\n    if query_unused_mask is not None or key_unused_mask is not None:\n        assert not kvpacked\n        assert not qkvpacked\n\n    if query_padding_mask is not None:\n        q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q = unpad_input(\n            q, query_padding_mask, query_unused_mask\n        )\n        output_pad_fn = lambda output_unpad: pad_input(\n            output_unpad, indices_q, batch_size, seqlen_q\n        )\n        qv_unpad = rearrange(qv, \"b s ... -> (b s) ...\")[indices_q] if qv is not None else None\n    else:\n        q_unpad = rearrange(q, \"b s h d -> (b s) h d\")\n        cu_seqlens_q = torch.arange(\n            0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device\n        )\n        seqused_q = None\n        max_seqlen_q = seqlen_q\n        output_pad_fn = lambda output_unpad: rearrange(\n            output_unpad, \"(b s) h d -> b s h d\", b=batch_size\n        )\n        qv_unpad = rearrange(qv, \"b s ... -> (b s) ...\") if qv is not None else None\n\n    if key_padding_mask is not None:\n        k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k = unpad_input(\n            k, key_padding_mask, key_unused_mask\n        )\n        v_unpad, *rest = unpad_input(v, key_padding_mask, key_unused_mask)\n    else:\n        k_unpad = rearrange(k, \"b s h d -> (b s) h d\")\n        v_unpad = rearrange(v, \"b s h d -> (b s) h d\")\n        cu_seqlens_k = torch.arange(\n            0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device\n        )\n        seqused_k = None\n        max_seqlen_k = seqlen_k\n\n    if qkvpacked:\n        assert (query_padding_mask == key_padding_mask).all()\n        assert nheads == nheads_k\n        qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)\n        qkv = torch.stack([q, k, v], dim=2)\n        if query_padding_mask is not None:\n            dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q)\n        else:\n            dqkv_pad_fn = lambda dqkv_unpad: rearrange(\n                dqkv_unpad, \"(b s) t h d -> b s t h d\", b=batch_size\n            )\n        return (\n            qkv_unpad.detach().requires_grad_(),\n            cu_seqlens_q,\n            max_seqlen_q,\n            qkv.detach().requires_grad_(),\n            output_pad_fn,\n            dqkv_pad_fn,\n        )\n    elif kvpacked:\n        kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)\n        kv = torch.stack([k, v], dim=2)\n        dq_pad_fn = output_pad_fn\n        if key_padding_mask is not None:\n            dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k)\n        else:\n            dkv_pad_fn = lambda dkv_unpad: rearrange(\n                dkv_unpad, \"(b s) t h d -> b s t h d\", b=batch_size\n            )\n        return (\n            q_unpad.detach().requires_grad_(),\n            kv_unpad.detach().requires_grad_(),\n            cu_seqlens_q,\n            cu_seqlens_k,\n            max_seqlen_q,\n            max_seqlen_k,\n            q.detach().requires_grad_(),\n            kv.detach().requires_grad_(),\n            output_pad_fn,\n            dq_pad_fn,\n            dkv_pad_fn,\n        )\n    else:\n        dq_pad_fn = output_pad_fn\n        if key_padding_mask is not None:\n            dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k)\n        else:\n            dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, \"(b s) h d -> b s h d\", b=batch_size)\n        return (\n            q_unpad.detach().requires_grad_(),\n            k_unpad.detach().requires_grad_(),\n            v_unpad.detach().requires_grad_(),\n            qv_unpad.detach()  if qv is not None else None,\n            cu_seqlens_q,\n            cu_seqlens_k,\n            seqused_q,\n            seqused_k,\n            max_seqlen_q,\n            max_seqlen_k,\n            q.detach().requires_grad_(),\n            k.detach().requires_grad_(),\n            v.detach().requires_grad_(),\n            qv.detach() if qv is not None else None,\n            output_pad_fn,\n            dq_pad_fn,\n            dk_pad_fn,\n        )\n\n\ndef construct_local_mask(\n    seqlen_q,\n    seqlen_k,\n    window_size=(None, None),\n    sink_token_length=0,\n    query_padding_mask=None,\n    key_padding_mask=None,\n    key_leftpad=None,\n    device=None,\n):\n    row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), \"s -> s 1\")\n    col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)\n    if key_leftpad is not None:\n        key_leftpad = rearrange(key_leftpad, \"b -> b 1 1 1\")\n        col_idx = repeat(col_idx, \"s -> b 1 1 s\", b=key_leftpad.shape[0])\n        col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)\n    sk = (\n        seqlen_k\n        if key_padding_mask is None\n        else rearrange(key_padding_mask.sum(-1), \"b -> b 1 1 1\")\n    )\n    sq = (\n        seqlen_q\n        if query_padding_mask is None\n        else rearrange(query_padding_mask.sum(-1), \"b -> b 1 1 1\")\n    )\n    if window_size[0] is None:\n        return col_idx > row_idx + sk - sq + window_size[1]\n    else:\n        sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk\n        return torch.logical_or(\n            col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),\n            torch.logical_and(col_idx < row_idx + sk - sq - window_size[0], col_idx >= sink_token_length),\n        )\n\n\ndef construct_chunk_mask(\n    seqlen_q,\n    seqlen_k,\n    attention_chunk,\n    query_padding_mask=None,\n    key_padding_mask=None,\n    key_leftpad=None,\n    device=None,\n):\n    row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), \"s -> s 1\")\n    col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)\n    if key_leftpad is not None:\n        key_leftpad = rearrange(key_leftpad, \"b -> b 1 1 1\")\n        col_idx = repeat(col_idx, \"s -> b 1 1 s\", b=key_leftpad.shape[0])\n        col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)\n    sk = (\n        seqlen_k\n        if key_padding_mask is None\n        else rearrange(key_padding_mask.sum(-1), \"b -> b 1 1 1\")\n    )\n    sq = (\n        seqlen_q\n        if query_padding_mask is None\n        else rearrange(query_padding_mask.sum(-1), \"b -> b 1 1 1\")\n    )\n    sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk\n    # Subtract remainder instead of divide and then multiply to take care of negative values\n    col_limit_left_chunk = row_idx + sk - sq - (row_idx + sk - sq) % attention_chunk\n    return torch.logical_or(\n        col_idx < col_limit_left_chunk, col_idx >= col_limit_left_chunk + attention_chunk\n    )\n\n\ndef attention_ref(\n    q,\n    k,\n    v,\n    query_padding_mask=None,\n    key_padding_mask=None,\n    key_leftpad=None,\n    attn_bias=None,\n    dropout_p=0.0,\n    dropout_mask=None,\n    causal=False,\n    qv=None,\n    q_descale=None, k_descale=None, v_descale=None,\n    window_size=(None, None),\n    attention_chunk=0,\n    sink_token_length=0,\n    learnable_sink: Optional[torch.Tensor] = None,\n    softcap=0.0,\n    upcast=True,\n    reorder_ops=False,\n    intermediate_dtype=None,\n):\n    \"\"\"\n    Arguments:\n        q: (batch_size, seqlen_q, nheads, head_dim)\n        k: (batch_size, seqlen_k, nheads, head_dim)\n        v: (batch_size, seqlen_k, nheads, head_dim_v)\n        qv: (batch_size, seqlen_q, nheads, head_dim_v)\n        query_padding_mask: (batch_size, seqlen_q)\n        key_padding_mask: (batch_size, seqlen_k)\n        attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)\n        dropout_p: float\n        dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)\n        causal: whether to apply causal masking\n        upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast\n            output back to fp16/bf16.\n        reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.)\n            without changing the math. This is to estimate the numerical error from operation\n            reordering.\n    Output:\n        output: (batch_size, seqlen_q, nheads, head_dim_v)\n        attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout\n    \"\"\"\n    if causal:\n        window_size = (window_size[0], 0)\n    dtype_og = q.dtype\n    if upcast:\n        q, k, v = q.float(), k.float(), v.float()\n        qv = qv.float() if qv is not None else None\n    if q_descale is not None:\n        q_descale = repeat(q_descale, \"b h -> b 1 (h g) 1\", g=q.shape[2] // k.shape[2])\n        q = (q.float() * q_descale).to(q.dtype)\n        qv = (qv.float() * q_descale).to(qv.dtype) if qv is not None else None\n    if k_descale is not None:\n        k = (k.float() * rearrange(k_descale, \"b h -> b 1 h 1\")).to(dtype=k.dtype)\n    if v_descale is not None:\n        v = (v.float() * rearrange(v_descale, \"b h -> b 1 h 1\")).to(dtype=v.dtype)\n    seqlen_q, seqlen_k = q.shape[1], k.shape[1]\n    k = repeat(k, \"b s h d -> b s (h g) d\", g=q.shape[2] // k.shape[2])\n    v = repeat(v, \"b s h d -> b s (h g) d\", g=q.shape[2] // v.shape[2])\n    d = q.shape[-1]\n    dv = v.shape[-1]\n    softmax_scale = 1.0 / math.sqrt(d if qv is None else d + dv)\n    if not reorder_ops:\n        scores = torch.einsum(\"bthd,bshd->bhts\", q * softmax_scale, k)\n    else:\n        scores = torch.einsum(\"bthd,bshd->bhts\", q, k * softmax_scale)\n    if qv is not None:\n        scores = scores + torch.einsum(\"bthd,bshd->bhts\", qv * softmax_scale, v)\n    if softcap > 0:\n        scores = torch.tanh(scores / softcap) * softcap\n    if key_padding_mask is not None:\n        scores.masked_fill_(rearrange(~key_padding_mask, \"b s -> b 1 1 s\"), float(\"-inf\"))\n    local_mask = None\n    if window_size[0] is not None or window_size[1] is not None:\n        local_mask = construct_local_mask(\n            seqlen_q,\n            seqlen_k,\n            window_size,\n            sink_token_length,\n            query_padding_mask,\n            key_padding_mask,\n            key_leftpad=key_leftpad,\n            device=q.device,\n        )\n    if attention_chunk > 0:\n        chunk_mask = construct_chunk_mask(\n            seqlen_q,\n            seqlen_k,\n            attention_chunk,\n            query_padding_mask,\n            key_padding_mask,\n            key_leftpad=key_leftpad,\n            device=q.device,\n        )\n        local_mask = torch.logical_or(local_mask, chunk_mask) if local_mask is not None else chunk_mask\n    if local_mask is not None:\n        scores.masked_fill_(local_mask, float(\"-inf\"))\n    if attn_bias is not None:\n        scores = scores + attn_bias\n    if learnable_sink is None:\n        attention = torch.softmax(scores, dim=-1).to(v.dtype)\n    else:\n        scores_fp32 = scores.to(torch.float32)\n        logits_max = torch.amax(scores_fp32, dim=-1, keepdim=True)\n        learnable_sink = rearrange(learnable_sink, \"h -> h 1 1\")\n        logits_or_sinks_max = torch.maximum(learnable_sink, logits_max)\n        unnormalized_scores = torch.exp(scores_fp32 - logits_or_sinks_max)\n        normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + torch.exp(learnable_sink - logits_or_sinks_max)\n        attention = (unnormalized_scores / normalizer).to(v.dtype)\n    # We want to mask here so that the attention matrix doesn't have any NaNs\n    # Otherwise we'll get NaN in dV\n    if query_padding_mask is not None:\n        attention = attention.masked_fill(rearrange(~query_padding_mask, \"b s -> b 1 s 1\"), 0.0)\n    # Without this we might get NaN in dv\n    if key_padding_mask is not None:\n        attention = attention.masked_fill(rearrange(~key_padding_mask, \"b s -> b 1 1 s\"), 0.0)\n    # Some rows might be completely masked out so we fill them with zero instead of NaN\n    if local_mask is not None:\n        attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0)\n    dropout_scaling = 1.0 / (1 - dropout_p)\n    # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling\n    # output = torch.einsum('bhts,bshd->bthd', attention_drop , v)\n    if dropout_mask is not None:\n        attention_drop = attention.masked_fill(~dropout_mask, 0.0)\n    else:\n        attention_drop = attention\n    if intermediate_dtype is not None:\n        attention_drop = attention_drop.to(intermediate_dtype).to(attention_drop.dtype)\n    output = torch.einsum(\"bhts,bshd->bthd\", attention_drop, v * dropout_scaling)\n    if query_padding_mask is not None:\n        output.masked_fill_(rearrange(~query_padding_mask, \"b s -> b s 1 1\"), 0.0)\n    return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)\n"
  },
  {
    "path": "flash_attn/utils/torch.py",
    "content": "import torch\nfrom typing import Callable\n\n\ndef custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool):\n    def decorator(*args, **kwargs):\n        if cuda_amp_deprecated:\n            kwargs[\"device_type\"] = \"cuda\"\n        return dec(*args, **kwargs)\n    return decorator\n\n\nif hasattr(torch.amp, \"custom_fwd\"): # type: ignore[attr-defined]\n    deprecated = True\n    from torch.amp import custom_fwd, custom_bwd # type: ignore[attr-defined]\nelse:\n    deprecated = False\n    from torch.cuda.amp import custom_fwd, custom_bwd\n\ncustom_fwd = custom_amp_decorator(custom_fwd, deprecated)\ncustom_bwd = custom_amp_decorator(custom_bwd, deprecated)\n"
  },
  {
    "path": "hopper/__init__.py",
    "content": "__version__ = \"3.0.0\"\n"
  },
  {
    "path": "hopper/benchmark_attn.py",
    "content": "from collections import namedtuple\nfrom functools import partial\nimport math\nimport os\nfrom typing import NamedTuple\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nimport time\n\ntry:\n    import cudnn\nexcept ImportError:\n    cudnn = None\n# cudnn = None\n\nTiming = NamedTuple('timing', [('mean', float)])\n\n\nfrom einops import rearrange, repeat\n\n# from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler\nfrom flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler\nfrom flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func\nfrom flash_attn_interface import flash_attn_func as flash_attn_func_v3\n# from flash_attn_interface import flash_attn_with_kvcache as flash_attn_func_v3\nfrom flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3\n\nfrom triton.testing import do_bench\n\ntry:\n    from triton_fused_attention import attention as triton_attention\nexcept ImportError:\n    triton_attention = None\ntriton_attention = None\n\nDISABLE_BACKWARD = os.getenv(\"FLASH_ATTENTION_DISABLE_BACKWARD\", \"FALSE\") == \"TRUE\"\n\n\ndef time_fwd(func, *args, repeats=30, verbose=True, desc=\"\", **kwargs):\n    # # Warmup\n    # for _ in range(5):\n    #     func(*args, **kwargs)\n    # time.sleep(1)\n    # return benchmark_forward(func, *args, **kwargs, repeats=repeats, verbose=verbose, desc=desc)[1]\n    # s = torch.cuda.Stream()\n    # s.wait_stream(torch.cuda.current_stream())\n    # with torch.cuda.stream(s):\n    #     for _ in range(2):\n    #         out = func(*args, **kwargs)\n    # torch.cuda.current_stream().wait_stream(s)\n    # graph = torch.cuda.CUDAGraph()\n    # with torch.cuda.graph(graph):\n    #     out = func(*args, **kwargs)\n    # time_f = benchmark_forward(lambda: graph.replay(), repeats=repeats, verbose=verbose, desc=desc)\n    # # return time_f[1].mean\n    # return time_f[1]\n    return Timing(do_bench(lambda: func(*args, **kwargs), warmup=3, rep=repeats) * 1e-3)\n\n\ndef flops(batch, nheads, seqlen_q, seqlen_k, headdim, headdim_v, causal=False, window_size=(-1, -1)):\n    if causal:\n        avg_seqlen = (max(0, seqlen_k - seqlen_q) + seqlen_k) / 2\n    else:\n        if window_size == (-1, -1):\n            avg_seqlen = seqlen_k\n        else:\n            row_idx = torch.arange(seqlen_q, device='cuda')\n            col_left = torch.maximum(row_idx + seqlen_k - seqlen_q - window_size[0], torch.tensor(0))\n            col_right = torch.minimum(row_idx + seqlen_k - seqlen_q + window_size[1], torch.tensor(seqlen_k - 1))\n            avg_seqlen = (col_right - col_left + 1).float().mean().item()\n    return batch * nheads * 2 * seqlen_q * avg_seqlen * (headdim + headdim_v)\n\n\ndef convert_to_cudnn_type(torch_type):\n    if torch_type == torch.float16:\n        return cudnn.data_type.HALF\n    elif torch_type == torch.bfloat16:\n        return cudnn.data_type.BFLOAT16\n    elif torch_type == torch.float32:\n        return cudnn.data_type.FLOAT\n    elif torch_type == torch.int32:\n        return cudnn.data_type.INT32\n    elif torch_type == torch.int64:\n        return cudnn.data_type.INT64\n    else:\n        raise ValueError(\"Unsupported tensor data type.\")\n\n\ndef cudnn_spda_setup(q, k, v, causal=False, window_size_left=-1):\n    b, nheads, seqlen_q, headdim = q.shape\n    _, nheads_k, seqlen_k, _ = k.shape\n    assert v.shape == (b, nheads_k, seqlen_k, headdim)\n    assert cudnn is not None, 'CUDNN is not available'\n    q_gpu, k_gpu, v_gpu = q, k, v\n    o_gpu = torch.empty_like(q_gpu)\n    stats_gpu = torch.empty(b, nheads, seqlen_q, 1, dtype=torch.float32, device=q.device)\n    graph = cudnn.pygraph(\n        io_data_type=convert_to_cudnn_type(q.dtype),\n        intermediate_data_type=cudnn.data_type.FLOAT,\n        compute_data_type=cudnn.data_type.FLOAT,\n    )\n    q = graph.tensor_like(q_gpu.detach())\n    k = graph.tensor_like(k_gpu.detach())\n    v = graph.tensor_like(v_gpu.detach())\n\n    o, stats = graph.sdpa(\n        name=\"sdpa\",\n        q=q,\n        k=k,\n        v=v,\n        is_inference=False,\n        attn_scale=1.0 / math.sqrt(headdim),\n        # use_causal_mask_bottom_right=causal or window_size_left >= 0,\n        use_causal_mask=causal or window_size_left >= 0,\n        sliding_window_length=window_size_left if window_size_left >= 0 and not causal else None,\n    )\n\n    o.set_output(True).set_dim(o_gpu.shape).set_stride(o_gpu.stride())\n    stats.set_output(True).set_data_type(cudnn.data_type.FLOAT)\n\n    graph.validate()\n    graph.build_operation_graph()\n    graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])\n    graph.check_support()\n    graph.build_plans()\n\n    variant_pack = {\n        q: q_gpu,\n        k: k_gpu,\n        v: v_gpu,\n        o: o_gpu,\n        stats: stats_gpu,\n    }\n\n    workspace = torch.empty(graph.get_workspace_size(), device=\"cuda\", dtype=torch.uint8)\n\n    def run(*args, **kwargs):\n        graph.execute(variant_pack, workspace)\n        return o_gpu\n\n    return run\n\n\ndef cudnn_spda_bwd_setup(q, k, v, o, g, lse, causal=False, window_size_left=-1):\n    b, nheads, seqlen_q, headdim = q.shape\n    _, nheads_k, seqlen_k, _ = k.shape\n    assert v.shape == (b, nheads_k, seqlen_k, headdim)\n    assert g.shape == (b, nheads, seqlen_q, headdim)\n    assert o.shape == (b, nheads, seqlen_q, headdim)\n    assert lse.shape == (b, nheads, seqlen_q, 1)\n    assert cudnn is not None, 'CUDNN is not available'\n    q_gpu, k_gpu, v_gpu, o_gpu, g_gpu = q, k, v, o, g\n    dq_gpu = torch.empty_like(q_gpu)\n    dk_gpu = torch.empty_like(k_gpu)\n    dv_gpu = torch.empty_like(v_gpu)\n    graph = cudnn.pygraph(\n        io_data_type=convert_to_cudnn_type(q.dtype),\n        intermediate_data_type=cudnn.data_type.FLOAT,\n        compute_data_type=cudnn.data_type.FLOAT,\n    )\n    q = graph.tensor_like(q_gpu.detach())\n    k = graph.tensor_like(k_gpu.detach())\n    v = graph.tensor_like(v_gpu.detach())\n    o = graph.tensor_like(o_gpu.detach())\n    g = graph.tensor_like(g_gpu.detach())\n    stats = graph.tensor_like(lse.detach())\n\n    dq, dk, dv = graph.sdpa_backward(\n        name=\"sdpa_backward\",\n        q=q,\n        k=k,\n        v=v,\n        o=o,\n        dO=g,\n        stats=stats,\n        attn_scale=1.0 / math.sqrt(headdim),\n        # use_causal_mask_bottom_right=causal or window_size_left >= 0,\n        use_causal_mask=causal or window_size_left >= 0,\n        sliding_window_length=window_size_left if window_size_left >= 0 and not causal else None,\n    )\n\n    dq.set_output(True).set_dim(dq_gpu.shape).set_stride(dq_gpu.stride())\n    dk.set_output(True).set_dim(dk_gpu.shape).set_stride(dk_gpu.stride())\n    dv.set_output(True).set_dim(dv_gpu.shape).set_stride(dv_gpu.stride())\n\n    graph.validate()\n    graph.build_operation_graph()\n    graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])\n    graph.check_support()\n    graph.build_plans()\n\n    variant_pack = {\n        q: q_gpu,\n        k: k_gpu,\n        v: v_gpu,\n        o: o_gpu,\n        g: g_gpu,\n        stats: lse,\n        dq: dq_gpu,\n        dk: dk_gpu,\n        dv: dv_gpu,\n    }\n\n    workspace = torch.empty(graph.get_workspace_size(), device=\"cuda\", dtype=torch.uint8)\n\n    def run(*args, **kwargs):\n        graph.execute(variant_pack, workspace)\n        return dq_gpu, dk_gpu, dv_gpu\n\n    return run\n\n\ntorch.manual_seed(0)\nrepeats = 10\ndropout_p = 0.0\ncausal = False\ndtype = torch.bfloat16\n# dtype = torch.float8_e4m3fn\ndtype_gen = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype\ndevice = 'cuda'\nverbose = True\nvarlen = False\npage_size = None\nsoftcap = 0.0\nV_colmajor = False\ndeterministic = False\nbatch_size = 2\n# seqlen = 2048\nseqlen = 8192\n# seqlen = 4096\n# seqlen = 2047\ndim = 2048\n# headdim = 128\n# headdim = 64\nheaddim = 256\n# for headdim in [64, 128, 256]:\n# bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)]\n# bs_seqlen_vals = [(16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)]\n# bs_seqlen_vals = [(32, 512), (16, 1024)]\n# bs_seqlen_vals = [(2, 64 * 132)]\nbs_seqlen_vals = [(2, 8192)]\n# bs_seqlen_vals = [(1, 16 * 1024)]\ntime_f = {}\ntime_b = {}\n\n# for headdim in [64, 128, 256]:\n# for headdim in [64, 96, 128, 192]:\n# for headdim in [64, 96, 128, 192, 256]:\n# for headdim in [64, 96, 128]:\n# for headdim in [64, 128, 256]:\n# for headdim in [64, 96, 128, 192, 256]:\nfor headdim in [128]:\n    nheads = dim // headdim\n    # nheads = 128\n    # headdim = 64\n    # batch_size = 64\n    # seqlen = 512\n    # nheads = 8\n    # headdim = 128\n    nheads_kv = nheads\n    # nheads_kv = nheads // 4\n    # nheads_kv = 1\n    headdim_v = headdim\n    # headdim_v = 512\n    has_qv = headdim == 64 and headdim_v == 512\n    # has_qv = False\n\n    for batch_size, seqlen in bs_seqlen_vals:\n        num_splits = 0\n        window_size = (-1, -1)\n        # window_size = (seqlen // 2 - 1, 0)\n        pack_gqa = None\n        # seqlen_q = 64\n        seqlen_q = seqlen\n        leftpad_k = None\n        # leftpad_k = torch.full((batch_size,), 0, device=device, dtype=torch.int32)\n        q = torch.randn(batch_size, seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=True)\n        k = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype_gen, requires_grad=True)\n        v = torch.randn(batch_size, seqlen, nheads_kv, headdim_v, device=device, dtype=dtype_gen, requires_grad=True)\n        q, k, v = [x.detach().to(dtype).requires_grad_() for x in [q, k, v]]\n        v_colmajor = v.detach().transpose(-1, -3).contiguous().transpose(-1, -3).requires_grad_()\n        v_fa3 = v if not V_colmajor else v_colmajor\n        qv = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen) if has_qv else None\n        # q = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim), device=device, dtype=torch.int32).to(dtype)\n        # k = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim), device=device, dtype=torch.int32).to(dtype)\n        # v = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim_v), device=device, dtype=torch.int32).to(dtype)\n        g = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen, requires_grad=True)\n        o = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen, requires_grad=True)\n        stats = torch.randn(batch_size, seqlen_q, nheads, 1, device=device, dtype=torch.float32)\n        if varlen:\n            q_unpad, k_unpad, v_unpad = [rearrange(x.detach(), \"b s h d -> (b s) h d\").requires_grad_() for x in [q, k, v]]\n            cu_seqlens_q = torch.arange(batch_size + 1, device=device, dtype=torch.int32) * seqlen_q\n            cu_seqlens_k = torch.arange(batch_size + 1, device=device, dtype=torch.int32) * seqlen\n            # cu_seqlens_q = torch.tensor([0, 248, 249, 250, 251, 252, 253, 254, 255, 256], device=device, dtype=torch.int32)\n            # q_unpad = q_unpad[:256]\n            # seqlen_q = 256\n            # cu_seqlens_q = torch.tensor([0, 376, 377, 378, 379, 380, 381, 382, 383, 384], device=device, dtype=torch.int32)\n            # q_unpad = q_unpad[:384]\n            # seqlen_q = 384\n        if page_size is not None:\n            assert seqlen % page_size == 0\n            k_paged, v_paged = [rearrange(x, \"b (n p) h d -> (b n) p h d\", p=page_size) for x in [k, v]]\n            page_table = rearrange(torch.arange(batch_size * seqlen // page_size, device=device, dtype=torch.int32),\n                                   \"(b s) -> b s\", s=seqlen // page_size)\n        else:\n            page_table = None\n\n        for causal in [False, True]:\n        # for causal in [True]:\n            print(f\"\\n### {headdim = }, {causal = }, {seqlen = } ###\")\n            nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen, headdim if not has_qv else headdim + headdim_v, headdim_v, causal=causal, window_size=window_size)\n            if cudnn is not None:\n            # if False:\n                if headdim <= 256 and dtype != torch.float8_e4m3fn and headdim == headdim_v:\n                    cudnn_spda = cudnn_spda_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), causal=causal, window_size_left=window_size[0])\n                    cudnn_spda_bwd = cudnn_spda_bwd_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), o.transpose(1, 2), g.transpose(1, 2), stats.transpose(1, 2), causal=causal, window_size_left=window_size[0])\n            # _, m0 = benchmark_forward(flash_attn_func, q, k, v, dropout_p, causal=causal, repeats=repeats, verbose=verbose, desc='Fav2')\n            if dtype != torch.float8_e4m3fn and headdim == headdim_v:\n            # if False:\n                if not varlen:\n                    m0 = time_fwd(flash_attn_func, q, k, v, dropout_p, causal=causal, window_size=window_size, softcap=softcap, repeats=repeats, verbose=verbose, desc='Fav2')\n                else:\n                    m0 = time_fwd(flash_attn_varlen_func, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, dropout_p, causal=causal, window_size=window_size, softcap=softcap, repeats=repeats, verbose=verbose, desc='Fav2')\n                time_f[(causal, headdim, batch_size, seqlen), \"Flash2\"] = m0.mean\n                time.sleep(1)\n                if not varlen:\n                    _, m0b = benchmark_backward(flash_attn_func, q, k, v, dropout_p, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic,\n                                                repeats=repeats, verbose=False, desc='Fav2')\n                else:\n                    _, m0b = benchmark_backward(flash_attn_varlen_func, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, dropout_p, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic,\n                                                repeats=repeats, verbose=False, desc='Fav2')\n                time_b[(causal, headdim, batch_size, seqlen), \"Flash2\"] = m0b.mean\n            # pytorch_profiler(flash_attn_func, q, k, v, dropout_p, causal=causal, backward=True)\n            if headdim <= 256 and dtype != torch.float8_e4m3fn and headdim == headdim_v:\n                if triton_attention is not None:\n                    qt, kt, vt = [x.detach().transpose(1, 2).contiguous().requires_grad_() for x in [q, k, v]]\n                    time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark\n                    m3 = time_fwd(triton_attention, qt, kt, vt, causal, 1 / math.sqrt(headdim), repeats=repeats, verbose=verbose, desc='Triton')\n                    time_f[(causal, headdim, batch_size, seqlen), \"Triton\"] = m3.mean\n                    # if causal: # triton bwd only works w causal for now\n                    #     time.sleep(1)\n                    #     _, m3b = benchmark_backward(triton_attention, qt, kt, vt, causal, 1 / math.sqrt(headdim), repeats=repeats, verbose=verbose, desc='Triton')\n                    #     time_b[(causal, headdim, batch_size, seqlen), \"Triton\"] = m3b.mean\n                    # # pytorch_profiler(triton_attention, q.transpose(1, 2).contiguous(), k.transpose(1, 2).contiguous(), v.transpose(1, 2).contiguous(), causal, 1 / math.sqrt(headdim), backward=True)\n            if cudnn is not None:\n            # if False:\n                if headdim <= 256 and dtype != torch.float8_e4m3fn and headdim == headdim_v:\n                    time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark\n                    m2 = time_fwd(cudnn_spda, repeats=repeats, verbose=verbose, desc='CuDNN')\n                    time_f[(causal, headdim, batch_size, seqlen), \"cuDNN\"] = m2.mean\n                    time.sleep(1)\n                    m2b = time_fwd(cudnn_spda_bwd, repeats=repeats, verbose=verbose, desc='CuDNN')\n                    time_b[(causal, headdim, batch_size, seqlen), \"cuDNN\"] = m2b.mean\n                # pytorch_profiler(cudnn_spda, backward=False)\n                # pytorch_profiler(cudnn_spda_bwd, backward=False)\n\n            time.sleep(1)\n            if not varlen:\n                # m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, cache_leftpad = leftpad_k, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3')\n                m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, qv=qv, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3')\n                # pytorch_profiler(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa)\n            else:\n                m1 = time_fwd(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3')\n                # pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits)\n            time_f[(causal, headdim, batch_size, seqlen), \"Flash3\"] = m1.mean\n            if dtype != torch.float8_e4m3fn and headdim == headdim_v and not DISABLE_BACKWARD:\n                time.sleep(1)\n                if not varlen:\n                    _, m1b = benchmark_backward(flash_attn_func_v3, q, k, v, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic,\n                                                repeats=repeats, verbose=False, desc='Fav3')\n                else:\n                    _, m1b = benchmark_backward(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic,\n                                                repeats=repeats, verbose=False, desc='Fav3')\n                time_b[(causal, headdim, batch_size, seqlen), \"Flash3\"] = m1b.mean\n                # time.sleep(1)\n                # if not varlen:\n                #     pytorch_profiler(flash_attn_func_v3, q, k, v, causal=causal, deterministic=deterministic, backward=True)\n                # else:\n                #     pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, deterministic=deterministic, backward=True)\n            # benchmark_forward(torch.clone, k, repeats=repeats, verbose=verbose, desc='Memcpy')\n\n            if dtype != torch.float8_e4m3fn and headdim == headdim_v:\n            # if False:\n                print(f'Fav2 fwd: {m0.mean * 1e3:.3f}ms, {(nFLOPS / m0.mean * 1e-12):.1f} TFLOPS')\n                print(f'Fav2 bwd: {m0b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m0b.mean * 1e-12):.1f} TFLOPS')\n            if headdim <= 256 and dtype != torch.float8_e4m3fn and headdim == headdim_v:\n                if triton_attention is not None:\n                    print(f'Triton fwd: {m3.mean * 1e3:.3f}ms, {(nFLOPS / m3.mean * 1e-12):.1f} TFLOPS')\n                    # if causal:\n                    #     print(f'Triton bwd: {m3b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m3b.mean * 1e-12):.1f} TFLOPS')\n                if cudnn is not None:\n                    print(f'CuDNN fwd: {m2.mean * 1e3:.3f}ms, {(nFLOPS / m2.mean * 1e-12):.1f} TFLOPS')\n                    print(f'CuDNN bwd: {m2b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m2b.mean * 1e-12):.1f} TFLOPS')\n            print(f'Fav3 fwd: {m1.mean * 1e3:.3f}ms, {(nFLOPS / m1.mean * 1e-12):.1f} TFLOPS')\n            if dtype != torch.float8_e4m3fn and headdim == headdim_v and not DISABLE_BACKWARD:\n                print(f'Fav3 bwd: {m1b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m1b.mean * 1e-12):.1f} TFLOPS')\n            # benchmark_forward(torch.square, k)\n            # print(f'cuBLAS: {m5.mean * 1e3:.3f}ms, {(nFLOPS_matmul / m5.mean * 1e-12):.1f} TFLOPS')\n    # print(time_f)\n    # print(time_b)\n\n    # import pickle\n    # # with open(f'flash3_attn_time_h100_hdim{headdim}_causal.plk', 'wb') as fp:\n    # # with open(f'flash3_attn_time_h100_cudnn_triton_20241208.plk', 'wb') as fp:\n    # with open(f'flash3_attn_time_h100_fa3_20250313.plk', 'wb') as fp:\n    # # with open(f'flash3_attn_time_h100_fa3_fp8_20250313.plk', 'wb') as fp:\n    # # with open(f'flash3_attn_time_h100_fp8_hdim{headdim}.plk', 'wb') as fp:\n    # # with open(f'flash3_attn_time_h100_hdim{headdim}_1031.plk', 'wb') as fp:\n    #     pickle.dump((time_f, time_b), fp, protocol=pickle.HIGHEST_PROTOCOL)\n"
  },
  {
    "path": "hopper/benchmark_flash_attention_fp8.py",
    "content": "# Install the newest triton version with\n# pip install \"git+https://github.com/openai/triton.git#egg=triton&subdirectory=python\"\nimport pickle\nimport math\nimport time\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom einops import rearrange, repeat\n\nfrom flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward\nfrom flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined\n\nfrom flash_attn import flash_attn_qkvpacked_func\nfrom flash_attn_interface import flash_attn_func, _flash_attn_forward\n\ntry:\n    from triton_fused_attention import attention as attention_triton\nexcept ImportError:\n    attention_triton = None\n\ntry:\n    import xformers.ops as xops\nexcept ImportError:\n    xops = None\n\ntry:\n    import cudnn\nexcept ImportError:\n    cudnn = None\n\n\ndef convert_to_cudnn_type(torch_type):\n    if torch_type == torch.float16:\n        return cudnn.data_type.HALF\n    elif torch_type == torch.bfloat16:\n        return cudnn.data_type.BFLOAT16\n    elif torch_type == torch.float32:\n        return cudnn.data_type.FLOAT\n    elif torch_type == torch.int32:\n        return cudnn.data_type.INT32\n    elif torch_type == torch.int64:\n        return cudnn.data_type.INT64\n    elif torch_type == torch.float8_e4m3fn:\n        return cudnn.data_type.FP8_E4M3\n    elif torch_type == torch.float8_e5m2:\n        return cudnn.data_type.FP8_E5M2\n    else:\n        raise ValueError(\"Unsupported tensor data type.\")\n\ndef cudnn_spda_setup(qkv, seqlen_q, seqlen_k, causal=False):\n    b, _, _, nheads, headdim = qkv.shape\n    assert cudnn is not None, 'CUDNN is not available'\n    o_gpu = torch.zeros(b, seqlen_q, nheads, headdim, dtype=qkv.dtype, device=qkv.device)\n    o_gpu_transposed = torch.as_strided(\n        o_gpu,\n        [b, nheads, seqlen_q, headdim],\n        [nheads * seqlen_q * headdim, headdim, nheads * headdim, 1],\n    )\n    stats_gpu = torch.empty(b, nheads, seqlen_q, 1, dtype=torch.float32, device=qkv.device)\n    amax_s_gpu = torch.empty(1, 1, 1, 1, dtype=torch.float32, device=qkv.device)\n    amax_o_gpu = torch.empty(1, 1, 1, 1, dtype=torch.float32, device=qkv.device)\n    graph = cudnn.pygraph(\n        io_data_type=convert_to_cudnn_type(qkv.dtype),\n        intermediate_data_type=cudnn.data_type.FLOAT,\n        compute_data_type=cudnn.data_type.FLOAT,\n    )\n    new_q = torch.as_strided(\n        qkv,\n        [b, nheads, seqlen_q, headdim],\n        [seqlen_q * nheads * headdim * 3, headdim, headdim * nheads * 3, 1],\n        storage_offset=0,\n    )\n    q = graph.tensor(\n        name = \"Q\",\n        dim = list(new_q.shape),\n        stride = list(new_q.stride()),\n        data_type=convert_to_cudnn_type(qkv.dtype)\n    )\n    new_k = torch.as_strided(\n        qkv,\n        [b, nheads, seqlen_k, headdim],\n        [seqlen_k * nheads * headdim * 3, headdim, headdim * nheads * 3, 1],\n        storage_offset=nheads * headdim,\n    )\n    k = graph.tensor(\n        name = \"K\",\n        dim = list(new_k.shape),\n        stride = list(new_k.stride()),\n        data_type=convert_to_cudnn_type(qkv.dtype)\n    )\n    new_v = torch.as_strided(\n        qkv,\n        [b, nheads, seqlen_k, headdim],\n        [seqlen_k * nheads * headdim * 3, headdim, headdim * nheads * 3, 1],\n        storage_offset=nheads * headdim * 2,\n    )\n    v = graph.tensor(\n        name = \"V\",\n        dim = list(new_v.shape),\n        stride = list(new_v.stride()),\n        data_type=convert_to_cudnn_type(qkv.dtype)\n    )\n\n    def get_default_scale_tensor():\n        return graph.tensor(\n            dim = [1, 1, 1, 1],\n            stride = [1, 1, 1, 1],\n            data_type=cudnn.data_type.FLOAT\n        )\n\n    default_scale_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float32, device=\"cuda\")\n    descale_q = get_default_scale_tensor()\n    descale_k = get_default_scale_tensor()\n    descale_v = get_default_scale_tensor()\n    descale_s = get_default_scale_tensor()\n    scale_s = get_default_scale_tensor()\n    scale_o = get_default_scale_tensor()\n\n    o, _, amax_s, amax_o = graph.sdpa_fp8(\n        q=q,\n        k=k,\n        v=v,\n        descale_q=descale_q,\n        descale_k=descale_k,\n        descale_v=descale_v,\n        descale_s=descale_s,\n        scale_s=scale_s,\n        scale_o=scale_o,\n        is_inference=True,\n        attn_scale=1.0 / math.sqrt(headdim),\n        use_causal_mask=causal,\n        name=\"sdpa\",\n    )\n\n    o.set_output(True).set_dim(o_gpu_transposed.shape).set_stride(o_gpu_transposed.stride())\n\n    amax_s.set_output(False).set_dim(amax_s_gpu.shape).set_stride(amax_s_gpu.stride())\n    amax_o.set_output(False).set_dim(amax_o_gpu.shape).set_stride(amax_o_gpu.stride())\n    # stats.set_output(True).set_data_type(cudnn.data_type.FLOAT)\n\n    graph.validate()\n    graph.build_operation_graph()\n    graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])\n    graph.check_support()\n    graph.build_plans()\n\n    variant_pack = {\n        q: new_q,\n        k: new_k,\n        v: new_v,\n        descale_q: default_scale_gpu,\n        descale_k: default_scale_gpu,\n        descale_v: default_scale_gpu,\n        descale_s: default_scale_gpu,\n        scale_s: default_scale_gpu,\n        scale_o: default_scale_gpu,\n        o: o_gpu_transposed,\n        amax_s: amax_s_gpu,\n        amax_o: amax_o_gpu,\n    }\n\n    workspace = torch.empty(graph.get_workspace_size(), device=\"cuda\", dtype=torch.uint8)\n\n    def run(*args, **kwargs):\n        graph.execute(variant_pack, workspace)\n        return o_gpu, amax_o_gpu\n\n    return run\n\n\ndef attention_pytorch(qkv, dropout_p=0.0, causal=True):\n    \"\"\"\n    Arguments:\n        qkv: (batch_size, seqlen, 3, nheads, head_dim)\n        dropout_p: float\n    Output:\n        output: (batch_size, seqlen, nheads, head_dim)\n    \"\"\"\n    batch_size, seqlen, _, nheads, d = qkv.shape\n    q, k, v = qkv.unbind(dim=2)\n    q = rearrange(q, 'b t h d -> (b h) t d')\n    k = rearrange(k, 'b s h d -> (b h) d s')\n    softmax_scale = 1.0 / math.sqrt(d)\n    # Preallocate attn_weights for `baddbmm`\n    scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device)\n    scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale),\n                       '(b h) t s -> b h t s', h=nheads)\n    if causal:\n        # \"triu_tril_cuda_template\" not implemented for 'BFloat16'\n        # So we have to construct the mask in float\n        causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)\n        # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)\n        scores = scores + causal_mask.to(dtype=scores.dtype)\n    attention = torch.softmax(scores, dim=-1)\n    attention_drop = F.dropout(attention, dropout_p)\n    output = torch.einsum('bhts,bshd->bthd', attention_drop , v)\n    return output.to(dtype=qkv.dtype)\n\ndef flops(batch, seqlen, headdim, nheads, causal, mode=\"fwd\"):\n    assert mode in [\"fwd\", \"bwd\", \"fwd_bwd\"]\n    f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)\n    return f if mode == \"fwd\" else (2.5 * f if mode == \"bwd\" else 3.5 * f)\n\ndef efficiency(flop, time):\n    return (flop / time / 10**12) if not math.isnan(time) else 0.0\n\ndef time_fwd(func, *args, **kwargs):\n    time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark\n    time_f = benchmark_forward(func, *args, **kwargs)\n    return time_f[1].mean\n\n\ntorch.manual_seed(0)\n\nrepeats = 30\ndevice = 'cuda'\n# dtype = torch.float16\ndtype = torch.float8_e4m3fn\n\n# bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4224), (2, 8448), (1, 8448 * 2)]\nbs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 8192 * 2)]\n# bs_seqlen_vals = [(4, 4096), (2, 8192), (1, 8192 * 2)]\n# bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048)]\ncausal_vals = [False, True]\nheaddim_vals = [64, 128, 256]\ndim = 2048\n# dim = 256\ndropout_p = 0.0\n\nmethods = ([\"Pytorch\", \"Flash3\"]\n        + ([\"cuDNN\"] if cudnn is not None else [])\n        # + ([\"Triton\"] if attention_triton is not None else [])\n        #    + ([\"xformers.c\"] if xops is not None else [])\n        #    + ([\"xformers.f\"] if xops is not None else [])\n           )\n\ntime_f = {}\ntime_b = {}\ntime_f_b = {}\nspeed_f = {}\nspeed_b = {}\nspeed_f_b = {}\nfor causal in causal_vals:\n    for headdim in headdim_vals:\n        for batch_size, seqlen in bs_seqlen_vals:\n            torch.cuda.empty_cache()\n            config = (causal, headdim, batch_size, seqlen)\n            nheads = dim // headdim\n            q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=torch.bfloat16, requires_grad=False) for _ in range(3)]\n            \n            qkv = torch.stack([q, k, v], dim=2)\n            qkv = qkv.to(torch.bfloat16)\n            f = time_fwd(attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False)\n            time_f[config, \"Pytorch\"] = f\n            res_baseline = attention_pytorch(qkv, dropout_p, causal=causal)\n\n            if attention_triton is not None:\n                q_transposed = q.transpose(1, 2).contiguous().to(torch.float8_e4m3fn)\n                k_transposed = k.transpose(1, 2).contiguous().to(torch.float8_e4m3fn)\n                v_transposed = v.transpose(1, 2).contiguous().permute(0, 1, 3, 2).to(torch.float8_e4m3fn)\n                scale = 1 / math.sqrt(headdim)\n                f = time_fwd(\n                    attention_triton, q_transposed, k_transposed, v_transposed,\n                    causal, scale, repeats=5, verbose=False, desc='Triton'\n                )\n                f = time_fwd(\n                    attention_triton, q_transposed, k_transposed, v_transposed,\n                    causal, scale, repeats=repeats, verbose=False, desc='Triton'\n                )\n                time_f[config, \"Triton\"] = f\n                res = attention_triton(\n                    q_transposed, k_transposed, v_transposed.permute(0, 1, 3, 2),\n                    causal, scale\n                ).half().transpose(1, 2)\n                torch.testing.assert_close(res, res_baseline, atol=0.5, rtol=0.5)\n\n            # out = torch.empty_like(q)\n            q, k, v = q.to(dtype), k.to(dtype), v.to(dtype)\n            softmax_scale = q.shape[-1] ** (-0.5)\n            descale_q = torch.tensor([1.0], dtype=torch.float32, device='cuda')\n            descale_k = torch.tensor([1.0], dtype=torch.float32, device='cuda')\n            descale_v = torch.tensor([1.0], dtype=torch.float32, device='cuda')\n\n            # f = time_fwd(flash_attn_func, q, k, v, causal=causal, repeats=repeats, verbose=False)\n            f = time_fwd(\n                _flash_attn_forward,\n                q, \n                k, \n                v, \n                softmax_scale, \n                causal=causal,\n                window_size=(-1,-1),\n                descale_q=descale_q, \n                descale_k=descale_k, \n                descale_v=descale_v, \n                repeats=repeats, \n                verbose=False\n            )\n\n            # res = flash_attn_func(q, k, v, causal=causal)\n            # torch.testing.assert_close(res.half(), res_baseline, atol=0.05, rtol=0.05)\n\n            time_f[config, \"Flash3\"] = f\n\n            if cudnn is not None:\n                qkv_fp8 = qkv.to(dtype)\n                time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark\n                f = time_fwd(\n                    cudnn_spda_setup(\n                        qkv_fp8, seqlen, seqlen,\n                        causal=causal\n                    ),\n                    repeats=repeats, verbose=False\n                )\n                time_f[config, \"cuDNN\"] = f\n                # res, amax_o = cudnn_spda_setup(\n                #     qkv_fp8, seqlen, seqlen,\n                #     causal=causal\n                # )()\n                # res = res.half()\n                # TODO: CUDNN has numerics issues when\n                # num_heads=16, dim=128, seq_len=1024, batch_size=2\n                # or larger sizes.\n                # res_cpu = res.cpu().reshape(-1)\n                # res_baseline_cpu = res_baseline.cpu().reshape(-1)\n                # print(amax_o)\n                # print(res)\n                # print(res_baseline)\n                # for i in range(len(res_cpu)):\n                #     item = res_cpu[i]\n                #     item_baseline = res_baseline_cpu[i]\n                #     if abs(item - item_baseline) > 0.5:\n                #         print(i)\n                #         print(item)\n                #         print(item_baseline)\n                # torch.testing.assert_close(res, res_baseline, atol=0.05, rtol=0.05)\n\n            print(f\"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###\")\n            for method in methods:\n                speed_f[config, method] = efficiency(\n                    flops(batch_size, seqlen, headdim, nheads, causal, mode=\"fwd\"),\n                    time_f[config, method]\n                )\n                #print (time_f[config,method])\n                print(\n                    f\"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, {time_f[config, method] * 1e3} ms, \"\n                )\n\n\n# with open('flash3_attn_time.plk', 'wb') as fp:\n#     pickle.dump((time_f, time_b, time_f_b), fp, protocol=pickle.HIGHEST_PROTOCOL)\n"
  },
  {
    "path": "hopper/benchmark_mla_decode.py",
    "content": "# Copyright (c) 2025, Ted Zadouri, Tri Dao.\n\n# We recommend locking GPU clocks before running the benchmark to ensure consistent results.\n# This can be done using the following commands (1830 MHz is the clock for H100):\n# sudo nvidia-smi -i 0 -pm 1\n# sudo nvidia-smi -i 0 --lock-gpu-clocks 1830,1830\n# See more here: https://github.com/triton-lang/triton/blob/d9f10ebdc5da53f73eb852fde73d8d7d80b679d1/python/triton/testing.py#L487\n\nimport time\nimport torch\nimport torch.nn.functional as F\n\nfrom triton.testing import do_bench, do_bench_cudagraph\n\nfrom einops import rearrange\n\nfrom flash_attn_interface import flash_attn_with_kvcache, get_scheduler_metadata\n\ntry:\n    from flash_mla import flash_mla_with_kvcache, get_mla_metadata\nexcept ImportError:\n    flash_mla_with_kvcache, get_mla_metadata = None, None\n\ntry:\n    from flash_attn.utils.benchmark import pytorch_profiler\nexcept ImportError:\n    pytorch_profiler = None\n\n\ndevice = \"cuda\"\ndtype = torch.bfloat16\nseqlen = 8192\nseqlen_q = 1\n# nheads_q = 16\nnheads_q = 128\n\nuse_bench_cudagraph = False\n\nattn_variants = [\"mha\", \"gqa\", \"mqa\", \"mla\", \"gla\"]\n# for attn_variant in attn_variants:\nfor attn_variant in attn_variants[3:5]:\n    nheads_kv = nheads_q if attn_variant == \"mha\" else (max(nheads_q // 8, 1) if attn_variant == \"gqa\" else (1 if attn_variant == \"mla\" else 2))\n    headdim = 64 if attn_variant in [\"mla\", \"gla\"] else 128\n    headdim_v = 512 if attn_variant == \"mla\" else (256 if attn_variant == \"gla\" else headdim)\n    has_qv = headdim == 64 and headdim_v > 64\n    # page_size = None\n    page_size = 64 if attn_variant in [\"mla\", \"gla\"] else 128\n\n    should_run_flashmla = attn_variant == \"mla\" and page_size == 64 and flash_mla_with_kvcache is not None\n\n    torch.manual_seed(0)\n\n    batch_size = 128\n    cache_seqlens = None\n    # cache_seqlens = torch.tensor([seqlen] * batch_size, device=device, dtype=torch.int)\n    # cache_seqlens = torch.tensor([seqlen - 1, 1024, 1024, 1024], device=device, dtype=torch.int32)\n    # cache_seqlens = torch.tensor([1024] * batch_size, device=device, dtype=torch.int)\n    # cache_seqlens = torch.tensor([4500, 45000, 1800, 1800], dtype=torch.int32, device=device)\n\n    print(f\"\\n{attn_variant.upper()}, nheads_q = {nheads_q}, nheads_kv = {nheads_kv}, headdim = {headdim}, headdim_v = {headdim_v}, page_size = {page_size}\")\n\n    for seqlen in [s * 1024 for s in [1, 2, 4, 8, 16, 32, 64]]:\n    # for seqlen in [s * 1024 for s in [8]]:\n        cache_seqlens = torch.tensor([seqlen] * batch_size, device=device, dtype=torch.int)\n        num_splits = 0\n        q = torch.randn(batch_size, seqlen_q, nheads_q, headdim, dtype=dtype, device=device)\n        try:\n            v_cache = torch.randn(batch_size, seqlen, nheads_kv, headdim_v, dtype=dtype, device=device)\n            k_cache = torch.randn(batch_size, seqlen, nheads_kv, headdim, dtype=dtype, device=device)\n            if page_size is not None:\n                assert seqlen % page_size == 0\n                k_cache, v_cache = [rearrange(x, \"b (n p) h d -> (b n) p h d\", p=page_size) for x in [k_cache, v_cache]]\n                page_table = rearrange(torch.arange(batch_size * seqlen // page_size, device=device, dtype=torch.int32),\n                                    \"(b s) -> b s\", s=seqlen // page_size)\n            else:\n                page_table = None\n        except torch.OutOfMemoryError:\n            continue\n        qv = torch.randn(batch_size, seqlen_q, nheads_q, headdim_v, dtype=dtype, device=device) if has_qv else None\n\n        # Precomputing this saves ~2us\n        scheduler_metadata = get_scheduler_metadata(\n            batch_size, seqlen_q, seqlen, nheads_q, nheads_kv, headdim,\n            cache_seqlens, q.dtype, headdim_v=headdim_v, page_size=page_size, causal=True\n        )\n        # scheduler_metadata = None\n        # breakpoint()\n        fn0 = lambda: flash_attn_with_kvcache(q, k_cache, v_cache, cache_seqlens=cache_seqlens, num_splits=num_splits, qv=qv, page_table=page_table, causal=True, scheduler_metadata=scheduler_metadata)\n        time.sleep(1)  # to avoid power throttling\n        # Time in ms\n        if not use_bench_cudagraph:\n            t0 = do_bench(fn0, warmup=1, rep=10)\n        else:\n            torch.cuda.synchronize()  # Gotta wait, otherwise e.g. k_cache might not be ready\n            with torch.cuda.stream(torch.cuda.Stream()):\n                t0 = do_bench_cudagraph(fn0, rep=10)\n        # exit(0)\n        if should_run_flashmla:\n            # Separate out the preprocessing since this can be done once and reused for all layers\n            mla_metadata = get_mla_metadata(cache_seqlens, seqlen_q * nheads_q // nheads_kv, nheads_kv)\n            q_concat = torch.concat([q, qv], dim=-1) if has_qv else q\n            kv_cache_concat = torch.concat([v_cache, k_cache], dim=-1)\n            fn1 = lambda: flash_mla_with_kvcache(q_concat, kv_cache_concat, page_table, cache_seqlens, headdim_v, *mla_metadata, causal=True)\n            time.sleep(1)  # to avoid power throttling\n            if not use_bench_cudagraph:\n                t1 = do_bench(fn1, warmup=1, rep=10)\n            else:\n                torch.cuda.synchronize()  # Gotta wait, otherwise e.g. k_cache might not be ready\n                with torch.cuda.stream(torch.cuda.Stream()):\n                    t1 = do_bench_cudagraph(fn1, rep=10)\n\n        total_seqlen = seqlen * batch_size if cache_seqlens is None else cache_seqlens.sum().item()\n        mem_io = total_seqlen * nheads_kv * (headdim + headdim_v) * 2 + q.numel() * 2 + (qv.numel() * 2 if has_qv else 0) + q.numel() * headdim_v // headdim * 2  # last term is for the output\n        flops = seqlen_q * total_seqlen * nheads_q * (headdim + headdim_v * (2 if has_qv else 1)) * 2\n        ideal_h100_time_mem = mem_io / 3.35e12 * 1e6\n        ideal_h100_time_flop = flops / 989e12 * 1e6\n        ideal_h100_time = max(ideal_h100_time_mem, ideal_h100_time_flop)\n        print(f\"Seqlen = {seqlen}, FA3 time{'' if not use_bench_cudagraph else ' w CUDA Graph'}: {t0 * 1e3:.1f} us, {mem_io * 1e-9 / (t0 * 1e-3):.0f} GB/s, {flops * 1e-12 / (t0 * 1e-3):.0f} TFLOPS/s\")\n        if should_run_flashmla:\n            print(f\"Seqlen = {seqlen}, FlashMLA time{'' if not use_bench_cudagraph else ' w CUDA Graph'}: {t1 * 1e3:.1f} us, {mem_io * 1e-9 / (t1 * 1e-3):.0f} GB/s, {flops * 1e-12 / (t1 * 1e-3):.0f} TFLOPS/s\")\n        print(f\"Arithmetic intensity: {flops / mem_io:.1f}\")\n        print(f\"Ideal time: {ideal_h100_time:.0f} us\")\n\n        # if pytorch_profiler is not None:\n        #     time.sleep(1)  # to avoid power throttling\n        #     pytorch_profiler(fn0)\n        #     if should_run_flashmla:\n        #         time.sleep(1)  # to avoid power throttling\n        #         pytorch_profiler(fn1)\n"
  },
  {
    "path": "hopper/benchmark_split_kv.py",
    "content": "import torch\nimport flash_attn\nimport flash_attn_interface\nimport itertools\nimport time\nimport math\n\nimport torch.utils.benchmark as benchmark\n\ndef round_up_to_power_of_2(x):\n    if x <= 1:\n        return 1\n    return 1 << (x - 1).bit_length()\n\ndef timeit(fn, *args, **kwargs):\n    torch.cuda.synchronize()\n\n    # Warmup\n    for _ in range(5):\n        fn(*args, **kwargs)\n\n    # Benchmark using PyTorch Timer\n    t = benchmark.Timer(\n        stmt='fn(*args, **kwargs)',\n        globals={'fn': fn, 'args': args, 'kwargs': kwargs}\n    )\n\n    # Measure execution time\n    measurement = t.timeit(20)  # Runs the function 20 times\n    # measurement = t.blocked_autorange(min_run_time=1)\n    avg_time = measurement.mean  # Average time in seconds\n\n    return avg_time\n\ndef main():\n    num_sms = torch.cuda.get_device_properties(\n        torch.cuda.current_device()\n    ).multi_processor_count\n\n    max_splits = 129\n    check_all_splits = True\n\n    causal = True\n    # causal = False\n    # dtype=torch.float16\n    dtype=torch.bfloat16\n    tp_degree = 1\n\n    torch.manual_seed(42)\n\n    model_configs = [\n        # (\"Gemma-2-2B\", 8, 4, 256),\n        # (\"Gemma-2-9B\", 16, 8, 256),\n        # (\"Gemma-2-27B\", 32, 16, 128),\n        # (\"Qwen-2.5-0.5B\", 14, 2, 64),\n        # (\"Qwen-2.5-1.5B\", 12, 2, 128),\n        # (\"Qwen-2.5-7B\", 28, 4, 128),\n        # (\"Llama-3.1-8B\", 32, 8, 128),\n        (\"Llama-3.1-70B\", 64, 8, 128),\n        # (\"Mistral Large\", 96, 8, 128),\n        # (\"Llama-3.1-405B\", 128, 8, 128),\n        # (\"Llama-3.2-1B\", 32, 8, 64),\n        # (\"Llama-3.2-3B\", 24, 8, 128),\n        # (\"Nemotron-4-15B\", 48, 8, 128),\n    ]\n\n    all_batch_configs = []\n\n    all_batch_configs.extend(itertools.product(\n        # [1024, 2048, 4096, 8192, 16384, 32768, 131072],  # context_seqlen\n        # [4096, 16384, 65536],  # context_seqlen\n        [131072],  # context_seqlen\n        # [i for i in range(1, (num_sms) + 1)], # num_requests\n        [1, 4, 8, 16],  # num_requests\n        # [1],  # num_requests\n        # [1, 4, 8, 16],  # query_seqlen\n        [1],  # query_seqlen\n    ))\n\n    num_caches = max(reqs for _, reqs, _ in all_batch_configs)\n    cache_seqlen = max(seqlen for seqlen, _, _ in all_batch_configs)\n\n    for model_name, nheads_q, nheads_kv, headdim in model_configs:\n        assert nheads_kv % tp_degree == 0\n        print(f\"***{model_name}***\")\n        print(f\"QHEADS:{nheads_q}, KVHEADS:{nheads_kv}, HEADDIM:{headdim}, TP:{tp_degree}\")\n        nheads_q //= tp_degree\n        nheads_kv //= tp_degree\n\n        k_cache = torch.randn(\n            (num_caches, cache_seqlen, nheads_kv, headdim), device=\"cuda\", dtype=dtype\n        )\n        v_cache = torch.randn(\n            (num_caches, cache_seqlen, nheads_kv, headdim), device=\"cuda\", dtype=dtype\n        )\n\n        if check_all_splits is False:\n            print(f\"{'CONTEXT':<9}{'BSZ':<5}{'QLEN':<6}{'FA2':<10}{'FA3':<9}{'RATIO':<7}{'GB/s':<10}\")\n\n        for context_seqlen, num_requests, query_seqlen in all_batch_configs:\n            bytes_kv = (context_seqlen * num_requests * nheads_kv * headdim * 4)\n            bytes_q = (query_seqlen * num_requests * nheads_q * headdim * 4)\n            blockH = round_up_to_power_of_2(nheads_q//nheads_kv)\n            blockM = 128 # true for hdim 128 causal and hdim 64\n            blockM_div_H = blockM//blockH\n            num_work_tiles = nheads_kv * num_requests * math.ceil(query_seqlen/blockM_div_H)\n\n            q = torch.randn((num_requests, query_seqlen, nheads_q, headdim), device=\"cuda\", dtype=dtype)\n            cache_idxs = torch.randperm(num_caches, dtype=torch.int32, device=\"cuda\")[:num_requests]\n            cache_seqlens = torch.tensor(\n                [context_seqlen] * num_requests, dtype=torch.int32, device=\"cuda\"\n            )\n\n            fa2_time_heuristic = timeit(\n                flash_attn.flash_attn_with_kvcache,\n                q=q,\n                k_cache=k_cache,\n                v_cache=v_cache,\n                cache_seqlens=cache_seqlens,\n                cache_batch_idx=cache_idxs,\n                causal=causal,\n            ) * 1000. * 1000.\n            # fastest_splitk_time = float(\"inf\")\n            # fastest_splitk = 0\n            # for i in range(1, max_splits):\n            #     t = timeit(\n            #         flash_attn.flash_attn_with_kvcache,\n            #         q=q,\n            #         k_cache=k_cache,\n            #         v_cache=v_cache,\n            #         cache_seqlens=cache_seqlens,\n            #         cache_batch_idx=cache_idxs,\n            #         causal=causal,\n            #         num_splits=i,\n            #     ) * 1000. * 1000.\n            #     if t < fastest_splitk_time:\n            #         fastest_splitk_time = t\n            #         fastest_splitk = i\n\n            fa3_time_one_split = timeit(\n                flash_attn_interface.flash_attn_with_kvcache,\n                q=q,\n                k_cache=k_cache,\n                v_cache=v_cache,\n                cache_seqlens=cache_seqlens,\n                cache_batch_idx=cache_idxs,\n                causal=causal,\n                pack_gqa=False,\n                num_splits=1,\n            ) * 1000. * 1000.\n\n            fa3_time_gqa_heuristic = timeit(\n                flash_attn_interface.flash_attn_with_kvcache,\n                q=q,\n                k_cache=k_cache,\n                v_cache=v_cache,\n                cache_seqlens=cache_seqlens,\n                cache_batch_idx=cache_idxs,\n                causal=causal,\n                pack_gqa=True,\n                num_splits=0,\n                # max_seqlen_k_hint=context_seqlen\n            ) * 1000. * 1000.\n\n            if check_all_splits:\n\n                fa3_fastest_num_splits = 0\n                fa3_fastest_splitk_time = float(\"inf\")\n\n                for num_splits in range(1, max_splits):\n                    t = timeit(\n                        flash_attn_interface.flash_attn_with_kvcache,\n                        q=q,\n                        k_cache=k_cache,\n                        v_cache=v_cache,\n                        cache_seqlens=cache_seqlens,\n                        cache_batch_idx=cache_idxs,\n                        causal=causal,\n                        pack_gqa=False,\n                        num_splits=num_splits\n                    ) * 1000. * 1000.\n\n                    out0 = flash_attn_interface.flash_attn_with_kvcache(\n                        q=q,\n                        k_cache=k_cache,\n                        v_cache=v_cache,\n                        cache_seqlens=cache_seqlens,\n                        cache_batch_idx=cache_idxs,\n                        causal=causal,\n                        pack_gqa=False,\n                        num_splits=num_splits\n                    )\n\n                    out1 = flash_attn_interface.flash_attn_with_kvcache(\n                        q=q,\n                        k_cache=k_cache,\n                        v_cache=v_cache,\n                        cache_seqlens=cache_seqlens,\n                        cache_batch_idx=cache_idxs,\n                        causal=causal,\n                        pack_gqa=False,\n                        num_splits=1\n                    )\n\n                    max_diff = (out0 - out1).abs().max().item()\n                    mean_diff = (out0 - out1).abs().mean().item()\n                    # print (f\"splits {num_splits}, out diff-max, {max_diff}, out diff-mean, {mean_diff}, time {t:.2f}\")\n                    # print (f\"splits {num_splits}, time {t:.2f}\")\n\n                    if math.isnan(max_diff) or math.isnan(mean_diff) or max_diff > 2e-3 or mean_diff > 1e-4:\n                        print(f\"Numerical error too high: Splits: {num_splits}, Max: {max_diff}, Mean: {mean_diff}\")\n\n                    if t < fa3_fastest_splitk_time:\n                        fa3_fastest_splitk_time = t\n                        fa3_fastest_num_splits = num_splits\n\n                fa3_fastest_num_splits_gqa = 0\n                fa3_fastest_splitk_time_gqa = float(\"inf\")\n                for num_splits in range(1, max_splits):\n\n                    t = timeit(\n                        flash_attn_interface.flash_attn_with_kvcache,\n                        q=q,\n                        k_cache=k_cache,\n                        v_cache=v_cache,\n                        cache_seqlens=cache_seqlens,\n                        cache_batch_idx=cache_idxs,\n                        causal=causal,\n                        pack_gqa=True,\n                        num_splits=num_splits\n                    ) * 1000. * 1000.\n\n                    out0 = flash_attn_interface.flash_attn_with_kvcache(\n                        q=q,\n                        k_cache=k_cache,\n                        v_cache=v_cache,\n                        cache_seqlens=cache_seqlens,\n                        cache_batch_idx=cache_idxs,\n                        causal=causal,\n                        pack_gqa=True,\n                        num_splits=num_splits\n                    )\n\n                    out1 = flash_attn_interface.flash_attn_with_kvcache(\n                        q=q,\n                        k_cache=k_cache,\n                        v_cache=v_cache,\n                        cache_seqlens=cache_seqlens,\n                        cache_batch_idx=cache_idxs,\n                        causal=causal,\n                        pack_gqa=True,\n                        num_splits=1\n                    )\n\n                    max_diff = (out0 - out1).abs().max().item()\n                    mean_diff = (out0 - out1).abs().mean().item()\n                    # print (f\"gqa splits {num_splits}, out gqa diff-max {max_diff}, out gqa diff-mean {mean_diff}, time {t:.2f}\")\n                    # print (f\"gqa splits {num_splits}, time {t:.2f}\")\n\n                    if math.isnan(max_diff) or math.isnan(mean_diff) or max_diff > 2e-3 or mean_diff > 1e-4:\n                        print(f\"Numerical error too high (gqa): Splits: {num_splits}, Max: {max_diff}, Mean: {mean_diff}\")\n\n                    if t < fa3_fastest_splitk_time_gqa:\n                        fa3_fastest_splitk_time_gqa = t\n                        fa3_fastest_num_splits_gqa = num_splits\n\n                efficiency = (num_work_tiles * fa3_fastest_num_splits_gqa)/num_sms\n                heuristic_ratio = fa3_time_gqa_heuristic/fa3_fastest_splitk_time_gqa\n                # remeasure to smooth anomalies\n                if heuristic_ratio > 1.1:\n\n                    fa3_time_gqa_heuristic = timeit(\n                        flash_attn_interface.flash_attn_with_kvcache,\n                        q=q,\n                        k_cache=k_cache,\n                        v_cache=v_cache,\n                        cache_seqlens=cache_seqlens,\n                        cache_batch_idx=cache_idxs,\n                        causal=causal,\n                        pack_gqa=True,\n                        # num_splits=num_splits_select,\n                        # num_splits=1,\n                        num_splits=0,\n                        # max_seqlen_k_hint=context_seqlen\n                    ) * 1000. * 1000.\n\n                    fa3_fastest_splitk_time_gqa = timeit(\n                        flash_attn_interface.flash_attn_with_kvcache,\n                        q=q,\n                        k_cache=k_cache,\n                        v_cache=v_cache,\n                        cache_seqlens=cache_seqlens,\n                        cache_batch_idx=cache_idxs,\n                        causal=causal,\n                        pack_gqa=True,\n                        num_splits=fa3_fastest_num_splits_gqa\n                    ) * 1000. * 1000.\n\n            if check_all_splits is True:\n                print(\n                    f\"CONTEXT:{context_seqlen}, BSZ:{num_requests}, QLEN:{query_seqlen}, \"\n                    f\"FA2:{fa2_time_heuristic:.2f}, \"\n                    # f\"FA2 MANUAL:{fastest_splitk_time:.2f}, \"\n                    # f\"FA2 NUM SPLITS:{fastest_splitk}, \"\n                    # f\"FA3 NOGQA NOSPLIT:{fa3_time_one_split:.2f}, \"\n                    # f\"FA3 NOGQA SPLIT MANUAL:{fa3_fastest_splitk_time:.2f}, \"\n                    # f\"FA3 NOSPLIT:{fa3_time_one_split_gqa:.2f}, \"\n                    f\"FA3 SPLIT MANUAL:{fa3_fastest_splitk_time_gqa:.2f}, \"\n                    f\"FA3:{fa3_time_gqa_heuristic:.2f}, \"\n                    # f\"FA3 RATIO (NONSPLIT/SPLIT):{fa3_time_one_split_gqa/fa3_time_gqa_heuristic:.2f}, \"\n                    # f\"FA2 NUM SPLITS:{fastest_splitk}, \"\n                    # f\"FA3 NOGQA NUM SPLITS:{fa3_fastest_num_splits}, \"\n                    f\"FA3 NUM SPLITS:{fa3_fastest_num_splits_gqa}, \"\n                    # f\"RATIO (FA2/3):{fa2_time_heuristic/fa3_time_gqa_heuristic:.2f}, \"\n                    f\"RATIO:{fa3_time_gqa_heuristic/fa3_fastest_splitk_time_gqa:.2f}, \"\n                    f\"EFF:{efficiency:.2f}, \"\n                    f\"GB/s:{bytes_kv/fa3_time_gqa_heuristic * 1e-3:.2f}\"\n                )\n\n            if check_all_splits is False:\n                print(\n                    f\"{context_seqlen:<9}{num_requests:<5}{query_seqlen:<6}\"\n                    f\"{fa2_time_heuristic:<10.2f}{fa3_time_gqa_heuristic:<9.2f}\"\n                    f\"{fa2_time_heuristic/fa3_time_gqa_heuristic:<7.2f}\"\n                    f\"{bytes_kv/fa3_time_gqa_heuristic * 1e-3:<10.2f}\"\n                )\n\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "hopper/block.h",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\nnamespace flash {\n\ntemplate <class SeqlenInfo_t, int kBlockM, int kBlockN, bool Is_causal, bool Is_local, bool PackGQA=false, bool Split=false>\nstruct BlockMN {\n\n    static\n    CUTLASS_DEVICE\n    cute::tuple<int, int> get_n_block_min_max(\n            SeqlenInfo_t const& seqlen_info,\n            int const m_block, int const bidb, int const split_idx, int const num_splits,\n            int const window_size_left, int const window_size_right,\n            cutlass::FastDivmod const& attention_chunk_divmod,\n            cutlass::FastDivmod const& qhead_per_khead_divmod) {\n\n        int const seqlen_k = seqlen_info.seqlen_k;\n        int const seqlen_q = seqlen_info.seqlen_q;\n        int n_block_max = cute::ceil_div(seqlen_k, kBlockN);\n        if constexpr (Is_causal || Is_local) {\n            int m_idx_max = (m_block + 1) * kBlockM;\n            // TODO: check off-by-1 error\n            if (PackGQA) { m_idx_max = qhead_per_khead_divmod.divide(m_idx_max - 1) + 1 ; }\n            int const n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q;\n            int n_idx_right = !Is_local ? n_idx : n_idx + window_size_right;\n            if (Is_local && attention_chunk_divmod.divisor > 0) {\n                n_idx_right = std::min(n_idx_right, flash::round_up(attention_chunk_divmod, n_idx));\n            }\n            n_block_max = std::min(n_block_max, cute::ceil_div(n_idx_right, kBlockN));\n        }\n        int n_block_min = 0;\n        if constexpr (Is_local) {\n            int m_idx_min = m_block * kBlockM;\n            if (PackGQA) { m_idx_min = qhead_per_khead_divmod.divide(m_idx_min); }\n            int const n_idx = m_idx_min + seqlen_k - seqlen_q;\n            int n_idx_left = n_idx - window_size_left;\n            if (attention_chunk_divmod.divisor > 0) {\n                n_idx_left = std::max(n_idx_left, flash::round_down(attention_chunk_divmod, n_idx));\n            }\n            n_block_min = std::max(int(0), n_idx_left / kBlockN);\n        }\n        // if (threadIdx.x == 128) { printf(\"Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\\n\", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); }\n        if constexpr (Split) {\n            uint32_t num_splits_dynamic_u = reinterpret_cast<uint32_t const&>(split_idx) >> 16; // first 16 bits are for num_splits\n            int num_splits_dynamic = reinterpret_cast<int&>(num_splits_dynamic_u);\n            int split_idx_actual = split_idx & 0x0000FFFF;\n            int num_splits_actual = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits;\n            int num_n_blocks_per_split = n_block_max <= n_block_min ? 0 : cute::ceil_div(n_block_max - n_block_min, num_splits_actual);\n            n_block_min = n_block_min + split_idx_actual * num_n_blocks_per_split;\n            n_block_max = std::min(n_block_min + num_n_blocks_per_split, n_block_max);\n            // if (threadIdx.x == 128) { printf(\"Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, num_splits_dynamic = %d, num_splits_actual = %d, num_n_blocks_per_split = %d, n_block_min: %d, n_block_max: %d\\n\", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, num_splits_dynamic, num_splits_actual, num_n_blocks_per_split, n_block_min, n_block_max); }\n        }\n        // if (threadIdx.x == 128) { printf(\"After split, inside, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\\n\", blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); }\n        return {n_block_min, n_block_max};\n    }\n\n    static\n    CUTLASS_DEVICE\n    cute::tuple<int, int> get_n_block_k_new_min_max(\n            SeqlenInfo_t const& seqlen_info,\n            int const m_block, int const bidb, int const split_idx, int const num_splits,\n            int const window_size_left, int const window_size_right,\n            cutlass::FastDivmod const& attention_chunk_divmod,\n            cutlass::FastDivmod const& qhead_per_khead_divmod) {\n\n        auto [n_block_min, n_block_max] = get_n_block_min_max(\n            seqlen_info, m_block, bidb, split_idx, num_splits,\n            window_size_left, window_size_right, attention_chunk_divmod, qhead_per_khead_divmod);\n        int const idx_k_new_min = std::max(n_block_min * kBlockN - seqlen_info.seqlen_k_og, 0);\n        int const idx_k_new_max = std::min(n_block_max * kBlockN - seqlen_info.seqlen_k_og, seqlen_info.seqlen_k_new);\n        int const n_block_new_min = idx_k_new_min / kBlockN;\n        int const n_block_new_max = idx_k_new_max > idx_k_new_min ? cute::ceil_div(idx_k_new_max, kBlockN) : n_block_new_min;\n        // if (threadIdx.x == 128 && m_block == 0) { printf(\"bidb = %d, seqlen_k_new = %d, seqlen_k_og = %d, n_block_min = %d, n_block_max = %d, idx_k_new_min = %d, idx_k_new_max = %d, n_block_new_min = %d, n_block_new_max = %d\\n\", bidb, seqlen_k_new, seqlen_k_og, n_block_min, n_block_max, idx_k_new_min, idx_k_new_max, n_block_new_min, n_block_new_max);}\n        return {n_block_new_min, n_block_new_max};\n    }\n\n    static\n    CUTLASS_DEVICE\n    cute::tuple<int, int> get_m_block_min_max(\n            SeqlenInfo_t const& seqlen_info,\n            int const n_block, int const bidb,\n            int const window_size_left, int const window_size_right, int const sink_token_length) {\n        // TODO: support attention_chunk\n        int const seqlen_q = seqlen_info.seqlen_q;\n        int const seqlen_k = seqlen_info.seqlen_k;\n        int m_block_max = cute::ceil_div(seqlen_q, kBlockM);\n        if constexpr (Is_local) {\n            if (n_block >= cute::ceil_div(sink_token_length, kBlockN)) {\n                m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + seqlen_q - seqlen_k + window_size_left, kBlockM));\n            }\n        }\n        int m_block_min = 0;\n        if constexpr (Is_causal || Is_local) {\n            m_block_min = std::max(m_block_min, (n_block * kBlockN + seqlen_q - seqlen_k - window_size_right) / kBlockM);\n        }\n        return {m_block_min, m_block_max};\n    }\n\n    // If we have separate iterations with causal or local masking at the start, where do we stop\n    static\n    CUTLASS_DEVICE\n    int get_n_block_min_causal_local_mask(\n            SeqlenInfo_t const& seqlen_info,\n            int const m_block, int const n_block_min, int const window_size_right,\n            cutlass::FastDivmod const& attention_chunk_divmod,\n            cutlass::FastDivmod const& qhead_per_khead_divmod) {\n        int const m_idx_min = !PackGQA ? m_block * kBlockM : qhead_per_khead_divmod.divide(m_block * kBlockM);\n        int const n_idx = m_idx_min + seqlen_info.seqlen_k - seqlen_info.seqlen_q;\n        int n_idx_right = !Is_local ? n_idx : n_idx + window_size_right;\n        if (Is_local && attention_chunk_divmod.divisor > 0) {\n            n_idx_right = std::min(n_idx_right, flash::round_up(attention_chunk_divmod, n_idx));\n        }\n        return std::max(n_block_min, n_idx_right / kBlockN);\n    }\n\n    // If we have separate iterations with local masking at the end, where do we stop the non-masked iterations\n    static\n    CUTLASS_DEVICE\n    int get_n_block_min_before_local_mask(\n            SeqlenInfo_t const& seqlen_info,\n            int const m_block, int const n_block_min, int const window_size_left,\n            cutlass::FastDivmod const& attention_chunk_divmod,\n            cutlass::FastDivmod const& qhead_per_khead_divmod) {\n        int const m_idx_max = !PackGQA ? (m_block + 1) * kBlockM : qhead_per_khead_divmod.divide((m_block + 1) * kBlockM - 1) + 1;\n        int const n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q;\n        int n_idx_left = !Is_local ? n_idx : n_idx - window_size_left;\n        if (Is_local && attention_chunk_divmod.divisor > 0) {\n            n_idx_left = std::max(n_idx_left, flash::round_down(attention_chunk_divmod, n_idx));\n        }\n        return !Is_local ? n_block_min : std::max(n_block_min, cute::ceil_div(n_idx_left, kBlockN));\n    }\n\n};\n\n} // namespace flash\n"
  },
  {
    "path": "hopper/copy_sm90_bulk_reduce.hpp",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n#include<cute/arch/copy_sm90_tma.hpp>\n\nnamespace cute\n{\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct SM90_BULK_REDUCE_ADD\n{\n  CUTE_HOST_DEVICE static void\n  copy(float const* smem_ptr,\n       float      * gmem_ptr, int32_t store_bytes)\n  {\n#if defined(CUTE_ARCH_TMA_SM90_ENABLED)\n    uint32_t smem_int_ptr  = cast_smem_ptr_to_uint(smem_ptr);\n    asm volatile(\"cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [%0], [%1], %2;\\n\"\n                     :\n                     : \"l\"(gmem_ptr), \"r\"(smem_int_ptr), \"r\"(store_bytes)\n                     : \"memory\");\n#else\n    CUTE_INVALID_CONTROL_PATH(\"Trying to use BULK_REDUCE_ADD without CUTE_ARCH_TMA_SM90_ENABLED.\");\n#endif\n  }\n\n  CUTE_HOST_DEVICE static void\n  copy(float const* smem_ptr,\n       float      * gmem_ptr, int32_t store_bytes, uint64_t cache_hint)\n  {\n#if defined(CUTE_ARCH_TMA_SM90_ENABLED)\n    uint32_t smem_int_ptr  = cast_smem_ptr_to_uint(smem_ptr);\n    asm volatile(\"cp.reduce.async.bulk.global.shared::cta.bulk_group.L2::cache_hint.add.f32 [%0], [%1], %2, %3;\\n\"\n                     :\n                     : \"l\"(gmem_ptr), \"r\"(smem_int_ptr), \"r\"(store_bytes), \"l\"(cache_hint)\n                     : \"memory\");\n#else\n    CUTE_INVALID_CONTROL_PATH(\"Trying to use BULK_REDUCE_ADD without CUTE_ARCH_TMA_SM90_ENABLED.\");\n#endif\n  }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n} // end namespace cute\n"
  },
  {
    "path": "hopper/cuda_check.h",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n#include <assert.h>\n#include <stdlib.h>\n\n#include <cutlass/cutlass.h>\n\n#define CHECK_CUDA(call)                        \\\n    do {                                                                                                  \\\n        cudaError_t status_ = call;                                                                       \\\n        if (status_ != cudaSuccess) {                                                                     \\\n            fprintf(stderr, \"CUDA error (%s:%d): %s\\n\", __FILE__, __LINE__, cudaGetErrorString(status_)); \\\n            exit(1);                                                                                      \\\n        }                                                                                                 \\\n    } while(0)\n\n#define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError())\n\n#define CHECK_CUTLASS(call)                                                                               \\\n    do {                                                                                                  \\\n        cutlass::Status status_ = (call);                                                                 \\\n        if (status_ != cutlass::Status::kSuccess) {                                                        \\\n            fprintf(stderr, \"CUTLASS error (%s:%d): %s\\n\", __FILE__, __LINE__, cutlass::cutlassGetStatusString(status_)); \\\n            exit(1);                                                                                      \\\n        }                                                                                                 \\\n    } while(0)\n"
  },
  {
    "path": "hopper/epilogue_bwd.hpp",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n#include \"cutlass/cutlass.h\"\n#include \"cutlass/barrier.h\"\n#include \"cute/tensor.hpp\"\n\n#include \"cutlass/gemm/collective/builders/sm90_common.inl\"\n\n#include \"seqlen.h\"\n#include \"named_barrier.hpp\"\n#include \"utils.h\"\n\nnamespace flash {\n\nusing namespace cute;\n\ntemplate <class TileShape_MNK_, class Element_, class ArchTag_,\n          int NumEpilogueThreads_, bool Varlen_, bool dKV_swapAB_, int AtomLayoutKdKV=1>\nstruct CollectiveEpilogueBwd {\n\n    using TileShape_MNK = TileShape_MNK_;\n    using Element = Element_;\n    using ArchTag = ArchTag_;\n    static constexpr int NumEpilogueThreads = NumEpilogueThreads_;\n    static constexpr bool Varlen = Varlen_;\n    static constexpr bool dKV_swapAB = dKV_swapAB_;\n    static constexpr bool Use_TMA = !Varlen && ArchTag::kMinComputeCapability >= 90;\n\n    static_assert(ArchTag::kMinComputeCapability >= 80);\n\n    using GmemTiledCopydKVTMA = cute::SM90_TMA_STORE;\n\n    // These are for storing the output tensor without TMA (e.g., for setting output to zero)\n    static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);\n    static_assert(get<2>(TileShape_MNK{}) % kGmemElemsPerLoad == 0, \"Headdim must be a multiple of kGmemElemsPerLoad\");\n    static constexpr int kHeadDim = get<2>(TileShape_MNK{});\n    static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDim / kGmemElemsPerLoad, NumEpilogueThreads);\n    static_assert(NumEpilogueThreads % kGmemThreadsPerRow == 0, \"NumEpilogueThreads must be a multiple of kGmemThreadsPerRow\");\n    using GmemLayoutAtom = Layout<Shape <Int<NumEpilogueThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,\n                                  Stride<Int<kGmemThreadsPerRow>, _1>>;\n    using GmemTiledCopydKV = decltype(\n        make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},\n                        GmemLayoutAtom{},\n                        Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{}));  // Val layout, 8 or 16 vals per store\n\n    using SmemLayoutAtomdKVTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,\n                                          // TODO: do we have to change this if dKV_swapAB is true?\n                                          decltype(cute::get<1>(TileShape_MNK{})), Int<CUTE_STATIC_V(cute::get<2>(TileShape_MNK{})) / AtomLayoutKdKV>>());\n    using SmemLayoutdKVTMA = decltype(tile_to_shape(SmemLayoutAtomdKVTMA{}, select<1, 2>(TileShape_MNK{})));\n    using SmemLayoutdKVtTMA =\n        decltype(cute::composition(SmemLayoutdKVTMA{},\n                                   make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})),\n                                               make_stride(decltype(get<1>(TileShape_MNK{})){}, _1{}))));\n\n    // If we don't use TMA\n    static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : (kHeadDim % 32 == 0 ? 32 : 16);\n    static constexpr int kSwizzle = kBlockKSmem == 64 ? 3 : (kBlockKSmem == 32 ? 2 : 1);\n    using SmemLayoutAtomdKVSTG =\n        decltype(composition(Swizzle<kSwizzle, 3, 3>{},\n                             Layout<Shape<Int<8>, Int<kBlockKSmem>>,\n                             Stride<Int<kBlockKSmem>, _1>>{}));\n\n    using SmemLayoutAtomdKV = std::conditional_t<Use_TMA, SmemLayoutAtomdKVTMA, SmemLayoutAtomdKVSTG>;\n    using SmemLayoutdKV = decltype(tile_to_shape(SmemLayoutAtomdKV{}, select<1, 2>(TileShape_MNK{})));\n    using SmemLayoutdKVt =\n        decltype(cute::composition(SmemLayoutdKV{},\n                                   make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})),\n                                               make_stride(decltype(get<1>(TileShape_MNK{})){}, _1{}))));\n\n    using SmemCopyAtomdKV = Copy_Atom<\n        std::conditional_t<\n            ArchTag::kMinComputeCapability >= 90,\n            std::conditional_t<!dKV_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,\n            AutoVectorizingCopyWithAssumedAlignment<128>\n        >,\n        Element>;\n\n    static constexpr size_t SmemAlignmentdKV = ArchTag::kMinComputeCapability >= 90 ? cutlass::detail::alignment_for_swizzle(SmemLayoutdKV{}) : 128;\n    static_assert(SmemAlignmentdKV >= 128, \"Require at least 128B alignment\");\n\n    struct TensorStorage : cute::aligned_struct<SmemAlignmentdKV> {\n        cute::array_aligned<Element, cute::cosize_v<SmemLayoutdKV>, SmemAlignmentdKV> smem_dk;\n        cute::array_aligned<Element, cute::cosize_v<SmemLayoutdKV>, SmemAlignmentdKV> smem_dv;\n    };\n\n    using ShapedKV = cute::Shape<int32_t, int32_t, int32_t, int32_t>;  // (seqlen_k, d, head, batch)\n    using StridedKV = cute::Stride<int64_t, _1, int64_t, int64_t>;\n\n    using TMA_dKV = std::conditional_t<\n        Use_TMA,\n        decltype(make_tma_copy(\n            GmemTiledCopydKVTMA{},\n            make_tensor(make_gmem_ptr(static_cast<Element*>(nullptr)), ShapedKV{}, StridedKV{}),\n            SmemLayoutdKVTMA{},\n            select<1, 2>(TileShape_MNK{}),\n            _1{})),  // no mcast for dKV\n        std::nullptr_t\n        >;\n\n    // Host side kernel arguments\n    struct Arguments {\n        Element* ptr_dK;\n        ShapedKV const shape_dK;\n        StridedKV const stride_dK;\n        Element* ptr_dV;\n        ShapedKV const shape_dV;\n        StridedKV const stride_dV;\n        int const num_batch;\n        int const num_heads_q;\n        int* dk_semaphore;\n        int* dv_semaphore;\n        int const* cu_seqlens;\n        int const* seqused;\n    };\n\n    // Device side kernel params\n    struct Params {\n        Element* ptr_dK;\n        ShapedKV const shape_dK;\n        StridedKV const stride_dK;\n        Element* ptr_dV;\n        ShapedKV const shape_dV;\n        StridedKV const stride_dV;\n        TMA_dKV tma_store_dK, tma_store_dV;\n        int const* cu_seqlens = nullptr;\n        int const* seqused = nullptr;\n    };\n\n    static Params\n    to_underlying_arguments(Arguments const& args) {\n        Tensor mdK = make_tensor(make_gmem_ptr(args.ptr_dK), args.shape_dK, args.stride_dK);\n        Tensor mdV = make_tensor(make_gmem_ptr(args.ptr_dV), args.shape_dV, args.stride_dV);\n        TMA_dKV tma_store_dK = [&] {\n            if constexpr (Use_TMA) {\n                return make_tma_copy(GmemTiledCopydKVTMA{}, mdK, SmemLayoutdKVTMA{}, select<1, 2>(TileShape_MNK{}), _1{}); // no mcast for dKV\n            } else {\n                return nullptr;\n            }\n        }();\n        TMA_dKV tma_store_dV = [&] {\n            if constexpr (Use_TMA) {\n                return make_tma_copy(GmemTiledCopydKVTMA{}, mdV, SmemLayoutdKVTMA{}, select<1, 2>(TileShape_MNK{}), _1{}); // no mcast for dKV\n            } else {\n                return nullptr;\n            }\n        }();\n        return {args.ptr_dK, args.shape_dK, args.stride_dK, args.ptr_dV, args.shape_dV, args.stride_dV,\n                tma_store_dK, tma_store_dV, args.cu_seqlens, args.seqused};\n    }\n\n    /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance\n    CUTLASS_DEVICE\n    static void prefetch_tma_descriptors(Params const& params) {\n        if constexpr (Use_TMA) {\n            cute::prefetch_tma_descriptor(params.tma_store_dK.get_tma_descriptor());\n            cute::prefetch_tma_descriptor(params.tma_store_dV.get_tma_descriptor());\n        }\n    }\n\n    template <typename SharedStorage, typename FrgTensorO, typename TiledMma>\n    CUTLASS_DEVICE void\n    store(Params const& params,\n          FrgTensorO const& tdKrdK,\n          FrgTensorO const& tdVrdV,\n          SharedStorage& shared_storage,\n          TiledMma tiled_mma,\n          int thread_idx,\n          cute::tuple<int32_t, int32_t, int32_t> const& block_coord\n          ) {\n\n        auto [n_block, bidh, bidb] = block_coord;\n        Tensor sdK = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dk.data()), SmemLayoutdKV{}));\n        Tensor sdV = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dv.data()), SmemLayoutdKV{}));\n        Tensor sdKt = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dk.data()), SmemLayoutdKVt{}));\n        Tensor sdVt = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dv.data()), SmemLayoutdKVt{}));\n        auto smem_tiled_copy_dKV = make_tiled_copy_C(SmemCopyAtomdKV{}, tiled_mma);\n        auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(thread_idx);\n\n        Tensor tdVrdV_out = make_tensor_like<Element>(tdVrdV);\n        flash::convert_type_out(tdVrdV, tdVrdV_out);\n        Tensor tdKrdK_out = make_tensor_like<Element>(tdKrdK);\n        flash::convert_type_out(tdKrdK, tdKrdK_out);\n        Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(tdKrdK_out);        // ((Atom,AtomNum), MMA_M, MMA_N)\n        Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(tdVrdV_out);        // ((Atom,AtomNum), MMA_M, MMA_N)\n        // if (blockIdx.x == 0 && threadIdx.x == 128) { print(smem_thr_copy_dKV); print(sdK); printf(\"\\n\"); print(sdKt); printf(\"\\n\"); }\n        Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(cute::conditional_return<!dKV_swapAB>(sdK, sdKt));     // ((Atom,AtomNum),PIPE_M,PIPE_N)\n        Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(cute::conditional_return<!dKV_swapAB>(sdV, sdVt));     // ((Atom,AtomNum),PIPE_M,PIPE_N)\n\n        // Make sure all WGs have finished reading K and V\n        flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);\n        cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV);\n        cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK);\n        if constexpr (Use_TMA) {\n            cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA\n            cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp,\n                                                cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);\n\n            Tensor mdK = params.tma_store_dK.get_tma_tensor(params.shape_dK);\n            Tensor mdV = params.tma_store_dV.get_tma_tensor(params.shape_dV);\n            Tensor gdK = local_tile(mdK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{}));  // (M, K)\n            Tensor gdV = local_tile(mdV(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{}));  // (M, K)\n            auto block_tma_dK = params.tma_store_dK.get_slice(_0{});\n            auto block_tma_dV = params.tma_store_dV.get_slice(_0{});\n            Tensor tdKgdK = block_tma_dK.partition_D(gdK);  // (TMA, TMA_M, TMA_K)\n            Tensor tdKsdK = block_tma_dK.partition_S(sdK); // (TMA, TMA_M, TMA_K)\n            Tensor tdVgdV = block_tma_dV.partition_D(gdV);  // (TMA, TMA_M, TMA_K)\n            Tensor tdVsdV = block_tma_dV.partition_S(sdV); // (TMA, TMA_M, TMA_K)\n            int warp_idx_sync = __shfl_sync(0xffffffff, thread_idx / cutlass::NumThreadsPerWarp, 0);\n            if (warp_idx_sync == NumEpilogueThreads / cutlass::NumThreadsPerWarp - 1) {\n                cutlass::arch::NamedBarrier::sync(NumEpilogueThreads + cutlass::NumThreadsPerWarp,\n                                                cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);\n                if (cute::elect_one_sync()) {\n                    cute::copy(params.tma_store_dV, tdVsdV, tdVgdV);\n                    cute::copy(params.tma_store_dK, tdKsdK, tdKgdK);\n                    tma_store_arrive();\n                }\n            }\n            tma_store_wait<0>();\n            // // Tell warp 0 that smem_k and smem_v are ready\n            // cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, static_cast<uint32_t>(BwdNamedBarriers::KVEmpty) /*id*/);\n\n        } else {\n            flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);\n            static constexpr int kBlockN = get<1>(TileShape_MNK{});\n            flash::SeqlenInfo<Varlen, kBlockN> seqlen_info{bidb, size<0>(params.shape_dK), params.cu_seqlens, params.seqused};\n            bool const is_varlen = Varlen && params.cu_seqlens;\n            Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0);\n            Tensor gdK = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{}));  // (M, K)\n            Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dV, params.stride_dV)(_, _, bidh, !is_varlen ? bidb : 0);\n            Tensor gdV = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{}));  // (M, K)\n\n            GmemTiledCopydKV gmem_tiled_copy_dKV;\n            auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(thread_idx);\n            Tensor tdKVgdV = gmem_thr_copy_dKV.partition_D(gdV);\n            Tensor tdKVsdV = gmem_thr_copy_dKV.partition_S(sdV); // (TMA, TMA_M, TMA_K)\n            Tensor tdKVgdK = gmem_thr_copy_dKV.partition_D(gdK);\n            Tensor tdKVsdK = gmem_thr_copy_dKV.partition_S(sdK); // (TMA, TMA_M, TMA_K)\n            Tensor tdKVrdV = make_fragment_like(tdKVgdV);\n            Tensor tdKVrdK = make_fragment_like(tdKVgdK);\n            Tensor cdKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{}));  // (BLK_N,BLK_K) -> (blk_n,blk_k)\n            // Repeat the partitioning with identity layouts\n            Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);\n            Tensor tdKVpdV = make_tensor<bool>(make_shape(size<2>(tdKVgdV)));\n            Tensor tdKVpdK = make_tensor<bool>(make_shape(size<2>(tdKVgdK)));\n            #pragma unroll\n            for (int k = 0; k < size(tdKVpdV); ++k) { tdKVpdV(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dV); }\n            #pragma unroll\n            for (int k = 0; k < size(tdKVpdK); ++k) { tdKVpdK(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dK); }\n            // Need to check OOB when reading from smem if kBlockN isn't evenly tiled\n            static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(size<0>(GmemLayoutAtom{})) == 0;\n            flash::copy</*Is_even_MN=*/EvenN, /*Is_even_K=*/true, /*Clear_OOB_MN=*/false>(\n                gmem_tiled_copy_dKV, tdKVsdV, tdKVrdV, tdKVcdKV, tdKVpdV, kBlockN);\n            flash::copy</*Is_even_MN=*/EvenN, /*Is_even_K=*/true, /*Clear_OOB_MN=*/false>(\n                gmem_tiled_copy_dKV, tdKVsdK, tdKVrdK, tdKVcdKV, tdKVpdK, kBlockN);\n            // // Tell warp 0 that smem_k and smem_v are ready\n            // cutlass::arch::fence_view_async_shared(); // ensure smem reads are done before next TMA to smem_k/v\n            // flash::named_barrier_arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, static_cast<uint32_t>(BwdNamedBarriers::KVEmpty) /*id*/);\n            // Construct identity layout for gdKV\n            // Clear_OOB_K must be false since we don't want to write zeros to gmem\n            flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(\n                gmem_tiled_copy_dKV, tdKVrdV, tdKVgdV, tdKVcdKV, tdKVpdV, std::min(seqlen_info.seqlen - n_block * kBlockN, kBlockN)\n            );\n            flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(\n                gmem_tiled_copy_dKV, tdKVrdK, tdKVgdK, tdKVcdKV, tdKVpdK, std::min(seqlen_info.seqlen - n_block * kBlockN, kBlockN)\n            );\n        }\n    }\n\n    CUTLASS_DEVICE void\n    store_tail() {\n        // if constexpr (Use_TMA) { tma_store_wait<0>(); }\n    }\n\n    // Write 0 to dK and dV\n    CUTLASS_DEVICE void\n    store_zero(\n         Params const& params,\n         int thread_idx,\n         cute::tuple<int32_t, int32_t, int32_t> const& block_coord\n         ) {\n        static constexpr int kBlockN = get<1>(TileShape_MNK{});\n        auto [n_block, bidh, bidb] = block_coord;\n        flash::SeqlenInfo<Varlen, kBlockN> seqlen_info{bidb, size<0>(params.shape_dK), params.cu_seqlens, params.seqused};\n        bool const is_varlen = Varlen && params.cu_seqlens;\n        Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0);\n        Tensor gdK = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{}));  // (M, K)\n        Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dV, params.stride_dV)(_, _, bidh, !is_varlen ? bidb : 0);\n        Tensor gdV = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{}));  // (M, K)\n\n        GmemTiledCopydKV gmem_tiled_copy_dKV;\n        auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(thread_idx);\n        Tensor tdKVgdK = gmem_thr_copy_dKV.partition_D(gdK);\n        Tensor tdKVgdV = gmem_thr_copy_dKV.partition_D(gdV);\n        Tensor tdKVrdKV = make_fragment_like(tdKVgdK);\n        clear(tdKVrdKV);\n        // Construct identity layout for gdKV\n        Tensor cdKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{}));  // (BLK_M,BLK_K) -> (blk_m,blk_k)\n        // Repeat the partitioning with identity layouts\n        Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);\n        Tensor tdKVpdK = make_tensor<bool>(make_shape(size<2>(tdKVgdK)));\n        Tensor tdKVpdV = make_tensor<bool>(make_shape(size<2>(tdKVgdV)));\n        #pragma unroll\n        for (int k = 0; k < size(tdKVpdK); ++k) { tdKVpdK(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dK); }\n        #pragma unroll\n        for (int k = 0; k < size(tdKVpdV); ++k) { tdKVpdV(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dV); }\n        // Clear_OOB_K must be false since we don't want to write zeros to gmem\n        flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(\n            gmem_tiled_copy_dKV, tdKVrdKV, tdKVgdK, tdKVcdKV, tdKVpdK, seqlen_info.seqlen - n_block * kBlockN\n        );\n        flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(\n            gmem_tiled_copy_dKV, tdKVrdKV, tdKVgdV, tdKVcdKV, tdKVpdV, seqlen_info.seqlen - n_block * kBlockN\n        );\n    }\n\n};\n\ntemplate <class TileShape_MNK_, class ElementAccum, class ArchTag_,\n          int NumEpilogueThreads_, bool Varlen_, bool Deterministic>\nstruct CollectiveEpilogueBwdGQA {\n\n    using TileShape_MNK = TileShape_MNK_;\n    using Element = ElementAccum;\n    using ArchTag = ArchTag_;\n    static constexpr int NumEpilogueThreads = NumEpilogueThreads_;\n    static constexpr bool Varlen = Varlen_;\n    static constexpr bool Use_TMA = ArchTag::kMinComputeCapability >= 90;\n\n    static_assert(ArchTag::kMinComputeCapability >= 80);\n\n    static constexpr int kBlockN = get<1>(TileShape_MNK{});\n    static constexpr int kHeadDim = get<2>(TileShape_MNK{});\n    static_assert(NumEpilogueThreads % cutlass::NumThreadsPerWarp == 0, \"NumEpilogueThreads must be a multiple of NumThreadsPerWarp\");\n    static constexpr int NumWarpGroups = NumEpilogueThreads / cutlass::NumThreadsPerWarpGroup;\n    // Thread layout, 256 or 384 threads per row\n    // We split into NumWarpGroups so that we can use the same postprocessing kernel as dQ\n    using R2SLayoutAtomdKVaccum = Layout<Shape<Int<cutlass::NumThreadsPerWarpGroup>, Int<NumWarpGroups>>>;\n    using R2STiledCopydKVaccum = decltype(make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, R2SLayoutAtomdKVaccum{},\n                                                         Layout<Shape < _4>>{}));  // Val layout, 4 vals per store\n    // For Sm80\n    using R2GLayoutAtomdKVaccum = Layout<Shape<Int<NumEpilogueThreads>>>;\n    using R2GTiledCopydKVaccum = decltype(make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, R2GLayoutAtomdKVaccum{},\n                                                         Layout<Shape < _1>>{}));  // Val layout, 1 vals per store\n\n    using SmemLayoutdKVaccum = Layout<Shape<Int<kBlockN * kHeadDim / NumWarpGroups>, Int<NumWarpGroups>>>;\n    using SmemLayoutdKVaccumFlat = Layout<Shape<Int<kBlockN * kHeadDim>>>;\n\n    // Strangely without this SmemAlignment, the total smem for hdim 128 (80 x 128) is 228KB even though we\n    // only need 227KB. We use the same alignment as the non-GQA epilogue to avoid this issue.\n    static constexpr int SmemAlignment = kHeadDim % 64 == 0 ? 1024 : (kHeadDim % 32 == 0 ? 512 : 256);\n    struct TensorStorageTMA : cute::aligned_struct<SmemAlignment> {\n        cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutdKVaccum>, SmemAlignment> smem_dkv;\n    };\n    struct TensorStorageSTG {\n        cute::array<ElementAccum, 0> smem_dkv;\n    };\n    using TensorStorage = std::conditional_t<Use_TMA, TensorStorageTMA, TensorStorageSTG>;\n\n    using ShapedKV = cute::Shape<int32_t, int32_t, int32_t>;  // (seqlen_k_rounded * d, head, batch)\n    using StridedKV = cute::Stride<_1, int64_t, int64_t>;\n\n    // Host side kernel arguments\n    struct Arguments {\n        ElementAccum* ptr_dKaccum;\n        ShapedKV const shape_dKaccum;\n        StridedKV const stride_dKaccum;\n        ElementAccum* ptr_dVaccum;\n        ShapedKV const shape_dVaccum;\n        StridedKV const stride_dVaccum;\n        int const num_batch;\n        int const num_heads_q;\n        int* dk_semaphore;\n        int* dv_semaphore;\n        int const* cu_seqlens;\n        int const* seqused;\n    };\n\n    // Device side kernel params\n    struct Params {\n        ElementAccum* ptr_dKaccum;\n        ShapedKV const shape_dKaccum;\n        StridedKV const stride_dKaccum;\n        ElementAccum* ptr_dVaccum;\n        ShapedKV const shape_dVaccum;\n        StridedKV const stride_dVaccum;\n        cutlass::FastDivmod qhead_per_khead_divmod;\n        int* dk_semaphore;\n        int* dv_semaphore;\n        int const num_batch;\n        int const* cu_seqlens = nullptr;\n        int const* seqused = nullptr;\n    };\n\n    static Params\n    to_underlying_arguments(Arguments const& args) {\n        if constexpr (Deterministic) {\n            assert(args.dk_semaphore != nullptr);\n            assert(args.dv_semaphore != nullptr);\n        }\n        return {args.ptr_dKaccum, args.shape_dKaccum, args.stride_dKaccum, args.ptr_dVaccum, args.shape_dVaccum, args.stride_dVaccum,\n                cutlass::FastDivmod(cute::ceil_div(args.num_heads_q, get<1>(args.shape_dKaccum))),\n                args.dk_semaphore, args.dv_semaphore,\n                args.num_batch, args.cu_seqlens, args.seqused};\n    }\n\n    /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance\n    CUTLASS_DEVICE\n    static void prefetch_tma_descriptors(Params const& params) {\n    }\n\n    template <typename SharedStorage, typename FrgTensorO, typename TiledMma>\n    CUTLASS_DEVICE void\n    store(Params const& params,\n          FrgTensorO const& tdKrdK,\n          FrgTensorO const& tdVrdV,\n          SharedStorage& shared_storage,\n          TiledMma tiled_mma,\n          int thread_idx,\n          cute::tuple<int32_t, int32_t, int32_t> const& block_coord\n          ) {\n\n        auto [n_block, bidh, bidb] = block_coord;\n        int bidh_idx_in_group;\n        int bidh_kv = params.qhead_per_khead_divmod.divmod(bidh_idx_in_group, bidh);\n        Tensor sdKV = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dkv.data()), SmemLayoutdKVaccum{});\n        Tensor sdKV_flat = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dkv.data()), SmemLayoutdKVaccumFlat{});\n        static constexpr int dKV_TMA_num_bytes = CUTE_STATIC_V(size(sdKV_flat)) * sizeof(ElementAccum);\n\n        flash::SeqlenInfo<Varlen, kBlockN> seqlen_info{bidb, size<0>(params.shape_dKaccum), params.cu_seqlens, params.seqused};\n        bool const is_varlen = Varlen && params.cu_seqlens;\n        Tensor mdKaccum = make_tensor(make_gmem_ptr(params.ptr_dKaccum), params.shape_dKaccum, params.stride_dKaccum)(_, bidh_kv, !is_varlen ? bidb : 0);\n        Tensor mdVaccum = make_tensor(make_gmem_ptr(params.ptr_dVaccum), params.shape_dVaccum, params.stride_dVaccum)(_, bidh_kv, !is_varlen ? bidb : 0);\n        Tensor gdKaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdKaccum), Shape<Int<kBlockN * kHeadDim>>{}, make_coord(n_block));  // (M * K)\n        Tensor gdVaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdVaccum), Shape<Int<kBlockN * kHeadDim>>{}, make_coord(n_block));  // (M * K)\n\n        R2STiledCopydKVaccum r2s_tiled_copy_dKVaccum;\n        auto r2s_thr_copy_dKVaccum = r2s_tiled_copy_dKVaccum.get_thread_slice(thread_idx);\n        Tensor tdKVsdKVaccum = r2s_thr_copy_dKVaccum.partition_D(sdKV);\n\n        // Only used if !Use_TMA\n        R2GTiledCopydKVaccum r2g_tiled_copy_dKVaccum;\n        auto r2g_thr_copy_dKVaccum = r2g_tiled_copy_dKVaccum.get_thread_slice(thread_idx);\n\n        // Make sure all WGs have finished reading K and V, otherwise we get racy dQ\n        // because smem_q could be changed.\n        flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);\n        if constexpr (Use_TMA) {\n            Tensor taccdKVrdV = r2s_thr_copy_dKVaccum.retile_S(tdVrdV); // ((Atom,AtomNum), MMA_M, MMA_N)\n            cute::copy(r2s_tiled_copy_dKVaccum, taccdKVrdV, tdKVsdKVaccum);\n        }\n\n        int const num_batch = params.num_batch;\n        // int const num_batch = get<2>(params.shape_dKaccum); // erroneously returns 1 for varlen\n        int const num_head_kv = get<1>(params.shape_dKaccum);\n        int *lock_ptr = !Deterministic ? nullptr : params.dv_semaphore + bidb * num_head_kv + bidh_kv;\n        using Barrier = cutlass::GenericBarrier<cutlass::detail::SyncwarpSync>;\n\n        // if (thread_idx == 0) { printf(\"blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dv_semaphore = %p, num_batch = %d, num_head_kv = %d, n_block = %d, bihd_idx_in_group = %d\\n\", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dv_semaphore, num_batch, num_head_kv, n_block, bidh_idx_in_group);}\n\n        if constexpr (Deterministic) {\n            Barrier::wait_eq(lock_ptr, thread_idx, n_block * num_batch * num_head_kv, bidh_idx_in_group);\n        }\n        // if (thread_idx == 0) { printf(\"After barrier blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dv_semaphore = %p\\n\", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dv_semaphore);}\n        if constexpr (Use_TMA) {\n            cutlass::arch::fence_view_async_shared();\n            cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);\n            if (thread_idx == 0) {\n                SM90_BULK_REDUCE_ADD::copy(raw_pointer_cast(sdKV_flat.data()), raw_pointer_cast(gdVaccum.data()), dKV_TMA_num_bytes, static_cast<uint64_t>(TMA::CacheHintSm90::EVICT_LAST));\n                tma_store_arrive();\n                tma_store_wait<0>();\n            }\n        } else {\n            Tensor tdVrdV_atomic = r2g_thr_copy_dKVaccum.retile_S(tdVrdV);\n            Tensor tdVgdV_atomic = r2g_thr_copy_dKVaccum.partition_D(gdVaccum);\n            static_assert(CUTE_STATIC_V(size(tdVrdV_atomic)) == CUTE_STATIC_V(size(tdVgdV_atomic)));\n            #pragma unroll\n            for (int i = 0; i < size(tdVrdV_atomic); ++i) { atomicAdd(&tdVgdV_atomic(i), tdVrdV_atomic(i)); }\n        }\n        if constexpr (Deterministic) {\n            Barrier::arrive_inc(lock_ptr, thread_idx, n_block * num_batch * num_head_kv);\n        }\n\n        if constexpr (Use_TMA) {\n            cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);\n            Tensor taccdKVrdK = r2s_thr_copy_dKVaccum.retile_S(tdKrdK); // ((Atom,AtomNum), MMA_M, MMA_N)\n            cute::copy(r2s_tiled_copy_dKVaccum, taccdKVrdK, tdKVsdKVaccum);\n        }\n        lock_ptr = !Deterministic ? nullptr : params.dk_semaphore + bidb * num_head_kv + bidh_kv;\n        // if (thread_idx == 0) { printf(\"blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dk_semaphore = %p, num_batch = %d, num_head_kv = %d, n_block = %d, bihd_idx_in_group = %d\\n\", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dk_semaphore, num_batch, num_head_kv, n_block, bidh_idx_in_group);}\n\n        if constexpr (Deterministic) {\n            Barrier::wait_eq(lock_ptr, thread_idx, n_block * num_batch * num_head_kv, bidh_idx_in_group);\n        }\n        // if (thread_idx == 0) { printf(\"After barrier blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dk_semaphore = %p\\n\", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dk_semaphore);}\n        if constexpr (Use_TMA) {\n            cutlass::arch::fence_view_async_shared();\n            cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);\n            if (thread_idx == 0) {\n                SM90_BULK_REDUCE_ADD::copy(raw_pointer_cast(sdKV_flat.data()), raw_pointer_cast(gdKaccum.data()), dKV_TMA_num_bytes, static_cast<uint64_t>(TMA::CacheHintSm90::EVICT_LAST));\n                tma_store_arrive();\n                tma_store_wait<0>();\n            }\n        } else {\n            Tensor tdKrdK_atomic = r2g_thr_copy_dKVaccum.retile_S(tdKrdK);\n            Tensor tdKgdK_atomic = r2g_thr_copy_dKVaccum.partition_D(gdKaccum);\n            static_assert(CUTE_STATIC_V(size(tdKrdK_atomic)) == CUTE_STATIC_V(size(tdKgdK_atomic)));\n            #pragma unroll\n            for (int i = 0; i < size(tdKrdK_atomic); ++i) { atomicAdd(&tdKgdK_atomic(i), tdKrdK_atomic(i)); }\n        }\n        if constexpr (Deterministic) {\n            Barrier::arrive_inc(lock_ptr, thread_idx, n_block * num_batch * num_head_kv);\n        }\n        // // Tell warp 0 that smem_k and smem_v are ready\n        // flash::named_barrier_arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, static_cast<uint32_t>(BwdNamedBarriers::KVEmpty) /*id*/);\n    }\n\n    CUTLASS_DEVICE void\n    store_tail() {\n    }\n\n    // Write 0 to dK and dV\n    CUTLASS_DEVICE void\n    store_zero(\n         Params const& params,\n         int thread_idx,\n         cute::tuple<int32_t, int32_t, int32_t> const& block_coord\n         ) {\n        // Don't need to do anything since dKaccum and dVaccum are already zero-initialized\n    }\n\n};\n\n} // namespace flash\n"
  },
  {
    "path": "hopper/epilogue_fwd.hpp",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n#include <cutlass/cutlass.h>\n#include <cutlass/fast_math.h>  // For FastDivMod\n#include \"cute/tensor.hpp\"\n\n#include \"cutlass/gemm/collective/builders/sm90_common.inl\"\n#include \"cutlass/epilogue/collective/builders/sm90_common.inl\"\n\n#include \"seqlen.h\"\n#include \"named_barrier.hpp\"\n#include \"pack_gqa.h\"\n#include \"utils.h\"\n\nnamespace flash {\n\nusing namespace cute;\n\ntemplate <class TileShape_MNK_PV_, class ClusterShape_, class Element_, class ArchTag_,\n          int NumEpilogueThreads_, bool Varlen_, bool PackGQA_, bool Split_, bool FP8PermuteCol=false>\nstruct CollectiveEpilogueFwd {\n\n    using TileShape_MNK_PV = TileShape_MNK_PV_;\n    using ClusterShape = ClusterShape_;\n    using Element = Element_;\n    using ElementPartial = float;\n    using ArchTag = ArchTag_;\n    static constexpr int NumEpilogueThreads = NumEpilogueThreads_;\n    static constexpr bool Varlen = Varlen_;\n    static constexpr bool PackGQA = PackGQA_;\n    static constexpr bool Split = Split_;\n    static constexpr bool Use_smem = !(Split && !Varlen);\n    static constexpr bool Use_TMA_O = ArchTag::kMinComputeCapability >= 90 && !Varlen && !Split && !PackGQA;\n\n    static_assert(ArchTag::kMinComputeCapability >= 80);\n    static_assert(ArchTag::kMinComputeCapability >= 90 || CUTE_STATIC_V(size(ClusterShape{})) == 1);\n    static_assert(sizeof(Element) <= 2);\n\n    static constexpr int kBlockM = get<0>(TileShape_MNK_PV{});\n    static constexpr int kHeadDimV = get<1>(TileShape_MNK_PV{});\n\n    static constexpr bool LargeHeadDimV = kHeadDimV > 256;\n\n    using GmemTiledCopyOTMA = cute::SM90_TMA_STORE;\n\n    // These are for storing the output tensor without TMA (e.g., for setting output to zero)\n    static constexpr int kGmemElemsPerStore = sizeof(cute::uint128_t) / sizeof(Element);\n    static_assert(kHeadDimV % kGmemElemsPerStore == 0, \"Headdim must be a multiple of kGmemElemsPerStore\");\n    // We want each \"row\" to have 64 elements (128 bytes, i.e. 1 cache line). We want each thread to have 4 elements\n    // in the M direction and 2 elements in the K direction. In the case of PackGQA, this reduces the number of times\n    // we need to call divmod.\n    static constexpr int kBytePerRow = kHeadDimV * sizeof(Element);\n    static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element);\n    static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerStore;\n    // If PackGQA, we split the work of compute O_ptr among threads in the same row, so we need this to within a warp\n    static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0);\n    static_assert(NumEpilogueThreads % kGmemThreadsPerRow == 0, \"NumEpilogueThreads must be a multiple of kGmemThreadsPerRow\");\n    using GmemLayoutAtom = Layout<Shape <Int<NumEpilogueThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,\n                                  Stride<Int<kGmemThreadsPerRow>, _1>>;\n    static_assert(kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0, \"kBlockM must be a multiple of NumEpilogueThreads / kGmemThreadsPerRow\");\n    using GmemTiledCopyO = decltype(\n        make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},\n                        GmemLayoutAtom{},\n                        Layout<Shape<_1, Int<kGmemElemsPerStore>>>{}));  // Val layout, 8 or 16 vals per store\n\n    using SmemLayoutAtomOTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,\n        decltype(cute::get<0>(TileShape_MNK_PV{})), decltype(cute::get<1>(TileShape_MNK_PV{}))>());\n    using SmemLayoutOTMA = decltype(tile_to_shape(SmemLayoutAtomOTMA{}, select<0, 1>(TileShape_MNK_PV{})));\n    static constexpr int kSwizzle = kBlockKGmem == 128 ? 4 : (kBlockKGmem == 64 ? 3 : (kBlockKGmem == 32 ? 2 : 1));\n    static constexpr int kSwizzleBase = sizeof(Element) == 4 ? 2 : (sizeof(Element) == 2 ? 3 : 4);\n    using SmemLayoutAtomO = decltype(\n        composition(Swizzle<kSwizzle, kSwizzleBase, kSwizzleBase>{},\n                    Layout<Shape<_8, Int<kBlockKGmem>>,\n                           Stride<Int<kBlockKGmem>, _1>>{}));\n    using SmemLayoutOSTS = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 1>(TileShape_MNK_PV{})));\n    using SmemLayoutO = std::conditional_t<ArchTag::kMinComputeCapability >= 90, SmemLayoutOTMA, SmemLayoutOSTS>;\n\n    using ShapeO = cute::Shape<int32_t, int32_t, int32_t, int32_t, int32_t>;  // (seqlen_q, d, head, batch, num_splits)\n    using StrideO = cute::Stride<int64_t, _1, int64_t, int64_t, int64_t>;\n    using StrideLSE = cute::Stride<_1, int64_t, int64_t, int64_t>;            // (seqlen_q, head, batch, num_splits)\n    // ((qhead_per_khead, seqlen_q), d, nheads_kv, batch, num_splits)\n    using ShapeOPacked = std::conditional_t<!PackGQA, ShapeO, cute::Shape<cute::Shape<int32_t, int32_t>, int32_t, int32_t, int32_t, int32_t>>;\n    using StrideOPacked = std::conditional_t<!PackGQA, StrideO, cute::Stride<cute::Stride<int64_t, int64_t>, _1, int64_t, int64_t, int64_t>>;\n    // ((qhead_per_khead, seqlen_q), nheads_kv, batch, num_splits)\n    using ShapeLSEPacked = std::conditional_t<!PackGQA, cute::Shape<int32_t, int32_t, int32_t, int32_t>, cute::Shape<cute::Shape<int32_t, int32_t>, int32_t, int32_t, int32_t>>;\n    using StrideLSEPacked = std::conditional_t<!PackGQA, StrideLSE, cute::Stride<cute::Stride<int64_t, _1>, int64_t, int64_t, int64_t>>;\n\n    using EpilogueTile_MN = decltype(select<0, 1>(TileShape_MNK_PV{}));\n    using CopyOpR2S = std::conditional_t<\n        ArchTag::kMinComputeCapability >= 90,\n        // cute::SM90_U32x4_STSM_N if Element size is 2 bytes (fp16, bf16)\n        decltype(cutlass::epilogue::collective::detail::sm90_get_smem_store_op_for_accumulator<StrideO, Element, EpilogueTile_MN>()),\n        AutoVectorizingCopyWithAssumedAlignment<128>\n    >;\n    using SmemCopyAtomO = Copy_Atom<CopyOpR2S, Element>;\n\n    // static constexpr size_t SmemAlignmentO = cutlass::detail::alignment_for_swizzle(SmemLayoutO{});\n    // static_assert(SmemAlignmentO >= 128, \"Require at least 128B alignment\");\n    // struct TensorStorage : cute::aligned_struct<SmemAlignmentO> {\n    //     cute::array_aligned<Element, Use_smem ? cute::cosize_v<SmemLayoutO> : 0, SmemAlignmentO> smem_o;\n    // };\n    struct TensorStorage : cute::aligned_struct<128> {\n        cute::array_aligned<Element, Use_smem ? cute::cosize_v<SmemLayoutO> : 0> smem_o;\n    };\n\n    using TMA_O = std::conditional_t<\n        Use_TMA_O,\n        decltype(make_tma_copy(\n            GmemTiledCopyOTMA{},\n            make_tensor(make_gmem_ptr(static_cast<Element*>(nullptr)), ShapeO{}, StrideO{}),\n            SmemLayoutOTMA{},\n            select<0, 1>(TileShape_MNK_PV{}),\n            _1{})),  // no mcast for O\n        std::nullptr_t\n    >;\n\n    // Host side kernel arguments\n    struct Arguments {\n        Element* ptr_O;\n        ShapeO const shape_O;\n        StrideO const stride_O;\n        ElementPartial* ptr_O_partial;\n        StrideO const stride_O_partial;\n        float* ptr_LSE;\n        StrideLSE const stride_LSE;\n        float* ptr_LSE_partial;\n        StrideLSE const stride_LSE_partial;\n        int32_t const nheads_kv;\n        int const* cu_seqlens = nullptr;\n        int const* seqused = nullptr;\n    };\n\n    // Device side kernel params\n    struct Params {\n        Element* ptr_O;\n        ShapeO const shape_O;\n        StrideO const stride_O;\n        ShapeOPacked const shape_O_packed;\n        StrideOPacked const stride_O_packed;\n        ElementPartial* ptr_O_partial;\n        StrideO const stride_O_partial;\n        StrideOPacked const stride_O_partial_packed;\n        float* ptr_LSE;\n        StrideLSE const stride_LSE;\n        ShapeLSEPacked const shape_LSE_packed;\n        StrideLSEPacked const stride_LSE_packed;\n        float* ptr_LSE_partial;\n        StrideLSE const stride_LSE_partial;\n        StrideLSEPacked const stride_LSE_partial_packed;\n        cutlass::FastDivmod qhead_per_khead_divmod;\n        TMA_O tma_store_O;\n        int const* cu_seqlens = nullptr;\n        int const* seqused = nullptr;\n    };\n\n    static Params\n    to_underlying_arguments(Arguments const& args) {\n        Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), args.shape_O, args.stride_O);\n        TMA_O tma_store_O = [&]{\n            if constexpr (Use_TMA_O) {\n                return make_tma_copy(GmemTiledCopyOTMA{}, mO, SmemLayoutO{}, select<0, 1>(TileShape_MNK_PV{}), _1{}); // no mcast\n            } else {\n                return nullptr;\n            }\n        }();\n        // If PackGQA, reshape O to be ((qhead_per_khead, seqlen_q), head_size, nhead_k, batch_size, num_splits)\n        int const qhead_per_khead = !PackGQA ? 1 : cute::ceil_div(get<2>(args.shape_O), args.nheads_kv);\n        auto const shape_O_packed = cute::conditional_return<!PackGQA>(\n            args.shape_O,\n            make_shape(make_shape(qhead_per_khead, get<0>(args.shape_O)), get<1>(args.shape_O), args.nheads_kv, get<3>(args.shape_O), get<4>(args.shape_O))\n        );\n        auto const stride_O_packed = cute::conditional_return<!PackGQA>(\n            args.stride_O,\n            make_stride(make_stride(get<2>(args.stride_O), get<0>(args.stride_O)), get<1>(args.stride_O), get<2>(args.stride_O) * qhead_per_khead, get<3>(args.stride_O), get<4>(args.stride_O))\n        );\n        auto const stride_O_partial_packed = cute::conditional_return<!PackGQA>(\n            args.stride_O_partial,\n            make_stride(make_stride(get<2>(args.stride_O_partial), get<0>(args.stride_O_partial)), get<1>(args.stride_O_partial), get<2>(args.stride_O_partial) * qhead_per_khead, get<3>(args.stride_O_partial), get<4>(args.stride_O_partial))\n        );\n        // If PackGQA, Reshape LSE to be ((qhead_per_khead, seqlen_q), nhead_k, batch_size, num_splits)\n        auto const shape_LSE_packed = cute::conditional_return<!PackGQA>(\n            select<0, 2, 3, 4>(args.shape_O),\n            make_shape(make_shape(qhead_per_khead, get<0>(args.shape_O)), args.nheads_kv, get<3>(args.shape_O), get<4>(args.shape_O))\n        );\n        auto const stride_LSE_packed = cute::conditional_return<!PackGQA>(\n            args.stride_LSE,\n            make_stride(make_stride(get<1>(args.stride_LSE), get<0>(args.stride_LSE)), get<1>(args.stride_LSE) * qhead_per_khead, get<2>(args.stride_LSE), get<3>(args.stride_LSE))\n        );\n        auto const stride_LSE_partial_packed = cute::conditional_return<!PackGQA>(\n            args.stride_LSE_partial,\n            make_stride(make_stride(get<1>(args.stride_LSE_partial), get<0>(args.stride_LSE_partial)), get<1>(args.stride_LSE_partial) * qhead_per_khead, get<2>(args.stride_LSE_partial), get<3>(args.stride_LSE_partial))\n        );\n        return {args.ptr_O, args.shape_O, args.stride_O, shape_O_packed, stride_O_packed,\n                args.ptr_O_partial, args.stride_O_partial, stride_O_partial_packed,\n                args.ptr_LSE, args.stride_LSE, shape_LSE_packed, stride_LSE_packed,\n                args.ptr_LSE_partial, args.stride_LSE_partial, stride_LSE_partial_packed,\n                cutlass::FastDivmod(qhead_per_khead),\n                tma_store_O, args.cu_seqlens, args.seqused};\n    }\n\n    /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance\n    CUTLASS_DEVICE\n    static void prefetch_tma_descriptors(Params const& params) {\n        if constexpr (Use_TMA_O) {\n            cute::prefetch_tma_descriptor(params.tma_store_O.get_tma_descriptor());\n        }\n    }\n\n    template <typename SharedStorage, typename FrgTensorO, typename FrgTensorLSE, typename TiledMma>\n    CUTLASS_DEVICE void\n    store(Params const& params,\n          FrgTensorO& tOrO,\n          FrgTensorLSE const& lse,\n          SharedStorage& shared_storage,\n          TiledMma tiled_mma,\n          int thread_idx,\n          cute::tuple<int32_t, int32_t, int32_t, int32_t> const& block_coord\n          ) {\n\n        auto [m_block, bidh, bidb, split_idx] = block_coord;\n        int num_splits = get<4>(params.shape_O_packed);\n        if constexpr (Split && Varlen) {\n            uint32_t num_splits_dynamic_u = reinterpret_cast<uint32_t const&>(split_idx) >> 16; // first 16 bits are for num_splits\n            int num_splits_dynamic = reinterpret_cast<int&>(num_splits_dynamic_u);\n            num_splits = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits;\n            split_idx &= 0x0000FFFF;  // Only use the lower 16 bits of split_idx\n        }\n        bool const is_split = !Split ? false : (!Varlen ? true : num_splits > 1);\n\n        Tensor sO = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_o.data()), SmemLayoutO{});\n        // Tensor sO_pi = cute::as_position_independent_swizzle_tensor(sO);\n\n        static constexpr bool NeedFP8Permute = FP8PermuteCol && (sizeof(Element) == 2 || sizeof(Element) == 4);\n        // If we will possibly need tOrO in FP32, we'd want to permute tOrO before type conversion.\n        // Otherwise we can permute after conversion.\n        if constexpr (NeedFP8Permute && Split) { flash::permute_output_fp8_Vcolmajor(tOrO); }\n        Tensor tOrO_out = make_tensor_like<Element>(tOrO);\n        flash::convert_type_out(tOrO, tOrO_out);\n        if constexpr (NeedFP8Permute && !Split) { flash::permute_output_fp8_Vcolmajor(tOrO_out); }\n\n        // Make sure all WGs have finished reading V\n        // Technically we don't need this if we're not using smem, but the mainloop makes the assumption that\n        // all epilogue threads sync at least once during the epilogue (so that we can start loading Q with\n        // cp.async if we need).\n        flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);\n\n        // Step 1: Write O from rmem -> smem\n        if constexpr (Use_smem) {\n            auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma);\n            auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);\n            Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out);        // ((Atom,AtomNum), MMA_M, MMA_N)\n            Tensor taccOsO = smem_thr_copy_O.partition_D(sO);     // ((Atom,AtomNum),PIPE_M,PIPE_N)\n            // Tensor taccOsO = smem_thr_copy_O.partition_D(sO_pi);     // ((Atom,AtomNum),PIPE_M,PIPE_N)\n            cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);\n            if constexpr (Use_TMA_O) {\n                cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA\n                cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp,\n                                                    cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);\n            } else {\n                flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);\n            }\n        } else {\n            if constexpr (ArchTag::kMinComputeCapability >= 90) {\n                #pragma unroll\n                for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {\n                    shared_storage.pipelines.barrier_O.arrive(cta_id);\n                }\n            }\n        }\n\n        flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused};\n        bool is_varlen = Varlen && params.cu_seqlens;\n        int offset_o = seqlen_info.offset;\n        int seqlen_o = seqlen_info.seqlen;\n        int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0);\n\n        // Step 2: Write LSE from rmem -> gmem\n        auto thread_mma = tiled_mma.get_thread_slice(thread_idx);\n        // (MMA,MMA_M,MMA_K)\n        Tensor taccOcO = thread_mma.partition_C(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{})));\n        static_assert(decltype(size<0, 0>(taccOcO))::value == 2);\n        static_assert(decltype(size<0, 1>(taccOcO))::value == 2);\n        Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout()));\n        Tensor taccOcO_row = taccOcO_rowcol(_, _0{});\n        CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row));                     // MMA_M\n\n        using PackGQA_t = flash::PackGQAManager<get<0>(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, Element>;\n        using PackGQApartial_t = flash::PackGQAManager<get<0>(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, ElementPartial>;\n\n        Tensor mLSE = make_tensor(make_gmem_ptr((!is_split ? params.ptr_LSE : params.ptr_LSE_partial) + offset_o * get<0>(!is_split ? params.stride_LSE : params.stride_LSE_partial)),\n                                  params.shape_LSE_packed,\n                                  !is_split ? params.stride_LSE_packed : params.stride_LSE_partial_packed)(_, bidh, !is_varlen ? bidb : 0, !is_split ? 0 : split_idx);\n        // if (thread_idx == 0) { printf(\"Before LSE write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d\\n\", m_block, bidh, bidb, split_idx, offset_o, seqlen_o); print(mLSE); printf(\"\\n\"); }\n        if (!LargeHeadDimV || warp_group_idx == 0) {\n            if constexpr (!PackGQA) {\n                #pragma unroll\n                for (int mi = 0; mi < size(lse); ++mi) {\n                    int const row = m_block * kBlockM + get<0>(taccOcO_row(mi));\n                    if (get<1>(taccOcO_row(_0{})) == 0 && row < seqlen_o) { mLSE(row) = lse(mi); }\n                }\n            } else {\n                PackGQA_t::store_LSE(mLSE, lse, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block);\n            }\n        }\n\n        // Step 3: Write O from smem -> gmem\n        if constexpr (Use_TMA_O) {\n            Tensor mO = params.tma_store_O.get_tma_tensor(params.shape_O)(_, _, bidh, bidb, split_idx);\n            Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{}));  // (M, K)\n            auto block_tma_O = params.tma_store_O.get_slice(_0{});\n            Tensor tOgO = block_tma_O.partition_D(gO);  // (TMA, TMA_M, TMA_K)\n            Tensor tOsO = block_tma_O.partition_S(sO); // (TMA, TMA_M, TMA_K)\n            int warp_idx_sync = __shfl_sync(0xffffffff, thread_idx / cutlass::NumThreadsPerWarp, 0);\n            if (warp_idx_sync == NumEpilogueThreads / cutlass::NumThreadsPerWarp - 1) {\n                cutlass::arch::NamedBarrier::sync(NumEpilogueThreads + cutlass::NumThreadsPerWarp,\n                                                  cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);\n                if (cute::elect_one_sync()) {\n                    cute::copy(params.tma_store_O, tOsO, tOgO);\n                    tma_store_arrive();\n                    tma_store_wait<0>();\n                    #pragma unroll\n                    for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {\n                        shared_storage.pipelines.barrier_O.arrive(cta_id);\n                    }\n                }\n            }\n        } else {  // Don't use TMA in Varlen case since we don't want to overwrite the output of another sequence\n            if (!is_split) {\n                Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, _0{});\n                Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{}));  // (M, K)\n                // if (thread_idx == 0) { printf(\"Before O write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d, mO_addr = %p, addr diff = %d\\n\", m_block, bidh, bidb, split_idx, offset_o, seqlen_o, mO.data(), reinterpret_cast<int>(&mO(0)) - reinterpret_cast<int>(params.ptr_O)); }\n                GmemTiledCopyO gmem_tiled_copy_O;\n                auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);\n                Tensor tOsO = gmem_thr_copy_O.partition_S(sO);        // ((Atom,AtomNum),ATOM_M,ATOM_N)\n                // Tensor tOsO = gmem_thr_copy_O.partition_S(sO_pi);        // ((Atom,AtomNum),ATOM_M,ATOM_N)\n                Tensor tOrO = make_fragment_like(tOsO);\n                cute::copy(gmem_tiled_copy_O, tOsO, tOrO);\n                if constexpr (ArchTag::kMinComputeCapability >= 90) {\n                    cutlass::arch::fence_view_async_shared(); // ensure smem reads are done before next TMA to smem_v\n                    #pragma unroll\n                    for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {\n                        shared_storage.pipelines.barrier_O.arrive(cta_id);\n                    }\n                }\n                if constexpr (!PackGQA) {\n                    // (BLK_M,BLK_K) -> (blk_m,blk_k)\n                    Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{})));\n                    Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOsO)));\n                    #pragma unroll\n                    for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); }\n                    Tensor tOgO = gmem_thr_copy_O.partition_D(gO);\n                    // Clear_OOB_K must be false since we don't want to write zeros to gmem\n                    flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(\n                        gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_o - m_block * kBlockM\n                    );\n                } else {\n                    // If PackGQA, we split the work of compute O_ptr among threads in the same row\n                    PackGQA_t::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block);\n                }\n            } else {\n                Tensor mOpartial = make_tensor(make_gmem_ptr(params.ptr_O_partial + offset_o * get<0>(params.stride_O_partial)), params.shape_O_packed, params.stride_O_partial_packed)(_, _, bidh, !is_varlen ? bidb : 0, split_idx);\n                Tensor gOpartial = local_tile(mOpartial, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{}));  // (M, K)\n                // We already arrived on barrier_O earlier if !Use_smem\n                if constexpr (Use_smem) {\n                    if constexpr (ArchTag::kMinComputeCapability >= 90) {\n                        #pragma unroll\n                        for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {\n                            shared_storage.pipelines.barrier_O.arrive(cta_id);\n                        }\n                    }\n                }\n                if constexpr (!PackGQA) {\n                    static constexpr int kGmemElemsPerStoreDirect = 2;\n                    cute::Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementPartial> gmem_copy_direct;\n                    // Reshape acc from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))\n                    Tensor tOrO_rowcol = make_tensor(tOrO.data(), flash::convert_layout_acc_rowcol(tOrO.layout()));\n                    Tensor tOrO_copy = cute::tiled_divide(tOrO_rowcol, Shape<_1, Int<kGmemElemsPerStoreDirect>>{});\n                    Tensor tOgO = thread_mma.partition_C(gOpartial);\n                    Tensor tOgO_rowcol = make_tensor(tOgO.data(), flash::convert_layout_acc_rowcol(tOgO.layout()));\n                    Tensor tOgO_copy = cute::tiled_divide(tOgO_rowcol, Shape<_1, Int<kGmemElemsPerStoreDirect>>{});\n                    Tensor taccOcO_col = taccOcO_rowcol(_0{}, _);\n                    #pragma unroll\n                    for (int m = 0; m < size(taccOcO_row); ++m) {\n                        if (get<0>(taccOcO_row(m)) < seqlen_o - m_block * kBlockM) {\n                            #pragma unroll\n                            for (int k = 0; k < size(taccOcO_col) / kGmemElemsPerStoreDirect; ++k) {\n                                if (get<1>(taccOcO_col(k * kGmemElemsPerStoreDirect)) < get<1>(params.shape_O)) {\n                                    cute::copy(gmem_copy_direct, tOrO_copy(_, m, k), tOgO_copy(_, m, k));\n                                }\n                            }\n                        }\n                    }\n                } else {\n                    PackGQApartial_t::store_O_direct(mOpartial, tOrO, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block);\n                }\n            }\n        }\n    }\n\n    CUTLASS_DEVICE void\n    store_tail() {\n        // Don't need to do tma_store_wait<0>() here since we already did in @store\n    }\n\n    // Write 0 to output and -inf to LSE\n    CUTLASS_DEVICE void\n    store_zero(\n         Params const& params,\n         int thread_idx,\n         cute::tuple<int32_t, int32_t, int32_t, int32_t> const& block_coord\n         ) {\n        static constexpr int kBlockM = get<0>(TileShape_MNK_PV{});\n        auto [m_block, bidh, bidb, split_idx] = block_coord;\n        int num_splits = get<4>(params.shape_O_packed);\n        if constexpr (Split && Varlen) {\n            uint32_t num_splits_dynamic_u = reinterpret_cast<uint32_t const&>(split_idx) >> 16; // first 16 bits are for num_splits\n            int num_splits_dynamic = reinterpret_cast<int&>(num_splits_dynamic_u);\n            num_splits = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits;\n            split_idx &= 0x0000FFFF;  // Only use the lower 16 bits of split_idx\n        }\n        bool const is_split = !Split ? false : (!Varlen ? true : num_splits > 1);\n\n        flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused};\n        bool const is_varlen = Varlen && params.cu_seqlens;\n        int offset_o = seqlen_info.offset;\n        int seqlen_o = seqlen_info.seqlen;\n        int qhead_per_khead = !PackGQA ? 1 : params.qhead_per_khead_divmod.divisor;\n        Tensor mLSE = make_tensor(make_gmem_ptr((!is_split ? params.ptr_LSE : params.ptr_LSE_partial) + offset_o * get<0>(!is_split ? params.stride_LSE : params.stride_LSE_partial)),\n                                  params.shape_LSE_packed,\n                                  !is_split ? params.stride_LSE_packed : params.stride_LSE_partial_packed)(_, bidh, !is_varlen ? bidb : 0, !is_split ? 0 : split_idx);\n        Tensor gLSE = local_tile(mLSE, Shape<Int<kBlockM>>{}, make_coord(m_block));\n\n        static_assert(kBlockM <= NumEpilogueThreads);\n        if (thread_idx < kBlockM) {\n            const int row = m_block * kBlockM + thread_idx;\n            if constexpr (!PackGQA) {\n                if (row < seqlen_o) { mLSE(row) = -INFINITY; }\n            } else {\n                if (row < seqlen_o * qhead_per_khead) {\n                    int m_idx, h_idx;\n                    m_idx = params.qhead_per_khead_divmod.divmod(h_idx, row);\n                    // mLSE has shape ((qhead_per_khead, seqlen_q)) and it's unhappy with just 1 \"make_coord\"\n                    mLSE(make_coord(make_coord(h_idx, m_idx))) = -INFINITY;\n                }\n            }\n        }\n\n        // If split, we don't have to write 0 to mOpartial if the mha_combine kernel is used,\n        // since it will not use the value of O if LSE is -inf.\n        if (!is_split) {\n            Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, _0{});\n\n            GmemTiledCopyO gmem_tiled_copy_O;\n            auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);\n            Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{})));\n            if constexpr (!PackGQA) {\n                Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOcO)));\n                #pragma unroll\n                for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); }\n                Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{}));  // (M, K)\n                Tensor tOgO = gmem_thr_copy_O.partition_D(gO);\n                Tensor tOrO = make_fragment_like(tOgO);\n                cute::clear(tOrO);\n                // Clear_OOB_K must be false since we don't want to write zeros to gmem\n                flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(\n                    gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_o - m_block * kBlockM\n                );\n            } else {\n                // If PackGQA, we split the work of compute O_ptr among threads in the same row\n                using PackGQA_t = flash::PackGQAManager<get<0>(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, Element>;\n                Tensor tOrO = make_tensor<Element>(make_shape(Shape<_1, Int<kGmemElemsPerStore>>{}, size<1>(tOcO), size<2>(tOcO)));\n                cute::clear(tOrO);\n                PackGQA_t::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block);\n            }\n        }\n\n    }\n\n};\n\n} // namespace flash\n"
  },
  {
    "path": "hopper/flash.h",
    "content": "/******************************************************************************\n * Copyright (c) 2023, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n#include <cuda.h>\n#include <vector>\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct Qkv_params {\n    using index_t = int64_t;\n    // The QKV matrices.\n    void *__restrict__ q_ptr;\n    void *__restrict__ k_ptr;\n    void *__restrict__ v_ptr;\n\n    // The stride between rows of the Q, K and V matrices.\n    index_t q_batch_stride;\n    index_t k_batch_stride;\n    index_t v_batch_stride;\n    index_t q_row_stride;\n    index_t k_row_stride;\n    index_t v_row_stride;\n    index_t q_head_stride;\n    index_t k_head_stride;\n    index_t v_head_stride;\n    index_t v_dim_stride;\n\n    // The number of heads.\n    int h, h_k;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct Flash_fwd_params : public Qkv_params {\n    using index_t = int64_t;\n\n    // The O matrix (output).\n    void * __restrict__ o_ptr;\n    void * __restrict__ oaccum_ptr;\n\n    // The stride between rows of O.\n    index_t o_batch_stride;\n    index_t o_row_stride;\n    index_t o_head_stride;\n\n    // The pointer to the softmax sum.\n    void * __restrict__ softmax_lse_ptr;\n    void * __restrict__ softmax_lseaccum_ptr;\n\n    // For FP8 scaling\n    float * __restrict__ q_descale_ptr;\n    float * __restrict__ k_descale_ptr;\n    float * __restrict__ v_descale_ptr;\n    index_t q_descale_batch_stride;\n    index_t q_descale_head_stride;\n    index_t k_descale_batch_stride;\n    index_t k_descale_head_stride;\n    index_t v_descale_batch_stride;\n    index_t v_descale_head_stride;\n\n    // The dimensions.\n    int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim;\n    int total_q, total_k, total_knew;\n    int b_k;  // When having KV cache and with cache_batch_idx, K & V might have larger batch size than Q\n    int dv, dv_rounded;  // For the case where V headdim is different from Q/K headdim\n\n    // The scaling factors for the kernel.\n    float scale_softmax;\n    float softcap;\n\n    // array of length b+1 holding starting offset of each sequence.\n    int * __restrict__ cu_seqlens_q;\n    int * __restrict__ cu_seqlens_k;\n    int * __restrict__ cu_seqlens_knew;\n    int * __restrict__ leftpad_k;\n\n    // If provided, the actual length of each q/k sequence.\n    int *__restrict__ seqused_q;\n    int *__restrict__ seqused_k;\n\n    // The stride between rows of Oaccum.\n    index_t oaccum_split_stride;\n    index_t oaccum_batch_stride;\n    index_t oaccum_row_stride;\n    index_t oaccum_head_stride;\n\n    // The stride between rows of LSEaccum.\n    index_t lseaccum_split_stride;\n    index_t lseaccum_batch_stride;\n    index_t lseaccum_head_stride;\n\n    // The K_new and V_new matrices.\n    void * __restrict__ knew_ptr;\n    void * __restrict__ vnew_ptr;\n\n    // The stride between rows of the Q, K and V matrices.\n    index_t knew_batch_stride;\n    index_t vnew_batch_stride;\n    index_t knew_row_stride;\n    index_t vnew_row_stride;\n    index_t knew_head_stride;\n    index_t vnew_head_stride;\n\n    void *__restrict__ qv_ptr;\n    index_t qv_batch_stride;\n    index_t qv_row_stride;\n    index_t qv_head_stride;\n\n    // The cos and sin matrices for rotary embedding.\n    void * __restrict__ rotary_cos_ptr;\n    void * __restrict__ rotary_sin_ptr;\n    int *__restrict__ seqlens_rotary;\n\n    // The indices to index into the KV cache.\n    int * __restrict__ kv_batch_idx;\n\n    // Paged KV cache\n    int * __restrict__ page_table;\n    index_t page_table_batch_stride;\n    int page_size;\n    int num_pages;\n    bool pagedkv_tma;\n\n    // The dropout probability (probability of keeping an activation).\n    float p_dropout;\n    // uint32_t p_dropout_in_uint;\n    // uint16_t p_dropout_in_uint16_t;\n    uint8_t p_dropout_in_uint8_t;\n\n    // Scale factor of 1 / (1 - p_dropout).\n    float rp_dropout;\n\n    // Local window size\n    int window_size_left, window_size_right;\n    int attention_chunk;\n\n    // Pointer to the RNG seed (idx 0) and offset (idx 1).\n    uint64_t * rng_state;\n\n    bool is_bf16;\n    bool is_fp32;\n    bool is_e4m3;\n    bool is_causal;\n    bool is_local;\n\n    bool is_rotary_interleaved;\n\n    int num_splits;  // For split-KV version\n    bool pack_gqa;\n\n    int * __restrict__ tile_count_semaphore;\n    int * __restrict__ num_m_blocks_ptr;\n    // int * __restrict__ num_n_blocks_ptr;\n    int * __restrict__ num_splits_dynamic_ptr;\n    int * __restrict__ varlen_batch_idx_ptr; // virtual -> actual\n    int * __restrict__ num_nheads_in_l2_ptr;\n    bool skip_scheduler_metadata_computation;\n    bool varlen_sort_batches;\n    int tile_count_semaphore_offset;\n    bool head_swizzle;\n    bool prepare_varlen_pdl;\n\n    int arch;\n    int num_sm;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nstruct Flash_bwd_params : public Flash_fwd_params {\n    using index_t = int64_t;\n\n    // The dO and dQKV matrices.\n    void *__restrict__ do_ptr;\n    void *__restrict__ dq_ptr;\n    void *__restrict__ dk_ptr;\n    void *__restrict__ dv_ptr;\n\n    // To accumulate dQ\n    void *__restrict__ dq_accum_ptr;\n    void *__restrict__ dk_accum_ptr;\n    void *__restrict__ dv_accum_ptr;\n\n    // // To accumulate dK and dV in case we're splitting the bwd along seqlen_q\n    // dimension void *__restrict__ dk_accum_ptr; void *__restrict__\n    // dv_accum_ptr;\n\n    // The stride between rows of the dO, dQ, dK and dV matrices.\n    index_t do_batch_stride;\n    index_t do_row_stride;\n    index_t do_head_stride;\n    index_t dq_batch_stride;\n    index_t dk_batch_stride;\n    index_t dv_batch_stride;\n    index_t dq_row_stride;\n    index_t dk_row_stride;\n    index_t dv_row_stride;\n    index_t dq_head_stride;\n    index_t dk_head_stride;\n    index_t dv_head_stride;\n\n    // The pointer to the softmax d sum.\n    void *__restrict__ dsoftmax_sum;\n    void *__restrict__ softmax_lse_log2_ptr;\n\n    int *__restrict__ dq_semaphore;\n    int *__restrict__ dk_semaphore;\n    int *__restrict__ dv_semaphore;\n\n    bool deterministic;\n    index_t dq_accum_split_stride;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int Arch, typename T, int kHeadDim, int kHeadDimV, bool Split, bool PagedKVNonTMA, bool Has_softcap, bool PackGQA>\nvoid run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);\nvoid prepare_varlen_num_blocks(Flash_fwd_params &params, cudaStream_t stream, bool packgqa, int blockM, int blockN, bool enable_pdl);\ntemplate <int Arch, typename T, int kHeadDim, bool Has_softcap>\nvoid run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);\ntemplate <typename T, typename Tpartial, int kBlockK>\nvoid run_mha_fwd_combine_(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl);\n"
  },
  {
    "path": "hopper/flash_api.cpp",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n ******************************************************************************/\n\n#include <Python.h>\n#include <torch/nn/functional/padding.h>\n#include <ATen/cuda/CUDAContextLight.h>\n#include <c10/cuda/CUDAGuard.h>\n\n#include <cutlass/numeric_types.h>\n\n#include \"flash.h\"\n#include \"static_switch.h\"\n#include \"tile_size.h\"\n#include \"heuristics.h\"\n#include \"cuda_check.h\"\n\n\nextern \"C\" {\n/* Creates a dummy empty _C module that can be imported from Python.\n    The import from Python will load the .so consisting of this file\n    in this extension, so that the TORCH_LIBRARY static initializers\n    below are run. */\nPyObject* PyInit__C(void)\n{\n    static struct PyModuleDef module_def = {\n        PyModuleDef_HEAD_INIT,\n        \"_C\",   /* name of module */\n        NULL,   /* module documentation, may be NULL */\n        -1,     /* size of per-interpreter state of the module,\n                    or -1 if the module keeps state in global variables. */\n        NULL,   /* methods */\n    };\n    return PyModule_Create(&module_def);\n}\n}\n\n#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x \" must be on CUDA\")\n#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x \" must have shape (\" #__VA_ARGS__ \")\")\n#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x \" must be contiguous\")\n\n#define PREPARE_VARLEN_MAX_BATCHES_1CTA 992\n\nnamespace {\ninline at::cuda::CUDAGuard make_cuda_guard_from_tensor(const at::Tensor& t) {\n  return at::cuda::CUDAGuard(static_cast<c10::DeviceIndex>(t.get_device()));\n}\n} // namespace\n\nvoid set_params_fprop(Flash_fwd_params &params,\n                      // sizes\n                      const size_t b,\n                      const size_t seqlen_q,\n                      const size_t seqlen_k,\n                      const size_t seqlen_q_rounded,\n                      const size_t seqlen_k_rounded,\n                      const size_t h,\n                      const size_t h_k,\n                      const size_t d,\n                      const size_t d_rounded,\n                      // device pointers\n                      const at::Tensor q,\n                      const at::Tensor k,\n                      const at::Tensor v,\n                      at::Tensor out,\n                      void *cu_seqlens_q_d,\n                      void *cu_seqlens_k_d,\n                      void *seqused_q,\n                      void *seqused_k,\n                      void *softmax_lse_d,\n                      float p_dropout,\n                      float softmax_scale,\n                      int window_size_left,\n                      int window_size_right,\n                      int attention_chunk,\n                      const float softcap=0.f,\n                      const int sm_margin=0) {\n\n    // Reset the parameters\n    params = {};\n\n    params.is_bf16 = q.dtype() == torch::kBFloat16;\n    params.is_e4m3 = q.dtype() == torch::kFloat8_e4m3fn;\n\n    // Set the pointers and strides.\n    params.q_ptr = q.data_ptr();\n    params.k_ptr = k.data_ptr();\n    params.v_ptr = v.data_ptr();\n    // All stride are in elements, not bytes.\n    params.q_row_stride = q.stride(-3);\n    params.k_row_stride = k.stride(-3);\n    params.v_row_stride = v.stride(-3);\n    params.q_head_stride = q.stride(-2);\n    params.k_head_stride = k.stride(-2);\n    params.v_head_stride = v.stride(-2);\n    params.v_dim_stride = v.stride(-1);\n    params.o_ptr = out.data_ptr();\n    params.o_row_stride = out.stride(-3);\n    params.o_head_stride = out.stride(-2);\n\n    if (cu_seqlens_q_d == nullptr) {\n        params.q_batch_stride = q.stride(0);\n        params.o_batch_stride = out.stride(0);\n    }\n    if (cu_seqlens_k_d == nullptr) {\n        params.k_batch_stride = k.stride(0);\n        params.v_batch_stride = v.stride(0);\n    }\n\n    params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);\n    params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);\n    params.seqused_q = static_cast<int *>(seqused_q);\n    params.seqused_k = static_cast<int *>(seqused_k);\n\n    // Softmax sum\n    params.softmax_lse_ptr = softmax_lse_d;\n\n    // Set the dimensions.\n    params.b = b;\n    params.h = h;\n    params.h_k = h_k;\n    params.seqlen_q = seqlen_q;\n    params.seqlen_k = seqlen_k;\n    params.seqlen_q_rounded = seqlen_q_rounded;\n    params.seqlen_k_rounded = seqlen_k_rounded;\n    params.d = d;\n    params.d_rounded = d_rounded;\n\n    // Set the different scale values.\n    params.scale_softmax = softmax_scale;\n    params.softcap = softcap;\n\n    // Set this to probability of keeping an element to simplify things.\n    params.p_dropout = 1.f - p_dropout;\n    // Convert p from float to int so we don't have to convert the random uint to float to compare.\n    // [Minor] We want to round down since when we do the comparison we use <= instead of <\n    // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));\n    // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));\n    params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));\n    params.rp_dropout = 1.f / params.p_dropout;\n    TORCH_CHECK(p_dropout < 1.f);\n    #ifdef FLASHATTENTION_DISABLE_DROPOUT\n        TORCH_CHECK(p_dropout == 0.0f, \"This flash attention build does not support dropout.\");\n    #endif\n\n    // Causal is the special case where window_size_right == 0 and window_size_left < 0.\n    // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.\n    params.is_causal = window_size_left < 0 && window_size_right == 0 && attention_chunk == 0;\n    params.is_local = (window_size_left >= 0 || window_size_right >= 0 || attention_chunk >= 1) && !params.is_causal;\n\n    // TODO: check this\n    if (window_size_left < 0) { window_size_left = seqlen_k - 1; }\n    if (window_size_right < 0) { window_size_right = seqlen_q - 1; }\n    if (attention_chunk > 0) {\n        window_size_left = std::min(window_size_left, attention_chunk - 1);\n        window_size_right = std::min(window_size_right, attention_chunk - 1);\n    }\n    params.window_size_left = window_size_left;\n    params.window_size_right = window_size_right;\n    params.attention_chunk = attention_chunk;\n\n    params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor;\n    params.num_sm = at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin;\n\n    #ifdef FLASHATTENTION_DISABLE_LOCAL\n        TORCH_CHECK(!params.is_local, \"This flash attention build does not support local attention.\");\n    #endif\n}\n\nvoid set_params_dgrad(Flash_bwd_params &params,\n                      // sizes\n                      const size_t b,\n                      const size_t seqlen_q,\n                      const size_t seqlen_k,\n                      const size_t seqlen_q_rounded,\n                      const size_t seqlen_k_rounded,\n                      const size_t h,\n                      const size_t h_k,\n                      const size_t d,\n                      const size_t d_rounded,\n                      // device pointers\n                      const at::Tensor q,\n                      const at::Tensor k,\n                      const at::Tensor v,\n                      const at::Tensor out,\n                      const at::Tensor dout,\n                      at::Tensor dq,\n                      at::Tensor dk,\n                      at::Tensor dv,\n                      void *cu_seqlens_q_d,\n                      void *cu_seqlens_k_d,\n                      void *seqused_q,\n                      void *seqused_k,\n                      void *dq_accum_d,\n                      void *dk_accum_d,\n                      void *dv_accum_d,\n                      void *softmax_lse_d,\n                      void *dsoftmax_sum_d,\n                      float p_dropout,\n                      float softmax_scale,\n                      int window_size_left,\n                      int window_size_right,\n                      int attention_chunk,\n                      const float softcap=0.f,\n                      bool deterministic=false,\n                      int const sm_margin=0) {\n\n    set_params_fprop(params,\n                     b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,\n                     q, k, v, out,\n                     cu_seqlens_q_d,\n                     cu_seqlens_k_d,\n                     seqused_q,\n                     seqused_k,\n                     softmax_lse_d,\n                     p_dropout,\n                     softmax_scale,\n                     window_size_left,\n                     window_size_right,\n                     attention_chunk,\n                     softcap,\n                     sm_margin);\n\n    // Set the pointers and strides.\n    params.do_ptr = dout.data_ptr();\n    params.do_row_stride = dout.stride(-3);\n    params.do_head_stride = dout.stride(-2);\n    params.dq_ptr = dq.data_ptr();\n    params.dk_ptr = dk.data_ptr();\n    params.dv_ptr = dv.data_ptr();\n    params.dq_row_stride = dq.stride(-3);\n    params.dk_row_stride = dk.stride(-3);\n    params.dv_row_stride = dv.stride(-3);\n    params.dq_head_stride = dq.stride(-2);\n    params.dk_head_stride = dk.stride(-2);\n    params.dv_head_stride = dv.stride(-2);\n\n    if (cu_seqlens_q_d == nullptr) {\n        params.do_batch_stride = dout.stride(0);\n        params.dq_batch_stride = dq.stride(0);\n        params.dk_batch_stride = dk.stride(0);\n        params.dv_batch_stride = dv.stride(0);\n    }\n\n    params.dq_accum_ptr = dq_accum_d;\n    params.dk_accum_ptr = dk_accum_d;\n    params.dv_accum_ptr = dv_accum_d;\n\n    // Softmax sum\n    params.dsoftmax_sum = dsoftmax_sum_d;\n\n    params.deterministic = deterministic;\n}\n\ntemplate <int Arch, int Split, bool PagedKVNonTMA, bool PackGQA, bool Has_softcap>\nvoid run_mha_fwd_constexpr(Flash_fwd_params &params, cudaStream_t stream) {\n    if (!params.is_e4m3) {\n        if (params.is_bf16) {\n            #ifndef FLASHATTENTION_DISABLE_HDIM64\n            if (params.d <= 64) {\n                #ifndef FLASHATTENTION_DISABLE_HDIMDIFF64\n                if constexpr (Arch == 90) {\n                    if (params.dv > 256) {\n                        return run_mha_fwd_<Arch, cutlass::bfloat16_t, 64, 512, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);\n                    } else if (params.dv > 64) {\n                        return run_mha_fwd_<Arch, cutlass::bfloat16_t, 64, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);\n                    }\n                }\n                #endif\n                return run_mha_fwd_<Arch, cutlass::bfloat16_t, 64, 64, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);\n            }\n            #endif\n            #ifndef FLASHATTENTION_DISABLE_HDIM96\n            if (params.d <= 96) { return run_mha_fwd_<Arch, cutlass::bfloat16_t, 96, 96, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }\n            #endif\n            #ifndef FLASHATTENTION_DISABLE_HDIM128\n            if (params.d <= 128) { return run_mha_fwd_<Arch, cutlass::bfloat16_t, 128, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }\n            #endif\n            #ifndef FLASHATTENTION_DISABLE_HDIM192\n            if (params.d <= 192) {\n                #ifndef FLASHATTENTION_DISABLE_HDIMDIFF192\n                if constexpr (Arch == 90) {\n                    if (params.dv <= 128) {\n                        return run_mha_fwd_<Arch, cutlass::bfloat16_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);\n                    }\n                }\n                #endif\n                return run_mha_fwd_<Arch, cutlass::bfloat16_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);\n            }\n            #endif\n            #ifndef FLASHATTENTION_DISABLE_HDIM256\n            if (params.d <= 256) { return run_mha_fwd_<Arch, cutlass::bfloat16_t, 256, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }\n            #endif\n        } else {\n            #ifndef FLASHATTENTION_DISABLE_FP16\n            #ifndef FLASHATTENTION_DISABLE_HDIM64\n            if (params.d <= 64) {\n                #ifndef FLASHATTENTION_DISABLE_HDIMDIFF64\n                if constexpr (Arch == 90) {\n                    if (params.dv > 256) {\n                        return run_mha_fwd_<Arch, cutlass::half_t, 64, 512, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);\n                    } else if (params.dv > 64) {\n                        return run_mha_fwd_<Arch, cutlass::half_t, 64, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);\n                    }\n                }\n                #endif\n                return run_mha_fwd_<Arch, cutlass::half_t, 64, 64, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);\n            }\n            #endif\n            #ifndef FLASHATTENTION_DISABLE_HDIM96\n            if (params.d <= 96) { return run_mha_fwd_<Arch, cutlass::half_t, 96, 96, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }\n            #endif\n            #ifndef FLASHATTENTION_DISABLE_HDIM128\n            if (params.d <= 128) { return run_mha_fwd_<Arch, cutlass::half_t, 128, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }\n            #endif\n            #ifndef FLASHATTENTION_DISABLE_HDIM192\n            if (params.d <= 192) {\n                #ifndef FLASHATTENTION_DISABLE_HDIMDIFF192\n                if constexpr (Arch == 90) {\n                    if (params.dv <= 128) {\n                        return run_mha_fwd_<Arch, cutlass::half_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);\n                    }\n                }\n                #endif\n                return run_mha_fwd_<Arch, cutlass::half_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);\n            }\n            #endif\n            #ifndef FLASHATTENTION_DISABLE_HDIM256\n            if (params.d <= 256) { return run_mha_fwd_<Arch, cutlass::half_t, 256, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }\n            #endif\n            #else\n            TORCH_CHECK(false, \"This flash attention build does not support FP16.\");\n            #endif\n        }\n    } else {\n        #ifndef FLASHATTENTION_DISABLE_FP8\n        #ifndef FLASHATTENTION_DISABLE_HDIM64\n        if (params.d <= 64) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }\n        #endif\n        #ifndef FLASHATTENTION_DISABLE_HDIM96\n        if (params.d <= 96) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }\n        #endif\n        #ifndef FLASHATTENTION_DISABLE_HDIM128\n        if (params.d <= 128) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }\n        #endif\n        #ifndef FLASHATTENTION_DISABLE_HDIM192\n        if (params.d <= 192) {\n            #ifndef FLASHATTENTION_DISABLE_HDIMDIFF192\n            if constexpr (Arch == 90) {\n                if (params.dv <= 128) {\n                    return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);\n                }\n            }\n            #endif\n            return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);\n        }\n        #endif\n        #ifndef FLASHATTENTION_DISABLE_HDIM256\n        if (params.d <= 256) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }\n        #endif\n        #else\n        TORCH_CHECK(false, \"This flash attention build does not support FP8.\");\n        #endif\n    }\n}\n\nvoid run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {\n    // HEADDIM_SWITCH(params.d, [&] {\n    //     run_mha_fwd_<cutlass::half_t, kHeadSize>(params, stream);\n    // });\n    TORCH_CHECK(params.num_splits >= 1);\n    ARCH_SWITCH(params.arch, Arch, [&] {\n        SPLIT_SWITCH(params.num_splits > 1, Split, [&] {\n            PAGEDKV_SWITCH(params.page_table && !params.pagedkv_tma, PagedKVNonTMA, [&] {\n                PACKGQA_SWITCH(params.pack_gqa, PackGQA_, [&] {\n                    // Always enable PackGQA for Sm8x or PagedKVNonTMA or Split to reduce compilation\n                    static constexpr bool PackGQA = PackGQA_ || Arch < 90 || PagedKVNonTMA || Split;\n                    SOFTCAP_SWITCH(params.softcap > 0.0, Has_softcap, [&] {\n                        run_mha_fwd_constexpr<Arch, Split, PagedKVNonTMA, PackGQA, Has_softcap>(params, stream);\n                    });\n                });\n            });\n        });\n    });\n}\n\nvoid run_mha_fwd_combine(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl=false) {\n    #ifndef FLASHATTENTION_DISABLE_SPLIT\n    // If hdim is 96 or 192, it's faster to round them to 128 or 256 respectively\n    // so that kBlockM is smaller and we have more parallelism.\n    if (params.is_fp32) {\n        if (params.dv <= 64) {\n            run_mha_fwd_combine_<float, float, 64>(params, stream, enable_pdl);\n        } else {\n            run_mha_fwd_combine_<float, float, 128>(params, stream, enable_pdl);\n        }\n    } else if (params.is_bf16) {\n        if (params.dv <= 64) {\n            run_mha_fwd_combine_<cutlass::bfloat16_t, float, 64>(params, stream, enable_pdl);\n        } else {\n            run_mha_fwd_combine_<cutlass::bfloat16_t, float, 128>(params, stream, enable_pdl);\n        }\n    } else {\n        if (params.dv <= 64) {\n            run_mha_fwd_combine_<cutlass::half_t, float, 64>(params, stream, enable_pdl);\n        } else {\n            run_mha_fwd_combine_<cutlass::half_t, float, 128>(params, stream, enable_pdl);\n        }\n    }\n    #else\n    TORCH_CHECK(false, \"This flash attention build does not support combine kernels.\");\n    #endif\n}\n\ninline bool get_pagedkv_tma(Flash_fwd_params const& params) {\n    if (params.arch < 90 || !params.page_table || params.leftpad_k || params.knew_ptr) { return false; }\n    // This needs to match the kernel configs\n    auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, false /*paged_kv_non_TMA*/, params.softcap > 0.f);\n    int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90);\n    int const kBlockN = std::get<1>(kBlockMN_kernel_args_sm90);\n    // Heuristic: when seqlen_q <= kBlockM, we're not compute bound, and somehow using TMA is slower,\n    // at least for MLA.\n    return params.page_size % kBlockN == 0 && params.seqlen_q * (params.h / params.h_k) > kBlockM;\n}\n\ninline bool get_pack_gqa(Flash_fwd_params const& params) {\n    // Always enable PackGQA for Sm8x or PagedKVNonTMA or Split to reduce compilation and binary size.\n    // Has little effect on speed.\n    if (params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1) { return true; }\n    #ifdef FLASHATTENTION_DISABLE_PACKGQA\n    return false;\n    #else\n    // params.page_table must already be set\n    if (params.h == params.h_k) { return false; }\n    // This needs to match the kernel configs\n    auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f);\n    int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90);\n    return should_pack_gqa(params.cu_seqlens_q || params.seqused_q, params.seqlen_q, params.h / params.h_k, kBlockM);\n    #endif\n}\n\ninline int get_num_splits(Flash_fwd_params const& params) {\n    #ifdef FLASHATTENTION_DISABLE_SPLIT\n    return 1;\n    #else\n    // Always enable PackGQA for Split\n    // params.page_table must already be set\n    // This needs to match the kernel configs\n    bool varlen = params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k;\n    auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f);\n    // Strictly speaking we need to pass in (varlen && params.num_splits > 1) but num_splits\n    // has not been set here. It's OK though because we might just underestimate kBlockN a bit\n    auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, varlen, params.softcap > 0.f, params.knew_ptr);\n    int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x);\n    int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x);\n    int seqlen_q_packgqa = params.seqlen_q * (params.h / params.h_k);\n    // If is_local, we're not going to load all of seqlen_k\n    int const seqlen_k_loaded = !params.is_local\n        ? params.seqlen_k\n        : std::max(0, std::min(params.seqlen_k, params.window_size_right + params.window_size_left + 1 + kBlockM));\n    int const num_n_blocks = (seqlen_k_loaded + kBlockN - 1) / kBlockN;\n    int const num_m_blocks = (seqlen_q_packgqa + kBlockM - 1) / kBlockM;\n    int const size_one_kv_head = params.seqlen_k * (params.d + params.dv) * (params.is_e4m3 ? 1 : 2);\n    // Always enable PackGQA for Split\n    // If varlen, we use dynamic split, so this heuristic just needs to get an upper bound on num_splits.\n    // We assume the case where there's 1 long sequence and the rest are short, i.e. pretending\n    // that batch = 1.\n    int total_mblocks = (params.num_splits_dynamic_ptr ? 1 : params.b) * params.h_k * num_m_blocks;\n    return num_splits_heuristic(total_mblocks, params.num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, params.is_causal || params.is_local, 128);\n    #endif\n}\n\ninline int get_max_headdim() {\n    #ifndef FLASHATTENTION_DISABLE_HDIM256\n    return 256;\n    #endif\n    #ifndef FLASHATTENTION_DISABLE_HDIM192\n    return 192;\n    #endif\n    #ifndef FLASHATTENTION_DISABLE_HDIM128\n    return 128;\n    #endif\n    #ifndef FLASHATTENTION_DISABLE_HDIM96\n    return 96;\n    #endif\n    #ifndef FLASHATTENTION_DISABLE_HDIM64\n    return 64;\n    #endif\n    return 0;\n}\n\ninline int round_up_headdim(int head_size) {\n    #ifndef FLASHATTENTION_DISABLE_HDIM64\n    if (head_size <= 64) { return 64; }\n    #endif\n    #ifndef FLASHATTENTION_DISABLE_HDIM96\n    if (head_size <= 96) { return 96; }\n    #endif\n    #ifndef FLASHATTENTION_DISABLE_HDIM128\n    if (head_size <= 128) { return 128; }\n    #endif\n    #ifndef FLASHATTENTION_DISABLE_HDIM192\n    if (head_size <= 192) { return 192; }\n    #endif\n    #ifndef FLASHATTENTION_DISABLE_HDIM256\n    if (head_size <= 256) { return 256; }\n    #endif\n    return 256;\n}\n\ninline int round_up_headdimv(int head_size) {\n    if (head_size <= 64) { return 64; }\n    if (head_size <= 96) { return 96; }\n    if (head_size <= 128) { return 128; }\n    if (head_size <= 192) { return 192; }\n    if (head_size <= 256) { return 256; }\n    return 512;\n}\n\n// Only applicable to the case where seqused_k (i.e. cache_seqlens) is available\nat::Tensor\nmha_fwd_get_scheduler_metadata(\n        int64_t batch_size,\n        int64_t max_seqlen_q,\n        int64_t max_seqlen_k,\n        int64_t num_heads,\n        int64_t num_heads_k,\n        int64_t headdim,\n        int64_t headdim_v,\n        at::ScalarType qkv_dtype,\n        at::Tensor seqused_k, // b\n        std::optional<at::Tensor> cu_seqlens_q_,  // b+1\n        std::optional<at::Tensor> cu_seqlens_k_,  // b+1\n        std::optional<at::Tensor> cu_seqlens_k_new_,  // b+1\n        std::optional<at::Tensor> seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.\n        std::optional<at::Tensor> leftpad_k_, // b\n        std::optional<int64_t> page_size,\n        int64_t max_seqlen_k_new,  // 0 means we're not appending new KV\n        bool is_causal,\n        int64_t window_size_left,\n        int64_t window_size_right,\n        int64_t attention_chunk,\n        bool has_softcap,\n        int64_t num_splits,\n        std::optional<bool> pack_gqa_,\n        int64_t sm_margin) {\n\n    TORCH_CHECK(qkv_dtype == at::ScalarType::Half || qkv_dtype == at::ScalarType::BFloat16 || qkv_dtype == at::ScalarType::Float8_e4m3fn,\n                \"FlashAttention only supports fp16, bf16, and fp8_e4m3 data type\");\n    TORCH_CHECK(num_heads % num_heads_k == 0, \"Number of heads in key/value must divide number of heads in query\");\n\n    // Reset the parameters\n    Flash_fwd_params params{};\n    params.is_bf16 = qkv_dtype == at::ScalarType::BFloat16;\n    params.is_e4m3 = qkv_dtype == at::ScalarType::Float8_e4m3fn;\n    params.b = batch_size;\n    params.seqlen_q = max_seqlen_q;\n    params.seqlen_k = max_seqlen_k;\n    params.h = num_heads;\n    params.h_k = num_heads_k;\n    params.d = headdim;\n    params.dv = headdim_v;\n    params.d_rounded = round_up_headdim(headdim);\n    params.dv_rounded = headdim_v == headdim ? params.d_rounded : round_up_headdimv(headdim_v);\n    params.seqlen_knew = max_seqlen_k_new;\n\n    bool const is_varlen_q = cu_seqlens_q_.has_value();\n    params.cu_seqlens_q = is_varlen_q ? cu_seqlens_q_.value().data_ptr<int>() : nullptr;\n    bool const is_varlen_k = cu_seqlens_k_.has_value();\n    params.cu_seqlens_k = is_varlen_k ? cu_seqlens_k_.value().data_ptr<int>() : nullptr;\n    params.cu_seqlens_knew = cu_seqlens_k_new_.has_value() ? cu_seqlens_k_new_.value().data_ptr<int>() : nullptr;\n    params.seqused_q = seqused_q_.has_value() ? seqused_q_.value().data_ptr<int>() : nullptr;\n    params.seqused_k = seqused_k.data_ptr<int>();\n    params.leftpad_k = leftpad_k_.has_value() ? leftpad_k_.value().data_ptr<int>() : nullptr;\n    params.knew_ptr = params.seqlen_knew > 0 ? reinterpret_cast<int*>(1) : nullptr;\n    if (window_size_left >= max_seqlen_k - 1) { window_size_left = -1; }\n    if (window_size_right >= max_seqlen_q - 1) { window_size_right = -1; }\n    // causal=true is the same as causal=false in this case\n    if (max_seqlen_q == 1 && window_size_left == -1 && window_size_right == -1 && attention_chunk == 0) {\n        // Special case of hdim 128 where we want causal to have kBlockN=128, better for pagedKV and TMA\n        if ((headdim <= 64 || headdim > 128) || !page_size.has_value()) {\n            is_causal = false;\n        }\n    }\n    if (is_causal) { window_size_right = 0; }\n\n    params.is_causal = window_size_left < 0 && window_size_right == 0 && attention_chunk == 0;\n    params.is_local = (window_size_left >= 0 || window_size_right >= 0 || attention_chunk >= 1) && !params.is_causal;\n    if (window_size_left < 0) { window_size_left = max_seqlen_k - 1; }\n    if (window_size_right < 0) { window_size_right = max_seqlen_q - 1; }\n    if (attention_chunk > 0) {\n        window_size_left = std::min(window_size_left, attention_chunk - 1);\n        window_size_right = std::min(window_size_right, attention_chunk - 1);\n    }\n    params.window_size_left = window_size_left;\n    params.window_size_right = window_size_right;\n    params.attention_chunk = attention_chunk;\n    params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor;\n    params.num_sm = at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin;\n    params.softcap = has_softcap ? 1.0f : 0.0f;\n\n    params.page_size = page_size.has_value() ? page_size.value() : 1;\n    params.page_table = !page_size.has_value() ? nullptr : reinterpret_cast<int*>(1);\n\n    bool const use_prepare_varlen = true;\n    params.prepare_varlen_pdl = use_prepare_varlen && params.b <= PREPARE_VARLEN_MAX_BATCHES_1CTA;\n    params.num_splits_dynamic_ptr = !use_prepare_varlen ? nullptr : reinterpret_cast<int*>(1);\n\n    params.pagedkv_tma = get_pagedkv_tma(params);\n    params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits;\n    // Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide\n    params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params);\n\n    bool is_varlen = true;\n\n    // Otherwise the kernel will be launched from cuda:0 device\n    // Cast to char to avoid compiler warning about narrowing\n    auto device_guard = make_cuda_guard_from_tensor(seqused_k);\n\n    auto opts = seqused_k.options();\n    // This needs to be set after get_num_splits\n    at::Tensor tile_count_semaphore;  // Contains the semaphore and optionally num_splits_dynamic\n    bool const scheduler_needs_semaphore = params.arch >= 90 || params.num_splits > 1;\n    auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };\n    params.varlen_sort_batches = !params.is_local; // Use this value for Sort in scheduler template\n    params.head_swizzle = params.is_causal || params.is_local; // Use this value for LPT in scheduler template\n    if (scheduler_needs_semaphore || use_prepare_varlen) {   \n        int b_rounded = round_multiple(params.b, 4); // for 16 byte alignment of pointers \n        int num_prepare_batch_vectors = use_prepare_varlen ? 2 : 0;\n        if(params.varlen_sort_batches) { num_prepare_batch_vectors += 1; }\n        if(params.head_swizzle) { num_prepare_batch_vectors += 1; }\n        int head_swizzle_offset = b_rounded * (params.varlen_sort_batches ? 3 : 2);\n        int tile_count_semaphore_offset = b_rounded * num_prepare_batch_vectors;\n        // printf(\"(Metadata) num prepare batch vectors = %d.\\n\", num_prepare_batch_vectors);\n        tile_count_semaphore = torch::empty(\n            {int(scheduler_needs_semaphore) + tile_count_semaphore_offset},\n            opts.dtype(torch::kInt32));\n        // {num_splits_dynamic, num_m_blocks, varlen_batch_idx, num_nheads_in_l2}\n        params.num_splits_dynamic_ptr = use_prepare_varlen ? tile_count_semaphore.data_ptr<int>() : nullptr;\n        params.num_m_blocks_ptr =  use_prepare_varlen ? tile_count_semaphore.data_ptr<int>() + b_rounded : nullptr;\n        params.varlen_batch_idx_ptr =  use_prepare_varlen && params.varlen_sort_batches ? tile_count_semaphore.data_ptr<int>() + b_rounded * 2 : nullptr;\n        // params.num_n_blocks_ptr  = use_prepare_varlen && params.head_swizzle ? tile_count_semaphore.data_ptr<int>() + head_swizzle_offset : nullptr;\n        params.num_nheads_in_l2_ptr = use_prepare_varlen && params.head_swizzle ? tile_count_semaphore.data_ptr<int>() + head_swizzle_offset : nullptr;\n        if (scheduler_needs_semaphore) {\n            if (!use_prepare_varlen) { tile_count_semaphore.zero_(); }  // If varlen we'll manually do the zero-ing\n            params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>() + tile_count_semaphore_offset;\n        } else {\n            params.tile_count_semaphore = nullptr;\n        }\n    }\n\n    if (use_prepare_varlen) {\n        auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f);\n        auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, is_varlen && params.num_splits > 1, params.softcap > 0.f, params.knew_ptr);\n        int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x);\n        int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x);\n        auto stream = at::cuda::getCurrentCUDAStream().stream();\n        prepare_varlen_num_blocks(params, stream, params.pack_gqa, kBlockM, kBlockN, false /*enable_pdl*/);\n        CHECK_CUDA_KERNEL_LAUNCH();\n    }\n    return tile_count_semaphore;\n}\n\n// b: batch_size\n// b_k: batch_size_k\n// s_q: seqlen_q\n// s_k: seqlen_k\n// s_k_new: seqlen_k_new\n// h: num_heads\n// h_k: num_heads_k\n// d: head_size\nstd::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>\nmha_fwd(at::Tensor q,   // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q\n        at::Tensor k,  // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table.\n        at::Tensor v,  // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table.\n        std::optional<at::Tensor> k_new_,  // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new\n        std::optional<at::Tensor> v_new_,  // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new\n        std::optional<at::Tensor> q_v_,  // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q\n        std::optional<at::Tensor> out_,  // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q\n        std::optional<at::Tensor> cu_seqlens_q_,  // b+1\n        std::optional<at::Tensor> cu_seqlens_k_,  // b+1\n        std::optional<at::Tensor> cu_seqlens_k_new_,  // b+1\n        std::optional<at::Tensor> seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.\n        std::optional<at::Tensor> seqused_k_, // b. If given, only this many elements of each batch element's keys are used.\n        std::optional<int64_t> max_seqlen_q_,\n        // TODO: check if we need max_seqlen_k\n        std::optional<int64_t> max_seqlen_k_,\n        std::optional<at::Tensor> page_table_, // (b_k, max_num_pages_per_seq)\n        std::optional<at::Tensor> kv_batch_idx_, // b. indices to index into the KV cache\n        std::optional<at::Tensor> leftpad_k_, // b\n        std::optional<at::Tensor> rotary_cos_, // seqlen_ro x (rotary_dim / 2)\n        std::optional<at::Tensor> rotary_sin_, // seqlen_ro x (rotary_dim / 2)\n        std::optional<at::Tensor> seqlens_rotary_, // b\n        std::optional<at::Tensor> q_descale_,  // (b, h_k), not (b, h)\n        std::optional<at::Tensor> k_descale_,  // (b, h_k)\n        std::optional<at::Tensor> v_descale_,  // (b, h_k)\n        std::optional<double> softmax_scale_,\n        bool is_causal,\n        int64_t window_size_left,\n        int64_t window_size_right,\n        int64_t attention_chunk,\n        double softcap,\n        bool is_rotary_interleaved,   // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2\n        std::optional<at::Tensor> scheduler_metadata_,  // (b + 1)\n        int64_t num_splits,\n        std::optional<bool> pack_gqa_,\n        int64_t sm_margin\n        ) {\n\n    auto dprops = at::cuda::getCurrentDeviceProperties();\n    bool is_sm8x = dprops->major >= 8;\n    TORCH_CHECK(is_sm8x, \"FlashAttention only supports Ampere GPUs or newer.\");\n\n    auto q_type = q.scalar_type();\n    TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16 || q_type == at::ScalarType::Float8_e4m3fn,\n                \"FlashAttention only supports fp16, bf16, and fp8_e4m3 data type\");\n    if (dprops->major < 9) {\n        TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16,\n                    \"FlashAttention on Ampere/Ada cards only supports fp16 and bf16 data type\");\n    }\n    TORCH_CHECK(k.scalar_type() == q_type, \"query and key must have the same dtype\");\n    TORCH_CHECK(v.scalar_type() == q_type, \"query and value must have the same dtype\");\n\n    CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);\n\n    TORCH_CHECK(q.stride(-1) == 1, \"Input tensor must have contiguous last dimension\");\n    TORCH_CHECK(k.stride(-1) == 1, \"Input tensor must have contiguous last dimension\");\n    TORCH_CHECK(v.stride(-1) == 1, \"Input tensor must have contiguous last dimension\");\n\n    at::Tensor page_table;\n    const bool paged_KV = page_table_.has_value();\n    if (paged_KV) {\n        page_table = page_table_.value();\n        CHECK_DEVICE(page_table);\n        TORCH_CHECK(page_table.dtype() == torch::kInt32, \"page_table must have dtype torch.int32\");\n        TORCH_CHECK(page_table.stride(-1) == 1, \"page_table must have contiguous last dimension\");\n    }\n\n    at::Tensor cu_seqlens_q;\n    bool const is_varlen_q = cu_seqlens_q_.has_value();\n    if (is_varlen_q) {\n        cu_seqlens_q = cu_seqlens_q_.value();\n        CHECK_DEVICE(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_q);\n        TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, \"cu_seqlens_q must have dtype torch.int32\");\n        TORCH_CHECK(max_seqlen_q_.has_value(), \"max_seqlen_q must be provided if cu_seqlens_q is provided\");\n    }\n    at::Tensor cu_seqlens_k;\n    bool const is_varlen_k = cu_seqlens_k_.has_value();\n    if (is_varlen_k) {\n        cu_seqlens_k = cu_seqlens_k_.value();\n        CHECK_DEVICE(cu_seqlens_k); CHECK_CONTIGUOUS(cu_seqlens_k);\n        TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, \"cu_seqlens_k must have dtype torch.int32\");\n        TORCH_CHECK(max_seqlen_k_.has_value(), \"max_seqlen_k must be provided if cu_seqlens_k is provided\");\n        TORCH_CHECK(!paged_KV, \"If cu_seqlens_k is passed in, then page table is not supported\");\n        TORCH_CHECK(!kv_batch_idx_.has_value(), \"If cu_seqlens_k is passed in, then page table is not supported\");\n    }\n\n    auto const sizes = q.sizes();\n    const int batch_size = !is_varlen_q ? sizes[0] : cu_seqlens_q.size(0) - 1;\n    int seqlen_q = !is_varlen_q ? sizes[1] : max_seqlen_q_.value();\n    int total_q = !is_varlen_q ? batch_size * sizes[1] : sizes[0];\n    int num_heads = q.size(-2);\n    int const head_size = q.size(-1);\n    int const head_size_v = v.size(-1);\n    int const max_num_pages_per_seq = !paged_KV ? 0 : page_table.size(1);\n    int const num_pages = !paged_KV ? 0 : k.size(0);\n    int const page_size = !paged_KV ? 1 : k.size(1);\n    int const seqlen_k = !is_varlen_k ? (!paged_KV ? k.size(1) : max_num_pages_per_seq * page_size) : max_seqlen_k_.value();\n    int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0);\n    int const num_heads_k = k.size(-2);\n    int const batch_size_k = !paged_KV ? (!is_varlen_k ? k.size(0) : cu_seqlens_k.size(0) - 1) : page_table.size(0);\n    double softmax_scale = 1.0 / sqrt(double(head_size));\n    if (softmax_scale_.has_value()) {\n        softmax_scale = softmax_scale_.value();\n    }\n    if (!kv_batch_idx_.has_value()) {\n        TORCH_CHECK(batch_size == batch_size_k, \"batch_size must be equal to batch_size_k\");\n    }\n    int const max_headdim = get_max_headdim();\n    TORCH_CHECK(head_size <= max_headdim, \"FlashAttention forward only supports head dimension at most \" + std::to_string(max_headdim));\n    TORCH_CHECK(num_heads % num_heads_k == 0, \"Number of heads in key/value must divide number of heads in query\");\n    if (head_size_v != head_size) {\n        TORCH_CHECK((head_size > 128 && head_size <= 192 && head_size_v > 96 && head_size_v <= 128) ||\n                   (head_size <= 64 && head_size_v <= 512),\n                   \"If V headdim is different from Q/K dim, we only support Q/K headdim in (128, 192] and V headdim in (96, 128], \"\n                   \"or (Q/K <= 64 and V <= 512).\");\n        TORCH_CHECK(dprops->major == 9, \"Only Hopper supports different V headdim\");\n        if (head_size_v > 256) {\n            TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16,\n                        \"HeaddimV > 256 requires fp16 and bf16 data type\");\n        }\n    }\n\n    // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM\n    // TODO: check this\n    if (window_size_left >= seqlen_k - 1) { window_size_left = -1; }\n    if (window_size_right >= seqlen_q - 1) { window_size_right = -1; }\n    // causal=true is the same as causal=false in this case\n    if (seqlen_q == 1 && window_size_left == -1 && window_size_right == -1 && attention_chunk == 0) {\n        // Special case of hdim 128 where we want causal to have kBlockN=128, better for pagedKV and TMA\n        if ((head_size <= 64 || head_size > 128) || !paged_KV) {\n            is_causal = false;\n        }\n    }\n    if (is_causal) { window_size_right = 0; }\n\n    if (!is_varlen_q) {\n        CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);\n    } else {\n        CHECK_SHAPE(q, total_q, num_heads, head_size);\n        CHECK_SHAPE(cu_seqlens_q, batch_size + 1);\n    }\n    if (!paged_KV) {\n        if (!is_varlen_k) {\n            CHECK_SHAPE(k, batch_size_k, seqlen_k, num_heads_k, head_size);\n            CHECK_SHAPE(v, batch_size_k, seqlen_k, num_heads_k, head_size_v);\n        } else {\n            CHECK_SHAPE(k, total_k, num_heads_k, head_size);\n            CHECK_SHAPE(v, total_k, num_heads_k, head_size_v);\n            CHECK_SHAPE(cu_seqlens_k, batch_size + 1);\n        }\n    } else {\n        CHECK_SHAPE(k, num_pages, page_size, num_heads_k, head_size);\n        CHECK_SHAPE(v, num_pages, page_size, num_heads_k, head_size_v);\n        CHECK_SHAPE(page_table, batch_size_k, max_num_pages_per_seq);\n    }\n\n    if (seqused_q_.has_value()){\n        auto seqused_q = seqused_q_.value();\n        TORCH_CHECK(seqused_q.dtype() == torch::kInt32, \"seqused_q must have dtype int32\");\n        CHECK_DEVICE(seqused_q); CHECK_CONTIGUOUS(seqused_q);\n        CHECK_SHAPE(seqused_q, batch_size);\n    }\n    if (seqused_k_.has_value()) {\n        auto seqused_k = seqused_k_.value();\n        TORCH_CHECK(seqused_k.dtype() == torch::kInt32, \"seqused_k must have dtype int32\");\n        CHECK_DEVICE(seqused_k); CHECK_CONTIGUOUS(seqused_k);\n        CHECK_SHAPE(seqused_k, batch_size);\n    }\n\n    if (leftpad_k_.has_value()) {\n        auto leftpad_k = leftpad_k_.value();\n        TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, \"leftpad_k must have dtype int32\");\n        CHECK_DEVICE(leftpad_k); CHECK_CONTIGUOUS(leftpad_k);\n        CHECK_SHAPE(leftpad_k, batch_size);\n    }\n\n    // This is what we will template on\n    bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value() || leftpad_k_.has_value();\n    #ifdef FLASHATTENTION_DISABLE_VARLEN\n        TORCH_CHECK(!is_varlen, \"This flash attention build does not support varlen.\");\n    #endif\n\n    int const alignment = q_type == torch::kFloat8_e4m3fn ? 16 : 8;\n    TORCH_CHECK(head_size % alignment == 0, \"head_size should be a multiple of \" + std::to_string(alignment));\n    TORCH_CHECK(head_size_v % alignment == 0, \"head_size_v should be a multiple of \" + std::to_string(alignment));\n\n    auto opts = q.options();\n    auto out_type = q_type == at::ScalarType::Float8_e4m3fn ? at::ScalarType::BFloat16 : q_type;\n    at::Tensor out;\n    if (out_.has_value()) {\n        out = out_.value();\n        TORCH_CHECK(out.scalar_type() == out_type, \"For FP16/BF16 input, output must have the same dtype as inputs. For FP8 input, output must have dtype BF16\");\n        CHECK_DEVICE(out);\n        TORCH_CHECK(out.stride(-1) == 1, \"Output tensor must have contiguous last dimension\");\n        if (!is_varlen_q) {\n            CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_v);\n        } else {\n            CHECK_SHAPE(out, total_q, num_heads, head_size_v);\n        }\n    } else {\n        out = !is_varlen_q\n            ? torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts.dtype(out_type))\n            : torch::empty({total_q, num_heads, head_size_v}, opts.dtype(out_type));\n    }\n\n    auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };\n    int const head_size_rounded = round_up_headdim(head_size);\n    int const head_size_v_rounded = head_size_v == head_size ? head_size_rounded : round_up_headdimv(head_size_v);\n    int const seqlen_q_rounded = round_multiple(seqlen_q, 128);\n    int const seqlen_k_rounded = round_multiple(seqlen_k, 128);\n\n    // Otherwise the kernel will be launched from cuda:0 device\n    // Cast to char to avoid compiler warning about narrowing\n    auto device_guard = make_cuda_guard_from_tensor(q);\n\n    at::Tensor softmax_lse;\n    if (!is_varlen_q) {\n        softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));\n    } else {\n        softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat));\n    }\n\n    Flash_fwd_params params;\n    set_params_fprop(params,\n                     batch_size,\n                     seqlen_q, seqlen_k,\n                     seqlen_q_rounded, seqlen_k_rounded,\n                     num_heads, num_heads_k,\n                     head_size, head_size_rounded,\n                     q, k, v, out,\n                     !is_varlen_q ? nullptr : cu_seqlens_q.data_ptr(),\n                     !is_varlen_k ? nullptr : cu_seqlens_k.data_ptr(),\n                     seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr,\n                     seqused_k_.has_value() ? seqused_k_.value().data_ptr() : nullptr,\n                     softmax_lse.data_ptr(),\n                     /*p_dropout=*/0.f,\n                     softmax_scale,\n                     window_size_left,\n                     window_size_right,\n                     attention_chunk,\n                     softcap,\n                     sm_margin);\n    params.total_q = total_q;\n    params.total_k = total_k;\n    params.b_k = batch_size_k;\n    params.dv = head_size_v;\n    params.dv_rounded = head_size_v_rounded;\n    if (leftpad_k_.has_value()) {  // This needs to be set before get_pagedkv_tma\n        params.leftpad_k = static_cast<int *>(leftpad_k_.value().data_ptr());\n    }\n    if (paged_KV) {\n        params.page_table = page_table.data_ptr<int>();\n        params.page_table_batch_stride = page_table.stride(0);\n    }\n    params.page_size = page_size;\n    params.num_pages = num_pages;\n\n    if (k_new_.has_value()) {  // This needs to be set before get_pagedkv_tma\n        at::Tensor k_new, v_new;\n        TORCH_CHECK(v_new_.has_value(), \"If k_new is supplied, v_new must also be passed in\");\n        TORCH_CHECK(seqused_k_.has_value(), \"If k_new is supplied, seqlens_k must also be passed in\");\n        TORCH_CHECK(seqlen_q <= seqlen_k, \"If k_new is supplied, it must have seqlen <= the seqlen of the KV cache\");\n        at::Tensor cu_seqlens_k_new;\n        bool const is_varlen_k_new = cu_seqlens_k_new_.has_value();\n        if (is_varlen_k_new) {\n            cu_seqlens_k_new = cu_seqlens_k_new_.value();\n            CHECK_DEVICE(cu_seqlens_k_new); CHECK_CONTIGUOUS(cu_seqlens_k_new);\n            TORCH_CHECK(cu_seqlens_k_new.dtype() == torch::kInt32, \"cu_seqlens_k_new must have dtype torch.int32\");\n        }\n        k_new = k_new_.value();\n        v_new = v_new_.value();\n        TORCH_CHECK(k_new.dtype() == q_type, \"k_new must have the same dtype as query\");\n        TORCH_CHECK(v_new.dtype() == q_type, \"v_new must have the same dtype as query\");\n        CHECK_DEVICE(k_new); CHECK_DEVICE(v_new);\n        TORCH_CHECK(k_new.stride(-1) == 1, \"k_new tensor must have contiguous last dimension\");\n        TORCH_CHECK(v_new.stride(-1) == 1, \"v_new tensor must have contiguous last dimension\");\n        // We don't need max_seqlen_k_new, so seqlen_k_new can be whatever when is_varlen_k_new\n        int seqlen_k_new = !is_varlen_k_new ? k_new.size(1) : 0;\n        int total_k_new = !is_varlen_k_new ? batch_size * k_new.size(1): k_new.size(0);\n        if (!is_varlen_k_new) {\n            CHECK_SHAPE(k_new, batch_size, seqlen_k_new, num_heads_k, head_size);\n            CHECK_SHAPE(v_new, batch_size, seqlen_k_new, num_heads_k, head_size_v);\n        } else {\n            CHECK_SHAPE(k_new, total_k_new, num_heads_k, head_size);\n            CHECK_SHAPE(v_new, total_k_new, num_heads_k, head_size_v);\n            CHECK_SHAPE(cu_seqlens_k_new, batch_size + 1);\n        }\n        params.seqlen_knew = seqlen_k_new;\n        params.total_knew = total_k_new;\n        params.knew_ptr = k_new.data_ptr();\n        params.vnew_ptr = v_new.data_ptr();\n        // All stride are in elements, not bytes.\n        params.knew_row_stride = k_new.stride(-3);\n        params.vnew_row_stride = v_new.stride(-3);\n        params.knew_head_stride = k_new.stride(-2);\n        params.vnew_head_stride = v_new.stride(-2);\n        if (!is_varlen_k_new) {\n            params.knew_batch_stride = k_new.stride(0);\n            params.vnew_batch_stride = v_new.stride(0);\n        }\n        if (is_varlen_k_new) {\n            params.cu_seqlens_knew = static_cast<int*>(cu_seqlens_k_new.data_ptr());\n        }\n    }\n    \n    bool const use_prepare_varlen = is_varlen;\n    params.prepare_varlen_pdl = use_prepare_varlen && params.b <= PREPARE_VARLEN_MAX_BATCHES_1CTA;\n    // Temporarily set num_splits_dynamic_ptr to 1 since get_num_splits checks it\n    params.num_splits_dynamic_ptr = !use_prepare_varlen ? nullptr : reinterpret_cast<int*>(1);\n\n    params.pagedkv_tma = get_pagedkv_tma(params);\n    params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits;\n    // Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide\n    params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params);\n\n    // This needs to be set after get_num_splits\n    at::Tensor tile_count_semaphore;  // Contains the semaphore and optionally num_splits_dynamic\n    // We don't use the persistent scheduler if Split and not Varlen\n    bool const scheduler_needs_semaphore = params.arch >= 90\n        ? (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen)\n        : ((params.is_causal && !is_varlen) || (is_varlen && params.num_splits > 1));\n    params.varlen_sort_batches = !params.is_local; // Use this value for Sort in scheduler template\n    params.head_swizzle = params.is_causal || params.is_local; // Use this value for LPT in scheduler template\n    if (scheduler_needs_semaphore || use_prepare_varlen) {\n        int b_rounded = round_multiple(params.b, 4); // for 16 byte alignment of pointers\n        int num_prepare_batch_vectors = use_prepare_varlen ? 2 : 0;\n        if(params.varlen_sort_batches) { num_prepare_batch_vectors += 1; }\n        if(params.head_swizzle) { num_prepare_batch_vectors += 1; }\n        int head_swizzle_offset = b_rounded * (params.varlen_sort_batches ? 3 : 2);\n        int tile_count_semaphore_offset = b_rounded * num_prepare_batch_vectors;\n        int metadata_size = int(scheduler_needs_semaphore) + tile_count_semaphore_offset;\n        // printf(\"Num prepare batch vectors = %d, metadata_size = %d.\\n\", num_prepare_batch_vectors, metadata_size);\n        params.skip_scheduler_metadata_computation = scheduler_metadata_.has_value();\n        if (scheduler_metadata_.has_value()) {\n            at::Tensor scheduler_metadata = scheduler_metadata_.value();\n            CHECK_DEVICE(scheduler_metadata);\n            CHECK_SHAPE(scheduler_metadata, metadata_size);\n            CHECK_CONTIGUOUS(scheduler_metadata);\n            TORCH_CHECK(scheduler_metadata.dtype() == torch::kInt32, \"scheduler_metadata must have dtype int32\");\n            tile_count_semaphore = scheduler_metadata;\n        } else {\n            tile_count_semaphore = torch::empty({metadata_size}, opts.dtype(torch::kInt32));\n        }\n        if (scheduler_needs_semaphore && !use_prepare_varlen) {\n            tile_count_semaphore.zero_();  // If varlen we'll manually do the zero-ing\n        }\n        // {num_splits_dynamic, num_m_blocks, varlen_batch_idx, num_nheads_in_l2}\n        params.num_splits_dynamic_ptr = use_prepare_varlen ? tile_count_semaphore.data_ptr<int>() : nullptr;\n        params.num_m_blocks_ptr =  use_prepare_varlen ? tile_count_semaphore.data_ptr<int>() + b_rounded : nullptr;\n        params.varlen_batch_idx_ptr =  use_prepare_varlen && params.varlen_sort_batches ? tile_count_semaphore.data_ptr<int>() + b_rounded * 2 : nullptr;\n        // params.num_n_blocks_ptr  = use_prepare_varlen && params.head_swizzle ? tile_count_semaphore.data_ptr<int>() + head_swizzle_offset : nullptr;\n        params.num_nheads_in_l2_ptr = use_prepare_varlen && params.head_swizzle ? tile_count_semaphore.data_ptr<int>() + head_swizzle_offset : nullptr;\n        params.tile_count_semaphore = scheduler_needs_semaphore ? tile_count_semaphore.data_ptr<int>() + tile_count_semaphore_offset : nullptr;\n        params.tile_count_semaphore_offset = tile_count_semaphore_offset; // might need to zero out semaphore later\n    }\n\n    if (q_v_.has_value()) {\n        TORCH_CHECK(head_size <= 64, \"q_v is only supported for head_size <= 64\");\n        TORCH_CHECK(head_size_v >= 256, \"q_v is only supported for hdim_v >= 256.\");\n        TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16,\n                    \"q_v is only supported for fp16 and bf16 data type\");\n        TORCH_CHECK(params.arch == 90, \"q_v is only supported for Hopper GPUs\");\n        at::Tensor q_v = q_v_.value();\n        TORCH_CHECK(q_v.dtype() == q_type, \"q_v must have the same dtype as query\");\n        CHECK_DEVICE(q_v);\n        TORCH_CHECK(q_v.stride(-1) == 1, \"q_v tensor must have contiguous last dimension\");\n        if (!is_varlen_q) {\n            CHECK_SHAPE(q_v, batch_size, seqlen_q, num_heads, head_size_v);\n        } else {\n            CHECK_SHAPE(q_v, total_q, num_heads, head_size_v);\n        }\n        params.qv_ptr = q_v.data_ptr();\n        // All stride are in elements, not bytes.\n        params.qv_row_stride = q_v.stride(-3);\n        params.qv_head_stride = q_v.stride(-2);\n        if (!is_varlen_q) {\n            params.qv_batch_stride = q_v.stride(0);\n        }\n    }\n\n    if (rotary_cos_.has_value()) {\n        TORCH_CHECK(k_new_.has_value(), \"If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided\");\n        auto rotary_cos = rotary_cos_.value();\n        CHECK_DEVICE(rotary_cos); CHECK_CONTIGUOUS(rotary_cos);\n        params.rotary_dim = rotary_cos.size(1) * 2;\n        TORCH_CHECK(params.rotary_dim <= head_size, \"rotary_dim must be <= headdim\");\n        TORCH_CHECK(params.rotary_dim % 16 == 0, \"Only rotary dimensions divisible by 16 are currently supported\");\n        const int seqlen_ro = rotary_cos.size(0);\n        if (paged_KV) {\n            TORCH_CHECK(seqlen_ro >= seqlen_k, \"cos/sin seqlen must be at least the seqlen of KV cache\");\n        }\n        CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2);\n        TORCH_CHECK(rotary_cos.scalar_type() == q_type, \"rotary_cos must have the same dtype as query\");\n\n        TORCH_CHECK(rotary_sin_.has_value(), \"If rotary cos is provided, rotary sin must also be provided\");\n        auto rotary_sin = rotary_sin_.value();\n        CHECK_DEVICE(rotary_sin); CHECK_CONTIGUOUS(rotary_sin);\n        CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2);\n        TORCH_CHECK(rotary_sin.scalar_type() == q_type, \"rotary_cos must have the same dtype as query\");\n        params.rotary_cos_ptr = rotary_cos.data_ptr();\n        params.rotary_sin_ptr = rotary_sin.data_ptr();\n        params.is_rotary_interleaved = is_rotary_interleaved;\n        if (seqlens_rotary_.has_value()) {\n            at::Tensor seqlens_rotary = seqlens_rotary_.value();\n            CHECK_DEVICE(seqlens_rotary); CHECK_CONTIGUOUS(seqlens_rotary);\n            TORCH_CHECK(seqlens_rotary.dtype() == torch::kInt32, \"seqlens_rotary must have dtype torch.int32\");\n            CHECK_SHAPE(seqlens_rotary, batch_size);\n            params.seqlens_rotary = seqlens_rotary.data_ptr<int>();\n        }\n    } else {\n        params.rotary_dim = 0;\n    }\n\n    if (kv_batch_idx_.has_value()) {\n        auto kv_batch_idx = kv_batch_idx_.value();\n        CHECK_DEVICE(kv_batch_idx); CHECK_CONTIGUOUS(kv_batch_idx);\n        TORCH_CHECK(kv_batch_idx.scalar_type() == torch::kInt32, \"kv_batch_idx must have dtype int32\");\n        params.kv_batch_idx = reinterpret_cast<int *>(kv_batch_idx.data_ptr());\n    }\n\n    at::Tensor out_accum, softmax_lse_accum;\n    auto outaccum_type = at::ScalarType::Float;\n    if (params.num_splits > 1) {\n        TORCH_CHECK(params.num_splits <= 256, \"num_splits > 256 not supported\");\n        if (!is_varlen_q) {\n            out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_v}, opts.dtype(outaccum_type));\n            softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));\n            params.oaccum_batch_stride = out_accum.stride(1);\n            params.lseaccum_batch_stride = softmax_lse_accum.stride(1);\n        } else {\n            out_accum = torch::empty({params.num_splits, num_heads, total_q, head_size_v}, opts.dtype(outaccum_type));\n            softmax_lse_accum = torch::empty({params.num_splits, num_heads, total_q}, opts.dtype(at::kFloat));\n        }\n        params.is_fp32 = false;\n        params.oaccum_ptr = out_accum.data_ptr();\n        params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();\n        params.oaccum_split_stride = out_accum.stride(0);\n        params.oaccum_row_stride = out_accum.stride(-2);\n        params.oaccum_head_stride = out_accum.stride(-3);\n        params.lseaccum_split_stride = softmax_lse_accum.stride(0);\n        params.lseaccum_head_stride = softmax_lse_accum.stride(-2);\n    }\n\n    if (q_type == at::ScalarType::Float8_e4m3fn) {\n        if (q_descale_.has_value()) {\n            auto q_descale = q_descale_.value();\n            CHECK_DEVICE(q_descale);\n            CHECK_SHAPE(q_descale, batch_size, num_heads_k);\n            params.q_descale_ptr = q_descale.data_ptr<float>();\n            params.q_descale_batch_stride = q_descale.stride(0);\n            params.q_descale_head_stride = q_descale.stride(1);\n        } else {\n            params.q_descale_ptr = nullptr;\n        }\n        if (k_descale_.has_value()) {\n            auto k_descale = k_descale_.value();\n            CHECK_DEVICE(k_descale);\n            CHECK_SHAPE(k_descale, batch_size, num_heads_k);\n            params.k_descale_ptr = k_descale.data_ptr<float>();\n            params.k_descale_batch_stride = k_descale.stride(0);\n            params.k_descale_head_stride = k_descale.stride(1);\n        } else {\n            params.k_descale_ptr = nullptr;\n        }\n        if (v_descale_.has_value()) {\n            auto v_descale = v_descale_.value();\n            CHECK_DEVICE(v_descale);\n            CHECK_SHAPE(v_descale, batch_size, num_heads_k);\n            params.v_descale_ptr = v_descale.data_ptr<float>();\n            params.v_descale_batch_stride = v_descale.stride(0);\n            params.v_descale_head_stride = v_descale.stride(1);\n        } else {\n            params.v_descale_ptr = nullptr;\n        }\n    }\n\n    #ifdef FLASHATTENTION_DISABLE_LOCAL\n    TORCH_CHECK(!params.is_local, \"This flash attention build does not support local attention.\");\n    #endif\n    #ifdef FLASHATTENTION_DISABLE_SOFTCAP\n    TORCH_CHECK(params.softcap == 0.0, \"This flash attention build does not support tanh softcapping.\");\n    #endif\n    #ifdef FLASHATTENTION_DISABLE_SPLIT\n    TORCH_CHECK(params.num_splits == 1, \"This flash attention build does not support splits.\");\n    #endif\n    #ifdef FLASHATTENTION_DISABLE_PACKGQA\n    TORCH_CHECK(!params.pack_gqa || params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1, \"This flash attention build does not support pack_gqa.\");\n    #endif\n    #ifdef FLASHATTENTION_DISABLE_PAGEDKV\n    TORCH_CHECK(!(params.page_table && !params.pagedkv_tma), \"This flash attention build does not support paged KV.\");\n    #endif\n    #ifdef FLASHATTENTION_DISABLE_APPENDKV\n    TORCH_CHECK(!k_new_.has_value(), \"This flash attention build does not support appending KV.\");\n    #endif\n\n    if (total_q > 0 && (total_k + params.total_knew) > 0 && num_heads_k > 0) {\n        auto stream = at::cuda::getCurrentCUDAStream().stream();\n        run_mha_fwd(params, stream);\n        if (params.num_splits > 1) {\n            if (out_type == at::ScalarType::BFloat16) {\n                // Since we want output in BF16. Otherwise fwd_combine will output to FP16\n                params.is_bf16 = true;\n            }\n            // Unless there's seqused_q, for the purpose of attn_combine, we can just treat it as batch=1\n            // and seqlen = total_q, and don't need to dispatch to Varlen there.\n            // However, with dynamic split, each row needs to know which batch it belongs to\n            // to read the number of splits, so we just use the varlen version of combine kernel.\n            // if (is_varlen_q && !seqused_q_.has_value()) {\n            // if (is_varlen_q) {\n            //     params.b = 1;\n            //     params.seqlen_q = total_q;\n            // }\n            // This will zero out the semaphore if needed\n            run_mha_fwd_combine(params, stream, true /*enable_pdl*/);\n        } else if (scheduler_needs_semaphore && params.skip_scheduler_metadata_computation) {\n            // need to zero out the semaphore in this case\n            tile_count_semaphore.index({torch::indexing::Slice(params.tile_count_semaphore_offset, params.tile_count_semaphore_offset + 1)}).zero_();\n        }\n    } else if (total_q > 0 && num_heads_k > 0) {\n        // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.\n        out.zero_();\n        softmax_lse.fill_(std::numeric_limits<float>::infinity());\n    }\n\n    // return {out, softmax_lse};\n    return {out, softmax_lse, out_accum, softmax_lse_accum};\n}\n\n#ifdef FLASHATTENTION_DISABLE_BACKWARD\nvoid run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream) {\n    TORCH_CHECK(false, \"Flash-Attention was built with backward disabled\");\n}\n#else\ntemplate <int Arch, bool Has_softcap>\nvoid run_mha_bwd_constexpr(Flash_bwd_params &params, cudaStream_t stream) {\n    if (!params.is_bf16) {\n        #ifndef FLASHATTENTION_DISABLE_FP16\n        #ifndef FLASHATTENTION_DISABLE_HDIM64\n        if (params.d_rounded == 64) { return run_mha_bwd_<Arch, cutlass::half_t, 64, Has_softcap>(params, stream); }\n        #endif\n        #ifndef FLASHATTENTION_DISABLE_HDIM96\n        if (params.d_rounded == 96) { return run_mha_bwd_<Arch, cutlass::half_t, 96, Has_softcap>(params, stream); }\n        #endif\n        #ifndef FLASHATTENTION_DISABLE_HDIM128\n        if (params.d_rounded == 128) { return run_mha_bwd_<Arch, cutlass::half_t, 128, Has_softcap>(params, stream); }\n        #endif\n        #ifndef FLASHATTENTION_DISABLE_HDIM192\n        if (params.d_rounded == 192) { return run_mha_bwd_<Arch, cutlass::half_t, 192, Has_softcap>(params, stream); }\n        #endif\n        #ifndef FLASHATTENTION_DISABLE_HDIM256\n        if (params.d_rounded == 256) { return run_mha_bwd_<Arch, cutlass::half_t, 256, Has_softcap>(params, stream); }\n        #endif\n        #else\n        TORCH_CHECK(false, \"This flash attention build does not support FP16.\");\n        #endif\n    } else {\n        #ifndef FLASHATTENTION_DISABLE_HDIM64\n        if (params.d_rounded == 64) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 64, Has_softcap>(params, stream); }\n        #endif\n        #ifndef FLASHATTENTION_DISABLE_HDIM96\n        if (params.d_rounded == 96) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 96, Has_softcap>(params, stream); }\n        #endif\n        #ifndef FLASHATTENTION_DISABLE_HDIM128\n        if (params.d_rounded == 128) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 128, Has_softcap>(params, stream); }\n        #endif\n        #ifndef FLASHATTENTION_DISABLE_HDIM192\n        if (params.d_rounded == 192) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 192, Has_softcap>(params, stream); }\n        #endif\n        #ifndef FLASHATTENTION_DISABLE_HDIM256\n        if (params.d_rounded == 256) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 256, Has_softcap>(params, stream); }\n        #endif\n    }\n}\n\nvoid run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream) {\n        // FP16_SWITCH(!params.is_bf16, [&] {\n        //     HEADDIM_SWITCH(params.d, [&] {\n        //         run_mha_bwd_<elem_type, kHeadDim>(params, stream);\n        //     });\n        // });\n    ARCH_SWITCH(params.arch, Arch, [&] {\n        SOFTCAP_SWITCH(params.softcap > 0.f, Has_softcap, [&] {\n            run_mha_bwd_constexpr<Arch, Has_softcap>(params, stream);\n        });\n    });\n}\n#endif\n\n\n// b: batch_size\n// s_q: seqlen_q\n// s_k: seqlen_k\n// h: num_heads\n// h_k: num_heads_k\n// d: head_size\nstd::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_bwd(\n    at::Tensor dout,  // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q\n    at::Tensor q,     // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q\n    at::Tensor k,     // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k\n    at::Tensor v,     // (b, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k\n    at::Tensor out,   // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q\n    at::Tensor softmax_lse,    // (b, h, s_q) or (h, total_q) if there is cu_seqlens_q\n    std::optional<at::Tensor> dq_,   // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q\n    std::optional<at::Tensor> dk_,   // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k\n    std::optional<at::Tensor> dv_,   // (b, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k\n    std::optional<at::Tensor> cu_seqlens_q_,   // b+1\n    std::optional<at::Tensor> cu_seqlens_k_,   // b+1\n    std::optional<at::Tensor> seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.\n    std::optional<at::Tensor> seqused_k_, // b. If given, only this many elements of each batch element's keys are used.\n    std::optional<int64_t> max_seqlen_q_,\n    std::optional<int64_t> max_seqlen_k_,\n    std::optional<double> softmax_scale_,\n    bool is_causal,\n    int64_t window_size_left,\n    int64_t window_size_right,\n    double softcap,\n    bool deterministic,\n    int64_t sm_margin\n) {\n\n    #ifdef FLASHATTENTION_DISABLE_BACKWARD\n        TORCH_CHECK(false, \"This flash attention build does not support backward.\");\n    #endif\n\n    auto dprops = at::cuda::getCurrentDeviceProperties();\n    bool is_sm8x = dprops->major >= 8;\n    TORCH_CHECK(is_sm8x, \"FlashAttention only supports Ampere GPUs or newer.\");\n\n    auto q_type = q.dtype();\n    TORCH_CHECK(q_type == torch::kFloat16 || q_type == torch::kBFloat16,\n                \"FlashAttention only support fp16 and bf16 data type\");\n    TORCH_CHECK(k.dtype() == q_type, \"query and key must have the same dtype\");\n    TORCH_CHECK(v.dtype() == q_type, \"query and value must have the same dtype\");\n    TORCH_CHECK(out.dtype() == q_type, \"query and out must have the same dtype\");\n    TORCH_CHECK(dout.dtype() == q_type, \"query and dout must have the same dtype\");\n\n    CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);\n    CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);\n\n    TORCH_CHECK(q.stride(-1) == 1, \"Input tensor must have contiguous last dimension\");\n    TORCH_CHECK(k.stride(-1) == 1, \"Input tensor must have contiguous last dimension\");\n    TORCH_CHECK(v.stride(-1) == 1, \"Input tensor must have contiguous last dimension\");\n    TORCH_CHECK(out.stride(-1) == 1, \"out tensor must have contiguous last dimension\");\n    TORCH_CHECK(dout.stride(-1) == 1, \"dout tensor must have contiguous last dimension\");\n\n    at::Tensor cu_seqlens_q;\n    bool const is_varlen_q = cu_seqlens_q_.has_value();\n    if (is_varlen_q) {\n        cu_seqlens_q = cu_seqlens_q_.value();\n        CHECK_DEVICE(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_q);\n        TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, \"cu_seqlens_q must have dtype torch.int32\");\n        TORCH_CHECK(max_seqlen_q_.has_value(), \"max_seqlen_q must be provided if cu_seqlens_q is provided\");\n    }\n    at::Tensor cu_seqlens_k;\n    bool const is_varlen_k = cu_seqlens_k_.has_value();\n    if (is_varlen_k) {\n        cu_seqlens_k = cu_seqlens_k_.value();\n        CHECK_DEVICE(cu_seqlens_k); CHECK_CONTIGUOUS(cu_seqlens_k);\n        TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, \"cu_seqlens_k must have dtype torch.int32\");\n        TORCH_CHECK(max_seqlen_k_.has_value(), \"max_seqlen_k must be provided if cu_seqlens_k is provided\");\n    }\n    // This is what we will template on\n    bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value();\n    #ifdef FLASHATTENTION_DISABLE_VARLEN\n        TORCH_CHECK(!is_varlen, \"This flash attention build does not support varlen.\");\n    #endif\n\n    auto const sizes = q.sizes();\n    int const batch_size = !is_varlen_q ? sizes[0] : cu_seqlens_q.size(0) - 1;\n    int const seqlen_q = !is_varlen_q ? sizes[1] : max_seqlen_q_.value();\n    int const total_q = !is_varlen_q ? batch_size * sizes[1] : sizes[0];\n    int const num_heads = q.size(-2);\n    int const head_size = q.size(-1);\n    int const head_size_v = v.size(-1);\n    int const seqlen_k = !is_varlen_k ? k.size(1) : max_seqlen_k_.value();\n    int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0);\n    int const num_heads_k = k.size(-2);\n    TORCH_CHECK(head_size % 8 == 0, \"head_size should be a multiple of 8\");\n    TORCH_CHECK(head_size_v % 8 == 0, \"head_size_v should be a multiple of 8\");\n    int const max_headdim = get_max_headdim();\n    TORCH_CHECK(std::max(head_size, head_size_v) <= max_headdim, \"FlashAttention forward only supports head dimension at most \" + std::to_string(max_headdim));\n    TORCH_CHECK(num_heads % num_heads_k == 0, \"Number of heads in key/value must divide number of heads in query\");\n    double softmax_scale = 1.0 / sqrt(double(head_size));\n    if (softmax_scale_.has_value()) {\n        softmax_scale = softmax_scale_.value();\n    }\n\n    // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM\n    if (window_size_left >= seqlen_k - 1) { window_size_left = -1; }\n    if (window_size_right >= seqlen_q - 1) { window_size_right = -1; }\n    if (is_causal) { window_size_right = 0; }\n    // There's a case where is_causal=false, window_size=(-1, 0). Then set_params_bprop will set params.is_causal=true.\n    // If we don't have is_causal here matching params.is_causal, we might get the wrong kBlockM (and cause IMA).\n    is_causal = window_size_left < 0 && window_size_right == 0;\n\n    int const arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor;\n    int const head_size_rounded = round_up_headdim(std::max(head_size, head_size_v));\n    int const head_size_v_rounded = head_size_rounded;\n    TORCH_CHECK(!deterministic || head_size_rounded < 256, \"Deterministic backward not supported for hdim 256.\");\n    // Very important that these match the kernel configs\n    bool const is_local = (window_size_left >= 0 || window_size_right >= 0) && !is_causal;\n    int const kBlockM_sm90 = head_size_rounded <= 64 ? (is_causal && softcap > 0.0 ? 96 : 128)\n        : (head_size_rounded <= 96 ? 64\n           : (head_size_rounded <= 128 ? (is_causal || is_local || softcap > 0.0 ? 64 : 80)\n              : 64));\n    int const kBlockM_sm80 = head_size_rounded <= 64 ? 128 : 64;\n    int const kBlockM_sm86 = head_size_rounded <= 192 ? 64 : 32;\n    int const kBlockM = arch >= 90 ? kBlockM_sm90 : (arch == 86 || arch == 89 ? kBlockM_sm86 : kBlockM_sm80);\n    int const kBlockN_sm90 = head_size_rounded <= 128\n        ? 128\n        : (head_size_rounded <= 192 ? 96 : 80);\n    int const kBlockN_sm80 = head_size_rounded <= 128\n        ? 128\n        : (head_size_rounded <= 192 ? 80 : 64);\n    int const kBlockN_sm86 = head_size_rounded <= 64 ? 128\n        : (head_size_rounded <= 96 ? 128\n           : (head_size_rounded <= 128 ? 96\n              : (head_size_rounded <= 192 ? 64 : 64)));\n    int const kBlockN = arch >= 90 ? kBlockN_sm90 : (arch == 86 || arch == 89 ? kBlockN_sm86 : kBlockN_sm80);\n    auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };\n    int const seqlen_q_rounded = round_multiple(seqlen_q, kBlockM);\n    int const seqlen_k_rounded = round_multiple(seqlen_k, kBlockN);\n    int const total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM);\n    int const total_k_padded_rounded = round_multiple(total_k + batch_size * kBlockN, kBlockN);\n\n    if (!is_varlen_q) {\n        CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);\n        CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_v);\n        CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_v);\n    } else {\n        CHECK_SHAPE(q, total_q, num_heads, head_size);\n        CHECK_SHAPE(out, total_q, num_heads, head_size_v);\n        CHECK_SHAPE(dout, total_q, num_heads, head_size_v);\n        CHECK_SHAPE(cu_seqlens_q, batch_size + 1);\n    }\n    if (!is_varlen_k) {\n        CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);\n        CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_v);\n    } else {\n        CHECK_SHAPE(k, total_k, num_heads_k, head_size);\n        CHECK_SHAPE(v, total_k, num_heads_k, head_size_v);\n        CHECK_SHAPE(cu_seqlens_k, batch_size + 1);\n    }\n\n    if (seqused_q_.has_value()){\n        auto seqused_q = seqused_q_.value();\n        TORCH_CHECK(seqused_q.dtype() == torch::kInt32, \"seqused_q must have dtype int32\");\n        CHECK_DEVICE(seqused_q); CHECK_CONTIGUOUS(seqused_q);\n        CHECK_SHAPE(seqused_q, batch_size);\n    }\n    if (seqused_k_.has_value()){\n        auto seqused_k = seqused_k_.value();\n        TORCH_CHECK(seqused_k.dtype() == torch::kInt32, \"seqused_k must have dtype int32\");\n        CHECK_DEVICE(seqused_k); CHECK_CONTIGUOUS(seqused_k);\n        CHECK_SHAPE(seqused_k, batch_size);\n    }\n\n    at::Tensor dq, dk, dv;\n    if (dq_.has_value()) {\n        dq = dq_.value();\n        TORCH_CHECK(dq.dtype() == q_type, \"dq must have the same dtype as q\");\n        CHECK_DEVICE(dq);\n        TORCH_CHECK(dq.stride(-1) == 1, \"dq must have contiguous last dimension\");\n        if (!is_varlen_q) {\n            CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size);\n        } else {\n            CHECK_SHAPE(dq, total_q, num_heads, head_size);\n        }\n    } else {\n        dq = torch::empty_like(q);\n    }\n    if (dk_.has_value()) {\n        dk = dk_.value();\n        TORCH_CHECK(dk.dtype() == q_type, \"dk must have the same dtype as q\");\n        CHECK_DEVICE(dk);\n        TORCH_CHECK(dk.stride(-1) == 1, \"dk must have contiguous last dimension\");\n        if (!is_varlen_k) {\n            CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size);\n        } else {\n            CHECK_SHAPE(dk, total_k, num_heads_k, head_size);\n        }\n    } else {\n        dk = torch::empty_like(k);\n    }\n    if (dv_.has_value()) {\n        dv = dv_.value();\n        TORCH_CHECK(dv.dtype() == q_type, \"dv must have the same dtype as q\");\n        CHECK_DEVICE(dv);\n        TORCH_CHECK(dv.stride(-1) == 1, \"dv must have contiguous last dimension\");\n        if (!is_varlen_k) {\n            CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size_v);\n        } else {\n            CHECK_SHAPE(dv, total_k, num_heads_k, head_size_v);\n        }\n    } else {\n        dv = torch::empty_like(v);\n    }\n\n    // Otherwise the kernel will be launched from cuda:0 device\n    // Cast to char to avoid compiler warning about narrowing\n    auto device_guard = make_cuda_guard_from_tensor(q);\n\n    auto opts = q.options();\n    // Need softmax_d to have total_q_padded_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64\n    at::Tensor softmax_d, softmax_lse_log2;\n    if (!is_varlen) {\n        // Need softmax_d to have seqlen_q_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64\n        softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));\n        softmax_lse_log2 = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));\n    } else {\n        softmax_d = torch::empty({num_heads, total_q_padded_rounded}, opts.dtype(at::kFloat));\n        softmax_lse_log2 = torch::empty({num_heads, total_q_padded_rounded}, opts.dtype(at::kFloat));\n    }\n    at::Tensor dq_accum, dk_accum, dv_accum;\n    if (!is_varlen) {\n        dq_accum = torch::empty({batch_size, num_heads, seqlen_q_rounded * head_size_rounded}, opts.dtype(at::kFloat));\n    } else {\n        dq_accum = torch::empty({num_heads, total_q_padded_rounded * head_size_rounded}, opts.dtype(at::kFloat));\n    }\n    if (num_heads_k != num_heads) {  // MQA / GQA\n        if (!is_varlen) {\n            dk_accum = torch::zeros({batch_size, num_heads_k, seqlen_k_rounded * head_size_rounded}, opts.dtype(at::kFloat));\n            dv_accum = torch::zeros({batch_size, num_heads_k, seqlen_k_rounded * head_size_v_rounded}, opts.dtype(at::kFloat));\n        } else {\n            dk_accum = torch::zeros({num_heads_k, total_k_padded_rounded, head_size_rounded}, opts.dtype(at::kFloat));\n            dv_accum = torch::zeros({num_heads_k, total_k_padded_rounded, head_size_v_rounded}, opts.dtype(at::kFloat));\n        }\n    }\n\n    Flash_bwd_params params;\n    set_params_dgrad(params,\n                     batch_size,\n                     seqlen_q, seqlen_k,\n                     seqlen_q_rounded, seqlen_k_rounded,\n                     num_heads, num_heads_k,\n                     head_size, head_size_rounded,\n                     q, k, v, out,\n                     dout, dq, dk, dv,\n                     !is_varlen_q ? nullptr : cu_seqlens_q.data_ptr(),\n                     !is_varlen_k ? nullptr : cu_seqlens_k.data_ptr(),\n                     seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr,\n                     seqused_k_.has_value() ? seqused_k_.value().data_ptr() : nullptr,\n                     dq_accum.data_ptr(),\n                     num_heads_k != num_heads ? dk_accum.data_ptr() : nullptr,\n                     num_heads_k != num_heads ? dv_accum.data_ptr() : nullptr,\n                     softmax_lse.data_ptr(),\n                     softmax_d.data_ptr(),\n                     /*p_dropout=*/0.f,\n                     softmax_scale,\n                     window_size_left,\n                     window_size_right,\n                     0,  // attention_chunk\n                     softcap,\n                     deterministic,\n                     sm_margin);\n    params.total_q = total_q;\n    params.total_k = total_k;\n    params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr();\n    params.dv = head_size_v;\n    params.dv_rounded = head_size_v_rounded;\n\n    // auto tile_count_semaphore = (params.is_causal || params.is_local) ? torch::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32));\n    // params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();\n    // Will be zero'ed out in the backward preprocess kernel\n    at::Tensor dq_semaphore = torch::empty({(seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads}, opts.dtype(torch::kInt32));\n    params.dq_semaphore = dq_semaphore.data_ptr<int>();\n    at::Tensor dk_semaphore, dv_semaphore;\n    if (num_heads_k != num_heads && params.deterministic) {\n        // TODO: maybe also zero'ed out dk_semaphore and dv_semaphore in the backward preprocess kernel\n        dk_semaphore = torch::zeros({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32));\n        dv_semaphore = torch::zeros({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32));\n        params.dk_semaphore = dk_semaphore.data_ptr<int>();\n        params.dv_semaphore = dv_semaphore.data_ptr<int>();\n    }\n\n    #ifdef FLASHATTENTION_DISABLE_LOCAL\n    TORCH_CHECK(!params.is_local, \"This flash attention build does not support local attention.\");\n    #endif\n    #ifdef FLASHATTENTION_DISABLE_SOFTCAP\n    TORCH_CHECK(params.softcap == 0.0, \"This flash attention build does not support tanh softcapping.\");\n    #endif\n\n    if (total_q > 0 && total_k > 0 && num_heads_k > 0) {\n        auto stream = at::cuda::getCurrentCUDAStream().stream();\n        run_mha_bwd(params, stream);\n    } else if (total_k > 0 && num_heads_k > 0) {\n        // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.\n        dk.zero_();\n        dv.zero_();\n        softmax_d.zero_();\n    } else if (total_q > 0 && num_heads_k > 0) {\n        dq.zero_();\n        softmax_d.zero_();\n    }\n\n    return { softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum };\n}\n\nstd::tuple<at::Tensor, at::Tensor>\nmha_combine(at::Tensor out_partial,         // num_splits x batch_size x seqlen x num_heads x head_size\n            at::Tensor lse_partial,         // num_splits x batch_size x seqlen x num_heads\n            std::optional<at::Tensor> out_,        // batch_size x seqlen x num_heads x head_size\n            std::optional<at::ScalarType> out_dtype_\n            ) {\n\n    auto dprops = at::cuda::getCurrentDeviceProperties();\n    bool is_sm8x = dprops->major >= 8;\n    TORCH_CHECK(is_sm8x, \"Attention combine function only supports Ampere GPUs or newer.\");\n\n    auto out_partial_type = out_partial.scalar_type();\n    TORCH_CHECK(out_partial_type == at::ScalarType::Float, \"Attention combine function only support fp32 data type\");\n    TORCH_CHECK(lse_partial.scalar_type() == at::ScalarType::Float, \"Attention combine function only support fp32 data type\");\n\n    CHECK_DEVICE(out_partial); CHECK_DEVICE(lse_partial);\n\n    TORCH_CHECK(out_partial.stride(-1) == 1, \"Input tensor must have contiguous last dimension\");\n    TORCH_CHECK(lse_partial.stride(-2) == 1, \"LSE tensor must be contiguous in the seqlen dimension\");\n\n    const auto sizes = out_partial.sizes();\n\n    const int num_splits = sizes[0];\n    const int batch_size = sizes[1];\n    const int seqlen = sizes[2];\n    const int num_heads = sizes[3];\n    const int head_size_og = sizes[4];\n    TORCH_CHECK(num_splits <= 256, \"FlashAttention combine only supports num_splits at most 256\");\n\n    CHECK_SHAPE(out_partial, num_splits, batch_size, seqlen, num_heads, head_size_og);\n    CHECK_SHAPE(lse_partial, num_splits, batch_size, seqlen, num_heads);\n\n    int const alignment = 4;\n    at::Tensor out_partial_padded;\n    auto pad = [](at::Tensor x, int alignment) {\n        return x.size(-1) % alignment == 0 ? x : torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions({0, alignment - x.size(-1) % alignment}));\n    };\n    out_partial_padded = pad(out_partial, alignment);\n\n    auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };\n    const int head_size = round_multiple(head_size_og, alignment);\n\n    auto opts = out_partial.options();\n    at::ScalarType out_type = out_dtype_.value_or(out_partial.scalar_type());\n    TORCH_CHECK(out_type == at::ScalarType::Float || out_type == at::ScalarType::BFloat16 || out_type == at::ScalarType::Half, \"Output type must be FP32, FP16 or BF16\");\n    at::Tensor out;\n    if (out_.has_value()) {\n        out = out_.value();\n        TORCH_CHECK(out.scalar_type() == out_type);\n        CHECK_DEVICE(out);\n        TORCH_CHECK(out.stride(-1) == 1, \"Output tensor must have contiguous last dimension\");\n        CHECK_SHAPE(out, batch_size, seqlen, num_heads, head_size_og);\n        if (head_size_og % alignment != 0) {\n            out = torch::empty({batch_size, seqlen, num_heads, head_size}, opts.dtype(out_type));\n        }\n    } else {\n        out = torch::empty({batch_size, seqlen, num_heads, head_size}, opts.dtype(out_type));\n    }\n\n    // Otherwise the kernel will be launched from cuda:0 device\n    // Cast to char to avoid compiler warning about narrowing\n    at::cuda::CUDAGuard device_guard{(char)out_partial.get_device()};\n\n    auto softmax_lse = torch::empty({batch_size, num_heads, seqlen}, opts.dtype(at::kFloat)).transpose(1, 2);\n\n    Flash_fwd_params params {};  // Need to reset the params to set everything to zero\n    params.is_fp32 = out_type == at::ScalarType::Float;\n    params.is_bf16 = out_type == at::ScalarType::BFloat16;\n    params.oaccum_ptr = out_partial_padded.data_ptr();\n    params.softmax_lseaccum_ptr = lse_partial.data_ptr();\n    params.o_ptr = out.data_ptr();\n    params.softmax_lse_ptr = softmax_lse.data_ptr();\n    params.b = batch_size;\n    params.h = num_heads;\n    params.seqlen_q = seqlen;\n    params.dv = head_size;\n    params.num_splits = num_splits;\n    params.oaccum_split_stride = out_partial_padded.stride(0);\n    params.oaccum_row_stride = out_partial_padded.stride(2);\n    params.oaccum_head_stride = out_partial_padded.stride(3);\n    params.oaccum_batch_stride = out_partial_padded.stride(1);\n    params.lseaccum_split_stride = lse_partial.stride(0);\n    params.lseaccum_head_stride = lse_partial.stride(3);\n    params.lseaccum_batch_stride = lse_partial.stride(1);\n    params.o_row_stride = out.stride(1);\n    params.o_head_stride = out.stride(2);\n    params.o_batch_stride = out.stride(0);\n    params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor;\n\n    if (seqlen > 0 && batch_size > 0) {\n        auto stream = at::cuda::getCurrentCUDAStream().stream();\n        run_mha_fwd_combine(params, stream, false /*enable_pdl*/);\n    }\n\n    at::Tensor out_padded = out;\n    if (head_size_og % alignment != 0) {\n        out = out.index({\"...\", torch::indexing::Slice(torch::indexing::None, head_size_og)});\n        // if (out_.has_value()) { out_.value().copy_(out); }\n    }\n\n    return {out, softmax_lse};\n}\n\nTORCH_LIBRARY(flash_attn_3, m) {\n    m.def(\"fwd(\"\n        \"Tensor q,\"\n        \"Tensor k,\"\n        \"Tensor v,\"\n        \"Tensor(k_new!)? k_new = None,\"\n        \"Tensor(v_new!)? v_new = None,\"\n        \"Tensor? q_v = None,\"\n        \"Tensor(out!)? out = None,\"\n        \"Tensor? cu_seqlens_q = None,\"\n        \"Tensor? cu_seqlens_k = None,\"\n        \"Tensor? cu_seqlens_k_new = None,\"\n        \"Tensor? seqused_q = None,\"\n        \"Tensor? seqused_k = None,\"\n        \"int? max_seqlen_q = None,\"\n        \"int? max_seqlen_k = None,\"\n        \"Tensor? page_table = None,\"\n        \"Tensor? kv_batch_idx = None,\"\n        \"Tensor? leftpad_k = None,\"\n        \"Tensor? rotary_cos = None,\"\n        \"Tensor? rotary_sin = None,\"\n        \"Tensor? seqlens_rotary = None,\"\n        \"Tensor? q_descale = None,\"\n        \"Tensor? k_descale = None,\"\n        \"Tensor? v_descale = None,\"\n        \"float? softmax_scale = None,\"\n        \"bool is_causal = False,\"\n        \"int window_size_left = -1,\"\n        \"int window_size_right = -1,\"\n        \"int attention_chunk = 0,\"\n        \"float softcap = 0.0,\"\n        \"bool is_rotary_interleaved = False,\"\n        \"Tensor? scheduler_metadata = None,\"\n        \"int num_splits = 0,\"\n        \"bool? pack_gqa = None,\"\n        \"int sm_margin = 0) -> (Tensor(out!), Tensor, Tensor, Tensor)\");\n    m.def(\"bwd(\"\n        \"Tensor dout,\"\n        \"Tensor q,\"\n        \"Tensor k,\"\n        \"Tensor v,\"\n        \"Tensor out,\"\n        \"Tensor softmax_lse,\"\n        \"Tensor(dq!)? dq = None,\"\n        \"Tensor(dk!)? dk = None,\"\n        \"Tensor(dv!)? dv = None,\"\n        \"Tensor? cu_seqlens_q = None,\"\n        \"Tensor? cu_seqlens_k = None,\"\n        \"Tensor? seqused_q = None,\"\n        \"Tensor? seqused_k = None,\"\n        \"int? max_seqlen_q = None,\"\n        \"int? max_seqlen_k = None,\"\n        \"float? softmax_scale = None,\"\n        \"bool is_causal = False,\"\n        \"int window_size_left = -1,\"\n        \"int window_size_right = -1,\"\n        \"float softcap = 0.0,\"\n        \"bool deterministic = False,\"\n        \"int sm_margin = 0) -> (Tensor, Tensor, Tensor, Tensor, Tensor)\");\n    m.def(\"fwd_combine(\"\n        \"Tensor out_partial,\"\n        \"Tensor lse_partial,\"\n        \"Tensor(out!)? out = None,\"\n        \"ScalarType? out_dtype = None) -> (Tensor(out!), Tensor)\");\n    m.def(\"get_scheduler_metadata(\"\n        \"int batch_size,\"\n        \"int max_seqlen_q,\"\n        \"int max_seqlen_k,\"\n        \"int num_heads,\"\n        \"int num_heads_k,\"\n        \"int headdim,\"\n        \"int headdim_v,\"\n        \"ScalarType qkv_dtype,\"\n        \"Tensor seqused_k,\"\n        \"Tensor? cu_seqlens_q = None,\"\n        \"Tensor? cu_seqlens_k = None,\"\n        \"Tensor? cu_seqlens_k_new = None,\"\n        \"Tensor? seqused_q = None,\"\n        \"Tensor? leftpad_k = None,\"\n        \"int? page_size = None,\"\n        \"int max_seqlen_k_new = 0,\"\n        \"bool is_causal = False,\"\n        \"int window_size_left = -1,\"\n        \"int window_size_right = -1,\"\n        \"int attention_chunk = 0,\"\n        \"bool has_softcap = False,\"\n        \"int num_splits = 0,\"\n        \"bool? pack_gqa = None,\"\n        \"int sm_margin = 0) -> Tensor\");\n}\n\nTORCH_LIBRARY_IMPL(flash_attn_3, CUDA, m) {\n    m.impl(\"fwd\", &mha_fwd);\n    m.impl(\"bwd\", &mha_bwd);\n    m.impl(\"fwd_combine\", &mha_combine);\n    m.impl(\"get_scheduler_metadata\", &mha_fwd_get_scheduler_metadata);\n}\n"
  },
  {
    "path": "hopper/flash_api_stable.cpp",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n ******************************************************************************/\n\n#include <Python.h>\n\n#include <cutlass/numeric_types.h>\n\n#include \"flash.h\"\n#include \"static_switch.h\"\n#include \"tile_size.h\"\n#include \"heuristics.h\"\n#include \"cuda_check.h\"\n\n#include <torch/csrc/stable/tensor.h>\n#include <torch/csrc/stable/library.h>\n#include <torch/csrc/stable/ops.h>\n#include <torch/csrc/stable/accelerator.h>\n#include <torch/csrc/inductor/aoti_torch/c/shim.h>\n\n// Declare the CUDA stream function that's behind #ifdef USE_CUDA in shim.h\nextern \"C\" AOTITorchError aoti_torch_get_current_cuda_stream(int32_t device_index, void** ret_stream);\n\n#include <torch/headeronly/core/ScalarType.h>\n#include <torch/headeronly/util/Exception.h>\n\n#include <cuda_runtime.h>\n#include <string>\n#include <deque>\n#include <mutex>\n\nusing torch::stable::Tensor;\nnamespace tsa = torch::stable::accelerator;\n\nnamespace {\ninline tsa::DeviceGuard make_device_guard(const Tensor& t) {\n  return tsa::DeviceGuard(static_cast<tsa::DeviceIndex>(t.get_device()));\n}\nstd::deque<std::once_flag> device_flags;\nstd::vector<cudaDeviceProp> device_properties;\n\nvoid initVectors() {\n  static bool init_flag [[maybe_unused]] = []() {\n    int device_count;\n    cudaError_t err = cudaGetDeviceCount(&device_count);\n    if (err != cudaSuccess) {\n      STD_TORCH_CHECK(false, \"cudaGetDeviceProperties failed: \" +\n                                 std::string(cudaGetErrorString(err)));\n    }\n    device_flags.resize(device_count);\n    device_properties.resize(device_count);\n    return true;\n  }();\n}\n\nvoid initDeviceProperty(int device_index) {\n  cudaDeviceProp device_prop{};\n  cudaError_t err = cudaGetDeviceProperties(&device_prop, device_index);\n  if (err != cudaSuccess) {\n    STD_TORCH_CHECK(false, \"cudaGetDeviceProperties failed: \" +\n                               std::string(cudaGetErrorString(err)));\n  }\n  device_properties[device_index] = device_prop;\n}\n\n// Helper function to get device properties using raw CUDA APIs\ncudaDeviceProp* get_device_prop() {\n  initVectors();\n  int device_index;\n  cudaError_t err = cudaGetDevice(&device_index);\n  if (err != cudaSuccess) {\n    STD_TORCH_CHECK(false, \"cudaGetDevice failed: \" +\n                               std::string(cudaGetErrorString(err)));\n  }\n\n  std::call_once(device_flags[device_index], initDeviceProperty, device_index);\n  return &device_properties[device_index];\n}\n} // anonymous namespace\n\n\nextern \"C\" {\n/* Creates a dummy empty _C module that can be imported from Python.\n    The import from Python will load the .so consisting of this file\n    in this extension, so that the STABLE_TORCH_LIBRARY static initializers\n    below are run. */\nPyObject* PyInit__C(void)\n{\n    static struct PyModuleDef module_def = {\n        PyModuleDef_HEAD_INIT,\n        \"_C\",   /* name of module */\n        NULL,   /* module documentation, may be NULL */\n        -1,     /* size of per-interpreter state of the module,\n                    or -1 if the module keeps state in global variables. */\n        NULL,   /* methods */\n    };\n    return PyModule_Create(&module_def);\n}\n}\n\n#define CHECK_DEVICE(x) STD_TORCH_CHECK(x.is_cuda(), #x \" must be on CUDA\")\n#define CHECK_SHAPE(x, ...) \\\n    do { \\\n        auto expected_dims = std::vector<int64_t>{__VA_ARGS__}; \\\n        STD_TORCH_CHECK(x.dim() == static_cast<int64_t>(expected_dims.size()), #x \" must have \" + std::to_string(expected_dims.size()) + \" dimensions, got \" + std::to_string(x.dim())); \\\n        for (size_t i = 0; i < expected_dims.size(); ++i) { \\\n            STD_TORCH_CHECK(x.size(i) == expected_dims[i], #x \" dimension \" + std::to_string(i) + \" must have size \" + std::to_string(expected_dims[i]) + \", got \" + std::to_string(x.size(i))); \\\n        } \\\n    } while (0)\n#define CHECK_CONTIGUOUS(x) STD_TORCH_CHECK(x.is_contiguous(), #x \" must be contiguous\")\n\n#define PREPARE_VARLEN_MAX_BATCHES_1CTA 992\n\nvoid set_params_fprop(Flash_fwd_params &params,\n                      // sizes\n                      const size_t b,\n                      const size_t seqlen_q,\n                      const size_t seqlen_k,\n                      const size_t seqlen_q_rounded,\n                      const size_t seqlen_k_rounded,\n                      const size_t h,\n                      const size_t h_k,\n                      const size_t d,\n                      const size_t d_rounded,\n                      // device pointers\n                      const Tensor q,\n                      const Tensor k,\n                      const Tensor v,\n                      Tensor out,\n                      void *cu_seqlens_q_d,\n                      void *cu_seqlens_k_d,\n                      void *seqused_q,\n                      void *seqused_k,\n                      void *softmax_lse_d,\n                      float p_dropout,\n                      float softmax_scale,\n                      int window_size_left,\n                      int window_size_right,\n                      int attention_chunk,\n                      const float softcap=0.f,\n                      const int sm_margin=0) {\n\n    // Reset the parameters\n    params = {};\n\n    params.is_bf16 = q.scalar_type() == torch::headeronly::ScalarType::BFloat16;\n    params.is_e4m3 = q.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn;\n\n    // Set the pointers and strides.\n    params.q_ptr = q.data_ptr();\n    params.k_ptr = k.data_ptr();\n    params.v_ptr = v.data_ptr();\n    // All stride are in elements, not bytes.\n    params.q_row_stride = q.stride(-3);\n    params.k_row_stride = k.stride(-3);\n    params.v_row_stride = v.stride(-3);\n    params.q_head_stride = q.stride(-2);\n    params.k_head_stride = k.stride(-2);\n    params.v_head_stride = v.stride(-2);\n    params.v_dim_stride = v.stride(-1);\n    params.o_ptr = out.data_ptr();\n    params.o_row_stride = out.stride(-3);\n    params.o_head_stride = out.stride(-2);\n\n    if (cu_seqlens_q_d == nullptr) {\n        params.q_batch_stride = q.stride(0);\n        params.o_batch_stride = out.stride(0);\n    }\n    if (cu_seqlens_k_d == nullptr) {\n        params.k_batch_stride = k.stride(0);\n        params.v_batch_stride = v.stride(0);\n    }\n\n    params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);\n    params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);\n    params.seqused_q = static_cast<int *>(seqused_q);\n    params.seqused_k = static_cast<int *>(seqused_k);\n\n    // Softmax sum\n    params.softmax_lse_ptr = softmax_lse_d;\n\n    // Set the dimensions.\n    params.b = b;\n    params.h = h;\n    params.h_k = h_k;\n    params.seqlen_q = seqlen_q;\n    params.seqlen_k = seqlen_k;\n    params.seqlen_q_rounded = seqlen_q_rounded;\n    params.seqlen_k_rounded = seqlen_k_rounded;\n    params.d = d;\n    params.d_rounded = d_rounded;\n\n    // Set the different scale values.\n    params.scale_softmax = softmax_scale;\n    params.softcap = softcap;\n\n    // Set this to probability of keeping an element to simplify things.\n    params.p_dropout = 1.f - p_dropout;\n    // Convert p from float to int so we don't have to convert the random uint to float to compare.\n    // [Minor] We want to round down since when we do the comparison we use <= instead of <\n    // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));\n    // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));\n    params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));\n    params.rp_dropout = 1.f / params.p_dropout;\n    STD_TORCH_CHECK(p_dropout < 1.f);\n    #ifdef FLASHATTENTION_DISABLE_DROPOUT\n        STD_TORCH_CHECK(p_dropout == 0.0f, \"This flash attention build does not support dropout.\");\n    #endif\n\n    // Causal is the special case where window_size_right == 0 and window_size_left < 0.\n    // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.\n    params.is_causal = window_size_left < 0 && window_size_right == 0 && attention_chunk == 0;\n    params.is_local = (window_size_left >= 0 || window_size_right >= 0 || attention_chunk >= 1) && !params.is_causal;\n\n    // TODO: check this\n    if (window_size_left < 0) { window_size_left = seqlen_k - 1; }\n    if (window_size_right < 0) { window_size_right = seqlen_q - 1; }\n    if (attention_chunk > 0) {\n        window_size_left = std::min(window_size_left, attention_chunk - 1);\n        window_size_right = std::min(window_size_right, attention_chunk - 1);\n    }\n    params.window_size_left = window_size_left;\n    params.window_size_right = window_size_right;\n    params.attention_chunk = attention_chunk;\n\n    auto dprops = get_device_prop();\n    params.arch = dprops->major * 10 + dprops->minor;\n    params.num_sm = dprops->multiProcessorCount - sm_margin;\n\n    #ifdef FLASHATTENTION_DISABLE_LOCAL\n        STD_TORCH_CHECK(!params.is_local, \"This flash attention build does not support local attention.\");\n    #endif\n}\n\nvoid set_params_dgrad(Flash_bwd_params &params,\n                      // sizes\n                      const size_t b,\n                      const size_t seqlen_q,\n                      const size_t seqlen_k,\n                      const size_t seqlen_q_rounded,\n                      const size_t seqlen_k_rounded,\n                      const size_t h,\n                      const size_t h_k,\n                      const size_t d,\n                      const size_t d_rounded,\n                      // device pointers\n                      const Tensor q,\n                      const Tensor k,\n                      const Tensor v,\n                      const Tensor out,\n                      const Tensor dout,\n                      Tensor dq,\n                      Tensor dk,\n                      Tensor dv,\n                      void *cu_seqlens_q_d,\n                      void *cu_seqlens_k_d,\n                      void *seqused_q,\n                      void *seqused_k,\n                      void *dq_accum_d,\n                      void *dk_accum_d,\n                      void *dv_accum_d,\n                      void *softmax_lse_d,\n                      void *dsoftmax_sum_d,\n                      float p_dropout,\n                      float softmax_scale,\n                      int window_size_left,\n                      int window_size_right,\n                      int attention_chunk,\n                      const float softcap=0.f,\n                      bool deterministic=false,\n                      int const sm_margin=0) {\n\n    set_params_fprop(params,\n                     b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,\n                     q, k, v, out,\n                     cu_seqlens_q_d,\n                     cu_seqlens_k_d,\n                     seqused_q,\n                     seqused_k,\n                     softmax_lse_d,\n                     p_dropout,\n                     softmax_scale,\n                     window_size_left,\n                     window_size_right,\n                     attention_chunk,\n                     softcap,\n                     sm_margin);\n\n    // Set the pointers and strides.\n    params.do_ptr = dout.data_ptr();\n    params.do_row_stride = dout.stride(-3);\n    params.do_head_stride = dout.stride(-2);\n    params.dq_ptr = dq.data_ptr();\n    params.dk_ptr = dk.data_ptr();\n    params.dv_ptr = dv.data_ptr();\n    params.dq_row_stride = dq.stride(-3);\n    params.dk_row_stride = dk.stride(-3);\n    params.dv_row_stride = dv.stride(-3);\n    params.dq_head_stride = dq.stride(-2);\n    params.dk_head_stride = dk.stride(-2);\n    params.dv_head_stride = dv.stride(-2);\n\n    if (cu_seqlens_q_d == nullptr) {\n        params.do_batch_stride = dout.stride(0);\n        params.dq_batch_stride = dq.stride(0);\n        params.dk_batch_stride = dk.stride(0);\n        params.dv_batch_stride = dv.stride(0);\n    }\n\n    params.dq_accum_ptr = dq_accum_d;\n    params.dk_accum_ptr = dk_accum_d;\n    params.dv_accum_ptr = dv_accum_d;\n\n    // Softmax sum\n    params.dsoftmax_sum = dsoftmax_sum_d;\n\n    params.deterministic = deterministic;\n}\n\ntemplate <int Arch, int Split, bool PagedKVNonTMA, bool PackGQA, bool Has_softcap>\nvoid run_mha_fwd_constexpr(Flash_fwd_params &params, cudaStream_t stream) {\n    if (!params.is_e4m3) {\n        if (params.is_bf16) {\n            #ifndef FLASHATTENTION_DISABLE_HDIM64\n            if (params.d <= 64) {\n                #ifndef FLASHATTENTION_DISABLE_HDIMDIFF64\n                if constexpr (Arch == 90) {\n                    if (params.dv > 256) {\n                        return run_mha_fwd_<Arch, cutlass::bfloat16_t, 64, 512, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);\n                    } else if (params.dv > 64) {\n                        return run_mha_fwd_<Arch, cutlass::bfloat16_t, 64, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);\n                    }\n                }\n                #endif\n                return run_mha_fwd_<Arch, cutlass::bfloat16_t, 64, 64, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);\n            }\n            #endif\n            #ifndef FLASHATTENTION_DISABLE_HDIM96\n            if (params.d <= 96) { return run_mha_fwd_<Arch, cutlass::bfloat16_t, 96, 96, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }\n            #endif\n            #ifndef FLASHATTENTION_DISABLE_HDIM128\n            if (params.d <= 128) { return run_mha_fwd_<Arch, cutlass::bfloat16_t, 128, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }\n            #endif\n            #ifndef FLASHATTENTION_DISABLE_HDIM192\n            if (params.d <= 192) {\n                #ifndef FLASHATTENTION_DISABLE_HDIMDIFF192\n                if constexpr (Arch == 90) {\n                    if (params.dv <= 128) {\n                        return run_mha_fwd_<Arch, cutlass::bfloat16_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);\n                    }\n                }\n                #endif\n                return run_mha_fwd_<Arch, cutlass::bfloat16_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);\n            }\n            #endif\n            #ifndef FLASHATTENTION_DISABLE_HDIM256\n            if (params.d <= 256) { return run_mha_fwd_<Arch, cutlass::bfloat16_t, 256, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }\n            #endif\n        } else {\n            #ifndef FLASHATTENTION_DISABLE_FP16\n            #ifndef FLASHATTENTION_DISABLE_HDIM64\n            if (params.d <= 64) {\n                #ifndef FLASHATTENTION_DISABLE_HDIMDIFF64\n                if constexpr (Arch == 90) {\n                    if (params.dv > 256) {\n                        return run_mha_fwd_<Arch, cutlass::half_t, 64, 512, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);\n                    } else if (params.dv > 64) {\n                        return run_mha_fwd_<Arch, cutlass::half_t, 64, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);\n                    }\n                }\n                #endif\n                return run_mha_fwd_<Arch, cutlass::half_t, 64, 64, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);\n            }\n            #endif\n            #ifndef FLASHATTENTION_DISABLE_HDIM96\n            if (params.d <= 96) { return run_mha_fwd_<Arch, cutlass::half_t, 96, 96, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }\n            #endif\n            #ifndef FLASHATTENTION_DISABLE_HDIM128\n            if (params.d <= 128) { return run_mha_fwd_<Arch, cutlass::half_t, 128, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }\n            #endif\n            #ifndef FLASHATTENTION_DISABLE_HDIM192\n            if (params.d <= 192) {\n                #ifndef FLASHATTENTION_DISABLE_HDIMDIFF192\n                if constexpr (Arch == 90) {\n                    if (params.dv <= 128) {\n                        return run_mha_fwd_<Arch, cutlass::half_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);\n                    }\n                }\n                #endif\n                return run_mha_fwd_<Arch, cutlass::half_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);\n            }\n            #endif\n            #ifndef FLASHATTENTION_DISABLE_HDIM256\n            if (params.d <= 256) { return run_mha_fwd_<Arch, cutlass::half_t, 256, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }\n            #endif\n            #else\n            STD_TORCH_CHECK(false, \"This flash attention build does not support FP16.\");\n            #endif\n        }\n    } else {\n        #ifndef FLASHATTENTION_DISABLE_FP8\n        #ifndef FLASHATTENTION_DISABLE_HDIM64\n        if (params.d <= 64) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }\n        #endif\n        #ifndef FLASHATTENTION_DISABLE_HDIM96\n        if (params.d <= 96) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }\n        #endif\n        #ifndef FLASHATTENTION_DISABLE_HDIM128\n        if (params.d <= 128) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }\n        #endif\n        #ifndef FLASHATTENTION_DISABLE_HDIM192\n        if (params.d <= 192) {\n            #ifndef FLASHATTENTION_DISABLE_HDIMDIFF192\n            if constexpr (Arch == 90) {\n                if (params.dv <= 128) {\n                    return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);\n                }\n            }\n            #endif\n            return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);\n        }\n        #endif\n        #ifndef FLASHATTENTION_DISABLE_HDIM256\n        if (params.d <= 256) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }\n        #endif\n        #else\n        STD_TORCH_CHECK(false, \"This flash attention build does not support FP8.\");\n        #endif\n    }\n}\n\nvoid run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {\n    // HEADDIM_SWITCH(params.d, [&] {\n    //     run_mha_fwd_<cutlass::half_t, kHeadSize>(params, stream);\n    // });\n    STD_TORCH_CHECK(params.num_splits >= 1);\n    ARCH_SWITCH(params.arch, Arch, [&] {\n        SPLIT_SWITCH(params.num_splits > 1, Split, [&] {\n            PAGEDKV_SWITCH(params.page_table && !params.pagedkv_tma, PagedKVNonTMA, [&] {\n                PACKGQA_SWITCH(params.pack_gqa, PackGQA_, [&] {\n                    // Always enable PackGQA for Sm8x or PagedKVNonTMA or Split to reduce compilation\n                    static constexpr bool PackGQA = PackGQA_ || Arch < 90 || PagedKVNonTMA || Split;\n                    SOFTCAP_SWITCH(params.softcap > 0.0, Has_softcap, [&] {\n                        run_mha_fwd_constexpr<Arch, Split, PagedKVNonTMA, PackGQA, Has_softcap>(params, stream);\n                    });\n                });\n            });\n        });\n    });\n}\n\nvoid run_mha_fwd_combine(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl=false) {\n    #ifndef FLASHATTENTION_DISABLE_SPLIT\n    // If hdim is 96 or 192, it's faster to round them to 128 or 256 respectively\n    // so that kBlockM is smaller and we have more parallelism.\n    if (params.is_fp32) {\n        if (params.dv <= 64) {\n            run_mha_fwd_combine_<float, float, 64>(params, stream, enable_pdl);\n        } else {\n            run_mha_fwd_combine_<float, float, 128>(params, stream, enable_pdl);\n        }\n    } else if (params.is_bf16) {\n        if (params.dv <= 64) {\n            run_mha_fwd_combine_<cutlass::bfloat16_t, float, 64>(params, stream, enable_pdl);\n        } else {\n            run_mha_fwd_combine_<cutlass::bfloat16_t, float, 128>(params, stream, enable_pdl);\n        }\n    } else {\n        if (params.dv <= 64) {\n            run_mha_fwd_combine_<cutlass::half_t, float, 64>(params, stream, enable_pdl);\n        } else {\n            run_mha_fwd_combine_<cutlass::half_t, float, 128>(params, stream, enable_pdl);\n        }\n    }\n    #else\n    STD_TORCH_CHECK(false, \"This flash attention build does not support combine kernels.\");\n    #endif\n}\n\ninline bool get_pagedkv_tma(Flash_fwd_params const& params) {\n    if (params.arch < 90 || !params.page_table || params.leftpad_k || params.knew_ptr) { return false; }\n    // This needs to match the kernel configs\n    auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, false /*paged_kv_non_TMA*/, params.softcap > 0.f);\n    int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90);\n    int const kBlockN = std::get<1>(kBlockMN_kernel_args_sm90);\n    // Heuristic: when seqlen_q <= kBlockM, we're not compute bound, and somehow using TMA is slower,\n    // at least for MLA.\n    return params.page_size % kBlockN == 0 && params.seqlen_q * (params.h / params.h_k) > kBlockM;\n}\n\ninline bool get_pack_gqa(Flash_fwd_params const& params) {\n    // Always enable PackGQA for Sm8x or PagedKVNonTMA or Split to reduce compilation and binary size.\n    // Has little effect on speed.\n    if (params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1) { return true; }\n    #ifdef FLASHATTENTION_DISABLE_PACKGQA\n    return false;\n    #else\n    // params.page_table must already be set\n    if (params.h == params.h_k) { return false; }\n    // This needs to match the kernel configs\n    auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f);\n    int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90);\n    return should_pack_gqa(params.cu_seqlens_q || params.seqused_q, params.seqlen_q, params.h / params.h_k, kBlockM);\n    #endif\n}\n\ninline int get_num_splits(Flash_fwd_params const& params) {\n    #ifdef FLASHATTENTION_DISABLE_SPLIT\n    return 1;\n    #else\n    // Always enable PackGQA for Split\n    // params.page_table must already be set\n    // This needs to match the kernel configs\n    bool varlen = params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k;\n    auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f);\n    // Strictly speaking we need to pass in (varlen && params.num_splits > 1) but num_splits\n    // has not been set here. It's OK though because we might just underestimate kBlockN a bit\n    auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, varlen, params.softcap > 0.f, params.knew_ptr);\n    int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x);\n    int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x);\n    int seqlen_q_packgqa = params.seqlen_q * (params.h / params.h_k);\n    // If is_local, we're not going to load all of seqlen_k\n    int const seqlen_k_loaded = !params.is_local\n        ? params.seqlen_k\n        : std::max(0, std::min(params.seqlen_k, params.window_size_right + params.window_size_left + 1 + kBlockM));\n    int const num_n_blocks = (seqlen_k_loaded + kBlockN - 1) / kBlockN;\n    int const num_m_blocks = (seqlen_q_packgqa + kBlockM - 1) / kBlockM;\n    int const size_one_kv_head = params.seqlen_k * (params.d + params.dv) * (params.is_e4m3 ? 1 : 2);\n    // Always enable PackGQA for Split\n    // If varlen, we use dynamic split, so this heuristic just needs to get an upper bound on num_splits.\n    // We assume the case where there's 1 long sequence and the rest are short, i.e. pretending\n    // that batch = 1.\n    int total_mblocks = (params.num_splits_dynamic_ptr ? 1 : params.b) * params.h_k * num_m_blocks;\n    return num_splits_heuristic(total_mblocks, params.num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, params.is_causal || params.is_local, 128);\n    #endif\n}\n\ninline int get_max_headdim() {\n    #ifndef FLASHATTENTION_DISABLE_HDIM256\n    return 256;\n    #endif\n    #ifndef FLASHATTENTION_DISABLE_HDIM192\n    return 192;\n    #endif\n    #ifndef FLASHATTENTION_DISABLE_HDIM128\n    return 128;\n    #endif\n    #ifndef FLASHATTENTION_DISABLE_HDIM96\n    return 96;\n    #endif\n    #ifndef FLASHATTENTION_DISABLE_HDIM64\n    return 64;\n    #endif\n    return 0;\n}\n\ninline int round_up_headdim(int head_size) {\n    #ifndef FLASHATTENTION_DISABLE_HDIM64\n    if (head_size <= 64) { return 64; }\n    #endif\n    #ifndef FLASHATTENTION_DISABLE_HDIM96\n    if (head_size <= 96) { return 96; }\n    #endif\n    #ifndef FLASHATTENTION_DISABLE_HDIM128\n    if (head_size <= 128) { return 128; }\n    #endif\n    #ifndef FLASHATTENTION_DISABLE_HDIM192\n    if (head_size <= 192) { return 192; }\n    #endif\n    #ifndef FLASHATTENTION_DISABLE_HDIM256\n    if (head_size <= 256) { return 256; }\n    #endif\n    return 256;\n}\n\ninline int round_up_headdimv(int head_size) {\n    if (head_size <= 64) { return 64; }\n    if (head_size <= 96) { return 96; }\n    if (head_size <= 128) { return 128; }\n    if (head_size <= 192) { return 192; }\n    if (head_size <= 256) { return 256; }\n    return 512;\n}\n\n// Only applicable to the case where seqused_k (i.e. cache_seqlens) is available\nTensor\nmha_fwd_get_scheduler_metadata(\n        int64_t batch_size,\n        int64_t max_seqlen_q,\n        int64_t max_seqlen_k,\n        int64_t num_heads,\n        int64_t num_heads_k,\n        int64_t headdim,\n        int64_t headdim_v,\n        torch::headeronly::ScalarType qkv_dtype,\n        Tensor seqused_k, // b\n        std::optional<Tensor> cu_seqlens_q_,  // b+1\n        std::optional<Tensor> cu_seqlens_k_,  // b+1\n        std::optional<Tensor> cu_seqlens_k_new_,  // b+1\n        std::optional<Tensor> seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.\n        std::optional<Tensor> leftpad_k_, // b\n        std::optional<int64_t> page_size,\n        int64_t max_seqlen_k_new,  // 0 means we're not appending new KV\n        bool is_causal,\n        int64_t window_size_left,\n        int64_t window_size_right,\n        int64_t attention_chunk,\n        bool has_softcap,\n        int64_t num_splits,\n        std::optional<bool> pack_gqa_,\n        int64_t sm_margin) {\n\n    STD_TORCH_CHECK(qkv_dtype == torch::headeronly::ScalarType::Half || qkv_dtype == torch::headeronly::ScalarType::BFloat16 || qkv_dtype == torch::headeronly::ScalarType::Float8_e4m3fn,\n                \"FlashAttention only supports fp16, bf16, and fp8_e4m3 data type\");\n    STD_TORCH_CHECK(num_heads % num_heads_k == 0, \"Number of heads in key/value must divide number of heads in query\");\n\n    // Reset the parameters\n    Flash_fwd_params params{};\n    params.is_bf16 = qkv_dtype == torch::headeronly::ScalarType::BFloat16;\n    params.is_e4m3 = qkv_dtype == torch::headeronly::ScalarType::Float8_e4m3fn;\n    params.b = batch_size;\n    params.seqlen_q = max_seqlen_q;\n    params.seqlen_k = max_seqlen_k;\n    params.h = num_heads;\n    params.h_k = num_heads_k;\n    params.d = headdim;\n    params.dv = headdim_v;\n    params.d_rounded = round_up_headdim(headdim);\n    params.dv_rounded = headdim_v == headdim ? params.d_rounded : round_up_headdimv(headdim_v);\n    params.seqlen_knew = max_seqlen_k_new;\n\n    bool const is_varlen_q = cu_seqlens_q_.has_value();\n    params.cu_seqlens_q = is_varlen_q ? static_cast<int*>(cu_seqlens_q_.value().data_ptr()) : nullptr;\n    bool const is_varlen_k = cu_seqlens_k_.has_value();\n    params.cu_seqlens_k = is_varlen_k ?  static_cast<int*>(cu_seqlens_k_.value().data_ptr()) : nullptr;\n    params.cu_seqlens_knew = cu_seqlens_k_new_.has_value() ? static_cast<int*>(cu_seqlens_k_new_.value().data_ptr()): nullptr;\n    params.seqused_q = seqused_q_.has_value() ?  static_cast<int*>(seqused_q_.value().data_ptr()) : nullptr;\n    params.seqused_k = static_cast<int*>(seqused_k.data_ptr());\n    params.leftpad_k = leftpad_k_.has_value() ? static_cast<int*>(leftpad_k_.value().data_ptr()) : nullptr;\n    params.knew_ptr = params.seqlen_knew > 0 ? reinterpret_cast<int*>(1) : nullptr;\n    if (window_size_left >= max_seqlen_k - 1) { window_size_left = -1; }\n    if (window_size_right >= max_seqlen_q - 1) { window_size_right = -1; }\n    // causal=true is the same as causal=false in this case\n    if (max_seqlen_q == 1 && window_size_left == -1 && window_size_right == -1 && attention_chunk == 0) {\n        // Special case of hdim 128 where we want causal to have kBlockN=128, better for pagedKV and TMA\n        if ((headdim <= 64 || headdim > 128) || !page_size.has_value()) {\n            is_causal = false;\n        }\n    }\n    if (is_causal) { window_size_right = 0; }\n\n    params.is_causal = window_size_left < 0 && window_size_right == 0 && attention_chunk == 0;\n    params.is_local = (window_size_left >= 0 || window_size_right >= 0 || attention_chunk >= 1) && !params.is_causal;\n    if (window_size_left < 0) { window_size_left = max_seqlen_k - 1; }\n    if (window_size_right < 0) { window_size_right = max_seqlen_q - 1; }\n    if (attention_chunk > 0) {\n        window_size_left = std::min(window_size_left, attention_chunk - 1);\n        window_size_right = std::min(window_size_right, attention_chunk - 1);\n    }\n    params.window_size_left = window_size_left;\n    params.window_size_right = window_size_right;\n    params.attention_chunk = attention_chunk;\n    auto dprops = get_device_prop();\n    params.arch = dprops->major * 10 + dprops->minor;\n    params.num_sm = dprops->multiProcessorCount - sm_margin;\n    params.softcap = has_softcap ? 1.0f : 0.0f;\n\n    params.page_size = page_size.has_value() ? page_size.value() : 1;\n    params.page_table = !page_size.has_value() ? nullptr : reinterpret_cast<int*>(1);\n\n    bool const use_prepare_varlen = true;\n    params.prepare_varlen_pdl = use_prepare_varlen && params.b <= PREPARE_VARLEN_MAX_BATCHES_1CTA;\n    params.num_splits_dynamic_ptr = !use_prepare_varlen ? nullptr : reinterpret_cast<int*>(1);\n\n    params.pagedkv_tma = get_pagedkv_tma(params);\n    params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits;\n    // Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide\n    params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params);\n\n    bool is_varlen = true;\n\n    // Otherwise the kernel will be launched from cuda:0 device\n    // Cast to char to avoid compiler warning about narrowing\n    auto device_guard = make_device_guard(seqused_k);\n\n    // This needs to be set after get_num_splits\n    Tensor tile_count_semaphore;  // Contains the semaphore and optionally num_splits_dynamic\n    bool const scheduler_needs_semaphore = params.arch >= 90 || params.num_splits > 1;\n    auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };\n    params.varlen_sort_batches = !params.is_local; // Use this value for Sort in scheduler template\n    params.head_swizzle = params.is_causal || params.is_local; // Use this value for LPT in scheduler template\n    if (scheduler_needs_semaphore || use_prepare_varlen) {   \n        int b_rounded = round_multiple(params.b, 4); // for 16 byte alignment of pointers \n        int num_prepare_batch_vectors = use_prepare_varlen ? 2 : 0;\n        if(params.varlen_sort_batches) { num_prepare_batch_vectors += 1; }\n        if(params.head_swizzle) { num_prepare_batch_vectors += 1; }\n        int head_swizzle_offset = b_rounded * (params.varlen_sort_batches ? 3 : 2);\n        int tile_count_semaphore_offset = b_rounded * num_prepare_batch_vectors;\n        // printf(\"(Metadata) num prepare batch vectors = %d.\\n\", num_prepare_batch_vectors);\n        tile_count_semaphore = torch::stable::new_empty(\n            seqused_k,\n            {int(scheduler_needs_semaphore) + tile_count_semaphore_offset},\n            std::make_optional(torch::headeronly::ScalarType::Int));\n        // {num_splits_dynamic, num_m_blocks, varlen_batch_idx, num_nheads_in_l2}\n        params.num_splits_dynamic_ptr = use_prepare_varlen ? static_cast<int*>(tile_count_semaphore.data_ptr()) : nullptr;\n        params.num_m_blocks_ptr =  use_prepare_varlen ? static_cast<int*>(tile_count_semaphore.data_ptr()) + b_rounded : nullptr;\n        params.varlen_batch_idx_ptr =  use_prepare_varlen && params.varlen_sort_batches ? static_cast<int*>(tile_count_semaphore.data_ptr()) + b_rounded * 2 : nullptr;\n        // params.num_n_blocks_ptr  = use_prepare_varlen && params.head_swizzle ? static_cast<int*>(tile_count_semaphore.data_ptr()) + head_swizzle_offset : nullptr;\n        params.num_nheads_in_l2_ptr = use_prepare_varlen && params.head_swizzle ? static_cast<int*>(tile_count_semaphore.data_ptr()) + head_swizzle_offset : nullptr;\n        if (scheduler_needs_semaphore) {\n            if (!use_prepare_varlen) { torch::stable::zero_(tile_count_semaphore); }  // If varlen we'll manually do the zero-ing\n            params.tile_count_semaphore = static_cast<int*>(tile_count_semaphore.data_ptr()) + tile_count_semaphore_offset;\n        } else {\n            params.tile_count_semaphore = nullptr;\n        }\n    }\n\n    if (use_prepare_varlen) {\n        auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f);\n        auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, is_varlen && params.num_splits > 1, params.softcap > 0.f, params.knew_ptr);\n        int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x);\n        int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x);\n        auto device_idx = torch::stable::accelerator::getCurrentDeviceIndex();\n        void* stream_ptr = nullptr;\n        TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream(device_idx, &stream_ptr));\n        cudaStream_t stream = static_cast<cudaStream_t>(stream_ptr);\n        prepare_varlen_num_blocks(params, stream, params.pack_gqa, kBlockM, kBlockN, false /*enable_pdl*/);\n        CHECK_CUDA_KERNEL_LAUNCH();\n    }\n    return tile_count_semaphore;\n}\n\n// b: batch_size\n// b_k: batch_size_k\n// s_q: seqlen_q\n// s_k: seqlen_k\n// s_k_new: seqlen_k_new\n// h: num_heads\n// h_k: num_heads_k\n// d: head_size\nstd::tuple<Tensor, Tensor, Tensor, Tensor>\nmha_fwd(Tensor q,   // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q\n        Tensor k,  // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table.\n        Tensor v,  // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table.\n        std::optional<Tensor> k_new_,  // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new\n        std::optional<Tensor> v_new_,  // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new\n        std::optional<Tensor> q_v_,  // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q\n        std::optional<Tensor> out_,  // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q\n        std::optional<Tensor> cu_seqlens_q_,  // b+1\n        std::optional<Tensor> cu_seqlens_k_,  // b+1\n        std::optional<Tensor> cu_seqlens_k_new_,  // b+1\n        std::optional<Tensor> seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.\n        std::optional<Tensor> seqused_k_, // b. If given, only this many elements of each batch element's keys are used.\n        std::optional<int64_t> max_seqlen_q_,\n        // TODO: check if we need max_seqlen_k\n        std::optional<int64_t> max_seqlen_k_,\n        std::optional<Tensor> page_table_, // (b_k, max_num_pages_per_seq)\n        std::optional<Tensor> kv_batch_idx_, // b. indices to index into the KV cache\n        std::optional<Tensor> leftpad_k_, // b\n        std::optional<Tensor> rotary_cos_, // seqlen_ro x (rotary_dim / 2)\n        std::optional<Tensor> rotary_sin_, // seqlen_ro x (rotary_dim / 2)\n        std::optional<Tensor> seqlens_rotary_, // b\n        std::optional<Tensor> q_descale_,  // (b, h_k), not (b, h)\n        std::optional<Tensor> k_descale_,  // (b, h_k)\n        std::optional<Tensor> v_descale_,  // (b, h_k)\n        std::optional<double> softmax_scale_,\n        bool is_causal,\n        int64_t window_size_left,\n        int64_t window_size_right,\n        int64_t attention_chunk,\n        double softcap,\n        bool is_rotary_interleaved,   // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2\n        std::optional<Tensor> scheduler_metadata_,  // (b + 1)\n        int64_t num_splits,\n        std::optional<bool> pack_gqa_,\n        int64_t sm_margin\n        ) {\n\n    auto dprops = get_device_prop();\n    bool is_sm8x = dprops->major >= 8;\n    STD_TORCH_CHECK(is_sm8x, \"FlashAttention only supports Ampere GPUs or newer.\");\n\n    auto q_type = q.scalar_type();\n    STD_TORCH_CHECK(q_type == torch::headeronly::ScalarType::Half || q_type == torch::headeronly::ScalarType::BFloat16 || q_type == torch::headeronly::ScalarType::Float8_e4m3fn,\n                \"FlashAttention only supports fp16, bf16, and fp8_e4m3 data type\");\n    if (dprops->major < 9) {\n        STD_TORCH_CHECK(q_type == torch::headeronly::ScalarType::Half || q_type == torch::headeronly::ScalarType::BFloat16,\n                    \"FlashAttention on Ampere/Ada cards only supports fp16 and bf16 data type\");\n    }\n    STD_TORCH_CHECK(k.scalar_type() == q_type, \"query and key must have the same dtype\");\n    STD_TORCH_CHECK(v.scalar_type() == q_type, \"query and value must have the same dtype\");\n\n    CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);\n\n    STD_TORCH_CHECK(q.stride(-1) == 1, \"Input tensor must have contiguous last dimension\");\n    STD_TORCH_CHECK(k.stride(-1) == 1, \"Input tensor must have contiguous last dimension\");\n    STD_TORCH_CHECK(v.stride(-1) == 1, \"Input tensor must have contiguous last dimension\");\n\n    Tensor page_table;\n    const bool paged_KV = page_table_.has_value();\n    if (paged_KV) {\n        page_table = page_table_.value();\n        CHECK_DEVICE(page_table);\n        STD_TORCH_CHECK(page_table.scalar_type() == torch::headeronly::ScalarType::Int, \"page_table must have dtype torch.int32\");\n        STD_TORCH_CHECK(page_table.stride(-1) == 1, \"page_table must have contiguous last dimension\");\n    }\n\n    Tensor cu_seqlens_q;\n    bool const is_varlen_q = cu_seqlens_q_.has_value();\n    if (is_varlen_q) {\n        cu_seqlens_q = cu_seqlens_q_.value();\n        CHECK_DEVICE(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_q);\n        STD_TORCH_CHECK(cu_seqlens_q.scalar_type() == torch::headeronly::ScalarType::Int, \"cu_seqlens_q must have dtype torch.int32\");\n        STD_TORCH_CHECK(max_seqlen_q_.has_value(), \"max_seqlen_q must be provided if cu_seqlens_q is provided\");\n    }\n    Tensor cu_seqlens_k;\n    bool const is_varlen_k = cu_seqlens_k_.has_value();\n    if (is_varlen_k) {\n        cu_seqlens_k = cu_seqlens_k_.value();\n        CHECK_DEVICE(cu_seqlens_k); CHECK_CONTIGUOUS(cu_seqlens_k);\n        STD_TORCH_CHECK(cu_seqlens_k.scalar_type() == torch::headeronly::ScalarType::Int, \"cu_seqlens_k must have dtype torch.int32\");\n        STD_TORCH_CHECK(max_seqlen_k_.has_value(), \"max_seqlen_k must be provided if cu_seqlens_k is provided\");\n        STD_TORCH_CHECK(!paged_KV, \"If cu_seqlens_k is passed in, then page table is not supported\");\n        STD_TORCH_CHECK(!kv_batch_idx_.has_value(), \"If cu_seqlens_k is passed in, then page table is not supported\");\n    }\n\n    const int batch_size = !is_varlen_q ? q.size(0) : cu_seqlens_q.size(0) - 1;\n    int seqlen_q = !is_varlen_q ? q.size(1) : max_seqlen_q_.value();\n    int total_q = !is_varlen_q ? batch_size * q.size(1) : q.size(0);\n    int num_heads = q.size(-2);\n    int const head_size = q.size(-1);\n    int const head_size_v = v.size(-1);\n    int const max_num_pages_per_seq = !paged_KV ? 0 : page_table.size(1);\n    int const num_pages = !paged_KV ? 0 : k.size(0);\n    int const page_size = !paged_KV ? 1 : k.size(1);\n    int const seqlen_k = !is_varlen_k ? (!paged_KV ? k.size(1) : max_num_pages_per_seq * page_size) : max_seqlen_k_.value();\n    int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0);\n    int const num_heads_k = k.size(-2);\n    int const batch_size_k = !paged_KV ? (!is_varlen_k ? k.size(0) : cu_seqlens_k.size(0) - 1) : page_table.size(0);\n    double softmax_scale = 1.0 / sqrt(double(head_size));\n    if (softmax_scale_.has_value()) {\n        softmax_scale = softmax_scale_.value();\n    }\n    if (!kv_batch_idx_.has_value()) {\n        STD_TORCH_CHECK(batch_size == batch_size_k, \"batch_size must be equal to batch_size_k\");\n    }\n    int const max_headdim = get_max_headdim();\n    STD_TORCH_CHECK(head_size <= max_headdim, \"FlashAttention forward only supports head dimension at most \" + std::to_string(max_headdim));\n    STD_TORCH_CHECK(num_heads % num_heads_k == 0, \"Number of heads in key/value must divide number of heads in query\");\n    if (head_size_v != head_size) {\n        STD_TORCH_CHECK((head_size > 128 && head_size <= 192 && head_size_v > 96 && head_size_v <= 128) ||\n                   (head_size <= 64 && head_size_v <= 512),\n                   \"If V headdim is different from Q/K dim, we only support Q/K headdim in (128, 192] and V headdim in (96, 128], \"\n                   \"or (Q/K <= 64 and V <= 512).\");\n        STD_TORCH_CHECK(dprops->major == 9, \"Only Hopper supports different V headdim\");\n        if (head_size_v > 256) {\n            STD_TORCH_CHECK(q_type == torch::headeronly::ScalarType::Half || q_type == torch::headeronly::ScalarType::BFloat16,\n                        \"HeaddimV > 256 requires fp16 and bf16 data type\");\n        }\n    }\n\n    // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM\n    // TODO: check this\n    if (window_size_left >= seqlen_k - 1) { window_size_left = -1; }\n    if (window_size_right >= seqlen_q - 1) { window_size_right = -1; }\n    // causal=true is the same as causal=false in this case\n    if (seqlen_q == 1 && window_size_left == -1 && window_size_right == -1 && attention_chunk == 0) {\n        // Special case of hdim 128 where we want causal to have kBlockN=128, better for pagedKV and TMA\n        if ((head_size <= 64 || head_size > 128) || !paged_KV) {\n            is_causal = false;\n        }\n    }\n    if (is_causal) { window_size_right = 0; }\n\n    if (!is_varlen_q) {\n        CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);\n    } else {\n        CHECK_SHAPE(q, total_q, num_heads, head_size);\n        CHECK_SHAPE(cu_seqlens_q, batch_size + 1);\n    }\n    if (!paged_KV) {\n        if (!is_varlen_k) {\n            CHECK_SHAPE(k, batch_size_k, seqlen_k, num_heads_k, head_size);\n            CHECK_SHAPE(v, batch_size_k, seqlen_k, num_heads_k, head_size_v);\n        } else {\n            CHECK_SHAPE(k, total_k, num_heads_k, head_size);\n            CHECK_SHAPE(v, total_k, num_heads_k, head_size_v);\n            CHECK_SHAPE(cu_seqlens_k, batch_size + 1);\n        }\n    } else {\n        CHECK_SHAPE(k, num_pages, page_size, num_heads_k, head_size);\n        CHECK_SHAPE(v, num_pages, page_size, num_heads_k, head_size_v);\n        CHECK_SHAPE(page_table, batch_size_k, max_num_pages_per_seq);\n    }\n\n    if (seqused_q_.has_value()){\n        auto seqused_q = seqused_q_.value();\n        STD_TORCH_CHECK(seqused_q.scalar_type() == torch::headeronly::ScalarType::Int, \"seqused_q must have dtype int32\");\n        CHECK_DEVICE(seqused_q); CHECK_CONTIGUOUS(seqused_q);\n        CHECK_SHAPE(seqused_q, batch_size);\n    }\n    if (seqused_k_.has_value()) {\n        auto seqused_k = seqused_k_.value();\n        STD_TORCH_CHECK(seqused_k.scalar_type() == torch::headeronly::ScalarType::Int, \"seqused_k must have dtype int32\");\n        CHECK_DEVICE(seqused_k); CHECK_CONTIGUOUS(seqused_k);\n        CHECK_SHAPE(seqused_k, batch_size);\n    }\n\n    if (leftpad_k_.has_value()) {\n        auto leftpad_k = leftpad_k_.value();\n        STD_TORCH_CHECK(leftpad_k.scalar_type() == torch::headeronly::ScalarType::Int, \"leftpad_k must have dtype int32\");\n        CHECK_DEVICE(leftpad_k); CHECK_CONTIGUOUS(leftpad_k);\n        CHECK_SHAPE(leftpad_k, batch_size);\n    }\n\n    // This is what we will template on\n    bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value() || leftpad_k_.has_value();\n    #ifdef FLASHATTENTION_DISABLE_VARLEN\n        STD_TORCH_CHECK(!is_varlen, \"This flash attention build does not support varlen.\");\n    #endif\n\n    int const alignment = q_type == torch::headeronly::ScalarType::Float8_e4m3fn ? 16 : 8;\n    STD_TORCH_CHECK(head_size % alignment == 0, \"head_size should be a multiple of \" + std::to_string(alignment));\n    STD_TORCH_CHECK(head_size_v % alignment == 0, \"head_size_v should be a multiple of \" + std::to_string(alignment));\n\n    auto out_type = q_type == torch::headeronly::ScalarType::Float8_e4m3fn ? torch::headeronly::ScalarType::BFloat16 : q_type;\n    Tensor out;\n    if (out_.has_value()) {\n        out = out_.value();\n        STD_TORCH_CHECK(out.scalar_type() == out_type, \"For FP16/BF16 input, output must have the same dtype as inputs. For FP8 input, output must have dtype BF16\");\n        CHECK_DEVICE(out);\n        STD_TORCH_CHECK(out.stride(-1) == 1, \"Output tensor must have contiguous last dimension\");\n        if (!is_varlen_q) {\n            CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_v);\n        } else {\n            CHECK_SHAPE(out, total_q, num_heads, head_size_v);\n        }\n    } else {\n        out = !is_varlen_q\n            ? torch::stable::new_empty(q, {batch_size, seqlen_q, num_heads, head_size_v}, std::make_optional(out_type))\n            : torch::stable::new_empty(q, {total_q, num_heads, head_size_v}, std::make_optional(out_type));\n    }\n\n    auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };\n    int const head_size_rounded = round_up_headdim(head_size);\n    int const head_size_v_rounded = head_size_v == head_size ? head_size_rounded : round_up_headdimv(head_size_v);\n    int const seqlen_q_rounded = round_multiple(seqlen_q, 128);\n    int const seqlen_k_rounded = round_multiple(seqlen_k, 128);\n\n    // Otherwise the kernel will be launched from cuda:0 device\n    // Cast to char to avoid compiler warning about narrowing\n    auto device_guard = make_device_guard(q);\n\n    Tensor softmax_lse;\n    if (!is_varlen_q) {\n        softmax_lse = torch::stable::new_empty(q, {batch_size, num_heads, seqlen_q}, std::make_optional(torch::headeronly::ScalarType::Float));\n    } else {\n        softmax_lse = torch::stable::new_empty(q, {num_heads, total_q}, std::make_optional(torch::headeronly::ScalarType::Float));\n    }\n\n    Flash_fwd_params params;\n    set_params_fprop(params,\n                     batch_size,\n                     seqlen_q, seqlen_k,\n                     seqlen_q_rounded, seqlen_k_rounded,\n                     num_heads, num_heads_k,\n                     head_size, head_size_rounded,\n                     q, k, v, out,\n                     !is_varlen_q ? nullptr : cu_seqlens_q.data_ptr(),\n                     !is_varlen_k ? nullptr : cu_seqlens_k.data_ptr(),\n                     seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr,\n                     seqused_k_.has_value() ? seqused_k_.value().data_ptr() : nullptr,\n                     softmax_lse.data_ptr(),\n                     /*p_dropout=*/0.f,\n                     softmax_scale,\n                     window_size_left,\n                     window_size_right,\n                     attention_chunk,\n                     softcap,\n                     sm_margin);\n    params.total_q = total_q;\n    params.total_k = total_k;\n    params.b_k = batch_size_k;\n    params.dv = head_size_v;\n    params.dv_rounded = head_size_v_rounded;\n    if (leftpad_k_.has_value()) {  // This needs to be set before get_pagedkv_tma\n        params.leftpad_k = static_cast<int *>(leftpad_k_.value().data_ptr());\n    }\n    if (paged_KV) {\n        params.page_table = static_cast<int*>(page_table.data_ptr());\n        params.page_table_batch_stride = page_table.stride(0);\n    }\n    params.page_size = page_size;\n    params.num_pages = num_pages;\n\n    if (k_new_.has_value()) {  // This needs to be set before get_pagedkv_tma\n        Tensor k_new, v_new;\n        STD_TORCH_CHECK(v_new_.has_value(), \"If k_new is supplied, v_new must also be passed in\");\n        STD_TORCH_CHECK(seqused_k_.has_value(), \"If k_new is supplied, seqlens_k must also be passed in\");\n        STD_TORCH_CHECK(seqlen_q <= seqlen_k, \"If k_new is supplied, it must have seqlen <= the seqlen of the KV cache\");\n        Tensor cu_seqlens_k_new;\n        bool const is_varlen_k_new = cu_seqlens_k_new_.has_value();\n        if (is_varlen_k_new) {\n            cu_seqlens_k_new = cu_seqlens_k_new_.value();\n            CHECK_DEVICE(cu_seqlens_k_new); CHECK_CONTIGUOUS(cu_seqlens_k_new);\n            STD_TORCH_CHECK(cu_seqlens_k_new.scalar_type() == torch::headeronly::ScalarType::Int, \"cu_seqlens_k_new must have dtype torch.int32\");\n        }\n        k_new = k_new_.value();\n        v_new = v_new_.value();\n        STD_TORCH_CHECK(k_new.scalar_type() == q_type, \"k_new must have the same dtype as query\");\n        STD_TORCH_CHECK(v_new.scalar_type() == q_type, \"v_new must have the same dtype as query\");\n        CHECK_DEVICE(k_new); CHECK_DEVICE(v_new);\n        STD_TORCH_CHECK(k_new.stride(-1) == 1, \"k_new tensor must have contiguous last dimension\");\n        STD_TORCH_CHECK(v_new.stride(-1) == 1, \"v_new tensor must have contiguous last dimension\");\n        // We don't need max_seqlen_k_new, so seqlen_k_new can be whatever when is_varlen_k_new\n        int seqlen_k_new = !is_varlen_k_new ? k_new.size(1) : 0;\n        int total_k_new = !is_varlen_k_new ? batch_size * k_new.size(1): k_new.size(0);\n        if (!is_varlen_k_new) {\n            CHECK_SHAPE(k_new, batch_size, seqlen_k_new, num_heads_k, head_size);\n            CHECK_SHAPE(v_new, batch_size, seqlen_k_new, num_heads_k, head_size_v);\n        } else {\n            CHECK_SHAPE(k_new, total_k_new, num_heads_k, head_size);\n            CHECK_SHAPE(v_new, total_k_new, num_heads_k, head_size_v);\n            CHECK_SHAPE(cu_seqlens_k_new, batch_size + 1);\n        }\n        params.seqlen_knew = seqlen_k_new;\n        params.total_knew = total_k_new;\n        params.knew_ptr = k_new.data_ptr();\n        params.vnew_ptr = v_new.data_ptr();\n        // All stride are in elements, not bytes.\n        params.knew_row_stride = k_new.stride(-3);\n        params.vnew_row_stride = v_new.stride(-3);\n        params.knew_head_stride = k_new.stride(-2);\n        params.vnew_head_stride = v_new.stride(-2);\n        if (!is_varlen_k_new) {\n            params.knew_batch_stride = k_new.stride(0);\n            params.vnew_batch_stride = v_new.stride(0);\n        }\n        if (is_varlen_k_new) {\n            params.cu_seqlens_knew = static_cast<int*>(cu_seqlens_k_new.data_ptr());\n        }\n    }\n    \n    bool const use_prepare_varlen = is_varlen;\n    params.prepare_varlen_pdl = use_prepare_varlen && params.b <= PREPARE_VARLEN_MAX_BATCHES_1CTA;\n    // Temporarily set num_splits_dynamic_ptr to 1 since get_num_splits checks it\n    params.num_splits_dynamic_ptr = !use_prepare_varlen ? nullptr : reinterpret_cast<int*>(1);\n\n    params.pagedkv_tma = get_pagedkv_tma(params);\n    params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits;\n    // Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide\n    params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params);\n\n    // This needs to be set after get_num_splits\n    Tensor tile_count_semaphore;  // Contains the semaphore and optionally num_splits_dynamic\n    // We don't use the persistent scheduler if Split and not Varlen\n    bool const scheduler_needs_semaphore = params.arch >= 90\n        ? (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen)\n        : ((params.is_causal && !is_varlen) || (is_varlen && params.num_splits > 1));\n    params.varlen_sort_batches = !params.is_local; // Use this value for Sort in scheduler template\n    params.head_swizzle = params.is_causal || params.is_local; // Use this value for LPT in scheduler template\n    if (scheduler_needs_semaphore || use_prepare_varlen) {\n        int b_rounded = round_multiple(params.b, 4); // for 16 byte alignment of pointers\n        int num_prepare_batch_vectors = use_prepare_varlen ? 2 : 0;\n        if(params.varlen_sort_batches) { num_prepare_batch_vectors += 1; }\n        if(params.head_swizzle) { num_prepare_batch_vectors += 1; }\n        int head_swizzle_offset = b_rounded * (params.varlen_sort_batches ? 3 : 2);\n        int tile_count_semaphore_offset = b_rounded * num_prepare_batch_vectors;\n        int metadata_size = int(scheduler_needs_semaphore) + tile_count_semaphore_offset;\n        // printf(\"Num prepare batch vectors = %d, metadata_size = %d.\\n\", num_prepare_batch_vectors, metadata_size);\n        params.skip_scheduler_metadata_computation = scheduler_metadata_.has_value();\n        if (scheduler_metadata_.has_value()) {\n            Tensor scheduler_metadata = scheduler_metadata_.value();\n            CHECK_DEVICE(scheduler_metadata);\n            CHECK_SHAPE(scheduler_metadata, metadata_size);\n            CHECK_CONTIGUOUS(scheduler_metadata);\n            STD_TORCH_CHECK(scheduler_metadata.scalar_type() == torch::headeronly::ScalarType::Int, \"scheduler_metadata must have dtype int32\");\n            tile_count_semaphore = scheduler_metadata;\n        } else {\n            tile_count_semaphore = torch::stable::new_empty(q, {metadata_size}, torch::headeronly::ScalarType::Int);\n        }\n        if (scheduler_needs_semaphore && !use_prepare_varlen) {\n            torch::stable::zero_(tile_count_semaphore);  // If varlen we'll manually do the zero-ing\n        }\n        // {num_splits_dynamic, num_m_blocks, varlen_batch_idx, num_nheads_in_l2}\n        params.num_splits_dynamic_ptr = use_prepare_varlen ? static_cast<int*>(tile_count_semaphore.data_ptr()) : nullptr;\n        params.num_m_blocks_ptr =  use_prepare_varlen ? static_cast<int*>(tile_count_semaphore.data_ptr()) + b_rounded : nullptr;\n        params.varlen_batch_idx_ptr =  use_prepare_varlen && params.varlen_sort_batches ? static_cast<int*>(tile_count_semaphore.data_ptr()) + b_rounded * 2 : nullptr;\n        // params.num_n_blocks_ptr  = use_prepare_varlen && params.head_swizzle ? static_cast<int*>(tile_count_semaphore.data_ptr()) + head_swizzle_offset : nullptr;\n        params.num_nheads_in_l2_ptr = use_prepare_varlen && params.head_swizzle ? static_cast<int*>(tile_count_semaphore.data_ptr()) + head_swizzle_offset : nullptr;\n        params.tile_count_semaphore = scheduler_needs_semaphore ? static_cast<int*>(tile_count_semaphore.data_ptr()) + tile_count_semaphore_offset : nullptr;\n        params.tile_count_semaphore_offset = tile_count_semaphore_offset; // might need to zero out semaphore later\n    }\n\n    if (q_v_.has_value()) {\n        STD_TORCH_CHECK(head_size <= 64, \"q_v is only supported for head_size <= 64\");\n        STD_TORCH_CHECK(head_size_v >= 256, \"q_v is only supported for hdim_v >= 256.\");\n        STD_TORCH_CHECK(q_type == torch::headeronly::ScalarType::Half || q_type == torch::headeronly::ScalarType::BFloat16,\n                    \"q_v is only supported for fp16 and bf16 data type\");\n        STD_TORCH_CHECK(params.arch == 90, \"q_v is only supported for Hopper GPUs\");\n        Tensor q_v = q_v_.value();\n        STD_TORCH_CHECK(q_v.scalar_type() == q_type, \"q_v must have the same dtype as query\");\n        CHECK_DEVICE(q_v);\n        STD_TORCH_CHECK(q_v.stride(-1) == 1, \"q_v tensor must have contiguous last dimension\");\n        if (!is_varlen_q) {\n            CHECK_SHAPE(q_v, batch_size, seqlen_q, num_heads, head_size_v);\n        } else {\n            CHECK_SHAPE(q_v, total_q, num_heads, head_size_v);\n        }\n        params.qv_ptr = q_v.data_ptr();\n        // All stride are in elements, not bytes.\n        params.qv_row_stride = q_v.stride(-3);\n        params.qv_head_stride = q_v.stride(-2);\n        if (!is_varlen_q) {\n            params.qv_batch_stride = q_v.stride(0);\n        }\n    }\n\n    if (rotary_cos_.has_value()) {\n        STD_TORCH_CHECK(k_new_.has_value(), \"If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided\");\n        auto rotary_cos = rotary_cos_.value();\n        CHECK_DEVICE(rotary_cos); CHECK_CONTIGUOUS(rotary_cos);\n        params.rotary_dim = rotary_cos.size(1) * 2;\n        STD_TORCH_CHECK(params.rotary_dim <= head_size, \"rotary_dim must be <= headdim\");\n        STD_TORCH_CHECK(params.rotary_dim % 16 == 0, \"Only rotary dimensions divisible by 16 are currently supported\");\n        const int seqlen_ro = rotary_cos.size(0);\n        if (paged_KV) {\n            STD_TORCH_CHECK(seqlen_ro >= seqlen_k, \"cos/sin seqlen must be at least the seqlen of KV cache\");\n        }\n        CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2);\n        STD_TORCH_CHECK(rotary_cos.scalar_type() == q_type, \"rotary_cos must have the same dtype as query\");\n\n        STD_TORCH_CHECK(rotary_sin_.has_value(), \"If rotary cos is provided, rotary sin must also be provided\");\n        auto rotary_sin = rotary_sin_.value();\n        CHECK_DEVICE(rotary_sin); CHECK_CONTIGUOUS(rotary_sin);\n        CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2);\n        STD_TORCH_CHECK(rotary_sin.scalar_type() == q_type, \"rotary_cos must have the same dtype as query\");\n        params.rotary_cos_ptr = rotary_cos.data_ptr();\n        params.rotary_sin_ptr = rotary_sin.data_ptr();\n        params.is_rotary_interleaved = is_rotary_interleaved;\n        if (seqlens_rotary_.has_value()) {\n            Tensor seqlens_rotary = seqlens_rotary_.value();\n            CHECK_DEVICE(seqlens_rotary); CHECK_CONTIGUOUS(seqlens_rotary);\n            STD_TORCH_CHECK(seqlens_rotary.scalar_type() == torch::headeronly::ScalarType::Int, \"seqlens_rotary must have dtype torch.int32\");\n            CHECK_SHAPE(seqlens_rotary, batch_size);\n            params.seqlens_rotary = static_cast<int*>(seqlens_rotary.data_ptr());\n        }\n    } else {\n        params.rotary_dim = 0;\n    }\n\n    if (kv_batch_idx_.has_value()) {\n        auto kv_batch_idx = kv_batch_idx_.value();\n        CHECK_DEVICE(kv_batch_idx); CHECK_CONTIGUOUS(kv_batch_idx);\n        STD_TORCH_CHECK(kv_batch_idx.scalar_type() == torch::headeronly::ScalarType::Int, \"kv_batch_idx must have dtype int32\");\n        params.kv_batch_idx = reinterpret_cast<int *>(kv_batch_idx.data_ptr());\n    }\n\n    Tensor out_accum, softmax_lse_accum;\n    auto outaccum_type = torch::headeronly::ScalarType::Float;\n    if (params.num_splits > 1) {\n        STD_TORCH_CHECK(params.num_splits <= 256, \"num_splits > 256 not supported\");\n        if (!is_varlen_q) {\n            out_accum = torch::stable::new_empty(q, {params.num_splits, batch_size, num_heads, seqlen_q, head_size_v}, std::make_optional(outaccum_type));\n            softmax_lse_accum = torch::stable::new_empty(q, {params.num_splits, batch_size, num_heads, seqlen_q}, std::make_optional(torch::headeronly::ScalarType::Float));\n            params.oaccum_batch_stride = out_accum.stride(1);\n            params.lseaccum_batch_stride = softmax_lse_accum.stride(1);\n        } else {\n            out_accum = torch::stable::new_empty(q, {params.num_splits, num_heads, total_q, head_size_v}, std::make_optional(outaccum_type));\n            softmax_lse_accum = torch::stable::new_empty(q, {params.num_splits, num_heads, total_q}, std::make_optional(torch::headeronly::ScalarType::Float));\n        }\n        params.is_fp32 = false;\n        params.oaccum_ptr = out_accum.data_ptr();\n        params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();\n        params.oaccum_split_stride = out_accum.stride(0);\n        params.oaccum_row_stride = out_accum.stride(-2);\n        params.oaccum_head_stride = out_accum.stride(-3);\n        params.lseaccum_split_stride = softmax_lse_accum.stride(0);\n        params.lseaccum_head_stride = softmax_lse_accum.stride(-2);\n    }\n\n    if (q_type == torch::headeronly::ScalarType::Float8_e4m3fn) {\n        if (q_descale_.has_value()) {\n            auto q_descale = q_descale_.value();\n            CHECK_DEVICE(q_descale);\n            CHECK_SHAPE(q_descale, batch_size, num_heads_k);\n            params.q_descale_ptr = static_cast<float*>(q_descale.data_ptr());\n            params.q_descale_batch_stride = q_descale.stride(0);\n            params.q_descale_head_stride = q_descale.stride(1);\n        } else {\n            params.q_descale_ptr = nullptr;\n        }\n        if (k_descale_.has_value()) {\n            auto k_descale = k_descale_.value();\n            CHECK_DEVICE(k_descale);\n            CHECK_SHAPE(k_descale, batch_size, num_heads_k);\n            params.k_descale_ptr = static_cast<float*>(k_descale.data_ptr());\n            params.k_descale_batch_stride = k_descale.stride(0);\n            params.k_descale_head_stride = k_descale.stride(1);\n        } else {\n            params.k_descale_ptr = nullptr;\n        }\n        if (v_descale_.has_value()) {\n            auto v_descale = v_descale_.value();\n            CHECK_DEVICE(v_descale);\n            CHECK_SHAPE(v_descale, batch_size, num_heads_k);\n            params.v_descale_ptr = static_cast<float*>(v_descale.data_ptr());\n            params.v_descale_batch_stride = v_descale.stride(0);\n            params.v_descale_head_stride = v_descale.stride(1);\n        } else {\n            params.v_descale_ptr = nullptr;\n        }\n    }\n\n    #ifdef FLASHATTENTION_DISABLE_LOCAL\n    STD_TORCH_CHECK(!params.is_local, \"This flash attention build does not support local attention.\");\n    #endif\n    #ifdef FLASHATTENTION_DISABLE_SOFTCAP\n    STD_TORCH_CHECK(params.softcap == 0.0, \"This flash attention build does not support tanh softcapping.\");\n    #endif\n    #ifdef FLASHATTENTION_DISABLE_SPLIT\n    STD_TORCH_CHECK(params.num_splits == 1, \"This flash attention build does not support splits.\");\n    #endif\n    #ifdef FLASHATTENTION_DISABLE_PACKGQA\n    STD_TORCH_CHECK(!params.pack_gqa || params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1, \"This flash attention build does not support pack_gqa.\");\n    #endif\n    #ifdef FLASHATTENTION_DISABLE_PAGEDKV\n    STD_TORCH_CHECK(!(params.page_table && !params.pagedkv_tma), \"This flash attention build does not support paged KV.\");\n    #endif\n    #ifdef FLASHATTENTION_DISABLE_APPENDKV\n    STD_TORCH_CHECK(!k_new_.has_value(), \"This flash attention build does not support appending KV.\");\n    #endif\n\n    if (total_q > 0 && (total_k + params.total_knew) > 0 && num_heads_k > 0) {\n        auto device_idx = torch::stable::accelerator::getCurrentDeviceIndex();\n        void* stream_ptr = nullptr;\n        TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream(device_idx, &stream_ptr));\n        cudaStream_t stream = static_cast<cudaStream_t>(stream_ptr);\n        run_mha_fwd(params, stream);\n        if (params.num_splits > 1) {\n            if (out_type == torch::headeronly::ScalarType::BFloat16) {\n                // Since we want output in BF16. Otherwise fwd_combine will output to FP16\n                params.is_bf16 = true;\n            }\n            // Unless there's seqused_q, for the purpose of attn_combine, we can just treat it as batch=1\n            // and seqlen = total_q, and don't need to dispatch to Varlen there.\n            // However, with dynamic split, each row needs to know which batch it belongs to\n            // to read the number of splits, so we just use the varlen version of combine kernel.\n            // if (is_varlen_q && !seqused_q_.has_value()) {\n            // if (is_varlen_q) {\n            //     params.b = 1;\n            //     params.seqlen_q = total_q;\n            // }\n            // This will zero out the semaphore if needed\n            run_mha_fwd_combine(params, stream, true /*enable_pdl*/);\n        } else if (scheduler_needs_semaphore && params.skip_scheduler_metadata_computation) {\n            // need to zero out the semaphore in this case\n            auto slice = torch::stable::narrow(tile_count_semaphore, 0, params.tile_count_semaphore_offset, 1);\n            torch::stable::zero_(slice);\n        }\n    } else if (total_q > 0 && num_heads_k > 0) {\n        // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.\n        torch::stable::zero_(out);\n        torch::stable::fill_(softmax_lse, std::numeric_limits<float>::infinity());\n    }\n\n    // return {out, softmax_lse};\n    return {out, softmax_lse, out_accum, softmax_lse_accum};\n}\n\n#ifdef FLASHATTENTION_DISABLE_BACKWARD\nvoid run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream) {\n    STD_TORCH_CHECK(false, \"Flash-Attention was built with backward disabled\");\n}\n#else\ntemplate <int Arch, bool Has_softcap>\nvoid run_mha_bwd_constexpr(Flash_bwd_params &params, cudaStream_t stream) {\n    if (!params.is_bf16) {\n        #ifndef FLASHATTENTION_DISABLE_FP16\n        #ifndef FLASHATTENTION_DISABLE_HDIM64\n        if (params.d_rounded == 64) { return run_mha_bwd_<Arch, cutlass::half_t, 64, Has_softcap>(params, stream); }\n        #endif\n        #ifndef FLASHATTENTION_DISABLE_HDIM96\n        if (params.d_rounded == 96) { return run_mha_bwd_<Arch, cutlass::half_t, 96, Has_softcap>(params, stream); }\n        #endif\n        #ifndef FLASHATTENTION_DISABLE_HDIM128\n        if (params.d_rounded == 128) { return run_mha_bwd_<Arch, cutlass::half_t, 128, Has_softcap>(params, stream); }\n        #endif\n        #ifndef FLASHATTENTION_DISABLE_HDIM192\n        if (params.d_rounded == 192) { return run_mha_bwd_<Arch, cutlass::half_t, 192, Has_softcap>(params, stream); }\n        #endif\n        #ifndef FLASHATTENTION_DISABLE_HDIM256\n        if (params.d_rounded == 256) { return run_mha_bwd_<Arch, cutlass::half_t, 256, Has_softcap>(params, stream); }\n        #endif\n        #else\n        STD_TORCH_CHECK(false, \"This flash attention build does not support FP16.\");\n        #endif\n    } else {\n        #ifndef FLASHATTENTION_DISABLE_HDIM64\n        if (params.d_rounded == 64) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 64, Has_softcap>(params, stream); }\n        #endif\n        #ifndef FLASHATTENTION_DISABLE_HDIM96\n        if (params.d_rounded == 96) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 96, Has_softcap>(params, stream); }\n        #endif\n        #ifndef FLASHATTENTION_DISABLE_HDIM128\n        if (params.d_rounded == 128) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 128, Has_softcap>(params, stream); }\n        #endif\n        #ifndef FLASHATTENTION_DISABLE_HDIM192\n        if (params.d_rounded == 192) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 192, Has_softcap>(params, stream); }\n        #endif\n        #ifndef FLASHATTENTION_DISABLE_HDIM256\n        if (params.d_rounded == 256) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 256, Has_softcap>(params, stream); }\n        #endif\n    }\n}\n\nvoid run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream) {\n        // FP16_SWITCH(!params.is_bf16, [&] {\n        //     HEADDIM_SWITCH(params.d, [&] {\n        //         run_mha_bwd_<elem_type, kHeadDim>(params, stream);\n        //     });\n        // });\n    ARCH_SWITCH(params.arch, Arch, [&] {\n        SOFTCAP_SWITCH(params.softcap > 0.f, Has_softcap, [&] {\n            run_mha_bwd_constexpr<Arch, Has_softcap>(params, stream);\n        });\n    });\n}\n#endif\n\n\n// b: batch_size\n// s_q: seqlen_q\n// s_k: seqlen_k\n// h: num_heads\n// h_k: num_heads_k\n// d: head_size\nstd::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> mha_bwd(\n    Tensor dout,  // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q\n    Tensor q,     // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q\n    Tensor k,     // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k\n    Tensor v,     // (b, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k\n    Tensor out,   // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q\n    Tensor softmax_lse,    // (b, h, s_q) or (h, total_q) if there is cu_seqlens_q\n    std::optional<Tensor> dq_,   // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q\n    std::optional<Tensor> dk_,   // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k\n    std::optional<Tensor> dv_,   // (b, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k\n    std::optional<Tensor> cu_seqlens_q_,   // b+1\n    std::optional<Tensor> cu_seqlens_k_,   // b+1\n    std::optional<Tensor> seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.\n    std::optional<Tensor> seqused_k_, // b. If given, only this many elements of each batch element's keys are used.\n    std::optional<int64_t> max_seqlen_q_,\n    std::optional<int64_t> max_seqlen_k_,\n    std::optional<double> softmax_scale_,\n    bool is_causal,\n    int64_t window_size_left,\n    int64_t window_size_right,\n    double softcap,\n    bool deterministic,\n    int64_t sm_margin\n) {\n\n    #ifdef FLASHATTENTION_DISABLE_BACKWARD\n        STD_TORCH_CHECK(false, \"This flash attention build does not support backward.\");\n    #endif\n\n    auto dprops = get_device_prop();\n    bool is_sm8x = dprops->major >= 8;\n    STD_TORCH_CHECK(is_sm8x, \"FlashAttention only supports Ampere GPUs or newer.\");\n\n    auto q_type = q.scalar_type();\n    STD_TORCH_CHECK(q_type == torch::headeronly::ScalarType::Half || q_type == torch::headeronly::ScalarType::BFloat16,\n                \"FlashAttention only support fp16 and bf16 data type\");\n    STD_TORCH_CHECK(k.scalar_type() == q_type, \"query and key must have the same dtype\");\n    STD_TORCH_CHECK(v.scalar_type() == q_type, \"query and value must have the same dtype\");\n    STD_TORCH_CHECK(out.scalar_type() == q_type, \"query and out must have the same dtype\");\n    STD_TORCH_CHECK(dout.scalar_type() == q_type, \"query and dout must have the same dtype\");\n\n    CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);\n    CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);\n\n    STD_TORCH_CHECK(q.stride(-1) == 1, \"Input tensor must have contiguous last dimension\");\n    STD_TORCH_CHECK(k.stride(-1) == 1, \"Input tensor must have contiguous last dimension\");\n    STD_TORCH_CHECK(v.stride(-1) == 1, \"Input tensor must have contiguous last dimension\");\n    STD_TORCH_CHECK(out.stride(-1) == 1, \"out tensor must have contiguous last dimension\");\n    STD_TORCH_CHECK(dout.stride(-1) == 1, \"dout tensor must have contiguous last dimension\");\n\n    Tensor cu_seqlens_q;\n    bool const is_varlen_q = cu_seqlens_q_.has_value();\n    if (is_varlen_q) {\n        cu_seqlens_q = cu_seqlens_q_.value();\n        CHECK_DEVICE(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_q);\n        STD_TORCH_CHECK(cu_seqlens_q.scalar_type() == torch::headeronly::ScalarType::Int, \"cu_seqlens_q must have dtype torch.int32\");\n        STD_TORCH_CHECK(max_seqlen_q_.has_value(), \"max_seqlen_q must be provided if cu_seqlens_q is provided\");\n    }\n    Tensor cu_seqlens_k;\n    bool const is_varlen_k = cu_seqlens_k_.has_value();\n    if (is_varlen_k) {\n        cu_seqlens_k = cu_seqlens_k_.value();\n        CHECK_DEVICE(cu_seqlens_k); CHECK_CONTIGUOUS(cu_seqlens_k);\n        STD_TORCH_CHECK(cu_seqlens_k.scalar_type() == torch::headeronly::ScalarType::Int, \"cu_seqlens_k must have dtype torch.int32\");\n        STD_TORCH_CHECK(max_seqlen_k_.has_value(), \"max_seqlen_k must be provided if cu_seqlens_k is provided\");\n    }\n    // This is what we will template on\n    bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value();\n    #ifdef FLASHATTENTION_DISABLE_VARLEN\n        STD_TORCH_CHECK(!is_varlen, \"This flash attention build does not support varlen.\");\n    #endif\n\n    // auto const sizes = q.sizes();\n    int const batch_size = !is_varlen_q ? q.size(0) : cu_seqlens_q.size(0) - 1;\n    int const seqlen_q = !is_varlen_q ? q.size(1) : max_seqlen_q_.value();\n    int const total_q = !is_varlen_q ? batch_size * q.size(1) : q.size(0);\n    int const num_heads = q.size(-2);\n    int const head_size = q.size(-1);\n    int const head_size_v = v.size(-1);\n    int const seqlen_k = !is_varlen_k ? k.size(1) : max_seqlen_k_.value();\n    int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0);\n    int const num_heads_k = k.size(-2);\n    STD_TORCH_CHECK(head_size % 8 == 0, \"head_size should be a multiple of 8\");\n    STD_TORCH_CHECK(head_size_v % 8 == 0, \"head_size_v should be a multiple of 8\");\n    int const max_headdim = get_max_headdim();\n    STD_TORCH_CHECK(std::max(head_size, head_size_v) <= max_headdim, \"FlashAttention forward only supports head dimension at most \" + std::to_string(max_headdim));\n    STD_TORCH_CHECK(num_heads % num_heads_k == 0, \"Number of heads in key/value must divide number of heads in query\");\n    double softmax_scale = 1.0 / sqrt(double(head_size));\n    if (softmax_scale_.has_value()) {\n        softmax_scale = softmax_scale_.value();\n    }\n\n    // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM\n    if (window_size_left >= seqlen_k - 1) { window_size_left = -1; }\n    if (window_size_right >= seqlen_q - 1) { window_size_right = -1; }\n    if (is_causal) { window_size_right = 0; }\n    // There's a case where is_causal=false, window_size=(-1, 0). Then set_params_bprop will set params.is_causal=true.\n    // If we don't have is_causal here matching params.is_causal, we might get the wrong kBlockM (and cause IMA).\n    is_causal = window_size_left < 0 && window_size_right == 0;\n\n    int const arch = dprops->major * 10 + dprops->minor;\n    int const head_size_rounded = round_up_headdim(std::max(head_size, head_size_v));\n    int const head_size_v_rounded = head_size_rounded;\n    STD_TORCH_CHECK(!deterministic || head_size_rounded < 256, \"Deterministic backward not supported for hdim 256.\");\n    // Very important that these match the kernel configs\n    bool const is_local = (window_size_left >= 0 || window_size_right >= 0) && !is_causal;\n    int const kBlockM_sm90 = head_size_rounded <= 64 ? (is_causal && softcap > 0.0 ? 96 : 128)\n        : (head_size_rounded <= 96 ? 64\n           : (head_size_rounded <= 128 ? (is_causal || is_local || softcap > 0.0 ? 64 : 80)\n              : 64));\n    int const kBlockM_sm80 = head_size_rounded <= 64 ? 128 : 64;\n    int const kBlockM_sm86 = head_size_rounded <= 192 ? 64 : 32;\n    int const kBlockM = arch >= 90 ? kBlockM_sm90 : (arch == 86 || arch == 89 ? kBlockM_sm86 : kBlockM_sm80);\n    int const kBlockN_sm90 = head_size_rounded <= 128\n        ? 128\n        : (head_size_rounded <= 192 ? 96 : 80);\n    int const kBlockN_sm80 = head_size_rounded <= 128\n        ? 128\n        : (head_size_rounded <= 192 ? 80 : 64);\n    int const kBlockN_sm86 = head_size_rounded <= 64 ? 128\n        : (head_size_rounded <= 96 ? 128\n           : (head_size_rounded <= 128 ? 96\n              : (head_size_rounded <= 192 ? 64 : 64)));\n    int const kBlockN = arch >= 90 ? kBlockN_sm90 : (arch == 86 || arch == 89 ? kBlockN_sm86 : kBlockN_sm80);\n    auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };\n    int const seqlen_q_rounded = round_multiple(seqlen_q, kBlockM);\n    int const seqlen_k_rounded = round_multiple(seqlen_k, kBlockN);\n    int const total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM);\n    int const total_k_padded_rounded = round_multiple(total_k + batch_size * kBlockN, kBlockN);\n\n    if (!is_varlen_q) {\n        CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);\n        CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_v);\n        CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_v);\n    } else {\n        CHECK_SHAPE(q, total_q, num_heads, head_size);\n        CHECK_SHAPE(out, total_q, num_heads, head_size_v);\n        CHECK_SHAPE(dout, total_q, num_heads, head_size_v);\n        CHECK_SHAPE(cu_seqlens_q, batch_size + 1);\n    }\n    if (!is_varlen_k) {\n        CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);\n        CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_v);\n    } else {\n        CHECK_SHAPE(k, total_k, num_heads_k, head_size);\n        CHECK_SHAPE(v, total_k, num_heads_k, head_size_v);\n        CHECK_SHAPE(cu_seqlens_k, batch_size + 1);\n    }\n\n    if (seqused_q_.has_value()){\n        auto seqused_q = seqused_q_.value();\n        STD_TORCH_CHECK(seqused_q.scalar_type() == torch::headeronly::ScalarType::Int, \"seqused_q must have dtype int32\");\n        CHECK_DEVICE(seqused_q); CHECK_CONTIGUOUS(seqused_q);\n        CHECK_SHAPE(seqused_q, batch_size);\n    }\n    if (seqused_k_.has_value()){\n        auto seqused_k = seqused_k_.value();\n        STD_TORCH_CHECK(seqused_k.scalar_type() == torch::headeronly::ScalarType::Int, \"seqused_k must have dtype int32\");\n        CHECK_DEVICE(seqused_k); CHECK_CONTIGUOUS(seqused_k);\n        CHECK_SHAPE(seqused_k, batch_size);\n    }\n\n    Tensor dq, dk, dv;\n    if (dq_.has_value()) {\n        dq = dq_.value();\n        STD_TORCH_CHECK(dq.scalar_type() == q_type, \"dq must have the same dtype as q\");\n        CHECK_DEVICE(dq);\n        STD_TORCH_CHECK(dq.stride(-1) == 1, \"dq must have contiguous last dimension\");\n        if (!is_varlen_q) {\n            CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size);\n        } else {\n            CHECK_SHAPE(dq, total_q, num_heads, head_size);\n        }\n    } else {\n        dq = torch::stable::empty_like(q);\n    }\n    if (dk_.has_value()) {\n        dk = dk_.value();\n        STD_TORCH_CHECK(dk.scalar_type() == q_type, \"dk must have the same dtype as q\");\n        CHECK_DEVICE(dk);\n        STD_TORCH_CHECK(dk.stride(-1) == 1, \"dk must have contiguous last dimension\");\n        if (!is_varlen_k) {\n            CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size);\n        } else {\n            CHECK_SHAPE(dk, total_k, num_heads_k, head_size);\n        }\n    } else {\n        dk = torch::stable::empty_like(k);\n    }\n    if (dv_.has_value()) {\n        dv = dv_.value();\n        STD_TORCH_CHECK(dv.scalar_type() == q_type, \"dv must have the same dtype as q\");\n        CHECK_DEVICE(dv);\n        STD_TORCH_CHECK(dv.stride(-1) == 1, \"dv must have contiguous last dimension\");\n        if (!is_varlen_k) {\n            CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size_v);\n        } else {\n            CHECK_SHAPE(dv, total_k, num_heads_k, head_size_v);\n        }\n    } else {\n        dv = torch::stable::empty_like(v);\n    }\n\n    // Otherwise the kernel will be launched from cuda:0 device\n    // Cast to char to avoid compiler warning about narrowing\n    auto device_guard = make_device_guard(q);\n\n    // auto opts = q.options();\n    // Need softmax_d to have total_q_padded_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64\n    Tensor softmax_d, softmax_lse_log2;\n    if (!is_varlen) {\n        // Need softmax_d to have seqlen_q_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64\n        softmax_d = torch::stable::new_empty(q, {batch_size, num_heads, seqlen_q_rounded}, std::make_optional(torch::headeronly::ScalarType::Float));\n        softmax_lse_log2 = torch::stable::new_empty(q, {batch_size, num_heads, seqlen_q_rounded}, std::make_optional(torch::headeronly::ScalarType::Float));\n    } else {\n        softmax_d = torch::stable::new_empty(q, {num_heads, total_q_padded_rounded}, std::make_optional(torch::headeronly::ScalarType::Float));\n        softmax_lse_log2 = torch::stable::new_empty(q, {num_heads, total_q_padded_rounded}, std::make_optional(torch::headeronly::ScalarType::Float));\n    }\n    Tensor dq_accum, dk_accum, dv_accum;\n    if (!is_varlen) {\n        dq_accum = torch::stable::new_empty(q, {batch_size, num_heads, seqlen_q_rounded * head_size_rounded}, std::make_optional(torch::headeronly::ScalarType::Float));\n    } else {\n        dq_accum = torch::stable::new_empty(q, {num_heads, total_q_padded_rounded * head_size_rounded}, std::make_optional(torch::headeronly::ScalarType::Float));\n    }\n    if (num_heads_k != num_heads) {  // MQA / GQA\n        if (!is_varlen) {\n            dk_accum = torch::stable::new_empty(q, {batch_size, num_heads_k, seqlen_k_rounded * head_size_rounded}, std::make_optional(torch::headeronly::ScalarType::Float));\n            dk_accum = torch::stable::fill_(dk_accum, 0.0);\n            dv_accum = torch::stable::new_empty(q, {batch_size, num_heads_k, seqlen_k_rounded * head_size_v_rounded}, std::make_optional(torch::headeronly::ScalarType::Float));\n            dv_accum = torch::stable::fill_(dv_accum, 0.0);\n        } else {\n            dk_accum = torch::stable::new_empty(q, {num_heads_k, total_k_padded_rounded, head_size_rounded}, std::make_optional(torch::headeronly::ScalarType::Float));\n            dk_accum = torch::stable::fill_(dk_accum, 0.0);\n            dv_accum = torch::stable::new_empty(q, {num_heads_k, total_k_padded_rounded, head_size_v_rounded}, std::make_optional(torch::headeronly::ScalarType::Float));\n            dv_accum = torch::stable::fill_(dv_accum, 0.0);\n        }\n    }\n\n    Flash_bwd_params params;\n    set_params_dgrad(params,\n                     batch_size,\n                     seqlen_q, seqlen_k,\n                     seqlen_q_rounded, seqlen_k_rounded,\n                     num_heads, num_heads_k,\n                     head_size, head_size_rounded,\n                     q, k, v, out,\n                     dout, dq, dk, dv,\n                     !is_varlen_q ? nullptr : cu_seqlens_q.data_ptr(),\n                     !is_varlen_k ? nullptr : cu_seqlens_k.data_ptr(),\n                     seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr,\n                     seqused_k_.has_value() ? seqused_k_.value().data_ptr() : nullptr,\n                     dq_accum.data_ptr(),\n                     num_heads_k != num_heads ? dk_accum.data_ptr() : nullptr,\n                     num_heads_k != num_heads ? dv_accum.data_ptr() : nullptr,\n                     softmax_lse.data_ptr(),\n                     softmax_d.data_ptr(),\n                     /*p_dropout=*/0.f,\n                     softmax_scale,\n                     window_size_left,\n                     window_size_right,\n                     0,  // attention_chunk\n                     softcap,\n                     deterministic,\n                     sm_margin);\n    params.total_q = total_q;\n    params.total_k = total_k;\n    params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr();\n    params.dv = head_size_v;\n    params.dv_rounded = head_size_v_rounded;\n\n    // auto tile_count_semaphore = (params.is_causal || params.is_local) ? torch::zeros({1}, opts.dtype(torch::headeronly::ScalarType::Int)) : torch::empty({1}, opts.dtype(torch::headeronly::ScalarType::Int));\n    // params.tile_count_semaphore = static_cast<int*>(tile_count_semaphore.data_ptr());\n    // Will be zero'ed out in the backward preprocess kernel\n    Tensor dq_semaphore = torch::stable::new_empty(q, {(seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads}, std::make_optional(torch::headeronly::ScalarType::Int));\n    params.dq_semaphore = static_cast<int*>(dq_semaphore.data_ptr());\n    Tensor dk_semaphore, dv_semaphore;\n    if (num_heads_k != num_heads && params.deterministic) {\n        // TODO: maybe also zero'ed out dk_semaphore and dv_semaphore in the backward preprocess kernel\n        dk_semaphore = torch::stable::new_zeros(q, {(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, std::make_optional(torch::headeronly::ScalarType::Int));\n        dv_semaphore = torch::stable::new_zeros(q, {(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, std::make_optional(torch::headeronly::ScalarType::Int));\n        params.dk_semaphore = static_cast<int*>(dk_semaphore.data_ptr());\n        params.dv_semaphore = static_cast<int*>(dv_semaphore.data_ptr());\n    }\n\n    #ifdef FLASHATTENTION_DISABLE_LOCAL\n    STD_TORCH_CHECK(!params.is_local, \"This flash attention build does not support local attention.\");\n    #endif\n    #ifdef FLASHATTENTION_DISABLE_SOFTCAP\n    STD_TORCH_CHECK(params.softcap == 0.0, \"This flash attention build does not support tanh softcapping.\");\n    #endif\n\n    if (total_q > 0 && total_k > 0 && num_heads_k > 0) {\n        auto device_idx = torch::stable::accelerator::getCurrentDeviceIndex();\n        void* stream_ptr = nullptr;\n        TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream(device_idx, &stream_ptr));\n        cudaStream_t stream = static_cast<cudaStream_t>(stream_ptr);\n        run_mha_bwd(params, stream);\n    } else if (total_k > 0 && num_heads_k > 0) {\n        // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.\n        torch::stable::zero_(dk);\n        torch::stable::zero_(dv);\n        torch::stable::zero_(softmax_d);\n    } else if (total_q > 0 && num_heads_k > 0) {\n        torch::stable::zero_(dq);\n        torch::stable::zero_(softmax_d);\n    }\n\n    return { softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum };\n}\n\nstd::tuple<Tensor, Tensor>\nmha_combine(Tensor out_partial,         // num_splits x batch_size x seqlen x num_heads x head_size\n            Tensor lse_partial,         // num_splits x batch_size x seqlen x num_heads\n            std::optional<Tensor> out_,        // batch_size x seqlen x num_heads x head_size\n            std::optional<torch::headeronly::ScalarType> out_dtype_\n            ) {\n\n    auto dprops = get_device_prop();\n    bool is_sm8x = dprops->major >= 8;\n    STD_TORCH_CHECK(is_sm8x, \"Attention combine function only supports Ampere GPUs or newer.\");\n\n    auto out_partial_type = out_partial.scalar_type();\n    STD_TORCH_CHECK(out_partial_type == torch::headeronly::ScalarType::Float, \"Attention combine function only support fp32 data type\");\n    STD_TORCH_CHECK(lse_partial.scalar_type() == torch::headeronly::ScalarType::Float, \"Attention combine function only support fp32 data type\");\n\n    CHECK_DEVICE(out_partial); CHECK_DEVICE(lse_partial);\n\n    STD_TORCH_CHECK(out_partial.stride(-1) == 1, \"Input tensor must have contiguous last dimension\");\n    STD_TORCH_CHECK(lse_partial.stride(-2) == 1, \"LSE tensor must be contiguous in the seqlen dimension\");\n\n    // const auto sizes = out_partial.sizes();\n\n    const int num_splits = out_partial.size(0);\n    const int batch_size = out_partial.size(1);\n    const int seqlen = out_partial.size(2);\n    const int num_heads = out_partial.size(3);\n    const int head_size_og = out_partial.size(4);\n    STD_TORCH_CHECK(num_splits <= 256, \"FlashAttention combine only supports num_splits at most 256\");\n\n    CHECK_SHAPE(out_partial, num_splits, batch_size, seqlen, num_heads, head_size_og);\n    CHECK_SHAPE(lse_partial, num_splits, batch_size, seqlen, num_heads);\n\n    int const alignment = 4;\n    Tensor out_partial_padded;\n    auto pad = [](Tensor x, int alignment) {\n        return x.size(-1) % alignment == 0 ? x : torch::stable::pad(x, {0, alignment - x.size(-1) % alignment});\n    };\n    out_partial_padded = pad(out_partial, alignment);\n\n    auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };\n    const int head_size = round_multiple(head_size_og, alignment);\n\n    // auto opts = out_partial.options();\n    torch::headeronly::ScalarType out_type = out_dtype_.value_or(out_partial.scalar_type());\n    STD_TORCH_CHECK(out_type == torch::headeronly::ScalarType::Float || out_type == torch::headeronly::ScalarType::BFloat16 || out_type == torch::headeronly::ScalarType::Half, \"Output type must be FP32, FP16 or BF16\");\n    Tensor out;\n    if (out_.has_value()) {\n        out = out_.value();\n        STD_TORCH_CHECK(out.scalar_type() == out_type);\n        CHECK_DEVICE(out);\n        STD_TORCH_CHECK(out.stride(-1) == 1, \"Output tensor must have contiguous last dimension\");\n        CHECK_SHAPE(out, batch_size, seqlen, num_heads, head_size_og);\n        if (head_size_og % alignment != 0) {\n            out = torch::stable::new_empty(out_partial, {batch_size, seqlen, num_heads, head_size}, std::make_optional(out_type));\n        }\n    } else {\n        out = torch::stable::new_empty(out_partial, {batch_size, seqlen, num_heads, head_size}, std::make_optional(out_type));\n    }\n\n    // Otherwise the kernel will be launched from cuda:0 device\n    // Cast to char to avoid compiler warning about narrowing\n    auto device_guard = make_device_guard(out_partial);\n\n    auto softmax_lse = torch::stable::new_empty(out_partial, {batch_size, num_heads, seqlen}, std::make_optional(torch::headeronly::ScalarType::Float));\n    softmax_lse = torch::stable::transpose(softmax_lse, 1, 2);\n\n    Flash_fwd_params params {};  // Need to reset the params to set everything to zero\n    params.is_fp32 = out_type == torch::headeronly::ScalarType::Float;\n    params.is_bf16 = out_type == torch::headeronly::ScalarType::BFloat16;\n    params.oaccum_ptr = out_partial_padded.data_ptr();\n    params.softmax_lseaccum_ptr = lse_partial.data_ptr();\n    params.o_ptr = out.data_ptr();\n    params.softmax_lse_ptr = softmax_lse.data_ptr();\n    params.b = batch_size;\n    params.h = num_heads;\n    params.seqlen_q = seqlen;\n    params.dv = head_size;\n    params.num_splits = num_splits;\n    params.oaccum_split_stride = out_partial_padded.stride(0);\n    params.oaccum_row_stride = out_partial_padded.stride(2);\n    params.oaccum_head_stride = out_partial_padded.stride(3);\n    params.oaccum_batch_stride = out_partial_padded.stride(1);\n    params.lseaccum_split_stride = lse_partial.stride(0);\n    params.lseaccum_head_stride = lse_partial.stride(3);\n    params.lseaccum_batch_stride = lse_partial.stride(1);\n    params.o_row_stride = out.stride(1);\n    params.o_head_stride = out.stride(2);\n    params.o_batch_stride = out.stride(0);\n    params.arch = dprops->major * 10 + dprops->minor;\n\n    if (seqlen > 0 && batch_size > 0) {\n        auto device_idx = torch::stable::accelerator::getCurrentDeviceIndex();\n        void* stream_ptr = nullptr;\n        TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream(device_idx, &stream_ptr));\n        cudaStream_t stream = static_cast<cudaStream_t>(stream_ptr);\n        run_mha_fwd_combine(params, stream, false /*enable_pdl*/);\n    }\n\n    Tensor out_padded = out;\n    if (head_size_og % alignment != 0) {\n        out = torch::stable::narrow(out, -1, 0, head_size_og);\n        // if (out_.has_value()) { out_.value().copy_(out); }\n    }\n\n    return {out, softmax_lse};\n}\n\nvoid boxed_mha_fwd(\n    StableIValue* stack,\n    uint64_t num_args,\n    uint64_t num_outputs\n) {\n    auto q = to<Tensor>(stack[0]);\n    auto k = to<Tensor>(stack[1]);\n    auto v = to<Tensor>(stack[2]);\n    auto k_new = to<std::optional<Tensor>>(stack[3]);\n    auto v_new = to<std::optional<Tensor>>(stack[4]);\n    auto q_v = to<std::optional<Tensor>>(stack[5]);\n    auto out = to<std::optional<Tensor>>(stack[6]);\n    auto cu_seqlens_q = to<std::optional<Tensor>>(stack[7]);\n    auto cu_seqlens_k = to<std::optional<Tensor>>(stack[8]);\n    auto cu_seqlens_k_new = to<std::optional<Tensor>>(stack[9]);\n    auto seqused_q = to<std::optional<Tensor>>(stack[10]);\n    auto seqused_k = to<std::optional<Tensor>>(stack[11]);\n    auto max_seqlen_q = to<std::optional<int64_t>>(stack[12]);\n    auto max_seqlen_k = to<std::optional<int64_t>>(stack[13]);\n    auto page_table = to<std::optional<Tensor>>(stack[14]);\n    auto kv_batch_idx = to<std::optional<Tensor>>(stack[15]);\n    auto leftpad_k = to<std::optional<Tensor>>(stack[16]);\n    auto rotary_cos = to<std::optional<Tensor>>(stack[17]);\n    auto rotary_sin = to<std::optional<Tensor>>(stack[18]);\n    auto seqlens_rotary = to<std::optional<Tensor>>(stack[19]);\n    auto q_descale = to<std::optional<Tensor>>(stack[20]);\n    auto k_descale = to<std::optional<Tensor>>(stack[21]);\n    auto v_descale = to<std::optional<Tensor>>(stack[22]);\n    auto softmax_scale = to<std::optional<double>>(stack[23]);\n    auto is_causal = to<bool>(stack[24]);\n    auto window_size_left = to<int64_t>(stack[25]);\n    auto window_size_right = to<int64_t>(stack[26]);\n    auto attention_chunk = to<int64_t>(stack[27]);\n    auto softcap = to<double>(stack[28]);\n    auto is_rotary_interleaved = to<bool>(stack[29]);\n    auto scheduler_metadata = to<std::optional<Tensor>>(stack[30]);\n    auto num_splits = to<int64_t>(stack[31]);\n    auto pack_gqa = to<std::optional<bool>>(stack[32]);\n    auto sm_margin = to<int64_t>(stack[33]);\n\n    auto [out_, softmax_lse, out_accum, softmax_lse_accum] = mha_fwd(q, k, v, k_new, v_new, q_v, out, cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new, seqused_q, seqused_k, max_seqlen_q, max_seqlen_k, page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, seqlens_rotary, q_descale, k_descale, v_descale, softmax_scale, is_causal, window_size_left, window_size_right, attention_chunk, softcap, is_rotary_interleaved, scheduler_metadata, num_splits, pack_gqa, sm_margin);\n\n\n    stack[0] = from(out_);\n    stack[1] = from(softmax_lse);\n    stack[2] = from(out_accum);\n    stack[3] = from(softmax_lse_accum);\n}\n\nvoid boxed_mha_bwd(\n    StableIValue* stack,\n    uint64_t num_args,\n    uint64_t num_outputs\n) {\n    auto dout = to<Tensor>(stack[0]);\n    auto q = to<Tensor>(stack[1]);\n    auto k = to<Tensor>(stack[2]);\n    auto v = to<Tensor>(stack[3]);\n    auto out = to<Tensor>(stack[4]);\n    auto softmax_lse = to<Tensor>(stack[5]);\n    auto dq = to<std::optional<Tensor>>(stack[6]);\n    auto dk = to<std::optional<Tensor>>(stack[7]);\n    auto dv = to<std::optional<Tensor>>(stack[8]);\n    auto cu_seqlens_q = to<std::optional<Tensor>>(stack[9]);\n    auto cu_seqlens_k = to<std::optional<Tensor>>(stack[10]);\n    auto seqused_q = to<std::optional<Tensor>>(stack[11]);\n    auto seqused_k = to<std::optional<Tensor>>(stack[12]);\n    auto max_seqlen_q = to<std::optional<int64_t>>(stack[13]);\n    auto max_seqlen_k = to<std::optional<int64_t>>(stack[14]);\n    auto softmax_scale = to<std::optional<double>>(stack[15]);\n    auto is_causal = to<bool>(stack[16]);\n    auto window_size_left = to<int64_t>(stack[17]);\n    auto window_size_right = to<int64_t>(stack[18]);\n    auto softcap = to<double>(stack[19]);\n    auto deterministic = to<bool>(stack[20]);\n    auto sm_margin = to<int64_t>(stack[21]);\n\n    auto [softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum] = mha_bwd(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, max_seqlen_q, max_seqlen_k, softmax_scale, is_causal, window_size_left, window_size_right, softcap, deterministic, sm_margin);\n\n    stack[0] = from(softmax_d);\n    stack[1] = from(softmax_lse_log2);\n    stack[2] = from(dq_accum);\n    stack[3] = from(dk_accum);\n    stack[4] = from(dv_accum);\n}\n\nvoid boxed_mha_combine(\n    StableIValue* stack,\n    uint64_t num_args,\n    uint64_t num_outputs\n) {\n    auto out_partial = to<Tensor>(stack[0]);\n    auto lse_partial = to<Tensor>(stack[1]);\n    auto out = to<std::optional<Tensor>>(stack[2]);\n    auto out_dtype = to<std::optional<torch::headeronly::ScalarType>>(stack[3]);\n\n    auto [out_, softmax_lse] = mha_combine(out_partial, lse_partial, out, out_dtype);\n\n    stack[0] = from(out_);\n    stack[1] = from(softmax_lse);\n}\n\nvoid boxed_mha_fwd_get_scheduler_metadata(\n    StableIValue* stack,\n    uint64_t num_args,\n    uint64_t num_outputs\n) {\n    auto batch_size = to<int64_t>(stack[0]);\n    auto max_seqlen_q = to<int64_t>(stack[1]);\n    auto max_seqlen_k = to<int64_t>(stack[2]);\n    auto num_heads = to<int64_t>(stack[3]);\n    auto num_heads_k = to<int64_t>(stack[4]);\n    auto headdim = to<int64_t>(stack[5]);\n    auto headdim_v = to<int64_t>(stack[6]);\n    auto qkv_dtype = to<torch::headeronly::ScalarType>(stack[7]);\n    auto seqused_k = to<Tensor>(stack[8]);\n    auto cu_seqlens_q = to<std::optional<Tensor>>(stack[9]);\n    auto cu_seqlens_k = to<std::optional<Tensor>>(stack[10]);\n    auto cu_seqlens_k_new = to<std::optional<Tensor>>(stack[11]);\n    auto seqused_q = to<std::optional<Tensor>>(stack[12]);\n    auto leftpad_k = to<std::optional<Tensor>>(stack[13]);\n    auto page_size = to<std::optional<int64_t>>(stack[14]);\n    auto max_seqlen_k_new = to<int64_t>(stack[15]);\n    auto is_causal = to<bool>(stack[16]);\n    auto window_size_left = to<int64_t>(stack[17]);\n    auto window_size_right = to<int64_t>(stack[18]);\n    auto attention_chunk = to<int64_t>(stack[19]);\n    auto has_softcap = to<bool>(stack[20]);\n    auto num_splits = to<int64_t>(stack[21]);\n    auto pack_gqa = to<std::optional<bool>>(stack[22]);\n    auto sm_margin = to<int64_t>(stack[23]);\n\n    auto scheduler_metadata = mha_fwd_get_scheduler_metadata(batch_size, max_seqlen_q, max_seqlen_k, num_heads, num_heads_k, headdim, headdim_v, qkv_dtype, seqused_k, cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new, seqused_q, leftpad_k, page_size, max_seqlen_k_new, is_causal, window_size_left, window_size_right, attention_chunk, has_softcap, num_splits, pack_gqa, sm_margin);\n\n    stack[0] = from(scheduler_metadata);\n}\n\nSTABLE_TORCH_LIBRARY(flash_attn_3, m) {\n    m.def(\"fwd(\"\n        \"Tensor q,\"\n        \"Tensor k,\"\n        \"Tensor v,\"\n        \"Tensor(k_new!)? k_new = None,\"\n        \"Tensor(v_new!)? v_new = None,\"\n        \"Tensor? q_v = None,\"\n        \"Tensor(out!)? out = None,\"\n        \"Tensor? cu_seqlens_q = None,\"\n        \"Tensor? cu_seqlens_k = None,\"\n        \"Tensor? cu_seqlens_k_new = None,\"\n        \"Tensor? seqused_q = None,\"\n        \"Tensor? seqused_k = None,\"\n        \"int? max_seqlen_q = None,\"\n        \"int? max_seqlen_k = None,\"\n        \"Tensor? page_table = None,\"\n        \"Tensor? kv_batch_idx = None,\"\n        \"Tensor? leftpad_k = None,\"\n        \"Tensor? rotary_cos = None,\"\n        \"Tensor? rotary_sin = None,\"\n        \"Tensor? seqlens_rotary = None,\"\n        \"Tensor? q_descale = None,\"\n        \"Tensor? k_descale = None,\"\n        \"Tensor? v_descale = None,\"\n        \"float? softmax_scale = None,\"\n        \"bool is_causal = False,\"\n        \"int window_size_left = -1,\"\n        \"int window_size_right = -1,\"\n        \"int attention_chunk = 0,\"\n        \"float softcap = 0.0,\"\n        \"bool is_rotary_interleaved = False,\"\n        \"Tensor? scheduler_metadata = None,\"\n        \"int num_splits = 0,\"\n        \"bool? pack_gqa = None,\"\n        \"int sm_margin = 0) -> (Tensor(out!), Tensor, Tensor, Tensor)\");\n    m.def(\"bwd(\"\n        \"Tensor dout,\"\n        \"Tensor q,\"\n        \"Tensor k,\"\n        \"Tensor v,\"\n        \"Tensor out,\"\n        \"Tensor softmax_lse,\"\n        \"Tensor(dq!)? dq = None,\"\n        \"Tensor(dk!)? dk = None,\"\n        \"Tensor(dv!)? dv = None,\"\n        \"Tensor? cu_seqlens_q = None,\"\n        \"Tensor? cu_seqlens_k = None,\"\n        \"Tensor? seqused_q = None,\"\n        \"Tensor? seqused_k = None,\"\n        \"int? max_seqlen_q = None,\"\n        \"int? max_seqlen_k = None,\"\n        \"float? softmax_scale = None,\"\n        \"bool is_causal = False,\"\n        \"int window_size_left = -1,\"\n        \"int window_size_right = -1,\"\n        \"float softcap = 0.0,\"\n        \"bool deterministic = False,\"\n        \"int sm_margin = 0) -> (Tensor, Tensor, Tensor, Tensor, Tensor)\");\n    m.def(\"fwd_combine(\"\n        \"Tensor out_partial,\"\n        \"Tensor lse_partial,\"\n        \"Tensor(out!)? out = None,\"\n        \"ScalarType? out_dtype = None) -> (Tensor(out!), Tensor)\");\n    m.def(\"get_scheduler_metadata(\"\n        \"int batch_size,\"\n        \"int max_seqlen_q,\"\n        \"int max_seqlen_k,\"\n        \"int num_heads,\"\n        \"int num_heads_k,\"\n        \"int headdim,\"\n        \"int headdim_v,\"\n        \"ScalarType qkv_dtype,\"\n        \"Tensor seqused_k,\"\n        \"Tensor? cu_seqlens_q = None,\"\n        \"Tensor? cu_seqlens_k = None,\"\n        \"Tensor? cu_seqlens_k_new = None,\"\n        \"Tensor? seqused_q = None,\"\n        \"Tensor? leftpad_k = None,\"\n        \"int? page_size = None,\"\n        \"int max_seqlen_k_new = 0,\"\n        \"bool is_causal = False,\"\n        \"int window_size_left = -1,\"\n        \"int window_size_right = -1,\"\n        \"int attention_chunk = 0,\"\n        \"bool has_softcap = False,\"\n        \"int num_splits = 0,\"\n        \"bool? pack_gqa = None,\"\n        \"int sm_margin = 0) -> Tensor\");\n}\n\nSTABLE_TORCH_LIBRARY_IMPL(flash_attn_3, CUDA, m) {\n    m.impl(\"fwd\", &boxed_mha_fwd);\n    m.impl(\"bwd\", &boxed_mha_bwd);\n    m.impl(\"fwd_combine\", &boxed_mha_combine);\n    m.impl(\"get_scheduler_metadata\", &boxed_mha_fwd_get_scheduler_metadata);\n}\n"
  },
  {
    "path": "hopper/flash_attn_interface.py",
    "content": "# Copyright (c) 2023, Tri Dao.\n\nfrom typing import Optional, Union, List, Tuple\n\nimport os\nimport torch\nimport torch.nn as nn\nimport warnings\n\n\nUSE_TRITON_ROCM = os.getenv(\"FLASH_ATTENTION_TRITON_AMD_ENABLE\", \"FALSE\") == \"TRUE\"\nif not USE_TRITON_ROCM and getattr(torch.version, 'hip', None) is not None:\n    try:\n        import flash_attn_3._C\n    except ImportError:\n        warnings.warn(\"flash_attn_3._C (which has ROCm/HIP kernels) not found, falling back to Triton implementation\")\n        USE_TRITON_ROCM = True\n\nif USE_TRITON_ROCM:\n    from aiter.ops.triton._triton_kernels.flash_attn_triton_amd import flash_attn_3 as flash_attn_3_gpu\nelse:\n    # isort: off\n    # We need to import the CUDA kernels after importing torch\n    import flash_attn_3._C # Registers operators with PyTorch\n\n    # isort: on\n\n    flash_attn_3_gpu = torch.ops.flash_attn_3\n\ndef maybe_contiguous(x):\n    return x.contiguous() if x is not None and x.stride(-1) != 1 else x\n\n\ndef round_multiple(x, m):\n    return (x + m - 1) // m * m\n\n\ndef round_up_headdim(head_size: int) -> int:\n    from flash_attn_config import CONFIG\n\n    if not CONFIG[\"build_flags\"][\"FLASHATTENTION_DISABLE_HDIM64\"]:\n        if head_size <= 64:\n            return 64\n    if not CONFIG[\"build_flags\"][\"FLASHATTENTION_DISABLE_HDIM96\"]:\n        if head_size <= 96:\n            return 96\n    if not CONFIG[\"build_flags\"][\"FLASHATTENTION_DISABLE_HDIM128\"]:\n        if head_size <= 128:\n            return 128\n    if not CONFIG[\"build_flags\"][\"FLASHATTENTION_DISABLE_HDIM192\"]:\n        if head_size <= 192:\n            return 192\n    if not CONFIG[\"build_flags\"][\"FLASHATTENTION_DISABLE_HDIM256\"]:\n        if head_size <= 256:\n            return 256\n    return 256\n\n\n@torch.library.custom_op(\"flash_attn_3::_flash_attn_forward\", mutates_args=(), device_types=\"cuda\")\ndef _flash_attn_forward(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    k_new: Optional[torch.Tensor] = None,\n    v_new: Optional[torch.Tensor] = None,\n    qv: Optional[torch.Tensor] = None,\n    out_: Optional[torch.Tensor] = None,\n    cu_seqlens_q: Optional[torch.Tensor] = None,\n    cu_seqlens_k: Optional[torch.Tensor] = None,\n    cu_seqlens_k_new: Optional[torch.Tensor] = None,\n    seqused_q: Optional[torch.Tensor] = None,\n    seqused_k: Optional[torch.Tensor] = None,\n    max_seqlen_q: Optional[int] = None,\n    max_seqlen_k: Optional[int] = None,\n    page_table: Optional[torch.Tensor] = None,\n    kv_batch_idx: Optional[torch.Tensor] = None,\n    leftpad_k: Optional[torch.Tensor] = None,\n    rotary_cos: Optional[torch.Tensor] = None,\n    rotary_sin: Optional[torch.Tensor] = None,\n    seqlens_rotary: Optional[torch.Tensor] = None,\n    q_descale: Optional[torch.Tensor] = None,\n    k_descale: Optional[torch.Tensor] = None,\n    v_descale: Optional[torch.Tensor] = None,\n    softmax_scale: Optional[float] = None,\n    causal: bool = False,\n    window_size_left: int = -1,\n    window_size_right: int = -1,\n    attention_chunk: int = 0,\n    softcap: float = 0.0,\n    rotary_interleaved: bool = True,\n    scheduler_metadata: Optional[torch.Tensor] = None,\n    num_splits: int = 1,\n    pack_gqa: Optional[bool] = None,\n    sm_margin: int = 0,\n) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n    q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]\n    v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v\n    cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [\n        maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)\n    ]\n    seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]\n    page_table, kv_batch_idx, leftpad_k = [\n        maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)\n    ]\n    rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]\n    seqlens_rotary = maybe_contiguous(seqlens_rotary)\n    out, softmax_lse, out_accum, softmax_lse_accum = flash_attn_3_gpu.fwd(\n        q,\n        k,\n        v,\n        k_new,\n        v_new,\n        qv,\n        out_,\n        cu_seqlens_q,\n        cu_seqlens_k,\n        cu_seqlens_k_new,\n        seqused_q,\n        seqused_k,\n        max_seqlen_q,\n        max_seqlen_k,\n        page_table,\n        kv_batch_idx,\n        leftpad_k,\n        rotary_cos,\n        rotary_sin,\n        seqlens_rotary,\n        q_descale,\n        k_descale,\n        v_descale,\n        softmax_scale,\n        causal,\n        window_size_left,\n        window_size_right,\n        attention_chunk,\n        softcap,\n        rotary_interleaved,\n        scheduler_metadata,\n        num_splits,\n        pack_gqa,\n        sm_margin,\n    )\n\n    if out_accum is None:\n        out_accum = torch.tensor([], device=out.device)\n\n    if softmax_lse_accum is None:\n        softmax_lse_accum = torch.tensor([], device=out.device)\n\n    return out, softmax_lse, out_accum, softmax_lse_accum\n\n\n@torch.library.register_fake(\"flash_attn_3::_flash_attn_forward\")\ndef _flash_attn_forward_fake(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    k_new: Optional[torch.Tensor] = None,\n    v_new: Optional[torch.Tensor] = None,\n    qv: Optional[torch.Tensor] = None,\n    out_: Optional[torch.Tensor] = None,\n    cu_seqlens_q: Optional[torch.Tensor] = None,\n    cu_seqlens_k: Optional[torch.Tensor] = None,\n    cu_seqlens_k_new: Optional[torch.Tensor] = None,\n    seqused_q: Optional[torch.Tensor] = None,\n    seqused_k: Optional[torch.Tensor] = None,\n    max_seqlen_q: Optional[int] = None,\n    max_seqlen_k: Optional[int] = None,\n    page_table: Optional[torch.Tensor] = None,\n    kv_batch_idx: Optional[torch.Tensor] = None,\n    leftpad_k: Optional[torch.Tensor] = None,\n    rotary_cos: Optional[torch.Tensor] = None,\n    rotary_sin: Optional[torch.Tensor] = None,\n    seqlens_rotary: Optional[torch.Tensor] = None,\n    q_descale: Optional[torch.Tensor] = None,\n    k_descale: Optional[torch.Tensor] = None,\n    v_descale: Optional[torch.Tensor] = None,\n    softmax_scale: Optional[float] = None,\n    causal: bool = False,\n    window_size_left: int = -1,\n    window_size_right: int = -1,\n    attention_chunk: int = 0,\n    softcap: float = 0.0,\n    rotary_interleaved: bool = True,\n    scheduler_metadata: Optional[torch.Tensor] = None,\n    num_splits: int = 1,\n    pack_gqa: Optional[bool] = None,\n    sm_margin: int = 0,\n) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Symbolic fake implementation of flash attention forward.\n    Returns tensors with the correct shapes and dtypes without actual computation.\n    \"\"\"\n\n    # Determine if we're in varlen mode\n    is_varlen_q = cu_seqlens_q is not None\n\n    # Get dimensions from query tensor\n    if is_varlen_q:\n        # varlen mode: q is (total_q, num_heads, head_size)\n        total_q, num_heads, head_size = q.shape\n        batch_size = cu_seqlens_q.shape[0] - 1\n\n        if max_seqlen_q is None:\n            raise ValueError(\"max_seqlen_q must be provided if cu_seqlens_q is provided\")\n        seqlen_q = max_seqlen_q\n    else:\n        # batch mode: q is (batch_size, seqlen_q, num_heads, head_size)\n        batch_size, seqlen_q, num_heads, head_size = q.shape\n        total_q = batch_size * q.shape[1]\n    # Get value head dimension\n    head_size_v = v.shape[-1]\n\n    # Determine output dtype (FP8 inputs produce BF16 outputs)\n    q_type = q.dtype\n    if q_type == torch.float8_e4m3fn:\n        out_dtype = torch.bfloat16\n    else:\n        out_dtype = q_type\n\n    # Create output tensor\n    if out_ is not None:\n        # If out_ is provided, _flash_attn_forward becomes non-functional\n        raise TypeError(\"Tracing (torch.compile/torch.export) with pre-allocated output tensor is not supported.\")\n\n    if is_varlen_q:\n        out = torch.empty((total_q, num_heads, head_size_v), dtype=out_dtype, device=q.device)\n    else:\n        out = torch.empty((batch_size, seqlen_q, num_heads, head_size_v), dtype=out_dtype, device=q.device)\n\n    # Create softmax_lse tensor\n    if is_varlen_q:\n        softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device)\n    else:\n        softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device)\n\n    # TODO(guilhermeleobas): Implement \"get_num_splits\"\n    # There's an heuristic to compute num_splits when \"num_splits <= 0\"\n    # assert that num_splits is > 0 for now\n    if num_splits <= 0:\n        raise ValueError(f\"tracing (torch.compile/torch.export) with num_splits <= 0 not supported. Got {num_splits=}\")\n\n    if num_splits > 1:\n        if is_varlen_q:\n            out_accum = torch.empty((num_splits, num_heads, total_q, head_size_v), dtype=torch.float32, device=q.device)\n            softmax_lse_accum = torch.empty((num_splits, num_heads, total_q), dtype=torch.float32, device=q.device)\n        else:\n            out_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q, head_size_v), dtype=torch.float32, device=q.device)\n            softmax_lse_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device)\n    else:\n        # Tensors are not set when num_splits < 1\n        out_accum = torch.tensor([], device=out.device)\n        softmax_lse_accum = torch.tensor([], device=out.device)\n\n    return out, softmax_lse, out_accum, softmax_lse_accum\n\n\n@torch.library.custom_op(\"flash_attn_3::_flash_attn_backward\", mutates_args=(\"dq\", \"dk\", \"dv\"), device_types=\"cuda\")\ndef _flash_attn_backward(\n    dout: torch.Tensor,\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    out: torch.Tensor,\n    softmax_lse: torch.Tensor,\n    cu_seqlens_q: Optional[torch.Tensor] = None,\n    cu_seqlens_k: Optional[torch.Tensor] = None,\n    sequed_q: Optional[torch.Tensor] = None,\n    sequed_k: Optional[torch.Tensor] = None,\n    max_seqlen_q: Optional[int] = None,\n    max_seqlen_k: Optional[int] = None,\n    dq: Optional[torch.Tensor] = None,\n    dk: Optional[torch.Tensor] = None,\n    dv: Optional[torch.Tensor] = None,\n    softmax_scale: Optional[float] = None,\n    is_causal: bool = False,\n    window_size_left: int = -1,\n    window_size_right: int = -1,\n    softcap: float = 0.0,\n    deterministic: bool = False,\n    sm_margin: int = 0,\n) -> torch.Tensor:\n    # dq, dk, dv are allocated by us so they should already be contiguous\n    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]\n    softmax_d, *rest = flash_attn_3_gpu.bwd(\n        dout,\n        q,\n        k,\n        v,\n        out,\n        softmax_lse,\n        dq,\n        dk,\n        dv,\n        cu_seqlens_q,\n        cu_seqlens_k,\n        sequed_q,\n        sequed_k,\n        max_seqlen_q,\n        max_seqlen_k,\n        softmax_scale,\n        is_causal,\n        window_size_left,\n        window_size_right,\n        softcap,\n        deterministic,\n        sm_margin,\n    )\n    return softmax_d\n\n\n@torch.library.register_fake(\"flash_attn_3::_flash_attn_backward\")\ndef _flash_attn_backward_fake(\n    dout: torch.Tensor,\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    out: torch.Tensor,\n    softmax_lse: torch.Tensor,\n    cu_seqlens_q: Optional[torch.Tensor] = None,\n    cu_seqlens_k: Optional[torch.Tensor] = None,\n    sequed_q: Optional[torch.Tensor] = None,\n    sequed_k: Optional[torch.Tensor] = None,\n    max_seqlen_q: Optional[int] = None,\n    max_seqlen_k: Optional[int] = None,\n    dq: Optional[torch.Tensor] = None,\n    dk: Optional[torch.Tensor] = None,\n    dv: Optional[torch.Tensor] = None,\n    softmax_scale: Optional[float] = None,\n    is_causal: bool = False,\n    window_size_left: int = -1,\n    window_size_right: int = -1,\n    softcap: float = 0.0,\n    deterministic: bool = False,\n    sm_margin: int = 0,\n) -> torch.Tensor:\n\n    is_varlen_q = cu_seqlens_q is not None\n    is_varlen_k = cu_seqlens_q is not None\n    is_varlen = is_varlen_q or is_varlen_k or sequed_q is not None or sequed_k is not None\n\n    if not is_varlen_q:\n        batch_size = q.size(0)\n        seqlen_q = q.size(1)\n        seqlen_k = k.size(1)\n        total_q = batch_size * q.size(1)\n    else:\n        batch_size = cu_seqlens_q.size(0) - 1\n        total_q = q.size(0)\n        seqlen_q = max_seqlen_q\n        seqlen_k = max_seqlen_k\n\n    if window_size_left >= seqlen_k - 1:\n        window_size_left = -1\n\n    if window_size_right >= seqlen_q - 1:\n        window_size_right = -1\n\n    if is_causal:\n        window_size_right = 0\n\n    is_causal = window_size_left < 0 and window_size_right == 0\n\n    head_size = q.size(-1)\n    head_size_v = v.size(-1)\n    head_size_rounded = round_up_headdim(max(head_size, head_size_v))\n\n    # Hopper gpus uses cuda compute capabilities 9.0\n    cap = torch.cuda.get_device_capability(q.device)\n    arch = cap[0] * 10 + cap[1]\n\n    is_local = (window_size_left >= 0 or window_size_right >= 0) and not is_causal\n\n    if head_size_rounded <= 64:\n        kBlockM_sm90 = 96 if (is_causal and softcap > 0.0) else 128\n    elif head_size_rounded <= 96:\n        kBlockM_sm90 = 64\n    elif head_size_rounded <= 128:\n        kBlockM_sm90 = 64 if (is_causal or is_local or softcap > 0.0) else 80\n    else:\n        kBlockM_sm90 = 64\n\n    kBlockM_sm80 = 128 if head_size_rounded <= 64 else 64\n    kBlockM_sm86 = 64 if head_size_rounded <= 192 else 32\n\n    if arch >= 90:\n        kBlockM = kBlockM_sm90\n    elif arch == 86 or arch == 89:\n        kBlockM = kBlockM_sm86\n    else:\n        kBlockM = kBlockM_sm80\n\n    num_heads = q.shape[-2]\n    seqlen_q_rounded = round_multiple(seqlen_q, kBlockM)\n\n    total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM)\n\n    dq = torch.empty_like(q) if dq is None else dq\n    dk = torch.empty_like(k) if dk is None else dk\n    dv = torch.empty_like(v) if dv is None else dv\n\n    if not is_varlen:\n        softmax_d = torch.empty((batch_size, num_heads, seqlen_q_rounded), dtype=torch.float32, device=q.device)\n    else:\n        softmax_d = torch.empty((num_heads, total_q_padded_rounded), dtype=torch.float32, device=q.device)\n\n    return softmax_d\n\n\ndef setup_context(ctx, inputs, output):\n    q, k, v = inputs[:3]\n    out, softmax_lse, _, _ = output\n    ctx.save_for_backward(q, k, v, out, softmax_lse)\n    ctx.softmax_scale = inputs[-11]\n    ctx.causal = inputs[-10]\n    ctx.window_size = [inputs[-9], inputs[-8]]\n    ctx.attention_chunk = inputs[-7]\n    ctx.softcap = inputs[-6]\n    ctx.sm_margin = inputs[-1]\n\n\ndef _backward(ctx, dout, *grads):\n    q, k, v, out, softmax_lse = ctx.saved_tensors\n    dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)\n    _flash_attn_backward(\n        dout,\n        q,\n        k,\n        v,\n        out,\n        softmax_lse,\n        None, None, # cu_seqlens_q, cu_seqlens_k,\n        None, None, # sequed_q, sequed_k,\n        None, None, # max_seqlen_q, max_seqlen_k,\n        dq,\n        dk,\n        dv,\n        ctx.softmax_scale,\n        ctx.causal,\n        ctx.window_size[0],\n        ctx.window_size[1],\n        ctx.softcap,\n        False, # deterministic\n        ctx.sm_margin,\n    )\n    return dq, dk, dv, *((None,) * 21)\n\n\n_flash_attn_forward.register_autograd(_backward, setup_context=setup_context)\n\n\n\nclass FlashAttnQKVPackedFunc(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        qkv,\n        softmax_scale,\n        causal,\n        q_descale=None, k_descale=None, v_descale=None,\n        window_size=(-1, -1),\n        attention_chunk=0,\n        softcap=0.0,\n        deterministic=False,\n        num_heads_q=None,\n        sm_margin=0,\n        return_softmax=False,\n    ):\n        if softmax_scale is None:\n            softmax_scale = qkv.shape[-1] ** (-0.5)\n        if qkv.dim() == 5:\n            assert qkv.shape[-3] == 3\n            q, k, v = qkv.unbind(dim=-3)\n        else:\n            assert qkv.dim() == 4\n            assert num_heads_q is not None\n            num_heads_k = (qkv.shape[2] - num_heads_q) // 2\n            assert num_heads_k * 2 + num_heads_q == qkv.shape[2]\n            q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)\n        out, softmax_lse, *rest = _flash_attn_forward(\n            q,\n            k,\n            v,\n            None, None,  # k_new, v_new\n            None,  # qv\n            None,  # out\n            None, None, None,   # cu_seqlens_q/k/k_new\n            None, None,   # seqused_q/k\n            None, None,   # max_seqlen_q/k\n            None, None, None,   # page_table, kv_batch_idx, leftpad_k,\n            None, None, None,  # rotary_cos/sin, seqlens_rotary\n            q_descale, k_descale, v_descale,\n            softmax_scale,\n            causal=causal,\n            window_size_left=window_size[0],\n            window_size_right=window_size[1],\n            attention_chunk=attention_chunk,\n            softcap=softcap,\n            sm_margin=sm_margin,\n        )\n        # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)\n        ctx.save_for_backward(q, k, v, out, softmax_lse)\n        ctx.softmax_scale = softmax_scale\n        ctx.causal = causal\n        ctx.window_size = window_size\n        ctx.attention_chunk = attention_chunk\n        ctx.softcap = softcap\n        ctx.deterministic = deterministic\n        ctx.ndim = qkv.dim()\n        ctx.sm_margin = sm_margin\n        return (out, softmax_lse) if return_softmax else out\n\n    @staticmethod\n    def backward(ctx, dout, *args):\n        q, k, v, out, softmax_lse = ctx.saved_tensors\n        assert ctx.attention_chunk == 0, \"FA3 backward does not support attention_chunk\"\n        if ctx.ndim == 5:\n            qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])\n            dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)\n            dq, dk, dv = dqkv.unbind(dim=-3)\n        else:\n            num_heads_q = q.shape[2]\n            num_heads_k = k.shape[2]\n            qkv_shape = q.shape[:-2] + (num_heads_q + num_heads_k * 2, *q.shape[-1:])\n            dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)\n            dq, dk, dv = dqkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)\n        _flash_attn_backward(\n            dout,\n            q,\n            k,\n            v,\n            out,\n            softmax_lse,\n            None, None, # cu_seqlens_q, cu_seqlens_k,\n            None, None, # sequed_q, sequed_k,\n            None, None, # max_seqlen_q, max_seqlen_k,\n            dq,\n            dk,\n            dv,\n            ctx.softmax_scale,\n            ctx.causal,\n            ctx.window_size[0],\n            ctx.window_size[1],\n            ctx.softcap,\n            ctx.deterministic,\n            ctx.sm_margin,\n        )\n        dqkv = dqkv[..., : dout.shape[-1]]  # We could have padded the head dimension\n        return dqkv, None, None, None, None, None, None, None, None, None, None, None, None\n\n\nclass FlashAttnFunc(torch.autograd.Function):\n\n    @staticmethod\n    def forward(\n        ctx,\n        q,\n        k,\n        v,\n        softmax_scale,\n        causal,\n        qv=None,\n        q_descale=None, k_descale=None, v_descale=None,\n        window_size=(-1, -1),\n        attention_chunk=0,\n        softcap=0.0,\n        num_splits=1,\n        pack_gqa=None,\n        deterministic=False,\n        sm_margin=0,\n        return_softmax=False,\n    ):\n        if softmax_scale is None:\n            softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)\n        # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(\n        out, softmax_lse, *rest = _flash_attn_forward(\n            q,\n            k,\n            v,\n            None, None,  # k_new, v_new\n            qv,  # qv\n            None,  # out\n            None, None, None,   # cu_seqlens_q/k/k_new\n            None, None,   # seqused_q/k\n            None, None,   # max_seqlen_q/k\n            None, None, None,   # page_table, kv_batch_idx, leftpad_k,\n            None, None, None,  # rotary_cos/sin, seqlens_rotary\n            q_descale, k_descale, v_descale,\n            softmax_scale,\n            causal=causal,\n            window_size_left=window_size[0],\n            window_size_right=window_size[1],\n            attention_chunk=attention_chunk,\n            softcap=softcap,\n            num_splits=num_splits,\n            pack_gqa=pack_gqa,\n            sm_margin=sm_margin,\n        )\n        # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)\n        ctx.save_for_backward(q, k, v, out, softmax_lse)\n        ctx.softmax_scale = softmax_scale\n        ctx.causal = causal\n        ctx.window_size = window_size\n        ctx.attention_chunk = attention_chunk\n        ctx.softcap = softcap\n        ctx.deterministic = deterministic\n        ctx.sm_margin = sm_margin\n        return (out, softmax_lse) if return_softmax else out\n\n    @staticmethod\n    def backward(ctx, dout, *args):\n        q, k, v, out, softmax_lse = ctx.saved_tensors\n        assert ctx.attention_chunk == 0, \"FA3 backward does not support attention_chunk\"\n        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)\n        _flash_attn_backward(\n            dout,\n            q,\n            k,\n            v,\n            out,\n            softmax_lse,\n            None, None, # cu_seqlens_q, cu_seqlens_k,\n            None, None, # sequed_q, sequed_k,\n            None, None, # max_seqlen_q, max_seqlen_k,\n            dq,\n            dk,\n            dv,\n            ctx.softmax_scale,\n            ctx.causal,\n            ctx.window_size[0],\n            ctx.window_size[1],\n            ctx.softcap,\n            ctx.deterministic,\n            ctx.sm_margin,\n        )\n        dq = dq[..., : q.shape[-1]]  # We could have padded the head dimension\n        dk = dk[..., : k.shape[-1]]\n        dv = dv[..., : v.shape[-1]]\n        return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None\n\n\nclass FlashAttnVarlenFunc(torch.autograd.Function):\n\n    @staticmethod\n    def forward(\n        ctx,\n        q,\n        k,\n        v,\n        cu_seqlens_q,\n        cu_seqlens_k,\n        seqused_q,\n        seqused_k,\n        max_seqlen_q,\n        max_seqlen_k,\n        softmax_scale,\n        causal,\n        qv=None,\n        q_descale=None, k_descale=None, v_descale=None,\n        window_size=(-1, -1),\n        attention_chunk=0,\n        softcap=0.0,\n        num_splits=1,\n        pack_gqa=None,\n        deterministic=False,\n        sm_margin=0,\n        return_softmax=False,\n    ):\n        if softmax_scale is None:\n            softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)\n        # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(\n        out, softmax_lse, *rest = _flash_attn_forward(\n            q,\n            k,\n            v,\n            None, None,  # k_new, v_new\n            qv,  # qv\n            None,  # out\n            cu_seqlens_q,\n            cu_seqlens_k,\n            None,   # cu_seqlens_k_new\n            seqused_q,\n            seqused_k,\n            max_seqlen_q,\n            max_seqlen_k,\n            None, None, None,   # page_table, kv_batch_idx, leftpad_k,\n            None, None, None,  # rotary_cos/sin, seqlens_rotary\n            q_descale, k_descale, v_descale,\n            softmax_scale,\n            causal=causal,\n            window_size_left=window_size[0],\n            window_size_right=window_size[1],\n            attention_chunk=attention_chunk,\n            softcap=softcap,\n            num_splits=num_splits,\n            pack_gqa=pack_gqa,\n            sm_margin=sm_margin,\n        )\n        # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)\n        ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)\n        ctx.max_seqlen_q = max_seqlen_q\n        ctx.max_seqlen_k = max_seqlen_k\n        ctx.softmax_scale = softmax_scale\n        ctx.causal = causal\n        ctx.window_size = window_size\n        ctx.attention_chunk = attention_chunk\n        ctx.softcap = softcap\n        ctx.deterministic = deterministic\n        ctx.sm_margin = sm_margin\n        return (out, softmax_lse) if return_softmax else out\n\n    @staticmethod\n    def backward(ctx, dout, *args):\n        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors\n        assert ctx.attention_chunk == 0, \"FA3 backward does not support attention_chunk\"\n        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)\n        _flash_attn_backward(\n            dout,\n            q,\n            k,\n            v,\n            out,\n            softmax_lse,\n            cu_seqlens_q,\n            cu_seqlens_k,\n            seqused_q,\n            seqused_k,\n            ctx.max_seqlen_q,\n            ctx.max_seqlen_k,\n            dq,\n            dk,\n            dv,\n            ctx.softmax_scale,\n            ctx.causal,\n            ctx.window_size[0],\n            ctx.window_size[1],\n            ctx.softcap,\n            ctx.deterministic,\n            ctx.sm_margin,\n        )\n        dq = dq[..., : q.shape[-1]]  # We could have padded the head dimension\n        dk = dk[..., : k.shape[-1]]\n        dv = dv[..., : v.shape[-1]]\n        return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None\n\n\ndef flash_attn_qkvpacked_func(\n    qkv,\n    softmax_scale=None,\n    causal=False,\n    q_descale=None, k_descale=None, v_descale=None,\n    window_size=(-1, -1),\n    attention_chunk=0,\n    softcap=0.0,\n    deterministic=False,\n    num_heads_q=None,\n    sm_margin=0,\n    return_attn_probs=False,\n):\n    \"\"\"dropout_p should be set to 0.0 during evaluation\n    If Q, K, V are already stacked into 1 tensor, this function will be faster than\n    calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation\n    of the gradients of Q, K, V.\n    For multi-query and grouped-query attention (MQA/GQA), please see\n    flash_attn_kvpacked_func and flash_attn_func.\n\n    If window_size != (-1, -1), implements sliding window local attention. Query at position i\n    will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.\n\n    Arguments:\n        qkv: (batch_size, seqlen, 3, nheads, headdim)\n        dropout_p: float. Dropout probability.\n        softmax_scale: float. The scaling of QK^T before applying softmax.\n            Default to 1 / sqrt(headdim).\n        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).\n        window_size: (left, right). If not (-1, -1), implements sliding window local attention.\n        softcap: float. Anything > 0 activates softcapping attention.\n        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to\n            the attention score of query i and key j.\n        deterministic: bool. Whether to use the deterministic implementation of the backward pass,\n            which is slightly slower and uses more memory. The forward pass is always deterministic.\n        return_attn_probs: bool. Whether to return the attention probabilities. This option is for\n           testing only. The returned probabilities are not guaranteed to be correct\n           (they might not have the right scaling).\n    Return:\n        out: (batch_size, seqlen, nheads, headdim).\n        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The\n            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax\n            normalization factor).\n        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).\n            The output of softmax (possibly with different scaling). It also encodes the dropout\n            pattern (negative means that location was dropped, nonnegative means it was kept).\n    \"\"\"\n    return FlashAttnQKVPackedFunc.apply(\n        qkv,\n        softmax_scale,\n        causal,\n        q_descale, k_descale, v_descale,\n        window_size,\n        attention_chunk,\n        softcap,\n        deterministic,\n        num_heads_q,\n        sm_margin,\n        return_attn_probs,\n    )\n\n\ndef flash_attn_func(\n    q,\n    k,\n    v,\n    softmax_scale=None,\n    causal=False,\n    qv=None,\n    q_descale=None, k_descale=None, v_descale=None,\n    window_size=(-1, -1),\n    attention_chunk=0,\n    softcap=0.0,\n    num_splits=1,\n    pack_gqa=None,\n    deterministic=False,\n    sm_margin=0,\n    return_attn_probs=False,\n):\n    \"\"\"dropout_p should be set to 0.0 during evaluation\n    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads\n    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.\n    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head\n    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.\n\n    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.\n    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:\n        1 1 1 1 0\n        1 1 1 1 1\n    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:\n        0 0\n        0 0\n        0 0\n        1 0\n        1 1\n    If the row of the mask is all zero, the output will be zero.\n\n    If window_size != (-1, -1), implements sliding window local attention. Query at position i\n    will only attend to keys between\n    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.\n\n    Arguments:\n        q: (batch_size, seqlen, nheads, headdim)\n        k: (batch_size, seqlen, nheads_k, headdim)\n        v: (batch_size, seqlen, nheads_k, headdim)\n        dropout_p: float. Dropout probability.\n        softmax_scale: float. The scaling of QK^T before applying softmax.\n            Default to 1 / sqrt(headdim).\n        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).\n        window_size: (left, right). If not (-1, -1), implements sliding window local attention.\n        alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of\n            (-alibi_slope * |i + seqlen_k - seqlen_q - j|)\n            is added to the attention score of query i and key j.\n        deterministic: bool. Whether to use the deterministic implementation of the backward pass,\n            which is slightly slower and uses more memory. The forward pass is always deterministic.\n        return_attn_probs: bool. Whether to return the attention probabilities. This option is for\n           testing only. The returned probabilities are not guaranteed to be correct\n           (they might not have the right scaling).\n    Return:\n        out: (batch_size, seqlen, nheads, headdim).\n        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The\n            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax\n            normalization factor).\n    \"\"\"\n    return FlashAttnFunc.apply(\n        q,\n        k,\n        v,\n        softmax_scale,\n        causal,\n        qv,\n        q_descale, k_descale, v_descale,\n        window_size,\n        attention_chunk,\n        softcap,\n        num_splits,\n        pack_gqa,\n        deterministic,\n        sm_margin,\n        return_attn_probs,\n    )\n\n\ndef flash_attn_varlen_func(\n    q,\n    k,\n    v,\n    cu_seqlens_q,\n    cu_seqlens_k,\n    max_seqlen_q,\n    max_seqlen_k,\n    seqused_q=None,\n    seqused_k=None,\n    softmax_scale=None,\n    causal=False,\n    qv=None,\n    q_descale=None, k_descale=None, v_descale=None,\n    window_size=(-1, -1),\n    attention_chunk=0,\n    softcap=0.0,\n    num_splits=1,\n    pack_gqa=None,\n    deterministic=False,\n    sm_margin=0,\n    return_attn_probs=False,\n):\n    return FlashAttnVarlenFunc.apply(\n        q,\n        k,\n        v,\n        cu_seqlens_q,\n        cu_seqlens_k,\n        seqused_q,\n        seqused_k,\n        max_seqlen_q,\n        max_seqlen_k,\n        softmax_scale,\n        causal,\n        qv,\n        q_descale, k_descale, v_descale,\n        window_size,\n        attention_chunk,\n        softcap,\n        num_splits,\n        pack_gqa,\n        deterministic,\n        sm_margin,\n        return_attn_probs,\n    )\n\n\ndef flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):\n    return flash_attn_3_gpu.fwd_combine(out_partial, lse_partial, out, out_dtype)\n\n\ndef flash_attn_with_kvcache(\n    q,\n    k_cache,\n    v_cache,\n    k=None,\n    v=None,\n    qv=None,\n    rotary_cos=None,\n    rotary_sin=None,\n    cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,\n    cache_batch_idx: Optional[torch.Tensor] = None,\n    cache_leftpad: Optional[torch.Tensor] = None,\n    page_table: Optional[torch.Tensor] = None,\n    cu_seqlens_q: Optional[torch.Tensor] = None,\n    cu_seqlens_k_new: Optional[torch.Tensor] = None,\n    max_seqlen_q: Optional[int] = None,\n    rotary_seqlens: Optional[torch.Tensor] = None,\n    q_descale: Optional[torch.Tensor] = None,\n    k_descale: Optional[torch.Tensor] = None,\n    v_descale: Optional[torch.Tensor] = None,\n    softmax_scale=None,\n    causal=False,\n    window_size=(-1, -1),  # -1 means infinite context window\n    attention_chunk=0,\n    softcap=0.0, # 0.0 means deactivated\n    rotary_interleaved=True,\n    scheduler_metadata=None,\n    num_splits=0,    # Can be tuned for speed\n    pack_gqa=None,   # Can be tuned for speed\n    sm_margin=0,     # Can be tuned if some SMs are used for communication\n    return_softmax_lse=False,\n):\n    \"\"\"\n    If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from\n    k and v. This is useful for incremental decoding: you can pass in the cached keys/values from\n    the previous step, and update them with the new keys/values from the current step, and do\n    attention with the updated cache, all in 1 kernel.\n\n    If you pass in k / v, you must make sure that the cache is large enough to hold the new values.\n    For example, the KV cache could be pre-allocated with the max sequence length, and you can use\n    cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.\n\n    Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be\n    rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.\n    If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos\n    and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.\n    If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at\n    indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).\n\n    See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.\n\n    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads\n    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.\n    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head\n    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.\n\n    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.\n    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:\n        1 1 1 1 0\n        1 1 1 1 1\n    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:\n        0 0\n        0 0\n        0 0\n        1 0\n        1 1\n    If the row of the mask is all zero, the output will be zero.\n\n    If window_size != (-1, -1), implements sliding window local attention. Query at position i\n    will only attend to keys between\n    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.\n\n    Note: Does not support backward pass.\n\n    Arguments:\n        q: (batch_size, seqlen, nheads, headdim)\n        k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,\n            or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)\n            page_block_size can be arbitrary (e.g, 1, 2, 3, 64, etc.).\n        v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,\n            or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)\n        k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate\n            k with k_cache, starting at the indices specified by cache_seqlens.\n        v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.\n        qv [optional]: (batch_size, seqlen, nheads, headdim_v)\n        rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding\n            to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.\n        rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.\n        cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the\n            KV cache.\n        cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.\n            If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].\n            If the indices are not distinct, and k and v are provided, the values updated in the cache\n                 might come from any of the duplicate indices.\n        cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.\n        page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.\n        softmax_scale: float. The scaling of QK^T before applying softmax.\n            Default to 1 / sqrt(headdim).\n        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).\n        window_size: (left, right). If not (-1, -1), implements sliding window local attention.\n        softcap: float. Anything > 0 activates softcapping attention.\n        rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.\n            If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,\n            rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1\n            (i.e. GPT-NeoX style).\n        num_splits: int. If > 1, split the key/value into this many chunks along the sequence.\n           If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic\n           to automatically determine the number of splits.\n           Don't change this unless you know what you are doing.\n        return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.\n\n    Return:\n        out: (batch_size, seqlen, nheads, headdim).\n        softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The\n            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax\n            normalization factor).\n    \"\"\"\n    assert k_cache.stride(-1) == 1, \"k_cache must have contiguous last dimension\"\n    assert v_cache.stride(-1) == 1, \"v_cache must have contiguous last dimension\"\n    if softmax_scale is None:\n        softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)\n    if cache_seqlens is not None and isinstance(cache_seqlens, int):\n        cache_seqlens = torch.full(\n            (q.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device\n        )\n        cache_seqlens = maybe_contiguous(cache_seqlens)\n    out, softmax_lse, *rest = _flash_attn_forward(\n        q,\n        k_cache,\n        v_cache,\n        k,\n        v,\n        qv,\n        None,  # out\n        cu_seqlens_q,\n        None,  # cu_seqlens_k\n        cu_seqlens_k_new,\n        None,  # seqused_q\n        cache_seqlens,\n        max_seqlen_q,\n        None,  # max_seqlen_k\n        page_table,\n        cache_batch_idx,\n        cache_leftpad,\n        rotary_cos,\n        rotary_sin,\n        rotary_seqlens,\n        q_descale, k_descale, v_descale,\n        softmax_scale,\n        causal=causal,\n        window_size_left=window_size[0],\n        window_size_right=window_size[1],\n        attention_chunk=attention_chunk,\n        softcap=softcap,\n        rotary_interleaved=rotary_interleaved,\n        scheduler_metadata=scheduler_metadata,\n        num_splits=num_splits,\n        pack_gqa=pack_gqa,\n        sm_margin=sm_margin,\n    )\n    # return (out, softmax_lse) if return_softmax_lse else out\n    return (out, softmax_lse, *rest) if return_softmax_lse else out\n\n\ndef get_scheduler_metadata(\n    batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,\n    cache_seqlens: torch.Tensor,\n    qkv_dtype=torch.bfloat16,\n    headdim_v=None,\n    cu_seqlens_q: Optional[torch.Tensor] = None,\n    cu_seqlens_k_new: Optional[torch.Tensor] = None,\n    cache_leftpad: Optional[torch.Tensor] = None,\n    page_size: Optional[int] = None,\n    max_seqlen_k_new=0,\n    causal=False,\n    window_size=(-1, -1),  # -1 means infinite context window\n    attention_chunk=0,\n    has_softcap=False,\n    num_splits=0,    # Can be tuned for speed\n    pack_gqa=None,   # Can be tuned for speed\n    sm_margin=0,     # Can be tuned if some SMs are used for communication\n):\n    cache_seqlens = maybe_contiguous(cache_seqlens)\n    if headdim_v is None:\n        headdim_v = headdim\n    scheduler_metadata = flash_attn_3_gpu.get_scheduler_metadata(\n        batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,\n        qkv_dtype,\n        cache_seqlens,\n        cu_seqlens_q,\n        None,  # cu_seqlens_k\n        cu_seqlens_k_new,\n        None,  # seqused_q\n        cache_leftpad,\n        page_size,\n        max_seqlen_k_new,\n        causal,\n        window_size[0], window_size[1],\n        attention_chunk,\n        has_softcap,\n        num_splits,\n        pack_gqa,\n        sm_margin,\n    )\n    return scheduler_metadata\n"
  },
  {
    "path": "hopper/flash_bwd_kernel_sm80.h",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n#include \"cute/tensor.hpp\"\n\n#include <cutlass/cutlass.h>\n#include <cutlass/array.h>\n#include <cutlass/numeric_types.h>\n#include <cutlass/kernel_hardware_info.h>\n\n#include \"utils.h\"\n\nnamespace flash {\n\nusing namespace cute;\n\ntemplate <class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>\nclass FlashAttnBwdSm80 {\n\npublic:\n\n    // Type Aliases\n    static constexpr bool Is_causal = CollectiveMainloop_::Is_causal;\n    static constexpr bool Is_local = CollectiveMainloop_::Is_local;\n    static_assert(CollectiveMainloop_::Varlen == CollectiveEpilogue_::Varlen);\n    static constexpr bool Varlen = CollectiveMainloop_::Varlen;\n\n    // Mainloop derived types\n    using CollectiveMainloop = CollectiveMainloop_;\n    using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK;\n    using TiledMmaSdP = typename CollectiveMainloop::TiledMmaSdP;\n    using TiledMmadKV = typename CollectiveMainloop::TiledMmadKV;\n    using ArchTag = typename CollectiveMainloop::ArchTag;\n    using MainloopArguments = typename CollectiveMainloop::Arguments;\n    using MainloopParams = typename CollectiveMainloop::Params;\n    static constexpr bool dKV_swapAB = CollectiveMainloop::dKV_swapAB;\n\n    // Epilogue derived types\n    using CollectiveEpilogue = CollectiveEpilogue_;\n    using EpilogueArguments = typename CollectiveEpilogue::Arguments;\n    using EpilogueParams = typename CollectiveEpilogue::Params;\n\n    static_assert(ArchTag::kMinComputeCapability >= 80);\n\n    using TileScheduler = TileScheduler_;\n    using TileSchedulerArguments = typename flash::TileSchedulerArguments;\n    using TileSchedulerParams = typename TileScheduler::Params;\n\n    static constexpr uint32_t NumThreads = CUTE_STATIC_V(size(TiledMmaSdP{}));\n    static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaSdP{}));\n    static constexpr uint32_t MinBlocksPerMultiprocessor = 1;\n\n    // Kernel level shared memory storage\n    struct SharedStorage {\n        struct TensorStorage : cute::aligned_struct<128> {\n            union {\n                typename CollectiveMainloop::TensorStorage mainloop;\n                typename CollectiveEpilogue::TensorStorage epilogue;\n            };\n        } tensors;\n\n        alignas(16) typename TileScheduler::SharedStorage smem_scheduler;\n\n    };\n\n    static constexpr int SharedStorageSize = sizeof(SharedStorage);\n\n    // Device side arguments\n    struct Arguments {\n        MainloopArguments mainloop{};\n        EpilogueArguments epilogue{};\n        cutlass::KernelHardwareInfo hw_info{};\n        TileSchedulerArguments scheduler{};\n    };\n\n    // Kernel entry point API\n    struct Params {\n        MainloopParams mainloop{};\n        EpilogueParams epilogue{};\n        cutlass::KernelHardwareInfo hw_info{};\n        TileSchedulerParams scheduler{};\n    };\n\n    //\n    // Methods\n    //\n\n    // Convert to underlying arguments. In this case, a simple copy for the aliased type.\n    static\n    Params\n    to_underlying_arguments(Arguments const& args) {\n        CUTLASS_TRACE_HOST(\"to_underlying_arguments():\");\n\n        // Get SM count if needed, otherwise use user supplied SM count\n        int sm_count = args.hw_info.sm_count;\n        if (sm_count <= 0) {\n            CUTLASS_TRACE_HOST(\"  WARNING: Arguments do not include a valid SM count.\\n\"\n                \"  For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.\");\n            sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);\n        }\n\n        CUTLASS_TRACE_HOST(\"to_underlying_arguments(): Setting persistent grid SM count to \" << sm_count);\n\n        cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};\n        return {\n            CollectiveMainloop::to_underlying_arguments(args.mainloop),\n            CollectiveEpilogue::to_underlying_arguments(args.epilogue),\n            hw_info,\n            TileScheduler::to_underlying_arguments(args.scheduler)\n        };\n    }\n\n    // Computes the kernel launch grid shape based on runtime parameters\n    static dim3\n    get_grid_shape(Params const& params) {\n        return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count);\n    }\n\n    static dim3\n    get_block_shape() {\n        return dim3(MaxThreadsPerBlock, 1, 1);\n    }\n\n    CUTLASS_DEVICE\n    void\n    operator()(Params const& params, char* smem_buf) {\n\n        static constexpr int kBlockM = get<0>(TileShape_MNK{});\n        static constexpr int kBlockN = get<1>(TileShape_MNK{});\n\n        SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);\n\n        CollectiveMainloop mainloop;\n        CollectiveEpilogue epilogue;\n\n        TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.smem_scheduler));\n        // Initialize matmul objects.\n        TiledMmadKV tiled_mma_dKV;\n\n        scheduler.init_consumer();\n\n        int warp_idx = cutlass::canonical_warp_idx_sync();\n        CUTLASS_PRAGMA_NO_UNROLL\n        for (auto work_tile_info = warp_idx == 0 ? scheduler.template get_initial_work</*IsProducerWarp=*/true>(params.scheduler) : scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);\n             work_tile_info.is_valid(params.scheduler);\n             work_tile_info = warp_idx == 0 ? scheduler.template get_next_work</*IsProducerWarp=*/true>(params.scheduler, work_tile_info) : scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {\n\n            auto block_coord_ = work_tile_info.get_block_coord(params.scheduler);\n            auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_;\n            cute::tuple<int32_t, int32_t, int32_t> block_coord = {n_block, bidh, bidb};\n\n            // dK and dV output accumulator.\n            Tensor tdKrdK = partition_fragment_C(tiled_mma_dKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));\n            Tensor tdVrdV = partition_fragment_C(tiled_mma_dKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));\n            bool tile_valid = mainloop.mma(params.mainloop, tdKrdK, tdVrdV, threadIdx.x,\n                                           block_coord, shared_storage);\n            scheduler.prefetch_next_work(params.scheduler, work_tile_info);\n            if (tile_valid) {\n                epilogue.store(params.epilogue, tdKrdK, tdVrdV, shared_storage, tiled_mma_dKV,\n                               threadIdx.x, block_coord);\n            } else {\n                epilogue.store_zero(params.epilogue, threadIdx.x, block_coord);\n            }\n        }\n\n    }\n\n};\n\n} // namespace flash\n"
  },
  {
    "path": "hopper/flash_bwd_kernel_sm90.h",
    "content": "\n/******************************************************************************\n * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n#include \"cute/tensor.hpp\"\n\n#include <cutlass/cutlass.h>\n#include <cutlass/arch/reg_reconfig.h>\n#include <cutlass/array.h>\n#include <cutlass/numeric_types.h>\n#include <cutlass/numeric_conversion.h>\n#include <cutlass/kernel_hardware_info.h>\n#include \"cutlass/pipeline/pipeline.hpp\"\n\n#include \"utils.h\"\n\nnamespace flash {\n\nusing namespace cute;\n\ntemplate <class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>\nclass FlashAttnBwdSm90 {\n\npublic:\n\n    // Type Aliases\n    static constexpr bool Is_causal = CollectiveMainloop_::Is_causal;\n    static constexpr bool Is_local = CollectiveMainloop_::Is_local;\n    static_assert(CollectiveMainloop_::Varlen == CollectiveEpilogue_::Varlen);\n    static constexpr bool Varlen = CollectiveMainloop_::Varlen;\n\n    // Mainloop derived types\n    using CollectiveMainloop = CollectiveMainloop_;\n    using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK;\n    using TiledMmaSdP = typename CollectiveMainloop::TiledMmaSdP;\n    using TiledMmadKV = typename CollectiveMainloop::TiledMmadKV;\n    using ArchTag = typename CollectiveMainloop::ArchTag;\n    using ClusterShape = typename CollectiveMainloop::ClusterShape;\n    using MainloopArguments = typename CollectiveMainloop::Arguments;\n    using MainloopParams = typename CollectiveMainloop::Params;\n    static constexpr bool dKV_swapAB = CollectiveMainloop::dKV_swapAB;\n\n    // Epilogue derived types\n    using CollectiveEpilogue = CollectiveEpilogue_;\n    using EpilogueArguments = typename CollectiveEpilogue::Arguments;\n    using EpilogueParams = typename CollectiveEpilogue::Params;\n\n    static_assert(ArchTag::kMinComputeCapability >= 90);\n\n    using TileScheduler = TileScheduler_;\n    using TileSchedulerArguments = typename flash::TileSchedulerArguments;\n    using TileSchedulerParams = typename TileScheduler::Params;\n\n    static constexpr uint32_t NumLoadWarpGroups = 1;\n    static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMmaSdP{})) / cutlass::NumThreadsPerWarpGroup;\n    static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaSdP{})) + (NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup);\n    static constexpr uint32_t MinBlocksPerMultiprocessor = 1;\n    static_assert(NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3);\n\n    /// Register requirement for Load and Math WGs\n    static constexpr uint32_t LoadRegisterRequirement = NumMmaWarpGroups == 2 ? 24 : 32;\n    static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 2 ? 240 : 160;\n    // If you want to print from the producer warp, you'd need to increase the number of registers\n    // Otherwise you'll get CUDA error.\n    // static constexpr uint32_t LoadRegisterRequirement = 40;\n    // static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 2 ? 232 : 152;\n\n    // Kernel level shared memory storage\n    struct SharedStorage {\n        struct TensorStorage : cute::aligned_struct<128> {\n            union {\n                typename CollectiveMainloop::TensorStorage mainloop;\n                typename CollectiveEpilogue::TensorStorage epilogue;\n            };\n        } tensors;\n\n        struct PipelineStorage : cute::aligned_struct<16> {\n            alignas(16) cutlass::arch::ClusterTransactionBarrier barrier_KV;\n            alignas(16) typename CollectiveMainloop::MainloopPipeline::SharedStorage pipeline_q;\n            alignas(16) typename CollectiveMainloop::MainloopPipeline_dO::SharedStorage pipeline_do;\n            alignas(16) typename TileScheduler::SharedStorage smem_scheduler;\n        } pipelines;\n\n    };\n\n    static constexpr int SharedStorageSize = sizeof(SharedStorage);\n\n    // Device side arguments\n    struct Arguments {\n        MainloopArguments mainloop{};\n        EpilogueArguments epilogue{};\n        cutlass::KernelHardwareInfo hw_info{};\n        TileSchedulerArguments scheduler{};\n    };\n\n    // Kernel entry point API\n    struct Params {\n        MainloopParams mainloop{};\n        EpilogueParams epilogue{};\n        cutlass::KernelHardwareInfo hw_info{};\n        TileSchedulerParams scheduler{};\n    };\n\n    //\n    // Methods\n    //\n\n    // Convert to underlying arguments. In this case, a simple copy for the aliased type.\n    static\n    Params\n    to_underlying_arguments(Arguments const& args) {\n        CUTLASS_TRACE_HOST(\"to_underlying_arguments():\");\n\n        // Get SM count if needed, otherwise use user supplied SM count\n        int sm_count = args.hw_info.sm_count;\n        if (sm_count <= 0) {\n            CUTLASS_TRACE_HOST(\"  WARNING: Arguments do not include a valid SM count.\\n\"\n                \"  For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.\");\n            sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);\n        }\n\n        CUTLASS_TRACE_HOST(\"to_underlying_arguments(): Setting persistent grid SM count to \" << sm_count);\n\n        cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};\n        return {\n            CollectiveMainloop::to_underlying_arguments(args.mainloop),\n            CollectiveEpilogue::to_underlying_arguments(args.epilogue),\n            hw_info,\n            TileScheduler::to_underlying_arguments(args.scheduler)\n        };\n    }\n\n    // Computes the kernel launch grid shape based on runtime parameters\n    static dim3\n    get_grid_shape(Params const& params) {\n        return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count);\n    }\n\n    static dim3\n    get_block_shape() {\n        return dim3(MaxThreadsPerBlock, 1, 1);\n    }\n\n    CUTLASS_DEVICE\n    void\n    operator()(Params const& params, char* smem_buf) {\n\n        static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup;\n        static constexpr int NumCopyThreads = NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup;\n        static constexpr int kBlockM = get<0>(TileShape_MNK{});\n        static constexpr int kBlockN = get<1>(TileShape_MNK{});\n\n        using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;\n        using PipelineParams = typename MainloopPipeline::Params;\n        using PipelineState = typename MainloopPipeline::PipelineState;\n        using MainloopPipeline_dO = typename CollectiveMainloop::MainloopPipeline_dO;\n        using PipelineParams_dO = typename MainloopPipeline_dO::Params;\n        using PipelineState_dO = typename MainloopPipeline_dO::PipelineState;\n        static constexpr bool Q_dO_same_stages = std::is_same_v<MainloopPipeline, MainloopPipeline_dO>;\n\n        SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);\n\n        int const lane_predicate = cute::elect_one_sync();\n        int const warp_idx = cutlass::canonical_warp_idx_sync();\n\n        // Issue Tma Descriptor Prefetch from a single thread\n        if (warp_idx == 0 && lane_predicate) {\n            CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);\n            CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);\n        }\n\n        // Obtain warp index\n        int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;\n\n        PipelineParams pipeline_params;\n        pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesQ + CollectiveMainloop::TmaTransactionBytesLSE;\n        int warp_group_idx = cutlass::canonical_warp_group_idx();\n        pipeline_params.role = warp_group_idx == 0\n            ? MainloopPipeline::ThreadCategory::Producer\n            : MainloopPipeline::ThreadCategory::Consumer;\n        pipeline_params.is_leader = warp_group_thread_idx == 0;\n        pipeline_params.num_consumers = NumMmaThreads;\n\n        if (warp_idx == 0 && lane_predicate) {\n            shared_storage.pipelines.barrier_KV.init(1 /*numThreads*/);\n        }\n        // We're counting on pipeline_q to call cutlass::arch::fence_barrier_init();\n        MainloopPipeline pipeline_q(shared_storage.pipelines.pipeline_q, pipeline_params, ClusterShape{});\n        auto role_dO = warp_group_idx == 0\n            ? MainloopPipeline_dO::ThreadCategory::Producer\n            : MainloopPipeline_dO::ThreadCategory::Consumer;\n        PipelineParams_dO pipeline_params_dO {pipeline_params.transaction_bytes, role_dO, pipeline_params.is_leader, pipeline_params.num_consumers};\n        MainloopPipeline_dO pipeline_do(shared_storage.pipelines.pipeline_do, cute::conditional_return<Q_dO_same_stages>(pipeline_params, pipeline_params_dO), ClusterShape{});\n\n        CollectiveMainloop mainloop;\n        CollectiveEpilogue epilogue;\n\n        // We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster\n        if constexpr (size(ClusterShape{}) > 1) {\n            cute::cluster_arrive_relaxed();\n            cute::cluster_wait();\n        } else {\n            __syncthreads();\n        }\n\n        TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.pipelines.smem_scheduler));\n\n        if (warp_group_idx == 0) {  // Producer\n            cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();\n\n            int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);\n            if (warp_idx_in_warpgroup == 0) {  // Load K, V, and do TMA on Q and dO\n                PipelineState smem_pipe_write = cutlass::make_producer_start_state<MainloopPipeline>();\n                PipelineState_dO smem_pipe_write_do = cutlass::make_producer_start_state<MainloopPipeline_dO>();\n                for (auto work_tile_info = scheduler.template get_initial_work</*IsProducerWarp=*/true>(params.scheduler);\n                     work_tile_info.is_valid(params.scheduler);\n                     work_tile_info = scheduler.template get_next_work</*IsProducerWarp=*/true>(params.scheduler, work_tile_info)) {\n                    auto block_coord_ = work_tile_info.get_block_coord(params.scheduler);\n                    auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_;\n                    cute::tuple<int32_t, int32_t, int32_t> block_coord = {n_block, bidh, bidb};\n                    auto scheduler_prefetch = [&scheduler, &params, &work_tile_info]() {\n                        scheduler.prefetch_next_work(params.scheduler, work_tile_info);\n                    };\n                    mainloop.load(params.mainloop, pipeline_q, pipeline_do, smem_pipe_write,\n                                  smem_pipe_write_do, shared_storage, scheduler_prefetch, block_coord);\n                }\n                mainloop.load_tail(pipeline_q, pipeline_do, smem_pipe_write, smem_pipe_write_do);\n            } else if (warp_idx_in_warpgroup == 1) {\n                for (auto work_tile_info = scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);\n                     work_tile_info.is_valid(params.scheduler);\n                     work_tile_info = scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {\n                    auto block_coord_ = work_tile_info.get_block_coord(params.scheduler);\n                    auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_;\n                    cute::tuple<int32_t, int32_t, int32_t> block_coord = {n_block, bidh, bidb};\n                    mainloop.store_dq(params.mainloop, shared_storage, block_coord);\n                }\n            }\n        } else {  // Consumer\n            cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();\n            // Initialize matmul objects.\n            TiledMmadKV tiled_mma_dKV;\n\n            PipelineState smem_pipe_read;\n            PipelineState_dO smem_pipe_read_do;\n\n            mainloop.mma_init();\n            scheduler.init_consumer();\n\n            int work_idx = 0;\n            CUTLASS_PRAGMA_NO_UNROLL\n            for (auto work_tile_info = scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);\n                 work_tile_info.is_valid(params.scheduler);\n                 work_tile_info = scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {\n                auto block_coord_ = work_tile_info.get_block_coord(params.scheduler);\n                auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_;\n                cute::tuple<int32_t, int32_t, int32_t> block_coord = {n_block, bidh, bidb};\n\n                // dK and dV output accumulator.\n                Tensor tdKrdK = partition_fragment_C(tiled_mma_dKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));\n                Tensor tdVrdV = partition_fragment_C(tiled_mma_dKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));\n                bool tile_valid = mainloop.mma(\n                    params.mainloop, pipeline_q, pipeline_do, smem_pipe_read, smem_pipe_read_do,\n                    tdKrdK, tdVrdV, threadIdx.x - NumCopyThreads, work_idx, block_coord, shared_storage);\n                if (tile_valid) {\n                    epilogue.store(params.epilogue, tdKrdK, tdVrdV, shared_storage, tiled_mma_dKV,\n                                   threadIdx.x - NumCopyThreads, block_coord);\n                } else {\n                    epilogue.store_zero(params.epilogue, threadIdx.x - NumCopyThreads, block_coord);\n                }\n\n            }\n            epilogue.store_tail();\n        }\n\n    }\n\n};\n\n} // namespace flash\n"
  },
  {
    "path": "hopper/flash_bwd_launch_template.h",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n#include \"cute/tensor.hpp\"\n\n#include \"cutlass/device_kernel.h\"  // For device_kernel\n#include \"cutlass/kernel_launch.h\"  // For kernel_launch\n#include \"cutlass/cluster_launch.hpp\"  // For ClusterLauncher\n\n#include \"cuda_check.h\"\n#include \"static_switch.h\"\n#include \"flash.h\"\n#include \"flash_bwd_preprocess_kernel.h\"\n#include \"flash_bwd_postprocess_kernel.h\"\n#include \"tile_scheduler.hpp\"\n#include \"mainloop_bwd_sm90_tma_gmma_ws.hpp\"\n#include \"mainloop_bwd_sm80.hpp\"\n#include \"epilogue_bwd.hpp\"\n#include \"flash_bwd_kernel_sm90.h\"\n#include \"flash_bwd_kernel_sm80.h\"\n\nusing namespace cute;\n\ntemplate <int Arch, int kHeadDim, int kBlockM, int kBlockN, typename Element,\n          bool Is_causal, bool Is_local, bool Has_softcap, bool Varlen, bool Deterministic, bool GQA,\n          int Stages_dO=2, int Stages_dS_or_QSm80=2,\n          bool SdP_swapAB=true, bool dKV_swapAB=false, bool dQ_swapAB=false,\n          int NumMmaWarpGroups=2, int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1,\n          bool V_in_regs=false>\nvoid run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {\n    static_assert(!(Is_causal && Is_local), \"Is_causal and Is_local cannot be true at the same time.\");\n    using ElementAccum = float;\n    using ArchTag = std::conditional_t<Arch >= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>;\n\n    int const total_q_padded_rounded = cute::round_up(params.total_q + params.b * kBlockM, kBlockM);\n    int const total_k_padded_rounded = cute::round_up(params.total_k + params.b * kBlockN, kBlockN);\n    bool const is_varlen_q = params.cu_seqlens_q;\n    bool const is_varlen_k = params.cu_seqlens_k;\n    int seqlen_q = !is_varlen_q ? params.seqlen_q : params.total_q;\n    int seqlen_k = !is_varlen_k ? params.seqlen_k : params.total_k;\n    int seqlen_q_rounded = !is_varlen_q ? params.seqlen_q_rounded : total_q_padded_rounded;\n    int seqlen_k_rounded = !is_varlen_k ? params.seqlen_k_rounded : total_k_padded_rounded;\n    int batch_q = !is_varlen_q ? params.b : 1;\n    int batch_k = !is_varlen_k ? params.b : 1;\n\n    using TileShape_MK = cute::Shape<Int<kBlockM>, Int<kHeadDim>>;\n    using PreprocessKernel = flash::FlashAttnBwdPreprocess<TileShape_MK, Element, ElementAccum, ArchTag, /*Clear_dQaccum=*/true, Varlen>;\n    typename PreprocessKernel::Arguments preprocess_args {\n        static_cast<Element const*>(params.o_ptr),\n        {seqlen_q, params.dv, params.h, batch_q},  // shape_O\n        {params.o_row_stride, _1{}, params.o_head_stride, !is_varlen_q ? params.o_batch_stride : 0},  // stride_O\n        static_cast<Element const*>(params.do_ptr),\n        {params.do_row_stride, _1{}, params.do_head_stride, !is_varlen_q ? params.do_batch_stride : 0},  // stride_dO\n        static_cast<float*>(params.dsoftmax_sum),\n        {seqlen_q_rounded, params.h, batch_q},  // shape_dPsum\n        {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0},  // stride_dPsum\n        static_cast<float*>(params.softmax_lse_ptr),\n        {_1{}, seqlen_q, !is_varlen_q ? params.h * params.seqlen_q : 0},  // stride_LSE\n        static_cast<float*>(params.softmax_lse_log2_ptr),\n        {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0},  // stride_LSE_log2\n        static_cast<ElementAccum*>(params.dq_accum_ptr),\n        {seqlen_q_rounded * params.d_rounded, params.h, batch_q},  // shape_dQaccum\n        {_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * seqlen_q_rounded * params.h : 0},  // stride_dQaccum\n        params.b,\n        params.dq_semaphore,\n        params.cu_seqlens_q,\n        params.seqused_q\n    };\n    typename PreprocessKernel::Params preprocess_params = PreprocessKernel::to_underlying_arguments(preprocess_args);\n    int num_m_block = cute::ceil_div(params.seqlen_q, kBlockM);\n    dim3 grid_m(num_m_block, params.h, params.b);\n    CHECK_CUTLASS(cutlass::kernel_launch<PreprocessKernel>(grid_m, PreprocessKernel::MaxThreadsPerBlock, PreprocessKernel::SharedStorageSize, stream, preprocess_params, false /*launch_with_pdl*/));\n\n    using TileShape_MNK = cute::Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;\n    using ClusterShape = cute::Shape<_1, Int<1>, _1>;  // Currently doesn't not support cluster\n    // Stages_dS_or_QSm80 is Stages_dS if Sm90 and Stages if Sm80\n    static constexpr int Stages = Arch >= 90 ? 2 : Stages_dS_or_QSm80;\n    static constexpr int Stages_dS = Arch >= 90 ? Stages_dS_or_QSm80 : 1;\n    using CollectiveMainloop = std::conditional_t<\n        Arch >= 90,\n        flash::CollectiveMainloopBwdSm90<Stages, Stages_dO, Stages_dS, ClusterShape, TileShape_MNK, Element, ElementAccum, cutlass::arch::Sm90,\n            Is_causal, Is_local, Has_softcap, Varlen, Deterministic,\n            SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs>,\n        flash::CollectiveMainloopBwdSm80<Stages, Stages_dO, TileShape_MNK, Element, ElementAccum, cutlass::arch::Sm80,\n            Is_causal, Is_local, Has_softcap, Varlen, Deterministic,\n            SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs>\n    >;\n    using CollectiveEpilogue = std::conditional_t<\n        !GQA,\n        flash::CollectiveEpilogueBwd<TileShape_MNK, Element, ArchTag, CollectiveMainloop::NumMmaThreads, Varlen, dKV_swapAB, NumMmaWarpGroups * (Arch >= 90 ? 1 : cutlass::NumWarpsPerWarpGroup) / AtomLayoutNdKV>,\n        flash::CollectiveEpilogueBwdGQA<TileShape_MNK, ElementAccum, ArchTag, CollectiveMainloop::NumMmaThreads, Varlen, Deterministic>\n    >;\n    using Scheduler = std::conditional_t<\n        Is_causal,\n        flash::SingleTileBwdLPTScheduler<Varlen, kBlockN, Is_causal && Deterministic /*SPT*/>,\n        flash::SingleTileScheduler<Varlen, false /*Split*/, false /*PackGQA*/, kBlockN>\n    >;\n    using AttnKernel = std::conditional_t<\n        Arch >= 90,\n        flash::enable_sm90<flash::FlashAttnBwdSm90<CollectiveMainloop, CollectiveEpilogue, Scheduler>>,\n        flash::enable_sm80_to_sm89<flash::FlashAttnBwdSm80<CollectiveMainloop, CollectiveEpilogue, Scheduler>>\n    >;\n\n    typename CollectiveMainloop::Arguments mainloop_args {\n        static_cast<Element const*>(params.q_ptr),\n        {seqlen_q, params.d, params.h, batch_q},  // shape_Q\n        {params.q_row_stride, _1{}, params.q_head_stride, !is_varlen_q ? params.q_batch_stride : 0},  // stride_Q\n        static_cast<Element const*>(params.k_ptr),\n        {seqlen_k, params.d, params.h_k, batch_k},  // shape_K\n        {params.k_row_stride, _1{}, params.k_head_stride, !is_varlen_k ? params.k_batch_stride : 0},  // stride_K\n        static_cast<Element const*>(params.v_ptr),\n        {seqlen_k, params.dv, params.h_k, batch_k},  // shape_V\n        {params.v_row_stride, _1{}, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0},  // stride_V\n        static_cast<Element const*>(params.do_ptr),\n        {seqlen_q, params.dv, params.h, batch_q},  // shape_dO\n        {params.do_row_stride, _1{}, params.do_head_stride, !is_varlen_q ? params.do_batch_stride : 0},  // stride_dO\n        static_cast<ElementAccum*>(params.dq_accum_ptr),\n        {seqlen_q_rounded * params.d_rounded, params.h, batch_q},  // shape_dQaccum\n        {_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQaccum\n        static_cast<float*>(params.softmax_lse_log2_ptr),\n        {seqlen_q_rounded, params.h, batch_q},  // shape_LSE\n        {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0},  // stride_LSE_log2\n        static_cast<float*>(params.dsoftmax_sum),\n        {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0},  // stride_dPsum\n        params.scale_softmax,\n        params.window_size_left, params.window_size_right, 0 /*attention_chunk*/,\n        params.softcap,\n        params.b,\n        params.dq_semaphore,\n        params.cu_seqlens_q, params.cu_seqlens_k,\n        params.seqused_q, params.seqused_k\n    };\n    // The case work with GQA is ugly but idk how to fix it.\n    typename CollectiveEpilogue::Arguments epilogue_args {\n        static_cast<typename CollectiveEpilogue::Element*>(!GQA ? params.dk_ptr : params.dk_accum_ptr),\n        [&] {\n            if constexpr (!GQA) {\n                return typename CollectiveEpilogue::ShapedKV {seqlen_k, params.d, params.h, batch_k};  // shape_dK\n            } else {\n                return typename CollectiveEpilogue::ShapedKV {seqlen_k_rounded * params.d_rounded, params.h_k, batch_k};  // shape_dKaccum\n            }\n        }(),\n        [&] {\n            if constexpr (!GQA) {\n                return typename CollectiveEpilogue::StridedKV {params.dk_row_stride, _1{}, params.dk_head_stride, !is_varlen_k ? params.dk_batch_stride : 0};  // stride_dK\n            } else {\n                return typename CollectiveEpilogue::StridedKV {_1{}, params.d_rounded * seqlen_k_rounded, !is_varlen_k ? params.h_k * params.d_rounded * params.seqlen_k_rounded : 0};  // stride_dKaccum\n            }\n        }(),\n        static_cast<typename CollectiveEpilogue::Element*>(!GQA ? params.dv_ptr : params.dv_accum_ptr),\n        [&] {\n            if constexpr (!GQA) {\n                return typename CollectiveEpilogue::ShapedKV {seqlen_k, params.dv, params.h, batch_k};  // shape_dV\n            } else {\n                return typename CollectiveEpilogue::ShapedKV {seqlen_k_rounded * params.dv_rounded, params.h_k, batch_k};  // shape_dVaccum\n            }\n        }(),\n        [&] {\n            if constexpr (!GQA) {\n                return typename CollectiveEpilogue::StridedKV {params.dv_row_stride, _1{}, params.dv_head_stride, !is_varlen_k ? params.dv_batch_stride : 0};  // stride_dV\n            } else {\n                return typename CollectiveEpilogue::StridedKV {_1{}, params.dv_rounded * seqlen_k_rounded, !is_varlen_k ? params.h_k * params.dv_rounded * params.seqlen_k_rounded : 0};  // stride_dVaccum\n            }\n        }(),\n        params.b,\n        params.h,\n        params.dk_semaphore,\n        params.dv_semaphore,\n        params.cu_seqlens_k,\n        params.seqused_k,\n    };\n\n    int num_blocks_n = cutlass::ceil_div(params.seqlen_k, get<1>(TileShape_MNK{}));\n    num_blocks_n = cutlass::round_up(num_blocks_n, size<1>(ClusterShape{}));\n    typename flash::TileSchedulerArguments scheduler_args {\n        num_blocks_n, params.h, params.b, 1 /*num_splits*/,\n        params.h / params.h_k,\n        params.seqlen_k,\n        params.seqlen_q, params.d, params.dv, sizeof(Element),\n        params.tile_count_semaphore, params.cu_seqlens_k, params.seqused_k\n    };\n\n    int device;\n    cudaGetDevice(&device);\n    typename AttnKernel::Params kernel_params = AttnKernel::to_underlying_arguments({\n        mainloop_args, epilogue_args, {device, params.num_sm}, scheduler_args\n    });\n\n    dim3 grid_dims = AttnKernel::get_grid_shape(kernel_params);\n    dim3 block_dims = AttnKernel::get_block_shape();\n    int smem_size = AttnKernel::SharedStorageSize;\n    // int smem_size_q = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_q));\n    // int smem_size_do = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_do));\n    // int smem_size_ds = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_ds));\n    // int smem_size_dqacc = [&] {\n    //     if constexpr (Arch >= 90) {\n    //         return sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_dqacc));\n    //     } else {\n    //         return 0;\n    //     }\n    // }();\n    // int smem_size_k = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_k));\n    // int smem_size_v = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v));\n    // int smem_size_lse = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_lse));\n    // int smem_size_dpsum = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_dpsum));\n    // printf(\"smem_size = %d, q = %d, k = %d, v = %d, do = %d, ds = %d, dqacc = %d, lse = %d, dpsum = %d\\n\", smem_size, smem_size_q, smem_size_k, smem_size_v, smem_size_do, smem_size_ds, smem_size_dqacc, smem_size_lse, smem_size_dpsum);\n    if constexpr (size(ClusterShape{}) > 1) {\n        void const* kernel = (void const*) cutlass::device_kernel<AttnKernel>;\n        if (smem_size >= 48 * 1024) {\n            CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));\n        }\n        dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));\n        CHECK_CUTLASS(cutlass::ClusterLauncher::launch(\n            grid_dims, cluster_dims, block_dims, smem_size, stream, kernel, kernel_params, false /*launch_with_pdl*/));\n    } else {\n        if (smem_size >= 48 * 1024) {\n            CHECK_CUDA(cudaFuncSetAttribute(cutlass::device_kernel<AttnKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));\n        }\n        CHECK_CUTLASS(cutlass::kernel_launch<AttnKernel>(grid_dims, block_dims, smem_size, stream, kernel_params, false /*launch_with_pdl*/));\n    }\n\n    using PostprocessKernel = flash::FlashAttnBwdPostprocessConvertdQ<TileShape_MK, Element, ElementAccum, ArchTag,\n        AttnKernel::CollectiveMainloop::NumMmaThreads,\n        typename AttnKernel::CollectiveMainloop::TiledMmadQ,\n        AttnKernel::CollectiveMainloop::dQ_swapAB\n        >;\n    typename PostprocessKernel::Arguments postprocess_args {\n        static_cast<ElementAccum const*>(params.dq_accum_ptr),\n        {seqlen_q_rounded * params.d_rounded, params.h, batch_q},  // shape_dQaccum\n        {_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQaccum\n        static_cast<Element*>(params.dq_ptr),\n        {seqlen_q, params.d, params.h, batch_q},  // shape_dQ\n        {params.dq_row_stride, _1{}, params.dq_head_stride, params.dq_batch_stride},  // stride_dQ\n        params.scale_softmax,\n        params.cu_seqlens_q,\n        params.seqused_q\n    };\n    typename PostprocessKernel::Params postprocess_params = PostprocessKernel::to_underlying_arguments(postprocess_args);\n    int num_m_block_postprocess = cute::ceil_div(params.seqlen_q, get<0>(TileShape_MK{}));\n    dim3 grid_m_postprocess(num_m_block_postprocess, params.h, params.b);\n    int smem_size_postprocess = PostprocessKernel::SharedStorageSize;\n    if (smem_size_postprocess >= 48 * 1024) {\n        CHECK_CUDA(cudaFuncSetAttribute(cutlass::device_kernel<PostprocessKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_postprocess));\n    }\n    CHECK_CUTLASS(cutlass::kernel_launch<PostprocessKernel>(grid_m_postprocess, PostprocessKernel::MaxThreadsPerBlock, smem_size_postprocess, stream, postprocess_params, false /*launch_with_pdl*/));\n\n    if constexpr (GQA) {\n        using TileShape_NK = cute::Shape<Int<kBlockN>, Int<kHeadDim>>;\n        using PostprocessKerneldKV = flash::FlashAttnBwdPostprocessConvertdQ<TileShape_NK, Element, ElementAccum, ArchTag,\n            AttnKernel::CollectiveEpilogue::NumEpilogueThreads,\n            typename AttnKernel::CollectiveMainloop::TiledMmadKV,\n            AttnKernel::CollectiveMainloop::dKV_swapAB\n            >;\n        typename PostprocessKerneldKV::Arguments postprocess_dK_args {\n            static_cast<ElementAccum const*>(params.dk_accum_ptr),\n            {seqlen_k_rounded * params.d_rounded, params.h_k, batch_k},  // shape_dKaccum\n            {_1{}, seqlen_k_rounded * params.d_rounded, !is_varlen_k ? params.d_rounded * params.seqlen_k_rounded * params.h_k : 0},  // stride_dKaccum\n            static_cast<Element*>(params.dk_ptr),\n            {seqlen_k, params.d, params.h_k, batch_k},  // shape_dK\n            {params.dk_row_stride, _1{}, params.dk_head_stride, params.dk_batch_stride},  // stride_dK\n            1.f,\n            params.cu_seqlens_k,\n            params.seqused_k\n        };\n        typename PostprocessKerneldKV::Params postprocess_dK_params = PostprocessKerneldKV::to_underlying_arguments(postprocess_dK_args);\n        typename PostprocessKerneldKV::Arguments postprocess_dV_args {\n            static_cast<ElementAccum const*>(params.dv_accum_ptr),\n            {seqlen_k_rounded * params.dv_rounded, params.h_k, batch_k},  // shape_dVaccum\n            {_1{}, seqlen_k_rounded * params.dv_rounded, !is_varlen_k ? params.dv_rounded * params.seqlen_k_rounded * params.h_k : 0},  // stride_dVaccum\n            static_cast<Element*>(params.dv_ptr),\n            {seqlen_k, params.dv, params.h_k, batch_k},  // shape_dV\n            {params.dv_row_stride, _1{}, params.dv_head_stride, params.dv_batch_stride},  // stride_dV\n            1.f,\n            params.cu_seqlens_k,\n            params.seqused_k\n        };\n        typename PostprocessKerneldKV::Params postprocess_dV_params = PostprocessKerneldKV::to_underlying_arguments(postprocess_dV_args);\n        int num_n_block_postprocess = cute::ceil_div(params.seqlen_k, get<0>(TileShape_NK{}));\n        dim3 grid_n_postprocess(num_n_block_postprocess, params.h_k, params.b);\n        int smem_size_postprocess = PostprocessKerneldKV::SharedStorageSize;\n        if (smem_size_postprocess >= 48 * 1024) {\n            CHECK_CUDA(cudaFuncSetAttribute(cutlass::device_kernel<PostprocessKerneldKV>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_postprocess));\n        }\n        CHECK_CUTLASS(cutlass::kernel_launch<PostprocessKerneldKV>(grid_n_postprocess, PostprocessKerneldKV::MaxThreadsPerBlock, smem_size_postprocess, stream, postprocess_dK_params, false /*launch_with_pdl*/));\n        CHECK_CUTLASS(cutlass::kernel_launch<PostprocessKerneldKV>(grid_n_postprocess, PostprocessKerneldKV::MaxThreadsPerBlock, smem_size_postprocess, stream, postprocess_dV_params, false /*launch_with_pdl*/));\n    }\n\n}\n\ntemplate<int Arch, typename T, int kBlockM, int kBlockN, int kHeadDim, bool Is_causal, bool Is_local, bool Has_softcap,\n         int Stages_dO=2, int Stages_dS_or_QSm80=2,\n         bool SdP_swapAB=true, bool dKV_swapAB=false, bool dQ_swapAB=false,\n         int NumMmaWarpGroups=2, int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1,\n         bool V_in_regs=false>\nvoid run_mha_bwd_dispatch(Flash_bwd_params &params, cudaStream_t stream) {\n    VARLEN_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] {\n        BOOL_SWITCH(params.h != params.h_k, GQA, [&] {\n            BOOL_SWITCH(params.deterministic, Deterministic_, [&] {\n                static constexpr bool Deterministic = Deterministic_ && kHeadDim < 256;\n                // run_flash_bwd<kHeadDim, kBlockM, kBlockN, T, Is_causal, Is_local, Has_softcap, Varlen, false, GQA, Stages_dO, Stages_dS_or_QSm80, SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ>(params, stream);\n                run_flash_bwd<Arch, kHeadDim, kBlockM, kBlockN, T, Is_causal, Is_local, Has_softcap, Varlen /*Varlen*/, Deterministic /*Deterministic*/, GQA, Stages_dO, Stages_dS_or_QSm80, SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs>(params, stream);\n            });\n        });\n    });\n}\n\n\ntemplate<int Arch, typename T, bool Has_softcap>\nvoid run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {\n    CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {\n        if constexpr (Arch >= 90) {\n            if constexpr (Is_causal && Has_softcap) {\n                // register spill with 128 x 128\n                run_mha_bwd_dispatch<Arch, T, 96, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, true, false, true, 2, 1, 2, 2, false>(params, stream);\n            } else {\n                // With ShuffleStats we no longer have register spilling when Has_softcap and using 128 x 128 block.\n                run_mha_bwd_dispatch<Arch, T, 128, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, true, false, false, 2, 1, 2, 2, false>(params, stream);\n            }\n        } else if constexpr (Arch == 86 || Arch == 89) {\n            run_mha_bwd_dispatch<Arch, T, 64, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 2, 4, 2, true>(params, stream);\n            // run_mha_bwd_dispatch<Arch, T, 96, 96, 64, Is_causal, Is_local, Has_softcap, 1, 2, false, true, true, 2, 2, 4, 4, false>(params, stream);\n            // run_mha_bwd_dispatch<Arch, T, 80, 128, 64, Is_causal, Is_local, Has_softcap, 1, 2, true, false, true, 2, 2, 4, 2, true>(params, stream);\n            // run_mha_bwd_dispatch<Arch, T, 96, 128, 64, Is_causal, Is_local, Has_softcap, 1, 2, true, false, true, 2, 1, 8, 4, false>(params, stream);\n        } else {\n            run_mha_bwd_dispatch<Arch, T, 128, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 4, 4, 4, false>(params, stream);\n        }\n    });\n}\n\ntemplate<int Arch, typename T, bool Has_softcap>\nvoid run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream) {\n    CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {\n        if constexpr (Arch >= 90) {\n            run_mha_bwd_dispatch<Arch, T, 64, 128, 96, Is_causal, Is_local, Has_softcap, 2, 2, true, false, false, 2, 1, 2, 1, true>(params, stream);\n        } else if constexpr (Arch == 86 || Arch == 89) {\n            run_mha_bwd_dispatch<Arch, T, 64, 128, 96, Is_causal, Is_local, Has_softcap, 1, 2, false, false, false, 2, 2, 4, 2, true>(params, stream);\n        } else {\n            run_mha_bwd_dispatch<Arch, T, 64, 128, 96, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 2, 4, 2, false>(params, stream);\n        }\n    });\n}\n\ntemplate<int Arch, typename T, bool Has_softcap>\nvoid run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {\n    CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {\n        if constexpr (Arch >= 90) {\n            if constexpr (Is_causal || Is_local || Has_softcap) {\n                run_mha_bwd_dispatch<Arch, T, 64, 128, 128, Is_causal, Is_local, Has_softcap, 2, 2, true, false, false, 2, 1, 2, 1, false>(params, stream);\n            } else {\n                run_mha_bwd_dispatch<Arch, T, 80, 128, 128, Is_causal, Is_local, Has_softcap, 2, 2, true, false, true, 2, 1, 2, 1, false>(params, stream);\n            }\n        } else if constexpr (Arch == 86 || Arch == 89) {\n            run_mha_bwd_dispatch<Arch, T, 64, 96, 128, Is_causal, Is_local, Has_softcap, 1, 2, false, false, false, 2, 2, 2, 2, true>(params, stream);\n        } else {\n            run_mha_bwd_dispatch<Arch, T, 64, 128, 128, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 2, 2, 2, false>(params, stream);\n        }\n    });\n}\n\ntemplate<int Arch, typename T, bool Has_softcap>\nvoid run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream) {\n    CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {\n        if constexpr (Arch >= 90) {\n            run_mha_bwd_dispatch<Arch, T, 64, 96, 192, Is_causal, Is_local, Has_softcap, 1, 1, false, true, false, 3, 1, 1, 1, false>(params, stream);\n        } else if constexpr (Arch == 86 || Arch == 89) {\n            run_mha_bwd_dispatch<Arch, T, 64, 64, 192, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 2, 2, 2, true>(params, stream);\n        } else {\n            run_mha_bwd_dispatch<Arch, T, 64, 80, 192, Is_causal, Is_local, Has_softcap, 1, 2, false, true, false, 2, 4, 2, 2, false>(params, stream);\n        }\n    });\n}\n\ntemplate<int Arch, typename T, bool Has_softcap>\nvoid run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {\n    CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {\n        if constexpr (Arch >= 90) {\n            run_mha_bwd_dispatch<Arch, T, 64, 80, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, true, true, 2, 1, 1, 1, false>(params, stream);\n        } else if constexpr (Arch == 86 || Arch == 89) {\n            run_mha_bwd_dispatch<Arch, T, 32, 64, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 2, 2, 1, true>(params, stream);\n            // run_mha_bwd_dispatch<Arch, T, 64, 32, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 4, 1, 2, true>(params, stream);\n        } else {\n            run_mha_bwd_dispatch<Arch, T, 64, 64, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 4, 2, 2, false>(params, stream);\n        }\n    });\n}\n"
  },
  {
    "path": "hopper/flash_bwd_postprocess_kernel.h",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n#include \"cute/tensor.hpp\"\n\n#include <cutlass/cutlass.h>\n#include <cutlass/array.h>\n#include <cutlass/numeric_types.h>\n#include <cutlass/numeric_conversion.h>\n#include \"cutlass/arch/barrier.h\"\n\n#include \"seqlen.h\"\n#include \"utils.h\"\n\nnamespace flash {\n\nusing namespace cute;\n\ntemplate <class TileShape_MK_, class Element, class ElementAccum, class ArchTag_, int kNThreads, class TiledMma, bool dQ_swapAB>\nclass FlashAttnBwdPostprocessConvertdQ {\n\npublic:\n\n    // Type Aliases\n    using TileShape_MK = TileShape_MK_;\n    using ArchTag = ArchTag_;\n\n    static_assert(ArchTag::kMinComputeCapability >= 75);\n    static constexpr bool IsSm90 = ArchTag::kMinComputeCapability >= 90;\n\n    static constexpr uint32_t MaxThreadsPerBlock = kNThreads;\n    static constexpr uint32_t MinBlocksPerMultiprocessor = 2;\n\n    static constexpr int kBlockM = get<0>(TileShape_MK{});\n    static constexpr int kHeadDim = get<1>(TileShape_MK{});\n    static_assert(!IsSm90 || kNThreads % cutlass::NumThreadsPerWarpGroup == 0, \"kNThreads must be a multiple of NumThreadsPerWarpGroup\");\n    static constexpr int NumdQWarpGgroups = kNThreads / cutlass::NumThreadsPerWarpGroup;\n    using R2SLayoutAtomdQaccum = std::conditional_t<\n        IsSm90,\n        Layout<Shape<Int<cutlass::NumThreadsPerWarpGroup>, Int<NumdQWarpGgroups>>>,\n        Layout<Shape<Int<kNThreads>>>\n    >;\n    using R2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, R2SLayoutAtomdQaccum{},\n                                                         Layout<Shape<Int<IsSm90 ? 4 : 1>>>{}));  // Val layout, 1 or 4 vals per read\n    using G2SLayoutAtomdQaccum = Layout<Shape<Int<kNThreads>>>;\n    // UniversalCopy instead of AutoVectorizingCopyWithAssumedAlignment as the latter generates cp.async instructions\n    using G2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom<UniversalCopy<uint128_t>, ElementAccum>{}, G2SLayoutAtomdQaccum{},\n                                                         Layout<Shape<_4>>{}));  // Val layout, 4 vals per read\n    // We don't do bound checking for the gmem -> smem load so we just assert here.\n    static_assert(IsSm90 || (kBlockM * kHeadDim) % (kNThreads * 4) == 0);\n    static constexpr int SmemdQaccumSize = size(TileShape_MK{});\n    using SmemLayoutdQaccumFlat = Layout<Shape<Int<SmemdQaccumSize>>>;\n    using SmemLayoutdQaccum = std::conditional_t<\n        IsSm90,\n        Layout<Shape<Int<kBlockM * kHeadDim / NumdQWarpGgroups>, Int<NumdQWarpGgroups>>>,\n        Layout<Shape<Int<kBlockM * kHeadDim>>>\n    >;\n\n    // We can't just use kHeadDim here. E.g. if MMA shape is 64 x 96 but split across 2 WGs,\n    // then setting kBlockKSmem to 32 will cause \"Static shape_div failure\".\n    // We want to treat it as 64 x 48, so kBlockKSmem should be 16.\n    static constexpr int MmaShapeN = get<1>(typename TiledMma::AtomShape_MNK{});\n    static constexpr int kBlockKSmem = MmaShapeN % 64 == 0 ? 64 : (MmaShapeN % 32 == 0 ? 32 : 16);\n    static constexpr int kSwizzle = kBlockKSmem == 64 ? 3 : (kBlockKSmem == 32 ? 2 : 1);\n    using SmemLayoutAtomdQ =\n        decltype(composition(Swizzle<kSwizzle, 3, 3>{},\n                 Layout<Shape<Int<8>, Int<kBlockKSmem>>,\n                 Stride<Int<kBlockKSmem>, _1>>{}));\n    using SmemLayoutdQ = decltype(tile_to_shape(SmemLayoutAtomdQ{}, TileShape_MK{}));\n    using SmemLayoutdQt =\n        decltype(cute::composition(SmemLayoutdQ{},\n                                   make_layout(make_shape(get<1>(TileShape_MK{}), get<0>(TileShape_MK{})),\n                                               make_stride(Int<get<0>(TileShape_MK{})>{}, _1{}))));\n\n    using SmemCopyAtomdQ = Copy_Atom<\n        std::conditional_t<\n            IsSm90,\n            std::conditional_t<!dQ_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,\n            AutoVectorizingCopyWithAssumedAlignment<128>\n        >,\n        Element>;\n\n    static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);\n    static_assert(kHeadDim % kGmemElemsPerLoad == 0, \"Headdim must be a multiple of kGmemElemsPerLoad\");\n    static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDim / kGmemElemsPerLoad, int(MaxThreadsPerBlock));\n    static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, \"MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow\");\n    using GmemLayoutAtom = Layout<Shape <Int<MaxThreadsPerBlock / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,\n                                  Stride<Int<kGmemThreadsPerRow>, _1>>;\n    using GmemTiledCopy = decltype(\n        make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},\n                        GmemLayoutAtom{},\n                        Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{}));  // Val layout, 8 or 16 vals per load\n\n    struct SharedStorage : cute::aligned_struct<128> {\n        cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutdQaccum>> smem_dqacc;\n        cute::array_aligned<Element, cute::cosize_v<SmemLayoutdQ>> smem_dq;\n        alignas(16) cutlass::arch::ClusterTransactionBarrier barrier_dQaccum;\n    };\n\n    static constexpr int SharedStorageSize = sizeof(SharedStorage);\n\n    using ShapedQ = cute::Shape<int32_t, int32_t, int32_t, int32_t>;   // (seqlen_q, d, head, batch)\n    using StridedQ = cute::Stride<int64_t, _1, int64_t, int64_t>;\n    using ShapedQaccum = cute::Shape<int32_t, int32_t, int32_t>;  // (seqlen_q * d, head, batch)\n    using StridedQaccum = cute::Stride<_1, int64_t, int64_t>;\n\n    // Device side arguments\n    struct Arguments {\n        ElementAccum const* ptr_dQaccum;\n        ShapedQaccum const shape_dQaccum;\n        StridedQaccum const stride_dQaccum;\n        Element* ptr_dQ;\n        ShapedQ const shape_dQ;\n        StridedQ const stride_dQ;\n        float const softmax_scale;\n        int const* cu_seqlens = nullptr;\n        int const* seqused = nullptr;\n    };\n\n    // Kernel entry point API\n    struct Params {\n        ElementAccum const* ptr_dQaccum;\n        ShapedQaccum const shape_dQaccum;\n        StridedQaccum const stride_dQaccum;\n        Element* ptr_dQ;\n        ShapedQ const shape_dQ;\n        StridedQ const stride_dQ;\n        float const softmax_scale;\n        int const* cu_seqlens = nullptr;\n        int const* seqused = nullptr;\n    };\n\n    // Convert to underlying arguments. In this case, a simple copy for the aliased type.\n    static\n    Params\n    to_underlying_arguments(Arguments const& args) {\n        return {\n            args.ptr_dQaccum,\n            args.shape_dQaccum,\n            args.stride_dQaccum,\n            args.ptr_dQ,\n            args.shape_dQ,\n            args.stride_dQ,\n            args.softmax_scale,\n            args.cu_seqlens,\n            args.seqused\n        };\n    }\n\n    CUTLASS_DEVICE\n    void\n    operator()(Params const& params, char* smem_buf) {\n\n        static constexpr int kBlockM = get<0>(TileShape_MK{});\n        SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);\n\n        Tensor sdQaccum = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), SmemLayoutdQaccum{});\n        Tensor sdQaccum_flat = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), SmemLayoutdQaccumFlat{});\n        Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.smem_dq.data()), SmemLayoutdQ{});\n        Tensor sdQt = make_tensor(make_smem_ptr(shared_storage.smem_dq.data()), SmemLayoutdQt{});\n\n        int const thread_idx = threadIdx.x;\n        int const m_block = blockIdx.x;\n        int const bidh = blockIdx.y;\n        int const bidb = blockIdx.z;\n\n        flash::SeqlenInfo<true /*Varlen*/, kBlockM> seqlen_info(bidb, size<0>(params.shape_dQ), params.cu_seqlens, params.seqused);\n        bool const is_varlen = params.cu_seqlens;\n        if (is_varlen && m_block * kBlockM >= seqlen_info.seqlen) { return; }\n\n        // Step 1: load dQaccum from gmem to smem\n        Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum const*>(params.ptr_dQaccum)),\n                                      params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0);\n        Tensor gdQaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdQaccum), Shape<Int<kBlockM * kHeadDim>>{}, make_coord(m_block));  // (M * K)\n        if constexpr (IsSm90) {  // Use BulkCopy\n            static constexpr uint32_t TmaTransactionBytesdQaccum = static_cast<uint32_t>(size(SmemLayoutdQaccumFlat{}) * cute::sizeof_bits_v<ElementAccum> / 8);\n            auto bulk_copy = Copy_Traits<SM90_BULK_COPY_AUTO>{};\n            // if (thread0()) { print(gdQaccum); printf(\"\\n\"); print(sdQaccum_flat); printf(\"\\n\"); }\n            if (thread_idx == 0) {\n                shared_storage.barrier_dQaccum.init(1 /*numThreads*/);\n                shared_storage.barrier_dQaccum.arrive_and_expect_tx(TmaTransactionBytesdQaccum);\n                copy(bulk_copy.with(*reinterpret_cast<uint64_t*>(&shared_storage.barrier_dQaccum)), gdQaccum, sdQaccum_flat);\n            }\n            __syncthreads();\n            shared_storage.barrier_dQaccum.wait(0);\n        } else {\n            G2STiledCopydQaccum g2s_tiled_copy_dQaccum;\n            auto g2s_thr_copy_dQaccum = g2s_tiled_copy_dQaccum.get_thread_slice(thread_idx);\n            Tensor tdQgdQaccumg2s = g2s_thr_copy_dQaccum.partition_S(gdQaccum);\n            Tensor tdQsdQaccumg2s = g2s_thr_copy_dQaccum.partition_D(sdQaccum);\n            cute::copy(g2s_tiled_copy_dQaccum, tdQgdQaccumg2s, tdQsdQaccumg2s);\n            __syncthreads();\n        }\n\n        // __syncthreads(); if (cute::thread0()) { print_tensor(sdQaccum); }\n\n        // Step 2: Load dQaccum from smem to register, then convert fp32 -> fp16/bf16\n        R2STiledCopydQaccum s2r_tiled_copy_dQaccum;\n        auto s2r_thr_copy_dQaccum = s2r_tiled_copy_dQaccum.get_thread_slice(thread_idx);\n        Tensor tdQsdQaccum = s2r_thr_copy_dQaccum.partition_S(sdQaccum);\n        TiledMma tiled_mma_dQ;\n        Tensor taccdQrdQaccum = partition_fragment_C(tiled_mma_dQ, select<!dQ_swapAB ? 0 : 1, !dQ_swapAB ? 1 : 0>(TileShape_MK{}));\n        // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 1) { print(tiled_mma_dQ); printf(\"\\n\"); }\n        // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 1) { print(tdQsdQaccum); }\n        // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 1) { print(taccdQrdQaccum); }\n        CUTE_STATIC_ASSERT_V(size(taccdQrdQaccum) == size(tdQsdQaccum));\n        Tensor tdQrdQaccum = s2r_thr_copy_dQaccum.retile_D(taccdQrdQaccum);\n        cute::copy(s2r_tiled_copy_dQaccum, tdQsdQaccum, tdQrdQaccum);\n        #pragma unroll\n        for (int i = 0; i < size(taccdQrdQaccum); ++i) { taccdQrdQaccum(i) *= params.softmax_scale; }\n        // Convert tdQrdQ from fp32 to fp16\n        Tensor rdQ = make_tensor_like<Element>(taccdQrdQaccum);\n        flash::convert_type_out(taccdQrdQaccum, rdQ);\n\n        // Step 3: Copy dQ from register to smem\n        auto smem_tiled_copy_dQ = make_tiled_copy_C(SmemCopyAtomdQ{}, tiled_mma_dQ);\n        auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(thread_idx);\n        Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ);  // ((Atom,AtomNum), MMA_N, MMA_N)\n        // if (cute::thread0()) { print(smem_tiled_copy_dQ); }\n        // if (cute::thread0()) { print(smem_thr_copy_dQ); }\n        // if (cute::thread0()) { print(sdQ); }\n        Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(cute::conditional_return<!dQ_swapAB>(sdQ, sdQt));  // ((Atom,AtomNum),PIPE_M,PIPE_N)\n        cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ);\n        __syncthreads();\n\n        // Step 4: Copy dQ from smem to register to prepare for coalesced write to gmem\n        Tensor mdQ = make_tensor(make_gmem_ptr(params.ptr_dQ), params.shape_dQ, params.stride_dQ)(_, _, bidh, !is_varlen ? bidb : 0);\n        Tensor gdQ = local_tile(domain_offset(make_coord(seqlen_info.offset, _0{}), mdQ), TileShape_MK{}, make_coord(m_block, _0{}));  // (M, K)\n        GmemTiledCopy gmem_tiled_copy_dQ;\n        auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(thread_idx);\n        Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ);    // ((Atom,AtomNum),ATOM_M,ATOM_N)\n        Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ);\n\n        Tensor tdQrdQ = make_fragment_like(tdQsdQ);\n        Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cute::make_identity_tensor(TileShape_MK{}));\n        Tensor tdQpdQ = make_tensor<bool>(make_shape(size<2>(tdQgdQ)));\n        #pragma unroll\n        for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(_0{}, _0{}, k)) < get<1>(params.shape_dQ); }\n        // Need to check OOB when reading from smem if kBlockM isn't evenly tiled\n        static constexpr bool EvenM = kBlockM % CUTE_STATIC_V(size<0>(GmemLayoutAtom{})) == 0;\n        flash::copy</*Is_even_MN=*/EvenM, /*Is_even_K=*/true, /*Clear_OOB_MN=*/false>(\n            gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ, tdQcdQ, tdQpdQ, kBlockM);\n\n        // Step 5: Copy dQ from register to gmem\n        // Clear_OOB_K must be false since we don't want to write zeros to gmem\n        flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(\n            gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, std::min(seqlen_info.seqlen - m_block * kBlockM, kBlockM)\n        );\n    }\n\n};\n\n} // namespace flash\n"
  },
  {
    "path": "hopper/flash_bwd_preprocess_kernel.h",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n#include \"cute/tensor.hpp\"\n\n#include <cutlass/cutlass.h>\n#include <cutlass/array.h>\n#include <cutlass/numeric_types.h>\n#include <cutlass/numeric_conversion.h>\n\n#include \"seqlen.h\"\n#include \"utils.h\"\n\nnamespace flash {\n\nusing namespace cute;\n\ntemplate <class TileShape_MK_, class Element, class ElementAccum, class ArchTag_, bool Clear_dQaccum, bool Varlen>\nclass FlashAttnBwdPreprocess {\n\npublic:\n\n    // Type Aliases\n    using TileShape_MK = TileShape_MK_;\n    using ArchTag = ArchTag_;\n\n    static_assert(std::is_same_v<Element, cutlass::half_t> && ArchTag::kMinComputeCapability >= 75 ||\n                  std::is_same_v<Element, cutlass::bfloat16_t> && ArchTag::kMinComputeCapability >= 80 ||\n                  std::is_same_v<Element, cutlass::float_e4m3_t> && ArchTag::kMinComputeCapability >= 89);\n\n    static constexpr uint32_t MaxThreadsPerBlock = 256;\n    static constexpr uint32_t MinBlocksPerMultiprocessor = 2;\n    static constexpr int SharedStorageSize = 0;\n\n    static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);\n    static_assert(get<1>(TileShape_MK{}) % kGmemElemsPerLoad == 0, \"Headdim must be a multiple of kGmemElemsPerLoad\");\n    static constexpr int kBlockM = get<0>(TileShape_MK{});\n    static constexpr int kHeadDim = get<1>(TileShape_MK{});\n    // We want kBlockKGmem to be a power of 2 so that when we do the summing,\n    // it's just between threads in the same warp\n    static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);\n    static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad;\n    static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, \"MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow\");\n    using GmemLayoutAtom = Layout<Shape <Int<MaxThreadsPerBlock / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,\n                                  Stride<Int<kGmemThreadsPerRow>, _1>>;\n    using GmemTiledCopy = decltype(\n        make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},\n                        GmemLayoutAtom{},\n                        Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{}));  // Val layout, 8 or 16 vals per load\n\n    static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum);\n    static_assert((kBlockM * kHeadDim / kGmemElemsPerLoadAccum) % MaxThreadsPerBlock == 0, \"MaxThreadsPerBlock must divide kBlockM * kHeadDim / kGmemElemsPerLoadAccum\");\n    using GmemLayoutAtomAccum = Layout<Shape<Int<MaxThreadsPerBlock>>>;\n    using GmemTiledCopyAccum = decltype(\n        make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},\n                        GmemLayoutAtomAccum{},\n                        Layout<Shape<Int<kGmemElemsPerLoadAccum>>>{}));  // Val layout, 4 vals per store\n\n    using ShapeO = cute::Shape<int32_t, int32_t, int32_t, int32_t>;  // (seqlen_q, d, head, batch)\n    using StrideO = cute::Stride<int64_t, _1, int64_t, int64_t>;\n    using ShapedPsum = cute::Shape<int32_t, int32_t, int32_t>;  // (seqlen_q, head, batch)\n    using StridedPsum = cute::Stride<_1, int64_t, int64_t>;\n    using ShapedQaccum = cute::Shape<int32_t, int32_t, int32_t>;  // (seqlen_q * d, head, batch)\n    using StridedQaccum = cute::Stride<_1, int64_t, int64_t>;\n\n    // Device side arguments\n    struct Arguments {\n        Element const* ptr_O;\n        ShapeO const shape_O;\n        StrideO const stride_O;\n        Element const* ptr_dO;\n        StrideO const stride_dO;\n        float* ptr_dPsum;\n        ShapedPsum const shape_dPsum;\n        StridedPsum const stride_dPsum;\n        float const* ptr_LSE;\n        StridedPsum const stride_LSE;\n        float *ptr_LSE_log2;\n        StridedPsum const stride_LSE_log2;\n        ElementAccum* ptr_dQaccum;\n        ShapedQaccum const shape_dQaccum;\n        StridedQaccum const stride_dQaccum;\n        int num_batch;  // We need this to know the size of dq_semaphore in case of varlen\n        int* dq_semaphore;\n        int const* cu_seqlens = nullptr;\n        int const* seqused = nullptr;\n    };\n\n    // Kernel entry point API\n    struct Params {\n        Element const* ptr_O;\n        ShapeO const shape_O;\n        StrideO const stride_O;\n        Element const* ptr_dO;\n        StrideO const stride_dO;\n        float* ptr_dPsum;\n        ShapedPsum const shape_dPsum;\n        StridedPsum const stride_dPsum;\n        float const* ptr_LSE;\n        StridedPsum const stride_LSE;\n        float* ptr_LSE_log2;\n        StridedPsum const stride_LSE_log2;\n        ElementAccum* ptr_dQaccum;\n        ShapedQaccum const shape_dQaccum;\n        StridedQaccum const stride_dQaccum;\n        int num_batch;\n        int* dq_semaphore;\n        int const* cu_seqlens = nullptr;\n        int const* seqused = nullptr;\n    };\n\n    // Convert to underlying arguments. In this case, a simple copy for the aliased type.\n    static\n    Params\n    to_underlying_arguments(Arguments const& args) {\n        return {\n            args.ptr_O,\n            args.shape_O,\n            args.stride_O,\n            args.ptr_dO,\n            args.stride_dO,\n            args.ptr_dPsum,\n            args.shape_dPsum,\n            args.stride_dPsum,\n            args.ptr_LSE,\n            args.stride_LSE,\n            args.ptr_LSE_log2,\n            args.stride_LSE_log2,\n            args.ptr_dQaccum,\n            args.shape_dQaccum,\n            args.stride_dQaccum,\n            args.num_batch,\n            args.dq_semaphore,\n            args.cu_seqlens,\n            args.seqused\n        };\n    }\n\n    CUTLASS_DEVICE\n    void\n    operator()(Params const& params, [[maybe_unused]] char* smem_buf) {\n\n        static constexpr int kBlockM = get<0>(TileShape_MK{});\n\n        int const thread_idx = threadIdx.x;\n        int const m_block = blockIdx.x;\n        int const bidh = blockIdx.y;\n        int const bidb = blockIdx.z;\n\n        flash::SeqlenInfo<Varlen, kBlockM> seqlen_info(bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused);\n        bool const is_varlen = Varlen && params.cu_seqlens;\n        int const seqlen_o = seqlen_info.seqlen;\n        if (is_varlen && m_block * kBlockM >= seqlen_o) { return; }\n\n        Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O), params.shape_O, params.stride_O)(_, _, bidh, !is_varlen ? bidb : 0);\n        Tensor gO = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mO), TileShape_MK{}, make_coord(m_block, _0{}));  // (M, K)\n        Tensor mdO = make_tensor(make_gmem_ptr(params.ptr_dO), params.shape_O, params.stride_dO)(_, _, bidh, !is_varlen ? bidb : 0);\n        Tensor gdO = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdO), TileShape_MK{}, make_coord(m_block, _0{}));  // (M, K)\n\n        auto shape_LSE = select<0, 2, 3>(params.shape_O);\n        Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE), shape_LSE, params.stride_LSE)(_, bidh, !is_varlen ? bidb : 0);\n        Tensor gLSE = local_tile(cute::domain_offset(make_coord(seqlen_info.offset), mLSE), Shape<Int<kBlockM>>{}, make_coord(m_block));\n        static_assert(kBlockM <= MaxThreadsPerBlock);\n        float lse = thread_idx < seqlen_o - m_block * kBlockM && thread_idx < kBlockM ? gLSE(thread_idx) : INFINITY;\n\n        GmemTiledCopy gmem_tiled_copy_O;\n        auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);\n\n        Tensor tOgO = gmem_thr_copy_O.partition_S(gO);\n        Tensor tOgdO = gmem_thr_copy_O.partition_S(gdO);\n        // Construct identity layout for gO\n        Tensor cO = cute::make_identity_tensor(TileShape_MK{});  // (BLK_M,BLK_K) -> (blk_m,blk_k)\n        // Repeat the partitioning with identity layouts\n        Tensor tOcO = gmem_thr_copy_O.partition_D(cO);\n        Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));\n        #pragma unroll\n        for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); }\n\n        // (8, kBlockM / 32, kHeadDim / 64) or (8, kBlockM / 16, kHeadDim / 128)\n        Tensor tOrO = make_fragment_like(tOgO);\n        Tensor tOrdO = make_fragment_like(tOgdO);\n        flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true, /*Clearn_OOB_K=*/true>(\n            gmem_tiled_copy_O, tOgO, tOrO, tOcO, tOpO, seqlen_o - m_block * kBlockM\n        );\n        flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true, /*Clearn_OOB_K=*/true>(\n            gmem_tiled_copy_O, tOgdO, tOrdO, tOcO, tOpO, seqlen_o - m_block * kBlockM\n        );\n        // if (threadIdx.x == 222) { printf(\"bidx = %d, bidy = %d, bidz = %d, seqlen_o = %d, m_block = %d, seqlen_o - m_block * kBlockM = %d, tOgO addr = %p\\n\", blockIdx.x, blockIdx.y, blockIdx.z, seqlen_o, m_block, seqlen_o - m_block * kBlockM, &tOgO(0));}\n\n        // Reshape from e.g. (8, kBlockM / 32, kHeadDim / 64) to (kBlockM / 32, (8, kHeadDim / 64))\n        Layout l = make_layout(get<1>(tOrO.layout()), make_layout(get<0>(tOrO.layout()), get<2>(tOrO.layout())));\n        Tensor tOrO_l = make_tensor(tOrO.data(), l);\n        Tensor o_fp32 = make_tensor_like<float>(tOrO_l);\n        flash::convert_type_out(tOrO_l, o_fp32);\n        Tensor tOrdO_l = make_tensor(tOrdO.data(), l);\n        Tensor do_fp32 = make_tensor_like<float>(tOrdO_l);\n        flash::convert_type_out(tOrdO_l, do_fp32);\n        // Sum across the last dimension\n        Tensor dP_sum = make_tensor<float>(make_shape(size<0>(o_fp32)));\n        #pragma unroll\n        for (int mi = 0; mi < size<0>(o_fp32); ++mi) {\n            float dP_sum_cur = do_fp32(mi, 0) * o_fp32(mi, 0);\n            #pragma unroll\n            for (int ni = 1; ni < size<1>(o_fp32); ni++) {\n                dP_sum_cur += do_fp32(mi, ni) * o_fp32(mi, ni);\n            }\n            flash::SumOp<float> sum_op;\n            dP_sum(mi) = flash::Allreduce<kGmemThreadsPerRow>::run(dP_sum_cur, sum_op);\n        }\n\n        Tensor mdPsum = make_tensor(make_gmem_ptr(params.ptr_dPsum), params.shape_dPsum, params.stride_dPsum)(_, bidh, !is_varlen ? bidb : 0);\n        Tensor gdPsum = local_tile(cute::domain_offset(make_coord(seqlen_info.offset_padded), mdPsum), Shape<Int<kBlockM>>{}, make_coord(m_block));\n        if (get<1>(tOcO(_0{}, _0{}, _0{})) == 0) {\n            #pragma unroll\n            for (int mi = 0; mi < size(dP_sum); ++mi) {\n                int const row = get<0>(tOcO(_0{}, mi, _0{}));\n                gdPsum(row) = row < seqlen_o - m_block * kBlockM ? dP_sum(mi) : 0;\n            }\n        }\n\n        int const seqlen_rounded = cute::round_up(seqlen_o, kBlockM);\n        Tensor mLSElog2 = make_tensor(make_gmem_ptr(params.ptr_LSE_log2), params.shape_dPsum, params.stride_LSE_log2)(_, bidh, !is_varlen ? bidb : 0);\n        Tensor gLSElog2 = local_tile(cute::domain_offset(make_coord(seqlen_info.offset_padded), mLSElog2), Shape<Int<kBlockM>>{}, make_coord(m_block));\n        if (thread_idx < seqlen_rounded - m_block * kBlockM && thread_idx < kBlockM) {\n            gLSElog2(thread_idx) = lse == -INFINITY ? 0.f : lse * float(M_LOG2E);\n        }\n\n        if constexpr (Clear_dQaccum) {\n            Tensor mdQaccum = make_tensor(make_gmem_ptr(params.ptr_dQaccum), params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0);\n            Tensor gdQaccum = local_tile(cute::domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdQaccum), Shape<Int<kBlockM * kHeadDim>>{}, make_coord(m_block));\n            GmemTiledCopyAccum gmem_tiled_copy_dQaccum;\n            auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(thread_idx);\n            Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum);\n            Tensor zero = make_fragment_like(tdQgdQaccum);\n            clear(zero);\n            cute::copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, zero, tdQgdQaccum);\n        }\n\n        if (params.dq_semaphore != nullptr && thread_idx == 0) {\n            int const num_batch = params.num_batch;\n            int const num_head = get<2>(params.shape_O);\n            params.dq_semaphore[bidh + bidb * num_head + m_block * num_head * num_batch] = 0;\n        }\n\n    }\n\n};\n\n} // namespace flash\n"
  },
  {
    "path": "hopper/flash_fwd_combine.cu",
    "content": "// Copyright (c) 2024, Tri Dao.\n// Splitting the different head dimensions to different files to speed up compilation.\n\n#include \"flash_fwd_combine_launch_template.h\"\n\ntemplate void run_mha_fwd_combine_<float, float, 64>(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl);\ntemplate void run_mha_fwd_combine_<float, float, 128>(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl);\n\ntemplate void run_mha_fwd_combine_<cutlass::half_t, float, 64>(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl);\ntemplate void run_mha_fwd_combine_<cutlass::half_t, float, 128>(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl);\n\ntemplate void run_mha_fwd_combine_<cutlass::bfloat16_t, float, 64>(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl);\ntemplate void run_mha_fwd_combine_<cutlass::bfloat16_t, float, 128>(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl);\n"
  },
  {
    "path": "hopper/flash_fwd_combine_kernel.h",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n#include \"cute/tensor.hpp\"\n\n#include <cutlass/cutlass.h>\n#include <cutlass/arch/memory.h>\n#include <cutlass/array.h>\n#include <cutlass/numeric_types.h>\n#include <cutlass/numeric_conversion.h>\n\n#include \"cutlass/arch/grid_dependency_control.h\"\n\n#include \"seqlen.h\"\n#include \"utils.h\"\n\nnamespace flash {\n\nusing namespace cute;\n\ntemplate <class TileShape_MK_, int kLogMaxSplits_, int kNThreads, int AlignmentLSE_,\n          bool Is_even_K, bool Varlen, class Element, class ElementPartial, class ArchTag_>\nclass FlashAttnFwdCombine {\n\npublic:\n\n    // Type Aliases\n    using TileShape_MK = TileShape_MK_;\n    using ArchTag = ArchTag_;\n    static constexpr int kMaxSplits = 1 << kLogMaxSplits_;\n    static constexpr int AlignmentLSE = std::min(AlignmentLSE_, int(128 / 8 / sizeof(float)));\n    static_assert(AlignmentLSE >= 1);\n    static constexpr int kStages = 4;\n\n    static_assert(ArchTag::kMinComputeCapability >= 75);\n    static constexpr bool Has_cp_async = ArchTag::kMinComputeCapability >= 80;\n\n    static constexpr uint32_t MaxThreadsPerBlock = kNThreads;\n    static constexpr uint32_t MinBlocksPerMultiprocessor = 2;\n\n    static constexpr int kBlockM = get<0>(TileShape_MK{});\n    static constexpr int kBlockK = get<1>(TileShape_MK{});\n\n    static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(ElementPartial);\n    static_assert(kBlockK % kGmemElemsPerLoad == 0, \"kBlockK must be a multiple of kGmemElemsPerLoad\");\n    static constexpr int kBlockKGmem = kBlockK % 128 == 0 ? 128 : (kBlockK % 64 == 0 ? 64 : 32);\n    static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad;\n    static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, \"MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow\");\n    using GmemCopyAtom = std::conditional_t<\n        Has_cp_async,\n        cute::Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<uint128_t>, ElementPartial>,\n        cute::Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementPartial>\n    >;\n    using GmemLayoutAtom = Layout<Shape <Int<MaxThreadsPerBlock / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,\n                                  Stride<Int<kGmemThreadsPerRow>, _1>>;\n    static_assert(kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0);\n    using GmemTiledCopyAccum = decltype(\n        make_tiled_copy(GmemCopyAtom{},\n                        GmemLayoutAtom{},\n                        Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{}));  // Val layout, 4 vals per load\n    using GmemTiledCopy = decltype(\n        make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},\n                        GmemLayoutAtom{},\n                        Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{}));  // Val layout, 4 vals per load\n\n    using AlignmentTypeLSE = cute::uint_byte_t<static_cast<int>(sizeof(float)) * AlignmentLSE>;\n    static constexpr int kGmemElemsPerLoadLSE = sizeof(AlignmentTypeLSE) / sizeof(float);\n    static_assert(kBlockM % kGmemElemsPerLoadLSE == 0, \"kBlockM must be a multiple of kGmemElemsPerLoadLSE\");\n    static_assert(kBlockM % 8 == 0, \"kBlockM must be a multiple of 8\");\n    static constexpr int kBlockMSmem = kBlockM % 128 == 0 ? 128 : (kBlockM % 64 == 0 ? 64 : (kBlockM % 32 == 0 ? 32 : (kBlockM % 16 == 0 ? 16 : 8)));\n    static constexpr int kGmemThreadsPerRowLSE = kBlockMSmem / kGmemElemsPerLoadLSE;\n    static_assert(MaxThreadsPerBlock % kGmemThreadsPerRowLSE == 0, \"MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRowLSE\");\n    using GmemLayoutAtomLSE = Layout<Shape <Int<MaxThreadsPerBlock / kGmemThreadsPerRowLSE>, Int<kGmemThreadsPerRowLSE>>,\n                                     Stride<Int<kGmemThreadsPerRowLSE>, _1>>;\n    static_assert(kMaxSplits % CUTE_STATIC_V(shape<0>(GmemLayoutAtomLSE{})) == 0);\n    using GmemCopyAtomLSE = std::conditional_t<\n        Has_cp_async,\n        cute::Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<AlignmentTypeLSE>, float>,\n        cute::Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<AlignmentLSE * sizeof(float) * 8>, float>\n    >;\n    using GmemTiledCopyLSE = decltype(\n        make_tiled_copy(GmemCopyAtomLSE{},\n                        GmemLayoutAtomLSE{},\n                        Layout<Shape<_1, Int<kGmemElemsPerLoadLSE>>>{}));  // Val layout, 4 vals per load\n\n    // Otherwise we get IMA when some threads access sLSE, as we're not doing any masking\n    static_assert((kBlockM * kMaxSplits * AlignmentLSE) % kNThreads == 0, \"kNThreads must divide kBlockM * kMaxSplits * AlignmentLSE\");\n    // This works for kBlockMSmem = 8, 16, 32, 64, 128, no bank conflicts\n    using SmemLSESwizzle = std::conditional_t<\n        kBlockMSmem == 8,\n        Swizzle<5, 0, 5>,\n        std::conditional_t<kBlockMSmem == 16, Swizzle<4, 0, 4>, Swizzle<3, 2, 3>>\n    >;\n    using SmemLayoutAtomLSE =\n        decltype(composition(SmemLSESwizzle{},\n                 Layout<Shape<Int<8>, Int<kBlockMSmem>>,\n                 Stride<Int<kBlockMSmem>, _1>>{}));\n    using SmemLayoutLSE = decltype(tile_to_shape(SmemLayoutAtomLSE{}, Shape<Int<kMaxSplits>, Int<kBlockM>>{}));\n\n    using SmemLayoutO = Layout<Shape<Int<kBlockM>, Int<kBlockK>, Int<kStages>>,\n                               Stride<Int<kBlockK>, _1, Int<kBlockM * kBlockK>>>;\n\n    // We want each column (kMaxSplits) to be processed by threads in the same warp.\n    // To reduce the number of shuffles, we want as few threads on the same column as possible.\n    // E.g., if kBlockM is divisible by 64, and there are 256 threads, we want 4 threads (0, 1, 2, 4) per column\n    // have have 64 such quads.\n    static_assert(MaxThreadsPerBlock % kBlockMSmem == 0, \"MaxThreadsPerBlock must be a multiple of kBlockMSmem\");\n    static constexpr int kSmemThreadsPerColLSEt = MaxThreadsPerBlock / kBlockMSmem;\n    static_assert(cutlass::NumThreadsPerWarp % kSmemThreadsPerColLSEt == 0, \"kSmemThreadsPerColLSEt must divide NumThreadsPerWarp\");\n    using S2RLayoutAtomLSE = Layout<Shape<Int<kSmemThreadsPerColLSEt>, Int<MaxThreadsPerBlock / kSmemThreadsPerColLSEt>>>;\n    using S2RTiledCopyLSE = decltype(make_tiled_copy(cute::Copy_Atom<cute::DefaultCopy, float>{}, S2RLayoutAtomLSE{}, Layout<_1>{}));\n\n    using ShapeOPartial = cute::Shape<int32_t, int32_t, int32_t, int32_t, int32_t>;  // (seqlen, d, num_splits, head, batch)\n    using StrideOPartial = cute::Stride<int64_t, _1, int64_t, int64_t, int64_t>;\n    using ShapeLSEPartial = cute::Shape<int32_t, int32_t, int32_t, int32_t>;  // (seqlen, num_splits, head, batch)\n    using StrideLSEPartial = cute::Stride<_1, int64_t, int64_t, int64_t>;  // (seqlen, num_splits, head, batch)\n    using ShapeO = cute::Shape<int32_t, int32_t, int32_t, int32_t>;  // (seqlen, d, head, batch)\n    using StrideO = cute::Stride<int64_t, _1, int64_t, int64_t>;\n    using ShapeLSE = cute::Shape<int32_t, int32_t, int32_t>;  // (seqlen, head, batch)\n    using StrideLSE = cute::Stride<_1, int64_t, int64_t>;  // (seqlen, head, batch)\n\n    struct SharedStorage : cute::aligned_struct<128> {\n        cute::array_aligned<float, cute::cosize_v<SmemLayoutLSE>> smem_lse_partial;\n        cute::array_aligned<int, kBlockM> smem_max_valid_split;\n        cute::array_aligned<ElementPartial, cute::cosize_v<SmemLayoutO>> smem_o_partial;\n    };\n\n    static constexpr int SharedStorageSize = sizeof(SharedStorage);\n\n    // Device side arguments\n    struct Arguments {\n        ElementPartial const* const ptr_O_partial;\n        ShapeOPartial const shape_O_partial;\n        StrideOPartial const stride_O_partial;\n        float const* const ptr_LSE_partial;\n        ShapeLSEPartial const shape_LSE_partial;\n        StrideLSEPartial const stride_LSE_partial;\n        Element* const ptr_O;\n        StrideO const stride_O;\n        float* const ptr_LSE;\n        StrideLSE const stride_LSE;\n        int const* const cu_seqlens = nullptr;\n        int const* const seqused = nullptr;\n        int const* const num_splits_dynamic_ptr = nullptr;\n        int const* const varlen_batch_idx_ptr = nullptr;\n        int* const semaphore_to_reset = nullptr;\n    };\n\n    // Kernel entry point API\n    struct Params {\n        ElementPartial const* const ptr_O_partial;\n        ShapeOPartial const shape_O_partial;\n        StrideOPartial const stride_O_partial;\n        float const* const ptr_LSE_partial;\n        ShapeLSEPartial const shape_LSE_partial;\n        StrideLSEPartial const stride_LSE_partial;\n        Element* const ptr_O;\n        StrideO const stride_O;\n        float* const ptr_LSE;\n        StrideLSE const stride_LSE;\n        cutlass::FastDivmod seqlen_divmod, head_divmod;\n        int const* const cu_seqlens = nullptr;\n        int const* const seqused = nullptr;\n        int const* const num_splits_dynamic_ptr = nullptr;\n        int const* const varlen_batch_idx_ptr = nullptr;\n        int* const semaphore_to_reset = nullptr;\n    };\n\n    // Convert to underlying arguments. In this case, a simple copy for the aliased type.\n    static\n    Params\n    to_underlying_arguments(Arguments const& args) {\n        assert(get<1>(args.shape_LSE_partial) <= kMaxSplits);\n        return {\n            args.ptr_O_partial,\n            args.shape_O_partial,\n            args.stride_O_partial,\n            args.ptr_LSE_partial,\n            args.shape_LSE_partial,\n            args.stride_LSE_partial,\n            args.ptr_O,\n            args.stride_O,\n            args.ptr_LSE,\n            args.stride_LSE,\n            cutlass::FastDivmod(get<0>(args.shape_LSE_partial)), cutlass::FastDivmod(get<2>(args.shape_LSE_partial)),\n            args.cu_seqlens,\n            args.seqused,\n            args.num_splits_dynamic_ptr,\n            args.varlen_batch_idx_ptr,\n            args.semaphore_to_reset,\n            \n        };\n    }\n\n    CUTLASS_DEVICE\n    void\n    operator()(Params const& params, char* smem_buf) {\n\n        SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);\n        Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.smem_lse_partial.data()), SmemLayoutLSE{});\n        Tensor sMaxValidSplit = make_tensor(make_smem_ptr(shared_storage.smem_max_valid_split.data()), Shape<Int<kBlockM>>{});\n        Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o_partial.data()), SmemLayoutO{});\n\n        int const thread_idx = threadIdx.x;\n        int const m_block = blockIdx.x;\n        int const k_block = blockIdx.y;\n        int const maybe_virtual_batch = blockIdx.z;\n        int const batch = params.varlen_batch_idx_ptr ? params.varlen_batch_idx_ptr[maybe_virtual_batch] : maybe_virtual_batch;\n        int const num_splits = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[maybe_virtual_batch] : get<1>(params.shape_LSE_partial);\n\n        if (params.semaphore_to_reset && threadIdx.x == 0 && blockIdx.x == gridDim.x - 1 && blockIdx.y == gridDim.y - 1 && blockIdx.z == gridDim.z - 1) {\n            cutlass::arch::wait_on_dependent_grids();\n            *params.semaphore_to_reset = 0;\n        }\n        if (num_splits <= 1) { return; }\n        flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{batch, size<0>(params.shape_LSE_partial), params.cu_seqlens, params.seqused};\n        int const offset = seqlen_info.offset;\n        int const seqlen = seqlen_info.seqlen;\n        int max_idx = seqlen * get<2>(params.shape_LSE_partial);\n        if constexpr (Varlen) {\n            if (m_block * kBlockM >= max_idx) { return; }\n        }\n\n        cutlass::FastDivmod seqlen_divmod_dynamic(seqlen);\n\n        // Step 1: load LSE_partial from gmem -> smem\n        Tensor mLSEpartial = make_tensor(make_gmem_ptr(params.ptr_LSE_partial + offset * get<0>(params.stride_LSE_partial)),\n                                         select<1, 0, 2, 3>(params.shape_LSE_partial),\n                                         select<1, 0, 2, 3>(params.stride_LSE_partial))(_, _, _, !Varlen ? batch : 0);  // (num_splits, seqlen, head)\n        Tensor mLSEpartial_copy = cute::tiled_divide(mLSEpartial, Shape<_1, Int<kGmemElemsPerLoadLSE>>{});\n        GmemTiledCopyLSE gmem_tiled_copy_LSE;\n        auto gmem_thr_copy_LSE = gmem_tiled_copy_LSE.get_thread_slice(thread_idx);\n        Tensor tLSEsLSE = gmem_thr_copy_LSE.partition_D(sLSE);\n\n        // Construct identity layout for sLSE\n        Tensor cLSE = make_identity_tensor(make_shape(size<0>(sLSE), size<1>(sLSE)));    // (NUM_SPLITS, BLK_M) -> (num_splits, blk_m)\n        // Repeat the partitioning with identity layouts\n        Tensor tLSEcLSE = gmem_thr_copy_LSE.partition_S(cLSE);\n\n        cutlass::arch::wait_on_dependent_grids();\n\n        #pragma unroll\n        for (int m = 0; m < size<2>(tLSEcLSE); ++m) {\n            int mi = int(get<1>(tLSEcLSE(_0{}, _0{}, m)));\n            int idx = m_block * kBlockM + mi;\n            if (idx < max_idx) {\n                int m_idx, bidh;\n                if constexpr (!Varlen) {\n                    bidh = params.seqlen_divmod.divmod(m_idx, idx);\n                } else {\n                    bidh = seqlen_divmod_dynamic.divmod(m_idx, idx);\n                }\n                Tensor mLSEpartial_cur_copy = mLSEpartial_copy(_, _, m_idx, bidh);\n                #pragma unroll\n                for (int s = 0; s < size<1>(tLSEcLSE); ++s) {\n                    int si = get<0>(tLSEcLSE(_0{}, s, _0{}));\n                    // if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && thread_idx < 32) { printf(\"thread_idx = %d, m = %d, s = %d, addr = %p, bank = %d\\n\", thread_idx, m, s, reinterpret_cast<float *>(&(tLSEsLSE(_0{}, s, m))), reinterpret_cast<int>(&(tLSEsLSE(_0{}, s, m))) / 4 % 32);}\n                    if (si < num_splits) {\n                        cute::copy(gmem_tiled_copy_LSE, mLSEpartial_cur_copy(_, si), tLSEsLSE(_, s, m));\n                    } else {\n                        cute::fill(tLSEsLSE(_, s, m), -INFINITY);\n                    }\n                }\n            } else {\n                // We don't need to zero out the rest of the LSEs, as we will not write the output to gmem\n                // cute::fill(tLSEsLSE(_, _, m), -INFINITY);\n            }\n        }\n        if constexpr (Has_cp_async) { cute::cp_async_fence(); }\n\n        // Step 2: Load O_partial from gmem -> smem for split = 0, 1, ..., kStages - 2.\n        // We want these async loads to be in flight as we compute the LSE.\n        GmemTiledCopyAccum gmem_tiled_copy_O_partial;\n        auto gmem_thr_copy_O_partial = gmem_tiled_copy_O_partial.get_thread_slice(thread_idx);\n        // Construct identity layout for gO\n        Tensor cO = cute::make_identity_tensor(TileShape_MK{});  // (BLK_M,BLK_K) -> (blk_m,blk_k)\n        // Repeat the partitioning with identity layouts\n        Tensor tOcO = gmem_thr_copy_O_partial.partition_D(cO);\n        Tensor mOpartial = make_tensor(make_gmem_ptr(params.ptr_O_partial + offset * get<0>(params.stride_O_partial)),\n                                       params.shape_O_partial, params.stride_O_partial)(_, _, _, _, !Varlen ? batch : 0);  // (seqlen, d, num_splits, head)\n\n        // Precompute these values to avoid recomputing them in the loop\n        Tensor tOmidx = make_tensor<int>(make_shape(size<1>(tOcO)));\n        Tensor tObidh = make_tensor<int>(make_shape(size<1>(tOcO)));\n        Tensor tOrOptr = make_tensor<ElementPartial const*>(make_shape(size<1>(tOcO)));\n        #pragma unroll\n        for (int m = 0; m < size<1>(tOcO); ++m) {\n            int mi = get<0>(tOcO(_0{}, m, _0{}));\n            int idx = m_block * kBlockM + mi;\n            if constexpr (!Varlen) {\n                tObidh(m) = params.seqlen_divmod.divmod(tOmidx(m), idx);\n            } else {\n                tObidh[m] = seqlen_divmod_dynamic.divmod(tOmidx(m), idx);\n            }\n            tOrOptr[m] = &mOpartial(tOmidx(m), k_block * kBlockK, _0{}, tObidh(m));\n            if (idx >= max_idx) {\n                tObidh[m] = -1;\n            }\n        }\n\n        Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOcO)));\n        if constexpr (!(Is_even_K)) {\n            #pragma unroll\n            for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O_partial) - k_block * kBlockK; }\n        }\n\n        Tensor tOsOpartial = gmem_thr_copy_O_partial.partition_D(sO);\n\n        auto load_O_partial = [&] (int split, int stage) {\n            Tensor tOsOpartial_cur = tOsOpartial(_, _, _, stage);\n            #pragma unroll\n            for (int m = 0; m < size<1>(tOcO); ++m) {\n                if (tObidh(m) >= 0)  {\n                    Tensor mOpartial_cur = make_tensor(make_gmem_ptr(tOrOptr[m]), mOpartial(_0{}, _, _, _0{}).layout());\n                    Tensor mOpartial_cur_copy = cute::tiled_divide(mOpartial_cur, Shape<Int<kGmemElemsPerLoad>>{});\n                    #pragma unroll\n                    for (int k = 0; k < size<2>(tOcO); ++k) {\n                        int k_idx = get<1>(tOcO(_0{}, _0{}, k)) / kGmemElemsPerLoad;\n                        if (Is_even_K || tOpO(k)) {\n                            cute::copy(gmem_tiled_copy_O_partial, mOpartial_cur_copy(_, k_idx, split), tOsOpartial_cur(_, m, k));\n                        }\n                    }\n                }\n            }\n        };\n\n        for (int s = 0; s < kStages - 1; ++s) {\n            if (s < num_splits) { load_O_partial(s, s); }\n            if constexpr (Has_cp_async) { cute::cp_async_fence(); }\n        }\n\n        // Step 3: load and transpose LSE_partial from smem -> rmem\n        if constexpr (Has_cp_async) { cutlass::arch::cp_async_wait<kStages - 1>(); }\n        __syncthreads();\n\n        S2RTiledCopyLSE s2r_tiled_copy_LSE;\n        auto s2r_thr_copy_LSE = s2r_tiled_copy_LSE.get_thread_slice(thread_idx);\n        Tensor ts2rsLSE = s2r_thr_copy_LSE.partition_S(sLSE);\n        Tensor ts2rrLSE = make_fragment_like(ts2rsLSE);\n        cute::copy(s2r_tiled_copy_LSE, ts2rsLSE, ts2rrLSE);\n\n        // Step 4: compute the final LSE along the split dimension\n        Tensor lse_sum = make_tensor<float>(make_shape(size<2>(ts2rrLSE)));\n        Tensor ts2rcLSE = s2r_thr_copy_LSE.partition_D(cLSE);\n        // We compute the max valid split for each row to short-circuit the computation later\n        Tensor max_valid_split = make_tensor<int>(make_shape(size<2>(ts2rrLSE)));\n        static_assert(CUTE_STATIC_V(size<0>(ts2rrLSE)) == 1);\n        #pragma unroll\n        for (int m = 0; m < size<2>(ts2rrLSE); ++m) {\n            float lse_max = ts2rrLSE(_0{}, _0{}, m);\n            #pragma unroll\n            for (int s = 1; s < size<1>(ts2rrLSE); ++s) { lse_max = max(lse_max, ts2rrLSE(_0{}, s, m)); }\n            MaxOp<float> max_op;\n            lse_max = Allreduce<kSmemThreadsPerColLSEt>::run(lse_max, max_op);\n            int max_valid_idx = -1;\n            #pragma unroll\n            for (int s = 0; s < size<1>(ts2rrLSE); ++s) {\n                if (ts2rrLSE(_0{}, s, m) != -INFINITY) { max_valid_idx = get<0>(ts2rcLSE(_0{}, s, _0{})); }\n            }\n            MaxOp<int> max_int_op;\n            max_valid_split[m] = Allreduce<kSmemThreadsPerColLSEt>::run(max_valid_idx, max_int_op);\n            float lse_max_cur = lse_max == -INFINITY ? 0.0f : lse_max;  // In case all local LSEs are -inf\n            float lse_sum_cur = 0.f;\n            #pragma unroll\n            for (int s = 0; s < size<1>(ts2rrLSE); ++s) {\n                float scale = expf(ts2rrLSE(_0{}, s, m) - lse_max_cur);\n                lse_sum_cur += scale;\n                // if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && thread_idx < 32) { printf(\"thread_idx = %d, m = %d, s = %d, addr = %p, bank = %d\\n\", thread_idx, m, s, reinterpret_cast<float *>(&(ts2rsLSE(_0{}, s, m))), reinterpret_cast<int>(&(ts2rsLSE(_0{}, s, m))) / 4 % 32);}\n                // ts2rsLSE(_0{}, m, s) = scale;\n                ts2rrLSE(_0{}, s, m) = scale;\n            }\n            SumOp<float> sum_op;\n            lse_sum_cur = Allreduce<kSmemThreadsPerColLSEt>::run(lse_sum_cur, sum_op);\n            lse_sum(m) = logf(lse_sum_cur) + lse_max;\n            float inv_sum = (lse_sum_cur == 0.f || lse_sum_cur != lse_sum_cur) ? 0.f : 1.f / lse_sum_cur;\n            #pragma unroll\n            for (int s = 0; s < size<1>(ts2rrLSE); ++s) { ts2rrLSE(_0{}, s, m) *= inv_sum; }\n        }\n        // Store the scales exp(lse - lse_logsum) back to smem\n        cute::copy(s2r_tiled_copy_LSE, ts2rrLSE, ts2rsLSE);\n\n        // Store max_valid_split to smem\n        #pragma unroll\n        for (int m = 0; m < size<2>(ts2rrLSE); ++m) {\n            if (get<0>(ts2rcLSE(_0{}, _0{}, m)) == 0) {  // Only the thread responsible for s=0 writes to smem\n                int mi = int(get<1>(ts2rcLSE(_0{}, _0{}, m)));\n                if (mi < kBlockM) { sMaxValidSplit[mi] = max_valid_split[m]; }\n            }\n        }\n\n        // Step 5: store final LSE back to gmem\n        if (k_block == 0) {\n            auto shape_LSE = select<0, 2, 3>(params.shape_LSE_partial);\n            Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset * get<0>(params.stride_LSE)), shape_LSE, params.stride_LSE)(_, _, !Varlen ? batch : 0);\n            #pragma unroll\n            for (int m = 0; m < size<2>(ts2rrLSE); ++m) {\n                if (get<0>(ts2rcLSE(_0{}, _0{}, m)) == 0) {  // Only the thread responsible for s=0 writes to gmem\n                    int mi = int(get<1>(ts2rcLSE(_0{}, _0{}, m)));\n                    int idx = m_block * kBlockM + mi;\n                    if (idx < max_idx) {\n                        int m_idx, bidh;\n                        if constexpr (!Varlen) {\n                            bidh = params.seqlen_divmod.divmod(m_idx, idx);\n                        } else {\n                            bidh = seqlen_divmod_dynamic.divmod(m_idx, idx);\n                        }\n                        // printf(\"thread_idx = %d, m = %d, mi = %d, idx = %d, m_idx = %d, bidh = %d, bidb = %d, lse_sum = %f\\n\", thread_idx, m, mi, idx, m_idx, bidh, bidb, lse_sum(m));\n                        mLSE(m_idx, bidh) = lse_sum(m);\n                    }\n                }\n            }\n        }\n\n        // Step 6: read O_partial from gmem -> smem -> rmem and accumulate the final O\n        __syncthreads();\n        int thr_max_valid_split = sMaxValidSplit[get<0>(tOcO(_0{}, _0{}, _0{}))];\n        #pragma unroll\n        for (int m = 1; m < size<1>(tOcO); ++m) { thr_max_valid_split = max(thr_max_valid_split, sMaxValidSplit[get<0>(tOcO(_0{}, m, _0{}))]); }\n        Layout tOrOpartial_layout = gmem_thr_copy_O_partial.partition_S(make_tensor<ElementPartial>(TileShape_MK{})).layout();\n        Tensor tOrOpartial = make_fragment_like<ElementPartial>(tOrOpartial_layout);\n        Tensor tOrO = make_fragment_like<float>(tOrOpartial);\n        clear(tOrO);\n        int stage_load = kStages - 1, stage_compute = 0;\n        #pragma unroll 4 // Already tuned for speed\n        for (int s = 0; s <= thr_max_valid_split; ++s) {\n            Tensor scale = make_tensor<float>(make_shape(size<1>(tOrOpartial)));\n            #pragma unroll\n            for (int m = 0; m < size<1>(tOrOpartial); ++m) { scale(m) = sLSE(s, get<0>(tOcO(_0{}, m, _0{}))); }\n\n            if (s + kStages - 1 <= thr_max_valid_split) { load_O_partial(s + kStages - 1, stage_load); }\n            if constexpr (Has_cp_async) { cute::cp_async_fence(); }\n            stage_load = stage_load < kStages - 1 ? stage_load + 1 : 0;\n            if constexpr (Has_cp_async) { cutlass::arch::cp_async_wait<kStages - 1>(); }\n            // We don't need __syncthreads() because each thread is just reading its own data from smem\n            cute::copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementPartial>{},\n                       tOsOpartial(_, _, _, stage_compute), tOrOpartial);\n            stage_compute = stage_compute < kStages - 1 ? stage_compute + 1 : 0;\n\n            #pragma unroll\n            for (int m = 0; m < size<1>(tOrOpartial); ++m) {\n                if (tObidh(m) >= 0 && scale(m) > 0.f) {\n                    #pragma unroll\n                    for (int k = 0; k < size<2>(tOrOpartial); ++k) {\n                        if (Is_even_K || tOpO(k)) {\n                            Tensor rOpartial = make_tensor_like<float>(tOrOpartial(_, m, k));\n                            flash::convert_type_out(tOrOpartial(_, m, k), rOpartial);\n                            #pragma unroll\n                            for (int i = 0; i < size<0>(tOrOpartial); ++i) {\n                                tOrO(i, m, k) += scale(m) * rOpartial[i];\n                            }\n                        }\n                    }\n                }\n            }\n        }\n\n        // Step 7: Write the final O to gmem\n        Tensor rO = make_tensor_like<Element>(tOrO);\n        flash::convert_type_out(tOrO, rO);\n        auto shape_O = make_shape(get<0>(params.shape_O_partial), get<1>(params.shape_O_partial) - k_block * kBlockK, get<3>(params.shape_O_partial), get<4>(params.shape_O_partial));\n        Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset * get<0>(params.stride_O) + k_block * kBlockK * get<1>(params.stride_O)),\n                                shape_O, params.stride_O)(_, _, _, !Varlen ? batch : 0);\n        Tensor mO_copy = cute::tiled_divide(mO, Shape<_1, Int<kGmemElemsPerLoad>>{});\n        GmemTiledCopy gmem_tiled_copy_O;\n        auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);\n\n        #pragma unroll\n        for (int m = 0; m < size<1>(tOcO); ++m) {\n            if (tObidh(m) >= 0)  {\n                #pragma unroll\n                for (int k = 0; k < size<2>(tOcO); ++k) {\n                    int k_idx = get<1>(tOcO(_0{}, _0{}, k)) / kGmemElemsPerLoad;\n                    if (Is_even_K || tOpO(k)) {\n                        cute::copy(gmem_tiled_copy_O, rO(_, m, k), mO_copy(_, tOmidx(m), k_idx, tObidh(m)));\n                    }\n                }\n            }\n        }\n\n    }\n\n};\n\n} // namespace flash\n"
  },
  {
    "path": "hopper/flash_fwd_combine_launch_template.h",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n#include \"cute/tensor.hpp\"\n\n#include \"cutlass/cutlass.h\"\n#include \"cutlass/arch/arch.h\"  // For cutlass::arch::Sm80\n#include \"cutlass/device_kernel.h\"  // For device_kernel\n#include \"cutlass/kernel_launch.h\"  // For kernel_launch\n\n#include \"cuda_check.h\"\n#include \"static_switch.h\"\n#include \"flash.h\"\n#include \"flash_fwd_combine_kernel.h\"\n\nusing namespace cute;\n\ntemplate <int Arch, int kBlockM, int kBlockK, int kLogMaxSplits, bool IsEvenK, bool Varlen, typename Element, typename ElementPartial>\nvoid run_flash_fwd_combine(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl) {\n    using ArchTag = std::conditional_t<Arch >= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>;\n    using TileShape_MK = cute::Shape<Int<kBlockM>, Int<kBlockK>>;\n    using CombineKernel = flash::FlashAttnFwdCombine<TileShape_MK, kLogMaxSplits, 256 /*kNThreads*/, 1 /*AlignmentLSE*/,\n                                                     IsEvenK, Varlen, Element, ElementPartial, ArchTag>;\n\n    typename CombineKernel::Arguments args {\n        static_cast<ElementPartial const*>(params.oaccum_ptr),\n        {!Varlen ? params.seqlen_q : params.total_q, params.dv, params.num_splits, params.h, !Varlen ? params.b : 1},  // shape_O_partial\n        {params.oaccum_row_stride, _1{}, params.oaccum_split_stride, params.oaccum_head_stride, !Varlen ? params.oaccum_batch_stride : 0},  // stride_O_partial\n        static_cast<float*>(params.softmax_lseaccum_ptr),\n        {!Varlen ? params.seqlen_q : params.total_q, params.num_splits, params.h, !Varlen ? params.b : 1},  // shape_LSE_partial\n        {_1{}, params.lseaccum_split_stride, params.lseaccum_head_stride, !Varlen ? params.lseaccum_batch_stride : 0},  // stride_LSE_partial\n        static_cast<Element*>(params.o_ptr),\n        {params.o_row_stride, _1{}, params.o_head_stride, !Varlen ? params.o_batch_stride : 0},  // stride_O\n        static_cast<float*>(params.softmax_lse_ptr),\n        {_1{}, !Varlen ? params.seqlen_q : params.total_q, !Varlen ? params.h * params.seqlen_q : 0},  // stride_LSE\n        params.cu_seqlens_q, params.seqused_q, params.num_splits_dynamic_ptr, params.varlen_batch_idx_ptr, params.tile_count_semaphore\n    };\n\n    typename CombineKernel::Params kernel_params = CombineKernel::to_underlying_arguments(args);\n    int num_blocks_k = cute::ceil_div(params.dv, kBlockK);\n    int num_blocks_m = cute::ceil_div(params.seqlen_q * params.h, kBlockM);\n    dim3 grid_m(num_blocks_m, num_blocks_k, params.b);\n    auto kernel = cutlass::device_kernel<CombineKernel>;\n    int smem_size = CombineKernel::SharedStorageSize;\n    if (smem_size >= 48 * 1024) {\n        CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));\n    }\n    // kernel<<<grid_m, CombineKernel::MaxThreadsPerBlock, smem_size, stream>>>(kernel_params);\n    CHECK_CUTLASS(cutlass::kernel_launch<CombineKernel>(grid_m, CombineKernel::MaxThreadsPerBlock, smem_size, stream, kernel_params, Arch >= 90 && enable_pdl /*launch_with_pdl*/));\n}\n\ntemplate<typename T, typename Tpartial, int kBlockK>\nvoid run_mha_fwd_combine_(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl) {\n    // We want kBlockM to be as small as possible to maximize parallelism.\n    // E.g., if hdim is 64, we want kBlockM to be 16 so that we can use 256 threads, each reading 4 elements (floats).\n    static_assert(kBlockK % 32 == 0, \"kBlockK must be a multiple of 32\");\n    static constexpr int kBlockM = kBlockK % 128 == 0 ? 8 : (kBlockK % 64 == 0 ? 16 : 32);\n    ARCH_SWITCH(params.arch, Arch, [&] {\n        BOOL_SWITCH(params.cu_seqlens_q || params.seqused_q, Varlen, [&] {\n            if constexpr (kBlockM >= 16) {  // If kBlockM == 8 then the minimum number of splits is 32.\n                if (params.num_splits <= 16) {\n                    run_flash_fwd_combine<Arch, kBlockM, kBlockK, 4, false /*IsEvenK*/, Varlen, T, Tpartial>(params, stream, enable_pdl);\n                    return;\n                }\n            }\n            if (params.num_splits <= 32) {\n                run_flash_fwd_combine<Arch, kBlockM, kBlockK, 5, false /*IsEvenK*/, Varlen, T, Tpartial>(params, stream, enable_pdl);\n            } else if (params.num_splits <= 64) {\n                run_flash_fwd_combine<Arch, kBlockM, kBlockK, 6, false /*IsEvenK*/, Varlen, T, Tpartial>(params, stream, enable_pdl);\n            } else if (params.num_splits <= 128) {\n                run_flash_fwd_combine<Arch, kBlockM, kBlockK, 7, false /*IsEvenK*/, Varlen, T, Tpartial>(params, stream, enable_pdl);\n            } else {\n                run_flash_fwd_combine<Arch, kBlockM, kBlockK, 8, false /*IsEvenK*/, Varlen, T, Tpartial>(params, stream, enable_pdl);\n            }\n        });\n    });\n}\n"
  },
  {
    "path": "hopper/flash_fwd_kernel_sm80.h",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n#include \"cute/tensor.hpp\"\n\n#include <cutlass/cutlass.h>\n#include <cutlass/array.h>\n#include <cutlass/numeric_types.h>\n#include <cutlass/kernel_hardware_info.h>\n\n#include \"seqlen.h\"\n#include \"utils.h\"\n#include \"softmax.h\"\n\nnamespace flash {\n\nusing namespace cute;\n\ntemplate <class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>\nclass FlashAttnFwdSm80 {\n\npublic:\n\n    // Type Aliases\n    using CollectiveMainloop = CollectiveMainloop_;\n    using CollectiveEpilogue = CollectiveEpilogue_;\n    static constexpr bool Is_causal = CollectiveMainloop::Is_causal;\n    static constexpr bool Is_local = CollectiveMainloop::Is_local;\n    static_assert(CollectiveMainloop::Varlen == CollectiveEpilogue::Varlen);\n    static constexpr bool Has_softcap = CollectiveMainloop::Has_softcap;\n    static constexpr bool Varlen = CollectiveMainloop::Varlen;\n    static constexpr bool PagedKV = CollectiveMainloop::PagedKV;\n    static constexpr bool Split = CollectiveMainloop::Split;\n    static constexpr bool Is_FP8 = CollectiveMainloop::Is_FP8;\n    static constexpr bool Transpose_V = CollectiveMainloop::Transpose_V;\n    static constexpr bool AppendKV = CollectiveMainloop::AppendKV;\n    static constexpr bool PackGQA = CollectiveMainloop::PackGQA;\n    static constexpr int NumProducerThreads = CollectiveMainloop::NumProducerThreads;\n    using SeqlenInfo_t = typename CollectiveMainloop::SeqlenInfo_t;\n\n    // Mainloop derived types\n    using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK;\n    using TiledMma = typename CollectiveMainloop::TiledMma;\n    using ArchTag = typename CollectiveMainloop::ArchTag;\n    using MainloopArguments = typename CollectiveMainloop::Arguments;\n    using MainloopParams = typename CollectiveMainloop::Params;\n\n    // Epilogue derived types\n    using EpilogueArguments = typename CollectiveEpilogue::Arguments;\n    using EpilogueParams = typename CollectiveEpilogue::Params;\n\n    static_assert(ArchTag::kMinComputeCapability >= 80);\n\n    using TileScheduler = TileScheduler_;\n    using TileSchedulerArguments = typename flash::TileSchedulerArguments;\n    using TileSchedulerParams = typename TileScheduler::Params;\n\n    static constexpr uint32_t NumThreads = CUTE_STATIC_V(size(TiledMma{}));\n    static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{}));\n    static constexpr uint32_t MinBlocksPerMultiprocessor = NumThreads == 128 ? 2 : 1;\n\n    // Kernel level shared memory storage\n    // We overlap the shared memory for the mainloop and epilogue. However, we only want smem_o to overlap with smem_v + smem_k and not smem_q\n    // and nothing else, so we'll pad in case sizeof(smem_o) > sizeof(smem_v) + sizeof(smem_k).\n    static constexpr int mainloop_smem_padding_ = int(sizeof(typename CollectiveEpilogue::TensorStorage))\n        - int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v)))\n        - int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_k)));\n    static constexpr int mainloop_smem_padding = mainloop_smem_padding_ < 0 ? 0 : mainloop_smem_padding_;\n    struct SharedStorage {\n        struct TensorStorage : cute::aligned_struct<128> {\n            union {\n                struct {\n                    cute::array<uint32_t, mainloop_smem_padding / sizeof(uint32_t)> padding_;\n                    typename CollectiveMainloop::TensorStorage mainloop;\n                };\n                // We want smem_o to line up with the start of smem_v\n                typename CollectiveEpilogue::TensorStorage epilogue;\n            };\n        } tensors;\n\n        alignas(16) typename TileScheduler::SharedStorage smem_scheduler;\n\n    };\n\n    static constexpr int SharedStorageSize = sizeof(SharedStorage);\n\n    // Device side arguments\n    struct Arguments {\n        MainloopArguments mainloop{};\n        EpilogueArguments epilogue{};\n        cutlass::KernelHardwareInfo hw_info{};\n        TileSchedulerArguments scheduler{};\n    };\n\n    // Kernel entry point API\n    struct Params {\n        MainloopParams mainloop{};\n        EpilogueParams epilogue{};\n        cutlass::KernelHardwareInfo hw_info{};\n        TileSchedulerParams scheduler{};\n    };\n\n    //\n    // Methods\n    //\n\n    // Convert to underlying arguments. In this case, a simple copy for the aliased type.\n    static\n    Params\n    to_underlying_arguments(Arguments const& args) {\n        CUTLASS_TRACE_HOST(\"to_underlying_arguments():\");\n\n        // Get SM count if needed, otherwise use user supplied SM count\n        int sm_count = args.hw_info.sm_count;\n        if (sm_count <= 0) {\n            CUTLASS_TRACE_HOST(\"  WARNING: Arguments do not include a valid SM count.\\n\"\n                \"  For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.\");\n            sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);\n        }\n\n        CUTLASS_TRACE_HOST(\"to_underlying_arguments(): Setting persistent grid SM count to \" << sm_count);\n\n        cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};\n        return {\n            CollectiveMainloop::to_underlying_arguments(args.mainloop),\n            CollectiveEpilogue::to_underlying_arguments(args.epilogue),\n            hw_info,\n            TileScheduler::to_underlying_arguments(args.scheduler)\n        };\n    }\n\n    // Computes the kernel launch grid shape based on runtime parameters\n    static dim3\n    get_grid_shape(Params const& params) {\n        return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count * MinBlocksPerMultiprocessor);\n    }\n\n    static dim3\n    get_block_shape() {\n        return dim3(MaxThreadsPerBlock, 1, 1);\n    }\n\n    CUTLASS_DEVICE\n    void\n    operator()(Params const& params, char* smem_buf) {\n\n        static constexpr int kBlockM = get<0>(TileShape_MNK{});\n\n        SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);\n\n        CollectiveMainloop mainloop;\n        CollectiveEpilogue epilogue;\n\n        TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.smem_scheduler));\n        // Initialize matmul objects.\n        TiledMma tiled_mma;\n\n        scheduler.init_consumer();\n\n        int warp_idx = cutlass::canonical_warp_idx_sync();\n        CUTLASS_PRAGMA_NO_UNROLL\n        for (auto work_tile_info = warp_idx == 0 ? scheduler.template get_initial_work</*IsProducerWarp=*/true>(params.scheduler) : scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);\n             work_tile_info.is_valid(params.scheduler);\n             work_tile_info = warp_idx == 0 ? scheduler.template get_next_work</*IsProducerWarp=*/true>(params.scheduler, work_tile_info) : scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {\n            // Attention output (GEMM-II) accumulator.\n            Tensor tOrO = partition_fragment_C(tiled_mma, select<0, 2>(TileShape_MNK{}));\n            float softmax_scale_log2 = params.mainloop.softmax_scale_log2;\n            // If there's tanh softcap, the scaling will be done before tanh.\n            auto block_coord = work_tile_info.get_block_coord(params.scheduler);\n            int const bidb = get<2>(block_coord);\n            if constexpr (Is_FP8 && !Has_softcap) {\n                int const bidh = get<1>(block_coord);\n                int const bidh_kv = !PackGQA ? params.mainloop.qhead_per_khead_divmod.divide(bidh) : bidh;\n                float const q_descale = params.mainloop.ptr_q_descale == nullptr ? 1.0f : params.mainloop.ptr_q_descale[bidb * get<0>(params.mainloop.stride_q_descale) + bidh_kv * get<1>(params.mainloop.stride_q_descale)];\n                float const k_descale = params.mainloop.ptr_k_descale == nullptr ? 1.0f : params.mainloop.ptr_k_descale[bidb * get<0>(params.mainloop.stride_k_descale) + bidh_kv * get<1>(params.mainloop.stride_k_descale)];\n                softmax_scale_log2 *= q_descale * k_descale;\n            }\n            flash::Softmax<2 * (2 * kBlockM / NumThreads), /*Max_offset=*/!Is_FP8 ? 0 : 8> softmax(softmax_scale_log2);\n\n            SeqlenInfo_t seqlen_info{\n                bidb,\n                get<0>(params.mainloop.shape_Q),\n                !PagedKV ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable),\n                get<0>(params.mainloop.shape_K_new),\n                params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new,\n                params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k,\n                params.mainloop.seqlens_rotary\n            };\n            if constexpr (AppendKV) {\n                bool tile_new_valid = mainloop.store_kv_new(\n                    params.mainloop, threadIdx.x, shared_storage, seqlen_info, block_coord);\n                if (tile_new_valid) { __syncthreads(); }\n            }\n            bool tile_valid = mainloop.mma(\n                params.mainloop, tOrO, softmax, threadIdx.x, seqlen_info, block_coord,\n                shared_storage);\n            scheduler.prefetch_next_work(params.scheduler, work_tile_info);\n            if (tile_valid) {\n                // if (threadIdx.x == 128) { printf(\"Before epilogue, bid.x = %d, bid.y = %d, bid.z = %d, m_block = %d, bidb = %d, split_idx = %d\\n\", blockIdx.x, blockIdx.y, blockIdx.z, m_block, bidb, split_idx); }\n                epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma,\n                               threadIdx.x, block_coord);\n            } else {\n                // Write 0 to gO and -inf to gLSE.\n                epilogue.store_zero(params.epilogue, threadIdx.x, block_coord);\n            }\n        }\n\n    }\n\n};\n\n} // namespace flash\n"
  },
  {
    "path": "hopper/flash_fwd_kernel_sm90.h",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n#include \"cute/tensor.hpp\"\n\n#include <cutlass/cutlass.h>\n#include <cutlass/arch/reg_reconfig.h>\n#include <cutlass/array.h>\n#include <cutlass/numeric_types.h>\n#include <cutlass/numeric_conversion.h>\n#include <cutlass/kernel_hardware_info.h>\n#include \"cutlass/pipeline/pipeline.hpp\"\n\n#include \"cutlass/arch/grid_dependency_control.h\"\n\n#include \"seqlen.h\"\n#include \"utils.h\"\n#include \"softmax.h\"\n\nnamespace flash {\n\nusing namespace cute;\n\ntemplate <class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>\nclass FlashAttnFwdSm90 {\n\npublic:\n\n    // Type Aliases\n    using CollectiveMainloop = CollectiveMainloop_;\n    using CollectiveEpilogue = CollectiveEpilogue_;\n    static constexpr bool Is_causal = CollectiveMainloop::Is_causal;\n    static constexpr bool Is_local = CollectiveMainloop::Is_local;\n    static_assert(CollectiveMainloop::Varlen == CollectiveEpilogue::Varlen);\n    static constexpr bool Has_softcap = CollectiveMainloop::Has_softcap;\n    static constexpr bool Varlen = CollectiveMainloop::Varlen;\n    static constexpr bool Split = CollectiveMainloop::Split;\n    static constexpr bool Is_FP8 = CollectiveMainloop::Is_FP8;\n    static constexpr bool Transpose_V = CollectiveMainloop::Transpose_V;\n    static constexpr bool AppendKV = CollectiveMainloop::AppendKV;\n    static constexpr bool HasQv = CollectiveMainloop::HasQv;\n    static constexpr bool Use_TMA_Q = CollectiveMainloop::Use_TMA_Q;\n    static constexpr bool Use_TMA_KV = CollectiveMainloop::Use_TMA_KV;\n    static constexpr bool Use_TMA_O = CollectiveEpilogue::Use_TMA_O;\n    static constexpr bool PackGQA = CollectiveMainloop::PackGQA;\n    static constexpr int NumProducerThreads = CollectiveMainloop::NumProducerThreads;\n    static constexpr bool SameHeadDim = CollectiveMainloop::SameHeadDim;\n    static constexpr bool LargeHeadDimV = CollectiveMainloop::LargeHeadDimV;\n    static_assert(CollectiveMainloop::LargeHeadDimV == CollectiveEpilogue::LargeHeadDimV);\n    using SeqlenInfo_t = typename CollectiveMainloop::SeqlenInfo_t;\n\n    // Mainloop derived types\n    using TileShape_MNK_PV = typename CollectiveMainloop::TileShape_MNK_PV;\n    using TiledMmaPV = typename CollectiveMainloop::TiledMmaPV;\n    using ArchTag = typename CollectiveMainloop::ArchTag;\n    using ClusterShape = typename CollectiveMainloop::ClusterShape;\n    using MainloopArguments = typename CollectiveMainloop::Arguments;\n    using MainloopParams = typename CollectiveMainloop::Params;\n    using BarrierQ = std::conditional_t<Use_TMA_Q, cutlass::arch::ClusterTransactionBarrier, cutlass::arch::ClusterBarrier>;\n\n    // Epilogue derived types\n    using EpilogueArguments = typename CollectiveEpilogue::Arguments;\n    using EpilogueParams = typename CollectiveEpilogue::Params;\n\n    static_assert(ArchTag::kMinComputeCapability >= 90);\n\n    using TileScheduler = TileScheduler_;\n    using TileSchedulerArguments = typename flash::TileSchedulerArguments;\n    using TileSchedulerParams = typename TileScheduler::Params;\n\n    static constexpr uint32_t NumLoadWarpGroups = 1;\n    static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMmaPV{})) / cutlass::NumThreadsPerWarpGroup;\n    static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaPV{})) + (NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup);\n    static constexpr uint32_t MinBlocksPerMultiprocessor = 1;\n    static_assert(NumMmaWarpGroups == 1 || NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3);\n\n    /// Register requirement for Load and Math WGs\n    // If we use cp.async to load K and V, we need more registers for the producer WG.\n    static constexpr uint32_t LoadRegisterRequirement = NumMmaWarpGroups == 1 ? 56 : (NumMmaWarpGroups == 2 ? (Use_TMA_KV ? 24 : 40) : 32);\n    static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 1 ? 256 : (NumMmaWarpGroups == 2 ? (Use_TMA_KV ? 240 : 232) : 160);\n    // If you want to print from the producer warp, you'd need to increase the number of registers\n    // Otherwise you'll get CUDA error.\n    // static constexpr uint32_t LoadRegisterRequirement = 40;\n    // static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 2 ? 232 : 152;\n\n    // Kernel level shared memory storage\n    // We overlap the shared memory for the mainloop and epilogue. However, we only want smem_o to overlap with smem_v\n    // and nothing else, so we'll pad in case sizeof(smem_o) > sizeof(smem_v).\n    static constexpr int mainloop_smem_padding_ = int(sizeof(typename CollectiveEpilogue::TensorStorage)) - int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v)));\n    static constexpr int mainloop_smem_padding = mainloop_smem_padding_ < 0 ? 0 : mainloop_smem_padding_;\n    struct SharedStorage {\n        struct TensorStorage : cute::aligned_struct<128, _1> {\n            union {\n                struct {\n                    cute::array<uint32_t, mainloop_smem_padding / sizeof(uint32_t)> padding_;\n                    typename CollectiveMainloop::TensorStorage mainloop;\n                };\n                // We want smem_o to line up with the start of smem_v\n                typename CollectiveEpilogue::TensorStorage epilogue;\n            };\n        } tensors;\n        struct PipelineStorage : cute::aligned_struct<16, _1> {\n            alignas(16) BarrierQ barrier_Q;\n            alignas(16) BarrierQ barrier_Qv;\n            alignas(16) cutlass::arch::ClusterBarrier barrier_O;\n            alignas(16) typename CollectiveMainloop::MainloopPipelineK::SharedStorage pipeline_k;\n            alignas(16) typename CollectiveMainloop::MainloopPipelineV::SharedStorage pipeline_v;\n            alignas(16) typename CollectiveMainloop::MainloopPipelineVt::SharedStorage pipeline_vt;\n            alignas(16) typename CollectiveMainloop::MainloopPipelineKVNew::SharedStorage pipeline_k_new;\n            alignas(16) typename CollectiveMainloop::MainloopPipelineKVNew::SharedStorage pipeline_v_new;\n            alignas(16) typename TileScheduler::SharedStorage smem_scheduler;\n        } pipelines;\n\n    };\n\n    static constexpr int SharedStorageSize = sizeof(SharedStorage);\n\n    // Device side arguments\n    struct Arguments {\n        MainloopArguments mainloop{};\n        EpilogueArguments epilogue{};\n        cutlass::KernelHardwareInfo hw_info{};\n        TileSchedulerArguments scheduler{};\n    };\n\n    // Kernel entry point API\n    struct Params {\n        MainloopParams mainloop{};\n        EpilogueParams epilogue{};\n        cutlass::KernelHardwareInfo hw_info{};\n        TileSchedulerParams scheduler{};\n    };\n\n    //\n    // Methods\n    //\n\n    // Convert to underlying arguments. In this case, a simple copy for the aliased type.\n    static\n    Params\n    to_underlying_arguments(Arguments const& args) {\n        CUTLASS_TRACE_HOST(\"to_underlying_arguments():\");\n\n        // Get SM count if needed, otherwise use user supplied SM count\n        int sm_count = args.hw_info.sm_count;\n        if (sm_count <= 0) {\n            CUTLASS_TRACE_HOST(\"  WARNING: Arguments do not include a valid SM count.\\n\"\n                \"  For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.\");\n            sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);\n        }\n\n        CUTLASS_TRACE_HOST(\"to_underlying_arguments(): Setting persistent grid SM count to \" << sm_count);\n\n        cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};\n        return {\n            CollectiveMainloop::to_underlying_arguments(args.mainloop),\n            CollectiveEpilogue::to_underlying_arguments(args.epilogue),\n            hw_info,\n            TileScheduler::to_underlying_arguments(args.scheduler)\n        };\n    }\n\n    // Computes the kernel launch grid shape based on runtime parameters\n    static dim3\n    get_grid_shape(Params const& params) {\n        return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count);\n    }\n\n    static dim3\n    get_block_shape() {\n        return dim3(MaxThreadsPerBlock, 1, 1);\n    }\n\n    CUTLASS_DEVICE\n    void\n    operator()(Params const& params, char* smem_buf) {\n\n        static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup;\n        static constexpr int MmaThreadOffset = NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup;\n        static constexpr int kBlockM = get<0>(TileShape_MNK_PV{});\n\n        using MainloopPipelineK = typename CollectiveMainloop::MainloopPipelineK;\n        using MainloopPipelineV = typename CollectiveMainloop::MainloopPipelineV;\n        using MainloopPipelineVt = typename CollectiveMainloop::MainloopPipelineVt;\n        using MainloopPipelineKVNew = typename CollectiveMainloop::MainloopPipelineKVNew;\n        using PipelineState = typename CollectiveMainloop::PipelineState;\n        using PipelineParamsK = typename MainloopPipelineK::Params;\n        using PipelineParamsV = typename MainloopPipelineV::Params;\n        using PipelineParamsVt = typename MainloopPipelineVt::Params;\n        using PipelineParamsKVNew = typename MainloopPipelineKVNew::Params;\n\n        SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);\n\n        int const lane_predicate = cute::elect_one_sync();\n        int const warp_idx = cutlass::canonical_warp_idx_sync();\n\n        // Issue Tma Descriptor Prefetch from a single thread\n        if (warp_idx == 0 && lane_predicate) {\n            CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);\n            CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);\n        }\n\n        // Obtain warp index\n        int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;\n        int warp_group_idx = cutlass::canonical_warp_group_idx();\n\n        if (warp_idx == 0 && lane_predicate) {\n            shared_storage.pipelines.barrier_Q.init(Use_TMA_Q ? 1 : NumProducerThreads /*numThreads*/);\n            if constexpr (HasQv) {\n                shared_storage.pipelines.barrier_Qv.init(Use_TMA_Q ? 1 : NumProducerThreads /*numThreads*/);\n            }\n            shared_storage.pipelines.barrier_O.init(size(ClusterShape{}) * (Use_TMA_O ? 1 : NumMmaThreads) /*numThreads*/);\n        }\n\n        // We're counting on pipeline_k to call cutlass::arch::fence_barrier_init();\n        PipelineParamsK pipeline_params_k;\n        pipeline_params_k.role = warp_group_idx == 0\n            ? MainloopPipelineK::ThreadCategory::Producer\n            : MainloopPipelineK::ThreadCategory::Consumer;\n        if constexpr (Use_TMA_KV) {\n            pipeline_params_k.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK;\n            pipeline_params_k.is_leader = warp_group_thread_idx == 0;\n            pipeline_params_k.num_consumers = !LargeHeadDimV ? NumMmaThreads : cutlass::NumThreadsPerWarpGroup;\n        } else {\n            pipeline_params_k.consumer_arv_count = !LargeHeadDimV ? NumMmaThreads : cutlass::NumThreadsPerWarpGroup;\n            pipeline_params_k.producer_arv_count = NumProducerThreads;\n        }\n\n        static_assert(is_same_v<PipelineParamsK, PipelineParamsVt>);\n        PipelineParamsVt pipeline_params_vt = pipeline_params_k;\n        if constexpr (Use_TMA_KV && !SameHeadDim) {\n            pipeline_params_vt.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV;\n            if constexpr (LargeHeadDimV) { pipeline_params_vt.num_consumers = NumMmaThreads; }\n        } else {\n            if constexpr (LargeHeadDimV) { pipeline_params_vt.consumer_arv_count = NumMmaThreads; }\n        }\n\n        MainloopPipelineK pipeline_k = [&] {\n            if constexpr (Use_TMA_KV) {\n                return MainloopPipelineK(shared_storage.pipelines.pipeline_k, pipeline_params_k, ClusterShape{});\n            } else {\n                return MainloopPipelineK(shared_storage.pipelines.pipeline_k, pipeline_params_k);\n            }\n        }();\n        // MainloopPipelineV pipeline_v(shared_storage.pipelines.pipeline_v, pipeline_params_v, ClusterShape{});\n        MainloopPipelineV pipeline_v = [&] {\n            if constexpr (!Transpose_V) {\n                static_assert(is_same_v<PipelineParamsK, PipelineParamsV>);\n                if constexpr (Use_TMA_KV) {\n                    return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_vt, ClusterShape{});\n                } else {\n                    return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_vt);\n                }\n            } else {\n                PipelineParamsV pipeline_params_v;\n                pipeline_params_v.role = warp_group_idx == 0\n                    ? MainloopPipelineV::ThreadCategory::Producer\n                    : MainloopPipelineV::ThreadCategory::Consumer;\n                pipeline_params_v.producer_arv_count = NumProducerThreads;\n                pipeline_params_v.consumer_arv_count = NumMmaThreads;\n                return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_v);\n            }\n        }();\n        // If we need to transpose V (e.g. FP8 and V is row-major), we use pipeline_vt for the TMA, then\n        // the producer WG will read from pipeline_vt and write to pipeline_v.\n        // If we don't need to transpose V, we use pipeline_v for the TMA, and pipeline_vt won't be used.\n        // Technically for pipeline_params_vt, warp0 of WG0 is the producer and all of WG0 are consumers.\n        // However, the thread role isn't used in the pipeline implementation.\n        MainloopPipelineVt pipeline_vt = [&] {\n            if constexpr (Use_TMA_KV) {\n                pipeline_params_vt.num_consumers = NumProducerThreads; // TMA_V is only consumed by the producer WG\n                return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_vt, ClusterShape{});\n            } else {\n                pipeline_params_vt.consumer_arv_count = NumProducerThreads; // TMA_V is only consumed by the producer WG\n                return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_vt);\n            }\n        }();\n\n        PipelineParamsKVNew pipeline_params_kv_new;\n        pipeline_params_kv_new.role = warp_group_idx == 0\n            ? MainloopPipelineKVNew::ThreadCategory::Producer\n            : MainloopPipelineKVNew::ThreadCategory::Consumer;\n        pipeline_params_kv_new.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK;\n        pipeline_params_kv_new.is_leader = warp_group_thread_idx == 0;\n        pipeline_params_kv_new.num_consumers = NumMmaThreads;\n        auto pipeline_k_new = cute::conditional_return<AppendKV>(MainloopPipelineKVNew(shared_storage.pipelines.pipeline_k_new, pipeline_params_kv_new, ClusterShape{}), nullptr);\n        if constexpr (!SameHeadDim) {\n            pipeline_params_kv_new.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV;\n        }\n        auto pipeline_v_new = cute::conditional_return<AppendKV>(MainloopPipelineKVNew(shared_storage.pipelines.pipeline_v_new, pipeline_params_kv_new, ClusterShape{}), nullptr);\n\n        CollectiveMainloop mainloop;\n        CollectiveEpilogue epilogue;\n\n        // We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster\n        if constexpr (size(ClusterShape{}) > 1) {\n            cute::cluster_arrive_relaxed();\n            cute::cluster_wait();\n        } else {\n            __syncthreads();\n        }\n\n        TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.pipelines.smem_scheduler));\n\n        if (warp_group_idx == 0) {  // Producer\n            cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();\n\n            // The pipelines for AppendKV and main attention are different, since e.g. main attention\n            // might use cp.async to load KV (if PagedKVNonTMA) while AppendKV always uses TMA to load\n            // KV_new. Since the pipeline states are different, we have to manually sync to make\n            // sure the two pipelines don't race when accessing smem_k and smem_v.\n            PipelineState smem_pipe_write = cutlass::make_producer_start_state<MainloopPipelineK>();\n            PipelineState smem_pipe_write_new = cutlass::make_producer_start_state<MainloopPipelineKVNew>();\n            int work_idx = 0;\n            int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);\n            static constexpr bool SingleProducerWarp = NumProducerThreads == cutlass::NumThreadsPerWarp;\n            if constexpr (SingleProducerWarp) {\n                if (warp_idx_in_warpgroup != 0) { return; }\n            }\n            if (!SingleProducerWarp && warp_idx_in_warpgroup != 0) { scheduler.init_consumer(); }\n\n            cutlass::arch::wait_on_dependent_grids();\n\n            // Load Q, K, V\n            for (auto work_tile_info = SingleProducerWarp || warp_idx_in_warpgroup == 0 ? scheduler.template get_initial_work</*IsProducerWarp=*/true>(params.scheduler) : scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);\n                 work_tile_info.is_valid(params.scheduler);\n                 work_tile_info = SingleProducerWarp || warp_idx_in_warpgroup == 0 ? scheduler.template get_next_work</*IsProducerWarp=*/true>(params.scheduler, work_tile_info) : scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {\n\n                auto block_coord = work_tile_info.get_block_coord(params.scheduler);\n                SeqlenInfo_t seqlen_info{\n                    get<2>(block_coord) /*bidb*/,\n                    get<0>(params.mainloop.shape_Q),\n                    !params.mainloop.ptr_pagetable ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable),\n                    get<0>(params.mainloop.shape_K_new),\n                    params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new,\n                    params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k,\n                    params.mainloop.seqlens_rotary\n                };\n                if constexpr (AppendKV) {\n                    bool tile_new_valid = mainloop.load_kv_new(\n                        params.mainloop, pipeline_k_new, pipeline_v_new,\n                        smem_pipe_write_new, shared_storage, seqlen_info, block_coord, work_idx);\n                    if (tile_new_valid) {\n                        // if (threadIdx.x == 0) { printf(\"Producer: Before sync\\n\"); }\n                        cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumProducerThreads, static_cast<uint32_t>(FwdNamedBarriers::AppendKV) /*id*/);\n                        // if (threadIdx.x == 0) { printf(\"Producer: After sync\\n\"); }\n                    }\n                }\n                auto scheduler_prefetch = [&scheduler, &params, &work_tile_info]() {\n                    scheduler.prefetch_next_work(params.scheduler, work_tile_info);\n                };\n                // pipeline_vt won't be used if we don't need to transpose V.\n                mainloop.load(params.mainloop, pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write,\n                                         shared_storage, scheduler_prefetch, seqlen_info, block_coord, work_idx);\n            }\n            mainloop.load_tail(pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write, shared_storage, work_idx);\n        } else {  // Consumer\n            cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();\n\n            // Initialize matmul objects.\n            TiledMmaPV tiled_mma_pv;\n\n            PipelineState smem_pipe_read;\n            PipelineState smem_pipe_read_new;\n            // We don't need separate variables smem_pipe_release_k and smem_pipe_release_v\n            // (like in Cutlass's gemm) because the read and release pipeline states are always the same.\n\n            scheduler.init_consumer();\n            mainloop.mma_init();\n\n            int work_idx = 0;\n            CUTLASS_PRAGMA_NO_UNROLL\n            for (auto work_tile_info = scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);\n                 work_tile_info.is_valid(params.scheduler);\n                 // get_next_work will be called before the epilogue\n                 ) {\n                auto block_coord = work_tile_info.get_block_coord(params.scheduler);\n                int const bidb = get<2>(block_coord);\n                SeqlenInfo_t seqlen_info{\n                    bidb,\n                    get<0>(params.mainloop.shape_Q),\n                    !params.mainloop.ptr_pagetable ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable),\n                    get<0>(params.mainloop.shape_K_new),\n                    params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new,\n                    params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k,\n                    params.mainloop.seqlens_rotary\n                };\n                if constexpr (AppendKV) {\n                    bool tile_new_valid = mainloop.store_kv_new(\n                        params.mainloop, pipeline_k_new, pipeline_v_new, smem_pipe_read_new,\n                        threadIdx.x - MmaThreadOffset, shared_storage, seqlen_info, block_coord);\n                    if (tile_new_valid) {\n                        // if (threadIdx.x == 128) { printf(\"Consumer: Before sync\\n\"); }\n                        // We need this sync so that the gmem write from the consumers is visible to the producer\n                        // that might do TMA read after that.\n                        asm volatile (\"fence.proxy.async.global;\");\n                        cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast<uint32_t>(FwdNamedBarriers::AppendKV) /*id*/);\n                        // arrive is enough, we don't need sync. The producer will sync, which means\n                        // after that sync we're guaranteed that the AppendKV pipeline have finished\n                        // loading and consumer smem_k and smem_v.\n                        // if (threadIdx.x == 128) { printf(\"Consumer: After sync\\n\"); }\n                    }\n                }\n                // If there's tanh softcap, the scaling will be done before tanh.\n                float softmax_scale_log2 = params.mainloop.softmax_scale_log2;\n                if constexpr (Is_FP8 && !Has_softcap) {\n                    int const bidh = get<1>(block_coord);\n                    int const bidh_kv = !PackGQA ? params.mainloop.qhead_per_khead_divmod.divide(bidh) : bidh;\n                    float const q_descale = params.mainloop.ptr_q_descale == nullptr ? 1.0f : params.mainloop.ptr_q_descale[bidb * get<0>(params.mainloop.stride_q_descale) + bidh_kv * get<1>(params.mainloop.stride_q_descale)];\n                    float const k_descale = params.mainloop.ptr_k_descale == nullptr ? 1.0f : params.mainloop.ptr_k_descale[bidb * get<0>(params.mainloop.stride_k_descale) + bidh_kv * get<1>(params.mainloop.stride_k_descale)];\n                    softmax_scale_log2 *= q_descale * k_descale;\n                }\n                flash::Softmax<!LargeHeadDimV ? 2 * (2 * kBlockM / NumMmaThreads) : 2, /*Max_offset=*/!Is_FP8 ? 0 : 8> softmax(softmax_scale_log2);\n                // Attention output (GEMM-II) accumulator.\n                Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_MNK_PV{}));\n                bool tile_valid;\n                if constexpr (!LargeHeadDimV) {\n                    tile_valid = mainloop.mma(\n                        params.mainloop, pipeline_k, pipeline_v, smem_pipe_read,\n                        tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage);\n                } else {  // mma_pv might not compile if !LargeHeadDimV\n                    if (warp_group_idx == 1) {\n                        tile_valid = mainloop.mma(\n                            params.mainloop, pipeline_k, pipeline_v, smem_pipe_read,\n                            tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage);\n                    } else {\n                        tile_valid = mainloop.mma_pv(\n                            params.mainloop, pipeline_v, smem_pipe_read,\n                            tOrO, softmax, threadIdx.x - MmaThreadOffset, seqlen_info, block_coord, shared_storage);\n                    }\n                }\n                // Do this here before the epilogue so that the next tile is ready to go.\n                work_tile_info = scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info);\n                if constexpr (Split && Varlen) {\n                    if (!work_tile_info.is_valid(params.scheduler)) {  // Last tile\n                        cutlass::arch::launch_dependent_grids();\n                    }\n                }\n                if (tile_valid) {\n                    // if (threadIdx.x == 128) { printf(\"Before epilogue, bid.x = %d, bid.y = %d, bid.z = %d, m_block = %d, bidb = %d, split_idx = %d\\n\", blockIdx.x, blockIdx.y, blockIdx.z, m_block, bidb, split_idx); }\n                    epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma_pv,\n                                   threadIdx.x - MmaThreadOffset, block_coord);\n                } else {\n                    // Write 0 to gO and -inf to gLSE.\n                    epilogue.store_zero(params.epilogue, threadIdx.x - MmaThreadOffset, block_coord);\n                }\n            }\n            epilogue.store_tail();\n        }\n\n    }\n\n};\n\n} // namespace flash\n"
  },
  {
    "path": "hopper/flash_fwd_launch_template.h",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n#include \"cute/tensor.hpp\"\n\n#include \"cutlass/cutlass.h\"\n#include \"cutlass/device_kernel.h\"  // For device_kernel\n#include <cutlass/kernel_hardware_info.h>\n#include \"cutlass/cluster_launch.hpp\"\n#include \"cutlass/kernel_launch.h\"\n\n#include \"cuda_check.h\"\n#include \"static_switch.h\"\n#include \"flash.h\"\n#include \"tile_size.h\"\n#include \"tile_scheduler.hpp\"\n#include \"flash_fwd_kernel_sm90.h\"\n#include \"flash_fwd_kernel_sm80.h\"\n#include \"mainloop_fwd_sm90_tma_gmma_ws.hpp\"\n#include \"mainloop_fwd_sm80.hpp\"\n#include \"epilogue_fwd.hpp\"\n\nusing namespace cute;\n\ntemplate <int Arch, int kHeadDim, int kHeadDimV, int ClusterM, typename Element, typename ElementOut,\n          bool Is_causal, bool Is_local, bool Has_softcap, bool Varlen, bool PagedKVNonTMA, bool AppendKV, bool HasQv,\n          bool PackGQA, bool Split, bool V_colmajor>\nvoid run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {\n    static_assert(!(Is_causal && Is_local), \"Causal and Local cannot be enabled at the same time\");\n    static_assert(!(AppendKV && V_colmajor), \"AppendKV and V_colmajor cannot be enabled at the same time\");\n    static_assert(!(AppendKV && !Varlen), \"AppendKV requires Varlen\");\n    static constexpr bool Is_FP8 = cute::is_same_v<Element, cutlass::float_e4m3_t> || cute::is_same_v<Element, cutlass::float_e5m2_t>;\n    static constexpr bool FP8_TransposeV = Is_FP8 && !V_colmajor;\n    using ArchTag = std::conditional_t<Arch >= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>;\n\n    // Can't use structured binding since it's not compatible with constexpr\n    static constexpr std::tuple<int, int, bool, bool> kBlockMN_RS_IntraWGOverlap = tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap);\n    static constexpr std::tuple<int, int, int, int, bool> kBlockMN_kNWarps_Stages_RS = tile_size_fwd_sm8x(Arch == 86 || Arch == 89, kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, PagedKVNonTMA, Varlen && Split, Has_softcap, AppendKV);\n    static constexpr int kBlockM = Arch >= 90 ? std::get<0>(kBlockMN_RS_IntraWGOverlap) : std::get<0>(kBlockMN_kNWarps_Stages_RS);\n    static constexpr int kBlockN = Arch >= 90 ? std::get<1>(kBlockMN_RS_IntraWGOverlap) : std::get<1>(kBlockMN_kNWarps_Stages_RS);\n    static constexpr bool MmaPV_is_RS = std::get<2>(kBlockMN_RS_IntraWGOverlap);\n    static constexpr bool IntraWGOverlap = std::get<3>(kBlockMN_RS_IntraWGOverlap);\n    static constexpr int kNWarps = std::get<2>(kBlockMN_kNWarps_Stages_RS);\n    static constexpr int kStages = Arch >= 90 ? 2 : std::get<3>(kBlockMN_kNWarps_Stages_RS);\n    static constexpr bool Q_in_regs = Arch >= 90 ? false : std::get<4>(kBlockMN_kNWarps_Stages_RS);\n\n    using TileShape_MNK = cute::Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;\n    using TileShape_MNK_PV = cute::Shape<Int<kBlockM>, Int<kHeadDimV>, Int<kBlockN>>;\n    using ClusterShape = cute::Shape<Int<ClusterM>, _1, _1>;\n    using CollectiveMainloop = std::conditional_t<\n        Arch >= 90,\n        flash::CollectiveMainloopFwdSm90<kStages, ClusterShape, TileShape_MNK, kHeadDimV, Element, float, cutlass::arch::Sm90, Is_causal, Is_local, Has_softcap, Varlen, PagedKVNonTMA, AppendKV, HasQv, MmaPV_is_RS, IntraWGOverlap, PackGQA, Split, V_colmajor>,\n        flash::CollectiveMainloopFwdSm80<kNWarps, kStages, Q_in_regs, TileShape_MNK, kHeadDimV, Element, float, cutlass::arch::Sm80, Is_causal, Is_local, Has_softcap, Varlen, PagedKVNonTMA, AppendKV, PackGQA, Split>\n    >;\n    using CollectiveEpilogue = flash::CollectiveEpilogueFwd<TileShape_MNK_PV, ClusterShape, ElementOut, ArchTag, CollectiveMainloop::NumMmaThreads, Varlen, PackGQA, Split, FP8_TransposeV>;\n\n    static constexpr int NumProducerThreads = Arch >= 90 ? CollectiveMainloop::NumProducerThreads : CollectiveMainloop::NumMmaThreads;\n    static constexpr bool LPT = Is_causal || Is_local;\n    static constexpr bool Sort = !Is_local;\n    using SchedulerPersistent = std::conditional_t<Varlen,\n        flash::VarlenDynamicPersistentTileScheduler<kBlockM, kBlockN, CollectiveMainloop::NumMmaThreads, NumProducerThreads, Split, PackGQA, Arch >= 90 /*WarpSpecialized*/, LPT, Sort, true /*Prepared*/>,\n        std::conditional_t<!Is_causal && !Is_local,\n            flash::StaticPersistentTileScheduler<Split>,\n            flash::DynamicPersistentTileScheduler<CollectiveMainloop::NumMmaThreads, NumProducerThreads, Split, PackGQA, Arch >= 90 /*WarpSpecialized*/>\n        >\n    >;\n    using SchedulerSingleTile = flash::SingleTileScheduler<Varlen, Split, PackGQA, kBlockM>;\n    // If Split then we probably don't have enough work for PersistentScheduler to be useful.\n    // However, if Varlen (e.g., during decode where we have max_seqlens), using PersistentScheduler is better\n    // since we'll avoid launching a bunch of thread blocks that immediately exit.\n    // On Sm80, noncausal persistent seems a bit slower.\n    static constexpr bool UsePersistentScheduler = Arch >= 90 ? !(Split && !Varlen) : ((Is_causal && !Varlen) || (Varlen && Split));\n    using Scheduler = std::conditional_t<!UsePersistentScheduler, SchedulerSingleTile, SchedulerPersistent>;\n    using AttnKernel = std::conditional_t<\n        Arch >= 90,\n        flash::enable_sm90<flash::FlashAttnFwdSm90<CollectiveMainloop, CollectiveEpilogue, Scheduler>>,\n        flash::enable_sm80_to_sm89<flash::FlashAttnFwdSm80<CollectiveMainloop, CollectiveEpilogue, Scheduler>>\n    >;\n\n    bool const is_varlen_q = params.cu_seqlens_q;\n    bool const is_varlen_k = params.cu_seqlens_k;\n    bool const is_varlen_k_new = params.cu_seqlens_knew;\n    int seqlen_q = !is_varlen_q ? params.seqlen_q : params.total_q;\n    int batch_q = !is_varlen_q ? params.b : 1;\n    int batch_k = !is_varlen_k ? (params.kv_batch_idx ? params.b_k : params.b) : 1;\n    typename CollectiveMainloop::StrideV v_strides =\n        cute::conditional_return<!V_colmajor>(\n            make_stride(params.v_row_stride, _1{}, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0),\n            make_stride(_1{}, params.v_dim_stride, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0));\n    typename CollectiveMainloop::Arguments mainloop_args {\n        static_cast<Element const*>(params.q_ptr),\n        {seqlen_q, params.d, params.h, batch_q},  // shape_Q\n        {params.q_row_stride, _1{}, params.q_head_stride, !is_varlen_q ? params.q_batch_stride : 0},  // stride_Q\n        static_cast<Element*>(params.k_ptr),\n        {!params.page_table ? (!is_varlen_k ? params.seqlen_k : params.total_k) : params.page_size,\n         params.d, params.h_k, !params.page_table ? batch_k : params.num_pages},  // shape_K\n        {params.k_row_stride, _1{}, params.k_head_stride, !is_varlen_k ? params.k_batch_stride : 0},  // stride_K\n        static_cast<Element*>(params.v_ptr),\n        params.dv,  // headdim_v\n        v_strides,  // stride_V\n        static_cast<Element const*>(params.knew_ptr),\n        {!is_varlen_k_new ? params.seqlen_knew : params.total_knew, params.d, params.h_k, !is_varlen_k_new ? params.b : 1},  // shape_K_new\n        {params.knew_row_stride, _1{}, params.knew_head_stride, !is_varlen_k_new ? params.knew_batch_stride : 0},  // stride_K_new\n        static_cast<Element const*>(params.vnew_ptr),\n        {params.vnew_row_stride, _1{}, params.vnew_head_stride, !is_varlen_k_new ? params.vnew_batch_stride : 0}, // stride_V_new\n        static_cast<Element const*>(params.qv_ptr),\n        {params.qv_row_stride, _1{}, params.qv_head_stride, !is_varlen_q ? params.qv_batch_stride : 0},  // stride_Qv\n        static_cast<Element const*>(params.rotary_cos_ptr),\n        {params.seqlen_k, params.rotary_dim / 2},  // shape_rotary, the seqlen shape doesn't matter\n        {params.rotary_dim / 2, _1{}},  // stride_rotary_cos\n        static_cast<Element const*>(params.rotary_sin_ptr),\n        {params.rotary_dim / 2, _1{}},  // stride_rotary_sin\n        params.is_rotary_interleaved,\n        params.page_table,\n        // if page_size is not set, avoid dividing by zero\n        {params.kv_batch_idx ? params.b_k : params.b, !params.page_table ? 0 : params.seqlen_k / params.page_size}, // shape_page_table\n        {params.page_table_batch_stride, _1{}},  // stride_page_table\n        params.scale_softmax,\n        params.q_descale_ptr, params.k_descale_ptr, params.v_descale_ptr,\n        {params.q_descale_batch_stride, params.q_descale_head_stride},\n        {params.k_descale_batch_stride, params.k_descale_head_stride},\n        {params.v_descale_batch_stride, params.v_descale_head_stride},\n        params.window_size_left, params.window_size_right, params.attention_chunk,\n        params.softcap,\n        params.num_splits,\n        params.kv_batch_idx,\n        params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew,\n        params.seqused_q, params.seqused_k,\n        params.leftpad_k, params.seqlens_rotary\n    };\n    typename CollectiveEpilogue::Arguments epilogue_args {\n        static_cast<ElementOut*>(params.o_ptr),\n        {seqlen_q, params.dv, params.h, batch_q, params.num_splits},  // shape_O\n        {params.o_row_stride, _1{}, params.o_head_stride, !is_varlen_q ? params.o_batch_stride : 0, 0}, // stride_O\n        static_cast<float*>(params.oaccum_ptr),\n        {params.oaccum_row_stride, _1{}, params.oaccum_head_stride, !is_varlen_q ? params.oaccum_batch_stride : 0, params.oaccum_split_stride}, // stride_O_partial\n        static_cast<float*>(params.softmax_lse_ptr),\n        {_1{}, seqlen_q, !is_varlen_q ? params.h * seqlen_q : 0, 0},  // stride_LSE\n        static_cast<float*>(params.softmax_lseaccum_ptr),\n        {_1{}, seqlen_q, !is_varlen_q ? params.h * seqlen_q : 0, params.h * seqlen_q * batch_q},  // stride_LSE_partial\n        params.h_k,\n        params.cu_seqlens_q, params.seqused_q\n    };\n\n    int qhead_per_khead = !PackGQA ? 1 : cutlass::ceil_div(params.h, params.h_k);\n    int num_blocks_m = cutlass::ceil_div(params.seqlen_q * qhead_per_khead, get<0>(TileShape_MNK{}));\n    num_blocks_m = cutlass::round_up(num_blocks_m, size<0>(ClusterShape{}));\n    typename flash::TileSchedulerArguments scheduler_args {\n        num_blocks_m, !PackGQA ? params.h : params.h_k, params.b, params.num_splits,\n        params.h / params.h_k,\n        params.seqlen_q,\n        params.seqlen_k, params.d, params.dv, sizeof(Element), \n        params.tile_count_semaphore, params.cu_seqlens_q, params.seqused_q,\n        params.num_splits_dynamic_ptr,\n        params.num_m_blocks_ptr,\n        params.varlen_batch_idx_ptr,\n        params.num_nheads_in_l2_ptr\n    };\n\n    if (Varlen && !params.skip_scheduler_metadata_computation) {\n        prepare_varlen_num_blocks(params, stream, PackGQA, kBlockM, kBlockN, Arch >= 90 && params.prepare_varlen_pdl /*enable_pdl*/);\n        CHECK_CUDA_KERNEL_LAUNCH();\n    }\n\n    int device;\n    CHECK_CUDA(cudaGetDevice(&device));\n    typename AttnKernel::Params kernel_params = AttnKernel::to_underlying_arguments({\n        mainloop_args, epilogue_args, {device, params.num_sm}, scheduler_args\n    });\n\n    dim3 grid_dims = AttnKernel::get_grid_shape(kernel_params);\n    dim3 block_dims = AttnKernel::get_block_shape();\n    int smem_size = AttnKernel::SharedStorageSize;\n    // int smem_size_q = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_q));\n    // int smem_size_k = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_k));\n    // int smem_size_v = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v));\n    // printf(\"smem_size = %d, q = %d, k = %d, v = %d\\n\", smem_size, smem_size_q, smem_size_k, smem_size_v);\n    // Get the ptr to kernel function.\n    if constexpr (size(ClusterShape{}) > 1) {\n        void const* kernel = (void const*) cutlass::device_kernel<AttnKernel>;\n        if (smem_size >= 48 * 1024) {\n            CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));\n        }\n        dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));\n        cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};\n        CHECK_CUTLASS(cutlass::launch_kernel_on_cluster(launch_params, kernel, kernel_params));\n    } else {\n        auto kernel = cutlass::device_kernel<AttnKernel>;\n        if (smem_size >= 48 * 1024) {\n            CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));\n        }\n        // kernel<<<grid_dims, block_dims, smem_size, stream>>>(kernel_params);\n        CHECK_CUTLASS(cutlass::kernel_launch<AttnKernel>(grid_dims, block_dims, smem_size, stream, kernel_params,\n                                           Arch >= 90 && Varlen && !params.skip_scheduler_metadata_computation && params.prepare_varlen_pdl /*launch_with_pdl*/));\n    }\n}\n\ntemplate<int Arch, typename T, int kHeadDim, int kHeadDimV, bool Split, bool PagedKVNonTMA, bool Has_softcap, bool PackGQA>\nvoid run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream) {\n    static_assert(sizeof(T) == 2 || sizeof(T) == 1, \"Only 16bit and 8bit are supported\");\n    static constexpr bool Is_FP8 = cute::is_same_v<T, cutlass::float_e4m3_t> || cute::is_same_v<T, cutlass::float_e5m2_t>;\n    using T_out = std::conditional_t<!Is_FP8, T, cutlass::bfloat16_t>;\n    CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {\n        VCOLMAJOR_SWITCH(params.v_dim_stride != 1, V_colmajor_, [&] {\n            static constexpr bool V_colmajor = V_colmajor_ && sizeof(T) == 1;\n            VARLEN_SWITCH(params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k, Varlen, [&] {\n                // Only needed here to decide if we should use cluster\n                static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap)) : 128;\n                static constexpr bool Enable_cluster = Arch == 90 && (sizeof(T) == 2 ? (kHeadDim >= 128) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKVNonTMA && !Varlen;\n                BOOL_SWITCH(params.qv_ptr, HasQV_, [&] {\n                    static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 && kHeadDimV >= 256;\n                    APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] {\n                        // Only use Cluster if number of tiles along seqlen_q is even and not varlen\n                        CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] {\n                            static constexpr int ClusterM = Enable_cluster && Use_cluster ? 2 : 1;\n                            run_flash_fwd<Arch, kHeadDim, kHeadDimV, ClusterM, T, T_out, Is_causal, Is_local, Has_softcap, Varlen, PagedKVNonTMA, AppendKV && Varlen, HasQv, PackGQA, Split, V_colmajor>(params, stream);\n                        });\n                    });\n                });\n            });\n        });\n    });\n}\n"
  },
  {
    "path": "hopper/flash_prepare_scheduler.cu",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n ******************************************************************************/\n\n#include <cub/cub.cuh>\n#include \"cutlass/fast_math.h\"\n#include \"cutlass/barrier.h\"\n#include \"cutlass/arch/barrier.h\"\n\n#include \"cutlass/arch/grid_dependency_control.h\"\n\n#include \"flash.h\"\n\n#include \"static_switch.h\"\n\nnamespace flash {\n\n// Sort in descending order\ntemplate <typename T>\nstruct PrepareSortOp\n{\n    __device__ __forceinline__ bool operator()(T const & lhs, T const & rhs)\n    {\n        return lhs > rhs;\n    }\n};\n\ntemplate <>\nstruct PrepareSortOp<int2> {\n    __device__ __forceinline__ bool operator()(int2 const & lhs, int2 const & rhs) const {\n        return lhs.x > rhs.x;\n    }\n};\n\ntemplate <>\nstruct PrepareSortOp<int4> {\n    __device__ __forceinline__ bool operator()(int4 const & lhs, int4 const & rhs) const {\n        return lhs.x > rhs.x;\n    }\n};\n\ntemplate <int NumWarps, bool Sort>\n__global__ void prepare_varlen_num_blocks_kernel(\n        int seqlen_q_static, int seqlen_k_static, int seqlen_k_new_static,\n        int const* const cu_seqlens_q, int const* const cu_seqlens_k, int const* const cu_seqlens_k_new,\n        int const* const seqused_q, int const* const seqused_k, int const* const leftpad_k_ptr,\n        int num_batch, int num_head, int qhead_per_khead, int num_sm, int num_splits_static,\n        cutlass::FastDivmod blockm_divmod, cutlass::FastDivmod blockn_divmod,\n        int* const tile_count_semaphore,\n        int* const num_m_blocks_ptr,\n        int* const num_splits_dynamic_ptr,\n        int* const varlen_batch_idx_ptr,\n        // int* const num_n_blocks_ptr,\n        int* const num_nheads_in_l2_ptr,\n        bool enable_pdl,\n        bool is_causal,\n        bool packgqa,\n        int max_kvblocks_in_l2) {\n\n    static constexpr int kNumBatchPerWarp = cutlass::NumThreadsPerWarp - 1;\n    static constexpr int kSmemSize = 1;\n    static constexpr int BLOCK_DIM_X = NumWarps * 32;\n    static constexpr int ITEMS_PER_THREAD = 1;\n    static_assert(BLOCK_DIM_X * ITEMS_PER_THREAD == NumWarps * 32);\n    using BlockMergeSort = cub::BlockMergeSort<int4, BLOCK_DIM_X, ITEMS_PER_THREAD>;\n\n    __shared__ int total_blocks_smem[kSmemSize];\n\n    // Allocate shared memory for BlockMergeSort operations\n    __shared__ typename BlockMergeSort::TempStorage temp_storage;\n\n    if (enable_pdl) { cutlass::arch::launch_dependent_grids(); }\n\n    if (threadIdx.x < kSmemSize) { total_blocks_smem[threadIdx.x] = 0; }\n    __syncthreads();\n\n    if (threadIdx.x == 0 && tile_count_semaphore) { *tile_count_semaphore = 0; }\n\n    int lane = threadIdx.x % cutlass::NumThreadsPerWarp;\n\n    auto get_num_m_blocks = [&](int batch_idx) {\n        int seqlen;\n        if (seqused_q) {\n            seqlen = batch_idx < num_batch ? seqused_q[batch_idx] : 0;\n        } else if (cu_seqlens_q) {\n            int cur_cu_seqlen = batch_idx <= num_batch ? cu_seqlens_q[batch_idx] : 0;\n            int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1);\n            seqlen = next_cu_seqlen - cur_cu_seqlen;\n        } else {\n            seqlen = seqlen_q_static;\n        }\n        if(packgqa) { seqlen *= qhead_per_khead; }\n        return batch_idx < num_batch && lane < kNumBatchPerWarp\n            ? blockm_divmod.div(seqlen + blockm_divmod.divisor - 1) : 0;\n    };\n\n    auto get_num_n_blocks = [&](int batch_idx) {\n        int leftpad_k = batch_idx < num_batch && leftpad_k_ptr != nullptr ? leftpad_k_ptr[batch_idx] : 0;\n        int seqlen;\n        if (seqused_k) {\n            seqlen = batch_idx < num_batch ? seqused_k[batch_idx] : 0;\n        } else if (cu_seqlens_k) {\n            int cur_cu_seqlen = batch_idx <= num_batch ? cu_seqlens_k[batch_idx] : 0;\n            int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1);\n            seqlen = next_cu_seqlen - cur_cu_seqlen;\n        } else {\n            seqlen = seqlen_k_static;\n        }\n        int seqlen_new;\n        if (cu_seqlens_k_new) {\n            int cur_cu_seqlen_new = batch_idx <= num_batch ? cu_seqlens_k_new[batch_idx] : 0;\n            int next_cu_seqlen_new = __shfl_down_sync(0xffffffff, cur_cu_seqlen_new, 1);\n            seqlen_new = next_cu_seqlen_new - cur_cu_seqlen_new;\n        } else {\n            seqlen_new = seqlen_k_new_static;\n        }\n        // if (threadIdx.x == 0) { printf(\"seqlen = %d, seqlen_new = %d, leftpad_k = %d\\n\", seqlen, seqlen_new, leftpad_k); }\n        seqlen = seqlen - leftpad_k + seqlen_new;\n        return batch_idx < num_batch && lane < kNumBatchPerWarp\n            ? blockn_divmod.div(seqlen + blockn_divmod.divisor - 1) : 0;\n    };\n\n    int warp_idx = threadIdx.x / cutlass::NumThreadsPerWarp;\n    int batch_cta_idx_offset = int(blockIdx.x) * 992;\n    int bidb_start = batch_cta_idx_offset + kNumBatchPerWarp * warp_idx;\n    int batch_idx = lane + bidb_start;\n    int num_m_blocks = get_num_m_blocks(batch_idx);\n    int num_n_blocks = get_num_n_blocks(batch_idx);\n\n    auto get_nheads_in_l2 = [&](int n_blocks) {\n        int nheads_in_l2 = n_blocks * 16 <= max_kvblocks_in_l2 ? 16\n            : n_blocks * 8 <= max_kvblocks_in_l2 ? 8\n            : n_blocks * 4 <= max_kvblocks_in_l2 ? 4\n            : n_blocks * 2 <= max_kvblocks_in_l2 ? 2\n            : 1;\n        if(!packgqa) { nheads_in_l2 *= qhead_per_khead; }\n        return min(nheads_in_l2, num_head);\n    };\n    \n    int num_splits_dynamic;\n    if (int(gridDim.x) > 1 || num_splits_static == 1) {\n        // set num splits for all batches to 1 (note that user expects num_splits_static to mean upper bound on splits)\n        // for batch size > 992, we expect GPU occupancy to not be an issue except in degenerate cases (e.g., most are zero-length)\n        num_splits_dynamic = 1;\n    } else {\n        int total_blocks = num_m_blocks * num_n_blocks;\n        // Warp sum\n        #pragma unroll\n        for (int i = cutlass::NumThreadsPerWarp / 2; i >= 1; i /= 2) {\n            total_blocks += __shfl_down_sync(0xffffffff, total_blocks, i);\n        }\n        if (lane == 0) { atomicAdd(total_blocks_smem, total_blocks); }\n        __syncthreads();\n        total_blocks = total_blocks_smem[0];\n        // 10% margin\n        int blocks_per_sm = static_cast<int>(ceilf(float(total_blocks) * 1.1f * float(num_head) / float(num_sm)));\n        // blocks_per_sm = std::max(1, blocks_per_sm);  // 1 is the minimum number of blocks per SM\n        num_splits_dynamic = std::max(std::min((num_n_blocks + blocks_per_sm - 1) / blocks_per_sm, num_splits_static), 1);\n        // num_n_blocks per work tile for the batch\n        num_n_blocks = cutlass::ceil_div(num_n_blocks, num_splits_dynamic); \n    }\n\n    if constexpr (Sort) {\n        if(lane == kNumBatchPerWarp || batch_idx >= num_batch) {\n            num_n_blocks = INT_MIN; // sort last\n        } else if (is_causal) {\n            // sort by shortest member to process\n            num_n_blocks = num_n_blocks * blockn_divmod.divisor - num_m_blocks * blockm_divmod.divisor;\n        }\n        int4 batch_coords[ITEMS_PER_THREAD]; // 1 item per thread\n        batch_coords[0] = make_int4(num_n_blocks, num_m_blocks, num_splits_dynamic, batch_idx);\n\n        // if (threadIdx.x == 0) {\n        //     printf(\"Unsorted: num_n_blocks - num_m_blocks = %d, num_m_blocks = %d, num_splits = %d, batch_idx = %d.\\n\", \n        //         batch_coords[0].x, batch_coords[0].y, batch_coords[0].z, batch_coords[0].w);\n        // } __syncthreads();\n\n        // Sort batches by num_n_blocks in descending order\n        BlockMergeSort(temp_storage).Sort(batch_coords, PrepareSortOp<int4>());\n\n        // if (threadIdx.x == 0) {\n        //     printf(\"Sorted: num_n_blocks - num_m_blocks = %d, num_m_blocks = %d, num_splits = %d, batch_idx = %d.\\n\", \n        //         batch_coords[0].x, batch_coords[0].y, batch_coords[0].z, batch_coords[0].w);\n        // } __syncthreads();\n\n        if (is_causal) {\n            // reset value to num_n_blocks\n            batch_coords[0].x = blockn_divmod.div(batch_coords[0].x + batch_coords[0].y * blockm_divmod.divisor);\n        }\n\n        // When sorting, we re-index some metadata by 'virtual batch index'\n        // and also store the vbidx -> bidx mapping.\n        // 1. num_nheads_in_l2_ptr: virtual_batch_idx -> num_nheads_in_l2[batch_idx]\n        // 2. num_splits_dynamic_ptr: virtual_batch_idx -> num_splits[batch_idx]\n        // 3. num_m_blocks_ptr: virtual_batch_idx -> num_m_blocks[batch_idx]\n        // 4. varlen_batch_idx_ptr: virtual_batch_idx -> batch_idx      \n        batch_idx = batch_cta_idx_offset + threadIdx.x;\n        if (batch_idx < num_batch && threadIdx.x < 992) {\n            // num_n_blocks_ptr[threadIdx.x] = max(batch_coords[0].x, 1);\n            if(num_nheads_in_l2_ptr) { num_nheads_in_l2_ptr[batch_idx] = get_nheads_in_l2(max(batch_coords[0].x, 1)); }\n            num_m_blocks_ptr[batch_idx] = batch_coords[0].y;\n            num_splits_dynamic_ptr[batch_idx] = batch_coords[0].z;\n            varlen_batch_idx_ptr[batch_idx] = batch_coords[0].w;\n        }  \n    } else {\n        if (batch_idx < num_batch && lane < kNumBatchPerWarp) {\n            // num_n_blocks_ptr[batch_idx] = max(num_n_blocks, 1);\n            if(num_nheads_in_l2_ptr) { num_nheads_in_l2_ptr[batch_idx] = get_nheads_in_l2(max(num_n_blocks, 1)); }\n            num_splits_dynamic_ptr[batch_idx] = num_splits_dynamic;\n            num_m_blocks_ptr[batch_idx] = num_m_blocks;\n            // printf(\"idx = %d, num_m_blocks = %d, num_n_blocks = %d, num_split_static = %d, num_splits_dynamic = %d\\n\", bidb_start + lane, num_m_blocks_ptr[bidb_start + lane], num_n_blocks, num_splits_static, num_splits_dynamic);\n        }\n    }\n    \n}\n\n} // flash\n\nvoid prepare_varlen_num_blocks(Flash_fwd_params &params, cudaStream_t stream, bool packgqa,\n                               int blockM, int blockN, bool enable_pdl) {\n    int qhead_per_khead = cutlass::ceil_div(params.h, params.h_k);\n    int num_warps = cutlass::ceil_div(params.b, 31); // warp switch will cap this at 32\n    int num_ctas = cutlass::ceil_div(params.b, 31 * 32);\n    // int const size_l2 = 50 * 1024 * 1024; // 50 MB\n    int const size_l2 = 8 * 1024 * 1024; // underestimate seems better in practice\n    int const element_size = params.is_e4m3 ? 1 : 2;\n    int const size_one_kvblock = blockN * (params.d + params.dv) * element_size;\n    // printf(\"block size = %d, element size = %d, headdim = %d, headdim_v = %d, size 1 kblock = %d.\\n\", blockN, element_size, params.d, params.dv, size_one_kvblock);\n    int const max_kvblocks_in_l2 = size_l2 / size_one_kvblock;\n    BOOL_SWITCH(params.varlen_sort_batches, Sort, [&] {\n        NUM_WARP_SWITCH(num_warps, NumWarps, [&] {\n            flash::prepare_varlen_num_blocks_kernel<NumWarps, Sort><<<num_ctas /*grid*/, 32 * NumWarps /*block*/, 0, stream>>>(\n                params.seqlen_q, params.seqlen_k, params.seqlen_knew,\n                params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew,\n                params.seqused_q, params.seqused_k, params.leftpad_k,\n                params.b, !packgqa ? params.h : params.h_k, qhead_per_khead, params.num_sm, params.num_splits,\n                cutlass::FastDivmod(blockM), cutlass::FastDivmod(blockN),\n                params.tile_count_semaphore,\n                params.num_m_blocks_ptr,\n                params.num_splits_dynamic_ptr,\n                params.varlen_batch_idx_ptr,\n                // params.num_n_blocks_ptr,\n                params.num_nheads_in_l2_ptr,\n                enable_pdl,\n                params.is_causal,\n                packgqa,\n                max_kvblocks_in_l2);\n        });\n    });\n}\n"
  },
  {
    "path": "hopper/generate_kernels.py",
    "content": "# Copied from Driss Guessous's PR in PyTorch: https://github.com/pytorch/pytorch/pull/105602\n\n# This file is run to generate the kernel instantiations for the flash_attn kernels\n# They are written to several files in order to speed up compilation\n\nimport argparse\nimport itertools\nfrom collections import namedtuple\nfrom dataclasses import dataclass\nfrom pathlib import Path\nfrom typing import List, Optional\n\nKERNEL_BATCH = namedtuple(\"Kernel\", [\"template\", \"filename\"])\n\nDTYPE_MAP = {\n    \"fp16\": \"cutlass::half_t\",\n    \"bf16\": \"cutlass::bfloat16_t\",\n    \"e4m3\": \"cutlass::float_e4m3_t\",\n}\n\nDTYPE_MAP_FWD_SM8x = {\n    \"fp16\": \"cutlass::half_t\",\n    \"bf16\": \"cutlass::bfloat16_t\",\n}\n\nDTYPE_MAP_BWD = {\n    \"fp16\": \"cutlass::half_t\",\n    \"bf16\": \"cutlass::bfloat16_t\",\n}\n\nSM = [80, 90]  # Sm kernels support up to\nHEAD_DIMENSIONS = [64, 96, 128, 192, 256]\nPAGEDKV = [False, True]\nSPLIT = [False, True]\nSOFTCAP = [False, True]\nPACKGQA = [False, True]\n\nKERNEL_IMPL_TEMPLATE_FWD_SM90 = \"\"\"#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM{HEAD_DIM}\ntemplate void run_mha_fwd_<{ARCH}, {DTYPE}, {HEAD_DIM}, {HEAD_DIM_V}, {SPLIT}, {PAGEDKV}, {SOFTCAP}, {PACKGQA}>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n\"\"\"\n\nKERNEL_IMPL_TEMPLATE_FWD_SM8x = \"\"\"#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM{HEAD_DIM}\ntemplate void run_mha_fwd_<80, {DTYPE}, {HEAD_DIM}, {HEAD_DIM_V}, {SPLIT}, {PAGEDKV}, {SOFTCAP}, {PACKGQA}>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, {DTYPE}, {HEAD_DIM}, {HEAD_DIM_V}, {SPLIT}, {PAGEDKV}, {SOFTCAP}, {PACKGQA}>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n\"\"\"\n\nKERNEL_IMPL_TEMPLATE_BWD_SM90 = \"\"\"#include \"flash_bwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM{HEAD_DIM}\ntemplate<>\nvoid run_mha_bwd_<{ARCH}, {DTYPE}, {HEAD_DIM}, {SOFTCAP}>(Flash_bwd_params &params, cudaStream_t stream) {{\n    run_mha_bwd_hdim{HEAD_DIM}<{ARCH}, {DTYPE}, {SOFTCAP}>(params, stream);\n}}\n#endif\n\"\"\"\n\nKERNEL_IMPL_TEMPLATE_BWD_SM8x = \"\"\"#include \"flash_bwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM{HEAD_DIM}\ntemplate<>\nvoid run_mha_bwd_<80, {DTYPE}, {HEAD_DIM}, {SOFTCAP}>(Flash_bwd_params &params, cudaStream_t stream) {{\n    run_mha_bwd_hdim{HEAD_DIM}<80, {DTYPE}, {SOFTCAP}>(params, stream);\n}}\ntemplate<>\nvoid run_mha_bwd_<86, {DTYPE}, {HEAD_DIM}, {SOFTCAP}>(Flash_bwd_params &params, cudaStream_t stream) {{\n    run_mha_bwd_hdim{HEAD_DIM}<86, {DTYPE}, {SOFTCAP}>(params, stream);\n}}\n#endif\n#endif\n\"\"\"\n\n\n\n@dataclass\nclass Kernel:\n    sm: int\n    dtype: str\n    head_dim: int\n    head_dim_v: int\n    split: bool\n    paged_kv: bool\n    softcap: bool\n    packgqa: bool\n    direction: str\n\n    @property\n    def template(self) -> str:\n        if self.direction == \"fwd\":\n            if self.sm == 90:\n                # Always enable PackGQA for PagedKV or Split to reduce compilation\n                packgqa = self.packgqa or self.paged_kv or self.split\n                return KERNEL_IMPL_TEMPLATE_FWD_SM90.format(\n                    ARCH=str(self.sm), DTYPE=DTYPE_MAP[self.dtype],\n                    HEAD_DIM=self.head_dim, HEAD_DIM_V=self.head_dim_v,\n                    SPLIT=str(self.split).lower(), PAGEDKV=str(self.paged_kv).lower(),\n                    SOFTCAP=str(self.softcap).lower(), PACKGQA=str(packgqa).lower()\n                )\n            else:\n                # Always enable PackGQA for Sm8x to reduce compilation\n                return KERNEL_IMPL_TEMPLATE_FWD_SM8x.format(\n                    DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, HEAD_DIM_V=self.head_dim_v,\n                    SPLIT=str(self.split).lower(), PAGEDKV=str(self.paged_kv).lower(),\n                    SOFTCAP=str(self.softcap).lower(), PACKGQA=str(True).lower()\n                )\n        elif self.direction == \"bwd\":\n            if self.sm == 90:\n                return KERNEL_IMPL_TEMPLATE_BWD_SM90.format(\n                    ARCH=str(self.sm), DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim,\n                    SOFTCAP=str(self.softcap).lower()\n                )\n            else:\n                return KERNEL_IMPL_TEMPLATE_BWD_SM8x.format(\n                    DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim,\n                    SOFTCAP=str(self.softcap).lower()\n                )\n\n    @property\n    def filename(self) -> str:\n        return f\"flash_{self.direction}_hdim{self.head_dim}{f'_{self.head_dim_v}' if self.head_dim_v != self.head_dim else ''}_{self.dtype}{'_paged' if self.paged_kv else ''}{'_split' if self.split else ''}{'_softcap' if self.softcap else ''}{'_packgqa' if self.packgqa else ''}_sm{self.sm}.cu\"\n\n\ndef get_all_kernels() -> List[Kernel]:\n    for dtype, head_dim, split, paged_kv, softcap, packgqa, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, SPLIT, PAGEDKV, SOFTCAP, PACKGQA, SM):\n        # We always enable PackGQA for Sm8x or PagedKV or Split\n         # so we should just pass in packgqa=False to avoid the `_packgqa` in the filename.\n        if packgqa and (sm < 90 or (sm >= 90 and (paged_kv or split))):\n            continue\n        if sm >= 90 or dtype in DTYPE_MAP_FWD_SM8x:\n            yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=head_dim, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction=\"fwd\")\n        if sm == 90 and head_dim == 192:\n            yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=128, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction=\"fwd\")\n        if sm == 90 and head_dim == 64 and dtype in [\"bf16\", \"fp16\"]:\n            yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=256, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction=\"fwd\")\n            yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=512, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction=\"fwd\")\n    for dtype, head_dim, softcap, sm in itertools.product(DTYPE_MAP_BWD.keys(), HEAD_DIMENSIONS, SOFTCAP, SM):\n        yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=head_dim, split=False, paged_kv=False, softcap=softcap, packgqa=False, direction=\"bwd\")\n\n\ndef batch_hdim(kernels_all) -> List[KERNEL_BATCH]:\n    for dtype, split, paged_kv, softcap, packgqa, sm in itertools.product(DTYPE_MAP.keys(), SPLIT, PAGEDKV, SOFTCAP, PACKGQA, SM):\n        if sm < 90:\n            continue\n        # Same hdim and hdimv\n        kernels = [k for k in kernels_all if k.direction == \"fwd\" and k.dtype == dtype and k.split == split and k.paged_kv == paged_kv and k.softcap == softcap and k.packgqa == packgqa and k.sm == sm and k.head_dim == k.head_dim_v]\n        if len(kernels) > 0:\n            filename = f\"flash_fwd_hdimall_{dtype}{'_paged' if paged_kv else ''}{'_split' if split else ''}{'_softcap' if softcap else ''}{'_packgqa' if packgqa else ''}_sm{sm}.cu\"\n            template = \"\\n\".join([f\"#include \\\"{k.filename}\\\"\" for k in kernels])\n            yield KERNEL_BATCH(template, filename)\n        # Different hdim and hdimv\n        kernels = [k for k in kernels_all if k.direction == \"fwd\" and k.dtype == dtype and k.split == split and k.paged_kv == paged_kv and k.softcap == softcap and k.packgqa == packgqa and k.sm == sm and k.head_dim != k.head_dim_v]\n        if len(kernels) > 0:\n            filename = f\"flash_fwd_hdimdiff_{dtype}{'_paged' if paged_kv else ''}{'_split' if split else ''}{'_softcap' if softcap else ''}{'_packgqa' if packgqa else ''}_sm{sm}.cu\"\n            template = \"\\n\".join([f\"#include \\\"{k.filename}\\\"\" for k in kernels])\n            yield KERNEL_BATCH(template, filename)\n\n\ndef batch_softcap(kernels_all) -> List[KERNEL_BATCH]:\n    for dtype, head_dim, split, paged_kv, packgqa, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, SPLIT, PAGEDKV, PACKGQA, SM):\n        if sm >= 90:\n            continue\n        kernels = [k for k in kernels_all if k.direction == \"fwd\" and k.dtype == dtype and k.head_dim == head_dim and k.split == split and k.paged_kv == paged_kv and k.packgqa == packgqa and k.sm == sm]\n        if len(kernels) > 0:\n            filename = f\"flash_fwd_hdim{head_dim}_{dtype}{'_paged' if paged_kv else ''}{'_split' if split else ''}_softcapall{'_packgqa' if packgqa else ''}_sm{sm}.cu\"\n            template = \"\\n\".join([f\"#include \\\"{k.filename}\\\"\" for k in kernels])\n            yield KERNEL_BATCH(template, filename)\n\n    # Bwd\n    for dtype, head_dim, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, SM):\n        if sm < 90:\n            continue\n        kernels = [k for k in kernels_all if k.direction == \"bwd\" and k.dtype == dtype and k.head_dim == head_dim and k.sm == sm]\n        if len(kernels) > 0:\n            filename = f\"flash_bwd_hdim{head_dim}_{dtype}_softcapall_sm{sm}.cu\"\n            template = \"\\n\".join([f\"#include \\\"{k.filename}\\\"\" for k in kernels])\n            yield KERNEL_BATCH(template, filename)\n\n\ndef write_kernel(kernel: Kernel, autogen_dir: Path) -> None:\n    prelude = \"\"\"// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\\n\n\"\"\"\n    (autogen_dir / kernel.filename).write_text(prelude + kernel.template)\n\n\ndef main(output_dir: Optional[str]) -> None:\n    output_dir = Path(output_dir) if output_dir is not None else Path(__file__).parent\n    output_dir.mkdir(parents=True, exist_ok=True)\n    kernels_all = list(get_all_kernels())\n    for kernel in kernels_all:\n        write_kernel(kernel, output_dir)\n    for kernel in batch_hdim(kernels_all):\n        write_kernel(kernel, output_dir)\n    for kernel in batch_softcap(kernels_all):\n        write_kernel(kernel, output_dir)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\n        prog=\"generate_kernels\",\n        description=\"Generate the flash_attention kernels template instantiations\",\n    )\n    # Set an optional output directory\n    parser.add_argument(\n        \"-o\",\n        \"--output_dir\",\n        default=\"instantiations\",\n        required=False,\n        help=\"Where to generate the kernels \"\n        \" will default to the current directory \",\n    )\n    args = parser.parse_args()\n    main(args.output_dir)\n"
  },
  {
    "path": "hopper/heuristics.h",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n#include <vector>\n\ninline bool should_pack_gqa(bool varlen_q, int seqlen_q, int qhead_per_khead, int blockM) {\n    // If varlen, we don't actually know seqlen_q but only max_seqlen_q.\n    if (varlen_q) return true;\n    // Heuristic: PackGQA is a bit slower but can help if seqlen_q is small or not near a multiple of kBlockM\n    auto round_up = [](int a, int b) { return (a + b - 1) / b * b; };\n    float nopack_gqa_efficiency = float(seqlen_q) / float(round_up(seqlen_q, blockM));\n    float pack_gqa_efficiency = float(seqlen_q * qhead_per_khead) / float(round_up(seqlen_q * qhead_per_khead, blockM));\n    return nopack_gqa_efficiency < 0.9 * pack_gqa_efficiency;\n};\n\n// Find the number of splits that maximizes the occupancy. For example, if we have\n// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is\n// better than having 3 splits (efficiency = 0.67). However, we also don't want too many\n// splits as that would incur more HBM reads/writes.\n// So we find the best efficiency, then find the smallest number of splits that gets 85%\n// of the best efficiency.\ninline int num_splits_heuristic(int total_mblocks, int num_SMs, int num_n_blocks, int num_m_blocks, int size_one_kv_head, bool is_causal_or_local, int max_splits) {\n    // If we have enough to almost fill the SMs, then just use 1 split\n    // However, in the case of super long seqlen where each head of KV doesn't even fit into\n    // L2 (we assume that L2 size is 50MB), we want to split.\n    if (total_mblocks >= 0.8f * num_SMs) {\n        int const size_l2 = 50 * 1024 * 1024;\n        // Only split if there are enough queries to go over the KV at least twice\n        // Don't split if causal\n        if (size_one_kv_head > size_l2 && num_m_blocks >= num_SMs * 2 && !is_causal_or_local) {\n            return std::min((size_one_kv_head + size_l2 - 1) / size_l2, max_splits);\n        } else {\n            return 1;\n        }\n    }\n    // If num_n_blocks is too small, use 1 split. For example, we never split for hdim = 128 and seqlen_k = 512.\n    if (num_n_blocks <= 4) { return 1; }\n    max_splits = std::min({max_splits, num_SMs, num_n_blocks});\n    float max_efficiency = 0.f;\n    std::vector<float> efficiency;\n    efficiency.reserve(max_splits);\n    for (int num_splits = 1; num_splits <= max_splits; num_splits++) {\n        float n_waves = float(total_mblocks * num_splits) / num_SMs;\n        float eff = n_waves / ceil(n_waves);\n        // printf(\"num_splits = %d, eff = %f\\n\", num_splits, eff);\n        if (eff > max_efficiency) { max_efficiency = eff; }\n        efficiency.push_back(eff);\n    }\n    for (int num_splits = 1; num_splits <= max_splits; num_splits++) {\n        if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) {\n            // printf(\"num_splits chosen = %d\\n\", num_splits);\n            return num_splits;\n        }\n    }\n    return 1;\n}\n"
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim128_bf16_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate<>\nvoid run_mha_bwd_<80, cutlass::bfloat16_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim128<80, cutlass::bfloat16_t, false>(params, stream);\n}\ntemplate<>\nvoid run_mha_bwd_<86, cutlass::bfloat16_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim128<86, cutlass::bfloat16_t, false>(params, stream);\n}\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim128_bf16_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate<>\nvoid run_mha_bwd_<90, cutlass::bfloat16_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim128<90, cutlass::bfloat16_t, false>(params, stream);\n}\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim128_bf16_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate<>\nvoid run_mha_bwd_<80, cutlass::bfloat16_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim128<80, cutlass::bfloat16_t, true>(params, stream);\n}\ntemplate<>\nvoid run_mha_bwd_<86, cutlass::bfloat16_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim128<86, cutlass::bfloat16_t, true>(params, stream);\n}\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim128_bf16_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate<>\nvoid run_mha_bwd_<90, cutlass::bfloat16_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim128<90, cutlass::bfloat16_t, true>(params, stream);\n}\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim128_bf16_softcapall_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_hdim128_bf16_sm90.cu\"\n#include \"flash_bwd_hdim128_bf16_softcap_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim128_fp16_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate<>\nvoid run_mha_bwd_<80, cutlass::half_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim128<80, cutlass::half_t, false>(params, stream);\n}\ntemplate<>\nvoid run_mha_bwd_<86, cutlass::half_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim128<86, cutlass::half_t, false>(params, stream);\n}\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim128_fp16_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate<>\nvoid run_mha_bwd_<90, cutlass::half_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim128<90, cutlass::half_t, false>(params, stream);\n}\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim128_fp16_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate<>\nvoid run_mha_bwd_<80, cutlass::half_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim128<80, cutlass::half_t, true>(params, stream);\n}\ntemplate<>\nvoid run_mha_bwd_<86, cutlass::half_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim128<86, cutlass::half_t, true>(params, stream);\n}\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim128_fp16_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate<>\nvoid run_mha_bwd_<90, cutlass::half_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim128<90, cutlass::half_t, true>(params, stream);\n}\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim128_fp16_softcapall_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_hdim128_fp16_sm90.cu\"\n#include \"flash_bwd_hdim128_fp16_softcap_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim192_bf16_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate<>\nvoid run_mha_bwd_<80, cutlass::bfloat16_t, 192, false>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim192<80, cutlass::bfloat16_t, false>(params, stream);\n}\ntemplate<>\nvoid run_mha_bwd_<86, cutlass::bfloat16_t, 192, false>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim192<86, cutlass::bfloat16_t, false>(params, stream);\n}\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim192_bf16_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate<>\nvoid run_mha_bwd_<90, cutlass::bfloat16_t, 192, false>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim192<90, cutlass::bfloat16_t, false>(params, stream);\n}\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim192_bf16_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate<>\nvoid run_mha_bwd_<80, cutlass::bfloat16_t, 192, true>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim192<80, cutlass::bfloat16_t, true>(params, stream);\n}\ntemplate<>\nvoid run_mha_bwd_<86, cutlass::bfloat16_t, 192, true>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim192<86, cutlass::bfloat16_t, true>(params, stream);\n}\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim192_bf16_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate<>\nvoid run_mha_bwd_<90, cutlass::bfloat16_t, 192, true>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim192<90, cutlass::bfloat16_t, true>(params, stream);\n}\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim192_bf16_softcapall_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_hdim192_bf16_sm90.cu\"\n#include \"flash_bwd_hdim192_bf16_softcap_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim192_fp16_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate<>\nvoid run_mha_bwd_<80, cutlass::half_t, 192, false>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim192<80, cutlass::half_t, false>(params, stream);\n}\ntemplate<>\nvoid run_mha_bwd_<86, cutlass::half_t, 192, false>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim192<86, cutlass::half_t, false>(params, stream);\n}\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim192_fp16_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate<>\nvoid run_mha_bwd_<90, cutlass::half_t, 192, false>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim192<90, cutlass::half_t, false>(params, stream);\n}\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim192_fp16_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate<>\nvoid run_mha_bwd_<80, cutlass::half_t, 192, true>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim192<80, cutlass::half_t, true>(params, stream);\n}\ntemplate<>\nvoid run_mha_bwd_<86, cutlass::half_t, 192, true>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim192<86, cutlass::half_t, true>(params, stream);\n}\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim192_fp16_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate<>\nvoid run_mha_bwd_<90, cutlass::half_t, 192, true>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim192<90, cutlass::half_t, true>(params, stream);\n}\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim192_fp16_softcapall_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_hdim192_fp16_sm90.cu\"\n#include \"flash_bwd_hdim192_fp16_softcap_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim256_bf16_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate<>\nvoid run_mha_bwd_<80, cutlass::bfloat16_t, 256, false>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim256<80, cutlass::bfloat16_t, false>(params, stream);\n}\ntemplate<>\nvoid run_mha_bwd_<86, cutlass::bfloat16_t, 256, false>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim256<86, cutlass::bfloat16_t, false>(params, stream);\n}\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim256_bf16_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate<>\nvoid run_mha_bwd_<90, cutlass::bfloat16_t, 256, false>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim256<90, cutlass::bfloat16_t, false>(params, stream);\n}\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim256_bf16_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate<>\nvoid run_mha_bwd_<80, cutlass::bfloat16_t, 256, true>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim256<80, cutlass::bfloat16_t, true>(params, stream);\n}\ntemplate<>\nvoid run_mha_bwd_<86, cutlass::bfloat16_t, 256, true>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim256<86, cutlass::bfloat16_t, true>(params, stream);\n}\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim256_bf16_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate<>\nvoid run_mha_bwd_<90, cutlass::bfloat16_t, 256, true>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim256<90, cutlass::bfloat16_t, true>(params, stream);\n}\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim256_bf16_softcapall_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_hdim256_bf16_sm90.cu\"\n#include \"flash_bwd_hdim256_bf16_softcap_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim256_fp16_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate<>\nvoid run_mha_bwd_<80, cutlass::half_t, 256, false>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim256<80, cutlass::half_t, false>(params, stream);\n}\ntemplate<>\nvoid run_mha_bwd_<86, cutlass::half_t, 256, false>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim256<86, cutlass::half_t, false>(params, stream);\n}\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim256_fp16_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate<>\nvoid run_mha_bwd_<90, cutlass::half_t, 256, false>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim256<90, cutlass::half_t, false>(params, stream);\n}\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim256_fp16_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate<>\nvoid run_mha_bwd_<80, cutlass::half_t, 256, true>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim256<80, cutlass::half_t, true>(params, stream);\n}\ntemplate<>\nvoid run_mha_bwd_<86, cutlass::half_t, 256, true>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim256<86, cutlass::half_t, true>(params, stream);\n}\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim256_fp16_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate<>\nvoid run_mha_bwd_<90, cutlass::half_t, 256, true>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim256<90, cutlass::half_t, true>(params, stream);\n}\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim256_fp16_softcapall_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_hdim256_fp16_sm90.cu\"\n#include \"flash_bwd_hdim256_fp16_softcap_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim64_bf16_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate<>\nvoid run_mha_bwd_<80, cutlass::bfloat16_t, 64, false>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim64<80, cutlass::bfloat16_t, false>(params, stream);\n}\ntemplate<>\nvoid run_mha_bwd_<86, cutlass::bfloat16_t, 64, false>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim64<86, cutlass::bfloat16_t, false>(params, stream);\n}\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim64_bf16_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate<>\nvoid run_mha_bwd_<90, cutlass::bfloat16_t, 64, false>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim64<90, cutlass::bfloat16_t, false>(params, stream);\n}\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim64_bf16_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate<>\nvoid run_mha_bwd_<80, cutlass::bfloat16_t, 64, true>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim64<80, cutlass::bfloat16_t, true>(params, stream);\n}\ntemplate<>\nvoid run_mha_bwd_<86, cutlass::bfloat16_t, 64, true>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim64<86, cutlass::bfloat16_t, true>(params, stream);\n}\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim64_bf16_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate<>\nvoid run_mha_bwd_<90, cutlass::bfloat16_t, 64, true>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim64<90, cutlass::bfloat16_t, true>(params, stream);\n}\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim64_bf16_softcapall_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_hdim64_bf16_sm90.cu\"\n#include \"flash_bwd_hdim64_bf16_softcap_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim64_fp16_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate<>\nvoid run_mha_bwd_<80, cutlass::half_t, 64, false>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim64<80, cutlass::half_t, false>(params, stream);\n}\ntemplate<>\nvoid run_mha_bwd_<86, cutlass::half_t, 64, false>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim64<86, cutlass::half_t, false>(params, stream);\n}\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim64_fp16_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate<>\nvoid run_mha_bwd_<90, cutlass::half_t, 64, false>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim64<90, cutlass::half_t, false>(params, stream);\n}\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim64_fp16_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate<>\nvoid run_mha_bwd_<80, cutlass::half_t, 64, true>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim64<80, cutlass::half_t, true>(params, stream);\n}\ntemplate<>\nvoid run_mha_bwd_<86, cutlass::half_t, 64, true>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim64<86, cutlass::half_t, true>(params, stream);\n}\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim64_fp16_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate<>\nvoid run_mha_bwd_<90, cutlass::half_t, 64, true>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim64<90, cutlass::half_t, true>(params, stream);\n}\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim64_fp16_softcapall_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_hdim64_fp16_sm90.cu\"\n#include \"flash_bwd_hdim64_fp16_softcap_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim96_bf16_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate<>\nvoid run_mha_bwd_<80, cutlass::bfloat16_t, 96, false>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim96<80, cutlass::bfloat16_t, false>(params, stream);\n}\ntemplate<>\nvoid run_mha_bwd_<86, cutlass::bfloat16_t, 96, false>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim96<86, cutlass::bfloat16_t, false>(params, stream);\n}\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim96_bf16_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate<>\nvoid run_mha_bwd_<90, cutlass::bfloat16_t, 96, false>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim96<90, cutlass::bfloat16_t, false>(params, stream);\n}\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim96_bf16_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate<>\nvoid run_mha_bwd_<80, cutlass::bfloat16_t, 96, true>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim96<80, cutlass::bfloat16_t, true>(params, stream);\n}\ntemplate<>\nvoid run_mha_bwd_<86, cutlass::bfloat16_t, 96, true>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim96<86, cutlass::bfloat16_t, true>(params, stream);\n}\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim96_bf16_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate<>\nvoid run_mha_bwd_<90, cutlass::bfloat16_t, 96, true>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim96<90, cutlass::bfloat16_t, true>(params, stream);\n}\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim96_bf16_softcapall_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_hdim96_bf16_sm90.cu\"\n#include \"flash_bwd_hdim96_bf16_softcap_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim96_fp16_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate<>\nvoid run_mha_bwd_<80, cutlass::half_t, 96, false>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim96<80, cutlass::half_t, false>(params, stream);\n}\ntemplate<>\nvoid run_mha_bwd_<86, cutlass::half_t, 96, false>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim96<86, cutlass::half_t, false>(params, stream);\n}\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim96_fp16_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate<>\nvoid run_mha_bwd_<90, cutlass::half_t, 96, false>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim96<90, cutlass::half_t, false>(params, stream);\n}\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim96_fp16_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate<>\nvoid run_mha_bwd_<80, cutlass::half_t, 96, true>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim96<80, cutlass::half_t, true>(params, stream);\n}\ntemplate<>\nvoid run_mha_bwd_<86, cutlass::half_t, 96, true>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim96<86, cutlass::half_t, true>(params, stream);\n}\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim96_fp16_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate<>\nvoid run_mha_bwd_<90, cutlass::half_t, 96, true>(Flash_bwd_params &params, cudaStream_t stream) {\n    run_mha_bwd_hdim96<90, cutlass::half_t, true>(params, stream);\n}\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_bwd_hdim96_fp16_softcapall_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_bwd_hdim96_fp16_sm90.cu\"\n#include \"flash_bwd_hdim96_fp16_softcap_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_bf16_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_bf16_paged_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_bf16_paged_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_bf16_paged_softcapall_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim128_bf16_paged_sm80.cu\"\n#include \"flash_fwd_hdim128_bf16_paged_softcap_sm80.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_softcapall_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim128_bf16_paged_split_sm80.cu\"\n#include \"flash_fwd_hdim128_bf16_paged_split_softcap_sm80.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_bf16_sm100.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate void run_mha_fwd_<100, cutlass::bfloat16_t, 128, 128, false, false, false, false>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_bf16_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, false, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, false, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_bf16_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, false, false, false>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_bf16_softcap_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_bf16_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, false, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, false, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_bf16_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, false, true, false>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_bf16_softcapall_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim128_bf16_sm80.cu\"\n#include \"flash_fwd_hdim128_bf16_softcap_sm80.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_bf16_split_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, true, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, true, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_bf16_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, true, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, true, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, true, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, true, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_bf16_split_softcapall_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim128_bf16_split_sm80.cu\"\n#include \"flash_fwd_hdim128_bf16_split_softcap_sm80.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_e4m3_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_e4m3_paged_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_e4m3_paged_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_e4m3_paged_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_e4m3_paged_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_e4m3_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, false, false, false>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_e4m3_softcap_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_e4m3_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, false, true, false>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_e4m3_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, true, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_e4m3_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, true, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_fp16_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_fp16_paged_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate void run_mha_fwd_<80, cutlass::half_t, 128, 128, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::half_t, 128, 128, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_fp16_paged_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate void run_mha_fwd_<80, cutlass::half_t, 128, 128, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::half_t, 128, 128, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_fp16_paged_softcapall_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim128_fp16_paged_sm80.cu\"\n#include \"flash_fwd_hdim128_fp16_paged_softcap_sm80.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate void run_mha_fwd_<80, cutlass::half_t, 128, 128, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::half_t, 128, 128, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate void run_mha_fwd_<90, cutlass::half_t, 128, 128, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate void run_mha_fwd_<80, cutlass::half_t, 128, 128, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::half_t, 128, 128, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate void run_mha_fwd_<90, cutlass::half_t, 128, 128, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_softcapall_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim128_fp16_paged_split_sm80.cu\"\n#include \"flash_fwd_hdim128_fp16_paged_split_softcap_sm80.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_fp16_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate void run_mha_fwd_<80, cutlass::half_t, 128, 128, false, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::half_t, 128, 128, false, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_fp16_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, false, false, false>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_fp16_softcap_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_fp16_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate void run_mha_fwd_<80, cutlass::half_t, 128, 128, false, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::half_t, 128, 128, false, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_fp16_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, false, true, false>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_fp16_softcapall_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim128_fp16_sm80.cu\"\n#include \"flash_fwd_hdim128_fp16_softcap_sm80.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_fp16_split_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate void run_mha_fwd_<80, cutlass::half_t, 128, 128, true, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::half_t, 128, 128, true, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_fp16_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate void run_mha_fwd_<90, cutlass::half_t, 128, 128, true, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate void run_mha_fwd_<80, cutlass::half_t, 128, 128, true, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::half_t, 128, 128, true, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM128\ntemplate void run_mha_fwd_<90, cutlass::half_t, 128, 128, true, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim128_fp16_split_softcapall_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim128_fp16_split_sm80.cu\"\n#include \"flash_fwd_hdim128_fp16_split_softcap_sm80.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_128_bf16_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_128_bf16_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, false, false, false>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_128_bf16_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, false, true, false>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_128_bf16_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, true, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, true, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_128_e4m3_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, false, false, false>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_128_e4m3_softcap_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_128_e4m3_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, false, true, false>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_128_e4m3_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, true, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_128_e4m3_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, true, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_128_fp16_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::half_t, 192, 128, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::half_t, 192, 128, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_128_fp16_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, false, false, false>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_128_fp16_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, false, true, false>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_128_fp16_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::half_t, 192, 128, true, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::half_t, 192, 128, true, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_bf16_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_bf16_paged_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_bf16_paged_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_bf16_paged_softcapall_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim192_bf16_paged_sm80.cu\"\n#include \"flash_fwd_hdim192_bf16_paged_softcap_sm80.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_softcapall_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim192_bf16_paged_split_sm80.cu\"\n#include \"flash_fwd_hdim192_bf16_paged_split_softcap_sm80.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_bf16_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, false, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, false, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_bf16_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, false, false, false>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_bf16_softcap_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_bf16_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, false, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, false, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_bf16_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, false, true, false>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_bf16_softcapall_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim192_bf16_sm80.cu\"\n#include \"flash_fwd_hdim192_bf16_softcap_sm80.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_bf16_split_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, true, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, true, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_bf16_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, true, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, true, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, true, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, true, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_bf16_split_softcapall_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim192_bf16_split_sm80.cu\"\n#include \"flash_fwd_hdim192_bf16_split_softcap_sm80.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_e4m3_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_e4m3_paged_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_e4m3_paged_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_e4m3_paged_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_e4m3_paged_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_e4m3_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, false, false, false>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_e4m3_softcap_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_e4m3_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, false, true, false>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_e4m3_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, true, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_e4m3_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, true, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_fp16_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_fp16_paged_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<80, cutlass::half_t, 192, 192, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::half_t, 192, 192, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_fp16_paged_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<80, cutlass::half_t, 192, 192, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::half_t, 192, 192, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_fp16_paged_softcapall_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim192_fp16_paged_sm80.cu\"\n#include \"flash_fwd_hdim192_fp16_paged_softcap_sm80.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<80, cutlass::half_t, 192, 192, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::half_t, 192, 192, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::half_t, 192, 192, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<80, cutlass::half_t, 192, 192, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::half_t, 192, 192, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::half_t, 192, 192, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_softcapall_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim192_fp16_paged_split_sm80.cu\"\n#include \"flash_fwd_hdim192_fp16_paged_split_softcap_sm80.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_fp16_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<80, cutlass::half_t, 192, 192, false, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::half_t, 192, 192, false, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_fp16_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, false, false, false>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_fp16_softcap_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_fp16_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<80, cutlass::half_t, 192, 192, false, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::half_t, 192, 192, false, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_fp16_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, false, true, false>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_fp16_softcapall_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim192_fp16_sm80.cu\"\n#include \"flash_fwd_hdim192_fp16_softcap_sm80.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_fp16_split_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<80, cutlass::half_t, 192, 192, true, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::half_t, 192, 192, true, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_fp16_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::half_t, 192, 192, true, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<80, cutlass::half_t, 192, 192, true, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::half_t, 192, 192, true, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM192\ntemplate void run_mha_fwd_<90, cutlass::half_t, 192, 192, true, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim192_fp16_split_softcapall_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim192_fp16_split_sm80.cu\"\n#include \"flash_fwd_hdim192_fp16_split_softcap_sm80.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_bf16_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_bf16_paged_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_bf16_paged_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_bf16_paged_softcapall_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim256_bf16_paged_sm80.cu\"\n#include \"flash_fwd_hdim256_bf16_paged_softcap_sm80.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_softcapall_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim256_bf16_paged_split_sm80.cu\"\n#include \"flash_fwd_hdim256_bf16_paged_split_softcap_sm80.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_bf16_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, false, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, false, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_bf16_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, false, false, false>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_bf16_softcap_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_bf16_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, false, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, false, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_bf16_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, false, true, false>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_bf16_softcapall_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim256_bf16_sm80.cu\"\n#include \"flash_fwd_hdim256_bf16_softcap_sm80.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_bf16_split_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, true, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, true, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_bf16_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, true, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, true, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, true, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, true, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_bf16_split_softcapall_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim256_bf16_split_sm80.cu\"\n#include \"flash_fwd_hdim256_bf16_split_softcap_sm80.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_e4m3_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_e4m3_paged_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_e4m3_paged_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_e4m3_paged_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_e4m3_paged_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_e4m3_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, false, false, false>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_e4m3_softcap_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_e4m3_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, false, true, false>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_e4m3_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, true, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_e4m3_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, true, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_fp16_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_fp16_paged_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate void run_mha_fwd_<80, cutlass::half_t, 256, 256, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::half_t, 256, 256, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_fp16_paged_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate void run_mha_fwd_<80, cutlass::half_t, 256, 256, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::half_t, 256, 256, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_fp16_paged_softcapall_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim256_fp16_paged_sm80.cu\"\n#include \"flash_fwd_hdim256_fp16_paged_softcap_sm80.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate void run_mha_fwd_<80, cutlass::half_t, 256, 256, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::half_t, 256, 256, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate void run_mha_fwd_<90, cutlass::half_t, 256, 256, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate void run_mha_fwd_<80, cutlass::half_t, 256, 256, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::half_t, 256, 256, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate void run_mha_fwd_<90, cutlass::half_t, 256, 256, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_softcapall_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim256_fp16_paged_split_sm80.cu\"\n#include \"flash_fwd_hdim256_fp16_paged_split_softcap_sm80.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_fp16_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate void run_mha_fwd_<80, cutlass::half_t, 256, 256, false, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::half_t, 256, 256, false, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_fp16_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, false, false, false>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_fp16_softcap_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_fp16_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate void run_mha_fwd_<80, cutlass::half_t, 256, 256, false, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::half_t, 256, 256, false, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_fp16_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, false, true, false>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_fp16_softcapall_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim256_fp16_sm80.cu\"\n#include \"flash_fwd_hdim256_fp16_softcap_sm80.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_fp16_split_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate void run_mha_fwd_<80, cutlass::half_t, 256, 256, true, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::half_t, 256, 256, true, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_fp16_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate void run_mha_fwd_<90, cutlass::half_t, 256, 256, true, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate void run_mha_fwd_<80, cutlass::half_t, 256, 256, true, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::half_t, 256, 256, true, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM256\ntemplate void run_mha_fwd_<90, cutlass::half_t, 256, 256, true, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim256_fp16_split_softcapall_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim256_fp16_split_sm80.cu\"\n#include \"flash_fwd_hdim256_fp16_split_softcap_sm80.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_256_bf16_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, false, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_256_bf16_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, false, false, false, false>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_256_bf16_softcap_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, false, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_256_bf16_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, false, false, true, false>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_256_bf16_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, true, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_256_bf16_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, true, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_256_fp16_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::half_t, 64, 256, false, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::half_t, 64, 256, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::half_t, 64, 256, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::half_t, 64, 256, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::half_t, 64, 256, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_256_fp16_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::half_t, 64, 256, false, false, false, false>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_256_fp16_softcap_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::half_t, 64, 256, false, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_256_fp16_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::half_t, 64, 256, false, false, true, false>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_256_fp16_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::half_t, 64, 256, true, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_256_fp16_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::half_t, 64, 256, true, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_512_bf16_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_512_bf16_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, false, false, false>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_512_bf16_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, false, true, false>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_512_bf16_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, true, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, true, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_512_fp16_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::half_t, 64, 512, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::half_t, 64, 512, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_512_fp16_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, false, false, false>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_512_fp16_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, false, true, false>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_512_fp16_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::half_t, 64, 512, true, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::half_t, 64, 512, true, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_bf16_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_bf16_paged_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_bf16_paged_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_bf16_paged_softcapall_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_bf16_paged_sm80.cu\"\n#include \"flash_fwd_hdim64_bf16_paged_softcap_sm80.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_softcapall_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_bf16_paged_split_sm80.cu\"\n#include \"flash_fwd_hdim64_bf16_paged_split_softcap_sm80.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_bf16_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, false, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, false, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_bf16_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, false, false, false>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_bf16_softcap_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_bf16_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, false, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, false, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_bf16_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, false, true, false>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_bf16_softcapall_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_bf16_sm80.cu\"\n#include \"flash_fwd_hdim64_bf16_softcap_sm80.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_bf16_split_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, true, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, true, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_bf16_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, true, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, true, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, true, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, true, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_bf16_split_softcapall_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_bf16_split_sm80.cu\"\n#include \"flash_fwd_hdim64_bf16_split_softcap_sm80.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_e4m3_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_e4m3_paged_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_e4m3_paged_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_e4m3_paged_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_e4m3_paged_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_e4m3_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, false, false, false>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_e4m3_softcap_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_e4m3_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, false, true, false>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_e4m3_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, true, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_e4m3_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, true, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_fp16_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_fp16_paged_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<80, cutlass::half_t, 64, 64, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::half_t, 64, 64, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_fp16_paged_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<80, cutlass::half_t, 64, 64, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::half_t, 64, 64, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_fp16_paged_softcapall_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_fp16_paged_sm80.cu\"\n#include \"flash_fwd_hdim64_fp16_paged_softcap_sm80.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<80, cutlass::half_t, 64, 64, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::half_t, 64, 64, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::half_t, 64, 64, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<80, cutlass::half_t, 64, 64, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::half_t, 64, 64, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::half_t, 64, 64, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_softcapall_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_fp16_paged_split_sm80.cu\"\n#include \"flash_fwd_hdim64_fp16_paged_split_softcap_sm80.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_fp16_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<80, cutlass::half_t, 64, 64, false, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::half_t, 64, 64, false, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_fp16_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, false, false, false>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_fp16_softcap_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_fp16_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<80, cutlass::half_t, 64, 64, false, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::half_t, 64, 64, false, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_fp16_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, false, true, false>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_fp16_softcapall_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_fp16_sm80.cu\"\n#include \"flash_fwd_hdim64_fp16_softcap_sm80.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_fp16_split_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<80, cutlass::half_t, 64, 64, true, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::half_t, 64, 64, true, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_fp16_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::half_t, 64, 64, true, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<80, cutlass::half_t, 64, 64, true, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::half_t, 64, 64, true, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM64\ntemplate void run_mha_fwd_<90, cutlass::half_t, 64, 64, true, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim64_fp16_split_softcapall_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_fp16_split_sm80.cu\"\n#include \"flash_fwd_hdim64_fp16_split_softcap_sm80.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_bf16_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_bf16_paged_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_bf16_paged_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_bf16_paged_softcapall_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim96_bf16_paged_sm80.cu\"\n#include \"flash_fwd_hdim96_bf16_paged_softcap_sm80.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_softcapall_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim96_bf16_paged_split_sm80.cu\"\n#include \"flash_fwd_hdim96_bf16_paged_split_softcap_sm80.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_bf16_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, false, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, false, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_bf16_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, false, false, false>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_bf16_softcap_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_bf16_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, false, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, false, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_bf16_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, false, true, false>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_bf16_softcapall_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim96_bf16_sm80.cu\"\n#include \"flash_fwd_hdim96_bf16_softcap_sm80.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_bf16_split_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, true, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, true, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_bf16_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, true, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, true, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, true, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, true, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_bf16_split_softcapall_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim96_bf16_split_sm80.cu\"\n#include \"flash_fwd_hdim96_bf16_split_softcap_sm80.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_e4m3_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_e4m3_paged_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_e4m3_paged_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_e4m3_paged_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_e4m3_paged_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_e4m3_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, false, false, false>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_e4m3_softcap_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_e4m3_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, false, true, false>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_e4m3_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, true, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_e4m3_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, true, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_fp16_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_fp16_paged_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate void run_mha_fwd_<80, cutlass::half_t, 96, 96, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::half_t, 96, 96, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_fp16_paged_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate void run_mha_fwd_<80, cutlass::half_t, 96, 96, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::half_t, 96, 96, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_fp16_paged_softcapall_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim96_fp16_paged_sm80.cu\"\n#include \"flash_fwd_hdim96_fp16_paged_softcap_sm80.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate void run_mha_fwd_<80, cutlass::half_t, 96, 96, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::half_t, 96, 96, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate void run_mha_fwd_<90, cutlass::half_t, 96, 96, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate void run_mha_fwd_<80, cutlass::half_t, 96, 96, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::half_t, 96, 96, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate void run_mha_fwd_<90, cutlass::half_t, 96, 96, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_softcapall_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim96_fp16_paged_split_sm80.cu\"\n#include \"flash_fwd_hdim96_fp16_paged_split_softcap_sm80.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_fp16_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate void run_mha_fwd_<80, cutlass::half_t, 96, 96, false, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::half_t, 96, 96, false, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_fp16_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, false, false, false>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_fp16_softcap_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_fp16_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate void run_mha_fwd_<80, cutlass::half_t, 96, 96, false, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::half_t, 96, 96, false, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_fp16_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, false, true, false>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_fp16_softcapall_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim96_fp16_sm80.cu\"\n#include \"flash_fwd_hdim96_fp16_softcap_sm80.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_fp16_split_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate void run_mha_fwd_<80, cutlass::half_t, 96, 96, true, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::half_t, 96, 96, true, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_fp16_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate void run_mha_fwd_<90, cutlass::half_t, 96, 96, true, false, false, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_SM8x\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate void run_mha_fwd_<80, cutlass::half_t, 96, 96, true, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\ntemplate void run_mha_fwd_<86, cutlass::half_t, 96, 96, true, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_launch_template.h\"\n\n#ifndef FLASHATTENTION_DISABLE_HDIM96\ntemplate void run_mha_fwd_<90, cutlass::half_t, 96, 96, true, false, true, true>(Flash_fwd_params &params, cudaStream_t stream);\n#endif\n"
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdim96_fp16_split_softcapall_sm80.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim96_fp16_split_sm80.cu\"\n#include \"flash_fwd_hdim96_fp16_split_softcap_sm80.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimall_bf16_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_bf16_packgqa_sm90.cu\"\n#include \"flash_fwd_hdim96_bf16_packgqa_sm90.cu\"\n#include \"flash_fwd_hdim128_bf16_packgqa_sm90.cu\"\n#include \"flash_fwd_hdim192_bf16_packgqa_sm90.cu\"\n#include \"flash_fwd_hdim256_bf16_packgqa_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimall_bf16_paged_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_bf16_paged_sm90.cu\"\n#include \"flash_fwd_hdim96_bf16_paged_sm90.cu\"\n#include \"flash_fwd_hdim128_bf16_paged_sm90.cu\"\n#include \"flash_fwd_hdim192_bf16_paged_sm90.cu\"\n#include \"flash_fwd_hdim256_bf16_paged_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimall_bf16_paged_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_bf16_paged_softcap_sm90.cu\"\n#include \"flash_fwd_hdim96_bf16_paged_softcap_sm90.cu\"\n#include \"flash_fwd_hdim128_bf16_paged_softcap_sm90.cu\"\n#include \"flash_fwd_hdim192_bf16_paged_softcap_sm90.cu\"\n#include \"flash_fwd_hdim256_bf16_paged_softcap_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimall_bf16_paged_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_bf16_paged_split_sm90.cu\"\n#include \"flash_fwd_hdim96_bf16_paged_split_sm90.cu\"\n#include \"flash_fwd_hdim128_bf16_paged_split_sm90.cu\"\n#include \"flash_fwd_hdim192_bf16_paged_split_sm90.cu\"\n#include \"flash_fwd_hdim256_bf16_paged_split_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimall_bf16_paged_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_bf16_paged_split_softcap_sm90.cu\"\n#include \"flash_fwd_hdim96_bf16_paged_split_softcap_sm90.cu\"\n#include \"flash_fwd_hdim128_bf16_paged_split_softcap_sm90.cu\"\n#include \"flash_fwd_hdim192_bf16_paged_split_softcap_sm90.cu\"\n#include \"flash_fwd_hdim256_bf16_paged_split_softcap_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimall_bf16_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_bf16_sm90.cu\"\n#include \"flash_fwd_hdim96_bf16_sm90.cu\"\n#include \"flash_fwd_hdim128_bf16_sm90.cu\"\n#include \"flash_fwd_hdim192_bf16_sm90.cu\"\n#include \"flash_fwd_hdim256_bf16_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimall_bf16_softcap_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_bf16_softcap_packgqa_sm90.cu\"\n#include \"flash_fwd_hdim96_bf16_softcap_packgqa_sm90.cu\"\n#include \"flash_fwd_hdim128_bf16_softcap_packgqa_sm90.cu\"\n#include \"flash_fwd_hdim192_bf16_softcap_packgqa_sm90.cu\"\n#include \"flash_fwd_hdim256_bf16_softcap_packgqa_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimall_bf16_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_bf16_softcap_sm90.cu\"\n#include \"flash_fwd_hdim96_bf16_softcap_sm90.cu\"\n#include \"flash_fwd_hdim128_bf16_softcap_sm90.cu\"\n#include \"flash_fwd_hdim192_bf16_softcap_sm90.cu\"\n#include \"flash_fwd_hdim256_bf16_softcap_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimall_bf16_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_bf16_split_sm90.cu\"\n#include \"flash_fwd_hdim96_bf16_split_sm90.cu\"\n#include \"flash_fwd_hdim128_bf16_split_sm90.cu\"\n#include \"flash_fwd_hdim192_bf16_split_sm90.cu\"\n#include \"flash_fwd_hdim256_bf16_split_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimall_bf16_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_bf16_split_softcap_sm90.cu\"\n#include \"flash_fwd_hdim96_bf16_split_softcap_sm90.cu\"\n#include \"flash_fwd_hdim128_bf16_split_softcap_sm90.cu\"\n#include \"flash_fwd_hdim192_bf16_split_softcap_sm90.cu\"\n#include \"flash_fwd_hdim256_bf16_split_softcap_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimall_e4m3_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_e4m3_packgqa_sm90.cu\"\n#include \"flash_fwd_hdim96_e4m3_packgqa_sm90.cu\"\n#include \"flash_fwd_hdim128_e4m3_packgqa_sm90.cu\"\n#include \"flash_fwd_hdim192_e4m3_packgqa_sm90.cu\"\n#include \"flash_fwd_hdim256_e4m3_packgqa_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimall_e4m3_paged_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_e4m3_paged_sm90.cu\"\n#include \"flash_fwd_hdim96_e4m3_paged_sm90.cu\"\n#include \"flash_fwd_hdim128_e4m3_paged_sm90.cu\"\n#include \"flash_fwd_hdim192_e4m3_paged_sm90.cu\"\n#include \"flash_fwd_hdim256_e4m3_paged_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimall_e4m3_paged_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_e4m3_paged_softcap_sm90.cu\"\n#include \"flash_fwd_hdim96_e4m3_paged_softcap_sm90.cu\"\n#include \"flash_fwd_hdim128_e4m3_paged_softcap_sm90.cu\"\n#include \"flash_fwd_hdim192_e4m3_paged_softcap_sm90.cu\"\n#include \"flash_fwd_hdim256_e4m3_paged_softcap_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimall_e4m3_paged_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_e4m3_paged_split_sm90.cu\"\n#include \"flash_fwd_hdim96_e4m3_paged_split_sm90.cu\"\n#include \"flash_fwd_hdim128_e4m3_paged_split_sm90.cu\"\n#include \"flash_fwd_hdim192_e4m3_paged_split_sm90.cu\"\n#include \"flash_fwd_hdim256_e4m3_paged_split_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimall_e4m3_paged_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_e4m3_paged_split_softcap_sm90.cu\"\n#include \"flash_fwd_hdim96_e4m3_paged_split_softcap_sm90.cu\"\n#include \"flash_fwd_hdim128_e4m3_paged_split_softcap_sm90.cu\"\n#include \"flash_fwd_hdim192_e4m3_paged_split_softcap_sm90.cu\"\n#include \"flash_fwd_hdim256_e4m3_paged_split_softcap_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimall_e4m3_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_e4m3_sm90.cu\"\n#include \"flash_fwd_hdim96_e4m3_sm90.cu\"\n#include \"flash_fwd_hdim128_e4m3_sm90.cu\"\n#include \"flash_fwd_hdim192_e4m3_sm90.cu\"\n#include \"flash_fwd_hdim256_e4m3_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimall_e4m3_softcap_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_e4m3_softcap_packgqa_sm90.cu\"\n#include \"flash_fwd_hdim96_e4m3_softcap_packgqa_sm90.cu\"\n#include \"flash_fwd_hdim128_e4m3_softcap_packgqa_sm90.cu\"\n#include \"flash_fwd_hdim192_e4m3_softcap_packgqa_sm90.cu\"\n#include \"flash_fwd_hdim256_e4m3_softcap_packgqa_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimall_e4m3_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_e4m3_softcap_sm90.cu\"\n#include \"flash_fwd_hdim96_e4m3_softcap_sm90.cu\"\n#include \"flash_fwd_hdim128_e4m3_softcap_sm90.cu\"\n#include \"flash_fwd_hdim192_e4m3_softcap_sm90.cu\"\n#include \"flash_fwd_hdim256_e4m3_softcap_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimall_e4m3_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_e4m3_split_sm90.cu\"\n#include \"flash_fwd_hdim96_e4m3_split_sm90.cu\"\n#include \"flash_fwd_hdim128_e4m3_split_sm90.cu\"\n#include \"flash_fwd_hdim192_e4m3_split_sm90.cu\"\n#include \"flash_fwd_hdim256_e4m3_split_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimall_e4m3_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_e4m3_split_softcap_sm90.cu\"\n#include \"flash_fwd_hdim96_e4m3_split_softcap_sm90.cu\"\n#include \"flash_fwd_hdim128_e4m3_split_softcap_sm90.cu\"\n#include \"flash_fwd_hdim192_e4m3_split_softcap_sm90.cu\"\n#include \"flash_fwd_hdim256_e4m3_split_softcap_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimall_fp16_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_fp16_packgqa_sm90.cu\"\n#include \"flash_fwd_hdim96_fp16_packgqa_sm90.cu\"\n#include \"flash_fwd_hdim128_fp16_packgqa_sm90.cu\"\n#include \"flash_fwd_hdim192_fp16_packgqa_sm90.cu\"\n#include \"flash_fwd_hdim256_fp16_packgqa_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimall_fp16_paged_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_fp16_paged_sm90.cu\"\n#include \"flash_fwd_hdim96_fp16_paged_sm90.cu\"\n#include \"flash_fwd_hdim128_fp16_paged_sm90.cu\"\n#include \"flash_fwd_hdim192_fp16_paged_sm90.cu\"\n#include \"flash_fwd_hdim256_fp16_paged_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimall_fp16_paged_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_fp16_paged_softcap_sm90.cu\"\n#include \"flash_fwd_hdim96_fp16_paged_softcap_sm90.cu\"\n#include \"flash_fwd_hdim128_fp16_paged_softcap_sm90.cu\"\n#include \"flash_fwd_hdim192_fp16_paged_softcap_sm90.cu\"\n#include \"flash_fwd_hdim256_fp16_paged_softcap_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimall_fp16_paged_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_fp16_paged_split_sm90.cu\"\n#include \"flash_fwd_hdim96_fp16_paged_split_sm90.cu\"\n#include \"flash_fwd_hdim128_fp16_paged_split_sm90.cu\"\n#include \"flash_fwd_hdim192_fp16_paged_split_sm90.cu\"\n#include \"flash_fwd_hdim256_fp16_paged_split_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimall_fp16_paged_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_fp16_paged_split_softcap_sm90.cu\"\n#include \"flash_fwd_hdim96_fp16_paged_split_softcap_sm90.cu\"\n#include \"flash_fwd_hdim128_fp16_paged_split_softcap_sm90.cu\"\n#include \"flash_fwd_hdim192_fp16_paged_split_softcap_sm90.cu\"\n#include \"flash_fwd_hdim256_fp16_paged_split_softcap_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimall_fp16_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_fp16_sm90.cu\"\n#include \"flash_fwd_hdim96_fp16_sm90.cu\"\n#include \"flash_fwd_hdim128_fp16_sm90.cu\"\n#include \"flash_fwd_hdim192_fp16_sm90.cu\"\n#include \"flash_fwd_hdim256_fp16_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimall_fp16_softcap_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_fp16_softcap_packgqa_sm90.cu\"\n#include \"flash_fwd_hdim96_fp16_softcap_packgqa_sm90.cu\"\n#include \"flash_fwd_hdim128_fp16_softcap_packgqa_sm90.cu\"\n#include \"flash_fwd_hdim192_fp16_softcap_packgqa_sm90.cu\"\n#include \"flash_fwd_hdim256_fp16_softcap_packgqa_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimall_fp16_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_fp16_softcap_sm90.cu\"\n#include \"flash_fwd_hdim96_fp16_softcap_sm90.cu\"\n#include \"flash_fwd_hdim128_fp16_softcap_sm90.cu\"\n#include \"flash_fwd_hdim192_fp16_softcap_sm90.cu\"\n#include \"flash_fwd_hdim256_fp16_softcap_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimall_fp16_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_fp16_split_sm90.cu\"\n#include \"flash_fwd_hdim96_fp16_split_sm90.cu\"\n#include \"flash_fwd_hdim128_fp16_split_sm90.cu\"\n#include \"flash_fwd_hdim192_fp16_split_sm90.cu\"\n#include \"flash_fwd_hdim256_fp16_split_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimall_fp16_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_fp16_split_softcap_sm90.cu\"\n#include \"flash_fwd_hdim96_fp16_split_softcap_sm90.cu\"\n#include \"flash_fwd_hdim128_fp16_split_softcap_sm90.cu\"\n#include \"flash_fwd_hdim192_fp16_split_softcap_sm90.cu\"\n#include \"flash_fwd_hdim256_fp16_split_softcap_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimdiff_bf16_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_256_bf16_packgqa_sm90.cu\"\n#include \"flash_fwd_hdim64_512_bf16_packgqa_sm90.cu\"\n#include \"flash_fwd_hdim192_128_bf16_packgqa_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_256_bf16_paged_sm90.cu\"\n#include \"flash_fwd_hdim64_512_bf16_paged_sm90.cu\"\n#include \"flash_fwd_hdim192_128_bf16_paged_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_256_bf16_paged_softcap_sm90.cu\"\n#include \"flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu\"\n#include \"flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_256_bf16_paged_split_sm90.cu\"\n#include \"flash_fwd_hdim64_512_bf16_paged_split_sm90.cu\"\n#include \"flash_fwd_hdim192_128_bf16_paged_split_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_256_bf16_paged_split_softcap_sm90.cu\"\n#include \"flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu\"\n#include \"flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimdiff_bf16_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_256_bf16_sm90.cu\"\n#include \"flash_fwd_hdim64_512_bf16_sm90.cu\"\n#include \"flash_fwd_hdim192_128_bf16_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_256_bf16_softcap_packgqa_sm90.cu\"\n#include \"flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu\"\n#include \"flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_256_bf16_softcap_sm90.cu\"\n#include \"flash_fwd_hdim64_512_bf16_softcap_sm90.cu\"\n#include \"flash_fwd_hdim192_128_bf16_softcap_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimdiff_bf16_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_256_bf16_split_sm90.cu\"\n#include \"flash_fwd_hdim64_512_bf16_split_sm90.cu\"\n#include \"flash_fwd_hdim192_128_bf16_split_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimdiff_bf16_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_256_bf16_split_softcap_sm90.cu\"\n#include \"flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu\"\n#include \"flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimdiff_e4m3_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim192_128_e4m3_paged_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim192_128_e4m3_paged_softcap_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim192_128_e4m3_paged_split_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim192_128_e4m3_paged_split_softcap_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimdiff_e4m3_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim192_128_e4m3_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimdiff_e4m3_softcap_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim192_128_e4m3_softcap_packgqa_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimdiff_e4m3_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim192_128_e4m3_softcap_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimdiff_e4m3_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim192_128_e4m3_split_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimdiff_e4m3_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim192_128_e4m3_split_softcap_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimdiff_fp16_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_256_fp16_packgqa_sm90.cu\"\n#include \"flash_fwd_hdim64_512_fp16_packgqa_sm90.cu\"\n#include \"flash_fwd_hdim192_128_fp16_packgqa_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_256_fp16_paged_sm90.cu\"\n#include \"flash_fwd_hdim64_512_fp16_paged_sm90.cu\"\n#include \"flash_fwd_hdim192_128_fp16_paged_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_256_fp16_paged_softcap_sm90.cu\"\n#include \"flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu\"\n#include \"flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_256_fp16_paged_split_sm90.cu\"\n#include \"flash_fwd_hdim64_512_fp16_paged_split_sm90.cu\"\n#include \"flash_fwd_hdim192_128_fp16_paged_split_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_256_fp16_paged_split_softcap_sm90.cu\"\n#include \"flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu\"\n#include \"flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimdiff_fp16_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_256_fp16_sm90.cu\"\n#include \"flash_fwd_hdim64_512_fp16_sm90.cu\"\n#include \"flash_fwd_hdim192_128_fp16_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_packgqa_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_256_fp16_softcap_packgqa_sm90.cu\"\n#include \"flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu\"\n#include \"flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_256_fp16_softcap_sm90.cu\"\n#include \"flash_fwd_hdim64_512_fp16_softcap_sm90.cu\"\n#include \"flash_fwd_hdim192_128_fp16_softcap_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimdiff_fp16_split_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_256_fp16_split_sm90.cu\"\n#include \"flash_fwd_hdim64_512_fp16_split_sm90.cu\"\n#include \"flash_fwd_hdim192_128_fp16_split_sm90.cu\""
  },
  {
    "path": "hopper/instantiations/flash_fwd_hdimdiff_fp16_split_softcap_sm90.cu",
    "content": "// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n// Splitting the different template instantiations to different files to speed up compilation.\n// This file is auto-generated. See \"generate_kernels.py\"\n\n#include \"flash_fwd_hdim64_256_fp16_split_softcap_sm90.cu\"\n#include \"flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu\"\n#include \"flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu\""
  },
  {
    "path": "hopper/mainloop_bwd_sm80.hpp",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n#include <cutlass/cutlass.h>\n#include <cutlass/array.h>\n#include <cutlass/numeric_types.h>\n#include <cutlass/numeric_conversion.h>\n\n#include \"cute/tensor.hpp\"\n\n#include \"seqlen.h\"\n#include \"mask.h\"\n#include \"mask.h\"\n#include \"softmax.h\"\n#include \"utils.h\"\n\nnamespace flash {\n\nusing namespace cute;\n\ntemplate <int Stages, int Stages_dO, class TileShape_MNK_, class Element_, class ElementAccum_, class ArchTag_,\n        bool Is_causal_, bool Is_local_, bool Has_softcap_, bool Varlen_, bool Deterministic,\n        bool SdP_swapAB_, bool dKV_swapAB_, bool dQ_swapAB_,\n        int NumMmaWarpGroups=2, int AtomLayoutMSdP=1, int AtomLayoutNdKV=8, int AtomLayoutMdQ=1,\n        bool V_in_regs=false>\nstruct CollectiveMainloopBwdSm80 {\n\n    static constexpr int kStages = Stages;\n    static constexpr int kStages_dO = Stages_dO;\n    static_assert(kStages >= kStages_dO);\n    using TileShape_MNK = TileShape_MNK_;\n    using Element = Element_;\n    using ElementAccum = ElementAccum_;\n    using ArchTag = ArchTag_;\n    static constexpr bool Is_causal = Is_causal_;\n    static constexpr bool Is_local = Is_local_;\n    static constexpr bool Has_softcap = Has_softcap_;\n    static constexpr bool Varlen = Varlen_;\n    static constexpr int NumMmaWarps = NumMmaWarpGroups * cutlass::NumWarpsPerWarpGroup;\n\n    static constexpr bool SdP_swapAB = SdP_swapAB_;\n    static constexpr bool dKV_swapAB = dKV_swapAB_;\n    static constexpr bool dQ_swapAB = dQ_swapAB_;\n\n    static constexpr bool Q_dO_same_stages = kStages == kStages_dO;\n\n    static constexpr int kBlockM = get<0>(TileShape_MNK{});\n    static constexpr int kBlockN = get<1>(TileShape_MNK{});\n    static constexpr int kHeadDim = get<2>(TileShape_MNK{});\n\n    using SeqlenInfo_t = flash::SeqlenInfoQK<Varlen, kBlockM>;\n    using BlockMN_t = flash::BlockMN<SeqlenInfo_t, kBlockM, kBlockN, Is_causal, Is_local>;\n\n    static_assert(ArchTag::kMinComputeCapability >= 80);\n\n    static constexpr bool Has_cp_async = ArchTag::kMinComputeCapability >= 80;\n\n    static constexpr int NumMmaThreads = NumMmaWarps * cutlass::NumThreadsPerWarp;\n    static constexpr int NumProducerThreads = NumMmaThreads;  // For compatibility with TileScheduler\n\n    using MMA_Atom_Arch = std::conditional_t<\n        ArchTag::kMinComputeCapability >= 80,\n        std::conditional_t<\n            std::is_same_v<Element, cutlass::half_t>,\n            MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,\n            MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>\n        >,\n        MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>\n    >;\n\n    static_assert(NumMmaWarps % AtomLayoutMSdP == 0);\n    static_assert(NumMmaWarps % AtomLayoutNdKV == 0);\n    static_assert(NumMmaWarps % AtomLayoutMdQ == 0);\n    static constexpr bool Mma_dKV_is_RS = AtomLayoutMSdP == 1 && AtomLayoutNdKV == NumMmaWarps && SdP_swapAB && !dKV_swapAB;\n    static constexpr bool Mma_dQ_is_RS = AtomLayoutMSdP == NumMmaWarps && AtomLayoutMdQ == NumMmaWarps && !SdP_swapAB && !dQ_swapAB;  // If dQ_swapAB we can't use RS\n\n    using AtomLayoutSdP = std::conditional_t<\n        !SdP_swapAB,\n        Layout<Shape<Int<AtomLayoutMSdP>, Int<NumMmaWarps / AtomLayoutMSdP>, _1>>,\n        Layout<Shape<Int<NumMmaWarps / AtomLayoutMSdP>, Int<AtomLayoutMSdP>, _1>>\n    >;\n    static constexpr bool MmaSdPEvenN = ((!SdP_swapAB ? kBlockN : kBlockM) / size<1>(AtomLayoutSdP{})) % 16 == 0;\n    using TiledMmaSdP = TiledMMA<\n        MMA_Atom_Arch,\n        AtomLayoutSdP,\n        Tile<Int<16 * CUTE_STATIC_V(size<0>(AtomLayoutSdP{}))>, Int<(MmaSdPEvenN ? 16 : 8) * CUTE_STATIC_V(size<1>(AtomLayoutSdP{}))>, _16>>;\n\n    using AtomLayoutdKV = std::conditional_t<\n        !dKV_swapAB,\n        Layout<Shape<Int<AtomLayoutNdKV>, Int<NumMmaWarps / AtomLayoutNdKV>, _1>>,\n        Layout<Shape<Int<NumMmaWarps / AtomLayoutNdKV>, Int<AtomLayoutNdKV>, _1>>\n    >;\n    static constexpr bool MmadKVEvenN = ((!dKV_swapAB ? kHeadDim : kBlockN) / size<1>(AtomLayoutdKV{})) % 16 == 0;\n    using TiledMmadKV = TiledMMA<\n        MMA_Atom_Arch,\n        AtomLayoutdKV,\n        Tile<Int<16 * CUTE_STATIC_V(size<0>(AtomLayoutdKV{}))>, Int<(MmadKVEvenN ? 16 : 8) * CUTE_STATIC_V(size<1>(AtomLayoutdKV{}))>, _16>>;\n\n    using AtomLayoutdQ = std::conditional_t<\n        !dQ_swapAB,\n        Layout<Shape<Int<AtomLayoutMdQ>, Int<NumMmaWarps / AtomLayoutMdQ>, _1>>,\n        Layout<Shape<Int<NumMmaWarps / AtomLayoutMdQ>, Int<AtomLayoutMdQ>, _1>>\n    >;\n    static constexpr bool MmadQEvenN = ((!dQ_swapAB ? kHeadDim : kBlockM) / size<1>(AtomLayoutdQ{})) % 16 == 0;\n    using TiledMmadQ = TiledMMA<\n        MMA_Atom_Arch,\n        AtomLayoutdQ,\n        Tile<Int<16 * CUTE_STATIC_V(size<0>(AtomLayoutdQ{}))>, Int<(MmadQEvenN ? 16 : 8) * CUTE_STATIC_V(size<1>(AtomLayoutdQ{}))>, _16>>;\n\n    static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);\n    static_assert(kHeadDim % kGmemElemsPerLoad == 0, \"Headdim must be a multiple of kGmemElemsPerLoad\");\n    // We want each \"row\" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each\n    // thread to have 4 loads in the M direction and 2 vectorized load in the K direction.\n    static constexpr int kBytePerRow = kHeadDim * sizeof(Element);\n    static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element);\n\n    static constexpr int kSwizzle = kBlockKGmem == 128 ? 4 : (kBlockKGmem == 64 ? 3 : (kBlockKGmem == 32 ? 2 : 1));\n    static constexpr int kSwizzleBase = sizeof(Element) == 4 ? 2 : (sizeof(Element) == 2 ? 3 : 4);\n\n    // We need to accommodate both Q and Q^T (and dO and dO^T) in shared memory.\n    // Q & dO are used in the SdP Mma and Q^T and dO^T are used in the dKV Mma.\n    // Since this is GMMA::Major::K, the M dimension (kBlockM) doesn't matter for the layout, only the K dimension\n    // changes the layout.\n    using SmemLayoutAtomQdO = decltype(\n        composition(Swizzle<kSwizzle, kSwizzleBase, kSwizzleBase>{},\n                    Layout<Shape<_8, Int<kBlockKGmem>>,\n                           Stride<Int<kBlockKGmem>, _1>>{}));\n    using SmemLayoutQ =\n        decltype(tile_to_shape(SmemLayoutAtomQdO{},\n                 make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));\n    using SmemLayoutdO =\n        decltype(tile_to_shape(SmemLayoutAtomQdO{},\n                 make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages_dO>{})));\n\n    using SmemLayoutAtomKV = decltype(\n        composition(Swizzle<kSwizzle, kSwizzleBase, kSwizzleBase>{},\n                    // TODO: FA2 has a slightly different layout, does it matter?\n                    Layout<Shape<_8, Int<kBlockKGmem>>,\n                           Stride<Int<kBlockKGmem>, _1>>{}));\n    using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomKV{}, select<1, 2>(TileShape_MNK{})));\n\n    using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomKV{}, select<1, 2>(TileShape_MNK{})));\n\n    // TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest.\n    static constexpr int kPBlockN = kBlockN % 64 == 0 ? 64 : (kBlockN % 32 == 0 ? 32 : 16);\n    static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64);\n    // static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3);\n    static constexpr int kSwizzlePdS = 3;\n    using SmemLayoutAtomPdS = decltype(\n        composition(Swizzle<kSwizzlePdS, kSwizzleBase, kSwizzleBase>{},\n                    Layout<Shape<Int<kBlockM>, Int<kPBlockN>>,\n                           Stride<Int<kPBlockN>, _1>>{}));\n    using SmemLayoutPdS = decltype(tile_to_shape(\n        SmemLayoutAtomPdS{},\n        make_shape(Int<kBlockM>{}, Int<kBlockN>{})));\n\n    // We set stride to be multiple of 64 so that if ShuffleLSE, even if threads read from sLSE but out of bounds,\n    // it's still a valid smem address.\n    using SmemLayoutLSE = cute::Layout<cute::Shape<Int<kBlockM>, Int<kStages>>, cute::Stride<_1, Int<cute::round_up(kBlockM, 64)>>>;\n    using SmemLayoutLSEMma = std::conditional_t<\n        SdP_swapAB,\n        cute::Layout<cute::Shape<Int<kBlockN>, Int<kBlockM>, Int<kStages>>, cute::Stride<_0, _1, Int<cute::round_up(kBlockM, 64)>>>,\n        cute::Layout<cute::Shape<Int<kBlockM>, Int<kBlockN>, Int<kStages>>, cute::Stride<_1, _0, Int<cute::round_up(kBlockM, 64)>>>\n    >;\n\n    // Note this is the transpose in terms of the view, not in terms of memory.\n    using SmemLayoutQt =\n        decltype(cute::composition(SmemLayoutQ{},\n                                   make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int<kStages>{}),\n                                               make_stride(Int<kBlockM>{}, _1{}, Int<kBlockM * kHeadDim>{}))));\n    using SmemLayoutdOt =\n        decltype(cute::composition(SmemLayoutdO{},\n                                   make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int<kStages_dO>{}),\n                                               make_stride(Int<kBlockM>{}, _1{}, Int<kBlockM * kHeadDim>{}))));\n    using SmemLayoutKt =\n        decltype(cute::composition(SmemLayoutK{},\n                                   make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})),\n                                               make_stride(Int<kBlockN>{}, _1{}))));\n    using SmemLayoutPdSt =\n        decltype(cute::composition(SmemLayoutPdS{},\n                                   make_layout(make_shape(Int<kBlockN>{}, Int<kBlockM>{}),\n                                               make_stride(Int<kBlockM>{}, _1{}))));\n\n    // Thread layout, 256 or 384 threads per row\n    using R2SLayoutAtomdQaccum = Layout<Shape<Int<NumMmaThreads>>>;\n    using R2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, R2SLayoutAtomdQaccum{},\n                                                         Layout<Shape < _1>>{}));  // Val layout, 1 vals per store\n\n    using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, Element>;\n    using SmemCopyAtomTransposed = Copy_Atom<SM75_U16x8_LDSM_T, Element>;\n    // For the case where the N dimension of MmaSdP is divisible by 8 but not by 16\n    using SmemCopyAtomHalf = Copy_Atom<SM75_U32x2_LDSM_N, Element>;\n    // For the case where the N dimension of MmadQ is divisible by 8 but not by 16\n    using SmemCopyAtomTransposedHalf = Copy_Atom<SM75_U16x4_LDSM_T, Element>;\n    // If !SdP_swapAB, the accum registers hold P / dS, otherwise they hold Pt / dSt.\n    // If PdS_major is MN, then we need to \"transpose\" the write.\n    // TODO: check this write\n    using R2SCopyAtomPdS = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>;\n\n    // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading\n    // from the same address by the same threadblock. This is slightly faster.\n    using GmemCopyStruct = std::conditional_t<\n        Has_cp_async,\n        SM80_CP_ASYNC_CACHEGLOBAL_ZFILL<cute::uint128_t>,\n        AutoVectorizingCopyWithAssumedAlignment<128>\n    >;\n    using GmemCopyAtom = Copy_Atom<GmemCopyStruct, Element>;\n\n    static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad;\n    static_assert(NumMmaThreads % kGmemThreadsPerRow == 0, \"NumMmaThreads must be a multiple of kGmemThreadsPerRow\");\n    using GmemLayoutAtom = Layout<Shape <Int<NumMmaThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,\n                                  Stride<Int<kGmemThreadsPerRow>, _1>>;\n    using GmemTiledCopyQKV = decltype(\n        make_tiled_copy(GmemCopyAtom{},\n                        GmemLayoutAtom{},\n                        Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{}));  // Val layout, 8 or 16 vals per read\n    using GmemCopyAtomLSE = Copy_Atom<GmemCopyStruct, float>;\n    using GmemLayoutAtomLSE = Layout<Shape<Int<NumMmaThreads>>>;\n    using GmemTiledCopyLSE = decltype(make_tiled_copy(GmemCopyAtomLSE{}, GmemLayoutAtomLSE{},\n                                                      Layout<Shape<_4>>{}));  // Val layout, 4 vals per store\n    // So that we don't have to check if we overshot kBlockM when we load Q\n    // static_assert(kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0);\n\n    using ShapeQKV = cute::Shape<int32_t, int32_t, int32_t, int32_t>;  // (seqlen, d, head, batch)\n    using StrideQKV = cute::Stride<int64_t, _1, int64_t, int64_t>;\n    using ShapeLSE = cute::Shape<int32_t, int32_t, int32_t>;  // (seqlen, head, batch)\n    using StrideLSE = cute::Stride<_1, int64_t, int64_t>;  // (seqlen, head, batch)\n    using ShapedQaccum = cute::Shape<int32_t, int32_t, int32_t>;  // (seqlen_q * d, head, batch)\n    using StridedQaccum = cute::Stride<_1, int64_t, int64_t>;\n\n    // These are tuned for speed. They don't affect correctness.\n    // We have separate iterations with causal masking. Not necessary for hdim 128 but for hdim 64\n    // this helps quite a bit to not have to do causal masking for most of the iterations.\n    // For hdim 192, separating masking iterations results in register spills.\n    // static constexpr bool SeparateMaskingIterations = kHeadDim <= 64;\n    static constexpr bool SeparateMaskingIterations = false;\n    // Do we keep the LSE and dPsum in each thread, or split them across 8 threads that share them and then\n    // shuffle to get the value whenever we need? This can reduce register pressure when SdP_swapAB, where each\n    // thread needs to keep statistics for (kBlockM / 4) rows. If !SdP_swapAB, each thread only needs to keep\n    // statistic for 2 rows.\n    // static constexpr bool ShuffleLSE = SdP_swapAB && kHeadDim <= 64;\n    // static constexpr bool ShuffledPsum = SdP_swapAB && kHeadDim <= 64;\n    static constexpr bool ShuffleLSE = SdP_swapAB && false;\n    static constexpr bool ShuffledPsum = SdP_swapAB && false;\n\n    static constexpr bool Share_QV_Smem = V_in_regs;\n    using SmemP_t = std::conditional_t<Mma_dKV_is_RS, cute::array<Element, 0>, cute::array_aligned<Element, cute::cosize_v<SmemLayoutPdS>>>;\n\n    struct TensorStorageSharedQV : cute::aligned_struct<128> {\n        cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;\n        union {\n            cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;\n            cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;\n        };\n        cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;\n        cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutLSE>, 128> smem_lse;\n        cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutLSE>, 128> smem_dpsum;\n        SmemP_t smem_p;\n        cute::array_aligned<Element, cute::cosize_v<SmemLayoutPdS>> smem_ds;\n    };\n\n    struct TensorStorageSeparateQV : cute::aligned_struct<128> {\n        cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;\n        cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;\n        cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;\n        cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;\n        cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutLSE>, 128> smem_lse;\n        cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutLSE>, 128> smem_dpsum;\n        SmemP_t smem_p;\n        cute::array_aligned<Element, cute::cosize_v<SmemLayoutPdS>> smem_ds;\n    };\n\n    using TensorStorage = std::conditional_t<Share_QV_Smem, TensorStorageSharedQV, TensorStorageSeparateQV>;\n\n    // Host side kernel arguments\n    struct Arguments {\n        Element const* const ptr_Q;\n        ShapeQKV const shape_Q;\n        StrideQKV const stride_Q;\n        Element const* const ptr_K;\n        ShapeQKV const shape_K;\n        StrideQKV const stride_K;\n        Element const* const ptr_V;\n        ShapeQKV const shape_V;\n        StrideQKV const stride_V;\n        Element const* const ptr_dO;\n        ShapeQKV const shape_dO;\n        StrideQKV const stride_dO;\n        ElementAccum* const ptr_dQaccum;\n        ShapedQaccum const shape_dQaccum;\n        StridedQaccum const stride_dQaccum;\n        float const* const ptr_LSE_log2;\n        ShapeLSE const shape_LSE;\n        StrideLSE const stride_LSE_log2;\n        float const* const ptr_dPsum;\n        StrideLSE const stride_dPsum;\n        float const softmax_scale;\n        int const window_size_left, window_size_right, attention_chunk;\n        float const softcap_val;\n        int const num_batch;\n        int* const dq_semaphore;\n        int const* const cu_seqlens_q = nullptr;\n        int const* const cu_seqlens_k = nullptr;\n        int const* const seqused_q = nullptr;\n        int const* const seqused_k = nullptr;\n    };\n\n    // Device side kernel params\n    struct Params {\n        Element const* const ptr_Q;\n        ShapeQKV const shape_Q;\n        StrideQKV const stride_Q;\n        Element const* const ptr_K;\n        ShapeQKV const shape_K;\n        StrideQKV const stride_K;\n        Element const* const ptr_V;\n        ShapeQKV const shape_V;\n        StrideQKV const stride_V;\n        Element const* const ptr_dO;\n        ShapeQKV const shape_dO;\n        StrideQKV const stride_dO;\n        ElementAccum* const ptr_dQaccum;\n        ShapedQaccum const shape_dQaccum;\n        StridedQaccum stride_dQaccum;\n        cutlass::FastDivmod qhead_per_khead_divmod;\n        float const* const ptr_LSE_log2;\n        ShapeLSE const shape_LSE;\n        StrideLSE const stride_LSE_log2;\n        float const* const ptr_dPsum;\n        StrideLSE const stride_dPsum;\n        float const softmax_scale, softmax_scale_log2;\n        int const window_size_left, window_size_right;\n        cutlass::FastDivmod attention_chunk_divmod;\n        float const softcap_val;\n        int const num_batch;\n        int *const dq_semaphore;\n        int const *const cu_seqlens_q = nullptr;\n        int const *const cu_seqlens_k = nullptr;\n        int const *const seqused_q = nullptr;\n        int const *const seqused_k = nullptr;\n    };\n\n    static Params\n    to_underlying_arguments(Arguments const& args) {\n        if constexpr (Deterministic) { assert(args.dq_semaphore != nullptr); }\n        // Avoid dividing by zero\n        cutlass::FastDivmod attention_chunk_divmod(args.attention_chunk >= 1 ? args.attention_chunk : 1);\n        attention_chunk_divmod.divisor = args.attention_chunk;\n        // If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val.\n        // Right after this, we multiply by log2(e) before applying exp2.\n        // To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val\n        // (assigning it to params.softcap_val) and pre-multiply softcap_val * log2(e)\n        // (assigning it to params.softmax_scale_log2).\n        // In the backward, we need to multiply by\n        // (1 - tanh^2) * softmax_scale / softcap_val * softcap_val = (1 - tanh^2) * softmax_scale.\n        // Instead we multiply by (1 - tanh^2) and multiply dK and dV by params.softmax_scale\n        // (the original softmax_scale) at the end.\n        return {args.ptr_Q, args.shape_Q, args.stride_Q,\n                args.ptr_K, args.shape_K, args.stride_K,\n                args.ptr_V, args.shape_V, args.stride_V,\n                args.ptr_dO, args.shape_dO, args.stride_dO,\n                args.ptr_dQaccum, args.shape_dQaccum, args.stride_dQaccum,\n                cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))),\n                args.ptr_LSE_log2, args.shape_LSE, args.stride_LSE_log2, args.ptr_dPsum, args.stride_dPsum,\n                args.softmax_scale,\n                !Has_softcap ? float(args.softmax_scale * M_LOG2E) : float(args.softcap_val * M_LOG2E),\n                args.window_size_left, args.window_size_right, attention_chunk_divmod,\n                !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val,\n                args.num_batch, args.dq_semaphore,\n                args.cu_seqlens_q, args.cu_seqlens_k, args.seqused_q, args.seqused_k};\n    }\n\n    template <typename SharedStorage, typename FrgTensordKV>\n    CUTLASS_DEVICE bool\n    mma(Params const& params,\n        FrgTensordKV& tdKrdK,\n        FrgTensordKV& tdVrdV,\n        int thread_idx,\n        cute::tuple<int32_t, int32_t, int32_t> block_coord,\n        SharedStorage& shared_storage\n        ) {\n        static_assert(is_rmem<FrgTensordKV>::value, \"dK and dV tensor must be rmem resident.\");\n\n        int n_block = get<0>(block_coord);\n        int bidh = get<1>(block_coord);\n        int bidb = get<2>(block_coord);\n        SeqlenInfo_t seqlen_info{\n            bidb, get<0>(params.shape_Q), size<0>(params.shape_K),\n            params.cu_seqlens_q, params.cu_seqlens_k, params.seqused_q, params.seqused_k\n        };\n        auto m_block_min_max = BlockMN_t::get_m_block_min_max(\n            seqlen_info, n_block, bidb,\n            params.window_size_left, params.window_size_right, 0 /*sink_token_length*/);\n        int const m_block_min = get<0>(m_block_min_max);\n        int const m_block_max = get<1>(m_block_min_max);\n        // It's possible to have m_block_max <= m_block_min. Exit early\n        if constexpr (Is_causal || Is_local || Varlen) {\n            if (m_block_max <= m_block_min) { return false; }\n        }\n\n        Tensor sQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQ{});\n        Tensor sdO = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_do.data()), SmemLayoutdO{});\n        Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{});\n        Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutV{});\n        Tensor sQt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQt{});\n        Tensor sdOt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_do.data()), SmemLayoutdOt{});\n        Tensor sKt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutKt{});\n        Tensor sP = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutPdS{});\n        Tensor sPt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutPdSt{});\n        Tensor sdS = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_ds.data()), SmemLayoutPdS{});\n        Tensor sdSt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_ds.data()), SmemLayoutPdSt{});\n        Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_lse.data()), SmemLayoutLSE{});\n        Tensor sdPsum = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_dpsum.data()), SmemLayoutLSE{});\n        Tensor sLSEMma = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_lse.data()), SmemLayoutLSEMma{});\n        Tensor sdPsumMma = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_dpsum.data()), SmemLayoutLSEMma{});\n\n        bool const is_varlen_q = Varlen && params.cu_seqlens_q;\n        bool const is_varlen_k = Varlen && params.cu_seqlens_k;\n        int bidh_kv = params.qhead_per_khead_divmod.divide(bidh);\n        Tensor mQ = make_tensor(make_gmem_ptr(params.ptr_Q), params.shape_Q, params.stride_Q)(_, _, bidh, !is_varlen_q ? bidb : 0);\n        Tensor mdO = make_tensor(make_gmem_ptr(params.ptr_dO), params.shape_dO, params.stride_dO)(_, _, bidh, !is_varlen_q ? bidb : 0);\n        Tensor mK = make_tensor(make_gmem_ptr(params.ptr_K), params.shape_K, params.stride_K)(_, _, bidh_kv, !is_varlen_k ? bidb : 0);\n        Tensor mV = make_tensor(make_gmem_ptr(params.ptr_V), params.shape_V, params.stride_V)(_, _, bidh_kv, !is_varlen_k ? bidb : 0);\n        Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE_log2), params.shape_LSE, params.stride_LSE_log2)(_, bidh, !is_varlen_q ? bidb : 0);\n        Tensor mdPsum = make_tensor(make_gmem_ptr(params.ptr_dPsum), params.shape_LSE, params.stride_dPsum)(_, bidh, !is_varlen_q ? bidb : 0);\n        Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.ptr_dQaccum)),\n                                      params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen_q ? bidb : 0);\n\n        Tensor gQ = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQ), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{}));  // (M, K, _)\n        Tensor gdO = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mdO), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{}));  // (M, K, _)\n        Tensor gK = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{}));  // (N, K)\n        Tensor gV = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{}));  // (N, K)\n        Tensor gLSE = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded), mLSE), select<0>(TileShape_MNK{}), make_coord(_));  // (M, _)\n        Tensor gdPsum = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded), mdPsum), select<0>(TileShape_MNK{}), make_coord(_));  // (M, _)\n        Tensor gdQaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded * kHeadDim), mdQaccum), Shape<Int<kBlockM * kHeadDim>>{}, make_coord(_));  // (M * K, _)\n\n        GmemTiledCopyQKV gmem_tiled_copy_QKV;\n        auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(thread_idx);\n        auto gmem_thr0_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(_0{});  // For index calculation\n        GmemTiledCopyLSE gmem_tiled_copy_lse;\n        auto gmem_thr_copy_lse = gmem_tiled_copy_lse.get_thread_slice(thread_idx);\n        R2STiledCopydQaccum r2s_tiled_copy_dQaccum;\n        auto r2s_thr_copy_dQaccum = r2s_tiled_copy_dQaccum.get_thread_slice(thread_idx);\n\n        Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);\n        Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);\n        Tensor tdOgdO = gmem_thr_copy_QKV.partition_S(gdO);\n        Tensor tdOsdO = gmem_thr_copy_QKV.partition_D(sdO);\n        Tensor tLSEgLSE = gmem_thr_copy_lse.partition_S(gLSE);\n        Tensor tLSEsLSE = gmem_thr_copy_lse.partition_D(sLSE);\n        Tensor tLSEgdPsum = gmem_thr_copy_lse.partition_S(gdPsum);\n        Tensor tLSEsdPsum = gmem_thr_copy_lse.partition_D(sdPsum);\n        // We can reuse r2s_thr_copy_dQaccum for this partitioning\n        Tensor tdQgdQaccum = r2s_thr_copy_dQaccum.partition_D(gdQaccum);\n        // if (blockIdx.x == 0 && threadIdx.x == 128) { print(mdQaccum); printf(\"\\n\"); print(gdQaccum_); printf(\"\\n\"); print(gdQaccum); printf(\"\\n\"); print(tdQgdQaccum); printf(\"\\n\"); }\n\n        TiledMmaSdP tiled_mma_SdP;\n        TiledMmadKV tiled_mma_dKV;\n        TiledMmadQ tiled_mma_dQ;\n\n        auto thr_mma_SdP = tiled_mma_SdP.get_thread_slice(thread_idx);\n        auto thr_mma_dKV = tiled_mma_dKV.get_thread_slice(thread_idx);\n        auto thr_mma_dQ = tiled_mma_dQ.get_thread_slice(thread_idx);\n\n        // Allocate \"fragments/descriptors\"\n        // We have to use the templated mma_partition_fragment_AB instead of cute::conditional_return or lambda,\n        // because some partition_fragment_A/B don't compile.\n        // https://stackoverflow.com/questions/50051473/if-constexpr-in-c17-does-not-work-in-a-non-templated-function\n        Tensor tdPrV = mma_partition_fragment_AB</*A=*/SdP_swapAB>(thr_mma_SdP, sV);\n\n        // Copy Atom retiling\n        auto smem_copy_atom_SdP_B = cute::conditional_return<MmaSdPEvenN>(SmemCopyAtom{}, SmemCopyAtomHalf{});\n        auto smem_tiled_copy_QdO = cute::conditional_return<!SdP_swapAB>(make_tiled_copy_A(SmemCopyAtom{}, tiled_mma_SdP), make_tiled_copy_B(smem_copy_atom_SdP_B, tiled_mma_SdP));\n        auto smem_thr_copy_QdO = smem_tiled_copy_QdO.get_thread_slice(thread_idx);\n        Tensor tSsQ = smem_thr_copy_QdO.partition_S(sQ);\n        Tensor tdPsdO = smem_thr_copy_QdO.partition_S(sdO);\n\n        auto smem_tiled_copy_KV = cute::conditional_return<!SdP_swapAB>(make_tiled_copy_B(smem_copy_atom_SdP_B, tiled_mma_SdP), make_tiled_copy_A(SmemCopyAtom{}, tiled_mma_SdP));\n        auto smem_thr_copy_KV = smem_tiled_copy_KV.get_thread_slice(thread_idx);\n        Tensor tSsK = smem_thr_copy_KV.partition_S(sK);\n        Tensor tdPsV = smem_thr_copy_KV.partition_S(sV);\n\n        auto r2s_tiled_copy_PdS = make_tiled_copy_C(R2SCopyAtomPdS{}, tiled_mma_SdP);\n        auto r2s_thr_copy_PdS = r2s_tiled_copy_PdS.get_thread_slice(thread_idx);\n        Tensor tPsP = r2s_thr_copy_PdS.partition_D(cute::conditional_return<!SdP_swapAB>(sP, sPt));      // ((Atom,AtomNum),PIPE_M,PIPE_N)\n        Tensor tdSsdS = r2s_thr_copy_PdS.partition_D(cute::conditional_return<!SdP_swapAB>(sdS, sdSt));      // ((Atom,AtomNum),PIPE_M,PIPE_N)\n        // if (blockIdx.x == 0 && threadIdx.x == 128) { print(r2s_thr_copy_PdS); print(sP); printf(\"\\n\"); print(sPt); printf(\"\\n\"); print(tPsP); printf(\"\\n\"); print(tdSsdS); printf(\"\\n\"); }\n\n        auto smem_copy_atom_dKV_B = cute::conditional_return<MmadKVEvenN>(SmemCopyAtomTransposed{}, SmemCopyAtomTransposedHalf{});\n        auto smem_tiled_copy_PdSt = cute::conditional_return<!dKV_swapAB>(make_tiled_copy_A(SmemCopyAtomTransposed{}, tiled_mma_dKV), make_tiled_copy_B(smem_copy_atom_dKV_B, tiled_mma_dKV));\n        auto smem_thr_copy_PdSt = smem_tiled_copy_PdSt.get_thread_slice(thread_idx);\n        Tensor tdVsPt = smem_thr_copy_PdSt.partition_S(sPt);\n        Tensor tdKsdSt = smem_thr_copy_PdSt.partition_S(sdSt);\n\n        auto smem_tiled_copy_QdOt = cute::conditional_return<!dKV_swapAB>(make_tiled_copy_B(smem_copy_atom_dKV_B, tiled_mma_dKV), make_tiled_copy_A(SmemCopyAtomTransposed{}, tiled_mma_dKV));\n        auto smem_thr_copy_QdOt = smem_tiled_copy_QdOt.get_thread_slice(thread_idx);\n        Tensor tdVsdOt = smem_thr_copy_QdOt.partition_S(sdOt);\n        Tensor tdKsQt = smem_thr_copy_QdOt.partition_S(sQt);\n\n        auto smem_tiled_copy_dS = cute::conditional_return<!dQ_swapAB>(\n            make_tiled_copy_A(SmemCopyAtom{}, tiled_mma_dQ),\n            make_tiled_copy_B(cute::conditional_return<MmadQEvenN>(SmemCopyAtom{}, SmemCopyAtomHalf{}), tiled_mma_dQ));\n        auto smem_thr_copy_dS = smem_tiled_copy_dS.get_thread_slice(thread_idx);\n        Tensor tdQsdS = smem_thr_copy_dS.partition_S(sdS);\n\n        auto smem_tiled_copy_Kt = cute::conditional_return<!dQ_swapAB>(\n            make_tiled_copy_B(cute::conditional_return<MmadQEvenN>(SmemCopyAtomTransposed{}, SmemCopyAtomTransposedHalf{}), tiled_mma_dQ),\n            make_tiled_copy_A(SmemCopyAtomTransposed{}, tiled_mma_dQ));\n        auto smem_thr_copy_Kt = smem_tiled_copy_Kt.get_thread_slice(thread_idx);\n        Tensor tdQsKt = smem_thr_copy_Kt.partition_S(sKt);\n\n        // thr_mma_SdP.partition_C(sLSEMma) has shape (MMA=4, MMA_M, MMA_N, PIPE), we only take the col indices\n        // or row indices, depending on whether SdP_swapAB.\n        Tensor tSsLSEMma = logical_divide(thr_mma_SdP.partition_C(sLSEMma), Shape<_2>{});  // (2, 2, MMA_M, MMA_N, PIPE)\n        Tensor tSsLSE = group_modes<0, 2>(cute::conditional_return<!SdP_swapAB>(\n            tSsLSEMma(make_coord(_0{}, _), _, _0{}, _),  // (2, MMA_M, PIPE)\n            tSsLSEMma(make_coord(_, _0{}), _0{}, _, _)));  // (2, MMA_N, PIPE)\n        Tensor tSsdPsumMma = logical_divide(thr_mma_SdP.partition_C(sdPsumMma), Shape<_2>{});\n        Tensor tSsdPsum = group_modes<0, 2>(cute::conditional_return<!SdP_swapAB>(\n            tSsdPsumMma(make_coord(_0{}, _), _, _0{}, _),  // (2, MMA_M, PIPE)\n            tSsdPsumMma(make_coord(_, _0{}), _0{}, _, _)));  // (2, MMA_N, PIPE)\n        // if (blockIdx.x == 0 && threadIdx.x == 128) { print(sLSEMma); printf(\"\\n\"); print(tLSEsLSE); printf(\"\\n\"); }\n        // If we want to split the stats among the 8 threads that share the same rows.\n        static constexpr int kStatsPerThread = cute::ceil_div(decltype(size(tSsLSE))::value, 8);\n\n        // Predicates\n        Tensor cQ = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{}));\n        Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ);\n        Tensor t0QcQ = gmem_thr0_copy_QKV.partition_S(cQ);\n        Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));\n        #pragma unroll\n        for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(_0{}, _0{}, k)) < get<1>(params.shape_Q); }\n        Tensor cLSE = cute::make_identity_tensor(select<0>(TileShape_MNK{}));\n        Tensor tLSEcLSE = gmem_thr_copy_lse.partition_S(cLSE);\n        Tensor tdOpdO = make_tensor<bool>(make_shape(size<2>(tdOsdO)));\n        #pragma unroll\n        for (int k = 0; k < size(tdOpdO); ++k) { tdOpdO(k) = get<1>(tQcQ(_0{}, _0{}, k)) < get<1>(params.shape_dO); }\n\n        int const seqlen_q = seqlen_info.seqlen_q;\n        int const seqlen_k = seqlen_info.seqlen_k;\n\n        flash::Mask<kBlockM, kBlockN, false /*PackGQA*/, TiledMmaSdP, SdP_swapAB> mask(\n            thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, 0 /*sink_token_length*/,\n            params.attention_chunk_divmod, params.qhead_per_khead_divmod\n        );\n\n        {\n            Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK);  // (KCPY, KCPY_N, KCPY_K, nblocksN)\n            Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);\n            Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV);  // (VCPY, VCPY_N, VCPY_K, nblocksN)\n            Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);\n            // Predicates\n            Tensor cKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{}));\n            Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV);\n            Tensor t0KVcKV = gmem_thr0_copy_QKV.partition_S(cKV);\n            Tensor tKpK = make_tensor<bool>(make_shape(size<2>(tKsK)));\n            Tensor tVpV = make_tensor<bool>(make_shape(size<2>(tVsV)));\n            #pragma unroll\n            for (int k = 0; k < size(tKpK); ++k) { tKpK(k) = get<1>(tKVcKV(_0{}, _0{}, k)) < get<1>(params.shape_K); }\n            #pragma unroll\n            for (int k = 0; k < size(tVpV); ++k) { tVpV(k) = get<1>(tKVcKV(_0{}, _0{}, k)) < get<1>(params.shape_V); }\n            // Do we need bound check to make sure the row doesn't go above kBlockN\n            static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0;\n            // static_assert(EvenN);  // It simplifies the loading of K and V\n            // Instead of passing in tKVcKV, we pass in t0KVcKV and subtract the offset from the limit\n            // (seqlen_k - n_block * kBlockN). This is because the entries of t0KVcKV are known at compile time.\n            // int const seqlenk_row_limit = -int(get<0>(tKVcKV(_0{}, _0{}, _0{}))) + (EvenN\n            //     ? seqlen_info.seqlen_k - n_block * kBlockN\n            //     : std::min(seqlen_info.seqlen_k - n_block * kBlockN, kBlockN));\n            // // Need Clear_OOB_MN to be true here since the gemm will sum over the kBlockN dimension\n            // flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true, /*Clear_OOB_K=*/true>(\n            //     gmem_tiled_copy_QKV, tVgV, tVsV, t0KVcKV, tKVpKV, seqlenk_row_limit);\n            int const seqlenk_row_limit = seqlen_k - n_block * kBlockN - get<0>(tKVcKV(_0{}, _0{}, _0{}));\n            #pragma unroll\n            for (int m = 0; m < size<1>(tVsV); ++m) {\n                // If kBlockN doesn't evenly divide the tiled copy, only the last `m` needs to be checked\n                if (EvenN || m < size<1>(tVsV) - 1 || get<0>(tKVcKV(_0{}, m, _0{})) < kBlockN) {\n                    bool const predicate_n = get<0>(t0KVcKV(_0{}, m, _0{})) < seqlenk_row_limit;\n                    #pragma unroll\n                    for (int k = 0; k < size<2>(tVsV); ++k) {\n                        cute::copy(gmem_tiled_copy_QKV.with(tVpV(k) && predicate_n), tVgV(_, m, k), tVsV(_, m, k));\n                    }\n                }\n            }\n            if constexpr (V_in_regs) { flash::cp_async_fence(); }\n            // flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true, /*Clear_OOB_K=*/true>(\n            //     gmem_tiled_copy_QKV, tKgK, tKsK, t0KVcKV, tKVpKV, seqlenk_row_limit);\n            #pragma unroll\n            for (int m = 0; m < size<1>(tKsK); ++m) {\n                if (EvenN || m < size<1>(tKsK) - 1 || get<0>(tKVcKV(_0{}, m, _0{})) < kBlockN) {\n                    bool const predicate_n = get<0>(t0KVcKV(_0{}, m, _0{})) < seqlenk_row_limit;\n                    #pragma unroll\n                    for (int k = 0; k < size<2>(tKsK); ++k) {\n                        cute::copy(gmem_tiled_copy_QKV.with(tKpK(k) && predicate_n), tKgK(_, m, k), tKsK(_, m, k));\n                    }\n                }\n            }\n            flash::cp_async_fence();\n        }\n\n        if constexpr (V_in_regs) {\n            flash::cp_async_wait<1>();\n            __syncthreads();\n            Tensor tdPrV_copy_view = smem_thr_copy_KV.retile_D(tdPrV);\n            Tensor tdPsV_copy_view = smem_thr_copy_KV.partition_S(sV);\n            cute::copy(smem_tiled_copy_KV, tdPsV_copy_view, tdPrV_copy_view);\n            __syncthreads();  // Sync to avoid loading Q to smem_q, which overlaps with smem_v\n        }\n\n        // Do we need bound check to make sure the row doesn't go above kBlockM\n        static constexpr int kBlockM = get<0>(TileShape_MNK{});\n        static constexpr bool EvenM = kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0;\n\n        auto load_Q_LSE = [&] (int const m_block, int const smem_pipe_write) {\n            // if (cute::thread0()) { printf(\"Inside load_Q_LSE, m_block = %d, smem_pipe_write = %d\\n\", m_block, smem_pipe_write); }\n            Tensor tQsQ_cur = tQsQ(_, _, _, smem_pipe_write);\n            Tensor tQgQ_cur = tQgQ(_, _, _, m_block);\n            // Instead of passing in tQcQ, we pass in t0QcQ and subtract the offset from the limit\n            // (seqlen_q - m_block * kBlockM). This is because the entries of t0QcQ are known at compile time.\n            // int const seqlenq_row_limit = -int(get<0>(tQcQ(_0{}, _0{}, _0{}))) + (EvenM\n            //     ? seqlen_info.seqlen_q - m_block * kBlockM\n            //     : std::min(seqlen_info.seqlen_q - m_block * kBlockM, kBlockM));\n            // Need Clear_OOB_MN to be true here since the gemm will sum over the kBlockM dimension\n            // flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true, /*Clear_OOB_K=*/true>(\n            //     gmem_tiled_copy_QKV, tQgQ(_, _, _, m_block), tQsQ_cur, t0QcQ, tQpQ, seqlenq_row_limit);\n            int const seqlenq_row_limit = seqlen_info.seqlen_q - m_block * kBlockM - get<0>(tQcQ(_0{}, _0{}, _0{}));\n            #pragma unroll\n            for (int m = 0; m < size<1>(tQsQ); ++m) {\n                // If kBlockM doesn't evenly divide the tiled copy, only the last `m` needs to be checked\n                if (EvenM || m < size<1>(tQsQ) - 1 || get<0>(tQcQ(_0{}, m, _0{})) < kBlockM) {\n                    bool const predicate_m = get<0>(t0QcQ(_0{}, m, _0{})) < seqlenq_row_limit;\n                    #pragma unroll\n                    for (int k = 0; k < size<2>(tQsQ); ++k) {\n                        cute::copy(gmem_tiled_copy_QKV.with(tQpQ(k) && predicate_m), tQgQ_cur(_, m, k), tQsQ_cur(_, m, k));\n                    }\n                }\n            }\n            Tensor tLSEgLSE_cur = tLSEgLSE(_, _, m_block);\n            Tensor tLSEsLSE_cur = tLSEsLSE(_, _, smem_pipe_write);\n            // We made sure LSE length is padded so we read `kBlockM` elements so that all\n            // elements in sLSE are filled. Without this we might have uninitialized sLSE values.\n            #pragma unroll\n            for (int m = 0; m < size<1>(tLSEsLSE); ++m) {\n                if (get<0>(tLSEcLSE(_0{}, m)) < kBlockM) {\n                    cute::copy(gmem_tiled_copy_lse, tLSEgLSE_cur(_, m), tLSEsLSE_cur(_, m));\n                }\n            }\n        };\n\n        auto load_dO_dPsum = [&] (int const m_block, int const smem_pipe_write) {\n            // if (cute::thread0()) { printf(\"Inside load_dO_dPsum, m_block = %d, smem_pipe_write = %d\\n\", m_block, smem_pipe_write); }\n            Tensor tdOsdO_cur = tdOsdO(_, _, _, smem_pipe_write);\n            Tensor tdOgdO_cur = tdOgdO(_, _, _, m_block);\n            // int const seqlenq_row_limit = -int(get<0>(tQcQ(_0{}, _0{}, _0{}))) + (EvenM\n            //     ? seqlen_info.seqlen_q - m_block * kBlockM\n            //     : std::min(seqlen_info.seqlen_q - m_block * kBlockM, kBlockM));\n            // flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true, /*Clear_OOB_K=*/true>(\n            //     gmem_tiled_copy_QKV, tdOgdO(_, _, _, m_block), tdOsdO_cur, t0QcQ, tQpQ, seqlenq_row_limit);\n            int const seqlenq_row_limit = seqlen_info.seqlen_q - m_block * kBlockM - get<0>(tQcQ(_0{}, _0{}, _0{}));\n            #pragma unroll\n            for (int m = 0; m < size<1>(tdOsdO); ++m) {\n                // If kBlockM doesn't evenly divide the tiled copy, only the last `m` needs to be checked\n                if (EvenM || m < size<1>(tdOsdO) - 1 || get<0>(tQcQ(_0{}, m, _0{})) < kBlockM) {\n                    bool const predicate_m = get<0>(t0QcQ(_0{}, m, _0{})) < seqlenq_row_limit;\n                    #pragma unroll\n                    for (int k = 0; k < size<2>(tdOsdO); ++k) {\n                        cute::copy(gmem_tiled_copy_QKV.with(tdOpdO(k) && predicate_m), tdOgdO_cur(_, m, k), tdOsdO_cur(_, m, k));\n                    }\n                }\n            }\n            Tensor tLSEgdPsum_cur = tLSEgdPsum(_, _, m_block);\n            Tensor tLSEsdPsum_cur = tLSEsdPsum(_, _, smem_pipe_write);\n            #pragma unroll\n            for (int m = 0; m < size<1>(tLSEsdPsum); ++m) {\n                if (get<0>(tLSEcLSE(_0{}, m)) < kBlockM) {\n                    cute::copy(gmem_tiled_copy_lse, tLSEgdPsum_cur(_, m), tLSEsdPsum_cur(_, m));\n                }\n            }\n        };\n\n        int m_block = m_block_min;\n\n        // Note, using the for_each() function here to ensure `stage` is of type Int<x>.\n        for_each(make_int_sequence<kStages>{}, [&] (auto stage) {\n            static constexpr bool Is_first_stage = CUTE_STATIC_V(stage) == 0;\n            static constexpr bool Is_last_stage = CUTE_STATIC_V(stage) == kStages - 1;\n            if constexpr (!Is_last_stage || kStages == 1) {\n                if (Is_first_stage || m_block + stage < m_block_max) {\n                    load_Q_LSE(m_block + stage, stage);\n                }\n            }\n            // We want the fence outside the if statement to have a fixed number of cp.async commits.\n            // so that we can wait with the correct number of outstanding commits.\n            cute::cp_async_fence();\n            if constexpr (stage < kStages_dO) {\n                if (Is_first_stage || m_block + stage < m_block_max) {\n                    load_dO_dPsum(m_block + stage, stage);\n                }\n                cute::cp_async_fence();\n            }\n        });\n\n        int smem_pipe_read = 0, smem_pipe_read_do = 0, smem_pipe_write = kStages - 1, smem_pipe_write_do = 0;\n\n        auto load_Q_next = [&] {\n            // if (cute::thread0()) { printf(\"m_block = %d, m_block_max = %d, smem_pipe_write = %d\\n\", m_block, m_block_max, smem_pipe_write); }\n            if (m_block + (kStages > 1 ? kStages - 1 : 1) < m_block_max) {\n                load_Q_LSE(m_block + (kStages > 1 ? kStages - 1 : 1), kStages > 1 ? smem_pipe_write : 0);\n            }\n            cute::cp_async_fence();\n        };\n\n        auto load_dO_next = [&] {\n            // int smem_pipe_write_do_cur = Q_dO_same_stages ? smem_pipe_write : smem_pipe_write_do;\n            if (m_block + kStages_dO < m_block_max) {\n                // load_dO_dPsum(m_block + kStages_dO, kStages_dO > 1 ? smem_pipe_write_do_cur : 0);\n                load_dO_dPsum(m_block + kStages_dO, kStages_dO > 1 ? smem_pipe_write_do : 0);\n            }\n            cute::cp_async_fence();\n        };\n\n        clear(tdKrdK);\n        clear(tdVrdV);\n\n        auto bwd_step = [&](int m_block, auto mask_fn) {\n            Tensor tSrS = partition_fragment_C(tiled_mma_SdP, select<!SdP_swapAB ? 0 : 1, !SdP_swapAB ? 1 : 0>(TileShape_MNK{}));\n            clear(tSrS);\n            flash::cp_async_wait<(kStages > 1) ? 1 : 0>();\n            __syncthreads();\n            Tensor tSrQ = mma_partition_fragment_AB</*A=*/!SdP_swapAB>(thr_mma_SdP, sQ(_, _, _0{}));\n            Tensor tSrK = mma_partition_fragment_AB</*A=*/SdP_swapAB>(thr_mma_SdP, sK);\n            // if (cute::thread0()) { print(tiled_mma_SdP); print(tSrS); printf(\"\\n\"); print(tSrQ); printf(\"\\n\"); print(tSrK); printf(\"\\n\"); print(tSsQ); printf(\"\\n\"); print(tSsK); printf(\"\\n\"); }\n            flash::gemm_sm80<false /*A_in_regs*/, false /*B_in_regs*/, SdP_swapAB>(\n                tSrS, tSrQ, tSrK, tSsQ(_, _, _, kStages > 1 ? smem_pipe_read : 0), tSsK,\n                tiled_mma_SdP, smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV, nullptr /*hook*/);\n            Tensor tLSErLSE = cute::conditional_return<!ShuffleLSE>(make_fragment_like(tSsLSE(_, _0{})), make_tensor<ElementAccum>(Int<kStatsPerThread>{}));\n            if constexpr (!ShuffleLSE) {\n                cute::copy(tSsLSE(_, kStages > 1 ? smem_pipe_read : 0), tLSErLSE);\n            } else {\n                #pragma unroll\n                for (int i = 0; i < kStatsPerThread; ++i) {\n                    // It's ok to read OOB, since we made sure sLSE is large enough and we won't use the OOB values\n                    tLSErLSE(i) = tSsLSE((thread_idx % 32) / 4 + i * 8, kStages > 1 ? smem_pipe_read : 0);\n                }\n            }\n            if constexpr (Has_softcap) { flash::apply_softcap(tSrS, params.softcap_val); }\n\n            // Reshape tSrS from (4, MMA_N, MMA_M) to (nrow=(2, MMA_M), ncol=(2, MMA_N))\n            Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_rowcol</*Transposed=*/SdP_swapAB>(tSrS.layout()));\n            // dtanh needs to happen before masking, otherwise we get 1 - (-inf)^2 = NaN in the dtanh\n            // if (cute::thread0()) { print_tensor(scores); }\n            auto dtanh = [&] { if constexpr (Has_softcap) return flash::calculate_dtanh(scores); else return nullptr; }();\n            mask_fn(tSrS, m_block);\n            #pragma unroll\n            for (int mi = 0; mi < size<0>(scores); ++mi) {\n                float const lse_scaled = [&] {\n                    if constexpr (!ShuffleLSE) return tLSErLSE(mi);\n                    else return __shfl_sync(0xffffffff, tLSErLSE(mi / 8), (mi % 8) * 4 + (thread_idx % 4));\n                }();\n                #pragma unroll\n                for (int ni = 0; ni < size<1>(scores); ++ni) {\n                    scores(mi, ni) = exp2f(scores(mi, ni) * params.softmax_scale_log2 - lse_scaled);\n                }\n            }\n\n            Tensor tdPrdP = partition_fragment_C(tiled_mma_SdP, select<!SdP_swapAB ? 0 : 1, !SdP_swapAB ? 1 : 0>(TileShape_MNK{}));\n            clear(tdPrdP);\n            int smem_pipe_read_do_cur = Q_dO_same_stages ? smem_pipe_read : smem_pipe_read_do;\n            flash::cp_async_wait<(kStages_dO > 1) ? 1 : 0>();\n            __syncthreads();\n            auto hook = cute::conditional_return<(kStages > 1)>(load_Q_next, nullptr);\n            Tensor tdPrdO = mma_partition_fragment_AB</*A=*/!SdP_swapAB>(thr_mma_SdP, sdO(_, _, _0{}));\n            Tensor tdPrV_cur = cute::conditional_return<V_in_regs>(tdPrV, mma_partition_fragment_AB</*A=*/SdP_swapAB>(thr_mma_SdP, sV));\n            flash::gemm_sm80<false /*A_in_regs*/, V_in_regs, SdP_swapAB>(\n                tdPrdP, tdPrdO, tdPrV_cur, tdPsdO(_, _, _, kStages_dO > 1 ? smem_pipe_read_do_cur : 0), tdPsV,\n                tiled_mma_SdP, smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV, hook);\n            Tensor tLSErdPsum = cute::conditional_return<!ShuffledPsum>(make_fragment_like(tSsdPsum(_, _0{})), make_tensor<ElementAccum>(Int<kStatsPerThread>{}));\n            if constexpr (!ShuffledPsum) {\n                cute::copy(tSsdPsum(_, kStages_dO > 1 ? smem_pipe_read_do_cur : 0), tLSErdPsum);\n            } else {\n                #pragma unroll\n                for (int i = 0; i < kStatsPerThread; ++i) {\n                    tLSErdPsum(i) = tSsdPsum((thread_idx % 32) / 4 + i * 8, kStages_dO > 1 ? smem_pipe_read_do_cur : 0);\n                }\n            }\n\n            // Reshape tdPrdP from (4, MMA_N, MMA_M) to (nrow=(2, MMA_M), ncol=(2, MMA_N))\n            Tensor dS = make_tensor(tdPrdP.data(), scores.layout());\n            #pragma unroll\n            for (int mi = 0; mi < size<0>(dS); ++mi) {\n                float const dP_sum_cur = [&] {\n                    if constexpr (!ShuffledPsum) return tLSErdPsum(mi);\n                    else return __shfl_sync(0xffffffff, tLSErdPsum(mi / 8), (mi % 8) * 4 + (thread_idx % 4));\n                }();\n                #pragma unroll\n                for (int ni = 0; ni < size<1>(dS); ++ni) {\n                    dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - dP_sum_cur);\n                    if constexpr (Has_softcap) { dS(mi, ni) *= dtanh(mi, ni); }\n                }\n            }\n            // if (cute::thread0()) { print_tensor(dS); }\n\n            // Convert scores from fp32 to fp16/bf16\n            Tensor rP = make_tensor_like<Element>(tSrS);\n            flash::convert_type_out(tSrS, rP);\n            if constexpr (!Mma_dKV_is_RS) {\n                Tensor tPaP = r2s_thr_copy_PdS.retile_S(rP);  // ((Atom,AtomNum), MMA_N, MMA_N)\n                cute::copy(r2s_tiled_copy_PdS, tPaP, tPsP);\n            }\n            Tensor rdS = make_tensor_like<Element>(tdPrdP);\n            flash::convert_type_out(tdPrdP, rdS);\n            if constexpr (!Mma_dKV_is_RS) { __syncthreads(); }  // Make sure P is written\n            // For hdim 64, It's faster to write to smem_dS first before the dV gemm\n            Tensor tdSadS = r2s_thr_copy_PdS.retile_S(rdS);   // ((Atom,AtomNum), MMA_N, MMA_N)\n            cute::copy(r2s_tiled_copy_PdS, tdSadS, tdSsdS);\n\n            Tensor tdVrdO = mma_partition_fragment_AB</*A=*/dKV_swapAB>(thr_mma_dKV, sdOt(_, _, _0{}));\n            Tensor tdVsdO_cur = tdVsdOt(_, _, _, kStages_dO > 1 ? smem_pipe_read_do_cur : 0);\n            if constexpr (Mma_dKV_is_RS) {\n                Tensor tdVrP = make_tensor(rP.data(), convert_layout_acc_Aregs<TiledMmadKV>(tSrS.layout()));\n                flash::gemm_rs_sm80(tdVrdV, tdVrP, tdVrdO, tdVsdO_cur, tiled_mma_dKV, smem_tiled_copy_QdOt, smem_thr_copy_QdOt);\n            } else {\n                Tensor tdVrP = mma_partition_fragment_AB</*A=*/!dKV_swapAB>(thr_mma_dKV, sPt);\n                flash::gemm_sm80<false /*A_in_regs*/, false /*B_in_regs*/, /*SwapAB=*/dKV_swapAB>(\n                    tdVrdV, tdVrP, tdVrdO, tdVsPt, tdVsdO_cur,\n                    tiled_mma_dKV, smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt, nullptr);\n            }\n            // if (cute::thread0()) { print_tensor(tdVrdV); }\n            __syncthreads();  // make sure sdS is written\n            auto do_mma_dQ = [&] (auto hook) {\n                Tensor tdQrdQ = partition_fragment_C(tiled_mma_dQ, select<!dQ_swapAB ? 0 : 2, !dQ_swapAB ? 2 : 0>(TileShape_MNK{}));\n                clear(tdQrdQ);\n                Tensor tdQrdS = mma_partition_fragment_AB</*A=*/!dQ_swapAB>(thr_mma_dQ, sdS);\n                Tensor tdQrK = mma_partition_fragment_AB</*A=*/dQ_swapAB>(thr_mma_dQ, sKt);\n                flash::gemm_sm80<false /*A_in_regs*/, false /*B_in_regs*/, /*SwapAB=*/dQ_swapAB>(\n                    tdQrdQ, tdQrdS, tdQrK, tdQsdS, tdQsKt, tiled_mma_dQ,\n                    // smem_tiled_copy_dS, smem_tiled_copy_Kt, smem_thr_copy_dS, smem_thr_copy_Kt, load_dO_next);\n                    smem_tiled_copy_dS, smem_tiled_copy_Kt, smem_thr_copy_dS, smem_thr_copy_Kt, hook);\n                // if (cute::thread0()) { print_tensor(tdQrdQ); }\n                // We can reuse r2s_thr_copy_dQaccum for this partitioning\n                Tensor tdQrdQ_atomic = r2s_thr_copy_dQaccum.retile_S(tdQrdQ);\n                Tensor tdQgdQaccum_atomic = tdQgdQaccum(_, _, m_block);\n                static_assert(CUTE_STATIC_V(size(tdQrdQ_atomic)) == CUTE_STATIC_V(size(tdQgdQaccum_atomic)));\n                #pragma unroll\n                for (int i = 0; i < size(tdQrdQ_atomic); ++i) { atomicAdd(&tdQgdQaccum_atomic(i), tdQrdQ_atomic(i)); }\n            };\n            // If kStages == 1, we want to do Mma_dK first so we can start loading Q for the next iteration\n            if constexpr (kStages > 1) { do_mma_dQ(load_dO_next); }\n            Tensor tdKrQ = mma_partition_fragment_AB</*A=*/dKV_swapAB>(thr_mma_dKV, sQt(_, _, _0{}));\n            Tensor tdKsQ_cur = tdKsQt(_, _, _, kStages > 1 ? smem_pipe_read : 0);\n            if constexpr (Mma_dKV_is_RS) {\n                Tensor tdKrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs<TiledMmadKV>(tdPrdP.layout()));\n                flash::gemm_rs_sm80(tdKrdK, tdKrdS, tdKrQ, tdKsQ_cur, tiled_mma_dKV, smem_tiled_copy_QdOt, smem_thr_copy_QdOt);\n            } else {\n                Tensor tdKrdS = mma_partition_fragment_AB</*A=*/!dKV_swapAB>(thr_mma_dKV, sdSt);\n                flash::gemm_sm80<false /*A_in_regs*/, false /*B_in_regs*/, /*SwapAB=*/dKV_swapAB>(\n                    tdKrdK, tdKrdS, tdKrQ, tdKsdSt, tdKsQ_cur,\n                    tiled_mma_dKV, smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt, cute::conditional_return<(kStages > 1)>(nullptr, load_dO_next));\n            }\n            if constexpr (kStages == 1) {\n                __syncthreads();\n                do_mma_dQ(load_Q_next);\n            }\n            // if (cute::thread0()) { print_tensor(tdKrdK); }\n\n            smem_pipe_read = smem_pipe_read < kStages - 1 ? smem_pipe_read + 1 : 0;\n            smem_pipe_read_do = smem_pipe_read_do < kStages_dO - 1 ? smem_pipe_read_do + 1 : 0;\n            smem_pipe_write = smem_pipe_write < kStages - 1 ? smem_pipe_write + 1 : 0;\n            smem_pipe_write_do = smem_pipe_write_do < kStages_dO - 1 ? smem_pipe_write_do + 1 : 0;\n\n        };\n\n        // We have separate iterations with causal masking. Not necessary for hdim 128 but for hdim 64\n        // this helps quite a bit to not have to do causal masking for most of the iterations.\n        if constexpr ((Is_causal || Is_local) && SeparateMaskingIterations) {\n            auto mask_fn = [&](auto& tSrS, int m_block) { mask.template apply<true /*Seqlenk_mask*/, Is_causal, Is_local>(tSrS, m_block, n_block); };\n            int const m_block_masking_max = ((n_block + 1) * kBlockN - 1 + seqlen_q - seqlen_k - params.window_size_right) / kBlockM + 1;\n            CUTLASS_PRAGMA_NO_UNROLL\n            for (; m_block < std::min(m_block_max, m_block_masking_max); ++m_block) {\n                bwd_step(m_block, mask_fn);\n            }\n        }\n\n        static constexpr int kBlockN = get<1>(TileShape_MNK{});\n        int const m_block_max_before_local_mask = !Is_local || !SeparateMaskingIterations\n            ? m_block_max\n            : std::min(m_block_max, (n_block * kBlockN + seqlen_q - seqlen_k + params.window_size_left) / kBlockM);\n\n        auto mask_fn = [&](auto& tSrS, int m_block) { mask.template apply<true /*Seqlenk_mask*/, Is_causal && !SeparateMaskingIterations, Is_local && !SeparateMaskingIterations>(tSrS, m_block, n_block); };\n        CUTLASS_PRAGMA_NO_UNROLL\n        for (; m_block < m_block_max_before_local_mask; ++m_block) {\n            bwd_step(m_block, mask_fn);\n        }\n\n        if constexpr (Is_local && SeparateMaskingIterations) {\n            auto mask_fn = [&](auto& tSrS, int m_block) { mask.template apply<true /*Seqlenk_mask*/, false /*Causal_mask*/, Is_local>(tSrS, m_block, n_block); };\n            CUTLASS_PRAGMA_NO_UNROLL\n            for (; m_block < m_block_max; ++m_block) {\n                bwd_step(m_block, mask_fn);\n            }\n        }\n\n        // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(tdVrdV); }\n        #pragma unroll\n        for (int i = 0; i < size(tdKrdK); ++i) { tdKrdK(i) *= params.softmax_scale; }\n\n        return true;\n    }\n\n};\n\n} // namespace flash\n"
  },
  {
    "path": "hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n#include <cutlass/cutlass.h>\n#include <cutlass/array.h>\n#include <cutlass/numeric_types.h>\n#include <cutlass/numeric_conversion.h>\n#include <cutlass/barrier.h>\n#include \"cutlass/pipeline/pipeline.hpp\"\n\n#include \"cute/tensor.hpp\"\n\n#include \"cutlass/gemm/collective/builders/sm90_common.inl\"\n\n#include \"named_barrier.hpp\"\n#include \"seqlen.h\"\n#include \"block.h\"\n#include \"mask.h\"\n#include \"softmax.h\"\n#include \"utils.h\"\n#include \"copy_sm90_bulk_reduce.hpp\"\n\nnamespace flash {\n\nusing namespace cute;\n\ntemplate <int Stages, int Stages_dO, int Stages_dS, class ClusterShape_, class TileShape_MNK_, class Element_, class ElementAccum_, class ArchTag_,\n        bool Is_causal_, bool Is_local_, bool Has_softcap_, bool Varlen_, bool Deterministic,\n        bool SdP_swapAB_, bool dKV_swapAB_, bool dQ_swapAB_,\n        int NumMmaWarpGroups=2, int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1,\n        bool Mma_dP_is_RS=false>\nstruct CollectiveMainloopBwdSm90 {\n\n    static constexpr int kStages = Stages;\n    static constexpr int kStages_dO = Stages_dO;\n    static constexpr int kStages_dS = Stages_dS;\n    static_assert(kStages >= kStages_dO);\n    static_assert(Stages_dS == 1 || Stages_dS == kStages);\n    static_assert(!Mma_dP_is_RS || SdP_swapAB_);  // If Mma_dP_is_RS, we need SdP_SwapAB\n    using ClusterShape = ClusterShape_;\n    using TileShape_MNK = TileShape_MNK_;\n    using Element = Element_;\n    using ElementAccum = ElementAccum_;\n    using ArchTag = ArchTag_;\n    static constexpr bool Is_causal = Is_causal_;\n    static constexpr bool Is_local = Is_local_;\n    static constexpr bool Has_softcap = Has_softcap_;\n    static constexpr bool Varlen = Varlen_;\n\n    static constexpr bool SdP_swapAB = SdP_swapAB_;\n    static constexpr bool dKV_swapAB = dKV_swapAB_;\n    static constexpr bool dQ_swapAB = dQ_swapAB_;\n\n    static constexpr bool Q_dO_same_stages = kStages == kStages_dO;\n\n    static constexpr int kBlockM = get<0>(TileShape_MNK{});\n    static constexpr int kBlockN = get<1>(TileShape_MNK{});\n    static constexpr int kHeadDim = get<2>(TileShape_MNK{});\n\n    using SeqlenInfo_t = flash::SeqlenInfoQK<Varlen, kBlockM>;\n    using BlockMN_t = flash::BlockMN<SeqlenInfo_t, kBlockM, kBlockN, Is_causal, Is_local>;\n\n    static_assert(ArchTag::kMinComputeCapability >= 90);\n    static_assert(get<0>(ClusterShape{}) == 1 && get<2>(ClusterShape{}) == 1);\n\n    static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup;\n    static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarp * 2;\n\n    static_assert(NumMmaWarpGroups % AtomLayoutMSdP == 0);\n    static_assert(NumMmaWarpGroups % AtomLayoutNdKV == 0);\n    static_assert(NumMmaWarpGroups % AtomLayoutMdQ == 0);\n    static constexpr bool Mma_dKV_is_RS = AtomLayoutMSdP == 1 && AtomLayoutNdKV == NumMmaWarpGroups && SdP_swapAB && !dKV_swapAB;\n    static constexpr bool Mma_dQ_is_RS = AtomLayoutMSdP == NumMmaWarpGroups && AtomLayoutMdQ == NumMmaWarpGroups && !SdP_swapAB && !dQ_swapAB;  // If dQ_swapAB we can't use RS\n\n    static constexpr GMMA::Major PdS_Major = GMMA::Major::K;\n    // static constexpr GMMA::Major PdS_Major = GMMA::Major::MN;\n    static constexpr GMMA::Major PdSt_Major = PdS_Major == GMMA::Major::K ? GMMA::Major::MN : GMMA::Major::K;\n\n    using TileShapeAtomSdP = std::conditional_t<\n        !SdP_swapAB,\n        Shape<Int<kBlockM>, Int<kBlockN / (NumMmaWarpGroups / AtomLayoutMSdP)>, Int<kHeadDim>>,\n        Shape<Int<kBlockN>, Int<kBlockM / AtomLayoutMSdP>, Int<kHeadDim>>\n    >;\n    using AtomLayoutSdP = std::conditional_t<\n        !SdP_swapAB,\n        Layout<Shape<Int<AtomLayoutMSdP>, Int<NumMmaWarpGroups / AtomLayoutMSdP>, _1>>,\n        Layout<Shape<Int<NumMmaWarpGroups / AtomLayoutMSdP>, Int<AtomLayoutMSdP>, _1>>\n    >;\n    using TiledMmaSdP = decltype(cute::make_tiled_mma(\n        cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomSdP>(),\n        AtomLayoutSdP{}));\n\n    using TiledMmadPRS = decltype(cute::make_tiled_mma(\n        cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShapeAtomSdP>(),\n        AtomLayoutSdP{}));\n\n    using TileShapeAtomdKV = std::conditional_t<\n        !dKV_swapAB,\n        Shape<Int<kBlockN>, Int<kHeadDim / (NumMmaWarpGroups / AtomLayoutNdKV)>, Int<kBlockM>>,\n        Shape<Int<kHeadDim>, Int<kBlockN / AtomLayoutNdKV>, Int<kBlockM>>\n    >;\n    using AtomLayoutdKV = std::conditional_t<\n        !dKV_swapAB,\n        Layout<Shape<Int<AtomLayoutNdKV>, Int<NumMmaWarpGroups / AtomLayoutNdKV>, _1>>,\n        Layout<Shape<Int<NumMmaWarpGroups / AtomLayoutNdKV>, Int<AtomLayoutNdKV>, _1>>\n    >;\n    using TiledMmadKV = decltype(cute::make_tiled_mma(\n        std::conditional_t<\n            Mma_dKV_is_RS,\n            decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShapeAtomdKV, GMMA::Major::K, GMMA::Major::MN>()),\n            decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdKV, !dKV_swapAB ? PdSt_Major : GMMA::Major::MN, !dKV_swapAB ? GMMA::Major::MN : PdSt_Major>())\n        >{},\n        AtomLayoutdKV{}));\n\n    using TileShapeAtomdQ = std::conditional_t<\n        !dQ_swapAB,\n        Shape<Int<kBlockM>, Int<kHeadDim / (NumMmaWarpGroups / AtomLayoutMdQ)>, Int<kBlockN>>,\n        Shape<Int<kHeadDim>, Int<kBlockM / AtomLayoutMdQ>, Int<kBlockN>>\n    >;\n    using AtomLayoutdQ = std::conditional_t<\n        !dQ_swapAB,\n        Layout<Shape<Int<AtomLayoutMdQ>, Int<NumMmaWarpGroups / AtomLayoutMdQ>, _1>>,\n        Layout<Shape<Int<NumMmaWarpGroups / AtomLayoutMdQ>, Int<AtomLayoutMdQ>, _1>>\n    >;\n    using TiledMmadQ = decltype(cute::make_tiled_mma(\n        std::conditional_t<\n            Mma_dQ_is_RS,\n            decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::K, GMMA::Major::MN>()),\n            decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, !dQ_swapAB ? PdS_Major : GMMA::Major::MN, !dQ_swapAB ? GMMA::Major::MN : PdS_Major>())\n        >{},\n        AtomLayoutdQ{}));\n\n    // We need to accommodate both Q and Q^T (and dO and dO^T) in shared memory.\n    // Q & dO are used in the SdP Mma and Q^T and dO^T are used in the dKV Mma.\n    // Since this is GMMA::Major::K, the M dimension (kBlockM) doesn't matter for the layout, only the K dimension\n    // changes the layout.\n    using SmemLayoutAtomQdO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,\n                                       Int<kBlockM>, Int<kHeadDim / (NumMmaWarpGroups / AtomLayoutNdKV)>>()); // for dKV_Mma\n    using SmemLayoutQ =\n        decltype(tile_to_shape(SmemLayoutAtomQdO{},\n                 make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));\n    using SmemLayoutdO =\n        decltype(tile_to_shape(SmemLayoutAtomQdO{},\n                 make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages_dO>{})));\n\n    using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,\n                                     Int<kBlockN>, Int<kHeadDim / (NumMmaWarpGroups / AtomLayoutMdQ)>>());\n    using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{}, select<1, 2>(TileShape_MNK{})));\n\n    using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,\n        decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());\n    using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{}, select<1, 2>(TileShape_MNK{})));\n\n    using SmemLayoutAtomPdS = decltype(cutlass::gemm::collective::detail::ss_smem_selector<PdS_Major, Element,\n                                       Int<kBlockM / AtomLayoutMSdP>,\n                                       Int<kBlockN / (NumMmaWarpGroups / AtomLayoutMSdP)>>());\n    using SmemLayoutPdS = decltype(tile_to_shape(\n        SmemLayoutAtomPdS{},\n        make_shape(Int<kBlockM>{}, Int<kBlockN>{}, Int<kStages_dS>{}),\n        std::conditional_t<PdS_Major == GMMA::Major::K, cute::Step<_1, _2, _3>, cute::Step<_2, _1, _3>>{}));\n\n    // Need stride to be multiple of 32, otherwise we get error (misaligned address) when doing TMA if e.g. kBlockM=80\n    // We set stride to be multiple of 64 so that if ShuffleLSE, even if threads read from sLSE but out of bounds,\n    // it's still a valid smem address.\n    using SmemLayoutLSE = cute::Layout<cute::Shape<Int<kBlockM>, Int<kStages>>, cute::Stride<_1, Int<cute::round_up(kBlockM, 64)>>>;\n    using SmemLayoutLSEMma = std::conditional_t<\n        SdP_swapAB,\n        cute::Layout<cute::Shape<Int<kBlockN>, Int<kBlockM>, Int<kStages>>, cute::Stride<_0, _1, Int<cute::round_up(kBlockM, 64)>>>,\n        cute::Layout<cute::Shape<Int<kBlockM>, Int<kBlockN>, Int<kStages>>, cute::Stride<_1, _0, Int<cute::round_up(kBlockM, 64)>>>\n    >;\n\n    // Note this is the transpose in terms of the view, not in terms of memory.\n    using SmemLayoutQt =\n        decltype(cute::composition(SmemLayoutQ{},\n                                   make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int<kStages>{}),\n                                               make_stride(Int<kBlockM>{}, _1{}, Int<kBlockM * kHeadDim>{}))));\n    using SmemLayoutdOt =\n        decltype(cute::composition(SmemLayoutdO{},\n                                   make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int<kStages_dO>{}),\n                                               make_stride(Int<kBlockM>{}, _1{}, Int<kBlockM * kHeadDim>{}))));\n    using SmemLayoutKt =\n        decltype(cute::composition(SmemLayoutK{},\n                                   make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})),\n                                               make_stride(Int<kBlockN>{}, _1{}))));\n    using SmemLayoutPdSt =\n        decltype(cute::composition(SmemLayoutPdS{},\n                                   make_layout(make_shape(Int<kBlockN>{}, Int<kBlockM>{}, Int<kStages_dS>{}),\n                                               make_stride(Int<kBlockM>{}, _1{}, Int<kBlockM * kBlockN>{}))));\n\n    // Thread layout, 256 or 384 threads per row\n    // We split into NumMmaWarpGroups so that we can do Bulk reduce add for each WG separately.\n    using R2SLayoutAtomdQaccum = Layout<Shape<Int<cutlass::NumThreadsPerWarpGroup>, Int<NumMmaWarpGroups>>>;\n    using R2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, R2SLayoutAtomdQaccum{},\n                                                         Layout<Shape < _4>>{}));  // Val layout, 4 vals per store\n    using SmemLayoutdQaccum = Layout<Shape<Int<kBlockM * kHeadDim / NumMmaWarpGroups>, Int<NumMmaWarpGroups>>>;\n\n    static constexpr int kNumPdSStore = kBlockM * kBlockN / NumMmaThreads;\n    // If !SdP_swapAB, the accum registers hold P / dS, otherwise they hold Pt / dSt.\n    // If PdS_major is MN, then we need to \"transpose\" the write.\n    using SmemCopyAtomPdS = Copy_Atom<\n        std::conditional_t<(!SdP_swapAB) ^ (PdS_Major == GMMA::Major::MN),\n            std::conditional_t<kNumPdSStore % 8 == 0, cute::SM90_U32x4_STSM_N, cute::SM90_U32x2_STSM_N>,\n            std::conditional_t<kNumPdSStore % 8 == 0, cute::SM90_U16x8_STSM_T, cute::SM90_U16x4_STSM_T>\n        >,\n        Element\n    >;\n\n    using GmemTiledCopyQdO = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape{})));\n    using GmemTiledCopyKV = cute::SM90_TMA_LOAD;\n\n    using ShapeQKV = cute::Shape<int32_t, int32_t, int32_t, int32_t>;  // (seqlen, d, head, batch)\n    using StrideQKV = cute::Stride<int64_t, _1, int64_t, int64_t>;\n    using ShapeLSE = cute::Shape<int32_t, int32_t, int32_t>;  // (seqlen, head, batch)\n    using StrideLSE = cute::Stride<_1, int64_t, int64_t>;  // (seqlen, head, batch)\n    using ShapedQaccum = cute::Shape<int32_t, int32_t, int32_t>;  // (seqlen_q * d, head, batch)\n    using StridedQaccum = cute::Stride<_1, int64_t, int64_t>;\n\n    using TMA_QdO = decltype(make_tma_copy_A_sm90(\n        GmemTiledCopyQdO{},\n        make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)), ShapeQKV{}, StrideQKV{}),\n        take<0, 2>(SmemLayoutQ{}),\n        TileShape_MNK{},\n        ClusterShape{})); // mcast along N mode for this M load, if any\n\n    using TMA_K = decltype(make_tma_copy_B_sm90(\n        GmemTiledCopyKV{},\n        make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)), ShapeQKV{}, StrideQKV{}),\n        SmemLayoutK{},\n        TileShape_MNK{},\n        ClusterShape{})); // no mcast for KV\n\n    using TMA_V = decltype(make_tma_copy_B_sm90(\n        GmemTiledCopyKV{},\n        make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)), ShapeQKV{}, StrideQKV{}),\n        SmemLayoutV{},\n        TileShape_MNK{},\n        ClusterShape{})); // no mcast for KV\n\n    using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;\n    using PipelineState = typename MainloopPipeline::PipelineState;\n    using MainloopPipeline_dO = typename cutlass::PipelineTmaAsync<kStages_dO>;\n    using PipelineState_dO = typename MainloopPipeline_dO::PipelineState;\n\n    // Set the bytes transferred in this TMA transaction (may involve multiple issues)\n    static constexpr uint32_t TmaTransactionBytesQ = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutQ{})) * cutlass::sizeof_bits_v<Element> / 8);\n    static constexpr uint32_t TmaTransactionBytesK = static_cast<uint32_t>(size(SmemLayoutK{}) * cutlass::sizeof_bits_v<Element> / 8);\n    static constexpr uint32_t TmaTransactionBytesV = static_cast<uint32_t>(size(SmemLayoutV{}) * cutlass::sizeof_bits_v<Element> / 8);\n    static constexpr uint32_t TmaTransactionBytesLSE = static_cast<uint32_t>(size(select<0>(SmemLayoutLSE{})) * cutlass::sizeof_bits_v<ElementAccum> / 8);\n\n    // These are tuned for speed. They don't affect correctness.\n    // We have separate iterations with causal masking. Not necessary for hdim 128 but for hdim 64\n    // this helps quite a bit to not have to do causal masking for most of the iterations.\n    // For hdim 192, separating masking iterations results in register spills.\n    static constexpr bool SeparateMaskingIterations = kHeadDim <= 64;\n    // Do we keep the LSE and dPsum in each thread, or split them across 8 threads that share them and then\n    // shuffle to get the value whenever we need? This can reduce register pressure when SdP_swapAB, where each\n    // thread needs to keep statistics for (kBlockM / 4) rows. If !SdP_swapAB, each thread only needs to keep\n    // statistic for 2 rows.\n    static constexpr bool ShuffleLSE = SdP_swapAB && kHeadDim <= 64;\n    static constexpr bool ShuffledPsum = SdP_swapAB && kHeadDim <= 64;\n    static constexpr bool dQacc_use_TMA = kHeadDim < 256;\n    // For hdim256, we want to slice the dQ MMA (64 x 256 on 2 WGs) into two (64 x 128 on 2 WGs) so that we can\n    // do atomic add on one half before doing the other half of the MMA, to reduce register pressure.\n    static constexpr bool Slice_dQKV_Mma = kHeadDim == 256 && !dQacc_use_TMA && dQ_swapAB && AtomLayoutMdQ == 1 && NumMmaWarpGroups == 2;\n    static_assert(!(Deterministic && Slice_dQKV_Mma), \"Deterministic mode not supported with Slice_dQKV_Mma\");\n\n    static constexpr size_t SmemAlignmentP = cutlass::detail::alignment_for_swizzle(SmemLayoutPdS{});\n    static constexpr size_t SmemAlignmentdS = cutlass::detail::alignment_for_swizzle(SmemLayoutPdS{});\n    // Without this SmemAlignment, with hdim 256 we get \"misaligned address\" error in TMA\n    static constexpr size_t SmemAlignmentQKVdO = kHeadDim % 256 == 0 ? 256 : 128;\n    static constexpr size_t SmemAlignmentV = !Mma_dP_is_RS ? SmemAlignmentQKVdO : cutlass::detail::alignment_for_swizzle(SmemLayoutV{});\n    static_assert(SmemAlignmentP >= 128 && SmemAlignmentdS >= 128, \"Require at least 128B alignment\");\n\n    // TODO: do we have to worry that smem_dk and smem_dv in the epilogue don't line up w smem_k and smem_v due to alignment?\n    using SmemdQacc_t = std::conditional_t<!dQacc_use_TMA, cute::array<ElementAccum, 0>, cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutdQaccum>>>;\n    using SmemP_t = std::conditional_t<Mma_dKV_is_RS, cute::array<Element, 0>, cute::array_aligned<Element, cute::cosize_v<SmemLayoutPdS>, SmemAlignmentP>>;\n    struct TensorStorage : cute::aligned_struct<cute::max(SmemAlignmentP, SmemAlignmentdS, SmemAlignmentQKVdO)> {\n        cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>, SmemAlignmentQKVdO> smem_k;\n        cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>, SmemAlignmentV> smem_v;\n        SmemdQacc_t smem_dqacc;\n        cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>, SmemAlignmentQKVdO> smem_q;\n        cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>, SmemAlignmentQKVdO> smem_do;\n        cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutLSE>, 128> smem_lse;\n        cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutLSE>, 128> smem_dpsum;\n        SmemP_t smem_p;\n        cute::array_aligned<Element, cute::cosize_v<SmemLayoutPdS>, SmemAlignmentdS> smem_ds;\n    };\n\n    // Host side kernel arguments\n    struct Arguments {\n        Element const* const ptr_Q;\n        ShapeQKV const shape_Q;\n        StrideQKV const stride_Q;\n        Element const* const ptr_K;\n        ShapeQKV const shape_K;\n        StrideQKV const stride_K;\n        Element const* const ptr_V;\n        ShapeQKV const shape_V;\n        StrideQKV const stride_V;\n        Element const* const ptr_dO;\n        ShapeQKV const shape_dO;\n        StrideQKV const stride_dO;\n        ElementAccum* const ptr_dQaccum;\n        ShapedQaccum const shape_dQaccum;\n        StridedQaccum const stride_dQaccum;\n        float const* const ptr_LSE_log2;\n        ShapeLSE const shape_LSE;\n        StrideLSE const stride_LSE_log2;\n        float const* const ptr_dPsum;\n        StrideLSE const stride_dPsum;\n        float const softmax_scale;\n        int const window_size_left, window_size_right, attention_chunk;\n        float const softcap_val;\n        int const num_batch;\n        int* const dq_semaphore;\n        int const* const cu_seqlens_q = nullptr;\n        int const* const cu_seqlens_k = nullptr;\n        int const* const seqused_q = nullptr;\n        int const* const seqused_k = nullptr;\n    };\n\n    // Device side kernel params\n    struct Params {\n        ShapeQKV const shape_Q;\n        ShapeQKV const shape_K;\n        ShapeQKV const shape_V;\n        ShapeQKV const shape_dO;\n        ElementAccum* const ptr_dQaccum;\n        ShapedQaccum const shape_dQaccum;\n        StridedQaccum stride_dQaccum;\n        cutlass::FastDivmod qhead_per_khead_divmod;\n        TMA_QdO tma_load_Q, tma_load_dO;\n        TMA_K tma_load_K;\n        TMA_V tma_load_V;\n        float const* const ptr_LSE_log2;\n        ShapeLSE const shape_LSE;\n        StrideLSE const stride_LSE_log2;\n        float const* const ptr_dPsum;\n        StrideLSE const stride_dPsum;\n        float const softmax_scale, softmax_scale_log2;\n        int const window_size_left, window_size_right;\n        cutlass::FastDivmod attention_chunk_divmod;\n        float const softcap_val;\n        int const num_batch;\n        int* const dq_semaphore;\n        int const* const cu_seqlens_q = nullptr;\n        int const* const cu_seqlens_k = nullptr;\n        int const* const seqused_q = nullptr;\n        int const* const seqused_k = nullptr;\n    };\n\n    static Params\n    to_underlying_arguments(Arguments const& args) {\n        Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.shape_Q, args.stride_Q);\n        TMA_QdO tma_load_Q = make_tma_copy_A_sm90(\n            GmemTiledCopyQdO{},\n            mQ,\n            SmemLayoutQ{}(_, _, _0{}),\n            TileShape_MNK{},\n            ClusterShape{}); // mcast along N mode for this M load, if any\n        Tensor mdO = make_tensor(make_gmem_ptr(args.ptr_dO), args.shape_dO, args.stride_dO);\n        TMA_QdO tma_load_dO = make_tma_copy_A_sm90(\n            GmemTiledCopyQdO{},\n            mdO,\n            SmemLayoutdO{}(_, _, _0{}),\n            TileShape_MNK{},\n            ClusterShape{}); // mcast along N mode for this M load, if any\n        Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.shape_K, args.stride_K);\n        TMA_K tma_load_K = make_tma_copy_B_sm90(\n            GmemTiledCopyKV{},\n            mK,\n            SmemLayoutK{},\n            TileShape_MNK{},\n            ClusterShape{}); // no mcast for KV\n        Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.shape_V, args.stride_V);\n        TMA_V tma_load_V = make_tma_copy_B_sm90(\n            GmemTiledCopyKV{},\n            mV,\n            SmemLayoutV{},\n            TileShape_MNK{},\n            ClusterShape{}); // no mcast for KV\n        if constexpr (Deterministic) { assert(args.dq_semaphore != nullptr); }\n        // Avoid dividing by zero\n        cutlass::FastDivmod attention_chunk_divmod(args.attention_chunk >= 1 ? args.attention_chunk : 1);\n        attention_chunk_divmod.divisor = args.attention_chunk;\n        // If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val.\n        // Right after this, we multiply by log2(e) before applying exp2.\n        // To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val\n        // (assigning it to params.softcap_val) and pre-multiply softcap_val * log2(e)\n        // (assigning it to params.softmax_scale_log2).\n        // In the backward, we need to multiply by\n        // (1 - tanh^2) * softmax_scale / softcap_val * softcap_val = (1 - tanh^2) * softmax_scale.\n        // Instead we multiply by (1 - tanh^2) and multiply dK and dV by params.softmax_scale\n        // (the original softmax_scale) at the end.\n        return {args.shape_Q, args.shape_K, \n                args.shape_V, args.shape_dO,\n                args.ptr_dQaccum, args.shape_dQaccum, args.stride_dQaccum,\n                cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))),\n                tma_load_Q, tma_load_dO, tma_load_K, tma_load_V,\n                args.ptr_LSE_log2, args.shape_LSE, args.stride_LSE_log2, args.ptr_dPsum, args.stride_dPsum,\n                args.softmax_scale,\n                !Has_softcap ? float(args.softmax_scale * M_LOG2E) : float(args.softcap_val * M_LOG2E),\n                args.window_size_left, args.window_size_right, attention_chunk_divmod,\n                !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val,\n                args.num_batch, args.dq_semaphore,\n                args.cu_seqlens_q, args.cu_seqlens_k, args.seqused_q, args.seqused_k};\n    }\n\n    /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance\n    CUTLASS_DEVICE\n    static void prefetch_tma_descriptors(Params const& params) {\n        cute::prefetch_tma_descriptor(params.tma_load_Q.get_tma_descriptor());\n        cute::prefetch_tma_descriptor(params.tma_load_dO.get_tma_descriptor());\n        cute::prefetch_tma_descriptor(params.tma_load_K.get_tma_descriptor());\n        cute::prefetch_tma_descriptor(params.tma_load_V.get_tma_descriptor());\n    }\n\n    template <typename SchedulerPrefetch, typename SharedStorage>\n    CUTLASS_DEVICE void\n    load(Params const& params,\n         MainloopPipeline pipeline_q,\n         MainloopPipeline_dO pipeline_do,\n         PipelineState& smem_pipe_write,\n         PipelineState_dO& smem_pipe_write_do,\n         SharedStorage &shared_storage,\n         SchedulerPrefetch const& scheduler_prefetch,\n         cute::tuple<int32_t, int32_t, int32_t> block_coord\n         ) {\n\n        auto [n_block, bidh, bidb] = block_coord;\n        SeqlenInfo_t seqlen_info{\n            bidb, get<0>(params.shape_Q), size<0>(params.shape_K),\n            params.cu_seqlens_q, params.cu_seqlens_k, params.seqused_q, params.seqused_k\n        };\n        auto [m_block_min, m_block_max] = BlockMN_t::get_m_block_min_max(\n            seqlen_info, n_block, bidb,\n            params.window_size_left, params.window_size_right, 0 /*sink_token_length*/);\n        // It's possible to have m_block_max <= m_block_min. Loading Q, K can cause illegal memory access.\n        if constexpr (Is_causal || Is_local || Varlen) {\n            if (m_block_max <= m_block_min) {\n                scheduler_prefetch();\n                return;\n            }\n        }\n\n        Tensor sQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQ{});\n        Tensor sdO = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_do.data()), SmemLayoutdO{});\n        Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{});\n        Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutV{});\n        Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_lse.data()), SmemLayoutLSE{});\n        Tensor sdPsum = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_dpsum.data()), SmemLayoutLSE{});\n\n        int bidh_kv = params.qhead_per_khead_divmod.divide(bidh);\n\n        // Prepare the TMA loads\n        uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();\n        constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());\n        uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};\n        bool const is_varlen_q = Varlen && params.cu_seqlens_q;\n        bool const is_varlen_k = Varlen && params.cu_seqlens_k;\n        Tensor mQ = params.tma_load_Q.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0);\n        Tensor mdO = params.tma_load_dO.get_tma_tensor(params.shape_dO)(_, _, bidh, !is_varlen_q ? bidb : 0);\n        Tensor mK = params.tma_load_K.get_tma_tensor(params.shape_K)(_, _, bidh_kv, !is_varlen_k ? bidb : 0);\n        Tensor mV = params.tma_load_V.get_tma_tensor(params.shape_V)(_, _, bidh_kv, !is_varlen_k ? bidb : 0);\n        Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE_log2), params.shape_LSE, params.stride_LSE_log2)(_, bidh, !is_varlen_q ? bidb : 0);\n        Tensor mdPsum = make_tensor(make_gmem_ptr(params.ptr_dPsum), params.shape_LSE, params.stride_dPsum)(_, bidh, !is_varlen_q ? bidb : 0);\n\n        Tensor gQ = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQ), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{}));  // (M, K, _)\n        Tensor gdO = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mdO), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{}));  // (M, K, _)\n        Tensor gK = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{}));  // (N, K)\n        Tensor gV = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{}));  // (N, K)\n        Tensor gLSE = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded), mLSE), select<0>(TileShape_MNK{}), make_coord(_));  // (M, _)\n        Tensor gdPsum = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded), mdPsum), select<0>(TileShape_MNK{}), make_coord(_));  // (M, _)\n\n        Tensor sK_x = make_tensor(sK.data(), make_layout(sK.layout(), Layout<_1>{}));\n        Tensor gK_x = make_tensor(gK.data(), make_layout(gK.layout(), Layout<_1>{}));\n        Tensor sV_x = make_tensor(sV.data(), make_layout(sV.layout(), Layout<_1>{}));\n        Tensor gV_x = make_tensor(gV.data(), make_layout(gV.layout(), Layout<_1>{}));\n        // auto [tQgQ, tQsQ] = tma_partition(params.tma_load_Q, block_rank_in_cluster, Layout<ClusterShape>{},\n        //                                   group_modes<0, 2>(sQ), group_modes<0, 2>(gQ));  // (TMA, k), (TMA, PIPE)\n        // auto [tdOgdO, tdOsdO] = tma_partition(params.tma_load_dO, block_rank_in_cluster, Layout<ClusterShape>{},\n        //                                   group_modes<0, 2>(sdO), group_modes<0, 2>(gdO));  // (TMA, k), (TMA, PIPE)\n        auto block_tma_Q = params.tma_load_Q.get_slice(cluster_local_block_id.y);\n        auto block_tma_dO = params.tma_load_dO.get_slice(cluster_local_block_id.y);\n        Tensor tQgQ = group_modes<0, 3>(block_tma_Q.partition_S(gQ));\n        Tensor tQsQ = group_modes<0, 3>(block_tma_Q.partition_D(sQ));\n        Tensor tdOgdO = group_modes<0, 3>(block_tma_dO.partition_S(gdO));\n        Tensor tdOsdO = group_modes<0, 3>(block_tma_dO.partition_D(sdO));\n        auto [tKgK, tKsK] = tma_partition(params.tma_load_K, _0{}, Layout<_1>{},\n                                          group_modes<0, 2>(sK_x), group_modes<0, 2>(gK_x));  // (TMA), (TMA)\n        auto [tVgV, tVsV] = tma_partition(params.tma_load_V, _0{}, Layout<_1>{},\n                                          group_modes<0, 2>(sV_x), group_modes<0, 2>(gV_x));  // (TMA), (TMA)\n        auto bulk_copy = Copy_Traits<SM90_BULK_COPY_AUTO>{};\n\n        uint16_t mcast_mask_qdo = 0;\n        if constexpr (cute::is_same_v<GmemTiledCopyQdO, SM90_TMA_LOAD_MULTICAST>) {\n            auto block_layout = Layout<ClusterShape>{}; // (m,n) -> block_id\n            for (int n = 0; n < size<1>(block_layout); ++n) {\n                mcast_mask_qdo |= (uint16_t(1) << block_layout(cluster_local_block_id.x, n, _0{}));\n            }\n        }\n\n        int m_block = m_block_min;\n\n        int lane_predicate = cute::elect_one_sync();\n\n        if (lane_predicate) {\n            pipeline_q.producer_acquire(smem_pipe_write);\n            copy(params.tma_load_Q.with(*pipeline_q.producer_get_barrier(smem_pipe_write), mcast_mask_qdo, TMA::CacheHintSm90::EVICT_LAST),\n                 tQgQ(_, m_block), tQsQ(_, smem_pipe_write.index()));\n            copy(bulk_copy.with(*pipeline_q.producer_get_barrier(smem_pipe_write)),\n                 gLSE(_, m_block), sLSE(_, smem_pipe_write.index()));\n        }\n\n        // // Wait for the MMA warpgroups to say that smem_k and smem_v are ready\n        // cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<uint32_t>(BwdNamedBarriers::KVEmpty) /*id*/);\n\n        if (lane_predicate) {\n            // Copy K tile and V tile from GMEM to SMEM.\n            shared_storage.pipelines.barrier_KV.arrive_and_expect_tx(TmaTransactionBytesK + TmaTransactionBytesV);\n            copy(params.tma_load_K.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.pipelines.barrier_KV), 0 /*mcast_mask*/), tKgK, tKsK);\n            copy(params.tma_load_V.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.pipelines.barrier_KV), 0 /*mcast_mask*/), tVgV, tVsV);\n\n            #pragma unroll (kHeadDim < 256 ? 2 : 1)\n            for (; m_block < m_block_max - 1; ++m_block) {\n                // If Q and dO have the same number of stages, we can use the same pipeline state variable\n                // to reduce registers\n                PipelineState_dO smem_pipe_write_do_cur = cute::conditional_return<Q_dO_same_stages>(smem_pipe_write, smem_pipe_write_do);\n                pipeline_do.producer_acquire(smem_pipe_write_do_cur);\n                copy(params.tma_load_dO.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do_cur), mcast_mask_qdo, TMA::CacheHintSm90::EVICT_LAST),\n                     tdOgdO(_, m_block), tdOsdO(_, smem_pipe_write_do_cur.index()));\n                copy(bulk_copy.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do_cur)),\n                     gdPsum(_, m_block), sdPsum(_, smem_pipe_write_do_cur.index()));\n                if constexpr (!Q_dO_same_stages) { ++smem_pipe_write_do; }\n                ++smem_pipe_write;\n                pipeline_q.producer_acquire(smem_pipe_write);\n                copy(params.tma_load_Q.with(*pipeline_q.producer_get_barrier(smem_pipe_write), mcast_mask_qdo, TMA::CacheHintSm90::EVICT_LAST),\n                     tQgQ(_, m_block + 1), tQsQ(_, smem_pipe_write.index()));\n                copy(bulk_copy.with(*pipeline_q.producer_get_barrier(smem_pipe_write)),\n                     gLSE(_, m_block + 1), sLSE(_, smem_pipe_write.index()));\n            }\n        }\n        scheduler_prefetch();\n        if (lane_predicate) {\n            PipelineState_dO smem_pipe_write_do_cur = cute::conditional_return<Q_dO_same_stages>(smem_pipe_write, smem_pipe_write_do);\n            pipeline_do.producer_acquire(smem_pipe_write_do_cur);\n            copy(params.tma_load_dO.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do_cur), mcast_mask_qdo, TMA::CacheHintSm90::EVICT_LAST),\n                 tdOgdO(_, m_block), tdOsdO(_, smem_pipe_write_do_cur.index()));\n            copy(bulk_copy.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do_cur)),\n                 gdPsum(_, m_block), sdPsum(_, smem_pipe_write_do_cur.index()));\n            if constexpr (!Q_dO_same_stages) { ++smem_pipe_write_do; }\n            ++smem_pipe_write;\n        }\n        if constexpr (Q_dO_same_stages) { smem_pipe_write_do = smem_pipe_write; }\n    }\n\n    /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster\n    CUTLASS_DEVICE void\n    load_tail(MainloopPipeline pipeline_q, MainloopPipeline_dO pipeline_do,\n              PipelineState& smem_pipe_write) {\n        static_assert(Q_dO_same_stages, \"Q and dO must have the same number of stages\");\n        // Need to copy since pipeline_q.producer_tail(smem_pipe_write) will increment smem_pipe_write\n        PipelineState smem_pipe_write_do = smem_pipe_write;\n        // Issue the epilogue waits\n        if (cute::elect_one_sync()) {\n            /* This helps avoid early exit of blocks in Cluster\n            * Waits for all stages to either be released (all Consumer UNLOCKs), or if the stage was never used\n            * then would just be acquired since the phase was still inverted from make_producer_start_state\n            */\n            pipeline_q.producer_tail(smem_pipe_write);\n            pipeline_do.producer_tail(smem_pipe_write_do);\n        }\n    }\n\n    /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster\n    CUTLASS_DEVICE void\n    load_tail(MainloopPipeline pipeline_q, MainloopPipeline_dO pipeline_do,\n              PipelineState& smem_pipe_write, PipelineState_dO& smem_pipe_write_do) {\n        // Issue the epilogue waits\n        if (cute::elect_one_sync()) {\n            /* This helps avoid early exit of blocks in Cluster\n            * Waits for all stages to either be released (all Consumer UNLOCKs), or if the stage was never used\n            * then would just be acquired since the phase was still inverted from make_producer_start_state\n            */\n            pipeline_q.producer_tail(smem_pipe_write);\n            pipeline_do.producer_tail(smem_pipe_write_do);\n        }\n    }\n\n    template <typename SharedStorage>\n    CUTLASS_DEVICE void\n    store_dq(Params const& params,\n             SharedStorage &shared_storage,\n             cute::tuple<int32_t, int32_t, int32_t> block_coord\n             ) {\n        if constexpr (!dQacc_use_TMA) { return; }\n\n        auto [n_block, bidh, bidb] = block_coord;\n        SeqlenInfo_t seqlen_info{\n            bidb, get<0>(params.shape_Q), size<0>(params.shape_K),\n            params.cu_seqlens_q, params.cu_seqlens_k, params.seqused_q, params.seqused_k\n        };\n        auto [m_block_min, m_block_max] = BlockMN_t::get_m_block_min_max(\n            seqlen_info, n_block, bidb, params.window_size_left,\n            params.window_size_right, 0 /*sink_token_length*/);\n        // It's possible to have m_block_max <= m_block_min. Exit early\n        // Though if local and deterministic, still need to increment dq semaphore\n        if constexpr ((Is_causal || Is_local || Varlen) && !(Is_local && Deterministic)) {\n            if (m_block_max <= m_block_min) { return; }\n        }\n\n        Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_dqacc.data()), SmemLayoutdQaccum{});\n        static constexpr int dQ_TMA_num_bytes = CUTE_STATIC_V(size<0>(sdQ)) * sizeof(ElementAccum);\n\n        bool const is_varlen = Varlen && params.cu_seqlens_q;\n        Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.ptr_dQaccum)),\n                                      params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0);\n        Tensor gdQaccum_ = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded * kHeadDim), mdQaccum), Shape<Int<kBlockM * kHeadDim>>{}, make_coord(_));  // (M * K, _)\n        Tensor gdQaccum = cute::flat_divide(gdQaccum_, Int<kBlockM * kHeadDim / NumMmaWarpGroups>{});  // (M * K / WG, WG, _)\n\n        int const num_batch = params.num_batch;\n        int const num_head = get<2>(params.shape_Q);\n        int *lock_ptr = !Deterministic ? nullptr : params.dq_semaphore + bidb * num_head + bidh;\n        using Barrier = cutlass::GenericBarrier<cutlass::detail::SyncwarpSync>;\n        bool const lane_predicate = cute::elect_one_sync();\n        int m_block = m_block_min;\n        constexpr int kBlockM = get<0>(TileShape_MNK{});\n        constexpr int kBlockN = get<1>(TileShape_MNK{});\n        int n_block_global_max = cute::ceil_div(seqlen_info.seqlen_k, kBlockN);\n        #pragma unroll 2\n        for (; m_block < m_block_max; ++m_block) {\n            if constexpr (Deterministic) {\n                if constexpr(Is_causal) {\n                    int n_block_max_for_m_block = std::min(n_block_global_max, cute::ceil_div((m_block + 1) * kBlockM + seqlen_info.seqlen_k - seqlen_info.seqlen_q, kBlockN));\n                    Barrier::wait_eq(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head, n_block_max_for_m_block - 1 - n_block);\n                } else {\n                    Barrier::wait_eq(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head, n_block);\n                }\n            }\n            #pragma unroll\n            for (int warpgroup_idx = 0; warpgroup_idx < NumMmaWarpGroups; ++warpgroup_idx) {\n                cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup + cutlass::NumThreadsPerWarp, static_cast<uint32_t>(BwdNamedBarriers::dQFullWG1) + warpgroup_idx /*id*/);  // sdQ full, to be written to gmem\n                if (lane_predicate) {\n                    SM90_BULK_REDUCE_ADD::copy(raw_pointer_cast(sdQ(_, warpgroup_idx).data()), raw_pointer_cast(gdQaccum(_, warpgroup_idx, m_block).data()), dQ_TMA_num_bytes, static_cast<uint64_t>(TMA::CacheHintSm90::EVICT_LAST));\n                    tma_store_arrive();\n                }\n            }\n            // Note, the for_each() function is required here to ensure `warpgroup_idx` is of type Int<x>.\n            for_each(make_int_sequence<NumMmaWarpGroups>{}, [&] (auto warpgroup_idx) {\n                if (lane_predicate) { tma_store_wait<NumMmaWarpGroups - 1 - CUTE_STATIC_V(warpgroup_idx)>(); }\n                cutlass::arch::NamedBarrier::arrive(cutlass::NumThreadsPerWarpGroup + cutlass::NumThreadsPerWarp, static_cast<uint32_t>(BwdNamedBarriers::dQEmptyWG1) + warpgroup_idx /*id*/);  // sdQ empty, ready to be written to\n            });\n            if constexpr (Deterministic) {\n                Barrier::arrive_inc(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head);\n            }\n        }\n        if constexpr (Is_local && Deterministic) {\n            int const m_block_global_max = cute::ceil_div(seqlen_info.seqlen_q, kBlockM);\n            #pragma unroll 2\n            for (; m_block < m_block_global_max; ++m_block) {\n                Barrier::arrive_inc(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head);\n            }\n        }\n    }\n\n    CUTLASS_DEVICE void\n    mma_init() {\n        // We're not currently using this bc we're not using persistent scheduler\n        // // Tell producer (warp 0) that smem_k and smem_v are ready\n        // cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<uint32_t>(BwdNamedBarriers::KVEmpty) /*id*/);\n        int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);\n        if constexpr (dQacc_use_TMA) {\n            if (warp_idx_in_warpgroup == 0) {\n                cutlass::arch::NamedBarrier::arrive(cutlass::NumThreadsPerWarpGroup + cutlass::NumThreadsPerWarp, static_cast<uint32_t>(BwdNamedBarriers::dQEmptyWG1) - 1 + flash::canonical_warp_group_idx_nosync() /*id*/);  // sdQ empty, ready to be written to\n            }\n        }\n    }\n\n    template <typename SharedStorage, typename FrgTensordKV>\n    CUTLASS_DEVICE bool\n    mma(Params const& params,\n        MainloopPipeline pipeline_q,\n        MainloopPipeline_dO pipeline_do,\n        PipelineState& smem_pipe_read,\n        PipelineState_dO& smem_pipe_read_do,\n        FrgTensordKV& tdKrdK,\n        FrgTensordKV& tdVrdV,\n        int thread_idx,\n        int &work_idx,\n        cute::tuple<int32_t, int32_t, int32_t> block_coord,\n        SharedStorage& shared_storage\n        ) {\n        static_assert(is_rmem<FrgTensordKV>::value, \"dK and dV tensor must be rmem resident.\");\n\n        int n_block = get<0>(block_coord);\n        int bidb = get<2>(block_coord);\n        SeqlenInfo_t seqlen_info{\n            bidb, get<0>(params.shape_Q), size<0>(params.shape_K),\n            params.cu_seqlens_q, params.cu_seqlens_k, params.seqused_q, params.seqused_k\n        };\n        auto [m_block_min, m_block_max] = BlockMN_t::get_m_block_min_max(\n            seqlen_info, n_block, bidb, params.window_size_left,\n            params.window_size_right, 0 /*sink_token_length*/);\n        // It's possible to have m_block_max <= m_block_min. Exit early\n        if constexpr (Is_causal || Is_local || Varlen) {\n            if (m_block_max <= m_block_min) { return false; }\n        }\n\n        Tensor sQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQ{});\n        Tensor sdO = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_do.data()), SmemLayoutdO{});\n        Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{});\n        Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutV{});\n        Tensor sQt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQt{});\n        Tensor sdOt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_do.data()), SmemLayoutdOt{});\n        Tensor sKt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutKt{});\n        Tensor sP = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutPdS{});\n        Tensor sP_pi = cute::as_position_independent_swizzle_tensor(sP);\n        Tensor sPt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutPdSt{});\n        Tensor sPt_pi = cute::as_position_independent_swizzle_tensor(sPt);\n        Tensor sdS = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_ds.data()), SmemLayoutPdS{});\n        Tensor sdS_pi = cute::as_position_independent_swizzle_tensor(sdS);\n        Tensor sdSt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_ds.data()), SmemLayoutPdSt{});\n        Tensor sdSt_pi = cute::as_position_independent_swizzle_tensor(sdSt);\n        Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_dqacc.data()), SmemLayoutdQaccum{});\n        Tensor sLSEMma = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_lse.data()), SmemLayoutLSEMma{});\n        Tensor sdPsumMma = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_dpsum.data()), SmemLayoutLSEMma{});\n\n        static_assert(stride<0>(typename TiledMmaSdP::ALayout{}) == 0 and\n                      stride<0>(typename TiledMmaSdP::BLayout{}) == 0 and\n                      size<0>(typename TiledMmaSdP::ALayout{}) == cutlass::NumThreadsPerWarpGroup and\n                      size<0>(typename TiledMmaSdP::BLayout{}) == cutlass::NumThreadsPerWarpGroup,\n                      \"Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup\");\n        constexpr int MmaWarpGroups = NumMmaThreads / cutlass::NumThreadsPerWarpGroup;\n        Layout warp_group_thread_layout = make_layout(make_shape(Int<MmaWarpGroups>{}),\n                                                      make_stride(Int<cutlass::NumThreadsPerWarpGroup>{}));\n        Layout warp_group_thread_layout_dq = make_layout(make_shape(Int<NumMmaWarpGroups>{}),\n                                                      make_stride(Int<cutlass::NumThreadsPerWarpGroup>{}));\n\n        int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0);\n        TiledMmaSdP tiled_mma_SdP;\n        using TiledMmadP = std::conditional_t<!Mma_dP_is_RS, TiledMmaSdP, TiledMmadPRS>;\n        TiledMmadP tiled_mma_dP;\n        TiledMmadKV tiled_mma_dKV;\n        TiledMmadQ tiled_mma_dQ;\n\n        auto wg_mma_SdP = tiled_mma_SdP.get_slice(warp_group_thread_layout(warp_group_idx));\n        auto wg_mma_dP = tiled_mma_dP.get_slice(warp_group_thread_layout(warp_group_idx));\n        auto thread_mma_SdP = tiled_mma_SdP.get_thread_slice(thread_idx);\n        auto wg_mma_dKV = tiled_mma_dKV.get_slice(warp_group_thread_layout(warp_group_idx));\n        auto wg_mma_dQ = tiled_mma_dQ.get_slice(warp_group_thread_layout_dq(warp_group_idx));\n\n        auto smem_tiled_copy_PdS = make_tiled_copy_C(SmemCopyAtomPdS{}, tiled_mma_SdP);\n        auto smem_thr_copy_PdS = smem_tiled_copy_PdS.get_thread_slice(thread_idx);\n\n        R2STiledCopydQaccum r2s_tiled_copy_dQaccum;\n        auto r2s_thr_copy_dQaccum = r2s_tiled_copy_dQaccum.get_thread_slice(thread_idx);\n        Tensor tdQsdQaccum = r2s_thr_copy_dQaccum.partition_D(sdQ);\n        // if (thread_idx == 0) { print(sdQ); printf(\"\\n\"); print(tdQsdQaccum); printf(\"\\n\"); }\n\n        // Allocate \"fragments/descriptors\"\n        // We have to use the templated mma_partition_fragment_AB instead of cute::conditional_return or lambda,\n        // because some partition_fragment_A/B don't compile.\n        // https://stackoverflow.com/questions/50051473/if-constexpr-in-c17-does-not-work-in-a-non-templated-function\n        Tensor tSrQ = mma_partition_fragment_AB</*A=*/!SdP_swapAB>(wg_mma_SdP, sQ);\n        Tensor tSrK = mma_partition_fragment_AB</*A=*/SdP_swapAB>(wg_mma_SdP, sK);\n        Tensor tdPrdO = mma_partition_fragment_AB</*A=*/!SdP_swapAB>(wg_mma_SdP, sdO);\n        Tensor tdPrV = mma_partition_fragment_AB</*A=*/SdP_swapAB>(wg_mma_dP, sV);\n        Tensor tdVrdO = mma_partition_fragment_AB</*A=*/dKV_swapAB>(wg_mma_dKV, sdOt);\n        Tensor tdKrQ = mma_partition_fragment_AB</*A=*/dKV_swapAB>(wg_mma_dKV, sQt);\n        Tensor tdQrdS = mma_partition_fragment_AB</*A=*/!dQ_swapAB>(wg_mma_dQ, sdS);\n        Tensor tdQrK = mma_partition_fragment_AB</*A=*/dQ_swapAB>(wg_mma_dQ, sKt);\n\n        Tensor tPsP = smem_thr_copy_PdS.partition_D(cute::conditional_return<!SdP_swapAB>(sP_pi, sPt_pi));      // ((Atom,AtomNum),PIPE_M,PIPE_N)\n        Tensor tdSsdS = smem_thr_copy_PdS.partition_D(cute::conditional_return<!SdP_swapAB>(sdS_pi, sdSt_pi));      // ((Atom,AtomNum),PIPE_M,PIPE_N)\n        // if (blockIdx.x == 0 && threadIdx.x == 128) { print(smem_thr_copy_PdS); print(sP_pi); printf(\"\\n\"); print(sPt_pi); printf(\"\\n\"); print(tPsP); printf(\"\\n\"); print(tdSsdS); printf(\"\\n\"); }\n\n        // thread_mma_SdP.partition_C(sLSEMma) has shape ((2, 2, V), MMA_M, MMA_N, PIPE), we only take the col indices\n        // or row indices, depending on whether SdP_swapAB.\n        Tensor tLSEsLSE = cute::conditional_return<!SdP_swapAB>(\n            group_modes<0, 2>(thread_mma_SdP.partition_C(sLSEMma)(make_coord(_0{}, _, _0{}), _, _0{}, _)),  // (2, MMA_M, PIPE)\n            group_modes<0, 3>(thread_mma_SdP.partition_C(sLSEMma)(make_coord(_, _0{}, _), _0{}, _, _)));  // (2, V, MMA_N, PIPE)\n        Tensor tLSEsdPsum = cute::conditional_return<!SdP_swapAB>(\n            group_modes<0, 2>(thread_mma_SdP.partition_C(sdPsumMma)(make_coord(_0{}, _, _0{}), _, _0{}, _)),\n            group_modes<0, 3>(thread_mma_SdP.partition_C(sdPsumMma)(make_coord(_, _0{}, _), _0{}, _, _)));\n        // if (blockIdx.x == 0 && threadIdx.x == 128) { print(sLSEMma); printf(\"\\n\"); print(tLSEsLSE); printf(\"\\n\"); }\n        // If we want to split the stats among the 8 threads that share the same rows.\n        static constexpr int kStatsPerThread = cute::ceil_div(decltype(size(tLSEsLSE))::value, 8);\n\n        auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {\n            auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);\n            pipeline.consumer_wait(smem_pipe_read, barrier_token);\n        };\n\n        int bidh = get<1>(block_coord);\n        int const seqlen_q = seqlen_info.seqlen_q;\n        int const seqlen_k = seqlen_info.seqlen_k;\n\n        // For the case where we do atomicAdd directly to gdQaccum instead of using TMA\n        bool const is_varlen = Varlen && params.cu_seqlens_q;\n        Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.ptr_dQaccum)),\n                                      params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0);\n        Tensor gdQaccum_ = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded * kHeadDim), mdQaccum), Shape<Int<kBlockM * kHeadDim>>{}, make_coord(_));  // (M * K, _)\n        Tensor gdQaccum = cute::flat_divide(gdQaccum_, Int<kBlockM * kHeadDim / NumMmaWarpGroups>{});  // (M * K / WG, WG, _)\n        // We can reuse r2s_thr_copy_dQaccum for this partitioning\n        Tensor tdQgdQaccum = r2s_thr_copy_dQaccum.partition_D(gdQaccum);\n        // if (blockIdx.x == 0 && threadIdx.x == 128) { print(mdQaccum); printf(\"\\n\"); print(gdQaccum_); printf(\"\\n\"); print(gdQaccum); printf(\"\\n\"); print(tdQgdQaccum); printf(\"\\n\"); }\n\n        flash::Mask<kBlockM, kBlockN, false /*PackGQA*/, TiledMmaSdP, SdP_swapAB> mask(\n            thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, 0 /*sink_token_length*/,\n            params.attention_chunk_divmod, params.qhead_per_khead_divmod\n        );\n\n        int m_block = m_block_min;\n\n        clear(tdKrdK);\n        clear(tdVrdV);\n        // tiled_mma_dKV.accumulate_ = GMMA::ScaleOut::Zero;\n\n        cutlass::ConsumerToken barrier_token = static_cast<cutlass::BarrierStatus>(shared_storage.pipelines.barrier_KV.try_wait(work_idx % 2));\n        if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.pipelines.barrier_KV.wait(work_idx % 2); }\n\n        if constexpr (Mma_dP_is_RS) {\n            using SmemCopyAtomV = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;\n            auto smem_tiled_copy_V = make_tiled_copy_A(SmemCopyAtomV{}, tiled_mma_dP);\n            auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(thread_idx);\n            Tensor tdPrV_copy_view = smem_thr_copy_V.retile_D(tdPrV);\n            Tensor tdPsV_copy_view = smem_thr_copy_V.partition_S(cute::as_position_independent_swizzle_tensor(sV));\n            cute::copy(smem_tiled_copy_V, tdPsV_copy_view, tdPrV_copy_view);\n        }\n\n        auto bwd_step = [&](int m_block, auto mask_fn) {\n            Tensor tSrS = partition_fragment_C(tiled_mma_SdP, select<!SdP_swapAB ? 0 : 1, !SdP_swapAB ? 1 : 0>(TileShape_MNK{}));\n            consumer_wait(pipeline_q, smem_pipe_read);\n            flash::gemm</*zero_init=*/true, /*wg_wait=*/-1, /*SwapAB=*/SdP_swapAB>(tiled_mma_SdP, tSrQ(_, _, _, smem_pipe_read.index()), tSrK, tSrS);\n            Tensor tLSErLSE = cute::conditional_return<!ShuffleLSE>(make_fragment_like(tLSEsLSE(_, _0{})), make_tensor<ElementAccum>(Int<kStatsPerThread>{}));\n            if constexpr (!ShuffleLSE) {\n                cute::copy(tLSEsLSE(_, smem_pipe_read.index()), tLSErLSE);\n            } else {\n                #pragma unroll\n                for (int i = 0; i < kStatsPerThread; ++i) {\n                    // It's ok to read OOB, since we made sure sLSE is large enough and we won't use the OOB values\n                    tLSErLSE(i) = tLSEsLSE((thread_idx % 32) / 4 + i * 8, smem_pipe_read.index());\n                }\n            }\n            Tensor tdPrdP = partition_fragment_C(tiled_mma_SdP, select<!SdP_swapAB ? 0 : 1, !SdP_swapAB ? 1 : 0>(TileShape_MNK{}));\n            PipelineState_dO smem_pipe_read_do_cur = cute::conditional_return<Q_dO_same_stages>(smem_pipe_read, smem_pipe_read_do);\n            consumer_wait(pipeline_do, smem_pipe_read_do_cur);\n            flash::gemm</*zero_init=*/true, /*wg_wait=*/-1, /*SwapAB=*/SdP_swapAB>(tiled_mma_dP, tdPrdO(_, _, _, smem_pipe_read_do_cur.index()), tdPrV, tdPrdP);\n            warpgroup_wait<1>();\n            if constexpr (Has_softcap) { flash::apply_softcap(tSrS, params.softcap_val); }\n\n            // Reshape tSrS from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N))\n            Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_rowcol</*Transposed=*/SdP_swapAB>(tSrS.layout()));\n            // dtanh needs to happen before masking, otherwise we get 1 - (-inf)^2 = NaN in the dtanh\n            auto dtanh = [&] { if constexpr (Has_softcap) return flash::calculate_dtanh(scores); else return nullptr; }();\n            mask_fn(tSrS, m_block);\n            #pragma unroll\n            for (int mi = 0; mi < size<0>(scores); ++mi) {\n                float const lse_scaled = [&] {\n                    if constexpr (!ShuffleLSE) return tLSErLSE(mi);\n                    else return __shfl_sync(0xffffffff, tLSErLSE(mi / 8), (mi % 8) * 4 + (thread_idx % 4));\n                }();\n                #pragma unroll\n                for (int ni = 0; ni < size<1>(scores); ++ni) {\n                    scores(mi, ni) = exp2f(scores(mi, ni) * params.softmax_scale_log2 - lse_scaled);\n                }\n            }\n\n            Tensor tLSErdPsum = cute::conditional_return<!ShuffledPsum>(make_fragment_like(tLSEsdPsum(_, _0{})), make_tensor<ElementAccum>(Int<kStatsPerThread>{}));\n            if constexpr (!ShuffledPsum) {\n                cute::copy(tLSEsdPsum(_, smem_pipe_read_do_cur.index()), tLSErdPsum);\n            } else {\n                #pragma unroll\n                for (int i = 0; i < kStatsPerThread; ++i) {\n                    tLSErdPsum(i) = tLSEsdPsum((thread_idx % 32) / 4 + i * 8, smem_pipe_read_do_cur.index());\n                }\n            }\n\n            warpgroup_wait<0>();\n            // Reshape tdPrdP from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N))\n            Tensor dS = make_tensor(tdPrdP.data(), scores.layout());\n            #pragma unroll\n            for (int mi = 0; mi < size<0>(dS); ++mi) {\n                float const dP_sum_cur = [&] {\n                    if constexpr (!ShuffledPsum) return tLSErdPsum(mi);\n                    else return __shfl_sync(0xffffffff, tLSErdPsum(mi / 8), (mi % 8) * 4 + (thread_idx % 4));\n                }();\n                #pragma unroll\n                for (int ni = 0; ni < size<1>(dS); ++ni) {\n                    dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - dP_sum_cur);\n                    if constexpr (Has_softcap) { dS(mi, ni) *= dtanh(mi, ni); }\n                }\n            }\n\n            // Convert scores from fp32 to fp16/bf16\n            Tensor rP = make_tensor_like<Element>(tSrS);\n            flash::convert_type_out(tSrS, rP);\n            if constexpr (!Mma_dKV_is_RS) {\n                // Need to sync to make sure P has already been used in the previous iteration before writing new values\n                if constexpr (kStages_dS == 1) {\n                    cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<uint32_t>(BwdNamedBarriers::PdS) /*id*/);\n                }\n                Tensor tPaP = smem_thr_copy_PdS.retile_S(rP);     // ((Atom,AtomNum), MMA_N, MMA_N)\n                cute::copy(smem_tiled_copy_PdS, tPaP, tPsP(_, _, _, cute::conditional_return<kStages_dS==1>(_0{}, smem_pipe_read.index())));\n            }\n            Tensor rdS = make_tensor_like<Element>(tdPrdP);\n            flash::convert_type_out(tdPrdP, rdS);\n            // If there's double buffering on dS, we don't need to sync here.\n            // Otherwise we might have WG1 writing to dS before WG2 is done reading from it during MmadQ.\n            // But because both WGs have to sync at the end of the loop and double buffering,\n            // this race condition is not possible.\n            // This sync is to ensure (1) P is written in case of !Mma_dKV_is_RS and\n            // (2) dS is already read by the Mma in the previous iteration in case of Mma_dKV_is_RS.\n            if constexpr (!Mma_dKV_is_RS || (kStages_dS == 1 && Mma_dKV_is_RS)) {\n                cutlass::arch::fence_view_async_shared();\n                cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<uint32_t>(BwdNamedBarriers::PdS) /*id*/);\n            }\n            // For hdim 64, It's faster to write to smem_dS first before the dV gemm\n            Tensor tdSadS = smem_thr_copy_PdS.retile_S(rdS);     // ((Atom,AtomNum), MMA_N, MMA_N)\n            cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS(_, _, _, cute::conditional_return<kStages_dS==1>(_0{}, smem_pipe_read.index())));\n\n            if constexpr (!Slice_dQKV_Mma) {\n                // Most cases take this path, except for hdim256 where we want to slice to reduce register pressure\n                if constexpr (Mma_dKV_is_RS) {\n                    Tensor tdVrP = make_tensor(rP.data(), convert_layout_acc_Aregs<TiledMmadKV>(tSrS.layout()));\n                    flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma_dKV, tdVrP, tdVrdO(_, _, _, smem_pipe_read_do_cur.index()), tdVrdV);\n                } else {\n                    Tensor tdVrP = mma_partition_fragment_AB</*A=*/!dKV_swapAB>(wg_mma_dKV, sPt);\n                    Tensor tdVrP_cur = tdVrP(_, _, _, cute::conditional_return<kStages_dS==1>(_0{}, smem_pipe_read.index()));\n                    flash::gemm</*zero_init=*/false, /*wg_wait=*/-1, /*SwapAB=*/dKV_swapAB>(tiled_mma_dKV, tdVrP_cur, tdVrdO(_, _, _, smem_pipe_read_do_cur.index()), tdVrdV);\n                }\n                // SMEM fence to make sure sdS is written before it's read by WGMMA\n                cutlass::arch::fence_view_async_shared();\n                cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<uint32_t>(BwdNamedBarriers::PdS) /*id*/);\n                Tensor tdQrdQ = partition_fragment_C(tiled_mma_dQ, select<!dQ_swapAB ? 0 : 2, !dQ_swapAB ? 2 : 0>(TileShape_MNK{}));\n                Tensor tdQrdS_cur = tdQrdS(_, _, _, cute::conditional_return<kStages_dS==1>(_0{}, smem_pipe_read.index()));\n                flash::gemm</*zero_init=*/true, /*wg_wait=*/1, /*SwapAB=*/dQ_swapAB>(tiled_mma_dQ, tdQrdS_cur, tdQrK, tdQrdQ);\n                pipeline_do.consumer_release(smem_pipe_read_do_cur);  // release dO\n\n                if constexpr (Mma_dKV_is_RS) {\n                    Tensor tdKrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs<TiledMmadKV>(tdPrdP.layout()));\n                    flash::gemm</*zero_init=*/false, /*wg_wait=*/1>(tiled_mma_dKV, tdKrdS, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK);\n                } else {\n                    Tensor tdKrdS = mma_partition_fragment_AB</*A=*/!dKV_swapAB>(wg_mma_dKV, sdSt);\n                    Tensor tdKrdS_cur = tdKrdS(_, _, _, cute::conditional_return<kStages_dS==1>(_0{}, smem_pipe_read.index()));\n                    flash::gemm</*zero_init=*/false, /*wg_wait=*/1, /*SwapAB=*/dKV_swapAB>(tiled_mma_dKV, tdKrdS_cur, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK);\n                }\n                if constexpr (dQacc_use_TMA) {\n                    int const warp_group_idx = flash::canonical_warp_group_idx_nosync() - 1;\n                    cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup + cutlass::NumThreadsPerWarp, static_cast<uint32_t>(BwdNamedBarriers::dQEmptyWG1) + warp_group_idx /*id*/);  // sdQ full, to be written to gmem\n                    Tensor taccdQrdQ = r2s_thr_copy_dQaccum.retile_S(tdQrdQ);\n                    cute::copy(r2s_tiled_copy_dQaccum, taccdQrdQ, tdQsdQaccum);\n                    cutlass::arch::fence_view_async_shared();\n                    cutlass::arch::NamedBarrier::arrive(cutlass::NumThreadsPerWarpGroup + cutlass::NumThreadsPerWarp, static_cast<uint32_t>(BwdNamedBarriers::dQFullWG1) + warp_group_idx /*id*/);  // sdQ full, to be written to gmem\n                } else {\n                    // We can reuse r2s_thr_copy_dQaccum for this partitioning\n                    Tensor tdQrdQ_atomic = recast<float4>(r2s_thr_copy_dQaccum.retile_S(tdQrdQ));\n                    Tensor tdQgdQaccum_atomic = recast<float4>(tdQgdQaccum(_, _, _, m_block));\n                    static_assert(CUTE_STATIC_V(size(tdQrdQ_atomic)) == CUTE_STATIC_V(size(tdQgdQaccum_atomic)));\n                    #pragma unroll\n                    for (int i = 0; i < size(tdQrdQ_atomic); ++i) { atomicAdd(&tdQgdQaccum_atomic(i), tdQrdQ_atomic(i)); }\n                }\n\n            } else {  // Slice_dQKV_Mma\n\n                static_assert(!(Slice_dQKV_Mma && Mma_dKV_is_RS));\n                Tensor tdVrP = mma_partition_fragment_AB</*A=*/!dKV_swapAB>(wg_mma_dKV, sPt);\n                Tensor tdVrP_cur = tdVrP(_, _, _, cute::conditional_return<kStages_dS==1>(_0{}, smem_pipe_read.index()));\n                flash::gemm</*zero_init=*/false, /*wg_wait=*/-1, /*SwapAB=*/dKV_swapAB, /*M_slice=*/0>(tiled_mma_dKV, tdVrP_cur, tdVrdO(_, _, _, smem_pipe_read_do_cur.index()), tdVrdV);\n\n                cutlass::arch::fence_view_async_shared();\n                cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<uint32_t>(BwdNamedBarriers::PdS) /*id*/);\n                Tensor tdQrdQ = partition_fragment_C(tiled_mma_dQ, select<!dQ_swapAB ? 0 : 2, !dQ_swapAB ? 2 : 0>(TileShape_MNK{}));\n                Tensor tdQrdS_cur = tdQrdS(_, _, _, cute::conditional_return<kStages_dS==1>(_0{}, smem_pipe_read.index()));\n                flash::gemm</*zero_init=*/true, /*wg_wait=*/-1, /*SwapAB=*/dQ_swapAB, /*M_slice=*/0>(tiled_mma_dQ, tdQrdS_cur, tdQrK, tdQrdQ);\n                flash::gemm</*zero_init=*/false, /*wg_wait=*/1, /*SwapAB=*/dKV_swapAB, /*M_slice=*/1>(tiled_mma_dKV, tdVrP_cur, tdVrdO(_, _, _, smem_pipe_read_do_cur.index()), tdVrdV);\n                Tensor tdQrdQ_atomic = recast<float4>(r2s_thr_copy_dQaccum.retile_S(tdQrdQ));\n                Tensor tdQgdQaccum_atomic = recast<float4>(tdQgdQaccum(_, _, _, m_block));\n                #pragma unroll\n                for (int i = 0; i < size(tdQrdQ_atomic) / 2; ++i) { atomicAdd(&tdQgdQaccum_atomic(i), tdQrdQ_atomic(i)); }\n\n                Tensor tdKrdS = mma_partition_fragment_AB</*A=*/!dKV_swapAB>(wg_mma_dKV, sdSt);\n                Tensor tdKrdS_cur = tdKrdS(_, _, _, cute::conditional_return<kStages_dS==1>(_0{}, smem_pipe_read.index()));\n                flash::gemm</*zero_init=*/false, /*wg_wait=*/1, /*SwapAB=*/dKV_swapAB, /*M_slice=*/0>(tiled_mma_dKV, tdKrdS_cur, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK);\n                pipeline_do.consumer_release(smem_pipe_read_do_cur);  // release dO\n\n                flash::gemm</*zero_init=*/true, /*wg_wait=*/0, /*SwapAB=*/dQ_swapAB, /*M_slice=*/1>(tiled_mma_dQ, tdQrdS_cur, tdQrK, tdQrdQ);\n                #pragma unroll\n                for (int i = size(tdQrdQ_atomic) / 2;  i < size(tdQrdQ_atomic); ++i) { atomicAdd(&tdQgdQaccum_atomic(i), tdQrdQ_atomic(i)); }\n\n                flash::gemm</*zero_init=*/false, /*wg_wait=*/-1, /*SwapAB=*/dKV_swapAB, /*M_slice=*/1>(tiled_mma_dKV, tdKrdS_cur, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK);\n            }\n\n            warpgroup_wait<0>();\n            pipeline_q.consumer_release(smem_pipe_read);   // release Q\n            ++smem_pipe_read;\n            if constexpr (!Q_dO_same_stages) { ++smem_pipe_read_do; }\n        };\n\n        // We have separate iterations with causal masking. Not necessary for hdim 128 but for hdim 64\n        // this helps quite a bit to not have to do causal masking for most of the iterations.\n        if constexpr ((Is_causal || Is_local) && SeparateMaskingIterations) {\n            auto mask_fn = [&](auto& tSrS, int m_block) { mask.template apply<true /*Seqlenk_mask*/, Is_causal, Is_local>(tSrS, m_block, n_block); };\n            static constexpr int kBlockM = get<0>(TileShape_MNK{});\n            int const m_block_masking_max = ((n_block + 1) * kBlockN - 1 + seqlen_q - seqlen_k - params.window_size_right) / kBlockM + 1;\n            CUTLASS_PRAGMA_NO_UNROLL\n            for (; m_block < std::min(m_block_max, m_block_masking_max); ++m_block) {\n                bwd_step(m_block, mask_fn);\n            }\n        }\n\n        static constexpr int kBlockM = get<0>(TileShape_MNK{});\n        static constexpr int kBlockN = get<1>(TileShape_MNK{});\n        int const m_block_max_before_local_mask = !Is_local || !SeparateMaskingIterations\n            ? m_block_max\n            : std::min(m_block_max, (n_block * kBlockN + seqlen_q - seqlen_k + params.window_size_left) / kBlockM);\n\n        auto mask_fn = [&](auto& tSrS, int m_block) { mask.template apply<true /*Seqlenk_mask*/, Is_causal && !SeparateMaskingIterations, Is_local && !SeparateMaskingIterations>(tSrS, m_block, n_block); };\n        CUTLASS_PRAGMA_NO_UNROLL\n        for (; m_block < m_block_max_before_local_mask; ++m_block) {\n            bwd_step(m_block, mask_fn);\n        }\n\n        if constexpr (Is_local && SeparateMaskingIterations) {\n            auto mask_fn = [&](auto& tSrS, int m_block) { mask.template apply<true /*Seqlenk_mask*/, false /*Causal_mask*/, Is_local>(tSrS, m_block, n_block); };\n            CUTLASS_PRAGMA_NO_UNROLL\n            for (; m_block < m_block_max; ++m_block) {\n                bwd_step(m_block, mask_fn);\n            }\n        }\n\n        // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(tdVrdV); }\n        #pragma unroll\n        for (int i = 0; i < size(tdKrdK); ++i) { tdKrdK(i) *= params.softmax_scale; }\n\n        if constexpr (Q_dO_same_stages) { smem_pipe_read_do = smem_pipe_read; }\n        ++work_idx;\n        return true;\n    }\n\n};\n\n} // namespace flash\n"
  },
  {
    "path": "hopper/mainloop_fwd_sm80.hpp",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n#include <cutlass/cutlass.h>\n#include <cutlass/array.h>\n#include <cutlass/numeric_types.h>\n#include <cutlass/numeric_conversion.h>\n\n#include \"cute/tensor.hpp\"\n\n#include \"seqlen.h\"\n#include \"block.h\"\n#include \"mask.h\"\n#include \"pack_gqa.h\"\n#include \"paged_kv.h\"\n#include \"rotary.h\"\n#include \"utils.h\"\n\nnamespace flash {\n\nusing namespace cute;\n\ntemplate <int kNWarps, int Stages, bool Q_in_regs, class TileShape_MNK_, int kHeadDimV, class Element_, class ElementAccum_, class ArchTag_,\n        bool Is_causal_, bool Is_local_, bool Has_softcap_, bool Varlen_, bool PagedKV_, bool AppendKV_,\n        bool PackGQA_, bool Split_>\nstruct CollectiveMainloopFwdSm80 {\n\n    static constexpr int kStages = Stages;\n    static_assert(kStages > 0, \"kStages must be greater than 0\");\n    using TileShape_MNK = TileShape_MNK_;\n    using TileShape_MNK_PV = Shape<decltype(get<0>(TileShape_MNK{})), Int<kHeadDimV>, decltype(get<1>(TileShape_MNK{}))>;\n    using Element = Element_;\n    using ElementAccum = ElementAccum_;\n    using ArchTag = ArchTag_;\n    static constexpr bool Is_FP8 = cute::is_same_v<Element, cutlass::float_e4m3_t> || cute::is_same_v<Element, cutlass::float_e5m2_t>;;\n    static constexpr bool Is_causal = Is_causal_;\n    static constexpr bool Is_local = Is_local_;\n    static constexpr bool Has_softcap = Has_softcap_;\n    static constexpr bool Varlen = Varlen_;\n    static constexpr bool PagedKV = PagedKV_;\n    static constexpr bool AppendKV = AppendKV_;\n    static constexpr bool PackGQA = PackGQA_;\n    static constexpr bool Split = Split_;\n    static constexpr bool Transpose_V = Is_FP8;\n\n    static_assert(ArchTag::kMinComputeCapability >= 80);\n\n    static constexpr bool Has_cp_async = ArchTag::kMinComputeCapability >= 80;\n\n    static constexpr int kBlockM = get<0>(TileShape_MNK{});\n    static constexpr int kBlockN = get<1>(TileShape_MNK{});\n    static constexpr int kHeadDim = get<2>(TileShape_MNK{});\n\n    using SeqlenInfo_t = flash::SeqlenInfoQKNewK<Varlen, AppendKV>;\n    using BlockMN_t = flash::BlockMN<SeqlenInfo_t, kBlockM, kBlockN, Is_causal, Is_local, PackGQA, Split>;\n\n    using MMA_Atom_Arch = std::conditional_t<\n        ArchTag::kMinComputeCapability >= 80,\n        std::conditional_t<\n            std::is_same_v<Element, cutlass::half_t>,\n            MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,\n            MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>\n        >,\n        MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>\n    >;\n    using TiledMma = TiledMMA<\n        MMA_Atom_Arch,\n        Layout<Shape<Int<kNWarps>,_1,_1>>,  // 4x1x1 or 8x1x1 thread group\n        Tile<Int<16 * kNWarps>, _16, _16>>;\n\n    static constexpr int NumMmaThreads = size(TiledMma{});\n    static constexpr int NumProducerThreads = NumMmaThreads;  // For compatibility with TileScheduler\n\n    static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);\n    static_assert(kHeadDim % kGmemElemsPerLoad == 0, \"Headdim must be a multiple of kGmemElemsPerLoad\");\n    // We want each \"row\" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each\n    // thread to have 4 loads in the M direction and 2 vectorized load in the K direction.\n    static constexpr int kBytePerRow = kHeadDim * sizeof(Element);\n    static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element);\n\n    static constexpr int kSwizzle = kBlockKGmem == 128 ? 4 : (kBlockKGmem == 64 ? 3 : (kBlockKGmem == 32 ? 2 : 1));\n    static constexpr int kSwizzleBase = sizeof(Element) == 4 ? 2 : (sizeof(Element) == 2 ? 3 : 4);\n    using SmemLayoutAtomQKV = decltype(\n        composition(Swizzle<kSwizzle, kSwizzleBase, kSwizzleBase>{},\n                    Layout<Shape<_8, Int<kBlockKGmem>>,\n                           Stride<Int<kBlockKGmem>, _1>>{}));\n    using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQKV{}, select<0, 2>(TileShape_MNK{})));\n\n    using SmemLayoutK = decltype(tile_to_shape(\n        SmemLayoutAtomQKV{},\n        make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));\n\n    using SmemLayoutV = decltype(tile_to_shape(\n        SmemLayoutAtomQKV{},\n        make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));\n    using SmemLayoutVt = decltype(\n        composition(SmemLayoutV{},\n                    make_ordered_layout(make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int<kStages>{}),\n                                        Step<_2, _1, _3>{})));\n\n    using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, Element>;\n    using SmemCopyAtomTransposed = Copy_Atom<SM75_U16x8_LDSM_T, Element>;\n\n    // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading\n    // from the same address by the same threadblock. This is slightly faster.\n    using GmemCopyAtom = Copy_Atom<std::conditional_t<\n        Has_cp_async,\n        SM80_CP_ASYNC_CACHEGLOBAL_ZFILL<cute::uint128_t>,\n        AutoVectorizingCopyWithAssumedAlignment<128>\n    >, Element>;\n\n    static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad;\n    static_assert(NumMmaThreads % kGmemThreadsPerRow == 0, \"NumMmaThreads must be a multiple of kGmemThreadsPerRow\");\n    using GmemLayoutAtom = Layout<Shape <Int<NumMmaThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,\n                                  Stride<Int<kGmemThreadsPerRow>, _1>>;\n    using GmemTiledCopyQKV = decltype(\n        make_tiled_copy(GmemCopyAtom{},\n                        GmemLayoutAtom{},\n                        Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{}));  // Val layout, 8 or 16 vals per read\n    // So that we don't have to check if we overshot kBlockM when we load Q\n    static_assert(kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0);\n\n    // For AppendKV, We want each thread to have at least 2 loads in the K direction since in the case of\n    // non-interleaved rotary (combining elements at indices 0 and rotary_dim/2, 1 and rotary_dim/2+1, etc),\n    // each thread will load twice from the same row.\n    static constexpr int kBytePerHalfRow = kHeadDim / 2 * sizeof(Element);\n    static constexpr int kBlockKGmemAppend = (kBytePerHalfRow % 128 == 0 ? 128 : (kBytePerHalfRow % 64 == 0 ? 64 : 32)) / sizeof(Element);\n    static constexpr int kGmemThreadsPerRowAppend = kBlockKGmemAppend / kGmemElemsPerLoad;\n    static_assert(NumMmaThreads % kGmemThreadsPerRowAppend == 0, \"NumMmaThreads must be a multiple of kGmemThreadsPerRowAppend\");\n    // We assume threads loading the same row are in the same warp. This is for an optimization in PagedKV where\n    // these threads share the same page table entry and share the work of computing pointers to paged K and paged V.\n    static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRowAppend == 0, \"kGmemThreadsPerRowAppend must divide NumThreadsPerWarp\");\n    using GmemLayoutAtomAppend = Layout<Shape <Int<NumMmaThreads / kGmemThreadsPerRowAppend>, Int<kGmemThreadsPerRowAppend>>,\n                                        Stride<Int<kGmemThreadsPerRowAppend>, _1>>;\n    // If AppendKV, we'll be loading Q for rotary, and we assume divisibility to avoid predication\n    static_assert(!AppendKV || kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtomAppend{})) == 0, \"kBlockM must be a multiple of NumMmaThreads / kGmemThreadsPerRowAppend\");\n    using GmemTiledCopyAppendKV = decltype(\n        make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},\n                        GmemLayoutAtomAppend{},\n                        Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{}));  // Val layout, 8 or 16 vals per store\n\n    using ShapeQKV = cute::Shape<int32_t, int32_t, int32_t, int32_t>;  // (seqlen, d, head, batch)\n    using StrideQK = cute::Stride<int64_t, _1, int64_t, int64_t>;\n    using StrideV = StrideQK;\n    // ((qhead_per_khead, seqlen_q), d, nheads_kv, batch, num_splits)\n    using ShapeQPacked = std::conditional_t<!PackGQA, ShapeQKV, cute::Shape<cute::Shape<int32_t, int32_t>, int32_t, int32_t, int32_t>>;\n    using StrideQPacked = std::conditional_t<!PackGQA, StrideQK, cute::Stride<cute::Stride<int64_t, int64_t>, _1, int64_t, int64_t>>;\n    using ShapePageTable = cute::Shape<int32_t, int32_t>;  // (batch, max_num_pages_per_seq)\n    using StridePageTable = cute::Stride<int64_t, _1>;\n    using ShapeRotary = cute::Shape<int32_t, int32_t>;  // (seqlen_ro, rotary_dim // 2)\n    using StrideRotary = cute::Stride<int64_t, _1>;\n    using StrideDescale = cute::Stride<int64_t, int64_t>;\n\n    static constexpr bool Share_QV_Smem = Q_in_regs;\n\n    struct TensorStorageSharedQV : cute::aligned_struct<128> {\n        union {\n            cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;\n            cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;\n        };\n        cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;\n    };\n\n    struct TensorStorageSeparateQV : cute::aligned_struct<128> {\n        cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;\n        cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;\n        cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;\n    };\n\n    using TensorStorage = std::conditional_t<Share_QV_Smem, TensorStorageSharedQV, TensorStorageSeparateQV>;\n\n    // Host side kernel arguments\n    struct Arguments {\n        Element const* const ptr_Q;\n        ShapeQKV const shape_Q;\n        StrideQK const stride_Q;\n        Element* const ptr_K;  // Not Element const* since we might append to KV cache in-place\n        ShapeQKV const shape_K;\n        StrideQK const stride_K;\n        Element* const ptr_V;\n        int32_t const headdim_v;\n        StrideV const stride_V;\n        Element const* const ptr_K_new;\n        ShapeQKV const shape_K_new;\n        StrideQK const stride_K_new;\n        Element const* const ptr_V_new;\n        StrideV const stride_V_new;\n        Element const* const ptr_Qv;\n        StrideQK const stride_Qv;\n        Element const* const ptr_rotary_cos;\n        ShapeRotary const shape_rotary;\n        StrideRotary const stride_rotary_cos;\n        Element const* const ptr_rotary_sin;\n        StrideRotary const stride_rotary_sin;\n        bool const is_rotary_interleaved;\n        int const* const ptr_pagetable;\n        ShapePageTable const shape_pagetable;\n        StridePageTable const stride_pagetable;\n        float const softmax_scale;\n        float const* ptr_q_descale, *ptr_k_descale, *ptr_v_descale;\n        StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale;\n        int const window_size_left = -1, window_size_right = -1, attention_chunk = 0;\n        float const softcap_val;\n        int const num_splits;\n        int const* const kv_batch_idx = nullptr;\n        int const* const cu_seqlens_q = nullptr;\n        int const* const cu_seqlens_k = nullptr;\n        int const* const cu_seqlens_k_new = nullptr;\n        int const* const seqused_q = nullptr;\n        int const* const seqused_k = nullptr;\n        int const* const leftpad_k = nullptr;\n        int const* const seqlens_rotary = nullptr;\n    };\n\n    // Device side kernel params\n    struct Params {\n        Element const* const ptr_Q;\n        ShapeQKV const shape_Q;\n        StrideQK const stride_Q;\n        ShapeQPacked const shape_Q_packed;\n        StrideQPacked const stride_Q_packed;\n        Element* const ptr_K;\n        ShapeQKV const shape_K;\n        StrideQK const stride_K;\n        Element* const ptr_V;\n        int32_t const headdim_v;\n        StrideV const stride_V;\n        Element const* const ptr_K_new;\n        ShapeQKV const shape_K_new;\n        StrideQK const stride_K_new;\n        Element const* const ptr_V_new;\n        StrideV const stride_V_new;\n        Element const* const ptr_rotary_cos;\n        ShapeRotary const shape_rotary;\n        StrideRotary const stride_rotary_cos;\n        Element const* const ptr_rotary_sin;\n        StrideRotary const stride_rotary_sin;\n        bool const is_rotary_interleaved;\n        int const* const ptr_pagetable;\n        ShapePageTable const shape_pagetable;\n        StridePageTable const stride_pagetable;\n        cutlass::FastDivmod page_size_divmod;\n        cutlass::FastDivmod qhead_per_khead_divmod;\n        float const softmax_scale_log2;\n        float const* ptr_q_descale, *ptr_k_descale, *ptr_v_descale;\n        StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale;\n        float const softcap_val;\n        int const window_size_left, window_size_right;\n        cutlass::FastDivmod attention_chunk_divmod;\n        int const num_splits;\n        int const* const kv_batch_idx = nullptr;\n        int const* const cu_seqlens_q = nullptr;\n        int const* const cu_seqlens_k = nullptr;\n        int const* const cu_seqlens_k_new = nullptr;\n        int const* const seqused_q = nullptr;\n        int const* const seqused_k = nullptr;\n        int const* const leftpad_k = nullptr;\n        int const* const seqlens_rotary = nullptr;\n    };\n\n    static Params\n    to_underlying_arguments(Arguments const& args) {\n        // If PackGQA, reshape Q to be ((qhead_per_khead, seqlen_q), head_size, nhead_k, batch_size)\n        int const qhead_per_khead = !PackGQA ? 1 : cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K));\n        auto const shape_Q_packed = cute::conditional_return<!PackGQA>(\n            args.shape_Q,\n            make_shape(make_shape(qhead_per_khead, get<0>(args.shape_Q)), get<1>(args.shape_Q), get<2>(args.shape_K), get<3>(args.shape_Q))\n        );\n        auto const stride_Q_packed = cute::conditional_return<!PackGQA>(\n            args.stride_Q,\n            make_stride(make_stride(get<2>(args.stride_Q), get<0>(args.stride_Q)), get<1>(args.stride_Q), get<2>(args.stride_Q) * qhead_per_khead, get<3>(args.stride_Q))\n        );\n        if (get<1>(args.shape_rotary) > 0) {\n            assert(args.ptr_rotary_cos != nullptr && args.ptr_rotary_sin != nullptr);\n        }\n        assert(args.num_splits >= 1);\n        // Avoid dividing by zero\n        cutlass::FastDivmod attention_chunk_divmod(args.attention_chunk >= 1 ? args.attention_chunk : 1);\n        attention_chunk_divmod.divisor = args.attention_chunk;\n        // If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val.\n        // Right after this, we multiply by log2(e) before applying exp2.\n        // To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val\n        // (assigning it to params.softcap_val) and pre-multiply softcap_val * log2(e)\n        // (assigning it to params.softmax_scale_log2).\n        return {args.ptr_Q, args.shape_Q, args.stride_Q, shape_Q_packed, stride_Q_packed,\n                args.ptr_K, args.shape_K, args.stride_K, args.ptr_V, args.headdim_v, args.stride_V,\n                args.ptr_K_new, args.shape_K_new, args.stride_K_new, args.ptr_V_new, args.stride_V_new,\n                args.ptr_rotary_cos, args.shape_rotary, args.stride_rotary_cos,\n                args.ptr_rotary_sin, args.stride_rotary_sin, args.is_rotary_interleaved,\n                args.ptr_pagetable, args.shape_pagetable, args.stride_pagetable,\n                cutlass::FastDivmod(int(get<0>(args.shape_K))),\n                cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))),\n                !Has_softcap ? float(args.softmax_scale * M_LOG2E) : float(args.softcap_val * M_LOG2E),\n                args.ptr_q_descale, args.ptr_k_descale, args.ptr_v_descale,\n                args.stride_q_descale, args.stride_k_descale, args.stride_v_descale,\n                !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val,\n                args.window_size_left, args.window_size_right, attention_chunk_divmod,\n                !Split ? 1 : args.num_splits,\n                args.kv_batch_idx,\n                args.cu_seqlens_q, args.cu_seqlens_k, args.cu_seqlens_k_new,\n                args.seqused_q, args.seqused_k, args.leftpad_k, args.seqlens_rotary};\n    }\n\n    template <typename SharedStorage, typename FrgTensorO, typename Softmax>\n    CUTLASS_DEVICE bool\n    mma(Params const& params,\n        FrgTensorO& tOrO,\n        Softmax& softmax,\n        int const thread_idx,\n        SeqlenInfo_t const& seqlen_info,\n        cute::tuple<int32_t, int32_t, int32_t, int32_t> block_coord,\n        SharedStorage& shared_storage\n        ) {\n        static_assert(is_rmem<FrgTensorO>::value, \"O tensor must be rmem resident.\");\n        static constexpr int kBlockM = get<0>(TileShape_MNK{});\n        static constexpr int kBlockN = get<1>(TileShape_MNK{});\n\n        // can't use auto [m_block, ...] = block_coord since structured binding cannot be captured in lambda\n        int const m_block = get<0>(block_coord);\n        int const bidh = get<1>(block_coord);\n        int const bidb = get<2>(block_coord);\n        int const split_idx = get<3>(block_coord);\n        int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh;\n        auto n_block_min_max = BlockMN_t::get_n_block_min_max(\n            seqlen_info, m_block, bidb, split_idx, params.num_splits,\n            params.window_size_left, params.window_size_right, params.attention_chunk_divmod,\n            params.qhead_per_khead_divmod);\n        int const n_block_min = get<0>(n_block_min_max);\n        int const n_block_max = get<1>(n_block_min_max);\n        // It's possible to have n_block_max <= n_block_min. We don't want to load Q or change any barrier\n        if constexpr (Is_causal || Is_local || Varlen || Split) {\n            if (n_block_max <= n_block_min) { return false; }\n        }\n\n        Tensor sQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQ{});\n        Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{});\n        Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutV{});\n        Tensor sVt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVt{});\n\n        bool const is_varlen_q = Varlen && params.cu_seqlens_q;\n        bool const is_varlen_k = Varlen && params.cu_seqlens_k;\n\n        int const bidb_kv = params.kv_batch_idx == nullptr ? bidb : params.kv_batch_idx[bidb];\n        Tensor mQ = make_tensor(make_gmem_ptr(params.ptr_Q + seqlen_info.offset_q * get<0>(params.stride_Q)), params.shape_Q_packed, params.stride_Q_packed)(_, _, bidh, !is_varlen_q ? bidb : 0);\n        Tensor gQ = local_tile(mQ, select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{}));  // (M, K)\n        Tensor mK = make_tensor(make_gmem_ptr(params.ptr_K + seqlen_info.offset_k * get<0>(params.stride_K)), params.shape_K, params.stride_K)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0);\n        Tensor gK = local_tile(mK, select<1, 2>(TileShape_MNK{}), make_coord(_, _0{}));  // (N, K, _)\n        Tensor mV = make_tensor(make_gmem_ptr(params.ptr_V + seqlen_info.offset_k * get<0>(params.stride_V)), params.shape_K, params.stride_V)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0);\n        Tensor gV = local_tile(mV, select<1, 2>(TileShape_MNK{}), make_coord(_, _0{}));  // (N, K, _)\n\n        GmemTiledCopyQKV gmem_tiled_copy_QKV;\n        auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(thread_idx);\n        auto gmem_thr0_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(_0{});  // For index calculation\n\n        Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK);  // (KCPY, KCPY_N, KCPY_K, nblocksN)\n        Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);\n        Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV);  // (VCPY, VCPY_N, VCPY_K, nblocksN)\n        Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);\n\n        TiledMma tiled_mma;\n        auto thr_mma = tiled_mma.get_slice(thread_idx);\n\n        // Allocate \"fragments/descriptors\"\n        Tensor tSrQ = thr_mma.partition_fragment_A(sQ);\n\n        // Copy Atom retiling\n        auto smem_tiled_copy_Q = make_tiled_copy_A(SmemCopyAtom{}, tiled_mma);\n        auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(thread_idx);\n        auto smem_tiled_copy_K = make_tiled_copy_B(SmemCopyAtom{}, tiled_mma);\n        auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(thread_idx);\n        auto smem_tiled_copy_V = make_tiled_copy_B(SmemCopyAtomTransposed{}, tiled_mma);\n        auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(thread_idx);\n        Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);\n        Tensor tSsK = smem_thr_copy_K.partition_S(sK);\n        Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);\n\n        // Predicates\n        Tensor cKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{}));\n        Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV);\n        Tensor t0KVcKV = gmem_thr0_copy_QKV.partition_S(cKV);\n        Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK)));\n        #pragma unroll\n        for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(_0{}, _0{}, k)) < get<1>(params.shape_K); }\n\n        int const seqlen_q = seqlen_info.seqlen_q;\n        int const seqlen_k = seqlen_info.seqlen_k;\n        int n_block = n_block_max - 1;\n\n        // Prologue: load Q, K, V\n        // If persistent, we don't need to wait for the previous work_idx to finish\n        // since we assume that all MMA threads sync in the epilogue before writing to smem_o.\n        // So any thread gets there, all threads must have finished the previous MMA and at least started\n        // writing to smem_o.\n        // If persistent, need to sync to make sure all threads have finished with smem_o before writing to smem_v\n        if constexpr (Share_QV_Smem) { __syncthreads(); }\n        if constexpr (!PackGQA) {\n            Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);\n            Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);\n            Tensor cQ = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{}));\n            Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ);\n            Tensor t0QcQ = gmem_thr0_copy_QKV.partition_S(cQ);\n            Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));\n            #pragma unroll\n            for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(_0{}, _0{}, k)) < get<1>(params.shape_Q); }\n            // Instead of passing in tQcQ, we pass in t0QcQ and subtract the offset from the limit\n            // (seqlen_q - m_block * kBlockM). This is because the entries of t0QcQ are known at compile time.\n            // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs\n            flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/true>(\n                gmem_tiled_copy_QKV, tQgQ, tQsQ, t0QcQ, tQpQ, seqlen_info.seqlen_q - m_block * kBlockM - get<0>(tQcQ(_0{}, _0{}, _0{}))\n            );\n        } else {\n            using PackGQAt = flash::PackGQAManager<get<0>(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumMmaThreads, Element>;\n            PackGQAt::load_Q(mQ, sQ, params.qhead_per_khead_divmod, thread_idx, seqlen_q, m_block);\n        }\n        cute::cp_async_fence();\n\n        using PagedKVManager_t = PagedKVManager<get<1>(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumMmaThreads, Element, true /*KV_Same_Iter*/>;\n        PagedKVManager_t paged_kv_manager(\n            params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable,\n            params.ptr_K, params.shape_K, params.stride_K,\n            params.ptr_V, params.headdim_v, params.stride_V,\n            params.page_size_divmod,\n            params.page_size_divmod /*blockN_per_page_size_divmod, not used since we don't use TMA*/,\n            bidb_kv, bidh_kv, thread_idx, seqlen_info.seqlen_k, seqlen_info.leftpad_k,\n            0 /*bidb_kv_idx, not used since we don't use TMA for Sm8x*/\n        );\n\n        auto load_K = [&] (int const n_block, int const smem_pipe_write, auto need_seqlenk_masking_type) {\n            static constexpr bool Seqlenk_mask = decltype(need_seqlenk_masking_type)::value;\n            if constexpr (!PagedKV) {\n                // Do we need bound check to make sure the row doesn't go above kBlockN\n                static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0;\n                Tensor tKsK_cur = tKsK(_, _, _, smem_pipe_write);\n                // Instead of passing in tKVcKV, we pass in t0KVcKV and subtract the offset from the limit\n                // (seqlen_k - n_block * kBlockN). This is because the entries of t0KVcKV are known at compile time.\n                int const seqlenk_row_limit = -int(get<0>(tKVcKV(_0{}, _0{}, _0{}))) + (EvenN\n                    ? seqlen_info.seqlen_k - n_block * kBlockN\n                    : (!Seqlenk_mask ? kBlockN : std::min(seqlen_info.seqlen_k - n_block * kBlockN, kBlockN)));\n                // We don't need to clear the sK smem tiles since we'll mask out the scores anyway.\n                flash::copy</*Is_even_MN=*/!Seqlenk_mask && EvenN, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/true>(\n                    gmem_tiled_copy_QKV, tKgK(_, _, _, n_block), tKsK_cur, t0KVcKV, tKVpKV, seqlenk_row_limit);\n            } else {\n                paged_kv_manager.template load_page_table<Seqlenk_mask>(n_block);\n                paged_kv_manager.template load_K<Seqlenk_mask>(n_block, sK(_, _, smem_pipe_write));\n            }\n        };\n\n        auto load_V = [&] (int const n_block, int const smem_pipe_write, auto need_seqlenk_masking_type) {\n            static constexpr bool Seqlenk_mask = decltype(need_seqlenk_masking_type)::value;\n            if constexpr (!PagedKV) {\n                // Do we need bound check to make sure the row doesn't go above kBlockN\n                static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0;\n                Tensor tVsV_cur = tVsV(_, _, _, smem_pipe_write);\n                // We don't call flash::copy since it doesn't support bound checking\n                // to not overshot kBlockN when writing to smem.\n                Tensor tVgV_cur = tVgV(_, _, _, n_block);\n                int const seqlenk_row_limit = seqlen_info.seqlen_k - n_block * kBlockN - get<0>(tKVcKV(_0{}, _0{}, _0{}));\n                #pragma unroll\n                for (int m = 0; m < size<1>(tVsV); ++m) {\n                    // If kBlockN doesn't evenly divide the tiled copy, only the last `m` needs to be checked\n                    if (EvenN || m < size<1>(tVsV) - 1 || get<0>(tKVcKV(_0{}, m, _0{})) < kBlockN) {\n                        bool const predicate_n = !Seqlenk_mask || get<0>(t0KVcKV(_0{}, m, _0{})) < seqlenk_row_limit;\n                        #pragma unroll\n                        for (int k = 0; k < size<2>(tVsV); ++k) {\n                            cute::copy(gmem_tiled_copy_QKV.with(tKVpKV(k) && predicate_n), tVgV_cur(_, m, k), tVsV_cur(_, m, k));\n                        }\n                    }\n                }\n            } else {\n                paged_kv_manager.template load_V<Seqlenk_mask>(n_block, sV(_, _, smem_pipe_write));\n            }\n        };\n\n        auto preprocess_Q = [&] {\n            if constexpr (!AppendKV) {\n                flash::cp_async_wait<Share_QV_Smem ? 1 : kStages * 2 - 1>();\n            } else {\n                if (get<1>(params.shape_rotary) > 0) {  // Apply rotary to Q\n                    using Rotary_t = Rotary<kBlockM, kHeadDim, NumMmaThreads, Element, !(Is_causal || Is_local) /*FixedPosition*/>;\n                    Rotary_t rotary(params.ptr_rotary_cos, params.shape_rotary, params.stride_rotary_cos,\n                                    params.ptr_rotary_sin, params.stride_rotary_sin,\n                                    params.is_rotary_interleaved, thread_idx, seqlen_q,\n                                    seqlen_info.seqlen_rotary);\n                    int const qhead_per_khead = !PackGQA ? 1 : params.qhead_per_khead_divmod.divisor;\n                    if (params.is_rotary_interleaved) {\n                        auto [tRrCos, tRrSin] = cute::conditional_return<!PackGQA>(\n                            rotary.template load_cos_sin<true /*kInterleaved*/>(m_block),\n                            rotary.template load_cos_sin_packgqa<true /*kInterleaved*/>(m_block, params.qhead_per_khead_divmod)\n                        );\n                        flash::cp_async_wait<Share_QV_Smem ? 1 : kStages * 2 - 1>();\n                        __syncthreads();\n                        rotary.apply_Q_interleaved(sQ, tRrCos, tRrSin, m_block, qhead_per_khead);\n                    } else {\n                        auto [tRrCosCont, tRrSinCont] = cute::conditional_return<!PackGQA>(\n                            rotary.template load_cos_sin<false /*kInterleaved*/>(m_block),\n                            rotary.template load_cos_sin_packgqa<false /*kInterleaved*/>(m_block, params.qhead_per_khead_divmod)\n                        );\n                        flash::cp_async_wait<Share_QV_Smem ? 1 : kStages * 2 - 1>();\n                        __syncthreads();\n                        rotary.apply_Q_contiguous(sQ, tRrCosCont, tRrSinCont, m_block, qhead_per_khead);\n                    }\n                } else {\n                    flash::cp_async_wait<Share_QV_Smem ? 1 : kStages * 2 - 1>();\n                }\n            }\n\n            if constexpr (Q_in_regs) {\n                __syncthreads();\n                Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);\n                Tensor tSsQ_copy_view = smem_thr_copy_Q.partition_S(sQ);\n                cute::copy(smem_tiled_copy_Q, tSsQ_copy_view, tSrQ_copy_view);\n            }\n        };\n\n        // If Share_QV_Smem, we load Q, then load 1 stage of K, then (optionally) rotate Q and\n        // read from smem_q to registers, then load V.\n        // If !Share_QV, Smem, we load Q, load all stages of K & V, then (optionally) rotate Q.\n\n        if constexpr (Share_QV_Smem) {\n            load_K(n_block, 0, cute::true_type{} /*Seqlenk_mask*/);\n            cute::cp_async_fence();\n            preprocess_Q();\n            __syncthreads();  // Make sure all threads have read smem_q before loading V\n        }\n\n        // For persistent, make sure all threads have finished reading smem_o\n        if constexpr (!Share_QV_Smem) { __syncthreads(); }\n        // Note, using the for_each() function here to ensure `stage` is of type Int<x>.\n        for_each(make_int_sequence<kStages>{}, [&] (auto stage) {\n            static constexpr bool Is_first_stage = CUTE_STATIC_V(stage) == 0;\n            static constexpr bool Is_last_stage = CUTE_STATIC_V(stage) == kStages - 1;\n            if constexpr (!Share_QV_Smem || !Is_first_stage) {\n                if (Is_first_stage || n_block - stage >= n_block_min) {\n                    load_K(n_block - stage, stage, cute::bool_constant<Is_first_stage>{} /*Seqlenk_mask*/);\n                }\n                // We want the fence outside the if statement to have a fixed number of cp.async commits.\n                // so that we can wait with the correct number of outstanding commits.\n                cute::cp_async_fence();\n            }\n            if constexpr (!Is_last_stage) {\n                if (Is_first_stage || n_block - stage >= n_block_min) {\n                    load_V(n_block - stage, stage, cute::bool_constant<Is_first_stage>{} /*Seqlenk_mask*/);\n                }\n                cute::cp_async_fence();\n            }\n        });\n\n        if constexpr (!Share_QV_Smem) { preprocess_Q(); }\n\n        flash::Mask<kBlockM, kBlockN, PackGQA, TiledMma> mask(\n            thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, 0 /*sink_token_length*/,\n            params.attention_chunk_divmod, params.qhead_per_khead_divmod\n        );\n\n        float softcap_val = params.softcap_val;\n        if constexpr (Has_softcap && Is_FP8) {\n            float const q_descale = params.ptr_q_descale == nullptr ? 1.0f : params.ptr_q_descale[bidb * get<0>(params.stride_q_descale) + bidh_kv * get<1>(params.stride_q_descale)];\n            float const k_descale = params.ptr_k_descale == nullptr ? 1.0f : params.ptr_k_descale[bidb * get<0>(params.stride_k_descale) + bidh_kv * get<1>(params.stride_k_descale)];\n            softcap_val *= q_descale * k_descale;\n        }\n        // Softcapping needs to happen before masking since if we apply after masking, softcapping can turn\n        // -inf to e.g. -50.0, which can affect the attention softmax.\n        auto scoremod_premask_fn = [&](auto& tSrS) {\n            if constexpr (Has_softcap) { flash::apply_softcap(tSrS, softcap_val); }\n        };\n\n        int smem_pipe_read = 0, smem_pipe_write = kStages - 1;\n\n        auto load_K_next = [&] {\n            if (n_block - kStages >= n_block_min) {\n                load_K(n_block - kStages, kStages > 1 ? smem_pipe_write : 0, cute::false_type{} /*Seqlenk_mask*/);\n            }\n            cute::cp_async_fence();\n        };\n\n        auto sync = [&] {\n            flash::cp_async_wait<kStages * 2 - 2>();\n            __syncthreads();\n        };\n\n        clear(tOrO);\n\n        auto fwd_step = [&](int const n_block, auto mask_fn, auto is_first_iter_type, auto check_inf_type) {\n            static constexpr bool Is_first_iter = decltype(is_first_iter_type)::value;\n            static constexpr bool Check_inf = decltype(check_inf_type)::value;\n            Tensor tSrS = partition_fragment_C(tiled_mma, select<0, 1>(TileShape_MNK{}));\n            clear(tSrS);\n            sync();\n            auto load_V_next = [&] {\n                if (n_block - kStages + 1 >= n_block_min) {\n                    load_V(n_block - kStages + 1, kStages > 1 ? smem_pipe_write : 0, cute::bool_constant<Is_first_iter && kStages == 1>{} /*Seqlenk_mask*/);\n                }\n                cute::cp_async_fence();\n            };\n            Tensor tSrQ_cur = cute::conditional_return<Q_in_regs>(tSrQ, thr_mma.partition_fragment_A(sQ));\n            Tensor tSrK = thr_mma.partition_fragment_B(sK(_, _, _0{}));\n            flash::gemm_sm80<Q_in_regs>(\n                tSrS, tSrQ_cur, tSrK, tSsQ, tSsK(_, _, _, kStages > 1 ? smem_pipe_read : 0),\n                tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K, load_V_next\n            );\n            smem_pipe_write = smem_pipe_write < kStages - 1 ? smem_pipe_write + 1 : 0;\n            scoremod_premask_fn(tSrS);\n            // Faster to load_K before gemm if we only have 1 stage\n            if constexpr (kStages == 1) { sync(); load_K_next(); }\n            mask_fn(tSrS, n_block);\n            Tensor scores_scale = softmax.template max_get_scale</*Is_first=*/Is_first_iter, Check_inf>(tSrS);\n            softmax.template online_softmax</*Is_first=*/Is_first_iter, Check_inf>(tSrS);\n            if constexpr (Is_FP8) { flash::permute_Cregs_fp8(tSrS); }\n            Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs<TiledMma>(tSrS.layout()));\n            Tensor tOrP = make_tensor_like<Element>(tOrP_acc);\n            convert_type_out(tOrP_acc, tOrP);\n            if constexpr (!Is_first_iter) { softmax.rescale_o(tOrO, scores_scale); }\n            if constexpr (kStages > 1) { sync(); }\n            Tensor tOrV = thr_mma.partition_fragment_B(sVt(_, _, _0{}));\n            flash::gemm_rs_sm80(tOrO, tOrP, tOrV, tOsVt(_, _, _, kStages > 1 ? smem_pipe_read : 0), tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);\n            if constexpr (kStages > 1) { load_K_next(); }\n            smem_pipe_read = smem_pipe_read < kStages - 1 ? smem_pipe_read + 1 : 0;\n        };\n\n        auto first_iter_mask_fn = [&](auto& tSrS, int n_block) { mask.template apply<true /*Seqlenk_mask*/, Is_causal, Is_local>(tSrS, m_block, n_block); };\n        fwd_step(n_block, first_iter_mask_fn, cute::true_type{} /*is_first_iter*/, cute::true_type{} /*check_inf*/);\n        --n_block;\n        if constexpr (Is_causal || Is_local) { // Separate iterations with causal or local masking\n            auto mask_fn = [&](auto& tSrS, int n_block) { mask.template apply<false /*Seqlenk_mask*/, Is_causal, Is_local>(tSrS, m_block, n_block); };\n            int const n_block_min_causal_local_mask = BlockMN_t::get_n_block_min_causal_local_mask(\n                seqlen_info, m_block, n_block_min, params.window_size_right,\n                params.attention_chunk_divmod, params.qhead_per_khead_divmod);\n            #pragma unroll 1\n            for (; n_block >= n_block_min_causal_local_mask; --n_block) {\n                fwd_step(n_block, mask_fn, cute::false_type{} /*is_first_iter*/, cute::true_type{} /*check_inf*/);\n            }\n        }\n        int const n_block_min_before_local_mask = BlockMN_t::get_n_block_min_before_local_mask(\n            seqlen_info, m_block, n_block_min, params.window_size_left,\n            params.attention_chunk_divmod, params.qhead_per_khead_divmod);\n        auto no_mask_fn = [](auto& tSrS, int n_block) { };\n        #pragma unroll 1\n        for (; n_block >= n_block_min_before_local_mask; --n_block) {\n            fwd_step(n_block, no_mask_fn, cute::false_type{} /*is_first_iter*/, cute::false_type{} /*check_inf*/);\n        }\n        // Separate masking iterations on the left for local attention\n        if constexpr (Is_local) {\n            auto local_mask_fn = [&](auto& tSrS, int n_block) { mask.template apply<false /*Seqlenk_mask*/, false /*Causal_mask*/, Is_local>(tSrS, m_block, n_block); };\n            #pragma unroll 1\n            for (; n_block >= n_block_min; --n_block) {\n                fwd_step(n_block, local_mask_fn, cute::false_type{} /*is_first_iter*/, cute::bool_constant<Is_local>{} /*check_inf*/);\n            }\n        }\n        float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)];\n        Tensor scores_scale = softmax.finalize(v_descale);\n        softmax.rescale_o(tOrO, scores_scale);\n        if constexpr (Is_FP8) { flash::permute_output_fp8(tOrO); }\n        return true;\n    }\n\n    template <typename SharedStorage>\n    CUTLASS_DEVICE bool\n    store_kv_new(Params const& params,\n                 int const thread_idx,\n                 SharedStorage &shared_storage,\n                 SeqlenInfo_t const& seqlen_info,\n                 cute::tuple<int32_t, int32_t, int32_t, int32_t> block_coord\n    ) {\n        auto [m_block, bidh, bidb, split_idx] = block_coord;\n        auto n_block_new_min_max = BlockMN_t::get_n_block_k_new_min_max(\n            seqlen_info, m_block, bidb, split_idx, params.num_splits,\n            params.window_size_left, params.window_size_right, params.attention_chunk_divmod,\n            params.qhead_per_khead_divmod);\n        int const n_block_new_min = get<0>(n_block_new_min_max);\n        int const n_block_new_max = get<1>(n_block_new_min_max);\n        if (n_block_new_max <= n_block_new_min) { return false; }\n\n        Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{});\n        Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutV{});\n\n        int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh;\n        int const bidb_kv = params.kv_batch_idx == nullptr ? bidb : params.kv_batch_idx[bidb];\n\n        bool const is_varlen_k_new = Varlen && params.cu_seqlens_k_new;\n        Tensor mKnew = make_tensor(make_gmem_ptr(params.ptr_K_new), params.shape_K_new, params.stride_K_new)(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0);\n        Tensor mVnew = make_tensor(make_gmem_ptr(params.ptr_V_new), params.shape_K_new, params.stride_V_new)(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0);\n\n        bool const is_varlen_k = Varlen && params.cu_seqlens_k;\n        Tensor mK = make_tensor(make_gmem_ptr(params.ptr_K), params.shape_K, params.stride_K)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0);\n        Tensor mV = make_tensor(make_gmem_ptr(params.ptr_V), params.shape_K, params.stride_V)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0);\n\n        Tensor gKnew = local_tile(domain_offset(make_coord(seqlen_info.offset_k_new, _0{}), mKnew), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{}));  // (N, K, _)\n        Tensor gVnew = local_tile(domain_offset(make_coord(seqlen_info.offset_k_new, _0{}), mVnew), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{}));  // (N, K, _)\n        int const offset_k = seqlen_info.offset_k + seqlen_info.seqlen_k_og;\n        Tensor gK = local_tile(domain_offset(make_coord(offset_k, _0{}), mK), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{}));  // (N, K, _)\n        Tensor gV = local_tile(domain_offset(make_coord(offset_k, _0{}), mV), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{}));  // (N, K, _)\n\n        static constexpr int kBlockN = get<1>(TileShape_MNK{});\n        static constexpr int kHeadDim = get<2>(TileShape_MNK{});\n        int const seqlen_k_new = seqlen_info.seqlen_k_new;\n        using Rotary_t = Rotary<kBlockN, kHeadDim, NumMmaThreads, Element>;\n        Rotary_t rotary(params.ptr_rotary_cos, params.shape_rotary, params.stride_rotary_cos,\n                        params.ptr_rotary_sin, params.stride_rotary_sin,\n                        params.is_rotary_interleaved, thread_idx, seqlen_k_new,\n                        seqlen_info.seqlen_rotary);\n\n        using PagedKVManager_t = PagedKVManager<get<1>(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumMmaThreads, Element, true /*KV_Same_Iter*/, 2 /*LoadsPerRow_LB*/>;\n        PagedKVManager_t paged_kv_manager(\n            params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable,\n            params.ptr_K, params.shape_K, params.stride_K,\n            params.ptr_V, params.headdim_v, params.stride_V,\n            params.page_size_divmod,\n            params.page_size_divmod /*blockN_per_page_size_divmod, not used since we don't use TMA*/,\n            bidb_kv, bidh_kv, thread_idx, seqlen_k_new, offset_k,\n            // passing offset_k instead of leftpad_k will move the PageTable pointer to the right position\n            0 /*bidb_kv_idx, not used since we don't use TMA for Sm8x*/\n        );\n\n        static_assert(std::is_same_v<GmemLayoutAtomAppend, typename Rotary_t::LayoutAtom>);\n        static_assert(!PagedKV || std::is_same_v<GmemLayoutAtomAppend, typename PagedKVManager_t::GmemLayoutAtomKVCpAsync>);\n        GmemTiledCopyQKV gmem_tiled_copy_kv_g2s;\n        auto gmem_thr_copy_kv_g2s = gmem_tiled_copy_kv_g2s.get_thread_slice(thread_idx);\n        auto gmem_thr0_copy_kv_g2s = gmem_tiled_copy_kv_g2s.get_thread_slice(_0{});  // Only for index calculation\n        GmemTiledCopyAppendKV gmem_tiled_copy_kv_s2g;\n        auto gmem_thr_copy_kv_s2g = gmem_tiled_copy_kv_s2g.get_thread_slice(thread_idx);\n        auto gmem_thr0_copy_kv_s2g = gmem_tiled_copy_kv_s2g.get_thread_slice(_0{});  // Only for index calculation\n        Tensor tKgKnew = gmem_thr_copy_kv_g2s.partition_S(gKnew);\n        Tensor tKsKg2s = gmem_thr_copy_kv_g2s.partition_S(sK);\n        Tensor tKsKs2g = gmem_thr_copy_kv_s2g.partition_S(sK);        // ((Atom,AtomNum),ATOM_M,ATOM_N)\n        Tensor tKgK = gmem_thr_copy_kv_s2g.partition_D(gK);\n        Tensor tVgVnew = gmem_thr_copy_kv_g2s.partition_S(gVnew);        // ((Atom,AtomNum),ATOM_M,ATOM_N)\n        Tensor tVsVg2s = gmem_thr_copy_kv_g2s.partition_S(sV);        // ((Atom,AtomNum),ATOM_M,ATOM_N)\n        Tensor tVsVs2g = gmem_thr_copy_kv_s2g.partition_S(sV);        // ((Atom,AtomNum),ATOM_M,ATOM_N)\n        Tensor tVgV = gmem_thr_copy_kv_s2g.partition_D(gV);\n        Tensor cK = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{}));  // (BLK_N,BLK_K) -> (blk_n,blk_k)\n        Tensor tKcKg2s = gmem_thr_copy_kv_g2s.partition_D(cK);\n        Tensor t0KcKg2s = gmem_thr0_copy_kv_g2s.partition_D(cK);\n        Tensor tKpKg2s = make_tensor<bool>(make_shape(size<2>(tKsKg2s)));\n        Tensor tKcKs2g = gmem_thr_copy_kv_s2g.partition_D(cK);\n        Tensor t0KcKs2g = gmem_thr0_copy_kv_s2g.partition_D(cK);\n        Tensor tKpKs2g = make_tensor<bool>(make_shape(size<2>(tKsKs2g)));\n        #pragma unroll\n        for (int k = 0; k < size(tKpKg2s); ++k) { tKpKg2s(k) = get<1>(tKcKg2s(_0{}, _0{}, k)) < get<1>(params.shape_K); }\n        #pragma unroll\n        for (int k = 0; k < size(tKpKs2g); ++k) { tKpKs2g(k) = get<1>(tKcKs2g(_0{}, _0{}, k)) < get<1>(params.shape_K); }\n\n        auto load_K_new = [&] (int const n_block, int const smem_pipe_write, auto need_seqlenk_masking_type) {\n            static constexpr bool Seqlenk_mask = decltype(need_seqlenk_masking_type)::value;\n            static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0;\n            Tensor tKsK_cur = tKsKg2s(_, _, _, smem_pipe_write);\n            int const seqlenk_row_limit = -int(get<0>(tKcKg2s(_0{}, _0{}, _0{}))) + (EvenN\n                ? seqlen_k_new - n_block * kBlockN\n                : (!Seqlenk_mask ? kBlockN : std::min(seqlen_k_new - n_block * kBlockN, kBlockN)));\n            // We don't need to clear the sK smem tiles since we won't write them out\n            flash::copy</*Is_even_MN=*/!Seqlenk_mask && EvenN, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/true>(\n                gmem_tiled_copy_kv_g2s, tKgKnew(_, _, _, n_block), tKsK_cur, t0KcKg2s, tKpKg2s, seqlenk_row_limit);\n        };\n\n        auto load_V_new = [&] (int const n_block, int const smem_pipe_write, auto need_seqlenk_masking_type) {\n            static constexpr bool Seqlenk_mask = decltype(need_seqlenk_masking_type)::value;\n            static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0;\n            Tensor tVsV_cur = tVsVg2s(_, _, _, smem_pipe_write);\n            int const seqlenk_row_limit = -int(get<0>(tKcKg2s(_0{}, _0{}, _0{}))) + (EvenN\n                ? seqlen_k_new - n_block * kBlockN\n                : (!Seqlenk_mask ? kBlockN : std::min(seqlen_k_new - n_block * kBlockN, kBlockN)));\n            // We don't need to clear the sV smem tiles since we won't write them out\n            flash::copy</*Is_even_MN=*/!Seqlenk_mask && EvenN, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/true>(\n                gmem_tiled_copy_kv_g2s, tVgVnew(_, _, _, n_block), tVsV_cur, t0KcKg2s, tKpKg2s, seqlenk_row_limit);\n        };\n\n        auto store_K = [&] (int const n_block, int const smem_pipe_read) {\n            int const n_limit = std::min(seqlen_k_new - n_block * kBlockN, kBlockN);\n            if (get<1>(params.shape_rotary) <= 0) {\n                Tensor tKsK_cur = tKsKs2g(_, _, _, smem_pipe_read);\n                if constexpr (!PagedKV) {\n                    Tensor tKgK_cur = tKgK(_, _, _, n_block);\n                    // Clear_OOB_K must be false since we don't want to write zeros to gmem\n                    flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(\n                        gmem_tiled_copy_kv_s2g, tKsK_cur, tKgK_cur, tKcKs2g, tKpKs2g, std::min(seqlen_k_new - n_block * kBlockN, kBlockN)\n                    );\n                } else {\n                    paged_kv_manager.store_K(n_block, tKsK_cur);\n                }\n            } else {\n                Tensor gK_cur = gK(_, _, n_block);\n                auto tPrKPtr = cute::conditional_return<PagedKV>(paged_kv_manager.compute_K_ptr(), nullptr);\n                if (params.is_rotary_interleaved) {\n                    auto [tRrCos, tRrSin] = rotary.template load_cos_sin<true /*kInterleaved*/>(n_block);\n                    rotary.template apply_K_interleaved<PagedKV>(sK(_, _, smem_pipe_read), gK_cur, tKpKs2g, tRrCos, tRrSin, tPrKPtr, n_block);\n                } else {\n                    auto [tRrCosCont, tRrSinCont] = rotary.template load_cos_sin<false /*kInterleaved*/>(n_block);\n                    rotary.template apply_K_contiguous<PagedKV>(sK(_, _, smem_pipe_read), gK_cur, tKpKs2g, tRrCosCont, tRrSinCont, tPrKPtr, n_block, get<1>(params.shape_K));\n                }\n            }\n        };\n\n        auto store_V = [&] (int const n_block, int const smem_pipe_read) {\n            int const n_limit = std::min(seqlen_k_new - n_block * kBlockN, kBlockN);\n            Tensor tVsV_cur = tVsVs2g(_, _, _, smem_pipe_read);\n            if constexpr (!PagedKV) {\n                Tensor tVgV_cur = tVgV(_, _, _, n_block);\n                // Clear_OOB_K must be false since we don't want to write zeros to gmem\n                flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(\n                    gmem_tiled_copy_kv_s2g, tVsV_cur, tVgV_cur, tKcKs2g, tKpKs2g, n_limit);\n            } else {\n                paged_kv_manager.store_V(n_block, tVsV_cur);\n            }\n        };\n\n        int n_block = n_block_new_max - 1;\n        // Note, using the for_each() function here to ensure `stage` is of type Int<x>.\n        for_each(make_int_sequence<kStages>{}, [&] (auto stage) {\n            static constexpr bool Is_first_stage = CUTE_STATIC_V(stage) == 0;\n            static constexpr bool Is_last_stage = CUTE_STATIC_V(stage) == kStages - 1;\n            if (Is_first_stage || n_block - stage >= n_block_new_min) {\n                load_K_new(n_block - stage, stage, cute::bool_constant<Is_first_stage>{} /*Seqlenk_mask*/);\n            }\n            cute::cp_async_fence();\n            // If persistent, need to sync to make sure all threads have finished with smem_o before writing to smem_v\n            if constexpr (Is_first_stage) { __syncthreads(); }\n            if constexpr (!Is_last_stage) {\n                if (Is_first_stage || n_block - stage >= n_block_new_min) {\n                    load_V_new(n_block - stage, stage, cute::bool_constant<Is_first_stage>{} /*Seqlenk_mask*/);\n                }\n                cute::cp_async_fence();\n            }\n        });\n\n        int smem_pipe_read = 0, smem_pipe_write = kStages - 1;\n        #pragma unroll 1\n        for (; n_block >= n_block_new_min; --n_block) {\n            if constexpr (PagedKV) { paged_kv_manager.template load_page_table<true /*Seqlenk_mask*/>(n_block); }\n            flash::cp_async_wait<kStages * 2 - 2>();\n            __syncthreads();\n            store_K(n_block, kStages > 1 ? smem_pipe_read : 0);\n            if (n_block - kStages + 1 >= n_block_new_min) {\n                load_V_new(n_block - kStages + 1, kStages > 1 ? smem_pipe_write : 0, cute::bool_constant<kStages == 1>{} /*Seqlenk_mask*/);\n            }\n            cute::cp_async_fence();\n            smem_pipe_write = smem_pipe_write < kStages - 1 ? smem_pipe_write + 1 : 0;\n            flash::cp_async_wait<kStages * 2 - 2>();\n            __syncthreads();\n            store_V(n_block, kStages > 1 ? smem_pipe_read : 0);\n            smem_pipe_read = smem_pipe_read < kStages - 1 ? smem_pipe_read + 1 : 0;\n            if (n_block - kStages >= n_block_new_min) {\n                load_K_new(n_block - kStages, kStages > 1 ? smem_pipe_write : 0, cute::false_type{} /*Seqlenk_mask*/);\n            }\n            cute::cp_async_fence();\n        }\n\n        return true;\n\n    }\n\n};\n\n} // namespace flash\n"
  },
  {
    "path": "hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n#include <cutlass/cutlass.h>\n#include <cutlass/array.h>\n#include <cutlass/numeric_types.h>\n#include <cutlass/numeric_conversion.h>\n#include \"cutlass/pipeline/pipeline.hpp\"\n\n#include \"cute/tensor.hpp\"\n\n#include \"cutlass/gemm/collective/builders/sm90_common.inl\"\n\n#include \"named_barrier.hpp\"\n#include \"seqlen.h\"\n#include \"block.h\"\n#include \"mask.h\"\n#include \"pack_gqa.h\"\n#include \"paged_kv.h\"\n#include \"rotary.h\"\n#include \"utils.h\"\n#include \"sm90_pipeline_no_cluster.hpp\"\n\nnamespace flash {\n\nusing namespace cute;\n\ntemplate <int Stages, class ClusterShape_, class TileShape_MNK_, int kHeadDimV, class Element_, class ElementAccum_, class ArchTag_,\n        bool Is_causal_, bool Is_local_, bool Has_softcap_, bool Varlen_, bool PagedKVNonTMA_, bool AppendKV_, bool HasQv_,\n        bool MmaPV_is_RS, bool IntraWGOverlap, bool PackGQA_, bool Split_, bool V_colmajor_>\nstruct CollectiveMainloopFwdSm90 {\n\n    static constexpr int kStages = Stages;\n    using ClusterShape = ClusterShape_;\n    using TileShape_MNK = TileShape_MNK_;\n    using TileShape_MNK_PV = Shape<decltype(get<0>(TileShape_MNK{})), Int<kHeadDimV>, decltype(get<1>(TileShape_MNK{}))>;\n    using TileShape_MNK_QV = Shape<decltype(get<0>(TileShape_MNK{})), decltype(get<1>(TileShape_MNK{})), Int<kHeadDimV>>;\n    using Element = Element_;\n    using ElementAccum = ElementAccum_;\n    using ArchTag = ArchTag_;\n    static constexpr bool Is_FP8 = cute::is_same_v<Element, cutlass::float_e4m3_t> || cute::is_same_v<Element, cutlass::float_e5m2_t>;;\n    static constexpr bool Is_causal = Is_causal_;\n    static constexpr bool Is_local = Is_local_;\n    static constexpr bool Has_softcap = Has_softcap_;\n    static constexpr bool Varlen = Varlen_;\n    static constexpr bool PagedKVNonTMA = PagedKVNonTMA_;\n    static constexpr bool AppendKV = AppendKV_;\n    static constexpr bool HasQv = HasQv_;\n    static constexpr bool PackGQA = PackGQA_;\n    static constexpr bool Split = Split_;\n    static constexpr bool V_colmajor = V_colmajor_;\n    static constexpr bool Transpose_V = Is_FP8 && !V_colmajor;\n    static constexpr bool Use_TMA_Q = !PackGQA;\n    static constexpr bool Use_TMA_KV = !PagedKVNonTMA;\n    static_assert(Use_TMA_KV || CUTE_STATIC_V(size(ClusterShape{})) == 1, \"If not using TMA for KV, ClusterShape must be 1\");\n    static_assert(Use_TMA_KV || !V_colmajor, \"If not using TMA for KV, V_colmajor is not supported\");\n    static constexpr bool SameHeadDim = get<2>(TileShape_MNK{}) == kHeadDimV;\n    static constexpr bool LargeHeadDimV = kHeadDimV > 256;\n\n    static_assert(ArchTag::kMinComputeCapability >= 90);\n\n    static constexpr cute::GMMA::Major MmaMajorV = !Is_FP8 && !V_colmajor ? GMMA::Major::MN : GMMA::Major::K;\n    static constexpr cute::GMMA::Major TmaMajorV = !V_colmajor ? GMMA::Major::MN : GMMA::Major::K;\n\n    static constexpr int kBlockM = get<0>(TileShape_MNK{});\n    static constexpr int kBlockN = get<1>(TileShape_MNK{});\n    static constexpr int kHeadDim = get<2>(TileShape_MNK{});\n\n    using SeqlenInfo_t = flash::SeqlenInfoQKNewK<Varlen, AppendKV>;\n    using BlockMN_t = flash::BlockMN<SeqlenInfo_t, kBlockM, kBlockN, Is_causal, Is_local, PackGQA, Split>;\n\n    static_assert(!LargeHeadDimV || kHeadDimV % 256 == 0);\n    static_assert(!LargeHeadDimV || kBlockM <= 64, \"kBlockM must be 64 or less for large Headdim_V\");\n    static_assert(!LargeHeadDimV || !MmaPV_is_RS, \"MmaPV must be SS for large Headdim_V\");\n\n    // Register bandwidth is actually a bottleneck so we don't want Q to be in registers.\n    // Leaving this option here for reference.\n    static constexpr bool MmaQK_is_RS = false;\n    // We can have MmaPV with P in smem in rmem to reduce register pressure at the cost of more smem.\n    static_assert(!(!MmaPV_is_RS && Is_FP8), \"MmaPV must be RS if FP8\");\n    static_assert(!(!MmaPV_is_RS && Transpose_V), \"MmaPV must be RS if Transpose_V\");\n\n    // Slightly faster in this case to have WG1 use RS instead of SS to avoid waiting for the P smem write\n    static constexpr bool MmaPV_use_RS_WG1 = !MmaPV_is_RS && kHeadDim == 64 && kHeadDimV == 512;\n\n    using AtomLayoutQK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>;\n    using TiledMmaQK = decltype(cute::make_tiled_mma(\n        std::conditional_t<\n            !MmaQK_is_RS,\n            decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShape_MNK>()),\n            decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShape_MNK>())\n        >{},\n        AtomLayoutQK{}));\n    using AtomLayoutPV = std::conditional_t<\n        !LargeHeadDimV,\n        AtomLayoutQK,\n        Layout<Shape<_1, Int<kHeadDimV / 256>, _1>>\n    >;\n    using TiledMmaPV = decltype(cute::make_tiled_mma(\n        std::conditional_t<\n            !MmaPV_is_RS,\n            decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum,\n                     TileShape_MNK_PV, GMMA::Major::K, MmaMajorV>()),\n            decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum,\n                     TileShape_MNK_PV, GMMA::Major::K, MmaMajorV>())\n        >{},\n        AtomLayoutPV{}));\n    using TiledMmaQV = decltype(cute::make_tiled_mma(\n        cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShape_MNK_QV>(),\n        AtomLayoutQK{}));\n    // For hdim64,512, WG1 can use RS but WG2 must use SS\n    using TiledMmaPV_RS = decltype(cute::make_tiled_mma(\n        cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShape_MNK_PV, GMMA::Major::K, MmaMajorV>(),\n        AtomLayoutPV{}));\n\n    static constexpr int NumMmaThreadsQK = size(TiledMmaQK{});\n    static constexpr int NumMmaThreads = size(TiledMmaPV{});\n    static constexpr int NumProducerThreads = !Transpose_V && Use_TMA_KV && Use_TMA_Q ? cutlass::NumThreadsPerWarp : cutlass::NumThreadsPerWarpGroup;\n    static_assert(NumMmaThreadsQK % cutlass::NumThreadsPerWarpGroup == 0);\n    static_assert(NumMmaThreads % cutlass::NumThreadsPerWarpGroup == 0);\n    static constexpr int NumMmaWarpGroups = NumMmaThreads / cutlass::NumThreadsPerWarpGroup;\n    static_assert(NumMmaWarpGroups == 1 || NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3);\n\n    using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,\n        decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());\n    using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));\n\n    using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,\n        decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());\n    using SmemLayoutK = decltype(tile_to_shape(\n        SmemLayoutAtomK{},\n        make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));\n\n    using SmemLayoutAtomVt = decltype(cutlass::gemm::collective::detail::ss_smem_selector<TmaMajorV, Element,\n                                      Int<kHeadDimV>, decltype(cute::get<2>(TileShape_MNK_PV{}))>());\n    using SmemLayoutVt = decltype(tile_to_shape(\n        SmemLayoutAtomVt{},\n        make_shape(Int<kHeadDimV>{}, shape<2>(TileShape_MNK_PV{}), Int<kStages>{}),\n        std::conditional_t<TmaMajorV == GMMA::Major::K, cute::Step<_1, _2, _3>, cute::Step<_2, _1, _3>>{}));\n\n    using SmemLayoutAtomVtMma = decltype(cutlass::gemm::collective::detail::ss_smem_selector<MmaMajorV, Element,\n                                         Int<kHeadDimV>, decltype(cute::get<2>(TileShape_MNK_PV{}))>());\n    using SmemLayoutVtMma = decltype(tile_to_shape(\n        SmemLayoutAtomVtMma{},\n        make_shape(Int<kHeadDimV>{}, shape<2>(TileShape_MNK_PV{}), Int<kStages>{}),\n        std::conditional_t<MmaMajorV == GMMA::Major::K, cute::Step<_1, _2, _3>, cute::Step<_2, _1, _3>>{}));\n\n    using SmemLayoutAtomQv = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,\n        decltype(cute::get<0>(TileShape_MNK_QV{})), decltype(cute::get<2>(TileShape_MNK_QV{}))>());\n    using SmemLayoutQv = decltype(tile_to_shape(SmemLayoutAtomQv{}, select<0, 2>(TileShape_MNK_QV{})));\n    using SmemLayoutAtomVMmaQV = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,\n        decltype(cute::get<1>(TileShape_MNK_QV{})), decltype(cute::get<2>(TileShape_MNK_QV{}))>());\n    using SmemLayoutVMmaQV = decltype(tile_to_shape(\n        SmemLayoutAtomVMmaQV{},\n        make_shape(shape<1>(TileShape_MNK_QV{}), Int<kHeadDimV>{}, Int<kStages>{})));\n    static_assert(CUTE_STATIC_V(size(SmemLayoutVMmaQV{})) == size(SmemLayoutVtMma{}));\n\n    // Only used if we're using cp.async to load V\n    using SmemLayoutAtomVCpAsync = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,\n        decltype(cute::get<1>(TileShape_MNK{})), Int<kHeadDimV>>());\n    using SmemLayoutVCpAsync = decltype(tile_to_shape(\n        SmemLayoutAtomVCpAsync{},\n        make_shape(shape<1>(TileShape_MNK{}), Int<kHeadDimV>{}, Int<kStages>{})));\n\n    using SmemLayoutAtomP = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,\n        decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());\n    using SmemLayoutP = decltype(tile_to_shape(SmemLayoutAtomP{}, select<0, 1>(TileShape_MNK{})));\n\n    // Only for LargeHeadDimV where WG0 sends WG1 the scales\n    using SmemLayoutScale = cute::Layout<cute::Shape<Int<kBlockM>, Int<kStages>>>;\n\n    using SmemCopyAtomP = Copy_Atom<cute::SM90_U32x4_STSM_N, Element>;\n\n    // Use LDSM.T and STSM to transpose V in the case of FP8 and V being row-major.\n    // For FP16/BF16 we don't do any transposing.\n    static_assert(!Transpose_V || (kHeadDimV % 32 == 0 && kBlockN % 32 == 0));\n    static constexpr bool kHeadDimV_multiple_64 = kHeadDimV % 64 == 0;\n    // Either kHeadDimV is a multiple of 64 (in which case we use a block size of 64 x 32 for the transpose),\n    // or we need kBlockN to be a multiple of 64 (in which case we use a block size of 32 x 64 for the transpose).\n    static_assert(!Transpose_V || (kHeadDimV_multiple_64 || kBlockN % 64 == 0));\n    using LDSM_thread_shape  = std::conditional_t<kHeadDimV_multiple_64, Shape<_32, _4, _1, _1>, Shape<_16, _4, _1, _2>>;\n    using LDSM_thread_stride = std::conditional_t<kHeadDimV_multiple_64, Stride<_4, _1, _0, _0>, Stride<_4, _1, _0, _64>>;\n    using LDSM_value_shape = Shape<_2, _2, _1, _4>;\n    using LDSM_value_stride = Stride<_1, _2, _16, _4>;\n    using LDSM_divide_shape = std::conditional_t<kHeadDimV_multiple_64, Shape<_64, _8>, Shape<_32, _8>>;\n    using S2RTiledCopyVt = decltype(make_tiled_copy(\n        Copy_Atom<SM75_U16x8_LDSM_T, Element>{}, Layout<LDSM_thread_shape, LDSM_thread_stride>{},\n        Layout<LDSM_value_shape, LDSM_value_stride>{}));\n\n    using STSM_thread_shape  = std::conditional_t<kHeadDimV_multiple_64, Shape<_8, _4, _4, _1>, Shape<_8, _4, _2, _2>>;\n    using STSM_thread_stride = std::conditional_t<kHeadDimV_multiple_64, Stride<_4, _1, _32, _0>, Stride<_4, _1, _32, _64>>;\n    using STSM_value_shape = Shape<_1, _4, _2, _2>;\n    using STSM_value_stride = Stride<_0, _1, _4, _8>;\n    using STSM_divide_shape = Shape<_8, _16>;\n    // These will not permute the columns of V (the kHeadDimV dimension) but incur bank conflicts\n    // so a little slower (e.g. 1150 TFLOPS for hdim 256 instead of 1200 TFLOPS).\n    // Instead we will permute the cols of V, and un-permute the cols of O in the epilogue.\n    // using STSM_value_shape = Shape<_2, _4, _1, _2>;\n    // using STSM_value_stride = Stride<_4, _1, _0, _8>;\n    // using STSM_divide_shape = Shape<_16, _16>;\n    using R2STiledCopyV = decltype(make_tiled_copy(\n        Copy_Atom<SM90_U32x4_STSM_N, Element>{}, Layout<STSM_thread_shape, STSM_thread_stride>{},\n        Layout<STSM_value_shape, STSM_value_stride>{}));\n\n    using GmemTiledCopyQ = cute::SM90_TMA_LOAD;\n    using GmemTiledCopyKV = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape{})));\n\n    // We use CpAsync for K and V if PagedKVNonTMA and AppendKV, since TMA doesn't work there\n    static constexpr int kHeadDimGCD = cute::gcd(kHeadDim, kHeadDimV);\n    static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);\n    static_assert(kHeadDimGCD % kGmemElemsPerLoad == 0, \"Headdim and HeaddimV must be a multiple of kGmemElemsPerLoad\");\n    // We want each \"row\" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each\n    // thread to have 4 loads in the M direction and 2 vectorized load in the K direction.\n    // We want each thread to have at least 2 loads in the K direction since in the case of non-interleaved\n    // rotary (combining elements at indices 0 and rotary_dim/2, 1 and rotary_dim/2+1, etc), each thread will\n    // load twice from the same row.\n    static constexpr int kBytePerHalfRow = kHeadDimGCD / 2 * sizeof(Element);\n    static constexpr int kBlockKGmem = (kBytePerHalfRow % 128 == 0 ? 128 : (kBytePerHalfRow % 64 == 0 ? 64 : 32)) / sizeof(Element);\n    static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad;\n    static_assert(NumMmaThreads % kGmemThreadsPerRow == 0, \"NumMmaThreads must be a multiple of kGmemThreadsPerRow\");\n    // We assume threads loading the same row are in the same warp. This is for an optimization in PagedKVNonTMA where\n    // these threads share the same page table entry and share the work of computing pointers to paged K and paged V.\n    static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0, \"kGmemThreadsPerRow must divide NumThreadsPerWarp\");\n    using GmemLayoutAtom = Layout<Shape <Int<NumMmaThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,\n                                  Stride<Int<kGmemThreadsPerRow>, _1>>;\n    // If AppendKV, we'll be loading Q for rotary, and we assume divisibility to avoid predication\n    static_assert(!AppendKV || kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0, \"kBlockM must be a multiple of NumMmaThreads / kGmemThreadsPerRow\");\n    using GmemTiledCopyAppendKV = decltype(\n        make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},\n                        GmemLayoutAtom{},\n                        Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{}));  // Val layout, 8 or 16 vals per store\n\n    using ShapeQKV = cute::Shape<int32_t, int32_t, int32_t, int32_t>;  // (seqlen, d, head, batch)\n    using StrideQK = cute::Stride<int64_t, _1, int64_t, int64_t>;\n    using StrideV = std::conditional_t<!V_colmajor, StrideQK, cute::Stride<_1, int64_t, int64_t, int64_t>>;\n    // ((qhead_per_khead, seqlen_q), d, nheads_kv, batch, num_splits)\n    using ShapeQPacked = std::conditional_t<!PackGQA, ShapeQKV, cute::Shape<cute::Shape<int32_t, int32_t>, int32_t, int32_t, int32_t>>;\n    using StrideQPacked = std::conditional_t<!PackGQA, StrideQK, cute::Stride<cute::Stride<int64_t, int64_t>, _1, int64_t, int64_t>>;\n    using ShapePageTable = cute::Shape<int32_t, int32_t>;  // (batch, max_num_pages_per_seq)\n    using StridePageTable = cute::Stride<int64_t, _1>;\n    using ShapeRotary = cute::Shape<int32_t, int32_t>;  // (seqlen_ro, rotary_dim // 2)\n    using StrideRotary = cute::Stride<int64_t, _1>;\n    using StrideDescale = cute::Stride<int64_t, int64_t>;\n\n    using TMA_Q = decltype(make_tma_copy_A_sm90(\n        GmemTiledCopyQ{},\n        make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)), ShapeQKV{}, StrideQK{}),\n        SmemLayoutQ{},\n        TileShape_MNK{},\n        ClusterShape{}));\n\n    using TMA_K = decltype(make_tma_copy_B_sm90(\n        GmemTiledCopyKV{},\n        make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)), ShapeQKV{}, StrideQK{}),\n        take<0, 2>(SmemLayoutK{}),\n        TileShape_MNK{},\n        ClusterShape{})); // mcast along M mode for this N load, if any\n\n    using TMA_V = decltype(make_tma_copy(\n        GmemTiledCopyKV{},\n        make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)), ShapeQKV{}, select<1, 0, 2, 3>(StrideV{})),\n        take<0, 2>(SmemLayoutVt{}),\n        select<1, 2>(TileShape_MNK_PV{}),\n        size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any\n\n    using TMA_Qv_ = decltype(make_tma_copy_A_sm90(\n        GmemTiledCopyQ{},\n        make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)), ShapeQKV{}, StrideQK{}),\n        SmemLayoutQv{},\n        TileShape_MNK_QV{},\n        ClusterShape{}));\n    using TMA_Qv = std::conditional_t<HasQv, TMA_Qv_, std::nullptr_t>;\n\n    // Set the bytes transferred in this TMA transaction (may involve multiple issues)\n    static constexpr uint32_t TmaTransactionBytesQ = static_cast<uint32_t>(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v<Element> / 8);\n    static constexpr uint32_t TmaTransactionBytesK = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v<Element> / 8);\n    static constexpr uint32_t TmaTransactionBytesV = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutVt{})) * cutlass::sizeof_bits_v<Element> / 8);\n    static constexpr uint32_t TmaTransactionBytesQv = static_cast<uint32_t>(size(SmemLayoutQv{}) * cutlass::sizeof_bits_v<Element> / 8);\n\n    using PipelineTmaAsync = std::conditional_t<CUTE_STATIC_V(size(ClusterShape{})) == 1, typename cutlass::PipelineTmaAsyncNoCluster<kStages>, typename cutlass::PipelineTmaAsync<kStages>>;\n    using MainloopPipelineK = std::conditional_t<Use_TMA_KV, PipelineTmaAsync, typename cutlass::PipelineAsync<kStages>>;\n    using MainloopPipelineV = std::conditional_t<!Transpose_V && Use_TMA_KV, PipelineTmaAsync, typename cutlass::PipelineAsync<kStages>>;\n    using MainloopPipelineVt = std::conditional_t<Use_TMA_KV, PipelineTmaAsync, typename cutlass::PipelineAsync<kStages>>;\n    // We always use TMA for K_new and V_new\n    using MainloopPipelineKVNew = PipelineTmaAsync;\n    using PipelineState = cutlass::PipelineState<kStages>;\n\n    // If PackGQA, we use cp.async (instead of TMA) to load Q, so we want smem_q to be aligned\n    // and have sQ being position_independent_swizzle_tensor.\n    // If !Use_TMA_KV, we use cp.async (instead of TMA) to load K & V, so we want smem_k and smem_v to be aligned.\n    static constexpr size_t SmemAlignmentQ = Use_TMA_Q && !MmaQK_is_RS ? 128 : cutlass::detail::alignment_for_swizzle(SmemLayoutQ{});\n    static constexpr size_t SmemAlignmentK = Use_TMA_KV && !AppendKV ? 128 : cutlass::detail::alignment_for_swizzle(SmemLayoutK{});\n    static constexpr size_t SmemAlignmentVtNoTranspose = cutlass::detail::alignment_for_swizzle(SmemLayoutVt{});\n    static constexpr size_t SmemAlignmentQv = Use_TMA_Q ? 128 : cutlass::detail::alignment_for_swizzle(SmemLayoutQv{});\n    static_assert(SmemAlignmentQ >= 128 and SmemAlignmentK >= 128 && SmemAlignmentVtNoTranspose >= 128, \"Require at least 128B alignment\");\n    static constexpr size_t SmemAlignmentP = cutlass::detail::alignment_for_swizzle(SmemLayoutP{});\n    static_assert(SmemAlignmentP >= 128, \"Require at least 128B alignment\");\n\n    using SmemP_t = std::conditional_t<MmaPV_is_RS, cute::array<Element, 0>, cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>, SmemAlignmentP>>;\n    using SmemScale_t = std::conditional_t<!LargeHeadDimV, cute::array<float, 0>, cute::array_aligned<float, cute::cosize_v<SmemLayoutScale>, 128>>;\n    using SmemQv_t = std::conditional_t<!HasQv, cute::array<Element, 0>, cute::array_aligned<Element, cute::cosize_v<SmemLayoutQv>, SmemAlignmentQv>>;\n    // Sometimes even with SmemP_t = cute::array<Element, 0>, putting it in the TensorStorage struct causes\n    // smem size to go from 227KB to 228KB and we get \"invalid argument\".\n\n    struct TensorStorageWithoutPNoTranspose : cute::aligned_struct<cute::max(SmemAlignmentQ, SmemAlignmentK, SmemAlignmentVtNoTranspose), _0> {\n        cute::array_aligned<Element, cute::cosize_v<SmemLayoutVt>, SmemAlignmentVtNoTranspose> smem_v;\n        cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>, SmemAlignmentQ> smem_q;\n        cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>, SmemAlignmentK> smem_k;\n        SmemQv_t smem_qv;\n    };\n\n    struct TensorStorageWithPNoTranspose : cute::aligned_struct<cute::max(SmemAlignmentQ, SmemAlignmentK, SmemAlignmentVtNoTranspose, SmemAlignmentP), _0> {\n        cute::array_aligned<Element, cute::cosize_v<SmemLayoutVt>, SmemAlignmentVtNoTranspose> smem_v;\n        cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>, SmemAlignmentQ> smem_q;\n        cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>, SmemAlignmentK> smem_k;\n        SmemQv_t smem_qv;\n        SmemP_t smem_p;\n    };\n    struct TensorStorageWithPScaleNoTranspose : cute::aligned_struct<cute::max(SmemAlignmentQ, SmemAlignmentK, SmemAlignmentVtNoTranspose, SmemAlignmentP), _0> {\n        cute::array_aligned<Element, cute::cosize_v<SmemLayoutVt>, SmemAlignmentVtNoTranspose> smem_v;\n        cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>, SmemAlignmentQ> smem_q;\n        cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>, SmemAlignmentK> smem_k;\n        SmemQv_t smem_qv;\n        SmemP_t smem_p;\n        SmemScale_t smem_scale;\n    };\n\n    using TensorStorageNoTranspose = std::conditional_t<\n        MmaPV_is_RS,\n        TensorStorageWithoutPNoTranspose,\n        std::conditional_t<!LargeHeadDimV, TensorStorageWithPNoTranspose, TensorStorageWithPScaleNoTranspose>\n    >;\n\n    static constexpr size_t SmemAlignmentVt = cutlass::detail::alignment_for_swizzle(SmemLayoutVt{});\n    static constexpr size_t SmemAlignmentV = cutlass::detail::alignment_for_swizzle(SmemLayoutVtMma{});\n    static_assert(SmemAlignmentVt >= 128 and SmemAlignmentV >= 128, \"Require at least 128B alignment\");\n    struct TensorStorageTransposeV : cute::aligned_struct<cute::max(SmemAlignmentQ, SmemAlignmentK, SmemAlignmentV), _0> {\n        cute::array_aligned<Element, cute::cosize_v<SmemLayoutVtMma>, SmemAlignmentV> smem_v;\n        cute::array_aligned<Element, cute::cosize_v<SmemLayoutVt>, SmemAlignmentVt> smem_vt;\n        cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>, SmemAlignmentQ> smem_q;\n        cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>, SmemAlignmentK> smem_k;\n        SmemQv_t smem_qv;\n        SmemScale_t smem_scale;\n    };\n\n    using TensorStorage = std::conditional_t<!Transpose_V, TensorStorageNoTranspose, TensorStorageTransposeV>;\n\n    // These are tuned for speed. They don't affect correctness.\n    static constexpr bool UseSchedulerBarrier = (IntraWGOverlap\n        ? (NumMmaWarpGroups >= 2) && (!Is_FP8 ? kHeadDim <= 128 : kHeadDim >= 128)\n        : NumMmaWarpGroups == 2)\n        && !LargeHeadDimV;\n    static constexpr bool RescaleOBeforeGemm = kHeadDim > 128 && (!Is_FP8 || V_colmajor) && IntraWGOverlap;\n\n    // Host side kernel arguments\n    struct Arguments {\n        Element const* const ptr_Q;\n        ShapeQKV const shape_Q;\n        StrideQK const stride_Q;\n        Element* const ptr_K;  // not Element const* since we might append to KV cache in-place\n        ShapeQKV const shape_K;\n        StrideQK const stride_K;\n        Element* const ptr_V;\n        int32_t const headdim_v;\n        StrideV const stride_V;\n        Element const* const ptr_K_new;\n        ShapeQKV const shape_K_new;\n        StrideQK const stride_K_new;\n        Element const* const ptr_V_new;\n        StrideV const stride_V_new;\n        Element const* const ptr_Qv;\n        StrideQK const stride_Qv;\n        Element const* const ptr_rotary_cos;\n        ShapeRotary const shape_rotary;\n        StrideRotary const stride_rotary_cos;\n        Element const* const ptr_rotary_sin;\n        StrideRotary const stride_rotary_sin;\n        bool const is_rotary_interleaved;\n        int const* const ptr_pagetable;\n        ShapePageTable const shape_pagetable;\n        StridePageTable const stride_pagetable;\n        float const softmax_scale;\n        float const* ptr_q_descale, *ptr_k_descale, *ptr_v_descale;\n        StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale;\n        int const window_size_left = -1, window_size_right = -1, attention_chunk = 0;\n        float const softcap_val;\n        int const num_splits;\n        int const* const kv_batch_idx = nullptr;\n        int const* const cu_seqlens_q = nullptr;\n        int const* const cu_seqlens_k = nullptr;\n        int const* const cu_seqlens_k_new = nullptr;\n        int const* const seqused_q = nullptr;\n        int const* const seqused_k = nullptr;\n        int const* const leftpad_k = nullptr;\n        int const* const seqlens_rotary = nullptr;\n    };\n\n    // Device side kernel params\n    struct Params {\n        Element const* const ptr_Q;\n        ShapeQKV const shape_Q;\n        StrideQK const stride_Q;\n        ShapeQPacked const shape_Q_packed;\n        StrideQPacked const stride_Q_packed;\n        Element* const ptr_K;\n        ShapeQKV const shape_K;\n        StrideQK const stride_K;\n        Element* const ptr_V;\n        int32_t const headdim_v;\n        StrideV const stride_V;\n        Element const* const ptr_K_new;\n        ShapeQKV const shape_K_new;\n        StrideQK const stride_K_new;\n        Element const* const ptr_V_new;\n        StrideV const stride_V_new;\n        Element const* const ptr_Qv;\n        StrideV const stride_Qv;\n        ShapeQPacked const shape_Qv_packed;\n        StrideQPacked const stride_Qv_packed;\n        Element const* const ptr_rotary_cos;\n        ShapeRotary const shape_rotary;\n        StrideRotary const stride_rotary_cos;\n        Element const* const ptr_rotary_sin;\n        StrideRotary const stride_rotary_sin;\n        bool const is_rotary_interleaved;\n        int const* const ptr_pagetable;\n        ShapePageTable const shape_pagetable;\n        StridePageTable const stride_pagetable;\n        cutlass::FastDivmod page_size_divmod;\n        cutlass::FastDivmod blockN_per_page_size_divmod;\n        cutlass::FastDivmod qhead_per_khead_divmod;\n        TMA_Q tma_load_Q;\n        TMA_K tma_load_K;\n        TMA_V tma_load_V;\n        TMA_K tma_load_K_new;\n        TMA_V tma_load_V_new;\n        TMA_Qv tma_load_Qv;\n        float const softmax_scale_log2;\n        float const* ptr_q_descale, *ptr_k_descale, *ptr_v_descale;\n        StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale;\n        float const softcap_val;\n        int const window_size_left, window_size_right;\n        cutlass::FastDivmod attention_chunk_divmod;\n        int const num_splits;\n        int const* const kv_batch_idx = nullptr;\n        int const* const cu_seqlens_q = nullptr;\n        int const* const cu_seqlens_k = nullptr;\n        int const* const cu_seqlens_k_new = nullptr;\n        int const* const seqused_q = nullptr;\n        int const* const seqused_k = nullptr;\n        int const* const leftpad_k = nullptr;\n        int const *const seqlens_rotary = nullptr;\n    };\n\n    static Params\n    to_underlying_arguments(Arguments const& args) {\n        Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.shape_Q, args.stride_Q);\n        TMA_Q tma_load_Q = make_tma_copy_A_sm90(\n            GmemTiledCopyQ{},\n            mQ,\n            SmemLayoutQ{},\n            TileShape_MNK{},\n            ClusterShape{}); // no mcast for Q\n        Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.shape_K, args.stride_K);\n        TMA_K tma_load_K = make_tma_copy_B_sm90(\n            GmemTiledCopyKV{},\n            mK,\n            take<0, 2>(SmemLayoutK{}),\n            TileShape_MNK{},\n            ClusterShape{}); // mcast along M mode for this N load, if any\n        Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V),\n                                make_shape(args.headdim_v, get<0>(args.shape_K), get<2>(args.shape_K), get<3>(args.shape_K)),\n                                select<1, 0, 2, 3>(args.stride_V));\n        TMA_V tma_load_V = make_tma_copy(\n            GmemTiledCopyKV{},\n            mV,\n            take<0, 2>(SmemLayoutVt{}),\n            select<1, 2>(TileShape_MNK_PV{}),\n            size<0>(ClusterShape{})); // mcast along M mode for this N load, if any\n        Tensor mKnew = make_tensor(make_gmem_ptr(args.ptr_K_new), args.shape_K_new, args.stride_K_new);\n        TMA_K tma_load_K_new = make_tma_copy_B_sm90(\n            GmemTiledCopyKV{},\n            cute::conditional_return<AppendKV>(mKnew, mK),\n            take<0, 2>(SmemLayoutK{}),\n            TileShape_MNK{},\n            ClusterShape{}); // mcast along M mode for this N load, if any\n        Tensor mVnew = make_tensor(make_gmem_ptr(args.ptr_V_new),\n                                   make_shape(args.headdim_v, get<0>(args.shape_K_new), get<2>(args.shape_K_new), get<3>(args.shape_K_new)),\n                                   select<1, 0, 2, 3>(args.stride_V_new));\n        TMA_V tma_load_V_new = make_tma_copy(\n            GmemTiledCopyKV{},\n            cute::conditional_return<AppendKV>(mVnew, mV),\n            take<0, 2>(SmemLayoutVt{}),\n            select<1, 2>(TileShape_MNK_PV{}),\n            size<0>(ClusterShape{})); // mcast along M mode for this N load, if any\n        auto shape_Qv = make_shape(get<0>(args.shape_Q), args.headdim_v, get<2>(args.shape_Q), get<3>(args.shape_Q));\n        Tensor mQv = make_tensor(make_gmem_ptr(args.ptr_Qv), shape_Qv, args.stride_Qv);\n        TMA_Qv tma_load_Qv = [&] {\n            if constexpr (HasQv) {\n                return make_tma_copy_A_sm90(\n                    GmemTiledCopyQ{},\n                    mQv,\n                    SmemLayoutQv{},\n                    TileShape_MNK_QV{},\n                    ClusterShape{}); // no mcast for Qv\n            } else {\n                return nullptr;\n            }\n        }();\n        // If PackGQA, reshape Q to be ((qhead_per_khead, seqlen_q), head_size, nhead_k, batch_size)\n        int const qhead_per_khead = !PackGQA ? 1 : cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K));\n        auto const shape_Q_packed = cute::conditional_return<!PackGQA>(\n            args.shape_Q,\n            make_shape(make_shape(qhead_per_khead, get<0>(args.shape_Q)), get<1>(args.shape_Q), get<2>(args.shape_K), get<3>(args.shape_Q))\n        );\n        auto const stride_Q_packed = cute::conditional_return<!PackGQA>(\n            args.stride_Q,\n            make_stride(make_stride(get<2>(args.stride_Q), get<0>(args.stride_Q)), get<1>(args.stride_Q), get<2>(args.stride_Q) * qhead_per_khead, get<3>(args.stride_Q))\n        );\n        auto const shape_Qv_packed = cute::conditional_return<!PackGQA>(\n            shape_Qv,\n            make_shape(make_shape(qhead_per_khead, get<0>(shape_Qv)), get<1>(shape_Qv), get<2>(args.shape_K), get<3>(shape_Qv))\n        );\n        auto const stride_Qv_packed = cute::conditional_return<!PackGQA>(\n            args.stride_Qv,\n            make_stride(make_stride(get<2>(args.stride_Qv), get<0>(args.stride_Qv)), get<1>(args.stride_Qv), get<2>(args.stride_Qv) * qhead_per_khead, get<3>(args.stride_Qv))\n        );\n        if (get<1>(args.shape_rotary) > 0) {\n            assert(args.ptr_rotary_cos != nullptr && args.ptr_rotary_sin != nullptr);\n        }\n        assert(args.num_splits >= 1);\n        int page_size = !args.ptr_pagetable ? 1 : get<0>(args.shape_K);\n        if (!PagedKVNonTMA && args.ptr_pagetable != nullptr) {\n            assert(page_size % kBlockN == 0);\n            assert(!args.leftpad_k);\n        }\n        // Avoid dividing by zero\n        cutlass::FastDivmod attention_chunk_divmod(args.attention_chunk >= 1 ? args.attention_chunk : 1);\n        attention_chunk_divmod.divisor = args.attention_chunk;\n        // If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val.\n        // Right after this, we multiply by log2(e) before applying exp2.\n        // To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val\n        // (assigning it to params.softcap_val) and pre-multiply softcap_val * log2(e)\n        // (assigning it to params.softmax_scale_log2).\n        return {args.ptr_Q, args.shape_Q, args.stride_Q, shape_Q_packed, stride_Q_packed,\n                args.ptr_K, args.shape_K, args.stride_K, args.ptr_V, args.headdim_v, args.stride_V,\n                args.ptr_K_new, args.shape_K_new, args.stride_K_new, args.ptr_V_new, args.stride_V_new,\n                args.ptr_Qv, args.stride_Qv, shape_Qv_packed, stride_Qv_packed,\n                args.ptr_rotary_cos, args.shape_rotary, args.stride_rotary_cos,\n                args.ptr_rotary_sin, args.stride_rotary_sin, args.is_rotary_interleaved,\n                args.ptr_pagetable, args.shape_pagetable, args.stride_pagetable,\n                cutlass::FastDivmod(page_size),  // page_size_divmod\n                cutlass::FastDivmod(!args.ptr_pagetable ? 1 : cute::ceil_div(page_size, kBlockN)),  // blockN_per_page_size_divmod\n                cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))),\n                tma_load_Q, tma_load_K, tma_load_V, tma_load_K_new, tma_load_V_new, tma_load_Qv,\n                !Has_softcap ? float(args.softmax_scale * M_LOG2E) : float(args.softcap_val * M_LOG2E),\n                args.ptr_q_descale, args.ptr_k_descale, args.ptr_v_descale,\n                args.stride_q_descale, args.stride_k_descale, args.stride_v_descale,\n                !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val,\n                args.window_size_left, args.window_size_right, attention_chunk_divmod,\n                !Split ? 1 : args.num_splits,\n                args.kv_batch_idx,\n                args.cu_seqlens_q, args.cu_seqlens_k, args.cu_seqlens_k_new,\n                args.seqused_q, args.seqused_k, args.leftpad_k, args.seqlens_rotary};\n    }\n\n    /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance\n    CUTLASS_DEVICE\n    static void prefetch_tma_descriptors(Params const& params) {\n        if constexpr (Use_TMA_Q) {\n            cute::prefetch_tma_descriptor(params.tma_load_Q.get_tma_descriptor());\n            if constexpr (HasQv) {\n                cute::prefetch_tma_descriptor(params.tma_load_Qv.get_tma_descriptor());\n            }\n        }\n        if constexpr (Use_TMA_KV) {\n            cute::prefetch_tma_descriptor(params.tma_load_K.get_tma_descriptor());\n            cute::prefetch_tma_descriptor(params.tma_load_V.get_tma_descriptor());\n        }\n        if constexpr (AppendKV) {\n            cute::prefetch_tma_descriptor(params.tma_load_K_new.get_tma_descriptor());\n            cute::prefetch_tma_descriptor(params.tma_load_V_new.get_tma_descriptor());\n        }\n    }\n\n    template <typename SchedulerPrefetch, typename SharedStorage>\n    CUTLASS_DEVICE void\n    load(Params const& params,\n         MainloopPipelineK pipeline_k,\n         MainloopPipelineV pipeline_v,\n         MainloopPipelineVt pipeline_vt,\n         PipelineState& smem_pipe_write,\n         SharedStorage &shared_storage,\n         SchedulerPrefetch const& scheduler_prefetch,\n         SeqlenInfo_t const& seqlen_info,\n         cute::tuple<int32_t, int32_t, int32_t, int32_t> block_coord,\n         int &work_idx\n         ) {\n\n        // some of these are captured in lambda so can't use structured binding\n        int const m_block = get<0>(block_coord);\n        int const bidh = get<1>(block_coord);\n        int const bidb = get<2>(block_coord);\n        int const split_idx = get<3>(block_coord);\n        auto [n_block_min, n_block_max] = BlockMN_t::get_n_block_min_max(\n            seqlen_info, m_block, bidb, split_idx, params.num_splits,\n            params.window_size_left, params.window_size_right, params.attention_chunk_divmod,\n            params.qhead_per_khead_divmod);\n        // It's possible to have n_block_max <= n_block_min. Loading K can cause illegal memory access.\n        if constexpr (Is_causal || Is_local || Varlen || Split) {\n            if (n_block_max <= n_block_min) {\n                scheduler_prefetch();\n                return;\n            }\n        }\n\n        Tensor sQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQ{});\n        Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{});\n        Tensor sK_pi = as_position_independent_swizzle_tensor(sK);\n        // as_position_independent_swizzle_tensor makes address calculation easier when we do LDSM & STSM to transpose.\n        // But it requires smem_vt and smem_v to be aligned to e.g 512 bytes.\n        Tensor sVt = [&] {\n            if constexpr (!Transpose_V) {\n                return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVt{});\n            } else {\n                return cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_vt.data()), SmemLayoutVt{}));\n            }\n        }();\n        // Only used if Transpose_V\n        Tensor sV = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVtMma{}));\n        // Only used if we're using cp.async to load V\n        Tensor sVcpasync = [&] {\n            if constexpr (!Transpose_V) {\n                return cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVCpAsync{}));\n            } else {\n                return cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_vt.data()), SmemLayoutVCpAsync{}));\n            }\n        }();\n        Tensor sQv = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_qv.data()), SmemLayoutQv{});\n\n        int const thread_idx = threadIdx.x % NumProducerThreads;\n        int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh;\n        int const bidb_kv = params.kv_batch_idx == nullptr ? bidb : params.kv_batch_idx[bidb];\n\n        // Prepare the TMA loads\n        uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();\n        constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());\n        uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};\n\n        bool const is_varlen_q = Varlen && params.cu_seqlens_q;\n        bool const is_varlen_k = Varlen && params.cu_seqlens_k;\n        Tensor mQ = params.tma_load_Q.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0);\n        Tensor mK_TMA = params.tma_load_K.get_tma_tensor(params.shape_K)(_, _, bidh_kv, _);\n        auto shape_V = make_shape(params.headdim_v, get<0>(params.shape_K), get<2>(params.shape_K), get<3>(params.shape_K));\n        Tensor mVt_TMA = params.tma_load_V.get_tma_tensor(shape_V)(_, _, bidh_kv, _);\n\n        Tensor gQ = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQ), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{}));  // (M, K)\n        // if (cute::thread0()) { printf(\"Varlen = %d, params.leftpad_k = %p, leftpad_k = %d\\n\", Varlen, params.leftpad_k, leftpad_k); }\n        Tensor gK_TMA = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}, _0{}), mK_TMA), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{}, _));  // (N, K, _, _)\n        Tensor gVt_TMA = local_tile(domain_offset(make_coord(_0{}, seqlen_info.offset_k, _0{}), mVt_TMA), select<1, 2>(TileShape_MNK_PV{}), make_coord(_0{}, _, _));  // (K, N, _, _)\n\n        auto block_tma_Q = params.tma_load_Q.get_slice(_0{});\n        Tensor tQgQ = group_modes<0, 3>(block_tma_Q.partition_S(gQ));  // (TMA)\n        Tensor tQsQ = group_modes<0, 3>(block_tma_Q.partition_D(sQ));  // (TMA)\n        if (Use_TMA_Q && thread_idx == 0) { prefetch(params.tma_load_Q, tQgQ); }\n        // tma_partition doesn't handle position_independent_swizzle_tensor correctly, so we need to do it manually\n        auto block_tma_K = params.tma_load_K.get_slice(cluster_local_block_id.x);\n        Tensor tKgK_TMA = group_modes<0, 3>(block_tma_K.partition_S(gK_TMA));  // (TMA, k, batch)\n        Tensor tKsK_TMA = group_modes<0, 3>(block_tma_K.partition_D(sK));  // (TMA, PIPE)\n        auto block_tma_V = params.tma_load_V.get_slice(cluster_local_block_id.x);\n        Tensor tVgVt_TMA = group_modes<0, 3>(block_tma_V.partition_S(gVt_TMA));  // (TMA, k, batch)\n        Tensor tVsVt_TMA = group_modes<0, 3>(block_tma_V.partition_D(sVt));  // (TMA, PIPE)\n        auto [tQvgQv, tQvsQv] = [&] {\n            if constexpr (HasQv) {\n                auto shape_Qv = make_shape(get<0>(params.shape_Q), params.headdim_v, get<2>(params.shape_Q), get<3>(params.shape_Q));\n                Tensor mQv = params.tma_load_Qv.get_tma_tensor(shape_Qv)(_, _, bidh, !is_varlen_q ? bidb : 0);\n                Tensor gQv = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQv), select<0, 2>(TileShape_MNK_QV{}), make_coord(m_block, _0{}));  // (M, Kv)\n                auto block_tma_Qv = params.tma_load_Qv.get_slice(_0{});\n                Tensor tQvgQv = group_modes<0, 3>(block_tma_Qv.partition_S(gQv));  // (TMA)\n                Tensor tQvsQv = group_modes<0, 3>(block_tma_Qv.partition_D(sQv));  // (TMA)\n                return cute::make_tuple(tQvgQv, tQvsQv);\n            } else {\n                return cute::make_tuple(nullptr, nullptr);\n            }\n        }();\n\n        // This is used to index into the batch dimension of mK and mV\n        int const bidb_kv_idx = !is_varlen_k && !params.ptr_pagetable ? bidb_kv : 0;\n\n        using PagedKVManager_t = PagedKVManager<get<1>(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumProducerThreads, Element, Transpose_V || !IntraWGOverlap /*KV_Same_Iter*/>;\n        PagedKVManager_t paged_kv_manager(\n            params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable,\n            params.ptr_K, params.shape_K, params.stride_K,\n            params.ptr_V, params.headdim_v, params.stride_V,\n            params.page_size_divmod, params.blockN_per_page_size_divmod,\n            bidb_kv, bidh_kv, thread_idx, seqlen_info.seqlen_k, seqlen_info.leftpad_k, bidb_kv_idx\n        );\n\n        // Set up for transposing V, only used if Transpose_V\n        S2RTiledCopyVt s2r_tiled_copy_vt;\n        R2STiledCopyV r2s_tiled_copy_v;\n        auto s2r_thr_copy_vt = s2r_tiled_copy_vt.get_thread_slice(thread_idx);\n        auto r2s_thr_copy_v = r2s_tiled_copy_v.get_thread_slice(thread_idx);\n        // flat_divide(sVt, LDSM_divide_shape{}):  (64, 8, kHeadDim / 64, kBlockN / 8, kStages)\n        Tensor tTranssVt_ = s2r_thr_copy_vt.partition_S(flat_divide(sVt, LDSM_divide_shape{}));  // ((16, 1), 1, 1, kHeadDim / 64, kBlockN / 32, kStages)\n        // flat_divide(sV, STSM_divide_shape{}):  (8, 16, kHeadDim / 8, (4, kBlockN / 64), kStages)\n        Tensor tTranssV_ = r2s_thr_copy_v.partition_D(flat_divide(sV, STSM_divide_shape{}));  // ((16, 1), 1, 1, kHeadDim / 64, (2, kBlockN / 64), kStages)\n        CUTE_STATIC_ASSERT_V(rank(tTranssVt_) == rank(tTranssV_));\n        CUTE_STATIC_ASSERT_V(size<0>(tTranssVt_) == size<0>(tTranssV_));\n        CUTE_STATIC_ASSERT_V(size<1>(tTranssVt_) == size<1>(tTranssV_));\n        CUTE_STATIC_ASSERT_V(size<2>(tTranssVt_) == size<2>(tTranssV_));\n        CUTE_STATIC_ASSERT_V(size<3>(tTranssVt_) == size<3>(tTranssV_));\n        CUTE_STATIC_ASSERT_V(size<4>(tTranssVt_) == size<4>(tTranssV_));\n        // Faster to have 2 LDSM.T, byte permute, STSM for better ILP\n        static constexpr int Transpose_ILP = (size<2>(tTranssVt_) * size<3>(tTranssVt_)) % 2 == 0 ? 2 : 1;\n        Tensor tTranssVt = logical_divide(group_modes<1, rank(tTranssVt_) - 1>(tTranssVt_), Shape<Underscore, Int<Transpose_ILP>>{});  // ((16, 1), (2, kHeadDim / 64 * kBlockN / 32 / 2), kStages)\n        Tensor tTranssV = logical_divide(group_modes<1, rank(tTranssV_) - 1>(tTranssV_), Shape<Underscore, Int<Transpose_ILP>>{});  // ((16, 1), (2, kHeadDim / 64 * kBlockN / 32 / 2), kStages)\n        auto transpose_V = [&](int stage) {\n            if constexpr (Transpose_V) {\n                #pragma unroll\n                for (int i = 0; i < size<1, 1>(tTranssVt); ++i) {\n                    Tensor tTransrV = make_fragment_like(tTranssV(_, make_coord(_, _0{}), _0{}));\n                    static_assert(size<0>(tTransrV) == 16);\n                    Tensor tTransrV_64 = recast<uint2>(tTransrV);\n                    cute::copy(s2r_tiled_copy_vt, tTranssVt(_, make_coord(_, i), stage), tTransrV);\n                    #pragma unroll\n                    for (int j = 0; j < size(tTransrV_64); ++j) {\n                        uint32_t upper = tTransrV_64[j].x;\n                        uint32_t lower = tTransrV_64[j].y;\n                        tTransrV_64[j].x = __byte_perm(upper, lower, 0x6420);\n                        tTransrV_64[j].y = __byte_perm(upper, lower, 0x7531);\n                    }\n                    cute::copy(r2s_tiled_copy_v, tTransrV, tTranssV(_, make_coord(_, i), stage));\n                }\n            }\n        };\n\n        uint16_t mcast_mask_kv = 0;\n        if constexpr (cute::is_same_v<GmemTiledCopyKV, SM90_TMA_LOAD_MULTICAST>) {\n            auto block_layout = Layout<ClusterShape>{}; // (m,n) -> block_id\n            for (int m = 0; m < size<0>(block_layout); ++m) {\n                mcast_mask_kv |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, _0{}));\n            }\n        }\n\n        auto load_K = [&] (int const n_block, auto const& smem_pipe_write, auto need_seqlenk_masking_type) {\n            pipeline_k.producer_acquire(smem_pipe_write);\n            if constexpr (!PagedKVNonTMA) {\n                auto [n_block_idx, bidb_kv_idx] = paged_kv_manager.get_indices_for_K_TMA();\n                copy(params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv, TMA::CacheHintSm90::EVICT_LAST),\n                    tKgK_TMA(_, n_block_idx, bidb_kv_idx), tKsK_TMA(_, smem_pipe_write.index()));\n            } else {\n                constexpr bool Seqlenk_mask = decltype(need_seqlenk_masking_type)::value;\n                paged_kv_manager.template load_K<Seqlenk_mask>(n_block, sK_pi(_, _, smem_pipe_write.index()));\n                pipeline_k.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive);\n            }\n        };\n\n        auto load_V = [&] (int const n_block, auto const& smem_pipe_write, auto need_seqlenk_masking_type) {\n            auto pipeline_v_load = cute::conditional_return<!Transpose_V>(pipeline_v, pipeline_vt);\n            pipeline_v_load.producer_acquire(smem_pipe_write);\n            if constexpr (!PagedKVNonTMA) {\n                auto [n_block_idx, bidb_kv_idx] = paged_kv_manager.get_indices_for_V_TMA();\n                copy(params.tma_load_V.with(*pipeline_v_load.producer_get_barrier(smem_pipe_write), mcast_mask_kv, TMA::CacheHintSm90::EVICT_LAST),\n                    tVgVt_TMA(_, n_block_idx, bidb_kv_idx), tVsVt_TMA(_, smem_pipe_write.index()));\n            } else {\n                constexpr bool Seqlenk_mask = decltype(need_seqlenk_masking_type)::value;\n                paged_kv_manager.template load_V<Seqlenk_mask>(n_block, sVcpasync(_, _, smem_pipe_write.index()));\n                pipeline_v_load.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive);\n            }\n        };\n\n        auto copy_Vt_to_V = [&] (auto const& smem_pipe_write) {\n            // Instead of maintaining smem_pipe_read as a separate variable, we can just use smem_pipe_write,\n            // and exploit the invariance that smem_pipe_write.phase() == smem_pipe_read.phase() ^ 1.\n            // This saves 1 or 2 registers.\n            PipelineState smem_pipe_read{smem_pipe_write.index(), smem_pipe_write.phase() ^ 1, smem_pipe_write.count()};\n            pipeline_vt.consumer_wait(smem_pipe_read);\n            pipeline_v.producer_acquire(smem_pipe_write);\n            transpose_V(smem_pipe_write.index());\n            // SMEM fence to make sure V is transposed before math\n            cutlass::arch::fence_view_async_shared();\n            pipeline_v.producer_commit(smem_pipe_write);\n            // Very important: PipelineTmaAsync::consumer_release assumes that the warpgroup is synchronized\n            // before calling. Without this we get race conditions.\n            cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup, cutlass::arch::ReservedNamedBarriers::TransposeBarrier /*id*/);\n            pipeline_vt.consumer_release(smem_pipe_read);\n        };\n\n        int n_block = n_block_max - 1;\n\n        int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);\n        // If this is true, we're guaranteed that only the first warp will execute this function\n        static constexpr bool SingleProducerWarp = NumProducerThreads == cutlass::NumThreadsPerWarp;\n        bool should_load_KV = !Use_TMA_KV || ((SingleProducerWarp || warp_idx_in_warpgroup == 0) && cute::elect_one_sync());\n\n        if (should_load_KV) {\n            if constexpr (PagedKVNonTMA) {\n                paged_kv_manager.template load_page_table<true /*Seqlenk_mask*/, true /*First_iter*/>(n_block);\n            } else {\n                paged_kv_manager.template load_page_table_TMA<true /*First_iter*/>(n_block);\n            }\n            if constexpr (Transpose_V) { load_V(n_block, smem_pipe_write, cute::true_type{} /*Seqlenk_mask*/); }\n            // if (thread_idx == 0) { printf(\"Producer: main load, before load_K, index = %d\\n\", smem_pipe_write.index());}\n            load_K(n_block, smem_pipe_write, cute::true_type{} /*Seqlenk_mask*/);\n            // if (thread_idx == 0) { printf(\"Producer: main load, after load K, index = %d\\n\", smem_pipe_write.index());}\n        }\n\n        if constexpr (Use_TMA_Q) {\n            // Wait for the MMA warpgroups to signal that smem_q is ready\n            if (SingleProducerWarp || warp_idx_in_warpgroup == 0) {\n                cutlass::arch::NamedBarrier::sync(NumMmaThreadsQK + cutlass::NumThreadsPerWarp, static_cast<uint32_t>(FwdNamedBarriers::QueryEmpty) /*id*/);\n            }\n\n            if ((SingleProducerWarp || warp_idx_in_warpgroup == 0) && cute::elect_one_sync()) {\n                shared_storage.pipelines.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ);\n                copy(params.tma_load_Q.with(reinterpret_cast<typename cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.pipelines.barrier_Q), 0 /*mcast_mask*/, !Split ? TMA::CacheHintSm90::EVICT_FIRST : TMA::CacheHintSm90::EVICT_LAST),\n                    tQgQ, tQsQ);\n                if constexpr (HasQv) {\n                    shared_storage.pipelines.barrier_Qv.arrive_and_expect_tx(TmaTransactionBytesQv);\n                    copy(params.tma_load_Qv.with(reinterpret_cast<typename cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.pipelines.barrier_Qv), 0 /*mcast_mask*/, !Split ? TMA::CacheHintSm90::EVICT_FIRST : TMA::CacheHintSm90::EVICT_LAST),\n                        tQvgQv, tQvsQv);\n                }\n            }\n        } else {  // Load Q with cp.async\n            cutlass::arch::NamedBarrier::sync(NumMmaThreadsQK + NumProducerThreads, static_cast<uint32_t>(FwdNamedBarriers::QueryEmpty) /*id*/);\n            Tensor mQ = make_tensor(make_gmem_ptr(params.ptr_Q + seqlen_info.offset_q * get<0>(params.stride_Q)), params.shape_Q_packed, params.stride_Q_packed)(_, _, bidh, !is_varlen_q ? bidb : 0);\n            Tensor sQ_pi = cute::as_position_independent_swizzle_tensor(sQ);\n            using PackGQAt = flash::PackGQAManager<get<0>(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumProducerThreads, Element>;\n            PackGQAt::load_Q(mQ, sQ_pi, params.qhead_per_khead_divmod, thread_idx, seqlen_info.seqlen_q, m_block);\n            auto &barrier_Q = shared_storage.pipelines.barrier_Q;\n            cutlass::arch::cpasync_barrier_arrive(reinterpret_cast<uint64_t*>(&barrier_Q));\n            barrier_Q.arrive();\n            if constexpr (HasQv) {\n                Tensor mQv = make_tensor(make_gmem_ptr(params.ptr_Qv + seqlen_info.offset_q * get<0>(params.stride_Qv)), params.shape_Qv_packed, params.stride_Qv_packed)(_, _, bidh, !is_varlen_q ? bidb : 0);\n                Tensor sQv_pi = cute::as_position_independent_swizzle_tensor(sQv);\n                using PackGQAt = flash::PackGQAManager<get<0>(TileShape_MNK_QV{}), get<2>(TileShape_MNK_QV{}), NumProducerThreads, Element>;\n                PackGQAt::load_Q(mQv, sQv_pi, params.qhead_per_khead_divmod, thread_idx, seqlen_info.seqlen_q, m_block);\n                auto &barrier_Qv = shared_storage.pipelines.barrier_Qv;\n                cutlass::arch::cpasync_barrier_arrive(reinterpret_cast<uint64_t*>(&barrier_Qv));\n                barrier_Qv.arrive();\n            }\n        }\n\n        // Wait for the MMA WGs to signal that smem_v are ready and V can be copied from gmem\n        // Need ClusterBarrier, not just NamedBarrier. Otherwise we might have CTA 0 finishing the\n        // TMA store on O first, call TMA multicast load on V, before CTA 1 can finishing TMA store on O.\n        // if (thread_idx == 0) { printf(\"Producer: main load, before barrier_O, work_idx = %d\\n\", work_idx);}\n        shared_storage.pipelines.barrier_O.wait((work_idx + 1) % 2);\n        // if (thread_idx == 0) { printf(\"Producer: main load, after barrier_O\\n\");}\n\n        if constexpr (!Transpose_V && !IntraWGOverlap) {\n            if (should_load_KV) { load_V(n_block, smem_pipe_write, cute::true_type{} /*Seqlenk_mask*/); }\n        }\n        int n_block_prev = n_block;\n        --n_block;\n        #pragma unroll (!Transpose_V && Use_TMA_KV ? 2 : 1)\n        for (; n_block >= n_block_min; --n_block) {\n            PipelineState smem_pipe_write_v = smem_pipe_write; // copy the state, write_v is always 1 step behind\n            ++smem_pipe_write;\n            if (should_load_KV) {\n                if constexpr (PagedKVNonTMA) {\n                    paged_kv_manager.template load_page_table<false /*Seqlenk_mask*/>(n_block);\n                } else {\n                    paged_kv_manager.load_page_table_TMA(n_block);\n                }\n                if constexpr (Transpose_V) { load_V(n_block, smem_pipe_write, cute::false_type{} /*Seqlenk_mask*/); }\n                load_K(n_block, smem_pipe_write, cute::false_type{} /*Seqlenk_mask*/);\n                if constexpr (!Transpose_V) {\n                    if constexpr (IntraWGOverlap) {\n                        load_V(n_block_prev, smem_pipe_write_v, cute::true_type{} /*Seqlenk_mask*/);\n                    } else {\n                        load_V(n_block, smem_pipe_write, cute::false_type{} /*Seqlenk_mask*/);\n                    }\n                }\n            }\n            n_block_prev = n_block;\n            if constexpr (Transpose_V) { copy_Vt_to_V(smem_pipe_write_v); }\n        }\n        scheduler_prefetch();\n        if constexpr (!Transpose_V && IntraWGOverlap) {\n            if (should_load_KV) { load_V(n_block_prev, smem_pipe_write, cute::true_type{} /*Seqlenk_mask*/); }\n        }\n        if constexpr (Transpose_V) { copy_Vt_to_V(smem_pipe_write); }\n        ++smem_pipe_write;\n        // At the end, all threads have the correct smem_pipe_write.\n        ++work_idx;\n    }\n\n    template <typename SharedStorage>\n    CUTLASS_DEVICE void\n    load_tail(MainloopPipelineK pipeline_k, MainloopPipelineV pipeline_v, MainloopPipelineVt pipeline_vt,\n              PipelineState& smem_pipe_write, SharedStorage &shared_storage, int const work_idx) {\n        // If we don't wait for barrier_O here, when using Cluster, CTA0 might exit early and CTA1 will\n        // try to arrive on barrier_O of CTA0, causing \"unspecified launch failure\".\n        shared_storage.pipelines.barrier_O.wait((work_idx + 1) % 2);\n        int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);\n        // Issue the epilogue waits\n        // TODO: check if this should be called by 1 thread or more\n        if (warp_idx_in_warpgroup == 0 && cute::elect_one_sync()) {\n            /* This helps avoid early exit of blocks in Cluster\n            *  Waits for all stages to either be released (all Consumer UNLOCKs), or if the stage was never used\n            *  then would just be acquired since the phase was still inverted from make_producer_start_state\n            */\n            pipeline_k.producer_tail(smem_pipe_write);\n            pipeline_v.producer_tail(smem_pipe_write);\n            if constexpr (Transpose_V) { pipeline_vt.producer_tail(smem_pipe_write); }\n        }\n    }\n\n    CUTLASS_DEVICE void\n    warp_scheduler_barrier_sync() {\n        if constexpr (UseSchedulerBarrier) {\n            cutlass::arch::NamedBarrier::sync(2 * cutlass::NumThreadsPerWarpGroup, static_cast<uint32_t>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + flash::canonical_warp_group_idx_nosync() /*id*/);\n        }\n    }\n\n    CUTLASS_DEVICE void\n    warp_scheduler_barrier_arrive() {\n        if constexpr (UseSchedulerBarrier) {\n            static_assert(NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3);\n            int const cur_WG = flash::canonical_warp_group_idx_nosync() - 1;\n            int const next_WG = NumMmaWarpGroups == 2\n                ? 1 - cur_WG\n                : (cur_WG < NumMmaWarpGroups - 1 ? cur_WG + 1 : 0);\n            cutlass::arch::NamedBarrier::arrive(2 * cutlass::NumThreadsPerWarpGroup, static_cast<uint32_t>(FwdNamedBarriers::WarpSchedulerWG1) + next_WG /*id*/);\n        }\n    }\n\n    CUTLASS_DEVICE void\n    mma_init() {\n        int warp_group_idx = flash::canonical_warp_group_idx_nosync();\n        // Tell producers that smem_q is ready\n        if (!LargeHeadDimV || warp_group_idx == 1) {\n            cutlass::arch::NamedBarrier::arrive(NumMmaThreadsQK + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast<uint32_t>(FwdNamedBarriers::QueryEmpty) /*id*/);\n        }\n        if (LargeHeadDimV && warp_group_idx > 1) {\n            cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<uint32_t>(FwdNamedBarriers::PEmpty) /*id*/);\n        }\n        if constexpr (UseSchedulerBarrier) {\n            // We have NamedBarrier for up to 3 WGs\n            static_assert(NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3);\n            // WG1 needs the very first signal to start\n            if (warp_group_idx == 1) {\n                cutlass::arch::NamedBarrier::arrive(2 * cutlass::NumThreadsPerWarpGroup, static_cast<uint32_t>(FwdNamedBarriers::WarpSchedulerWG1) /*id*/);\n            }\n        }\n    }\n\n    template <typename SharedStorage, typename FrgTensorO, typename Softmax>\n    CUTLASS_DEVICE bool\n    mma(Params const& params,\n        MainloopPipelineK pipeline_k,\n        MainloopPipelineV pipeline_v,\n        PipelineState& smem_pipe_read,\n        FrgTensorO& tOrO,\n        Softmax& softmax,\n        int const thread_idx,\n        int &work_idx,\n        SeqlenInfo_t const& seqlen_info,\n        cute::tuple<int32_t, int32_t, int32_t, int32_t> block_coord,\n        SharedStorage& shared_storage\n        ) {\n        static_assert(is_rmem<FrgTensorO>::value, \"O tensor must be rmem resident.\");\n        static constexpr int kBlockM = get<0>(TileShape_MNK{});\n        static constexpr int kBlockN = get<1>(TileShape_MNK{});\n\n        // can't use auto [m_block, ...] = block_coord since structured binding cannot be captured in lambda\n        int const m_block = get<0>(block_coord);\n        int const bidh = get<1>(block_coord);\n        int const bidb = get<2>(block_coord);\n        int const split_idx = get<3>(block_coord);\n        int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh;\n        auto [n_block_min, n_block_max] = BlockMN_t::get_n_block_min_max(\n            seqlen_info, m_block, bidb, split_idx, params.num_splits,\n            params.window_size_left, params.window_size_right, params.attention_chunk_divmod,\n            params.qhead_per_khead_divmod);\n        // It's possible to have n_block_max <= n_block_min. We don't want to load Q or change any barrier\n        if constexpr (Is_causal || Is_local || Varlen || Split) {\n            if (n_block_max <= n_block_min) { return false; }\n        }\n\n        Tensor sQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQ{});\n        Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{});\n        Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVtMma{});\n        Tensor sP = [&] {\n            if constexpr (MmaPV_is_RS) {\n                // We might not have smem_p if !MmaPV_is_RS, just use smem_q as a placeholder since we don't use it\n                return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutP{});\n            } else {\n                return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutP{});\n            }\n        }();\n        Tensor sScale = [&] {\n            if constexpr (LargeHeadDimV) {\n                return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_scale.data()), SmemLayoutScale{});\n            } else { // won't be used, just a placeholder\n                return make_tensor(make_smem_ptr(static_cast<float*>(nullptr)), SmemLayoutScale{});\n            }\n        }();\n        Tensor sQv = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_qv.data()), SmemLayoutQv{});\n        Tensor sVMmaQV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVMmaQV{});\n\n        if constexpr (!MmaQK_is_RS) {\n            static_assert(stride<0>(typename TiledMmaQK::ALayout{}) == 0 and\n                        stride<0>(typename TiledMmaQK::BLayout{}) == 0 and\n                        size<0>(typename TiledMmaQK::ALayout{}) == cutlass::NumThreadsPerWarpGroup and\n                        size<0>(typename TiledMmaQK::BLayout{}) == cutlass::NumThreadsPerWarpGroup,\n                \"Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup\");\n        }\n        static constexpr int MmaWarpGroups = size(TiledMmaPV{}) / cutlass::NumThreadsPerWarpGroup;\n        Layout warp_group_thread_layout = make_layout(make_shape(Int<MmaWarpGroups>{}),\n                                                      make_stride(Int<cutlass::NumThreadsPerWarpGroup>{}));\n\n        int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0);\n        TiledMmaQK tiled_mma_qk;\n        TiledMmaPV tiled_mma_pv;\n        TiledMmaQV tiled_mma_qv;\n        auto wg_mma_qk = tiled_mma_qk.get_slice(warp_group_thread_layout(warp_group_idx));\n        auto wg_mma_pv = tiled_mma_pv.get_slice(warp_group_thread_layout(warp_group_idx));\n        auto wg_mma_qv = tiled_mma_qv.get_slice(warp_group_thread_layout(warp_group_idx));\n\n        auto smem_tiled_copy_P = make_tiled_copy_C(SmemCopyAtomP{}, tiled_mma_qk);\n        auto smem_thr_copy_P = smem_tiled_copy_P.get_thread_slice(thread_idx);\n\n        // Allocate \"fragments/descriptors\"\n        Tensor tSrQ = wg_mma_qk.partition_fragment_A(sQ);\n        Tensor tSrK = wg_mma_qk.partition_fragment_B(sK);\n        Tensor tOrV = wg_mma_pv.partition_fragment_B(sV);\n        Tensor tOsP = wg_mma_pv.partition_fragment_A(sP);\n        Tensor tSrQv = wg_mma_qv.partition_fragment_A(sQv);\n        Tensor tSrV = wg_mma_qv.partition_fragment_B(sVMmaQV);\n        Tensor tPsP = smem_thr_copy_P.partition_D(cute::as_position_independent_swizzle_tensor(sP));\n\n        // For storing scales to smem, only used when LargeHeadDimV\n        auto thread_mma_pv = tiled_mma_pv.get_thread_slice(thread_idx);\n        Tensor taccOcO = thread_mma_pv.partition_C(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{})));\n        Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout()));\n        Tensor taccOcO_row = taccOcO_rowcol(_, _0{});\n        auto store_scales = [&](auto& scales, int stage) {\n            static_assert(CUTE_STATIC_V(size(scales)) == CUTE_STATIC_V(size(taccOcO_row)));\n            #pragma unroll\n            for (int mi = 0; mi < size(taccOcO_row); ++mi) {\n                if (get<1>(taccOcO_row(_0{})) == 0) {\n                    sScale(get<0>(taccOcO_row(mi)), stage) = scales(mi);\n                }\n            }\n        };\n\n        auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {\n            auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);\n            pipeline.consumer_wait(smem_pipe_read, barrier_token);\n        };\n\n        int const seqlen_q = seqlen_info.seqlen_q;\n        int const seqlen_k = seqlen_info.seqlen_k;\n        int n_block = n_block_max - 1;\n\n        flash::Mask<kBlockM, kBlockN, PackGQA, TiledMmaQK> mask(\n            thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, 0 /*sink_token_length*/,\n            params.attention_chunk_divmod, params.qhead_per_khead_divmod\n        );\n\n        float softcap_val = params.softcap_val;\n        if constexpr (Has_softcap && Is_FP8) {\n            float const q_descale = params.ptr_q_descale == nullptr ? 1.0f : params.ptr_q_descale[bidb * get<0>(params.stride_q_descale) + bidh_kv * get<1>(params.stride_q_descale)];\n            float const k_descale = params.ptr_k_descale == nullptr ? 1.0f : params.ptr_k_descale[bidb * get<0>(params.stride_k_descale) + bidh_kv * get<1>(params.stride_k_descale)];\n            softcap_val *= q_descale * k_descale;\n        }\n        // Softcapping needs to happen before masking since if we apply after masking, softcapping\n        // can turn -inf to e.g. -50.0, which can affect the attention softmax.\n        auto scoremod_premask_fn = [&](auto& tSrS) {\n            if constexpr (Has_softcap) { flash::apply_softcap(tSrS, softcap_val); }\n        };\n\n        auto write_P_to_smem = [&](auto& tOrP) {\n            if constexpr (LargeHeadDimV) {\n                cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<uint32_t>(FwdNamedBarriers::PEmpty) /*id*/);\n            }\n            cute::copy(smem_tiled_copy_P, smem_thr_copy_P.retile_S(tOrP), tPsP);\n        };\n\n        auto arrive_on_P_write_barrier = [&] {\n            cutlass::arch::fence_view_async_shared();\n            __syncwarp();  // Only need syncwarp since each warp is using its own P values for MmaPV\n            if constexpr (LargeHeadDimV) {\n                cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<uint32_t>(FwdNamedBarriers::PFull) /*id*/);\n            }\n        };\n\n        auto &barrier_Q = shared_storage.pipelines.barrier_Q;\n        if constexpr (!AppendKV) {\n            barrier_Q.wait(work_idx % 2);\n        } else {\n            if (get<1>(params.shape_rotary) > 0) {  // Apply rotary to Q\n                using Rotary_t = Rotary<kBlockM, kHeadDim, NumMmaThreadsQK, Element, !(Is_causal || Is_local) /*FixedPosition*/>;\n                Rotary_t rotary(params.ptr_rotary_cos, params.shape_rotary, params.stride_rotary_cos,\n                                params.ptr_rotary_sin, params.stride_rotary_sin,\n                                params.is_rotary_interleaved, thread_idx, seqlen_q,\n                                seqlen_info.seqlen_rotary);\n                Tensor sQ_pi = cute::as_position_independent_swizzle_tensor(sQ);\n                int const qhead_per_khead = !PackGQA ? 1 : params.qhead_per_khead_divmod.divisor;\n                if (params.is_rotary_interleaved) {\n                    auto [tRrCos, tRrSin] = cute::conditional_return<!PackGQA>(\n                        rotary.template load_cos_sin<true /*kInterleaved*/>(m_block),\n                        rotary.template load_cos_sin_packgqa<true /*kInterleaved*/>(m_block, params.qhead_per_khead_divmod)\n                    );\n                    barrier_Q.wait(work_idx % 2);\n                    rotary.apply_Q_interleaved(sQ_pi, tRrCos, tRrSin, m_block, qhead_per_khead);\n                } else {\n                    auto [tRrCosCont, tRrSinCont] = cute::conditional_return<!PackGQA>(\n                        rotary.template load_cos_sin<false /*kInterleaved*/>(m_block),\n                        rotary.template load_cos_sin_packgqa<false /*kInterleaved*/>(m_block, params.qhead_per_khead_divmod)\n                    );\n                    barrier_Q.wait(work_idx % 2);\n                    rotary.apply_Q_contiguous(sQ_pi, tRrCosCont, tRrSinCont, m_block, qhead_per_khead);\n                }\n                // SMEM fence to make sure the rotated Q is visible to GMMA\n                cutlass::arch::fence_view_async_shared();\n                cutlass::arch::NamedBarrier::sync(NumMmaThreadsQK, static_cast<uint32_t>(FwdNamedBarriers::QueryRotated) /*id*/);\n            } else {\n                barrier_Q.wait(work_idx % 2);\n            }\n        }\n\n        if constexpr (MmaQK_is_RS) {\n            using SmemCopyAtomQ = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;\n            auto smem_tiled_copy_Q = make_tiled_copy_A(SmemCopyAtomQ{}, tiled_mma_qk);\n            auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(thread_idx);\n            Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);\n            Tensor tSsQ_copy_view = smem_thr_copy_Q.partition_S(cute::as_position_independent_swizzle_tensor(sQ));\n            cute::copy(smem_tiled_copy_Q, tSsQ_copy_view, tSrQ_copy_view);\n        }\n\n        if constexpr (IntraWGOverlap) {\n            Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_MNK{}));\n            consumer_wait(pipeline_k, smem_pipe_read);\n            flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);\n            warpgroup_wait<0>();\n            pipeline_k.consumer_release(smem_pipe_read);\n            if constexpr (HasQv) {\n                shared_storage.pipelines.barrier_Qv.wait(work_idx % 2);\n                consumer_wait(pipeline_v, smem_pipe_read);\n                flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_qv, tSrQv, tSrV(_, _, _, smem_pipe_read.index()), tSrS);\n            }\n            scoremod_premask_fn(tSrS);\n            mask.template apply<true /*Seqlenk_mask*/, Is_causal, Is_local>(tSrS, m_block, n_block);\n\n            Tensor scores_scale = softmax.template max_get_scale</*Is_first=*/true, /*Check_inf=*/true>(tSrS);\n            // Don't need to store scales to send to WG1 (in the case of LargeHeadDimV) since it's 1.f\n\n            softmax.template online_softmax</*Is_first=*/true, /*Check_inf=*/true>(tSrS);\n            if constexpr (Is_FP8 && !V_colmajor) { flash::permute_Cregs_fp8(tSrS); }\n            Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs<TiledMmaPV>(tSrS.layout()));\n            Tensor tOrP = make_tensor_like<Element>(tOrP_acc);\n            convert_type_out(tOrP_acc, tOrP);\n            if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); }\n            if constexpr (!MmaPV_is_RS) { write_P_to_smem(tOrP); }\n            if constexpr (!MmaPV_is_RS) { arrive_on_P_write_barrier(); }\n            --n_block;\n\n            // Need to initialize tOrO in the case of RescaleOBeforeGemm where we will scale tOrO even in the 1st iter\n            clear(tOrO);\n            // tiled_mma_pv.accumulate_ = GMMA::ScaleOut::Zero;\n\n            // Each step does gemm0 for iter n_block, gemm1 for iter n_block + 1, and softmax for iter n_block.\n            auto fwd_step = [&](int const n_block, auto mask_fn, auto check_inf_type) {\n                static constexpr bool Check_inf = decltype(check_inf_type)::value;\n                PipelineState smem_pipe_read_v(smem_pipe_read.index(), smem_pipe_read.phase(), smem_pipe_read.count());\n                ++smem_pipe_read;\n                Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_MNK{}));\n                if (!UseSchedulerBarrier || warp_group_idx == 0) { consumer_wait(pipeline_k, smem_pipe_read); }\n                warp_scheduler_barrier_sync();\n                flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);\n                if constexpr (RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); }\n                if constexpr(!HasQv) {\n                    if (!UseSchedulerBarrier || warp_group_idx == 0) { consumer_wait(pipeline_v, smem_pipe_read_v); }\n                }\n                flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma_pv, cute::conditional_return<MmaPV_is_RS>(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);\n                warp_scheduler_barrier_arrive();\n                warpgroup_wait<1>();\n                pipeline_k.consumer_release(smem_pipe_read);  // release K\n                if constexpr (HasQv) {\n                    warpgroup_wait<0>();\n                    pipeline_v.consumer_release(smem_pipe_read_v);  // release V\n                    consumer_wait(pipeline_v, smem_pipe_read);\n                    flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_qv, tSrQv, tSrV(_, _, _, smem_pipe_read.index()), tSrS);\n                }\n                scoremod_premask_fn(tSrS);\n                mask_fn(tSrS, n_block);\n                cute::copy(softmax.template max_get_scale</*Is_first=*/false, Check_inf>(tSrS), scores_scale);\n                if constexpr (LargeHeadDimV) { store_scales(scores_scale, smem_pipe_read_v.index()); }\n                softmax.template online_softmax</*Is_first=*/false, Check_inf>(tSrS);\n                if constexpr (!HasQv) {\n                    warpgroup_wait<0>();\n                    pipeline_v.consumer_release(smem_pipe_read_v);  // release V\n                }\n                if constexpr (Is_FP8 && !V_colmajor) { flash::permute_Cregs_fp8(tSrS); }\n                convert_type_out(make_tensor(tSrS.data(), tOrP.layout()), tOrP);\n                if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); }\n                if constexpr (!MmaPV_is_RS) { write_P_to_smem(tOrP); }\n                if constexpr (!RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); }\n                if constexpr (!MmaPV_is_RS) { arrive_on_P_write_barrier(); }\n            };\n\n            if constexpr (Is_causal || Is_local) { // Separate iterations with causal or local masking\n                auto mask_fn = [&](auto& tSrS, int n_block) { mask.template apply<false /*Seqlenk_mask*/, Is_causal, Is_local>(tSrS, m_block, n_block); };\n                int const n_block_min_causal_local_mask = BlockMN_t::get_n_block_min_causal_local_mask(\n                    seqlen_info, m_block, n_block_min, params.window_size_right,\n                    params.attention_chunk_divmod, params.qhead_per_khead_divmod);\n                #pragma unroll 1\n                for (; n_block >= n_block_min_causal_local_mask; --n_block) {\n                    fwd_step(n_block, mask_fn, cute::true_type{} /*check_inf*/);\n                }\n            }\n\n            int const n_block_min_before_local_mask = BlockMN_t::get_n_block_min_before_local_mask(\n                seqlen_info, m_block, n_block_min, params.window_size_left,\n                params.attention_chunk_divmod, params.qhead_per_khead_divmod);\n            auto no_mask_fn = [](auto& tSrS, int n_block) { };\n            #pragma unroll 1\n            for (; n_block >= n_block_min_before_local_mask; --n_block) {\n                fwd_step(n_block, no_mask_fn, cute::false_type{} /*check_inf*/);\n            }\n            // Separate masking iterations on the left for local attention\n            if constexpr (Is_local) {\n                auto local_mask_fn = [&](auto& tSrS, int n_block) { mask.template apply<false /*Seqlenk_mask*/, false /*Causal_mask*/, Is_local>(tSrS, m_block, n_block); };\n                #pragma unroll 1\n                for (; n_block >= n_block_min; --n_block) {\n                    fwd_step(n_block, local_mask_fn, cute::bool_constant<Is_local>{} /*check_inf*/);\n                }\n            }\n            // Tell producers that smem_q is ready\n            cutlass::arch::NamedBarrier::arrive(NumMmaThreadsQK + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast<uint32_t>(FwdNamedBarriers::QueryEmpty) /*id*/);\n            if constexpr (RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); }\n            if constexpr (!HasQv) { consumer_wait(pipeline_v, smem_pipe_read); }\n            flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma_pv, cute::conditional_return<MmaPV_is_RS>(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO);\n            float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)];\n            cute::copy(softmax.finalize(v_descale), scores_scale);\n            if constexpr (LargeHeadDimV) {\n                cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<uint32_t>(FwdNamedBarriers::PEmpty) /*id*/);\n                store_scales(scores_scale, smem_pipe_read.index());\n                cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<uint32_t>(FwdNamedBarriers::PFull) /*id*/);\n            }\n            warpgroup_wait<0>();\n            pipeline_v.consumer_release(smem_pipe_read);  // release V, otherwise producers will hang\n            softmax.rescale_o(tOrO, scores_scale);\n            if constexpr (Is_FP8 && !V_colmajor) { flash::permute_output_fp8(tOrO); }\n            ++smem_pipe_read;\n\n        } else {  // No intra-WG overlap\n\n            warp_scheduler_barrier_sync();\n\n            auto fwd_step = [&](int const n_block, auto mask_fn, auto is_first_iter_type, auto check_inf_type) {\n                static constexpr bool Is_first_iter = decltype(is_first_iter_type)::value;\n                static constexpr bool Check_inf = decltype(check_inf_type)::value;\n                auto smem_pipe_read_prev = smem_pipe_read;\n                if constexpr (!Is_first_iter) { ++smem_pipe_read; }\n                Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_MNK{}));\n                consumer_wait(pipeline_k, smem_pipe_read);\n                flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);\n                if constexpr (!HasQv) {\n                    warp_scheduler_barrier_arrive();\n                    warpgroup_wait<0>();\n                    pipeline_k.consumer_release(smem_pipe_read);  // release K\n                } else {\n                    if constexpr (Is_first_iter) {\n                        shared_storage.pipelines.barrier_Qv.wait(work_idx % 2);\n                    }\n                    consumer_wait(pipeline_v, smem_pipe_read);\n                    flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma_qv, tSrQv, tSrV(_, _, _, smem_pipe_read.index()), tSrS);\n                    warp_scheduler_barrier_arrive();\n                    warpgroup_wait<1>();\n                    pipeline_k.consumer_release(smem_pipe_read);  // release K\n                    warpgroup_wait<0>();\n                }\n                scoremod_premask_fn(tSrS);\n                mask_fn(tSrS, n_block);\n                Tensor scores_scale = softmax.template max_get_scale</*Is_first=*/Is_first_iter, Check_inf>(tSrS);\n                if constexpr (LargeHeadDimV && !Is_first_iter) { store_scales(scores_scale, smem_pipe_read_prev.index()); }\n                softmax.template online_softmax</*Is_first=*/Is_first_iter, Check_inf>(tSrS);\n                if constexpr (Is_FP8 && !V_colmajor) { flash::permute_Cregs_fp8(tSrS); }\n                Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs<TiledMmaPV>(tSrS.layout()));\n                Tensor tOrP = make_tensor_like<Element>(tOrP_acc);\n                convert_type_out(tOrP_acc, tOrP);\n                if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); }\n                if constexpr (!MmaPV_is_RS) { write_P_to_smem(tOrP); }\n                if constexpr (!Is_first_iter) { softmax.rescale_o(tOrO, scores_scale); }\n                if constexpr (!MmaPV_is_RS && !MmaPV_use_RS_WG1) { arrive_on_P_write_barrier(); }\n                if constexpr (!HasQv) { consumer_wait(pipeline_v, smem_pipe_read); }\n                warp_scheduler_barrier_sync();\n                if constexpr (!MmaPV_use_RS_WG1) {\n                    flash::gemm</*zero_init=*/Is_first_iter, /*wg_wait=*/-1>(tiled_mma_pv, cute::conditional_return<MmaPV_is_RS>(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO);\n                } else {\n                    TiledMmaPV_RS tiled_mma_pv_rs;\n                    flash::gemm</*zero_init=*/Is_first_iter, /*wg_wait=*/-1>(tiled_mma_pv_rs, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);\n                }\n                if constexpr (!MmaPV_is_RS && MmaPV_use_RS_WG1) { arrive_on_P_write_barrier(); }\n                warpgroup_wait<0>();\n                pipeline_v.consumer_release(smem_pipe_read);  // release V\n            };\n\n            auto first_iter_mask_fn = [&](auto& tSrS, int n_block) { mask.template apply<true /*Seqlenk_mask*/, Is_causal, Is_local>(tSrS, m_block, n_block); };\n            fwd_step(n_block, first_iter_mask_fn, cute::true_type{} /*is_first_iter*/, cute::true_type{} /*check_inf*/);\n            --n_block;\n            if constexpr (Is_causal || Is_local) { // Separate iterations with causal or local masking\n                auto mask_fn = [&](auto& tSrS, int n_block) { mask.template apply<false /*Seqlenk_mask*/, Is_causal, Is_local>(tSrS, m_block, n_block); };\n                int const n_block_min_causal_local_mask = BlockMN_t::get_n_block_min_causal_local_mask(\n                    seqlen_info, m_block, n_block_min, params.window_size_right,\n                    params.attention_chunk_divmod, params.qhead_per_khead_divmod);\n                #pragma unroll 1\n                for (; n_block >= n_block_min_causal_local_mask; --n_block) {\n                    fwd_step(n_block, mask_fn, cute::false_type{} /*is_first_iter*/, cute::true_type{} /*check_inf*/);\n                }\n            }\n            int const n_block_min_before_local_mask = BlockMN_t::get_n_block_min_before_local_mask(\n                seqlen_info, m_block, n_block_min, params.window_size_left,\n                params.attention_chunk_divmod, params.qhead_per_khead_divmod);\n            auto no_mask_fn = [](auto& tSrS, int n_block) { };\n            #pragma unroll 1\n            for (; n_block >= n_block_min_before_local_mask; --n_block) {\n                fwd_step(n_block, no_mask_fn, cute::false_type{} /*is_first_iter*/, cute::false_type{} /*check_inf*/);\n            }\n            // Separate masking iterations on the left for local attention\n            if constexpr (Is_local) {\n                auto local_mask_fn = [&](auto& tSrS, int n_block) { mask.template apply<false /*Seqlenk_mask*/, false /*Causal_mask*/, Is_local>(tSrS, m_block, n_block); };\n                #pragma unroll 1\n                for (; n_block >= n_block_min; --n_block) {\n                    fwd_step(n_block, local_mask_fn, cute::false_type{} /*is_first_iter*/, cute::bool_constant<Is_local>{} /*check_inf*/);\n                }\n            }\n            warp_scheduler_barrier_arrive();\n            // Tell producers that smem_q is ready\n            cutlass::arch::NamedBarrier::arrive(NumMmaThreadsQK + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast<uint32_t>(FwdNamedBarriers::QueryEmpty) /*id*/);\n            float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)];\n            Tensor scores_scale = softmax.finalize(v_descale);\n            if constexpr (LargeHeadDimV) {\n                cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<uint32_t>(FwdNamedBarriers::PEmpty) /*id*/);\n                store_scales(scores_scale, smem_pipe_read.index());\n                cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<uint32_t>(FwdNamedBarriers::PFull) /*id*/);\n            }\n            softmax.rescale_o(tOrO, scores_scale);\n            if constexpr (Is_FP8 && !V_colmajor) { flash::permute_output_fp8(tOrO); }\n            ++smem_pipe_read;\n        }\n        ++work_idx;\n        return true;\n    }\n\n    template <typename SharedStorage, typename FrgTensorO, typename Softmax>\n    CUTLASS_DEVICE bool\n    mma_pv(Params const& params,\n           MainloopPipelineV pipeline_v,\n           PipelineState& smem_pipe_read,\n           FrgTensorO& tOrO,\n           Softmax& softmax,\n           int const thread_idx,\n           SeqlenInfo_t const& seqlen_info,\n           cute::tuple<int32_t, int32_t, int32_t, int32_t> block_coord,\n           SharedStorage& shared_storage\n           ) {\n        static_assert(is_rmem<FrgTensorO>::value, \"O tensor must be rmem resident.\");\n        // can't use auto [m_block, ...] = block_coord since structured binding cannot be captured in lambda\n        int const m_block = get<0>(block_coord);\n        int const bidb = get<2>(block_coord);\n        int const split_idx = get<3>(block_coord);\n        auto [n_block_min, n_block_max] = BlockMN_t::get_n_block_min_max(\n            seqlen_info, m_block, bidb, split_idx, params.num_splits,\n            params.window_size_left, params.window_size_right, params.attention_chunk_divmod,\n            params.qhead_per_khead_divmod);\n        // It's possible to have n_block_max <= n_block_min. We don't want to load Q or change any barrier\n        if constexpr (Is_causal || Is_local || Varlen || Split) {\n            if (n_block_max <= n_block_min) { return false; }\n        }\n\n        Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVtMma{});\n        Tensor sP = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutP{});\n        Tensor sScale = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_scale.data()), SmemLayoutScale{});\n        static constexpr int MmaWarpGroups = size(TiledMmaPV{}) / cutlass::NumThreadsPerWarpGroup;\n        Layout warp_group_thread_layout = make_layout(make_shape(Int<MmaWarpGroups>{}),\n                                                      make_stride(Int<cutlass::NumThreadsPerWarpGroup>{}));\n\n        int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0);\n        TiledMmaPV tiled_mma_pv;\n        auto wg_mma_pv = tiled_mma_pv.get_slice(warp_group_thread_layout(warp_group_idx));\n\n        // Allocate \"fragments/descriptors\"\n        Tensor tOrV = wg_mma_pv.partition_fragment_B(sV);\n        Tensor tOsP = wg_mma_pv.partition_fragment_A(sP);\n\n        // For load scales to smem, pretend thread_idx is thread_idx % 128\n        auto thread_mma_pv = tiled_mma_pv.get_thread_slice(thread_idx % cutlass::NumThreadsPerWarpGroup);\n        Tensor taccOcO = thread_mma_pv.partition_C(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{})));\n        Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout()));\n        Tensor taccOcO_row = taccOcO_rowcol(_, _0{});\n        auto load_scales = [&](auto& scales, int stage) {\n            static_assert(CUTE_STATIC_V(size(scales)) == CUTE_STATIC_V(size(taccOcO_row)));\n            #pragma unroll\n            for (int mi = 0; mi < size(taccOcO_row); ++mi) {\n                scales(mi) = sScale(get<0>(taccOcO_row(mi)), stage);\n            }\n        };\n\n        // clear(tOrO);\n        // tiled_mma_pv.accumulate_ = GMMA::ScaleOut::Zero;\n\n        typename Softmax::TensorT scores_scale;\n\n        int n_block = n_block_max - 1;\n        // If HasQv, then by the time P is ready, V must have been ready as well\n        if constexpr (!HasQv) { pipeline_v.consumer_wait(smem_pipe_read); }\n        cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<uint32_t>(FwdNamedBarriers::PFull) /*id*/);\n        flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);\n        cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<uint32_t>(FwdNamedBarriers::PEmpty) /*id*/);\n        pipeline_v.consumer_release(smem_pipe_read);  // release V\n        --n_block;\n\n        #pragma unroll 1\n        for (; n_block >= n_block_min; --n_block) {\n            cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<uint32_t>(FwdNamedBarriers::PFull) /*id*/);\n            load_scales(scores_scale, smem_pipe_read.index());\n            softmax.rescale_o(tOrO, scores_scale);\n            ++smem_pipe_read;\n            if constexpr (!HasQv) {\n                auto barrier_token = pipeline_v.consumer_try_wait(smem_pipe_read);\n                pipeline_v.consumer_wait(smem_pipe_read, barrier_token);\n            }\n            flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);\n            cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<uint32_t>(FwdNamedBarriers::PEmpty) /*id*/);\n            pipeline_v.consumer_release(smem_pipe_read);  // release V\n        };\n        cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<uint32_t>(FwdNamedBarriers::PFull) /*id*/);\n        load_scales(scores_scale, smem_pipe_read.index());\n        cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<uint32_t>(FwdNamedBarriers::PEmpty) /*id*/);\n        softmax.rescale_o(tOrO, scores_scale);\n        if constexpr (Is_FP8 && !V_colmajor) { flash::permute_output_fp8(tOrO); }\n        ++smem_pipe_read;\n        return true;\n    }\n\n    template <typename SharedStorage>\n    CUTLASS_DEVICE bool\n    load_kv_new(Params const& params,\n         MainloopPipelineKVNew pipeline_k_new,\n         MainloopPipelineKVNew pipeline_v_new,\n         PipelineState& smem_pipe_write,\n         SharedStorage &shared_storage,\n         SeqlenInfo_t const& seqlen_info,\n         cute::tuple<int32_t, int32_t, int32_t, int32_t> block_coord,\n         int const work_idx\n         ) {\n\n        auto [m_block, bidh, bidb, split_idx] = block_coord;\n        auto [n_block_new_min, n_block_new_max] = BlockMN_t::get_n_block_k_new_min_max(\n            seqlen_info, m_block, bidb, split_idx, params.num_splits,\n            params.window_size_left, params.window_size_right, params.attention_chunk_divmod,\n            params.qhead_per_khead_divmod);\n\n        if (n_block_new_max <= n_block_new_min) { return false; }\n\n        Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{});\n        Tensor sVt = [&] {\n            if constexpr (!Transpose_V) {\n                return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVt{});\n            } else {\n                return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_vt.data()), SmemLayoutVt{});\n            }\n        }();\n\n        // int const thread_idx = threadIdx.x % NumProducerThreads;\n        int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh;\n\n        // Prepare the TMA loads\n        uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();\n        constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());\n        uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};\n\n        bool const is_varlen_k_new = Varlen && params.cu_seqlens_k_new;\n        Tensor mKnew_TMA = params.tma_load_K_new.get_tma_tensor(params.shape_K_new)(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0);\n        auto shape_Vnew = make_shape(params.headdim_v, get<0>(params.shape_K_new), get<2>(params.shape_K_new), get<3>(params.shape_K_new));\n        Tensor mVnewt_TMA = params.tma_load_V_new.get_tma_tensor(shape_Vnew)(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0);\n\n        Tensor gKnew_TMA = local_tile(domain_offset(make_coord(seqlen_info.offset_k_new, _0{}), mKnew_TMA), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{}));  // (N, K, _)\n        Tensor gVnewt_TMA = local_tile(domain_offset(make_coord(_0{}, seqlen_info.offset_k_new), mVnewt_TMA), select<1, 2>(TileShape_MNK_PV{}), make_coord(_0{}, _));  // (K, N, _)\n\n        auto block_tma_K_new = params.tma_load_K_new.get_slice(cluster_local_block_id.x);\n        Tensor tKgKnew_TMA = group_modes<0, 3>(block_tma_K_new.partition_S(gKnew_TMA));  // (TMA, k)\n        Tensor tKsK_TMA = group_modes<0, 3>(block_tma_K_new.partition_D(sK));  // (TMA, PIPE)\n        auto block_tma_V_new = params.tma_load_V_new.get_slice(cluster_local_block_id.x);\n        Tensor tVgVnewt_TMA = group_modes<0, 3>(block_tma_V_new.partition_S(gVnewt_TMA));  // (TMA, k)\n        Tensor tVsVt_TMA = group_modes<0, 3>(block_tma_V_new.partition_D(sVt));  // (TMA, PIPE)\n\n        uint16_t mcast_mask_kv = 0;\n        if constexpr (cute::is_same_v<GmemTiledCopyKV, SM90_TMA_LOAD_MULTICAST>) {\n            auto block_layout = Layout<ClusterShape>{}; // (m,n) -> block_id\n            for (int m = 0; m < size<0>(block_layout); ++m) {\n                mcast_mask_kv |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, _0{}));\n            }\n        }\n\n        auto load_K_new = [&] (int const n_block, auto const& smem_pipe_write) {\n            pipeline_k_new.producer_acquire(smem_pipe_write);\n            copy(params.tma_load_K_new.with(*pipeline_k_new.producer_get_barrier(smem_pipe_write), mcast_mask_kv, TMA::CacheHintSm90::EVICT_FIRST),\n                tKgKnew_TMA(_, n_block), tKsK_TMA(_, smem_pipe_write.index()));\n        };\n\n        auto load_V_new = [&] (int const n_block, auto const& smem_pipe_write) {\n            pipeline_v_new.producer_acquire(smem_pipe_write);\n            copy(params.tma_load_V_new.with(*pipeline_v_new.producer_get_barrier(smem_pipe_write), mcast_mask_kv, TMA::CacheHintSm90::EVICT_FIRST),\n                tVgVnewt_TMA(_, n_block), tVsVt_TMA(_, smem_pipe_write.index()));\n        };\n\n        int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);\n        // If this is true, we're guaranteed that only the first warp will execute this function\n        static constexpr bool SingleProducerWarp = NumProducerThreads == cutlass::NumThreadsPerWarp;\n        bool should_load_KV = (SingleProducerWarp || warp_idx_in_warpgroup == 0) && cute::elect_one_sync();\n\n        int n_block = n_block_new_max - 1;\n        // Need to wait for barrier_O even before load_K_new since the pipelines for AppendKV\n        // and the main attention are not the same. We want to make sure the consumers\n        // have finished reading all smem_k and smem_v for the previous iteration.\n        shared_storage.pipelines.barrier_O.wait((work_idx + 1) % 2);\n        if (should_load_KV) { load_K_new(n_block, smem_pipe_write); }\n        // if (thread_idx == 0) { printf(\"Producer: Done loading K, n_block = %d, n_block_new_min = %d\\n\", n_block, n_block_new_min); }\n        if (should_load_KV) { load_V_new(n_block, smem_pipe_write); }\n        // if (thread_idx == 0) { printf(\"Producer: Done loading V, n_block = %d, n_block_new_min = %d\\n\", n_block, n_block_new_min); }\n        ++smem_pipe_write;\n        --n_block;\n        // if (thread_idx == 0) { printf(\"Producer: before for loop\\n\"); }\n        #pragma unroll 1\n        for (; n_block >= n_block_new_min; --n_block) {\n            if (should_load_KV) {\n                load_K_new(n_block, smem_pipe_write);\n                // if (thread_idx == 0) { printf(\"Producer: Done loading K, n_block = %d, n_block_new_min = %d\\n\", n_block, n_block_new_min); }\n                load_V_new(n_block, smem_pipe_write);\n                // if (thread_idx == 0) { printf(\"Producer: Done loading V, n_block = %d, n_block_new_min = %d\\n\", n_block, n_block_new_min); }\n            }\n            ++smem_pipe_write;\n        }\n        // if (thread_idx == 0) { printf(\"Producer: after for loop\\n\"); }\n        // At the end, all threads have the correct smem_pipe_write.\n        return true;\n    }\n\n    template <typename SharedStorage>\n    CUTLASS_DEVICE bool\n    store_kv_new(Params const& params,\n                 MainloopPipelineKVNew pipeline_k_new,\n                 MainloopPipelineKVNew pipeline_v_new,\n                 PipelineState& smem_pipe_read,\n                 int const thread_idx,\n                 SharedStorage &shared_storage,\n                 SeqlenInfo_t const& seqlen_info,\n                 cute::tuple<int32_t, int32_t, int32_t, int32_t> block_coord\n    ) {\n        auto [m_block, bidh, bidb, split_idx] = block_coord;\n        auto [n_block_new_min, n_block_new_max] = BlockMN_t::get_n_block_k_new_min_max(\n            seqlen_info, m_block, bidb, split_idx, params.num_splits,\n            params.window_size_left, params.window_size_right, params.attention_chunk_divmod,\n            params.qhead_per_khead_divmod);\n        if (n_block_new_max <= n_block_new_min) { return false; }\n\n        // as_position_independent_swizzle_tensor makes address calculation easier\n        Tensor sK = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}));\n        // We want to use SmemLayoutVCpAsync to have shape (kBlockN, kHeadDim) instead of (kHeadDim, kBlockN)\n        Tensor sV = [&] {\n            if constexpr (!Transpose_V) {\n                return cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVCpAsync{}));\n            } else {\n                return cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_vt.data()), SmemLayoutVCpAsync{}));\n            }\n        }();\n\n        int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh;\n        int const bidb_kv = params.kv_batch_idx == nullptr ? bidb : params.kv_batch_idx[bidb];\n\n        bool const is_varlen_k = Varlen && params.cu_seqlens_k;\n        Tensor mK = make_tensor(make_gmem_ptr(params.ptr_K), params.shape_K, params.stride_K)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0);\n        auto shape_V = make_shape(params.headdim_v, get<0>(params.shape_K), get<2>(params.shape_K), get<3>(params.shape_K));\n        Tensor mV = make_tensor(make_gmem_ptr(params.ptr_V), shape_V, params.stride_V)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0);\n\n        int const offset_k = seqlen_info.offset_k + seqlen_info.seqlen_k_og;\n        Tensor gK = local_tile(domain_offset(make_coord(offset_k, _0{}), mK), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{}));  // (N, K, _)\n        Tensor gV = local_tile(domain_offset(make_coord(offset_k, _0{}), mV), select<2, 1>(TileShape_MNK_PV{}), make_coord(_, _0{}));  // (N, K_v, _)\n\n        static constexpr int kBlockN = get<1>(TileShape_MNK{});\n        static constexpr int kHeadDim = get<2>(TileShape_MNK{});\n        int const seqlen_k_new = seqlen_info.seqlen_k_new;\n        using Rotary_t = Rotary<kBlockN, kHeadDim, NumMmaThreads, Element>;\n        Rotary_t rotary(params.ptr_rotary_cos, params.shape_rotary, params.stride_rotary_cos,\n                        params.ptr_rotary_sin, params.stride_rotary_sin,\n                        params.is_rotary_interleaved, thread_idx, seqlen_k_new,\n                        seqlen_info.seqlen_rotary);\n\n        // This is used to index into the batch dimension of mK and mV\n        int const bidb_kv_idx = !is_varlen_k && !params.ptr_pagetable ? bidb_kv : 0;\n\n        using PagedKVManager_t = PagedKVManager<get<1>(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumMmaThreads, Element, true /*KV_Same_Iter*/, 2 /*LoadsPerRow_LB*/>;\n        PagedKVManager_t paged_kv_manager(\n            params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable,\n            params.ptr_K, params.shape_K, params.stride_K,\n            params.ptr_V, params.headdim_v, params.stride_V,\n            params.page_size_divmod, params.blockN_per_page_size_divmod,\n            bidb_kv, bidh_kv, thread_idx, seqlen_k_new, offset_k, bidb_kv_idx\n            // passing offset_k instead of leftpad_k will move the PageTable pointer to the right position\n        );\n\n        if constexpr (UseSchedulerBarrier) {\n            // WG1 already got the very first signal from mma_init(), but we'll be using the same NamedBarrier.\n            // So we'll need to \"cancel it out\" here and then re-signal it at the end.\n            if (flash::canonical_warp_group_idx_nosync() == 1) {\n                cutlass::arch::NamedBarrier::sync(2 * cutlass::NumThreadsPerWarpGroup, static_cast<uint32_t>(FwdNamedBarriers::WarpSchedulerWG1) /*id*/);\n            }\n        }\n\n        static_assert(std::is_same_v<GmemLayoutAtom, typename Rotary_t::LayoutAtom>);\n        static_assert(!PagedKVNonTMA || std::is_same_v<GmemLayoutAtom, typename PagedKVManager_t::GmemLayoutAtomKVCpAsync>);\n        GmemTiledCopyAppendKV gmem_tiled_copy_kv;\n        auto gmem_thr_copy_kv = gmem_tiled_copy_kv.get_thread_slice(thread_idx);\n        Tensor tKsK = gmem_thr_copy_kv.partition_S(sK);        // ((Atom,AtomNum),ATOM_M,ATOM_N)\n        Tensor tKgK = gmem_thr_copy_kv.partition_D(gK);\n        Tensor tVsV = gmem_thr_copy_kv.partition_S(sV);        // ((Atom,AtomNum),ATOM_M,ATOM_N)\n        Tensor tVgV = gmem_thr_copy_kv.partition_D(gV);\n        Tensor cK = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{}));  // (BLK_N,BLK_K) -> (blk_n,blk_k)\n        Tensor tKcK = gmem_thr_copy_kv.partition_D(cK);\n        Tensor tKpK = make_tensor<bool>(make_shape(size<2>(tKsK)));\n        #pragma unroll\n        for (int k = 0; k < size(tKpK); ++k) { tKpK(k) = get<1>(tKcK(_0{}, _0{}, k)) < get<1>(params.shape_K); }\n        Tensor cV = cute::make_identity_tensor(select<2, 1>(TileShape_MNK_PV{}));  // (BLK_N,BLK_K_V) -> (blk_n,blk_k_v)\n        Tensor tVcV = cute::conditional_return<SameHeadDim>(tKcK, gmem_thr_copy_kv.partition_D(cV));\n        Tensor tVpV_ = make_tensor<bool>(make_shape(size<2>(tVsV)));\n        #pragma unroll\n        for (int k = 0; k < size(tVpV_); ++k) { tVpV_(k) = get<1>(tVcV(_0{}, _0{}, k)) < params.headdim_v; }\n        Tensor tVpV = cute::conditional_return<SameHeadDim>(tKpK, tVpV_);\n\n        auto store_K = [&] (int const n_block, auto const& smem_pipe_read) {\n            int const n_limit = std::min(seqlen_k_new - n_block * kBlockN, kBlockN);\n            if (get<1>(params.shape_rotary) <= 0) {\n                pipeline_k_new.consumer_wait(smem_pipe_read);\n                Tensor tKsK_cur = tKsK(_, _, _, smem_pipe_read.index());\n                if constexpr (!PagedKVNonTMA) {\n                    Tensor tKgK_cur = tKgK(_, _, _, n_block);\n                    // Clear_OOB_K must be false since we don't want to write zeros to gmem\n                    flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(\n                        gmem_tiled_copy_kv, tKsK_cur, tKgK_cur, tKcK, tKpK, std::min(seqlen_k_new - n_block * kBlockN, kBlockN)\n                    );\n                } else {\n                    paged_kv_manager.store_K(n_block, tKsK_cur);\n                }\n            } else {\n                Tensor gK_cur = gK(_, _, n_block);\n                auto tPrKPtr = cute::conditional_return<PagedKVNonTMA>(paged_kv_manager.compute_K_ptr(), nullptr);\n                if (params.is_rotary_interleaved) {\n                    auto [tRrCos, tRrSin] = rotary.template load_cos_sin<true /*kInterleaved*/>(n_block);\n                    pipeline_k_new.consumer_wait(smem_pipe_read);\n                    rotary.template apply_K_interleaved<PagedKVNonTMA>(sK(_, _, smem_pipe_read.index()), gK_cur, tKpK, tRrCos, tRrSin, tPrKPtr, n_block);\n                } else {\n                    auto [tRrCosCont, tRrSinCont] = rotary.template load_cos_sin<false /*kInterleaved*/>(n_block);\n                    pipeline_k_new.consumer_wait(smem_pipe_read);\n                    rotary.template apply_K_contiguous<PagedKVNonTMA>(sK(_, _, smem_pipe_read.index()), gK_cur, tKpK, tRrCosCont, tRrSinCont, tPrKPtr, n_block, get<1>(params.shape_K));\n                }\n            }\n            // Without this fence I'm getting race condition when seqlen_k is large\n            cutlass::arch::fence_view_async_shared();\n            // Very important: PipelineTmaAsync::consumer_release assumes that the warpgroup is synchronized\n            // before calling.\n            cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup, static_cast<uint32_t>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + flash::canonical_warp_group_idx_nosync() /*id*/);\n            pipeline_k_new.consumer_release(smem_pipe_read);\n            // if (thread_idx == 0) { print_tensor(tKpK); printf(\"\\n\"); printf(\"seqlen_limit = %d\\n\", seqlen_k_new - n_block * kBlockN);}\n        };\n\n        auto store_V = [&] (int const n_block, auto const& smem_pipe_read) {\n            pipeline_v_new.consumer_wait(smem_pipe_read);\n            int const n_limit = std::min(seqlen_k_new - n_block * kBlockN, kBlockN);\n            Tensor tVsV_cur = tVsV(_, _, _, smem_pipe_read.index());\n            if constexpr (!PagedKVNonTMA) {\n                Tensor tVgV_cur = tVgV(_, _, _, n_block);\n                // Clear_OOB_K must be false since we don't want to write zeros to gmem\n                flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(\n                    gmem_tiled_copy_kv, tVsV_cur, tVgV_cur, tVcV, tVpV, n_limit);\n            } else {\n                paged_kv_manager.store_V(n_block, tVsV_cur);\n            }\n            cutlass::arch::fence_view_async_shared();\n            cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup, static_cast<uint32_t>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + flash::canonical_warp_group_idx_nosync() /*id*/);\n            pipeline_v_new.consumer_release(smem_pipe_read);\n        };\n\n        #pragma unroll 1\n        for (int n_block = n_block_new_max - 1; n_block >= n_block_new_min; --n_block) {\n            if constexpr (PagedKVNonTMA) { paged_kv_manager.template load_page_table<true /*Seqlenk_mask*/>(n_block); }\n            store_K(n_block, smem_pipe_read);\n            // if (thread_idx == 0) { printf(\"Done storing K, n_block = %d, n_block_new_min = %d\\n\", n_block, n_block_new_min); }\n            store_V(n_block, smem_pipe_read);\n            // if (thread_idx == 0) { printf(\"Done storing V, n_block = %d, n_block_new_min = %d\\n\", n_block, n_block_new_min); }\n            ++smem_pipe_read;\n        }\n        // if (thread_idx == 0) { printf(\"After for loop\\n\"); }\n\n        // Re-signaling the NamedBarrier that we \"canceled out\"\n        if constexpr (UseSchedulerBarrier) {\n            if (flash::canonical_warp_group_idx_nosync() == 1) {\n                cutlass::arch::NamedBarrier::arrive(2 * cutlass::NumThreadsPerWarpGroup, static_cast<uint32_t>(FwdNamedBarriers::WarpSchedulerWG1) /*id*/);\n            }\n        }\n\n        return true;\n\n    }\n\n};\n\n} // namespace flash\n"
  },
  {
    "path": "hopper/mask.h",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n#include <cute/tensor.hpp>\n\n#include \"cutlass/fast_math.h\"  // For cutlass::FastDivmod\n\n#include \"utils.h\"\n\nnamespace flash {\n\nusing namespace cute;\n\ntemplate <int kBlockM, int kBlockN, bool PackGQA, typename TiledMma, bool SwapAB=false>\nstruct Mask {\n\n    static_assert(!(PackGQA && SwapAB), \"Cannot be both PackGQA and SwapAB\");\n\n    int const thread_idx;\n    int const seqlen_q, seqlen_k;\n    int const window_size_left, window_size_right, sink_token_length;\n    cutlass::FastDivmod const attention_chunk_divmod;\n    cutlass::FastDivmod const qhead_per_khead_divmod;\n\n    CUTLASS_DEVICE\n    Mask(const int thread_idx, const int seqlen_q, const int seqlen_k,\n         const int window_size_left, const int window_size_right, const int sink_token_length,\n         cutlass::FastDivmod const &attention_chunk_divmod,\n         cutlass::FastDivmod const &qhead_per_khead_divmod)\n        : thread_idx(thread_idx)\n        , seqlen_q(seqlen_q)\n        , seqlen_k(seqlen_k)\n        , window_size_left(window_size_left)\n        , window_size_right(window_size_right)\n        , sink_token_length(sink_token_length)\n        , attention_chunk_divmod(attention_chunk_divmod)\n        , qhead_per_khead_divmod(qhead_per_khead_divmod)\n    {\n    };\n\n    template <bool Seqlenk_mask=false, bool Causal_mask=false, bool Local_mask=false,\n        typename Engine, typename Layout>\n    CUTLASS_DEVICE\n    void apply(Tensor<Engine, Layout> &tSrS, const int m_block, const int n_block) const {\n        static_assert(!(Causal_mask && Local_mask), \"Cannot be both causal and local\");\n        static_assert(Layout::rank == 3, \"Only support 3D Tensor\");\n        if (!Seqlenk_mask && !Causal_mask && !Local_mask) { return; }\n\n        auto thread_mma = TiledMma{}.get_thread_slice(thread_idx);\n        auto thread0_mma = TiledMma{}.get_thread_slice(_0{});\n\n        static constexpr int Row = !SwapAB ? 0 : 1, Col = !SwapAB ? 1 : 0;\n\n        Tensor cS = cute::make_identity_tensor(Shape<Int<!SwapAB ? kBlockM : kBlockN>, Int<!SwapAB ? kBlockN : kBlockM>>{});\n        Tensor tScS = thread_mma.partition_C(cS);\n        Tensor tSrS_rowcol = make_tensor(tSrS.data(), flash::convert_layout_acc_rowcol</*Transposed=*/SwapAB>(tSrS.layout()));\n        Tensor tScS_rowcol = make_tensor(tScS.data(), flash::convert_layout_acc_rowcol</*Transposed=*/SwapAB>(tScS.layout()));\n        Tensor t0ScS = thread0_mma.partition_C(cS);\n        Tensor t0ScS_rowcol = make_tensor(t0ScS.data(), flash::convert_layout_acc_rowcol</*Transposed=*/SwapAB>(t0ScS.layout()));\n        // We want to use the col indices of thread0 to compare, since that is known at compile time.\n        // So we subtract the limit by the first col index of this thread (get<Col>(tScS_rowcol(_0{}, _0{})))\n        int const thread_col_offset = get<Col>(tScS_rowcol(_0{}, _0{}));\n        int const seqlenk_col_limit = seqlen_k - n_block * kBlockN - thread_col_offset;\n        if constexpr (!Causal_mask && !Local_mask) {\n            if constexpr (Seqlenk_mask) {  // Just masking based on col\n                #pragma unroll\n                for (int n = 0; n < size<1>(tSrS_rowcol); ++n) {\n                    if (int(get<Col>(t0ScS_rowcol(_0{}, n))) >= seqlenk_col_limit) {\n                        #pragma unroll\n                        for (int m = 0; m < size<0>(tSrS_rowcol); ++m) { tSrS_rowcol(m, n) = -INFINITY; }\n                    }\n                }\n            }\n        } else {  // mask based on both row and col\n            if constexpr (!SwapAB) {\n                // If PackGQA, we split the work of compute divmod among threads in the same row\n                static constexpr int kMmaThreadsPerRow = size<0, 0>(typename TiledMma::AtomLayoutC_TV{});\n                static_assert(cutlass::NumThreadsPerWarp % kMmaThreadsPerRow == 0);\n                static_assert(!PackGQA || CUTE_STATIC_V(size<0>(tSrS_rowcol)) <= kMmaThreadsPerRow);\n                int mma_m_idx;\n                // Might get OOB but it's ok since we'll check it later\n                if constexpr (PackGQA) {\n                    mma_m_idx = qhead_per_khead_divmod.divide(m_block * kBlockM + get<Row>(tScS_rowcol(thread_idx % kMmaThreadsPerRow, _0{})));\n                }\n                int const causal_row_offset = 1 + seqlen_k - n_block * kBlockN - seqlen_q - thread_col_offset;\n                if constexpr (Causal_mask) {\n                    #pragma unroll\n                    for (int m = 0; m < size<0>(tSrS_rowcol); ++m) {\n                        int const row_idx = !PackGQA\n                            ? get<Row>(tScS_rowcol(m, _0{})) + m_block * kBlockM\n                            :  __shfl_sync(0xffffffff, mma_m_idx, m % kMmaThreadsPerRow, kMmaThreadsPerRow);\n                        int const col_limit_right = !Seqlenk_mask\n                            ? row_idx + causal_row_offset\n                            : __viaddmin_s32(row_idx, causal_row_offset, seqlenk_col_limit);\n                        #pragma unroll\n                        for (int n = 0; n < size<1>(tSrS_rowcol); ++n) {\n                            if (int(get<Col>(t0ScS_rowcol(_0{}, n))) >= col_limit_right) { tSrS_rowcol(m, n) = -INFINITY; }\n                        }\n                    }\n                } else {\n                    int const local_row_offset_right = causal_row_offset + window_size_right;\n                    int const local_row_offset_left = causal_row_offset - 1 - window_size_left;\n                    int const col_limit_sink = sink_token_length - n_block * kBlockN;  // TODO: subtract thread_col_offset?\n                    #pragma unroll\n                    for (int m = 0; m < size<0>(tSrS_rowcol); ++m) {\n                        int const row_idx = !PackGQA\n                            ? get<Row>(tScS_rowcol(m, _0{})) + m_block * kBlockM\n                            :  __shfl_sync(0xffffffff, mma_m_idx, m % kMmaThreadsPerRow, kMmaThreadsPerRow);\n                        int col_limit_right = !Seqlenk_mask\n                            ? row_idx + local_row_offset_right\n                            : __viaddmin_s32(row_idx, local_row_offset_right, seqlenk_col_limit);\n                        int col_limit_left = row_idx + local_row_offset_left;\n                        if (attention_chunk_divmod.divisor > 0) {\n                            int col_limit_left_chunk = flash::round_down(attention_chunk_divmod, row_idx + seqlen_k - seqlen_q) - n_block * kBlockN - thread_col_offset;\n                            col_limit_left = std::max(col_limit_left, col_limit_left_chunk);\n                            col_limit_right = std::min(col_limit_right, col_limit_left_chunk + attention_chunk_divmod.divisor);\n                        }\n                        #pragma unroll\n                        for (int n = 0; n < size<1>(tSrS_rowcol); ++n) {\n                            int const col_idx = int(get<Col>(t0ScS_rowcol(m, n)));\n                            if (col_idx >= col_limit_right || (col_idx < col_limit_left && col_idx >= col_limit_sink)) { tSrS_rowcol(m, n) = -INFINITY; }\n                        }\n                    }\n                }\n            } else {\n                // TODO: backward does not support attention_chunk yet\n                int const thread_row_offset = get<Row>(tScS_rowcol(_0{}, _0{}));\n                int const causal_row_offset = seqlenk_col_limit - seqlen_q + m_block * kBlockM + thread_row_offset;\n                if constexpr (Causal_mask) {\n                    #pragma unroll\n                    for (int n = 0; n < size<1>(tSrS_rowcol); ++n) {\n                        int const col0 = int(get<Col>(t0ScS_rowcol(_0{}, n)));\n                        // If col0 is beyond the column limit, we want to mask out the entire column, by setting\n                        // row limit to be kBlockM.\n                        int const row_limit_top = col0 >= seqlenk_col_limit ? kBlockM : col0 - causal_row_offset;\n                        #pragma unroll\n                        for (int m = 0; m < size<0>(tSrS_rowcol); ++m) {\n                            if (int(get<Row>(t0ScS_rowcol(m, _0{}))) < row_limit_top) { tSrS_rowcol(m, n) = -INFINITY; }\n                        }\n                    }\n                } else {\n                    int const col_limit_sink = sink_token_length - n_block * kBlockN - thread_col_offset;\n                    #pragma unroll\n                    for (int n = 0; n < size<1>(tSrS_rowcol); ++n) {\n                        int const col0 = int(get<Col>(t0ScS_rowcol(_0{}, n)));\n                        // If col0 is beyond the column limit, we want to mask out the entire column, by setting\n                        // row limit to be kBlockM.\n                        int const row_limit_top = col0 >= seqlenk_col_limit ? kBlockM : col0 - causal_row_offset - window_size_right;\n                        int const row_limit_bot = col0 < col_limit_sink ? kBlockM : col0 - causal_row_offset + window_size_left;\n                        #pragma unroll\n                        for (int m = 0; m < size<0>(tSrS_rowcol); ++m) {\n                            int const row_idx = int(get<Row>(t0ScS_rowcol(m, _0{})));\n                            if (row_idx < row_limit_top || row_idx > row_limit_bot) { tSrS_rowcol(m, n) = -INFINITY; }\n                        }\n                    }\n                }\n            }\n        }\n    };\n\n};\n\n} // namespace flash\n"
  },
  {
    "path": "hopper/named_barrier.hpp",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n#include \"cutlass/arch/barrier.h\"\n\nnamespace flash {\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n// cutlass::arch::NamedBarrier::sync/arrive are only enabled Sm90 even though they work\n// for Sm80 as well. We reimplement them here, enabled for both Sm90 and Sm80.\n\nCUTLASS_DEVICE\nstatic void named_barrier_sync(uint32_t num_threads, uint32_t barrier_id_) {\n    static constexpr uint32_t ReservedNamedBarrierCount = static_cast<uint32_t>(cutlass::arch::ReservedNamedBarriers::FirstUserBarrier);\n    uint32_t barrier_id = barrier_id_ + ReservedNamedBarrierCount;\n    asm volatile(\"bar.sync %0, %1;\" : : \"r\"(barrier_id), \"r\"(num_threads));\n    cutlass::arch::synclog_emit_named_barrier_arrive_and_wait(__LINE__, num_threads, barrier_id);\n}\n\nCUTLASS_DEVICE\nstatic void named_barrier_sync(uint32_t num_threads, cutlass::arch::ReservedNamedBarriers reserved_named_barriers) {\n    uint32_t barrier_id = static_cast<uint32_t>(reserved_named_barriers);\n    asm volatile(\"bar.sync %0, %1;\" : : \"r\"(barrier_id), \"r\"(num_threads));\n    cutlass::arch::synclog_emit_named_barrier_arrive_and_wait(__LINE__, num_threads, barrier_id);\n}\n\nCUTLASS_DEVICE\nstatic void named_barrier_arrive(uint32_t num_threads, uint32_t barrier_id_) {\n    static constexpr uint32_t ReservedNamedBarrierCount = static_cast<uint32_t>(cutlass::arch::ReservedNamedBarriers::FirstUserBarrier);\n    uint32_t barrier_id = barrier_id_ + ReservedNamedBarrierCount;\n    cutlass::arch::synclog_emit_named_barrier_arrive(__LINE__, num_threads, barrier_id);\n    asm volatile(\"bar.arrive %0, %1;\" : : \"r\"(barrier_id), \"r\"(num_threads));\n}\n\nCUTLASS_DEVICE\nstatic void named_barrier_arrive(uint32_t num_threads, cutlass::arch::ReservedNamedBarriers reserved_named_barriers) {\n    uint32_t barrier_id = static_cast<uint32_t>(reserved_named_barriers);\n    cutlass::arch::synclog_emit_named_barrier_arrive(__LINE__, num_threads, barrier_id);\n    asm volatile(\"bar.arrive %0, %1;\" : : \"r\"(barrier_id), \"r\"(num_threads));\n}\n\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n// Enumerates the reserved named barriers to avoid potential conflicts\n\nenum class FwdNamedBarriers {\n    QueryEmpty = 0,\n    WarpSchedulerWG1 = 1,\n    WarpSchedulerWG2 = 2,\n    WarpSchedulerWG3 = 3,\n    AppendKV = 4,\n    QueryRotated = 5,\n    PFull = 6,\n    PEmpty = 7,\n};\n\nenum class BwdNamedBarriers {\n    KVEmpty = 0,\n    PdS = 1,\n    dQEmptyWG1 = 2,\n    dQEmptyWG2 = 3,\n    dQEmptyWG3 = 4,\n    dQFullWG1 = 5,\n    dQFullWG2 = 6,\n    dQFullWG3 = 7,\n};\n\n} // flash\n"
  },
  {
    "path": "hopper/pack_gqa.h",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n#include <cute/tensor.hpp>\n\n#include \"cutlass/fast_math.h\"  // For cutlass::FastDivmod\n\n#include \"utils.h\"\n\nnamespace flash {\n\nusing namespace cute;\n\ntemplate <int kBlockM, int kHeadDim, int NumThreads, typename Element>\nstruct PackGQAManager {\n    // We use CpAsync for Q, since TMA doesn't work there\n    static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);\n    static constexpr int kGmemElemsPerStore = kGmemElemsPerLoad;\n    static_assert(kHeadDim % kGmemElemsPerLoad == 0, \"Headdim must be a multiple of kGmemElemsPerLoad\");\n    // We want each \"row\" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each\n    // thread to have 4 loads in the M direction and 2 vectorized load in the K direction.\n    // In the case of PackGQA, this reduces the number of times we need to call divmod.\n    static constexpr int kBytePerRow = kHeadDim * sizeof(Element);\n    static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element);\n    static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad;\n    static_assert(NumThreads % kGmemThreadsPerRow == 0, \"NumThreads must be a multiple of kGmemThreadsPerRow\");\n    // We assume threads loading the same row are in the same warp. This is for an optimization in PagedKV where\n    // these threads share the same page table entry and share the work of computing pointers to paged K and paged V.\n    static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0, \"kGmemThreadsPerRow must divide NumThreadsPerWarp\");\n    using GmemCopyAtomCpAsync = cute::Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL_ZFILL<uint128_t>, Element>;\n    using GmemLayoutAtom = Layout<Shape <Int<NumThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,\n                                  Stride<Int<kGmemThreadsPerRow>, _1>>;\n    using GmemTiledCopyQCpAsync = decltype(\n        make_tiled_copy(GmemCopyAtomCpAsync{},\n                        GmemLayoutAtom{},\n                        Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{}));  // Val layout, 8 or 16 vals per load\n\n    // Was trying to have each WG loading Q to the rows in sQ that only that WG needs so that we only need\n    // to sync within each WG, but didn't seem to be any faster.\n    // using GmemLayoutAtomWG = Layout<Shape <Int<128 / kGmemThreadsPerRow>, Int<NumThreads / 128>, Int<kGmemThreadsPerRow> >,\n    //     Stride<Int<kGmemThreadsPerRow>, _128, _1>>;\n    // using GmemTiledCopyQCpAsyncWG = decltype(\n    //     make_tiled_copy(GmemCopyAtomCpAsync{},\n    //                     GmemLayoutAtomNew{},\n    //                     Layout<Shape<_1, _1, Int<kGmemElemsPerLoad>>>{}));  // Val layout, 8 or 16 vals per load\n\n    using GmemTiledCopyO = decltype(\n        make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},\n                        GmemLayoutAtom{},\n                        Layout<Shape<_1, Int<kGmemElemsPerStore>>>{}));  // Val layout, 8 or 16 vals per store\n\n    template <int NumThreadsPerRow=kGmemThreadsPerRow, typename Engine, typename Layout, typename TensorC>\n    CUTLASS_DEVICE\n    static auto\n    compute_ptr(Tensor<Engine, Layout> &tensor, TensorC const &tRows,\n                cutlass::FastDivmod const &qhead_per_khead_divmod, int const thread_idx, int const m_block) {\n        // tensor of shape ((qhead_per_khead, seqlen_q))\n        static constexpr int NumPtrPerThread = cute::ceil_div(CUTE_STATIC_V(cute::size(tRows)), NumThreadsPerRow);\n        using TensorType = typename Engine::value_type;\n        Tensor tPrPtr = make_tensor<TensorType const*>(Shape<Int<NumPtrPerThread>>{});\n        #pragma unroll\n        for (int i = 0; i < NumPtrPerThread; ++i) {\n            int const row = i * NumThreads + get<0>(tRows(thread_idx % NumThreadsPerRow));\n            int const idx = m_block * kBlockM + row;\n            int m_idx, h_idx;\n            m_idx = qhead_per_khead_divmod.divmod(h_idx, idx);\n            tPrPtr[i] = &tensor(make_coord(make_coord(h_idx, m_idx)));\n        }\n        return tPrPtr;\n    }\n\n\n    template <typename TensormQ, typename TensorsQ>\n    CUTLASS_DEVICE\n    static void\n    load_Q(TensormQ const &mQ,  // ((qhead_per_khead, seqlen_q), headdim)\n           TensorsQ &sQ,  // (kBlockM, kHeadDim)\n           cutlass::FastDivmod const &qhead_per_khead_divmod,\n           int const thread_idx, int const seqlen_q, int const m_block\n          )\n    {\n        GmemTiledCopyQCpAsync gmem_tiled_copy_Q_cp_async;\n        // GmemTiledCopyQCpAsyncNew gmem_tiled_copy_Q_cp_async;\n        auto gmem_thr_copy_Q_cp_async = gmem_tiled_copy_Q_cp_async.get_thread_slice(thread_idx);\n        Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ)));    // (BLK_M,BLK_K) -> (blk_m,blk_k)\n        Tensor tQcQ = gmem_thr_copy_Q_cp_async.partition_S(cQ);       // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)\n        Tensor tQsQ = gmem_thr_copy_Q_cp_async.partition_D(sQ);       // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)\n        // Tensor tQcQ_ = gmem_thr_copy_Q_cp_async.partition_S(cute::flat_divide(cQ, _64{}));       // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)\n        // Tensor tQsQ_ = gmem_thr_copy_Q_cp_async.partition_D(cute::flat_divide(sQ, _64{}));       // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)\n        // Tensor tQcQ = group_modes<1, rank(tQcQ_) - 1>(tQcQ_);\n        // Tensor tQsQ = group_modes<1, rank(tQsQ_) - 1>(tQsQ_);\n        Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));\n        #pragma unroll\n        for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(_0{}, _0{}, k)) < size<1>(mQ); }\n\n        // Similar to loading K and V when PagedKV, it's expensive to compute the pointers for Q.\n        // We split the work among threads loading the same row of Q, then __shfl_sync the pointers.\n        Tensor mQ_0 = mQ(_, _0{});\n        Tensor tQcQ_row = tQcQ(_0{}, _, _0{});\n        Tensor tPrQPtr = compute_ptr(mQ_0, tQcQ_row, qhead_per_khead_divmod, thread_idx, m_block);\n        int const qhead_per_khead = qhead_per_khead_divmod.divisor;\n        #pragma unroll\n        for (int m = 0; m < size<1>(tQsQ); ++m) {\n            int idx = m_block * kBlockM + get<0>(tQcQ(_0{}, m, _0{}));\n            Element const* q_ptr = reinterpret_cast<Element const*>(__shfl_sync(0xffffffff, reinterpret_cast<uint64_t>(tPrQPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow));\n            if (idx < seqlen_q * qhead_per_khead) {\n                // if (thread_idx == 0) { printf(\"m: %d, m_idx: %d, h_idx: %d, q_ptr = %p, q_ptr_og = %p\\n\", m, m_idx, h_idx, q_ptr, &mQ_copy(0, make_coord(h_idx, m_idx), 0));}\n                Tensor mQ_cur = make_tensor(make_gmem_ptr(q_ptr), Shape<Int<kHeadDim>>{});\n                Tensor mQ_cur_copy = cute::tiled_divide(mQ_cur, Shape<Int<kGmemElemsPerLoad>>{});\n                #pragma unroll\n                for (int k = 0; k < size<2>(tQsQ); ++k) {\n                    int ki = get<1>(tQcQ(_0{}, _0{}, k)) / kGmemElemsPerLoad;\n                    // the \"tiled_copy.with(tQpQ(k))\"\" will fill in zero for columns where tQpQ(k) is false\n                    // TODO: check this\n                    cute::copy(gmem_tiled_copy_Q_cp_async.with(tQpQ(k)), mQ_cur_copy(_, ki), tQsQ(_, m, k));\n                }\n            } // Don't need to fill in 0s for sQ since we're not gonna write the output to gmem for those rows\n        }\n    };\n\n    template <typename TensormLSE, typename TensorsLSE, typename TiledMma>\n    CUTLASS_DEVICE\n    static void\n    store_LSE(TensormLSE &mLSE,  // ((qhead_per_khead, seqlen_q))\n              TensorsLSE const &tLSErLSE,  // (kBlockM) split across threads according to tiled_mma\n              TiledMma tiled_mma,\n              cutlass::FastDivmod const &qhead_per_khead_divmod,\n              int const thread_idx, int const seqlen_o, int const m_block\n             )\n    {\n        Tensor caccO = cute::make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{});\n        auto thread_mma = tiled_mma.get_thread_slice(thread_idx);\n        Tensor taccOcO = thread_mma.partition_C(caccO);                           // (MMA,MMA_M,MMA_K)\n        Tensor taccOcO_row = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout()))(_, _0{});\n        CUTE_STATIC_ASSERT_V(size(tLSErLSE) == size(taccOcO_row));                     // MMA_M\n\n        // If PackGQA, we split the work of compute divmod among threads in the same row\n        static constexpr int kMmaThreadsPerRow = size<0, 0>(typename TiledMma::AtomLayoutC_TV{});\n        static_assert(cutlass::NumThreadsPerWarp % kMmaThreadsPerRow == 0);\n        static_assert(CUTE_STATIC_V(size(tLSErLSE)) <= kMmaThreadsPerRow);\n        static_assert(CUTE_STATIC_V(size(taccOcO_row)) <= kMmaThreadsPerRow);\n\n        Tensor tPrLSEPtr = compute_ptr<kMmaThreadsPerRow>(mLSE, taccOcO_row, qhead_per_khead_divmod, thread_idx, m_block);\n        static_assert(CUTE_STATIC_V(size(tPrLSEPtr)) == 1);\n        int const qhead_per_khead = qhead_per_khead_divmod.divisor;\n        #pragma unroll\n        for (int mi = 0; mi < size(tLSErLSE); ++mi) {\n            int const row = m_block * kBlockM + get<0>(taccOcO_row(mi));\n            float* ptr_LSE_cur = reinterpret_cast<float*>(__shfl_sync(0xffffffff, reinterpret_cast<uint64_t>(tPrLSEPtr[0]), mi % kMmaThreadsPerRow, kMmaThreadsPerRow));\n            if (get<1>(taccOcO_row(_0{})) == 0 && row < seqlen_o * qhead_per_khead) {\n                *ptr_LSE_cur = tLSErLSE(mi);\n            }\n        }\n    };\n\n    template <typename TensormO, typename TensorrO>\n    CUTLASS_DEVICE\n    static void\n    store_O(TensormO &mO,  // ((qhead_per_khead, seqlen_o), headdim)\n            TensorrO const &tOrO,  // (kBlockM, kHeadDim) split across threads according to gmem_tiled_copy_O\n            cutlass::FastDivmod const &qhead_per_khead_divmod,\n            int const thread_idx, int const seqlen_o, int const m_block\n          )\n    {\n        GmemTiledCopyO gmem_tiled_copy_O;\n        auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);\n        Tensor cO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{});    // (BLK_M,BLK_K) -> (blk_m,blk_k)\n        Tensor tOcO = gmem_thr_copy_O.partition_D(cO);       // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)\n        Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOcO)));\n        #pragma unroll\n        for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < size<1>(mO); }\n\n        // Similar to loading K and V when PagedKV, it's expensive to compute the pointers for O.\n        // We split the work among threads loading the same row of O, then __shfl_sync the pointers.\n        Tensor mO_0 = mO(_, _0{});\n        Tensor tOcO_row = tOcO(_0{}, _, _0{});\n        Tensor tPrOPtr = compute_ptr(mO_0, tOcO_row, qhead_per_khead_divmod, thread_idx, m_block);\n        int const qhead_per_khead = qhead_per_khead_divmod.divisor;\n        #pragma unroll\n        for (int m = 0; m < size<1>(tOrO); ++m) {\n            int idx = m_block * kBlockM + get<0>(tOcO(_0{}, m, _0{}));\n            Element* o_ptr = reinterpret_cast<Element*>(__shfl_sync(0xffffffff, reinterpret_cast<uint64_t>(tPrOPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow));\n            if (idx < seqlen_o * qhead_per_khead) {\n                Tensor mO_cur = make_tensor(make_gmem_ptr(o_ptr), Shape<Int<kHeadDim>>{});\n                Tensor mO_cur_copy = cute::tiled_divide(mO_cur, Shape<Int<kGmemElemsPerStore>>{});\n                #pragma unroll\n                for (int k = 0; k < size<2>(tOrO); ++k) {\n                    int ki = get<1>(tOcO(_0{}, _0{}, k)) / kGmemElemsPerStore;\n                    if (tOpO(k)) {\n                        cute::copy(gmem_tiled_copy_O, tOrO(_, m, k), mO_cur_copy(_, ki));\n                    }\n                }\n            }\n        }\n    };\n\n    template <typename TensormO, typename TensorrO, typename TiledMma>\n    CUTLASS_DEVICE\n    static void\n    store_O_direct(TensormO &mO,  // ((qhead_per_khead, seqlen_o), headdim)\n                   TensorrO const &tOrO,  // (kBlockM, kHeadDim) split across threads according to tiled_mma\n                   TiledMma tiled_mma,\n                   cutlass::FastDivmod const &qhead_per_khead_divmod,\n                   int const thread_idx, int const seqlen_o, int const m_block\n                 )\n    {\n        static constexpr int kGmemElemsPerStoreDirect = 2;\n        cute::Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element> gmem_copy_direct;\n        // Reshape acc from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))\n        Tensor tOrO_rowcol = make_tensor(tOrO.data(), flash::convert_layout_acc_rowcol(tOrO.layout()));\n        Tensor tOrO_copy = cute::tiled_divide(tOrO_rowcol, Shape<_1, Int<kGmemElemsPerStoreDirect>>{});\n\n        Tensor caccO = cute::make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{});\n        auto thread_mma = tiled_mma.get_thread_slice(thread_idx);\n        Tensor taccOcO = thread_mma.partition_C(caccO);                           // (MMA,MMA_M,MMA_K)\n        Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout()));\n        Tensor taccOcO_row = taccOcO_rowcol(_, _0{});\n        Tensor taccOcO_col = taccOcO_rowcol(_0{}, _);\n\n        // If PackGQA, we split the work of compute divmod among threads in the same row\n        static constexpr int kMmaThreadsPerRow = size<0, 0>(typename TiledMma::AtomLayoutC_TV{});\n        static_assert(cutlass::NumThreadsPerWarp % kMmaThreadsPerRow == 0);\n        static_assert(CUTE_STATIC_V(size(taccOcO_row)) <= kMmaThreadsPerRow);\n\n        // Similar to loading K and V when PagedKV, it's expensive to compute the pointers for O.\n        // We split the work among threads loading the same row of O, then __shfl_sync the pointers.\n        Tensor mO_0 = mO(_, _0{});\n        Tensor tPrOPtr = compute_ptr<kMmaThreadsPerRow>(mO_0, taccOcO_row, qhead_per_khead_divmod, thread_idx, m_block);\n        static_assert(CUTE_STATIC_V(size(tPrOPtr)) == 1);\n\n        int const qhead_per_khead = qhead_per_khead_divmod.divisor;\n        #pragma unroll\n        for (int m = 0; m < size<1>(tOrO_copy); ++m) {\n            int row = m_block * kBlockM + get<0>(taccOcO_row(m));\n            Element* o_ptr = reinterpret_cast<Element*>(__shfl_sync(0xffffffff, reinterpret_cast<uint64_t>(tPrOPtr[0]), m % kMmaThreadsPerRow, kMmaThreadsPerRow));\n            if (row < seqlen_o * qhead_per_khead) {\n                Tensor mO_cur = make_tensor(make_gmem_ptr(o_ptr), Shape<Int<kHeadDim>>{});\n                Tensor mO_cur_copy = cute::tiled_divide(mO_cur, Shape<Int<kGmemElemsPerStoreDirect>>{});\n                #pragma unroll\n                for (int k = 0; k < size<2>(tOrO_copy); ++k) {\n                    int col = get<1>(taccOcO_col(k * kGmemElemsPerStoreDirect));\n                    if (col < size<1>(mO)) {\n                        cute::copy(gmem_copy_direct, tOrO_copy(_, m, k), mO_cur_copy(_, col / kGmemElemsPerStoreDirect));\n                    }\n                }\n            }\n        }\n    };\n\n};\n\n} // namespace flash\n"
  },
  {
    "path": "hopper/padding.py",
    "content": "# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py\n\nimport torch\nimport torch.nn.functional as F\nfrom einops import rearrange\n\n\ndef unpad_input(hidden_states, attention_mask, unused_mask=None):\n    \"\"\"\n    Arguments:\n        hidden_states: (batch, seqlen, ...)\n        attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.\n        unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.\n    Return:\n        hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.\n        indices: (total_nnz), the indices of masked tokens from the flattened input sequence.\n        cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.\n        max_seqlen_in_batch: int\n        seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.\n    \"\"\"\n    all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask\n    seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)\n    used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)\n    indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()\n    max_seqlen_in_batch = seqlens_in_batch.max().item()\n    cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))\n    # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the\n    # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim\n    # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to\n    # index with integer indices.\n    return (\n        rearrange(hidden_states, \"b s ... -> (b s) ...\")[indices],\n        indices,\n        cu_seqlens,\n        max_seqlen_in_batch,\n        used_seqlens_in_batch,\n    )\n\n\ndef pad_input(hidden_states, indices, batch, seqlen):\n    \"\"\"\n    Arguments:\n        hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.\n        indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.\n        batch: int, batch size for the padded sequence.\n        seqlen: int, maximum sequence length for the padded sequence.\n    Return:\n        hidden_states: (batch, seqlen, ...)\n    \"\"\"\n    dim = hidden_states.shape[1:]\n    output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype)\n    output[indices] = hidden_states\n    return rearrange(output, \"(b s) ... -> b s ...\", b=batch)\n"
  },
  {
    "path": "hopper/paged_kv.h",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n#include <cute/tensor.hpp>\n\n#include \"cutlass/fast_math.h\"  // For cutlass::FastDivmod\n\n#include \"utils.h\"\n\nnamespace flash {\n\nusing namespace cute;\n\ntemplate <int kBlockN, int kHeadDim, int kHeadDimV, int NumThreads, typename Element, bool KV_Same_Iter=false, int LoadsPerRow_LB=1>\nstruct PagedKVManager {\n    // If KV_Same_Iter=false, then we do load_page_table(0), load_K(0), load_page_table(1), load_K(1), load_V(0),\n    // load_page_table(2), load_K(2), load_V(1), etc.\n    // So we need to compute the V pointers for the previous iteration.\n\n    // LoadsPerRow_LB is the lower bound on number of loads per row in the K direction. This is useful for\n    // rotary where we want each thread to have at least 2 loads per row.\n\n    static constexpr bool SameHeadDim = (kHeadDim == kHeadDimV);\n    static constexpr int kHeadDimGCD = cute::gcd(kHeadDim, kHeadDimV);\n\n    // We use CpAsync for K and V if PagedKV, since TMA doesn't work there\n    static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);\n    static_assert(kHeadDimGCD % kGmemElemsPerLoad == 0, \"Headdim and HeaddimV must be a multiple of kGmemElemsPerLoad\");\n    // We want each \"row\" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each\n    // thread to have 4 loads in the M direction and 2 vectorized load in the K direction.\n    // In the case of PackGQA, this reduces the number of times we need to call divmod.\n    static_assert(kHeadDimGCD % LoadsPerRow_LB == 0, \"Headdim and HeaddimV must be a multiple of LoadsPerRow_LB\");\n    static constexpr int kBytePerRow = kHeadDimGCD / LoadsPerRow_LB * sizeof(Element);\n    static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element);\n    static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad;\n    static_assert(NumThreads % kGmemThreadsPerRow == 0, \"NumThreads must be a multiple of kGmemThreadsPerRow\");\n    // We assume threads loading the same row are in the same warp. This is for an optimization in PagedKV where\n    // these threads share the same page table entry and share the work of computing pointers to paged K and paged V.\n    static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0, \"kGmemThreadsPerRow must divide NumThreadsPerWarp\");\n    using GmemCopyAtomCpAsync = cute::Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL_ZFILL<uint128_t>, Element>;\n    using GmemLayoutAtomKVCpAsync = Layout<Shape <Int<NumThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,\n                                           Stride<Int<kGmemThreadsPerRow>, _1>>;\n    using GmemTiledCopyKVCpAsync = decltype(\n        make_tiled_copy(GmemCopyAtomCpAsync{},\n                        GmemLayoutAtomKVCpAsync{},\n                        Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{}));  // Val layout, 8 or 16 vals per load\n    using GmemTiledCopyKVStore = decltype(\n        make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},\n                        GmemLayoutAtomKVCpAsync{},\n                        Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{}));  // Val layout, 8 or 16 vals per load\n\n    using ShapeKV = cute::Shape<int32_t, int32_t, int32_t, int32_t>;  // (seqlen, d, head, batch)\n    using StrideKV = cute::Stride<int64_t, _1, int64_t, int64_t>;\n    using ShapePageTable = cute::Shape<int32_t, int32_t>;  // (batch, max_num_pages_per_seq)\n    using StridePageTable = cute::Stride<int64_t, _1>;\n\n    using TensorPageTable = decltype(make_tensor(make_gmem_ptr(static_cast<int const*>(nullptr)), ShapePageTable{}, StridePageTable{})(int(0), _));\n    using TensorKV = decltype(make_tensor(make_gmem_ptr(static_cast<Element*>(nullptr)), ShapeKV{}, StrideKV{})(_, _, int(0), _));\n    using GmemThrCopyKVCpAsync = decltype(GmemTiledCopyKVCpAsync{}.get_thread_slice(int(0)));\n    using TensortKcK = decltype(GmemTiledCopyKVCpAsync{}.get_thread_slice(int(0)).partition_D(cute::make_identity_tensor(Shape<Int<kBlockN>, Int<kHeadDim>>{})));\n    using TensortKpK = decltype(make_tensor<bool>(make_shape(size<1>(TensortKcK{}), size<2>(TensortKcK{})), Stride<_0, _1>{}));\n    using TensortVcV = decltype(GmemTiledCopyKVCpAsync{}.get_thread_slice(int(0)).partition_D(cute::make_identity_tensor(Shape<Int<kBlockN>, Int<kHeadDimV>>{})));\n    using TensortVpV = decltype(make_tensor<bool>(make_shape(size<1>(TensortVcV{}), size<2>(TensortVcV{})), Stride<_0, _1>{}));\n\n    // For PagedKV, it's expensive the calculate the pointers to K and V for each page table entry,\n    // since those require int64_t arithmetic. We optimize by having threads split this work.\n    // Typically there are 8 threads loading per row (e.g. hdim 64 and 128), and there are 11 rows\n    // that each thread needs to load for the case of hdim 128 and kBlockN = 176.\n    // So each of those 8 threads will calculate the K_ptr and V_ptr for 11 / 8 = 2 rows.\n    // We then use __shfl_sync to broadcast the pointers to the other threads in the warp.\n    static_assert(CUTE_STATIC_V(size<1>(TensortKcK{})) == CUTE_STATIC_V(size<1>(TensortVcV{})));\n    static constexpr int kPageEntryPerThread = cute::ceil_div(size<1>(TensortKcK{}), kGmemThreadsPerRow);\n    using TensorPageOffset = decltype(make_tensor<cute::tuple<int, int>>(Shape<Int<kPageEntryPerThread>>{}));\n    using TensorKVPtr = decltype(make_tensor<Element*>(Shape<Int<kPageEntryPerThread>>{}));\n\n    GmemTiledCopyKVCpAsync gmem_tiled_copy_kv;\n    cutlass::FastDivmod const &page_size_divmod;\n    cutlass::FastDivmod const &blockN_per_page_size_divmod;\n    int const thread_idx;\n    int const seqlen_k;\n    int const leftpad_k;\n    int const* const ptr_page_table;\n    GmemThrCopyKVCpAsync const gmem_thr_copy_kv;\n    TensorPageTable mPageTable;\n    TensorKV mK_paged, mV_paged;\n    TensortKpK tKpK;\n    TensortVpV tVpV;\n    TensorPageOffset tPrPageOffset;\n    TensorKVPtr tPrVPtr;\n    int bidb_kv_idx, bidb_kv_idx_prev, n_block_idx, n_block_idx_prev;  // Only used for TMA\n\n    CUTLASS_DEVICE\n    PagedKVManager(int const* const ptr_page_table_,\n                   ShapePageTable const &shape_pagetable, StridePageTable const &stride_pagetable,\n                   Element* const ptr_K, ShapeKV const &shape_K, StrideKV const &stride_K,\n                   Element* const ptr_V, int const headdim_v, StrideKV const &stride_V,\n                   cutlass::FastDivmod const &page_size_divmod,\n                   cutlass::FastDivmod const &blockN_per_page_size_divmod,\n                   int const bidb, int const bidh, int const thread_idx, int const seqlen_k, int const leftpad_k,\n                   int bidb_kv_idx\n                   )\n        : page_size_divmod(page_size_divmod)\n        , blockN_per_page_size_divmod(blockN_per_page_size_divmod)\n        , thread_idx(thread_idx)\n        , seqlen_k(seqlen_k)\n        , leftpad_k(leftpad_k)\n        , ptr_page_table(ptr_page_table_)\n        , gmem_thr_copy_kv(gmem_tiled_copy_kv.get_thread_slice(thread_idx))\n        , bidb_kv_idx(bidb_kv_idx)\n        , bidb_kv_idx_prev(bidb_kv_idx)\n\n    {\n        mPageTable = make_tensor(make_gmem_ptr(ptr_page_table), shape_pagetable, stride_pagetable)(bidb, _);\n        mK_paged = make_tensor(make_gmem_ptr(ptr_K), shape_K, stride_K)(_, _, bidh, _);\n        auto shape_V = make_shape(get<0>(shape_K), headdim_v, get<2>(shape_K), get<3>(shape_K));\n        mV_paged = make_tensor(make_gmem_ptr(ptr_V), shape_V, stride_V)(_, _, bidh, _);\n        tKpK = make_tensor<bool>(make_shape(size<1>(TensortKcK{}), size<2>(TensortKcK{})), Stride<_0, _1>{});\n        Tensor cK = cute::make_identity_tensor(Shape<Int<kBlockN>, Int<kHeadDim>>{});  // (BLK_N,BLK_K) -> (blk_n,blk_k)\n        Tensor tKcK = gmem_thr_copy_kv.partition_S(cK);\n        #pragma unroll\n        for (int k = 0; k < size<1>(tKpK); ++k) { tKpK(_0{}, k) = get<1>(tKcK(_0{}, _0{}, k)) < get<1>(shape_K); }\n        Tensor tVpV_ = make_tensor<bool>(make_shape(size<1>(TensortVcV{}), size<2>(TensortVcV{})), Stride<_0, _1>{});\n        Tensor cV = cute::make_identity_tensor(Shape<Int<kBlockN>, Int<kHeadDimV>>{});  // (BLK_N,BLK_K) -> (blk_n,blk_k)\n        Tensor tVcV = gmem_thr_copy_kv.partition_S(cV);\n        #pragma unroll\n        for (int k = 0; k < size<1>(tVpV_); ++k) { tVpV_(_0{}, k) = get<1>(tVcV(_0{}, _0{}, k)) < get<1>(shape_V); }\n        tVpV = cute::conditional_return<SameHeadDim>(tKpK, tVpV_);\n    };\n\n    template <bool Seqlenk_mask=false, bool First_iter=false>\n    CUTLASS_DEVICE\n    void load_page_table(const int n_block) {\n        // The uncoalesced gmem load is intentional. This is so that each thread only loads the page table entries\n        // it needs, and we don't need any sync between warps.\n        // Assuming 8 threads per row, and 176 rows, then the rows from 0 to 175 are loaded by\n        // threads 0, 8, 16, ..., 120, 1, 9, ..., 121, 2, 10, ..., 122, etc.\n        #pragma unroll\n        for (int i = 0; i < kPageEntryPerThread; ++i) {\n            int const row = i * NumThreads + (thread_idx % kGmemThreadsPerRow) * (NumThreads / kGmemThreadsPerRow) + (thread_idx / kGmemThreadsPerRow);\n            int const row_idx = n_block * kBlockN + row;\n            int page_idx, page_offset;\n            page_idx = page_size_divmod.divmod(page_offset, row_idx + leftpad_k);\n            // Add the condition (i + 1) * NumThreads <= kBlockN since that is an upper bound of row\n            // and is known at compile time. It avoids branching when e.g., kBlockN = 176 and i = 0.\n            int const page = ((i + 1) * NumThreads <= kBlockN || row < kBlockN) && (!Seqlenk_mask || row_idx < seqlen_k) ? mPageTable[page_idx] : 0;\n            tPrPageOffset[i] = {page, page_offset};\n            // if (cute::thread0()) { printf(\"row = %d, page_idx = %d, page_offset = %d, page = %d, leftpad_k = %d, seqlen_k = %d\\n\", row, page_idx, page_offset, page, leftpad_k, seqlen_k); }\n        }\n        if constexpr (First_iter && !KV_Same_Iter) { compute_V_ptr(); }\n    };\n\n    template <bool First_iter=false>\n    CUTLASS_DEVICE\n    void load_page_table_TMA(const int n_block) {\n        // We require that page size is a multiple of kBlockN, and there's no leftpad_k\n        if (ptr_page_table) {\n            bidb_kv_idx = mPageTable[blockN_per_page_size_divmod.divmod(n_block_idx, n_block)];\n        } else {\n            n_block_idx = n_block;\n        }\n        if constexpr (First_iter && !KV_Same_Iter) {\n            bidb_kv_idx_prev = bidb_kv_idx;\n            n_block_idx_prev = n_block_idx;\n        }\n    };\n\n    CUTLASS_DEVICE\n    cute::tuple<int, int> get_indices_for_K_TMA() {\n        return {n_block_idx, bidb_kv_idx};\n    };\n\n    CUTLASS_DEVICE\n    cute::tuple<int, int> get_indices_for_V_TMA() {\n        if constexpr (KV_Same_Iter) {\n            return {n_block_idx, bidb_kv_idx};\n        } else {\n            cute::tuple<int, int> const indices = {n_block_idx_prev, bidb_kv_idx_prev};\n            bidb_kv_idx_prev = bidb_kv_idx;\n            n_block_idx_prev = n_block_idx;\n            return indices;\n        }\n    };\n\n    CUTLASS_DEVICE\n    TensorKVPtr compute_K_ptr() {\n        Tensor tPrKPtr = make_tensor<Element*>(Shape<Int<kPageEntryPerThread>>{});\n        #pragma unroll\n        for (int i = 0; i < kPageEntryPerThread; ++i) {\n            auto [page, page_offset] = tPrPageOffset[i];\n            tPrKPtr[i] = &mK_paged(page_offset, _0{}, page);\n        }\n        return tPrKPtr;\n    };\n\n    CUTLASS_DEVICE\n    void compute_V_ptr() {\n        #pragma unroll\n        for (int i = 0; i < kPageEntryPerThread; ++i) {\n            auto [page, page_offset] = tPrPageOffset[i];\n            tPrVPtr[i] = &mV_paged(page_offset, _0{}, page);\n        }\n    };\n\n    template <bool Seqlenk_mask=false, typename TensorK>\n    CUTLASS_DEVICE\n    void load_K(const int n_block, TensorK &&sK) {\n        // Do we need bound check to make sure the row doesn't go above kBlockN\n        static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtomKVCpAsync{})) == 0;\n\n        Tensor tPrKPtr = compute_K_ptr();\n\n        // Only for index calculation, since all the indices of thread 0 are known at compile time\n        auto gmem_thr0_copy_kv = gmem_tiled_copy_kv.get_thread_slice(_0{});\n        Tensor tKsK = gmem_thr_copy_kv.partition_D(sK);\n        Tensor cK = cute::make_identity_tensor(Shape<Int<kBlockN>, Int<kHeadDim>>{});  // (BLK_N,BLK_K) -> (blk_n,blk_k)\n        // Repeat the partitioning with identity layouts\n        Tensor tKcK = gmem_thr_copy_kv.partition_S(cK);\n        Tensor t0KcK = gmem_thr0_copy_kv.partition_S(cK);\n\n        // We want to use the row indices of thread0 to compare, since that is known at compile time.\n        // So we subtract the limit by the first row index of this thread (get<0>(tKcK(_0{}, _0{}, _0{})))\n        int const seqlenk_row_limit = -int(get<0>(tKcK(_0{}, _0{}, _0{}))) + (EvenN\n            ? seqlen_k - n_block * kBlockN\n            : (!Seqlenk_mask ? kBlockN : std::min(seqlen_k - n_block * kBlockN, kBlockN)));\n        #pragma unroll\n        for (int m = 0; m < size<1>(tKsK); ++m) {\n            bool const should_load = EvenN\n                ? (!Seqlenk_mask || get<0>(t0KcK(_0{}, m, _0{})) < seqlenk_row_limit)\n                : get<0>(t0KcK(_0{}, m, _0{})) < seqlenk_row_limit;\n            Element const* k_ptr = reinterpret_cast<Element const*>(__shfl_sync(0xffffffff, reinterpret_cast<uint64_t>(tPrKPtr(m / kGmemThreadsPerRow)), (m % kGmemThreadsPerRow), kGmemThreadsPerRow));\n            Tensor mK_paged_cur = make_tensor(make_gmem_ptr(k_ptr), Shape<Int<kHeadDim>>{});\n            Tensor mK_paged_cur_copy = cute::tiled_divide(mK_paged_cur, Shape<Int<kGmemElemsPerLoad>>{});\n            if (should_load) {\n                #pragma unroll\n                for (int k = 0; k < size<2>(tKsK); ++k) {\n                    int const ki = get<1>(tKcK(_0{}, _0{}, k)) / kGmemElemsPerLoad;\n                    cute::copy(gmem_tiled_copy_kv.with(tKpK(_0{}, k)), mK_paged_cur_copy(_, ki), tKsK(_, m, k));\n                }\n            }  // Don't need to clear out the rest of the smem since we'll mask out the scores anyway\n        }\n    };\n\n    template <bool Seqlenk_mask=false, typename TensorV>\n    CUTLASS_DEVICE\n    void load_V(const int n_block, TensorV &&sV) {\n        // Do we need bound check to make sure the row doesn't go above kBlockN\n        static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtomKVCpAsync{})) == 0;\n\n        if constexpr (KV_Same_Iter) { compute_V_ptr(); }\n        // Only for index calculation, since all the indices of thread 0 are known at compile time\n        auto gmem_thr0_copy_kv = gmem_tiled_copy_kv.get_thread_slice(_0{});\n        Tensor tVsV = gmem_thr_copy_kv.partition_D(sV);\n        Tensor cV = cute::make_identity_tensor(Shape<Int<kBlockN>, Int<kHeadDimV>>{});  // (BLK_N,BLK_K) -> (blk_n,blk_k)\n        // Repeat the partitioning with identity layouts\n        Tensor tVcV = gmem_thr_copy_kv.partition_S(cV);\n        Tensor t0VcV = gmem_thr0_copy_kv.partition_S(cV);\n\n        int const seqlenk_row_limit = seqlen_k - n_block * kBlockN - get<0>(tVcV(_0{}, _0{}, _0{}));\n        #pragma unroll\n        for (int m = 0; m < size<1>(tVsV); ++m) {\n            // Faster to rely on the cp.async to clear smem that are out of bound,\n            // rather than calling cute::clear directly.\n            // We have to be careful not to write to smem past `kBlockN` if !EvenN.\n            // If kBlockN doesn't evenly divide the tiled copy, only the last `m` needs to checked\n            if (EvenN || m < size<1>(tVsV) - 1 || get<0>(tVcV(_0{}, m, _0{})) < kBlockN) {\n                bool const should_load = !Seqlenk_mask || get<0>(t0VcV(_0{}, m, _0{})) < seqlenk_row_limit;\n                Element const* v_ptr = reinterpret_cast<Element const*>(__shfl_sync(0xffffffff, reinterpret_cast<uint64_t>(tPrVPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow));\n                Tensor mV_paged_cur = make_tensor(make_gmem_ptr(v_ptr), Shape<Int<kHeadDimV>>{});\n                Tensor mV_paged_cur_copy = cute::tiled_divide(mV_paged_cur, Shape<Int<kGmemElemsPerLoad>>{});\n                #pragma unroll\n                for (int k = 0; k < size<2>(tVsV); ++k) {\n                    int const ki = get<1>(tVcV(_0{}, _0{}, k)) / kGmemElemsPerLoad;\n                    cute::copy(gmem_tiled_copy_kv.with(tVpV(_0{}, k) && should_load), mV_paged_cur_copy(_, ki), tVsV(_, m, k));\n                }\n            }\n        }\n        if constexpr (!KV_Same_Iter) { compute_V_ptr(); }\n    };\n\n    template <typename TensorK>\n    CUTLASS_DEVICE\n    void store_K(const int n_block, TensorK &&tKrK) {\n        Tensor tPrKPtr = compute_K_ptr();\n        // We're using the same partitioning as GmemTiledCopyKVCpAsync (used for loading)\n        // Only for index calculation, since all the indices of thread 0 are known at compile time\n        auto gmem_thr0_copy_kv = gmem_tiled_copy_kv.get_thread_slice(_0{});\n        Tensor cK = cute::make_identity_tensor(Shape<Int<kBlockN>, Int<kHeadDim>>{});  // (BLK_N,BLK_K) -> (blk_n,blk_k)\n        // Repeat the partitioning with identity layouts\n        Tensor tKcK = gmem_thr_copy_kv.partition_S(cK);\n        Tensor t0KcK = gmem_thr0_copy_kv.partition_S(cK);\n\n        GmemTiledCopyKVStore gmem_tiled_copy_kv_store;\n        // We want to use the row indices of thread0 to compare, since that is known at compile time.\n        // So we subtract the limit by the first row index of this thread (get<0>(tKcK(_0{}, _0{}, _0{})))\n        // int const seqlenk_row_limit = seqlen_k - n_block * kBlockN - get<0>(tKcK(_0{}, _0{}, _0{}));\n        int const seqlenk_row_limit = std::min(seqlen_k - n_block * kBlockN, kBlockN) - get<0>(tKcK(_0{}, _0{}, _0{}));\n        // if (threadIdx.x == 128) { printf(\"bidx = %d, bidy = %d, bidz = %d, seqlen_k = %d, seqlenk_row_limit = %d\\n\", blockIdx.x, blockIdx.y, blockIdx.z, seqlen_k, seqlenk_row_limit); }\n        #pragma unroll\n        for (int m = 0; m < size<1>(tKrK); ++m) {\n            bool const should_load = get<0>(t0KcK(_0{}, m, _0{})) < seqlenk_row_limit;\n            Element* k_ptr = reinterpret_cast<Element*>(__shfl_sync(0xffffffff, reinterpret_cast<uint64_t>(tPrKPtr(m / kGmemThreadsPerRow)), (m % kGmemThreadsPerRow), kGmemThreadsPerRow));\n            Tensor mK_paged_cur = make_tensor(make_gmem_ptr(k_ptr), Shape<Int<kHeadDim>>{});\n            Tensor mK_paged_cur_copy = cute::tiled_divide(mK_paged_cur, Shape<Int<kGmemElemsPerLoad>>{});\n            if (should_load) {\n                #pragma unroll\n                for (int k = 0; k < size<2>(tKrK); ++k) {\n                    int const ki = get<1>(tKcK(_0{}, _0{}, k)) / kGmemElemsPerLoad;\n                    if (tKpK(_0{}, k)) {\n                        cute::copy(gmem_tiled_copy_kv_store, tKrK(_, m, k), mK_paged_cur_copy(_, ki));\n                    }\n                }\n            }\n        }\n    };\n\n    template <typename TensorV>\n    CUTLASS_DEVICE\n    void store_V(const int n_block, TensorV &&tVrV) {\n        if constexpr (KV_Same_Iter) { compute_V_ptr(); }\n        // Only for index calculation, since all the indices of thread 0 are known at compile time\n        auto gmem_thr0_copy_kv = gmem_tiled_copy_kv.get_thread_slice(_0{});\n        Tensor cV = cute::make_identity_tensor(Shape<Int<kBlockN>, Int<kHeadDimV>>{});  // (BLK_N,BLK_K) -> (blk_n,blk_k)\n        // Repeat the partitioning with identity layouts\n        Tensor tVcV = gmem_thr_copy_kv.partition_S(cV);\n        Tensor t0VcV = gmem_thr0_copy_kv.partition_S(cV);\n\n        GmemTiledCopyKVStore gmem_tiled_copy_kv_store;\n        int const seqlenk_row_limit = std::min(seqlen_k - n_block * kBlockN, kBlockN) - get<0>(tVcV(_0{}, _0{}, _0{}));\n        #pragma unroll\n        for (int m = 0; m < size<1>(tVrV); ++m) {\n            bool const should_load = get<0>(t0VcV(_0{}, m, _0{})) < seqlenk_row_limit;\n            Element* v_ptr = reinterpret_cast<Element*>(__shfl_sync(0xffffffff, reinterpret_cast<uint64_t>(tPrVPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow));\n            Tensor mV_paged_cur = make_tensor(make_gmem_ptr(v_ptr), Shape<Int<kHeadDimV>>{});\n            Tensor mV_paged_cur_copy = cute::tiled_divide(mV_paged_cur, Shape<Int<kGmemElemsPerLoad>>{});\n            if (should_load) {\n                #pragma unroll\n                for (int k = 0; k < size<2>(tVrV); ++k) {\n                    int const ki = get<1>(tVcV(_0{}, _0{}, k)) / kGmemElemsPerLoad;\n                    if (tVpV(_0{}, k)) {\n                        cute::copy(gmem_tiled_copy_kv_store, tVrV(_, m, k), mV_paged_cur_copy(_, ki));\n                    }\n                }\n            }\n        }\n        if constexpr (!KV_Same_Iter) { compute_V_ptr(); }\n    };\n\n\n};\n\n} // namespace flash\n"
  },
  {
    "path": "hopper/rotary.h",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n#include <cute/tensor.hpp>\n\n#include \"utils.h\"\n\nnamespace flash {\n\nusing namespace cute;\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Engine1, typename Layout1, typename Engine2, typename Layout2>\nCUTLASS_DEVICE void\napply_rotary_interleaved(Tensor<Engine1, Layout1> &rK,\n                         Tensor<Engine2, Layout2> const &rCos,\n                         Tensor<Engine2, Layout2> const &rSin) {\n    CUTE_STATIC_ASSERT_V(rank(rK) == _1{});\n    CUTE_STATIC_ASSERT_V(rank(rCos) == _1{});\n    CUTE_STATIC_ASSERT_V(rank(rSin) == _1{});\n    CUTE_STATIC_ASSERT_V(size<0>(rCos) == size<0>(rSin));\n    static_assert(decltype(size<0>(rK))::value == decltype(size<0>(rCos))::value * 2);\n    static_assert(decltype(size<0>(rCos))::value % 2 == 0);  // Since we do fast conversion from fp16/bf16 to fp32\n    Tensor K_fp32 = make_tensor_like<float>(rK);\n    convert_type_out(rK, K_fp32);\n    Tensor cos_fp32 = make_tensor_like<float>(rCos);\n    convert_type_out(rCos, cos_fp32);\n    Tensor sin_fp32 = make_tensor_like<float>(rSin);\n    convert_type_out(rSin, sin_fp32);\n    #pragma unroll\n    for (int i = 0; i < size<0>(K_fp32) / 2; ++i) {\n        float real = K_fp32[2 * i] * cos_fp32[i] - K_fp32[2 * i + 1] * sin_fp32[i];\n        float imag = K_fp32[2 * i] * sin_fp32[i] + K_fp32[2 * i + 1] * cos_fp32[i];\n        K_fp32[2 * i] = real;\n        K_fp32[2 * i + 1] = imag;\n    }\n    convert_type_out(K_fp32, rK);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Engine1, typename Layout1, typename Engine2, typename Layout2>\nCUTLASS_DEVICE void\napply_rotary_contiguous(Tensor<Engine1, Layout1> &rK_left,\n                        Tensor<Engine1, Layout1> &rK_right,\n                        Tensor<Engine2, Layout2> const &rCos,\n                        Tensor<Engine2, Layout2> const &rSin) {\n    CUTE_STATIC_ASSERT_V(rank(rK_left) == _1{});\n    CUTE_STATIC_ASSERT_V(rank(rK_right) == _1{});\n    CUTE_STATIC_ASSERT_V(rank(rCos) == _1{});\n    CUTE_STATIC_ASSERT_V(rank(rSin) == _1{});\n    CUTE_STATIC_ASSERT_V(size<0>(rK_left) == size<0>(rK_right));\n    CUTE_STATIC_ASSERT_V(size<0>(rK_left) == size<0>(rCos));\n    CUTE_STATIC_ASSERT_V(size<0>(rCos) == size<0>(rSin));\n    static_assert(decltype(size<0>(rCos))::value % 2 == 0);  // Since we do fast conversion from fp16/bf16 to fp32\n    Tensor K_left_fp32 = make_tensor_like<float>(rK_left);\n    convert_type_out(rK_left, K_left_fp32);\n    Tensor K_right_fp32 = make_tensor_like<float>(rK_right);\n    convert_type_out(rK_right, K_right_fp32);\n    Tensor cos_fp32 = make_tensor_like<float>(rCos);\n    convert_type_out(rCos, cos_fp32);\n    Tensor sin_fp32 = make_tensor_like<float>(rSin);\n    convert_type_out(rSin, sin_fp32);\n    #pragma unroll\n    for (int i = 0; i < size<0>(K_left_fp32); ++i) {\n        float real = K_left_fp32[i] * cos_fp32[i] - K_right_fp32[i] * sin_fp32[i];\n        float imag = K_left_fp32[i] * sin_fp32[i] + K_right_fp32[i] * cos_fp32[i];\n        K_left_fp32[i] = real;\n        K_right_fp32[i] = imag;\n    }\n    convert_type_out(K_left_fp32, rK_left);\n    convert_type_out(K_right_fp32, rK_right);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int kBlockMN, int kHeadDim, int NumThreads, typename Element, bool FixedPosition=false>\nstruct Rotary {\n\n    static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);\n    static_assert(kHeadDim % kGmemElemsPerLoad == 0, \"Headdim must be a multiple of kGmemElemsPerLoad\");\n    // We want each \"row\" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each\n    // thread to have 4 loads in the M direction and 2 vectorized load in the K direction.\n    // We want each thread to have at least 2 loads in the K direction since in the case of non-interleaved\n    // rotary (combining elements at indices 0 and rotary_dim/2, 1 and rotary_dim/2+1, etc), each thread will\n    // load twice from the same row.\n    static constexpr int kBytePerHalfRow = kHeadDim / 2 * sizeof(Element);\n    static constexpr int kBlockKGmem = (kBytePerHalfRow % 128 == 0 ? 128 : (kBytePerHalfRow % 64 == 0 ? 64 : 32)) / sizeof(Element);\n    static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad;\n    static_assert(NumThreads % kGmemThreadsPerRow == 0, \"NumThreads must be a multiple of kGmemThreadsPerRow\");\n    // We assume threads loading the same row are in the same warp.\n    static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0, \"kGmemThreadsPerRow must divide NumThreadsPerWarp\");\n\n    using LayoutAtom = Layout<Shape <Int<NumThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,\n                                  Stride<Int<kGmemThreadsPerRow>, _1>>;\n    using TiledCopyQK = decltype(\n        make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},\n                        LayoutAtom{},\n                        Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{}));  // Val layout, 8 or 16 vals per store\n    using GmemTiledCopyRotary = decltype(\n        make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<64>, Element>{},\n                        LayoutAtom{},\n                        Layout<Shape<_1, Int<kGmemElemsPerLoad / 2>>>{}));  // Val layout, 4 or 8 vals per store\n    using GmemTiledCopyRotaryCont = decltype(\n        make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},\n                        LayoutAtom{},\n                        Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{}));  // Val layout, 8 or 16 vals per store\n\n    using ShapeRotary = cute::Shape<int32_t, int32_t>;  // (seqlen_ro, rotary_dim // 2)\n    using StrideRotary = cute::Stride<int64_t, _1>;\n\n    using GmemThrCopyRotary = decltype(GmemTiledCopyRotary{}.get_thread_slice(int(0)));\n    using GmemThrCopyRotaryCont = decltype(GmemTiledCopyRotaryCont{}.get_thread_slice(int(0)));\n    using TensortRcR = decltype(GmemTiledCopyRotary{}.get_thread_slice(int(0)).partition_D(cute::make_identity_tensor(Shape<Int<kBlockMN>, Int<kHeadDim / 2>>{})));\n    using TensortRpR = decltype(make_tensor<bool>(make_shape(size<2>(TensortRcR{}))));\n    using TensortRcRCont = decltype(GmemTiledCopyRotaryCont{}.get_thread_slice(int(0)).partition_D(cute::make_identity_tensor(Shape<Int<kBlockMN>, Int<kHeadDim / 2>>{})));\n    using TensortRpRCont = decltype(make_tensor<bool>(make_shape(size<2>(TensortRcRCont{}))));\n    using TensormR = decltype(make_tensor(\n        make_gmem_ptr((Element const*)nullptr),\n        ShapeRotary{},\n        make_stride(cute::conditional_return<FixedPosition>(_0{}, int64_t(0)), _1{})));\n    using TensortRgR = decltype(\n        GmemTiledCopyRotary{}.get_thread_slice(int(0)).partition_S(make_tensor(\n            make_gmem_ptr((Element const*)nullptr),\n            make_shape(Int<kBlockMN>{}, Int<kHeadDim / 2>{}, int(0)),\n            make_stride(cute::conditional_return<FixedPosition>(_0{}, int64_t(0)), _1{}, cute::conditional_return<FixedPosition>(_0{}, int64_t(0))))));\n    using TensortRgRCont = decltype(\n        GmemTiledCopyRotaryCont{}.get_thread_slice(int(0)).partition_S(make_tensor(\n            make_gmem_ptr((Element const*)nullptr),\n            make_shape(Int<kBlockMN>{}, Int<kHeadDim / 2>{}, int(0)),\n            make_stride(cute::conditional_return<FixedPosition>(_0{}, int64_t(0)), _1{}, cute::conditional_return<FixedPosition>(_0{}, int64_t(0))))));\n\n    GmemTiledCopyRotary gmem_tiled_copy_rotary;\n    GmemTiledCopyRotaryCont gmem_tiled_copy_rotary_cont;\n    bool const is_rotary_interleaved;\n    int const rotary_dim;\n    int const thread_idx;\n    int const max_seqlen;\n    GmemThrCopyRotary const gmem_thr_copy_rotary;\n    GmemThrCopyRotaryCont const gmem_thr_copy_rotary_cont;\n    TensortRpR tRpR;\n    TensortRpRCont tRpRCont;\n    TensormR mCos, mSin;\n    TensortRgR tRgCos, tRgSin;\n    TensortRgRCont tRgCosCont, tRgSinCont;\n\n    CUTLASS_DEVICE\n    Rotary(Element const* const ptr_rotary_cos, ShapeRotary const &shape_rotary, StrideRotary const &stride_rotary_cos_,\n           Element const* const ptr_rotary_sin, StrideRotary const &stride_rotary_sin_,\n           bool const is_rotary_interleaved, int const thread_idx, int const max_seqlen, int const start_idx)\n        : is_rotary_interleaved(is_rotary_interleaved)\n        , rotary_dim(get<1>(shape_rotary) * 2)\n        , thread_idx(thread_idx)\n        , max_seqlen(max_seqlen)\n        , gmem_thr_copy_rotary(gmem_tiled_copy_rotary.get_thread_slice(thread_idx))\n        , gmem_thr_copy_rotary_cont(gmem_tiled_copy_rotary_cont.get_thread_slice(thread_idx))\n\n    {\n        auto stride_rotary_cos = make_stride(cute::conditional_return<!FixedPosition>(get<0>(stride_rotary_cos_), _0{}), get<1>(stride_rotary_cos_));\n        auto stride_rotary_sin = make_stride(cute::conditional_return<!FixedPosition>(get<0>(stride_rotary_sin_), _0{}), get<1>(stride_rotary_sin_));\n        mCos = make_tensor(make_gmem_ptr(ptr_rotary_cos + start_idx * get<0>(stride_rotary_cos_)), shape_rotary, stride_rotary_cos);\n        mSin = make_tensor(make_gmem_ptr(ptr_rotary_sin + start_idx * get<0>(stride_rotary_sin_)), shape_rotary, stride_rotary_sin);\n        Tensor gCos = local_tile(mCos, Shape<Int<kBlockMN>, Int<kHeadDim / 2>>{}, make_coord(_, _0{}));  // (MN, K / 2, _)\n        Tensor gSin = local_tile(mSin, Shape<Int<kBlockMN>, Int<kHeadDim / 2>>{}, make_coord(_, _0{}));  // (MN, K / 2, _)\n        tRgCos = gmem_thr_copy_rotary.partition_S(gCos);\n        tRgSin = gmem_thr_copy_rotary.partition_S(gSin);\n        tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCos);\n        tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSin);\n        Tensor cR = cute::make_identity_tensor(Shape<Int<kBlockMN>, Int<kHeadDim / 2>>{});  // (BLK_N,BLK_K / 2)\n        Tensor tRcR = gmem_thr_copy_rotary.partition_D(cR);\n        tRpR = make_tensor<bool>(make_shape(size<2>(tRcR)));\n        #pragma unroll\n        for (int k = 0; k < size(tRpR); ++k) { tRpR(k) = get<1>(tRcR(_0{}, _0{}, k)) < get<1>(shape_rotary); }\n        Tensor tRcRCont = gmem_thr_copy_rotary_cont.partition_D(cR);\n        tRpRCont = make_tensor<bool>(make_shape(size<2>(tRcRCont)));\n        #pragma unroll\n        for (int k = 0; k < size(tRpRCont); ++k) { tRpRCont(k) = get<1>(tRcRCont(_0{}, _0{}, k)) < get<1>(shape_rotary); }\n    };\n\n    template <bool kInterleaved=true>\n    CUTLASS_DEVICE\n    auto load_cos_sin(int const block) {\n        using GmemTiledCopyRo = std::conditional_t<kInterleaved, GmemTiledCopyRotary, GmemTiledCopyRotaryCont>;\n        auto gmem_thr_copy_ro = cute::conditional_return<kInterleaved>(gmem_thr_copy_rotary, gmem_thr_copy_rotary_cont);\n        Tensor tRpRCur = cute::conditional_return<kInterleaved>(tRpR, tRpRCont);\n        Tensor tRgCosCur = cute::conditional_return<kInterleaved>(tRgCos, tRgCosCont)(_, _, _, block);\n        Tensor tRgSinCur = cute::conditional_return<kInterleaved>(tRgSin, tRgSinCont)(_, _, _, block);\n        // make_tensor_like, not make_fragment_like. If the row_stride is _0{} we want to keep it that way\n        Tensor tRrCos = make_tensor_like(tRgCosCur);\n        Tensor tRrSin = make_tensor_like(tRgSinCur);\n        Tensor cR = cute::make_identity_tensor(Shape<Int<kBlockMN>, Int<kHeadDim / 2>>{});  // (BLK_N,BLK_K / 2)\n        Tensor tRcR = gmem_thr_copy_ro.partition_D(cR);\n        // If FixedPosition, only copy the first row as we only need the cos/sin for position cache_seqlens\n        #pragma unroll\n        for (int m = 0; m < (!FixedPosition ? size<1>(tRrCos) : 1); ++m) {\n            if (get<0>(tRcR(_0{}, m, _0{})) < std::min(max_seqlen - block * kBlockMN, kBlockMN)) {\n                #pragma unroll\n                for (int k = 0; k < size<2>(tRrCos); ++k) {\n                    if (tRpRCur(k)) {\n                        cute::copy(GmemTiledCopyRo{}, tRgCosCur(_, m, k), tRrCos(_, m, k));\n                        cute::copy(GmemTiledCopyRo{}, tRgSinCur(_, m, k), tRrSin(_, m, k));\n                    }\n                }\n            }\n        }\n        return cute::make_tuple(tRrCos, tRrSin);;\n    }\n\n    template <bool kInterleaved=true>\n    CUTLASS_DEVICE\n    auto load_cos_sin_packgqa(int const block, cutlass::FastDivmod const &qhead_per_khead_divmod) {\n        static constexpr int kGmemElemsPerLoadCur = kInterleaved ? kGmemElemsPerLoad / 2 : kGmemElemsPerLoad;\n        using GmemTiledCopyRo = std::conditional_t<kInterleaved, GmemTiledCopyRotary, GmemTiledCopyRotaryCont>;\n        auto gmem_thr_copy_ro = cute::conditional_return<kInterleaved>(gmem_thr_copy_rotary, gmem_thr_copy_rotary_cont);\n        Tensor tRpRCur = cute::conditional_return<kInterleaved>(tRpR, tRpRCont);\n        // make_tensor_like, not make_fragment_like. If the row_stride is _0{} we want to keep it that way\n        Tensor tRrCos = make_tensor_like(cute::conditional_return<kInterleaved>(tRgCos, tRgCosCont)(_, _, _, _0{}));\n        Tensor tRrSin = make_tensor_like(cute::conditional_return<kInterleaved>(tRgSin, tRgSinCont)(_, _, _, _0{}));\n        int const qhead_per_khead = qhead_per_khead_divmod.divisor;\n        Tensor cR = cute::make_identity_tensor(Shape<Int<kBlockMN>, Int<kHeadDim / 2>>{});  // (BLK_N,BLK_K / 2)\n        Tensor tRcR = gmem_thr_copy_ro.partition_D(cR);\n\n        // The main bottleneck here is actually instruction cache misses.\n\n        // Similar to PagedKVNonTMA, it's expensive to compute the pointers.\n        // We split the work among threads loading the same row, then __shfl_sync the pointers.\n        static constexpr int NumPtrPerThread = cute::ceil_div(CUTE_STATIC_V(cute::size<1>(tRrCos)), kGmemThreadsPerRow);\n        Tensor tPrCosPtr = make_tensor<Element const*>(Shape<Int<NumPtrPerThread>>{});\n        Tensor tPrSinPtr = make_tensor<Element const*>(Shape<Int<NumPtrPerThread>>{});\n        #pragma unroll\n        for (int i = 0; i < NumPtrPerThread; ++i) {\n            int const row = i * NumThreads + get<0>(tRcR(_0{}, thread_idx % kGmemThreadsPerRow, _0{}));\n            int const idx = block * kBlockMN + row;\n            int row_actual = qhead_per_khead_divmod.divide(idx);\n            tPrCosPtr[i] = &mCos(row_actual, _0{});\n            tPrSinPtr[i] = &mSin(row_actual, _0{});\n        }\n\n        #pragma unroll\n        for (int m = 0; m < (!FixedPosition ? size<1>(tRgCos) : 1); ++m) {\n            int const idx = block * kBlockMN + get<0>(tRcR(_0{}, m, _0{}));\n            Element const* cos_ptr = reinterpret_cast<Element const*>(__shfl_sync(0xffffffff, reinterpret_cast<uint64_t>(tPrCosPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow));\n            Element const* sin_ptr = reinterpret_cast<Element const*>(__shfl_sync(0xffffffff, reinterpret_cast<uint64_t>(tPrSinPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow));\n            if (idx < max_seqlen * qhead_per_khead) {\n                Tensor mCos_copy = cute::tiled_divide(make_tensor(make_gmem_ptr(cos_ptr), Shape<Int<kHeadDim / 2>>{}),\n                                                    Shape<Int<kGmemElemsPerLoadCur>>{});\n                Tensor mSin_copy = cute::tiled_divide(make_tensor(make_gmem_ptr(sin_ptr), Shape<Int<kHeadDim / 2>>{}),\n                                                    Shape<Int<kGmemElemsPerLoadCur>>{});\n                #pragma unroll\n                for (int k = 0; k < size<2>(tRgCos); ++k) {\n                    int const ki = get<1>(tRcR(_0{}, _0{}, k)) / (kGmemElemsPerLoadCur);\n                    if (tRpRCur(k)) {\n                        cute::copy(GmemTiledCopyRo{}, mCos_copy(_, ki), tRrCos(_, m, k));\n                        cute::copy(GmemTiledCopyRo{}, mSin_copy(_, ki), tRrSin(_, m, k));\n                    }\n                }\n            }\n        }\n        return cute::make_tuple(tRrCos, tRrSin);\n    }\n\n    template <typename TensorsQ, typename TensortRrR>\n    CUTLASS_DEVICE\n    void\n    apply_Q_interleaved(TensorsQ &sQ,  // (kBlockM, kHeadDim)\n                        TensortRrR const &tRrCos,   // (kBlockM, kHeadDim / 2) split according to GmemThrCopyRotary\n                        TensortRrR const &tRrSin,   // (kBlockM, kHeadDim / 2) split according to GmemThrCopyRotary\n                        int const m_block, int const qhead_per_khead=1)\n    {\n        TiledCopyQK tiled_copy_q;\n        auto gmem_thr_copy_q = tiled_copy_q.get_thread_slice(thread_idx);\n        Tensor tQsQ = gmem_thr_copy_q.partition_S(sQ);\n        Tensor tQcQ = gmem_thr_copy_q.partition_S(cute::make_identity_tensor(Shape<Int<kBlockMN>, Int<kHeadDim>>{}));\n\n        CUTE_STATIC_ASSERT_V(rank(tQsQ) == _3{});\n        CUTE_STATIC_ASSERT_V(rank(tRrCos) == _3{});\n        CUTE_STATIC_ASSERT_V(rank(tRrSin) == _3{});\n        CUTE_STATIC_ASSERT_V(size<1>(tQsQ) == size<1>(tRrCos));\n        CUTE_STATIC_ASSERT_V(size<2>(tQsQ) == size<2>(tRrCos));\n        CUTE_STATIC_ASSERT_V(size<1>(tQsQ) == size<1>(tRrSin));\n        CUTE_STATIC_ASSERT_V(size<2>(tQsQ) == size<2>(tRrSin));\n        CUTE_STATIC_ASSERT_V(size<0>(tRrCos) == size<0>(tRrSin));\n        static_assert(decltype(size<0>(tQsQ))::value == decltype(size<0>(tRrCos))::value * 2);\n        static_assert(decltype(size<0>(tRrCos))::value % 2 == 0);  // Since we do fast conversion from fp16/bf16 to fp32\n\n        #pragma unroll\n        for (int m = 0; m < size<1>(tQsQ); ++m) {\n            if (get<0>(tQcQ(_0{}, m, _0{})) < std::min(max_seqlen * qhead_per_khead - m_block * kBlockMN, kBlockMN)) {\n                #pragma unroll\n                for (int k = 0; k < size<2>(tQsQ); ++k) {\n                    if (tRpR(k)) {\n                        Tensor rQ = make_fragment_like(tQsQ(_, m, k));\n                        cute::copy(tiled_copy_q, tQsQ(_, m, k), rQ);\n                        apply_rotary_interleaved(rQ, tRrCos(_, m, k), tRrSin(_, m, k));\n                        cute::copy(tiled_copy_q, rQ, tQsQ(_, m, k));\n                    }\n                }\n            }\n        }\n    };\n\n    template <typename TensorsQ, typename TensortRrR>\n    CUTLASS_DEVICE\n    void\n    apply_Q_contiguous(TensorsQ &sQ,  // (kBlockM, kHeadDim)\n                       TensortRrR const &tRrCosCont, // (kBlockM, kHeadDim / 2) split according to GmemThrCopyRotaryCont\n                       TensortRrR const &tRrSinCont, // (kBlockM, kHeadDim / 2) split according to GmemThrCopyRotaryCont\n                       int const m_block, int const qhead_per_khead=1)\n    {\n        TiledCopyQK tiled_copy_q;\n        auto gmem_thr_copy_q = tiled_copy_q.get_thread_slice(thread_idx);\n        Tensor sQ_copy = cute::tiled_divide(sQ, Shape<_1, Int<kGmemElemsPerLoad>>{});\n        Tensor tQcQ = gmem_thr_copy_q.partition_S(cute::make_identity_tensor(Shape<Int<kBlockMN>, Int<kHeadDim / 2>>{}));\n\n        CUTE_STATIC_ASSERT_V(rank(tQcQ) == _3{});\n        CUTE_STATIC_ASSERT_V(rank(tRrCosCont) == _3{});\n        CUTE_STATIC_ASSERT_V(rank(tRrSinCont) == _3{});\n        CUTE_STATIC_ASSERT_V(size<1>(tQcQ) == size<1>(tRrCosCont));\n        CUTE_STATIC_ASSERT_V(size<2>(tQcQ) == size<2>(tRrCosCont));\n        CUTE_STATIC_ASSERT_V(size<1>(tQcQ) == size<1>(tRrSinCont));\n        CUTE_STATIC_ASSERT_V(size<2>(tQcQ) == size<2>(tRrSinCont));\n        CUTE_STATIC_ASSERT_V(size<0>(tRrCosCont) == size<0>(tRrSinCont));\n        CUTE_STATIC_ASSERT_V(size<0>(tQcQ) == size<0>(tRrCosCont));\n        static_assert(decltype(size<0>(tRrCosCont))::value % 2 == 0);  // Since we do fast conversion from fp16/bf16 to fp32\n\n        #pragma unroll\n        for (int m = 0; m < size<1>(tQcQ); ++m) {\n            int const row = get<0>(tQcQ(_0{}, m, _0{}));\n            if (row < std::min(max_seqlen * qhead_per_khead - m_block * kBlockMN, kBlockMN)) {\n                #pragma unroll\n                for (int k = 0; k < size<2>(tQcQ); ++k) {\n                    int const col = get<1>(tQcQ(_0{}, _0{}, k));\n                    if (col < rotary_dim / 2) {\n                        int const col_idx_left = col / kGmemElemsPerLoad;\n                        int const col_idx_right = col / kGmemElemsPerLoad + rotary_dim / (2 * kGmemElemsPerLoad);\n                        Tensor rQ_left = make_fragment_like(sQ_copy(_, row, col_idx_left));\n                        cute::copy(tiled_copy_q, sQ_copy(_, row, col_idx_left), rQ_left);\n                        Tensor rQ_right = make_fragment_like(rQ_left);\n                        cute::copy(tiled_copy_q, sQ_copy(_, row, col_idx_right), rQ_right);\n                        apply_rotary_contiguous(rQ_left, rQ_right, tRrCosCont(_, m, k), tRrSinCont(_, m, k));\n                        cute::copy(tiled_copy_q, rQ_left, sQ_copy(_, row, col_idx_left));\n                        cute::copy(tiled_copy_q, rQ_right, sQ_copy(_, row, col_idx_right));\n                    }\n                }\n            }\n        }\n    };\n\n    template <bool PagedKVNonTMA=false, typename TensorsK, typename TensorgK, typename TensorpK, typename TensortRrR, typename TensorKPtr>\n    CUTLASS_DEVICE\n    void\n    apply_K_interleaved(TensorsK const &sK,  // (kBlockN, kHeadDim)\n                        TensorgK &gK,  // (kBlockN, kHeadDim)\n                        TensorpK const &tKpK,  // (kBlockN, kHeadDim) split according to ThrCopyKV\n                        TensortRrR const &tRrCos,   // (kBlockN, kHeadDim/2) split according to GmemThrCopyRotary\n                        TensortRrR const &tRrSin,   // (kBlockN, kHeadDim/2) split according to GmemThrCopyRotary\n                        TensorKPtr const &tPrKPtr,\n                        int const n_block)\n    {\n        TiledCopyQK tiled_copy_k;\n        auto gmem_thr_copy_q = tiled_copy_k.get_thread_slice(thread_idx);\n        Tensor tKsK = gmem_thr_copy_q.partition_S(sK);\n        Tensor tKgK = gmem_thr_copy_q.partition_S(gK);\n        Tensor tKcK = gmem_thr_copy_q.partition_S(cute::make_identity_tensor(Shape<Int<kBlockMN>, Int<kHeadDim>>{}));\n\n        CUTE_STATIC_ASSERT_V(rank(tKsK) == _3{});\n        CUTE_STATIC_ASSERT_V(rank(tRrCos) == _3{});\n        CUTE_STATIC_ASSERT_V(rank(tRrSin) == _3{});\n        CUTE_STATIC_ASSERT_V(size<1>(tKsK) == size<1>(tRrCos));\n        CUTE_STATIC_ASSERT_V(size<2>(tKsK) == size<2>(tRrCos));\n        CUTE_STATIC_ASSERT_V(size<1>(tKsK) == size<1>(tRrSin));\n        CUTE_STATIC_ASSERT_V(size<2>(tKsK) == size<2>(tRrSin));\n        CUTE_STATIC_ASSERT_V(size<0>(tRrCos) == size<0>(tRrSin));\n        static_assert(decltype(size<0>(tKsK))::value == decltype(size<0>(tRrCos))::value * 2);\n        static_assert(decltype(size<0>(tRrCos))::value % 2 == 0);  // Since we do fast conversion from fp16/bf16 to fp32\n        if constexpr (PagedKVNonTMA) {\n            static_assert(decltype(size(tPrKPtr))::value == cute::ceil_div(size<1>(tKcK), kGmemThreadsPerRow));\n        }\n\n        #pragma unroll\n        for (int m = 0; m < size<1>(tKsK); ++m) {\n            int const row = get<0>(tKcK(_0{}, m, _0{}));\n            auto mK_cur_copy = [&] {\n                if constexpr (PagedKVNonTMA) {\n                    Element* k_ptr = reinterpret_cast<Element*>(__shfl_sync(0xffffffff, reinterpret_cast<uint64_t>(tPrKPtr(m / kGmemThreadsPerRow)), (m % kGmemThreadsPerRow), kGmemThreadsPerRow));\n                    Tensor mK_cur = make_tensor(make_gmem_ptr(k_ptr), Shape<Int<kHeadDim>>{});\n                    return cute::tiled_divide(mK_cur, Shape<Int<kGmemElemsPerLoad>>{});\n                } else {\n                    return nullptr;\n                }\n            }();\n            if (row < std::min(max_seqlen - n_block * kBlockMN, kBlockMN)) {\n                #pragma unroll\n                for (int k = 0; k < size<2>(tKsK); ++k) {\n                    if (tKpK(k)) {\n                        Tensor rK = make_fragment_like(tKsK(_, m, k));\n                        cute::copy(tiled_copy_k, tKsK(_, m, k), rK);\n                        if (tRpR(k)) { apply_rotary_interleaved(rK, tRrCos(_, m, k), tRrSin(_, m, k)); }\n                        if constexpr (!PagedKVNonTMA) {\n                            cute::copy(tiled_copy_k, rK, tKgK(_, m, k));\n                        } else {\n                            int const ki = get<1>(tKcK(_0{}, _0{}, k)) / kGmemElemsPerLoad;\n                            cute::copy(tiled_copy_k, rK, mK_cur_copy(_, ki));\n                        }\n                    }\n                }\n            }\n        }\n    };\n\n    template <bool PagedKVNonTMA=false, typename TensorsK, typename TensorgK, typename TensorpK, typename TensortRrR, typename TensorKPtr>\n    CUTLASS_DEVICE\n    void\n    apply_K_contiguous(TensorsK const &sK,  // (kBlockN, kHeadDim)\n                       TensorgK &gK,  // (kBlockN, kHeadDim)\n                       TensorpK const &tKpK,  // (kBlockN, kHeadDim) split according to ThrCopyKV\n                       TensortRrR const &tRrCosCont,   // (kBlockN, kHeadDim/2) split according to GmemThrCopyRotaryCont\n                       TensortRrR const &tRrSinCont,   // (kBlockN, kHeadDim/2) split according to GmemThrCopyRotaryCont\n                       TensorKPtr const &tPrKPtr,\n                       int const n_block, int const max_k)\n    {\n        TiledCopyQK tiled_copy_k;\n        auto gmem_thr_copy_q = tiled_copy_k.get_thread_slice(thread_idx);\n        Tensor sK_copy = cute::tiled_divide(sK, Shape<_1, Int<kGmemElemsPerLoad>>{});\n        Tensor gK_copy = cute::tiled_divide(gK, Shape<_1, Int<kGmemElemsPerLoad>>{});\n        Tensor tKcK = gmem_thr_copy_q.partition_S(cute::make_identity_tensor(Shape<Int<kBlockMN>, Int<kHeadDim / 2>>{}));\n\n        CUTE_STATIC_ASSERT_V(rank(tKcK) == _3{});\n        CUTE_STATIC_ASSERT_V(rank(tRrCosCont) == _3{});\n        CUTE_STATIC_ASSERT_V(rank(tRrSinCont) == _3{});\n        CUTE_STATIC_ASSERT_V(size<1>(tKcK) == size<1>(tRrCosCont));\n        CUTE_STATIC_ASSERT_V(size<2>(tKcK) == size<2>(tRrCosCont));\n        CUTE_STATIC_ASSERT_V(size<1>(tKcK) == size<1>(tRrSinCont));\n        CUTE_STATIC_ASSERT_V(size<2>(tKcK) == size<2>(tRrSinCont));\n        CUTE_STATIC_ASSERT_V(size<0>(tRrCosCont) == size<0>(tRrSinCont));\n        CUTE_STATIC_ASSERT_V(size<0>(tKcK) == size<0>(tRrCosCont));\n        static_assert(decltype(size<0>(tRrCosCont))::value % 2 == 0);  // Since we do fast conversion from fp16/bf16 to fp32\n        if constexpr (PagedKVNonTMA) {\n            static_assert(decltype(size(tPrKPtr))::value == cute::ceil_div(size<1>(tKcK), kGmemThreadsPerRow));\n        }\n\n        const int ro_dim_vec = rotary_dim / kGmemElemsPerLoad;\n        const int non_ro_dim_vec = (max_k - rotary_dim) / kGmemElemsPerLoad;\n        #pragma unroll\n        for (int m = 0; m < size<1>(tKcK); ++m) {\n            int const row = get<0>(tKcK(_0{}, m, _0{}));\n            Tensor gK_cur_copy = [&] {\n                if constexpr (PagedKVNonTMA) {\n                    Element* k_ptr = reinterpret_cast<Element*>(__shfl_sync(0xffffffff, reinterpret_cast<uint64_t>(tPrKPtr(m / kGmemThreadsPerRow)), (m % kGmemThreadsPerRow), kGmemThreadsPerRow));\n                    Tensor mK_cur = make_tensor(make_gmem_ptr(k_ptr), Shape<Int<kHeadDim>>{});\n                    return cute::tiled_divide(mK_cur, Shape<Int<kGmemElemsPerLoad>>{});\n                } else {\n                    return gK_copy(_, row, _);\n                }\n            }();\n            if (row < std::min(max_seqlen - n_block * kBlockMN, kBlockMN)) {\n                #pragma unroll\n                for (int k = 0; k < size<2>(tKcK); ++k) {\n                    if (tKpK(k)) {\n                        int const col = get<1>(tKcK(_0{}, _0{}, k));\n                        bool rotate = col < rotary_dim / 2;\n                        int const col_idx_left = rotate ? col / kGmemElemsPerLoad : (col + rotary_dim / 2) / kGmemElemsPerLoad;\n                        int const col_idx_right = col_idx_left + (rotate ? ro_dim_vec / 2 : non_ro_dim_vec / 2);\n                        Tensor rK_left = make_fragment_like(sK_copy(_, row, col_idx_left));\n                        cute::copy(tiled_copy_k, sK_copy(_, row, col_idx_left), rK_left);\n                        Tensor rK_right = make_fragment_like(rK_left);\n                        cute::copy(tiled_copy_k, sK_copy(_, row, col_idx_right), rK_right);\n                        if (rotate) {\n                            apply_rotary_contiguous(rK_left, rK_right, tRrCosCont(_, m, k), tRrSinCont(_, m, k));\n                        }\n                        cute::copy(tiled_copy_k, rK_left, gK_cur_copy(_, col_idx_left));\n                        if (col_idx_right * kGmemElemsPerLoad < max_k) {\n                            cute::copy(tiled_copy_k, rK_right, gK_cur_copy(_, col_idx_right));\n                        }\n                    }\n                }\n            }\n        }\n    };\n\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n} // namespace flash\n"
  },
  {
    "path": "hopper/seqlen.h",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\nnamespace flash {\n\n// We consolidate all the info related to sequence length here. This is so that we can do all\n// the gmem reads once at the beginning of each tile, rather than having to repeat these reads\n// to compute various things like n_block_min, n_block_max, etc.\n\ntemplate <bool Varlen, int kBlock>\nstruct SeqlenInfo {\n\n    int const offset, offset_padded;\n    int const seqlen;\n\n    CUTLASS_DEVICE\n    SeqlenInfo(int const bidb, int const seqlen_static, int const* const cu_seqlens, int const* const seqused)\n        : offset(!Varlen || cu_seqlens == nullptr ? 0 : cu_seqlens[bidb])\n        , offset_padded(!Varlen || cu_seqlens == nullptr ? 0 : (cu_seqlens[bidb] + bidb * kBlock) / kBlock * kBlock)\n        , seqlen(!Varlen\n                 ? seqlen_static\n                 : (seqused ? seqused[bidb] : (cu_seqlens ? cu_seqlens[bidb + 1] - cu_seqlens[bidb] : seqlen_static)))\n    {\n    }\n\n};\n\ntemplate <bool Varlen, int kBlockM>\nstruct SeqlenInfoQK {\n\n    int const offset_q, offset_k, offset_q_padded;\n    int const seqlen_q, seqlen_k;\n\n    CUTLASS_DEVICE\n    SeqlenInfoQK(int const bidb, int const seqlen_q_static, int const seqlen_k_static,\n                 int const* const cu_seqlens_q, int const* const cu_seqlens_k,\n                 int const* const seqused_q, int const* const seqused_k\n                 )\n        : offset_q(!Varlen || cu_seqlens_q == nullptr ? 0 : cu_seqlens_q[bidb])\n        , offset_k(!Varlen || cu_seqlens_k == nullptr ? 0 : cu_seqlens_k[bidb])\n        // If varlen, the layout for dPSum, LSE_log2, and dQaccum is that we pad each sequence in the batch\n        // by an extra kBlockM, so that the write for each sequence doesn't touch the next sequence.\n        // Sequence i starts at cu_seqlens[i] + i * kBlockM and ends at cu_seqlens[i + 1] + i * kBlockM\n        // However, the start must align to multiples of kBlockM.\n        , offset_q_padded(!Varlen || cu_seqlens_q == nullptr ? 0 : (cu_seqlens_q[bidb] + bidb * kBlockM) / kBlockM * kBlockM)\n        , seqlen_q(!Varlen\n                   ? seqlen_q_static\n                   : (seqused_q ? seqused_q[bidb] : (cu_seqlens_q ? cu_seqlens_q[bidb + 1] - cu_seqlens_q[bidb] : seqlen_q_static)))\n        , seqlen_k(!Varlen\n                   ? seqlen_k_static\n                   : (seqused_k ? seqused_k[bidb] : (cu_seqlens_k ? cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb] : seqlen_k_static)))\n    {\n    }\n\n};\n\ntemplate <bool Varlen, bool AppendKV>\nstruct SeqlenInfoQKNewK {\n\n    static_assert(!(AppendKV && !Varlen), \"AppendKV is only supported with Varlen\");\n\n    int const leftpad_k;\n    int const offset_q, offset_k, offset_k_new;\n    int const seqlen_q, seqlen_k_og, seqlen_k_new, seqlen_k, seqlen_rotary;\n\n    CUTLASS_DEVICE\n    SeqlenInfoQKNewK(int const bidb, int const seqlen_q_static, int const seqlen_k_static, int const shape_K_new_0,\n                     int const* const cu_seqlens_q, int const* const cu_seqlens_k, int const* const cu_seqlens_k_new,\n                     int const* const seqused_q, int const* const seqused_k, int const* const ptr_leftpad_k,\n                     int const* const seqlens_rotary\n                     )\n        : leftpad_k(ptr_leftpad_k ? ptr_leftpad_k[bidb] : 0)\n        , offset_q(!Varlen || cu_seqlens_q == nullptr ? 0 : cu_seqlens_q[bidb])\n        , offset_k(!Varlen ? 0 : (cu_seqlens_k ? cu_seqlens_k[bidb] : 0) + leftpad_k)\n        , offset_k_new(!AppendKV || cu_seqlens_k_new == nullptr ? 0 : cu_seqlens_k_new[bidb])\n        , seqlen_q(!Varlen\n                   ? seqlen_q_static\n                   : (seqused_q ? seqused_q[bidb] : (cu_seqlens_q ? cu_seqlens_q[bidb + 1] - cu_seqlens_q[bidb] : seqlen_q_static)))\n        , seqlen_k_og(!Varlen\n                      ? seqlen_k_static\n                      : (seqused_k ? seqused_k[bidb] : (cu_seqlens_k ? cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb] : seqlen_k_static)) - leftpad_k)\n        , seqlen_k_new(!AppendKV\n                       ? 0\n                       : (cu_seqlens_k_new ? cu_seqlens_k_new[bidb + 1] - cu_seqlens_k_new[bidb] : shape_K_new_0))\n        , seqlen_k(!AppendKV ? seqlen_k_og : seqlen_k_og + seqlen_k_new)\n        , seqlen_rotary(!AppendKV || !seqlens_rotary ? seqlen_k_og + leftpad_k : seqlens_rotary[bidb])\n    {\n    }\n\n};\n\n} // namespace flash\n"
  },
  {
    "path": "hopper/setup.py",
    "content": "# Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n\nimport sys\nimport warnings\nimport os\nimport stat\nimport re\nimport shutil\nimport ast\nfrom pathlib import Path\nfrom packaging.version import parse, Version\nimport platform\nimport sysconfig\nimport tarfile\nimport itertools\n\nfrom setuptools import setup, find_packages\nimport subprocess\n\nimport urllib.request\nimport urllib.error\nfrom wheel.bdist_wheel import bdist_wheel as _bdist_wheel\n\nimport torch\nfrom torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME\n\n\n# with open(\"../README.md\", \"r\", encoding=\"utf-8\") as fh:\nwith open(\"../README.md\", \"r\", encoding=\"utf-8\") as fh:\n    long_description = fh.read()\n\n\n# ninja build does not work unless include_dirs are abs path\nthis_dir = os.path.dirname(os.path.abspath(__file__))\n\nPACKAGE_NAME = \"flash_attn_3\"\n\nBASE_WHEEL_URL = \"https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}\"\n\n# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels\n# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation\nFORCE_BUILD = os.getenv(\"FLASH_ATTENTION_FORCE_BUILD\", \"FALSE\") == \"TRUE\"\nSKIP_CUDA_BUILD = os.getenv(\"FLASH_ATTENTION_SKIP_CUDA_BUILD\", \"FALSE\") == \"TRUE\"\n# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI\nFORCE_CXX11_ABI = os.getenv(\"FLASH_ATTENTION_FORCE_CXX11_ABI\", \"FALSE\") == \"TRUE\"\n# ROCm specific settings\nUSE_TRITON_ROCM = os.getenv(\"FLASH_ATTENTION_TRITON_AMD_ENABLE\", \"FALSE\") == \"TRUE\"\nif USE_TRITON_ROCM:\n    SKIP_CUDA_BUILD = True\n\nDISABLE_BACKWARD = os.getenv(\"FLASH_ATTENTION_DISABLE_BACKWARD\", \"FALSE\") == \"TRUE\"\nDISABLE_SPLIT = os.getenv(\"FLASH_ATTENTION_DISABLE_SPLIT\", \"FALSE\") == \"TRUE\"\nDISABLE_PAGEDKV = os.getenv(\"FLASH_ATTENTION_DISABLE_PAGEDKV\", \"FALSE\") == \"TRUE\"\nDISABLE_APPENDKV = os.getenv(\"FLASH_ATTENTION_DISABLE_APPENDKV\", \"FALSE\") == \"TRUE\"\nDISABLE_LOCAL = os.getenv(\"FLASH_ATTENTION_DISABLE_LOCAL\", \"FALSE\") == \"TRUE\"\nDISABLE_SOFTCAP = os.getenv(\"FLASH_ATTENTION_DISABLE_SOFTCAP\", \"FALSE\") == \"TRUE\"\nDISABLE_PACKGQA = os.getenv(\"FLASH_ATTENTION_DISABLE_PACKGQA\", \"FALSE\") == \"TRUE\"\nDISABLE_FP16 = os.getenv(\"FLASH_ATTENTION_DISABLE_FP16\", \"FALSE\") == \"TRUE\"\nDISABLE_FP8 = os.getenv(\"FLASH_ATTENTION_DISABLE_FP8\", \"FALSE\") == \"TRUE\"\nDISABLE_VARLEN = os.getenv(\"FLASH_ATTENTION_DISABLE_VARLEN\", \"FALSE\") == \"TRUE\"\nDISABLE_CLUSTER = os.getenv(\"FLASH_ATTENTION_DISABLE_CLUSTER\", \"FALSE\") == \"TRUE\"\nDISABLE_HDIM64 = os.getenv(\"FLASH_ATTENTION_DISABLE_HDIM64\", \"FALSE\") == \"TRUE\"\nDISABLE_HDIM96 = os.getenv(\"FLASH_ATTENTION_DISABLE_HDIM96\", \"FALSE\") == \"TRUE\"\nDISABLE_HDIM128 = os.getenv(\"FLASH_ATTENTION_DISABLE_HDIM128\", \"FALSE\") == \"TRUE\"\nDISABLE_HDIM192 = os.getenv(\"FLASH_ATTENTION_DISABLE_HDIM192\", \"FALSE\") == \"TRUE\"\nDISABLE_HDIM256 = os.getenv(\"FLASH_ATTENTION_DISABLE_HDIM256\", \"FALSE\") == \"TRUE\"\nDISABLE_SM8x = os.getenv(\"FLASH_ATTENTION_DISABLE_SM80\", \"FALSE\") == \"TRUE\"\n\nENABLE_VCOLMAJOR = os.getenv(\"FLASH_ATTENTION_ENABLE_VCOLMAJOR\", \"FALSE\") == \"TRUE\"\n\nDISABLE_HDIMDIFF64 = os.getenv(\"FLASH_ATTENTION_DISABLE_HDIMDIFF64\", \"FALSE\") == \"TRUE\"\nDISABLE_HDIMDIFF192 = os.getenv(\"FLASH_ATTENTION_DISABLE_HDIMDIFF192\", \"FALSE\") == \"TRUE\"\n\n# HACK: we monkey patch pytorch's _write_ninja_file to pass\n# \"-gencode arch=compute_sm90a,code=sm_90a\" to files ending in '_sm90.cu',\n# and pass \"-gencode arch=compute_sm80,code=sm_80\" to files ending in '_sm80.cu'\nfrom torch.utils.cpp_extension import (\n    IS_HIP_EXTENSION,\n    COMMON_HIP_FLAGS,\n    SUBPROCESS_DECODE_ARGS,\n    IS_WINDOWS,\n    get_cxx_compiler,\n    _join_rocm_home,\n    _join_cuda_home,\n    _is_cuda_file,\n    _maybe_write,\n)\n\ndef create_build_config_file():\n    CONFIG = {\n        \"build_flags\": {\n            \"FLASHATTENTION_DISABLE_BACKWARD\": DISABLE_BACKWARD,\n            \"FLASHATTENTION_DISABLE_SPLIT\": DISABLE_SPLIT,\n            \"FLASHATTENTION_DISABLE_PAGEDKV\": DISABLE_PAGEDKV,\n            \"FLASHATTENTION_DISABLE_APPENDKV\": DISABLE_APPENDKV,\n            \"FLASHATTENTION_DISABLE_LOCAL\": DISABLE_LOCAL,\n            \"FLASHATTENTION_DISABLE_SOFTCAP\": DISABLE_SOFTCAP,\n            \"FLASHATTENTION_DISABLE_PACKGQA\": DISABLE_PACKGQA,\n            \"FLASHATTENTION_DISABLE_FP16\": DISABLE_FP16,\n            \"FLASHATTENTION_DISABLE_FP8\": DISABLE_FP8,\n            \"FLASHATTENTION_DISABLE_VARLEN\": DISABLE_VARLEN,\n            \"FLASHATTENTION_DISABLE_CLUSTER\": DISABLE_CLUSTER,\n            \"FLASHATTENTION_DISABLE_HDIM64\": DISABLE_HDIM64,\n            \"FLASHATTENTION_DISABLE_HDIM96\": DISABLE_HDIM96,\n            \"FLASHATTENTION_DISABLE_HDIM128\": DISABLE_HDIM128,\n            \"FLASHATTENTION_DISABLE_HDIM192\": DISABLE_HDIM192,\n            \"FLASHATTENTION_DISABLE_HDIM256\": DISABLE_HDIM256,\n            \"FLASHATTENTION_DISABLE_SM8x\": DISABLE_SM8x,\n            \"FLASHATTENTION_ENABLE_VCOLMAJOR\": ENABLE_VCOLMAJOR,\n            \"FLASH_ATTENTION_DISABLE_HDIMDIFF64\": DISABLE_HDIMDIFF64,\n            \"FLASH_ATTENTION_DISABLE_HDIMDIFF192\": DISABLE_HDIMDIFF192,\n        }\n    }\n\n    with open(\"flash_attn_config.py\", \"w\") as f:\n        f.write(\"# Auto-generated by flash attention 3 setup.py\\n\")\n        f.write(f\"CONFIG = {repr(CONFIG)}\\n\")\n        f.write(\"\\n\")\n\n        f.write(\"def show():\\n\")\n        f.write(\"    from pprint import pprint\\n\")\n        f.write(\"    pprint(CONFIG)\\n\")\n        f.write(\"\\n\")\n\ndef _write_ninja_file(path,\n                      cflags,\n                      post_cflags,\n                      cuda_cflags,\n                      cuda_post_cflags,\n                      cuda_dlink_post_cflags,\n                      sources,\n                      objects,\n                      ldflags,\n                      library_target,\n                      with_cuda,\n                      **kwargs,  # kwargs (ignored) to absorb new flags in torch.utils.cpp_extension\n                      ) -> None:\n    r\"\"\"Write a ninja file that does the desired compiling and linking.\n\n    `path`: Where to write this file\n    `cflags`: list of flags to pass to $cxx. Can be None.\n    `post_cflags`: list of flags to append to the $cxx invocation. Can be None.\n    `cuda_cflags`: list of flags to pass to $nvcc. Can be None.\n    `cuda_postflags`: list of flags to append to the $nvcc invocation. Can be None.\n    `sources`: list of paths to source files\n    `objects`: list of desired paths to objects, one per source.\n    `ldflags`: list of flags to pass to linker. Can be None.\n    `library_target`: Name of the output library. Can be None; in that case,\n                      we do no linking.\n    `with_cuda`: If we should be compiling with CUDA.\n    \"\"\"\n    def sanitize_flags(flags):\n        if flags is None:\n            return []\n        else:\n            return [flag.strip() for flag in flags]\n\n    cflags = sanitize_flags(cflags)\n    post_cflags = sanitize_flags(post_cflags)\n    cuda_cflags = sanitize_flags(cuda_cflags)\n    cuda_post_cflags = sanitize_flags(cuda_post_cflags)\n    cuda_dlink_post_cflags = sanitize_flags(cuda_dlink_post_cflags)\n    ldflags = sanitize_flags(ldflags)\n\n    # Sanity checks...\n    assert len(sources) == len(objects)\n    assert len(sources) > 0\n\n    compiler = get_cxx_compiler()\n\n    # Version 1.3 is required for the `deps` directive.\n    config = ['ninja_required_version = 1.3']\n    config.append(f'cxx = {compiler}')\n    if with_cuda or cuda_dlink_post_cflags:\n        if IS_HIP_EXTENSION:\n            nvcc = _join_rocm_home('bin', 'hipcc')\n        else:\n            nvcc = _join_cuda_home('bin', 'nvcc')\n        if \"PYTORCH_NVCC\" in os.environ:\n            nvcc_from_env = os.getenv(\"PYTORCH_NVCC\")    # user can set nvcc compiler with ccache using the environment variable here\n        else:\n            nvcc_from_env = nvcc\n        config.append(f'nvcc_from_env = {nvcc_from_env}')\n        config.append(f'nvcc = {nvcc}')\n\n    if IS_HIP_EXTENSION:\n        post_cflags = COMMON_HIP_FLAGS + post_cflags\n    flags = [f'cflags = {\" \".join(cflags)}']\n    flags.append(f'post_cflags = {\" \".join(post_cflags)}')\n    if with_cuda:\n        flags.append(f'cuda_cflags = {\" \".join(cuda_cflags)}')\n        flags.append(f'cuda_post_cflags = {\" \".join(cuda_post_cflags)}')\n        cuda_post_cflags_sm80 = [s if s != 'arch=compute_90a,code=sm_90a' else 'arch=compute_80,code=sm_80' for s in cuda_post_cflags]\n        flags.append(f'cuda_post_cflags_sm80 = {\" \".join(cuda_post_cflags_sm80)}')\n        cuda_post_cflags_sm80_sm90 = cuda_post_cflags + ['-gencode', 'arch=compute_80,code=sm_80']\n        flags.append(f'cuda_post_cflags_sm80_sm90 = {\" \".join(cuda_post_cflags_sm80_sm90)}')\n        cuda_post_cflags_sm100 = [s if s != 'arch=compute_90a,code=sm_90a' else 'arch=compute_100a,code=sm_100a' for s in cuda_post_cflags]\n        flags.append(f'cuda_post_cflags_sm100 = {\" \".join(cuda_post_cflags_sm100)}')\n    flags.append(f'cuda_dlink_post_cflags = {\" \".join(cuda_dlink_post_cflags)}')\n    flags.append(f'ldflags = {\" \".join(ldflags)}')\n\n    # Turn into absolute paths so we can emit them into the ninja build\n    # file wherever it is.\n    sources = [os.path.abspath(file) for file in sources]\n\n    # See https://ninja-build.org/build.ninja.html for reference.\n    compile_rule = ['rule compile']\n    if IS_WINDOWS:\n        compile_rule.append(\n            '  command = cl /showIncludes $cflags -c $in /Fo$out $post_cflags')\n        compile_rule.append('  deps = msvc')\n    else:\n        compile_rule.append(\n            '  command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags')\n        compile_rule.append('  depfile = $out.d')\n        compile_rule.append('  deps = gcc')\n\n    if with_cuda:\n        cuda_compile_rule = ['rule cuda_compile']\n        nvcc_gendeps = ''\n        # --generate-dependencies-with-compile is not supported by ROCm\n        # Nvcc flag `--generate-dependencies-with-compile` is not supported by sccache, which may increase build time.\n        if torch.version.cuda is not None and os.getenv('TORCH_EXTENSION_SKIP_NVCC_GEN_DEPENDENCIES', '0') != '1':\n            cuda_compile_rule.append('  depfile = $out.d')\n            cuda_compile_rule.append('  deps = gcc')\n            # Note: non-system deps with nvcc are only supported\n            # on Linux so use --generate-dependencies-with-compile\n            # to make this work on Windows too.\n            nvcc_gendeps = '--generate-dependencies-with-compile --dependency-output $out.d'\n        cuda_compile_rule_sm80 = ['rule cuda_compile_sm80'] + cuda_compile_rule[1:] + [\n            f'  command = $nvcc_from_env {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags_sm80'\n        ]\n        cuda_compile_rule_sm80_sm90 = ['rule cuda_compile_sm80_sm90'] + cuda_compile_rule[1:] + [\n            f'  command = $nvcc_from_env {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags_sm80_sm90'\n        ]\n        cuda_compile_rule_sm100 = ['rule cuda_compile_sm100'] + cuda_compile_rule[1:] + [\n            f'  command = $nvcc_from_env {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags_sm100'\n        ]\n        cuda_compile_rule.append(\n            f'  command = $nvcc_from_env {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags')\n\n    # Emit one build rule per source to enable incremental build.\n    build = []\n    for source_file, object_file in zip(sources, objects):\n        is_cuda_source = _is_cuda_file(source_file) and with_cuda\n        if is_cuda_source:\n            if source_file.endswith('_sm90.cu'):\n                rule = 'cuda_compile'\n            elif source_file.endswith('_sm80.cu'):\n                rule = 'cuda_compile_sm80'\n            elif source_file.endswith('_sm100.cu'):\n                rule = 'cuda_compile_sm100'\n            else:\n                rule = 'cuda_compile_sm80_sm90'\n        else:\n            rule = 'compile'\n        if IS_WINDOWS:\n            source_file = source_file.replace(':', '$:')\n            object_file = object_file.replace(':', '$:')\n        source_file = source_file.replace(\" \", \"$ \")\n        object_file = object_file.replace(\" \", \"$ \")\n        build.append(f'build {object_file}: {rule} {source_file}')\n\n    if cuda_dlink_post_cflags:\n        devlink_out = os.path.join(os.path.dirname(objects[0]), 'dlink.o')\n        devlink_rule = ['rule cuda_devlink']\n        devlink_rule.append('  command = $nvcc $in -o $out $cuda_dlink_post_cflags')\n        devlink = [f'build {devlink_out}: cuda_devlink {\" \".join(objects)}']\n        objects += [devlink_out]\n    else:\n        devlink_rule, devlink = [], []\n\n    if library_target is not None:\n        link_rule = ['rule link']\n        if IS_WINDOWS:\n            cl_paths = subprocess.check_output(['where',\n                                                'cl']).decode(*SUBPROCESS_DECODE_ARGS).split('\\r\\n')\n            if len(cl_paths) >= 1:\n                cl_path = os.path.dirname(cl_paths[0]).replace(':', '$:')\n            else:\n                raise RuntimeError(\"MSVC is required to load C++ extensions\")\n            link_rule.append(f'  command = \"{cl_path}/link.exe\" $in /nologo $ldflags /out:$out')\n        else:\n            link_rule.append('  command = $cxx $in $ldflags -o $out')\n\n        link = [f'build {library_target}: link {\" \".join(objects)}']\n\n        default = [f'default {library_target}']\n    else:\n        link_rule, link, default = [], [], []\n\n    # 'Blocks' should be separated by newlines, for visual benefit.\n    blocks = [config, flags, compile_rule]\n    if with_cuda:\n        blocks.append(cuda_compile_rule)  # type: ignore[possibly-undefined]\n        blocks.append(cuda_compile_rule_sm80)  # type: ignore[possibly-undefined]\n        blocks.append(cuda_compile_rule_sm80_sm90)  # type: ignore[possibly-undefined]\n        blocks.append(cuda_compile_rule_sm100)  # type: ignore[possibly-undefined]\n    blocks += [devlink_rule, link_rule, build, devlink, link, default]\n    content = \"\\n\\n\".join(\"\\n\".join(b) for b in blocks)\n    # Ninja requires a new lines at the end of the .ninja file\n    content += \"\\n\"\n    _maybe_write(path, content)\n\n\n# Monkey patching\ntorch.utils.cpp_extension._write_ninja_file = _write_ninja_file\n\n\ndef get_platform():\n    \"\"\"\n    Returns the platform name as used in wheel filenames.\n    \"\"\"\n    if sys.platform.startswith(\"linux\"):\n        return \"linux_x86_64\"\n    elif sys.platform == \"darwin\":\n        mac_version = \".\".join(platform.mac_ver()[0].split(\".\")[:2])\n        return f\"macosx_{mac_version}_x86_64\"\n    elif sys.platform == \"win32\":\n        return \"win_amd64\"\n    else:\n        raise ValueError(\"Unsupported platform: {}\".format(sys.platform))\n\n\ndef get_cuda_bare_metal_version(cuda_dir):\n    raw_output = subprocess.check_output([cuda_dir + \"/bin/nvcc\", \"-V\"], universal_newlines=True)\n    output = raw_output.split()\n    release_idx = output.index(\"release\") + 1\n    bare_metal_version = parse(output[release_idx].split(\",\")[0])\n\n    return raw_output, bare_metal_version\n\n\ndef check_if_cuda_home_none(global_option: str) -> None:\n    if CUDA_HOME is not None:\n        return\n    # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary\n    # in that case.\n    warnings.warn(\n        f\"{global_option} was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  \"\n        \"If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, \"\n        \"only images whose names contain 'devel' will provide nvcc.\"\n    )\n\n\n# Taken from https://github.com/pytorch/pytorch/blob/master/tools/setup_helpers/env.py\ndef check_env_flag(name: str, default: str = \"\") -> bool:\n    return os.getenv(name, default).upper() in [\"ON\", \"1\", \"YES\", \"TRUE\", \"Y\"]\n\n\n# Copied from https://github.com/triton-lang/triton/blob/main/python/setup.py\ndef is_offline_build() -> bool:\n    \"\"\"\n    Downstream projects and distributions which bootstrap their own dependencies from scratch\n    and run builds in offline sandboxes\n    may set `FLASH_ATTENTION_OFFLINE_BUILD` in the build environment to prevent any attempts at downloading\n    pinned dependencies from the internet or at using dependencies vendored in-tree.\n\n    Dependencies must be defined using respective search paths (cf. `syspath_var_name` in `Package`).\n    Missing dependencies lead to an early abortion.\n    Dependencies' compatibility is not verified.\n\n    Note that this flag isn't tested by the CI and does not provide any guarantees.\n    \"\"\"\n    return check_env_flag(\"FLASH_ATTENTION_OFFLINE_BUILD\", \"\")\n\n\n# Copied from https://github.com/triton-lang/triton/blob/main/python/setup.py\ndef get_flashattn_cache_path():\n    user_home = os.getenv(\"FLASH_ATTENTION_HOME\")\n    if not user_home:\n        user_home = os.getenv(\"HOME\") or os.getenv(\"USERPROFILE\") or os.getenv(\"HOMEPATH\") or None\n    if not user_home:\n        raise RuntimeError(\"Could not find user home directory\")\n    return os.path.join(user_home, \".flashattn\")\n\n\ndef open_url(url):\n    user_agent = 'Mozilla/5.0 (X11; Linux x86_64; rv:109.0) Gecko/20100101 Firefox/119.0'\n    headers = {\n        'User-Agent': user_agent,\n    }\n    request = urllib.request.Request(url, None, headers)\n    # Set timeout to 300 seconds to prevent the request from hanging forever.\n    return urllib.request.urlopen(request, timeout=300)\n\n\ndef download_and_copy(name, src_func, dst_path, version, url_func):\n    if is_offline_build():\n        return\n    flashattn_cache_path = get_flashattn_cache_path()\n    base_dir = os.path.dirname(__file__)\n    system = platform.system()\n    arch = platform.machine()\n    arch = {\"arm64\": \"aarch64\"}.get(arch, arch)\n    supported = {\"Linux\": \"linux\", \"Darwin\": \"linux\"}\n    url = url_func(supported[system], arch, version)\n    src_path = src_func(supported[system], arch, version)\n    tmp_path = os.path.join(flashattn_cache_path, \"nvidia\", name)  # path to cache the download\n    dst_path = os.path.join(base_dir, os.pardir, \"third_party\", \"nvidia\", \"backend\", dst_path)  # final binary path\n    src_path = os.path.join(tmp_path, src_path)\n    download = not os.path.exists(src_path)\n    if download:\n        print(f'downloading and extracting {url} ...')\n        file = tarfile.open(fileobj=open_url(url), mode=\"r|*\")\n        file.extractall(path=tmp_path)\n    os.makedirs(os.path.split(dst_path)[0], exist_ok=True)\n    print(f'copy {src_path} to {dst_path} ...')\n    if os.path.isdir(src_path):\n        shutil.copytree(src_path, dst_path, dirs_exist_ok=True)\n    else:\n        shutil.copy(src_path, dst_path)\n\n\ndef nvcc_threads_args():\n    nvcc_threads = os.getenv(\"NVCC_THREADS\") or \"2\"\n    return [\"--threads\", nvcc_threads]\n\n\n# NVIDIA_TOOLCHAIN_VERSION = {\"nvcc\": \"12.3.107\"}\nNVIDIA_TOOLCHAIN_VERSION = {\"nvcc\": \"12.6.85\", \"ptxas\": \"12.8.93\"}\n\nexe_extension = sysconfig.get_config_var(\"EXE\")\n\n\ncmdclass = {}\next_modules = []\n# We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp\n# files included in the source distribution, in case the user compiles from source.\nif not USE_TRITON_ROCM:\n    subprocess.run([\"git\", \"submodule\", \"update\", \"--init\", \"../csrc/cutlass\"])\n\nif not SKIP_CUDA_BUILD:\n    print(\"\\n\\ntorch.__version__  = {}\\n\\n\".format(torch.__version__))\n    TORCH_MAJOR = int(torch.__version__.split(\".\")[0])\n    TORCH_MINOR = int(torch.__version__.split(\".\")[1])\n\n    create_build_config_file()\n    check_if_cuda_home_none(PACKAGE_NAME)\n    _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)\n    if bare_metal_version < Version(\"12.3\"):\n        raise RuntimeError(\"FlashAttention-3 is only supported on CUDA 12.3 and above\")\n    elif bare_metal_version >= Version(\"13.0\"):\n        # CUDA 13.0+ uses system nvcc and CCCL headers are in /usr/local/cuda/include/cccl/\n        cccl_include = os.path.join(CUDA_HOME, \"include\", \"cccl\")\n        for env_var in [\"CPLUS_INCLUDE_PATH\", \"C_INCLUDE_PATH\"]:\n            current = os.environ.get(env_var, \"\")\n            os.environ[env_var] = cccl_include + (\":\" + current if current else \"\")\n\n    # ptxas 12.8 gives the best perf currently\n    # We want to use the nvcc front end from 12.6 however, since if we use nvcc 12.8\n    # Cutlass 3.8 will expect the new data types in cuda.h from CTK 12.8, which we don't have.\n    # For CUDA 13.0+, use system nvcc instead of downloading CUDA 12.x toolchain\n    if bare_metal_version >= Version(\"12.3\") and bare_metal_version < Version(\"13.0\") and bare_metal_version != Version(\"12.8\"):\n        download_and_copy(\n            name=\"nvcc\",\n            src_func=lambda system, arch, version: f\"cuda_nvcc-{system}-{arch}-{version}-archive/bin\",\n            dst_path=\"bin\",\n            version=NVIDIA_TOOLCHAIN_VERSION[\"nvcc\"],\n            url_func=lambda system, arch, version:\n            f\"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/{system}-{arch}/cuda_nvcc-{system}-{arch}-{version}-archive.tar.xz\",\n        )\n        download_and_copy(\n            name=\"ptxas\",\n            src_func=lambda system, arch, version: f\"cuda_nvcc-{system}-{arch}-{version}-archive/bin/ptxas\",\n            dst_path=\"bin\",\n            version=NVIDIA_TOOLCHAIN_VERSION[\"ptxas\"],\n            url_func=lambda system, arch, version:\n            f\"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/{system}-{arch}/cuda_nvcc-{system}-{arch}-{version}-archive.tar.xz\",\n        )\n        download_and_copy(\n            name=\"ptxas\",\n            src_func=lambda system, arch, version: f\"cuda_nvcc-{system}-{arch}-{version}-archive/nvvm/bin\",\n            dst_path=\"nvvm/bin\",\n            version=NVIDIA_TOOLCHAIN_VERSION[\"ptxas\"],\n            url_func=lambda system, arch, version:\n            f\"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/{system}-{arch}/cuda_nvcc-{system}-{arch}-{version}-archive.tar.xz\",\n        )\n        base_dir = os.path.dirname(__file__)\n        ctk_path_new = os.path.abspath(os.path.join(base_dir, os.pardir, \"third_party\", \"nvidia\", \"backend\", \"bin\"))\n        nvcc_path_new = os.path.join(ctk_path_new, f\"nvcc{exe_extension}\")\n        # Need to append to path otherwise nvcc can't find cicc in nvvm/bin/cicc\n        # nvcc 12.8 seems to hard-code looking for cicc in ../nvvm/bin/cicc\n        os.environ[\"PATH\"] = ctk_path_new + os.pathsep + os.environ[\"PATH\"]\n        os.environ[\"PYTORCH_NVCC\"] = nvcc_path_new\n        # Make nvcc executable, sometimes after the copy it loses its permissions\n        os.chmod(nvcc_path_new, os.stat(nvcc_path_new).st_mode | stat.S_IEXEC)\n\n    cc_flag = []\n    cc_flag.append(\"-gencode\")\n    cc_flag.append(\"arch=compute_90a,code=sm_90a\")\n\n    # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as\n    # torch._C._GLIBCXX_USE_CXX11_ABI\n    # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920\n    if FORCE_CXX11_ABI:\n        torch._C._GLIBCXX_USE_CXX11_ABI = True\n    repo_dir = Path(this_dir).parent\n    cutlass_dir = repo_dir / \"csrc\" / \"cutlass\"\n\n    feature_args = (\n        []\n        + ([\"-DFLASHATTENTION_DISABLE_BACKWARD\"] if DISABLE_BACKWARD else [])\n        + ([\"-DFLASHATTENTION_DISABLE_PAGEDKV\"] if DISABLE_PAGEDKV else [])\n        + ([\"-DFLASHATTENTION_DISABLE_SPLIT\"] if DISABLE_SPLIT else [])\n        + ([\"-DFLASHATTENTION_DISABLE_APPENDKV\"] if DISABLE_APPENDKV else [])\n        + ([\"-DFLASHATTENTION_DISABLE_LOCAL\"] if DISABLE_LOCAL else [])\n        + ([\"-DFLASHATTENTION_DISABLE_SOFTCAP\"] if DISABLE_SOFTCAP else [])\n        + ([\"-DFLASHATTENTION_DISABLE_PACKGQA\"] if DISABLE_PACKGQA else [])\n        + ([\"-DFLASHATTENTION_DISABLE_FP16\"] if DISABLE_FP16 else [])\n        + ([\"-DFLASHATTENTION_DISABLE_FP8\"] if DISABLE_FP8 else [])\n        + ([\"-DFLASHATTENTION_DISABLE_VARLEN\"] if DISABLE_VARLEN else [])\n        + ([\"-DFLASHATTENTION_DISABLE_CLUSTER\"] if DISABLE_CLUSTER else [])\n        + ([\"-DFLASHATTENTION_DISABLE_HDIM64\"] if DISABLE_HDIM64 else [])\n        + ([\"-DFLASHATTENTION_DISABLE_HDIM96\"] if DISABLE_HDIM96 else [])\n        + ([\"-DFLASHATTENTION_DISABLE_HDIM128\"] if DISABLE_HDIM128 else [])\n        + ([\"-DFLASHATTENTION_DISABLE_HDIM192\"] if DISABLE_HDIM192 else [])\n        + ([\"-DFLASHATTENTION_DISABLE_HDIM256\"] if DISABLE_HDIM256 else [])\n        + ([\"-DFLASHATTENTION_DISABLE_SM8x\"] if DISABLE_SM8x else [])\n        + ([\"-DFLASHATTENTION_ENABLE_VCOLMAJOR\"] if ENABLE_VCOLMAJOR else [])\n        + ([\"-DFLASHATTENTION_DISABLE_HDIMDIFF64\"] if DISABLE_HDIMDIFF64 else [])\n        + ([\"-DFLASHATTENTION_DISABLE_HDIMDIFF192\"] if DISABLE_HDIMDIFF192 else [])\n    )\n\n    DTYPE_FWD_SM80 = [\"bf16\"] + ([\"fp16\"] if not DISABLE_FP16 else [])\n    DTYPE_FWD_SM90 = [\"bf16\"] + ([\"fp16\"] if not DISABLE_FP16 else []) + ([\"e4m3\"] if not DISABLE_FP8 else [])\n    HALF_DTYPE_FWD_SM90 = [\"bf16\"] + ([\"fp16\"] if not DISABLE_FP16 else [])\n    DTYPE_BWD = [\"bf16\"] + ([\"fp16\"] if not DISABLE_FP16 else [])\n    HEAD_DIMENSIONS_BWD = (\n        []\n        + ([64] if not DISABLE_HDIM64 else [])\n        + ([96] if not DISABLE_HDIM96 else [])\n        + ([128] if not DISABLE_HDIM128 else [])\n        + ([192] if not DISABLE_HDIM192 else [])\n        + ([256] if not DISABLE_HDIM256 else [])\n    )\n    # build will now explode with this compilation grouping given all our templating\n    # HEAD_DIMENSIONS_FWD = [\"all\", \"diff\"]\n    HEAD_DIMENSIONS_FWD = HEAD_DIMENSIONS_BWD\n    HEAD_DIMENSIONS_DIFF64_FWD = (\n        []\n        + ([\"64_256\"] if not DISABLE_HDIMDIFF64 else [])\n        + ([\"64_512\"] if not DISABLE_HDIMDIFF64 else [])\n    )\n    HEAD_DIMENSIONS_DIFF192_FWD = (\n        []\n        + ([\"192_128\"] if not DISABLE_HDIMDIFF192 else [])\n    )\n    HEAD_DIMENSIONS_FWD_SM80 = HEAD_DIMENSIONS_BWD\n    SPLIT = [\"\"] + ([\"_split\"] if not DISABLE_SPLIT else [])\n    PAGEDKV = [\"\"] + ([\"_paged\"] if not DISABLE_PAGEDKV else [])\n    SOFTCAP = [\"\"] + ([\"_softcap\"] if not DISABLE_SOFTCAP else [])\n    SOFTCAP_ALL = [\"\"] if DISABLE_SOFTCAP else [\"_softcapall\"]\n    PACKGQA = [\"\"] + ([\"_packgqa\"] if not DISABLE_PACKGQA else [])\n    # We already always hard-code PackGQA=true for Sm8x\n    sources_fwd_sm80 = [f\"instantiations/flash_fwd_hdim{hdim}_{dtype}{paged}{split}{softcap}_sm80.cu\"\n                        for hdim, dtype, split, paged, softcap in itertools.product(HEAD_DIMENSIONS_FWD_SM80, DTYPE_FWD_SM80, SPLIT, PAGEDKV, SOFTCAP_ALL)]\n    # We already always hard-code PackGQA=true for Sm9x if PagedKV or Split\n    sources_fwd_sm90 = [f\"instantiations/flash_fwd_hdim{hdim}_{dtype}{paged}{split}{softcap}{packgqa}_sm90.cu\"\n                        for hdim, dtype, split, paged, softcap, packgqa in itertools.product(HEAD_DIMENSIONS_FWD, DTYPE_FWD_SM90, SPLIT, PAGEDKV, SOFTCAP, PACKGQA)\n                        if not (packgqa and (paged or split))]\n    if not DISABLE_HDIMDIFF64:\n        sources_fwd_sm90 += [f\"instantiations/flash_fwd_hdim{hdim}_{dtype}{paged}{split}{softcap}{packgqa}_sm90.cu\"\n                             for hdim, dtype, split, paged, softcap, packgqa in itertools.product(HEAD_DIMENSIONS_DIFF64_FWD, HALF_DTYPE_FWD_SM90, SPLIT, PAGEDKV, SOFTCAP, PACKGQA)\n                             if not (packgqa and (paged or split))]\n    if not DISABLE_HDIMDIFF192:\n        sources_fwd_sm90 += [f\"instantiations/flash_fwd_hdim{hdim}_{dtype}{paged}{split}{softcap}{packgqa}_sm90.cu\"\n                            for hdim, dtype, split, paged, softcap, packgqa in itertools.product(HEAD_DIMENSIONS_DIFF192_FWD, DTYPE_FWD_SM90, SPLIT, PAGEDKV, SOFTCAP, PACKGQA)\n                            if not (packgqa and (paged or split))]\n    sources_bwd_sm80 = [f\"instantiations/flash_bwd_hdim{hdim}_{dtype}{softcap}_sm80.cu\"\n                        for hdim, dtype, softcap in itertools.product(HEAD_DIMENSIONS_BWD, DTYPE_BWD, SOFTCAP)]\n    sources_bwd_sm90 = [f\"instantiations/flash_bwd_hdim{hdim}_{dtype}{softcap}_sm90.cu\"\n                        for hdim, dtype, softcap in itertools.product(HEAD_DIMENSIONS_BWD, DTYPE_BWD, SOFTCAP_ALL)]\n    if DISABLE_BACKWARD:\n        sources_bwd_sm90 = []\n        sources_bwd_sm80 = []\n    \n    # Choose between flash_api.cpp and flash_api_stable.cpp based on torch version\n    torch_version = parse(torch.__version__)\n    target_version = parse(\"2.9.0.dev20250830\")\n    stable_args = []\n      \n    if torch_version >= target_version:\n        flash_api_source = \"flash_api_stable.cpp\"\n        stable_args = [\"-DTORCH_TARGET_VERSION=0x0209000000000000\"]  # Targets minimum runtime version torch 2.9.0\n    else:\n        flash_api_source = \"flash_api.cpp\"\n\n    sources = (\n        [flash_api_source]\n        + (sources_fwd_sm80 if not DISABLE_SM8x else []) + sources_fwd_sm90\n        + (sources_bwd_sm80 if not DISABLE_SM8x else []) + sources_bwd_sm90\n    )\n    if not DISABLE_SPLIT:\n        sources += [\"flash_fwd_combine.cu\"]\n    sources += [\"flash_prepare_scheduler.cu\"]\n    nvcc_flags = [\n        \"-O3\",\n        \"-std=c++17\",\n        \"--ftemplate-backtrace-limit=0\",  # To debug template code\n        \"--use_fast_math\",\n        # \"--keep\",\n        # \"--ptxas-options=--verbose,--register-usage-level=5,--warn-on-local-memory-usage\",  # printing out number of registers\n        \"--resource-usage\",  # printing out number of registers\n        # f\"--split-compile={os.getenv('NVCC_THREADS', '4')}\",  # split-compile is faster\n        \"-lineinfo\",  # TODO: disable this for release to reduce binary size\n        \"-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED\",  # Necessary for the WGMMA shapes that we use\n        \"-DCUTLASS_ENABLE_GDC_FOR_SM90\",  # For PDL\n        \"-DCUTLASS_DEBUG_TRACE_LEVEL=0\",  # Can toggle for debugging\n        \"-DNDEBUG\",  # Important, otherwise performance is severely impacted\n    ]\n    if get_platform() == \"win_amd64\":\n        nvcc_flags.extend(\n            [\n                \"-D_USE_MATH_DEFINES\",  # for M_LN2\n                \"-Xcompiler=/Zc:__cplusplus\",  # sets __cplusplus correctly, CUTLASS_CONSTEXPR_IF_CXX17 needed for cutlass::gcd\n            ]\n        )\n    include_dirs = [\n        Path(this_dir),\n        cutlass_dir / \"include\",\n    ]\n\n    ext_modules.append(\n        CUDAExtension(\n            name=f\"{PACKAGE_NAME}._C\",\n            sources=sources,\n            extra_compile_args={\n                \"cxx\": [\"-O3\", \"-std=c++17\", \"-DPy_LIMITED_API=0x03090000\"] + stable_args + feature_args,\n                \"nvcc\": nvcc_threads_args() + nvcc_flags + cc_flag + feature_args,\n            },\n            include_dirs=include_dirs,\n            py_limited_api=True,\n        )\n    )\n\n\ndef get_package_version():\n    with open(Path(this_dir) / \"__init__.py\", \"r\") as f:\n        version_match = re.search(r\"^__version__\\s*=\\s*(.*)$\", f.read(), re.MULTILINE)\n    public_version = ast.literal_eval(version_match.group(1))\n    local_version = os.environ.get(\"FLASH_ATTN_LOCAL_VERSION\")\n    if local_version:\n        return f\"{public_version}+{local_version}\"\n    else:\n        return str(public_version)\n\n\ndef get_wheel_url():\n    # Determine the version numbers that will be used to determine the correct wheel\n    # We're using the CUDA version used to build torch, not the one currently installed\n    # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)\n    torch_cuda_version = parse(torch.version.cuda)\n    torch_version_raw = parse(torch.__version__)\n    # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2\n    # to save CI time. Minor versions should be compatible.\n    torch_cuda_version = parse(\"11.8\") if torch_cuda_version.major == 11 else parse(\"12.2\")\n    python_version = f\"cp{sys.version_info.major}{sys.version_info.minor}\"\n    platform_name = get_platform()\n    package_version = get_package_version()\n    # cuda_version = f\"{cuda_version_raw.major}{cuda_version_raw.minor}\"\n    cuda_version = f\"{torch_cuda_version.major}{torch_cuda_version.minor}\"\n    torch_version = f\"{torch_version_raw.major}.{torch_version_raw.minor}\"\n    cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()\n\n    # Determine wheel URL based on CUDA version, torch version, python version and OS\n    wheel_filename = f\"{PACKAGE_NAME}-{package_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl\"\n    wheel_url = BASE_WHEEL_URL.format(tag_name=f\"v{package_version}\", wheel_name=wheel_filename)\n    return wheel_url, wheel_filename\n\n\nclass CachedWheelsCommand(_bdist_wheel):\n    \"\"\"\n    The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot\n    find an existing wheel (which is currently the case for all installs). We use\n    the environment parameters to detect whether there is already a pre-built version of a compatible\n    wheel available and short-circuits the standard full build pipeline.\n    \"\"\"\n\n    def run(self):\n        if FORCE_BUILD:\n            return super().run()\n\n        wheel_url, wheel_filename = get_wheel_url()\n        print(\"Guessing wheel URL: \", wheel_url)\n        try:\n            urllib.request.urlretrieve(wheel_url, wheel_filename)\n\n            # Make the archive\n            # Lifted from the root wheel processing command\n            # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85\n            if not os.path.exists(self.dist_dir):\n                os.makedirs(self.dist_dir)\n\n            impl_tag, abi_tag, plat_tag = self.get_tag()\n            archive_basename = f\"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}\"\n\n            wheel_path = os.path.join(self.dist_dir, archive_basename + \".whl\")\n            print(\"Raw wheel path\", wheel_path)\n            shutil.move(wheel_filename, wheel_path)\n        except urllib.error.HTTPError:\n            print(\"Precompiled wheel not found. Building from source...\")\n            # If the wheel could not be downloaded, build from source\n            super().run()\n\nsetup(\n    name=PACKAGE_NAME,\n    version=get_package_version(),\n    packages=find_packages(\n        exclude=(\n            \"build\",\n            \"csrc\",\n            \"include\",\n            \"tests\",\n            \"dist\",\n            \"docs\",\n            \"benchmarks\",\n        )\n    ),\n    py_modules=[\"flash_attn_interface\", \"flash_attn_config\"],\n    description=\"FlashAttention-3\",\n    long_description=long_description,\n    long_description_content_type=\"text/markdown\",\n    classifiers=[\n        \"Programming Language :: Python :: 3\",\n        \"License :: OSI Approved :: Apache Software License\",\n        \"Operating System :: Unix\",\n    ],\n    ext_modules=ext_modules,\n    cmdclass={\"bdist_wheel\": CachedWheelsCommand, \"build_ext\": BuildExtension}\n    if ext_modules\n    else {\n        \"bdist_wheel\": CachedWheelsCommand,\n    },\n    python_requires=\">=3.8\",\n    install_requires=[\n        \"torch\",\n        \"einops\",\n        \"packaging\",\n        \"ninja\",\n    ],\n    options={\"bdist_wheel\": {\"py_limited_api\": \"cp39\"}},\n)\n"
  },
  {
    "path": "hopper/sm90_pipeline_no_cluster.hpp",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n#include<cutlass/pipeline/sm90_pipeline.hpp>\n\nnamespace cutlass {\n\nusing namespace cute;\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n// As of Cutlass v3.6.0, if size(ClusterShape) == 1, PipelineTmaAsync has all threads\n// signaling the barrier during consumer_release. This causes a perf regression in FA3\n// forward pass (especially hdim 128 causal). We instead reimplement the version of\n// PipelineTmaAsync before v3.6.0 where only 1 out of 128 threads signals the barrier.\n//\n// Assumption: params.num_consumers % NumThreadsPerWarpGroup == 0\ntemplate <int Stages_, class Base=cutlass::PipelineTmaAsync<Stages_>>\nclass PipelineTmaAsyncNoCluster: public Base {\npublic:\n  using FullBarrier = typename Base::FullBarrier;\n  using EmptyBarrier = typename Base::EmptyBarrier;\n  static constexpr uint32_t Stages = Stages_;\n  using PipelineState = typename Base::PipelineState;\n\n  using SharedStorage = typename Base::SharedStorage;\n  using ThreadCategory = typename Base::ThreadCategory;\n  using Params = typename Base::Params;\n\n  static\n  CUTLASS_DEVICE\n  void\n  init_barriers(SharedStorage& storage, Params params) {\n    int warp_idx = canonical_warp_idx_sync();\n    bool is_initializing_warp = (warp_idx == 0);\n    if (is_initializing_warp) {\n      // Barrier FULL and EMPTY init\n      constexpr int producer_arv_cnt = 1;\n      uint32_t const num_consumer_warpgroups_per_cluster = (params.num_consumers + NumThreadsPerWarpGroup - 1) / NumThreadsPerWarpGroup;\n      uint32_t const multicast_consumer_arrival_count = num_consumer_warpgroups_per_cluster;\n\n      cutlass::arch::detail::initialize_barrier_array_pair_aligned<decltype(storage.full_barrier_), decltype(storage.empty_barrier_), Stages>(\n          storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count);\n    }\n    cutlass::arch::fence_barrier_init();\n  }\n\n  template<class ClusterShape, class InitBarriers, class InitMasks>\n  CUTLASS_DEVICE\n  PipelineTmaAsyncNoCluster(SharedStorage& storage, Params params, ClusterShape cluster_shape, InitBarriers = {}, InitMasks = {})\n      : Base(storage, params, make_shape(_1{}, _1{}, _1{}) /*cluster_shape*/, cute::false_type{} /*init_barriers*/, cute::false_type{} /*init_masks*/)\n      , empty_barrier_ptr_(&storage.empty_barrier_[0]) {\n\n    int warp_idx = canonical_warp_idx_sync();\n    int lane_predicate = cute::elect_one_sync();\n\n    static_assert(cute::is_same_v<InitBarriers, cute::true_type> || cute::is_same_v<InitBarriers, cute::false_type>);\n    static_assert(cute::is_same_v<InitMasks, cute::true_type> || cute::is_same_v<InitMasks, cute::false_type>);\n    if constexpr (cute::is_same_v<InitBarriers, cute::true_type>) {\n      init_barriers(storage, params);\n    }\n\n  }\n\n  // Constructor\n  template<class ClusterShape>\n  CUTLASS_DEVICE\n  PipelineTmaAsyncNoCluster(SharedStorage& storage, Params params, ClusterShape cluster_shape)\n      : PipelineTmaAsyncNoCluster(storage, params, cluster_shape, cute::true_type{}, cute::true_type{}) { }\n\n  template<class ClusterShape, class InitBarriers>\n  CUTLASS_DEVICE\n  PipelineTmaAsyncNoCluster(SharedStorage& storage, Params params, ClusterShape cluster_shape, InitBarriers = {})\n      : PipelineTmaAsyncNoCluster(storage, params, cluster_shape, InitBarriers{}, cute::true_type{}) { }\n\n  CUTLASS_DEVICE\n  void consumer_release(PipelineState state) {\n    consumer_release(state.index());\n  }\n\nprivate:\n  EmptyBarrier* const empty_barrier_ptr_ = nullptr;\n\n  // Consumer signalling Producer of completion\n  // Ensures all blocks in the Same Row and Column get notifed.\n  CUTLASS_DEVICE\n  void consumer_release(uint32_t stage, uint32_t skip = false) {\n    empty_barrier_ptr_[stage].arrive(0 /*dst_blockid_*/, uint32_t(threadIdx.x % cutlass::NumThreadsPerWarpGroup == 0) & (!skip) /*is_signaling_thread*/);\n  }\n\n};\n\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n} // end namespace cutlass\n"
  },
  {
    "path": "hopper/softmax.h",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n#include <cmath>\n\n#include <cute/tensor.hpp>\n\n#include <cutlass/numeric_types.h>\n\n#include \"utils.h\"\n\nnamespace flash {\n\nusing namespace cute;\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>\n__device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {\n    static_assert(Layout0::rank == 2, \"Only support 2D Tensor\");\n    static_assert(Layout1::rank == 1, \"Only support 1D Tensor\");\n    CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));\n    #pragma unroll\n    for (int ni = 0; ni < size<1>(tensor); ni++) {\n        #pragma unroll\n        for (int mi = 0; mi < size<0>(tensor); mi++) {\n            summary(mi) = zero_init && ni == 0 ? tensor(mi, ni) : op(summary(mi), tensor(mi, ni));\n        }\n    }\n}\n\ntemplate<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>\n__device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {\n    CUTE_STATIC_ASSERT_V(size(dst) == size(src));\n    #pragma unroll\n    for (int i = 0; i < size(dst); i++) {\n        dst(i) = Allreduce<4>::run(src(i), op);\n    }\n}\n\ntemplate<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>\n__device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {\n    thread_reduce_<zero_init>(tensor, summary, op);\n    quad_allreduce_(summary, summary, op);\n}\n\ntemplate<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>\n__device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){\n    MaxOp<float> max_op;\n    reduce_<zero_init>(tensor, max, max_op);\n}\n\ntemplate<bool zero_init=true, bool warp_reduce=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>\n__device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){\n    SumOp<float> sum_op;\n    thread_reduce_<zero_init>(tensor, sum, sum_op);\n    if constexpr (warp_reduce) { quad_allreduce_(sum, sum, sum_op); }\n}\n\n// Apply the exp to all the elements.\ntemplate <bool Scale_max=true, bool Check_inf=true, int Max_offset=0,\n        typename Engine0, typename Layout0, typename Engine1, typename Layout1>\n__forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {\n    // For FP8, we can subtract max by 8.0 so that the value after exp2 is in the range of [0, 256].\n    // This lets us use more of the FP8 range (instead of just [0, 1]) to reduce underflow.\n    static constexpr float max_offset = float(Max_offset);  // We can only template on int, not float\n    static_assert(Layout0::rank == 2, \"Only support 2D Tensor\");\n    static_assert(Layout1::rank == 1, \"Only support 1D Tensor\");\n    CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));\n    #pragma unroll\n    for (int mi = 0; mi < size<0>(tensor); ++mi) {\n        // If max is -inf, then all elements must have been -inf (possibly due to masking).\n        // We don't want (-inf - (-inf)) since that would give NaN.\n        const float max_scaled = Check_inf\n            ? (max(mi) == -INFINITY ? 0.f : (!Scale_max ? max(mi) : max(mi) * scale) - max_offset)\n            : (!Scale_max ? max(mi) : max(mi) * scale) - max_offset;\n        #pragma unroll\n        for (int ni = 0; ni < size<1>(tensor); ++ni)  {\n            // Instead of computing exp(x - max), we compute exp2(x * log_2(e) -\n            // max * log_2(e)). This allows the compiler to use the ffma\n            // instruction instead of fadd and fmul separately.\n            tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);\n        }\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int kNRows, int Max_offset=0>\nstruct Softmax {\n\n    using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));\n    TensorT row_max, row_sum;\n    float const softmax_scale_log2;\n\n    CUTLASS_DEVICE Softmax(float const softmax_scale_log2_) : softmax_scale_log2(softmax_scale_log2_) {};\n\n    template<bool Is_first, bool Check_inf=false, typename Tensor0>\n    __forceinline__ __device__ TensorT max_get_scale(Tensor0 &acc_s) {\n        // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))\n        Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));\n        static_assert(CUTE_STATIC_V(size<0>(scores)) == kNRows);\n        TensorT scores_scale;\n        if constexpr (Is_first) {\n            flash::template reduce_max</*zero_init=*/true>(scores, row_max);\n            cute::fill(scores_scale, 1.f);\n        } else {\n            Tensor scores_max_prev = make_fragment_like(row_max);\n            cute::copy(row_max, scores_max_prev);\n            flash::template reduce_max</*zero_init=*/false>(scores, row_max);\n            #pragma unroll\n            for (int mi = 0; mi < size(row_max); ++mi) {\n                float scores_max_cur = !Check_inf\n                    ? row_max(mi)\n                    : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));\n                scores_scale(mi) = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);\n                row_sum(mi) *= scores_scale(mi);\n            }\n        }\n        return scores_scale;\n    };\n\n    template<bool Is_first, bool Check_inf=false, typename Tensor0>\n    __forceinline__ __device__ void online_softmax(Tensor0 &acc_s) {\n        // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))\n        Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));\n        static_assert(CUTE_STATIC_V(size<0>(scores)) == kNRows);\n        flash::template scale_apply_exp2</*Scale_max=*/true, Check_inf, Max_offset>(scores, row_max, softmax_scale_log2);\n        // We don't do the reduce across threads here since we don't need to use the row_sum.\n        // We do that reduce at the end when we need to normalize the softmax.\n        flash::reduce_sum</*zero_init=*/Is_first, /*warp_reduce=*/false>(scores, row_sum);\n    };\n\n    __forceinline__ __device__ TensorT finalize(float const final_scale=1.f) {\n        SumOp<float> sum_op;\n        quad_allreduce_(row_sum, row_sum, sum_op);\n        TensorT scores_scale;\n        #pragma unroll\n        for (int mi = 0; mi < size(row_sum); ++mi) {\n            float sum = row_sum(mi);\n            float inv_sum = (sum == 0.f || sum != sum) ? 0.f : 1.f / sum;\n            scores_scale(mi) = inv_sum * final_scale;\n            // For FP8, we might have scaled the output of exp by 2**8 so we need to divide sum by that amount.\n            if constexpr (Max_offset != 0) {\n                static constexpr float sum_scale = 1.f / float(1 << Max_offset);\n                sum *= sum_scale;\n            }\n            row_sum(mi) = (sum == 0.f || sum != sum) ? -INFINITY : row_max(mi) * (softmax_scale_log2 * float(M_LN2)) + __logf(sum);\n        }\n        return scores_scale;\n    };\n\n    template<typename Tensor1>\n    __forceinline__ __device__ void rescale_o(Tensor1 &acc_o, TensorT const &scores_scale) {\n        // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))\n        Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));\n        static_assert(CUTE_STATIC_V(size<0>(acc_o_rowcol)) == kNRows);\n        #pragma unroll\n        for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {\n            #pragma unroll\n            for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale(mi); }\n        }\n    };\n\n};\n\n}  // namespace flash\n"
  },
  {
    "path": "hopper/static_switch.h",
    "content": "// Inspired by\n// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h\n// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h\n\n#pragma once\n\n/// @param COND       - a boolean expression to switch by\n/// @param CONST_NAME - a name given for the constexpr bool variable.\n/// @param ...       - code to execute for true and false\n///\n/// Usage:\n/// ```\n/// BOOL_SWITCH(flag, BoolConst, [&] {\n///     some_function<BoolConst>(...);\n/// });\n/// ```\n//\n\n#define BOOL_SWITCH(COND, CONST_NAME, ...)                                                       \\\n  [&] {                                                                                          \\\n    if (COND) {                                                                                  \\\n      constexpr static bool CONST_NAME = true;                                                   \\\n      return __VA_ARGS__();                                                                      \\\n    } else {                                                                                     \\\n      constexpr static bool CONST_NAME = false;                                                  \\\n      return __VA_ARGS__();                                                                      \\\n    }                                                                                            \\\n  }()\n\n#ifdef FLASHATTENTION_DISABLE_LOCAL\n  #define CAUSAL_LOCAL_SWITCH(CAUSAL_COND, LOCAL_COND, CAUSAL_CONST_NAME, LOCAL_CONST_NAME, ...) \\\n    [&] {                                                                                        \\\n      constexpr static bool LOCAL_CONST_NAME = false;                                            \\\n      if (CAUSAL_COND) {                                                                         \\\n        constexpr static bool CAUSAL_CONST_NAME = true;                                          \\\n        return __VA_ARGS__();                                                                    \\\n      } else {                                                                                   \\\n        constexpr static bool CAUSAL_CONST_NAME = false;                                         \\\n        return __VA_ARGS__();                                                                    \\\n      }                                                                                          \\\n    }()\n#else\n  #define CAUSAL_LOCAL_SWITCH(CAUSAL_COND, LOCAL_COND, CAUSAL_CONST_NAME, LOCAL_CONST_NAME, ...) \\\n    [&] {                                                                                        \\\n      if (CAUSAL_COND) {                                                                         \\\n        constexpr static bool CAUSAL_CONST_NAME = true;                                          \\\n        constexpr static bool LOCAL_CONST_NAME = false;                                          \\\n        return __VA_ARGS__();                                                                    \\\n      } else if (LOCAL_COND) {                                                                   \\\n        constexpr static bool CAUSAL_CONST_NAME = false;                                         \\\n        constexpr static bool LOCAL_CONST_NAME = true;                                           \\\n        return __VA_ARGS__();                                                                    \\\n      } else {                                                                                   \\\n        constexpr static bool CAUSAL_CONST_NAME = false;                                         \\\n        constexpr static bool LOCAL_CONST_NAME = false;                                          \\\n        return __VA_ARGS__();                                                                    \\\n      }                                                                                          \\\n    }()\n#endif\n\n#ifdef FLASHATTENTION_DISABLE_SOFTCAP\n  #define SOFTCAP_SWITCH(COND, CONST_NAME, ...)                                                  \\\n  [&] {                                                                                          \\\n    constexpr static bool CONST_NAME = false;                                                    \\\n    return __VA_ARGS__();                                                                        \\\n  }()\n#else\n  #define SOFTCAP_SWITCH BOOL_SWITCH\n#endif\n\n#ifdef FLASHATTENTION_DISABLE_PAGEDKV\n  #define PAGEDKV_SWITCH(COND, CONST_NAME, ...)                                                  \\\n  [&] {                                                                                          \\\n    constexpr static bool CONST_NAME = false;                                                    \\\n    return __VA_ARGS__();                                                                        \\\n  }()\n#else\n  #define PAGEDKV_SWITCH BOOL_SWITCH\n#endif\n\n#ifdef FLASHATTENTION_DISABLE_SPLIT\n  #define SPLIT_SWITCH(COND, CONST_NAME, ...)                                                    \\\n  [&] {                                                                                          \\\n    constexpr static bool CONST_NAME = false;                                                    \\\n    return __VA_ARGS__();                                                                        \\\n  }()\n#else\n  #define SPLIT_SWITCH BOOL_SWITCH\n#endif\n\n#ifdef FLASHATTENTION_DISABLE_APPENDKV\n  #define APPENDKV_SWITCH(COND, CONST_NAME, ...)                                                 \\\n  [&] {                                                                                          \\\n    constexpr static bool CONST_NAME = false;                                                    \\\n    return __VA_ARGS__();                                                                        \\\n  }()\n#else\n  #define APPENDKV_SWITCH BOOL_SWITCH\n#endif\n\n#ifdef FLASHATTENTION_DISABLE_PACKGQA\n  #define PACKGQA_SWITCH(COND, CONST_NAME, ...)                                                  \\\n  [&] {                                                                                          \\\n    constexpr static bool CONST_NAME = false;                                                    \\\n    return __VA_ARGS__();                                                                        \\\n  }()\n#else\n  #define PACKGQA_SWITCH BOOL_SWITCH\n#endif\n\n#ifdef FLASHATTENTION_DISABLE_VARLEN\n  #define VARLEN_SWITCH(COND, CONST_NAME, ...)                                                   \\\n  [&] {                                                                                          \\\n    constexpr static bool CONST_NAME = false;                                                    \\\n    return __VA_ARGS__();                                                                        \\\n  }()\n#else\n  #define VARLEN_SWITCH BOOL_SWITCH\n#endif\n\n#ifdef FLASHATTENTION_DISABLE_CLUSTER\n  #define CLUSTER_SWITCH(COND, CONST_NAME, ...)                                                  \\\n  [&] {                                                                                          \\\n    constexpr static bool CONST_NAME = false;                                                    \\\n    return __VA_ARGS__();                                                                        \\\n  }()\n#else\n  #define CLUSTER_SWITCH BOOL_SWITCH\n#endif\n\n#ifdef FLASHATTENTION_DISABLE_SM8x\n  #define ARCH_SWITCH(ARCH, ARCH_NAME, ...)                                                      \\\n  [&] {                                                                                          \\\n    constexpr static int ARCH_NAME = 90;                                                         \\\n    return __VA_ARGS__();                                                                        \\\n  }()\n#else\n  #define ARCH_SWITCH(ARCH, ARCH_NAME, ...)                                                      \\\n  [&] {                                                                                          \\\n    if (ARCH == 86 || ARCH == 89) {                                                              \\\n      constexpr static int ARCH_NAME = 86;                                                       \\\n      return __VA_ARGS__();                                                                      \\\n    } else if (ARCH < 90) {                                                                      \\\n      constexpr static int ARCH_NAME = 80;                                                       \\\n      return __VA_ARGS__();                                                                      \\\n    } else {                                                                                     \\\n      constexpr static int ARCH_NAME = 90;                                                       \\\n      return __VA_ARGS__();                                                                      \\\n    }                                                                                            \\\n  }()\n#endif\n\n#ifndef FLASHATTENTION_ENABLE_VCOLMAJOR\n  #define VCOLMAJOR_SWITCH(COND, CONST_NAME, ...)                                                \\\n  [&] {                                                                                          \\\n    constexpr static bool CONST_NAME = false;                                                    \\\n    return __VA_ARGS__();                                                                        \\\n  }()\n#else\n  #define VCOLMAJOR_SWITCH BOOL_SWITCH\n#endif\n\n#define HEADDIM_SWITCH(HEADDIM, ...)                                                             \\\n  [&] {                                                                                          \\\n    if (HEADDIM == 64) {                                                                         \\\n      constexpr static int kHeadSize = 64;                                                       \\\n      return __VA_ARGS__();                                                                      \\\n    } else if (HEADDIM == 96) {                                                                  \\\n      constexpr static int kHeadSize = 96;                                                       \\\n      return __VA_ARGS__();                                                                      \\\n    } else if (HEADDIM == 128) {                                                                 \\\n      constexpr static int kHeadSize = 128;                                                      \\\n      return __VA_ARGS__();                                                                      \\\n    } else if (HEADDIM == 96) {                                                                  \\\n      constexpr static int kHeadSize = 96;                                                       \\\n      return __VA_ARGS__();                                                                      \\\n    } else if (HEADDIM == 256) {                                                                 \\\n      constexpr static int kHeadSize = 256;                                                      \\\n      return __VA_ARGS__();                                                                      \\\n    }                                                                                            \\\n  }()\n\n#define NUM_WARP_SWITCH(VALUE, CONST_NAME, ...)                                                  \\\n  [&] {                                                                                          \\\n    if (VALUE <= 1) {                                                                            \\\n      constexpr static int CONST_NAME = 1;                                                       \\\n      return __VA_ARGS__();                                                                      \\\n    } else if (VALUE <= 2) {                                                                     \\\n      constexpr static int CONST_NAME = 2;                                                       \\\n      return __VA_ARGS__();                                                                      \\\n    } else if (VALUE <= 4) {                                                                     \\\n      constexpr static int CONST_NAME = 4;                                                       \\\n      return __VA_ARGS__();                                                                      \\\n    } else if (VALUE <= 8) {                                                                     \\\n      constexpr static int CONST_NAME = 8;                                                       \\\n      return __VA_ARGS__();                                                                      \\\n    } else if (VALUE <= 16) {                                                                    \\\n      constexpr static int CONST_NAME = 16;                                                      \\\n      return __VA_ARGS__();                                                                      \\\n    } else {                                                                                     \\\n      constexpr static int CONST_NAME = 32;                                                      \\\n      return __VA_ARGS__();                                                                      \\\n    }                                                                                            \\\n  }()\n"
  },
  {
    "path": "hopper/test_attn_kvcache.py",
    "content": "import pytest\nfrom einops import rearrange, repeat\nimport torch\nimport flash_attn\nimport flash_attn_interface\nimport itertools\nimport math\nimport time\n\ndef construct_local_mask(\n    seqlen_q,\n    seqlen_k,\n    window_size=(-1, -1),  # -1 means infinite window size\n    query_padding_mask=None,\n    key_padding_mask=None,\n    device=None,\n    key_leftpad=None,\n):\n    row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), \"s -> s 1\")\n    col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)\n    if key_leftpad is not None:\n        key_leftpad = rearrange(key_leftpad, \"b -> b 1 1 1\")\n        col_idx = repeat(col_idx, \"s -> b 1 1 s\", b=key_leftpad.shape[0])\n        col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)\n    sk = (\n        seqlen_k\n        if key_padding_mask is None\n        else rearrange(key_padding_mask.sum(-1), \"b -> b 1 1 1\")\n    )\n    sq = (\n        seqlen_q\n        if query_padding_mask is None\n        else rearrange(query_padding_mask.sum(-1), \"b -> b 1 1 1\")\n    )\n    if window_size[0] < 0:\n        return col_idx > row_idx + sk - sq + window_size[1]\n    else:\n        sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk\n        return torch.logical_or(\n            col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),\n            col_idx < row_idx + sk - sq - window_size[0],\n        )\n\n\ndef attention_ref(\n    q,\n    k,\n    v,\n    query_padding_mask=None,\n    key_padding_mask=None,\n    attn_bias=None,\n    dropout_p=0.0,\n    dropout_mask=None,\n    causal=False,\n    window_size=(-1, -1),  # -1 means infinite window size\n    softcap=0.0,\n    upcast=True,\n    reorder_ops=False,\n    key_leftpad=None,\n):\n    \"\"\"\n    Arguments:\n        q: (batch_size, seqlen_q, nheads, head_dim)\n        k: (batch_size, seqlen_k, nheads_k, head_dim)\n        v: (batch_size, seqlen_k, nheads_k, head_dim)\n        query_padding_mask: (batch_size, seqlen_q)\n        key_padding_mask: (batch_size, seqlen_k)\n        attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)\n        dropout_p: float\n        dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)\n        causal: whether to apply causal masking\n        window_size: (int, int), left and right window size\n        upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast\n            output back to fp16/bf16.\n        reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.)\n            without changing the math. This is to estimate the numerical error from operation\n            reordering.\n    Output:\n        output: (batch_size, seqlen_q, nheads, head_dim)\n        attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout\n    \"\"\"\n    if causal:\n        window_size = (window_size[0], 0)\n    dtype_og = q.dtype\n    if upcast:\n        q, k, v = q.float(), k.float(), v.float()\n    seqlen_q, seqlen_k = q.shape[1], k.shape[1]\n    k = repeat(k, \"b s h d -> b s (h g) d\", g=q.shape[2] // k.shape[2])\n    v = repeat(v, \"b s h d -> b s (h g) d\", g=q.shape[2] // v.shape[2])\n    d = q.shape[-1]\n    if not reorder_ops:\n        scores = torch.einsum(\"bthd,bshd->bhts\", q / math.sqrt(d), k)\n    else:\n        scores = torch.einsum(\"bthd,bshd->bhts\", q, k / math.sqrt(d))\n    if softcap > 0:\n        scores = scores / softcap\n        scores = scores.tanh()\n        scores = scores * softcap\n    if key_padding_mask is not None:\n        scores.masked_fill_(rearrange(~key_padding_mask, \"b s -> b 1 1 s\"), float(\"-inf\"))\n    if window_size[0] >= 0 or window_size[1] >= 0:\n        local_mask = construct_local_mask(\n            seqlen_q,\n            seqlen_k,\n            window_size,\n            query_padding_mask,\n            key_padding_mask,\n            q.device,\n            key_leftpad=key_leftpad,\n        )\n        scores.masked_fill_(local_mask, float(\"-inf\"))\n    if attn_bias is not None:\n        scores = scores + attn_bias\n    attention = torch.softmax(scores, dim=-1).to(v.dtype)\n    # Some rows might be completely masked out so we fill them with zero instead of NaN\n    if window_size[0] >= 0 or window_size[1] >= 0:\n        attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0)\n    # We want to mask here so that the attention matrix doesn't have any NaNs\n    # Otherwise we'll get NaN in dV\n    if query_padding_mask is not None:\n        attention = attention.masked_fill(rearrange(~query_padding_mask, \"b s -> b 1 s 1\"), 0.0)\n    dropout_scaling = 1.0 / (1 - dropout_p)\n    # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling\n    # output = torch.einsum('bhts,bshd->bthd', attention_drop , v)\n    if dropout_mask is not None:\n        attention_drop = attention.masked_fill(~dropout_mask, 0.0)\n    else:\n        attention_drop = attention\n    output = torch.einsum(\"bhts,bshd->bthd\", attention_drop, v * dropout_scaling)\n    if query_padding_mask is not None:\n        output.masked_fill_(rearrange(~query_padding_mask, \"b s -> b s 1 1\"), 0.0)\n    return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)\n\n\n@pytest.mark.parametrize(\"causal\", [True, False])\n@pytest.mark.parametrize(\"num_requests\", [1, 4])\n@pytest.mark.parametrize(\"query_seqlen\", [1, 8, 120])\n@pytest.mark.parametrize(\"context_seqlen\", [1024, 3131, 4224])\n@pytest.mark.parametrize(\"headdim\", [64, 128, 256])\n@pytest.mark.parametrize(\"gqa_parallel\", [False, True])\n@pytest.mark.parametrize(\n    \"nheads_kv, gqa_ratio\",\n    [\n        (1, 1),\n        (2, 5),\n        (3, 3),\n        (1, 32),\n        (5, 7),\n        (8, 1),\n        (1, 16),\n        (12, 4),\n        (8, 2),\n    ],\n)\ndef test_flash_attn_kvcache_nosplit(nheads_kv, gqa_ratio, num_requests, query_seqlen, context_seqlen, headdim, causal, gqa_parallel):\n    device = \"cuda\"\n    num_caches = num_requests\n    cache_seqlen = context_seqlen\n    nheads_q = nheads_kv * gqa_ratio\n\n    k_cache = torch.randn(\n        (num_caches, cache_seqlen, nheads_kv, headdim), device=\"cuda\", dtype=torch.bfloat16\n    )\n    v_cache = torch.randn(\n        (num_caches, cache_seqlen, nheads_kv, headdim), device=\"cuda\", dtype=torch.bfloat16\n    )\n    q = torch.randn((num_requests, query_seqlen, nheads_q, headdim), device=\"cuda\", dtype=torch.bfloat16)\n    # cache_idxs = torch.randperm(num_caches, dtype=torch.int32, device=\"cuda\")[:num_requests]\n    cache_seqlens = torch.tensor([context_seqlen] * num_requests, dtype=torch.int32, device=\"cuda\")\n    torch.cuda.synchronize()\n\n    out_ref, _ = attention_ref(\n        q,\n        k_cache,\n        v_cache,\n        causal=causal,\n    )\n\n    out_fa3, lse_fa3 = flash_attn_interface.flash_attn_with_kvcache(\n                    q=q,\n                    k_cache=k_cache,\n                    v_cache=v_cache,\n                    cache_seqlens=cache_seqlens,\n                    # cache_batch_idx=cache_idxs,\n                    causal=causal,\n                    num_splits=1,\n                    return_softmax_lse=True,\n                    gqa_parallel=gqa_parallel\n                )\n\n\n    torch.cuda.synchronize()\n    assert ((out_ref - out_fa3).abs().max().item() <= 4e-3)\n    assert ((out_ref - out_fa3).abs().mean().item() <= 2e-4)\n\n\n@pytest.mark.parametrize(\"causal\", [True, False])\n@pytest.mark.parametrize(\"num_requests\", [1, 3])\n@pytest.mark.parametrize(\"query_seqlen\", [1, 8, 120])\n@pytest.mark.parametrize(\"context_seqlen\", [1600, 4000, 5555])\n@pytest.mark.parametrize(\"headdim\", [64, 128, 256])\n@pytest.mark.parametrize(\"gqa_parallel\", [True, False])\n@pytest.mark.parametrize(\n    \"nheads_kv, gqa_ratio\",\n    [\n        (1, 1),\n        (2, 5),\n        (3, 3),\n        (1, 32),\n        (5, 7),\n        (8, 1),\n        (1, 16),\n        (12, 4),\n        (8, 2),\n    ],\n)\ndef test_flash_attn_kvcache_nosplit_fp8(nheads_kv, gqa_ratio, num_requests, query_seqlen, context_seqlen, headdim, causal, gqa_parallel):\n    device = \"cuda\"\n    num_caches = num_requests\n    cache_seqlen = context_seqlen\n    nheads_q = nheads_kv * gqa_ratio\n\n    k_cache = torch.randn(\n        (num_caches, cache_seqlen, nheads_kv, headdim), device=\"cuda\", dtype=torch.bfloat16\n    )\n    v_cache = torch.randn(\n        (num_caches, cache_seqlen, nheads_kv, headdim), device=\"cuda\", dtype=torch.bfloat16\n    )\n    q = torch.randn((num_requests, query_seqlen, nheads_q, headdim), device=\"cuda\", dtype=torch.bfloat16)\n    q = q.to(torch.float8_e4m3fn)\n    k_cache = k_cache.to(torch.float8_e4m3fn)\n    v_cache = v_cache.to(torch.float8_e4m3fn)\n    # cache_idxs = torch.randperm(num_caches, dtype=torch.int32, device=\"cuda\")[:num_requests]\n    cache_seqlens = torch.tensor([context_seqlen] * num_requests, dtype=torch.int32, device=\"cuda\")\n    torch.cuda.synchronize()\n\n    out_ref, _ = attention_ref(\n        q,\n        k_cache,\n        v_cache,\n        causal=causal,\n    )\n\n    descale_q = torch.tensor([1.0], dtype=torch.float32, device='cuda')\n    descale_k = torch.tensor([1.0], dtype=torch.float32, device='cuda')\n    descale_v = torch.tensor([1.0], dtype=torch.float32, device='cuda')\n    out_fa3, lse_fa3 = flash_attn_interface.flash_attn_with_kvcache(\n                    q=q,\n                    k_cache=k_cache,\n                    v_cache=v_cache,\n                    cache_seqlens=cache_seqlens,\n                    # cache_batch_idx=cache_idxs,\n                    causal=causal,\n                    num_splits=1,\n                    return_softmax_lse=True,\n                    gqa_parallel=gqa_parallel,\n                    descale_q=descale_q, descale_k=descale_k, descale_v=descale_v\n                )\n\n\n    torch.cuda.synchronize()\n    assert ((out_ref - out_fa3).abs().max().item() <= 4e-2)\n    assert ((out_ref - out_fa3).abs().mean().item() <= 2e-3)\n\n@pytest.mark.parametrize(\"dtype\", [torch.bfloat16])\n@pytest.mark.parametrize(\"use_heuristic_only\", [True])\n# @pytest.mark.parametrize(\"use_heuristic_only\", [False])\n@pytest.mark.parametrize(\"causal\", [True, False])\n# @pytest.mark.parametrize(\"num_requests\", [1, 4, 16])\n@pytest.mark.parametrize(\"num_requests\", [1, 3])\n# @pytest.mark.parametrize(\"query_seqlen\", [1, 16, 32, 128])\n@pytest.mark.parametrize(\"query_seqlen\", [1, 8, 25])\n# @pytest.mark.parametrize(\"context_seqlen\", [4096, 16384, 65536])\n@pytest.mark.parametrize(\"context_seqlen\", [1600, 4000, 5555])\n@pytest.mark.parametrize(\"headdim\", [64, 128, 256])\n@pytest.mark.parametrize(\"cache_seqlen_rand\", [True, False])\n@pytest.mark.parametrize(\"gqa_parallel\", [True, False])\n@pytest.mark.parametrize(\n    \"nheads_kv, gqa_ratio\",\n    [\n        (1, 1),\n        (4, 1),\n        (2, 2),\n        (3, 3),\n        (4, 4),\n        (2, 5),\n        (3, 9),\n        (1, 16),\n        (1, 32),\n    ],\n)\ndef test_flash_attn_kvcache_output(nheads_kv, gqa_ratio, num_requests, query_seqlen, context_seqlen, headdim, causal, use_heuristic_only, cache_seqlen_rand, gqa_parallel, dtype):\n    device = \"cuda\"\n    num_caches = 16\n    if context_seqlen <= 65536:\n        cache_seqlen = 65536\n    else:\n        cache_seqlen = context_seqlen\n    nheads_q = nheads_kv * gqa_ratio\n    if use_heuristic_only:\n        max_splits = 1\n    else:\n        max_splits = 128\n\n    k_cache = torch.randn(\n        (num_caches, cache_seqlen, nheads_kv, headdim), device=\"cuda\", dtype=torch.bfloat16\n    )\n    v_cache = torch.randn(\n        (num_caches, cache_seqlen, nheads_kv, headdim), device=\"cuda\", dtype=torch.bfloat16\n    )\n    q = torch.randn((num_requests, query_seqlen, nheads_q, headdim), device=\"cuda\", dtype=torch.bfloat16)\n\n    q = q.to(dtype)\n    k_cache = k_cache.to(dtype)\n    v_cache = v_cache.to(dtype)\n    cache_idxs = torch.randperm(num_caches, dtype=torch.int32, device=\"cuda\")[:num_requests]\n    cache_seqlens = torch.randint(1, context_seqlen-1, (num_requests,), dtype=torch.int32).to(device) if cache_seqlen_rand else torch.tensor([context_seqlen] * num_requests, dtype=torch.int32, device=\"cuda\")\n    torch.cuda.synchronize()\n\n    out_ref, lse_ref = flash_attn_interface.flash_attn_with_kvcache(\n                    q=q,\n                    k_cache=k_cache,\n                    v_cache=v_cache,\n                    cache_seqlens=cache_seqlens,\n                    cache_batch_idx=cache_idxs,\n                    causal=causal,\n                    num_splits=1,\n                    return_softmax_lse=True,\n                    gqa_parallel=False\n                )\n\n    # i=0 case is with num splits heuristic\n    for i in range(0, max_splits+1):\n                out_fa3, lse_fa3 = flash_attn_interface.flash_attn_with_kvcache(\n                    q=q,\n                    k_cache=k_cache,\n                    v_cache=v_cache,\n                    cache_seqlens=cache_seqlens,\n                    cache_batch_idx=cache_idxs,\n                    causal=causal,\n                    num_splits=i,\n                    return_softmax_lse=True,\n                    gqa_parallel=gqa_parallel,\n                    max_seqlen_k_hint=context_seqlen\n                )\n\n                torch.cuda.synchronize()\n                print ('output-ref', i, out_ref)\n                print ('output-fa3',i, out_fa3)\n                print ('output-max-diff', i, context_seqlen, (out_ref - out_fa3).abs().max().item())\n                print ('output-mean-diff',i, context_seqlen, (out_ref - out_fa3).abs().mean().item())\n                print ('lse-max-diff',i, context_seqlen, (lse_ref - lse_fa3).abs().max().item())\n                print ('lse-mean-diff',i,  context_seqlen, (lse_ref - lse_fa3).abs().mean().item())\n\n                if cache_seqlen_rand:\n                    assert ((out_ref - out_fa3).abs().max().item() <= 1e-2)\n                    assert ((out_ref - out_fa3).abs().mean().item() <= 1e-3)\n                else:\n                    assert ((out_ref - out_fa3).abs().max().item() <= 2e-3)\n                    assert ((out_ref - out_fa3).abs().mean().item() <= 1e-4)\n                lse_max_ref = lse_ref.abs().max().item()\n                lse_mean_ref = lse_ref.abs().mean().item()\n                lse_max_fa3 = lse_fa3.abs().max().item()\n                lse_mean_fa3 = lse_fa3.abs().mean().item()\n                lse_max_diff = (lse_ref - lse_fa3).abs().max().item()\n                lse_mean_diff = (lse_ref - lse_fa3).abs().mean().item()\n                assert ((lse_max_ref == math.inf and lse_max_fa3 == math.inf) or lse_max_diff <= 1e-3)\n                assert ((lse_mean_ref == math.inf and lse_mean_fa3 == math.inf) or lse_mean_diff <= 1e-4)\n\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float8_e4m3fn])\n@pytest.mark.parametrize(\"use_heuristic_only\", [True])\n# @pytest.mark.parametrize(\"use_heuristic_only\", [False])\n@pytest.mark.parametrize(\"causal\", [True, False])\n# @pytest.mark.parametrize(\"num_requests\", [1, 4, 16])\n@pytest.mark.parametrize(\"num_requests\", [1, 3])\n# @pytest.mark.parametrize(\"query_seqlen\", [1, 16, 32, 128])\n@pytest.mark.parametrize(\"query_seqlen\", [1, 8, 25])\n# @pytest.mark.parametrize(\"context_seqlen\", [4096, 16384, 65536])\n@pytest.mark.parametrize(\"context_seqlen\", [1600, 4000, 5555])\n@pytest.mark.parametrize(\"headdim\", [64, 128, 256])\n@pytest.mark.parametrize(\"cache_seqlen_rand\", [True, False])\n@pytest.mark.parametrize(\"gqa_parallel\", [True, False])\n@pytest.mark.parametrize(\n    \"nheads_kv, gqa_ratio\",\n    [\n        (1, 1),\n        (4, 1),\n        (2, 2),\n        (3, 3),\n        (4, 4),\n        (2, 5),\n        (3, 9),\n        (1, 16),\n        (1, 32),\n    ],\n)\ndef test_flash_attn_kvcache_output_fp8(nheads_kv, gqa_ratio, num_requests, query_seqlen, context_seqlen, headdim, causal, use_heuristic_only, cache_seqlen_rand, gqa_parallel, dtype):\n    device = \"cuda\"\n    num_caches = 16\n    if context_seqlen <= 65536:\n        cache_seqlen = 65536\n    else:\n        cache_seqlen = context_seqlen\n    nheads_q = nheads_kv * gqa_ratio\n    if use_heuristic_only:\n        max_splits = 1\n    else:\n        max_splits = 128\n\n    k_cache = torch.randn(\n        (num_caches, cache_seqlen, nheads_kv, headdim), device=\"cuda\", dtype=torch.bfloat16\n    )\n    v_cache = torch.randn(\n        (num_caches, cache_seqlen, nheads_kv, headdim), device=\"cuda\", dtype=torch.bfloat16\n    )\n    q = torch.randn((num_requests, query_seqlen, nheads_q, headdim), device=\"cuda\", dtype=torch.bfloat16)\n\n    q = q.to(dtype)\n    k_cache = k_cache.to(dtype)\n    v_cache = v_cache.to(dtype)\n    cache_idxs = torch.randperm(num_caches, dtype=torch.int32, device=\"cuda\")[:num_requests]\n    cache_seqlens = torch.randint(1, context_seqlen-1, (num_requests,), dtype=torch.int32).to(device) if cache_seqlen_rand else torch.tensor([context_seqlen] * num_requests, dtype=torch.int32, device=\"cuda\")\n    torch.cuda.synchronize()\n\n\n    descale_q = torch.tensor([1.0], dtype=torch.float32, device='cuda')\n    descale_k = torch.tensor([1.0], dtype=torch.float32, device='cuda')\n    descale_v = torch.tensor([1.0], dtype=torch.float32, device='cuda')\n\n    out_ref, lse_ref = flash_attn_interface.flash_attn_with_kvcache(\n                    q=q,\n                    k_cache=k_cache,\n                    v_cache=v_cache,\n                    cache_seqlens=cache_seqlens,\n                    cache_batch_idx=cache_idxs,\n                    causal=causal,\n                    num_splits=1,\n                    return_softmax_lse=True,\n                    gqa_parallel=False,\n                    descale_q=descale_q, descale_k=descale_k, descale_v=descale_v\n                )\n\n    # i=0 case is with num splits heuristic\n    for i in range(0, max_splits+1):\n                out_fa3, lse_fa3 = flash_attn_interface.flash_attn_with_kvcache(\n                    q=q,\n                    k_cache=k_cache,\n                    v_cache=v_cache,\n                    cache_seqlens=cache_seqlens,\n                    cache_batch_idx=cache_idxs,\n                    causal=causal,\n                    num_splits=i,\n                    return_softmax_lse=True,\n                    gqa_parallel=gqa_parallel,\n                    max_seqlen_k_hint=context_seqlen,\n                    descale_q=descale_q, descale_k=descale_k, descale_v=descale_v\n                )\n\n                torch.cuda.synchronize()\n                print ('output-ref', i, out_ref)\n                print ('output-fa3',i, out_fa3)\n                print ('output-max-diff', i, context_seqlen, (out_ref - out_fa3).abs().max().item())\n                print ('output-mean-diff',i, context_seqlen, (out_ref - out_fa3).abs().mean().item())\n                print ('lse-max-diff',i, context_seqlen, (lse_ref - lse_fa3).abs().max().item())\n                print ('lse-mean-diff',i,  context_seqlen, (lse_ref - lse_fa3).abs().mean().item())\n\n                if cache_seqlen_rand:\n                    assert ((out_ref - out_fa3).abs().max().item() <= 1e-1)\n                    assert ((out_ref - out_fa3).abs().mean().item() <= 1e-2)\n                else:\n                    assert ((out_ref - out_fa3).abs().max().item() <= 2e-2)\n                    assert ((out_ref - out_fa3).abs().mean().item() <= 2e-3)\n                lse_max_ref = lse_ref.abs().max().item()\n                lse_mean_ref = lse_ref.abs().mean().item()\n                lse_max_fa3 = lse_fa3.abs().max().item()\n                lse_mean_fa3 = lse_fa3.abs().mean().item()\n                lse_max_diff = (lse_ref - lse_fa3).abs().max().item()\n                lse_mean_diff = (lse_ref - lse_fa3).abs().mean().item()\n                assert ((lse_max_ref == math.inf and lse_max_fa3 == math.inf) or lse_max_diff <= 1e-3)\n                assert ((lse_mean_ref == math.inf and lse_mean_fa3 == math.inf) or lse_mean_diff <= 1e-4)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "hopper/test_flash_attn.py",
    "content": "import os\nimport math\nimport itertools\n\nimport pytest\nimport torch\nimport torch.nn.functional as F\nfrom torch._C import parse_schema\nfrom torch.testing._internal.optests.generate_tests import (\n    safe_fake_check,\n    safe_schema_check,\n    safe_aot_autograd_check,\n)\n\nfrom einops import rearrange, repeat\ntry:\n    from flash_attn.layers.rotary import apply_rotary_emb\nexcept ImportError:\n    apply_rotary_emb = None\n\nfrom padding import pad_input, unpad_input\nfrom test_util import (\n    attention_ref,\n    generate_qkv,\n    generate_random_padding_mask,\n)\n\nfrom flash_attn_interface import flash_attn_func, flash_attn_varlen_func, flash_attn_combine\nfrom flash_attn_interface import flash_attn_with_kvcache, get_scheduler_metadata\n\n\nDISABLE_BACKWARD = os.getenv(\"FLASH_ATTENTION_DISABLE_BACKWARD\", \"FALSE\") == \"TRUE\"\nDISABLE_SPLIT = os.getenv(\"FLASH_ATTENTION_DISABLE_SPLIT\", \"FALSE\") == \"TRUE\"\nDISABLE_PAGEDKV = os.getenv(\"FLASH_ATTENTION_DISABLE_PAGEDKV\", \"FALSE\") == \"TRUE\"\nDISABLE_APPENDKV = os.getenv(\"FLASH_ATTENTION_DISABLE_APPENDKV\", \"FALSE\") == \"TRUE\"\nDISABLE_LOCAL = os.getenv(\"FLASH_ATTENTION_DISABLE_LOCAL\", \"FALSE\") == \"TRUE\"\nDISABLE_SOFTCAP = os.getenv(\"FLASH_ATTENTION_DISABLE_SOFTCAP\", \"FALSE\") == \"TRUE\"\nDISABLE_PACKGQA = os.getenv(\"FLASH_ATTENTION_DISABLE_PACKGQA\", \"FALSE\") == \"TRUE\"\nDISABLE_FP16 = os.getenv(\"FLASH_ATTENTION_DISABLE_FP16\", \"FALSE\") == \"TRUE\"\nDISABLE_FP8 = os.getenv(\"FLASH_ATTENTION_DISABLE_FP8\", \"FALSE\") == \"TRUE\" or torch.cuda.get_device_capability(\"cuda\")[0] < 9\nDISABLE_HDIM64 = os.getenv(\"FLASH_ATTENTION_DISABLE_HDIM64\", \"FALSE\") == \"TRUE\"\nDISABLE_HDIM96 = os.getenv(\"FLASH_ATTENTION_DISABLE_HDIM96\", \"FALSE\") == \"TRUE\"\nDISABLE_HDIM128 = os.getenv(\"FLASH_ATTENTION_DISABLE_HDIM128\", \"FALSE\") == \"TRUE\"\nDISABLE_HDIM192 = os.getenv(\"FLASH_ATTENTION_DISABLE_HDIM192\", \"FALSE\") == \"TRUE\"\nDISABLE_HDIM256 = os.getenv(\"FLASH_ATTENTION_DISABLE_HDIM256\", \"FALSE\") == \"TRUE\"\nENABLE_OPCHECK = os.getenv(\"FLASH_ATTENTION_ENABLE_OPCHECK\", \"FALSE\") == \"TRUE\"\nENABLE_AUTOGRAD_CHECK = os.getenv(\"FLASH_ATTENTION_ENABLE_AUTOGRAD_CHECK\", \"FALSE\") == \"TRUE\"\n\nCOMPILED_HDIMS = (\n    []\n    + ([64] if not DISABLE_HDIM64 else [])\n    + ([96] if not DISABLE_HDIM96 else [])\n    + ([128] if not DISABLE_HDIM128 else [])\n    + ([192] if not DISABLE_HDIM192 else [])\n    + ([256] if not DISABLE_HDIM256 else [])\n)\n\ndef should_test_backward(args, kwargs):\n    v = args[2]\n    num_splits = kwargs.get(\"num_splits\", 1)\n    dtype = v.dtype\n    has_qv = V_colmajor = False  # no test runs this with V_colmajor or has_qv == True\n    attention_chunk = kwargs.get(\"attention_chunk\")\n    dv = v.size(-1)\n\n    if (\n        ENABLE_AUTOGRAD_CHECK\n        and not DISABLE_BACKWARD\n        and dtype != torch.float8_e4m3fn\n        and not V_colmajor\n        and not has_qv\n        and not dv > 256\n        and not attention_chunk != 0\n        and num_splits > 0  # we don't support num_split == 0 on torch.compile yet\n    ):\n        return True\n    return False\n\n\ndef should_run_schema_check(args, kwargs):\n    v = args[2]\n    if v.dtype == torch.float8_e4m3fn:\n        return False\n    return True\n\n\ndef should_run_fake_check(args, kwargs):\n    if 'num_splits' in kwargs:\n        return kwargs['num_splits'] > 0\n    return True\n\n\ndef run_opcheck(fn):\n    def wrapper(*args, **kwargs):\n        if should_run_schema_check(args, kwargs):\n            safe_schema_check(fn, args, kwargs)\n\n        if should_run_fake_check(args, kwargs):\n            safe_fake_check(fn, args, kwargs)\n\n        if should_test_backward(args, kwargs):\n            # Expensive check\n            safe_aot_autograd_check(fn, args, kwargs, dynamic=False)\n            safe_aot_autograd_check(fn, args, kwargs, dynamic=True)\n        return fn(*args, **kwargs)\n    return wrapper\n\n\nif ENABLE_OPCHECK:\n    flash_attn_func = run_opcheck(flash_attn_func)\n    flash_attn_varlen_func = run_opcheck(flash_attn_varlen_func)\n\n\n# @pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])\n@pytest.mark.parametrize(\"dtype\", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else []))\n# @pytest.mark.parametrize(\"dtype\", [torch.bfloat16])\n# @pytest.mark.parametrize(\"dtype\", [torch.float8_e4m3fn])\n@pytest.mark.parametrize(\"mha_type\", [\"mha\", \"mqa\", \"gqa\"])\n# @pytest.mark.parametrize(\"mha_type\", [\"mha\"])\n@pytest.mark.parametrize(\"has_qv\", [False, True])\n# @pytest.mark.parametrize(\"has_qv\", [True])\n# @pytest.mark.parametrize(\"deterministic\", [False, True])\n@pytest.mark.parametrize(\"deterministic\", [False])\n@pytest.mark.parametrize(\"softcap\", [0.0] + ([15.0] if not DISABLE_SOFTCAP else []))\n# @pytest.mark.parametrize(\"softcap\", [0.0])\n@pytest.mark.parametrize(\"local\", [False] + ([True] if not DISABLE_LOCAL else []))\n# @pytest.mark.parametrize(\"local\", [False])\n@pytest.mark.parametrize(\"causal\", [False, True])\n# @pytest.mark.parametrize(\"causal\", [True])\n# @pytest.mark.parametrize(\"V_colmajor\", [False, True])\n@pytest.mark.parametrize(\"V_colmajor\", [False])\n# @pytest.mark.parametrize(\"d\", [32, 64, 96, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256])\n# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])\n# @pytest.mark.parametrize('d', [56, 80])\n# @pytest.mark.parametrize(\"d\", [64, 128, 256])\n# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128])\n# @pytest.mark.parametrize(\"d\", [64, 96, 128, 192])\n@pytest.mark.parametrize(\"d\", COMPILED_HDIMS)\n# @pytest.mark.parametrize(\"d\", [64])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_k\",\n    [\n        (1, 1),\n        (64, 128),\n        (128, 192),\n        (256, 256),\n        (239, 1),\n        (799, 3),\n        (113, 203),\n        (113, 128),\n        (128, 217),\n        (113, 211),\n        (108, 256),\n        (256, 512),\n        (384, 256),\n        (640, 128),\n        (512, 256),\n        (1024, 1024),\n        (1023, 1024),\n        (1024, 1023),\n        (4096, 4096),\n        (4224, 4224),\n    ],\n)\n# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])\ndef test_flash_attn_output(\n        seqlen_q, seqlen_k, d, causal, local, softcap, V_colmajor, deterministic, has_qv, mha_type, dtype\n):\n    if V_colmajor and (seqlen_k % 16 != 0 or dtype != torch.float8_e4m3fn):\n        pytest.skip(\"V_colmajor requires seqlen_k to be a multiple of 16 and dtype to be float8_e4m3fn\")\n    if has_qv and (d != 64 or dtype == torch.float8_e4m3fn):\n        pytest.skip(\"Has Qv requires hdim 64 and dtype to be float16 or bfloat16 (not float8_e4m3fn)\")\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(0)\n    # batch_size = 40\n    # nheads = 16\n    batch_size = 9 if seqlen_k <= 2048 else 2\n    # batch_size = 1\n    nheads = 6\n    # nheads = 1\n    nheads_kv = nheads if mha_type == \"mha\" else (2 if mha_type == \"gqa\" else 1)\n    dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype\n    dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d])\n    if dtype == torch.float8_e4m3fn:\n        dv_vals = [d]\n    if has_qv:\n        dv_vals = [256, 512]\n    attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if not DISABLE_LOCAL else [0]\n    for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals):\n        print(f\"{dv = }, {attention_chunk = }\")\n        q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref)\n        if softcap > 0.0:\n            # Ensure the values of qk are at least within softcap range.\n            q_ref = (q_ref * softcap / 4)\n        q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_()\n        k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_()\n        v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_()\n        if has_qv:\n            qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref)\n        else:\n            qv_ref = None\n        # Put window_size after QKV randn so that window_size changes from test to test\n        window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)).tolist()\n        # window_size = (-1, -1) if not local else (16, 0)\n        if dtype == torch.float8_e4m3fn:\n            q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)]\n        else:\n            q_descale, k_descale, v_descale = None, None, None\n        q, k, v = [x.detach().to(dtype).requires_grad_() for x in (q_ref, k_ref, v_ref)]\n        qv = qv_ref.detach().to(dtype).requires_grad_() if has_qv else None\n        if V_colmajor:\n            v = rearrange(rearrange(v.detach(), \"b s h d -> b h d s\").contiguous(), \"b h d s -> b s h d\").requires_grad_()\n        out_ref, attn_ref = attention_ref(\n            q_ref,\n            k_ref,\n            v_ref,\n            None,\n            None,\n            causal=causal,\n            qv=qv_ref,\n            q_descale=q_descale, k_descale=k_descale, v_descale=v_descale,\n            window_size=window_size,\n            attention_chunk=attention_chunk,\n            softcap=softcap\n        )\n        out_pt, attn_pt = attention_ref(\n            q_ref,\n            k_ref,\n            v_ref,\n            None,\n            None,\n            causal=causal,\n            qv=qv_ref,\n            q_descale=q_descale, k_descale=k_descale, v_descale=v_descale,\n            window_size=window_size,\n            attention_chunk=attention_chunk,\n            softcap=softcap,\n            upcast=False,\n            reorder_ops=True,\n            intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None,\n        )\n\n        # qk = torch.einsum('bshd,bthd->bhst', q_ref, k_ref).float()\n        # if qv is not None:\n        #     qk += torch.einsum('bshd,bthd->bhst', qv_ref, v_ref).float()\n        # m = qk.amax(-1, keepdim=True)\n        # s_tmp = torch.exp((qk - m) / math.sqrt(d))\n        # exp_sum = s_tmp.sum(-1)\n        # qk = torch.einsum('bthd,bshd->bhts', q_ref.float() / math.sqrt(d), k_ref.float())\n        # lse_ref = torch.logsumexp(qk, dim=-1)\n\n        # Numerical error if we just do any arithmetic on out_ref\n        fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item()\n        rtol = 2 if softcap == 0.0 else 3\n\n        print(f\"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}\")\n        print(f\"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}\")\n        pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False]\n        num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1]\n        for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals):\n            print(f\"{pack_gqa = }, {num_splits = }\")\n            out = flash_attn_func(\n                q,\n                k,\n                v,\n                causal=causal,\n                qv=qv,\n                q_descale=q_descale, k_descale=k_descale, v_descale=v_descale,\n                window_size=window_size,\n                attention_chunk=attention_chunk,\n                softcap=softcap,\n                pack_gqa=pack_gqa,\n                num_splits=num_splits\n            )\n            print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n            print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n            # if not causal:\n            #     print(f\"LSE max diff: {(lse - lse_ref).abs().max().item()}\")\n            # breakpoint()\n\n            # Check that FlashAttention's numerical error is at most twice the numerical error\n            # of a Pytorch implementation.\n            assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol\n\n        if (\n            not DISABLE_BACKWARD \n            and dtype != torch.float8_e4m3fn \n            and not V_colmajor \n            and not has_qv\n            and not dv > 256\n            and not attention_chunk != 0\n        ):\n            g = torch.randn_like(out)\n            do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2)\n            # import flash_attn_3_cuda\n            # dq, dk, dv, softmax_d, dq_accum, dk_accum, dv_accum = flash_attn_3_cuda.bwd(\n            #     g,\n            #     q,\n            #     k,\n            #     v,\n            #     out,\n            #     lse,\n            #     None,\n            #     None,\n            #     None,\n            #     d ** (-0.5),\n            #     causal,\n            #     window_size[0], window_size[1],\n            #     softcap,\n            #     deterministic,\n            #     0,  # sm_margin\n            # )\n            dq, dk, dv = torch.autograd.grad(out, (q, k, v), g)\n            # print(f\"dO_O max diff: {(softmax_d - do_o).abs().max().item()}\")\n            # assert (softmax_d - do_o).abs().max().item() <= 1e-5\n            # assert dq_accum.abs().max().item() == 0.0\n\n            # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float())\n            # P = torch.softmax(qk, -1)\n            # dP = P * (dS - do_o.transpose(1, 2).unsqueeze(1))\n            # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float())\n            # dV = torch.einsum('bhts,bthd->bshd', P, g.float())\n            # dK = torch.einsum('bhts,bthd->bshd', dP, q.float())\n\n            # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g)\n            dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g)\n            dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g)\n            print(f\"dQ max diff: {(dq - dq_ref).abs().max().item()}\")\n            print(f\"dK max diff: {(dk - dk_ref).abs().max().item()}\")\n            print(f\"dV max diff: {(dv - dv_ref).abs().max().item()}\")\n            print(f\"dQ mean diff: {(dq - dq_ref).abs().mean().item()}\")\n            print(f\"dK mean diff: {(dk - dk_ref).abs().mean().item()}\")\n            print(f\"dV mean diff: {(dv - dv_ref).abs().mean().item()}\")\n            print(f\"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}\")\n            print(f\"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}\")\n            print(f\"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}\")\n            print(f\"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}\")\n            print(f\"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}\")\n            print(f\"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}\")\n            # breakpoint()\n            dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4)\n            assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol\n            dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4)\n            assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol\n            dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4)\n            assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol\n\n\n# @pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])\n@pytest.mark.parametrize(\"dtype\", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else []))\n# @pytest.mark.parametrize(\"dtype\", [torch.bfloat16])\n# @pytest.mark.parametrize(\"dtype\", [torch.float8_e4m3fn])\n@pytest.mark.parametrize(\"mha_type\", [\"mha\", \"mqa\", \"gqa\"])\n# @pytest.mark.parametrize(\"mha_type\", [\"mha\"])\n@pytest.mark.parametrize(\"has_qv\", [False, True])\n# @pytest.mark.parametrize(\"has_qv\", [False])\n# @pytest.mark.parametrize(\"deterministic\", [False, True])\n@pytest.mark.parametrize(\"deterministic\", [False])\n@pytest.mark.parametrize(\"softcap\", [0.0] + ([15.0] if not DISABLE_SOFTCAP else []))\n# @pytest.mark.parametrize(\"softcap\", [0.0])\n@pytest.mark.parametrize(\"local\", [False] + ([True] if not DISABLE_LOCAL else []))\n# @pytest.mark.parametrize(\"local\", [False])\n@pytest.mark.parametrize(\"causal\", [False, True])\n# @pytest.mark.parametrize(\"causal\", [True])\n@pytest.mark.parametrize(\"add_unused_qkv\", [False, True])\n# @pytest.mark.parametrize(\"add_unused_qkv\", [True])\n# @pytest.mark.parametrize(\"d\", [32, 64, 96, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256])\n# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])\n# @pytest.mark.parametrize('d', [56, 80])\n# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128])\n# @pytest.mark.parametrize(\"d\", [64, 96, 128])\n@pytest.mark.parametrize(\"d\", COMPILED_HDIMS)\n# @pytest.mark.parametrize(\"d\", [64])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_k\",\n    [\n        (1, 1),\n        (1, 3),\n        (2, 1),\n        (511, 1),\n        (3, 513),\n        (64, 128),\n        (128, 128),\n        (256, 256),\n        (113, 203),\n        (128, 217),\n        (113, 211),\n        (108, 256),\n        (256, 512),\n        (307, 256),\n        (640, 128),\n        (512, 256),\n        (1024, 1024),\n        (1023, 1024),\n        (1024, 1023),\n        (1024, 1024),\n        (2048, 2048),\n        (4096, 4096),\n    ],\n)\ndef test_flash_attn_varlen_output(\n    seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv, mha_type, dtype,\n):\n    if has_qv and (d != 64 or dtype == torch.float8_e4m3fn):\n        pytest.skip(\"Has Qv requires hdim 64 and dtype to be float16 or bfloat16 (not float8_e4m3fn)\")\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local))\n    # batch_size = 40\n    # nheads = 16\n    batch_size = 9 if seqlen_q <= 2048 else 2\n    # batch_size = 32\n    nheads = 6\n    nheads_kv = nheads if mha_type == \"mha\" else (2 if mha_type == \"gqa\" else 1)\n    # batch_size = 2\n    # nheads = 1\n    # nheads_kv = nheads\n    \n    dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype\n    dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d])\n    if dtype == torch.float8_e4m3fn:\n        dv_vals = [d]\n    if has_qv:\n        dv_vals = [256, 512]\n    attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k and not DISABLE_LOCAL else [0]\n    for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals):\n        print(f\"{dv = }, {attention_chunk = }\")\n        q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref)\n        if softcap > 0.0:\n            # Ensure the values of qk are at least within softcap range.\n            q_ref = (q_ref * softcap / 4).detach().requires_grad_()\n        q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_()\n        k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_()\n        v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_()\n        if has_qv:\n            qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref)\n        else:\n            qv_ref = None\n        # Put window_size after QKV randn so that window_size changes from test to test\n        window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))\n        if dtype == torch.float8_e4m3fn:\n            q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)]\n        else:\n            q_descale, k_descale, v_descale = None, None, None\n        q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)]\n        qv = qv_ref.detach() if has_qv else None\n        query_padding_mask = generate_random_padding_mask(\n            seqlen_q, batch_size, device, mode=\"random\", zero_lengths=False\n        )\n        key_padding_mask = generate_random_padding_mask(\n            seqlen_k, batch_size, device, mode=\"random\", zero_lengths=True\n        )\n\n        def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device):\n            if add_unused:\n                another_mask = generate_random_padding_mask(max_seq_len, bs, device)\n                attn_mask = torch.logical_and(padding_mask, another_mask)\n                unused_mask = torch.logical_xor(\n                    torch.logical_or(padding_mask, another_mask), attn_mask\n                )\n            else:\n                attn_mask = padding_mask\n                unused_mask = None\n            return attn_mask, unused_mask\n\n        query_padding_mask, query_unused_mask = _gen_unused_masks(\n            query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device\n        )\n        key_padding_mask, key_unused_mask = _gen_unused_masks(\n            key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device\n        )\n\n        (\n            q_unpad,\n            k_unpad,\n            v_unpad,\n            qv_unpad,\n            cu_seqlens_q,\n            cu_seqlens_k,\n            seqused_q,\n            seqused_k,\n            max_seqlen_q,\n            max_seqlen_k,\n            q,\n            k,\n            v,\n            qv,\n            output_pad_fn,\n            dq_pad_fn,\n            dk_pad_fn,\n        ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, qv=qv, kvpacked=False,\n                        query_unused_mask=query_unused_mask, key_unused_mask=key_unused_mask)\n        q_unpad, k_unpad, v_unpad = [x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)]\n        out_ref, attn_ref = attention_ref(\n            q_ref,\n            k_ref,\n            v_ref,\n            query_padding_mask,\n            key_padding_mask,\n            causal=causal,\n            qv=qv_ref,\n            q_descale=q_descale, k_descale=k_descale, v_descale=v_descale,\n            window_size=window_size,\n            attention_chunk=attention_chunk,\n            softcap=softcap\n        )\n        out_pt, attn_pt = attention_ref(\n            q_ref,\n            k_ref,\n            v_ref,\n            query_padding_mask,\n            key_padding_mask,\n            causal=causal,\n            qv=qv_ref,\n            q_descale=q_descale, k_descale=k_descale, v_descale=v_descale,\n            window_size=window_size,\n            attention_chunk=attention_chunk,\n            softcap=softcap,\n            upcast=False,\n            reorder_ops=True,\n            intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None,\n        )\n\n\n        print(f\"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}\")\n        print(f\"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}\")\n\n        if query_unused_mask is not None:\n            q_zero_masking = rearrange(query_unused_mask, \"b s -> b s 1 1\")\n\n        # Numerical error if we just do any arithmetic on out_ref\n        fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item()\n        rtol = 2 if softcap == 0.0 else 3\n\n        pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False]\n        # pack_gqa_vals = [False]\n        num_splits_vals = [1, 3, 0] if not DISABLE_SPLIT else [1]\n        # num_splits_vals = [1]\n        # print(\"cu_seqlens_q: \", cu_seqlens_q)\n        # print(\"cu_seqlens_k: \", cu_seqlens_k)\n        # print(\"seqused_q: \", seqused_q)\n        # print(\"seqused_k: \", seqused_k)\n        for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals):\n            print(f\"{pack_gqa = }, {num_splits = }\")\n            out_unpad = flash_attn_varlen_func(\n                q_unpad,\n                k_unpad,\n                v_unpad,\n                cu_seqlens_q,\n                cu_seqlens_k,\n                max_seqlen_q,\n                max_seqlen_k,\n                seqused_q=seqused_q,\n                seqused_k=seqused_k,\n                causal=causal,\n                qv=qv_unpad,\n                q_descale=q_descale,\n                k_descale=k_descale, v_descale=v_descale,\n                window_size=window_size,\n                attention_chunk=attention_chunk,\n                softcap=softcap,\n                pack_gqa=pack_gqa,\n                num_splits=num_splits,\n            )\n            out = output_pad_fn(out_unpad)\n            if query_unused_mask is not None:\n                out.masked_fill_(q_zero_masking, 0.0)\n            print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n            print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n            # if not causal:\n            #     print(f\"LSE max diff: {(lse - lse_ref).abs().max().item()}\")\n            # breakpoint()\n\n            # Check that FlashAttention's numerical error is at most 3x the numerical error\n            # of a Pytorch implementation.\n            assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol\n\n\n        if (\n            not DISABLE_BACKWARD \n            and dtype != torch.float8_e4m3fn \n            and not has_qv\n            and not dv > 256\n            and not attention_chunk != 0\n        ):\n            g_unpad = torch.randn_like(out_unpad)\n            do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2)\n            # import flash_attn_3_cuda\n            # dq_unpad, dk_unpad, dv_unpad, softmax_d, dq_accum, lse_log2 = flash_attn_3_cuda.bwd_varlen(\n            #     g_unpad,\n            #     q_unpad,\n            #     k_unpad,\n            #     v_unpad,\n            #     out_unpad,\n            #     lse,\n            #     None,\n            #     None,\n            #     None,\n            #     cu_seqlens_q,\n            #     cu_seqlens_k,\n            #     None, None,\n            #     max_seqlen_q,\n            #     max_seqlen_k,\n            #     d ** (-0.5),\n            #     causal,\n            #     window_size[0], window_size[1],\n            #     softcap,\n            #     deterministic,\n            #     0,  # sm_margin\n            # )\n            dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad(out_unpad, (q_unpad, k_unpad, v_unpad), g_unpad)\n            dq = dq_pad_fn(dq_unpad)\n            dk = dk_pad_fn(dk_unpad)\n            dv = dk_pad_fn(dv_unpad)\n            if key_unused_mask is not None:\n                k_zero_masking = rearrange(key_unused_mask, \"b s -> b s 1 1\")\n                dk.masked_fill_(k_zero_masking, 0.0)\n                dv.masked_fill_(k_zero_masking, 0.0)\n            if query_unused_mask is not None:\n                dq.masked_fill_(q_zero_masking, 0.0)\n            # print(f\"dO_O max diff: {(softmax_d - do_o).abs().max().item()}\")\n            # assert (softmax_d - do_o).abs().max().item() <= 1e-5\n            # assert dq_accum.abs().max().item() == 0.0\n            g = output_pad_fn(g_unpad)\n\n            # qk = torch.einsum('bthd,bshd->bhts', q / (d ** 0.5), k).float()\n            # qk = torch.masked_fill(qk, rearrange(~key_padding_mask, \"b s -> b 1 1 s\"), float(\"-inf\"))\n            # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float())\n            # P = torch.softmax(qk, -1)\n            # dP = P * (dS - (g.float() * out.float()).sum(-1).transpose(1, 2).unsqueeze(-1))\n            # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float())\n            # dV = torch.einsum('bhts,bthd->bshd', P, g.float())\n            # dK = torch.einsum('bhts,bthd->bshd', dP, q.float())\n\n\n            # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g)\n            dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g)\n            dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g)\n            print(f\"dQ max diff: {(dq - dq_ref).abs().max().item()}\")\n            print(f\"dK max diff: {(dk - dk_ref).abs().max().item()}\")\n            print(f\"dV max diff: {(dv - dv_ref).abs().max().item()}\")\n            print(f\"dQ mean diff: {(dq - dq_ref).abs().mean().item()}\")\n            print(f\"dK mean diff: {(dk - dk_ref).abs().mean().item()}\")\n            print(f\"dV mean diff: {(dv - dv_ref).abs().mean().item()}\")\n            print(f\"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}\")\n            print(f\"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}\")\n            print(f\"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}\")\n            print(f\"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}\")\n            print(f\"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}\")\n            print(f\"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}\")\n            # breakpoint()\n            dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4)\n            assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol\n            dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4)\n            assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol\n            dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4)\n            assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol\n\n\n# @pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])\n@pytest.mark.parametrize(\"dtype\", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else []))\n# @pytest.mark.parametrize(\"dtype\", [torch.bfloat16])\n# @pytest.mark.parametrize(\"dtype\", [torch.float8_e4m3fn])\n@pytest.mark.parametrize(\"mha_type\", [\"mha\", \"mqa\", \"gqa\"])\n# @pytest.mark.parametrize(\"mha_type\", [\"mha\"])\n@pytest.mark.parametrize(\"new_kv\", [False] + ([True] if not DISABLE_APPENDKV else []))\n# @pytest.mark.parametrize(\"new_kv\", [False])\n@pytest.mark.parametrize(\"causal,local\", [(False, False), (True, False)] + ([(False, True)] if not DISABLE_LOCAL else []))\n# @pytest.mark.parametrize(\"causal,local\", [(False, False), (True, False)])\n# @pytest.mark.parametrize(\"causal,local\", [(True, False)])\n@pytest.mark.parametrize(\"seqlen_new_eq_seqlen_q\", [True, False] if not DISABLE_APPENDKV else [True])\n# @pytest.mark.parametrize(\"seqlen_new_eq_seqlen_q\", [False])\n# @pytest.mark.parametrize(\"has_rotary_seqlens\", [False, True])\n@pytest.mark.parametrize(\"has_rotary_seqlens\", [False])\n@pytest.mark.parametrize(\"rotary_interleaved\", [False, True] if not DISABLE_APPENDKV else [False])\n# @pytest.mark.parametrize(\"rotary_interleaved\", [False])\n@pytest.mark.parametrize(\"rotary_fraction\", [0.0, 0.5, 1.0] if (not DISABLE_APPENDKV) and (apply_rotary_emb is not None) else [0.0])\n# @pytest.mark.parametrize(\"rotary_fraction\", [0.0])\n@pytest.mark.parametrize(\"page_size\", [None] + ([1, 4, 128] if not DISABLE_PAGEDKV else []))\n# @pytest.mark.parametrize(\"page_size\", [None])\n@pytest.mark.parametrize(\"has_leftpad\", [False, True])\n# @pytest.mark.parametrize(\"has_leftpad\", [False])\n@pytest.mark.parametrize(\"has_batch_idx\", [False, True])\n# @pytest.mark.parametrize(\"has_batch_idx\", [True])\n@pytest.mark.parametrize(\"varlen_q\", [False, True])\n# @pytest.mark.parametrize(\"varlen_q\", [True])\n# @pytest.mark.parametrize(\"d\", [32, 59, 64, 80, 128, 256])\n# @pytest.mark.parametrize(\"d\", [32, 64, 96, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])\n# @pytest.mark.parametrize('d', [56, 80])\n@pytest.mark.parametrize(\"d\", [128])\n# @pytest.mark.parametrize(\"d\", [192])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_k\",\n    [\n        (1, 128),\n        (1, 339),\n        (3, 1024),\n        (64, 800),\n        (64, 256),\n        (3, 799),\n        (64, 2048),\n        (16, 20000),\n        # (1, 128 * 1024),\n        # (16, 128 * 1024),\n        (128, 128),\n        (256, 512),  # To test appending KV with more than 1 block\n        (2048, 3577),  # Enough tile to test persistent scheduler\n    ],\n)\n# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])\ndef test_flash_attn_kvcache(\n    seqlen_q,\n    seqlen_k,\n    d,\n    varlen_q,\n    has_batch_idx,\n    has_leftpad,\n    page_size,\n    rotary_fraction,\n    rotary_interleaved,\n    has_rotary_seqlens,\n    seqlen_new_eq_seqlen_q,\n    causal,\n    local,\n    new_kv,\n    mha_type,\n    dtype,\n):\n    if page_size is not None and seqlen_k % page_size != 0:\n        pytest.skip()\n    if seqlen_q > seqlen_k and new_kv:\n        pytest.skip()\n    if not new_kv and rotary_fraction > 0.0:\n        pytest.skip()\n    if rotary_fraction == 0.0 and has_rotary_seqlens:\n        pytest.skip()\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 5\n    # batch_size = 1\n    batch_size_cache = batch_size if not has_batch_idx else batch_size * 2\n    nheads = 6\n    # nheads = 1\n    # rotary_dim must be a multiple of 16, and must be <= d\n    rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16\n    nheads_k = nheads if mha_type == \"mha\" else (1 if mha_type == \"mqa\" else 3)\n    assert nheads % nheads_k == 0\n    dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype\n    dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d])\n    if dtype == torch.float8_e4m3fn:\n        dv_vals = [d]\n    attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if (causal or local) and not DISABLE_LOCAL else [0]\n    for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals):\n        print(f\"{dv = }, {attention_chunk = }\")\n        has_qv = d == 64 and dv >= 256\n        q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref)\n        if has_qv:\n            qv = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref)\n        else:\n            qv = None\n        if varlen_q:\n            query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode=\"random\")\n            q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, *rest = unpad_input(q, query_padding_mask)\n            output_pad_fn = lambda output_unpad: pad_input(\n                output_unpad, indices_q, batch_size, seqlen_q\n            )\n            qv_unpad = rearrange(qv, \"b s ... -> (b s) ...\")[indices_q] if has_qv else None\n        else:\n            query_padding_mask = None\n            q_unpad = q\n            qv_unpad = qv\n            cu_seqlens_q, max_seqlen_q = None, None\n        # Put window_size after QKV randn so that window_size changes from test to test\n        window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))\n\n        seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item()\n        cu_seqlens_k_new = None\n        key_new_padding_mask = None\n        if new_kv:\n            k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref)\n            v = torch.randn(batch_size, seqlen_new, nheads_k, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref)\n            if varlen_q:  # k & v are also varlen\n                key_new_padding_mask = generate_random_padding_mask(seqlen_new, batch_size, device, mode=\"random\")\n                k_unpad, indices_k, cu_seqlens_k_new, *rest = unpad_input(k, key_new_padding_mask)\n                v_unpad, *rest = unpad_input(v, key_new_padding_mask)\n            else:\n                k_unpad, v_unpad = k, v\n        else:\n            k, v, k_unpad, v_unpad = None, None, None, None\n        if page_size is None:\n            k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref)\n            v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref)\n            page_table = None\n        else:\n            (\n                k_cache,\n                v_cache,\n                page_table,\n                k_cache_paged,\n                v_cache_paged,\n                num_blocks,\n            ) = _generate_block_kvcache(\n                seqlen_k, page_size, batch_size_cache, nheads_k, d, dv, device, dtype, dtype_ref\n            )\n        cache_seqlens = torch.randint(\n            0 if new_kv else 1,\n            # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough\n            (\n                (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1)\n                if new_kv\n                else (seqlen_k + 1)\n            ),\n            (batch_size,),\n            dtype=torch.int32,\n            device=device,\n        )\n        if has_leftpad:\n            cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device)\n                                    if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device)\n                                    for i in range(batch_size)])\n        else:\n            cache_leftpad = None\n        if has_batch_idx:\n            cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[\n                :batch_size\n            ]\n        else:\n            cache_batch_idx = None\n        arange = rearrange(torch.arange(seqlen_k, device=device), \"s -> 1 s\")\n        cache_seqlens_expanded = rearrange(cache_seqlens, \"b -> b 1\")\n        if not new_kv:\n            key_padding_mask = arange < cache_seqlens_expanded\n        else:\n            k_new_seqlens = key_new_padding_mask.sum(-1, keepdims=True) if varlen_q else seqlen_new\n            key_padding_mask = arange < cache_seqlens_expanded + k_new_seqlens\n        if has_leftpad:\n            key_padding_mask = torch.logical_and(\n                key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k)\n            )\n        # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)\n        rotary_seqlens = cache_seqlens if not has_rotary_seqlens else cache_seqlens // 2\n        if rotary_dim > 0:\n            angle = (\n                torch.rand(\n                    seqlen_k if page_size is None else num_blocks * page_size,\n                    rotary_dim // 2,\n                    device=device,\n                )\n                * 2\n                * math.pi\n            )\n            cos = torch.cos(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref)\n            sin = torch.sin(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref)\n            if causal or local:\n                q_ro = apply_rotary_emb(\n                    q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved\n                )\n            else:\n                q_ro = rearrange(\n                    apply_rotary_emb(\n                        rearrange(q, \"b s h d -> b 1 (s h) d\"),\n                        cos,\n                        sin,\n                        seqlen_offsets=rotary_seqlens,\n                        interleaved=rotary_interleaved,\n                    ),\n                    \"b 1 (s h) d -> b s h d\",\n                    s=seqlen_q,\n                )\n            # q_ro = q\n            k_ro = apply_rotary_emb(\n                k, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved\n            )\n        else:\n            cos, sin = None, None\n            q_ro, k_ro = q, k\n        # k_cache[:, 64:] = -1\n        k_cache_ref = (k_cache if not has_batch_idx else k_cache[cache_batch_idx]).clone()\n        v_cache_ref = (v_cache if not has_batch_idx else v_cache[cache_batch_idx]).clone()\n        if new_kv:\n            update_mask = torch.logical_and(\n                cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + k_new_seqlens\n            )\n            k_to_update = rearrange(k_ro, \"b s ... -> (b s) ...\")\n            v_to_update = rearrange(v, \"b s ... -> (b s) ...\")\n            if varlen_q:\n                k_to_update = k_to_update[indices_k]\n                v_to_update = v_to_update[indices_k]\n            k_cache_ref[update_mask] = k_to_update\n            v_cache_ref[update_mask] = v_to_update\n        k_cache_rep = repeat(k_cache_ref, \"b s h d -> b s (h g) d\", g=nheads // nheads_k)\n        v_cache_rep = repeat(v_cache_ref, \"b s h d -> b s (h g) d\", g=nheads // nheads_k)\n        out_ref, _ = attention_ref(\n            q_ro,\n            k_cache_rep,\n            v_cache_rep,\n            query_padding_mask,\n            key_padding_mask,\n            causal=causal,\n            qv=qv,\n            window_size=window_size,\n            attention_chunk=attention_chunk,\n            key_leftpad=cache_leftpad,\n        )\n        out_pt, _ = attention_ref(\n            q_ro,\n            k_cache_rep,\n            v_cache_rep,\n            query_padding_mask,\n            key_padding_mask,\n            causal=causal,\n            qv=qv,\n            window_size=window_size,\n            attention_chunk=attention_chunk,\n            upcast=False,\n            reorder_ops=True,\n            key_leftpad=cache_leftpad,\n            intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None\n        )\n        q = q.to(dtype)\n        q_unpad = q_unpad.to(dtype) if varlen_q else None\n        k_cache = k_cache.to(dtype)\n        v_cache = v_cache.to(dtype)\n        k_cache_paged = k_cache_paged.to(dtype) if page_size is not None else None\n        v_cache_paged = v_cache_paged.to(dtype) if page_size is not None else None\n        k = k.to(dtype) if k is not None else None\n        v = v.to(dtype) if v is not None else None\n        k_unpad = k_unpad.to(dtype) if k_unpad is not None else None\n        v_unpad = v_unpad.to(dtype) if v_unpad is not None else None\n        qv = qv.to(dtype) if qv is not None else None\n        qv_unpad = qv_unpad.to(dtype) if (varlen_q and qv is not None) else None\n        cos = cos.to(dtype) if cos is not None else None\n        sin = sin.to(dtype) if sin is not None else None\n        k_cache_saved = k_cache.clone() if page_size is None else k_cache_paged.clone()\n        v_cache_saved = v_cache.clone() if page_size is None else v_cache_paged.clone()\n        num_splits_vals = [1, 3, 0] if not DISABLE_SPLIT else [1]\n        precompute_metadata_vals = [False, True]\n        for num_splits, precompute_metadata in itertools.product(num_splits_vals, precompute_metadata_vals):\n            print(f\"{num_splits = }, {precompute_metadata = }\")\n            if precompute_metadata:\n                scheduler_metadata = get_scheduler_metadata(\n                    batch_size,\n                    max_seqlen_q if varlen_q else seqlen_q,\n                    seqlen_k if page_size is None else page_table.shape[1] * page_size,\n                    nheads, nheads_k, d,\n                    cache_seqlens, q.dtype, headdim_v=dv, cu_seqlens_q=cu_seqlens_q,\n                    cu_seqlens_k_new=cu_seqlens_k_new, cache_leftpad=cache_leftpad,\n                    max_seqlen_k_new=seqlen_new, page_size=page_size,\n                    causal=causal, window_size=window_size, attention_chunk=attention_chunk,\n                    num_splits=num_splits,\n                )\n            else:\n                scheduler_metadata = None\n            # Repeat to test metadata reuse\n            for _ in range(1 if not precompute_metadata else 2):\n                if page_size is None:\n                    k_cache.copy_(k_cache_saved)\n                    v_cache.copy_(v_cache_saved)\n                else:\n                    k_cache_paged.copy_(k_cache_saved)\n                    v_cache_paged.copy_(v_cache_saved)\n                out, lse, *rest = flash_attn_with_kvcache(\n                    q if not varlen_q else q_unpad,\n                    k_cache if page_size is None else k_cache_paged,\n                    v_cache if page_size is None else v_cache_paged,\n                    k if not new_kv or not varlen_q else k_unpad,\n                    v if not new_kv or not varlen_q else v_unpad,\n                    qv=qv if not varlen_q else qv_unpad,\n                    rotary_cos=cos,\n                    rotary_sin=sin,\n                    cache_seqlens=cache_seqlens,\n                    cache_batch_idx=cache_batch_idx,\n                    cache_leftpad=cache_leftpad,\n                    page_table=page_table,\n                    cu_seqlens_q=cu_seqlens_q,\n                    cu_seqlens_k_new=cu_seqlens_k_new,\n                    max_seqlen_q=max_seqlen_q,\n                    rotary_seqlens=rotary_seqlens,\n                    causal=causal,\n                    window_size=window_size,\n                    attention_chunk=attention_chunk,\n                    rotary_interleaved=rotary_interleaved,\n                    scheduler_metadata=scheduler_metadata,\n                    num_splits=num_splits,\n                    return_softmax_lse=True,\n                )\n                if varlen_q:\n                    out = output_pad_fn(out)\n                # out = flash_attn_with_kvcache(\n                #     q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size\n                # )\n                # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size)\n                # qk = torch.einsum(\"bqhd,bkhd->bhqk\", q, k_cache_ref)\n                # m = qk.amax(-1, keepdim=True)\n                # s_tmp = torch.exp((qk - m) / math.sqrt(d))\n                # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref)\n                # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)\n                # probs = torch.softmax(qk, dim=-1)\n                print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n                print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n                print(f\"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}\")\n                print(f\"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}\")\n                # breakpoint()\n\n                # Check that FlashAttention's numerical error is at most twice the numerical error\n                # of a Pytorch implementation.\n                if new_kv:\n                    if page_size is None:\n                        k_cache_select = (\n                            k_cache.to(dtype_ref) if not has_batch_idx else k_cache.to(dtype_ref)[cache_batch_idx]\n                        )\n                        v_cache_select = (\n                            v_cache.to(dtype_ref) if not has_batch_idx else v_cache.to(dtype_ref)[cache_batch_idx]\n                        )\n                    else:\n                        k_cache_select = rearrange(\n                            k_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()],\n                            \"(b nblocks) block_size ... -> b (nblocks block_size) ...\",\n                            b=batch_size,\n                        )[:, :seqlen_k].to(dtype_ref)\n                        v_cache_select = rearrange(\n                            v_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()],\n                            \"(b nblocks) block_size ... -> b (nblocks block_size) ...\",\n                            b=batch_size,\n                        )[:, :seqlen_k].to(dtype_ref)\n                    k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref)\n                    v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref)\n                    if dtype is not torch.float8_e4m3fn:\n                        assert torch.equal(v_cache_select, v_cache_ref)\n                    else:\n                        assert torch.allclose(v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3)\n                    # breakpoint()\n                    # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn:\n                    if rotary_dim == 0:\n                        assert torch.equal(k_cache_select, k_cache_ref)\n                    else:\n                        # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3):\n                        #     breakpoint()\n                        if dtype is not torch.float8_e4m3fn:\n                            assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3)\n                        else:\n                            assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1)\n                mult = 4 if dtype == torch.float8_e4m3fn else 2\n                assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5\n                mult_mean = 3 if dtype == torch.float8_e4m3fn else 1.5\n                assert (out - out_ref).abs().mean().item() <= mult_mean * (out_pt - out_ref).abs().mean().item()\n\n\ndef _generate_block_kvcache(seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype, dtype_ref):\n    num_blocks = math.ceil(seqlen_k / page_size) * batch_size * 3\n    k_cache_paged = torch.randn(\n        num_blocks, page_size, nheads_k, d, device=device, dtype=dtype_ref\n    ).to(dtype).to(dtype_ref)\n    v_cache_paged = torch.randn(\n        num_blocks, page_size, nheads_k, dv, device=device, dtype=dtype_ref\n    ).to(dtype).to(dtype_ref)\n    page_table = rearrange(\n        torch.randperm(num_blocks, dtype=torch.int32, device=device),\n        \"(b nblocks) -> b nblocks\",\n        b=batch_size,\n    )\n    k_cache = rearrange(\n        k_cache_paged[page_table.flatten()],\n        \"(b nblocks) block_size ... -> b (nblocks block_size) ...\",\n        b=batch_size,\n    )[:, :seqlen_k]\n    v_cache = rearrange(\n        v_cache_paged[page_table.flatten()],\n        \"(b nblocks) block_size ... -> b (nblocks block_size) ...\",\n        b=batch_size,\n    )[:, :seqlen_k]\n    return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.bfloat16])\n@pytest.mark.parametrize(\"causal\", [False, True])\n# @pytest.mark.parametrize('causal', [False])\n@pytest.mark.parametrize('d', [128])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_k\",\n    [\n        (64, 8192),\n    ],\n)\ndef test_flash_attn_cluster(seqlen_q, seqlen_k, d, causal, dtype):\n    device = \"cuda\"\n    torch.random.manual_seed(0)\n    batch_size = 2\n    nheads = 16\n    nheads_kv = 4\n    # There was a bug where this would cause \"unspecified launch failure\" due to Cluster\n    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype)\n    k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype)\n    v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype)\n    for _ in range(100):\n        flash_attn_func(q, k, v, causal=causal)\n\n\n# @pytest.mark.parametrize(\"dtype\", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))\n@pytest.mark.parametrize(\"dtype\", [torch.bfloat16])\n@pytest.mark.parametrize(\"causal\", [False, True])\n# @pytest.mark.parametrize('causal', [False])\n@pytest.mark.parametrize(\"d\", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize(\"d\", [32, 40, 59, 64, 80, 96, 111, 128])\n# @pytest.mark.parametrize('d', [32, 56, 64, 80, 96, 128])\n# @pytest.mark.parametrize(\"d\", [32, 64, 96, 128, 160, 192])\n# @pytest.mark.parametrize('d', [80])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_k\",\n    [\n        (1, 239),\n        (239, 1),\n        (3, 799),\n        (799, 3),\n        (1024, 128),\n        (97, 97),\n        (128, 128),\n        (200, 200),\n        (256, 256),\n        (257, 257),\n        (384, 384),\n        (512, 512),\n        (768, 768),\n        (1024, 1024),\n        (2048, 2048),\n    ],\n)\ndef test_flash_attn_race_condition(seqlen_q, seqlen_k, d, causal, dtype):\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(0)\n    # Simulate under memory load\n    dummy = torch.empty(70 * 1024 ** 3, dtype=torch.uint8, device=device)\n    batch_size = 60  # Sometimes we need large batch size for the race conditions to trigger\n    nheads = 4\n    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    torch.random.manual_seed(42)\n    out0 = flash_attn_func(q, k, v, causal=causal)\n    g = torch.randn_like(out0)\n    dq0, dk0, dv0 = torch.autograd.grad(out0, (q, k, v), g)\n    # Numerical error if we just do any arithmetic on dq\n    dq_atol = 2 * ((dq0 + 0.3 - 0.3) - dq0).abs().max().item()\n\n    for i in range(1000):\n        torch.random.manual_seed(42)\n        out = flash_attn_func(q, k, v, causal=causal)\n        assert torch.equal(out, out0)\n        # assert torch.equal(lse, lse0)\n\n        dq, dk, dv = torch.autograd.grad(out, (q, k, v), g)\n        dq_equal = torch.allclose(dq, dq0, atol=dq_atol)\n        if not dq_equal:\n            print(f\"Iter {i}, {dq_atol = }, dQ max diff: {(dq - dq0).abs().max().item()}\")\n            # breakpoint()\n        assert torch.equal(dv, dv0)\n        assert torch.equal(dk, dk0)\n        assert dq_equal\n\n\ndef attention_combine_ref(out_partial, lse_partial):\n    \"\"\"\n    out_partial: (num_splits, batch_size, seqlen, nheads, d)\n    lse_partial: (num_splits, batch_size, nheads, seqlen)\n    \"\"\"\n    lse = torch.logsumexp(lse_partial, dim=0)\n    scale = torch.exp(lse_partial - lse)\n    scale = torch.where(torch.isinf(scale) | torch.isnan(scale), torch.zeros_like(scale), scale)\n    out = (scale.unsqueeze(-1) * out_partial).sum(0)\n    return out, lse\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float32, torch.float16, torch.bfloat16])\n# @pytest.mark.parametrize(\"dtype\", [torch.float32])\n# @pytest.mark.parametrize(\"d\", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])\n@pytest.mark.parametrize(\"d\", [64, 96, 128, 192, 256, 512])\n# @pytest.mark.parametrize(\"d\", [128])\n@pytest.mark.parametrize(\"seqlen\", [1, 2, 3, 32, 64, 256, 113, 108, 640, 1024])\n# @pytest.mark.parametrize(\"seqlen\", [12, 32, 64, 256, 112, 108, 640, 1024, 2048, 8192])\n# @pytest.mark.parametrize(\"seqlen\", [15])\n@pytest.mark.parametrize(\"num_splits\", [1, 2, 3, 5, 17, 32, 55, 97, 133])\n# @pytest.mark.parametrize(\"num_splits\", [1, 2, 3, 5, 11])\n# @pytest.mark.parametrize(\"num_splits\", [128])\ndef test_flash_attn_combine(num_splits, seqlen, d, dtype):\n    if DISABLE_SPLIT:\n        pytest.skip()\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(1)\n    batch_size = 5\n    nheads = 16\n    # batch_size = 1\n    # nheads = 1\n    out_partial = torch.randn(num_splits * 2, batch_size, nheads, seqlen, d, device=device, dtype=torch.float32).transpose(2, 3)[:num_splits]  # To test non-contiguous tensor\n    lse_partial = torch.randn(num_splits, batch_size, nheads * 2, seqlen, device=device, dtype=torch.float32).transpose(-1, -2)[:, :, :, :nheads]  # To test non-contiguous tensor\n    # To test short-circuiting based on num_splits\n    lse_partial[num_splits // 2:, :batch_size // 3] = -float(\"inf\")\n    out, lse = flash_attn_combine(out_partial, lse_partial, out_dtype=dtype)\n    out_ref, lse_ref = attention_combine_ref(out_partial, lse_partial)\n    out_pt = out_ref.to(dtype)\n\n    print(f\"LSE max diff: {(lse - lse_ref).abs().max().item()}\")\n    print(f\"LSE mean diff: {(lse - lse_ref).abs().mean().item()}\")\n    print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n    print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n    print(f\"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}\")\n    print(f\"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}\")\n    # breakpoint()\n\n    assert torch.allclose(lse, lse_ref, atol=1e-5, rtol=1e-5)\n    multiple = 2\n    assert ((out - out_ref).abs().max().item() <= multiple * (out_pt - out_ref).abs().max().item()) or torch.allclose(out, out_pt, atol=1e-5, rtol=1e-5)\n\n    # from flash_attn.utils.benchmark import pytorch_profiler\n    # # pytorch_profiler(torch.sum, lse_partial)\n    # pytorch_profiler(flash_attn_combine, out_partial, lse_partial)\n    # pytorch_profiler(torch.sum, out_partial)\n\ndef test_flash3_bw_compatibility() -> None:\n    # Let's try to always stay backward compatible! This will make life easier\n    # for downstream libaries, users, and exported models.\n    # 1/ Instead of removing arguments, error out if their value is no longer supported\n    # 2/ When adding arguments, add them at the end with a default value\n    assert torch.ops.flash_attn_3.fwd.default._schema.is_backward_compatible_with(parse_schema(\n        \"flash_attn_3::fwd(Tensor q, Tensor k, Tensor v, Tensor(k_new!)? k_new=None, \"\n        \"Tensor(v_new!)? v_new=None, Tensor? q_v=None, Tensor(out!)? out=None, \"\n        \"Tensor? cu_seqlens_q=None, Tensor? cu_seqlens_k=None, \"\n        \"Tensor? cu_seqlens_k_new=None, Tensor? seqused_q=None, Tensor? seqused_k=None, \"\n        \"int? max_seqlen_q=None, int? max_seqlen_k=None, Tensor? page_table=None, \"\n        \"Tensor? kv_batch_idx=None, Tensor? leftpad_k=None, Tensor? rotary_cos=None, Tensor? rotary_sin=None, \"\n        \"Tensor? seqlens_rotary=None, Tensor? q_descale=None, Tensor? k_descale=None, Tensor? v_descale=None, \"\n        \"float? softmax_scale=None, bool is_causal=False, int window_size_left=-1, int window_size_right=-1, \"\n        \"int attention_chunk=0, float softcap=0., bool is_rotary_interleaved=False, \"\n        \"Tensor? scheduler_metadata=None, int num_splits=0, bool? pack_gqa=None, int sm_margin=0) \"\n        \"-> (Tensor(out!), Tensor, Tensor, Tensor)\"\n    ))\n    assert torch.ops.flash_attn_3.bwd.default._schema.is_backward_compatible_with(parse_schema(\n        \"flash_attn_3::bwd(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, \"\n        \"Tensor(dq!)? dq=None, Tensor(dk!)? dk=None, Tensor(dv!)? dv=None, Tensor? cu_seqlens_q=None, \"\n        \"Tensor? cu_seqlens_k=None, Tensor? seqused_q=None, Tensor? seqused_k=None, int? max_seqlen_q=None, \"\n        \"int? max_seqlen_k=None, float? softmax_scale=None, bool is_causal=False, int window_size_left=-1, \"\n        \"int window_size_right=-1, float softcap=0., bool deterministic=False, int sm_margin=0) \"\n        \"-> (Tensor(dq!), Tensor(dk!), Tensor(dv!), Tensor, Tensor, Tensor, Tensor, Tensor)\"\n    ))\n    assert torch.ops.flash_attn_3.fwd_combine.default._schema.is_backward_compatible_with(parse_schema(\n        \"flash_attn_3::fwd_combine(Tensor out_partial, Tensor lse_partial, Tensor(out!)? out=None, \"\n        \"ScalarType? out_dtype=None) -> (Tensor(out!), Tensor)\"\n    ))\n    assert torch.ops.flash_attn_3.get_scheduler_metadata.default._schema.is_backward_compatible_with(parse_schema(\n        \"flash_attn_3::get_scheduler_metadata(int batch_size, int max_seqlen_q, int max_seqlen_k, \"\n        \"int num_heads, int num_heads_k, int headdim, int headdim_v, ScalarType qkv_dtype, Tensor seqused_k, \"\n        \"Tensor? cu_seqlens_q=None, Tensor? cu_seqlens_k=None, Tensor? cu_seqlens_k_new=None, \"\n        \"Tensor? seqused_q=None, Tensor? leftpad_k=None, int? page_size=None, int max_seqlen_k_new=0, \"\n        \"bool is_causal=False, int window_size_left=-1, int window_size_right=-1, \"\n        \"int attention_chunk=0, bool has_softcap=False, int num_splits=0, bool? pack_gqa=None, \"\n        \"int sm_margin=0) -> Tensor\"\n    ))\n"
  },
  {
    "path": "hopper/test_flash_attn_bwd_determinism.py",
    "content": "import os\nimport math\nimport itertools\n\nimport pytest\nimport torch\nimport torch.nn.functional as F\nfrom torch._C import parse_schema\n\nfrom einops import rearrange, repeat\ntry:\n    from flash_attn.layers.rotary import apply_rotary_emb\nexcept ImportError:\n    apply_rotary_emb = None\n\nfrom padding import pad_input, unpad_input\nfrom test_util import (\n    attention_ref,\n    generate_qkv,\n    generate_random_padding_mask,\n)\n\nfrom flash_attn_interface import flash_attn_func, flash_attn_varlen_func, flash_attn_combine\nfrom flash_attn_interface import flash_attn_with_kvcache, get_scheduler_metadata\n\nfrom flash_attn_interface import _flash_attn_backward\n\n\nDISABLE_BACKWARD = os.getenv(\"FLASH_ATTENTION_DISABLE_BACKWARD\", \"FALSE\") == \"TRUE\"\nDISABLE_SPLIT = os.getenv(\"FLASH_ATTENTION_DISABLE_SPLIT\", \"FALSE\") == \"TRUE\"\nDISABLE_PAGEDKV = os.getenv(\"FLASH_ATTENTION_DISABLE_PAGEDKV\", \"FALSE\") == \"TRUE\"\nDISABLE_APPENDKV = os.getenv(\"FLASH_ATTENTION_DISABLE_APPENDKV\", \"FALSE\") == \"TRUE\"\nDISABLE_LOCAL = os.getenv(\"FLASH_ATTENTION_DISABLE_LOCAL\", \"FALSE\") == \"TRUE\"\nDISABLE_SOFTCAP = os.getenv(\"FLASH_ATTENTION_DISABLE_SOFTCAP\", \"FALSE\") == \"TRUE\"\nDISABLE_PACKGQA = os.getenv(\"FLASH_ATTENTION_DISABLE_PACKGQA\", \"FALSE\") == \"TRUE\"\nDISABLE_FP16 = os.getenv(\"FLASH_ATTENTION_DISABLE_FP16\", \"FALSE\") == \"TRUE\"\nDISABLE_FP8 = os.getenv(\"FLASH_ATTENTION_DISABLE_FP8\", \"FALSE\") == \"TRUE\" or torch.cuda.get_device_capability(\"cuda\")[0] < 9\nDISABLE_HDIM64 = os.getenv(\"FLASH_ATTENTION_DISABLE_HDIM64\", \"FALSE\") == \"TRUE\"\nDISABLE_HDIM96 = os.getenv(\"FLASH_ATTENTION_DISABLE_HDIM96\", \"FALSE\") == \"TRUE\"\nDISABLE_HDIM128 = os.getenv(\"FLASH_ATTENTION_DISABLE_HDIM128\", \"FALSE\") == \"TRUE\"\nDISABLE_HDIM192 = os.getenv(\"FLASH_ATTENTION_DISABLE_HDIM192\", \"FALSE\") == \"TRUE\"\nDISABLE_HDIM256 = os.getenv(\"FLASH_ATTENTION_DISABLE_HDIM256\", \"FALSE\") == \"TRUE\"\n\n# deterministic mode not supported for hdim 256\nDISABLE_HDIM256 = True\n\nCOMPILED_HDIMS = (\n    []\n    + ([64] if not DISABLE_HDIM64 else [])\n    + ([96] if not DISABLE_HDIM96 else [])\n    + ([128] if not DISABLE_HDIM128 else [])\n    + ([192] if not DISABLE_HDIM192 else [])\n    + ([256] if not DISABLE_HDIM256 else [])\n)\n\n# @pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])\n# @pytest.mark.parametrize(\"dtype\", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else []))\n@pytest.mark.parametrize(\"dtype\", [torch.bfloat16])\n# @pytest.mark.parametrize(\"dtype\", [torch.float8_e4m3fn])\n@pytest.mark.parametrize(\"mha_type\", [\"mha\", \"mqa\", \"gqa\"])\n# @pytest.mark.parametrize(\"mha_type\", [\"mqa\"])\n# @pytest.mark.parametrize(\"has_qv\", [False, True])\n@pytest.mark.parametrize(\"has_qv\", [False])\n@pytest.mark.parametrize(\"deterministic\", [False, True])\n# @pytest.mark.parametrize(\"deterministic\", [True])\n@pytest.mark.parametrize(\"softcap\", [0.0] + ([15.0] if not DISABLE_SOFTCAP else []))\n# @pytest.mark.parametrize(\"softcap\", [0.0])\n@pytest.mark.parametrize(\"local\", [False] + ([True] if not DISABLE_LOCAL else []))\n# @pytest.mark.parametrize(\"local\", [True])\n@pytest.mark.parametrize(\"causal\", [False, True])\n# @pytest.mark.parametrize(\"causal\", [False])\n# @pytest.mark.parametrize(\"V_colmajor\", [False, True])\n@pytest.mark.parametrize(\"V_colmajor\", [False])\n# @pytest.mark.parametrize(\"d\", [32, 64, 96, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256])\n# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])\n# @pytest.mark.parametrize('d', [56, 80])\n# @pytest.mark.parametrize(\"d\", [64, 128, 256])\n# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128])\n# @pytest.mark.parametrize(\"d\", [64, 96, 128, 192])\n@pytest.mark.parametrize(\"d\", COMPILED_HDIMS)\n# @pytest.mark.parametrize(\"d\", [128])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_k\",\n    [\n        (1, 1),\n        (64, 128),\n        (128, 192),\n        (256, 256),\n        (239, 1),\n        (799, 3),\n        (113, 203),\n        (113, 128),\n        (128, 217),\n        (113, 211),\n        (108, 256),\n        (256, 512),\n        (384, 256),\n        (640, 128),\n        (512, 256),\n        (1024, 1024),\n        (1023, 1024),\n        (1024, 1023),\n        (4096, 4096),\n        # (4224, 4224),\n        # (8192, 8192),\n    ],\n)\n# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])\ndef test_flash_attn_output(\n        seqlen_q, seqlen_k, d, causal, local, softcap, V_colmajor, deterministic, has_qv, mha_type, dtype\n):\n    if V_colmajor and (seqlen_k % 16 != 0 or dtype != torch.float8_e4m3fn):\n        pytest.skip(\"V_colmajor requires seqlen_k to be a multiple of 16 and dtype to be float8_e4m3fn\")\n    if has_qv and (d != 64 or dtype == torch.float8_e4m3fn):\n        pytest.skip(\"Has Qv requires hdim 64 and dtype to be float16 or bfloat16 (not float8_e4m3fn)\")\n    if deterministic and d == 256:\n        pytest.skip(\"Deterministic mode not supported for hdim 256\")\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(0)\n    # batch_size = 40\n    # nheads = 16\n    batch_size = 9 if seqlen_k <= 2048 else 2\n    # batch_size = 1\n    nheads = 6\n    # nheads = 1\n    nheads_kv = nheads if mha_type == \"mha\" else (2 if mha_type == \"gqa\" else 1)\n    dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype\n    # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d])\n    # if dtype == torch.float8_e4m3fn:\n    #     dv_vals = [d]\n    # if has_qv:\n    #     dv_vals = [256, 512]\n    # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if not DISABLE_LOCAL else [0]\n    dv_vals = [d]\n    attention_chunk_vals = [0]\n    for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals):\n        print(f\"{dv = }, {attention_chunk = }\")\n        q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref)\n        if softcap > 0.0:\n            # Ensure the values of qk are at least within softcap range.\n            q_ref = (q_ref * softcap / 4)\n        q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_()\n        k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_()\n        v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_()\n        if has_qv:\n            qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref)\n        else:\n            qv_ref = None\n        # Put window_size after QKV randn so that window_size changes from test to test\n        window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)).tolist()\n        # window_size = (-1, -1) if not local else (16, 0)\n        if dtype == torch.float8_e4m3fn:\n            q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)]\n        else:\n            q_descale, k_descale, v_descale = None, None, None\n        q, k, v = [x.detach().to(dtype).requires_grad_() for x in (q_ref, k_ref, v_ref)]\n        qv = qv_ref.detach().to(dtype).requires_grad_() if has_qv else None\n        if V_colmajor:\n            v = rearrange(rearrange(v.detach(), \"b s h d -> b h d s\").contiguous(), \"b h d s -> b s h d\").requires_grad_()\n        out_ref, attn_ref = attention_ref(\n            q_ref,\n            k_ref,\n            v_ref,\n            None,\n            None,\n            causal=causal,\n            qv=qv_ref,\n            q_descale=q_descale, k_descale=k_descale, v_descale=v_descale,\n            window_size=window_size,\n            attention_chunk=attention_chunk,\n            softcap=softcap\n        )\n        out_pt, attn_pt = attention_ref(\n            q_ref,\n            k_ref,\n            v_ref,\n            None,\n            None,\n            causal=causal,\n            qv=qv_ref,\n            q_descale=q_descale, k_descale=k_descale, v_descale=v_descale,\n            window_size=window_size,\n            attention_chunk=attention_chunk,\n            softcap=softcap,\n            upcast=False,\n            reorder_ops=True,\n            intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None,\n        )\n\n        # qk = torch.einsum('bshd,bthd->bhst', q_ref, k_ref).float()\n        # if qv is not None:\n        #     qk += torch.einsum('bshd,bthd->bhst', qv_ref, v_ref).float()\n        # m = qk.amax(-1, keepdim=True)\n        # s_tmp = torch.exp((qk - m) / math.sqrt(d))\n        # exp_sum = s_tmp.sum(-1)\n        # qk = torch.einsum('bthd,bshd->bhts', q_ref.float() / math.sqrt(d), k_ref.float())\n        # lse_ref = torch.logsumexp(qk, dim=-1)\n\n        # Numerical error if we just do any arithmetic on out_ref\n        fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item()\n        rtol = 2 if softcap == 0.0 else 3\n\n        print(f\"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}\")\n        print(f\"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}\")\n        # pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False]\n        # num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1]\n        pack_gqa_vals = [False]\n        num_splits_vals = [1]\n        for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals):\n            print(f\"{pack_gqa = }, {num_splits = }\")\n            out, softmax_lse = flash_attn_func(\n                q,\n                k,\n                v,\n                causal=causal,\n                qv=qv,\n                q_descale=q_descale, k_descale=k_descale, v_descale=v_descale,\n                window_size=window_size,\n                attention_chunk=attention_chunk,\n                softcap=softcap,\n                pack_gqa=pack_gqa,\n                num_splits=num_splits,\n                return_attn_probs=True,\n            )\n            print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n            print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n            # if not causal:\n            #     print(f\"LSE max diff: {(lse - lse_ref).abs().max().item()}\")\n            # breakpoint()\n\n            # Check that FlashAttention's numerical error is at most twice the numerical error\n            # of a Pytorch implementation.\n            assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol\n\n        if (\n            not DISABLE_BACKWARD \n            and dtype != torch.float8_e4m3fn \n            and not V_colmajor \n            and not has_qv\n            and not dv > 256\n            and not attention_chunk != 0\n        ):\n            g = torch.randn_like(out)\n            do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2)\n            dq = torch.empty_like(q)\n            dk = torch.empty_like(k)\n            dv = torch.empty_like(v)\n            dq, dk, dv, softmax_d = _flash_attn_backward(\n                g,\n                q,\n                k,\n                v,\n                out,\n                softmax_lse,\n                None, None, # cu_seqlens_q, cu_seqlens_k,\n                None, None, # sequed_q, sequed_k,\n                None, None, # max_seqlen_q, max_seqlen_k,\n                dq,\n                dk,\n                dv,\n                d ** (-0.5),\n                causal,\n                window_size=window_size,\n                softcap=softcap,\n                deterministic=deterministic,\n            )\n            # print(f\"dO_O max diff: {(softmax_d - do_o).abs().max().item()}\")\n            # assert (softmax_d - do_o).abs().max().item() <= 1e-5\n            # assert dq_accum.abs().max().item() == 0.0\n\n            # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float())\n            # P = torch.softmax(qk, -1)\n            # dP = P * (dS - do_o.transpose(1, 2).unsqueeze(1))\n            # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float())\n            # dV = torch.einsum('bhts,bthd->bshd', P, g.float())\n            # dK = torch.einsum('bhts,bthd->bshd', dP, q.float())\n\n            # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g)\n            dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g)\n            dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g)\n            print(f\"dQ max diff: {(dq - dq_ref).abs().max().item()}\")\n            print(f\"dK max diff: {(dk - dk_ref).abs().max().item()}\")\n            print(f\"dV max diff: {(dv - dv_ref).abs().max().item()}\")\n            print(f\"dQ mean diff: {(dq - dq_ref).abs().mean().item()}\")\n            print(f\"dK mean diff: {(dk - dk_ref).abs().mean().item()}\")\n            print(f\"dV mean diff: {(dv - dv_ref).abs().mean().item()}\")\n            print(f\"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}\")\n            print(f\"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}\")\n            print(f\"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}\")\n            print(f\"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}\")\n            print(f\"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}\")\n            print(f\"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}\")\n            # breakpoint()\n            dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4)\n            assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol\n            dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4)\n            assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol\n            dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4)\n            assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol\n\n            if deterministic:\n                iterations = 1000\n\n                for i in range(iterations):\n                    dq2 = torch.empty_like(dq)\n                    dk2 = torch.empty_like(dk)\n                    dv2 = torch.empty_like(dv)\n                    dq2, dk2, dv2, softmax_d = _flash_attn_backward(\n                        g,\n                        q,\n                        k,\n                        v,\n                        out,\n                        softmax_lse,\n                        None, None, # cu_seqlens_q, cu_seqlens_k,\n                        None, None, # sequed_q, sequed_k,\n                        None, None, # max_seqlen_q, max_seqlen_k,\n                        dq2,\n                        dk2,\n                        dv2,\n                        d ** (-0.5),\n                        causal,\n                        window_size=window_size,\n                        softcap=softcap,\n                        deterministic=deterministic,\n                    )\n                    print(f'dq max diff with myself: {(dq2 - dq).abs().max().item()}')\n                    print(f'dk max diff with myself: {(dk2 - dk).abs().max().item()}')\n                    print(f'dv max diff with myself: {(dv2 - dv).abs().max().item()}')\n                    assert torch.equal(dq, dq2), f\"dq not deterministic\"\n                    assert torch.equal(dk, dk2), f\"dk not deterministic\"\n                    assert torch.equal(dv, dv2), f\"dv not deterministic\"\n                    print(f\"✅ Iteration {i} passed!\")\n\n\n# @pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])\n# @pytest.mark.parametrize(\"dtype\", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else []))\n@pytest.mark.parametrize(\"dtype\", [torch.bfloat16])\n# @pytest.mark.parametrize(\"dtype\", [torch.float8_e4m3fn])\n@pytest.mark.parametrize(\"mha_type\", [\"mha\", \"mqa\", \"gqa\"])\n# @pytest.mark.parametrize(\"mha_type\", [\"mha\"])\n# @pytest.mark.parametrize(\"has_qv\", [False, True])\n@pytest.mark.parametrize(\"has_qv\", [False])\n@pytest.mark.parametrize(\"deterministic\", [False, True])\n# @pytest.mark.parametrize(\"deterministic\", [True])\n@pytest.mark.parametrize(\"softcap\", [0.0] + ([15.0] if not DISABLE_SOFTCAP else []))\n# @pytest.mark.parametrize(\"softcap\", [0.0])\n@pytest.mark.parametrize(\"local\", [False] + ([True] if not DISABLE_LOCAL else []))\n# @pytest.mark.parametrize(\"local\", [False])\n@pytest.mark.parametrize(\"causal\", [False, True])\n# @pytest.mark.parametrize(\"causal\", [False])\n@pytest.mark.parametrize(\"add_unused_qkv\", [False, True])\n# @pytest.mark.parametrize(\"add_unused_qkv\", [True])\n# @pytest.mark.parametrize(\"d\", [32, 64, 96, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256])\n# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])\n# @pytest.mark.parametrize('d', [56, 80])\n# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128])\n# @pytest.mark.parametrize(\"d\", [64, 96, 128])\n@pytest.mark.parametrize(\"d\", COMPILED_HDIMS)\n# @pytest.mark.parametrize(\"d\", [128])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_k\",\n    [\n        (1, 1),\n        (1, 3),\n        (2, 1),\n        (511, 1),\n        (3, 513),\n        (64, 128),\n        (128, 128),\n        (256, 256),\n        (113, 203),\n        (128, 217),\n        (113, 211),\n        (108, 256),\n        (256, 512),\n        (307, 256),\n        (640, 128),\n        (512, 256),\n        (1024, 1024),\n        (1023, 1024),\n        (1024, 1023),\n        (1024, 1024),\n        (2048, 2048),\n        (4096, 4096),\n    ],\n)\ndef test_flash_attn_varlen_output(\n    seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv, mha_type, dtype,\n):\n    if has_qv and (d != 64 or dtype == torch.float8_e4m3fn):\n        pytest.skip(\"Has Qv requires hdim 64 and dtype to be float16 or bfloat16 (not float8_e4m3fn)\")\n    if deterministic and d == 256:\n        pytest.skip(\"Deterministic mode not supported for hdim 256\")\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local))\n    # batch_size = 40\n    # nheads = 16\n    batch_size = 9 if seqlen_q <= 2048 else 2\n    # batch_size = 32\n    nheads = 6\n    nheads_kv = nheads if mha_type == \"mha\" else (2 if mha_type == \"gqa\" else 1)\n    # batch_size = 2\n    # nheads = 1\n    # nheads_kv = nheads\n    \n    dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype\n    # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d])\n    # if dtype == torch.float8_e4m3fn:\n    #     dv_vals = [d]\n    # if has_qv:\n    #     dv_vals = [256, 512]\n    # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k and not DISABLE_LOCAL else [0]\n    dv_vals = [d]\n    attention_chunk_vals = [0]\n    for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals):\n        print(f\"{dv = }, {attention_chunk = }\")\n        q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref)\n        if softcap > 0.0:\n            # Ensure the values of qk are at least within softcap range.\n            q_ref = (q_ref * softcap / 4).detach().requires_grad_()\n        q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_()\n        k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_()\n        v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_()\n        if has_qv:\n            qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref)\n        else:\n            qv_ref = None\n        # Put window_size after QKV randn so that window_size changes from test to test\n        window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))\n        if dtype == torch.float8_e4m3fn:\n            q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)]\n        else:\n            q_descale, k_descale, v_descale = None, None, None\n        q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)]\n        qv = qv_ref.detach() if has_qv else None\n        query_padding_mask = generate_random_padding_mask(\n            seqlen_q, batch_size, device, mode=\"random\", zero_lengths=False\n        )\n        key_padding_mask = generate_random_padding_mask(\n            seqlen_k, batch_size, device, mode=\"random\", zero_lengths=True\n        )\n\n        def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device):\n            if add_unused:\n                another_mask = generate_random_padding_mask(max_seq_len, bs, device)\n                attn_mask = torch.logical_and(padding_mask, another_mask)\n                unused_mask = torch.logical_xor(\n                    torch.logical_or(padding_mask, another_mask), attn_mask\n                )\n            else:\n                attn_mask = padding_mask\n                unused_mask = None\n            return attn_mask, unused_mask\n\n        query_padding_mask, query_unused_mask = _gen_unused_masks(\n            query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device\n        )\n        key_padding_mask, key_unused_mask = _gen_unused_masks(\n            key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device\n        )\n\n        (\n            q_unpad,\n            k_unpad,\n            v_unpad,\n            qv_unpad,\n            cu_seqlens_q,\n            cu_seqlens_k,\n            seqused_q,\n            seqused_k,\n            max_seqlen_q,\n            max_seqlen_k,\n            q,\n            k,\n            v,\n            qv,\n            output_pad_fn,\n            dq_pad_fn,\n            dk_pad_fn,\n        ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, qv=qv, kvpacked=False,\n                        query_unused_mask=query_unused_mask, key_unused_mask=key_unused_mask)\n        q_unpad, k_unpad, v_unpad = [x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)]\n        out_ref, attn_ref = attention_ref(\n            q_ref,\n            k_ref,\n            v_ref,\n            query_padding_mask,\n            key_padding_mask,\n            causal=causal,\n            qv=qv_ref,\n            q_descale=q_descale, k_descale=k_descale, v_descale=v_descale,\n            window_size=window_size,\n            attention_chunk=attention_chunk,\n            softcap=softcap\n        )\n        out_pt, attn_pt = attention_ref(\n            q_ref,\n            k_ref,\n            v_ref,\n            query_padding_mask,\n            key_padding_mask,\n            causal=causal,\n            qv=qv_ref,\n            q_descale=q_descale, k_descale=k_descale, v_descale=v_descale,\n            window_size=window_size,\n            attention_chunk=attention_chunk,\n            softcap=softcap,\n            upcast=False,\n            reorder_ops=True,\n            intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None,\n        )\n\n\n        print(f\"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}\")\n        print(f\"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}\")\n\n        if query_unused_mask is not None:\n            q_zero_masking = rearrange(query_unused_mask, \"b s -> b s 1 1\")\n\n        # Numerical error if we just do any arithmetic on out_ref\n        fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item()\n        rtol = 2 if softcap == 0.0 else 3\n\n        # pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False]\n        # num_splits_vals = [1, 3, 0] if not DISABLE_SPLIT else [1]\n        pack_gqa_vals = [False]\n        num_splits_vals = [1]\n        print(\"cu_seqlens_q: \", cu_seqlens_q)\n        print(\"cu_seqlens_k: \", cu_seqlens_k)\n        print(\"seqused_q: \", seqused_q)\n        print(\"seqused_k: \", seqused_k)\n        for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals):\n            print(f\"{pack_gqa = }, {num_splits = }\")\n            out_unpad, softmax_lse = flash_attn_varlen_func(\n                q_unpad,\n                k_unpad,\n                v_unpad,\n                cu_seqlens_q,\n                cu_seqlens_k,\n                max_seqlen_q,\n                max_seqlen_k,\n                seqused_q=seqused_q,\n                seqused_k=seqused_k,\n                causal=causal,\n                qv=qv_unpad,\n                q_descale=q_descale,\n                k_descale=k_descale, v_descale=v_descale,\n                window_size=window_size,\n                attention_chunk=attention_chunk,\n                softcap=softcap,\n                pack_gqa=pack_gqa,\n                num_splits=num_splits,\n                deterministic=deterministic,\n                return_attn_probs=True,\n            )\n            out = output_pad_fn(out_unpad)\n            if query_unused_mask is not None:\n                out.masked_fill_(q_zero_masking, 0.0)\n            print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n            print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n            # if not causal:\n            #     print(f\"LSE max diff: {(lse - lse_ref).abs().max().item()}\")\n            # breakpoint()\n\n            # Check that FlashAttention's numerical error is at most 3x the numerical error\n            # of a Pytorch implementation.\n            assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol\n\n\n        if (\n            not DISABLE_BACKWARD \n            and dtype != torch.float8_e4m3fn \n            and not has_qv\n            and not dv > 256\n            and not attention_chunk != 0\n        ):\n            g_unpad = torch.randn_like(out_unpad)\n            do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2)\n            dq_unpad = torch.empty_like(q_unpad)\n            dk_unpad = torch.empty_like(k_unpad)\n            dv_unpad = torch.empty_like(v_unpad)\n            dq_unpad, dk_unpad, dv_unpad, softmax_d = _flash_attn_backward(\n                g_unpad,\n                q_unpad,\n                k_unpad,\n                v_unpad,\n                out_unpad,\n                softmax_lse,\n                cu_seqlens_q, cu_seqlens_k,\n                seqused_q, seqused_k,\n                max_seqlen_q, max_seqlen_k,\n                dq_unpad,\n                dk_unpad,\n                dv_unpad,\n                d ** (-0.5),\n                causal,\n                window_size=window_size,\n                softcap=softcap,\n                deterministic=deterministic,\n            )\n            dq = dq_pad_fn(dq_unpad)\n            dk = dk_pad_fn(dk_unpad)\n            dv = dk_pad_fn(dv_unpad)\n            if key_unused_mask is not None:\n                k_zero_masking = rearrange(key_unused_mask, \"b s -> b s 1 1\")\n                dk.masked_fill_(k_zero_masking, 0.0)\n                dv.masked_fill_(k_zero_masking, 0.0)\n            if query_unused_mask is not None:\n                dq.masked_fill_(q_zero_masking, 0.0)\n            # print(f\"dO_O max diff: {(softmax_d - do_o).abs().max().item()}\")\n            # assert (softmax_d - do_o).abs().max().item() <= 1e-5\n            # assert dq_accum.abs().max().item() == 0.0\n            g = output_pad_fn(g_unpad)\n\n            # qk = torch.einsum('bthd,bshd->bhts', q / (d ** 0.5), k).float()\n            # qk = torch.masked_fill(qk, rearrange(~key_padding_mask, \"b s -> b 1 1 s\"), float(\"-inf\"))\n            # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float())\n            # P = torch.softmax(qk, -1)\n            # dP = P * (dS - (g.float() * out.float()).sum(-1).transpose(1, 2).unsqueeze(-1))\n            # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float())\n            # dV = torch.einsum('bhts,bthd->bshd', P, g.float())\n            # dK = torch.einsum('bhts,bthd->bshd', dP, q.float())\n\n\n            # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g)\n            dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g)\n            dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g)\n            print(f\"dQ max diff: {(dq - dq_ref).abs().max().item()}\")\n            print(f\"dK max diff: {(dk - dk_ref).abs().max().item()}\")\n            print(f\"dV max diff: {(dv - dv_ref).abs().max().item()}\")\n            print(f\"dQ mean diff: {(dq - dq_ref).abs().mean().item()}\")\n            print(f\"dK mean diff: {(dk - dk_ref).abs().mean().item()}\")\n            print(f\"dV mean diff: {(dv - dv_ref).abs().mean().item()}\")\n            print(f\"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}\")\n            print(f\"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}\")\n            print(f\"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}\")\n            print(f\"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}\")\n            print(f\"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}\")\n            print(f\"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}\")\n            # breakpoint()\n            dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4)\n            assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol\n            dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4)\n            assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol\n            dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4)\n            assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol\n\n            print(dq_unpad.shape)\n            print(dk_unpad.shape)\n            print(dv_unpad.shape)\n\n            print(dq.shape)\n            print(dk.shape)\n            print(dv.shape)\n\n            if deterministic:\n                iterations = 1000\n\n                for i in range(iterations):\n                    dq_unpad2 = torch.empty_like(q_unpad)\n                    dk_unpad2 = torch.empty_like(k_unpad)\n                    dv_unpad2 = torch.empty_like(v_unpad)\n                    dq_unpad2, dk_unpad2, dv_unpad2, softmax_d = _flash_attn_backward(\n                        g_unpad,\n                        q_unpad,\n                        k_unpad,\n                        v_unpad,\n                        out_unpad,\n                        softmax_lse,\n                        cu_seqlens_q, cu_seqlens_k,\n                        seqused_q, seqused_k,\n                        max_seqlen_q, max_seqlen_k,\n                        dq_unpad2,\n                        dk_unpad2,\n                        dv_unpad2,\n                        d ** (-0.5),\n                        causal,\n                        window_size=window_size,\n                        softcap=softcap,\n                        deterministic=deterministic,\n                    )\n\n                    dq2 = dq_pad_fn(dq_unpad2)\n                    dk2 = dk_pad_fn(dk_unpad2)\n                    dv2 = dk_pad_fn(dv_unpad2)\n                    if key_unused_mask is not None:\n                        k_zero_masking = rearrange(key_unused_mask, \"b s -> b s 1 1\")\n                        dk2.masked_fill_(k_zero_masking, 0.0)\n                        dv2.masked_fill_(k_zero_masking, 0.0)\n                    if query_unused_mask is not None:\n                        dq2.masked_fill_(q_zero_masking, 0.0)\n                    \n                    print(f'dq max diff with myself: {(dq2 - dq).abs().max().item()}')\n                    print(f'dk max diff with myself: {(dk2 - dk).abs().max().item()}')\n                    print(f'dv max diff with myself: {(dv2 - dv).abs().max().item()}')\n                    \n                    assert torch.equal(dq, dq2), f\"dq not deterministic\"\n                    assert torch.equal(dk, dk2), f\"dk not deterministic\"\n                    assert torch.equal(dv, dv2), f\"dv not deterministic\"\n\n                    print(f\"✅ Iteration {i} passed!\")"
  },
  {
    "path": "hopper/test_flash_attn_triton_amd.py",
    "content": "import os\nimport math\nimport itertools\n\nimport pytest\nimport torch\nimport torch.nn.functional as F\nfrom torch._C import parse_schema\n\nfrom einops import rearrange, repeat\ntry:\n    from flash_attn.layers.rotary import apply_rotary_emb\nexcept ImportError:\n    apply_rotary_emb = None\n\nfrom padding import pad_input, unpad_input\nfrom test_util import (\n    attention_ref,\n    generate_qkv,\n    generate_random_padding_mask,\n)\n\nfrom flash_attn_interface import flash_attn_func, flash_attn_varlen_func, flash_attn_combine\nfrom flash_attn_interface import flash_attn_with_kvcache, get_scheduler_metadata\n\n\nDISABLE_BACKWARD = os.getenv(\"FLASH_ATTENTION_DISABLE_BACKWARD\", \"FALSE\") == \"TRUE\"\nDISABLE_SPLIT = os.getenv(\"FLASH_ATTENTION_DISABLE_SPLIT\", \"TRUE\") == \"TRUE\"\nDISABLE_PAGEDKV = os.getenv(\"FLASH_ATTENTION_DISABLE_PAGEDKV\", \"FALSE\") == \"TRUE\"\nDISABLE_APPENDKV = os.getenv(\"FLASH_ATTENTION_DISABLE_APPENDKV\", \"FALSE\") == \"TRUE\"\nDISABLE_LOCAL = os.getenv(\"FLASH_ATTENTION_DISABLE_LOCAL\", \"TRUE\") == \"TRUE\"\nDISABLE_SOFTCAP = os.getenv(\"FLASH_ATTENTION_DISABLE_SOFTCAP\", \"TRUE\") == \"TRUE\"\nDISABLE_PACKGQA = os.getenv(\"FLASH_ATTENTION_DISABLE_PACKGQA\", \"TRUE\") == \"TRUE\"\nDISABLE_FP16 = os.getenv(\"FLASH_ATTENTION_DISABLE_FP16\", \"FALSE\") == \"TRUE\"\nDISABLE_FP8 = os.getenv(\"FLASH_ATTENTION_DISABLE_FP8\", \"FALSE\") == \"TRUE\" or torch.cuda.get_device_capability(\"cuda\")[0] < 9\nDISABLE_HDIM64 = os.getenv(\"FLASH_ATTENTION_DISABLE_HDIM64\", \"FALSE\") == \"TRUE\"\nDISABLE_HDIM96 = os.getenv(\"FLASH_ATTENTION_DISABLE_HDIM96\", \"FALSE\") == \"TRUE\"\nDISABLE_HDIM128 = os.getenv(\"FLASH_ATTENTION_DISABLE_HDIM128\", \"FALSE\") == \"TRUE\"\nDISABLE_HDIM192 = os.getenv(\"FLASH_ATTENTION_DISABLE_HDIM192\", \"FALSE\") == \"TRUE\"\nDISABLE_HDIM256 = os.getenv(\"FLASH_ATTENTION_DISABLE_HDIM256\", \"FALSE\") == \"TRUE\"\n\nCOMPILED_HDIMS = (\n    []\n    + ([64] if not DISABLE_HDIM64 else [])\n    + ([96] if not DISABLE_HDIM96 else [])\n    + ([128] if not DISABLE_HDIM128 else [])\n    + ([192] if not DISABLE_HDIM192 else [])\n    + ([256] if not DISABLE_HDIM256 else [])\n)\n\n\n# @pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])\n@pytest.mark.parametrize(\"dtype\", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else []))\n# @pytest.mark.parametrize(\"dtype\", [torch.bfloat16])\n# @pytest.mark.parametrize(\"dtype\", [torch.float8_e4m3fn])\n@pytest.mark.parametrize(\"mha_type\", [\"mha\", \"mqa\", \"gqa\"])\n# @pytest.mark.parametrize(\"mha_type\", [\"mha\"])\n# @pytest.mark.parametrize(\"has_qv\", [False, True])\n@pytest.mark.parametrize(\"has_qv\", [False])\n# @pytest.mark.parametrize(\"deterministic\", [False, True])\n@pytest.mark.parametrize(\"deterministic\", [False])\n@pytest.mark.parametrize(\"softcap\", [0.0] + ([15.0] if not DISABLE_SOFTCAP else []))\n# @pytest.mark.parametrize(\"softcap\", [0.0])\n@pytest.mark.parametrize(\"local\", [False] + ([True] if not DISABLE_LOCAL else []))\n# @pytest.mark.parametrize(\"local\", [False])\n@pytest.mark.parametrize(\"causal\", [False, True])\n# @pytest.mark.parametrize(\"causal\", [True])\n# @pytest.mark.parametrize(\"V_colmajor\", [False, True])\n@pytest.mark.parametrize(\"V_colmajor\", [False])\n# @pytest.mark.parametrize(\"d\", [32, 64, 96, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256])\n# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])\n# @pytest.mark.parametrize('d', [56, 80])\n# @pytest.mark.parametrize(\"d\", [64, 128, 256])\n# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128])\n# @pytest.mark.parametrize(\"d\", [64, 96, 128, 192])\n@pytest.mark.parametrize(\"d\", COMPILED_HDIMS)\n# @pytest.mark.parametrize(\"d\", [128])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_k\",\n    [\n        (1, 1),\n        (64, 128),\n        (128, 192),\n        (256, 256),\n        (239, 1),\n        (799, 3),\n        (113, 203),\n        (113, 128),\n        (128, 217),\n        (113, 211),\n        (108, 256),\n        (256, 512),\n        (384, 256),\n        (640, 128),\n        (512, 256),\n        (1024, 1024),\n        (1023, 1024),\n        (1024, 1023),\n        (4096, 4096),\n        (4224, 4224),\n    ],\n)\n# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])\ndef test_flash_attn_output(\n        seqlen_q, seqlen_k, d, causal, local, softcap, V_colmajor, deterministic, has_qv, mha_type, dtype\n):\n    if V_colmajor and (seqlen_k % 16 != 0 or dtype != torch.float8_e4m3fn):\n        pytest.skip(\"V_colmajor requires seqlen_k to be a multiple of 16 and dtype to be float8_e4m3fn\")\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(0)\n    # batch_size = 40\n    # nheads = 16\n    batch_size = 9 if seqlen_k <= 2048 else 2\n    # batch_size = 1\n    nheads = 6\n    # nheads = 1\n    nheads_kv = nheads if mha_type == \"mha\" else (2 if mha_type == \"gqa\" else 1)\n    dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype\n    dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d])\n    if dtype == torch.float8_e4m3fn:\n        dv_vals = [d]\n    attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if not DISABLE_LOCAL else [0]\n    for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals):\n        q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref)\n        if softcap > 0.0:\n            # Ensure the values of qk are at least within softcap range.\n            q_ref = (q_ref * softcap / 4)\n        q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_()\n        k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_()\n        v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_()\n        if has_qv:\n            qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref)\n        else:\n            qv_ref = None\n        # Put window_size after QKV randn so that window_size changes from test to test\n        window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)).tolist()\n        # window_size = (-1, -1) if not local else (16, 0)\n        if dtype == torch.float8_e4m3fn:\n            q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)]\n        else:\n            q_descale, k_descale, v_descale = None, None, None\n        q, k, v = [x.detach().to(dtype).requires_grad_() for x in (q_ref, k_ref, v_ref)]\n        qv = qv_ref.detach().to(dtype).requires_grad_() if has_qv else None\n        if V_colmajor:\n            v = rearrange(rearrange(v.detach(), \"b s h d -> b h d s\").contiguous(), \"b h d s -> b s h d\").requires_grad_()\n        out_ref, attn_ref = attention_ref(\n            q_ref,\n            k_ref,\n            v_ref,\n            None,\n            None,\n            causal=causal,\n            qv=qv_ref,\n            q_descale=q_descale, k_descale=k_descale, v_descale=v_descale,\n            window_size=window_size,\n            attention_chunk=attention_chunk,\n            softcap=softcap\n        )\n        out_pt, attn_pt = attention_ref(\n            q_ref,\n            k_ref,\n            v_ref,\n            None,\n            None,\n            causal=causal,\n            qv=qv_ref,\n            q_descale=q_descale, k_descale=k_descale, v_descale=v_descale,\n            window_size=window_size,\n            attention_chunk=attention_chunk,\n            softcap=softcap,\n            upcast=False,\n            reorder_ops=True,\n            intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None,\n        )\n\n        # qk = torch.einsum('bshd,bthd->bhst', q_ref, k_ref).float()\n        # if qv is not None:\n        #     qk += torch.einsum('bshd,bthd->bhst', qv_ref, v_ref).float()\n        # m = qk.amax(-1, keepdim=True)\n        # s_tmp = torch.exp((qk - m) / math.sqrt(d))\n        # exp_sum = s_tmp.sum(-1)\n        # qk = torch.einsum('bthd,bshd->bhts', q_ref.float() / math.sqrt(d), k_ref.float())\n        # lse_ref = torch.logsumexp(qk, dim=-1)\n\n        # Numerical error if we just do any arithmetic on out_ref\n        fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item()\n        rtol = 2 if softcap == 0.0 else 3\n\n        print(f\"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}\")\n        print(f\"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}\")\n        pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False]\n        num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1]\n        for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals):\n            out = flash_attn_func(\n                q,\n                k,\n                v,\n                causal=causal,\n                qv=qv,\n                q_descale=q_descale, k_descale=k_descale, v_descale=v_descale,\n                window_size=window_size,\n                attention_chunk=attention_chunk,\n                softcap=softcap,\n                pack_gqa=pack_gqa,\n                num_splits=num_splits\n            )\n            print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n            print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n            # if not causal:\n            #     print(f\"LSE max diff: {(lse - lse_ref).abs().max().item()}\")\n            # breakpoint()\n\n            # Check that FlashAttention's numerical error is at most twice the numerical error\n            # of a Pytorch implementation.\n            assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol\n\n        if (\n            not DISABLE_BACKWARD \n            and dtype != torch.float8_e4m3fn \n            and not V_colmajor \n            and not has_qv\n            and not dv > 256\n            and not attention_chunk != 0\n        ):\n            g = torch.randn_like(out)\n            do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2)\n            # import flash_attn_3_cuda\n            # dq, dk, dv, softmax_d, dq_accum, dk_accum, dv_accum = flash_attn_3_cuda.bwd(\n            #     g,\n            #     q,\n            #     k,\n            #     v,\n            #     out,\n            #     lse,\n            #     None,\n            #     None,\n            #     None,\n            #     d ** (-0.5),\n            #     causal,\n            #     window_size[0], window_size[1],\n            #     softcap,\n            #     deterministic,\n            #     0,  # sm_margin\n            # )\n            dq, dk, dv = torch.autograd.grad(out, (q, k, v), g)\n            # print(f\"dO_O max diff: {(softmax_d - do_o).abs().max().item()}\")\n            # assert (softmax_d - do_o).abs().max().item() <= 1e-5\n            # assert dq_accum.abs().max().item() == 0.0\n\n            # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float())\n            # P = torch.softmax(qk, -1)\n            # dP = P * (dS - do_o.transpose(1, 2).unsqueeze(1))\n            # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float())\n            # dV = torch.einsum('bhts,bthd->bshd', P, g.float())\n            # dK = torch.einsum('bhts,bthd->bshd', dP, q.float())\n\n            # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g)\n            dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g)\n            dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g)\n            print(f\"dQ max diff: {(dq - dq_ref).abs().max().item()}\")\n            print(f\"dK max diff: {(dk - dk_ref).abs().max().item()}\")\n            print(f\"dV max diff: {(dv - dv_ref).abs().max().item()}\")\n            print(f\"dQ mean diff: {(dq - dq_ref).abs().mean().item()}\")\n            print(f\"dK mean diff: {(dk - dk_ref).abs().mean().item()}\")\n            print(f\"dV mean diff: {(dv - dv_ref).abs().mean().item()}\")\n            print(f\"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}\")\n            print(f\"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}\")\n            print(f\"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}\")\n            print(f\"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}\")\n            print(f\"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}\")\n            print(f\"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}\")\n            # breakpoint()\n            dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4)\n            assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol\n            dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4)\n            assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol\n            dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4)\n            assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol\n\n\n# @pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])\n@pytest.mark.parametrize(\"dtype\", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else []))\n# @pytest.mark.parametrize(\"dtype\", [torch.bfloat16])\n# @pytest.mark.parametrize(\"dtype\", [torch.float8_e4m3fn])\n@pytest.mark.parametrize(\"mha_type\", [\"mha\", \"mqa\", \"gqa\"])\n# @pytest.mark.parametrize(\"mha_type\", [\"mha\"])\n# @pytest.mark.parametrize(\"has_qv\", [False, True])\n@pytest.mark.parametrize(\"has_qv\", [False])\n# @pytest.mark.parametrize(\"deterministic\", [False, True])\n@pytest.mark.parametrize(\"deterministic\", [False])\n@pytest.mark.parametrize(\"softcap\", [0.0] + ([15.0] if not DISABLE_SOFTCAP else []))\n# @pytest.mark.parametrize(\"softcap\", [0.0])\n@pytest.mark.parametrize(\"local\", [False] + ([True] if not DISABLE_LOCAL else []))\n# @pytest.mark.parametrize(\"local\", [False])\n@pytest.mark.parametrize(\"causal\", [False, True])\n# @pytest.mark.parametrize(\"causal\", [False])\n@pytest.mark.parametrize(\"add_unused_qkv\", [False, True])\n# @pytest.mark.parametrize(\"add_unused_qkv\", [True])\n# @pytest.mark.parametrize(\"d\", [32, 64, 96, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256])\n# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])\n# @pytest.mark.parametrize('d', [56, 80])\n# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128])\n# @pytest.mark.parametrize(\"d\", [64, 96, 128])\n@pytest.mark.parametrize(\"d\", COMPILED_HDIMS)\n# @pytest.mark.parametrize(\"d\", [128])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_k\",\n    [\n        (1, 1),\n        (1, 3),\n        (2, 1),\n        (511, 1),\n        (3, 513),\n        (64, 128),\n        (128, 128),\n        (256, 256),\n        (113, 203),\n        (128, 217),\n        (113, 211),\n        (108, 256),\n        (256, 512),\n        (307, 256),\n        (640, 128),\n        (512, 256),\n        (1024, 1024),\n        (1023, 1024),\n        (1024, 1023),\n        (2048, 2048),\n    ],\n)\ndef test_flash_attn_varlen_output(\n        seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv, mha_type, dtype\n):\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local))\n    # batch_size = 40\n    # nheads = 16\n    batch_size = 9 if seqlen_q <= 2048 else 2\n    nheads = 6\n    # batch_size = 2\n    # nheads = 1\n    nheads_kv = nheads if mha_type == \"mha\" else (2 if mha_type == \"gqa\" else 1)\n    dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype\n    dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d])\n    if dtype == torch.float8_e4m3fn:\n        dv_vals = [d]\n    attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k and not DISABLE_LOCAL else [0]\n    for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals):\n        q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref)\n        if softcap > 0.0:\n            # Ensure the values of qk are at least within softcap range.\n            q_ref = (q_ref * softcap / 4).detach().requires_grad_()\n        q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_()\n        k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_()\n        v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_()\n        if has_qv:\n            qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref)\n        else:\n            qv_ref = None\n        # Put window_size after QKV randn so that window_size changes from test to test\n        window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))\n        if dtype == torch.float8_e4m3fn:\n            q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)]\n        else:\n            q_descale, k_descale, v_descale = None, None, None\n        q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)]\n        qv = qv_ref.detach() if has_qv else None\n        query_padding_mask = generate_random_padding_mask(\n            seqlen_q, batch_size, device, mode=\"random\", zero_lengths=False\n        )\n        key_padding_mask = generate_random_padding_mask(\n            seqlen_k, batch_size, device, mode=\"random\", zero_lengths=True\n        )\n\n        def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device):\n            if add_unused:\n                another_mask = generate_random_padding_mask(max_seq_len, bs, device)\n                attn_mask = torch.logical_and(padding_mask, another_mask)\n                unused_mask = torch.logical_xor(\n                    torch.logical_or(padding_mask, another_mask), attn_mask\n                )\n            else:\n                attn_mask = padding_mask\n                unused_mask = None\n            return attn_mask, unused_mask\n\n        query_padding_mask, query_unused_mask = _gen_unused_masks(\n            query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device\n        )\n        key_padding_mask, key_unused_mask = _gen_unused_masks(\n            key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device\n        )\n\n        (\n            q_unpad,\n            k_unpad,\n            v_unpad,\n            qv_unpad,\n            cu_seqlens_q,\n            cu_seqlens_k,\n            seqused_q,\n            seqused_k,\n            max_seqlen_q,\n            max_seqlen_k,\n            q,\n            k,\n            v,\n            qv,\n            output_pad_fn,\n            dq_pad_fn,\n            dk_pad_fn,\n        ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, qv=qv, kvpacked=False,\n                        query_unused_mask=query_unused_mask, key_unused_mask=key_unused_mask)\n        q_unpad, k_unpad, v_unpad = [x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)]\n        out_ref, attn_ref = attention_ref(\n            q_ref,\n            k_ref,\n            v_ref,\n            query_padding_mask,\n            key_padding_mask,\n            causal=causal,\n            qv=qv_ref,\n            q_descale=q_descale, k_descale=k_descale, v_descale=v_descale,\n            window_size=window_size,\n            attention_chunk=attention_chunk,\n            softcap=softcap\n        )\n        out_pt, attn_pt = attention_ref(\n            q_ref,\n            k_ref,\n            v_ref,\n            query_padding_mask,\n            key_padding_mask,\n            causal=causal,\n            qv=qv_ref,\n            q_descale=q_descale, k_descale=k_descale, v_descale=v_descale,\n            window_size=window_size,\n            attention_chunk=attention_chunk,\n            softcap=softcap,\n            upcast=False,\n            reorder_ops=True,\n            intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None,\n        )\n\n\n        print(f\"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}\")\n        print(f\"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}\")\n\n        if query_unused_mask is not None:\n            q_zero_masking = rearrange(query_unused_mask, \"b s -> b s 1 1\")\n\n        # Numerical error if we just do any arithmetic on out_ref\n        fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item()\n        rtol = 2 if softcap == 0.0 else 3\n\n        pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False]\n        num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1]\n        for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals):\n            out_unpad = flash_attn_varlen_func(\n                q_unpad,\n                k_unpad,\n                v_unpad,\n                cu_seqlens_q,\n                cu_seqlens_k,\n                max_seqlen_q,\n                max_seqlen_k,\n                seqused_q=seqused_q,\n                seqused_k=seqused_k,\n                causal=causal,\n                qv=qv_unpad,\n                q_descale=q_descale,\n                k_descale=k_descale, v_descale=v_descale,\n                window_size=window_size,\n                attention_chunk=attention_chunk,\n                softcap=softcap,\n            )\n            out = output_pad_fn(out_unpad)\n            if query_unused_mask is not None:\n                out.masked_fill_(q_zero_masking, 0.0)\n            print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n            print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n            # if not causal:\n            #     print(f\"LSE max diff: {(lse - lse_ref).abs().max().item()}\")\n            # breakpoint()\n\n            # Check that FlashAttention's numerical error is at most 3x the numerical error\n            # of a Pytorch implementation.\n            assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol\n\n\n        if (\n            not DISABLE_BACKWARD \n            and dtype != torch.float8_e4m3fn \n            and not has_qv\n            and not dv > 256\n            and not attention_chunk != 0\n        ):\n            g_unpad = torch.randn_like(out_unpad)\n            do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2)\n            # import flash_attn_3_cuda\n            # dq_unpad, dk_unpad, dv_unpad, softmax_d, dq_accum, lse_log2 = flash_attn_3_cuda.bwd_varlen(\n            #     g_unpad,\n            #     q_unpad,\n            #     k_unpad,\n            #     v_unpad,\n            #     out_unpad,\n            #     lse,\n            #     None,\n            #     None,\n            #     None,\n            #     cu_seqlens_q,\n            #     cu_seqlens_k,\n            #     None, None,\n            #     max_seqlen_q,\n            #     max_seqlen_k,\n            #     d ** (-0.5),\n            #     causal,\n            #     window_size[0], window_size[1],\n            #     softcap,\n            #     deterministic,\n            #     0,  # sm_margin\n            # )\n            dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad(out_unpad, (q_unpad, k_unpad, v_unpad), g_unpad)\n            dq = dq_pad_fn(dq_unpad)\n            dk = dk_pad_fn(dk_unpad)\n            dv = dk_pad_fn(dv_unpad)\n            if key_unused_mask is not None:\n                k_zero_masking = rearrange(key_unused_mask, \"b s -> b s 1 1\")\n                dk.masked_fill_(k_zero_masking, 0.0)\n                dv.masked_fill_(k_zero_masking, 0.0)\n            if query_unused_mask is not None:\n                dq.masked_fill_(q_zero_masking, 0.0)\n            # print(f\"dO_O max diff: {(softmax_d - do_o).abs().max().item()}\")\n            # assert (softmax_d - do_o).abs().max().item() <= 1e-5\n            # assert dq_accum.abs().max().item() == 0.0\n            g = output_pad_fn(g_unpad)\n\n            # qk = torch.einsum('bthd,bshd->bhts', q / (d ** 0.5), k).float()\n            # qk = torch.masked_fill(qk, rearrange(~key_padding_mask, \"b s -> b 1 1 s\"), float(\"-inf\"))\n            # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float())\n            # P = torch.softmax(qk, -1)\n            # dP = P * (dS - (g.float() * out.float()).sum(-1).transpose(1, 2).unsqueeze(-1))\n            # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float())\n            # dV = torch.einsum('bhts,bthd->bshd', P, g.float())\n            # dK = torch.einsum('bhts,bthd->bshd', dP, q.float())\n\n\n            # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g)\n            dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g)\n            dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g)\n            print(f\"dQ max diff: {(dq - dq_ref).abs().max().item()}\")\n            print(f\"dK max diff: {(dk - dk_ref).abs().max().item()}\")\n            print(f\"dV max diff: {(dv - dv_ref).abs().max().item()}\")\n            print(f\"dQ mean diff: {(dq - dq_ref).abs().mean().item()}\")\n            print(f\"dK mean diff: {(dk - dk_ref).abs().mean().item()}\")\n            print(f\"dV mean diff: {(dv - dv_ref).abs().mean().item()}\")\n            print(f\"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}\")\n            print(f\"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}\")\n            print(f\"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}\")\n            print(f\"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}\")\n            print(f\"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}\")\n            print(f\"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}\")\n            # breakpoint()\n            dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4)\n            assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol\n            dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4)\n            assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol\n            dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4)\n            assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol\n\n\n# @pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])\n@pytest.mark.parametrize(\"dtype\", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else []))\n# @pytest.mark.parametrize(\"dtype\", [torch.bfloat16])\n# @pytest.mark.parametrize(\"dtype\", [torch.float8_e4m3fn])\n@pytest.mark.parametrize(\"mha_type\", [\"mha\", \"mqa\", \"gqa\"])\n# @pytest.mark.parametrize(\"mha_type\", [\"mha\"])\n@pytest.mark.parametrize(\"new_kv\", [False] + ([True] if not DISABLE_APPENDKV else []))\n# @pytest.mark.parametrize(\"new_kv\", [True])\n@pytest.mark.parametrize(\"causal,local\", [(False, False), (True, False)] + ([(False, True)] if not DISABLE_LOCAL else []))\n# @pytest.mark.parametrize(\"causal,local\", [(False, False), (True, False)])\n# @pytest.mark.parametrize(\"causal,local\", [(False, False)])\n@pytest.mark.parametrize(\"seqlen_new_eq_seqlen_q\", [True, False] if not DISABLE_APPENDKV else [True])\n# @pytest.mark.parametrize(\"seqlen_new_eq_seqlen_q\", [True])\n@pytest.mark.parametrize(\"has_rotary_seqlens\", [False, True])\n# @pytest.mark.parametrize(\"has_rotary_seqlens\", [False])\n@pytest.mark.parametrize(\"rotary_interleaved\", [False, True] if not DISABLE_APPENDKV else [False])\n# @pytest.mark.parametrize(\"rotary_interleaved\", [True])\n@pytest.mark.parametrize(\"rotary_fraction\", [0.0, 0.5, 1.0] if (not DISABLE_APPENDKV) and (apply_rotary_emb is not None) else [0.0])\n# @pytest.mark.parametrize(\"rotary_fraction\", [0.0])\n@pytest.mark.parametrize(\"page_size\", [None] + ([1, 4, 128] if not DISABLE_PAGEDKV else []))\n# @pytest.mark.parametrize(\"page_size\", [None])\n@pytest.mark.parametrize(\"has_leftpad\", [False])\n# @pytest.mark.parametrize(\"has_leftpad\", [False])\n@pytest.mark.parametrize(\"has_batch_idx\", [False])\n# @pytest.mark.parametrize(\"has_batch_idx\", [False])\n@pytest.mark.parametrize(\"varlen_q\", [False])\n# @pytest.mark.parametrize(\"varlen_q\", [False])\n# @pytest.mark.parametrize(\"d\", [32, 59, 64, 80, 128, 256])\n# @pytest.mark.parametrize(\"d\", [32, 64, 96, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])\n# @pytest.mark.parametrize('d', [56, 80])\n@pytest.mark.parametrize(\"d\", [128])\n# @pytest.mark.parametrize(\"d\", [192])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_k\",\n    [\n        (1, 128),\n        (1, 339),\n        (3, 1024),\n        (64, 800),\n        (64, 256),\n        (3, 799),\n        (64, 2048),\n        (16, 20000),\n        # (1, 128 * 1024),\n        # (16, 128 * 1024),\n        (128, 128),\n        (256, 512),  # To test appending KV with more than 1 block\n        (2048, 3577),  # Enough tile to test persistent scheduler\n    ],\n)\n# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])\ndef test_flash_attn_kvcache(\n    seqlen_q,\n    seqlen_k,\n    d,\n    varlen_q,\n    has_batch_idx,\n    has_leftpad,\n    page_size,\n    rotary_fraction,\n    rotary_interleaved,\n    has_rotary_seqlens,\n    seqlen_new_eq_seqlen_q,\n    causal,\n    local,\n    new_kv,\n    mha_type,\n    dtype,\n):\n    if page_size is not None and seqlen_k % page_size != 0:\n        pytest.skip()\n    if seqlen_q > seqlen_k and new_kv:\n        pytest.skip()\n    if not new_kv and rotary_fraction > 0.0:\n        pytest.skip()\n    if rotary_fraction == 0.0 and has_rotary_seqlens:\n        pytest.skip()\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 5\n    # batch_size = 1\n    batch_size_cache = batch_size if not has_batch_idx else batch_size * 2\n    nheads = 6\n    # nheads = 1\n    # rotary_dim must be a multiple of 16, and must be <= d\n    rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16\n    nheads_k = nheads if mha_type == \"mha\" else (1 if mha_type == \"mqa\" else 3)\n    assert nheads % nheads_k == 0\n    dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype\n    dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d])\n    if dtype == torch.float8_e4m3fn:\n        dv_vals = [d]\n    attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if (causal or local) and not DISABLE_LOCAL else [0]\n    for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals):\n        has_qv = d == 64 and dv >= 256\n        q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref)\n        if has_qv:\n            qv = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref)\n        else:\n            qv = None\n        if varlen_q:\n            query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode=\"random\")\n            q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, *rest = unpad_input(q, query_padding_mask)\n            output_pad_fn = lambda output_unpad: pad_input(\n                output_unpad, indices_q, batch_size, seqlen_q\n            )\n            qv_unpad = rearrange(qv, \"b s ... -> (b s) ...\")[indices_q] if has_qv else None\n        else:\n            query_padding_mask = None\n            q_unpad = q\n            qv_unpad = qv\n            cu_seqlens_q, max_seqlen_q = None, None\n        # Put window_size after QKV randn so that window_size changes from test to test\n        window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))\n\n        seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item()\n        cu_seqlens_k_new = None\n        key_new_padding_mask = None\n        if new_kv:\n            k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref)\n            v = torch.randn(batch_size, seqlen_new, nheads_k, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref)\n            if varlen_q:  # k & v are also varlen\n                key_new_padding_mask = generate_random_padding_mask(seqlen_new, batch_size, device, mode=\"random\")\n                k_unpad, indices_k, cu_seqlens_k_new, *rest = unpad_input(k, key_new_padding_mask)\n                v_unpad, *rest = unpad_input(v, key_new_padding_mask)\n            else:\n                k_unpad, v_unpad = k, v\n        else:\n            k, v, k_unpad, v_unpad = None, None, None, None\n        if page_size is None:\n            k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref)\n            v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref)\n            page_table = None\n        else:\n            (\n                k_cache,\n                v_cache,\n                page_table,\n                k_cache_paged,\n                v_cache_paged,\n                num_blocks,\n            ) = _generate_block_kvcache(\n                seqlen_k, page_size, batch_size_cache, nheads_k, d, dv, device, dtype, dtype_ref\n            )\n        cache_seqlens = torch.randint(\n            0 if new_kv else 1,\n            # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough\n            (\n                (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1)\n                if new_kv\n                else (seqlen_k + 1)\n            ),\n            (batch_size,),\n            dtype=torch.int32,\n            device=device,\n        )\n        if has_leftpad:\n            cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device)\n                                    if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device)\n                                    for i in range(batch_size)])\n        else:\n            cache_leftpad = None\n        if has_batch_idx:\n            cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[\n                :batch_size\n            ]\n        else:\n            cache_batch_idx = None\n        arange = rearrange(torch.arange(seqlen_k, device=device), \"s -> 1 s\")\n        cache_seqlens_expanded = rearrange(cache_seqlens, \"b -> b 1\")\n        if not new_kv:\n            key_padding_mask = arange < cache_seqlens_expanded\n        else:\n            k_new_seqlens = key_new_padding_mask.sum(-1, keepdims=True) if varlen_q else seqlen_new\n            key_padding_mask = arange < cache_seqlens_expanded + k_new_seqlens\n        if has_leftpad:\n            key_padding_mask = torch.logical_and(\n                key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k)\n            )\n        # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)\n        rotary_seqlens = cache_seqlens if not has_rotary_seqlens else cache_seqlens // 2\n        if rotary_dim > 0:\n            angle = (\n                torch.rand(\n                    seqlen_k if page_size is None else num_blocks * page_size,\n                    rotary_dim // 2,\n                    device=device,\n                )\n                * 2\n                * math.pi\n            )\n            cos = torch.cos(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref)\n            sin = torch.sin(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref)\n            if causal or local:\n                q_ro = apply_rotary_emb(\n                    q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved\n                )\n            else:\n                q_ro = rearrange(\n                    apply_rotary_emb(\n                        rearrange(q, \"b s h d -> b 1 (s h) d\"),\n                        cos,\n                        sin,\n                        seqlen_offsets=rotary_seqlens,\n                        interleaved=rotary_interleaved,\n                    ),\n                    \"b 1 (s h) d -> b s h d\",\n                    s=seqlen_q,\n                )\n            # q_ro = q\n            k_ro = apply_rotary_emb(\n                k, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved\n            )\n        else:\n            cos, sin = None, None\n            q_ro, k_ro = q, k\n        # k_cache[:, 64:] = -1\n        k_cache_ref = (k_cache if not has_batch_idx else k_cache[cache_batch_idx]).clone()\n        v_cache_ref = (v_cache if not has_batch_idx else v_cache[cache_batch_idx]).clone()\n        if new_kv:\n            update_mask = torch.logical_and(\n                cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + k_new_seqlens\n            )\n            k_to_update = rearrange(k_ro, \"b s ... -> (b s) ...\")\n            v_to_update = rearrange(v, \"b s ... -> (b s) ...\")\n            if varlen_q:\n                k_to_update = k_to_update[indices_k]\n                v_to_update = v_to_update[indices_k]\n            k_cache_ref[update_mask] = k_to_update\n            v_cache_ref[update_mask] = v_to_update\n        k_cache_rep = repeat(k_cache_ref, \"b s h d -> b s (h g) d\", g=nheads // nheads_k)\n        v_cache_rep = repeat(v_cache_ref, \"b s h d -> b s (h g) d\", g=nheads // nheads_k)\n        out_ref, _ = attention_ref(\n            q_ro,\n            k_cache_rep,\n            v_cache_rep,\n            query_padding_mask,\n            key_padding_mask,\n            causal=causal,\n            qv=qv,\n            window_size=window_size,\n            attention_chunk=attention_chunk,\n            key_leftpad=cache_leftpad,\n        )\n        out_pt, _ = attention_ref(\n            q_ro,\n            k_cache_rep,\n            v_cache_rep,\n            query_padding_mask,\n            key_padding_mask,\n            causal=causal,\n            qv=qv,\n            window_size=window_size,\n            attention_chunk=attention_chunk,\n            upcast=False,\n            reorder_ops=True,\n            key_leftpad=cache_leftpad,\n            intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None\n        )\n        q = q.to(dtype)\n        q_unpad = q_unpad.to(dtype) if varlen_q else None\n        k_cache = k_cache.to(dtype)\n        v_cache = v_cache.to(dtype)\n        k_cache_paged = k_cache_paged.to(dtype) if page_size is not None else None\n        v_cache_paged = v_cache_paged.to(dtype) if page_size is not None else None\n        k = k.to(dtype) if k is not None else None\n        v = v.to(dtype) if v is not None else None\n        k_unpad = k_unpad.to(dtype) if k_unpad is not None else None\n        v_unpad = v_unpad.to(dtype) if v_unpad is not None else None\n        qv = qv.to(dtype) if qv is not None else None\n        qv_unpad = qv_unpad.to(dtype) if (varlen_q and qv is not None) else None\n        cos = cos.to(dtype) if cos is not None else None\n        sin = sin.to(dtype) if sin is not None else None\n        k_cache_saved = k_cache.clone() if page_size is None else k_cache_paged.clone()\n        v_cache_saved = v_cache.clone() if page_size is None else v_cache_paged.clone()\n        num_splits_vals = [1, 0] if not DISABLE_SPLIT else [1]\n        precompute_metadata_vals = [False]\n        for num_splits, precompute_metadata in itertools.product(num_splits_vals, precompute_metadata_vals):\n            if precompute_metadata:\n                scheduler_metadata = get_scheduler_metadata(\n                    batch_size, max_seqlen_q if varlen_q else seqlen_q, seqlen_k, nheads, nheads_k, d,\n                    cache_seqlens, q.dtype, headdim_v=dv, cu_seqlens_q=cu_seqlens_q,\n                    cu_seqlens_k_new=cu_seqlens_k_new, cache_leftpad=cache_leftpad,\n                    max_seqlen_k_new=seqlen_new, page_size=page_size,\n                    causal=causal, window_size=window_size, attention_chunk=attention_chunk,\n                    num_splits=num_splits\n                )\n            else:\n                scheduler_metadata = None\n            # Repeat to test metadata reuse\n            for _ in range(1 if not precompute_metadata else 2):\n                if page_size is None:\n                    k_cache.copy_(k_cache_saved)\n                    v_cache.copy_(v_cache_saved)\n                else:\n                    k_cache_paged.copy_(k_cache_saved)\n                    v_cache_paged.copy_(v_cache_saved)\n                out, lse, *rest = flash_attn_with_kvcache(\n                    q if not varlen_q else q_unpad,\n                    k_cache if page_size is None else k_cache_paged,\n                    v_cache if page_size is None else v_cache_paged,\n                    k if not new_kv or not varlen_q else k_unpad,\n                    v if not new_kv or not varlen_q else v_unpad,\n                    qv=qv if not varlen_q else qv_unpad,\n                    rotary_cos=cos,\n                    rotary_sin=sin,\n                    cache_seqlens=cache_seqlens,\n                    cache_batch_idx=cache_batch_idx,\n                    cache_leftpad=cache_leftpad,\n                    page_table=page_table,\n                    cu_seqlens_q=cu_seqlens_q,\n                    cu_seqlens_k_new=cu_seqlens_k_new,\n                    max_seqlen_q=max_seqlen_q,\n                    rotary_seqlens=rotary_seqlens,\n                    causal=causal,\n                    window_size=window_size,\n                    attention_chunk=attention_chunk,\n                    rotary_interleaved=rotary_interleaved,\n                    scheduler_metadata=scheduler_metadata,\n                    num_splits=num_splits,\n                    return_softmax_lse=True\n                )\n                if varlen_q:\n                    out = output_pad_fn(out)\n                # out = flash_attn_with_kvcache(\n                #     q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size\n                # )\n                # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size)\n                # qk = torch.einsum(\"bqhd,bkhd->bhqk\", q, k_cache_ref)\n                # m = qk.amax(-1, keepdim=True)\n                # s_tmp = torch.exp((qk - m) / math.sqrt(d))\n                # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref)\n                # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)\n                # probs = torch.softmax(qk, dim=-1)\n                print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n                print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n                print(f\"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}\")\n                print(f\"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}\")\n                # breakpoint()\n\n                # Check that FlashAttention's numerical error is at most twice the numerical error\n                # of a Pytorch implementation.\n                if new_kv:\n                    if page_size is None:\n                        k_cache_select = (\n                            k_cache.to(dtype_ref) if not has_batch_idx else k_cache.to(dtype_ref)[cache_batch_idx]\n                        )\n                        v_cache_select = (\n                            v_cache.to(dtype_ref) if not has_batch_idx else v_cache.to(dtype_ref)[cache_batch_idx]\n                        )\n                    else:\n                        k_cache_select = rearrange(\n                            k_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()],\n                            \"(b nblocks) block_size ... -> b (nblocks block_size) ...\",\n                            b=batch_size,\n                        )[:, :seqlen_k].to(dtype_ref)\n                        v_cache_select = rearrange(\n                            v_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()],\n                            \"(b nblocks) block_size ... -> b (nblocks block_size) ...\",\n                            b=batch_size,\n                        )[:, :seqlen_k].to(dtype_ref)\n                    k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref)\n                    v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref)\n                    if dtype is not torch.float8_e4m3fn:\n                        assert torch.equal(v_cache_select, v_cache_ref)\n                    else:\n                        assert torch.allclose(v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3)\n                    # breakpoint()\n                    # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn:\n                    if rotary_dim == 0:\n                        assert torch.equal(k_cache_select, k_cache_ref)\n                    else:\n                        # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3):\n                        #     breakpoint()\n                        if dtype is not torch.float8_e4m3fn:\n                            assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3)\n                        else:\n                            assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1)\n                mult = 4 if dtype == torch.float8_e4m3fn else 2\n                assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5\n                mult_mean = 3 if dtype == torch.float8_e4m3fn else 1.5\n                assert (out - out_ref).abs().mean().item() <= mult_mean * (out_pt - out_ref).abs().mean().item()\n\n\ndef _generate_block_kvcache(seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype, dtype_ref):\n    num_blocks = math.ceil(seqlen_k / page_size) * batch_size * 3\n    k_cache_paged = torch.randn(\n        num_blocks, page_size, nheads_k, d, device=device, dtype=dtype_ref\n    ).to(dtype).to(dtype_ref)\n    v_cache_paged = torch.randn(\n        num_blocks, page_size, nheads_k, dv, device=device, dtype=dtype_ref\n    ).to(dtype).to(dtype_ref)\n    page_table = rearrange(\n        torch.randperm(num_blocks, dtype=torch.int32, device=device),\n        \"(b nblocks) -> b nblocks\",\n        b=batch_size,\n    )\n    k_cache = rearrange(\n        k_cache_paged[page_table.flatten()],\n        \"(b nblocks) block_size ... -> b (nblocks block_size) ...\",\n        b=batch_size,\n    )[:, :seqlen_k]\n    v_cache = rearrange(\n        v_cache_paged[page_table.flatten()],\n        \"(b nblocks) block_size ... -> b (nblocks block_size) ...\",\n        b=batch_size,\n    )[:, :seqlen_k]\n    return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.bfloat16])\n@pytest.mark.parametrize(\"causal\", [False, True])\n# @pytest.mark.parametrize('causal', [False])\n@pytest.mark.parametrize('d', [128])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_k\",\n    [\n        (64, 8192),\n    ],\n)\ndef test_flash_attn_cluster(seqlen_q, seqlen_k, d, causal, dtype):\n    device = \"cuda\"\n    torch.random.manual_seed(0)\n    batch_size = 2\n    nheads = 16\n    nheads_kv = 4\n    # There was a bug where this would cause \"unspecified launch failure\" due to Cluster\n    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype)\n    k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype)\n    v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype)\n    for _ in range(100):\n        flash_attn_func(q, k, v, causal=causal)\n\n\n# @pytest.mark.parametrize(\"dtype\", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))\n@pytest.mark.parametrize(\"dtype\", [torch.bfloat16])\n@pytest.mark.parametrize(\"causal\", [False, True])\n# @pytest.mark.parametrize('causal', [False])\n@pytest.mark.parametrize(\"d\", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize(\"d\", [32, 40, 59, 64, 80, 96, 111, 128])\n# @pytest.mark.parametrize('d', [32, 56, 64, 80, 96, 128])\n# @pytest.mark.parametrize(\"d\", [32, 64, 96, 128, 160, 192])\n# @pytest.mark.parametrize('d', [80])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_k\",\n    [\n        (1, 239),\n        (239, 1),\n        (3, 799),\n        (799, 3),\n        (1024, 128),\n        (97, 97),\n        (128, 128),\n        (200, 200),\n        (256, 256),\n        (257, 257),\n        (384, 384),\n        (512, 512),\n        (768, 768),\n        (1024, 1024),\n        (2048, 2048),\n    ],\n)\n@pytest.mark.skip(reason=\"Cannot be run in parallel with other tests due to memory usage\")\ndef test_flash_attn_race_condition(seqlen_q, seqlen_k, d, causal, dtype):\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(0)\n    # Simulate under memory load\n    dummy = torch.empty(70 * 1024 ** 3, dtype=torch.uint8, device=device)\n    batch_size = 60  # Sometimes we need large batch size for the race conditions to trigger\n    nheads = 4\n    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    torch.random.manual_seed(42)\n    out0 = flash_attn_func(q, k, v, causal=causal)\n    g = torch.randn_like(out0)\n    dq0, dk0, dv0 = torch.autograd.grad(out0, (q, k, v), g)\n    # Numerical error if we just do any arithmetic on dq\n    dq_atol = 2 * ((dq0 + 0.3 - 0.3) - dq0).abs().max().item()\n\n    for i in range(1000):\n        torch.random.manual_seed(42)\n        out = flash_attn_func(q, k, v, causal=causal)\n        assert torch.equal(out, out0)\n        # assert torch.equal(lse, lse0)\n\n        dq, dk, dv = torch.autograd.grad(out, (q, k, v), g)\n        dq_equal = torch.allclose(dq, dq0, atol=dq_atol)\n        if not dq_equal:\n            print(f\"Iter {i}, {dq_atol = }, dQ max diff: {(dq - dq0).abs().max().item()}\")\n            # breakpoint()\n        assert torch.equal(dv, dv0)\n        assert torch.equal(dk, dk0)\n        assert dq_equal\n\n\ndef attention_combine_ref(out_partial, lse_partial):\n    \"\"\"\n    out_partial: (num_splits, batch_size, seqlen, nheads, d)\n    lse_partial: (num_splits, batch_size, nheads, seqlen)\n    \"\"\"\n    lse = torch.logsumexp(lse_partial, dim=0)\n    scale = torch.exp(lse_partial - lse)\n    scale = torch.where(torch.isinf(scale) | torch.isnan(scale), torch.zeros_like(scale), scale)\n    out = (scale.unsqueeze(-1) * out_partial).sum(0)\n    return out, lse\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float32, torch.float16, torch.bfloat16])\n# @pytest.mark.parametrize(\"dtype\", [torch.float32])\n# @pytest.mark.parametrize(\"d\", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])\n@pytest.mark.parametrize(\"d\", [64, 96, 128, 192, 256, 512])\n# @pytest.mark.parametrize(\"d\", [128])\n@pytest.mark.parametrize(\"seqlen\", [1, 2, 3, 32, 64, 256, 113, 108, 640, 1024])\n# @pytest.mark.parametrize(\"seqlen\", [12, 32, 64, 256, 112, 108, 640, 1024, 2048, 8192])\n# @pytest.mark.parametrize(\"seqlen\", [15])\n@pytest.mark.parametrize(\"num_splits\", [1, 2, 3, 5, 17, 32, 55, 97, 133])\n# @pytest.mark.parametrize(\"num_splits\", [1, 2, 3, 5, 11])\n# @pytest.mark.parametrize(\"num_splits\", [128])\ndef test_flash_attn_combine(num_splits, seqlen, d, dtype):\n    if DISABLE_SPLIT:\n        pytest.skip()\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(1)\n    batch_size = 5\n    nheads = 16\n    # batch_size = 1\n    # nheads = 1\n    out_partial = torch.randn(num_splits * 2, batch_size, nheads, seqlen, d, device=device, dtype=torch.float32).transpose(2, 3)[:num_splits]  # To test non-contiguous tensor\n    lse_partial = torch.randn(num_splits, batch_size, nheads * 2, seqlen, device=device, dtype=torch.float32).transpose(-1, -2)[:, :, :, :nheads]  # To test non-contiguous tensor\n    # To test short-circuiting based on num_splits\n    lse_partial[num_splits // 2:, :batch_size // 3] = -float(\"inf\")\n    out, lse = flash_attn_combine(out_partial, lse_partial, out_dtype=dtype)\n    out_ref, lse_ref = attention_combine_ref(out_partial, lse_partial)\n    out_pt = out_ref.to(dtype)\n\n    print(f\"LSE max diff: {(lse - lse_ref).abs().max().item()}\")\n    print(f\"LSE mean diff: {(lse - lse_ref).abs().mean().item()}\")\n    print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n    print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n    print(f\"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}\")\n    print(f\"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}\")\n    # breakpoint()\n\n    assert torch.allclose(lse, lse_ref, atol=1e-5, rtol=1e-5)\n    multiple = 2\n    assert ((out - out_ref).abs().max().item() <= multiple * (out_pt - out_ref).abs().max().item()) or torch.allclose(out, out_pt, atol=1e-5, rtol=1e-5)\n\n    # from flash_attn.utils.benchmark import pytorch_profiler\n    # # pytorch_profiler(torch.sum, lse_partial)\n    # pytorch_profiler(flash_attn_combine, out_partial, lse_partial)\n    # pytorch_profiler(torch.sum, out_partial)\n\n@pytest.mark.skip(reason=\"AMD Triton backend doesn't use torch ops registration\")\ndef test_flash3_bw_compatibility() -> None:\n    # Let's try to always stay backward compatible! This will make life easier\n    # for downstream libaries, users, and exported models.\n    # 1/ Instead of removing arguments, error out if their value is no longer supported\n    # 2/ When adding arguments, add them at the end with a default value\n    assert torch.ops.flash_attn_3.fwd.default._schema.is_backward_compatible_with(parse_schema(\n        \"flash_attn_3::fwd(Tensor q, Tensor k, Tensor v, Tensor(k_new!)? k_new=None, \"\n        \"Tensor(v_new!)? v_new=None, Tensor? q_v=None, Tensor(out!)? out=None, \"\n        \"Tensor? cu_seqlens_q=None, Tensor? cu_seqlens_k=None, \"\n        \"Tensor? cu_seqlens_k_new=None, Tensor? seqused_q=None, Tensor? seqused_k=None, \"\n        \"int? max_seqlen_q=None, int? max_seqlen_k=None, Tensor? page_table=None, \"\n        \"Tensor? kv_batch_idx=None, Tensor? leftpad_k=None, Tensor? rotary_cos=None, Tensor? rotary_sin=None, \"\n        \"Tensor? seqlens_rotary=None, Tensor? q_descale=None, Tensor? k_descale=None, Tensor? v_descale=None, \"\n        \"float? softmax_scale=None, bool is_causal=False, int window_size_left=-1, int window_size_right=-1, \"\n        \"int attention_chunk=0, float softcap=0., bool is_rotary_interleaved=False, \"\n        \"Tensor? scheduler_metadata=None, int num_splits=0, bool? pack_gqa=None, int sm_margin=0) \"\n        \"-> (Tensor(out!), Tensor, Tensor, Tensor)\"\n    ))\n    assert torch.ops.flash_attn_3.bwd.default._schema.is_backward_compatible_with(parse_schema(\n        \"flash_attn_3::bwd(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, \"\n        \"Tensor(dq!)? dq=None, Tensor(dk!)? dk=None, Tensor(dv!)? dv=None, Tensor? cu_seqlens_q=None, \"\n        \"Tensor? cu_seqlens_k=None, Tensor? seqused_q=None, Tensor? seqused_k=None, int? max_seqlen_q=None, \"\n        \"int? max_seqlen_k=None, float? softmax_scale=None, bool is_causal=False, int window_size_left=-1, \"\n        \"int window_size_right=-1, float softcap=0., bool deterministic=False, int sm_margin=0) \"\n        \"-> (Tensor(dq!), Tensor(dk!), Tensor(dv!), Tensor, Tensor, Tensor, Tensor, Tensor)\"\n    ))\n    assert torch.ops.flash_attn_3.fwd_combine.default._schema.is_backward_compatible_with(parse_schema(\n        \"flash_attn_3::fwd_combine(Tensor out_partial, Tensor lse_partial, Tensor(out!)? out=None, \"\n        \"ScalarType? out_dtype=None) -> (Tensor(out!), Tensor)\"\n    ))\n    assert torch.ops.flash_attn_3.get_scheduler_metadata.default._schema.is_backward_compatible_with(parse_schema(\n        \"flash_attn_3::get_scheduler_metadata(int batch_size, int max_seqlen_q, int max_seqlen_k, \"\n        \"int num_heads, int num_heads_k, int headdim, int headdim_v, ScalarType qkv_dtype, Tensor seqused_k, \"\n        \"Tensor? cu_seqlens_q=None, Tensor? cu_seqlens_k=None, Tensor? cu_seqlens_k_new=None, \"\n        \"Tensor? seqused_q=None, Tensor? leftpad_k=None, int? page_size=None, int max_seqlen_k_new=0, \"\n        \"bool is_causal=False, int window_size_left=-1, int window_size_right=-1, \"\n        \"int attention_chunk=0, bool has_softcap=False, int num_splits=0, bool? pack_gqa=None, \"\n        \"int sm_margin=0) -> Tensor\"\n    ))\n"
  },
  {
    "path": "hopper/test_kvcache.py",
    "content": "import torch\n#from flash_attn_interface import flash_attn_func, flash_attn_varlen_func, flash_attn_with_kvcache\nimport flash_attn_interface as fa3\nimport flash_attn as fa2\nimport torch.utils.benchmark as benchmark\nimport time\n\nimport argparse\nimport math\n\nparser = argparse.ArgumentParser(description='Process some integers.')\nparser.add_argument('--causal', action='store_true')\nparser.add_argument('--splits', type=int, default=1)\nparser.add_argument('--repeats', type=int, default=10)\nparser.add_argument('--validate', action='store_true')\nparser.add_argument('--gqa', action='store_true')\n\nargs = parser.parse_args()\n\ndef benchmark_fa_kv_old(fn, repeats=10, desc='', verbose=True, **kwinputs):\n    \"\"\"Use Pytorch Benchmark on the forward pass of an arbitrary function.\"\"\"\n    if verbose:\n        print(desc, '- Forward pass')\n    t = benchmark.Timer(\n            stmt='fn(**kwinputs)',\n            globals={'fn': fn, 'kwinputs': kwinputs},\n            num_threads=torch.get_num_threads(),\n            )\n    m = t.timeit(repeats)\n    if verbose:\n        print(desc, m)\n    return t, m\n\ndef benchmark_fa_kv(fn, repeats=10, *args, **kwargs):\n    # warmup\n    for _ in range(5):\n        fn(*args, **kwargs)\n    niters = repeats\n    torch.cuda.synchronize()\n    start = time.time()\n    for _ in range(niters):\n        fn(*args, **kwargs)\n    torch.cuda.synchronize()\n    end = time.time()\n    return (end - start) / niters\n\ndef main():\n    # *SAMPLE CONFIG*\n    # Model arch params:\n    nheads_q = 64\n    nheads_kv = 8\n    headdim = 128\n    #dtype = torch.bfloat16\n    dtype = torch.float16\n\n    # Cache settings:\n    num_caches = 8\n    cache_seqlen = 1024 * 16\n\n    # Batching settings\n    ntokens = 1024\n    max_queries_per_batch = 4\n    small_request_ntokens = 16\n\n    # Input settings\n    query_seqlens = [900, 12, 1]\n    num_queries = len(query_seqlens)\n    # Need to add empty queries to fill out `max_queries_per_batch`\n    num_padding_queries = max_queries_per_batch - num_queries\n    context_seqlens = [4096, 5120*2, 6145*2]\n    #context_seqlens = [4096, 5120*2, 6152*2]\n\n    # Validation\n    assert sum(query_seqlens) <= ntokens\n    assert all(s < small_request_ntokens for s in query_seqlens[1:])\n    assert num_queries <= max_queries_per_batch\n    assert all(s < cache_seqlen for s in context_seqlens)\n\n    torch.manual_seed(5434)\n\n    # Allocate some tensors\n    k_cache = torch.randn(\n        (num_caches, cache_seqlen, nheads_kv, headdim), device=\"cuda\", dtype=dtype\n    )\n    v_cache = torch.randn(\n        (num_caches, cache_seqlen, nheads_kv, headdim), device=\"cuda\", dtype=dtype\n    )\n\n    q_buf_large = torch.randn(\n        (1, ntokens, nheads_q, headdim), device=\"cuda\", dtype=dtype\n    )\n    cache_seqlen_large = torch.tensor(\n        [context_seqlens[0]], dtype=torch.int32, device=\"cuda\"\n    )\n    cache_idx_large = torch.tensor([1], dtype=torch.int32, device=\"cuda\")\n\n    q_buf_small = torch.randn(\n        (max_queries_per_batch - 1, small_request_ntokens, nheads_q, headdim),\n        device=\"cuda\",\n        dtype=dtype,\n    )\n    cache_seqlens_small = torch.tensor(\n        context_seqlens[1:] + [0] * num_padding_queries, dtype=torch.int32, device=\"cuda\"\n    )\n    cache_idxs_small = torch.randperm(num_caches, dtype=torch.int32, device=\"cuda\")[\n        : max_queries_per_batch - 1\n    ]\n\n    if args.validate:\n        # Call flash attn\n        # First for the single full-sized query\n        out0, lse0 = fa3.flash_attn_with_kvcache(\n            q=q_buf_large,\n            k_cache=k_cache,\n            v_cache=v_cache,\n            cache_seqlens=cache_seqlen_large,\n            cache_batch_idx=cache_idx_large,\n            causal=bool(args.causal),\n            num_splits=args.splits,\n            return_softmax_lse=True,\n           #num_splits=1\n        )   \n\n         # Second for n-1 small queries\n        out1_split1, lse1_split1 = fa3.flash_attn_with_kvcache(\n            q=q_buf_small,\n            k_cache=k_cache,\n            v_cache=v_cache,\n            cache_seqlens=cache_seqlens_small,\n            cache_batch_idx=cache_idxs_small,\n            causal=bool(args.causal),\n            num_splits=1,\n            gqa_decoding=bool(args.gqa),\n            return_softmax_lse=True,\n        )\n\n        # Second for n-1 small queries\n        out1, lse1 = fa3.flash_attn_with_kvcache(\n            q=q_buf_small,\n            k_cache=k_cache,\n            v_cache=v_cache,\n            cache_seqlens=cache_seqlens_small,\n            cache_batch_idx=cache_idxs_small,\n            causal=bool(args.causal),\n            num_splits=args.splits,\n            gqa_decoding=bool(args.gqa),\n            return_softmax_lse=True,\n        )\n\n        # Call flash attn\n        # First for the single full-sized query\n        out2 = fa2.flash_attn_with_kvcache(\n            q=q_buf_large,\n            k_cache=k_cache,\n            v_cache=v_cache,\n            cache_seqlens=cache_seqlen_large,\n            cache_batch_idx=cache_idx_large,\n            causal=bool(args.causal),\n            num_splits=args.splits,\n        )\n\n        print ('big')\n        print ('diff-max', (out0 - out2).abs().max().item(), cache_seqlens_small)\n        print ('diff-mean', (out0 - out2).abs().mean().item())\n\n\n        # Second for n-1 small queries\n        out3, lse_fa2 = fa2.flash_attn_with_kvcache(\n            q=q_buf_small,\n            k_cache=k_cache,\n            v_cache=v_cache,\n            cache_seqlens=cache_seqlens_small,\n            cache_batch_idx=cache_idxs_small,\n            causal=bool(args.causal),\n            num_splits=args.splits,\n            return_softmax_lse=True,\n            #num_splits=1\n        )\n\n        print ('small') #, out1)\n        print ('lse', lse1, lse_fa2, (lse1 - lse_fa2).abs(), out1.shape)\n        print ('lse-dif-max', (lse1 - lse_fa2).abs().max().item())\n        print ('diff-max', (out1 - out3).abs().max().item())\n        print ('diff-mean', (out1 - out3).abs().mean().item())\n\n\n    print ('fa3', args.repeats)\n    time_fa3_big = benchmark_fa_kv(fa3.flash_attn_with_kvcache, repeats=args.repeats, \n        q=q_buf_large,\n        k_cache=k_cache,\n        v_cache=v_cache,\n        cache_seqlens=cache_seqlen_large,\n        cache_batch_idx=cache_idx_large,\n        causal=bool(args.causal),\n        num_splits=args.splits,\n    )\n\n    time_fa3_small = benchmark_fa_kv(fa3.flash_attn_with_kvcache, repeats=args.repeats,\n        q=q_buf_small,\n        k_cache=k_cache,\n        v_cache=v_cache,\n        cache_seqlens=cache_seqlens_small,\n        cache_batch_idx=cache_idxs_small,\n        causal=bool(args.causal),\n        num_splits=args.splits,\n    )\n\n    print ('fa2 ')\n\n    time_fa2_big = benchmark_fa_kv(fa2.flash_attn_with_kvcache, repeats=args.repeats, \n            q=q_buf_large,\n            k_cache=k_cache,\n            v_cache=v_cache,\n            cache_seqlens=cache_seqlen_large,\n            cache_batch_idx=cache_idx_large,\n            causal=bool(args.causal),\n            num_splits=args.splits\n    )\n\n    time_fa2_small = benchmark_fa_kv(fa2.flash_attn_with_kvcache, repeats=args.repeats, \n            q=q_buf_small,\n            k_cache=k_cache,\n            v_cache=v_cache,\n            cache_seqlens=cache_seqlens_small,\n            cache_batch_idx=cache_idxs_small,\n            causal=bool(args.causal),\n            num_splits=args.splits\n    )\n\n    print ('big (split, fa3, fa2, ratio):', args.splits, time_fa3_big * 1000000, time_fa2_big * 1000000, time_fa3_big / time_fa2_big)\n    print ('small (split, fa3, fa2, ratio):', args.splits, time_fa3_small * 1000000, time_fa2_small * 1000000, time_fa3_small / time_fa2_small)\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "hopper/test_torch_compile_and_export.py",
    "content": "import torch\nfrom flash_attn_interface import flash_attn_func\nfrom torch import nn\n\n\nclass EfficienctMultiHeadAttention(nn.Module):\n    def __init__(self, embed_size, num_heads, dropout=0.0, use_flash_attn=True):\n        super().__init__()\n        assert embed_size % num_heads == 0, f\"{embed_size=} {num_heads=}\"\n\n        self.embed_size = embed_size\n        self.num_heads = num_heads\n        self.head_dim = embed_size // num_heads\n        self.use_flash_attn = use_flash_attn and (flash_attn_func is not None)\n\n        self.qkv_proj = nn.Linear(embed_size, 3 * embed_size)\n        self.out_proj = nn.Linear(embed_size, embed_size)\n        self.dropout = dropout\n\n    def forward(self, x, attention_mask=None):\n        N, seq_length, _ = x.shape\n\n        qkv = self.qkv_proj(x)\n        q, k, v = qkv.chunk(3, dim=-1)\n\n        q = q.view(N, seq_length, self.num_heads, self.head_dim)\n        k = k.view(N, seq_length, self.num_heads, self.head_dim)\n        v = v.view(N, seq_length, self.num_heads, self.head_dim)\n\n        if self.use_flash_attn and attention_mask is None:\n            out = flash_attn_func(\n                q, k, v\n            )\n        out = out.reshape(N, seq_length, self.embed_size)\n        out = self.out_proj(out)\n        return out\n\n\ndef create_model(batch_size=16, sequence_length=256, embedding_dim=2048, num_heads=16):\n    model = EfficienctMultiHeadAttention(embedding_dim, num_heads).cuda().bfloat16()\n    input_tensor = torch.randn(batch_size, sequence_length, embedding_dim).cuda().bfloat16()\n    return model, input_tensor\n\n\ndef test_export_model():\n    model, input_tensor = create_model()\n    expected = torch.compile(model, backend=\"aot_eager\")(input_tensor)\n    loss = expected.sum()\n    loss.backward()\n\n    ep = torch.export.export(model, (input_tensor,))\n    got = ep.module()(input_tensor,)\n    assert torch.equal(expected, got)\n\n    loss_2 = got.sum()\n    loss_2.backward()\n\n    assert torch.equal(loss, loss_2)\n\n\ndef test_compile_and_package_model():\n    model, input_tensor = create_model()\n    expected = torch.compile(model, backend=\"aot_eager\")(input_tensor)\n\n    exported = torch.export.export(model, (input_tensor,))\n    torch._inductor.aoti_compile_and_package(\n        exported,\n        package_path=\"model.pt2\",\n    )\n\n    compiled_model = torch._inductor.package.load_package(\"model.pt2\")\n    out = compiled_model(input_tensor,)\n    assert torch.equal(expected, out)\n"
  },
  {
    "path": "hopper/test_util.py",
    "content": "import math\n\nimport torch\nfrom einops import rearrange, repeat\n\nfrom padding import pad_input, unpad_input\n\n\ndef generate_random_padding_mask(max_seqlen, batch_size, device, mode=\"random\", zero_lengths=False):\n    assert mode in [\"full\", \"random\", \"third\"]\n    if mode == \"full\":\n        lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32)\n    elif mode == \"random\":\n        lengths = torch.randint(\n            max(0 if zero_lengths else 1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device\n        )\n    elif mode == \"third\":\n        lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device)\n\n    if zero_lengths:\n        # Generate zero-lengths every 5 batches and the last batch.\n        for i in range(batch_size):\n            if i % 5 == 0:\n                lengths[i] = 0\n        lengths[-1] = 0\n    padding_mask = (\n        repeat(torch.arange(max_seqlen, device=device), \"s -> b s\", b=batch_size) < lengths\n    )\n    return padding_mask\n\n\ndef generate_qkv(\n    q, k, v, query_padding_mask=None, key_padding_mask=None, qv=None, kvpacked=False, qkvpacked=False,\n    query_unused_mask=None, key_unused_mask=None,\n):\n    \"\"\"\n    Arguments:\n        q: (batch_size, seqlen_q, nheads, d)\n        k: (batch_size, seqlen_k, nheads_k, d)\n        v: (batch_size, seqlen_k, nheads_k, d_v)\n        query_padding_mask: (batch_size, seqlen), bool\n        key_padding_mask: (batch_size, seqlen), bool\n    \"\"\"\n    assert not (kvpacked and qkvpacked)\n    batch_size, seqlen_q, nheads, d = q.shape\n    d_v = v.shape[-1]\n    _, seqlen_k, nheads_k, _ = k.shape\n    assert k.shape == (batch_size, seqlen_k, nheads_k, d)\n    assert v.shape == (batch_size, seqlen_k, nheads_k, d_v)\n    if query_unused_mask is not None or key_unused_mask is not None:\n        assert not kvpacked\n        assert not qkvpacked\n\n    if query_padding_mask is not None:\n        q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q = unpad_input(\n            q, query_padding_mask, query_unused_mask\n        )\n        output_pad_fn = lambda output_unpad: pad_input(\n            output_unpad, indices_q, batch_size, seqlen_q\n        )\n        qv_unpad = rearrange(qv, \"b s ... -> (b s) ...\")[indices_q] if qv is not None else None\n    else:\n        q_unpad = rearrange(q, \"b s h d -> (b s) h d\")\n        cu_seqlens_q = torch.arange(\n            0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device\n        )\n        seqused_q = None\n        max_seqlen_q = seqlen_q\n        output_pad_fn = lambda output_unpad: rearrange(\n            output_unpad, \"(b s) h d -> b s h d\", b=batch_size\n        )\n        qv_unpad = rearrange(qv, \"b s ... -> (b s) ...\") if qv is not None else None\n\n    if key_padding_mask is not None:\n        k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k = unpad_input(\n            k, key_padding_mask, key_unused_mask\n        )\n        v_unpad, *rest = unpad_input(v, key_padding_mask, key_unused_mask)\n    else:\n        k_unpad = rearrange(k, \"b s h d -> (b s) h d\")\n        v_unpad = rearrange(v, \"b s h d -> (b s) h d\")\n        cu_seqlens_k = torch.arange(\n            0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device\n        )\n        seqused_k = None\n        max_seqlen_k = seqlen_k\n\n    if qkvpacked:\n        assert (query_padding_mask == key_padding_mask).all()\n        assert nheads == nheads_k\n        qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)\n        qkv = torch.stack([q, k, v], dim=2)\n        if query_padding_mask is not None:\n            dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q)\n        else:\n            dqkv_pad_fn = lambda dqkv_unpad: rearrange(\n                dqkv_unpad, \"(b s) t h d -> b s t h d\", b=batch_size\n            )\n        return (\n            qkv_unpad.detach().requires_grad_(),\n            cu_seqlens_q,\n            max_seqlen_q,\n            qkv.detach().requires_grad_(),\n            output_pad_fn,\n            dqkv_pad_fn,\n        )\n    elif kvpacked:\n        kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)\n        kv = torch.stack([k, v], dim=2)\n        dq_pad_fn = output_pad_fn\n        if key_padding_mask is not None:\n            dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k)\n        else:\n            dkv_pad_fn = lambda dkv_unpad: rearrange(\n                dkv_unpad, \"(b s) t h d -> b s t h d\", b=batch_size\n            )\n        return (\n            q_unpad.detach().requires_grad_(),\n            kv_unpad.detach().requires_grad_(),\n            cu_seqlens_q,\n            cu_seqlens_k,\n            max_seqlen_q,\n            max_seqlen_k,\n            q.detach().requires_grad_(),\n            kv.detach().requires_grad_(),\n            output_pad_fn,\n            dq_pad_fn,\n            dkv_pad_fn,\n        )\n    else:\n        dq_pad_fn = output_pad_fn\n        if key_padding_mask is not None:\n            dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k)\n        else:\n            dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, \"(b s) h d -> b s h d\", b=batch_size)\n        return (\n            q_unpad.detach().requires_grad_(),\n            k_unpad.detach().requires_grad_(),\n            v_unpad.detach().requires_grad_(),\n            qv_unpad.detach()  if qv is not None else None,\n            cu_seqlens_q,\n            cu_seqlens_k,\n            seqused_q,\n            seqused_k,\n            max_seqlen_q,\n            max_seqlen_k,\n            q.detach().requires_grad_(),\n            k.detach().requires_grad_(),\n            v.detach().requires_grad_(),\n            qv.detach() if qv is not None else None,\n            output_pad_fn,\n            dq_pad_fn,\n            dk_pad_fn,\n        )\n\n\ndef construct_local_mask(\n    seqlen_q,\n    seqlen_k,\n    window_size=(-1, -1),  # -1 means infinite window size\n    sink_token_length=0,\n    query_padding_mask=None,\n    key_padding_mask=None,\n    key_leftpad=None,\n    device=None,\n):\n    row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), \"s -> s 1\")\n    col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)\n    if key_leftpad is not None:\n        key_leftpad = rearrange(key_leftpad, \"b -> b 1 1 1\")\n        col_idx = repeat(col_idx, \"s -> b 1 1 s\", b=key_leftpad.shape[0])\n        col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)\n    sk = (\n        seqlen_k\n        if key_padding_mask is None\n        else rearrange(key_padding_mask.sum(-1), \"b -> b 1 1 1\")\n    )\n    sq = (\n        seqlen_q\n        if query_padding_mask is None\n        else rearrange(query_padding_mask.sum(-1), \"b -> b 1 1 1\")\n    )\n    if window_size[0] < 0:\n        return col_idx > row_idx + sk - sq + window_size[1]\n    else:\n        sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk\n        return torch.logical_or(\n            col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),\n            torch.logical_and(col_idx < row_idx + sk - sq - window_size[0], col_idx >= sink_token_length),\n        )\n\n\ndef construct_chunk_mask(\n    seqlen_q,\n    seqlen_k,\n    attention_chunk,\n    query_padding_mask=None,\n    key_padding_mask=None,\n    key_leftpad=None,\n    device=None,\n):\n    row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), \"s -> s 1\")\n    col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)\n    if key_leftpad is not None:\n        key_leftpad = rearrange(key_leftpad, \"b -> b 1 1 1\")\n        col_idx = repeat(col_idx, \"s -> b 1 1 s\", b=key_leftpad.shape[0])\n        col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)\n    sk = (\n        seqlen_k\n        if key_padding_mask is None\n        else rearrange(key_padding_mask.sum(-1), \"b -> b 1 1 1\")\n    )\n    sq = (\n        seqlen_q\n        if query_padding_mask is None\n        else rearrange(query_padding_mask.sum(-1), \"b -> b 1 1 1\")\n    )\n    sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk\n    # Subtract remainder instead of divide and then multiply to take care of negative values\n    col_limit_left_chunk = row_idx + sk - sq - (row_idx + sk - sq) % attention_chunk\n    return torch.logical_or(\n        col_idx < col_limit_left_chunk, col_idx >= col_limit_left_chunk + attention_chunk\n    )\n\n\ndef attention_ref(\n    q,\n    k,\n    v,\n    query_padding_mask=None,\n    key_padding_mask=None,\n    key_leftpad=None,\n    attn_bias=None,\n    dropout_p=0.0,\n    dropout_mask=None,\n    causal=False,\n    qv=None,\n    q_descale=None, k_descale=None, v_descale=None,\n    window_size=(-1, -1),  # -1 means infinite window size\n    attention_chunk=0,\n    sink_token_length=0,\n    softcap=0.0,\n    upcast=True,\n    reorder_ops=False,\n    intermediate_dtype=None,\n):\n    \"\"\"\n    Arguments:\n        q: (batch_size, seqlen_q, nheads, head_dim)\n        k: (batch_size, seqlen_k, nheads, head_dim)\n        v: (batch_size, seqlen_k, nheads, head_dim_v)\n        qv: (batch_size, seqlen_q, nheads, head_dim_v)\n        query_padding_mask: (batch_size, seqlen_q)\n        key_padding_mask: (batch_size, seqlen_k)\n        attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)\n        dropout_p: float\n        dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)\n        causal: whether to apply causal masking\n        upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast\n            output back to fp16/bf16.\n        reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.)\n            without changing the math. This is to estimate the numerical error from operation\n            reordering.\n    Output:\n        output: (batch_size, seqlen_q, nheads, head_dim_v)\n        attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout\n    \"\"\"\n    if causal:\n        window_size = (window_size[0], 0)\n    dtype_og = q.dtype\n    if upcast:\n        q, k, v = q.float(), k.float(), v.float()\n        qv = qv.float() if qv is not None else None\n    if q_descale is not None:\n        q_descale = repeat(q_descale, \"b h -> b 1 (h g) 1\", g=q.shape[2] // k.shape[2])\n        q = (q.float() * q_descale).to(q.dtype)\n        qv = (qv.float() * q_descale).to(qv.dtype) if qv is not None else None\n    if k_descale is not None:\n        k = (k.float() * rearrange(k_descale, \"b h -> b 1 h 1\")).to(dtype=k.dtype)\n    if v_descale is not None:\n        v = (v.float() * rearrange(v_descale, \"b h -> b 1 h 1\")).to(dtype=v.dtype)\n    seqlen_q, seqlen_k = q.shape[1], k.shape[1]\n    k = repeat(k, \"b s h d -> b s (h g) d\", g=q.shape[2] // k.shape[2])\n    v = repeat(v, \"b s h d -> b s (h g) d\", g=q.shape[2] // v.shape[2])\n    d = q.shape[-1]\n    dv = v.shape[-1]\n    softmax_scale = 1.0 / math.sqrt(d if qv is None else d + dv)\n    if not reorder_ops:\n        scores = torch.einsum(\"bthd,bshd->bhts\", q * softmax_scale, k)\n    else:\n        scores = torch.einsum(\"bthd,bshd->bhts\", q, k * softmax_scale)\n    if qv is not None:\n        scores = scores + torch.einsum(\"bthd,bshd->bhts\", qv * softmax_scale, v)\n    if softcap > 0:\n        scores = torch.tanh(scores / softcap) * softcap\n    if key_padding_mask is not None:\n        scores.masked_fill_(rearrange(~key_padding_mask, \"b s -> b 1 1 s\"), float(\"-inf\"))\n    local_mask = None\n    if window_size[0] >= 0 or window_size[1] >= 0:\n        local_mask = construct_local_mask(\n            seqlen_q,\n            seqlen_k,\n            window_size,\n            sink_token_length,\n            query_padding_mask,\n            key_padding_mask,\n            key_leftpad=key_leftpad,\n            device=q.device,\n        )\n    if attention_chunk > 0:\n        chunk_mask = construct_chunk_mask(\n            seqlen_q,\n            seqlen_k,\n            attention_chunk,\n            query_padding_mask,\n            key_padding_mask,\n            key_leftpad=key_leftpad,\n            device=q.device,\n        )\n        local_mask = torch.logical_or(local_mask, chunk_mask) if local_mask is not None else chunk_mask\n    if local_mask is not None:\n        scores.masked_fill_(local_mask, float(\"-inf\"))\n    if attn_bias is not None:\n        scores = scores + attn_bias\n    attention = torch.softmax(scores, dim=-1).to(v.dtype)\n    # We want to mask here so that the attention matrix doesn't have any NaNs\n    # Otherwise we'll get NaN in dV\n    if query_padding_mask is not None:\n        attention = attention.masked_fill(rearrange(~query_padding_mask, \"b s -> b 1 s 1\"), 0.0)\n    # Without this we might get NaN in dv\n    if key_padding_mask is not None:\n        attention = attention.masked_fill(rearrange(~key_padding_mask, \"b s -> b 1 1 s\"), 0.0)\n    # Some rows might be completely masked out so we fill them with zero instead of NaN\n    if local_mask is not None:\n        attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0)\n    dropout_scaling = 1.0 / (1 - dropout_p)\n    # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling\n    # output = torch.einsum('bhts,bshd->bthd', attention_drop , v)\n    if dropout_mask is not None:\n        attention_drop = attention.masked_fill(~dropout_mask, 0.0)\n    else:\n        attention_drop = attention\n    if intermediate_dtype is not None:\n        attention_drop = attention_drop.to(intermediate_dtype).to(attention_drop.dtype)\n    output = torch.einsum(\"bhts,bshd->bthd\", attention_drop, v * dropout_scaling)\n    if query_padding_mask is not None:\n        output.masked_fill_(rearrange(~query_padding_mask, \"b s -> b s 1 1\"), 0.0)\n    return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)\n"
  },
  {
    "path": "hopper/tile_scheduler.hpp",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n#include \"cutlass/fast_math.h\"\n#include \"cutlass/arch/barrier.h\"\n\n#include \"named_barrier.hpp\"\n#include \"utils.h\"\n\nnamespace flash {\n\n///////////////////////////////////////////////////////////////////////////////\n\n// Host side kernel arguments\nstruct TileSchedulerArguments {\n    // num_head is num_head_q if not PackGQA, else num_head_k\n    int const num_blocks, num_head, num_batch, num_splits;\n    int const qhead_per_khead;\n    int const seqlen;  // Only used if Varlen and cu_seqlens == nullptr and seqused == nullptr\n    int const seqlen_k, headdim, headdim_v, element_size;  // Used to calculate L2 swizzling\n    int* const tile_count_semaphore = nullptr;\n    int const* const cu_seqlens = nullptr;\n    int const* const seqused = nullptr;\n    int const* const num_splits_dynamic_ptr = nullptr;\n    int const* const num_m_blocks_ptr = nullptr;\n    int const* const varlen_batch_idx_ptr = nullptr;\n    // int const* const num_n_blocks_ptr = nullptr;\n    int const* const num_nheads_in_l2_ptr = nullptr;\n};\n\n///////////////////////////////////////////////////////////////////////////////\n\ntemplate<bool Varlen=false, bool Split=false, bool PackGQA=false, int kBlock=128>\nclass SingleTileScheduler {\n\npublic:\n\n    using SharedStorage = int;\n\n    // Device side kernel params\n    struct Params {\n        int const num_blocks, num_head, num_batch, num_splits;\n        int const qhead_per_khead;\n        int const seqlen;\n        cutlass::FastDivmod nsplits_divmod;\n        int const* const cu_seqlens;\n        int const* const seqused;\n        int const* const num_splits_dynamic_ptr = nullptr;\n    };\n\n    static Params\n    to_underlying_arguments(TileSchedulerArguments const& args) {\n        assert(!Split || !Varlen || args.num_splits_dynamic_ptr != nullptr);\n        assert(!Split || !Varlen || args.num_splits < (1 << 16)); // We use the top 16 bits to store num_splits\n        return {args.num_blocks, args.num_head, args.num_batch, !Split ? 1 : args.num_splits,\n                args.qhead_per_khead, args.seqlen,\n                cutlass::FastDivmod(!Split ? 1 : args.num_splits),\n                !Varlen ? nullptr : args.cu_seqlens, !Varlen ? nullptr : args.seqused,\n                args.num_splits_dynamic_ptr};\n    }\n\n    static dim3\n    get_grid_shape(Params const& params, int num_sm) {\n        return {uint32_t(params.num_blocks), uint32_t((!Split ? 1 : params.num_splits) * params.num_head), uint32_t(params.num_batch)};\n    }\n\n    struct WorkTileInfo {\n        int block_idx = 0;\n        int bidh = 0;\n        int bidb = 0;\n        int split_idx = 0;\n\n        CUTLASS_DEVICE\n        bool\n        is_valid(Params const& params) const {\n            return bidb >= 0;\n        }\n\n        CUTLASS_DEVICE\n        cute::tuple<int32_t, int32_t, int32_t, int32_t>\n        get_block_coord(Params const& params) const {\n            return {block_idx, bidh, bidb, !Split ? 0 : split_idx};\n        }\n\n    };\n\n    CUTLASS_DEVICE\n    SingleTileScheduler(SharedStorage* const smem_scheduler) { }\n\n    template<bool IsProducerWarp=false>\n    CUTLASS_DEVICE\n    WorkTileInfo\n    get_initial_work(Params const& params) const {\n        WorkTileInfo work_info {int(blockIdx.x), int(blockIdx.y), int(blockIdx.z), 0};\n        if constexpr (Split) {\n            int split_idx;\n            work_info.bidh = params.nsplits_divmod.divmod(split_idx, work_info.bidh);\n            work_info.split_idx = split_idx;\n        }\n        bool is_valid_tile = true;\n        if constexpr (Varlen) {\n            int seqlen = params.seqused\n                ? params.seqused[work_info.bidb]\n                : (params.cu_seqlens ? params.cu_seqlens[work_info.bidb + 1] - params.cu_seqlens[work_info.bidb] : params.seqlen);\n            if constexpr (PackGQA) { seqlen *= params.qhead_per_khead; }\n            is_valid_tile = work_info.block_idx * kBlock < seqlen;\n        }\n        if constexpr (Varlen && Split) {\n            int num_splits_dynamic = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[work_info.bidb] : params.num_splits;\n            is_valid_tile &= work_info.split_idx < num_splits_dynamic;\n            // Use the top 16 bits to store num_splits\n            work_info.split_idx |= (num_splits_dynamic << 16);\n        }\n        work_info.bidb = is_valid_tile ? work_info.bidb : -1;\n        return work_info;\n    }\n\n    CUTLASS_DEVICE\n    void\n    init_consumer() const {}\n\n    CUTLASS_DEVICE\n    void\n    prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {}\n\n    template<bool IsProducerWarp=false>\n    CUTLASS_DEVICE\n    WorkTileInfo\n    get_next_work(Params const& params, WorkTileInfo const& current_work) const {\n        return {0, 0, -1, 0};\n    }\n\n};\n\n///////////////////////////////////////////////////////////////////////////////\n\ntemplate<bool Split=false>\nclass StaticPersistentTileScheduler {\n\npublic:\n\n    using SharedStorage = int;\n\n    // Device side kernel params\n    struct Params {\n        int total_blocks;\n        cutlass::FastDivmod m_block_divmod, head_divmod;\n        cutlass::FastDivmod nsplits_divmod;\n    };\n\n    static Params\n    to_underlying_arguments(TileSchedulerArguments const& args) {\n        return {args.num_blocks * args.num_head * args.num_batch * (!Split ? 1 : args.num_splits),\n                cutlass::FastDivmod(args.num_blocks), cutlass::FastDivmod(args.num_head * (!Split ? 1 : args.num_splits)),\n                cutlass::FastDivmod(!Split ? 1 : args.num_splits)};\n    }\n\n    static dim3\n    get_grid_shape(Params const& params, int num_sm) {\n        return {uint32_t(num_sm)};\n    }\n\n    struct WorkTileInfo {\n        int tile_idx;\n\n        CUTLASS_DEVICE\n        bool\n        is_valid(Params const& params) const {\n            return tile_idx < params.total_blocks;\n        }\n\n        CUTLASS_DEVICE\n        cute::tuple<int32_t, int32_t, int32_t, int32_t>\n        get_block_coord(Params const& params) const {\n            int block, bidh, bidb;\n            bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(block, tile_idx));\n            int split_idx = 0;\n            if constexpr (Split) {\n                bidh = params.nsplits_divmod.divmod(split_idx, bidh);\n            }\n            return {block, bidh, bidb, split_idx};\n        }\n\n    };\n\n    CUTLASS_DEVICE\n    StaticPersistentTileScheduler(SharedStorage* const smem_scheduler) {};\n\n    template<bool IsProducerWarp=false>\n    CUTLASS_DEVICE\n    WorkTileInfo\n    get_initial_work(Params const& params) const {\n        return {int(blockIdx.x)};\n    }\n\n    CUTLASS_DEVICE\n    void\n    init_consumer() const {}\n\n    CUTLASS_DEVICE\n    void\n    prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {}\n\n    template<bool IsProducerWarp=false>\n    CUTLASS_DEVICE\n    WorkTileInfo\n    get_next_work(Params const& params, WorkTileInfo const& current_work) const {\n        return {current_work.tile_idx + int(gridDim.x)};\n    }\n\n};\n\n///////////////////////////////////////////////////////////////////////////////\n\ntemplate<int NumMmaThreads=2 * cutlass::NumThreadsPerWarpGroup, int NumProducerThreads=cutlass::NumThreadsPerWarp,\n        bool Split=false, bool PackGQA=false, bool WarpSpecialized=true>\nclass DynamicPersistentTileScheduler {\n\n    // This scheduler targets the causal (or local) case where each tile takes different\n    // amount of time. We use longest-processing-time-first scheduling:\n    // the longest remaining tile is assigned to the first SM that's free.\n    // SM indicates they are free by incrementing a semaphore.\n    // However, we have to make sure K & V still fit into L2 cache, so we perform scheduling\n    // on \"sections\" of the head & batch dimension, each section consisting of e.g. 8 heads.\n    // This is the L2 swizzling part. The size of each section is precomputed based on the\n    // size of K & V and the L2 cache size.\n\n    static_assert(WarpSpecialized || NumProducerThreads == NumMmaThreads);\n    static constexpr int NumThreads = WarpSpecialized ? NumMmaThreads + NumProducerThreads : NumMmaThreads;\n\npublic:\n    using SharedStorage = int;\n\nprotected:\n    SharedStorage* const tile_count_smem;\n\npublic:\n\n    // Device side kernel params\n    struct Params {\n        int const total_blocks;\n        cutlass::FastDivmod const m_block_divmod, head_divmod;\n        cutlass::FastDivmod const l2_minor_divmod, l2_major_divmod;\n        cutlass::FastDivmod const l2_minor_residual_divmod;\n        int const num_hb_quotient;\n        int* const tile_count_semaphore;\n    };\n\n    static Params\n    to_underlying_arguments(TileSchedulerArguments const& args) {\n        long long const size_one_kv_head = long(args.seqlen_k) * long(args.headdim + args.headdim_v) * long(args.element_size);\n        int const size_l2 = 32 * 1024 * 1024;  // 32 MB for K & V\n        // Swizzle is the size of each \"section\". Round swizzle to a power of 2\n        // If not PackGQA already, the size of each section can increase by qhead_per_khead\n        // Need to be careful about the case where only one head will fit\n        auto find_log2_floor = [&](int n) { return 31 - cutlass::clz(n); };\n        // Seems faster if swizzle if a power of 2\n        int const swizzle = (size_l2 < size_one_kv_head ? 1 : (1 << find_log2_floor(size_l2 / size_one_kv_head))) * (PackGQA ? 1 : args.qhead_per_khead);\n        // If we're in the last section (called residual), we don't want to divide by\n        // swizzle. Instead we want to divide by the remainder.\n        int const num_hb_remainder = (args.num_head * args.num_batch) % swizzle;\n        int const num_split_blocks = args.num_blocks * (!Split ? 1 : args.num_splits);\n        // printf(\"num_split_blocks = %d, num_head = %d, num_batch = %d, swizzle = %d, PackGQA = %d, qhead_per_khead = %d, num_hb_remainder = %d\\n\", num_split_blocks, args.num_head, args.num_batch, swizzle, int(PackGQA), args.qhead_per_khead, num_hb_remainder);\n        assert(args.tile_count_semaphore != nullptr);\n        return {num_split_blocks * args.num_head * args.num_batch,\n                cutlass::FastDivmod(args.num_blocks), cutlass::FastDivmod(args.num_head),\n                cutlass::FastDivmod(swizzle), cutlass::FastDivmod(swizzle * num_split_blocks),\n                // don't divide by 0\n                cutlass::FastDivmod(num_hb_remainder > 0 ? num_hb_remainder : 1),\n                (args.num_head * args.num_batch) / swizzle,\n                args.tile_count_semaphore};\n    }\n\n    static dim3\n    get_grid_shape(Params const& params, int num_sm) {\n        return {uint32_t(num_sm)};\n    }\n\n    struct WorkTileInfo {\n        int tile_idx;\n\n        CUTLASS_DEVICE\n        bool\n        is_valid(Params const& params) const {\n            return tile_idx < params.total_blocks;\n        }\n\n        CUTLASS_DEVICE\n        cute::tuple<int32_t, int32_t, int32_t, int32_t>\n        get_block_coord(Params const& params) const {\n            int block, bidh, bidb;\n            int l2_mod, bidhb, bidhb_residual;\n            bidhb = params.l2_major_divmod.divmod(l2_mod, tile_idx);\n            // If we're in the last section (called residual), we don't want to divide by\n            // swizzle. Instead we want to divide by the remainder.\n            if (bidhb < params.num_hb_quotient) {\n                block = params.l2_minor_divmod.divmod(bidhb_residual, l2_mod);\n            } else {\n                block = params.l2_minor_residual_divmod.divmod(bidhb_residual, l2_mod);\n            }\n            bidb = params.head_divmod.divmod(bidh, bidhb * params.l2_minor_divmod.divisor + bidhb_residual);\n            int split_idx = 0;\n            if constexpr (Split) {\n                split_idx = params.m_block_divmod.divmod(block, block);\n            }\n            // Longest-processing-time-first\n            block = params.m_block_divmod.divisor - 1 - block;\n            return {block, bidh, bidb, split_idx};\n        }\n\n    };\n\n    CUTLASS_DEVICE\n    DynamicPersistentTileScheduler(SharedStorage* const smem_scheduler) : tile_count_smem(smem_scheduler) {};\n\n    template<bool IsProducerWarp=false>\n    CUTLASS_DEVICE\n    WorkTileInfo\n    get_initial_work(Params const& params) const {\n        return {int(blockIdx.x)};\n    }\n\n    CUTLASS_DEVICE\n    void\n    init_consumer() const {\n        if (WarpSpecialized || cutlass::canonical_warp_idx_sync() > 0) {\n            flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier0 /*id*/);  // TileCountSmemEmpty\n        }\n    }\n\n    CUTLASS_DEVICE\n    void\n    prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {\n        if (threadIdx.x % NumProducerThreads == 0) {\n            current_work.tile_idx = atomicAdd(params.tile_count_semaphore, 1) + int(gridDim.x);\n        }\n    }\n\n    template<bool IsProducerWarp=false>\n    CUTLASS_DEVICE\n    WorkTileInfo\n    get_next_work(Params const& params, WorkTileInfo const& current_work) const {\n        if constexpr (IsProducerWarp) {\n            // thread 0 already has the right tile_idx, just need to broadcast to the rest of warp 0\n            int new_tile_idx = __shfl_sync(0xffffffff, current_work.tile_idx, 0 /*lane*/);\n            flash::named_barrier_sync(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier0 /*id*/);  // TileCountSmemEmpty\n            if (threadIdx.x % NumProducerThreads == 0) {\n                *tile_count_smem = current_work.tile_idx;\n            }\n            flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier1 /*id*/);  // TileCountSmemFull\n            return {new_tile_idx};\n        } else {\n            flash::named_barrier_sync(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier1 /*id*/);  // TileCountSmemFull\n            int tile_idx = *tile_count_smem;\n            flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier0 /*id*/);  // TileCountSmemEmpty\n            return {tile_idx};\n        }\n    }\n\n};\n\n///////////////////////////////////////////////////////////////////////////////\n\ntemplate <bool Varlen, int kBlock, bool SPT = false>\nclass SingleTileBwdLPTScheduler {\n\npublic:\n\n    using SharedStorage = int;\n\n    // Device side kernel params\n    struct Params {\n        int const total_blocks;\n        cutlass::FastDivmod const block_divmod, head_divmod;\n        cutlass::FastDivmod const l2_minor_divmod, l2_major_divmod;\n        cutlass::FastDivmod const l2_minor_residual_divmod;\n        int const num_hb_quotient;\n        int const seqlen;\n        int const* const cu_seqlens;\n        int const* const seqused;\n    };\n\n    static Params\n    to_underlying_arguments(TileSchedulerArguments const& args) {\n        // Since it's the bwd pass, seqlen_k get passed to args.seqlen and seqlen_q is passed to args.seqlen_k\n        long long const size_one_qdo_head = long(args.seqlen_k) * long(args.headdim + args.headdim_v) * long(args.element_size);\n        long long const size_one_dqaccum_head = long(args.seqlen_k) * long(args.headdim) * sizeof(float);\n        long long const size_one_head = size_one_qdo_head + size_one_dqaccum_head;\n        int const size_l2 = 40 * 1024 * 1024;  // 40 MB for Q, dO, and dQaccum\n        // Swizzle is the size of each \"section\". Round swizzle to a power of 2\n        // Need to be careful about the case where only one head will fit\n        auto find_log2_floor = [&](int n) { return 31 - cutlass::clz(n); };\n        // Seems faster if swizzle if a power of 2\n        int const swizzle = size_l2 < size_one_head ? 1 : (1 << find_log2_floor(size_l2 / size_one_head));\n        // If we're in the last section (called residual), we don't want to divide by\n        // swizzle. Instead we want to divide by the remainder.\n        int const num_hb_remainder = (args.num_head * args.num_batch) % swizzle;\n        // printf(\"num_blocks = %d, num_head = %d, num_batch = %d, size_one_head = %d, ratio = %d, swizzle = %d, num_hb_remainder = %d\\n\", args.num_blocks, args.num_head, args.num_batch, size_one_head, size_l2 / size_one_head, swizzle, num_hb_remainder);\n        assert(args.tile_count_semaphore != nullptr);\n        return {args.num_blocks * args.num_head * args.num_batch,\n                cutlass::FastDivmod(args.num_blocks), cutlass::FastDivmod(args.num_head),\n                cutlass::FastDivmod(swizzle), cutlass::FastDivmod(swizzle * args.num_blocks),\n                // don't divide by 0\n                cutlass::FastDivmod(num_hb_remainder > 0 ? num_hb_remainder : 1),\n                (args.num_head * args.num_batch) / swizzle,\n                args.seqlen, !Varlen ? nullptr : args.cu_seqlens, !Varlen ? nullptr : args.seqused};\n    }\n\n    static dim3\n    get_grid_shape(Params const& params, int num_sm) {\n        return {uint32_t(params.total_blocks)};\n    }\n\n    struct WorkTileInfo {\n        int block;\n        int bidh;\n        int bidb;\n\n        CUTLASS_DEVICE\n        bool\n        is_valid(Params const& params) const {\n            return bidb >= 0;\n        }\n\n        CUTLASS_DEVICE\n        cute::tuple<int32_t, int32_t, int32_t, int32_t>\n        get_block_coord(Params const& params) const {\n            return {block, bidh, bidb, 0 /*split_idx*/};\n        }\n\n    };\n\n    CUTLASS_DEVICE\n    SingleTileBwdLPTScheduler(SharedStorage* const smem_scheduler) { }\n\n    template<bool IsProducerWarp=false>\n    CUTLASS_DEVICE\n    WorkTileInfo\n    get_initial_work(Params const& params) const {\n        int tile_idx = blockIdx.x;\n        int block, bidh, bidb;\n        int l2_mod, bidhb, bidhb_residual;\n        bidhb = params.l2_major_divmod.divmod(l2_mod, tile_idx);\n        // If we're in the last section (called residual), we don't want to divide by\n        // swizzle. Instead we want to divide by the remainder.\n        if (bidhb < params.num_hb_quotient) {\n            block = params.l2_minor_divmod.divmod(bidhb_residual, l2_mod);\n        } else {\n            block = params.l2_minor_residual_divmod.divmod(bidhb_residual, l2_mod);\n        }\n        bidb = params.head_divmod.divmod(bidh, bidhb * params.l2_minor_divmod.divisor + bidhb_residual);\n        bool is_valid_tile = true;\n        int num_blocks;\n        if constexpr (Varlen) {\n            int seqlen = params.seqused\n                ? params.seqused[bidb]\n                : (params.cu_seqlens ? params.cu_seqlens[bidb + 1] - params.cu_seqlens[bidb] : params.seqlen);\n            num_blocks = cute::ceil_div(seqlen, Int<kBlock>{});\n            is_valid_tile = block < num_blocks;\n        } else {\n            num_blocks = params.block_divmod.divisor;\n        }\n        if constexpr (SPT) {\n            block = num_blocks - block - 1;\n        }\n        return {block, bidh, is_valid_tile ? bidb : -1};\n    }\n\n    CUTLASS_DEVICE\n    void\n    init_consumer() const {}\n\n    CUTLASS_DEVICE\n    void\n    prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {}\n\n    template<bool IsProducerWarp=false>\n    CUTLASS_DEVICE\n    WorkTileInfo\n    get_next_work(Params const& params, WorkTileInfo const& current_work) const {\n        return {0, 0, -1};\n    }\n\n};\n\n///////////////////////////////////////////////////////////////////////////////\n\ntemplate<int kBlockM, int kBlockN, int NumMmaThreads=2 * cutlass::NumThreadsPerWarpGroup, int NumProducerThreads=cutlass::NumThreadsPerWarp,\n         bool Split=false, bool PackGQA=false, bool WarpSpecialized=true, bool LPT = false, bool Sort = false, bool Prepared = true>\nclass VarlenDynamicPersistentTileScheduler {\n\n    static_assert(WarpSpecialized || NumProducerThreads == NumMmaThreads);\n    static constexpr int NumThreads = WarpSpecialized ? NumMmaThreads + NumProducerThreads : NumMmaThreads;\n\npublic:\n    using SharedStorage = int4;\n\nprotected:\n    SharedStorage* const work_info_smem;\n\npublic:\n\n    // Device side kernel params\n    struct Params {\n        int num_head, num_batch;\n        int const qhead_per_khead;\n        int const seqlen;\n        // int const max_kvblocks_in_l2;\n        cutlass::FastDivmod head_divmod;\n        cutlass::FastDivmod nsplits_divmod;\n        int* const tile_count_semaphore;\n        int const* const cu_seqlens;\n        int const* const seqused;\n        int const* const num_splits_dynamic_ptr;\n        int const* const num_m_blocks_ptr;\n        int const* const varlen_batch_idx_ptr;\n        // int const* const num_n_blocks_ptr;\n        int const* const num_nheads_in_l2_ptr;\n    };\n\n    static Params\n    to_underlying_arguments(TileSchedulerArguments const& args) {\n        // If Split, for the purpose of scheduling, we pretend that instead there are\n        // (args.num_splits * args.num_head) number of heads.\n        assert(args.tile_count_semaphore != nullptr);\n        assert(args.num_head < (1 << 16));  // We use the top 16 bits to store num_splits & split_idx\n        assert(!Split || args.num_splits < (1 << 8)); // We use the top 8 bits to store num_splits\n        // int const size_l2 = 50 * 1024 * 1024; // 50 MB\n        // int const size_one_kvblock = kBlockN * (args.headdim + args.headdim_v) * args.element_size;\n        // int max_kvblocks_in_l2 = size_l2 / size_one_kvblock;\n        return {args.num_head, args.num_batch,\n                args.qhead_per_khead, args.seqlen,\n                // max_kvblocks_in_l2,\n                cutlass::FastDivmod(args.num_head),\n                cutlass::FastDivmod(!Split ? 1 : args.num_splits),\n                args.tile_count_semaphore, args.cu_seqlens, args.seqused,\n                args.num_splits_dynamic_ptr,\n                args.num_m_blocks_ptr,\n                args.varlen_batch_idx_ptr,\n                // aras.num_n_blocks_ptr,\n                args.num_nheads_in_l2_ptr};\n    }\n\n    static dim3\n    get_grid_shape(Params const& params, int num_sm) {\n        return {uint32_t(num_sm)};\n    }\n\n    struct WorkTileInfo {\n        int tile_idx, block, bidh, bidb;\n\n        CUTLASS_DEVICE\n        bool\n        is_valid(Params const& params) const {\n            // if (blockIdx.x >= 0 && (threadIdx.x == 128 || threadIdx.x == 0)) { printf(\"blockIdx.x = %d, threadIdx.x = %d, checking valid, bidb = %d, params.num_batch = %d\\n\", blockIdx.x, threadIdx.x, bidb, params.num_batch); }\n            return bidb < params.num_batch;\n        }\n\n        CUTLASS_DEVICE\n        cute::tuple<int32_t, int32_t, int32_t, int32_t>\n        get_block_coord(Params const& params) const {\n            auto get_actual_batch = [&](int virtual_batch) {\n                if constexpr(Prepared && Sort) {\n                    return params.varlen_batch_idx_ptr[virtual_batch];\n                } else {\n                    return virtual_batch;\n                }\n            };\n            if constexpr (!Split) {\n                return {block, bidh, get_actual_batch(bidb), 0 /*split_idx*/};\n            } else {\n                // the top 8 bits of bidh store num_splits and the next 8 bits store split_idx\n                // reinterpret_cast to uint32_t to make sure we're not doing sign extension when we shift\n                uint32_t bidh_packed = reinterpret_cast<uint32_t const&>(bidh);\n                uint32_t bidh_actual_u = bidh_packed & 0x0000FFFF;\n                int bidh_actual = reinterpret_cast<int&>(bidh_actual_u);\n                // Use the top 16 bits of split_idx to store num_splits and the next 16 bits to store split_idx\n                uint32_t split_idx_u = ((bidh_packed & 0x00FF0000) >> 16) + ((bidh_packed & 0xFF000000) >> 8);\n                int split_idx = reinterpret_cast<int&>(split_idx_u);\n                // int bidh_actual = params.nsplits_divmod.divmod(split_idx, bidh);\n                // if (threadIdx.x == 128) {\n                //     printf(\"blockIdx.x = %d, bidb = %d, bidh = %d, bidh_actual = %d, split_idx = %d\\n\", blockIdx.x, bidb, bidh, bidh_actual, split_idx);\n                // }\n                return {block, bidh_actual, get_actual_batch(bidb), split_idx};\n            }\n        }\n    };\n\n    CUTLASS_DEVICE\n    VarlenDynamicPersistentTileScheduler(SharedStorage* const smem_scheduler) : work_info_smem(smem_scheduler) {};\n\n    CUTLASS_DEVICE\n    WorkTileInfo\n    tile_idx_to_work_tile(Params const& params, int next_tile_idx, WorkTileInfo const& current_work) const {\n        int lane = threadIdx.x % cutlass::NumThreadsPerWarp;\n        auto get_num_m_blocks = [&] (int bidb_start) {\n            int batch_idx = lane + bidb_start;\n            if constexpr (Prepared) {\n                return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1\n                    ? params.num_m_blocks_ptr[batch_idx] : 0;\n            } else {\n                int seqlen = params.seqlen * (!PackGQA ? 1 : params.qhead_per_khead);\n                if (seqlen > kBlockM) {\n                    if (params.seqused) {\n                        seqlen = batch_idx < params.num_batch ? params.seqused[batch_idx] : 0;\n                    } else if (params.cu_seqlens) {\n                        int cur_cu_seqlen = batch_idx <= params.num_batch ? params.cu_seqlens[batch_idx] : 0;\n                        int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1);\n                        seqlen = next_cu_seqlen - cur_cu_seqlen;\n                    } else {\n                        seqlen = params.seqlen;\n                    }\n                    if constexpr (PackGQA) { seqlen *= params.qhead_per_khead; }\n                }\n                return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1\n                    ? cute::ceil_div(seqlen, kBlockM) : 0;\n                    // ? params.num_m_blocks_ptr[batch_idx] : 0;\n            }\n        };\n\n        auto get_num_splits = [&] (int bidb_start) {\n            int batch_idx = lane + bidb_start;\n            bool is_valid = batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1;\n            if constexpr (!Split) {\n                return is_valid ? 1 : 0;\n            } else if constexpr(Prepared) {\n                return is_valid ? params.num_splits_dynamic_ptr[batch_idx] : 0;\n            } else {\n                return is_valid ? params.nsplits_divmod.divisor : 0;\n            }\n        };\n\n        int num_m_blocks = get_num_m_blocks(current_work.bidb);  // Different for each lane\n        int num_splits = get_num_splits(current_work.bidb);\n        int num_split_m_blocks = !Split ? num_m_blocks : num_m_blocks * num_splits;\n        // Cumulative number of blocks for the next 31 batches\n        int num_m_blocks_cumulative = warp_prefix_sum(num_split_m_blocks);\n        // Total number of blocks for the next 31 batches\n        int m_blocks_in_group = __shfl_sync(0xffffffff, num_m_blocks_cumulative, cutlass::NumThreadsPerWarp - 1);\n        // Only the lower 16 bits are the actual bidh\n        // int current_bidh = !Split ? current_work.bidh : (current_work.bidh & 0x0000FFFF);\n        // int group_end_tile = current_work.tile_idx - current_work.block - current_bidh * __shfl_sync(0xffffffff, num_split_m_blocks, 0 /*lane*/) + m_blocks_in_group * params.num_head;  // Same for all lanes\n        // if constexpr (Split) {\n        //     int current_split_idx = (current_work.bidh & 0x00FF0000) >> 16;\n        //     group_end_tile -= current_split_idx * __shfl_sync(0xffffffff, num_m_blocks, 0 /*lane*/);\n        // }\n        // NEW: current_work.tile_idx holds group_start_tile for starting batch\n        int group_end_tile = current_work.tile_idx + m_blocks_in_group * params.num_head;  // Same for all lanes\n        int bidb = current_work.bidb;\n        // if (blockIdx.x <= 9 && threadIdx.x == 0) {\n        //     printf(\"Before while, blockIdx.x = %d, threadIdx.x = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, cur tile_idx = %d, cur block = %d, cur bidh = %d, num_split_m_blocks = %d, group_end_tile = %d, m_blocks_in_group = %d\\n\", blockIdx.x, threadIdx.x, current_work.bidb, num_m_blocks, next_tile_idx, current_work.tile_idx, current_work.block, current_bidh, num_split_m_blocks, group_end_tile, m_blocks_in_group);\n        // }\n        // if (threadIdx.x == 0 && blockIdx.x == 0) { printf(\"tile_idx = %d, group_end_tile = %d, num_m_blocks_cumulative = %d, m_blocks_in_group = %d\\n\", current_work.tile_idx, group_end_tile, num_m_blocks_cumulative, m_blocks_in_group); }\n        while (group_end_tile <= next_tile_idx) {\n            bidb += cutlass::NumThreadsPerWarp - 1;\n            if (bidb >= params.num_batch) {\n                // if (blockIdx.x <= 9 && threadIdx.x == 0) {\n                //     printf(\"Returning early, blockIdx.x = %d, threadIdx.x = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d\\n\", blockIdx.x, threadIdx.x, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group);\n                // }\n                return {next_tile_idx, 0, 0, params.num_batch};\n            }\n            num_m_blocks = get_num_m_blocks(bidb);\n            num_splits = get_num_splits(bidb);\n            num_split_m_blocks = !Split ? num_m_blocks : num_m_blocks * num_splits;\n            num_m_blocks_cumulative = warp_prefix_sum(num_split_m_blocks);\n            m_blocks_in_group = __shfl_sync(0xffffffff, num_m_blocks_cumulative, cutlass::NumThreadsPerWarp - 1);\n            group_end_tile += m_blocks_in_group * params.num_head;\n            // if (blockIdx.x <= 9 && threadIdx.x == 0) {\n            //     printf(\"Bottom of while, blockIdx.x = %d, threadIdx.x = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d\\n\", blockIdx.x, threadIdx.x, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group);\n            // }\n        }\n        int group_start_tile = group_end_tile - m_blocks_in_group * params.num_head;\n        // The next problem to process is the first one that does not have ending tile position\n        // that is greater than or equal to tile index.\n        int batch_idx_in_group = __popc(__ballot_sync(0xffffffff, group_start_tile + num_m_blocks_cumulative * params.num_head <= next_tile_idx));\n        // if (threadIdx.x == 31 || threadIdx.x == 0) { printf(\"blockIdx.x = %d, tidx %d, group_start_tile = %d, num_m_blocks_cumulative = %d, num_head = %d, next_tile_idx = %d, ballot = %x, batch_idx_in_group = %d\\n\", blockIdx.x, threadIdx.x, group_start_tile, num_m_blocks_cumulative, params.num_head, next_tile_idx, tmp, batch_idx_in_group); }\n        bidb += batch_idx_in_group;\n        num_m_blocks = __shfl_sync(0xffffffff, num_m_blocks, batch_idx_in_group);\n        if constexpr (Split) { num_splits = __shfl_sync(0xffffffff, num_splits, batch_idx_in_group); }\n        group_start_tile += (batch_idx_in_group == 0 ? 0 : __shfl_sync(0xffffffff, num_m_blocks_cumulative, batch_idx_in_group - 1)) * params.num_head;\n        int mh_block = next_tile_idx - group_start_tile;\n        int block, bidh;\n        if constexpr (LPT) {\n            if (!Split || num_splits == 1) {\n                // NOTE: code for computing nheads_in_l2 directly left as reference\n                // int num_n_blocks = params.num_n_blocks_ptr ? params.num_n_blocks_ptr[bidb] : num_m_blocks;\n                // auto find_log2_floor = [&](int n) { return 31 - cutlass::clz(n); };\n                // int nheads_in_l2 = params.max_kvblocks_in_l2 < num_n_blocks\n                //     ? 1 : 1 << find_log2_floor(params.max_kvblocks_in_l2 / num_n_blocks);\n                // if constexpr (!PackGQA) { nheads_in_l2 *= params.qhead_per_khead; }\n                // nheads_in_l2 = min(nheads_in_l2, params.num_head);\n                auto get_nheads_in_l2 = [&](int batch_idx) {\n                    if constexpr(Prepared) {\n                        return params.num_nheads_in_l2_ptr[batch_idx];\n                    } else {\n                        return !PackGQA ? params.qhead_per_khead : 1;\n                    }\n                };\n                int nheads_in_l2 = get_nheads_in_l2(bidb);\n                int mh_in_l2 = nheads_in_l2 * num_m_blocks;\n                int section_idx = mh_block / mh_in_l2;\n                int l2_mod = mh_block - section_idx * mh_in_l2;\n                // tail section\n                int nheads_remainder = params.num_head - section_idx * nheads_in_l2;\n                int nheads_in_this_section = nheads_in_l2 <= nheads_remainder ? nheads_in_l2 : nheads_remainder;\n                block = l2_mod / nheads_in_this_section;\n                int bidh_residual = l2_mod - block * nheads_in_this_section;\n                bidh = section_idx * nheads_in_l2 + bidh_residual;\n                if constexpr(Split) {\n                    // remember to set num_splits = 1 in work tile\n                    uint32_t bidh_packed = reinterpret_cast<uint32_t&>(bidh) + (reinterpret_cast<uint32_t&>(num_splits) << 24);\n                    bidh = reinterpret_cast<int&>(bidh_packed);\n                }\n            } else {\n                // NOTE: leave traverse heads first version for reference\n                // block = params.head_divmod.divmod(bidh, mh_block);\n                // if constexpr (Split) {\n                //     int split_idx = block / num_m_blocks;\n                //     block = block - split_idx * num_m_blocks;\n                //     uint32_t bidh_packed = reinterpret_cast<uint32_t&>(bidh) + (reinterpret_cast<uint32_t&>(split_idx) << 16) + (reinterpret_cast<uint32_t&>(num_splits) << 24);\n                //     bidh = reinterpret_cast<int&>(bidh_packed);\n                // }\n                bidh = mh_block / num_m_blocks;\n                block = mh_block - bidh * num_m_blocks;\n                if constexpr (Split) {\n                    int bidh_actual = bidh / num_splits;\n                    int split_idx = bidh - bidh_actual * num_splits;\n                    uint32_t bidh_packed = reinterpret_cast<uint32_t&>(bidh_actual) + (reinterpret_cast<uint32_t&>(split_idx) << 16) + (reinterpret_cast<uint32_t&>(num_splits) << 24);\n                    bidh = reinterpret_cast<int&>(bidh_packed);\n                }\n            }\n            block = num_m_blocks - 1 - block;\n        } else {\n            bidh = mh_block / num_m_blocks;\n            block = mh_block - bidh * num_m_blocks;\n            if constexpr (Split) {\n                int bidh_actual = bidh / num_splits;\n                int split_idx = bidh - bidh_actual * num_splits;\n                // TODO: idk why this gives wrong answer nondeterministically\n                // int bidh_actual, split_idx;\n                // split_idx = params.head_divmod.divmod(bidh_actual, bidh);\n                // Use the top 8 bits to store num_splits and the next 8 bits to store split_idx\n                // reinterpret_cast to uint32_t to make sure we're not doing sign extension when we shift\n                uint32_t bidh_packed = reinterpret_cast<uint32_t&>(bidh_actual) + (reinterpret_cast<uint32_t&>(split_idx) << 16) + (reinterpret_cast<uint32_t&>(num_splits) << 24);\n                // if (threadIdx.x == 0) {\n                //     printf(\"blockIdx.x = %d, group_start_tiled = %d, bidb = %d, batch_idx_in_group = %d, mh_block = %d, num_m_blocks = %d, bidh = %d, bidh_actual = %d, split_idx = %d, num_splits = %d, bidh_packed = %d\\n\", blockIdx.x, group_start_tile, bidb, batch_idx_in_group, mh_block, num_m_blocks, bidh, bidh_actual, split_idx, num_splits, bidh_packed);\n                // }\n                bidh = reinterpret_cast<int&>(bidh_packed);\n            }\n            // if (blockIdx.x <= 9 && threadIdx.x == 0) {\n            //     printf(\"Before returning, blockIdx.x = %d, threadIdx.x = %d, group_start_tile = %d, batch_idx_in_group = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d, mh_block = %d, bidh = %d, block = %d\\n\", blockIdx.x, threadIdx.x, group_start_tile, batch_idx_in_group, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group, mh_block, bidh, block);\n            // }\n        }\n        return {group_start_tile, block, bidh, bidb};\n    }\n\n    template<bool IsProducerWarp=false>\n    CUTLASS_DEVICE\n    WorkTileInfo\n    get_initial_work(Params const& params) const {\n        if constexpr (IsProducerWarp) {\n            WorkTileInfo work_info = tile_idx_to_work_tile(params, int(blockIdx.x), {0, 0, 0, 0});\n            if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) {\n                *work_info_smem = make_int4(work_info.tile_idx, work_info.block, work_info.bidh, work_info.bidb);\n            }\n            flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier1 /*id*/);  // TileCountSmemFull\n            return work_info;\n        } else {\n            return get_next_work<false>(params, {0, 0, 0, 0});\n        }\n    }\n\n    CUTLASS_DEVICE\n    void\n    init_consumer() const {\n        // Don't arrive at the TileCountSmemEmpty barrier here, because get_initial_work will do that\n    }\n\n    CUTLASS_DEVICE\n    void\n    prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {\n        if (threadIdx.x % NumProducerThreads == 0) {\n            current_work.tile_idx = atomicAdd(params.tile_count_semaphore, 1) + int(gridDim.x);\n        }\n    }\n\n    template<bool IsProducerWarp=false>\n    CUTLASS_DEVICE\n    WorkTileInfo\n    get_next_work(Params const& params, WorkTileInfo const& current_work) const {\n        if constexpr (IsProducerWarp) {\n            // thread 0 has the next tile_idx, just need to broadcast to the rest of warp 0\n            int new_tile_idx = __shfl_sync(0xffffffff, current_work.tile_idx, 0 /*lane*/);\n            WorkTileInfo work_info = {__shfl_sync(0xffffffff, current_work.tile_idx, 1 /*lane*/), current_work.block, current_work.bidh, current_work.bidb};\n            work_info = tile_idx_to_work_tile(params, new_tile_idx, work_info);\n            flash::named_barrier_sync(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier0 /*id*/);  // TileCountSmemEmpty\n            if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) {\n                *work_info_smem = make_int4(work_info.tile_idx, work_info.block, work_info.bidh, work_info.bidb);\n            }\n            flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier1 /*id*/);  // TileCountSmemFull\n            return work_info;\n        } else {\n            flash::named_barrier_sync(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier1 /*id*/);  // TileCountSmemFull\n            int4 work_info = *work_info_smem;\n            flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier0 /*id*/);  // TileCountSmemEmpty\n            return WorkTileInfo{work_info.x, work_info.y, work_info.z, work_info.w};\n        }\n    }\n\n};\n\n///////////////////////////////////////////////////////////////////////////////\n\n} // flash\n"
  },
  {
    "path": "hopper/tile_size.h",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n#include <tuple>\n\n// Return {kBlockM, kBlockN, MmaPV_is_RS, IntraWGOverlap}\nconstexpr std::tuple<int, int, bool, bool> tile_size_fwd_sm90(\n        int headdim, int headdim_v, bool is_causal, bool is_local, int element_size=2,\n        bool v_colmajor=false, bool paged_kv_non_TMA=false, bool softcap=false) {\n    if (element_size == 2) {\n        if (headdim <= 64) {\n            // return {same_hdim ? 192 : 64, same_hdim ? 128 : 64, same_hdim, same_hdim};\n            // With this workaround in Cutlass 3.8, tile size 192 x 128 got slower for non-causal, idk why\n            // https://github.com/NVIDIA/cutlass/blob/833f6990e031b48b4cd2fcf55e0849c51ef6bac2/include/cute/container/tuple.hpp#L131\n            if (headdim_v == 512) {\n                return {64, 64, false, false};\n            } else if (headdim_v == 256) {\n                return {128, 96, true, false};\n            } else {\n                // Switch to tile size 192 x 192 for now\n                bool const use_blockN_128 = is_causal || is_local || paged_kv_non_TMA;\n                return {192, use_blockN_128 ? 128 : 192, use_blockN_128, true};\n            }\n            // Good for long seqlen (>= 4k) but suffers from tile quantization at short seqlen\n            // return {192, is_causal || is_local ? 192 : 176, true, false};\n        } else if (headdim <= 96) {\n            return {192, is_local || paged_kv_non_TMA ? 128 : 144, false, true};\n        } else if (headdim <= 128) {\n            bool const use_blockN_128 = is_causal || is_local || paged_kv_non_TMA;\n            return {128, use_blockN_128 ? 128 : 176, true, true};\n            // {128, 192, true, false} and {192, 128, false, true} are quite good too\n            // 128 x 192 hits the limit of smem if MmaPV_is_RS, 128 x 144 hits the limit if !MmaPV_is_RS\n        } else if (headdim <= 192) {\n            return {128, paged_kv_non_TMA || is_local ? 96 : (headdim_v <= 128 ? 128 : 112), true, true};  // 128 x 112 hits the limit of smem\n        } else {\n            return {128, is_local ? 64 : 80, true, true};  // 128 x 80 hits the limit of smem\n        }\n    } else {\n        if (headdim <= 64) {\n            return {192, 160, true, true};\n        } else if (headdim <= 96) {\n            return {192, 128, true, true};\n        } else if (headdim <= 128) {\n            return {128, paged_kv_non_TMA ? 160 : (v_colmajor || (softcap && is_local) ? 192 : 224), true, true};\n        } else if (headdim <= 192) {\n            return {128, (paged_kv_non_TMA || softcap) && is_local ? 128 : 160, true, true};\n        } else {\n            return {128, is_local ? 64 : 128, true, !paged_kv_non_TMA};  // PagedKV uses more registers so we disabled IntraWGOverlap\n        }\n    }\n}\n\n// Return {kBlockM, kBlockN, kNWarps, kStages, Q_in_regs}\nconstexpr std::tuple<int, int, int, int, bool> tile_size_fwd_sm8x(\n        bool sm86_or_89, int headdim, int headdim_v, bool is_causal, bool is_local, int element_size=2,\n        bool paged_kv=false, bool varlen_and_split=false,\n        bool softcap=false, bool append_kv=false) {\n    if (element_size == 2) {\n        if (headdim <= 64) {\n            return {128, varlen_and_split ? 80 : (is_local ? 96 : 112), 4, 1, false};\n        } else if (headdim <= 96) {\n            return {128, varlen_and_split || is_local ? 48 : 64, 4, 1, false};\n        } else if (headdim <= 128) {\n            bool const use_8_warps = sm86_or_89 | varlen_and_split;\n            return {128, use_8_warps ? (varlen_and_split ? (is_local ? 96 : 112) : (is_local ? 96 : 128)) : (is_local ? 48 : 64), use_8_warps ? 8 : 4, 1, use_8_warps};\n        } else if (headdim <= 192) {\n            bool const kBlockN_64 = append_kv || is_local || varlen_and_split || paged_kv;\n            return {128, kBlockN_64 ? 64 : 96, 8, sm86_or_89 ? 1 : 2, !kBlockN_64};\n        } else {\n            return {128, sm86_or_89 ? (append_kv ? 32 : (varlen_and_split || is_local ? 48 : 64)) : (append_kv ? 48 : (varlen_and_split || is_local ? 64 : 96)), 8, 1, sm86_or_89 && !append_kv};\n        }\n    } else {\n        // Placeholder for now\n        return {128, 64, 8, 2, false};\n    }\n}\n"
  },
  {
    "path": "hopper/utils.h",
    "content": "/******************************************************************************\n * Copyright (c) 2024, Tri Dao.\n ******************************************************************************/\n\n#pragma once\n\n#include <assert.h>\n#include <stdint.h>\n#include <stdlib.h>\n\n#include <cuda_fp16.h>\n\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800\n#include <cuda_bf16.h>\n#endif\n\n#include <cute/tensor.hpp>\n\n#include <cutlass/cutlass.h>\n#include <cutlass/array.h>\n#include <cutlass/numeric_conversion.h>\n#include <cutlass/numeric_types.h>\n\n#include \"cuda_check.h\"\n\nnamespace flash {\n\nusing namespace cute;\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n// A wrapper for the kernel that is used to guard against compilation on\n// architectures that will never use the kernel. The purpose of this is to\n// reduce the size of the compiled binary.\n// Adapted from https://github.com/vllm-project/vllm/blob/4d29e91be84d27ca313d657eee92c067439a4c23/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh#L55\ntemplate <typename Kernel>\nstruct enable_sm90 : Kernel {\n    template <typename... Args>\n    CUTLASS_DEVICE void operator()(Args&&... args) {\n#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 900)\n        Kernel::operator()(std::forward<Args>(args)...);\n#endif\n    }\n};\n\ntemplate <typename Kernel>\nstruct enable_sm80_to_sm89 : Kernel {\n    template <typename... Args>\n    CUTLASS_DEVICE void operator()(Args&&... args) {\n#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ <= 890)\n        Kernel::operator()(std::forward<Args>(args)...);\n#endif\n    }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename T>\nstruct MaxOp {\n__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; }\n};\n\ntemplate <>\nstruct MaxOp<float> {\n// This is slightly faster\n__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename T>\nstruct SumOp {\n__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<int THREADS>\nstruct Allreduce {\n    static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);\n    template<typename T, typename Operator>\n    static __device__ __forceinline__ T run(T x, Operator &op) {\n        constexpr int OFFSET = THREADS / 2;\n        x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));\n        return Allreduce<OFFSET>::run(x, op);\n    }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<>\nstruct Allreduce<2> {\ntemplate<typename T, typename Operator>\nstatic __device__ __forceinline__ T run(T x, Operator &op) {\n    x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));\n    return x;\n}\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nCUTLASS_HOST_DEVICE\nint div_floor(cutlass::FastDivmod const& divmod, int dividend) {\n    // Take care of the negative case: https://stackoverflow.com/questions/39304681/division-with-negative-dividend-but-rounded-towards-negative-infinity\n    // Maybe the compiler will turn the -1 - * into bit negation operation, I haven't checked.\n    return dividend >= 0 ? divmod.divide(dividend) : -1 - divmod.divide(-1 - dividend);\n}\n\nCUTLASS_HOST_DEVICE\nint round_down(cutlass::FastDivmod const& divmod, int dividend) {\n    return div_floor(divmod, dividend) * divmod.divisor;\n}\n\nCUTLASS_HOST_DEVICE\nint round_up(cutlass::FastDivmod const& divmod, int dividend) {\n    return div_floor(divmod, dividend - 1) * divmod.divisor + divmod.divisor;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n// For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))\n// For SM90, convert acc_layout from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))\ntemplate<bool Transposed=false, typename Layout0>\nCUTLASS_DEVICE auto convert_layout_acc_rowcol(Layout0 acc_layout) {\n    if constexpr (decltype(rank<0>(acc_layout))::value == 3) {  // SM90\n        static_assert(decltype(size<0, 0>(acc_layout))::value == 2);\n        static_assert(decltype(size<0, 1>(acc_layout))::value == 2);\n        static_assert(decltype(rank(acc_layout))::value == 3);\n        auto l = acc_layout;\n        if constexpr (!Transposed) {\n            return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)));\n        } else {\n             return make_layout(make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l)));\n        }\n\n    } else {  // SM80\n        static_assert(decltype(size<0>(acc_layout))::value == 4);\n        static_assert(decltype(rank(acc_layout))::value == 3);\n        auto l = logical_divide(acc_layout, Shape<_2>{});  // ((2, 2), MMA_M, MMA_N)\n        if constexpr (!Transposed) {\n            return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));\n        } else {\n            return make_layout(make_layout(get<0, 0>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l)));\n        }\n    }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n// For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)\n// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8.\n// For SM90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N))\n// For SM90, FP8, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((4, 2, 2), MMA_M, (N / 32, MMA_N))\ntemplate<typename MMA_Traits, typename Layout0>\nCUTLASS_DEVICE auto convert_layout_acc_Aregs(Layout0 acc_layout) {\n    using X = Underscore;\n    if constexpr (decltype(rank<0>(acc_layout))::value == 3) {  // SM90\n        static_assert(decltype(size<0, 0>(acc_layout))::value == 2);\n        static_assert(decltype(size<0, 1>(acc_layout))::value == 2);\n        static_assert(decltype(rank(acc_layout))::value == 3);\n        static_assert(decltype(rank(get<0>(acc_layout)))::value == 3);\n        if constexpr (sizeof(typename MMA_Traits::ValTypeA) == 2) {\n            auto l = logical_divide(get<0, 2>(acc_layout), Tile<_2>{});  // ((2, N / 16))\n            return make_layout(make_layout(get<0, 0>(acc_layout), get<0, 1>(acc_layout), get<0, 0>(l)), get<1>(acc_layout), coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout))));\n        } else {\n            static_assert(sizeof(typename MMA_Traits::ValTypeA) == 1);\n            static_assert(decltype(stride<0, 0>(acc_layout))::value == 1);\n            static_assert(decltype(stride<0, 1>(acc_layout))::value == 2);\n            auto l = logical_divide(get<0, 2>(acc_layout), Tile<Layout<Shape<_2, _2>>>{});  // (((2, 2), N / 32))\n            // This combines the first two modes (<0, 0> and <0, 1>) into one mode.\n            // Will require register shuffling later to be correct.\n            return make_layout(make_layout(Layout<_4>{}, get<0, 0, 0>(l), get<0, 0, 1>(l)),\n                               get<1>(acc_layout),\n                               coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout))));  // ((4, 2, 2), MMA_M, N / 32 * MMA_N)\n            // This combination is right but doesn't work with register shuffling.\n            // return make_layout(make_layout(coalesce(make_layout(get<0, 0>(acc_layout), get<0, 0, 0>(l))), get<0, 1>(acc_layout), get<0, 0, 1>(l)),\n            //                    get<1>(acc_layout),\n            //                    coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout))));\n        }\n    } else {  // SM80\n        static_assert(decltype(size<0>(acc_layout))::value == 4);\n        static_assert(decltype(rank(acc_layout))::value == 3);\n        constexpr int mma_shape_K = get<2>(typename MMA_Traits::Shape_MNK{});\n        static_assert(mma_shape_K == 8 || mma_shape_K == 16);\n        if constexpr (mma_shape_K == 8) {\n            return acc_layout;\n        } else {\n            auto l = logical_divide(acc_layout, Shape<X, X, _2>{});  // (4, MMA_M, (2, MMA_N / 2)))\n            return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));\n        }\n    }\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename To_type, typename Engine, typename Layout>\nCUTLASS_DEVICE auto convert_type_unsafe(Tensor<Engine, Layout> const &tensor) {\n    using From_type = typename Engine::value_type;\n    static constexpr int numel = decltype(size(tensor))::value;\n    cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;\n    // HACK: this requires tensor to be \"contiguous\"\n    auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));\n    return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());\n    // Unsafe because we're returning a tensor with memory allocated on the stack. If the compiler does not\n    // inline this function, then the memory might not be valid.\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Engine, typename Layout, typename EngineOut>\nCUTLASS_DEVICE void convert_type_out(Tensor<Engine, Layout> const &tensor, Tensor<EngineOut, Layout> &out) {\n    // Somehow if we allocate out inside this function and return it, e2e is slower and the output can be wrong.\n    using From_type = typename Engine::value_type;\n    using To_type = typename EngineOut::value_type;\n    static constexpr int FragmentSize = std::max(sizeof(From_type) / sizeof(To_type), sizeof(To_type) / sizeof(From_type));\n    static_assert(CUTE_STATIC_V(size(tensor)) % FragmentSize == 0, \"Fragment size does not vectorize properly\");\n    Tensor frag = recast<cutlass::Array<From_type, FragmentSize> const>(tensor);\n    Tensor out_frg = recast<cutlass::Array<To_type, FragmentSize>>(out);\n    static_assert(size(frag) == size(out_frg));\n    cutlass::NumericArrayConverter<To_type, From_type, FragmentSize> convert_op;\n    #pragma unroll\n    for (int i = 0; i < size(frag); ++i) { out_frg[i] = convert_op(frag[i]); }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n// Blocks until all but N previous cp.async.commit_group operations have committed.\n// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all\n// (which is equivalent to commit_group then wait_group 0).\n// Instead we just call cp.async.wait_group 0, which is slightly faster.\n// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113\ntemplate <int N>\nCUTE_HOST_DEVICE\nvoid cp_async_wait() {\n#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)\n    asm volatile(\"cp.async.wait_group %0;\\n\" :: \"n\"(N));\n#endif\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <bool A, class Mma, class Tensor0>\nCUTLASS_DEVICE\nauto mma_partition_fragment_AB(Mma const& mma, Tensor0 const& tensor0) {\n    if constexpr (A) {\n        return mma.partition_fragment_A(tensor0);\n    } else {\n        return mma.partition_fragment_B(tensor0);\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <bool zero_init=false, int wg_wait=0, bool SwapAB=false, int M_slice=-1,\n        typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>\nCUTLASS_DEVICE void gemm(TiledMma& tiled_mma, Tensor0 const& tCrA, Tensor1 const& tCrB, Tensor2& tCrC) {\n    if constexpr (M_slice >= 0) {\n        static constexpr int MMA_M = decltype(size<1>(tCrC))::value;\n        static_assert(M_slice < MMA_M);\n        // After logical_divide, C has shape ((2,2,V), (MMA_M, 1), MMA_N)\n        Tensor tCrC_slice = cute::logical_divide(tCrC, Shape<cute::Underscore, Int<MMA_M>>{})(_, make_coord(Int<M_slice>{}, _), _);\n        if constexpr (!SwapAB) {\n            Tensor tCrA_slice = cute::logical_divide(tCrA, Shape<cute::Underscore, Int<MMA_M>>{})(_, make_coord(Int<M_slice>{}, _), _);\n            gemm<zero_init, wg_wait, SwapAB, /*M_slice=*/-1>(tiled_mma, tCrA_slice, tCrB, tCrC_slice);\n        } else {\n            Tensor tCrB_slice = cute::logical_divide(tCrB, Shape<cute::Underscore, Int<MMA_M>>{})(_, make_coord(Int<M_slice>{}, _), _);\n            gemm<zero_init, wg_wait, SwapAB, /*M_slice=*/-1>(tiled_mma, tCrA, tCrB_slice, tCrC_slice);\n        }\n    } else {\n        constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;\n        // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const\n        if constexpr (Is_RS) {\n            if constexpr (!SwapAB) {\n                warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA));\n            } else {\n                warpgroup_fence_operand(const_cast<Tensor1 &>(tCrB));\n            }\n        }\n        warpgroup_fence_operand(tCrC);\n        warpgroup_arrive();\n        if constexpr (zero_init) {\n            tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;\n        }\n        static constexpr int kNumKIters = CUTE_STATIC_V(size<2>(tCrA));\n        static constexpr int kMaxKIters = 16;\n        // Unroll the K mode manually to set scale D to 1\n        CUTLASS_PRAGMA_UNROLL\n        for (int k_block = 0; k_block < std::min(kNumKIters, kMaxKIters); ++k_block) {\n            if constexpr (!SwapAB) {\n                cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);\n            } else {\n                cute::gemm(tiled_mma, tCrB(_,_,k_block), tCrA(_,_,k_block), tCrC);\n            }\n            tiled_mma.accumulate_ = GMMA::ScaleOut::One;\n        }\n        // In the case of large kNumKIters, the compiler chooses to store the smem addresses\n        // in registers, causing spills. This loop forces the compiler to recompute the addresses.\n        if constexpr (kNumKIters > kMaxKIters) {\n            // This will always be zero, just a way to force the compiler to recompute the smem\n            // addresses. This results in USEL instructions. There's probably a better way to do this.\n            int const k_offset = cutlass::canonical_warp_group_idx() < 128 ? 0 : 1;\n            CUTLASS_PRAGMA_UNROLL\n            for (int k_block = kMaxKIters; k_block < kNumKIters; ++k_block) {\n                if constexpr (!SwapAB) {\n                    cute::gemm(tiled_mma, tCrA(_,_,k_block + k_offset), tCrB(_,_,k_block + k_offset), tCrC);\n                } else {\n                    cute::gemm(tiled_mma, tCrB(_,_,k_block + k_offset), tCrA(_,_,k_block + k_offset), tCrC);\n                }\n                tiled_mma.accumulate_ = GMMA::ScaleOut::One;\n            }\n        }\n        warpgroup_commit_batch();\n        if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }\n        warpgroup_fence_operand(tCrC);\n        if constexpr (Is_RS) {\n            if constexpr (!SwapAB) {\n                warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA));\n            } else {\n                warpgroup_fence_operand(const_cast<Tensor1 &>(tCrB));\n            }\n        }\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<bool A_in_regs=false, bool B_in_regs=false, bool SwapAB=false,\n         typename Tensor0, typename Tensor1,\n         typename Tensor2, typename Tensor3, typename Tensor4,\n         typename TiledMma, typename TiledCopyA, typename TiledCopyB,\n         typename ThrCopyA, typename ThrCopyB, typename Hook>\nCUTLASS_DEVICE void gemm_sm80(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA,\n                              Tensor4 const& tCsB, TiledMma tiled_mma,\n                              TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B,\n                              ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B, Hook fn) {\n    if constexpr (SwapAB) {\n        gemm_sm80<B_in_regs, A_in_regs>(acc, tCrB, tCrA, tCsB, tCsA, tiled_mma, smem_tiled_copy_B, smem_tiled_copy_A, smem_thr_copy_B, smem_thr_copy_A, fn);\n    } else {\n        CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc));                     // MMA_M\n        CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc));                     // MMA_N\n        CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB));                     // MMA_K\n        Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA);\n        CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view));            // M\n        Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);\n        CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view));            // N\n        if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); }\n        if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); }\n        #pragma unroll\n        for (int i = 0; i < size<2>(tCrA); ++i) {\n            if (i < size<2>(tCrA) - 1) {\n                if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); }\n                if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); }\n            }\n            if constexpr (!std::is_same_v<Hook, std::nullptr_t>) {\n                if (i == 0) { fn(); }\n            }\n            cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);\n        }\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,\n         typename TiledMma, typename TiledCopy, typename ThrCopy>\nCUTLASS_DEVICE void gemm_rs_sm80(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,\n                                 TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,\n                                 ThrCopy smem_thr_copy_B) {\n    CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc));                     // MMA_M\n    CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc));                     // MMA_N\n    CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB));                     // MMA_K\n    Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);\n    CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view));            // N\n    cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{}));\n    #pragma unroll\n    for (int i = 0; i < size<2>(tCrA); ++i) {\n        if (i < size<2>(tCrA) - 1) {\n            cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1));\n        }\n        cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <bool zero_init=false, typename Atom, typename TA, typename TB, typename TC>\nCUTLASS_DEVICE void gemm_sm100(Atom& atom, TA const& tA, TB const& tB, TC&& tC) {\n    static constexpr int rA = decltype(rank(tA))::value;\n    static constexpr int rB = decltype(rank(tB))::value;\n    static constexpr int rC = decltype(rank(tC))::value;\n    static_assert(rA == 3 && rB == 3 && rC == 3);\n\n    if constexpr (zero_init) { atom.accumulate_ = decltype(atom.accumulate_)::Zero; }\n    CUTLASS_PRAGMA_UNROLL\n    for (int k_block = 0; k_block < size<2>(tA); k_block++) {\n        cute::gemm(atom, tA(_,_,k_block), tB(_,_,k_block), tC);\n        atom.accumulate_ = decltype(atom.accumulate_)::One;\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <class a_type, class b_type, class c_type,\n          int M, int N, UMMA::Major a_major, UMMA::Major b_major,\n          UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg, class... TAs, class... TMs>\nCUTE_HOST_DEVICE constexpr\nauto\nto_tiled_mma_sm100_ts(\n    TiledMMA<MMA_Atom<\n      MMA_Traits<SM100_MMA_F8F6F4_SS, a_type, b_type, c_type,\n                    cute::C<M>, cute::C<N>,\n                    cute::integral_constant<UMMA::Major, a_major>,\n                    cute::integral_constant<UMMA::Major, b_major>,\n                    cute::integral_constant<UMMA::ScaleIn, a_neg>,\n                    cute::integral_constant<UMMA::ScaleIn, b_neg>>,\n      TAs...>, TMs...>) {\n\n  return TiledMMA<MMA_Atom<\n    MMA_Traits<SM100_MMA_F8F6F4_TS<a_type, b_type, c_type,\n                                M, N,\n                                a_major, b_major,\n                                a_neg, b_neg, UMMA::Saturate::False>>,\n    TAs...>, TMs...>{};\n}\n\ntemplate <class a_type, class b_type, class c_type,\n          int M, int N, UMMA::Major a_major, UMMA::Major b_major,\n          UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg, class... TAs, class... TMs>\nCUTE_HOST_DEVICE constexpr\nauto\nto_tiled_mma_sm100_ts(\n    TiledMMA<MMA_Atom<\n      SM100_MMA_F16BF16_SS<a_type, b_type, c_type,\n                    M, N,\n                    a_major,\n                    b_major,\n                    a_neg,\n                    b_neg>,\n      TAs...>, TMs...>) {\n  return TiledMMA<MMA_Atom<\n    SM100_MMA_F16BF16_TS<a_type, b_type, c_type,\n                                M, N,\n                                a_major, b_major,\n                                a_neg, b_neg, UMMA::Saturate::False>,\n    TAs...>, TMs...>{};\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,\n          class CopyAtom, class TV, class Tiler, typename Engine0, typename Layout0, typename Engine1, typename Layout1,\n          typename Engine2, typename Layout2, typename Engine3, typename Layout3>\nCUTLASS_DEVICE void copy(TiledCopy<CopyAtom, TV, Tiler> const &tiled_copy, Tensor<Engine0, Layout0> const &S,\n                         Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,\n                         Tensor<Engine3, Layout3> const &predicate_K, const int max_MN=0) {\n    // Decay TiledCopy to CopyAtom\n    auto copy_atom = static_cast<CopyAtom const&>(tiled_copy);\n    CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});\n    CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});\n    CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D));                     // MMA\n    CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D));                     // MMA_M\n    CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D));                     // MMA_K\n    // There's no case where !Clear_OOB_K && Clear_OOB_MN\n    static_assert(!(Clear_OOB_MN && !Clear_OOB_K));\n    auto has_with_bool = cute::is_valid([](auto t)->void_t<decltype(declval<typename decltype(t)::Traits>().with(true))>{}, copy_atom);\n    #pragma unroll\n    for (int m = 0; m < size<1>(S); ++m) {\n        bool predicate_mn = Is_even_MN || get<0>(identity_MN(_0{}, m, _0{})) < max_MN;\n        if constexpr (Is_even_MN || !Clear_OOB_MN) {\n            if (Is_even_MN || predicate_mn) {\n                #pragma unroll\n                for (int k = 0; k < size<2>(S); ++k) {\n                    if constexpr (Is_even_K || !Clear_OOB_K) {\n                        if (Is_even_K || predicate_K(k)) { cute::copy(copy_atom, S(_, m, k), D(_, m, k)); }\n                    } else {  // Clear_OOB_K == true && Is_even_K == false\n                        // If copy traits can be transformed with a predicate value, do it, otherwise branch here\n                        if constexpr (has_with_bool) {\n                            cute::copy(copy_atom.with(predicate_K(k)), S(_, m, k), D(_, m, k));\n                        } else {\n                            if (predicate_K(k)) {\n                                cute::copy(copy_atom, S(_, m, k), D(_, m, k));\n                            } else {\n                                cute::clear(D(_, m, k));\n                            }\n                        }\n                    }\n                }\n            }\n        } else {  // Clear_OOB_MN == true && Is_even_MN == false, also implies Clear_OOB_K == true\n            if constexpr (!has_with_bool) {\n                if (predicate_mn) {\n                    #pragma unroll\n                    for (int k = 0; k < size<2>(S); ++k) {\n                        if (Is_even_K || predicate_K(k)) {\n                            cute::copy(copy_atom, S(_, m, k), D(_, m, k));\n                        } else if (Clear_OOB_K) {\n                            cute::clear(D(_, m, k));\n                        }\n                    }\n                } else {\n                    cute::clear(D(_, m, _));\n                }\n            } else {  // combine the mn predicate with the k predicate\n                #pragma unroll\n                for (int k = 0; k < size<2>(S); ++k) {\n                    cute::copy(copy_atom.with(predicate_mn && (Is_even_K || predicate_K(k))), S(_, m, k), D(_, m, k));\n                }\n            }\n        }\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n// Byte permute and shuffle to match register layout of\n// (FP8 downcasted) accumulator of GEMM-I to FP8 operand A of GEMM-II.\ntemplate <typename Fragment>\nCUTLASS_DEVICE void permute_Aregs_fp8(Fragment &frag) {\n    // frag has shape ((4, 2, 2), MMA_M, MMA_N), each element is 8 bits\n    static_assert(decltype(size<0, 0>(frag))::value == 4);\n    static_assert(decltype(size<0, 1>(frag))::value == 2);\n    static_assert(decltype(stride<0, 0>(frag))::value == 1);\n    static_assert(decltype(stride<0, 1>(frag))::value == 4);\n    static_assert(sizeof(typename Fragment::value_type) == 1);\n\n    int quad_idx = threadIdx.x % 4;\n    bool lane_03 = quad_idx == 0 || quad_idx == 3;\n    int selector_upper = lane_03 ? 0x5410 : 0x1054;\n    int selector_lower = lane_03 ? 0x7632 : 0x3276;\n\n    static constexpr int upper_map[4] = {0, 3, 1, 2};\n    // static constexpr int lower_map[4] = {1, 2, 0, 3};\n\n    Tensor frag_64b = recast<uint2>(frag);  // ((1, 1, 2), MMA_M, MMA_N)\n    #pragma unroll\n    for (int i = 0; i < size(frag_64b); ++i) {\n        uint32_t upper = frag_64b[i].x;\n        uint32_t lower = frag_64b[i].y;\n        uint32_t upper0 = lane_03 ? upper : lower;\n        uint32_t lower0 = lane_03 ? lower : upper;\n        upper0 = __shfl_sync(uint32_t(-1), upper0, upper_map[quad_idx], 4);\n        // lower0 = __shfl_sync(uint32_t(-1), lower0, lower_map[quad_idx], 4);\n        lower0 = __shfl_sync(uint32_t(-1), lower0, upper_map[quad_idx] ^ 1, 4);\n        frag_64b[i].x = __byte_perm(upper0, lower0, selector_upper);\n        frag_64b[i].y = __byte_perm(upper0, lower0, selector_lower);\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Fragment>\nCUTLASS_DEVICE void permute_Cregs_fp8(Fragment &frag) {\n    // frag has shape ((2, 2, N / 8), MMA_M, MMA_N), each element is 32 bits\n    static_assert(decltype(size<0, 0>(frag))::value == 2);\n    static_assert(decltype(size<0, 1>(frag))::value == 2);\n    static_assert(decltype(size<0, 2>(frag))::value % 2 == 0);\n    static_assert(decltype(stride<0, 0>(frag))::value == 1);\n    static_assert(sizeof(typename Fragment::value_type) == 4);\n    Tensor frag_64b = group_modes<1, 3>(recast<uint2>(frag));  // ((1, 2, N / 8), (MMA_M, MMA_N))\n    #pragma unroll\n    for (int mi = 0; mi < size<1>(frag_64b); ++mi) {\n        #pragma unroll\n        for (int i = 0; i < size<0, 2>(frag_64b) / 2; ++i) {\n            cutlass::swap(frag_64b(make_coord(_0{}, _1{}, 2 * i), mi), frag_64b(make_coord(_0{}, _0{}, 2 * i + 1), mi));\n        }\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Fragment>\nCUTLASS_DEVICE void permute_output_fp8(Fragment &out) {\n    // out has shape ((2, 2, N / 8), MMA_M, MMA_N), each element is 32 bits\n    static_assert(decltype(size<0, 0>(out))::value == 2);\n    static_assert(decltype(size<0, 1>(out))::value == 2);\n    static_assert(decltype(size<0, 2>(out))::value % 2 == 0);\n    static_assert(decltype(stride<0, 0>(out))::value == 1);\n    static_assert(sizeof(typename Fragment::value_type) == 4);\n    Tensor frag = group_modes<1, 3>(out);  // ((2, 2, N / 8), (MMA_M, MMA_N))\n    #pragma unroll\n    for (int mi = 0; mi < size<1>(frag); ++mi) {\n        #pragma unroll\n        for (int j = 0; j < size<0, 1>(frag); ++j) {\n            #pragma unroll\n            for (int i = 0; i < size<0, 2>(frag) / 2; ++i) {\n                cutlass::swap(frag(make_coord(_1{}, j, 2 * i), mi), frag(make_coord(_0{}, j, 2 * i + 1), mi));\n            }\n        }\n    }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Fragment>\nCUTLASS_DEVICE void permute_output_fp8_Vcolmajor(Fragment &frag) {\n    // frag has shape ((2, 2, N / 8), MMA_M, MMA_N), each element is 16 bits\n    static_assert(decltype(size<0, 0>(frag))::value == 2);\n    static_assert(decltype(size<0, 1>(frag))::value == 2);\n    static_assert(decltype(stride<0, 0>(frag))::value == 1);\n    static_assert(sizeof(typename Fragment::value_type) == 2 || sizeof(typename Fragment::value_type) == 4);\n\n    int quad_idx = threadIdx.x % 4;\n    bool lane_03 = quad_idx == 0 || quad_idx == 3;\n\n    static constexpr int upper_map[4] = {0, 2, 3, 1};\n    // static constexpr int lower_map[4] = {2, 0, 1, 3};\n\n    // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(frag); }\n    using type2 = std::conditional_t<sizeof(typename Fragment::value_type) == 2, uint32_t, uint64_t>;\n    Tensor frag_2 = group_modes<1, 3>(recast<type2>(frag));  // ((1, 2, N / 8), (MMA_M, MMA_N))\n    // if (blockIdx.x == 0 && threadIdx.x == 128) { print(frag); printf(\"\\n\"); print(frag_2); }\n    #pragma unroll\n    for (int mi = 0; mi < size<1>(frag_2); ++mi) {\n        #pragma unroll\n        for (int j = 0; j < size<0, 1>(frag_2); ++j) {\n            #pragma unroll\n            for (int i = 0; i < size<0, 2>(frag_2) / 2; ++i) {\n                type2 upper = frag_2(make_coord(_0{}, j, 2 * i), mi);\n                type2 lower = frag_2(make_coord(_0{}, j, 2 * i + 1), mi);\n                type2 upper0 = lane_03 ? upper : lower;\n                type2 lower0 = lane_03 ? lower : upper;\n                upper0 = __shfl_sync(uint32_t(-1), upper0, upper_map[quad_idx], 4);\n                // lower0 = __shfl_sync(uint32_t(-1), lower0, lower_map[quad_idx], 4);\n                lower0 = __shfl_sync(uint32_t(-1), lower0, upper_map[quad_idx] ^ 2, 4);\n                frag_2(make_coord(_0{}, j, 2 * i), mi) = lane_03 ? upper0 : lower0;\n                frag_2(make_coord(_0{}, j, 2 * i + 1), mi) = lane_03 ? lower0 : upper0;\n            }\n        }\n    }\n    // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(frag); }\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename Engine, typename Layout>\nCUTLASS_DEVICE void apply_softcap(Tensor<Engine, Layout> &tensor, float const softcap){\n    #pragma unroll\n    for (int i = 0; i < size(tensor); ++i) {\n        tensor(i) = cutlass::fast_tanh(tensor(i) * softcap);\n    }\n}\n\ntemplate <typename Engine, typename Layout>\nCUTLASS_DEVICE auto calculate_dtanh(Tensor<Engine, Layout> &tensor){\n    Tensor out = make_fragment_like<float>(tensor);\n    #pragma unroll\n    for (int i = 0; i < size(tensor); ++i) {\n        out(i) = 1.f - (tensor(i) * tensor(i));\n    }\n    return out;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<class T>\nCUTE_DEVICE T warp_prefix_sum(T val) {\n    int lane = threadIdx.x % cutlass::NumThreadsPerWarp;\n    CUTLASS_PRAGMA_UNROLL\n    for (int i = 1; i < cutlass::NumThreadsPerWarp; i <<= 1) {\n        T partial_sum = __shfl_up_sync(0xffffffff, val, i);\n        if (lane >= i) { val += partial_sum; }\n    }\n    return val;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate<class T>\nCUTE_DEVICE T warp_uniform(T a) {\n    return __shfl_sync(0xffffffff, a, 0);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\nCUTLASS_DEVICE\nint canonical_warp_group_idx_nosync() {\n    return threadIdx.x / cutlass::NumThreadsPerWarpGroup;\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n\n} // namespace flash\n"
  },
  {
    "path": "setup.py",
    "content": "# Copyright (c) 2023, Tri Dao.\n\nimport sys\nimport functools\nimport warnings\nimport os\nimport re\nimport ast\nimport glob\nimport shutil\nfrom pathlib import Path\nfrom typing import Literal, Optional\nfrom packaging.version import parse, Version\nimport platform\n\nfrom setuptools import setup, find_packages\nimport subprocess\n\nimport urllib.request\nimport urllib.error\nfrom wheel.bdist_wheel import bdist_wheel as _bdist_wheel\n\nimport torch\nfrom torch.utils.cpp_extension import (\n    BuildExtension,\n    CppExtension,\n    CUDAExtension,\n    CUDA_HOME,\n    ROCM_HOME,\n    IS_HIP_EXTENSION,\n)\n\n\nwith open(\"README.md\", \"r\", encoding=\"utf-8\") as fh:\n    long_description = fh.read()\n\n\n# ninja build does not work unless include_dirs are abs path\nthis_dir = os.path.dirname(os.path.abspath(__file__))\n\nBUILD_TARGET = os.environ.get(\"BUILD_TARGET\", \"auto\")\n\nif BUILD_TARGET == \"auto\":\n    if IS_HIP_EXTENSION:\n        IS_ROCM = True\n    else:\n        IS_ROCM = False\nelse:\n    if BUILD_TARGET == \"cuda\":\n        IS_ROCM = False\n    elif BUILD_TARGET == \"rocm\":\n        IS_ROCM = True\n\nPACKAGE_NAME = \"flash_attn\"\n\nBASE_WHEEL_URL = (\n    \"https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}\"\n)\n\n# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels\n# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation\nFORCE_BUILD = os.getenv(\"FLASH_ATTENTION_FORCE_BUILD\", \"FALSE\") == \"TRUE\"\nSKIP_CUDA_BUILD = os.getenv(\"FLASH_ATTENTION_SKIP_CUDA_BUILD\", \"FALSE\") == \"TRUE\"\n# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI\nFORCE_CXX11_ABI = os.getenv(\"FLASH_ATTENTION_FORCE_CXX11_ABI\", \"FALSE\") == \"TRUE\"\nROCM_BACKEND: Optional[Literal[\"triton\", \"ck\"]] = None\nif IS_ROCM:\n    ROCM_BACKEND = \"triton\" if os.getenv(\"FLASH_ATTENTION_TRITON_AMD_ENABLE\", \"FALSE\") == \"TRUE\" else \"ck\"\nNVCC_THREADS = os.getenv(\"NVCC_THREADS\") or \"4\"\n\n@functools.lru_cache(maxsize=None)\ndef cuda_archs() -> str:\n    return os.getenv(\"FLASH_ATTN_CUDA_ARCHS\", \"80;90;100;110;120\").split(\";\")\n\n\ndef get_platform():\n    \"\"\"\n    Returns the platform name as used in wheel filenames.\n    \"\"\"\n    if sys.platform.startswith(\"linux\"):\n        return f'linux_{platform.uname().machine}'\n    elif sys.platform == \"darwin\":\n        mac_version = \".\".join(platform.mac_ver()[0].split(\".\")[:2])\n        return f\"macosx_{mac_version}_x86_64\"\n    elif sys.platform == \"win32\":\n        return \"win_amd64\"\n    else:\n        raise ValueError(\"Unsupported platform: {}\".format(sys.platform))\n\n\ndef get_cuda_bare_metal_version(cuda_dir):\n    raw_output = subprocess.check_output([cuda_dir + \"/bin/nvcc\", \"-V\"], universal_newlines=True)\n    output = raw_output.split()\n    release_idx = output.index(\"release\") + 1\n    bare_metal_version = parse(output[release_idx].split(\",\")[0])\n\n    return raw_output, bare_metal_version\n\n\ndef add_cuda_gencodes(cc_flag, archs, bare_metal_version):\n    \"\"\"\n    Adds -gencode flags based on nvcc capabilities:\n      - sm_80/90 (regular)\n      - sm_100/120 on CUDA >= 12.8\n      - Use 100f on CUDA >= 12.9 (Blackwell family-specific)\n      - Map requested 110 -> 101 if CUDA < 13.0 (Thor rename)\n      - Embed PTX for newest arch for forward compatibility\n    \"\"\"\n    # Always-regular 80\n    if \"80\" in archs:\n        cc_flag += [\"-gencode\", \"arch=compute_80,code=sm_80\"]\n\n    # Hopper 9.0 needs >= 11.8\n    if bare_metal_version >= Version(\"11.8\") and \"90\" in archs:\n        cc_flag += [\"-gencode\", \"arch=compute_90,code=sm_90\"]\n\n    # Blackwell 10.x requires >= 12.8\n    if bare_metal_version >= Version(\"12.8\"):\n        if \"100\" in archs:\n            # CUDA 12.9 introduced \"family-specific\" for Blackwell (100f)\n            if bare_metal_version >= Version(\"12.9\"):\n                cc_flag += [\"-gencode\", \"arch=compute_100f,code=sm_100\"]\n            else:\n                cc_flag += [\"-gencode\", \"arch=compute_100,code=sm_100\"]\n\n        if \"120\" in archs:\n            # sm_120 is supported in CUDA 12.8/12.9+ toolkits\n            if bare_metal_version >= Version(\"12.9\"):\n                cc_flag += [\"-gencode\", \"arch=compute_120f,code=sm_120\"]\n            else:\n                cc_flag += [\"-gencode\", \"arch=compute_120,code=sm_120\"]\n\n\n        # Thor rename: 12.9 uses sm_101; 13.0+ uses sm_110\n        if \"110\" in archs:\n            if bare_metal_version >= Version(\"13.0\"):\n                cc_flag += [\"-gencode\", \"arch=compute_110f,code=sm_110\"]\n            else:\n                # Provide Thor support for CUDA 12.9 via sm_101\n                if bare_metal_version >= Version(\"12.8\"):\n                    cc_flag += [\"-gencode\", \"arch=compute_101,code=sm_101\"]\n                # else: no Thor support in older toolkits\n\n    # PTX for newest requested arch (forward-compat)\n    numeric = [a for a in archs if a.isdigit()]\n    if numeric:\n        newest = max(numeric, key=int)\n        cc_flag += [\"-gencode\", f\"arch=compute_{newest},code=compute_{newest}\"]\n\n    return cc_flag\n\n\ndef get_hip_version():\n    return parse(torch.version.hip.split()[-1].rstrip('-').replace('-', '+'))\n\n\ndef check_if_cuda_home_none(global_option: str) -> None:\n    if CUDA_HOME is not None:\n        return\n    # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary\n    # in that case.\n    warnings.warn(\n        f\"{global_option} was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  \"\n        \"If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, \"\n        \"only images whose names contain 'devel' will provide nvcc.\"\n    )\n\n\ndef check_if_rocm_home_none(global_option: str) -> None:\n    if ROCM_HOME is not None:\n        return\n    # warn instead of error because user could be downloading prebuilt wheels, so hipcc won't be necessary\n    # in that case.\n    warnings.warn(\n        f\"{global_option} was requested, but hipcc was not found.\"\n    )\n\n\ndef detect_hipify_v2():\n    try:\n        from torch.utils.hipify import __version__\n        from packaging.version import Version\n        if Version(__version__) >= Version(\"2.0.0\"):\n            return True\n    except Exception as e:\n        print(\"failed to detect pytorch hipify version, defaulting to version 1.0.0 behavior\")\n        print(e)\n    return False\n\n\ndef append_nvcc_threads(nvcc_extra_args):\n    return nvcc_extra_args + [\"--threads\", NVCC_THREADS]\n\n\ndef rename_cpp_to_cu(cpp_files):\n    for entry in cpp_files:\n        shutil.copy(entry, os.path.splitext(entry)[0] + \".cu\")\n\n\ndef validate_and_update_archs(archs):\n    # List of allowed architectures\n    allowed_archs = [\"native\", \"gfx90a\", \"gfx950\", \"gfx942\"]\n\n    # Validate if each element in archs is in allowed_archs\n    assert all(\n        arch in allowed_archs for arch in archs\n    ), f\"One of GPU archs of {archs} is invalid or not supported by Flash-Attention\"\n\n\ncmdclass = {}\next_modules = []\n\n# We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp\n# files included in the source distribution, in case the user compiles from source.\nif IS_ROCM:\n    if ROCM_BACKEND == \"triton\":\n        if os.path.isdir(\".git\"):\n            subprocess.run([\"git\", \"submodule\", \"update\", \"--init\", \"third_party/aiter\"], check=True)\n        else:\n            assert os.path.isdir(\"third_party/aiter\"), (\n                \"third_party/aiter is missing, please use source distribution or git clone\"\n            )\n        subprocess.run(\n            [sys.executable, \"-m\", \"pip\", \"install\", \"--no-build-isolation\", \"third_party/aiter\"],\n            check=True,\n        )\n    elif ROCM_BACKEND == \"ck\":\n        if os.path.isdir(\".git\"):\n            subprocess.run([\"git\", \"submodule\", \"update\", \"--init\", \"csrc/composable_kernel\"], check=True)\n        else:\n            assert os.path.exists(\"csrc/composable_kernel/example/ck_tile/01_fmha/generate.py\"), (\n                \"csrc/composable_kernel is missing, please use source distribution or git clone\"\n            )\nelse:\n    # CUDA: cutlass submodule\n    if os.path.isdir(\".git\"):\n        subprocess.run([\"git\", \"submodule\", \"update\", \"--init\", \"csrc/cutlass\"], check=True)\n    else:\n        assert os.path.exists(\"csrc/cutlass/include/cutlass/cutlass.h\"), (\n            \"csrc/cutlass is missing, please use source distribution or git clone\"\n        )\n\nif not SKIP_CUDA_BUILD and not IS_ROCM:\n    print(\"\\n\\ntorch.__version__  = {}\\n\\n\".format(torch.__version__))\n    TORCH_MAJOR = int(torch.__version__.split(\".\")[0])\n    TORCH_MINOR = int(torch.__version__.split(\".\")[1])\n\n    check_if_cuda_home_none(\"flash_attn\")\n    # Check, if CUDA11 is installed for compute capability 8.0\n    cc_flag = []\n    if CUDA_HOME is not None:\n        _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)\n        if bare_metal_version < Version(\"11.7\"):\n            raise RuntimeError(\n                \"FlashAttention is only supported on CUDA 11.7 and above.  \"\n                \"Note: make sure nvcc has a supported version by running nvcc -V.\"\n            )\n        # Build -gencode (regular + PTX + family-specific 'f' when available)\n        add_cuda_gencodes(cc_flag, set(cuda_archs()), bare_metal_version)\n    else:\n        # No nvcc present; warnings already emitted above\n        pass\n\n    # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as\n    # torch._C._GLIBCXX_USE_CXX11_ABI\n    # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920\n    if FORCE_CXX11_ABI:\n        torch._C._GLIBCXX_USE_CXX11_ABI = True\n\n    nvcc_flags = [\n    \"-O3\",\n    \"-std=c++17\",\n    \"-U__CUDA_NO_HALF_OPERATORS__\",\n    \"-U__CUDA_NO_HALF_CONVERSIONS__\",\n    \"-U__CUDA_NO_HALF2_OPERATORS__\",\n    \"-U__CUDA_NO_BFLOAT16_CONVERSIONS__\",\n    \"--expt-relaxed-constexpr\",\n    \"--expt-extended-lambda\",\n    \"--use_fast_math\",\n    # \"--ptxas-options=-v\",\n    # \"--ptxas-options=-O2\",\n    # \"-lineinfo\",\n    # \"-DFLASHATTENTION_DISABLE_BACKWARD\",\n    # \"-DFLASHATTENTION_DISABLE_DROPOUT\",\n    # \"-DFLASHATTENTION_DISABLE_ALIBI\",\n    # \"-DFLASHATTENTION_DISABLE_SOFTCAP\",\n    # \"-DFLASHATTENTION_DISABLE_UNEVEN_K\",\n    # \"-DFLASHATTENTION_DISABLE_LOCAL\",\n    ]\n\n    compiler_c17_flag=[\"-O3\", \"-std=c++17\"]\n    # Add Windows-specific flags\n    if sys.platform == \"win32\" and os.getenv('DISTUTILS_USE_SDK') == '1':\n        nvcc_flags.extend([\"-Xcompiler\", \"/Zc:__cplusplus\"])\n        compiler_c17_flag=[\"-O2\", \"/std:c++17\", \"/Zc:__cplusplus\"]\n\n    ext_modules.append(\n        CUDAExtension(\n            name=\"flash_attn_2_cuda\",\n            sources=[\n                \"csrc/flash_attn/flash_api.cpp\",\n                \"csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu\",\n                \"csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu\",\n                \"csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu\",\n                \"csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu\",\n                \"csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu\",\n                \"csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu\",\n                \"csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu\",\n                \"csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu\",\n                \"csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu\",\n                \"csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu\",\n                \"csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu\",\n                \"csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu\",\n                \"csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu\",\n                \"csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu\",\n                \"csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu\",\n                \"csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu\",\n                \"csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu\",\n                \"csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu\",\n                \"csrc/flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu\",\n                \"csrc/flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu\",\n                \"csrc/flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu\",\n                \"csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu\",\n                \"csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu\",\n                \"csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu\",\n                \"csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu\",\n                \"csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu\",\n                \"csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu\",\n                \"csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu\",\n                \"csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu\",\n                \"csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu\",\n                \"csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu\",\n                \"csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu\",\n                \"csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu\",\n                \"csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu\",\n                \"csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu\",\n                \"csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu\",\n                \"csrc/flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu\",\n                \"csrc/flash_attn/src/flash_bwd_hdim32_bf16_causal_sm80.cu\",\n                \"csrc/flash_attn/src/flash_bwd_hdim64_fp16_causal_sm80.cu\",\n                \"csrc/flash_attn/src/flash_bwd_hdim64_bf16_causal_sm80.cu\",\n                \"csrc/flash_attn/src/flash_bwd_hdim96_fp16_causal_sm80.cu\",\n                \"csrc/flash_attn/src/flash_bwd_hdim96_bf16_causal_sm80.cu\",\n                \"csrc/flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu\",\n                \"csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu\",\n                \"csrc/flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu\",\n                \"csrc/flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu\",\n                \"csrc/flash_attn/src/flash_bwd_hdim256_fp16_causal_sm80.cu\",\n                \"csrc/flash_attn/src/flash_bwd_hdim256_bf16_causal_sm80.cu\",\n                \"csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu\",\n                \"csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu\",\n                \"csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu\",\n                \"csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu\",\n                \"csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu\",\n                \"csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu\",\n                \"csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu\",\n                \"csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu\",\n                \"csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu\",\n                \"csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu\",\n                \"csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu\",\n                \"csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu\",\n                \"csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu\",\n                \"csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu\",\n                \"csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu\",\n                \"csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu\",\n                \"csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu\",\n                \"csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu\",\n                \"csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu\",\n                \"csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu\",\n                \"csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu\",\n                \"csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu\",\n                \"csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu\",\n                \"csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu\",\n            ],\n            extra_compile_args={\n                \"cxx\": compiler_c17_flag,\n                \"nvcc\": append_nvcc_threads(nvcc_flags + cc_flag),\n            },\n            include_dirs=[\n                Path(this_dir) / \"csrc\" / \"flash_attn\",\n                Path(this_dir) / \"csrc\" / \"flash_attn\" / \"src\",\n                Path(this_dir) / \"csrc\" / \"cutlass\" / \"include\",\n            ],\n        )\n    )\nelif not SKIP_CUDA_BUILD and IS_ROCM:\n    print(\"\\n\\ntorch.__version__  = {}\\n\\n\".format(torch.__version__))\n    TORCH_MAJOR = int(torch.__version__.split(\".\")[0])\n    TORCH_MINOR = int(torch.__version__.split(\".\")[1])\n\n    # Skips CK C++ extension compilation if using Triton Backend\n    if ROCM_BACKEND == \"ck\":\n        ck_dir = \"csrc/composable_kernel\"\n\n        #use codegen get code dispatch\n        if not os.path.exists(\"./build\"):\n            os.makedirs(\"build\")\n\n        optdim = os.getenv(\"OPT_DIM\", \"32,64,128,256\")\n        subprocess.run([sys.executable, f\"{ck_dir}/example/ck_tile/01_fmha/generate.py\", \"-d\", \"fwd\", \"--output_dir\", \"build\", \"--receipt\", \"2\", \"--optdim\", optdim], check=True)\n        subprocess.run([sys.executable, f\"{ck_dir}/example/ck_tile/01_fmha/generate.py\", \"-d\", \"fwd_appendkv\", \"--output_dir\", \"build\", \"--receipt\", \"2\", \"--optdim\", optdim], check=True)\n        subprocess.run([sys.executable, f\"{ck_dir}/example/ck_tile/01_fmha/generate.py\", \"-d\", \"fwd_splitkv\", \"--output_dir\", \"build\", \"--receipt\", \"2\", \"--optdim\", optdim], check=True)\n        subprocess.run([sys.executable, f\"{ck_dir}/example/ck_tile/01_fmha/generate.py\", \"-d\", \"bwd\", \"--output_dir\", \"build\", \"--receipt\", \"2\", \"--optdim\", optdim], check=True)\n\n        # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h\n        # See https://github.com/pytorch/pytorch/pull/70650\n        generator_flag = []\n        torch_dir = torch.__path__[0]\n        if os.path.exists(os.path.join(torch_dir, \"include\", \"ATen\", \"CUDAGeneratorImpl.h\")):\n            generator_flag = [\"-DOLD_GENERATOR_PATH\"]\n\n        check_if_rocm_home_none(\"flash_attn\")\n        archs = os.getenv(\"GPU_ARCHS\", \"native\").split(\";\")\n        validate_and_update_archs(archs)\n\n        if archs != ['native']:\n            cc_flag = [f\"--offload-arch={arch}\" for arch in archs]\n        else:\n            arch = torch.cuda.get_device_properties(\"cuda\").gcnArchName.split(\":\")[0]\n            cc_flag = [f\"--offload-arch={arch}\"]\n\n        # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as\n        # torch._C._GLIBCXX_USE_CXX11_ABI\n        # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920\n        if FORCE_CXX11_ABI:\n            torch._C._GLIBCXX_USE_CXX11_ABI = True\n\n        sources = [\"csrc/flash_attn_ck/flash_api.cpp\",\n                \"csrc/flash_attn_ck/flash_common.cpp\",\n                \"csrc/flash_attn_ck/mha_bwd.cpp\",\n                \"csrc/flash_attn_ck/mha_fwd_kvcache.cpp\",\n                \"csrc/flash_attn_ck/mha_fwd.cpp\",\n                \"csrc/flash_attn_ck/mha_varlen_bwd.cpp\",\n                \"csrc/flash_attn_ck/mha_varlen_fwd.cpp\"] + glob.glob(\n            f\"build/fmha_*wd*.cpp\"\n        )\n\n        # Check if torch is using hipify v2. Until CK is updated with HIPIFY_V2 macro,\n        # we must replace the incorrect APIs.\n        maybe_hipify_v2_flag = []\n        if detect_hipify_v2():\n            maybe_hipify_v2_flag = [\"-DHIPIFY_V2\"]\n\n        rename_cpp_to_cu(sources)\n\n        renamed_sources = [\"csrc/flash_attn_ck/flash_api.cu\",\n                        \"csrc/flash_attn_ck/flash_common.cu\",\n                        \"csrc/flash_attn_ck/mha_bwd.cu\",\n                        \"csrc/flash_attn_ck/mha_fwd_kvcache.cu\",\n                        \"csrc/flash_attn_ck/mha_fwd.cu\",\n                        \"csrc/flash_attn_ck/mha_varlen_bwd.cu\",\n                        \"csrc/flash_attn_ck/mha_varlen_fwd.cu\"] + glob.glob(f\"build/fmha_*wd*.cu\")\n\n        cc_flag += [\"-O3\",\"-std=c++20\",\n                    \"-Wno-unknown-warning-option\",\n                    \"-fbracket-depth=1024\",\n                    \"-DCK_TILE_FMHA_FWD_FAST_EXP2=1\",\n                    \"-fgpu-flush-denormals-to-zero\",\n                    \"-DCK_ENABLE_BF16\",\n                    \"-DCK_ENABLE_BF8\",\n                    \"-DCK_ENABLE_FP16\",\n                    \"-DCK_ENABLE_FP32\",\n                    \"-DCK_ENABLE_FP64\",\n                    \"-DCK_ENABLE_FP8\",\n                    \"-DCK_ENABLE_INT8\",\n                    \"-DCK_USE_XDL\",\n                    \"-DUSE_PROF_API=1\",\n                    # \"-DFLASHATTENTION_DISABLE_BACKWARD\",\n                    \"-D__HIP_PLATFORM_HCC__=1\"]\n\n        cc_flag += [f\"-DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT={os.environ.get('CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT', 3)}\"]\n\n        # Imitate https://github.com/ROCm/composable_kernel/blob/c8b6b64240e840a7decf76dfaa13c37da5294c4a/CMakeLists.txt#L190-L214\n        hip_version = get_hip_version()\n        if hip_version > Version('5.5.00000'):\n            cc_flag += [\"-mllvm\", \"--lsr-drop-solution=1\"]\n        if hip_version > Version('5.7.23302'):\n            cc_flag += [\"-fno-offload-uniform-block\"]\n        if hip_version > Version('6.1.40090'):\n            cc_flag += [\"-mllvm\", \"-enable-post-misched=0\"]\n        if hip_version > Version('6.2.41132'):\n            cc_flag += [\"-mllvm\", \"-amdgpu-early-inline-all=true\",\n                        \"-mllvm\", \"-amdgpu-function-calls=false\"]\n        if hip_version > Version('6.2.41133') and hip_version < Version('6.3.00000'):\n            cc_flag += [\"-mllvm\", \"-amdgpu-coerce-illegal-types=1\"]\n\n        extra_compile_args = {\n            \"cxx\": [\"-O3\", \"-std=c++20\"] + generator_flag + maybe_hipify_v2_flag,\n            \"nvcc\": cc_flag + generator_flag + maybe_hipify_v2_flag,\n        }\n\n        include_dirs = [\n            Path(this_dir) / \"csrc\" / \"composable_kernel\" / \"include\",\n            Path(this_dir) / \"csrc\" / \"composable_kernel\" / \"library\" / \"include\",\n            Path(this_dir) / \"csrc\" / \"composable_kernel\" / \"example\" / \"ck_tile\" / \"01_fmha\",\n        ]\n\n        ext_modules.append(\n            CUDAExtension(\n                name=\"flash_attn_2_cuda\",\n                sources=renamed_sources,\n                extra_compile_args=extra_compile_args,\n                include_dirs=include_dirs,\n            )\n        )\n\n\ndef get_package_version():\n    with open(Path(this_dir) / \"flash_attn\" / \"__init__.py\", \"r\") as f:\n        version_match = re.search(r\"^__version__\\s*=\\s*(.*)$\", f.read(), re.MULTILINE)\n    public_version = ast.literal_eval(version_match.group(1))\n    local_version = os.environ.get(\"FLASH_ATTN_LOCAL_VERSION\")\n    if local_version:\n        return f\"{public_version}+{local_version}\"\n    else:\n        return str(public_version)\n\n\ndef get_wheel_url():\n    torch_version_raw = parse(torch.__version__)\n    python_version = f\"cp{sys.version_info.major}{sys.version_info.minor}\"\n    platform_name = get_platform()\n    flash_version = get_package_version()\n    torch_version = f\"{torch_version_raw.major}.{torch_version_raw.minor}\"\n    cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()\n\n    if IS_ROCM:\n        torch_hip_version = get_hip_version()\n        hip_version = f\"{torch_hip_version.major}{torch_hip_version.minor}\"\n        wheel_filename = f\"{PACKAGE_NAME}-{flash_version}+rocm{hip_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl\"\n    else:\n        # Determine the version numbers that will be used to determine the correct wheel\n        # We're using the CUDA version used to build torch, not the one currently installed\n        # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)\n        torch_cuda_version = parse(torch.version.cuda)\n        # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3\n        # to save CI time. Minor versions should be compatible.\n        torch_cuda_version = parse(\"11.8\") if torch_cuda_version.major == 11 else parse(\"12.3\")\n        # cuda_version = f\"{cuda_version_raw.major}{cuda_version_raw.minor}\"\n        cuda_version = f\"{torch_cuda_version.major}\"\n\n        # Determine wheel URL based on CUDA version, torch version, python version and OS\n        wheel_filename = f\"{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl\"\n\n    wheel_url = BASE_WHEEL_URL.format(tag_name=f\"v{flash_version}\", wheel_name=wheel_filename)\n\n    return wheel_url, wheel_filename\n\n\nclass CachedWheelsCommand(_bdist_wheel):\n    \"\"\"\n    The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot\n    find an existing wheel (which is currently the case for all flash attention installs). We use\n    the environment parameters to detect whether there is already a pre-built version of a compatible\n    wheel available and short-circuits the standard full build pipeline.\n    \"\"\"\n\n    def run(self):\n        if FORCE_BUILD:\n            return super().run()\n\n        wheel_url, wheel_filename = get_wheel_url()\n        print(\"Guessing wheel URL: \", wheel_url)\n        try:\n            urllib.request.urlretrieve(wheel_url, wheel_filename)\n\n            # Make the archive\n            # Lifted from the root wheel processing command\n            # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85\n            if not os.path.exists(self.dist_dir):\n                os.makedirs(self.dist_dir)\n\n            impl_tag, abi_tag, plat_tag = self.get_tag()\n            archive_basename = f\"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}\"\n\n            wheel_path = os.path.join(self.dist_dir, archive_basename + \".whl\")\n            print(\"Raw wheel path\", wheel_path)\n            os.rename(wheel_filename, wheel_path)\n        except (urllib.error.HTTPError, urllib.error.URLError):\n            print(\"Precompiled wheel not found. Building from source...\")\n            # If the wheel could not be downloaded, build from source\n            super().run()\n\n\nclass NinjaBuildExtension(BuildExtension):\n    def __init__(self, *args, **kwargs) -> None:\n        # do not override env MAX_JOBS if already exists\n        if not os.environ.get(\"MAX_JOBS\"):\n            import psutil\n\n            nvcc_threads = max(1, int(NVCC_THREADS))\n\n            # calculate the maximum allowed NUM_JOBS based on cores\n            max_num_jobs_cores = max(1, os.cpu_count() // 2)\n\n            # calculate the maximum allowed NUM_JOBS based on free memory\n            free_memory_gb = psutil.virtual_memory().available / (1024 ** 3)  # free memory in GB\n            # Assume worst-case peak observed memory usage of ~5GB per NVCC thread.\n            # Limit: peak_threads = max_jobs * nvcc_threads and peak_threads * 5GB <= free_memory.\n            max_num_jobs_memory = max(1, int(free_memory_gb / (5 * nvcc_threads)))\n\n            # pick lower value of jobs based on cores vs memory metric to minimize oom and swap usage during compilation\n            max_jobs = max(1, min(max_num_jobs_cores, max_num_jobs_memory))\n            print(\n                f\"Auto set MAX_JOBS to `{max_jobs}`, NVCC_THREADS to `{nvcc_threads}`. \"\n                \"If you see memory pressure, please use a lower `MAX_JOBS=N` or `NVCC_THREADS=N` value.\"\n            )\n            os.environ[\"MAX_JOBS\"] = str(max_jobs)\n\n        super().__init__(*args, **kwargs)\n\n\n# Build install_requires based on platform\nif ROCM_BACKEND == \"triton\":\n    # Note: torch is excluded because pip resolves it to CUDA PyTorch from PyPI, overwriting any pre-installed ROCm PyTorch. Users must have torch installed.\n    install_requires = [\n        \"einops\",\n        \"triton==3.5.1\",\n    ]\nelse:\n    install_requires = [\n        \"torch\",\n        \"einops\",\n    ]\n\nsetup(\n    name=PACKAGE_NAME,\n    version=get_package_version(),\n    packages=find_packages(\n        exclude=(\n            \"build\",\n            \"csrc\",\n            \"include\",\n            \"tests\",\n            \"dist\",\n            \"docs\",\n            \"benchmarks\",\n            \"flash_attn.egg-info\",\n            \"flash_attn.cute\",\n            \"flash_attn.cute.*\",\n        )\n    ),\n    author=\"Tri Dao\",\n    author_email=\"tri@tridao.me\",\n    description=\"Flash Attention: Fast and Memory-Efficient Exact Attention\",\n    long_description=long_description,\n    long_description_content_type=\"text/markdown\",\n    url=\"https://github.com/Dao-AILab/flash-attention\",\n    classifiers=[\n        \"Programming Language :: Python :: 3\",\n        \"License :: OSI Approved :: BSD License\",\n        \"Operating System :: Unix\",\n    ],\n    ext_modules=ext_modules,\n    cmdclass={\"bdist_wheel\": CachedWheelsCommand, \"build_ext\": NinjaBuildExtension}\n    if ext_modules\n    else {\n        \"bdist_wheel\": CachedWheelsCommand,\n    },\n    python_requires=\">=3.9\",\n    install_requires=install_requires,\n    setup_requires=[\n        \"packaging\",\n        \"psutil\",\n        \"ninja\",\n    ],\n)\n"
  },
  {
    "path": "tests/cute/benchmark_block_sparsity.py",
    "content": "\"\"\"\nComparative benchmark: CuTe DSL vs Native PyTorch block sparsity computation.\n\"\"\"\n\nimport torch\nfrom dataclasses import dataclass\nfrom typing import Callable, Optional, List\nfrom tabulate import tabulate\nfrom tqdm import tqdm\nimport itertools\n\nfrom cutlass.cute.runtime import from_dlpack\nfrom cutlass.cute.testing import benchmark as cute_benchmark\nimport cutlass.cute as cute\nfrom flash_attn.cute.compute_block_sparsity import BlockSparsityKernel\nfrom flash_attn.cute.block_sparsity import BlockSparseTensors\nfrom mask_mod_definitions import (\n    get_mask_pair,\n    random_doc_id_tensor,\n    flex_document_mask,\n    cute_document_mask,\n)\n\nfrom torch.nn.attention.flex_attention import create_block_mask\nfrom triton.testing import do_bench\n\n# Configure torch.compile cache to prevent memory buildup\ntorch._dynamo.config.cache_size_limit = 1000\n\n\n@dataclass(frozen=True)\nclass BenchmarkConfig:\n    \"\"\"Configuration for a benchmark run.\"\"\"\n\n    batch_size: int\n    num_heads: int\n    seqlen_q: int\n    seqlen_k: int\n    mask_name: str\n    tile_m: int = 128\n    tile_n: int = 128\n    use_fast_sampling: bool = False\n    aux_tensors_cute: Optional[list] = None\n\n\n@dataclass(frozen=True)\nclass BenchmarkResult:\n    \"\"\"Result of a single benchmark run.\"\"\"\n\n    config: BenchmarkConfig\n    cute_time_ms: Optional[float]\n    pytorch_time_ms: Optional[float]\n    error_message: Optional[str] = None\n\n\ndef benchmark_pytorch_block_sparsity(\n    config: BenchmarkConfig,\n    mask_fn: Callable,\n) -> Optional[float]:\n    \"\"\"\n    Benchmark PyTorch block mask creation (compiled).\n    Returns: creation_time_ms\n    \"\"\"\n    device = \"cuda\"\n\n    try:\n        cbm = torch.compile(create_block_mask)\n\n        def run_benchmark():\n            return cbm(\n                mask_fn,\n                config.batch_size,\n                config.num_heads,\n                config.seqlen_q,\n                config.seqlen_k,\n                device=device,\n            )\n\n        creation_time_ms = do_bench(run_benchmark, warmup=10, rep=100)\n\n        return creation_time_ms\n\n    except Exception as e:\n        print(f\"PyTorch benchmark failed ({config.mask_name}): {e}\")\n        import traceback\n\n        traceback.print_exc()\n        return None\n\n\ndef benchmark_cute_block_sparsity(\n    config: BenchmarkConfig,\n    mask_fn: Callable,\n) -> Optional[float]:\n    \"\"\"\n    Benchmark CuTe block sparsity kernel.\n    Returns: creation_time_ms\n    \"\"\"\n    device = \"cuda\"\n\n    try:\n        num_m_blocks = (config.seqlen_q + config.tile_m - 1) // config.tile_m\n        num_n_blocks = (config.seqlen_k + config.tile_n - 1) // config.tile_n\n\n        mask_block_cnt = torch.zeros(\n            (config.batch_size, config.num_heads, num_m_blocks),\n            device=device,\n            dtype=torch.int32,\n        )\n        mask_block_idx = torch.zeros(\n            (config.batch_size, config.num_heads, num_m_blocks, num_n_blocks),\n            device=device,\n            dtype=torch.int32,\n        )\n        full_block_cnt = torch.zeros(\n            (config.batch_size, config.num_heads, num_m_blocks),\n            device=device,\n            dtype=torch.int32,\n        )\n        full_block_idx = torch.zeros(\n            (config.batch_size, config.num_heads, num_m_blocks, num_n_blocks),\n            device=device,\n            dtype=torch.int32,\n        )\n\n        # Convert to CuTe tensors\n        mask_cnt_cute = from_dlpack(\n            mask_block_cnt.detach(), assumed_align=4\n        ).mark_layout_dynamic(leading_dim=2)\n        mask_idx_cute = from_dlpack(\n            mask_block_idx.detach(), assumed_align=4\n        ).mark_layout_dynamic(leading_dim=3)\n        full_cnt_cute = from_dlpack(\n            full_block_cnt.detach(), assumed_align=4\n        ).mark_layout_dynamic(leading_dim=2)\n        full_idx_cute = from_dlpack(\n            full_block_idx.detach(), assumed_align=4\n        ).mark_layout_dynamic(leading_dim=3)\n\n        blocksparse_tensors = BlockSparseTensors(\n            mask_block_cnt=mask_cnt_cute,\n            mask_block_idx=mask_idx_cute,\n            full_block_cnt=full_cnt_cute,\n            full_block_idx=full_idx_cute,\n        )\n\n        # Create kernel\n        use_aux = (\n            config.aux_tensors_cute is not None and len(config.aux_tensors_cute) > 0\n        )\n        kernel = BlockSparsityKernel(\n            mask_mod=mask_fn,\n            tile_mn=(config.tile_m, config.tile_n),\n            compute_full_blocks=True,\n            use_aux_tensors=use_aux,\n            use_fast_sampling=config.use_fast_sampling,\n        )\n\n        # Compile kernel\n        compiled_kernel = cute.compile(\n            kernel,\n            blocksparse_tensors,\n            config.seqlen_q,\n            config.seqlen_k,\n            config.aux_tensors_cute,\n        )\n\n        def generate_tensors():\n            from cutlass.cute.testing import JitArguments\n\n            return JitArguments(\n                blocksparse_tensors,\n                config.seqlen_q,\n                config.seqlen_k,\n                config.aux_tensors_cute,\n            )\n\n        creation_time_us = cute_benchmark(\n            compiled_kernel,\n            workspace_generator=generate_tensors,\n            warmup_iterations=10,\n            iterations=100,\n        )\n\n        torch.cuda.synchronize(device)\n        creation_time_ms = creation_time_us / 1000.0\n\n        return creation_time_ms\n\n    except Exception as e:\n        print(f\"CuTe benchmark failed: {e}\")\n        return None\n\n\ndef run_benchmark(\n    config: BenchmarkConfig,\n    pytorch_mask_fn: Callable,\n    cute_mask_fn: Callable,\n) -> BenchmarkResult:\n    \"\"\"Run benchmarks for both implementations.\"\"\"\n\n    print(\n        f\"Benchmarking {config.mask_name} - B={config.batch_size}, H={config.num_heads}, \"\n        f\"M={config.seqlen_q}, N={config.seqlen_k}\"\n    )\n\n    # Benchmark PyTorch\n    pytorch_time = benchmark_pytorch_block_sparsity(config, pytorch_mask_fn)\n\n    # Benchmark CuTe\n    cute_time = benchmark_cute_block_sparsity(config, cute_mask_fn)\n\n    return BenchmarkResult(\n        config=config,\n        cute_time_ms=cute_time,\n        pytorch_time_ms=pytorch_time,\n    )\n\n\ndef generate_configs(\n    batch_sizes: List[int],\n    num_heads: List[int],\n    seqlens: List[int],\n    mask_names: List[str],\n) -> List[BenchmarkConfig]:\n    \"\"\"Generate all benchmark configurations.\"\"\"\n    configs = []\n    for B, H, S, mask_name in itertools.product(\n        batch_sizes, num_heads, seqlens, mask_names\n    ):\n        configs.append(\n            BenchmarkConfig(\n                batch_size=B,\n                num_heads=H,\n                seqlen_q=S,\n                seqlen_k=S,\n                mask_name=mask_name,\n            )\n        )\n    return configs\n\n\ndef print_results(results: List[BenchmarkResult]):\n    successful_results = [\n        r\n        for r in results\n        if r.cute_time_ms is not None and r.pytorch_time_ms is not None\n    ]\n\n    if not successful_results:\n        print(\"No successful benchmark results to display\")\n        return\n\n    headers = [\n        \"B\",\n        \"H\",\n        \"M\",\n        \"N\",\n        \"Mask Type\",\n        \"CuTe Time (ms)\",\n        \"PyTorch Time (ms)\",\n        \"Speedup\",\n    ]\n\n    rows = []\n    for result in successful_results:\n        speedup = (\n            result.pytorch_time_ms / result.cute_time_ms\n            if result.cute_time_ms > 0\n            else 0\n        )\n\n        rows.append(\n            [\n                result.config.batch_size,\n                result.config.num_heads,\n                result.config.seqlen_q,\n                result.config.seqlen_k,\n                result.config.mask_name,\n                f\"{result.cute_time_ms:.4f}\",\n                f\"{result.pytorch_time_ms:.4f}\",\n                f\"{speedup:.2f}x\",\n            ]\n        )\n\n    # Sort by batch, head, seqlen, then mask type\n    rows.sort(key=lambda x: (x[0], x[1], x[2], x[4]))\n\n    print(\"\\n\" + \"=\" * 100)\n    print(\"CuTe DSL vs PyTorch Block Sparsity Benchmark Results\")\n    print(\"=\" * 100)\n    print(tabulate(rows, headers=headers, tablefmt=\"github\"))\n    print(\"=\" * 100)\n\n\ndef main():\n    \"\"\"Run the comparative benchmark.\"\"\"\n\n    # Configuration\n    batch_sizes = [1, 4, 8]\n    num_heads = [8, 16]\n    seqlens = [1024, 2048, 4096, 8192]\n    mask_names = [\n        \"causal\",\n        \"sliding_window\",\n        \"prefix_lm\",\n        \"dilated_sliding_window\",\n        \"document\",\n    ]\n\n    device = \"cuda\"\n    max_seqlen = max(seqlens)\n    max_batch = max(batch_sizes)\n    max_heads = max(num_heads)\n\n    # Create document IDs using the helper from mask_definitions\n    doc_ids = random_doc_id_tensor(max_heads, max_batch, max_seqlen, device=device)\n    doc_ids_cute = from_dlpack(doc_ids.detach(), assumed_align=4).mark_layout_dynamic(\n        leading_dim=2\n    )\n\n    # Generate base configurations\n    base_configs = generate_configs(batch_sizes, num_heads, seqlens, mask_names)\n\n    # Update configs with aux tensors for document masking\n    configs = []\n    for config in base_configs:\n        if config.mask_name == \"document\":\n            # Add aux tensors for document masking\n            configs.append(\n                BenchmarkConfig(\n                    batch_size=config.batch_size,\n                    num_heads=config.num_heads,\n                    seqlen_q=config.seqlen_q,\n                    seqlen_k=config.seqlen_k,\n                    mask_name=config.mask_name,\n                    tile_m=config.tile_m,\n                    tile_n=config.tile_n,\n                    use_fast_sampling=False,\n                    aux_tensors_cute=[doc_ids_cute],\n                )\n            )\n        else:\n            configs.append(config)\n\n    # Run benchmarks\n    results = []\n    print(f\"Running {len(configs)} benchmark configurations...\")\n    for config in tqdm(configs, desc=\"Benchmarking\"):\n        try:\n            # Get mask pair from mask_definitions\n            mask_kwargs = {}\n            if config.mask_name == \"sliding_window\":\n                mask_kwargs[\"window_size\"] = 128  # Default window size\n\n            cute_mask_fn, pytorch_mask_fn = get_mask_pair(\n                config.mask_name,\n                seqlen_q=config.seqlen_q,\n                seqlen_k=config.seqlen_k,\n                **mask_kwargs,\n            )\n\n            # For document masking, create wrapper that captures doc_ids\n            if config.mask_name == \"document\":\n                # PyTorch wrapper\n                def pytorch_mask_fn(b, h, q, kv):\n                    return flex_document_mask(b, h, q, kv, doc_ids)\n\n                # CuTe wrapper - reuse cute_document_mask with aux_tensors\n                cute_mask_fn = cute_document_mask\n\n            result = run_benchmark(config, pytorch_mask_fn, cute_mask_fn)\n            results.append(result)\n\n        except Exception as e:\n            print(f\"Failed to run config {config}: {e}\")\n            results.append(\n                BenchmarkResult(\n                    config=config,\n                    cute_time_ms=None,\n                    pytorch_time_ms=None,\n                    error_message=str(e),\n                )\n            )\n        finally:\n            torch.cuda.empty_cache()\n            torch._dynamo.reset()\n\n    print_results(results)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "tests/cute/benchmark_mask_mod.py",
    "content": "\"\"\"\nFlashAttention benchmarking script with Flex Attention-style\nmask mod support and varlen sequences.\n\"\"\"\n\nfrom dataclasses import dataclass\nimport math\nfrom typing import Any, Dict, Optional, Tuple\n\nimport cuda.bindings.driver as cuda\nimport cutlass\nimport cutlass.cute as cute\nfrom cutlass.cute.runtime import from_dlpack\nimport numpy as np\nimport torch\n\nfrom flash_attn.cute.flash_fwd_sm90 import FlashAttentionForwardSm90\nfrom mask_mod_definitions import (\n    get_mask_pair,\n    random_doc_id_tensor,\n)\nfrom flash_attn.cute.block_sparsity import (\n    BlockSparseTensorsTorch,\n    to_cute_block_sparse_tensors,\n)\nfrom flash_attn.cute.compute_block_sparsity import compute_block_sparsity\n\n\n@dataclass\nclass BenchmarkConfig:\n    \"\"\"Benchmark configuration\"\"\"\n\n    # Model parameters\n    headdim: int\n    headdim_v: int\n    nheads: int\n    nheads_kv: int\n    dtype: torch.dtype\n\n    # Sequence parameters\n    batch_size: int = 2\n    seqlen_q: int = 8192\n    seqlen_k: int = 8192\n\n    # Varlen parameters\n    use_varlen: bool = False\n    min_seqlen_q: Optional[int] = None  # If None, use seqlen_q // 2\n    max_seqlen_q: Optional[int] = None  # If None, use seqlen_q\n    min_seqlen_k: Optional[int] = None  # If None, use seqlen_k // 2\n    max_seqlen_k: Optional[int] = None  # If None, use seqlen_k\n\n    # Mask parameters\n    use_mask_mod: bool = True\n    mask_mod_name: str = \"causal\"\n    has_aux_tensors: bool = mask_mod_name == \"document\"\n\n    # Sliding window parameter (used when mask_mod_name == \"sliding_window\")\n    window_size: int = 128\n\n    # Attention parameters\n    causal: bool = False\n    is_local: bool = False\n    window_left: Optional[int] = 128  # For base Flash Attention local\n    window_right: Optional[int] = 0  # For base Flash Attention local\n    softcap: Optional[float] = None\n    use_learnable_sink: bool = False\n\n    # Kernel configuration\n    tile_m: int = 128\n    tile_n: int = 128\n    num_stages: int = 2\n    num_threads: int = 384\n    intra_wg_overlap: bool = True\n    mma_pv_is_rs: bool = True\n\n    # Benchmark parameters\n    warmup_iters: int = 10\n    benchmark_iters: int = 25\n    verbose: bool = False\n    seed: int = 42\n\n\nclass FlashAttentionBenchmark:\n    def __init__(self, config: BenchmarkConfig):\n        self.config = config\n\n        torch.manual_seed(config.seed)\n        np.random.seed(config.seed)\n\n        # Verify SM90 compute capability\n        compute_capability = torch.cuda.get_device_capability()\n        assert compute_capability >= (9, 0), (\n            f\"Requires SM90+, got SM{compute_capability[0]}{compute_capability[1]}\"\n        )\n        # causal overrides use_mask_mod\n        if config.causal:\n            config.use_mask_mod = False\n\n        if config.use_mask_mod:\n            self.mask_mod_cute, self.mask_mod_flex = get_mask_pair(\n                config.mask_mod_name,\n                seqlen_q=config.seqlen_q,\n                seqlen_k=config.seqlen_k,\n                window_size=config.window_size,\n            )\n        else:\n            self.mask_mod_cute = None\n            self.mask_mod_flex = None\n\n        self._validate_config()\n\n    def _validate_config(self):\n        config = self.config\n\n        assert config.headdim <= 256, \"headdim must be <= 256\"\n        assert config.headdim_v <= 256, \"headdim_v must be <= 256\"\n        assert config.nheads % config.nheads_kv == 0, \"nheads must be divisible by nheads_kv\"\n\n        alignment = 16 // config.dtype.itemsize\n        assert config.headdim % alignment == 0, f\"headdim must be divisible by {alignment}\"\n        assert config.headdim_v % alignment == 0, f\"headdim_v must be divisible by {alignment}\"\n\n        # Validate is_local configuration\n        if config.is_local:\n            assert config.window_left is not None or config.window_right is not None, (\n                \"When is_local=True, at least one of window_left or window_right must be set\"\n            )\n            assert not config.use_mask_mod, (\n                \"Cannot use both is_local and use_mask_mod simultaneously\"\n            )\n            assert not config.causal, \"Cannot use both is_local and causal simultaneously\"\n\n        # Validate mask_mod configuration\n        if config.use_mask_mod and config.mask_mod_name == \"sliding_window\":\n            assert config.window_size > 0, (\n                \"window_size must be positive when using sliding_window mask\"\n            )\n\n    def _generate_varlen_seqlens(self, min_len: int, max_len: int) -> Tuple[torch.Tensor, int]:\n        \"\"\"Generate random sequence lengths and compute cumulative lengths.\"\"\"\n        seqlens = torch.randint(\n            min_len, max_len + 1, (self.config.batch_size,), dtype=torch.int32, device=\"cuda\"\n        )\n        cu_seqlens = torch.cat(\n            [\n                torch.zeros(1, dtype=torch.int32, device=\"cuda\"),\n                torch.cumsum(seqlens, dtype=torch.int32, dim=0),\n            ]\n        )\n\n        total_tokens = cu_seqlens[-1].item()\n        return cu_seqlens, total_tokens\n\n    def _create_tensors(self) -> Dict[str, torch.Tensor]:\n        config = self.config\n        device = \"cuda\"\n\n        if config.use_varlen:\n            # Set defaults for varlen range\n            min_q = config.min_seqlen_q if config.min_seqlen_q is not None else config.seqlen_q // 2\n            max_q = config.max_seqlen_q if config.max_seqlen_q is not None else config.seqlen_q\n            min_k = config.min_seqlen_k if config.min_seqlen_k is not None else config.seqlen_k // 2\n            max_k = config.max_seqlen_k if config.max_seqlen_k is not None else config.seqlen_k\n\n            # Generate cu_seqlens\n            cu_seqlens_q, total_q = self._generate_varlen_seqlens(min_q, max_q)\n            cu_seqlens_k, total_k = self._generate_varlen_seqlens(min_k, max_k)\n\n            # Varlen shape: (total_tokens, nheads, headdim)\n            q = torch.randn(\n                total_q, config.nheads, config.headdim, dtype=config.dtype, device=device\n            )\n            k = torch.randn(\n                total_k, config.nheads_kv, config.headdim, dtype=config.dtype, device=device\n            )\n            v = torch.randn(\n                total_k, config.nheads_kv, config.headdim_v, dtype=config.dtype, device=device\n            )\n            out = torch.empty(\n                total_q, config.nheads, config.headdim_v, dtype=config.dtype, device=device\n            )\n            lse = torch.empty(config.nheads, total_q, dtype=torch.float32, device=device)\n\n            tensors = {\n                \"q\": q.contiguous(),\n                \"k\": k.contiguous(),\n                \"v\": v.contiguous(),\n                \"out\": out.contiguous(),\n                \"lse\": lse.contiguous(),\n                \"cu_seqlens_q\": cu_seqlens_q.contiguous(),\n                \"cu_seqlens_k\": cu_seqlens_k.contiguous(),\n            }\n\n            if config.verbose:\n                print(f\"Varlen: total_q={total_q}, total_k={total_k}\")\n                print(f\"Q seqlens: {cu_seqlens_q[1:] - cu_seqlens_q[:-1]}\")\n                print(f\"K seqlens: {cu_seqlens_k[1:] - cu_seqlens_k[:-1]}\")\n        else:\n            # Standard shape: (batch, seqlen, nheads, headdim)\n            q = torch.randn(\n                config.batch_size,\n                config.seqlen_q,\n                config.nheads,\n                config.headdim,\n                dtype=config.dtype,\n                device=device,\n            )\n            k = torch.randn(\n                config.batch_size,\n                config.seqlen_k,\n                config.nheads_kv,\n                config.headdim,\n                dtype=config.dtype,\n                device=device,\n            )\n            v = torch.randn(\n                config.batch_size,\n                config.seqlen_k,\n                config.nheads_kv,\n                config.headdim_v,\n                dtype=config.dtype,\n                device=device,\n            )\n            out = torch.empty(\n                config.batch_size,\n                config.seqlen_q,\n                config.nheads,\n                config.headdim_v,\n                dtype=config.dtype,\n                device=device,\n            )\n            lse = torch.empty(\n                config.batch_size,\n                config.nheads,\n                config.seqlen_q,\n                dtype=torch.float32,\n                device=device,\n            )\n\n            tensors = {\n                \"q\": q.contiguous(),\n                \"k\": k.contiguous(),\n                \"v\": v.contiguous(),\n                \"out\": out.contiguous(),\n                \"lse\": lse.contiguous(),\n            }\n\n        if config.use_learnable_sink:\n            learnable_sink = torch.rand(config.nheads, dtype=torch.bfloat16, device=device)\n\n            tensors[\"learnable_sink\"] = learnable_sink.contiguous()\n\n        # Compute block sparsity when using mask_mod\n        if config.use_mask_mod:\n            if config.mask_mod_name == \"document\":\n                doc_id = random_doc_id_tensor(\n                    config.batch_size, config.nheads, config.seqlen_q, device=device\n                )\n                tensors[\"aux_tensors\"] = [doc_id.contiguous()]\n\n            _, blocksparse_torch_tensors = compute_block_sparsity(\n                tile_m=self.config.tile_m,\n                tile_n=self.config.tile_n,\n                batch_size=self.config.batch_size,\n                num_heads=self.config.nheads,\n                seqlen_q=self.config.seqlen_q,\n                seqlen_k=self.config.seqlen_k,\n                mask_mod=self.mask_mod_cute,\n                device=device,\n                cu_seqlens_q=tensors.get(\"cu_seqlens_q\"),\n                cu_seqlens_k=tensors.get(\"cu_seqlens_k\"),\n                aux_tensors=tensors.get(\"aux_tensors\"),\n            )\n            if blocksparse_torch_tensors is not None:\n                tensors[\"block_sparse_tensors\"] = blocksparse_torch_tensors\n\n                if config.verbose:\n                    total_full = blocksparse_torch_tensors.full_block_cnt.sum().item()\n                    total_partial = blocksparse_torch_tensors.mask_block_cnt.sum().item()\n\n                    if config.use_varlen:\n                        # Compute max possible blocks across all sequences\n                        max_blocks = 0\n                        for i in range(config.batch_size):\n                            seq_len_q = (\n                                tensors[\"cu_seqlens_q\"][i + 1] - tensors[\"cu_seqlens_q\"][i]\n                            ).item()\n                            seq_len_k = (\n                                tensors[\"cu_seqlens_k\"][i + 1] - tensors[\"cu_seqlens_k\"][i]\n                            ).item()\n                            n_blocks_q = (seq_len_q + config.tile_m - 1) // config.tile_m\n                            n_blocks_k = (seq_len_k + config.tile_n - 1) // config.tile_n\n                            max_blocks += n_blocks_q * n_blocks_k * config.nheads\n                    else:\n                        n_blocks_k = (config.seqlen_k + config.tile_n - 1) // config.tile_n\n                        n_blocks_q = (config.seqlen_q + config.tile_m - 1) // config.tile_m\n                        max_blocks = n_blocks_k * n_blocks_q * config.nheads * config.batch_size\n\n                    skipped = max_blocks - total_full - total_partial\n                    print(\n                        f\"Block stats: Full={total_full}, Partial={total_partial}, \"\n                        f\"Skipped={skipped}/{max_blocks}\"\n                    )\n\n        return tensors\n\n    def _compile_kernel(self, tensors: Dict[str, torch.Tensor]) -> Tuple[Any, tuple]:\n        config = self.config\n\n        dtype_map = {\n            torch.float16: cutlass.Float16,\n            torch.bfloat16: cutlass.BFloat16,\n            torch.float32: cutlass.Float32,\n        }\n        cute_dtype = dtype_map[config.dtype]\n\n        qhead_per_kvhead = config.nheads // config.nheads_kv\n        kernel = FlashAttentionForwardSm90(\n            cute_dtype,\n            config.headdim,\n            config.headdim_v,\n            qhead_per_kvhead,\n            is_causal=config.causal,\n            is_local=config.is_local,\n            pack_gqa=False,\n            tile_m=config.tile_m,\n            tile_n=config.tile_n,\n            num_stages=config.num_stages,\n            num_threads=config.num_threads,\n            intra_wg_overlap=config.intra_wg_overlap,\n            mma_pv_is_rs=config.mma_pv_is_rs,\n            mask_mod=self.mask_mod_cute,\n            Q_in_regs=False,\n            has_aux_tensors=config.has_aux_tensors,\n        )\n\n        softmax_scale = 1.0 / math.sqrt(config.headdim)\n        current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)\n\n        # Convert tensors to cute\n        q_cute = from_dlpack(tensors[\"q\"].detach(), assumed_align=16).mark_layout_dynamic(\n            leading_dim=tensors[\"q\"].ndim - 1\n        )\n        k_cute = from_dlpack(tensors[\"k\"].detach(), assumed_align=16).mark_layout_dynamic(\n            leading_dim=tensors[\"k\"].ndim - 1\n        )\n        v_cute = from_dlpack(tensors[\"v\"].detach(), assumed_align=16).mark_layout_dynamic(\n            leading_dim=tensors[\"v\"].ndim - 1\n        )\n        out_cute = from_dlpack(tensors[\"out\"].detach(), assumed_align=16).mark_layout_dynamic(\n            leading_dim=tensors[\"out\"].ndim - 1\n        )\n        lse_cute = from_dlpack(tensors[\"lse\"].detach(), assumed_align=4).mark_layout_dynamic(\n            leading_dim=tensors[\"lse\"].ndim - 1\n        )\n\n        # Varlen tensors\n        cu_seqlens_q_cute = (\n            from_dlpack(tensors[\"cu_seqlens_q\"].detach(), assumed_align=4).mark_layout_dynamic(\n                leading_dim=0\n            )\n            if \"cu_seqlens_q\" in tensors\n            else None\n        )\n        cu_seqlens_k_cute = (\n            from_dlpack(tensors[\"cu_seqlens_k\"].detach(), assumed_align=4).mark_layout_dynamic(\n                leading_dim=0\n            )\n            if \"cu_seqlens_k\" in tensors\n            else None\n        )\n        learnable_sink_cute = (\n            from_dlpack(tensors[\"learnable_sink\"].detach(), assumed_align=4).mark_layout_dynamic(\n                leading_dim=0\n            )\n            if \"learnable_sink\" in tensors\n            else None\n        )\n\n        blocksparse_tensors_cute = (\n            to_cute_block_sparse_tensors(tensors[\"block_sparse_tensors\"])\n            if \"block_sparse_tensors\" in tensors\n            else None\n        )\n\n        if \"aux_tensors\" in tensors:\n            aux_tensors_cute = []\n            for i in range(len(tensors[\"aux_tensors\"])):\n                buf = from_dlpack(tensors[\"aux_tensors\"][i].detach(), assumed_align=4)\n                aux_tensors_cute.append(buf.mark_layout_dynamic(leading_dim=2))\n\n        else:\n            aux_tensors_cute = None\n\n        # Window parameters for is_local\n        window_left_cute = (\n            cutlass.Int32(config.window_left) if config.window_left is not None else None\n        )\n        window_right_cute = (\n            cutlass.Int32(config.window_right) if config.window_right is not None else None\n        )\n\n        compiled = cute.compile(\n            kernel,\n            q_cute,\n            k_cute,\n            v_cute,\n            out_cute,\n            lse_cute,\n            softmax_scale,\n            current_stream,\n            cu_seqlens_q_cute,\n            cu_seqlens_k_cute,\n            None,  # seqused_q\n            None,  # seqused_k\n            None,  # page_table\n            window_left_cute,\n            window_right_cute,\n            learnable_sink_cute,\n            blocksparse_tensors_cute,\n            aux_tensors_cute,\n            # None,\n        )\n\n        args = (\n            q_cute,\n            k_cute,\n            v_cute,\n            out_cute,\n            lse_cute,\n            softmax_scale,\n            current_stream,\n            cu_seqlens_q_cute,\n            cu_seqlens_k_cute,\n            None,\n            None,\n            None,\n            window_left_cute,\n            window_right_cute,\n            learnable_sink_cute,\n            blocksparse_tensors_cute,\n            aux_tensors_cute,\n            # None,\n        )\n\n        return compiled, args\n\n    def _calculate_flops(self, tensors: Dict[str, torch.Tensor]) -> float:\n        config = self.config\n\n        # Estimate sparsity for known mask patterns\n        if config.is_local:\n            # Local attention with window_left and window_right\n            window_left = config.window_left if config.window_left is not None else 0\n            window_right = config.window_right if config.window_right is not None else 0\n            total_window = window_left + window_right + 1  # +1 for current position\n            sparsity_ratio = min(1.0, total_window / config.seqlen_k)\n        elif config.use_mask_mod:\n            if config.mask_mod_name in [\"identity\", \"identity_partial\"]:\n                sparsity_ratio = 1.0\n            elif config.mask_mod_name in [\"causal\", \"block_causal\"]:\n                sparsity_ratio = 0.5\n            elif config.mask_mod_name == \"sliding_window\":\n                # Use configured window size\n                sparsity_ratio = min(1.0, config.window_size / config.seqlen_k)\n            elif config.mask_mod_name == \"block_diagonal\":\n                block_size = 64\n                num_blocks = (config.seqlen_k + block_size - 1) // block_size\n                sparsity_ratio = 1.0 / num_blocks if num_blocks > 1 else 1.0\n            elif config.mask_mod_name == \"document\":\n                vals = tensors[\"aux_tensors\"][0]\n                val_mask = torch.ones_like(vals, dtype=torch.bool)\n                val_mask[..., 1:] = vals[..., 1:] != vals[..., :-1]\n                total = torch.where(val_mask, vals.square(), 0).sum()\n                sparsity_ratio = total / (config.seqlen_q * config.seqlen_k)\n            else:\n                sparsity_ratio = 1.0\n        elif config.causal:\n            sparsity_ratio = 0.5\n        else:\n            sparsity_ratio = 1.0\n\n        if config.use_varlen:\n            # Compute FLOPs per sequence and sum\n            total_flops = 0\n            cu_q = tensors[\"cu_seqlens_q\"]\n            cu_k = tensors[\"cu_seqlens_k\"]\n            for i in range(config.batch_size):\n                seq_len_q = (cu_q[i + 1] - cu_q[i]).item()\n                seq_len_k = (cu_k[i + 1] - cu_k[i]).item()\n\n                # Adjust sparsity for local attention in varlen case\n                if config.is_local:\n                    window_left = config.window_left if config.window_left is not None else 0\n                    window_right = config.window_right if config.window_right is not None else 0\n                    total_window = window_left + window_right + 1\n                    seq_sparsity = min(1.0, total_window / seq_len_k)\n                elif config.use_mask_mod and config.mask_mod_name == \"sliding_window\":\n                    seq_sparsity = min(1.0, config.window_size / seq_len_k)\n                else:\n                    seq_sparsity = sparsity_ratio\n\n                num_cells = int(seq_len_q * seq_len_k * seq_sparsity)\n\n                if config.headdim == config.headdim_v:\n                    flops_this_seq = 4 * config.nheads * num_cells * config.headdim\n                else:\n                    flops_this_seq = (\n                        2 * config.nheads * num_cells * config.headdim\n                        + 2 * config.nheads * num_cells * config.headdim_v\n                    )\n                total_flops += flops_this_seq\n            return total_flops\n        else:\n            num_cells = int(config.seqlen_q * config.seqlen_k * sparsity_ratio)\n            if config.headdim == config.headdim_v:\n                flops_per_batch = 4 * config.nheads * num_cells * config.headdim\n            else:\n                flops_per_batch = (\n                    2 * config.nheads * num_cells * config.headdim\n                    + 2 * config.nheads * num_cells * config.headdim_v\n                )\n            return flops_per_batch * config.batch_size\n\n    def benchmark(self) -> Dict[str, Any]:\n        config = self.config\n\n        tensors = self._create_tensors()\n        compiled_kernel, args = self._compile_kernel(tensors)\n\n        # Warmup\n        for _ in range(config.warmup_iters):\n            compiled_kernel(*args)\n        torch.cuda.synchronize()\n\n        # Benchmark\n        times = []\n        for _ in range(config.benchmark_iters):\n            start = torch.cuda.Event(enable_timing=True)\n            end = torch.cuda.Event(enable_timing=True)\n\n            start.record()\n            compiled_kernel(*args)\n            end.record()\n            torch.cuda.synchronize()\n\n            times.append(start.elapsed_time(end))\n\n        times_tensor = torch.tensor(times)\n        mean_time = times_tensor.mean().item()\n        std_time = times_tensor.std().item() if len(times) > 1 else 0.0\n\n        total_flops = self._calculate_flops(tensors)\n        tflops = total_flops / (mean_time * 1e-3) / 1e12\n\n        # Bandwidth calculation\n        bytes_per_element = config.dtype.itemsize\n        if config.use_varlen:\n            total_q = tensors[\"q\"].shape[0]\n            total_k = tensors[\"k\"].shape[0]\n            memory_accessed = (\n                total_q * config.nheads * config.headdim * bytes_per_element\n                + total_k * config.nheads_kv * config.headdim * bytes_per_element\n                + total_k * config.nheads_kv * config.headdim_v * bytes_per_element\n                + total_q * config.nheads * config.headdim_v * bytes_per_element\n            )\n        else:\n            memory_accessed = (\n                config.batch_size\n                * config.seqlen_q\n                * config.nheads\n                * config.headdim\n                * bytes_per_element\n                + config.batch_size\n                * config.seqlen_k\n                * config.nheads_kv\n                * config.headdim\n                * bytes_per_element\n                + config.batch_size\n                * config.seqlen_k\n                * config.nheads_kv\n                * config.headdim_v\n                * bytes_per_element\n                + config.batch_size\n                * config.seqlen_q\n                * config.nheads\n                * config.headdim_v\n                * bytes_per_element\n            )\n        bandwidth_gbps = memory_accessed / (mean_time * 1e-3) / 1e9\n\n        results = {\n            \"mean_time_ms\": mean_time,\n            \"std_time_ms\": std_time,\n            \"tflops\": tflops,\n            \"bandwidth_gbps\": bandwidth_gbps,\n        }\n\n        if config.verbose:\n            self._print_results(results)\n\n        return results\n\n    def _print_results(self, results: Dict[str, Any]):\n        config = self.config\n\n        # Basic configuration\n        if config.use_varlen:\n            print(\n                f\"Shape: B={config.batch_size} (varlen), HD={config.headdim}, \"\n                f\"NH={config.nheads}, NKV={config.nheads_kv}\"\n            )\n        else:\n            print(\n                f\"Shape: B={config.batch_size}, Q={config.seqlen_q}, K={config.seqlen_k}, \"\n                f\"HD={config.headdim}, NH={config.nheads}, NKV={config.nheads_kv}\"\n            )\n\n        # Attention pattern\n        attn_info = []\n        if config.causal:\n            attn_info.append(\"causal\")\n        if config.is_local:\n            window_info = f\"local(L={config.window_left},R={config.window_right})\"\n            attn_info.append(window_info)\n        if config.use_mask_mod:\n            if config.mask_mod_name == \"sliding_window\":\n                attn_info.append(f\"mask_mod={config.mask_mod_name}(w={config.window_size})\")\n            else:\n                attn_info.append(f\"mask_mod={config.mask_mod_name}\")\n        if config.use_varlen:\n            attn_info.append(\"varlen\")\n        if attn_info:\n            print(f\"Attention: {', '.join(attn_info)}\")\n\n        # Performance metrics\n        print(f\"Time: {results['mean_time_ms']:.3f} ± {results['std_time_ms']:.3f} ms\")\n        print(f\"Throughput: {results['tflops']:.2f} TFLOPS\")\n        print(f\"Bandwidth: {results['bandwidth_gbps']:.1f} GB/s\")\n\n\nif __name__ == \"__main__\":\n    B = 2\n    config = BenchmarkConfig(\n        headdim=128,\n        headdim_v=128,\n        nheads=16,\n        nheads_kv=16,\n        dtype=torch.bfloat16,\n        batch_size=B,\n        # batch_size=1,\n        seqlen_q=8192,\n        # seqlen_q=128,\n        seqlen_k=8192,\n        # seqlen_k=192,\n        use_varlen=False,\n        use_mask_mod=False,\n        mask_mod_name=\"causal\",\n        window_size=128,  # Configurable window size for mask_mod\n        use_learnable_sink=False,\n        causal=True,\n        is_local=False,\n        verbose=True,\n    )\n\n    # Example 2: Base Flash Attention Local\n    # config = BenchmarkConfig(\n    #     headdim=64,\n    #     headdim_v=64,\n    #     nheads=64,\n    #     nheads_kv=8,\n    #     dtype=torch.bfloat16,\n    #     batch_size=2,\n    #     seqlen_q=8192,\n    #     seqlen_k=8192,\n    #     use_varlen=False,\n    #     use_mask_mod=False,\n    #     causal=False,\n    #     is_local=True,\n    #     window_left=128,   # Left window size for base local attention\n    #     window_right=0,    # Right window size for base local attention\n    #     verbose=True,\n    # )\n\n    benchmark = FlashAttentionBenchmark(config)\n    results = benchmark.benchmark()\n"
  },
  {
    "path": "tests/cute/conftest.py",
    "content": "import os\nimport subprocess\nimport logging\nimport tempfile\nimport json\nimport time\nfrom pathlib import Path\nfrom getpass import getuser\n\n\ndef _get_gpu_ids():\n    visible = os.environ.get(\"CUDA_VISIBLE_DEVICES\")\n    if visible:\n        return [g.strip() for g in visible.split(\",\")]\n\n    try:\n        result = subprocess.run(\n            [\"nvidia-smi\", \"--query-gpu=index\", \"--format=csv,noheader\"],\n            capture_output=True,\n            text=True,\n            timeout=5,\n        )\n        if result.returncode == 0:\n            return result.stdout.strip().splitlines()\n    except (FileNotFoundError,):\n        pass\n\n    logging.warning(\"Failed to get gpu ids, use default '0'\")\n    return [\"0\"]\n\n\ndef pytest_configure(config):\n    tmp = Path(tempfile.gettempdir()) / getuser() / \"flash_attention_tests\"\n    tmp.mkdir(parents=True, exist_ok=True)\n\n    worker_id = os.environ.get(\"PYTEST_XDIST_WORKER\")\n    logging.basicConfig(\n        format=config.getini(\"log_file_format\"),\n        filename=str(tmp / f\"tests_{worker_id}.log\"),\n        level=config.getini(\"log_file_level\"),\n    )\n    if not worker_id:\n        return\n    worker_num = int(worker_id.replace(\"gw\", \"\"))\n\n    # cache gpu_ids, because nvidia-smi is expensive when we launch many workers doing torch initialization\n    # Always elect worker_0 to get gpu_ids.\n    cached_gpu_ids = tmp / \"gpu_ids.json\"\n    if worker_num == 0:\n        gpu_ids = _get_gpu_ids()\n        with cached_gpu_ids.open(mode=\"w\") as f:\n            json.dump(gpu_ids, f)\n    else:\n        while not cached_gpu_ids.exists():\n            time.sleep(1)\n        with cached_gpu_ids.open() as f:\n            gpu_ids = json.load(f)\n\n    os.environ[\"CUDA_VISIBLE_DEVICES\"] = gpu_ids[worker_num % len(gpu_ids)]\n\ndef pytest_collection_finish(session):\n    # file_name -> test_name -> counter\n    test_counts: dict[str, dict[str, int]] = {}\n    for item in session.items:\n        funcname = item.function.__name__\n        parent = test_counts.setdefault(item.parent.name, {})\n        parent[funcname] = parent.setdefault(funcname, 0) + 1\n    print(json.dumps(test_counts, indent=2))\n"
  },
  {
    "path": "tests/cute/mask_mod_definitions.py",
    "content": "from typing import Callable, Optional\n\nimport random\nimport math\n\nimport cutlass\nimport cutlass.cute as cute\nimport torch\n\nfrom flash_attn.cute import utils\nfrom flash_attn.cute.block_sparsity import fast_sampling\n\n\n# =============================================================================\n# CuTe mask_mod functions (for kernel compilation)\n# All use signature: (batch, head, m_idx, n_idx, seqlen_info, aux_tensors)\n# =============================================================================\n\n# =============================================================================\n# mask_mod functions that don't use global indices\n# =============================================================================\n\n\n@fast_sampling\n@cute.jit\ndef cute_causal_mask(\n    batch: cute.TensorSSA,\n    head: cute.TensorSSA,\n    m_idx: cute.TensorSSA,\n    n_idx: cute.TensorSSA,\n    seqlen_info,\n    aux_tensors: None,\n) -> cute.TensorSSA:\n    offset = seqlen_info.seqlen_k - seqlen_info.seqlen_q\n    offset_ssa = utils.scalar_to_ssa(offset, cutlass.Int32)\n    return n_idx <= (m_idx + offset_ssa)\n\n\ndef get_cute_causal_mask(offset: int):\n    return cute_causal_mask\n\n\ndef get_cute_block_causal_mask(offset: int):\n    @fast_sampling\n    @cute.jit\n    def _cute_block_causal_mask(\n        batch: cute.TensorSSA,\n        head: cute.TensorSSA,\n        m_idx: cute.TensorSSA,\n        n_idx: cute.TensorSSA,\n        seqlen_info,\n        aux_tensors: None,\n    ) -> cute.TensorSSA:\n        offset_ssa = utils.scalar_to_ssa(offset, cutlass.Int32)\n        return n_idx <= (m_idx + offset_ssa)\n\n    return _cute_block_causal_mask\n\n\ndef get_cute_sliding_window_mask(window_left: int, window_right: int, offset: int):\n    @fast_sampling\n    @cute.jit\n    def _cute_sliding_window_mask(\n        batch: cute.TensorSSA,\n        head: cute.TensorSSA,\n        m_idx: cute.TensorSSA,\n        n_idx: cute.TensorSSA,\n        seqlen_info,\n        aux_tensors,\n    ) -> cute.TensorSSA:\n        offset = seqlen_info.seqlen_k - seqlen_info.seqlen_q\n        offset_ssa = utils.scalar_to_ssa(offset, cutlass.Int32)\n        window_left_ssa = utils.scalar_to_ssa(window_left, cutlass.Int32)\n        window_right_ssa = utils.scalar_to_ssa(window_right, cutlass.Int32)\n        center = m_idx + offset_ssa\n        lower = center - window_left_ssa\n        upper = center + window_right_ssa\n        return (n_idx >= lower) & (n_idx <= upper)\n\n    return _cute_sliding_window_mask\n\n\n@fast_sampling\n@cute.jit\ndef cute_block_diagonal_mask(\n    batch: cute.TensorSSA,\n    head: cute.TensorSSA,\n    m_idx: cute.TensorSSA,\n    n_idx: cute.TensorSSA,\n    seqlen_info,\n    aux_tensors,\n) -> cute.TensorSSA:\n    block_size_ssa = utils.scalar_to_ssa(128, cutlass.Int32)\n    return (m_idx // block_size_ssa) == (n_idx // block_size_ssa)\n\n\n@cute.jit\ndef cute_mini_causal_mask(\n    batch: cute.TensorSSA,\n    head: cute.TensorSSA,\n    m_idx: cute.TensorSSA,\n    n_idx: cute.TensorSSA,\n    seqlen_info,\n    aux_tensors,\n) -> cute.TensorSSA:\n    tile_size_ssa = utils.scalar_to_ssa(128, cutlass.Int32)\n    m_mod = m_idx % tile_size_ssa\n    n_mod = n_idx % tile_size_ssa\n    return m_mod >= n_mod\n\n\n@fast_sampling\n@cute.jit\ndef cute_prefix_lm_mask(\n    batch: cute.TensorSSA,\n    head: cute.TensorSSA,\n    m_idx: cute.TensorSSA,\n    n_idx: cute.TensorSSA,\n    seqlen_info,\n    aux_tensors,\n) -> cute.TensorSSA:\n    \"\"\"Prefix LM mask: first 512 tokens attend bidirectionally, rest use causal masking.\"\"\"\n    prefix_size_ssa = utils.scalar_to_ssa(512, cutlass.Int32)\n    both_in_prefix = (m_idx < prefix_size_ssa) & (n_idx < prefix_size_ssa)\n    causal_part = m_idx >= n_idx\n    return both_in_prefix | causal_part\n\n\n@cute.jit\ndef cute_dilated_sliding_window_mask(\n    batch: cute.TensorSSA,\n    head: cute.TensorSSA,\n    m_idx: cute.TensorSSA,\n    n_idx: cute.TensorSSA,\n    seqlen_info,\n    aux_tensors,\n) -> cute.TensorSSA:\n    \"\"\"Dilated sliding window: every other position in a 256-position window.\"\"\"\n    window_size_ssa = utils.scalar_to_ssa(256, cutlass.Int32)\n    dilation_ssa = utils.scalar_to_ssa(2, cutlass.Int32)\n    in_window = (m_idx >= n_idx) & (m_idx - n_idx < window_size_ssa)\n    dilated = ((m_idx - n_idx) % dilation_ssa) == utils.scalar_to_ssa(0, cutlass.Int32)\n    return in_window & dilated\n\n\n@fast_sampling\n@cute.jit\ndef cute_document_mask(\n    batch: cute.TensorSSA,\n    head: cute.TensorSSA,\n    m_idx: cute.TensorSSA,\n    n_idx: cute.TensorSSA,\n    seqlen_info,\n    aux_tensors: list,\n) -> cute.TensorSSA:\n    doc_id = aux_tensors[0]\n    m_doc = utils.scalar_to_ssa(doc_id[batch[0], head[0], m_idx[0]], cutlass.Int32)\n    n_doc = utils.scalar_to_ssa(doc_id[batch[0], head[0], n_idx[0]], cutlass.Int32)\n    return m_doc == n_doc\n\n\n@fast_sampling\n@cute.jit\ndef cute_ima_mask(\n    batch: cute.TensorSSA,\n    head: cute.TensorSSA,\n    m_idx: cute.TensorSSA,\n    n_idx: cute.TensorSSA,\n    seqlen_info,\n    aux_tensors,\n) -> cute.TensorSSA:\n    bias = aux_tensors[0]\n    threshold = utils.scalar_to_ssa(bias[n_idx[0]], cutlass.Int32)\n    return n_idx >= threshold\n\n\n# =============================================================================\n# mask_mod functions that use global indices (for use with variable sequence length)\n# Global indices computed as: m_idx_global = m_idx + seqlen_info.offset_q\n#                            n_idx_global = n_idx + seqlen_info.offset_k\n# =============================================================================\n\n# TODO: Add varlen mask implementations here\n\n\n# =============================================================================\n# Eager reference functions (PyTorch/Flex Attention signatures)\n# =============================================================================\n\n\ndef get_flex_causal_mask(offset: int):\n    def _flex_causal_mask(b, h, q_idx, kv_idx):\n        return kv_idx <= q_idx + offset\n\n    return _flex_causal_mask\n\n\ndef get_flex_block_causal_mask(offset: int):\n    def _flex_block_causal_mask(b, h, q_idx, kv_idx):\n        return kv_idx <= q_idx + offset\n\n    return _flex_block_causal_mask\n\n\ndef get_flex_sliding_window_mask(window_left: int, window_right: int, offset: int):\n    def _flex_sliding_window_mask(b, h, q_idx, kv_idx):\n        center = q_idx + offset\n        lower = center - window_left\n        upper = center + window_right\n        return (kv_idx >= lower) & (kv_idx <= upper)\n\n    return _flex_sliding_window_mask\n\n\ndef flex_block_diagonal_mask(b, h, q_idx, kv_idx):\n    block_size = 128\n    return (q_idx // block_size) == (kv_idx // block_size)\n\n\ndef flex_mini_causal_mask(b, h, q_idx, kv_idx):\n    return (q_idx % 128) >= (kv_idx % 128)\n\n\ndef flex_prefix_lm_mask(b, h, q_idx, kv_idx):\n    \"\"\"Prefix LM mask: first 512 tokens attend bidirectionally, rest use causal masking.\"\"\"\n    prefix_size = 512\n    both_in_prefix = (q_idx < prefix_size) & (kv_idx < prefix_size)\n    causal_part = q_idx >= kv_idx\n    return both_in_prefix | causal_part\n\n\ndef flex_dilated_sliding_window_mask(b, h, q_idx, kv_idx):\n    \"\"\"Dilated sliding window: every other position in a 256-position window.\"\"\"\n    window_size = 256\n    dilation = 2\n    in_window = (q_idx >= kv_idx) & (q_idx - kv_idx < window_size)\n    dilated = ((q_idx - kv_idx) % dilation) == 0\n    return in_window & dilated\n\n\ndef flex_document_mask(b, h, q_idx, kv_idx, doc_id):\n    return doc_id[b, h, q_idx] == doc_id[b, h, kv_idx]\n\n\ndef flex_ima_mask(b, h, q_idx, kv_idx, bias):\n    return kv_idx >= bias[kv_idx]\n\n\n# =============================================================================\n# Utility functions\n# =============================================================================\n\n\ndef random_doc_id_tensor(nheads, batch, seqlen_q, device=\"cpu\"):\n    \"\"\"Generate synthetic document ids shared across heads.\"\"\"\n    doc_ids_tensor = torch.zeros(batch, nheads, seqlen_q, dtype=torch.int32, device=device)\n    for b in range(batch):\n        N = seqlen_q\n        max_segments = max(1, math.ceil(math.sqrt(max(N // 4, 1))))\n        n = random.randint(1, max_segments)\n        n = min(n, N)\n        cuts = sorted(random.sample(range(1, N), n - 1))\n        lengths = [b - a for a, b in zip((0, *cuts), (*cuts, N))]\n        base_doc_ids = torch.repeat_interleave(\n            torch.arange(len(lengths), device=device, dtype=torch.int32),\n            torch.tensor(lengths, device=device, dtype=torch.int32),\n        )\n\n        for h in range(nheads):\n            doc_ids_tensor[b, h, :] = base_doc_ids\n    return doc_ids_tensor\n\n\n# =============================================================================\n# Mask registry and factory functions\n# =============================================================================\n\n\nSTATIC_MASKS = {\n    \"block_diagonal\": (cute_block_diagonal_mask, flex_block_diagonal_mask),\n    \"mini_causal\": (cute_mini_causal_mask, flex_mini_causal_mask),\n    \"prefix_lm\": (cute_prefix_lm_mask, flex_prefix_lm_mask),\n    \"dilated_sliding_window\": (\n        cute_dilated_sliding_window_mask,\n        flex_dilated_sliding_window_mask,\n    ),\n    \"document\": (cute_document_mask, flex_document_mask),\n    \"ima\": (cute_ima_mask, flex_ima_mask),\n}\n\nPARAMETERIZED_MASK_FACTORIES = {\n    \"causal\": (get_cute_causal_mask, get_flex_causal_mask),\n    \"block_causal\": (get_cute_block_causal_mask, get_flex_block_causal_mask),\n    \"sliding_window\": (get_cute_sliding_window_mask, get_flex_sliding_window_mask),\n}\n\n\ndef get_mask_pair(mask_name, seqlen_q=None, seqlen_k=None, window_size=None):\n    \"\"\"Get (cute_mask, flex_mask) pair for the given mask name.\n\n    For static masks, seqlen info is not needed.\n    For parameterized masks, seqlen_q and seqlen_k are required.\n    \"\"\"\n    if mask_name in STATIC_MASKS:\n        return STATIC_MASKS[mask_name]\n\n    if mask_name not in PARAMETERIZED_MASK_FACTORIES:\n        raise ValueError(f\"Unknown mask: {mask_name}\")\n\n    if seqlen_q is None or seqlen_k is None:\n        raise ValueError(\n            f\"Parameterized mask '{mask_name}' requires seqlen_q and seqlen_k\"\n        )\n\n    cute_factory, flex_factory = PARAMETERIZED_MASK_FACTORIES[mask_name]\n    offset = seqlen_k - seqlen_q\n\n    if mask_name == \"sliding_window\":\n        if window_size is None:\n            raise ValueError(\"sliding_window mask requires window_size parameter\")\n        cute_mask = cute_factory(window_size, window_size, offset)\n        flex_mask = flex_factory(window_size, window_size, offset)\n    else:\n        cute_mask = cute_factory(offset)\n        flex_mask = flex_factory(offset)\n\n    return cute_mask, flex_mask\n\n\nif __name__ == \"__main__\":\n    doc_ids = random_doc_id_tensor(1, 2, 128)\n    print(f\"{doc_ids = }\")\n"
  },
  {
    "path": "tests/cute/score_mod_definitions.py",
    "content": "import torch\nimport cutlass\nimport cutlass.cute as cute\nfrom cutlass._mlir.dialects import math as mlir_math\nimport operator\n\n# =============================================================================\n# Score_mod functions that don't use global indices\n# All use signature: (tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors)\n# =============================================================================\n\n\n@cute.jit\ndef score_mod_identity(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):\n    return tSrS_ssa\n\n\n@cute.jit\ndef score_mod_identity_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):\n    return tSrS_ssa\n\n\n@cute.jit\ndef score_mod_causal(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):\n    mask = operator.ge(q_idx, kv_idx)\n    return cute.where(mask, tSrS_ssa, cute.full_like(tSrS_ssa, float(\"-inf\")))\n\n\n@cute.jit\ndef score_mod_causal_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):\n    mask = cute.make_rmem_tensor(kv_idx.shape, dtype=cutlass.Boolean)\n    kv_idx0 = kv_idx[0]\n    q_idx0 = q_idx[0]\n    for i in cutlass.range_constexpr(cute.size(mask.shape)):\n        mask[i] = q_idx0 >= kv_idx0 + i\n    mask_ssa = mask.load()\n    return cute.where(mask_ssa, tSrS_ssa, cute.full_like(tSrS_ssa, float(\"-inf\")))\n\n\n@cute.jit\ndef score_mod_rel_bias(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):\n    diff = q_idx - kv_idx\n    abs_diff = cute.TensorSSA(mlir_math.absi(diff), diff.shape, diff.dtype)\n    return tSrS_ssa + abs_diff.to(cutlass.Float32)\n\n\n@cute.jit\ndef score_mod_rel_bias_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):\n    q_idx0 = q_idx[0]\n    kv_idx0 = kv_idx[0]\n    diff0 = q_idx0 - kv_idx0\n    abs_diff = cute.make_rmem_tensor(kv_idx.shape, dtype=diff0.dtype)\n    for i in cutlass.range_constexpr(cute.size(kv_idx.shape)):\n        diffi = diff0 - i\n        abs_diff[i] = mlir_math.absi(diffi)\n    return tSrS_ssa + abs_diff.load().to(cutlass.Float32)\n\n\n@cute.jit\ndef score_mod_rel_bias_x2(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):\n    diff = q_idx - kv_idx\n    abs_diff = cute.TensorSSA(mlir_math.absi(diff), diff.shape, diff.dtype)\n    scaled = abs_diff * cute.full_like(abs_diff, 2)\n    return tSrS_ssa + scaled.to(cutlass.Float32)\n\n\n@cute.jit\ndef score_mod_rel_bias_x2_vectorized(\n    tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors\n):\n    q_idx0 = q_idx[0]\n    kv_idx0 = kv_idx[0]\n    diff0 = q_idx0 - kv_idx0\n    abs_diff_x2 = cute.make_rmem_tensor(kv_idx.shape, dtype=diff0.dtype)\n    for i in cutlass.range_constexpr(cute.size(kv_idx.shape)):\n        diffi = diff0 - i\n        abs_diff_x2[i] = mlir_math.absi(diffi) * 2\n    return tSrS_ssa + abs_diff_x2.load().to(cutlass.Float32)\n\n\n@cute.jit\ndef score_mod_times_two(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):\n    return tSrS_ssa * cute.full_like(tSrS_ssa, 2)\n\nscore_mod_times_two_vectorized = score_mod_times_two\n\n@cute.jit\ndef score_mod_alibi(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):\n    score = tSrS_ssa.to(cutlass.Float32)\n    slope_exp = (h_idx + cute.full_like(h_idx, 1)) * cute.full_like(h_idx, -8)\n    slope = cute.math.exp2(\n        slope_exp.to(cutlass.Float32)\n        * cute.full_like(score, 0.125 * 0.6931471805599453 * 1.4426950408889634)\n    )\n    diff = q_idx - kv_idx\n    abs_diff = cute.TensorSSA(mlir_math.absi(diff), diff.shape, diff.dtype).to(cutlass.Float32)\n    return score - slope * abs_diff\n\n@cute.jit\ndef score_mod_alibi_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):\n    score = tSrS_ssa.to(cutlass.Float32)\n    slope_exp = (h_idx + cute.full_like(h_idx, 1)) * cute.full_like(h_idx, -8)\n    slope = cute.math.exp2(\n        slope_exp.to(cutlass.Float32)\n        * cute.full_like(score, 0.125 * 0.6931471805599453 * 1.4426950408889634)\n    )\n    diff0 = q_idx[0] - kv_idx[0]\n    abs_diff = cute.make_rmem_tensor(kv_idx.shape, diff0.dtype)\n    for i in cutlass.range_constexpr(cute.size(abs_diff.shape)):\n        diffi = diff0 - i\n        abs_diff[i] = mlir_math.absi(diffi)\n    return score - slope * abs_diff.load().to(cutlass.Float32)\n\n\n@cute.jit\ndef score_mod_sliding_window(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):\n    diff = q_idx - kv_idx\n    abs_diff = cute.TensorSSA(mlir_math.absi(diff), diff.shape, diff.dtype)\n    mask = operator.le(abs_diff, cute.full_like(abs_diff, 256))\n    return cute.where(mask, tSrS_ssa, cute.full_like(tSrS_ssa, float(\"-inf\")))\n\n\n@cute.jit\ndef score_mod_block_diagonal(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):\n    q_block = q_idx // 64\n    kv_block = kv_idx // 64\n    mask = operator.eq(q_block, kv_block)\n    return cute.where(mask, tSrS_ssa, cute.full_like(tSrS_ssa, float(\"-inf\")))\n\n\n@cute.jit\ndef score_mod_causal_v2(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):\n    diff = q_idx - kv_idx\n    mask = operator.ge(diff, cute.full_like(diff, 0))\n    return cute.where(mask, tSrS_ssa, cute.full_like(tSrS_ssa, float(\"-inf\")))\n\n\n@cute.jit\ndef score_mod_batch_bias(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):\n    batch_bias = aux_tensors[0]\n    dtype = batch_bias.element_type\n    b_frag = cute.make_fragment(1, cutlass.Int32)\n    b_frag.store(b_idx)\n    bias_frag = cute.make_fragment(1, dtype)\n    bias_frag[0] = batch_bias[b_frag[0]]\n    bias_val = (bias_frag.load()).to(cutlass.Float32)\n    return tSrS_ssa + bias_val\n\n@cute.jit\ndef score_mod_batch_bias_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):\n    batch_bias = aux_tensors[0]\n    dtype = batch_bias.element_type\n    b_idx0 = b_idx[0]\n    bias_frag = cute.make_rmem_tensor(1, dtype)\n    bias_frag[0] = batch_bias[b_idx0]\n    bias_val = (bias_frag.load()).to(cutlass.Float32)\n    return tSrS_ssa + bias_val\n\n\n@cute.jit\ndef score_mod_dual_buffer(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):\n    head_bias = aux_tensors[0]\n    pos_bias = aux_tensors[1]\n    dtype = head_bias.element_type\n\n    h_frag = cute.make_fragment(1, cutlass.Int32)\n    h_frag.store(h_idx)\n    head_val_frag = cute.make_fragment(1, dtype)\n    head_val_frag[0] = head_bias[h_frag[0]]\n    head_val = (head_val_frag.load()).to(cutlass.Float32)\n\n    q_frag = cute.make_fragment(1, cutlass.Int32)\n    q_frag.store(q_idx)\n    pos_val_frag = cute.make_fragment(1, dtype)\n    pos_val_frag[0] = pos_bias[q_frag[0]]\n    pos_val = (pos_val_frag.load()).to(cutlass.Float32)\n\n    return tSrS_ssa + head_val + pos_val\n\n@cute.jit\ndef score_mod_dual_buffer_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):\n    head_bias = aux_tensors[0]\n    pos_bias = aux_tensors[1]\n    dtype = head_bias.element_type\n\n    head_val_frag = cute.make_fragment(1, dtype)\n    head_val_frag[0] = head_bias[h_idx[0]]\n    head_val = (head_val_frag.load()).to(cutlass.Float32)\n\n    pos_val_frag = cute.make_fragment(1, dtype)\n    pos_val_frag[0] = pos_bias[q_idx[0]]\n    pos_val = (pos_val_frag.load()).to(cutlass.Float32)\n\n    return tSrS_ssa + head_val + pos_val\n\n\n# =============================================================================\n# Score_mod functions that use global indices\n# All use signature: (tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors)\n# Global indices computed as: q_idx_global = q_idx + seqlen_info.offset_q (and similarly for kv)\n# =============================================================================\n\n\n@cute.jit\ndef score_mod_global_kv_bias(\n    tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors\n):\n    \"\"\"Per-token bias using global kv index.\"\"\"\n    offset_k = seqlen_info.offset_k\n    kv_idx_global = kv_idx + offset_k\n    token_bias = aux_tensors[0]\n    dtype = token_bias.element_type\n    kv_frag = cute.make_fragment(1, cutlass.Int32)\n    kv_frag.store(kv_idx_global)\n    bias_frag = cute.make_fragment(1, dtype)\n    bias_frag[0] = token_bias[kv_frag[0]]\n\n    return tSrS_ssa + (bias_frag.load()).to(cutlass.Float32)\n\n\n@cute.jit\ndef score_mod_global_q_bias(\n    tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors\n):\n    \"\"\"Per-token bias using global q index.\"\"\"\n    offset_q = seqlen_info.offset_q\n    q_idx_global = q_idx + offset_q\n    token_bias = aux_tensors[0]\n    dtype = token_bias.element_type\n    q_frag = cute.make_fragment(1, cutlass.Int32)\n    q_frag.store(q_idx_global)\n    bias_frag = cute.make_fragment(1, dtype)\n    bias_frag[0] = token_bias[q_frag[0]]\n    return tSrS_ssa + (bias_frag.load()).to(cutlass.Float32)\n\n\n@cute.jit\ndef score_mod_global_rel_plus_kv_bias(\n    tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors\n):\n    \"\"\"Relative position (logical) + per-token bias (global kv).\"\"\"\n    offset_k = seqlen_info.offset_k\n    kv_idx_global = kv_idx + offset_k\n    token_bias = aux_tensors[0]\n    dtype = token_bias.element_type\n\n    rel_pos = q_idx - kv_idx\n    rel_pos_abs = cute.TensorSSA(mlir_math.absi(rel_pos), rel_pos.shape, rel_pos.dtype)\n    rel_bias = rel_pos_abs.to(cutlass.Float32) * cute.full_like(tSrS_ssa, 0.1)\n\n    kv_frag = cute.make_fragment(1, cutlass.Int32)\n    kv_frag.store(kv_idx_global)\n    bias_frag = cute.make_fragment(1, dtype)\n    bias_frag[0] = token_bias[kv_frag[0]]\n\n    return tSrS_ssa + rel_bias + (bias_frag.load()).to(cutlass.Float32)\n\n\n@cute.jit\ndef score_mod_global_q_and_kv_bias(\n    tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors\n):\n    \"\"\"Both q and kv global indices.\"\"\"\n    offset_q = seqlen_info.offset_q\n    q_idx_global = q_idx + offset_q\n    offset_k = seqlen_info.offset_k\n    kv_idx_global = kv_idx + offset_k\n    q_bias = aux_tensors[0]\n    kv_bias = aux_tensors[1]\n    dtype = q_bias.element_type\n\n    q_frag = cute.make_fragment(1, cutlass.Int32)\n    q_frag.store(q_idx_global)\n    q_bias_frag = cute.make_fragment(1, dtype)\n    q_bias_frag[0] = q_bias[q_frag[0]]\n\n    kv_frag = cute.make_fragment(1, cutlass.Int32)\n    kv_frag.store(kv_idx_global)\n    kv_bias_frag = cute.make_fragment(1, dtype)\n    kv_bias_frag[0] = kv_bias[kv_frag[0]]\n\n    return (\n        tSrS_ssa\n        + (q_bias_frag.load()).to(cutlass.Float32)\n        + (kv_bias_frag.load()).to(cutlass.Float32)\n    )\n\n\n@cute.jit\ndef score_mod_global_logical_rel_plus_kv_bias(\n    tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors\n):\n    \"\"\"Logical relative + global-indexed per-token bias.\"\"\"\n    offset_k = seqlen_info.offset_k\n    kv_idx_global = kv_idx + offset_k\n    token_bias = aux_tensors[0]\n    dtype = token_bias.element_type\n\n    rel_pos = q_idx - kv_idx\n    rel_pos_abs = cute.TensorSSA(mlir_math.absi(rel_pos), rel_pos.shape, rel_pos.dtype)\n    rel_bias = rel_pos_abs.to(cutlass.Float32) * cute.full_like(tSrS_ssa, 0.01)\n\n    kv_frag = cute.make_fragment(1, cutlass.Int32)\n    kv_frag.store(kv_idx_global)\n    bias_frag = cute.make_fragment(1, dtype)\n    bias_frag[0] = token_bias[kv_frag[0]]\n\n    return tSrS_ssa + rel_bias + (bias_frag.load()).to(cutlass.Float32)\n\n\n# \"Stress tests\" - score_mods with complex global index usage\n\n@cute.jit\ndef score_mod_stress_complex_arithmetic(\n    tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors\n):\n    \"\"\"All indices in complex arithmetic.\"\"\"\n    offset_q = seqlen_info.offset_q\n    q_idx_global = q_idx + offset_q\n    bias = aux_tensors[0]\n    dtype = bias.element_type\n\n    # Use absolute value instead of squaring to avoid overflow with large sequences\n    rel_pos = q_idx - kv_idx\n    rel_pos_abs = cute.TensorSSA(mlir_math.absi(rel_pos), rel_pos.shape, rel_pos.dtype)\n    rel_bias = rel_pos_abs.to(cutlass.Float32) * cute.full_like(tSrS_ssa, 0.001)\n\n    q_frag = cute.make_fragment(1, cutlass.Int32)\n    q_frag.store(q_idx_global)\n    bias_q_frag = cute.make_fragment(1, dtype)\n    bias_q_frag[0] = bias[q_frag[0]]\n    bias_q = (bias_q_frag.load()).to(cutlass.Float32)\n\n    scale = (b_idx + cute.full_like(b_idx, 1)) * (h_idx + cute.full_like(h_idx, 1))\n    scale_f32 = scale.to(cutlass.Float32) * 0.001\n\n    result = tSrS_ssa + rel_bias + bias_q * scale_f32\n    return result\n\n\n@cute.jit\ndef score_mod_stress_conditional_mask(\n    tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors\n):\n    \"\"\"Conditional masking with global vs logical.\"\"\"\n    offset_q = seqlen_info.offset_q\n    q_idx_global = q_idx + offset_q\n    offset_k = seqlen_info.offset_k\n    kv_idx_global = kv_idx + offset_k\n    token_bias = aux_tensors[0]\n    dtype = token_bias.element_type\n\n    kv_frag = cute.make_fragment(1, cutlass.Int32)\n    kv_frag.store(kv_idx_global)\n    bias_frag = cute.make_fragment(1, dtype)\n    bias_frag[0] = token_bias[kv_frag[0]]\n    bias_val = (bias_frag.load()).to(cutlass.Float32)\n\n    is_causal = operator.ge(q_idx, kv_idx)\n\n    global_diff = q_idx_global - kv_idx_global\n    is_nearby = operator.le(\n        cute.TensorSSA(mlir_math.absi(global_diff), global_diff.shape, global_diff.dtype),\n        cute.full_like(global_diff, 512),\n    )\n\n    both_conditions = is_causal & is_nearby\n    return cute.where(both_conditions, tSrS_ssa + bias_val, cute.full_like(tSrS_ssa, float(\"-inf\")))\n\n\n@cute.jit\ndef score_mod_stress_multi_buffer(\n    tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors\n):\n    \"\"\"Multiple aux tensors with different indexing.\"\"\"\n    offset_q = seqlen_info.offset_q\n    q_idx_global = q_idx + offset_q\n    offset_k = seqlen_info.offset_k\n    kv_idx_global = kv_idx + offset_k\n    batch_bias = aux_tensors[0]\n    head_scale = aux_tensors[1]\n    q_pos_bias = aux_tensors[2]\n    kv_pos_bias = aux_tensors[3]\n    rel_pos_scale = aux_tensors[4]\n\n    dtype = batch_bias.element_type\n\n    b_frag = cute.make_fragment(1, cutlass.Int32)\n    b_frag.store(b_idx)\n    bb_frag = cute.make_fragment(1, dtype)\n    bb_frag[0] = batch_bias[b_frag[0]]\n    bb_val = (bb_frag.load()).to(cutlass.Float32)\n\n    h_frag = cute.make_fragment(1, cutlass.Int32)\n    h_frag.store(h_idx)\n    hs_frag = cute.make_fragment(1, dtype)\n    hs_frag[0] = head_scale[h_frag[0]]\n    hs_val = (hs_frag.load()).to(cutlass.Float32)\n\n    qg_frag = cute.make_fragment(1, cutlass.Int32)\n    qg_frag.store(q_idx_global)\n    qpb_frag = cute.make_fragment(1, dtype)\n    qpb_frag[0] = q_pos_bias[qg_frag[0]]\n    qpb_val = (qpb_frag.load()).to(cutlass.Float32)\n\n    kvg_frag = cute.make_fragment(1, cutlass.Int32)\n    kvg_frag.store(kv_idx_global)\n    kvpb_frag = cute.make_fragment(1, dtype)\n    kvpb_frag[0] = kv_pos_bias[kvg_frag[0]]\n    kvpb_val = (kvpb_frag.load()).to(cutlass.Float32)\n\n    rel_idx = q_idx - kv_idx + cute.full_like(q_idx, 512)\n    rel_idx_clamped = cute.where(\n        operator.lt(rel_idx, cute.full_like(rel_idx, 0)), cute.full_like(rel_idx, 0), rel_idx\n    )\n    rel_idx_clamped = cute.where(\n        operator.gt(rel_idx_clamped, cute.full_like(rel_idx_clamped, 1024)),\n        cute.full_like(rel_idx_clamped, 1024),\n        rel_idx_clamped,\n    )\n    ri_frag = cute.make_fragment(1, cutlass.Int32)\n    ri_frag.store(rel_idx_clamped)\n    rps_frag = cute.make_fragment(1, dtype)\n    rps_frag[0] = rel_pos_scale[ri_frag[0]]\n    rps_val = (rps_frag.load()).to(cutlass.Float32)\n\n    return tSrS_ssa * hs_val + bb_val + qpb_val + kvpb_val + rps_val * cute.full_like(tSrS_ssa, 0.1)\n\n\n@cute.jit\ndef score_mod_stress_global_offset(\n    tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors\n):\n    \"\"\"Verify global - logical = offset.\"\"\"\n    offset_k = seqlen_info.offset_k\n    kv_idx_global = kv_idx + offset_k\n    token_bias = aux_tensors[0]\n    dtype = token_bias.element_type\n\n    kv_frag = cute.make_fragment(1, cutlass.Int32)\n    kv_frag.store(kv_idx_global)\n    bias_frag = cute.make_fragment(1, dtype)\n    bias_frag[0] = token_bias[kv_frag[0]]\n\n    return tSrS_ssa + (bias_frag.load()).to(cutlass.Float32)\n\n\n@cute.jit\ndef score_mod_stress_xor_pattern(\n    tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors\n):\n    \"\"\"XOR-based pattern using index bits.\"\"\"\n    offset_k = seqlen_info.offset_k\n    kv_idx_global = kv_idx + offset_k\n    token_bias = aux_tensors[0]\n    dtype = token_bias.element_type\n\n    xor_logical = q_idx ^ kv_idx\n    pattern_logical = xor_logical & cute.full_like(xor_logical, 0xFF)\n    pattern_bias = pattern_logical.to(cutlass.Float32) * cute.full_like(tSrS_ssa, 0.001)\n\n    kv_frag = cute.make_fragment(1, cutlass.Int32)\n    kv_frag.store(kv_idx_global)\n    bias_frag = cute.make_fragment(1, dtype)\n    bias_frag[0] = token_bias[kv_frag[0]]\n\n    return (\n        tSrS_ssa\n        + pattern_bias\n        + (bias_frag.load()).to(cutlass.Float32) * cute.full_like(tSrS_ssa, 0.1)\n    )\n\n\n@cute.jit\ndef score_mod_debug_global_idx(\n    tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors\n):\n    # Don't read from aux_tensors at all - just add the global index as bias\n    offset_k = seqlen_info.offset_k\n    kv_idx_global = kv_idx + offset_k\n    bias = kv_idx_global.to(cutlass.Float32) * cute.full_like(tSrS_ssa, 0.001)\n    return tSrS_ssa + bias\n\n\n# =============================================================================\n# Eager reference functions\n# =============================================================================\n\n\ndef identity_eager(score, b, h, q_idx, kv_idx):\n    return score\n\n\ndef causal_eager(score, b, h, q_idx, kv_idx):\n    return torch.where(q_idx >= kv_idx, score, float(\"-inf\"))\n\n\ndef rel_bias_eager(score, b, h, q_idx, kv_idx):\n    return score + torch.abs(q_idx - kv_idx)\n\n\ndef rel_bias_x2_eager(score, b, h, q_idx, kv_idx):\n    return score + 2 * torch.abs(q_idx - kv_idx)\n\n\ndef times_two_eager(score, b, h, q_idx, kv_idx):\n    return score * 2\n\n\ndef alibi_eager(score, b, h, q_idx, kv_idx):\n    slope = 2 ** (-8 * (h + 1) / 8)\n    return score - slope * torch.abs(q_idx - kv_idx)\n\n\ndef sliding_window_eager(score, b, h, q_idx, kv_idx):\n    return torch.where(torch.abs(q_idx - kv_idx) <= 256, score, float(\"-inf\"))\n\n\ndef block_diagonal_eager(score, b, h, q_idx, kv_idx):\n    return torch.where(q_idx // 64 == kv_idx // 64, score, float(\"-inf\"))\n\n\ndef causal_v2_eager(score, b, h, q_idx, kv_idx):\n    return torch.where(q_idx - kv_idx >= 0, score, float(\"-inf\"))\n\n\ndef batch_bias_factory(bias_tensor):\n    def mod(score, b, h, q_idx, kv_idx):\n        return score + bias_tensor[b]\n\n    return mod\n\n\ndef dual_buffer_factory(head_bias, pos_bias):\n    def mod(score, b, h, q_idx, kv_idx):\n        return score + head_bias[h] + pos_bias[q_idx]\n\n    return mod\n\n\ndef packed_kv_bias_factory(bias_tensor, cu_seqlens_k):\n    def mod(score, b, h, q_idx, kv_idx):\n        # Calculate valid length for this sequence\n        start = cu_seqlens_k[b]\n        seq_len = cu_seqlens_k[b+1] - start\n\n        # Clamp kv_idx.\n        safe_kv_idx = torch.clamp(kv_idx, max=seq_len - 1)\n\n        return score + bias_tensor[start + safe_kv_idx]\n    return mod\n\n\ndef packed_q_bias_factory(bias_tensor, cu_seqlens_q):\n    def mod(score, b, h, q_idx, kv_idx):\n        start = cu_seqlens_q[b]\n        seq_len = cu_seqlens_q[b+1] - start\n\n        # Clamp q_idx\n        safe_q_idx = torch.clamp(q_idx, max=seq_len - 1)\n\n        return score + bias_tensor[start + safe_q_idx]\n    return mod\n\n\ndef packed_rel_plus_kv_bias_factory(bias_tensor, cu_seqlens_k):\n    def mod(score, b, h, q_idx, kv_idx):\n        start = cu_seqlens_k[b]\n        seq_len = cu_seqlens_k[b+1] - start\n\n        # Clamp kv_idx\n        safe_kv_idx = torch.clamp(kv_idx, max=seq_len - 1)\n\n        rel_bias = torch.abs(q_idx - kv_idx).float() * 0.1\n        return score + rel_bias + bias_tensor[start + safe_kv_idx]\n\n    return mod\n\n\ndef packed_q_and_kv_bias_factory(q_bias, kv_bias, cu_seqlens_q, cu_seqlens_k):\n    def mod(score, b, h, q_idx, kv_idx):\n        # Handle Q bounds\n        q_start = cu_seqlens_q[b]\n        q_len = cu_seqlens_q[b+1] - q_start\n        safe_q_idx = torch.clamp(q_idx, max=q_len - 1)\n\n        # Handle KV bounds\n        kv_start = cu_seqlens_k[b]\n        kv_len = cu_seqlens_k[b+1] - kv_start\n        safe_kv_idx = torch.clamp(kv_idx, max=kv_len - 1)\n\n        return score + q_bias[q_start + safe_q_idx] + kv_bias[kv_start + safe_kv_idx]\n\n    return mod\n\n\ndef packed_logical_rel_plus_kv_bias_factory(bias_tensor, cu_seqlens_k):\n    def mod(score, b, h, q_idx, kv_idx):\n        rel_bias = torch.abs(q_idx - kv_idx).float() * 0.01\n        return score + rel_bias + bias_tensor[cu_seqlens_k[b] + kv_idx]\n\n    return mod\n\n\ndef stress_complex_arithmetic_factory(bias, cu_seqlens_q):\n    def mod(score, b, h, q_idx, kv_idx):\n        # Use absolute value instead of squaring to avoid overflow with large sequences\n        rel_pos_abs = torch.abs(q_idx - kv_idx)\n        q_global = cu_seqlens_q[b] + q_idx\n        bias_q = bias[q_global]\n        scale = (b + 1) * (h + 1) * 0.001\n        rel_bias = rel_pos_abs * 0.001\n        return score + rel_bias + bias_q * scale\n\n    return mod\n\n\ndef stress_conditional_mask_factory(token_bias, cu_seqlens_q, cu_seqlens_k):\n    def mod(score, b, h, q_idx, kv_idx):\n        kv_global = cu_seqlens_k[b] + kv_idx\n        bias_val = token_bias[kv_global]\n        is_causal = q_idx >= kv_idx\n        q_global = cu_seqlens_q[b] + q_idx\n        global_diff = q_global - kv_global\n        is_nearby = torch.abs(global_diff) <= 512\n        both_conditions = is_causal & is_nearby\n        return torch.where(both_conditions, score + bias_val, float(\"-inf\"))\n\n    return mod\n\n\ndef stress_multi_buffer_factory(\n    batch_bias,\n    head_scale,\n    q_pos_bias,\n    kv_pos_bias,\n    rel_pos_scale,\n    cu_seqlens_q,\n    cu_seqlens_k,\n    max_rel_pos=512,\n):\n    def mod(score, b, h, q_idx, kv_idx):\n        bb_val = batch_bias[b]\n        hs_val = head_scale[h]\n        qpb_val = q_pos_bias[cu_seqlens_q[b] + q_idx]\n        kvpb_val = kv_pos_bias[cu_seqlens_k[b] + kv_idx]\n        rel_idx = (q_idx - kv_idx + max_rel_pos).clamp(0, max_rel_pos * 2)\n        rps_val = rel_pos_scale[rel_idx]\n        return score * hs_val + bb_val + qpb_val + kvpb_val + rps_val * 0.1\n\n    return mod\n\n\ndef stress_global_offset_factory(token_bias, cu_seqlens_k):\n    def mod(score, b, h, q_idx, kv_idx):\n        return score + token_bias[cu_seqlens_k[b] + kv_idx]\n\n    return mod\n\n\ndef stress_xor_pattern_factory(token_bias, cu_seqlens_q, cu_seqlens_k):\n    def mod(score, b, h, q_idx, kv_idx):\n        xor_logical = q_idx ^ kv_idx\n        pattern_bias = (xor_logical & 0xFF).float() * 0.001\n        kv_global = cu_seqlens_k[b] + kv_idx\n        return score + pattern_bias + token_bias[kv_global] * 0.1\n\n    return mod\n\ndef debug_global_idx_factory(bias, cu_seqlens_k):\n    offsets = cu_seqlens_k.tolist()\n    def mod(score, b, h, q_idx, kv_idx):\n        global_kv = offsets[b] + kv_idx\n        return score + global_kv.float() * 0.001\n    return mod\n"
  },
  {
    "path": "tests/cute/test_block_sparsity.py",
    "content": "\"\"\"Tests for block sparsity computation in flash attention.\"\"\"\n\nimport pytest\nimport torch\nfrom torch.nn.attention.flex_attention import create_block_mask\n\nfrom mask_mod_definitions import get_mask_pair\nfrom flash_attn.cute.compute_block_sparsity import compute_block_sparsity\n\n\ndef _call_compute_block_sparsity(\n    batch_size,\n    nheads,\n    seqlen_q,\n    seqlen_k,\n    tile_m,\n    tile_n,\n    mask_name,\n    window_size=None,\n    aux_tensors=None,\n    use_fast_sampling=False,\n):\n    \"\"\"Call compute_block_sparsity and return torch tensors.\"\"\"\n    cute_mask, _ = get_mask_pair(\n        mask_name, seqlen_q=seqlen_q, seqlen_k=seqlen_k, window_size=window_size\n    )\n    _, torch_tensors = compute_block_sparsity(\n        tile_m=tile_m,\n        tile_n=tile_n,\n        batch_size=batch_size,\n        num_heads=nheads,\n        seqlen_q=seqlen_q,\n        seqlen_k=seqlen_k,\n        mask_mod=cute_mask,\n        aux_tensors=aux_tensors,\n        device=\"cuda\",\n        use_fast_sampling=use_fast_sampling,\n    )\n    mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx, *_ = torch_tensors\n    return mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx\n\n\ndef _compare_block_sparsity(\n    mask_block_cnt,\n    mask_block_idx,\n    full_block_cnt,\n    full_block_idx,\n    mask_block_cnt_ref,\n    mask_block_idx_ref,\n    full_block_cnt_ref,\n    full_block_idx_ref,\n    batch_size,\n    nheads,\n    seqlen_q,\n    seqlen_k,\n    tile_m,\n    tile_n,\n):\n    \"\"\"Compare block sparsity against reference, handling boundary block semantics.\n\n    PyTorch treats OOB regions as masked, so boundary blocks with all in-bounds\n    elements unmasked appear as \"partial\" in PyTorch but \"full\" in CuTe.\n\n    This applies to BOTH boundary m_blocks (OOB q_idx) and boundary n_blocks (OOB kv_idx).\n    \"\"\"\n    if not isinstance(mask_block_cnt, torch.Tensor):\n        return False, f\"mask_block_cnt is not a tensor: {type(mask_block_cnt)}\"\n\n    n_blocks_q = mask_block_cnt.shape[2]\n\n    # Identify boundary blocks\n    last_m_block = (seqlen_q - 1) // tile_m\n    last_n_block = (seqlen_k - 1) // tile_n\n    m_is_boundary = seqlen_q % tile_m != 0\n    n_is_boundary = seqlen_k % tile_n != 0\n\n    def is_boundary_n_block(n_block):\n        return n_is_boundary and n_block == last_n_block\n\n    def is_boundary_m_block(m_block):\n        return m_is_boundary and m_block == last_m_block\n\n    for b in range(batch_size):\n        for h in range(nheads):\n            for m in range(n_blocks_q):\n                cute_mask_cnt = mask_block_cnt[b, h, m].item()\n                cute_full_cnt = full_block_cnt[b, h, m].item()\n                ref_mask_cnt = mask_block_cnt_ref[b, h, m].item()\n                ref_full_cnt = full_block_cnt_ref[b, h, m].item()\n\n                cute_mask_set = set(mask_block_idx[b, h, m, :cute_mask_cnt].tolist())\n                cute_full_set = set(full_block_idx[b, h, m, :cute_full_cnt].tolist())\n                ref_mask_set = set(mask_block_idx_ref[b, h, m, :ref_mask_cnt].tolist())\n                ref_full_set = set(full_block_idx_ref[b, h, m, :ref_full_cnt].tolist())\n\n                # A block is \"boundary-affected\" if EITHER the m_block OR n_block is at boundary\n                def is_boundary_affected(n_block):\n                    return is_boundary_m_block(m) or is_boundary_n_block(n_block)\n\n                # Blocks that are full in CuTe but not in ref\n                full_in_cute_not_ref = cute_full_set - ref_full_set\n\n                for n_block in full_in_cute_not_ref:\n                    if not is_boundary_affected(n_block):\n                        return False, (\n                            f\"Non-boundary block mismatch at [{b},{h},{m}]: \"\n                            f\"n_block {n_block} is full in CuTe but not in ref\"\n                        )\n                    # Boundary-affected: CuTe says full, ref should say partial\n                    if n_block not in ref_mask_set:\n                        # Check if ref skipped it entirely (all masked)\n                        # This is valid for boundary blocks\n                        pass\n\n                # Blocks that are partial in CuTe but full in ref (would be a bug)\n                partial_in_cute_full_in_ref = cute_mask_set & ref_full_set\n                if partial_in_cute_full_in_ref:\n                    return False, (\n                        f\"Block mismatch at [{b},{h},{m}]: \"\n                        f\"n_blocks {sorted(partial_in_cute_full_in_ref)} are partial in CuTe but full in ref\"\n                    )\n\n                # Check non-boundary blocks match exactly\n                non_boundary_cute_full = {\n                    n for n in cute_full_set if not is_boundary_affected(n)\n                }\n                non_boundary_ref_full = {\n                    n for n in ref_full_set if not is_boundary_affected(n)\n                }\n                if non_boundary_cute_full != non_boundary_ref_full:\n                    return False, (\n                        f\"Non-boundary full block mismatch at [{b},{h},{m}]: \"\n                        f\"CuTe={sorted(non_boundary_cute_full)}, ref={sorted(non_boundary_ref_full)}\"\n                    )\n\n                non_boundary_cute_mask = {\n                    n for n in cute_mask_set if not is_boundary_affected(n)\n                }\n                non_boundary_ref_mask = {\n                    n for n in ref_mask_set if not is_boundary_affected(n)\n                }\n                if non_boundary_cute_mask != non_boundary_ref_mask:\n                    return False, (\n                        f\"Non-boundary partial block mismatch at [{b},{h},{m}]: \"\n                        f\"CuTe={sorted(non_boundary_cute_mask)}, ref={sorted(non_boundary_ref_mask)}\"\n                    )\n\n    return True, \"\"\n\n\n# Test configurations\nSEQLEN_PAIRS = [\n    # Small aligned\n    (64, 64),\n    (128, 128),\n    (256, 256),\n    (512, 512),\n    # Rectangular\n    (128, 256),\n    (256, 128),\n    (512, 256),\n    (256, 512),\n    # Large aligned\n    (1024, 1024),\n    (2048, 2048),\n    (4096, 4096),\n    (8192, 8192),\n    # Large unaligned\n    (1000, 1000),\n    (2000, 2000),\n    (4000, 4000),\n    # Edge cases with unaligned seqlens\n    (113, 203),\n    (127, 127),\n    (129, 129),\n    (255, 255),\n    (257, 257),\n    (1023, 1023),\n    (1025, 1025),\n    (2047, 2047),\n    (2049, 2049),\n]\nTILE_SIZES = [\n    # Standard powers of 2\n    (32, 32),\n    (64, 64),\n    (128, 128),\n    (256, 256),\n    # Rectangular\n    (32, 64),\n    (64, 32),\n    (64, 128),\n    (128, 64),\n    (128, 256),\n    (256, 128),\n    # Unusual sizes\n    (40, 40),\n    (48, 48),\n    (96, 96),\n    (112, 112),\n    (32, 128),\n    (128, 32),\n    (40, 96),\n    (96, 40),\n]\n\n\n@pytest.mark.parametrize(\"seqlen_q,seqlen_k\", SEQLEN_PAIRS)\n@pytest.mark.parametrize(\"tile_m,tile_n\", TILE_SIZES)\n@pytest.mark.parametrize(\"batch_size\", [1, 2])\n@pytest.mark.parametrize(\"nheads\", [1, 4])\n@pytest.mark.parametrize(\"mask_name\", [\"block_diagonal\", \"mini_causal\"])\ndef test_fixed_length_masks(\n    seqlen_q, seqlen_k, tile_m, tile_n, batch_size, nheads, mask_name\n):\n    \"\"\"Test fixed-length masks.\"\"\"\n    seqlen_unaligned = (seqlen_q % tile_m != 0) or (seqlen_k % tile_n != 0)\n\n    mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = (\n        _call_compute_block_sparsity(\n            batch_size,\n            nheads,\n            seqlen_q,\n            seqlen_k,\n            tile_m,\n            tile_n,\n            mask_name,\n            use_fast_sampling=False,\n        )\n    )\n\n    _, mask_mod_flex = get_mask_pair(mask_name)\n    block_mask = create_block_mask(\n        mask_mod_flex,\n        B=batch_size,\n        H=nheads,\n        Q_LEN=seqlen_q,\n        KV_LEN=seqlen_k,\n        device=\"cuda\",\n        BLOCK_SIZE=(tile_m, tile_n),\n    )\n    (\n        _,\n        _,\n        mask_block_cnt_ref,\n        mask_block_idx_ref,\n        full_block_cnt_ref,\n        full_block_idx_ref,\n        *_,\n    ) = block_mask.as_tuple()\n\n    print(\"CuTe results:\")\n    print(f\"    mask_block_cnt: {mask_block_cnt}\")\n    print(f\"    full_block_cnt: {full_block_cnt}\")\n    print(f\"    mask_block_idx: {mask_block_idx}\")\n    print(f\"    full_block_idx: {full_block_idx}\")\n    print(\"Torch results:\")\n    print(f\"    mask_block_cnt: {mask_block_cnt_ref}\")\n    print(f\"    full_block_cnt: {full_block_cnt_ref}\")\n    print(f\"    mask_block_idx: {mask_block_idx_ref}\")\n    print(f\"    full_block_idx: {full_block_idx_ref}\")\n\n    all_match, error_msg = _compare_block_sparsity(\n        mask_block_cnt,\n        mask_block_idx,\n        full_block_cnt,\n        full_block_idx,\n        mask_block_cnt_ref,\n        mask_block_idx_ref,\n        full_block_cnt_ref,\n        full_block_idx_ref,\n        batch_size,\n        nheads,\n        seqlen_q,\n        seqlen_k,\n        tile_m,\n        tile_n,\n    )\n    assert all_match, f\"Mismatch: {error_msg}\"\n\n\n@pytest.mark.parametrize(\"seqlen_q,seqlen_k\", SEQLEN_PAIRS)\n@pytest.mark.parametrize(\n    \"tile_m,tile_n\", [(64, 64), (128, 128), (64, 128), (128, 64), (256, 256)]\n)\n@pytest.mark.parametrize(\"batch_size\", [1])\n@pytest.mark.parametrize(\"nheads\", [1, 4])\n@pytest.mark.parametrize(\n    \"mask_name,window_size\",\n    [(\"causal\", None), (\"sliding_window\", 64), (\"sliding_window\", 256)],\n)\ndef test_parameterized_masks(\n    seqlen_q, seqlen_k, tile_m, tile_n, batch_size, nheads, mask_name, window_size\n):\n    \"\"\"Test parameterized masks.\"\"\"\n    if mask_name == \"sliding_window\" and seqlen_q > seqlen_k:\n        pytest.skip(\"Sliding window not supported for seqlen_q > seqlen_k\")\n\n    mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = (\n        _call_compute_block_sparsity(\n            batch_size,\n            nheads,\n            seqlen_q,\n            seqlen_k,\n            tile_m,\n            tile_n,\n            mask_name,\n            window_size=window_size,\n        )\n    )\n\n    _, mask_mod_flex = get_mask_pair(\n        mask_name, seqlen_q=seqlen_q, seqlen_k=seqlen_k, window_size=window_size\n    )\n    block_mask = create_block_mask(\n        mask_mod_flex,\n        B=batch_size,\n        H=nheads,\n        Q_LEN=seqlen_q,\n        KV_LEN=seqlen_k,\n        device=\"cuda\",\n        BLOCK_SIZE=(tile_m, tile_n),\n    )\n    (\n        _,\n        _,\n        mask_block_cnt_ref,\n        mask_block_idx_ref,\n        full_block_cnt_ref,\n        full_block_idx_ref,\n        *_,\n    ) = block_mask.as_tuple()\n\n    all_match, error_msg = _compare_block_sparsity(\n        mask_block_cnt,\n        mask_block_idx,\n        full_block_cnt,\n        full_block_idx,\n        mask_block_cnt_ref,\n        mask_block_idx_ref,\n        full_block_cnt_ref,\n        full_block_idx_ref,\n        batch_size,\n        nheads,\n        seqlen_q,\n        seqlen_k,\n        tile_m,\n        tile_n,\n    )\n\n    assert all_match, f\"Mismatch: {error_msg}\"\n\n\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_k,tile_m,tile_n\",\n    [\n        (1, 1, 64, 64),\n        (63, 63, 64, 64),\n        (65, 65, 64, 64),\n        (129, 129, 128, 128),\n        (100, 200, 64, 128),\n    ],\n)\ndef test_edge_cases(seqlen_q, seqlen_k, tile_m, tile_n):\n    \"\"\"Test edge cases with unaligned dimensions.\"\"\"\n    batch_size, nheads = 1, 1\n    seqlen_unaligned = (seqlen_q % tile_m != 0) or (seqlen_k % tile_n != 0)\n\n    mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = (\n        _call_compute_block_sparsity(\n            batch_size,\n            nheads,\n            seqlen_q,\n            seqlen_k,\n            tile_m,\n            tile_n,\n            \"causal\",\n        )\n    )\n\n    _, mask_mod_flex = get_mask_pair(\"causal\", seqlen_q=seqlen_q, seqlen_k=seqlen_k)\n    block_mask = create_block_mask(\n        mask_mod_flex,\n        B=batch_size,\n        H=nheads,\n        Q_LEN=seqlen_q,\n        KV_LEN=seqlen_k,\n        device=\"cuda\",\n        BLOCK_SIZE=(tile_m, tile_n),\n    )\n    (\n        _,\n        _,\n        mask_block_cnt_ref,\n        mask_block_idx_ref,\n        full_block_cnt_ref,\n        full_block_idx_ref,\n        *_,\n    ) = block_mask.as_tuple()\n\n    all_match, error_msg = _compare_block_sparsity(\n        mask_block_cnt,\n        mask_block_idx,\n        full_block_cnt,\n        full_block_idx,\n        mask_block_cnt_ref,\n        mask_block_idx_ref,\n        full_block_cnt_ref,\n        full_block_idx_ref,\n        batch_size,\n        nheads,\n        seqlen_q,\n        seqlen_k,\n        tile_m,\n        tile_n,\n    )\n    assert all_match, f\"Mismatch: {error_msg}\"\n\n\n@pytest.mark.parametrize(\"seqlen_q,seqlen_k\", SEQLEN_PAIRS)\n@pytest.mark.parametrize(\n    \"tile_m,tile_n\", [(64, 64), (128, 128), (64, 128), (128, 64), (256, 256)]\n)\n@pytest.mark.parametrize(\"nheads\", [1, 4])\n@pytest.mark.parametrize(\"mask_name\", [\"causal\", \"block_diagonal\"])\ndef test_fast_sampling(seqlen_q, seqlen_k, tile_m, tile_n, nheads, mask_name):\n    \"\"\"Test fast sampling mode (5-point sampling).\"\"\"\n    batch_size = 1\n    seqlen_unaligned = (seqlen_q % tile_m != 0) or (seqlen_k % tile_n != 0)\n\n    mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = (\n        _call_compute_block_sparsity(\n            batch_size,\n            nheads,\n            seqlen_q,\n            seqlen_k,\n            tile_m,\n            tile_n,\n            mask_name,\n            use_fast_sampling=True,\n        )\n    )\n\n    _, mask_mod_flex = get_mask_pair(mask_name, seqlen_q=seqlen_q, seqlen_k=seqlen_k)\n    block_mask = create_block_mask(\n        mask_mod_flex,\n        B=batch_size,\n        H=nheads,\n        Q_LEN=seqlen_q,\n        KV_LEN=seqlen_k,\n        device=\"cuda\",\n        BLOCK_SIZE=(tile_m, tile_n),\n    )\n    (\n        _,\n        _,\n        mask_block_cnt_ref,\n        mask_block_idx_ref,\n        full_block_cnt_ref,\n        full_block_idx_ref,\n        *_,\n    ) = block_mask.as_tuple()\n\n    all_match, error_msg = _compare_block_sparsity(\n        mask_block_cnt,\n        mask_block_idx,\n        full_block_cnt,\n        full_block_idx,\n        mask_block_cnt_ref,\n        mask_block_idx_ref,\n        full_block_cnt_ref,\n        full_block_idx_ref,\n        batch_size,\n        nheads,\n        seqlen_q,\n        seqlen_k,\n        tile_m,\n        tile_n,\n    )\n\n    assert all_match, f\"Mismatch: {error_msg}\"\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__, \"-v\"])\n"
  },
  {
    "path": "tests/cute/test_flash_attn.py",
    "content": "# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n\nimport math\nimport itertools\nimport os\nimport random\nimport re\n\nimport pytest\nimport torch\n\nfrom einops import rearrange, repeat\n\ntry:\n    from flash_attn.layers.rotary import apply_rotary_emb\nexcept ImportError:\n    apply_rotary_emb = None\n\nfrom flash_attn.cute.testing import (\n    attention_ref,\n    generate_qkv,\n    generate_random_padding_mask,\n    pad_input,\n    unpad_input,\n    maybe_fake_tensor_mode,\n    is_fake_mode,\n)\nfrom flash_attn.cute.interface import (\n    flash_attn_func,\n    flash_attn_varlen_func,\n)\n\n# torch FakeTensorMode would enable fast cutedsl kernel compilation without allocating the actual GPU memory or running the kernel\n# When operating fake tensors, we cannot perform data-dependent operations (e.g., `tensor.max()`).\nUSE_FAKE_TENSOR = int(os.getenv(\"FLASH_ATTENTION_FAKE_TENSOR\", 0)) == 1\nDISABLE_SPLIT = os.getenv(\"FLASH_ATTENTION_DISABLE_SPLIT\", \"FALSE\") == \"TRUE\"\n# SplitKV is not supported on SM90\nIS_SM90 = torch.cuda.get_device_capability()[0] == 9\nIS_SM100 = torch.cuda.get_device_capability()[0] == 10\nTEST_BWD_ONLY = False\nVERBOSE = True\n\n# @pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])\n@pytest.mark.parametrize(\"dtype\", [torch.bfloat16])\n@pytest.mark.parametrize(\"mha_type\", [\"mha\", \"mqa\", \"gqa\"])\n# @pytest.mark.parametrize(\"mha_type\", [\"mha\"])\n@pytest.mark.parametrize(\"has_learnable_sink\", [False, True])\n# @pytest.mark.parametrize(\"has_learnable_sink\", [False])\n# @pytest.mark.parametrize(\"has_qv\", [False, True])\n@pytest.mark.parametrize(\"has_qv\", [False])\n@pytest.mark.parametrize(\"deterministic\", [False, True])\n# @pytest.mark.parametrize(\"deterministic\", [False])\n# @pytest.mark.parametrize(\"softcap\", [0.0, 15.0])\n@pytest.mark.parametrize(\"softcap\", [0.0])\n@pytest.mark.parametrize(\"local_enum\", [0, 1, 2, 3])\n# @pytest.mark.parametrize(\"local_enum\", [0])\n@pytest.mark.parametrize(\"causal\", [False, True])\n# @pytest.mark.parametrize(\"causal\", [False])\n# @pytest.mark.parametrize(\"d\", [32, 64, 96, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256])\n# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])\n# @pytest.mark.parametrize('d', [56, 80])\n# @pytest.mark.parametrize(\"d\", [64, 128, 256])\n# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128])\n# @pytest.mark.parametrize(\"d\", [64, 96, 128, 192])\n# @pytest.mark.parametrize(\"d\", [128, 192])\n@pytest.mark.parametrize(\"d\", [64, 96, 128, 192, 256])\n# @pytest.mark.parametrize(\"d\", [128])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_k\",\n    [\n        (1, 1),\n        (3, 3),\n        (64, 32),\n        (64, 128),\n        (128, 128),\n        (128, 192),\n        (256, 256),\n        (239, 1),\n        (799, 3),\n        (113, 203),\n        (113, 128),\n        (128, 217),\n        (113, 211),\n        (108, 256),\n        (256, 512),\n        (384, 256),\n        (640, 128),\n        (512, 256),\n        (1024, 1024),\n        (1023, 1024),\n        (1024, 1023),\n        (2048, 2048),\n        (4096, 4096),\n        (4224, 4224),\n    ],\n)\n# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])\n@maybe_fake_tensor_mode(USE_FAKE_TENSOR)\ndef test_flash_attn_output(\n    seqlen_q,\n    seqlen_k,\n    d,\n    causal,\n    local_enum,\n    softcap,\n    deterministic,\n    has_qv,\n    has_learnable_sink,\n    mha_type,\n    dtype,\n):\n    local = local_enum > 0\n    if local and causal:\n        pytest.skip()\n    device = \"cuda\"\n    # set seed\n    seed = 0\n    random.seed(seed)\n    torch.random.manual_seed(seed)\n    torch.cuda.empty_cache()\n    torch.cuda.synchronize()\n    batch_size = 9 if seqlen_k <= 2048 else 2\n    # batch_size = 2\n    nheads = 6\n    # nheads = 1\n    nheads_kv = nheads if mha_type == \"mha\" else (3 if mha_type == \"gqa\" else 1)\n    dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype\n    # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d])\n    dv_vals = [128] if d == 192 else ([d] if d != 128 else [64, d])\n    if dtype == torch.float8_e4m3fn:\n        dv_vals = [d]\n    # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0]\n    attention_chunk_vals = [0]\n    for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals):\n        q_ref = torch.randn(\n            batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref\n        )\n        if softcap > 0.0:\n            # Ensure the values of qk are at least within softcap range.\n            q_ref = q_ref * softcap / 4\n        q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_()\n        k_ref = (\n            torch.randn(\n                batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref\n            )\n            .to(dtype)\n            .to(dtype_ref)\n            .requires_grad_()\n        )\n        v_ref = (\n            torch.randn(\n                batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref\n            )\n            .to(dtype)\n            .to(dtype_ref)\n            .requires_grad_()\n        )\n        if has_qv:\n            qv_ref = (\n                torch.randn(\n                    batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref\n                )\n                .to(dtype)\n                .to(dtype_ref)\n            )\n        else:\n            qv_ref = None\n        # Put window_size after QKV randn so that window_size changes from test to test\n        window_size = (\n            (None, None) if not local else tuple(random.randrange(0, seqlen_k) for _ in range(2))\n        )\n        if local_enum == 2:\n            window_size = (None, -window_size[1])\n        elif local_enum == 3:\n            window_size = (-window_size[0], None)\n        if local:\n            print(\"window size = \", window_size)\n        # window_size = (-1, -1) if not local else (16, 0)\n        if has_learnable_sink:\n            learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device)\n        else:\n            learnable_sink = None\n        if dtype == torch.float8_e4m3fn:\n            q_descale, k_descale, v_descale = [\n                torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32)\n                * 2\n                for _ in range(3)\n            ]\n        else:\n            q_descale, k_descale, v_descale = None, None, None\n        q, k, v = [x.detach().to(dtype).requires_grad_() for x in (q_ref, k_ref, v_ref)]\n        qv = qv_ref.detach().to(dtype).requires_grad_() if has_qv else None\n        out_ref, attn_ref = attention_ref(\n            q_ref,\n            k_ref,\n            v_ref,\n            None,\n            None,\n            causal=causal,\n            qv=qv_ref,\n            q_descale=q_descale,\n            k_descale=k_descale,\n            v_descale=v_descale,\n            window_size=window_size,\n            attention_chunk=attention_chunk,\n            learnable_sink=learnable_sink,\n            softcap=softcap,\n        )\n        out_pt, attn_pt = attention_ref(\n            q_ref,\n            k_ref,\n            v_ref,\n            None,\n            None,\n            causal=causal,\n            qv=qv_ref,\n            q_descale=q_descale,\n            k_descale=k_descale,\n            v_descale=v_descale,\n            window_size=window_size,\n            attention_chunk=attention_chunk,\n            learnable_sink=learnable_sink,\n            softcap=softcap,\n            upcast=False,\n            reorder_ops=True,\n            intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None,\n        )\n\n        # k_extended = repeat(k_ref, \"b s h d -> b s (h k) d\", k=nheads // nheads_kv)\n        # qk = torch.einsum('bshd,bthd->bhst', q_ref, k_extended).float()\n        # # if qv is not None:\n        # #     qk += torch.einsum('bshd,bthd->bhst', qv_ref, v_ref).float()\n        # m = qk.amax(-1, keepdim=True)\n        # s_tmp = torch.exp((qk - m) / math.sqrt(d))\n        # exp_sum = s_tmp.sum(-1)\n        # # qk = torch.einsum('bthd,bshd->bhts', q_ref.float() / math.sqrt(d), k_ref.float())\n        # # lse_ref = torch.logsumexp(qk, dim=-1)\n\n        # Numerical error if we just do any arithmetic on out_ref\n        if not is_fake_mode():\n            fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item()\n            rtol = 2 if softcap == 0.0 else 3\n\n            print(f\"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}\")\n            print(f\"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}\")\n        # num_splits_vals = [1, 3]\n        pack_gqa_vals = [False, True, None] if not TEST_BWD_ONLY else [False]\n        # SplitKV is not supported for hdim >= 192\n        # pack_gqa_vals = [False]\n        num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT and not TEST_BWD_ONLY else [1]\n        for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals):\n            # SplitKV not supported on SM90 - skip this iteration\n            if IS_SM90 and num_splits > 1:\n                continue\n            if IS_SM100 and (d >= 192 and dv >= 192):  # hdim 192 and 256 not support on SM100\n                continue\n            out, lse = flash_attn_func(\n                q,\n                k,\n                v,\n                causal=causal,\n                # qv=qv,\n                # q_descale=q_descale, k_descale=k_descale, v_descale=v_descale,\n                window_size=window_size,\n                # attention_chunk=attention_chunk,\n                softcap=softcap,\n                learnable_sink=learnable_sink,\n                pack_gqa=pack_gqa,\n                num_splits=num_splits,\n                deterministic=deterministic,\n            )\n            if is_fake_mode():\n                # no more flash_attn cutedsl calls for the rest of the loop\n                # skip data-dependent postprocessing\n                continue\n            print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n            print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n            # if not causal:\n            #     print(f\"LSE max diff: {(lse - lse_ref).abs().max().item()}\")\n            # breakpoint()\n\n            # Check that FlashAttention's numerical error is at most twice the numerical error\n            # of a Pytorch implementation.\n            assert (out - out_ref).abs().max().item() <= rtol * (\n                out_pt - out_ref\n            ).abs().max().item() + fwd_atol\n\n        if (\n            dtype != torch.float8_e4m3fn\n            and not has_qv\n            and not dv > 256\n            and not attention_chunk != 0\n            and softcap == 0.0\n            and ((dv == d and d <= 128) or (d == 192 and dv == 128))\n            and learnable_sink is None\n            # and False\n            and not ((causal or local) and seqlen_k < seqlen_q)\n        ):\n            if d > 192 and IS_SM90:\n                pytest.xfail(\"hdim > 192 backward: SM90 not supported yet\")\n            if d != dv and mha_type != \"mha\" and IS_SM90:\n                pytest.xfail(\"SM90 GQA bwd currently requires headdim == headdim_v\")\n            g = torch.randn_like(out)\n            # do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2)\n            dq, dk, dv = torch.autograd.grad(out, (q, k, v), g)\n            if is_fake_mode():\n                # no more flash_attn cutedsl calls for the rest of the loop\n                # skip data-dependent postprocessing\n                continue\n            # print(f\"dO_O max diff: {(softmax_d - do_o).abs().max().item()}\")\n            # assert (softmax_d - do_o).abs().max().item() <= 1e-5\n            # assert dq_accum.abs().max().item() == 0.0\n\n            # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float())\n            # P = torch.softmax(qk, -1)\n            # dP = P * (dS - do_o.transpose(1, 2).unsqueeze(1))\n            # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float())\n            # dV = torch.einsum('bhts,bthd->bshd', P, g.float())\n            # dK = torch.einsum('bhts,bthd->bshd', dP, q.float())\n            # breakpoint()\n\n            # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g)\n            dq_ref, dk_ref, dv_ref = torch.autograd.grad(\n                out_ref, (q_ref, k_ref, v_ref), g\n            )\n            dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g)\n            print(f\"dQ max diff: {(dq - dq_ref).abs().max().item()}\")\n            print(f\"dK max diff: {(dk - dk_ref).abs().max().item()}\")\n            print(f\"dV max diff: {(dv - dv_ref).abs().max().item()}\")\n            print(f\"dQ mean diff: {(dq - dq_ref).abs().mean().item()}\")\n            print(f\"dK mean diff: {(dk - dk_ref).abs().mean().item()}\")\n            print(f\"dV mean diff: {(dv - dv_ref).abs().mean().item()}\")\n            print(f\"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}\")\n            print(f\"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}\")\n            print(f\"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}\")\n            print(f\"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}\")\n            print(f\"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}\")\n            print(f\"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}\")\n\n            if VERBOSE:\n                diff_dq = (dq - dq_ref).abs()\n                max_idx = diff_dq.argmax()\n                coords = torch.unravel_index(max_idx, diff_dq.shape)\n                print(f\"dQ max diff: {diff_dq.max().item()}\")\n                print(f\"  at coordinates {tuple(c.item() for c in coords)}: dQ={dq[coords].item()}, dQ_ref={dq_ref[coords].item()}\")\n\n                diff_dk = (dk - dk_ref).abs()\n                max_idx = diff_dk.argmax()\n                coords = torch.unravel_index(max_idx, diff_dk.shape)\n                print(f\"dK max diff: {diff_dk.max().item()}\")\n                print(f\"  at coordinates {tuple(c.item() for c in coords)}: dK={dk[coords].item()}, dK_ref={dk_ref[coords].item()}\")\n\n                diff_dv = (dv - dv_ref).abs()\n                max_idx = diff_dv.argmax()\n                coords = torch.unravel_index(max_idx, diff_dv.shape)\n                print(f\"dV max diff: {diff_dv.max().item()}\")\n                print(f\"  at coordinates {tuple(c.item() for c in coords)}: dV={dv[coords].item()}, dV_ref={dv_ref[coords].item()}\")\n\n            # breakpoint()\n            dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (\n                0 if softcap == 0 else 3e-4\n            )\n            assert (dq - dq_ref).abs().max().item() <= rtol * (\n                dq_pt - dq_ref\n            ).abs().max().item() + dq_atol\n            dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (\n                0 if softcap == 0 else 3e-4\n            )\n            assert (dk - dk_ref).abs().max().item() <= rtol * (\n                dk_pt - dk_ref\n            ).abs().max().item() + dk_atol\n            dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (\n                0 if softcap == 0 else 3e-4\n            )\n            assert (dv - dv_ref).abs().max().item() <= rtol * (\n                dv_pt - dv_ref\n            ).abs().max().item() + dv_atol\n\n\n# @pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])\n@pytest.mark.parametrize(\"dtype\", [torch.bfloat16])\n@pytest.mark.parametrize(\"mha_type\", [\"mha\", \"mqa\", \"gqa\"])\n# @pytest.mark.parametrize(\"mha_type\", [\"mha\"])\n# @pytest.mark.parametrize(\"has_learnable_sink\", [False, True])\n@pytest.mark.parametrize(\"has_learnable_sink\", [False])\n# @pytest.mark.parametrize(\"has_qv\", [False, True])\n@pytest.mark.parametrize(\"has_qv\", [False])\n@pytest.mark.parametrize(\"deterministic\", [False, True])\n# @pytest.mark.parametrize(\"deterministic\", [False])\n# @pytest.mark.parametrize(\"softcap\", [0.0, 15.0])\n@pytest.mark.parametrize(\"softcap\", [0.0])\n@pytest.mark.parametrize(\"local_enum\", [0, 1, 2, 3])\n# @pytest.mark.parametrize(\"local_enum\", [0])\n@pytest.mark.parametrize(\"causal\", [False, True])\n# @pytest.mark.parametrize(\"causal\", [False])\n# @pytest.mark.parametrize(\"add_unused_qkv\", [False, True])\n@pytest.mark.parametrize(\"add_unused_qkv\", [False])\n# @pytest.mark.parametrize(\"d\", [32, 64, 96, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256])\n# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])\n# @pytest.mark.parametrize('d', [56, 80])\n# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128])\n# @pytest.mark.parametrize(\"d\", [64, 96, 128])\n# @pytest.mark.parametrize(\"d\", [128, 192])\n@pytest.mark.parametrize(\"d\", [64, 128, 192])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_k\",\n    [\n        # (1, 1),\n        # (1, 3),\n        # (2, 1),\n        (511, 1),\n        (3, 513),\n        (64, 128),\n        (128, 128),\n        (256, 256),\n        (113, 203),\n        (128, 217),\n        (113, 211),\n        (108, 256),\n        (256, 512),\n        (307, 256),\n        (640, 128),\n        (512, 256),\n        (1024, 1024),\n        (1023, 1024),\n        (1024, 1023),\n        (2048, 2048),\n    ],\n)\n@pytest.mark.parametrize(\"varlen_mode\", [\"random\", \"third\", \"full\"])\n# @pytest.mark.parametrize(\"varlen_mode\", [\"full\"])\n@pytest.mark.parametrize(\n    \"zero_lengths_q, zero_lengths_k\",\n    [\n        (False, False),\n        (True, False),\n        (False, True),\n        (True, True),\n    ],\n)\n@pytest.mark.parametrize(\n    \"unpad_q, unpad_kv\",\n    [\n        (True, True),\n        (False, False),\n        (True, False),\n        (False, True),\n    ],\n)\n@maybe_fake_tensor_mode(USE_FAKE_TENSOR)\ndef test_flash_attn_varlen_output(\n    seqlen_q,\n    seqlen_k,\n    d,\n    add_unused_qkv,\n    causal,\n    local_enum,\n    softcap,\n    deterministic,\n    has_qv,\n    has_learnable_sink,\n    mha_type,\n    dtype,\n    varlen_mode,\n    zero_lengths_q,\n    zero_lengths_k,\n    unpad_q,\n    unpad_kv,\n):\n    local = local_enum > 0\n    if local and causal:\n        pytest.skip()\n    if (\n        causal or local\n    ):  # Right now reference only supports causal attention with seqlen_k == seqlen_q\n        seqlen_k = seqlen_q\n    device = \"cuda\"\n    # set seed\n    seed = seqlen_q + seqlen_k + d + int(causal) * 2 + int(local)\n    random.seed(seed)\n    torch.random.manual_seed(seed)\n    batch_size = 49 if seqlen_q <= 512 else 7\n    nheads = 6\n    # nheads = 1\n    nheads_kv = nheads if mha_type == \"mha\" else (3 if mha_type == \"gqa\" else 1)\n    dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype\n    # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d])\n    dv_vals = [128] if d == 192 else ([d] if d != 128 else [64, d])\n    if dtype == torch.float8_e4m3fn:\n        dv_vals = [d]\n    # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k else [0]\n    attention_chunk_vals = [0]\n    for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals):\n        q_ref = torch.randn(\n            batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref\n        )\n        if softcap > 0.0:\n            # Ensure the values of qk are at least within softcap range.\n            q_ref = (q_ref * softcap / 4).detach().requires_grad_()\n        q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_()\n        k_ref = (\n            torch.randn(\n                batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref\n            )\n            .to(dtype)\n            .to(dtype_ref)\n            .requires_grad_()\n        )\n        v_ref = (\n            torch.randn(\n                batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref\n            )\n            .to(dtype)\n            .to(dtype_ref)\n            .requires_grad_()\n        )\n        if has_qv:\n            qv_ref = (\n                torch.randn(\n                    batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref\n                )\n                .to(dtype)\n                .to(dtype_ref)\n            )\n        else:\n            qv_ref = None\n        # Put window_size after QKV randn so that window_size changes from test to test\n        window_size = (\n            (None, None) if not local else tuple(random.randrange(0, seqlen_k) for _ in range(2))\n        )\n        if local_enum == 2:\n            window_size = (None, window_size[1])\n        elif local_enum == 3:\n            window_size = (window_size[0], None)\n        if local:\n            print(\"window size = \", window_size)\n        if has_learnable_sink:\n            learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device)\n        else:\n            learnable_sink = None\n        if dtype == torch.float8_e4m3fn:\n            q_descale, k_descale, v_descale = [\n                torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32)\n                * 2\n                for _ in range(3)\n            ]\n        else:\n            q_descale, k_descale, v_descale = None, None, None\n        q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)]\n        qv = qv_ref.detach() if has_qv else None\n        query_padding_mask = generate_random_padding_mask(\n            seqlen_q,\n            batch_size,\n            device,\n            mode=varlen_mode,\n            zero_lengths=zero_lengths_q,\n        )\n        key_padding_mask = generate_random_padding_mask(\n            seqlen_k,\n            batch_size,\n            device,\n            mode=varlen_mode,\n            zero_lengths=zero_lengths_k,\n        )\n        def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device):\n            if add_unused:\n                another_mask = generate_random_padding_mask(max_seq_len, bs, device)\n                attn_mask = torch.logical_and(padding_mask, another_mask)\n                unused_mask = torch.logical_xor(\n                    torch.logical_or(padding_mask, another_mask), attn_mask\n                )\n            else:\n                attn_mask = padding_mask\n                unused_mask = None\n            return attn_mask, unused_mask\n\n        query_padding_mask, query_unused_mask = _gen_unused_masks(\n            query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device\n        )\n        # query_padding_mask[:] = True\n        # query_unused_mask = None\n        key_padding_mask, key_unused_mask = _gen_unused_masks(\n            key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device\n        )\n\n        if causal or local:\n            key_padding_mask = query_padding_mask\n\n        (\n            q_unpad,\n            k_unpad,\n            v_unpad,\n            qv_unpad,\n            cu_seqlens_q,\n            cu_seqlens_k,\n            seqused_q,\n            seqused_k,\n            max_seqlen_q,\n            max_seqlen_k,\n            q,\n            k,\n            v,\n            qv,\n            output_pad_fn,\n            dq_pad_fn,\n            dk_pad_fn,\n        ) = generate_qkv(\n            q,\n            k,\n            v,\n            query_padding_mask,\n            key_padding_mask,\n            qv=qv,\n            kvpacked=False,\n            query_unused_mask=query_unused_mask,\n            key_unused_mask=key_unused_mask,\n        )\n        if unpad_q:\n            print(\"cu_seqlens_q = \", cu_seqlens_q)\n        else:\n            print(\"seqused_q = \", seqused_q)\n        if unpad_kv:\n            print(\"cu_seqlens_k = \", cu_seqlens_k)\n        else:\n            print(\"seqused_k = \", seqused_k)\n        q_unpad, k_unpad, v_unpad = [\n            x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)\n        ]\n\n        out_ref, attn_ref = attention_ref(\n            q_ref,\n            k_ref,\n            v_ref,\n            query_padding_mask,\n            key_padding_mask,\n            causal=causal,\n            qv=qv_ref,\n            q_descale=q_descale,\n            k_descale=k_descale,\n            v_descale=v_descale,\n            window_size=window_size,\n            attention_chunk=attention_chunk,\n            learnable_sink=learnable_sink,\n            softcap=softcap,\n        )\n        out_pt, attn_pt = attention_ref(\n            q_ref,\n            k_ref,\n            v_ref,\n            query_padding_mask,\n            key_padding_mask,\n            causal=causal,\n            qv=qv_ref,\n            q_descale=q_descale,\n            k_descale=k_descale,\n            v_descale=v_descale,\n            window_size=window_size,\n            attention_chunk=attention_chunk,\n            learnable_sink=learnable_sink,\n            softcap=softcap,\n            upcast=False,\n            reorder_ops=True,\n            intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None,\n        )\n\n        if not is_fake_mode():\n            print(f\"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}\")\n            print(f\"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}\")\n\n            if query_unused_mask is not None:\n                q_zero_masking = rearrange(query_unused_mask, \"b s -> b s 1 1\")\n\n            # Numerical error if we just do any arithmetic on out_ref\n            fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item()\n            rtol = 2 if softcap == 0.0 else 3\n\n        pack_gqa_vals = [False, True, None] if not TEST_BWD_ONLY else [False]\n        # pack_gqa_vals = [False]\n        # num_splits_vals = [1, 3]\n        # SplitKV is not supported for hdim >= 192\n        num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT and not TEST_BWD_ONLY else [1]\n        for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals):\n            # SplitKV not supported on SM90 - skip this iteration\n            if IS_SM90 and num_splits > 1:\n                continue\n            out_unpad, lse = flash_attn_varlen_func(\n                q_unpad if unpad_q else q,\n                k_unpad if unpad_kv else k,\n                v_unpad if unpad_kv else v,\n                cu_seqlens_q=cu_seqlens_q if unpad_q else None,\n                cu_seqlens_k=cu_seqlens_k if unpad_kv else None,\n                max_seqlen_q=seqlen_q,\n                max_seqlen_k=seqlen_k,\n                seqused_q=seqused_q if not unpad_q else None,\n                seqused_k=seqused_k if not unpad_kv else None,\n                causal=causal,\n                # qv=qv_unpad,\n                # q_descale=q_descale,\n                # k_descale=k_descale, v_descale=v_descale,\n                window_size=window_size,\n                # attention_chunk=attention_chunk,\n                learnable_sink=learnable_sink,\n                softcap=softcap,\n                num_splits=num_splits,\n                pack_gqa=pack_gqa,\n                deterministic=deterministic,\n            )\n            out = output_pad_fn(out_unpad) if unpad_q else out_unpad\n            if is_fake_mode():\n                # no more flash_attn cutedsl calls for the rest of the loop\n                # skip data-dependent postprocessing\n                continue\n            if query_unused_mask is not None:\n                out.masked_fill_(q_zero_masking, 0.0)\n            print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n            print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n            # if not causal:\n            #     print(f\"LSE max diff: {(lse - lse_ref).abs().max().item()}\")\n            # breakpoint()\n\n            # Check that FlashAttention's numerical error is at most 3x the numerical error\n            # of a Pytorch implementation.\n            assert (out - out_ref).abs().max().item() <= rtol * (\n                out_pt - out_ref\n            ).abs().max().item() + fwd_atol\n\n        if (\n            dtype != torch.float8_e4m3fn\n            and not has_qv\n            and not dv > 256\n            and not attention_chunk != 0\n            and ((dv == d and d <= 128) or (d == 192 and dv == 128))\n            and not has_learnable_sink\n            # and False\n        ):\n            if d > 192 and IS_SM90:\n                pytest.xfail(\"hdim > 192 backward: SM90 not supported yet\")\n            if d != dv and mha_type != \"mha\" and IS_SM90:\n                pytest.xfail(\"SM90 GQA bwd currently requires headdim == headdim_v\")\n            g_unpad = torch.randn_like(out_unpad)\n            # do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2)\n            # import flash_attn_3_cuda\n            # dq_unpad, dk_unpad, dv_unpad, softmax_d, dq_accum, lse_log2 = flash_attn_3_cuda.bwd_varlen(\n            #     g_unpad,\n            #     q_unpad,\n            #     k_unpad,\n            #     v_unpad,\n            #     out_unpad,\n            #     lse,\n            #     None,\n            #     None,\n            #     None,\n            #     cu_seqlens_q,\n            #     cu_seqlens_k,\n            #     None, None,\n            #     max_seqlen_q,\n            #     max_seqlen_k,\n            #     d ** (-0.5),\n            #     causal,\n            #     window_size[0], window_size[1],\n            #     softcap,\n            #     deterministic,\n            #     0,  # sm_margin\n            # )\n            dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad(\n                out_unpad,\n                (\n                    q_unpad if unpad_q else q,\n                    k_unpad if unpad_kv else k,\n                    v_unpad if unpad_kv else v,\n                ),\n                g_unpad\n            )\n            if is_fake_mode():\n                # no more flash_attn cutedsl calls for the rest of the loop\n                # skip data-dependent postprocessing\n                continue\n            dq = dq_pad_fn(dq_unpad) if unpad_q else dq_unpad\n            dk = dk_pad_fn(dk_unpad) if unpad_kv else dk_unpad\n            dv = dk_pad_fn(dv_unpad) if unpad_kv else dv_unpad\n            if key_unused_mask is not None:\n                k_zero_masking = rearrange(key_unused_mask, \"b s -> b s 1 1\")\n                dk.masked_fill_(k_zero_masking, 0.0)\n                dv.masked_fill_(k_zero_masking, 0.0)\n            if query_unused_mask is not None:\n                dq.masked_fill_(q_zero_masking, 0.0)\n            if not unpad_kv:\n                dk.masked_fill_(rearrange(~key_padding_mask, \"b s -> b s 1 1\"), 0.0)\n                dv.masked_fill_(rearrange(~key_padding_mask, \"b s -> b s 1 1\"), 0.0)\n            if not unpad_q:\n                dq.masked_fill_(rearrange(~query_padding_mask, \"b s -> b s 1 1\"), 0.0)\n            # print(f\"dO_O max diff: {(softmax_d - do_o).abs().max().item()}\")\n            # assert (softmax_d - do_o).abs().max().item() <= 1e-5\n            # assert dq_accum.abs().max().item() == 0.0\n            g = output_pad_fn(g_unpad) if unpad_q else g_unpad\n\n            # qk = torch.einsum('bthd,bshd->bhts', q / (d ** 0.5), k).float()\n            # qk = torch.masked_fill(qk, rearrange(~key_padding_mask, \"b s -> b 1 1 s\"), float(\"-inf\"))\n            # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float())\n            # P = torch.softmax(qk, -1)\n            # dP = P * (dS - (g.float() * out.float()).sum(-1).transpose(1, 2).unsqueeze(-1))\n            # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float())\n            # dV = torch.einsum('bhts,bthd->bshd', P, g.float())\n            # dK = torch.einsum('bhts,bthd->bshd', dP, q.float())\n\n            # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g)\n            dq_ref, dk_ref, dv_ref = torch.autograd.grad(\n                out_ref, (q_ref, k_ref, v_ref), g\n            )\n            dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g)\n            print(f\"dQ max diff: {(dq - dq_ref).abs().max().item()}\")\n            print(f\"dK max diff: {(dk - dk_ref).abs().max().item()}\")\n            print(f\"dV max diff: {(dv - dv_ref).abs().max().item()}\")\n            print(f\"dQ mean diff: {(dq - dq_ref).abs().mean().item()}\")\n            print(f\"dK mean diff: {(dk - dk_ref).abs().mean().item()}\")\n            print(f\"dV mean diff: {(dv - dv_ref).abs().mean().item()}\")\n            print(f\"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}\")\n            print(f\"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}\")\n            print(f\"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}\")\n            print(f\"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}\")\n            print(f\"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}\")\n            print(f\"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}\")\n            if VERBOSE:\n                diff_dq = (dq - dq_ref).abs()\n                max_idx = diff_dq.argmax()\n                coords = torch.unravel_index(max_idx, diff_dq.shape)\n                print(f\"dQ max diff: {diff_dq.max().item()}\")\n                print(f\"  at coordinates {tuple(c.item() for c in coords)}: dQ={dq[coords].item()}, dQ_ref={dq_ref[coords].item()}\")\n\n                diff_dk = (dk - dk_ref).abs()\n                max_idx = diff_dk.argmax()\n                coords = torch.unravel_index(max_idx, diff_dk.shape)\n                print(f\"dK max diff: {diff_dk.max().item()}\")\n                print(f\"  at coordinates {tuple(c.item() for c in coords)}: dK={dk[coords].item()}, dK_ref={dk_ref[coords].item()}\")\n\n                diff_dv = (dv - dv_ref).abs()\n                max_idx = diff_dv.argmax()\n                coords = torch.unravel_index(max_idx, diff_dv.shape)\n                print(f\"dV max diff: {diff_dv.max().item()}\")\n                print(f\"  at coordinates {tuple(c.item() for c in coords)}: dV={dv[coords].item()}, dV_ref={dv_ref[coords].item()}\")\n            # breakpoint()\n            dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (\n                0 if softcap == 0 else 3e-4\n            )\n            assert (dq - dq_ref).abs().max().item() <= rtol * (\n                dq_pt - dq_ref\n            ).abs().max().item() + dq_atol\n            dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (\n                0 if softcap == 0 else 3e-4\n            )\n            assert (dk - dk_ref).abs().max().item() <= rtol * (\n                dk_pt - dk_ref\n            ).abs().max().item() + dk_atol\n            dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (\n                0 if softcap == 0 else 3e-4\n            )\n            assert (dv - dv_ref).abs().max().item() <= rtol * (\n                dv_pt - dv_ref\n            ).abs().max().item() + dv_atol\n\n\n# @pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])\n@pytest.mark.parametrize(\"dtype\", [torch.bfloat16])\n# @pytest.mark.parametrize(\"dtype\", [torch.float8_e4m3fn])\n@pytest.mark.parametrize(\"mha_type\", [\"mha\", \"mqa\", \"gqa\"])\n# @pytest.mark.parametrize(\"mha_type\", [\"mha\"])\n@pytest.mark.parametrize(\"has_learnable_sink\", [False, True])\n# @pytest.mark.parametrize(\"has_learnable_sink\", [False])\n# @pytest.mark.parametrize(\"new_kv\", [False, True])\n@pytest.mark.parametrize(\"new_kv\", [False])\n@pytest.mark.parametrize(\"local\", [False, True])\n# @pytest.mark.parametrize(\"local\", [False])\n@pytest.mark.parametrize(\"causal\", [False, True])\n# @pytest.mark.parametrize(\"causal\", [True])\n# @pytest.mark.parametrize(\"seqlen_new_eq_seqlen_q\", [True, False])\n@pytest.mark.parametrize(\"seqlen_new_eq_seqlen_q\", [False])\n# @pytest.mark.parametrize(\"has_rotary_seqlens\", [False, True])\n@pytest.mark.parametrize(\"has_rotary_seqlens\", [False])\n# @pytest.mark.parametrize(\"rotary_interleaved\", [False, True])\n@pytest.mark.parametrize(\"rotary_interleaved\", [True])\n# @pytest.mark.parametrize(\"rotary_fraction\", [0.0, 0.5, 1.0])\n@pytest.mark.parametrize(\"rotary_fraction\", [0.0])\n@pytest.mark.parametrize(\"page_size\", [None] + ([1, 4, 128]))\n# @pytest.mark.parametrize(\"page_size\", [None, 128])\n# @pytest.mark.parametrize(\"page_size\", [128])\n# @pytest.mark.parametrize(\"has_leftpad\", [False, True])\n@pytest.mark.parametrize(\"has_leftpad\", [False])\n# @pytest.mark.parametrize(\"has_batch_idx\", [False, True])\n@pytest.mark.parametrize(\"has_batch_idx\", [False])\n@pytest.mark.parametrize(\"varlen_q\", [False, True])\n# @pytest.mark.parametrize(\"varlen_q\", [False])\n# @pytest.mark.parametrize(\"d\", [32, 59, 64, 80, 128, 256])\n# @pytest.mark.parametrize(\"d\", [32, 64, 96, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])\n# @pytest.mark.parametrize('d', [56, 80])\n# @pytest.mark.parametrize(\"d\", [128])\n@pytest.mark.parametrize(\"d\", [64, 128])\n# @pytest.mark.parametrize(\"d\", [192])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_k\",\n    [\n        (1, 128),\n        (1, 339),\n        (3, 1024),\n        (64, 800),\n        (64, 256),\n        (3, 799),\n        (64, 2048),\n        (16, 20000),\n        # # (1, 128 * 1024),\n        # # (16, 128 * 1024),\n        # (128, 128),\n        # (256, 512),  # To test appending KV with more than 1 block\n        # (2048, 3577),  # Enough tile to test persistent scheduler\n    ],\n)\n# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])\n@maybe_fake_tensor_mode(USE_FAKE_TENSOR)\ndef test_flash_attn_kvcache(\n    seqlen_q,\n    seqlen_k,\n    d,\n    varlen_q,\n    has_batch_idx,\n    has_leftpad,\n    page_size,\n    rotary_fraction,\n    rotary_interleaved,\n    has_rotary_seqlens,\n    seqlen_new_eq_seqlen_q,\n    causal,\n    local,\n    new_kv,\n    has_learnable_sink,\n    mha_type,\n    dtype,\n):\n    if page_size is not None and seqlen_k % page_size != 0:\n        pytest.skip()\n    if seqlen_q > seqlen_k and new_kv:\n        pytest.skip()\n    if not new_kv and rotary_fraction > 0.0:\n        pytest.skip()\n    if rotary_fraction == 0.0 and has_rotary_seqlens:\n        pytest.skip()\n    device = \"cuda\"\n    # set seed\n    seed = 0\n    random.seed(seed)\n    torch.random.manual_seed(seed)\n    batch_size = 5\n    # batch_size = 1\n    batch_size_cache = batch_size if not has_batch_idx else batch_size * 2\n    nheads = 6\n    # nheads = 1\n    # rotary_dim must be a multiple of 16, and must be <= d\n    rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16\n    nheads_k = nheads if mha_type == \"mha\" else (1 if mha_type == \"mqa\" else 3)\n    assert nheads % nheads_k == 0\n    dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype\n    # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d])\n    dv_vals = [d]\n    if dtype == torch.float8_e4m3fn:\n        dv_vals = [d]\n    # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if (causal or local) else [0]\n    attention_chunk_vals = [0]\n    for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals):\n        # has_qv = d == 64 and dv >= 256\n        has_qv = False\n        q = (\n            torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref)\n            .to(dtype)\n            .to(dtype_ref)\n        )\n        if has_qv:\n            qv = (\n                torch.randn(\n                    batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref\n                )\n                .to(dtype)\n                .to(dtype_ref)\n            )\n        else:\n            qv = None\n        if varlen_q:\n            query_padding_mask = generate_random_padding_mask(\n                seqlen_q, batch_size, device, mode=\"random\"\n            )\n            q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, *rest = unpad_input(\n                q, query_padding_mask\n            )\n            output_pad_fn = lambda output_unpad: pad_input(\n                output_unpad, indices_q, batch_size, seqlen_q\n            )\n            qv_unpad = (\n                rearrange(qv, \"b s ... -> (b s) ...\")[indices_q] if has_qv else None\n            )\n        else:\n            query_padding_mask = None\n            q_unpad = q\n            qv_unpad = qv\n            cu_seqlens_q, max_seqlen_q = None, None\n        # Put window_size after QKV randn so that window_size changes from test to test\n        window_size = (\n            (None, None) if not local else tuple(random.randrange(0, seqlen_k) for _ in range(2))\n        )\n        if has_learnable_sink:\n            learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device)\n        else:\n            learnable_sink = None\n\n        seqlen_new = (\n            seqlen_q\n            if seqlen_new_eq_seqlen_q\n            else random.randrange(1, seqlen_q + 1)\n        )\n        cu_seqlens_k_new = None\n        key_new_padding_mask = None\n        if new_kv:\n            k = (\n                torch.randn(\n                    batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype_ref\n                )\n                .to(dtype)\n                .to(dtype_ref)\n            )\n            v = (\n                torch.randn(\n                    batch_size, seqlen_new, nheads_k, dv, device=device, dtype=dtype_ref\n                )\n                .to(dtype)\n                .to(dtype_ref)\n            )\n            if varlen_q:  # k & v are also varlen\n                key_new_padding_mask = generate_random_padding_mask(\n                    seqlen_new, batch_size, device, mode=\"random\"\n                )\n                k_unpad, indices_k, cu_seqlens_k_new, *rest = unpad_input(\n                    k, key_new_padding_mask\n                )\n                v_unpad, *rest = unpad_input(v, key_new_padding_mask)\n            else:\n                k_unpad, v_unpad = k, v\n        else:\n            k, v, k_unpad, v_unpad = None, None, None, None\n        if page_size is None:\n            k_cache = (\n                torch.randn(\n                    batch_size_cache,\n                    seqlen_k,\n                    nheads_k,\n                    d,\n                    device=device,\n                    dtype=dtype_ref,\n                )\n                .to(dtype)\n                .to(dtype_ref)\n            )\n            v_cache = (\n                torch.randn(\n                    batch_size_cache,\n                    seqlen_k,\n                    nheads_k,\n                    dv,\n                    device=device,\n                    dtype=dtype_ref,\n                )\n                .to(dtype)\n                .to(dtype_ref)\n            )\n            page_table = None\n        else:\n            (\n                k_cache,\n                v_cache,\n                page_table,\n                k_cache_paged,\n                v_cache_paged,\n                num_blocks,\n            ) = _generate_block_kvcache(\n                seqlen_k,\n                page_size,\n                batch_size_cache,\n                nheads_k,\n                d,\n                dv,\n                device,\n                dtype,\n                dtype_ref,\n            )\n        if not is_fake_mode():\n            cache_seqlens = torch.randint(\n                0 if new_kv else 1,\n                # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough\n                (\n                    (\n                        seqlen_k\n                        - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new)\n                        + 1\n                    )\n                    if new_kv\n                    else (seqlen_k + 1)\n                ),\n                (batch_size,),\n                dtype=torch.int32,\n                device=device,\n            )\n        else:\n            cache_seqlens = torch.ones(\n                batch_size,\n                dtype=torch.int32,\n                device=device,\n            )\n        if has_leftpad:\n            if not is_fake_mode():\n                cache_leftpad = torch.cat(\n                    [\n                        torch.randint(\n                            0,\n                            cache_seqlens[i].item(),\n                            (1,),\n                            dtype=torch.int32,\n                            device=device,\n                        )\n                        if cache_seqlens[i].item() > 0\n                        else torch.zeros(1, dtype=torch.int32, device=device)\n                        for i in range(batch_size)\n                    ]\n                )\n            else:\n                cache_leftpad = torch.zeros(batch_size, dtype=torch.int32, device=device)\n        else:\n            cache_leftpad = None\n        if has_batch_idx:\n            if not is_fake_mode():\n                cache_batch_idx = torch.randperm(\n                    batch_size_cache, dtype=torch.int32, device=device\n                )[:batch_size]\n            else:\n                cache_batch_idx = torch.arange(\n                    batch_size, dtype=torch.int32, device=device\n                )\n        else:\n            cache_batch_idx = None\n        arange = rearrange(torch.arange(seqlen_k, device=device), \"s -> 1 s\")\n        cache_seqlens_expanded = rearrange(cache_seqlens, \"b -> b 1\")\n        if not new_kv:\n            key_padding_mask = arange < cache_seqlens_expanded\n        else:\n            k_new_seqlens = (\n                key_new_padding_mask.sum(-1, keepdims=True) if varlen_q else seqlen_new\n            )\n            key_padding_mask = arange < cache_seqlens_expanded + k_new_seqlens\n        if has_leftpad:\n            key_padding_mask = torch.logical_and(\n                key_padding_mask,\n                arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k),\n            )\n        # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)\n        rotary_seqlens = cache_seqlens if not has_rotary_seqlens else cache_seqlens // 2\n        if rotary_dim > 0:\n            angle = (\n                torch.rand(\n                    seqlen_k if page_size is None else num_blocks * page_size,\n                    rotary_dim // 2,\n                    device=device,\n                )\n                * 2\n                * math.pi\n            )\n            cos = torch.cos(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref)\n            sin = torch.sin(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref)\n            if causal or local:\n                q_ro = apply_rotary_emb(\n                    q,\n                    cos,\n                    sin,\n                    seqlen_offsets=rotary_seqlens,\n                    interleaved=rotary_interleaved,\n                )\n            else:\n                q_ro = rearrange(\n                    apply_rotary_emb(\n                        rearrange(q, \"b s h d -> b 1 (s h) d\"),\n                        cos,\n                        sin,\n                        seqlen_offsets=rotary_seqlens,\n                        interleaved=rotary_interleaved,\n                    ),\n                    \"b 1 (s h) d -> b s h d\",\n                    s=seqlen_q,\n                )\n            # q_ro = q\n            k_ro = apply_rotary_emb(\n                k,\n                cos,\n                sin,\n                seqlen_offsets=rotary_seqlens,\n                interleaved=rotary_interleaved,\n            )\n        else:\n            cos, sin = None, None\n            q_ro, k_ro = q, k\n        # k_cache[:, 64:] = -1\n        k_cache_ref = (\n            k_cache if not has_batch_idx else k_cache[cache_batch_idx]\n        ).clone()\n        v_cache_ref = (\n            v_cache if not has_batch_idx else v_cache[cache_batch_idx]\n        ).clone()\n        if new_kv:\n            update_mask = torch.logical_and(\n                cache_seqlens_expanded <= arange,\n                arange < cache_seqlens_expanded + k_new_seqlens,\n            )\n            k_to_update = rearrange(k_ro, \"b s ... -> (b s) ...\")\n            v_to_update = rearrange(v, \"b s ... -> (b s) ...\")\n            if varlen_q:\n                k_to_update = k_to_update[indices_k]\n                v_to_update = v_to_update[indices_k]\n            k_cache_ref[update_mask] = k_to_update\n            v_cache_ref[update_mask] = v_to_update\n        k_cache_rep = repeat(\n            k_cache_ref, \"b s h d -> b s (h g) d\", g=nheads // nheads_k\n        )\n        v_cache_rep = repeat(\n            v_cache_ref, \"b s h d -> b s (h g) d\", g=nheads // nheads_k\n        )\n        out_ref, _ = attention_ref(\n            q_ro,\n            k_cache_rep,\n            v_cache_rep,\n            query_padding_mask,\n            key_padding_mask,\n            causal=causal,\n            qv=qv,\n            window_size=window_size,\n            learnable_sink=learnable_sink,\n            attention_chunk=attention_chunk,\n            key_leftpad=cache_leftpad,\n        )\n        out_pt, _ = attention_ref(\n            q_ro,\n            k_cache_rep,\n            v_cache_rep,\n            query_padding_mask,\n            key_padding_mask,\n            causal=causal,\n            qv=qv,\n            window_size=window_size,\n            learnable_sink=learnable_sink,\n            attention_chunk=attention_chunk,\n            upcast=False,\n            reorder_ops=True,\n            key_leftpad=cache_leftpad,\n            intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None,\n        )\n        q = q.to(dtype)\n        q_unpad = q_unpad.to(dtype) if varlen_q else None\n        k_cache = k_cache.to(dtype)\n        v_cache = v_cache.to(dtype)\n        k_cache_paged = k_cache_paged.to(dtype) if page_size is not None else None\n        v_cache_paged = v_cache_paged.to(dtype) if page_size is not None else None\n        k = k.to(dtype) if k is not None else None\n        v = v.to(dtype) if v is not None else None\n        k_unpad = k_unpad.to(dtype) if k_unpad is not None else None\n        v_unpad = v_unpad.to(dtype) if v_unpad is not None else None\n        qv = qv.to(dtype) if qv is not None else None\n        qv_unpad = qv_unpad.to(dtype) if (varlen_q and qv is not None) else None\n        cos = cos.to(dtype) if cos is not None else None\n        sin = sin.to(dtype) if sin is not None else None\n        k_cache_saved = k_cache.clone() if page_size is None else k_cache_paged.clone()\n        v_cache_saved = v_cache.clone() if page_size is None else v_cache_paged.clone()\n        # num_splits_vals = [1, 0]\n        # SplitKV is not supported for hdim >= 192\n        num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT else [1]\n        # precompute_metadata_vals = [False, True]\n        precompute_metadata_vals = [False]\n        for num_splits, precompute_metadata in itertools.product(\n            num_splits_vals, precompute_metadata_vals\n        ):\n            # SplitKV not supported on SM90 - skip this iteration\n            if IS_SM90 and num_splits > 1:\n                continue\n            # if precompute_metadata:\n            #     scheduler_metadata = get_scheduler_metadata(\n            #         batch_size, max_seqlen_q if varlen_q else seqlen_q, seqlen_k, nheads, nheads_k, d,\n            #         cache_seqlens, q.dtype, headdim_v=dv, cu_seqlens_q=cu_seqlens_q,\n            #         cu_seqlens_k_new=cu_seqlens_k_new, cache_leftpad=cache_leftpad,\n            #         max_seqlen_k_new=seqlen_new, page_size=page_size,\n            #         causal=causal, window_size=window_size, attention_chunk=attention_chunk,\n            #         num_splits=num_splits\n            #     )\n            # else:\n            #     scheduler_metadata = None\n            scheduler_metadata = None\n            # Repeat to test metadata reuse\n            for _ in range(1 if not precompute_metadata else 2):\n                if page_size is None:\n                    k_cache.copy_(k_cache_saved)\n                    v_cache.copy_(v_cache_saved)\n                else:\n                    k_cache_paged.copy_(k_cache_saved)\n                    v_cache_paged.copy_(v_cache_saved)\n                # out, lse, *rest = flash_attn_with_kvcache(\n                out, lse, *rest = flash_attn_varlen_func(\n                    q if not varlen_q else q_unpad,\n                    k_cache if page_size is None else k_cache_paged,\n                    v_cache if page_size is None else v_cache_paged,\n                    # k if not new_kv or not varlen_q else k_unpad,\n                    # v if not new_kv or not varlen_q else v_unpad,\n                    # qv=qv if not varlen_q else qv_unpad,\n                    # rotary_cos=cos,\n                    # rotary_sin=sin,\n                    seqused_k=cache_seqlens,\n                    # cache_batch_idx=cache_batch_idx,\n                    # cache_leftpad=cache_leftpad,\n                    page_table=page_table,\n                    cu_seqlens_q=cu_seqlens_q,\n                    # cu_seqlens_k_new=cu_seqlens_k_new,\n                    # rotary_seqlens=rotary_seqlens,\n                    causal=causal,\n                    window_size=window_size,\n                    learnable_sink=learnable_sink,\n                    # attention_chunk=attention_chunk,\n                    # rotary_interleaved=rotary_interleaved,\n                    # scheduler_metadata=scheduler_metadata,\n                    num_splits=num_splits,\n                    # return_softmax_lse=True\n                )\n                if varlen_q:\n                    out = output_pad_fn(out)\n                if is_fake_mode():\n                    # no more flash_attn cutedsl calls for the rest of the loop\n                    # skip data-dependent postprocessing\n                    continue\n                # out = flash_attn_with_kvcache(\n                #     q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size\n                # )\n                # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size)\n                # qk = torch.einsum(\"bqhd,bkhd->bhqk\", q, k_cache_ref)\n                # m = qk.amax(-1, keepdim=True)\n                # s_tmp = torch.exp((qk - m) / math.sqrt(d))\n                # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref)\n                # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)\n                # probs = torch.softmax(qk, dim=-1)\n                print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n                print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n                print(f\"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}\")\n                print(f\"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}\")\n                # breakpoint()\n\n                # Check that FlashAttention's numerical error is at most twice the numerical error\n                # of a Pytorch implementation.\n                if new_kv:\n                    if page_size is None:\n                        k_cache_select = (\n                            k_cache.to(dtype_ref)\n                            if not has_batch_idx\n                            else k_cache.to(dtype_ref)[cache_batch_idx]\n                        )\n                        v_cache_select = (\n                            v_cache.to(dtype_ref)\n                            if not has_batch_idx\n                            else v_cache.to(dtype_ref)[cache_batch_idx]\n                        )\n                    else:\n                        k_cache_select = rearrange(\n                            k_cache_paged.to(dtype_ref)[\n                                (\n                                    page_table\n                                    if not has_batch_idx\n                                    else page_table[cache_batch_idx]\n                                ).flatten()\n                            ],\n                            \"(b nblocks) block_size ... -> b (nblocks block_size) ...\",\n                            b=batch_size,\n                        )[:, :seqlen_k].to(dtype_ref)\n                        v_cache_select = rearrange(\n                            v_cache_paged.to(dtype_ref)[\n                                (\n                                    page_table\n                                    if not has_batch_idx\n                                    else page_table[cache_batch_idx]\n                                ).flatten()\n                            ],\n                            \"(b nblocks) block_size ... -> b (nblocks block_size) ...\",\n                            b=batch_size,\n                        )[:, :seqlen_k].to(dtype_ref)\n                    k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref)\n                    v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref)\n                    if dtype is not torch.float8_e4m3fn:\n                        assert torch.equal(v_cache_select, v_cache_ref)\n                    else:\n                        assert torch.allclose(\n                            v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3\n                        )\n                    # breakpoint()\n                    # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn:\n                    if rotary_dim == 0:\n                        assert torch.equal(k_cache_select, k_cache_ref)\n                    else:\n                        # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3):\n                        #     breakpoint()\n                        if dtype is not torch.float8_e4m3fn:\n                            assert torch.allclose(\n                                k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3\n                            )\n                        else:\n                            assert torch.allclose(\n                                k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1\n                            )\n                mult = 4 if dtype == torch.float8_e4m3fn else 2\n                assert (out - out_ref).abs().max().item() <= mult * (\n                    out_pt - out_ref\n                ).abs().max().item() + 1e-5\n                mult_mean = 3 if dtype == torch.float8_e4m3fn else 1.5\n                assert (out - out_ref).abs().mean().item() <= mult_mean * (\n                    out_pt - out_ref\n                ).abs().mean().item()\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.bfloat16])\n@pytest.mark.parametrize(\"causal\", [False, True])\n@pytest.mark.parametrize(\"d\", [64, 128])\n@pytest.mark.parametrize(\"seqlen_q,seqlen_k\", [(128, 128), (256, 256)])\n@maybe_fake_tensor_mode(USE_FAKE_TENSOR)\ndef test_flash_attn_bwd_preallocated_outputs(seqlen_q, seqlen_k, d, causal, dtype):\n    from flash_attn.cute.interface import _flash_attn_fwd, _flash_attn_bwd\n\n    device = \"cuda\"\n    torch.random.manual_seed(42)\n    batch_size = 2\n    nheads = 4\n\n    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)\n\n    out, lse = _flash_attn_fwd(q, k, v, causal=causal, return_lse=True)\n    dout = torch.randn_like(out)\n\n    dq_ref, dk_ref, dv_ref = _flash_attn_bwd(q, k, v, out, dout, lse, causal=causal)\n\n    dq = torch.empty_like(q)\n    dk = torch.empty_like(k)\n    dv = torch.empty_like(v)\n    dq_out, dk_out, dv_out = _flash_attn_bwd(\n        q, k, v, out, dout, lse, causal=causal, dq=dq, dk=dk, dv=dv\n    )\n\n    if is_fake_mode():\n        return\n    assert dq_out is dq\n    assert dk_out is dk\n    assert dv_out is dv\n    assert torch.allclose(dq, dq_ref, atol=1e-5, rtol=1e-5)\n    assert torch.allclose(dk, dk_ref, atol=1e-5, rtol=1e-5)\n    assert torch.allclose(dv, dv_ref, atol=1e-5, rtol=1e-5)\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.bfloat16])\n@pytest.mark.parametrize(\"causal\", [False, True])\n@pytest.mark.parametrize(\"d\", [64, 128])\n@pytest.mark.parametrize(\"seqlen_q,seqlen_k\", [(128, 128), (256, 256)])\n@maybe_fake_tensor_mode(USE_FAKE_TENSOR)\ndef test_flash_attn_lse_grad(seqlen_q, seqlen_k, d, causal, dtype):\n    \"\"\"Test that gradient flows through the returned LSE tensor.\"\"\"\n    device = \"cuda\"\n    torch.random.manual_seed(42)\n    batch_size = 2\n    nheads = 4\n\n    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)\n\n    out, lse = flash_attn_func(q, k, v, causal=causal, return_lse=True)\n\n    if is_fake_mode():\n        return\n\n    assert lse is not None\n    assert lse.requires_grad\n\n    # Compute loss = sum(out * g) + sum(lse * dlse_weight) to test gradient flows through both\n    g = torch.randn_like(out)\n    dlse_weight = torch.randn_like(lse)\n    loss = (out * g).sum() + (lse * dlse_weight).sum()\n    dq, dk, dv = torch.autograd.grad(loss, (q, k, v))\n\n    # Compare against reference: manually compute what the gradients should be\n    # Reference: standard attention in float\n    q_ref = q.detach().float().requires_grad_()\n    k_ref = k.detach().float().requires_grad_()\n    v_ref = v.detach().float().requires_grad_()\n    # (batch, seqlen_q, nheads, d) -> (batch, nheads, seqlen_q, d)\n    qk = torch.einsum(\"bshd,bthd->bhst\", q_ref, k_ref) / (d ** 0.5)\n    if causal:\n        mask = torch.triu(torch.ones(seqlen_q, seqlen_k, device=device, dtype=torch.bool), diagonal=seqlen_k - seqlen_q + 1)\n        qk = qk.masked_fill(mask, float(\"-inf\"))\n    lse_ref = torch.logsumexp(qk, dim=-1)  # (batch, nheads, seqlen_q)\n    p = torch.softmax(qk, dim=-1)\n    # v_ref: (batch, seqlen_k, nheads, d)\n    out_ref = torch.einsum(\"bhst,bthd->bshd\", p, v_ref)\n    loss_ref = (out_ref * g.float()).sum() + (lse_ref * dlse_weight.float()).sum()\n    dq_ref, dk_ref, dv_ref = torch.autograd.grad(loss_ref, (q_ref, k_ref, v_ref))\n\n    # Use relaxed tolerances since flash_attn operates in bf16 while reference is float32.\n    # The reference is also not a perfect bf16 simulation (it doesn't reorder ops), so\n    # we use a generous tolerance.\n    print(f\"dQ max diff: {(dq.float() - dq_ref).abs().max().item()}\")\n    print(f\"dK max diff: {(dk.float() - dk_ref).abs().max().item()}\")\n    print(f\"dV max diff: {(dv.float() - dv_ref).abs().max().item()}\")\n    # Absolute tolerance: bf16 has ~0.004-0.02 error for these sizes\n    atol = 0.02\n    assert (dq.float() - dq_ref).abs().max().item() <= atol, f\"dQ error too large\"\n    assert (dk.float() - dk_ref).abs().max().item() <= atol, f\"dK error too large\"\n    assert (dv.float() - dv_ref).abs().max().item() <= atol, f\"dV error too large\"\n\n    # Also test: gradient with only dLSE (no dO)\n    out2, lse2 = flash_attn_func(q, k, v, causal=causal, return_lse=True)\n    loss_lse_only = (lse2 * dlse_weight).sum()\n    dq2, dk2, dv2 = torch.autograd.grad(loss_lse_only, (q, k, v))\n\n    q_ref2 = q.detach().float().requires_grad_()\n    k_ref2 = k.detach().float().requires_grad_()\n    qk2 = torch.einsum(\"bshd,bthd->bhst\", q_ref2, k_ref2) / (d ** 0.5)\n    if causal:\n        qk2 = qk2.masked_fill(mask, float(\"-inf\"))\n    lse_ref2 = torch.logsumexp(qk2, dim=-1)\n    loss_ref2 = (lse_ref2 * dlse_weight.float()).sum()\n    dq_ref2, dk_ref2 = torch.autograd.grad(loss_ref2, (q_ref2, k_ref2))\n\n    print(f\"LSE-only dQ max diff: {(dq2.float() - dq_ref2).abs().max().item()}\")\n    print(f\"LSE-only dK max diff: {(dk2.float() - dk_ref2).abs().max().item()}\")\n    # dV should be zero when only LSE gradient flows (LSE doesn't depend on V)\n    print(f\"LSE-only dV max: {dv2.abs().max().item()}\")\n    assert dv2.abs().max().item() == 0.0, \"dV should be zero when loss depends only on LSE\"\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.bfloat16])\n@pytest.mark.parametrize(\"causal\", [False, True])\n@pytest.mark.parametrize(\"d\", [128])\n@pytest.mark.parametrize(\"seqlen_q,seqlen_k\", [(128, 128)])\n@maybe_fake_tensor_mode(USE_FAKE_TENSOR)\ndef test_flash_attn_lse_grad_unused(seqlen_q, seqlen_k, d, causal, dtype):\n    \"\"\"Test return_lse=True when LSE is returned but not used in the loss.\n\n    With set_materialize_grads(False), dlse should be None (not a zero tensor),\n    so no extra zeroing kernel is launched. Gradients should match the standard\n    backward (without return_lse).\n    \"\"\"\n    device = \"cuda\"\n    torch.random.manual_seed(42)\n    batch_size = 2\n    nheads = 4\n\n    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    g = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype)\n\n    # Case 1: return_lse=False (standard path, lse marked non-differentiable)\n    out1, lse1 = flash_attn_func(q, k, v, causal=causal, return_lse=False)\n    if is_fake_mode():\n        return\n    dq1, dk1, dv1 = torch.autograd.grad(out1, (q, k, v), g)\n\n    # Case 2: return_lse=True but lse NOT used in loss (dlse should be None)\n    out2, lse2 = flash_attn_func(q, k, v, causal=causal, return_lse=True)\n    dq2, dk2, dv2 = torch.autograd.grad(out2, (q, k, v), g)\n\n    # Case 3: return_lse=True and lse IS used in loss\n    out3, lse3 = flash_attn_func(q, k, v, causal=causal, return_lse=True)\n    dlse_weight = torch.randn_like(lse3)\n    loss3 = (out3 * g).sum() + (lse3 * dlse_weight).sum()\n    dq3, dk3, dv3 = torch.autograd.grad(loss3, (q, k, v))\n\n    # Cases 1 and 2 should produce identical gradients\n    assert torch.equal(dq1, dq2), \"dQ should be identical when LSE is unused\"\n    assert torch.equal(dk1, dk2), \"dK should be identical when LSE is unused\"\n    assert torch.equal(dv1, dv2), \"dV should be identical when LSE is unused\"\n\n    # Case 3 should differ from case 1 (LSE gradient adds extra contribution to dQ, dK)\n    assert not torch.equal(dq1, dq3), \"dQ should differ when LSE gradient is included\"\n    assert not torch.equal(dk1, dk3), \"dK should differ when LSE gradient is included\"\n    # dV should be the same since LSE doesn't depend on V\n    assert torch.equal(dv1, dv3), \"dV should be identical since LSE doesn't depend on V\"\n\n    print(\"Case 1 vs 2 (unused LSE): dQ diff =\", (dq1 - dq2).abs().max().item())\n    print(\"Case 1 vs 3 (used LSE):   dQ diff =\", (dq1 - dq3).abs().max().item())\n    print(\"Case 1 vs 3 (used LSE):   dK diff =\", (dk1 - dk3).abs().max().item())\n    print(\"Case 1 vs 3 (used LSE):   dV diff =\", (dv1 - dv3).abs().max().item())\n\n\ndef _generate_block_kvcache(\n    seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype, dtype_ref\n):\n    num_blocks = math.ceil(seqlen_k / page_size) * batch_size * 3\n    k_cache_paged = (\n        torch.randn(num_blocks, page_size, nheads_k, d, device=device, dtype=dtype_ref)\n        .to(dtype)\n        .to(dtype_ref)\n    )\n    v_cache_paged = (\n        torch.randn(num_blocks, page_size, nheads_k, dv, device=device, dtype=dtype_ref)\n        .to(dtype)\n        .to(dtype_ref)\n    )\n    page_table = rearrange(\n        torch.randperm(num_blocks, dtype=torch.int32, device=device),\n        \"(b nblocks) -> b nblocks\",\n        b=batch_size,\n    )\n    k_cache = rearrange(\n        k_cache_paged[page_table.flatten()],\n        \"(b nblocks) block_size ... -> b (nblocks block_size) ...\",\n        b=batch_size,\n    )[:, :seqlen_k]\n    v_cache = rearrange(\n        v_cache_paged[page_table.flatten()],\n        \"(b nblocks) block_size ... -> b (nblocks block_size) ...\",\n        b=batch_size,\n    )[:, :seqlen_k]\n    return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks\n\n\n@pytest.mark.parametrize(\"page_size\", [16, 64, 256])\n@pytest.mark.parametrize(\"seqlen_q\", [64, 128, 256])\n@maybe_fake_tensor_mode(USE_FAKE_TENSOR)\ndef test_flash_attn_paged_deepseek(seqlen_q, page_size):\n    \"\"\"Regression test: paged non-TMA with DeepSeek MLA shape (d=192, dv=128).\n    seqlen_q<=128 triggers q_stage=1, seqlen_q>128 triggers q_stage=2.\n    \"\"\"\n    if IS_SM90:\n        pytest.skip(\"paged KV not supported on SM90\")\n    device = \"cuda\"\n    dtype = torch.bfloat16\n    d, dv = 192, 128\n    nheads = 16\n    nheads_kv = 16\n\n    torch.random.manual_seed(0)\n    q = torch.randn(seqlen_q, nheads, d, device=device, dtype=dtype)\n    k = torch.randn(seqlen_q, nheads_kv, d, device=device, dtype=dtype)\n    v = torch.randn(seqlen_q, nheads_kv, dv, device=device, dtype=dtype)\n    cu_seqlens = torch.tensor([0, seqlen_q], dtype=torch.int32, device=device)\n\n    # Non-paged reference\n    out_ref, _ = flash_attn_varlen_func(\n        q, k, v, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens,\n        max_seqlen_q=seqlen_q, max_seqlen_k=seqlen_q, causal=True,\n    )\n\n    # Paged\n    num_pages = (seqlen_q + page_size - 1) // page_size\n    k_cache_paged = torch.zeros(num_pages, page_size, nheads_kv, d, device=device, dtype=dtype)\n    v_cache_paged = torch.zeros(num_pages, page_size, nheads_kv, dv, device=device, dtype=dtype)\n    for i in range(seqlen_q):\n        k_cache_paged[i // page_size, i % page_size] = k[i]\n        v_cache_paged[i // page_size, i % page_size] = v[i]\n    page_table = torch.arange(num_pages, dtype=torch.int32, device=device).unsqueeze(0)\n    cache_seqlens = torch.tensor([seqlen_q], dtype=torch.int32, device=device)\n\n    out, _ = flash_attn_varlen_func(\n        q, k_cache_paged, v_cache_paged,\n        cu_seqlens_q=cu_seqlens, cu_seqlens_k=None,\n        max_seqlen_q=seqlen_q, max_seqlen_k=None,\n        seqused_k=cache_seqlens, page_table=page_table, causal=True,\n    )\n\n    if is_fake_mode():\n        return\n\n    print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n    print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n    assert torch.equal(out, out_ref)\n\n\n@pytest.mark.parametrize(\"head_dim\", [4, 148, 288])\ndef test_flash_attn_invalid_head_dim(head_dim):\n    device = \"cuda\"\n    dtype = torch.bfloat16\n    batch_size, seqlen, nheads = 1, 64, 4\n\n    q = torch.randn(batch_size, seqlen, nheads, head_dim, device=device, dtype=dtype)\n    k = torch.randn(batch_size, seqlen, nheads, head_dim, device=device, dtype=dtype)\n    v = torch.randn(batch_size, seqlen, nheads, head_dim, device=device, dtype=dtype)\n\n    with pytest.raises(AssertionError, match=re.escape(f\"(head_dim, head_dim_v)=({head_dim}, {head_dim}) is not supported on SM\")):\n        flash_attn_func(q, k, v)\n"
  },
  {
    "path": "tests/cute/test_flash_attn_combine.py",
    "content": "# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n\nimport os\n\nimport pytest\nimport torch\n\nfrom flash_attn.cute.testing import (\n    maybe_fake_tensor_mode,\n    is_fake_mode,\n)\nfrom flash_attn.cute.interface import (\n    flash_attn_combine,\n)\n\nUSE_FAKE_TENSOR = int(os.getenv(\"FLASH_ATTENTION_FAKE_TENSOR\", 0)) == 1\n\n\ndef attention_combine_ref(out_partial, lse_partial):\n    \"\"\"\n    out_partial: (num_splits, batch_size, seqlen, nheads, d)\n    lse_partial: (num_splits, batch_size, seqlen, nheads)\n    \"\"\"\n    lse = torch.logsumexp(lse_partial, dim=0)\n    scale = torch.exp(lse_partial - lse)\n    scale = torch.where(\n        torch.isinf(scale) | torch.isnan(scale), torch.zeros_like(scale), scale\n    )\n    out = (scale.unsqueeze(-1) * out_partial).sum(0)\n    return out, lse\n\n\ndef check_combine_results(out, lse, out_ref, lse_ref, dtype):\n    \"\"\"Check combine kernel output against reference for a single (seqlen, nheads, d) chunk.\"\"\"\n    out_pt = out_ref.to(dtype)\n    print(f\"LSE max diff: {(lse - lse_ref).abs().max().item()}, \"\n          f\"Output max diff: {(out - out_ref).abs().max().item()}, \"\n          f\"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}\")\n    assert torch.allclose(lse, lse_ref, atol=1e-5, rtol=1e-5)\n    assert (\n        (out - out_ref).abs().max().item()\n        <= 2 * (out_pt - out_ref).abs().max().item()\n    ) or torch.allclose(out, out_pt, atol=1e-5, rtol=1e-5)\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float32, torch.float16, torch.bfloat16])\n# @pytest.mark.parametrize(\"dtype\", [torch.float32])\n# @pytest.mark.parametrize(\"d\", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])\n@pytest.mark.parametrize(\"d\", [64, 96, 128, 192, 256, 512])\n# @pytest.mark.parametrize(\"d\", [128])\n@pytest.mark.parametrize(\"seqlen\", [1, 2, 3, 32, 64, 256, 113, 108, 640, 1024])\n# @pytest.mark.parametrize(\"seqlen\", [12, 32, 64, 256, 112, 108, 640, 1024, 2048, 8192])\n# @pytest.mark.parametrize(\"seqlen\", [15])\n@pytest.mark.parametrize(\"num_splits\", [1, 2, 3, 5, 17, 32, 55, 97, 133])\n# @pytest.mark.parametrize(\"num_splits\", [1, 2, 3, 5, 11])\n# @pytest.mark.parametrize(\"num_splits\", [11])\n@maybe_fake_tensor_mode(USE_FAKE_TENSOR)\ndef test_flash_attn_combine(num_splits, seqlen, d, dtype):\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(1)\n    batch_size = 5\n    nheads = 16\n    # batch_size = 1\n    # nheads = 1\n    # Create tensors in the expected format: (num_splits, batch_size, seqlen, nheads, d) and (num_splits, batch_size, seqlen, nheads)\n    out_partial = torch.randn(\n        num_splits * 2,\n        batch_size,\n        nheads,\n        seqlen,\n        d,\n        device=device,\n        dtype=torch.float32,\n    ).transpose(2, 3)[:num_splits]  # To test non-contiguous tensor\n    lse_partial = torch.randn(\n        num_splits, batch_size, nheads * 2, seqlen, device=device, dtype=torch.float32\n    ).transpose(-1, -2)[:, :, :, :nheads]  # To test non-contiguous tensor\n    # To test short-circuiting based on num_splits\n    lse_partial[num_splits // 2 :, : batch_size // 3] = -float(\"inf\")\n\n    # Test with LSE returned (default behavior)\n    out, lse = flash_attn_combine(\n        out_partial, lse_partial, out_dtype=dtype, return_lse=True\n    )\n    if is_fake_mode():\n        return\n    out_ref, lse_ref = attention_combine_ref(out_partial, lse_partial)\n    check_combine_results(out, lse, out_ref, lse_ref, dtype)\n\n    # Test with LSE not returned\n    out_no_lse, lse_no_lse = flash_attn_combine(\n        out_partial, lse_partial, out_dtype=dtype, return_lse=False\n    )\n    assert lse_no_lse is None, \"LSE should be None when return_lse=False\"\n    assert torch.allclose(out_no_lse, out, atol=1e-5, rtol=1e-5), (\n        \"Output should be the same regardless of return_lse\"\n    )\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float32, torch.float16, torch.bfloat16])\n# @pytest.mark.parametrize(\"dtype\", [torch.bfloat16])\n@pytest.mark.parametrize(\"d\", [64, 96, 128, 256])\n# @pytest.mark.parametrize(\"d\", [128])\n@pytest.mark.parametrize(\"seqlen\", [1, 32, 113, 256, 1024])\n# @pytest.mark.parametrize(\"seqlen\", [113])\n@pytest.mark.parametrize(\"num_splits\", [2, 5, 17, 55])\n# @pytest.mark.parametrize(\"num_splits\", [5])\n@pytest.mark.parametrize(\n    \"varlen_mode\",\n    [\"cu_seqlens\", \"seqused\", \"cu_seqlens_seqused\"],\n)\n# @pytest.mark.parametrize(\"varlen_mode\", [\"cu_seqlens\"])\n@maybe_fake_tensor_mode(USE_FAKE_TENSOR)\ndef test_flash_attn_combine_varlen(varlen_mode, num_splits, seqlen, d, dtype):\n    device = \"cuda\"\n    torch.random.manual_seed(1)\n    batch_size = 3\n    nheads = 8\n    use_cu_seqlens = \"cu_seqlens\" in varlen_mode\n    use_seqused = \"seqused\" in varlen_mode\n\n    # Generate variable-length sequences\n    seqlens = torch.randint(1, seqlen + 1, (batch_size,), device=device, dtype=torch.int32)\n    # For cu_seqlens+seqused mode, seqused < seqlen (kernel processes fewer tokens)\n    seqused_vals = (\n        torch.clamp(\n            seqlens - torch.randint(0, max(1, seqlen // 4), (batch_size,), device=device, dtype=torch.int32),\n            min=1,\n        )\n        if use_cu_seqlens and use_seqused\n        else seqlens\n    )\n\n    if use_cu_seqlens:\n        # Packed varlen layout: (num_splits, total_q, nheads, d)\n        total_q = seqlens.sum().item()\n        cu_seqlens_q = torch.zeros(batch_size + 1, device=device, dtype=torch.int32)\n        cu_seqlens_q[1:] = torch.cumsum(seqlens, dim=0)\n\n        out_partial = torch.randn(\n            num_splits * 2, total_q, nheads, d, device=device, dtype=torch.float32,\n        )[:num_splits]  # Non-contiguous in splits dim\n        # lse_partial needs stride(-2)==1 (seqlen dim contiguous)\n        lse_partial = torch.randn(\n            num_splits, nheads, total_q, device=device, dtype=torch.float32\n        ).transpose(-1, -2)\n        lse_partial[num_splits // 2:, :total_q // 3] = -float(\"inf\")\n\n        out, lse = flash_attn_combine(\n            out_partial, lse_partial, out_dtype=dtype,\n            cu_seqlens=cu_seqlens_q,\n            seqused=seqused_vals if use_seqused else None,\n            return_lse=True,\n        )\n        if is_fake_mode():\n            return\n\n        # Reference on full packed tensor\n        out_ref, lse_ref = attention_combine_ref(\n            out_partial.unsqueeze(1), lse_partial.unsqueeze(1)\n        )\n        out_ref = out_ref.squeeze(0)\n        lse_ref = lse_ref.squeeze(0)\n\n        # Validate per-batch (only seqused_vals tokens are guaranteed correct)\n        for i in range(batch_size):\n            start = cu_seqlens_q[i].item()\n            sl = seqused_vals[i].item()\n            check_combine_results(\n                out[start:start + sl], lse[start:start + sl],\n                out_ref[start:start + sl], lse_ref[start:start + sl], dtype,\n            )\n\n        # Also test return_lse=False\n        out_no_lse, lse_no_lse = flash_attn_combine(\n            out_partial, lse_partial, out_dtype=dtype,\n            cu_seqlens=cu_seqlens_q,\n            seqused=seqused_vals if use_seqused else None,\n            return_lse=False,\n        )\n        assert lse_no_lse is None\n        # Only compare valid positions (beyond seqused, output is undefined)\n        for i in range(batch_size):\n            start = cu_seqlens_q[i].item()\n            sl = seqused_vals[i].item()\n            assert torch.allclose(out_no_lse[start:start + sl], out[start:start + sl], atol=1e-5, rtol=1e-5)\n\n    else:\n        # seqused only — batched layout: (num_splits, batch, max_seqlen, nheads, d)\n        max_seqlen = seqlens.max().item()\n        out_partial = torch.randn(\n            num_splits, batch_size, max_seqlen, nheads, d, device=device, dtype=torch.float32,\n        )\n        # lse_partial needs stride(-2)==1 (seqlen dim contiguous)\n        lse_partial = torch.randn(\n            num_splits, batch_size, nheads, max_seqlen, device=device, dtype=torch.float32,\n        ).transpose(-1, -2)\n        lse_partial[num_splits // 2:, :batch_size // 2] = -float(\"inf\")\n        # Zero out / -inf beyond seqused so reference matches kernel\n        for i in range(batch_size):\n            out_partial[:, i, seqlens[i]:] = 0\n            lse_partial[:, i, seqlens[i]:] = -float(\"inf\")\n\n        out, lse = flash_attn_combine(\n            out_partial, lse_partial, out_dtype=dtype, seqused=seqlens, return_lse=True,\n        )\n        if is_fake_mode():\n            return\n\n        out_ref, lse_ref = attention_combine_ref(out_partial, lse_partial)\n\n        # Validate per-batch (only seqused tokens)\n        for i in range(batch_size):\n            sl = seqlens[i].item()\n            check_combine_results(\n                out[i, :sl], lse[i, :sl],\n                out_ref[i, :sl], lse_ref[i, :sl], dtype,\n            )\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16])\n# @pytest.mark.parametrize(\"dtype\", [torch.bfloat16])\n@pytest.mark.parametrize(\"d\", [64, 128, 256])\n# @pytest.mark.parametrize(\"d\", [128])\n@pytest.mark.parametrize(\"seqlen\", [32, 113, 256])\n# @pytest.mark.parametrize(\"seqlen\", [113])\n@pytest.mark.parametrize(\"num_splits\", [2, 5, 17])\n# @pytest.mark.parametrize(\"num_splits\", [5])\n@maybe_fake_tensor_mode(USE_FAKE_TENSOR)\ndef test_flash_attn_combine_varlen_batch_idx(num_splits, seqlen, d, dtype):\n    \"\"\"Test that varlen_batch_idx correctly remaps virtual batch indices to real batch indices.\n\n    varlen_batch_idx maps blockIdx.z (virtual batch) -> real batch index. The kernel\n    reads AND writes using the remapped batch_idx, so with a permutation the output\n    should match running without varlen_batch_idx (each real batch is processed once).\n\n    We also test with seqused to verify interaction with variable-length sequences.\n    \"\"\"\n    device = \"cuda\"\n    torch.random.manual_seed(42)\n    batch_size = 4\n    nheads = 8\n\n    # Create batched input data\n    out_partial = torch.randn(\n        num_splits, batch_size, seqlen, nheads, d, device=device, dtype=torch.float32,\n    )\n    lse_partial = torch.randn(\n        num_splits, batch_size, nheads, seqlen, device=device, dtype=torch.float32,\n    ).transpose(-1, -2)  # stride(-2)==1\n    lse_partial[num_splits // 2:, :batch_size // 2] = -float(\"inf\")\n\n    # Create a permuted batch index mapping: virtual batch -> real batch\n    perm = torch.tensor([2, 0, 3, 1], device=device, dtype=torch.int32)\n    assert perm.shape[0] == batch_size\n\n    # Also test with seqused to verify interaction with varlen_batch_idx\n    seqused = torch.randint(1, seqlen + 1, (batch_size,), device=device, dtype=torch.int32)\n    # Zero out / -inf beyond seqused so reference matches kernel\n    for i in range(batch_size):\n        out_partial[:, i, seqused[i]:] = 0\n        lse_partial[:, i, seqused[i]:] = -float(\"inf\")\n\n    # Run with varlen_batch_idx and seqused via public API\n    out, lse = flash_attn_combine(\n        out_partial, lse_partial, out_dtype=dtype,\n        seqused=seqused,\n        varlen_batch_idx=perm,\n        return_lse=True,\n    )\n    if is_fake_mode():\n        return\n\n    # Reference: standard combine (no remapping needed since perm is a bijection\n    # and both reads and writes use the remapped batch_idx)\n    out_ref, lse_ref = attention_combine_ref(out_partial, lse_partial)\n\n    # The kernel reads from input[perm[v]] and writes to output[perm[v]],\n    # so the net result is output[b] = combine(input[b]) for all b.\n    for b in range(batch_size):\n        sl = seqused[b].item()\n        check_combine_results(\n            out[b, :sl], lse[b, :sl],\n            out_ref[b, :sl], lse_ref[b, :sl], dtype,\n        )\n"
  },
  {
    "path": "tests/cute/test_flash_attn_fast.py",
    "content": "# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n# Fast subset of test_flash_attn.py for quick iteration.\n# Covers: causal/noncausal, varlen/not varlen, MHA/GQA, split/not split, fwd+bwd.\n\nimport os\nimport random\n\nimport pytest\nimport torch\n\nfrom einops import rearrange\n\nfrom flash_attn.cute.testing import (\n    attention_ref,\n    generate_random_padding_mask,\n    generate_qkv,\n    maybe_fake_tensor_mode,\n    is_fake_mode,\n)\nfrom flash_attn.cute.interface import (\n    flash_attn_func,\n    flash_attn_varlen_func,\n    flash_attn_combine,\n)\n\nUSE_FAKE_TENSOR = int(os.getenv(\"FLASH_ATTENTION_FAKE_TENSOR\", 0)) == 1\nIS_SM90 = torch.cuda.get_device_capability()[0] == 9\n\n\n# ---------------------------------------------------------------------------\n# Forward + backward (non-varlen)\n# ---------------------------------------------------------------------------\n\n@pytest.mark.parametrize(\"dtype\", [torch.bfloat16])\n@pytest.mark.parametrize(\"mha_type\", [\"mha\", \"gqa\"])\n@pytest.mark.parametrize(\"num_splits\", [1, 3])\n@pytest.mark.parametrize(\"causal\", [False, True])\n@pytest.mark.parametrize(\"d\", [64, 128])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_k\",\n    [\n        (128, 128),\n        (256, 256),\n        (113, 203),\n        (1024, 1024),\n    ],\n)\n@maybe_fake_tensor_mode(USE_FAKE_TENSOR)\ndef test_flash_attn_output(seqlen_q, seqlen_k, d, causal, num_splits, mha_type, dtype):\n    if IS_SM90 and num_splits > 1:\n        pytest.skip(\"SM90 fwd doens't support num_splits > 1\")\n    device = \"cuda\"\n    torch.random.manual_seed(0)\n    random.seed(0)\n    torch.cuda.empty_cache()\n    batch_size = 4\n    nheads = 6\n    nheads_kv = nheads if mha_type == \"mha\" else 3\n\n    q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype).to(dtype).requires_grad_()\n    k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype).to(dtype).requires_grad_()\n    v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype).to(dtype).requires_grad_()\n\n    q = q_ref.detach().to(dtype).requires_grad_()\n    k = k_ref.detach().to(dtype).requires_grad_()\n    v = v_ref.detach().to(dtype).requires_grad_()\n\n    out_ref, _ = attention_ref(q_ref, k_ref, v_ref, None, None, causal=causal)\n    out_pt, _ = attention_ref(\n        q_ref, k_ref, v_ref, None, None, causal=causal, upcast=False, reorder_ops=True,\n    )\n\n    out, lse = flash_attn_func(q, k, v, causal=causal, num_splits=num_splits)\n\n    if is_fake_mode():\n        return\n\n    fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item()\n    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + fwd_atol\n\n    # Backward (only for non-split, matching d)\n    can_bwd = (\n        num_splits == 1\n        and d <= 128\n        and not (causal and seqlen_k < seqlen_q)\n    )\n    if IS_SM90 and d == 64 and not causal:\n        can_bwd = False  # SM90 d=64 non-causal xfail\n    if not can_bwd:\n        return\n\n    g = torch.randn_like(out)\n    dq, dk, dv = torch.autograd.grad(out, (q, k, v), g)\n\n    dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g)\n    dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g)\n\n    dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item()\n    dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item()\n    dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item()\n    assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + dq_atol\n    assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + dk_atol\n    assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + dv_atol\n\n\n# ---------------------------------------------------------------------------\n# Forward + backward (varlen with cu_seqlens)\n# ---------------------------------------------------------------------------\n\n@pytest.mark.parametrize(\"dtype\", [torch.bfloat16])\n@pytest.mark.parametrize(\"mha_type\", [\"mha\", \"gqa\"])\n@pytest.mark.parametrize(\"causal\", [False, True])\n@pytest.mark.parametrize(\"d\", [64, 128])\n@pytest.mark.parametrize(\"seqlen\", [128, 256, 1024])\n@maybe_fake_tensor_mode(USE_FAKE_TENSOR)\ndef test_flash_attn_varlen_output(seqlen, d, causal, mha_type, dtype):\n    \"\"\"Varlen test with cu_seqlens (packed): equal seqlens so we can compare with non-varlen ref.\"\"\"\n    device = \"cuda\"\n    seed = seqlen + d + int(causal) * 2\n    torch.random.manual_seed(seed)\n    random.seed(seed)\n    batch_size = 9\n    nheads = 6\n    nheads_kv = nheads if mha_type == \"mha\" else 3\n\n    q_ref = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype).to(dtype).requires_grad_()\n    k_ref = torch.randn(batch_size, seqlen, nheads_kv, d, device=device, dtype=dtype).to(dtype).requires_grad_()\n    v_ref = torch.randn(batch_size, seqlen, nheads_kv, d, device=device, dtype=dtype).to(dtype).requires_grad_()\n\n    out_ref, _ = attention_ref(q_ref, k_ref, v_ref, None, None, causal=causal)\n    out_pt, _ = attention_ref(\n        q_ref, k_ref, v_ref, None, None, causal=causal, upcast=False, reorder_ops=True,\n    )\n\n    cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, seqlen, device=device, dtype=torch.int32)\n    q_varlen = rearrange(q_ref.detach(), \"b s h d -> (b s) h d\").requires_grad_()\n    k_varlen = rearrange(k_ref.detach(), \"b s h d -> (b s) h d\").requires_grad_()\n    v_varlen = rearrange(v_ref.detach(), \"b s h d -> (b s) h d\").requires_grad_()\n\n    out_varlen, lse = flash_attn_varlen_func(\n        q_varlen, k_varlen, v_varlen,\n        cu_seqlens, cu_seqlens,\n        seqlen, seqlen,\n        causal=causal,\n    )\n\n    if is_fake_mode():\n        return\n\n    out_reshaped = rearrange(out_varlen, \"(b s) h d -> b s h d\", b=batch_size)\n    fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item()\n    assert (out_reshaped - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + fwd_atol\n\n    # Backward\n    can_bwd = d <= 128\n    if not can_bwd:\n        return\n\n    g = torch.randn_like(out_varlen)\n    dq_varlen, dk_varlen, dv_varlen = torch.autograd.grad(out_varlen, (q_varlen, k_varlen, v_varlen), g)\n\n    assert dq_varlen.isfinite().all(), \"dq contains non-finite values\"\n    assert dk_varlen.isfinite().all(), \"dk contains non-finite values\"\n    assert dv_varlen.isfinite().all(), \"dv contains non-finite values\"\n    assert dq_varlen.abs().max().item() > 0, \"dq is all zeros\"\n    assert dk_varlen.abs().max().item() > 0, \"dk is all zeros\"\n    assert dv_varlen.abs().max().item() > 0, \"dv is all zeros\"\n\n\n# ---------------------------------------------------------------------------\n# Forward + backward (varlen with padding masks — all unpad combinations)\n# Covers 4 compile-key-distinct paths:\n#   (unpad_q, unpad_kv) = (T,T): cu_seqlens for both Q and K\n#   (unpad_q, unpad_kv) = (F,F): seqused for both Q and K\n#   (unpad_q, unpad_kv) = (T,F): cu_seqlens_q + seqused_k\n#   (unpad_q, unpad_kv) = (F,T): seqused_q + cu_seqlens_k\n# ---------------------------------------------------------------------------\n\n@pytest.mark.parametrize(\"dtype\", [torch.bfloat16])\n@pytest.mark.parametrize(\"mha_type\", [\"mha\", \"gqa\"])\n@pytest.mark.parametrize(\"causal\", [False, True])\n@pytest.mark.parametrize(\"d\", [64, 128])\n@pytest.mark.parametrize(\"seqlen\", [128, 256])\n@pytest.mark.parametrize(\n    \"unpad_q,unpad_kv\",\n    [(True, True), (False, False), (True, False), (False, True)],\n)\n@maybe_fake_tensor_mode(USE_FAKE_TENSOR)\ndef test_flash_attn_varlen_unpad_output(seqlen, d, causal, mha_type, unpad_q, unpad_kv, dtype):\n    \"\"\"Varlen test with all 4 (unpad_q, unpad_kv) combos: cu_seqlens vs seqused.\"\"\"\n    device = \"cuda\"\n    seed = seqlen + d + int(causal) * 2 + int(unpad_q) * 7 + int(unpad_kv) * 13\n    torch.random.manual_seed(seed)\n    random.seed(seed)\n    batch_size = 9\n    nheads = 6\n    nheads_kv = nheads if mha_type == \"mha\" else 3\n\n    q = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype)\n    k = torch.randn(batch_size, seqlen, nheads_kv, d, device=device, dtype=dtype)\n    v = torch.randn(batch_size, seqlen, nheads_kv, d, device=device, dtype=dtype)\n    q_ref = q.detach().to(dtype).requires_grad_()\n    k_ref = k.detach().to(dtype).requires_grad_()\n    v_ref = v.detach().to(dtype).requires_grad_()\n\n    query_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode=\"random\")\n    key_padding_mask = query_padding_mask if causal else generate_random_padding_mask(\n        seqlen, batch_size, device, mode=\"random\"\n    )\n\n    (\n        q_unpad_t, k_unpad_t, v_unpad_t, _qv_unpad,\n        cu_seqlens_q, cu_seqlens_k,\n        seqused_q, seqused_k,\n        max_seqlen_q, max_seqlen_k,\n        q_padded, k_padded, v_padded, _qv_padded,\n        output_pad_fn, dq_pad_fn, dk_pad_fn,\n    ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask)\n\n    out_ref, _ = attention_ref(\n        q_ref, k_ref, v_ref, query_padding_mask, key_padding_mask, causal=causal,\n    )\n    out_pt, _ = attention_ref(\n        q_ref, k_ref, v_ref, query_padding_mask, key_padding_mask, causal=causal,\n        upcast=False, reorder_ops=True,\n    )\n\n    # Select Q input: packed (unpad) or padded (seqused)\n    if unpad_q:\n        q_in = q_unpad_t.detach().to(dtype).requires_grad_()\n    else:\n        q_in = q.detach().to(dtype).requires_grad_()\n    # Select KV input: packed (unpad) or padded (seqused)\n    if unpad_kv:\n        k_in = k_unpad_t.detach().to(dtype).requires_grad_()\n        v_in = v_unpad_t.detach().to(dtype).requires_grad_()\n    else:\n        k_in = k.detach().to(dtype).requires_grad_()\n        v_in = v.detach().to(dtype).requires_grad_()\n\n    out_unpad, lse = flash_attn_varlen_func(\n        q_in, k_in, v_in,\n        cu_seqlens_q=cu_seqlens_q if unpad_q else None,\n        cu_seqlens_k=cu_seqlens_k if unpad_kv else None,\n        max_seqlen_q=seqlen,\n        max_seqlen_k=seqlen,\n        seqused_q=seqused_q if not unpad_q else None,\n        seqused_k=seqused_k if not unpad_kv else None,\n        causal=causal,\n    )\n\n    if is_fake_mode():\n        return\n\n    # Reshape output to (batch, seqlen, nheads, d) for comparison\n    out = output_pad_fn(out_unpad) if unpad_q else out_unpad\n\n    # Mask out padding positions — kernel output at padding positions is undefined\n    q_mask = rearrange(query_padding_mask, \"b s -> b s 1 1\")\n    out_masked = out.clone().masked_fill_(~q_mask, 0.0)\n    out_ref_masked = out_ref.clone().masked_fill_(~q_mask, 0.0)\n    out_pt_masked = out_pt.clone().masked_fill_(~q_mask, 0.0)\n\n    fwd_atol = 2 * (out_ref_masked + 0.3 - 0.3 - out_ref_masked).abs().max().item()\n    assert (out_masked - out_ref_masked).abs().max().item() <= 2 * (out_pt_masked - out_ref_masked).abs().max().item() + fwd_atol\n\n    # Backward (original test skips all SM90 varlen backward)\n    can_bwd = d <= 128 and not IS_SM90\n    if not can_bwd:\n        return\n\n    g = torch.randn_like(out_unpad)\n    dq_in, dk_in, dv_in = torch.autograd.grad(out_unpad, (q_in, k_in, v_in), g)\n\n    assert dq_in.isfinite().all(), \"dq contains non-finite values\"\n    assert dk_in.isfinite().all(), \"dk contains non-finite values\"\n    assert dv_in.isfinite().all(), \"dv contains non-finite values\"\n    assert dq_in.abs().max().item() > 0, \"dq is all zeros\"\n    assert dk_in.abs().max().item() > 0, \"dk is all zeros\"\n    assert dv_in.abs().max().item() > 0, \"dv is all zeros\"\n\n\n# ---------------------------------------------------------------------------\n# Combine kernel\n# ---------------------------------------------------------------------------\n\ndef attention_combine_ref(out_partial, lse_partial):\n    lse = torch.logsumexp(lse_partial, dim=0)\n    scale = torch.exp(lse_partial - lse)\n    scale = torch.where(torch.isinf(scale) | torch.isnan(scale), torch.zeros_like(scale), scale)\n    out = (scale.unsqueeze(-1) * out_partial).sum(0)\n    return out, lse\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.bfloat16])\n@pytest.mark.parametrize(\"d\", [64, 128])\n@pytest.mark.parametrize(\"seqlen\", [32, 256])\n@pytest.mark.parametrize(\"num_splits\", [2, 5, 17])\n@maybe_fake_tensor_mode(USE_FAKE_TENSOR)\ndef test_flash_attn_combine(num_splits, seqlen, d, dtype):\n    device = \"cuda\"\n    torch.random.manual_seed(1)\n    batch_size = 3\n    nheads = 8\n\n    # out_partial: (num_splits, batch, seqlen, nheads, d) with stride(-1)==1\n    # lse_partial: (num_splits, batch, seqlen, nheads) with stride(-2)==1 (seqlen contiguous)\n    out_partial = torch.randn(\n        num_splits, batch_size, seqlen, nheads, d, device=device, dtype=torch.float32,\n    )\n    lse_partial = torch.randn(\n        num_splits, batch_size, nheads, seqlen, device=device, dtype=torch.float32,\n    ).transpose(-1, -2)\n    lse_partial[num_splits // 2 :, : batch_size // 3] = -float(\"inf\")\n\n    out, lse = flash_attn_combine(out_partial, lse_partial, out_dtype=dtype, return_lse=True)\n    if is_fake_mode():\n        return\n    out_ref, lse_ref = attention_combine_ref(out_partial, lse_partial)\n    out_pt = out_ref.to(dtype)\n\n    assert torch.allclose(lse, lse_ref, atol=1e-5, rtol=1e-5)\n    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() or torch.allclose(out, out_pt, atol=1e-5, rtol=1e-5)\n"
  },
  {
    "path": "tests/cute/test_flash_attn_race_condition.py",
    "content": "# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.\n\nimport math\nimport itertools\nimport os\n\nimport pytest\nimport torch\n\nfrom einops import rearrange, repeat\n\ntry:\n    from flash_attn.layers.rotary import apply_rotary_emb\nexcept ImportError:\n    apply_rotary_emb = None\n\nfrom flash_attn.cute.testing import (\n    attention_ref,\n    generate_qkv,\n    generate_random_padding_mask,\n    pad_input,\n    unpad_input,\n)\nfrom flash_attn.cute.interface import (\n    flash_attn_func,\n    flash_attn_varlen_func,\n    flash_attn_combine,\n    _flash_attn_bwd,\n)\n\n\nDISABLE_SPLIT = os.getenv(\"FLASH_ATTENTION_DISABLE_SPLIT\", \"FALSE\") == \"TRUE\"\nIS_SM90 = torch.cuda.get_device_capability()[0] == 9\nINCREASED_TRIALS = False\n\n# @pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])\n@pytest.mark.parametrize(\"dtype\", [torch.bfloat16])\n@pytest.mark.parametrize(\"mha_type\", [\"mha\", \"mqa\", \"gqa\"])\n# @pytest.mark.parametrize(\"mha_type\", [\"gqa\"])\n# @pytest.mark.parametrize(\"has_learnable_sink\", [False, True])\n@pytest.mark.parametrize(\"has_learnable_sink\", [False])\n# @pytest.mark.parametrize(\"has_qv\", [False, True])\n@pytest.mark.parametrize(\"has_qv\", [False])\n# @pytest.mark.parametrize(\"deterministic\", [False, True])\n@pytest.mark.parametrize(\"deterministic\", [True])\n# @pytest.mark.parametrize(\"softcap\", [0.0, 15.0])\n@pytest.mark.parametrize(\"softcap\", [0.0])\n# @pytest.mark.parametrize(\"local_enum\", [0, 1, 2, 3])\n@pytest.mark.parametrize(\"local_enum\", [0, 1])\n@pytest.mark.parametrize(\"causal\", [False, True])\n# @pytest.mark.parametrize(\"causal\", [True])\n# @pytest.mark.parametrize(\"d\", [64, 128])\n# @pytest.mark.parametrize(\"d\", [128, 192])\n@pytest.mark.parametrize(\"d\", [64, 128, 192])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_k\",\n    [\n        (4224, 4224),\n        (2000, 4000),\n    ],\n)\n# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])\ndef test_flash_attn_output(\n    seqlen_q,\n    seqlen_k,\n    d,\n    causal,\n    local_enum,\n    softcap,\n    deterministic,\n    has_qv,\n    has_learnable_sink,\n    mha_type,\n    dtype,\n):\n    local = local_enum > 0\n    if local and causal:\n        pytest.skip()\n    is_sm90 = torch.cuda.get_device_capability()[0] == 9\n    if is_sm90 and d == 192:\n        pytest.xfail(\"headdim 192 not supported on sm90\")\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(0)\n    torch.cuda.empty_cache()\n    torch.cuda.synchronize()\n    batch_size = 9 if seqlen_k <= 2048 else 2\n    # batch_size = 1\n    nheads = 6\n    # nheads = 1\n    nheads_kv = nheads if mha_type == \"mha\" else (3 if mha_type == \"gqa\" else 1)\n    dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype\n    # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d])\n    dv_vals = [128] if d == 192 else [d]\n    if dtype == torch.float8_e4m3fn:\n        dv_vals = [d]\n    # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0]\n    attention_chunk_vals = [0]\n    for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals):\n        q_ref = torch.randn(\n            batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref\n        )\n        if softcap > 0.0:\n            # Ensure the values of qk are at least within softcap range.\n            q_ref = q_ref * softcap / 4\n        q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_()\n        k_ref = (\n            torch.randn(\n                batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref\n            )\n            .to(dtype)\n            .to(dtype_ref)\n            .requires_grad_()\n        )\n        v_ref = (\n            torch.randn(\n                batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref\n            )\n            .to(dtype)\n            .to(dtype_ref)\n            .requires_grad_()\n        )\n        if has_qv:\n            qv_ref = (\n                torch.randn(\n                    batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref\n                )\n                .to(dtype)\n                .to(dtype_ref)\n            )\n        else:\n            qv_ref = None\n        # Put window_size after QKV randn so that window_size changes from test to test\n        window_size = (\n            (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist()\n        )\n        if local_enum == 2:\n            window_size = (None, -window_size[1])\n        elif local_enum == 3:\n            window_size = (-window_size[0], None)\n        if local:\n            print(\"window size = \", window_size)\n        # window_size = (-1, -1) if not local else (16, 0)\n        if has_learnable_sink:\n            learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device)\n        else:\n            learnable_sink = None\n        if dtype == torch.float8_e4m3fn:\n            q_descale, k_descale, v_descale = [\n                torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32)\n                * 2\n                for _ in range(3)\n            ]\n        else:\n            q_descale, k_descale, v_descale = None, None, None\n        q, k, v = [x.detach().to(dtype).requires_grad_() for x in (q_ref, k_ref, v_ref)]\n        qv = qv_ref.detach().to(dtype).requires_grad_() if has_qv else None\n        out_ref, attn_ref = attention_ref(\n            q_ref,\n            k_ref,\n            v_ref,\n            None,\n            None,\n            causal=causal,\n            qv=qv_ref,\n            q_descale=q_descale,\n            k_descale=k_descale,\n            v_descale=v_descale,\n            window_size=window_size,\n            attention_chunk=attention_chunk,\n            learnable_sink=learnable_sink,\n            softcap=softcap,\n        )\n        out_pt, attn_pt = attention_ref(\n            q_ref,\n            k_ref,\n            v_ref,\n            None,\n            None,\n            causal=causal,\n            qv=qv_ref,\n            q_descale=q_descale,\n            k_descale=k_descale,\n            v_descale=v_descale,\n            window_size=window_size,\n            attention_chunk=attention_chunk,\n            learnable_sink=learnable_sink,\n            softcap=softcap,\n            upcast=False,\n            reorder_ops=True,\n            intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None,\n        )\n\n        # k_extended = repeat(k_ref, \"b s h d -> b s (h k) d\", k=nheads // nheads_kv)\n        # qk = torch.einsum('bshd,bthd->bhst', q_ref, k_extended).float()\n        # # if qv is not None:\n        # #     qk += torch.einsum('bshd,bthd->bhst', qv_ref, v_ref).float()\n        # m = qk.amax(-1, keepdim=True)\n        # s_tmp = torch.exp((qk - m) / math.sqrt(d))\n        # exp_sum = s_tmp.sum(-1)\n        # # qk = torch.einsum('bthd,bshd->bhts', q_ref.float() / math.sqrt(d), k_ref.float())\n        # # lse_ref = torch.logsumexp(qk, dim=-1)\n\n        # Numerical error if we just do any arithmetic on out_ref\n        fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item()\n        rtol = 2 if softcap == 0.0 else 3\n\n        print(f\"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}\")\n        print(f\"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}\")\n        # num_splits_vals = [1, 3]\n        # pack_gqa_vals = [False, True, None]\n        # SplitKV is not supported for hdim >= 192\n        pack_gqa_vals = [False]\n        # num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT else [1]\n        num_splits_vals = [1]\n        for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals):\n            out, lse = flash_attn_func(\n                q,\n                k,\n                v,\n                causal=causal,\n                # qv=qv,\n                # q_descale=q_descale, k_descale=k_descale, v_descale=v_descale,\n                window_size=window_size,\n                # attention_chunk=attention_chunk,\n                softcap=softcap,\n                learnable_sink=learnable_sink,\n                pack_gqa=pack_gqa,\n                num_splits=num_splits,\n                deterministic=deterministic,\n            )\n            print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n            print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n            # if not causal:\n            #     print(f\"LSE max diff: {(lse - lse_ref).abs().max().item()}\")\n            # breakpoint()\n\n            # Check that FlashAttention's numerical error is at most twice the numerical error\n            # of a Pytorch implementation.\n            assert (out - out_ref).abs().max().item() <= rtol * (\n                out_pt - out_ref\n            ).abs().max().item() + fwd_atol\n\n        if (\n            dtype != torch.float8_e4m3fn\n            and not has_qv\n            and not dv > 256\n            and not attention_chunk != 0\n            and softcap == 0.0\n            and ((dv == d and d <= 128) or (d == 192 and dv == 128))\n            and learnable_sink is None\n            # and False\n        ):\n            if IS_SM90 and mha_type != \"mha\":\n                pytest.xfail(\"SM90 backward: GQA/MQA has tensor layout issue (qhead_per_kvhead > 1)\")\n            if IS_SM90 and local:\n                pytest.xfail(\"SM90 backward: local attention not supported yet\")\n            g = torch.randn_like(out)\n            # do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2)\n            dq, dk, dv = torch.autograd.grad(out, (q, k, v), g)\n            # print(f\"dO_O max diff: {(softmax_d - do_o).abs().max().item()}\")\n            # assert (softmax_d - do_o).abs().max().item() <= 1e-5\n            # assert dq_accum.abs().max().item() == 0.0\n\n            # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float())\n            # P = torch.softmax(qk, -1)\n            # dP = P * (dS - do_o.transpose(1, 2).unsqueeze(1))\n            # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float())\n            # dV = torch.einsum('bhts,bthd->bshd', P, g.float())\n            # dK = torch.einsum('bhts,bthd->bshd', dP, q.float())\n            # breakpoint()\n\n            # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g)\n            dq_ref, dk_ref, dv_ref = torch.autograd.grad(\n                out_ref, (q_ref, k_ref, v_ref), g\n            )\n            dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g)\n            print(f\"dQ max diff: {(dq - dq_ref).abs().max().item()}\")\n            print(f\"dK max diff: {(dk - dk_ref).abs().max().item()}\")\n            print(f\"dV max diff: {(dv - dv_ref).abs().max().item()}\")\n            print(f\"dQ mean diff: {(dq - dq_ref).abs().mean().item()}\")\n            print(f\"dK mean diff: {(dk - dk_ref).abs().mean().item()}\")\n            print(f\"dV mean diff: {(dv - dv_ref).abs().mean().item()}\")\n            print(f\"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}\")\n            print(f\"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}\")\n            print(f\"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}\")\n            print(f\"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}\")\n            print(f\"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}\")\n            print(f\"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}\")\n            # breakpoint()\n            dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (\n                0 if softcap == 0 else 3e-4\n            )\n            assert (dq - dq_ref).abs().max().item() <= rtol * (\n                dq_pt - dq_ref\n            ).abs().max().item() + dq_atol\n            dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (\n                0 if softcap == 0 else 3e-4\n            )\n            assert (dk - dk_ref).abs().max().item() <= rtol * (\n                dk_pt - dk_ref\n            ).abs().max().item() + dk_atol\n            dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (\n                0 if softcap == 0 else 3e-4\n            )\n            assert (dv - dv_ref).abs().max().item() <= rtol * (\n                dv_pt - dv_ref\n            ).abs().max().item() + dv_atol\n\n            num_iters = 10_000 if INCREASED_TRIALS else 1000\n            for i in range(num_iters):\n                dq2, dk2, dv2, = _flash_attn_bwd(\n                    q, k, v, out, g, lse,\n                    causal=causal,\n                    window_size_left=window_size[0],\n                    window_size_right=window_size[1],\n                    deterministic=True,\n                )\n\n                diff_dq = (dq - dq2).abs()\n                max_idx = diff_dq.argmax()\n                print(f\"dQ max diff: {diff_dq.max().item()}\")\n                print(f\"  at index {max_idx.item()}: dQ={dq.flatten()[max_idx].item()}, dQ2={dq2.flatten()[max_idx].item()}\")\n\n                diff_dk = (dk - dk2).abs()\n                max_idx = diff_dk.argmax()\n                print(f\"dK max diff: {diff_dk.max().item()}\")\n                print(f\"  at index {max_idx.item()}: dK={dk.flatten()[max_idx].item()}, dK2={dk2.flatten()[max_idx].item()}\")\n\n                diff_dv = (dv - dv2).abs()\n                max_idx = diff_dv.argmax()\n                print(f\"dV max diff: {diff_dv.max().item()}\")\n                print(f\"  at index {max_idx.item()}: dV={dv.flatten()[max_idx].item()}, dV2={dv2.flatten()[max_idx].item()}\")\n                \n                # print(f\"dQ max diff with myself: {(dq - dq2).abs().max().item()}\")\n                # print(f\"dK max diff with myself: {(dk - dk2).abs().max().item()}\")\n                # print(f\"dV max diff with myself: {(dv - dv2).abs().max().item()}\")\n                # print(f\"dQ mean diff with myself: {(dq - dq2).abs().mean().item()}\")\n                # print(f\"dK mean diff with myself: {(dk - dk2).abs().mean().item()}\")\n                # print(f\"dV mean diff with myself: {(dv - dv2).abs().mean().item()}\")\n                \n                assert torch.equal(dq, dq2)\n                assert torch.equal(dk, dk2)\n                assert torch.equal(dv, dv2)\n\n                print(f\"✅ Iteration {i} passed!\")\n\n\n# @pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])\n@pytest.mark.parametrize(\"dtype\", [torch.bfloat16])\n@pytest.mark.parametrize(\"mha_type\", [\"mha\", \"mqa\", \"gqa\"])\n# @pytest.mark.parametrize(\"mha_type\", [\"gqa\"])\n# @pytest.mark.parametrize(\"has_learnable_sink\", [False, True])\n@pytest.mark.parametrize(\"has_learnable_sink\", [False])\n# @pytest.mark.parametrize(\"has_qv\", [False, True])\n@pytest.mark.parametrize(\"has_qv\", [False])\n# @pytest.mark.parametrize(\"deterministic\", [False, True])\n@pytest.mark.parametrize(\"deterministic\", [True])\n# @pytest.mark.parametrize(\"softcap\", [0.0, 15.0])\n@pytest.mark.parametrize(\"softcap\", [0.0])\n# @pytest.mark.parametrize(\"local_enum\", [0, 1, 2, 3])\n@pytest.mark.parametrize(\"local_enum\", [0, 1])\n@pytest.mark.parametrize(\"causal\", [False, True])\n# @pytest.mark.parametrize(\"causal\", [True])\n# @pytest.mark.parametrize(\"add_unused_qkv\", [False, True])\n@pytest.mark.parametrize(\"add_unused_qkv\", [False])\n# @pytest.mark.parametrize(\"d\", [32, 64, 96, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256])\n# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])\n# @pytest.mark.parametrize('d', [56, 80])\n# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128])\n# @pytest.mark.parametrize(\"d\", [64, 96, 128])\n@pytest.mark.parametrize(\"d\", [64, 128, 192])\n# @pytest.mark.parametrize(\"d\", [192])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_k\",\n    [\n        (1024, 1024),\n        (2048, 2048),\n    ],\n)\n@pytest.mark.parametrize(\"varlen_mode\", [\"random\", \"third\", \"full\"])\n# @pytest.mark.parametrize(\"varlen_mode\", [\"random\"])\n@pytest.mark.parametrize(\n    \"zero_lengths_q, zero_lengths_k\",\n    [\n        (False, False),\n        (True, False),\n        (False, True),\n        (True, True),\n    ],\n)\ndef test_flash_attn_varlen_output(\n    seqlen_q,\n    seqlen_k,\n    d,\n    add_unused_qkv,\n    causal,\n    local_enum,\n    softcap,\n    deterministic,\n    has_qv,\n    has_learnable_sink,\n    mha_type,\n    dtype,\n    varlen_mode,\n    zero_lengths_q,\n    zero_lengths_k,\n):\n    local = local_enum > 0\n    if local and causal:\n        pytest.skip()\n    is_sm90 = torch.cuda.get_device_capability()[0] == 9\n    if is_sm90 and local:\n        pytest.xfail(\"bwd local attention not supported on sm90\")\n    if is_sm90 and d == 192:\n        pytest.xfail(\"headdim 192 not supported on sm90\")\n    if (\n        causal or local\n    ):  # Right now reference only supports causal attention with seqlen_k == seqlen_q\n        seqlen_k = seqlen_q\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local))\n    batch_size = 49 if seqlen_q <= 1024 else 7\n    nheads = 6\n    # nheads = 1\n    nheads_kv = nheads if mha_type == \"mha\" else (3 if mha_type == \"gqa\" else 1)\n    dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype\n    # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d])\n    # dv_vals = [128] if d == 192 else ([d] if d != 128 else [64, d])\n    dv_vals = [128] if d == 192 else [d]\n    # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k else [0]\n    attention_chunk_vals = [0]\n    for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals):\n        q_ref = torch.randn(\n            batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref\n        )\n        if softcap > 0.0:\n            # Ensure the values of qk are at least within softcap range.\n            q_ref = (q_ref * softcap / 4).detach().requires_grad_()\n        q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_()\n        k_ref = (\n            torch.randn(\n                batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref\n            )\n            .to(dtype)\n            .to(dtype_ref)\n            .requires_grad_()\n        )\n        v_ref = (\n            torch.randn(\n                batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref\n            )\n            .to(dtype)\n            .to(dtype_ref)\n            .requires_grad_()\n        )\n        if has_qv:\n            qv_ref = (\n                torch.randn(\n                    batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref\n                )\n                .to(dtype)\n                .to(dtype_ref)\n            )\n        else:\n            qv_ref = None\n        # Put window_size after QKV randn so that window_size changes from test to test\n        window_size = (\n            (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist()\n        )\n        if local_enum == 2:\n            window_size = (None, window_size[1])\n        elif local_enum == 3:\n            window_size = (window_size[0], None)\n        if local:\n            print(\"window size = \", window_size)\n        if has_learnable_sink:\n            learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device)\n        else:\n            learnable_sink = None\n        if dtype == torch.float8_e4m3fn:\n            q_descale, k_descale, v_descale = [\n                torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32)\n                * 2\n                for _ in range(3)\n            ]\n        else:\n            q_descale, k_descale, v_descale = None, None, None\n        q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)]\n        qv = qv_ref.detach() if has_qv else None\n        query_padding_mask = generate_random_padding_mask(\n            seqlen_q,\n            batch_size,\n            device,\n            mode=varlen_mode,\n            zero_lengths=zero_lengths_q,\n        )\n        key_padding_mask = generate_random_padding_mask(\n            seqlen_k,\n            batch_size,\n            device,\n            mode=varlen_mode,\n            zero_lengths=zero_lengths_k,\n        )\n        def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device):\n            if add_unused:\n                another_mask = generate_random_padding_mask(max_seq_len, bs, device)\n                attn_mask = torch.logical_and(padding_mask, another_mask)\n                unused_mask = torch.logical_xor(\n                    torch.logical_or(padding_mask, another_mask), attn_mask\n                )\n            else:\n                attn_mask = padding_mask\n                unused_mask = None\n            return attn_mask, unused_mask\n\n        query_padding_mask, query_unused_mask = _gen_unused_masks(\n            query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device\n        )\n        # query_padding_mask[:] = True\n        # query_unused_mask = None\n        key_padding_mask, key_unused_mask = _gen_unused_masks(\n            key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device\n        )\n\n        if causal or local:\n            key_padding_mask = query_padding_mask\n\n        (\n            q_unpad,\n            k_unpad,\n            v_unpad,\n            qv_unpad,\n            cu_seqlens_q,\n            cu_seqlens_k,\n            seqused_q,\n            seqused_k,\n            max_seqlen_q,\n            max_seqlen_k,\n            q,\n            k,\n            v,\n            qv,\n            output_pad_fn,\n            dq_pad_fn,\n            dk_pad_fn,\n        ) = generate_qkv(\n            q,\n            k,\n            v,\n            query_padding_mask,\n            key_padding_mask,\n            qv=qv,\n            kvpacked=False,\n            query_unused_mask=query_unused_mask,\n            key_unused_mask=key_unused_mask,\n        )\n        print(\"cu_seqlens_q = \", cu_seqlens_q)\n        print(\"cu_seqlens_k = \", cu_seqlens_k)\n        q_unpad, k_unpad, v_unpad = [\n            x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)\n        ]\n        out_ref, attn_ref = attention_ref(\n            q_ref,\n            k_ref,\n            v_ref,\n            query_padding_mask,\n            key_padding_mask,\n            causal=causal,\n            qv=qv_ref,\n            q_descale=q_descale,\n            k_descale=k_descale,\n            v_descale=v_descale,\n            window_size=window_size,\n            attention_chunk=attention_chunk,\n            learnable_sink=learnable_sink,\n            softcap=softcap,\n        )\n        out_pt, attn_pt = attention_ref(\n            q_ref,\n            k_ref,\n            v_ref,\n            query_padding_mask,\n            key_padding_mask,\n            causal=causal,\n            qv=qv_ref,\n            q_descale=q_descale,\n            k_descale=k_descale,\n            v_descale=v_descale,\n            window_size=window_size,\n            attention_chunk=attention_chunk,\n            learnable_sink=learnable_sink,\n            softcap=softcap,\n            upcast=False,\n            reorder_ops=True,\n            intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None,\n        )\n\n        print(f\"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}\")\n        print(f\"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}\")\n\n        if query_unused_mask is not None:\n            q_zero_masking = rearrange(query_unused_mask, \"b s -> b s 1 1\")\n\n        # Numerical error if we just do any arithmetic on out_ref\n        fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item()\n        rtol = 2 if softcap == 0.0 else 3\n\n        out_unpad, lse = flash_attn_varlen_func(\n            q_unpad,\n            k_unpad,\n            v_unpad,\n            cu_seqlens_q=cu_seqlens_q,\n            cu_seqlens_k=cu_seqlens_k,\n            # max_seqlen_k,\n            # seqused_q=seqused_q,\n            # seqused_k=seqused_k,\n            max_seqlen_q=seqlen_q,\n            max_seqlen_k=seqlen_k,\n            causal=causal,\n            # qv=qv_unpad,\n            # q_descale=q_descale,\n            # k_descale=k_descale, v_descale=v_descale,\n            window_size=window_size,\n            # attention_chunk=attention_chunk,\n            learnable_sink=learnable_sink,\n            softcap=softcap,\n            num_splits=1,\n            pack_gqa=False,\n            deterministic=deterministic,\n        )\n        out = output_pad_fn(out_unpad)\n        if query_unused_mask is not None:\n            out.masked_fill_(q_zero_masking, 0.0)\n        print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n        print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n        # if not causal:\n        #     print(f\"LSE max diff: {(lse - lse_ref).abs().max().item()}\")\n        # breakpoint()\n\n        # Check that FlashAttention's numerical error is at most 3x the numerical error\n        # of a Pytorch implementation.\n        assert (out - out_ref).abs().max().item() <= rtol * (\n            out_pt - out_ref\n        ).abs().max().item() + fwd_atol\n\n        if (\n            dtype != torch.float8_e4m3fn\n            and not has_qv\n            and not dv > 256\n            and not attention_chunk != 0\n            and ((dv == d and d <= 128) or (d == 192 and dv == 128))\n            and not has_learnable_sink\n            and not is_sm90\n            # and False\n        ):\n            g_unpad = torch.randn_like(out_unpad)\n            # do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2)\n            # import flash_attn_3_cuda\n            # dq_unpad, dk_unpad, dv_unpad, softmax_d, dq_accum, lse_log2 = flash_attn_3_cuda.bwd_varlen(\n            #     g_unpad,\n            #     q_unpad,\n            #     k_unpad,\n            #     v_unpad,\n            #     out_unpad,\n            #     lse,\n            #     None,\n            #     None,\n            #     None,\n            #     cu_seqlens_q,\n            #     cu_seqlens_k,\n            #     None, None,\n            #     max_seqlen_q,\n            #     max_seqlen_k,\n            #     d ** (-0.5),\n            #     causal,\n            #     window_size[0], window_size[1],\n            #     softcap,\n            #     deterministic,\n            #     0,  # sm_margin\n            # )\n            dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad(\n                out_unpad, (q_unpad, k_unpad, v_unpad), g_unpad\n            )\n            dq = dq_pad_fn(dq_unpad)\n            dk = dk_pad_fn(dk_unpad)\n            dv = dk_pad_fn(dv_unpad)\n            if key_unused_mask is not None:\n                k_zero_masking = rearrange(key_unused_mask, \"b s -> b s 1 1\")\n                dk.masked_fill_(k_zero_masking, 0.0)\n                dv.masked_fill_(k_zero_masking, 0.0)\n            if query_unused_mask is not None:\n                dq.masked_fill_(q_zero_masking, 0.0)\n            # print(f\"dO_O max diff: {(softmax_d - do_o).abs().max().item()}\")\n            # assert (softmax_d - do_o).abs().max().item() <= 1e-5\n            # assert dq_accum.abs().max().item() == 0.0\n            g = output_pad_fn(g_unpad)\n\n            # qk = torch.einsum('bthd,bshd->bhts', q / (d ** 0.5), k).float()\n            # qk = torch.masked_fill(qk, rearrange(~key_padding_mask, \"b s -> b 1 1 s\"), float(\"-inf\"))\n            # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float())\n            # P = torch.softmax(qk, -1)\n            # dP = P * (dS - (g.float() * out.float()).sum(-1).transpose(1, 2).unsqueeze(-1))\n            # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float())\n            # dV = torch.einsum('bhts,bthd->bshd', P, g.float())\n            # dK = torch.einsum('bhts,bthd->bshd', dP, q.float())\n\n            # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g)\n            dq_ref, dk_ref, dv_ref = torch.autograd.grad(\n                out_ref, (q_ref, k_ref, v_ref), g\n            )\n            dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g)\n            print(f\"dQ max diff: {(dq - dq_ref).abs().max().item()}\")\n            print(f\"dK max diff: {(dk - dk_ref).abs().max().item()}\")\n            print(f\"dV max diff: {(dv - dv_ref).abs().max().item()}\")\n            print(f\"dQ mean diff: {(dq - dq_ref).abs().mean().item()}\")\n            print(f\"dK mean diff: {(dk - dk_ref).abs().mean().item()}\")\n            print(f\"dV mean diff: {(dv - dv_ref).abs().mean().item()}\")\n            print(f\"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}\")\n            print(f\"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}\")\n            print(f\"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}\")\n            print(f\"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}\")\n            print(f\"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}\")\n            print(f\"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}\")\n            # breakpoint()\n            dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (\n                0 if softcap == 0 else 3e-4\n            )\n            assert (dq - dq_ref).abs().max().item() <= rtol * (\n                dq_pt - dq_ref\n            ).abs().max().item() + dq_atol\n            dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (\n                0 if softcap == 0 else 3e-4\n            )\n            assert (dk - dk_ref).abs().max().item() <= rtol * (\n                dk_pt - dk_ref\n            ).abs().max().item() + dk_atol\n            dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (\n                0 if softcap == 0 else 3e-4\n            )\n            assert (dv - dv_ref).abs().max().item() <= rtol * (\n                dv_pt - dv_ref\n            ).abs().max().item() + dv_atol\n\n            num_iters = 10_000 if INCREASED_TRIALS else 1000\n\n            for i in range(num_iters):\n                dq_unpad2, dk_unpad2, dv_unpad2 = _flash_attn_bwd(\n                    q_unpad, k_unpad, v_unpad, out_unpad, g_unpad, lse,\n                    causal=causal,\n                    window_size_left=window_size[0],\n                    window_size_right=window_size[1],\n                    deterministic=True,\n                    cu_seqlens_q=cu_seqlens_q,\n                    cu_seqlens_k=cu_seqlens_k,\n                    max_seqlen_q=seqlen_q,\n                    max_seqlen_k=seqlen_k,\n                )\n\n                diff_dq = (dq_unpad - dq_unpad2).abs()\n                max_idx = diff_dq.argmax()\n                if i % 100 == 0:\n                    print(f\"dQ max diff: {diff_dq.max().item()}\")\n                    print(f\"  at index {max_idx.item()}: dQ={dq_unpad.flatten()[max_idx].item()}, dQ2={dq_unpad2.flatten()[max_idx].item()}\")\n\n                diff_dk = (dk_unpad - dk_unpad2).abs()\n                max_idx = diff_dk.argmax()\n                if i % 100 == 0:\n                    print(f\"dK max diff: {diff_dk.max().item()}\")\n                    print(f\"  at index {max_idx.item()}: dK={dk_unpad.flatten()[max_idx].item()}, dK2={dk_unpad2.flatten()[max_idx].item()}\")\n\n                diff_dv = (dv_unpad - dv_unpad2).abs()\n                max_idx = diff_dv.argmax()\n                if i % 100 == 0:\n                    print(f\"dV max diff: {diff_dv.max().item()}\")\n                    print(f\"  at index {max_idx.item()}: dV={dv_unpad.flatten()[max_idx].item()}, dV2={dv_unpad2.flatten()[max_idx].item()}\")\n                \n                assert torch.equal(dq_unpad, dq_unpad2)\n                assert torch.equal(dk_unpad, dk_unpad2)\n                assert torch.equal(dv_unpad, dv_unpad2)\n\n                if i % 100 == 0:\n                    print(f\"✅ Iteration {i} passed!\")\n"
  },
  {
    "path": "tests/cute/test_flash_attn_varlen.py",
    "content": "from typing import Optional\nimport pytest\n\nimport torch\nimport torch.nn.functional as F\nfrom flash_attn.cute import flash_attn_varlen_func\n\n@pytest.mark.parametrize(\"B\", [1, 7, 20])\n@pytest.mark.parametrize(\"H\", [1, 4, 6])\n@pytest.mark.parametrize(\"D\", [64, 128])\n@pytest.mark.parametrize(\"min_seq_len\", [1, 32, 128])\n@pytest.mark.parametrize(\"max_seq_len\", [8, 64, 2048])\n@pytest.mark.parametrize(\"causal\", [True, False])\n@pytest.mark.parametrize(\"softmax_scale\", [None, 0.1])\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16])\n@pytest.mark.parametrize(\"mha_type\", [\"mha\", \"mqa\", \"gqa\"])\ndef test_varlen(\n    B,\n    H,\n    D,\n    min_seq_len,\n    max_seq_len,\n    causal,\n    softmax_scale,\n    dtype,\n    mha_type,\n):\n    if min_seq_len > max_seq_len:\n        pytest.skip(\"Skipping min_seq_len > max_seq_len\")\n\n    q, k, v, cu_seqlens_q, cu_seqlens_k, total_q, total_k = generate_varlen_args(\n        batch_size=B,\n        n_heads=H,\n        d_head=D,\n        min_len=min_seq_len,\n        max_len=max_seq_len,\n        mha_type=mha_type,\n        dtype=dtype\n    )\n\n    ok = check_varlen_vs_torch_flash(\n        q, k, v,\n        cu_seqlens_q, cu_seqlens_k,\n        total_q=total_q, total_k=total_k,\n        softmax_scale=softmax_scale,\n        causal=causal,\n        mha_type=mha_type,\n    )\n    assert ok\n\ndef check_varlen_vs_torch_flash(\n    q, k, v,\n    cu_seqlens_q=None,\n    cu_seqlens_k=None,\n    seqused_q=None,\n    seqused_k=None,\n    total_q=None,\n    total_k=None,\n    softmax_scale=None,\n    causal=True,\n    mha_type='mha',\n    softcap=0.0,\n    atol=3e-2,\n    rtol=3e-2,\n):\n    assert q.requires_grad and k.requires_grad and v.requires_grad, \"Set requires_grad=True on inputs\"\n\n    def clone_like(t):\n        c = t.clone().detach().requires_grad_(True)\n        return c\n\n    q_fa, k_fa, v_fa = map(clone_like, (q, k, v))\n    q_t,  k_t,  v_t  = map(clone_like, (q, k, v))\n\n    if cu_seqlens_q is not None:\n        cu_seqlens_q_fa = cu_seqlens_q.clone()\n        cu_seqlens_q_t = cu_seqlens_q.clone()\n    else:\n        cu_seqlens_q_fa = None\n        cu_seqlens_q_t = None\n\n    if cu_seqlens_k is not None:\n        cu_seqlens_k_fa = cu_seqlens_k.clone()\n        cu_seqlens_k_t = cu_seqlens_k.clone()\n    else:\n        cu_seqlens_k_fa = None\n        cu_seqlens_k_t = None\n\n    out_fa, lse_fa = flash_attn_varlen_func(\n        q_fa, k_fa, v_fa,\n        cu_seqlens_q=cu_seqlens_q_fa,\n        cu_seqlens_k=cu_seqlens_k_fa,\n        seqused_q=seqused_q,\n        seqused_k=seqused_k,\n        softmax_scale=(1.0 / q.shape[-1]**0.5) if softmax_scale is None else softmax_scale,\n        causal=causal,\n        window_size=(None, None),\n        learnable_sink=None,\n        softcap=softcap,\n        pack_gqa=None,\n    )\n\n    out_t = torch_flash_ref(\n        q_t, k_t, v_t,\n        cu_seqlens_q=cu_seqlens_q_t,\n        cu_seqlens_k=cu_seqlens_k_t,\n        seqused_q=seqused_q,\n        seqused_k=seqused_k,\n        total_q=total_q,\n        total_k=total_k,\n        softmax_scale=softmax_scale,\n        causal=causal,\n        mha_type=mha_type,\n    )\n\n\n    ok_fwd = torch.allclose(out_fa.float(), out_t.float(), atol=atol, rtol=rtol)\n    if not ok_fwd:\n        return False\n\n    # Use the same upstream gradient to compare backward paths\n    grad_out = torch.randn_like(out_fa)\n\n    grad_fa = clone_like(grad_out)\n    grad_t = clone_like(grad_out)\n\n    # Cute bwd\n    out_fa.backward(grad_fa, retain_graph=False)\n    dq_fa, dk_fa, dv_fa = q_fa.grad, k_fa.grad, v_fa.grad\n\n    # Ref bwd\n    out_t.backward(grad_t, retain_graph=False)\n    dq_t, dk_t, dv_t = q_t.grad, k_t.grad, v_t.grad\n\n    # mean_ok_q = _stats(\"dQ\", dq_fa, dq_t, atol=atol, rtol=rtol)\n    # mean_ok_k = _stats(\"dK\", dk_fa, dk_t, atol=atol, rtol=rtol)\n    # mean_ok_v = _stats(\"dV\", dv_fa, dv_t, atol=atol, rtol=rtol)\n\n    # return mean_ok_q and mean_ok_k and mean_ok_v\n\n    ok_q = torch.allclose(dq_fa.float(), dq_t.float(), atol=atol, rtol=rtol)\n    ok_k = torch.allclose(dk_fa.float(), dk_t.float(), atol=atol, rtol=rtol)\n    ok_v = torch.allclose(dv_fa.float(), dv_t.float(), atol=atol, rtol=rtol)\n    # print(f\"Close? dQ={ok_q}, dK={ok_k}, dV={ok_v}\")\n    return ok_q and ok_k and ok_v\n\ndef generate_varlen_args(\n    batch_size=8,\n    n_heads=16,\n    d_head=128,\n    min_len=32,\n    max_len=64,\n    mha_type=\"mha\",\n    dtype = torch.bfloat16,\n):\n\n    torch.manual_seed(0)\n    device = \"cuda\"\n\n    assert mha_type in [\"mha\", \"mqa\", \"gqa\"]\n\n    lens_q = torch.randint(low=min_len, high=max_len + 1, size=(batch_size,))\n    lens_k = lens_q.clone()\n\n    cu_seqlens_q = torch.cat([torch.zeros(1, dtype=torch.int32), lens_q.cumsum(0)])\n    cu_seqlens_k = torch.cat([torch.zeros(1, dtype=torch.int32), lens_k.cumsum(0)])\n\n    total_q = cu_seqlens_q[-1]\n    total_k = cu_seqlens_k[-1]\n\n    cu_seqlens_q = cu_seqlens_q.contiguous().to(dtype=torch.int32, device=device)\n    cu_seqlens_k = cu_seqlens_k.contiguous().to(dtype=torch.int32, device=device)\n\n    if mha_type == \"gqa\":\n        H = 3 * n_heads\n        H_kv = n_heads\n    elif mha_type == \"mha\":\n        H = H_kv = n_heads\n    else: # MQA\n        H = n_heads\n        H_kv = 1\n\n    d_head_v = d_head\n\n    q = torch.randn(total_q, H, d_head, device=device, dtype=dtype, requires_grad=True)\n    k = torch.randn(total_k, H_kv, d_head, device=device, dtype=dtype, requires_grad=True)\n    v = torch.randn(total_k, H_kv, d_head_v, device=device, dtype=dtype, requires_grad=True)\n\n    return q, k, v, cu_seqlens_q, cu_seqlens_k, total_q, total_k\n\n# Simple for loop over batch dim implementation\ndef torch_flash_ref(\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        cu_seqlens_q: torch.Tensor = None,\n        cu_seqlens_k: torch.Tensor = None,\n        total_q: int = 0,\n        total_k: int = 0,\n        softmax_scale: Optional[float] = None,\n        causal: bool = False,\n        **kwargs\n    ):\n\n    \"\"\"\n    q: (total_q, H, d) if cu_seqlens_q is not None, otherwise (B, L, H, d)\n    k: (total_k, H_kv, d) if cu_seqlens_k is not None, otherwise (B, L, H_kv, d)\n    v: (total_k, H_kv, d_v) if cu_seqlens_k is not None, otherwise (B, L, H_kv, d_v)\n    cu_seqlens_q: (B+1,) int32, cumulative\n    cu_seqlens_k: (B+1,) int32, cumulative\n\n    seqused_q: (B+1,) int32\n    seqused_k: (B+1,) int32\n    Returns:\n        out packed like q: (total_q, H, d_v)\n    \"\"\"\n\n    if cu_seqlens_q is not None:\n        assert cu_seqlens_q.dim() == 1\n        assert total_q == q.shape[0]\n        assert q.dim() == 3\n        H = q.shape[1]\n        B = cu_seqlens_q.shape[0] - 1\n    else:\n        assert q.dim() == 4\n        H = q.shape[2]\n        B = q.shape[0]\n\n    if cu_seqlens_k is not None:\n        assert cu_seqlens_k.dim() == 1\n        assert total_k == k.shape[0] == v.shape[0]\n        assert k.dim() == v.dim() == 3\n        H_kv = k.shape[1]\n        B_kv = cu_seqlens_k.shape[0] - 1\n    else:\n        assert k.dim() == v.dim() == 4\n        assert k.shape[0] == v.shape[0]\n        H_kv = k.shape[2]\n        B_kv = k.shape[0]\n\n    d = q.shape[-1]\n    d_v = v.shape[-1]\n\n    assert H_kv == v.shape[-2]\n    assert d == k.shape[-1]\n    assert B == B_kv\n\n    assert q.device == k.device == v.device\n    assert q.is_floating_point() and k.is_floating_point() and v.is_floating_point()\n\n    device = q.device\n    dtype = q.dtype\n\n    hcseq_q = cu_seqlens_q.to(device='cpu')\n    hcseq_k = cu_seqlens_k.to(device='cpu')\n\n    outs = []\n    for b in range(B):\n        if hcseq_q is not None:\n            q_start, q_end = int(hcseq_q[b]), int(hcseq_q[b+1])\n            qb = q[q_start:q_end]\n        else:\n            qb = q[b]\n\n        if hcseq_k is not None:\n            k_start, k_end = int(hcseq_k[b]), int(hcseq_k[b+1])\n            kb = k[k_start:k_end]\n            vb = v[k_start:k_end]\n        else:\n            kb = k[b]\n            vb = v[b]\n\n        qb = qb.permute(1, 0, 2).unsqueeze(0)\n        kb = kb.permute(1, 0, 2).unsqueeze(0)\n        vb = vb.permute(1, 0, 2).unsqueeze(0)\n\n        ob = F.scaled_dot_product_attention(\n            qb, kb, vb,\n            attn_mask=None,\n            dropout_p=0.0,\n            is_causal=causal,\n            scale=softmax_scale,\n            enable_gqa=H_kv!=H\n        )\n\n        ob = ob.squeeze(0).permute(1, 0, 2).contiguous()\n        outs.append(ob)\n\n    if cu_seqlens_q is not None:\n        out = torch.cat(outs, dim=0).to(device=device, dtype=dtype)\n    else:\n        out = torch.stack(outs, dim=0).to(device=device, dtype=dtype)\n    return out\n\n@torch.no_grad()\ndef _stats(name, a, b, atol, rtol):\n    diff = (a - b).float()\n    mean_abs = diff.abs().mean().item()\n    mean_rel = (diff.abs().mean() / b.abs().clamp_min(1e-6).mean().item())\n    print(f\"{name}: mean_abs={mean_abs:.4e}, mean_rel={mean_rel:.4e}, sum_fa={a.sum()}, sum_ref={b.sum()}\")\n    return mean_abs < atol and mean_rel < rtol\n"
  },
  {
    "path": "tests/cute/test_mask_mod.py",
    "content": "# mask mod test script\n# REFACTORED to use _flash_attn_fwd as the kernel entrypoint\n#\n# Test Organization:\n# - test_static_masks: Fast tests for masks that don't need per-seqlen compilation\n#   (identity, document, block_diagonal, etc.) with comprehensive seqlen coverage\n# - test_parameterized_masks: Slower tests for masks that require recompilation per\n#   seqlen pair (causal, block_causal, sliding_window) with reduced seqlen coverage\n#\n# Usage:\n#   pytest test_mask_mod.py::test_static_masks         # Run only fast tests\n#   pytest test_mask_mod.py::test_parameterized_masks  # Run only slow tests\n#   pytest test_mask_mod.py                            # Run all tests\n\nimport math\nfrom unittest import mock\n\nimport pytest\nimport torch\nimport cutlass\nimport cutlass.cute as cute\nfrom torch.nn.attention.flex_attention import create_block_mask, flex_attention\nimport torch.nn.functional as F\n\nfrom flash_attn.cute.interface import _flash_attn_fwd, _flash_attn_bwd\nfrom flash_attn.cute.block_sparsity import (\n    BlockSparseTensorsTorch,\n    fast_sampling,\n    normalize_block_sparse_config,\n)\nfrom flash_attn.cute.cache_utils import get_jit_cache\nfrom flash_attn.cute import utils\nfrom mask_mod_definitions import get_mask_pair, random_doc_id_tensor\nCOMPUTE_CAPABILITY = torch.cuda.get_device_capability()[0]\n\n\n@pytest.fixture(autouse=True)\ndef reset_torch_state():\n    \"\"\"Reset torch dynamo/compile state between tests to avoid state pollution.\"\"\"\n    torch._dynamo.reset()\n    torch.cuda.empty_cache()\n\n    yield\n\n    torch._dynamo.reset()\n    torch.cuda.empty_cache()\n\ndef create_tensors(\n    batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, headdim, headdim_v, dtype\n):\n    device = \"cuda\"\n    q = torch.randn(batch_size, seqlen_q, nheads, headdim, device=device, dtype=dtype)\n    k = torch.randn(\n        batch_size, seqlen_k, nheads_kv, headdim, device=device, dtype=dtype\n    )\n    v = torch.randn(\n        batch_size, seqlen_k, nheads_kv, headdim_v, device=device, dtype=dtype\n    )\n    out = torch.empty(\n        batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype\n    )\n    lse = torch.empty(batch_size, nheads, seqlen_q, device=device, dtype=torch.float32)\n\n    return {\n        \"q\": q,\n        \"k\": k,\n        \"v\": v,\n        \"out\": out,\n        \"lse\": lse,\n    }\n\n\ndef compute_reference_flex_attn(tensors, mask_mod_flex, block_size: tuple[int, int] | None = None):\n    \"\"\"Compute reference using flex_attention for custom mask_mods\"\"\"\n    batch_size, seqlen_q, nheads, headdim = tensors[\"q\"].shape\n    _, seqlen_k, nheads_kv, _ = tensors[\"k\"].shape\n\n    q = tensors[\"q\"].transpose(1, 2)\n    k = tensors[\"k\"].transpose(1, 2)\n    v = tensors[\"v\"].transpose(1, 2)\n\n    if nheads != nheads_kv:\n        repeat_factor = nheads // nheads_kv\n        k = k.repeat_interleave(repeat_factor, dim=1)\n        v = v.repeat_interleave(repeat_factor, dim=1)\n\n    scale = 1.0 / math.sqrt(headdim)\n\n    # Handle identity (no masking) case\n    if mask_mod_flex is None:\n        out_ref = F.scaled_dot_product_attention(q, k, v, scale=scale)\n        return out_ref.transpose(1, 2).contiguous()\n\n    block_mask_kwargs = {}\n    if block_size is not None:\n        block_mask_kwargs[\"BLOCK_SIZE\"] = block_size\n\n    block_mask = create_block_mask(\n        mask_mod_flex,\n        B=batch_size,\n        H=nheads,\n        Q_LEN=seqlen_q,\n        KV_LEN=seqlen_k,\n        device=q.device,\n        **block_mask_kwargs,\n    )\n    out_ref = flex_attention(q, k, v, block_mask=block_mask, scale=scale, enable_gqa=True)\n    return out_ref.transpose(1, 2).contiguous()\n\n\ndef get_coarse_block_mask_pair(sparse_tile_m: int, tile_n: int, last_block: int):\n    @fast_sampling\n    @cute.jit\n    def _cute_coarse_block_mask(\n        batch: cute.TensorSSA,\n        head: cute.TensorSSA,\n        m_idx: cute.TensorSSA,\n        n_idx: cute.TensorSSA,\n        seqlen_info,\n        aux_tensors,\n    ) -> cute.TensorSSA:\n        sparse_tile_m_ssa = utils.scalar_to_ssa(sparse_tile_m, cutlass.Int32)\n        tile_n_ssa = utils.scalar_to_ssa(tile_n, cutlass.Int32)\n        q_block = m_idx // sparse_tile_m_ssa\n        n_block = n_idx // tile_n_ssa\n        zero = utils.scalar_to_ssa(0, cutlass.Int32)\n        one = utils.scalar_to_ssa(1, cutlass.Int32)\n        last = utils.scalar_to_ssa(last_block, cutlass.Int32)\n        return ((q_block == zero) & (n_block == zero)) | ((q_block == one) & (n_block == last))\n\n    def _flex_coarse_block_mask(b, h, q_idx, kv_idx):\n        q_block = q_idx // sparse_tile_m\n        n_block = kv_idx // tile_n\n        return ((q_block == 0) & (n_block == 0)) | ((q_block == 1) & (n_block == last_block))\n\n    return _cute_coarse_block_mask, _flex_coarse_block_mask\n\n\nSEQLEN_PAIRS_COMPREHENSIVE = [\n    (1, 1),\n    (64, 128),\n    (128, 192),\n    (256, 256),\n    (239, 1),\n    (799, 3),\n    (113, 203),\n    (113, 128),\n    (128, 217),\n    (113, 211),\n    (108, 256),\n    (256, 512),\n    (384, 256),\n    (640, 128),\n    (512, 256),\n    (1024, 1024),\n    (1023, 1024),\n    (1024, 1023),\n    (4096, 4096),\n    (4224, 4224),\n]\n\nSEQLEN_PAIRS_SMOKE = [\n    (128, 128),\n    (256, 256),\n    (113, 203),\n    (1024, 1024),\n    (128, 8192)\n]\n\n\ndef _run_mask_test(\n    seqlen_q,\n    seqlen_k,\n    nheads,\n    kv_mode,\n    headdim,\n    dtype,\n    mask_name,\n    window_size,\n    window_left,\n    window_right,\n    tile_m,\n    tile_n,\n    use_block_sparsity,\n    needs_backward=False,\n):\n    torch.manual_seed(42)\n\n    if mask_name == \"sliding_window\":\n        assert window_size is not None, (\n            \"window_size must be specified for sliding_window\"\n        )\n        if seqlen_q > seqlen_k:\n            pytest.skip(\n                f\"seqlen_q={seqlen_q} > seqlen_k={seqlen_k} not supported for sliding_window\"\n            )\n\n    # Determine nheads_kv based on mode\n    if kv_mode == \"mha\":\n        nheads_kv = nheads\n        pack_gqa = False\n    elif kv_mode == \"gqa\":\n        if COMPUTE_CAPABILITY < 9:\n            pytest.xfail(\"pack_gqa requires SM90+\")\n        nheads_kv = nheads // 4\n        pack_gqa = True\n    elif kv_mode == \"mqa\":\n        nheads_kv = 1\n        pack_gqa = False\n    else:\n        raise ValueError(f\"Unknown kv_mode: {kv_mode}\")\n\n    batch_size = 1\n    headdim_v = headdim\n\n    aux_tensors_arg = None\n    mask_mod_cute, mask_mod_flex = get_mask_pair(\n        mask_name, seqlen_q=seqlen_q, seqlen_k=seqlen_k, window_size=window_size\n    )\n    if mask_name == \"document\":\n        doc_len = max(seqlen_q, seqlen_k)\n        doc_ids = random_doc_id_tensor(nheads, batch_size, doc_len, device=\"cuda\").to(\n            dtype=torch.int32, device=\"cuda\"\n        )\n        original_flex_mask = mask_mod_flex\n\n        def mask_mod_flex(b, h, q_idx, kv_idx, doc_ids=doc_ids):\n            return original_flex_mask(b, h, q_idx, kv_idx, doc_ids)\n\n        aux_tensors_arg = [doc_ids]\n    elif mask_name == \"ima\":\n        bias_threshold = (seqlen_k // 4) * 3\n        bias = torch.full((seqlen_k,), bias_threshold, dtype=torch.int32, device=\"cuda\")\n        original_flex_mask = mask_mod_flex\n\n        def mask_mod_flex(b, h, q_idx, kv_idx, bias=bias):\n            return original_flex_mask(b, h, q_idx, kv_idx, bias)\n\n        aux_tensors_arg = [bias]\n    causal = False\n\n    if causal and seqlen_k < seqlen_q:\n        pytest.skip(\"causal masking requires seqlen_k >= seqlen_q\")\n\n    tensors = create_tensors(\n        batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, headdim, headdim_v, dtype\n    )\n\n    # SM100 uses sparse_tile_m = 2*tile_m to match forward q_stage=2 pipelining\n    if COMPUTE_CAPABILITY == 10:\n        sparse_tile_m = 2 * tile_m\n    else:\n        sparse_tile_m = tile_m\n\n    block_mask_nheads = 1 if pack_gqa else nheads\n    bm = create_block_mask(\n        mask_mod_flex,\n        batch_size,\n        block_mask_nheads,\n        seqlen_q,\n        seqlen_k,\n        device=\"cuda\",\n        BLOCK_SIZE=(sparse_tile_m, tile_n),\n    )\n    (\n        _seq_q,\n        _seq_k,\n        kv_mask_cnt,\n        kv_mask_idx,\n        full_kv_cnt,\n        full_kv_idx,\n        q_mask_cnt,\n        q_mask_idx,\n        full_q_cnt,\n        full_q_idx,\n        *_,\n    ) = bm.as_tuple()\n\n    # SM90 block-sparse backward expects BlockMask granularity (128, 128) regardless of fwd tiling.\n    sparse_tile_m_bwd = sparse_tile_m\n    tile_n_bwd = tile_n\n    if COMPUTE_CAPABILITY == 9 and use_block_sparsity and (sparse_tile_m, tile_n) != (128, 128):\n        bm_bwd = create_block_mask(\n            mask_mod_flex,\n            batch_size,\n            nheads,\n            seqlen_q,\n            seqlen_k,\n            device=\"cuda\",\n            BLOCK_SIZE=(128, 128),\n        )\n        (\n            _seq_q,\n            _seq_k,\n            _kv_mask_cnt,\n            _kv_mask_idx,\n            _full_kv_cnt,\n            _full_kv_idx,\n            q_mask_cnt,\n            q_mask_idx,\n            full_q_cnt,\n            full_q_idx,\n            *_,\n        ) = bm_bwd.as_tuple()\n        sparse_tile_m_bwd = 128\n        tile_n_bwd = 128\n\n    softmax_scale = 1.0 / math.sqrt(headdim)\n\n    block_sparse_mask_fwd = (\n        BlockSparseTensorsTorch(\n            mask_block_cnt=kv_mask_cnt,\n            mask_block_idx=kv_mask_idx,\n            full_block_cnt=full_kv_cnt,\n            full_block_idx=full_kv_idx,\n            block_size=(sparse_tile_m, tile_n),\n        )\n        if use_block_sparsity\n        else None\n    )\n\n    # Backward uses Q-direction (transposed) sparse tensors\n    block_sparse_mask_bwd = (\n        BlockSparseTensorsTorch(\n            mask_block_cnt=q_mask_cnt,\n            mask_block_idx=q_mask_idx,\n            full_block_cnt=full_q_cnt,\n            full_block_idx=full_q_idx,\n            block_size=(sparse_tile_m_bwd, tile_n_bwd),\n        )\n        if use_block_sparsity\n        else None\n    )\n\n    out_tuple = _flash_attn_fwd(\n        q=tensors[\"q\"],\n        k=tensors[\"k\"],\n        v=tensors[\"v\"],\n        out=tensors[\"out\"],\n        lse=tensors[\"lse\"],\n        cu_seqlens_q=None,\n        cu_seqlens_k=None,\n        seqused_q=None,\n        seqused_k=None,\n        page_table=None,\n        softmax_scale=softmax_scale,\n        causal=causal,\n        softcap=None,\n        window_size_left=window_left,\n        window_size_right=window_right,\n        learnable_sink=None,\n        tile_mn=(tile_m, tile_n),\n        pack_gqa=pack_gqa,\n        _arch=None,\n        score_mod=None,\n        mask_mod=mask_mod_cute,\n        block_sparse_tensors=block_sparse_mask_fwd,\n        return_lse=True,\n        aux_tensors=aux_tensors_arg,\n    )\n\n    out_cute = out_tuple[0]\n    lse_cute = out_tuple[1]\n    tensors_fp32 = {\n        k: v.float() if v.dtype in [torch.float16, torch.bfloat16] else v\n        for k, v in tensors.items()\n    }\n\n    block_size = (tile_m, tile_n)\n    out_ref_fp32 = compute_reference_flex_attn(tensors_fp32, mask_mod_flex, block_size)\n    out_ref = compute_reference_flex_attn(tensors, mask_mod_flex, block_size)\n    out_pt = out_ref.clone()\n\n    # Check for invalid values\n    assert out_cute.shape == out_ref_fp32.shape == out_ref.shape\n    assert not torch.isnan(out_cute).any()\n    assert not torch.isnan(out_ref_fp32).any()\n    assert torch.isfinite(out_cute).all()\n    assert torch.isfinite(out_ref_fp32).all()\n\n    # Compute numerical tolerance (matching flash attention tests)\n    fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item()\n    rtol = 2\n\n    ref_error = (out_ref - out_ref_fp32).abs().max().item()\n    pt_error = (out_pt - out_ref_fp32).abs().max().item()\n    cute_error = (out_cute - out_ref_fp32).abs().max().item()\n\n    mask_desc = f\"mask_mod={mask_name}\"\n    if mask_name == \"sliding_window\" and window_size is not None:\n        mask_desc += f\"(w={window_size})\"\n\n    print(\n        f\"\\n{mask_desc} @ Q={seqlen_q}, K={seqlen_k}, H={nheads}/{nheads_kv} ({kv_mode}), \"\n        f\"D={headdim}, M={tile_m}, N={tile_n}\"\n    )\n    print(\"  Reference implementation: FlexAttention\")\n    print(f\"  Reference vs FP32: {ref_error:.2e}\")\n    print(f\"  PyTorch vs FP32: {pt_error:.2e}\")\n    print(f\"  Kernel vs FP32: {cute_error:.2e}\")\n    print(f\"  Tolerance: rtol={rtol} * {pt_error:.2e} + {fwd_atol:.2e}\")\n    print(f\"  Error ratio: {cute_error / max(pt_error, 1e-10):.2f}\")\n\n    # Debug: show some sample values if error is large\n    if cute_error > 1e-2:\n        print(f\"  DEBUG: Sample kernel output: {out_cute[0, 0, 0, :5]}\")\n        print(f\"  DEBUG: Sample reference output: {out_ref_fp32[0, 0, 0, :5]}\")\n        print(f\"  DEBUG: Max diff location: {(out_cute - out_ref_fp32).abs().argmax()}\")\n        max_diff_idx = (out_cute - out_ref_fp32).abs().argmax()\n        max_diff_coords = torch.unravel_index(max_diff_idx, out_cute.shape)\n        print(f\"  DEBUG: Max diff at coords: {max_diff_coords}\")\n        print(f\"  DEBUG: Kernel value: {out_cute[max_diff_coords]:.6f}\")\n        print(f\"  DEBUG: Reference value: {out_ref_fp32[max_diff_coords]:.6f}\")\n\n    # Use the same assertion logic as FlashAttention tests\n    assert cute_error <= rtol * pt_error + fwd_atol, (\n        f\"Kernel error {cute_error:.2e} exceeds {rtol}x PyTorch error {pt_error:.2e} + {fwd_atol:.2e}\"\n    )\n\n    if needs_backward:\n        q = tensors[\"q\"]\n        k = tensors[\"k\"]\n        v = tensors[\"v\"]\n\n        # Create grad_out once and reuse\n        grad_out = torch.randn_like(out_cute)\n\n        # Create block_mask for flex reference\n        flex_block_mask = create_block_mask(\n            mask_mod_flex, batch_size, nheads, seqlen_q, seqlen_k,\n            device=\"cuda\", BLOCK_SIZE=(tile_m, tile_n),\n        )\n\n        dq_cute, dk_cute, dv_cute = run_cute_mask_bwd(\n            q, k, v, out_cute, lse_cute, grad_out, mask_mod_cute,\n            block_sparse_mask_bwd=block_sparse_mask_bwd, tile_m=tile_m, tile_n=tile_n,\n            aux_tensors=aux_tensors_arg,\n        )\n        _, dq_ref_fp32, dk_ref_fp32, dv_ref_fp32 = run_flex_reference_bwd(\n            q, k, v, flex_block_mask, grad_out, dtype=torch.float32\n        )\n        _, dq_pt, dk_pt, dv_pt = run_flex_reference_bwd(\n            q, k, v, flex_block_mask, grad_out\n        )\n\n        # Check for invalid values\n        assert not torch.isnan(dq_cute).any(), \"dQ contains NaN\"\n        assert not torch.isnan(dk_cute).any(), \"dK contains NaN\"\n        assert not torch.isnan(dv_cute).any(), \"dV contains NaN\"\n\n        bwd_rtol = 2\n        min_seqlen = min(seqlen_q, seqlen_k)\n        bwd_atol_floor = 1e-5 if min_seqlen >= 64 else 3e-5\n        dq_atol = max(bwd_atol_floor, 2 * (dq_ref_fp32 + 0.3 - 0.3 - dq_ref_fp32).abs().max().item())\n        dk_atol = max(bwd_atol_floor, 2 * (dk_ref_fp32 + 0.3 - 0.3 - dk_ref_fp32).abs().max().item())\n        dv_atol = max(bwd_atol_floor, 2 * (dv_ref_fp32 + 0.3 - 0.3 - dv_ref_fp32).abs().max().item())\n\n        dq_ref = dq_ref_fp32.to(dtype)\n        dk_ref = dk_ref_fp32.to(dtype)\n        dv_ref = dv_ref_fp32.to(dtype)\n\n        pt_dq_err = (dq_pt - dq_ref).abs().max().item()\n        pt_dk_err = (dk_pt - dk_ref).abs().max().item()\n        pt_dv_err = (dv_pt - dv_ref).abs().max().item()\n\n        cute_dq_err = (dq_cute - dq_ref).abs().max().item()\n        cute_dk_err = (dk_cute - dk_ref).abs().max().item()\n        cute_dv_err = (dv_cute - dv_ref).abs().max().item()\n\n        print(\"  Backward comparison:\")\n        print(f\"    dQ: PT err={pt_dq_err:.2e}, CuTE err={cute_dq_err:.2e}, atol={dq_atol:.2e}\")\n        print(f\"    dK: PT err={pt_dk_err:.2e}, CuTE err={cute_dk_err:.2e}, atol={dk_atol:.2e}\")\n        print(f\"    dV: PT err={pt_dv_err:.2e}, CuTE err={cute_dv_err:.2e}, atol={dv_atol:.2e}\")\n\n        assert cute_dq_err <= bwd_rtol * pt_dq_err + dq_atol, f\"dQ error too large: {cute_dq_err:.2e}\"\n        assert cute_dk_err <= bwd_rtol * pt_dk_err + dk_atol, f\"dK error too large: {cute_dk_err:.2e}\"\n        assert cute_dv_err <= bwd_rtol * pt_dv_err + dv_atol, f\"dV error too large: {cute_dv_err:.2e}\"\n\n\ndef test_mask_mod_ima_partial_block():\n    _run_mask_test(\n        seqlen_q=257,\n        seqlen_k=257,\n        nheads=1,\n        kv_mode=\"mha\",\n        headdim=128,\n        dtype=torch.bfloat16,\n        mask_name=\"ima\",\n        window_size=None,\n        window_left=None,\n        window_right=None,\n        tile_m=128,\n        tile_n=128,\n        use_block_sparsity=True,\n        needs_backward=True,\n    )\n\n\n# Q boundary seqlens: NOT multiples of tile_m (128)\n# These exercise the fix for is_full_block tiles not masking OOB Q rows in backward\nQ_BOUNDARY_SEQLEN_PAIRS = [\n    (200, 200),    # Last m_block: rows 128-199 valid, 200-255 should be masked\n    (300, 300),    # Last m_block: rows 256-299 valid, 300-383 should be masked\n    (129, 129),    # Just 1 element into second tile\n    (255, 255),    # Just 1 element short of 2 full tiles\n    (500, 512),    # Q boundary only (K aligned)\n    (512, 500),    # K boundary only (Q aligned)\n    (333, 444),    # Both non-aligned\n]\n\n\n@pytest.mark.parametrize(\"seqlen_q,seqlen_k\", Q_BOUNDARY_SEQLEN_PAIRS)\n@pytest.mark.parametrize(\"mask_name\", [\"block_diagonal\", \"document\"])\ndef test_q_boundary_masking_block_sparse_bwd(seqlen_q, seqlen_k, mask_name):\n    \"\"\"Test Q boundary masking for block-sparse backward pass.\n\n    This test specifically exercises the fix for the bug where Q rows beyond seqlen_q\n    were not masked in backward pass for is_full_block=True tiles.\n\n    The bug occurred because:\n    - In forward, apply_mask_sm100 always checks both Q and K bounds\n    - In backward, apply_mask_sm100_transposed with is_full_block=True only checked K bounds\n    - Result: partial last m_blocks had unmasked garbage Q rows contributing to gradients\n\n    Key conditions:\n    - seqlen_q NOT a multiple of tile_m (128): creates partial last m_block\n    - Block-sparse with mask_mod: exercises is_full_block=True path\n    - Backward pass: where the bug manifested\n    \"\"\"\n    _run_mask_test(\n        seqlen_q=seqlen_q,\n        seqlen_k=seqlen_k,\n        nheads=4,\n        kv_mode=\"mha\",\n        headdim=128,\n        dtype=torch.bfloat16,\n        mask_name=mask_name,\n        window_size=None,\n        window_left=None,\n        window_right=None,\n        tile_m=128,\n        tile_n=128,\n        use_block_sparsity=True,\n        needs_backward=True,\n    )\n\n\n@pytest.mark.skipif(COMPUTE_CAPABILITY != 10, reason=\"Test uses SM100 block mask conventions (2*tile_m)\")\ndef test_single_doc_bwd_minimal():\n    \"\"\"Minimal test to isolate single-document backward pass bug.\n\n    This test uses batch=1, nheads=1, and a single document (all same doc_id)\n    to make debugging easier. The bug manifests as large numerical errors\n    in dQ, dK, dV when blocks are classified as \"full blocks\" due to\n    the mask returning True for all positions.\n\n    Run with: pytest tests/cute/test_mask_mod.py::test_single_doc_bwd_minimal -v -s\n    \"\"\"\n    import random\n    random.seed(42)\n    torch.manual_seed(42)\n\n    seqlen_q = 384\n    seqlen_k = 300\n    batch_size = 1\n    nheads = 1\n    headdim = 128\n    tile_m = 128\n    tile_n = 128\n    dtype = torch.bfloat16\n\n    # Create single-document doc_ids (all same doc_id = 0)\n    doc_ids = torch.zeros(batch_size, nheads, max(seqlen_q, seqlen_k), dtype=torch.int32, device=\"cuda\")\n\n    from mask_mod_definitions import get_mask_pair\n    mask_mod_cute, mask_mod_flex = get_mask_pair(\"document\", seqlen_q=seqlen_q, seqlen_k=seqlen_k)\n\n    original_flex_mask = mask_mod_flex\n    def mask_mod_flex(b, h, q_idx, kv_idx, doc_ids=doc_ids):\n        return original_flex_mask(b, h, q_idx, kv_idx, doc_ids)\n\n    aux_tensors_arg = [doc_ids]\n\n    # Create tensors\n    q = torch.randn(batch_size, seqlen_q, nheads, headdim, device=\"cuda\", dtype=dtype)\n    k = torch.randn(batch_size, seqlen_k, nheads, headdim, device=\"cuda\", dtype=dtype)\n    v = torch.randn(batch_size, seqlen_k, nheads, headdim, device=\"cuda\", dtype=dtype)\n    out = torch.empty(batch_size, seqlen_q, nheads, headdim, device=\"cuda\", dtype=dtype)\n    lse = torch.empty(batch_size, nheads, seqlen_q, device=\"cuda\", dtype=torch.float32)\n\n    sparse_tile_m = 2 * tile_m\n    bm = create_block_mask(\n        mask_mod_flex, batch_size, nheads, seqlen_q, seqlen_k,\n        device=\"cuda\", BLOCK_SIZE=(sparse_tile_m, tile_n),\n    )\n    (\n        _seq_q, _seq_k,\n        kv_mask_cnt, kv_mask_idx, full_kv_cnt, full_kv_idx,\n        q_mask_cnt, q_mask_idx, full_q_cnt, full_q_idx, *_,\n    ) = bm.as_tuple()\n\n    block_sparse_mask_fwd = BlockSparseTensorsTorch(\n        mask_block_cnt=kv_mask_cnt,\n        mask_block_idx=kv_mask_idx,\n        full_block_cnt=full_kv_cnt,\n        full_block_idx=full_kv_idx,\n        block_size=(sparse_tile_m, tile_n),\n    )\n    block_sparse_mask_bwd = BlockSparseTensorsTorch(\n        mask_block_cnt=q_mask_cnt,\n        mask_block_idx=q_mask_idx,\n        full_block_cnt=full_q_cnt,\n        full_block_idx=full_q_idx,\n        block_size=(sparse_tile_m, tile_n),\n    )\n\n\n    out_tuple = _flash_attn_fwd(\n        q=q, k=k, v=v, out=out, lse=lse,\n        cu_seqlens_q=None, cu_seqlens_k=None,\n        seqused_q=None, seqused_k=None, page_table=None,\n        causal=False, softcap=None,\n        window_size_left=-1, window_size_right=-1,\n        tile_mn=(tile_m, tile_n), pack_gqa=False,\n        _arch=None, score_mod=None,\n        mask_mod=mask_mod_cute,\n        block_sparse_tensors=block_sparse_mask_fwd,\n        return_lse=True, aux_tensors=aux_tensors_arg,\n    )\n    out_cute = out_tuple[0]\n    lse_cute = out_tuple[1]\n\n    # Backward pass\n    grad_out = torch.randn_like(out_cute)\n\n    dq_cute, dk_cute, dv_cute = run_cute_mask_bwd(\n        q, k, v, out_cute, lse_cute, grad_out, mask_mod_cute,\n        block_sparse_mask_bwd=block_sparse_mask_bwd,\n        tile_m=tile_m, tile_n=tile_n,\n        aux_tensors=aux_tensors_arg,\n    )\n\n    flex_block_mask = create_block_mask(\n        mask_mod_flex, batch_size, nheads, seqlen_q, seqlen_k,\n        device=\"cuda\", BLOCK_SIZE=(tile_m, tile_n),\n    )\n    out_ref, dq_ref, dk_ref, dv_ref = run_flex_reference_bwd(\n        q, k, v, flex_block_mask, grad_out, dtype=torch.float32\n    )\n\n    # Compare\n    dq_err = (dq_cute - dq_ref.to(dtype)).abs().max().item()\n    dk_err = (dk_cute - dk_ref.to(dtype)).abs().max().item()\n    dv_err = (dv_cute - dv_ref.to(dtype)).abs().max().item()\n\n    print(f\"dQ error: {dq_err:.2e}\")\n    print(f\"dK error: {dk_err:.2e}\")\n    print(f\"dV error: {dv_err:.2e}\")\n\n    # Assert gradients are correct (this will fail, demonstrating the bug)\n    assert dq_err < 0.05, f\"dQ error too large: {dq_err:.2e}\"\n    assert dk_err < 0.05, f\"dK error too large: {dk_err:.2e}\"\n    assert dv_err < 0.05, f\"dV error too large: {dv_err:.2e}\"\n\n\n@pytest.mark.parametrize(\"seqlen_q,seqlen_k\", SEQLEN_PAIRS_COMPREHENSIVE)\n@pytest.mark.parametrize(\"nheads\", [16])\n@pytest.mark.parametrize(\"kv_mode\", [\"mha\", \"gqa\", \"mqa\"])\n@pytest.mark.parametrize(\"headdim\", [128])\n@pytest.mark.parametrize(\"dtype\", [torch.bfloat16])\n@pytest.mark.parametrize(\"use_block_sparsity\", [True, False])\n@pytest.mark.parametrize(\n    \"mask_name\",\n    [\"block_diagonal\", \"mini_causal\"],\n)\n@pytest.mark.parametrize(\"tile_m,tile_n\", [(128, 128), (128, 112)])\ndef test_static_masks(\n    seqlen_q, seqlen_k, nheads, kv_mode, headdim, dtype, use_block_sparsity, mask_name, tile_m, tile_n\n):\n    \"\"\"Test static masks that don't require recompilation per seqlen pair.\n\n    Known good masks:\n    - block_diagonal: Masks by 64-element diagonal blocks\n    - mini_causal: Local causal within 128-element tiles\n    \"\"\"\n    if COMPUTE_CAPABILITY == 10 and (tile_m, tile_n) != (128, 128):\n        pytest.skip(\"TODO: Non-128x128 tiles currently not supported on SM 10.0. due to TMEM\")\n\n    _run_mask_test(\n        seqlen_q=seqlen_q,\n        seqlen_k=seqlen_k,\n        nheads=nheads,\n        kv_mode=kv_mode,\n        headdim=headdim,\n        dtype=dtype,\n        mask_name=mask_name,\n        window_size=None,\n        window_left=None,\n        window_right=None,\n        tile_m=tile_m,\n        tile_n=tile_n,\n        use_block_sparsity=use_block_sparsity,\n        needs_backward=True,\n    )\n\n\n@pytest.mark.parametrize(\"seqlen_q,seqlen_k\", SEQLEN_PAIRS_SMOKE)\n@pytest.mark.parametrize(\"nheads\", [16])\n@pytest.mark.parametrize(\"kv_mode\", [\"mha\", \"gqa\"])\n@pytest.mark.parametrize(\"headdim\", [128])\n@pytest.mark.parametrize(\"dtype\", [torch.bfloat16])\n@pytest.mark.parametrize(\"use_block_sparsity\", [True, False])\n@pytest.mark.parametrize(\n    \"mask_name,window_size\",\n    [\n        (\"causal\", None),\n        (\"block_causal\", None),\n        (\"sliding_window\", 128),\n        (\"sliding_window\", 256),\n        (\"sliding_window\", 512),\n        (\"document\", None),\n    ],\n)\n@pytest.mark.parametrize(\"tile_m,tile_n\", [(128, 128), (128, 112), (64, 128)])\ndef test_parameterized_masks(\n    seqlen_q, seqlen_k, nheads, kv_mode, headdim, dtype, use_block_sparsity, mask_name, window_size, tile_m, tile_n\n):\n    \"\"\"Test parameterized masks that require recompilation per seqlen pair.\n\n    Uses fewer seqlen combinations to reduce test time.\n\n    Masks tested:\n    - causal, block_causal: Require offset = seqlen_k - seqlen_q\n    - sliding_window: Requires window size and offset parameters\n    - document: Slower to check\n    \"\"\"\n    if COMPUTE_CAPABILITY == 10 and (tile_m, tile_n) != (128, 128):\n        pytest.skip(\"TODO: Non-128x128 tiles currently not supported on SM 10.0. due to TMEM\")\n\n    _run_mask_test(\n        seqlen_q=seqlen_q,\n        seqlen_k=seqlen_k,\n        nheads=nheads,\n        kv_mode=kv_mode,\n        headdim=headdim,\n        dtype=dtype,\n        mask_name=mask_name,\n        window_size=window_size,\n        window_left=None,\n        window_right=None,\n        tile_m=tile_m,\n        tile_n=tile_n,\n        use_block_sparsity=use_block_sparsity,\n        needs_backward=True,\n    )\n\n\ndef test_sm100_block_sparse_sink_all_masked():\n    \"\"\"Block-sparse regression for the sink path\"\"\"\n    if torch.cuda.get_device_capability()[0] != 10:\n        pytest.skip(\"SM100-only test\")\n    device = \"cuda\"\n    dtype = torch.bfloat16\n    batch_size = 1\n    seqlen_q = 256\n    seqlen_k = 128\n    nheads = 8\n    headdim = 128\n    q = torch.randn(batch_size, seqlen_q, nheads, headdim, dtype=dtype, device=device)\n    k = torch.randn(batch_size, seqlen_k, nheads, headdim, dtype=dtype, device=device)\n    v = torch.randn(batch_size, seqlen_k, nheads, headdim, dtype=dtype, device=device)\n    learnable_sink = torch.full((nheads,), 0.5, dtype=torch.bfloat16, device=device)\n    zero_cnt = torch.zeros((batch_size, nheads, 1), dtype=torch.int32, device=device)\n    zero_idx = torch.zeros((batch_size, nheads, 1, 1), dtype=torch.int32, device=device)\n    sparse = BlockSparseTensorsTorch(\n        mask_block_cnt=zero_cnt,\n        mask_block_idx=zero_idx,\n        full_block_cnt=zero_cnt,\n        full_block_idx=zero_idx,\n        block_size=(256, 128),\n    )\n    softmax_scale = 1.0 / math.sqrt(headdim)\n    _, lse = _flash_attn_fwd(\n        q=q,\n        k=k,\n        v=v,\n        softmax_scale=softmax_scale,\n        causal=False,\n        window_size_left=None,\n        window_size_right=None,\n        learnable_sink=learnable_sink,\n        tile_mn=(128, 128),\n        num_threads=384,\n        pack_gqa=False,\n        block_sparse_tensors=sparse,\n        return_lse=True,\n    )\n    # Fully masked tile ⇒ probability mass sits entirely on the sink, so LSE equals sink logit.\n    expected = learnable_sink.float()[None, :, None].expand_as(lse)\n    assert torch.allclose(lse, expected, atol=0.0, rtol=0.0)\n\n\n@pytest.mark.skipif(COMPUTE_CAPABILITY != 10, reason=\"SM100-only test\")\ndef test_sm100_block_sparse_q_stage1():\n    from flash_attn.cute import flash_fwd_sm100\n    from flash_attn.cute.interface import _flash_attn_fwd\n\n    observed = {}\n    original_init = flash_fwd_sm100.FlashAttentionForwardSm100.__init__\n\n    def wrapped_init(self, *args, **kwargs):\n        observed[\"q_stage\"] = kwargs.get(\"q_stage\")\n        return original_init(self, *args, **kwargs)\n\n    with mock.patch.object(\n        flash_fwd_sm100.FlashAttentionForwardSm100,\n        \"__init__\",\n        wrapped_init,\n    ):\n        compile_cache = _flash_attn_fwd.compile_cache\n        _flash_attn_fwd.compile_cache = get_jit_cache(\"test_mask_mod.fwd\")\n        try:\n            _run_mask_test(\n                seqlen_q=128,\n                seqlen_k=128,\n                nheads=4,\n                kv_mode=\"mha\",\n                headdim=128,\n                dtype=torch.bfloat16,\n                mask_name=\"block_diagonal\",\n                window_size=None,\n                window_left=None,\n                window_right=None,\n                tile_m=128,\n                tile_n=128,\n                use_block_sparsity=True,\n                needs_backward=False,\n            )\n        finally:\n            _flash_attn_fwd.compile_cache.clear()\n            _flash_attn_fwd.compile_cache = compile_cache\n    assert observed.get(\"q_stage\") == 1\n\n\n@pytest.mark.skipif(COMPUTE_CAPABILITY != 10, reason=\"SM100-only test\")\ndef test_sm100_block_sparse_coarse_blocks():\n    torch.manual_seed(42)\n    seqlen_q = 512\n    seqlen_k = 512\n    nheads = 4\n    headdim = 128\n    dtype = torch.bfloat16\n    tile_m = 128\n    tile_n = 128\n    sparse_tile_m = 512\n    batch_size = 1\n\n    mask_mod_cute, mask_mod_flex = get_mask_pair(\n        \"block_diagonal\", seqlen_q=seqlen_q, seqlen_k=seqlen_k, window_size=None\n    )\n    tensors = create_tensors(\n        batch_size, seqlen_q, seqlen_k, nheads, nheads, headdim, headdim, dtype\n    )\n\n    bm = create_block_mask(\n        mask_mod_flex,\n        batch_size,\n        nheads,\n        seqlen_q,\n        seqlen_k,\n        device=\"cuda\",\n        BLOCK_SIZE=(sparse_tile_m, tile_n),\n    )\n    (\n        _seq_q,\n        _seq_k,\n        kv_mask_cnt,\n        kv_mask_idx,\n        full_kv_cnt,\n        full_kv_idx,\n        *_,\n    ) = bm.as_tuple()\n\n    block_sparse_mask_fwd = BlockSparseTensorsTorch(\n        mask_block_cnt=kv_mask_cnt,\n        mask_block_idx=kv_mask_idx,\n        full_block_cnt=full_kv_cnt,\n        full_block_idx=full_kv_idx,\n        block_size=(sparse_tile_m, tile_n),\n    )\n\n    out_cute, _ = _flash_attn_fwd(\n        q=tensors[\"q\"],\n        k=tensors[\"k\"],\n        v=tensors[\"v\"],\n        out=tensors[\"out\"],\n        lse=tensors[\"lse\"],\n        cu_seqlens_q=None,\n        cu_seqlens_k=None,\n        seqused_q=None,\n        seqused_k=None,\n        page_table=None,\n        softmax_scale=1.0 / math.sqrt(headdim),\n        causal=False,\n        softcap=None,\n        window_size_left=None,\n        window_size_right=None,\n        learnable_sink=None,\n        tile_mn=(tile_m, tile_n),\n        pack_gqa=False,\n        _arch=None,\n        score_mod=None,\n        mask_mod=mask_mod_cute,\n        block_sparse_tensors=block_sparse_mask_fwd,\n        return_lse=True,\n    )\n\n    tensors_fp32 = {\n        k: v.float() if v.dtype in [torch.float16, torch.bfloat16] else v\n        for k, v in tensors.items()\n    }\n    out_ref_fp32 = compute_reference_flex_attn(\n        tensors_fp32, mask_mod_flex, (sparse_tile_m, tile_n)\n    )\n    out_ref = compute_reference_flex_attn(tensors, mask_mod_flex, (sparse_tile_m, tile_n))\n\n    assert out_cute.shape == out_ref_fp32.shape == out_ref.shape\n    assert not torch.isnan(out_cute).any()\n    assert not torch.isnan(out_ref_fp32).any()\n    assert torch.isfinite(out_cute).all()\n    assert torch.isfinite(out_ref_fp32).all()\n\n    fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item()\n    rtol = 2\n    pt_error = (out_ref - out_ref_fp32).abs().max().item()\n    cute_error = (out_cute - out_ref_fp32).abs().max().item()\n    assert cute_error <= rtol * pt_error + fwd_atol, (\n        f\"Kernel error {cute_error:.2e} exceeds {rtol}x PyTorch error {pt_error:.2e} + {fwd_atol:.2e}\"\n    )\n\n\n@pytest.mark.skipif(COMPUTE_CAPABILITY != 10, reason=\"SM100-only test\")\ndef test_sm100_block_sparse_coarse_blocks_mismatch():\n    torch.manual_seed(0)\n    seqlen_q = 1024\n    seqlen_k = 512\n    nheads = 2\n    headdim = 128\n    dtype = torch.bfloat16\n    tile_m = 128\n    tile_n = 128\n    sparse_tile_m = 512\n    batch_size = 1\n\n    mask_mod_cute, mask_mod_flex = get_coarse_block_mask_pair(\n        sparse_tile_m, tile_n, last_block=3\n    )\n    tensors = create_tensors(\n        batch_size, seqlen_q, seqlen_k, nheads, nheads, headdim, headdim, dtype\n    )\n\n    bm = create_block_mask(\n        mask_mod_flex,\n        batch_size,\n        nheads,\n        seqlen_q,\n        seqlen_k,\n        device=\"cuda\",\n        BLOCK_SIZE=(sparse_tile_m, tile_n),\n    )\n    (\n        _seq_q,\n        _seq_k,\n        kv_mask_cnt,\n        kv_mask_idx,\n        full_kv_cnt,\n        full_kv_idx,\n        *_,\n    ) = bm.as_tuple()\n\n    block_sparse_mask_fwd = BlockSparseTensorsTorch(\n        mask_block_cnt=kv_mask_cnt,\n        mask_block_idx=kv_mask_idx,\n        full_block_cnt=full_kv_cnt,\n        full_block_idx=full_kv_idx,\n        block_size=(sparse_tile_m, tile_n),\n    )\n\n    observed = {}\n    original_normalize = normalize_block_sparse_config\n\n    def wrapped_normalize(*args, **kwargs):\n        normalized, pattern, q_subtile_factor = original_normalize(*args, **kwargs)\n        observed[\"q_subtile_factor\"] = q_subtile_factor\n        return normalized, pattern, q_subtile_factor\n\n    with mock.patch(\"flash_attn.cute.interface.normalize_block_sparse_config\", wrapped_normalize):\n        out_cute, _ = _flash_attn_fwd(\n            q=tensors[\"q\"],\n            k=tensors[\"k\"],\n            v=tensors[\"v\"],\n            out=tensors[\"out\"],\n            lse=tensors[\"lse\"],\n            cu_seqlens_q=None,\n            cu_seqlens_k=None,\n            seqused_q=None,\n            seqused_k=None,\n            page_table=None,\n            softmax_scale=1.0 / math.sqrt(headdim),\n            causal=False,\n            softcap=None,\n            window_size_left=None,\n            window_size_right=None,\n            learnable_sink=None,\n            tile_mn=(tile_m, tile_n),\n            pack_gqa=False,\n            _arch=None,\n            score_mod=None,\n            mask_mod=mask_mod_cute,\n            block_sparse_tensors=block_sparse_mask_fwd,\n            return_lse=True,\n        )\n    assert observed.get(\"q_subtile_factor\") == 2\n\n    tensors_fp32 = {\n        k: v.float() if v.dtype in [torch.float16, torch.bfloat16] else v\n        for k, v in tensors.items()\n    }\n    out_ref_fp32 = compute_reference_flex_attn(\n        tensors_fp32, mask_mod_flex, (sparse_tile_m, tile_n)\n    )\n    out_ref = compute_reference_flex_attn(tensors, mask_mod_flex, (sparse_tile_m, tile_n))\n\n    assert out_cute.shape == out_ref_fp32.shape == out_ref.shape\n    assert not torch.isnan(out_cute).any()\n    assert not torch.isnan(out_ref_fp32).any()\n    assert torch.isfinite(out_cute).all()\n    assert torch.isfinite(out_ref_fp32).all()\n\n    fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item()\n    rtol = 2\n    pt_error = (out_ref - out_ref_fp32).abs().max().item()\n    cute_error = (out_cute - out_ref_fp32).abs().max().item()\n    assert cute_error <= rtol * pt_error + fwd_atol, (\n        f\"Kernel error {cute_error:.2e} exceeds {rtol}x PyTorch error {pt_error:.2e} + {fwd_atol:.2e}\"\n    )\n\n\n# =============================================================================\n# Backward Helper Functions\n# =============================================================================\n\ndef run_cute_mask_bwd(\n    q, k, v, out, lse, grad_out, mask_mod_cute,\n    block_sparse_mask_bwd=None, tile_m=128, tile_n=128,\n    aux_tensors=None,\n):\n    \"\"\"Run flash attention backward with mask_mod.\n\n    Args:\n        q, k, v: Input tensors in BSHD format\n        out: Forward output tensor\n        lse: Log-sum-exp from forward pass\n        grad_out: Gradient of output\n        mask_mod_cute: CuTE mask modification function\n        block_sparse_mask_bwd: Block sparse tensors for backward pass\n        tile_m, tile_n: Tile sizes\n        aux_tensors: Auxiliary tensors for mask_mod (e.g., doc_ids for document masking)\n\n    Returns (dq, dk, dv) all in BSHD format.\n    \"\"\"\n    dq, dk, dv = _flash_attn_bwd(\n        q=q,\n        k=k,\n        v=v,\n        out=out,\n        dout=grad_out,\n        lse=lse,\n        causal=False,\n        m_block_size=tile_m,\n        n_block_size=tile_n,\n        mask_mod=mask_mod_cute,\n        block_sparse_tensors=block_sparse_mask_bwd,\n        aux_tensors=aux_tensors,\n    )\n\n    return dq, dk, dv\n\n\ndef run_flex_reference_bwd(q, k, v, block_mask, grad_out, dtype=None):\n    \"\"\"Run flex_attention forward + backward for reference.\n\n    Args:\n        q, k, v: Input tensors in BSHD format\n        block_mask: Pre-created block mask for flex_attention\n        grad_out: Gradient of output in BSHD format\n        dtype: Optional dtype to cast inputs to (e.g., torch.float32 for reference)\n\n    Returns (out, dq, dk, dv) all in BSHD format.\n    \"\"\"\n    # Transpose to BHSD for flex_attention\n    if dtype is not None:\n        q_ref = q.transpose(1, 2).to(dtype).requires_grad_(True)\n        k_ref = k.transpose(1, 2).to(dtype).requires_grad_(True)\n        v_ref = v.transpose(1, 2).to(dtype).requires_grad_(True)\n        grad_out_ref = grad_out.transpose(1, 2).to(dtype)\n    else:\n        q_ref = q.transpose(1, 2).requires_grad_(True)\n        k_ref = k.transpose(1, 2).requires_grad_(True)\n        v_ref = v.transpose(1, 2).requires_grad_(True)\n        grad_out_ref = grad_out.transpose(1, 2)\n\n    # Use flex_attention directly without torch.compile for backward tests\n    # torch.compile can hang on certain mask patterns (e.g., mini_causal with float32)\n    out_ref = flex_attention(q_ref, k_ref, v_ref, block_mask=block_mask, enable_gqa=True)\n    dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), grad_out_ref)\n\n    # Transpose back to BSHD\n    return (\n        out_ref.transpose(1, 2),\n        dq_ref.transpose(1, 2),\n        dk_ref.transpose(1, 2),\n        dv_ref.transpose(1, 2),\n    )\n\n\ndef test_sm90_block_sparse_bwd_mismatched_q_block_granularity_error_message():\n    if COMPUTE_CAPABILITY != 9:\n        pytest.skip(\"SM90-only test\")\n\n    batch_size = 1\n    seqlen_q = 256\n    seqlen_k = 256\n    nheads = 4\n    nheads_kv = nheads\n    headdim = 128\n    dtype = torch.bfloat16\n    tile_m = 80\n    tile_n = 128\n\n    tensors = create_tensors(batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, headdim, headdim, dtype)\n    mask_mod_cute, mask_mod_flex = get_mask_pair(\"block_diagonal\", seqlen_q=seqlen_q, seqlen_k=seqlen_k)\n    bm = create_block_mask(\n        mask_mod_flex,\n        batch_size,\n        nheads,\n        seqlen_q,\n        seqlen_k,\n        device=\"cuda\",\n        BLOCK_SIZE=(tile_m, tile_n),\n    )\n    (\n        _seq_q,\n        _seq_k,\n        _kv_mask_cnt,\n        _kv_mask_idx,\n        _full_kv_cnt,\n        _full_kv_idx,\n        q_mask_cnt,\n        q_mask_idx,\n        full_q_cnt,\n        full_q_idx,\n        *_,\n    ) = bm.as_tuple()\n\n    block_sparse_mask_bwd = BlockSparseTensorsTorch(\n        mask_block_cnt=q_mask_cnt,\n        mask_block_idx=q_mask_idx,\n        full_block_cnt=full_q_cnt,\n        full_block_idx=full_q_idx,\n        block_size=(tile_m, tile_n),\n    )\n\n    softmax_scale = 1.0 / math.sqrt(headdim)\n    out = torch.empty(batch_size, seqlen_q, nheads, headdim, device=\"cuda\", dtype=dtype)\n    lse = torch.empty(batch_size, nheads, seqlen_q, device=\"cuda\", dtype=torch.float32)\n    grad_out = torch.randn_like(out)\n\n    with pytest.raises(\n        ValueError,\n        match=r\"Block sparsity expects sparse_block_size_q=128 for subtile_factor=2\\.\",\n    ):\n        _flash_attn_bwd(\n            q=tensors[\"q\"],\n            k=tensors[\"k\"],\n            v=tensors[\"v\"],\n            out=out,\n            dout=grad_out,\n            lse=lse,\n            softmax_scale=softmax_scale,\n            causal=False,\n            m_block_size=tile_m,\n            n_block_size=tile_n,\n            mask_mod=mask_mod_cute,\n            block_sparse_tensors=block_sparse_mask_bwd,\n        )\n\n\ndef test_gqa_block_sparse_broadcast_pattern_recompilation():\n    \"\"\"Test that different block sparse broadcast patterns trigger recompilation.\n\n    This is a regression test for a bug where:\n    1. First call with block_mask H=1 (broadcasts across all query heads)\n    2. Second call with block_mask H=nheads (no broadcast)\n    3. Second call incorrectly reused cached kernel from first call\n\n    The fix adds block_sparse_broadcast_pattern to the compile key so that\n    kernels are recompiled when broadcast patterns change. CuTe's\n    mark_layout_dynamic() keeps stride=0 as static, so different broadcast\n    patterns require different compiled kernels.\n    \"\"\"\n    torch.manual_seed(42)\n\n    batch_size = 2\n    nheads = 8\n    nheads_kv = 2\n    seqlen = 257\n    headdim = 64\n    dtype = torch.bfloat16\n    tile_m = 128\n    tile_n = 128\n\n    sparse_tile_m = 2 * tile_m if COMPUTE_CAPABILITY == 10 else tile_m\n\n    def causal_mask(b, h, q, kv):\n        return q >= kv\n\n    mask_mod_cute, _ = get_mask_pair(\"causal\", seqlen_q=seqlen, seqlen_k=seqlen)\n\n    tensors = create_tensors(batch_size, seqlen, seqlen, nheads, nheads_kv, headdim, headdim, dtype)\n    q, k, v = tensors[\"q\"], tensors[\"k\"], tensors[\"v\"]\n    grad_out = torch.randn_like(tensors[\"out\"])\n    softmax_scale = 1.0 / math.sqrt(headdim)\n\n    def run_with_block_mask_nheads(block_mask_nheads: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        bm = create_block_mask(\n            causal_mask, batch_size, block_mask_nheads, seqlen, seqlen,\n            device=\"cuda\", BLOCK_SIZE=(sparse_tile_m, tile_n),\n        )\n        (\n            _seq_q, _seq_k,\n            kv_mask_cnt, kv_mask_idx, full_kv_cnt, full_kv_idx,\n            q_mask_cnt, q_mask_idx, full_q_cnt, full_q_idx, *_,\n        ) = bm.as_tuple()\n\n        block_sparse_fwd = BlockSparseTensorsTorch(\n            mask_block_cnt=kv_mask_cnt,\n            mask_block_idx=kv_mask_idx,\n            full_block_cnt=full_kv_cnt,\n            full_block_idx=full_kv_idx,\n            block_size=(sparse_tile_m, tile_n),\n        )\n        block_sparse_bwd = BlockSparseTensorsTorch(\n            mask_block_cnt=q_mask_cnt,\n            mask_block_idx=q_mask_idx,\n            full_block_cnt=full_q_cnt,\n            full_block_idx=full_q_idx,\n            block_size=(sparse_tile_m, tile_n),\n        )\n\n        out = torch.empty_like(tensors[\"out\"])\n        lse = torch.empty_like(tensors[\"lse\"])\n\n        out_tuple = _flash_attn_fwd(\n            q=q, k=k, v=v, out=out, lse=lse,\n            softmax_scale=softmax_scale, causal=False,\n            window_size_left=-1, window_size_right=-1,\n            tile_mn=(tile_m, tile_n), pack_gqa=False,\n            mask_mod=mask_mod_cute, block_sparse_tensors=block_sparse_fwd,\n            return_lse=True,\n        )\n        out_cute, lse_cute = out_tuple[0], out_tuple[1]\n\n        dq, dk, dv = run_cute_mask_bwd(\n            q, k, v, out_cute, lse_cute, grad_out, mask_mod_cute,\n            block_sparse_mask_bwd=block_sparse_bwd, tile_m=tile_m, tile_n=tile_n,\n        )\n        return dq, dk, dv\n\n    flex_block_mask = create_block_mask(\n        causal_mask, batch_size, nheads, seqlen, seqlen,\n        device=\"cuda\", BLOCK_SIZE=(tile_m, tile_n),\n    )\n    _, dq_ref, dk_ref, dv_ref = run_flex_reference_bwd(q, k, v, flex_block_mask, grad_out, dtype=torch.float32)\n    dq_ref, dk_ref, dv_ref = dq_ref.to(dtype), dk_ref.to(dtype), dv_ref.to(dtype)\n\n    dq_broadcast, dk_broadcast, dv_broadcast = run_with_block_mask_nheads(1)\n    dq_no_broadcast, dk_no_broadcast, dv_no_broadcast = run_with_block_mask_nheads(nheads)\n\n    err_broadcast_dq = (dq_broadcast - dq_ref).abs().max().item()\n    err_no_broadcast_dq = (dq_no_broadcast - dq_ref).abs().max().item()\n\n    print(\"\\nGQA block sparse broadcast pattern test:\")\n    print(f\"  dQ error (H=1 broadcast): {err_broadcast_dq:.2e}\")\n    print(f\"  dQ error (H={nheads} no broadcast): {err_no_broadcast_dq:.2e}\")\n\n    assert err_broadcast_dq < 0.1, f\"Broadcast dQ error too large: {err_broadcast_dq:.2e}\"\n    assert err_no_broadcast_dq < 0.1, f\"No-broadcast dQ error too large: {err_no_broadcast_dq:.2e}\"\n\n\ndef test_gqa_expand_stride_zero_bug():\n    \"\"\"Test that GQA with expand()-created K/V tensors works correctly.\n\n    This is a regression test for bugs with expand()-created tensors:\n\n    Forward bug: cute.assume() fails when tensor strides are Python int 0\n    (from expand()) instead of MLIR values.\n    Error: AttributeError: 'int' object has no attribute 'type'\n\n    Backward bug: mark_layout_dynamic fails with expanded tensors.\n    Error: RuntimeError: Expected strides[leading_dim] == 1, but got N.\n\n    Trigger: expand() + transpose() creates stride=0 dimensions (GQA pattern).\n    \"\"\"\n    torch.manual_seed(42)\n\n    batch_size = 1\n    seqlen = 2048\n    headdim = 128\n    n_heads = 4\n    n_kv_heads = 1\n    dtype = torch.bfloat16\n    device = \"cuda\"\n\n    q = torch.randn(batch_size, seqlen, n_heads, headdim, device=device, dtype=dtype)\n    k_orig = torch.randn(batch_size, seqlen, n_kv_heads, headdim, device=device, dtype=dtype)\n    v_orig = torch.randn(batch_size, seqlen, n_kv_heads, headdim, device=device, dtype=dtype)\n\n    k = k_orig.expand(batch_size, seqlen, n_heads, headdim)\n    v = v_orig.expand(batch_size, seqlen, n_heads, headdim)\n\n    assert k.stride()[2] == 0, \"K should have stride=0 in head dim from expand()\"\n    assert v.stride()[2] == 0, \"V should have stride=0 in head dim from expand()\"\n\n    out = torch.empty_like(q)\n    lse = torch.empty(batch_size, n_heads, seqlen, device=device, dtype=torch.float32)\n    softmax_scale = 1.0 / math.sqrt(headdim)\n\n    out_tuple = _flash_attn_fwd(\n        q=q, k=k, v=v, out=out, lse=lse,\n        softmax_scale=softmax_scale,\n        causal=True,\n        tile_mn=(128, 128),\n        return_lse=True,\n    )\n    out_fwd, lse_fwd = out_tuple[0], out_tuple[1]\n\n    assert not torch.isnan(out_fwd).any(), \"Forward output contains NaN\"\n    assert torch.isfinite(out_fwd).all(), \"Forward output contains non-finite values\"\n\n    tensors_for_ref = {\"q\": q, \"k\": k, \"v\": v}\n    tensors_fp32 = {\"q\": q.float(), \"k\": k.float(), \"v\": v.float()}\n\n    def causal_mask(b, h, q_idx, kv_idx):\n        return q_idx >= kv_idx\n\n    out_ref = compute_reference_flex_attn(tensors_for_ref, causal_mask)\n    out_ref_fp32 = compute_reference_flex_attn(tensors_fp32, causal_mask)\n\n    fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item()\n    rtol = 2\n    pt_error = (out_ref - out_ref_fp32).abs().max().item()\n    cute_error = (out_fwd - out_ref_fp32).abs().max().item()\n\n    print(f\"\\nGQA expand stride=0 test:\")\n    print(f\"  Forward: kernel err={cute_error:.2e}, ref err={pt_error:.2e}, atol={fwd_atol:.2e}\")\n    assert cute_error <= rtol * pt_error + fwd_atol, (\n        f\"Forward error {cute_error:.2e} exceeds {rtol}x ref error {pt_error:.2e} + {fwd_atol:.2e}\"\n    )\n\n    grad_out = torch.randn_like(out_fwd)\n    dq, dk, dv = _flash_attn_bwd(\n        q=q, k=k, v=v, out=out_fwd, dout=grad_out, lse=lse_fwd,\n        softmax_scale=softmax_scale,\n        causal=True,\n        m_block_size=128, n_block_size=128,\n    )\n\n    assert not torch.isnan(dq).any(), \"dQ contains NaN\"\n    assert not torch.isnan(dk).any(), \"dK contains NaN\"\n    assert not torch.isnan(dv).any(), \"dV contains NaN\"\n\n    flex_block_mask = create_block_mask(\n        causal_mask, batch_size, n_heads, seqlen, seqlen,\n        device=device, BLOCK_SIZE=(128, 128),\n    )\n    _, dq_ref, dk_ref, dv_ref = run_flex_reference_bwd(q, k, v, flex_block_mask, grad_out, dtype=torch.float32)\n\n    bwd_rtol = 2\n    bwd_atol_floor = 1e-5\n\n    dq_atol = max(bwd_atol_floor, 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item())\n    dk_atol = max(bwd_atol_floor, 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item())\n    dv_atol = max(bwd_atol_floor, 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item())\n\n    _, dq_pt, dk_pt, dv_pt = run_flex_reference_bwd(q, k, v, flex_block_mask, grad_out)\n\n    pt_dq_err = (dq_pt - dq_ref.to(dtype)).abs().max().item()\n    pt_dk_err = (dk_pt - dk_ref.to(dtype)).abs().max().item()\n    pt_dv_err = (dv_pt - dv_ref.to(dtype)).abs().max().item()\n\n    cute_dq_err = (dq - dq_ref.to(dtype)).abs().max().item()\n    cute_dk_err = (dk - dk_ref.to(dtype)).abs().max().item()\n    cute_dv_err = (dv - dv_ref.to(dtype)).abs().max().item()\n\n    print(f\"  Backward dQ: kernel err={cute_dq_err:.2e}, ref err={pt_dq_err:.2e}, atol={dq_atol:.2e}\")\n    print(f\"  Backward dK: kernel err={cute_dk_err:.2e}, ref err={pt_dk_err:.2e}, atol={dk_atol:.2e}\")\n    print(f\"  Backward dV: kernel err={cute_dv_err:.2e}, ref err={pt_dv_err:.2e}, atol={dv_atol:.2e}\")\n\n    assert cute_dq_err <= bwd_rtol * pt_dq_err + dq_atol, f\"dQ error too large: {cute_dq_err:.2e}\"\n    assert cute_dk_err <= bwd_rtol * pt_dk_err + dk_atol, f\"dK error too large: {cute_dk_err:.2e}\"\n    assert cute_dv_err <= bwd_rtol * pt_dv_err + dv_atol, f\"dV error too large: {cute_dv_err:.2e}\"\n\n\n@pytest.mark.skipif(COMPUTE_CAPABILITY not in (10, 11), reason=\"SM100/SM110 persistent forward only\")\ndef test_persistent_blocksparse_empty_tiles():\n    \"\"\"Regression test for persistent forward deadlock with highly-sparse block masks.\n\n    When most Q-tiles are empty (no active KV blocks), the persistent kernel\n    deadlocked due to barrier phase desync in the empty-tile paths of both the\n    softmax and correction warp groups.\n    \"\"\"\n    torch.manual_seed(5)\n    batch_size, nheads_q, nheads_kv = 2, 16, 1\n    seqlen_q, seqlen_k, headdim = 8192, 128, 128\n    tile_m, tile_n = 128, 128\n    dtype = torch.bfloat16\n\n    sparse_tile_m = 2 * tile_m if COMPUTE_CAPABILITY == 10 else tile_m\n    window_size = 64\n    mask_mod_cute, mask_mod_flex = get_mask_pair(\n        \"sliding_window\", seqlen_q=seqlen_q, seqlen_k=seqlen_k, window_size=window_size,\n    )\n\n    bm = create_block_mask(\n        mask_mod_flex, batch_size, nheads_q, seqlen_q, seqlen_k,\n        device=\"cuda\", BLOCK_SIZE=(sparse_tile_m, tile_n),\n    )\n    (_, _, kv_mask_cnt, kv_mask_idx, full_kv_cnt, full_kv_idx, *_) = bm.as_tuple()\n    block_sparse_mask_fwd = BlockSparseTensorsTorch(\n        mask_block_cnt=kv_mask_cnt, mask_block_idx=kv_mask_idx,\n        full_block_cnt=full_kv_cnt, full_block_idx=full_kv_idx,\n        block_size=(sparse_tile_m, tile_n),\n    )\n\n    q = torch.randn(batch_size, seqlen_q, nheads_q, headdim, device=\"cuda\", dtype=dtype)\n    k = torch.randn(batch_size, seqlen_k, nheads_kv, headdim, device=\"cuda\", dtype=dtype)\n    v = torch.randn(batch_size, seqlen_k, nheads_kv, headdim, device=\"cuda\", dtype=dtype)\n\n    out, lse = _flash_attn_fwd(\n        q=q, k=k, v=v,\n        out=torch.empty(batch_size, seqlen_q, nheads_q, headdim, device=\"cuda\", dtype=dtype),\n        lse=torch.empty(batch_size, nheads_q, seqlen_q, device=\"cuda\", dtype=torch.float32),\n        cu_seqlens_q=None, cu_seqlens_k=None, seqused_q=None, seqused_k=None,\n        page_table=None, softmax_scale=1.0 / math.sqrt(headdim),\n        causal=False, softcap=None,\n        window_size_left=None, window_size_right=None,\n        learnable_sink=None,\n        tile_mn=(tile_m, tile_n),\n        pack_gqa=False, _arch=None,\n        score_mod=None, mask_mod=mask_mod_cute,\n        block_sparse_tensors=block_sparse_mask_fwd,\n        return_lse=True, aux_tensors=None,\n    )\n    torch.cuda.synchronize()\n    assert out.shape == (batch_size, seqlen_q, nheads_q, headdim)\n    assert not out.isnan().any()\n\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__, \"-v\", \"-s\"])\n"
  },
  {
    "path": "tests/cute/test_score_mod.py",
    "content": "import pytest\nimport torch\nimport cutlass\nimport cutlass.cute as cute\nfrom cutlass._mlir.dialects import math as mlir_math\nimport operator\nfrom torch.nn.attention.flex_attention import flex_attention\nfrom flash_attn.cute.interface import _flash_attn_fwd, _flash_attn_bwd\n\nCOMPUTE_CAPABILITY = torch.cuda.get_device_capability()[0]\n\nfrom score_mod_definitions import (\n    # TensorSSA-based score mods\n    score_mod_identity as score_mod_1,\n    score_mod_causal as score_mod_2,\n    score_mod_rel_bias as score_mod_3,\n    score_mod_rel_bias_x2 as score_mod_4,\n    score_mod_times_two as score_mod_5,\n    score_mod_alibi as score_mod_6,\n    score_mod_sliding_window as score_mod_7,\n    score_mod_block_diagonal as score_mod_8,\n    score_mod_causal_v2 as score_mod_9,\n    score_mod_batch_bias as score_mod_10,\n    score_mod_dual_buffer as score_mod_11,\n)  # isort: split\nfrom score_mod_definitions import (\n    score_mod_identity_vectorized as score_mod_1_vectorized,\n    score_mod_causal_vectorized as score_mod_2_vectorized,\n    score_mod_rel_bias as score_mod_3_vectorized,\n    score_mod_rel_bias_x2_vectorized as score_mod_4_vectorized,\n    score_mod_times_two_vectorized as score_mod_5_vectorized,\n    score_mod_alibi_vectorized as score_mod_6_vectorized,\n    score_mod_batch_bias_vectorized as score_mod_10_vectorized,\n    score_mod_dual_buffer_vectorized as score_mod_11_vectorized,\n)  # isort: split\nfrom score_mod_definitions import (\n    # Eager (torch) reference score mods\n    identity_eager,\n    causal_eager as causal_mask_eager,\n    rel_bias_eager as relative_bias_eager,\n    rel_bias_x2_eager as relative_bias_v2_eager,\n    times_two_eager,\n    alibi_eager as alibi_bias_eager,\n    sliding_window_eager,\n    block_diagonal_eager,\n    causal_v2_eager as causal_mask_v2_eager,\n    batch_bias_factory as batch_bias,\n    dual_buffer_factory as dual_buffer_bias,\n)\n\nCOMPUTE_CAPABILITY = torch.cuda.get_device_capability()[0]\n\n# Test pairs: (cute_jit_function, eager_reference_function)\nTEST_PAIRS = [\n    (score_mod_1, None),\n    (score_mod_2, causal_mask_eager),\n    (score_mod_3, relative_bias_eager),\n    (score_mod_4, relative_bias_v2_eager),\n    (score_mod_5, times_two_eager),\n    (score_mod_6, alibi_bias_eager),\n    (score_mod_7, sliding_window_eager),\n    (score_mod_8, block_diagonal_eager),\n    (score_mod_9, causal_mask_v2_eager),\n]\n\n# Test pairs with aux_tensors: (cute_jit_function, eager_reference_function_factory)\nTEST_PAIRS_WITH_AUX_TENSORS = [\n    (score_mod_10, batch_bias),\n    (score_mod_11, dual_buffer_bias),\n]\n\n# Test pairs to compare vectorized score_mods: (cute_jit_function, cute_jit_function_vectorized)\nTEST_PAIRS_VECTORIZED = [\n    (score_mod_1, score_mod_1_vectorized),\n    (score_mod_2, score_mod_2_vectorized),\n    (score_mod_3, score_mod_3_vectorized),\n    (score_mod_4, score_mod_4_vectorized),\n    (score_mod_5, score_mod_5_vectorized),\n    (score_mod_6, score_mod_6_vectorized),\n]\n\nTEST_PAIRS_WITH_AUX_TENSORS_VECTORIZED = [\n    (score_mod_10, score_mod_10_vectorized),\n    (score_mod_11, score_mod_11_vectorized),\n]\n\nSEQLEN_CONFIGS = [\n    (1, 1),\n    (64, 128),\n    (128, 192),\n    (256, 256),\n    (239, 1),\n    (799, 3),\n    (113, 203),\n    (113, 128),\n    (128, 217),\n    (113, 211),\n    (108, 256),\n    (256, 512),\n    (384, 256),\n    (640, 128),\n    (512, 256),\n    (1024, 1024),\n    (1023, 1024),\n    (1024, 1023),\n    (4096, 4096),\n    (4224, 4224),\n]\n\nVEC_SIZES_TO_CHECK_EQUALITY = [1, 2, 4] if COMPUTE_CAPABILITY == 10 else [1, 2]\n\n\ndef create_tensors(\n    batch_size=2, num_heads=4, seqlen_q=64, seqlen_kv=64, dim=128, dtype=torch.bfloat16\n):\n    q = torch.randn(batch_size, num_heads, seqlen_q, dim, device=\"cuda\", dtype=dtype)\n    k = torch.randn(batch_size, num_heads, seqlen_kv, dim, device=\"cuda\", dtype=dtype)\n    v = torch.randn(batch_size, num_heads, seqlen_kv, dim, device=\"cuda\", dtype=dtype)\n    return q, k, v\n\n\ndef run_cute_flash(q, k, v, cute_score_mod, aux_tensors=None, pack_gqa=False) -> torch.Tensor:\n    q_transposed, k_transposed, v_transposed = map(lambda x: x.transpose(1, 2), (q, k, v))\n    out = torch.empty_like(q_transposed)\n    _flash_attn_fwd(\n        q_transposed,\n        k_transposed,\n        v_transposed,\n        return_lse=True,\n        score_mod=cute_score_mod,\n        out=out,\n        lse=None,\n        aux_tensors=aux_tensors,\n        pack_gqa=pack_gqa,\n    )\n    return out.transpose(1, 2)\n\n\ndef run_flex_reference(q, k, v, eager_score_mod, dtype=None) -> torch.Tensor:\n    if dtype is not None:\n        q, k, v = q.to(dtype), k.to(dtype), v.to(dtype)\n    return flex_attention(q, k, v, score_mod=eager_score_mod, enable_gqa=q.shape[1] != k.shape[1])\n\n\n@pytest.mark.parametrize(\"seqlen_q,seqlen_kv\", SEQLEN_CONFIGS)\n@pytest.mark.parametrize(\"qhead_per_kvhead,num_kv_heads\", [(1, 2), (4, 2)])\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16])\n@pytest.mark.parametrize(\"score_mod_pair\", TEST_PAIRS)\ndef test_cute_vs_flex_attention(\n    seqlen_q, seqlen_kv, qhead_per_kvhead, num_kv_heads, dtype, score_mod_pair\n):\n    torch.random.manual_seed(42)\n    cute_score_mod, eager_score_mod = score_mod_pair\n\n    num_q_heads = num_kv_heads * qhead_per_kvhead\n    pack_gqa = qhead_per_kvhead > 1\n    q, k, v = create_tensors(\n        seqlen_q=seqlen_q, seqlen_kv=seqlen_kv, num_heads=num_q_heads, dtype=dtype\n    )\n    if pack_gqa:\n        k = k[:, :num_kv_heads, :, :].clone()\n        v = v[:, :num_kv_heads, :, :].clone()\n\n    out_ref_fp32 = run_flex_reference(q, k, v, eager_score_mod, dtype=torch.float32)\n\n    out_pt = run_flex_reference(q, k, v, eager_score_mod)\n    out_cute = run_cute_flash(q, k, v, cute_score_mod, pack_gqa=pack_gqa)\n\n    # Basic shape and NaN checks\n    assert out_cute.shape == out_ref_fp32.shape == out_pt.shape\n    assert not torch.isnan(out_cute).any()\n    assert not torch.isnan(out_ref_fp32).any()\n    assert not torch.isnan(out_pt).any()\n    assert torch.isfinite(out_cute).all()\n    assert torch.isfinite(out_ref_fp32).all()\n    assert torch.isfinite(out_pt).all()\n\n    # Numerical error if we just do any arithmetic on out_ref\n    fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item()\n    rtol = 2\n\n    # Calculate actual errors\n    pt_error = (out_pt - out_ref_fp32).abs().max().item()\n    cute_error = (out_cute - out_ref_fp32).abs().max().item()\n\n    print(f\"\\nNumerical comparison for {cute_score_mod.__name__}:\")\n    print(f\"  PyTorch vs FP32 ref max error: {pt_error:.2e}\")\n    print(f\"  CuTE vs FP32 ref max error: {cute_error:.2e}\")\n    print(f\"  Dynamic absolute tolerance: {fwd_atol:.2e}\")\n    print(f\"  Error ratio (CuTE/PyTorch): {cute_error / max(pt_error, 1e-10):.2f}\")\n\n    # Assert that CuTE's error is at most rtol times PyTorch's error + fwd_atol\n    assert cute_error <= rtol * pt_error + fwd_atol, (\n        f\"CuTE error {cute_error:.2e} exceeds {rtol}x PyTorch error {pt_error:.2e} + {fwd_atol:.2e}\"\n    )\n\n\n@pytest.mark.parametrize(\"seqlen_q,seqlen_kv\", SEQLEN_CONFIGS)\n@pytest.mark.parametrize(\"qhead_per_kvhead,num_kv_heads\", [(1, 1), (4, 2)])\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16])\n@pytest.mark.parametrize(\"score_mod_vec_pair\", TEST_PAIRS_VECTORIZED)\ndef test_cute_score_mod_vectorized(\n    seqlen_q,\n    seqlen_kv,\n    qhead_per_kvhead,\n    num_kv_heads,\n    dtype,\n    score_mod_vec_pair,\n):\n    \"\"\"Tests equality between original and vectorized versions of score mods\"\"\"\n    torch.random.manual_seed(42)\n    cute_score_mod, cute_vectorized_score_mod = score_mod_vec_pair\n\n    num_q_heads = num_kv_heads * qhead_per_kvhead\n    pack_gqa = qhead_per_kvhead > 1\n    q, k, v = create_tensors(\n        seqlen_q=seqlen_q, seqlen_kv=seqlen_kv, num_heads=num_q_heads, dtype=dtype\n    )\n    if pack_gqa:\n        k = k[:, :num_kv_heads, :, :].clone()\n        v = v[:, :num_kv_heads, :, :].clone()\n\n    out_ref = run_cute_flash(q, k, v, cute_score_mod, pack_gqa=pack_gqa)\n\n    for vec_size in VEC_SIZES_TO_CHECK_EQUALITY:\n        cute_vectorized_score_mod.__vec_size__ = vec_size\n        out = run_cute_flash(q, k, v, cute_vectorized_score_mod, pack_gqa=pack_gqa)\n        assert torch.equal(out, out_ref)\n\n\n@pytest.mark.parametrize(\"seqlen_q,seqlen_kv\", SEQLEN_CONFIGS)\n@pytest.mark.parametrize(\"qhead_per_kvhead,num_kv_heads\", [(1, 1), (4, 2)])\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16])\n@pytest.mark.parametrize(\"score_mod_pair\", TEST_PAIRS_WITH_AUX_TENSORS)\ndef test_cute_vs_flex_attention_with_aux_tensors(\n    seqlen_q, seqlen_kv, qhead_per_kvhead, num_kv_heads, dtype, score_mod_pair\n):\n    torch.random.manual_seed(42)\n    cute_score_mod, eager_score_mod_factory = score_mod_pair\n\n    batch_size = 2\n    num_q_heads = num_kv_heads * qhead_per_kvhead\n    pack_gqa = qhead_per_kvhead > 1\n    q, k, v = create_tensors(\n        batch_size=batch_size,\n        seqlen_q=seqlen_q,\n        seqlen_kv=seqlen_kv,\n        num_heads=num_q_heads,\n        dtype=dtype,\n    )\n    if pack_gqa:\n        k = k[:, :num_kv_heads, :, :].clone()\n        v = v[:, :num_kv_heads, :, :].clone()\n\n    if cute_score_mod == score_mod_10:\n        buffer = torch.randn(batch_size, device=\"cuda\", dtype=dtype) * 0.1\n        aux_tensors = [buffer]\n        eager_score_mod = eager_score_mod_factory(buffer)\n        assert buffer.shape == (batch_size,)\n    elif cute_score_mod == score_mod_11:\n        head_bias = torch.randn(num_q_heads, device=\"cuda\", dtype=dtype) * 0.2\n        pos_scale = torch.arange(seqlen_q, device=\"cuda\", dtype=dtype) * 0.01\n        aux_tensors = [head_bias, pos_scale]\n        eager_score_mod = eager_score_mod_factory(head_bias, pos_scale)\n        assert head_bias.shape == (num_q_heads,)\n        assert pos_scale.shape == (seqlen_q,)\n\n    out_ref_fp32 = run_flex_reference(q, k, v, eager_score_mod, dtype=torch.float32)\n\n    out_pt = run_flex_reference(q, k, v, eager_score_mod)\n    out_cute = run_cute_flash(q, k, v, cute_score_mod, aux_tensors=aux_tensors, pack_gqa=pack_gqa)\n\n    # Basic shape and NaN checks\n    assert out_cute.shape == out_ref_fp32.shape == out_pt.shape\n    assert not torch.isnan(out_cute).any()\n    assert not torch.isnan(out_ref_fp32).any()\n    assert not torch.isnan(out_pt).any()\n    assert torch.isfinite(out_cute).all()\n    assert torch.isfinite(out_ref_fp32).all()\n    assert torch.isfinite(out_pt).all()\n\n    # Numerical error if we just do any arithmetic on out_ref\n    fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item()\n    rtol = 2\n\n    # Calculate actual errors\n    pt_error = (out_pt - out_ref_fp32).abs().max().item()\n    cute_error = (out_cute - out_ref_fp32).abs().max().item()\n\n    print(f\"\\nNumerical comparison for {cute_score_mod.__name__}:\")\n    print(f\"  PyTorch vs FP32 ref max error: {pt_error:.2e}\")\n    print(f\"  CuTE vs FP32 ref max error: {cute_error:.2e}\")\n    print(f\"  Dynamic absolute tolerance: {fwd_atol:.2e}\")\n    print(f\"  Error ratio (CuTE/PyTorch): {cute_error / max(pt_error, 1e-10):.2f}\")\n\n    # Assert that CuTE's error is at most rtol times PyTorch's error + fwd_atol\n    assert cute_error <= rtol * pt_error + fwd_atol, (\n        f\"CuTE error {cute_error:.2e} exceeds {rtol}x PyTorch error {pt_error:.2e} + {fwd_atol:.2e}\"\n    )\n\n\n@pytest.mark.parametrize(\"seqlen_q,seqlen_kv\", SEQLEN_CONFIGS)\n@pytest.mark.parametrize(\"qhead_per_kvhead,num_kv_heads\", [(1, 1), (4, 2)])\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16])\n@pytest.mark.parametrize(\"score_mod_vec_pair\", TEST_PAIRS_WITH_AUX_TENSORS_VECTORIZED)\ndef test_cute_score_mod_with_aux_tensors_vectorized(\n    seqlen_q,\n    seqlen_kv,\n    qhead_per_kvhead,\n    num_kv_heads,\n    dtype,\n    score_mod_vec_pair,\n):\n    \"\"\"Tests equality between original and vectorized versions of score mods\"\"\"\n    torch.random.manual_seed(42)\n    cute_score_mod, cute_vectorized_score_mod = score_mod_vec_pair\n    batch_size = 2\n\n    num_q_heads = num_kv_heads * qhead_per_kvhead\n    pack_gqa = qhead_per_kvhead > 1\n    q, k, v = create_tensors(\n        seqlen_q=seqlen_q, seqlen_kv=seqlen_kv, num_heads=num_q_heads, dtype=dtype\n    )\n    if pack_gqa:\n        k = k[:, :num_kv_heads, :, :].clone()\n        v = v[:, :num_kv_heads, :, :].clone()\n\n    if cute_score_mod == score_mod_10:\n        buffer = torch.randn(batch_size, device=\"cuda\", dtype=dtype) * 0.1\n        aux_tensors = [buffer]\n        assert buffer.shape == (batch_size,)\n    elif cute_score_mod == score_mod_11:\n        head_bias = torch.randn(num_q_heads, device=\"cuda\", dtype=dtype) * 0.2\n        pos_scale = torch.arange(seqlen_q, device=\"cuda\", dtype=dtype) * 0.01\n        aux_tensors = [head_bias, pos_scale]\n        assert head_bias.shape == (num_q_heads,)\n        assert pos_scale.shape == (seqlen_q,)\n\n    out_ref = run_cute_flash(q, k, v, cute_score_mod, aux_tensors=aux_tensors, pack_gqa=pack_gqa)\n\n    for vec_size in VEC_SIZES_TO_CHECK_EQUALITY:\n        cute_vectorized_score_mod.__vec_size__ = vec_size\n        out = run_cute_flash(\n            q,\n            k,\n            v,\n            cute_vectorized_score_mod,\n            aux_tensors=aux_tensors,\n            pack_gqa=pack_gqa,\n        )\n        assert torch.equal(out, out_ref)\n\n\ndef _generate_block_kvcache(seqlen_k, page_size, batch_size, nheads_k, d, device, dtype):\n    import math\n    from einops import rearrange\n\n    num_blocks = math.ceil(seqlen_k / page_size) * batch_size * 3\n    k_cache_paged = torch.randn(num_blocks, page_size, nheads_k, d, device=device, dtype=dtype)\n    v_cache_paged = torch.randn(num_blocks, page_size, nheads_k, d, device=device, dtype=dtype)\n    page_table = rearrange(\n        torch.randperm(num_blocks, dtype=torch.int32, device=device),\n        \"(b nblocks) -> b nblocks\",\n        b=batch_size,\n    )\n    k_cache_bshd = rearrange(\n        k_cache_paged[page_table.flatten()],\n        \"(b nblocks) block_size ... -> b (nblocks block_size) ...\",\n        b=batch_size,\n    )[:, :seqlen_k]\n    v_cache_bshd = rearrange(\n        v_cache_paged[page_table.flatten()],\n        \"(b nblocks) block_size ... -> b (nblocks block_size) ...\",\n        b=batch_size,\n    )[:, :seqlen_k]\n    k_cache = k_cache_bshd.transpose(1, 2)\n    v_cache = v_cache_bshd.transpose(1, 2)\n    return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16])\n@pytest.mark.parametrize(\"page_size\", [None, 1, 4, 128])\n@pytest.mark.parametrize(\"qhead_per_kvhead,num_kv_heads\", [(1, 2), (4, 2)])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_kv\",\n    [\n        (1, 128),\n        (64, 256),\n        (64, 800),\n        (256, 256),\n        (113, 203),\n    ],\n)\n@pytest.mark.parametrize(\"score_mod_pair\", TEST_PAIRS)\n@pytest.mark.skipif(COMPUTE_CAPABILITY != 10, reason=\"Paged KV cache only supported on SM100\")\ndef test_score_mod_with_paged_kvcache(\n    seqlen_q,\n    seqlen_kv,\n    qhead_per_kvhead,\n    num_kv_heads,\n    page_size,\n    dtype,\n    score_mod_pair,\n):\n    if COMPUTE_CAPABILITY == 9:\n        pytest.xfail(\"Paged KV cache only supported on SM100\")\n    if page_size is not None and seqlen_kv % page_size != 0:\n        pytest.skip()\n\n    torch.random.manual_seed(42)\n    cute_score_mod, eager_score_mod = score_mod_pair\n\n    batch_size = 2\n    num_q_heads = num_kv_heads * qhead_per_kvhead\n    pack_gqa = qhead_per_kvhead > 1\n    dim = 128\n    device = \"cuda\"\n\n    q = torch.randn(batch_size, num_q_heads, seqlen_q, dim, device=device, dtype=dtype)\n\n    if page_size is None:\n        k_cache = torch.randn(batch_size, num_kv_heads, seqlen_kv, dim, device=device, dtype=dtype)\n        v_cache = torch.randn(batch_size, num_kv_heads, seqlen_kv, dim, device=device, dtype=dtype)\n        page_table = None\n        k_cache_paged = None\n        v_cache_paged = None\n    else:\n        (\n            k_cache,\n            v_cache,\n            page_table,\n            k_cache_paged,\n            v_cache_paged,\n            num_blocks,\n        ) = _generate_block_kvcache(\n            seqlen_kv, page_size, batch_size, num_kv_heads, dim, device, dtype\n        )\n\n    cache_seqlens = torch.randint(1, seqlen_kv + 1, (batch_size,), dtype=torch.int32, device=device)\n\n    from einops import rearrange\n\n    arange = rearrange(torch.arange(seqlen_kv, device=device), \"s -> 1 s\")\n    cache_seqlens_expanded = rearrange(cache_seqlens, \"b -> b 1\")\n    key_padding_mask = arange < cache_seqlens_expanded\n\n    if pack_gqa:\n        k_cache_rep = k_cache.repeat_interleave(qhead_per_kvhead, dim=1)\n        v_cache_rep = v_cache.repeat_interleave(qhead_per_kvhead, dim=1)\n    else:\n        k_cache_rep = k_cache\n        v_cache_rep = v_cache\n\n    def make_masked_score_mod(base_score_mod, seqused_k_tensor):\n        seqused_k_dev = seqused_k_tensor\n\n        def masked_score_mod(score, b, h, q_idx, kv_idx):\n            if base_score_mod is not None:\n                score = base_score_mod(score, b, h, q_idx, kv_idx)\n            seqlen_limit = torch.gather(seqused_k_dev, 0, b.long())\n            valid_mask = kv_idx < seqlen_limit\n            return torch.where(valid_mask, score, torch.full_like(score, float(\"-inf\")))\n\n        return masked_score_mod\n\n    masked_score_mod_fp32 = make_masked_score_mod(eager_score_mod, cache_seqlens)\n    masked_score_mod = make_masked_score_mod(eager_score_mod, cache_seqlens)\n\n    out_ref_fp32 = run_flex_reference(\n        q, k_cache_rep, v_cache_rep, masked_score_mod_fp32, dtype=torch.float32\n    )\n    out_pt = run_flex_reference(q, k_cache_rep, v_cache_rep, masked_score_mod)\n\n    q_bshd = q.transpose(1, 2)\n    out_cute = torch.empty_like(q_bshd)\n\n    if page_size is None:\n        k_bshd = k_cache.transpose(1, 2)\n        v_bshd = v_cache.transpose(1, 2)\n        _flash_attn_fwd(\n            q_bshd,\n            k_bshd,\n            v_bshd,\n            seqused_k=cache_seqlens,\n            return_lse=True,\n            score_mod=cute_score_mod,\n            out=out_cute,\n            lse=None,\n            pack_gqa=pack_gqa,\n        )\n    else:\n        _flash_attn_fwd(\n            q_bshd,\n            k_cache_paged,\n            v_cache_paged,\n            seqused_k=cache_seqlens,\n            page_table=page_table,\n            return_lse=True,\n            score_mod=cute_score_mod,\n            out=out_cute,\n            lse=None,\n            pack_gqa=pack_gqa,\n        )\n\n    out_cute = out_cute.transpose(1, 2)\n\n    assert out_cute.shape == out_ref_fp32.shape == out_pt.shape\n    assert not torch.isnan(out_cute).any()\n    assert not torch.isnan(out_ref_fp32).any()\n    assert not torch.isnan(out_pt).any()\n    assert torch.isfinite(out_cute).all()\n    assert torch.isfinite(out_ref_fp32).all()\n    assert torch.isfinite(out_pt).all()\n\n    fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item()\n    rtol = 2\n\n    pt_error = (out_pt - out_ref_fp32).abs().max().item()\n    cute_error = (out_cute - out_ref_fp32).abs().max().item()\n\n    print(f\"\\nNumerical comparison for {cute_score_mod.__name__} (paged={page_size is not None}):\")\n    print(f\"  PyTorch vs FP32 ref max error: {pt_error:.2e}\")\n    print(f\"  CuTE vs FP32 ref max error: {cute_error:.2e}\")\n    print(f\"  Dynamic absolute tolerance: {fwd_atol:.2e}\")\n    print(f\"  Error ratio (CuTE/PyTorch): {cute_error / max(pt_error, 1e-10):.2f}\")\n\n    assert cute_error <= rtol * pt_error + fwd_atol, (\n        f\"CuTE error {cute_error:.2e} exceeds {rtol}x PyTorch error {pt_error:.2e} + {fwd_atol:.2e}\"\n    )\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16])\n@pytest.mark.parametrize(\"page_size\", [None, 128])\n@pytest.mark.parametrize(\"qhead_per_kvhead,num_kv_heads\", [(1, 1), (4, 2)])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_kv\",\n    [\n        (64, 128),\n        (128, 256),\n        (256, 256),\n    ],\n)\n@pytest.mark.parametrize(\"score_mod_pair\", TEST_PAIRS_WITH_AUX_TENSORS)\n@pytest.mark.skipif(COMPUTE_CAPABILITY != 10, reason=\"Paged KV cache only supported on SM100\")\ndef test_score_mod_with_paged_kvcache_aux_tensors(\n    seqlen_q,\n    seqlen_kv,\n    qhead_per_kvhead,\n    num_kv_heads,\n    page_size,\n    dtype,\n    score_mod_pair,\n):\n    if COMPUTE_CAPABILITY == 9:\n        pytest.xfail(\"Paged KV cache only supported on SM100\")\n    if page_size is not None and seqlen_kv % page_size != 0:\n        pytest.skip()\n\n    torch.random.manual_seed(42)\n    cute_score_mod, eager_score_mod_factory = score_mod_pair\n\n    batch_size = 2\n    num_q_heads = num_kv_heads * qhead_per_kvhead\n    pack_gqa = qhead_per_kvhead > 1\n    dim = 128\n    device = \"cuda\"\n\n    q = torch.randn(batch_size, num_q_heads, seqlen_q, dim, device=device, dtype=dtype)\n\n    if page_size is None:\n        k_cache = torch.randn(batch_size, num_kv_heads, seqlen_kv, dim, device=device, dtype=dtype)\n        v_cache = torch.randn(batch_size, num_kv_heads, seqlen_kv, dim, device=device, dtype=dtype)\n        page_table = None\n        k_cache_paged = None\n        v_cache_paged = None\n    else:\n        (\n            k_cache,\n            v_cache,\n            page_table,\n            k_cache_paged,\n            v_cache_paged,\n            num_blocks,\n        ) = _generate_block_kvcache(\n            seqlen_kv, page_size, batch_size, num_kv_heads, dim, device, dtype\n        )\n\n    cache_seqlens = torch.randint(1, seqlen_kv + 1, (batch_size,), dtype=torch.int32, device=device)\n\n    if cute_score_mod == score_mod_10:\n        buffer = torch.randn(batch_size, device=device, dtype=dtype) * 0.1\n        aux_tensors = [buffer]\n        eager_score_mod = eager_score_mod_factory(buffer)\n    elif cute_score_mod == score_mod_11:\n        head_bias = torch.randn(num_q_heads, device=device, dtype=dtype) * 0.2\n        pos_scale = torch.arange(seqlen_q, device=device, dtype=dtype) * 0.01\n        aux_tensors = [head_bias, pos_scale]\n        eager_score_mod = eager_score_mod_factory(head_bias, pos_scale)\n\n    from einops import rearrange\n\n    arange = rearrange(torch.arange(seqlen_kv, device=device), \"s -> 1 s\")\n    cache_seqlens_expanded = rearrange(cache_seqlens, \"b -> b 1\")\n    key_padding_mask = arange < cache_seqlens_expanded\n\n    if pack_gqa:\n        k_cache_rep = k_cache.repeat_interleave(qhead_per_kvhead, dim=1)\n        v_cache_rep = v_cache.repeat_interleave(qhead_per_kvhead, dim=1)\n    else:\n        k_cache_rep = k_cache\n        v_cache_rep = v_cache\n\n    def make_masked_score_mod(base_score_mod, seqused_k_tensor):\n        seqused_k_dev = seqused_k_tensor\n\n        def masked_score_mod(score, b, h, q_idx, kv_idx):\n            if base_score_mod is not None:\n                score = base_score_mod(score, b, h, q_idx, kv_idx)\n            seqlen_limit = torch.gather(seqused_k_dev, 0, b.long())\n            valid_mask = kv_idx < seqlen_limit\n            return torch.where(valid_mask, score, torch.full_like(score, float(\"-inf\")))\n\n        return masked_score_mod\n\n    masked_score_mod_fp32 = make_masked_score_mod(eager_score_mod, cache_seqlens)\n    masked_score_mod = make_masked_score_mod(eager_score_mod, cache_seqlens)\n\n    out_ref_fp32 = run_flex_reference(\n        q, k_cache_rep, v_cache_rep, masked_score_mod_fp32, dtype=torch.float32\n    )\n    out_pt = run_flex_reference(q, k_cache_rep, v_cache_rep, masked_score_mod)\n\n    q_bshd = q.transpose(1, 2)\n    out_cute = torch.empty_like(q_bshd)\n\n    if page_size is None:\n        k_bshd = k_cache.transpose(1, 2)\n        v_bshd = v_cache.transpose(1, 2)\n        _flash_attn_fwd(\n            q_bshd,\n            k_bshd,\n            v_bshd,\n            seqused_k=cache_seqlens,\n            return_lse=True,\n            score_mod=cute_score_mod,\n            out=out_cute,\n            lse=None,\n            aux_tensors=aux_tensors,\n            pack_gqa=pack_gqa,\n        )\n    else:\n        _flash_attn_fwd(\n            q_bshd,\n            k_cache_paged,\n            v_cache_paged,\n            seqused_k=cache_seqlens,\n            page_table=page_table,\n            return_lse=True,\n            score_mod=cute_score_mod,\n            out=out_cute,\n            lse=None,\n            aux_tensors=aux_tensors,\n            pack_gqa=pack_gqa,\n        )\n\n    out_cute = out_cute.transpose(1, 2)\n\n    assert out_cute.shape == out_ref_fp32.shape == out_pt.shape\n    assert not torch.isnan(out_cute).any()\n    assert not torch.isnan(out_ref_fp32).any()\n    assert not torch.isnan(out_pt).any()\n    assert torch.isfinite(out_cute).all()\n    assert torch.isfinite(out_ref_fp32).all()\n    assert torch.isfinite(out_pt).all()\n\n    fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item()\n    rtol = 2\n\n    pt_error = (out_pt - out_ref_fp32).abs().max().item()\n    cute_error = (out_cute - out_ref_fp32).abs().max().item()\n\n    print(f\"\\nNumerical comparison for {cute_score_mod.__name__} (paged={page_size is not None}):\")\n    print(f\"  PyTorch vs FP32 ref max error: {pt_error:.2e}\")\n    print(f\"  CuTE vs FP32 ref max error: {cute_error:.2e}\")\n    print(f\"  Dynamic absolute tolerance: {fwd_atol:.2e}\")\n    print(f\"  Error ratio (CuTE/PyTorch): {cute_error / max(pt_error, 1e-10):.2f}\")\n\n    assert cute_error <= rtol * pt_error + fwd_atol, (\n        f\"CuTE error {cute_error:.2e} exceeds {rtol}x PyTorch error {pt_error:.2e} + {fwd_atol:.2e}\"\n    )\n\n\n@cute.jit\ndef score_mod_bwd_5(grad, score, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):\n    \"\"\"Backward for score_mod_5 (times_two): d(score*2)/d(score) = 2.\"\"\"\n    return grad * cute.full_like(grad, 2.0)\n\n\n@cute.jit\ndef score_mod_bwd_3(grad, score, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):\n    \"\"\"Backward for score_mod_3 (relative_bias): d(score + |q-kv|)/d(score) = 1.\"\"\"\n    return grad\n\n\n@cute.jit\ndef score_mod_bwd_identity(grad, score, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):\n    return grad\n\n\n@cute.jit\ndef score_mod_bwd_causal(grad, score, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):\n    \"\"\"Backward for causal masking: d(where(mask, score, -inf))/d(score) = where(mask, 1, 0).\n\n    At unmasked positions (q_idx >= kv_idx), grad passes through.\n    At masked positions (q_idx < kv_idx), the kernel already zeros grad because P=0.\n    \"\"\"\n    return grad\n\n\n@cute.jit\ndef score_mod_squared(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):\n    \"\"\"Forward: score ** 2.\"\"\"\n    return tSrS_ssa * tSrS_ssa\n\n\n@cute.jit\ndef score_mod_bwd_squared(grad, score, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors):\n    \"\"\"Backward for score**2: d(score**2)/d(score) = 2*score.\"\"\"\n    return grad * cute.full_like(grad, 2.0) * score\n\n\ndef score_squared_eager(score, b, h, q_idx, kv_idx):\n    return score * score\n\n\nBWD_TEST_PAIRS = [\n    (score_mod_5, score_mod_bwd_5, times_two_eager),\n    (score_mod_3, score_mod_bwd_3, relative_bias_eager),\n    (score_mod_squared, score_mod_bwd_squared, score_squared_eager),\n    (score_mod_2, score_mod_bwd_causal, causal_mask_eager),\n]\n\nBWD_TEST_PAIRS_WITH_AUX = [\n    (score_mod_10, score_mod_bwd_identity, batch_bias),\n    (score_mod_11, score_mod_bwd_identity, dual_buffer_bias),\n]\n\nBWD_TEST_PAIRS_PACK_GQA = [\n    (score_mod_5, score_mod_bwd_5, times_two_eager),\n    (score_mod_3, score_mod_bwd_3, relative_bias_eager),\n]\n\n\ndef run_cute_flash_bwd(\n    q, k, v, cute_score_mod, cute_score_mod_bwd, aux_tensors=None, pack_gqa=False\n):\n    \"\"\"Run flash attention forward + backward with score_mod.\"\"\"\n    q_t = q.transpose(1, 2)\n    k_t = k.transpose(1, 2)\n    v_t = v.transpose(1, 2)\n\n    out, lse = _flash_attn_fwd(\n        q_t,\n        k_t,\n        v_t,\n        return_lse=True,\n        score_mod=cute_score_mod,\n        aux_tensors=aux_tensors,\n        pack_gqa=pack_gqa,\n    )\n\n    grad_out = torch.randn_like(out)\n\n    dq, dk, dv = _flash_attn_bwd(\n        q_t,\n        k_t,\n        v_t,\n        out,\n        grad_out,\n        lse,\n        score_mod=cute_score_mod,\n        score_mod_bwd=cute_score_mod_bwd,\n        aux_tensors=aux_tensors,\n        pack_gqa=pack_gqa,\n    )\n\n    return (\n        out.transpose(1, 2),\n        grad_out.transpose(1, 2),\n        dq.transpose(1, 2),\n        dk.transpose(1, 2),\n        dv.transpose(1, 2),\n    )\n\n\ndef run_flex_reference_bwd(q, k, v, eager_score_mod, grad_out, dtype=None):\n    \"\"\"Run flex_attention forward + backward for reference.\"\"\"\n    if dtype is not None:\n        q = q.to(dtype).requires_grad_(True)\n        k = k.to(dtype).requires_grad_(True)\n        v = v.to(dtype).requires_grad_(True)\n        grad_out = grad_out.to(dtype)\n    else:\n        q = q.requires_grad_(True)\n        k = k.requires_grad_(True)\n        v = v.requires_grad_(True)\n\n    compiled_flex = torch.compile(flex_attention)\n    out = compiled_flex(q, k, v, score_mod=eager_score_mod, enable_gqa=q.shape[1] != k.shape[1])\n    dq, dk, dv = torch.autograd.grad(out, (q, k, v), grad_out)\n\n    return out, dq, dk, dv\n\n\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_kv\",\n    [\n        (64, 64),\n        (128, 128),\n        (256, 256),\n        (512, 512),\n        (799, 3),\n        (3, 799),\n        (128, 256),\n        (256, 128),\n        (113, 203),\n    ],\n)\n@pytest.mark.parametrize(\"dim\", [64, 128])\n@pytest.mark.parametrize(\"dtype\", [torch.bfloat16, torch.float16])\n@pytest.mark.parametrize(\"score_mod_triple\", BWD_TEST_PAIRS)\ndef test_cute_vs_flex_attention_backward(seqlen_q, seqlen_kv, dim, dtype, score_mod_triple):\n    \"\"\"Test backward pass with score_mod against flex_attention reference.\"\"\"\n    if COMPUTE_CAPABILITY == 9 and dim == 64:\n        pytest.skip(\"head_dim=64 not supported on SM90 for backward\")\n\n    torch.random.manual_seed(42)\n    cute_fwd, cute_bwd, eager_ref = score_mod_triple\n\n    q, k, v = create_tensors(\n        seqlen_q=seqlen_q, seqlen_kv=seqlen_kv, num_heads=4, dim=dim, dtype=dtype\n    )\n\n    out_cute, grad_out, dq_cute, dk_cute, dv_cute = run_cute_flash_bwd(q, k, v, cute_fwd, cute_bwd)\n    out_ref_fp32, dq_ref_fp32, dk_ref_fp32, dv_ref_fp32 = run_flex_reference_bwd(\n        q, k, v, eager_ref, grad_out, dtype=torch.float32\n    )\n    out_pt, dq_pt, dk_pt, dv_pt = run_flex_reference_bwd(q, k, v, eager_ref, grad_out)\n\n    assert not torch.isnan(dq_cute).any(), \"dQ contains NaN\"\n    assert not torch.isnan(dk_cute).any(), \"dK contains NaN\"\n    assert not torch.isnan(dv_cute).any(), \"dV contains NaN\"\n\n    rtol = 2\n    dq_atol = 2 * (dq_ref_fp32 + 0.3 - 0.3 - dq_ref_fp32).abs().max().item()\n    dk_atol = 2 * (dk_ref_fp32 + 0.3 - 0.3 - dk_ref_fp32).abs().max().item()\n    dv_atol = 2 * (dv_ref_fp32 + 0.3 - 0.3 - dv_ref_fp32).abs().max().item()\n\n    dq_ref = dq_ref_fp32.to(dtype)\n    dk_ref = dk_ref_fp32.to(dtype)\n    dv_ref = dv_ref_fp32.to(dtype)\n\n    pt_dq_err = (dq_pt - dq_ref).abs().max().item()\n    pt_dk_err = (dk_pt - dk_ref).abs().max().item()\n    pt_dv_err = (dv_pt - dv_ref).abs().max().item()\n\n    cute_dq_err = (dq_cute - dq_ref).abs().max().item()\n    cute_dk_err = (dk_cute - dk_ref).abs().max().item()\n    cute_dv_err = (dv_cute - dv_ref).abs().max().item()\n\n    print(f\"\\nBackward comparison for {cute_fwd.__name__}:\")\n    print(f\"  dQ: PT err={pt_dq_err:.2e}, CuTE err={cute_dq_err:.2e}, atol={dq_atol:.2e}\")\n    print(f\"  dK: PT err={pt_dk_err:.2e}, CuTE err={cute_dk_err:.2e}, atol={dk_atol:.2e}\")\n    print(f\"  dV: PT err={pt_dv_err:.2e}, CuTE err={cute_dv_err:.2e}, atol={dv_atol:.2e}\")\n\n    assert cute_dq_err <= rtol * pt_dq_err + dq_atol, f\"dQ error too large: {cute_dq_err:.2e}\"\n    assert cute_dk_err <= rtol * pt_dk_err + dk_atol, f\"dK error too large: {cute_dk_err:.2e}\"\n    assert cute_dv_err <= rtol * pt_dv_err + dv_atol, f\"dV error too large: {cute_dv_err:.2e}\"\n\n\ndef make_aux_tensors_for_bwd(cute_score_mod, eager_factory, seqlen_q, num_heads, batch_size, dtype):\n    if cute_score_mod == score_mod_10:\n        buffer = torch.randn(batch_size, device=\"cuda\", dtype=dtype) * 0.1\n        return [buffer], eager_factory(buffer)\n    head_bias = torch.randn(num_heads, device=\"cuda\", dtype=dtype) * 0.2\n    pos_scale = torch.arange(seqlen_q, device=\"cuda\", dtype=dtype) * 0.01\n    return [head_bias, pos_scale], eager_factory(head_bias, pos_scale)\n\n\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_kv\",\n    [\n        (64, 64),\n        (128, 128),\n        (256, 128),\n    ],\n)\n@pytest.mark.parametrize(\"dim\", [64, 128])\n@pytest.mark.parametrize(\"dtype\", [torch.bfloat16, torch.float16])\n@pytest.mark.parametrize(\"score_mod_triple\", BWD_TEST_PAIRS_WITH_AUX)\ndef test_cute_vs_flex_attention_backward_with_aux(\n    seqlen_q, seqlen_kv, dim, dtype, score_mod_triple\n):\n    if COMPUTE_CAPABILITY == 9 and dim == 64:\n        pytest.skip(\"head_dim=64 not supported on SM90 for backward\")\n\n    torch.random.manual_seed(42)\n    cute_fwd, cute_bwd, eager_factory = score_mod_triple\n\n    q, k, v = create_tensors(\n        seqlen_q=seqlen_q, seqlen_kv=seqlen_kv, num_heads=4, dim=dim, dtype=dtype\n    )\n\n    aux_tensors, eager_ref = make_aux_tensors_for_bwd(\n        cute_fwd, eager_factory, seqlen_q, q.shape[1], q.shape[0], dtype\n    )\n\n    out_cute, grad_out, dq_cute, dk_cute, dv_cute = run_cute_flash_bwd(\n        q, k, v, cute_fwd, cute_bwd, aux_tensors=aux_tensors\n    )\n    out_ref_fp32, dq_ref_fp32, dk_ref_fp32, dv_ref_fp32 = run_flex_reference_bwd(\n        q, k, v, eager_ref, grad_out, dtype=torch.float32\n    )\n    out_pt, dq_pt, dk_pt, dv_pt = run_flex_reference_bwd(q, k, v, eager_ref, grad_out)\n\n    assert not torch.isnan(dq_cute).any()\n    assert not torch.isnan(dk_cute).any()\n    assert not torch.isnan(dv_cute).any()\n\n    rtol = 3\n    dq_atol = 2 * (dq_ref_fp32 + 0.3 - 0.3 - dq_ref_fp32).abs().max().item()\n    dk_atol = 2 * (dk_ref_fp32 + 0.3 - 0.3 - dk_ref_fp32).abs().max().item()\n    dv_atol = 2 * (dv_ref_fp32 + 0.3 - 0.3 - dv_ref_fp32).abs().max().item()\n\n    dq_ref = dq_ref_fp32.to(dtype)\n    dk_ref = dk_ref_fp32.to(dtype)\n    dv_ref = dv_ref_fp32.to(dtype)\n\n    pt_dq_err = (dq_pt - dq_ref).abs().max().item()\n    pt_dk_err = (dk_pt - dk_ref).abs().max().item()\n    pt_dv_err = (dv_pt - dv_ref).abs().max().item()\n\n    cute_dq_err = (dq_cute - dq_ref).abs().max().item()\n    cute_dk_err = (dk_cute - dk_ref).abs().max().item()\n    cute_dv_err = (dv_cute - dv_ref).abs().max().item()\n\n    print(f\"\\nBackward comparison with aux for {cute_fwd.__name__}:\")\n    print(f\"  dQ: PT err={pt_dq_err:.2e}, CuTE err={cute_dq_err:.2e}, atol={dq_atol:.2e}\")\n    print(f\"  dK: PT err={pt_dk_err:.2e}, CuTE err={cute_dk_err:.2e}, atol={dk_atol:.2e}\")\n    print(f\"  dV: PT err={pt_dv_err:.2e}, CuTE err={cute_dv_err:.2e}, atol={dv_atol:.2e}\")\n\n    assert cute_dq_err <= rtol * pt_dq_err + dq_atol, f\"dQ error too large: {cute_dq_err:.2e}\"\n    assert cute_dk_err <= rtol * pt_dk_err + dk_atol, f\"dK error too large: {cute_dk_err:.2e}\"\n    assert cute_dv_err <= rtol * pt_dv_err + dv_atol, f\"dV error too large: {cute_dv_err:.2e}\"\n\n\n@pytest.mark.parametrize(\"seqlen_q,seqlen_kv\", [(128, 128), (128, 256)])\n@pytest.mark.parametrize(\"dim\", [64, 128])\n@pytest.mark.parametrize(\"dtype\", [torch.bfloat16, torch.float16])\n@pytest.mark.parametrize(\"qhead_per_kvhead,num_kv_heads\", [(4, 2)])\n@pytest.mark.parametrize(\"score_mod_triple\", BWD_TEST_PAIRS_PACK_GQA)\ndef test_cute_vs_flex_attention_backward_pack_gqa(\n    seqlen_q, seqlen_kv, dim, dtype, qhead_per_kvhead, num_kv_heads, score_mod_triple\n):\n    if COMPUTE_CAPABILITY == 9:\n        pytest.xfail(\"pack_gqa backward not yet implemented on SM90\")\n\n    torch.random.manual_seed(42)\n    cute_fwd, cute_bwd, eager_ref = score_mod_triple\n\n    num_q_heads = num_kv_heads * qhead_per_kvhead\n    q, k, v = create_tensors(\n        seqlen_q=seqlen_q, seqlen_kv=seqlen_kv, num_heads=num_q_heads, dim=dim, dtype=dtype\n    )\n    k = k[:, :num_kv_heads, :, :].clone()\n    v = v[:, :num_kv_heads, :, :].clone()\n\n    out_cute, grad_out, dq_cute, dk_cute, dv_cute = run_cute_flash_bwd(\n        q, k, v, cute_fwd, cute_bwd, pack_gqa=True\n    )\n    out_ref_fp32, dq_ref_fp32, dk_ref_fp32, dv_ref_fp32 = run_flex_reference_bwd(\n        q, k, v, eager_ref, grad_out, dtype=torch.float32\n    )\n    out_pt, dq_pt, dk_pt, dv_pt = run_flex_reference_bwd(q, k, v, eager_ref, grad_out)\n\n    assert not torch.isnan(dq_cute).any()\n    assert not torch.isnan(dk_cute).any()\n    assert not torch.isnan(dv_cute).any()\n\n    rtol = 3\n    dq_atol = 2 * (dq_ref_fp32 + 0.3 - 0.3 - dq_ref_fp32).abs().max().item()\n    dk_atol = 2 * (dk_ref_fp32 + 0.3 - 0.3 - dk_ref_fp32).abs().max().item()\n    dv_atol = 2 * (dv_ref_fp32 + 0.3 - 0.3 - dv_ref_fp32).abs().max().item()\n\n    dq_ref = dq_ref_fp32.to(dtype)\n    dk_ref = dk_ref_fp32.to(dtype)\n    dv_ref = dv_ref_fp32.to(dtype)\n\n    pt_dq_err = (dq_pt - dq_ref).abs().max().item()\n    pt_dk_err = (dk_pt - dk_ref).abs().max().item()\n    pt_dv_err = (dv_pt - dv_ref).abs().max().item()\n\n    cute_dq_err = (dq_cute - dq_ref).abs().max().item()\n    cute_dk_err = (dk_cute - dk_ref).abs().max().item()\n    cute_dv_err = (dv_cute - dv_ref).abs().max().item()\n\n    print(f\"\\nBackward Pack-GQA comparison for {cute_fwd.__name__}:\")\n    print(f\"  dQ: PT err={pt_dq_err:.2e}, CuTE err={cute_dq_err:.2e}, atol={dq_atol:.2e}\")\n    print(f\"  dK: PT err={pt_dk_err:.2e}, CuTE err={cute_dk_err:.2e}, atol={dk_atol:.2e}\")\n    print(f\"  dV: PT err={pt_dv_err:.2e}, CuTE err={cute_dv_err:.2e}, atol={dv_atol:.2e}\")\n\n    assert cute_dq_err <= rtol * pt_dq_err + dq_atol, f\"dQ error too large: {cute_dq_err:.2e}\"\n    assert cute_dk_err <= rtol * pt_dk_err + dk_atol, f\"dK error too large: {cute_dk_err:.2e}\"\n    assert cute_dv_err <= rtol * pt_dv_err + dv_atol, f\"dV error too large: {cute_dv_err:.2e}\"\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__, \"-v\"])\n"
  },
  {
    "path": "tests/cute/test_score_mod_varlen.py",
    "content": "import pytest\nimport torch\nfrom torch.nn.attention.flex_attention import flex_attention\nfrom flash_attn.cute.interface import _flash_attn_fwd\nfrom test_score_mod import _generate_block_kvcache\nfrom score_mod_definitions import (\n    # TensorSSA-based score mods\n    score_mod_alibi,\n    score_mod_batch_bias,\n    score_mod_block_diagonal,\n    score_mod_causal,\n    score_mod_causal_v2,\n    score_mod_debug_global_idx,\n    score_mod_dual_buffer,\n    score_mod_global_kv_bias,\n    score_mod_global_logical_rel_plus_kv_bias,\n    score_mod_global_q_and_kv_bias,\n    score_mod_global_q_bias,\n    score_mod_global_rel_plus_kv_bias,\n    score_mod_identity,\n    score_mod_rel_bias,\n    score_mod_rel_bias_x2,\n    score_mod_sliding_window,\n    score_mod_stress_complex_arithmetic,\n    score_mod_stress_conditional_mask,\n    score_mod_stress_global_offset,\n    score_mod_stress_multi_buffer,\n    score_mod_stress_xor_pattern,\n    score_mod_times_two,\n)  # isort: split\nfrom score_mod_definitions import (\n    score_mod_identity_vectorized,\n    score_mod_causal_vectorized,\n    score_mod_rel_bias as score_mod_rel_bias_vectorized,\n    score_mod_rel_bias_x2_vectorized,\n    score_mod_times_two_vectorized,\n    score_mod_alibi_vectorized,\n    score_mod_batch_bias_vectorized,\n    score_mod_dual_buffer_vectorized,\n)  # isort: split\nfrom score_mod_definitions import (\n    # Eager (torch) reference score mods\n    identity_eager,\n    causal_eager,\n    rel_bias_eager,\n    rel_bias_x2_eager,\n    times_two_eager,\n    alibi_eager,\n    sliding_window_eager,\n    block_diagonal_eager,\n    causal_v2_eager,\n    batch_bias_factory,\n    dual_buffer_factory,\n    packed_kv_bias_factory,\n    packed_q_bias_factory,\n    packed_rel_plus_kv_bias_factory,\n    packed_q_and_kv_bias_factory,\n    packed_logical_rel_plus_kv_bias_factory,\n    stress_complex_arithmetic_factory,\n    stress_conditional_mask_factory,\n    stress_multi_buffer_factory,\n    stress_global_offset_factory,\n    stress_xor_pattern_factory,\n    debug_global_idx_factory,\n)\n\nIS_SM90 = torch.cuda.get_device_capability()[0] == 9\nIS_SM100 = torch.cuda.get_device_capability()[0] == 10\n\n# =============================================================================\n# Test pairs\n# =============================================================================\n\n# (cute_score_mod, eager_factory_or_fn, aux_type)\n# aux_type: None, \"batch\", \"dual_buffer\"\n# All score_mods use 7-arg signature: (tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors)\nTEST_PAIRS_NO_GLOBAL = [\n    (score_mod_identity, identity_eager, None),\n    (score_mod_causal, causal_eager, None),\n    (score_mod_rel_bias, rel_bias_eager, None),\n    (score_mod_rel_bias_x2, rel_bias_x2_eager, None),\n    (score_mod_times_two, times_two_eager, None),\n    (score_mod_alibi, alibi_eager, None),\n    (score_mod_sliding_window, sliding_window_eager, None),\n    (score_mod_block_diagonal, block_diagonal_eager, None),\n    (score_mod_causal_v2, causal_v2_eager, None),\n    (score_mod_batch_bias, batch_bias_factory, \"batch\"),\n    (score_mod_dual_buffer, dual_buffer_factory, \"dual_buffer\"),\n]\n\n# Test pairs to compare vectorized score_mods: (cute_jit_function, cute_jit_function_vectorized)\nTEST_PAIRS_VECTORIZED_NO_GLOBAL = [\n    (score_mod_identity, score_mod_identity_vectorized, None),\n    (score_mod_causal, score_mod_causal_vectorized, None),\n    (score_mod_rel_bias, score_mod_rel_bias_vectorized, None),\n    (score_mod_rel_bias_x2, score_mod_rel_bias_x2_vectorized, None),\n    (score_mod_times_two, score_mod_times_two_vectorized, None),\n    (score_mod_alibi, score_mod_alibi_vectorized, None),\n    (score_mod_batch_bias, score_mod_batch_bias_vectorized, \"batch\"),\n    (score_mod_dual_buffer, score_mod_dual_buffer_vectorized, \"dual_buffer\"),\n]\n# (cute_score_mod, eager_factory, aux_type, requires_global)\n# aux_type: \"kv\", \"q\", \"q_and_kv\", \"q_concat\", \"kv_with_cu\", \"multi_buffer\"\n# requires_global: \"q\" (needs varlen_q), \"kv\" (needs varlen_k), \"both\" (needs both)\n# All score_mods use 7-arg signature and compute global indices from seqlen_info\nTEST_PAIRS_WITH_GLOBAL = [\n    (score_mod_global_kv_bias, packed_kv_bias_factory, \"kv\", \"kv\"),\n    (score_mod_global_q_bias, packed_q_bias_factory, \"q\", \"q\"),\n    (score_mod_global_rel_plus_kv_bias, packed_rel_plus_kv_bias_factory, \"kv\", \"kv\"),\n    (score_mod_global_q_and_kv_bias, packed_q_and_kv_bias_factory, \"q_and_kv\", \"both\"),\n    (\n        score_mod_global_logical_rel_plus_kv_bias,\n        packed_logical_rel_plus_kv_bias_factory,\n        \"kv\",\n        \"kv\",\n    ),\n    (\n        score_mod_stress_complex_arithmetic,\n        stress_complex_arithmetic_factory,\n        \"q_concat\",\n        \"q\",\n    ),\n    (\n        score_mod_stress_conditional_mask,\n        stress_conditional_mask_factory,\n        \"kv_with_cu\",\n        \"both\",\n    ),\n    (\n        score_mod_stress_multi_buffer,\n        stress_multi_buffer_factory,\n        \"multi_buffer\",\n        \"both\",\n    ),\n    (score_mod_stress_global_offset, stress_global_offset_factory, \"kv\", \"kv\"),\n    (score_mod_stress_xor_pattern, stress_xor_pattern_factory, \"kv_with_cu\", \"kv\"),\n    (score_mod_debug_global_idx, debug_global_idx_factory, \"kv\", \"kv\"),\n]\n\nSEQLEN_CONFIGS = [\n    ([1], [1]),\n    ([1, 1], [1, 1]),\n    ([2, 3], [2, 3]),\n    ([8, 16], [8, 16]),\n    ([32, 32], [32, 32]),\n    ([64, 128], [64, 128]),\n    ([64, 56, 128], [64, 56, 128]),\n    ([256, 512], [256, 512]),\n    ([113, 203], [113, 203]),\n    ([239, 1], [239, 1]),\n    ([64], [64]),\n    ([128, 128], [128, 128]),\n    ([32, 32, 32, 32], [32, 32, 32, 32]),\n    ([16, 32, 64, 128, 256], [16, 32, 64, 128, 256]),\n    ([1, 1024], [1, 1024]),\n    ([1024, 1], [1024, 1]),\n    ([1, 256, 1], [1, 256, 1]),\n    ([256, 1, 256], [256, 1, 256]),\n    ([17, 33, 65], [17, 33, 65]),\n    ([64, 128], [32, 64]),\n    ([100, 100], [50, 50]),\n    ([256, 512, 256], [128, 256, 128]),\n    ([2, 1], [16384, 32 * 1024]),\n    ([1, 1], [128 * 1024] * 2),\n    ([2, 1], [8192, 8192]),\n    ([1, 3], [8192, 8192]),\n    ([3, 3], [8192, 8192]),\n    ([128, 128], [8192, 8192]),\n    ([2, 2, 2], [8 * 1024] * 3),\n    ([2, 1], [1024 * 32, 16384]),\n    ([1, 2], [1024 * 32, 16384]),\n    ([1, 1, 1], [128 * 1024] * 3),\n    ([1, 1, 1], [256 * 1024] * 3),\n]\n\nVEC_SIZES_TO_CHECK_EQUALITY = [1, 2, 4] if IS_SM100 else [1, 2]\n\n# =============================================================================\n# Helper functions\n# =============================================================================\n\n\ndef run_cute_flash(\n    q,\n    k,\n    v,\n    score_mod,\n    aux_tensors=None,\n    pack_gqa=False,\n    cu_seqlens_q=None,\n    cu_seqlens_k=None,\n    page_table=None,\n    seqused_k=None,\n):\n    \"\"\"Run CuTE flash attention.\"\"\"\n    if cu_seqlens_q is not None or cu_seqlens_k is not None:\n        out = torch.empty_like(q)\n        _flash_attn_fwd(\n            q,\n            k,\n            v,\n            cu_seqlens_q=cu_seqlens_q,\n            cu_seqlens_k=cu_seqlens_k,\n            seqused_k=seqused_k,\n            page_table=page_table,\n            return_lse=True,\n            score_mod=score_mod,\n            out=out,\n            lse=None,\n            aux_tensors=aux_tensors,\n            pack_gqa=pack_gqa,\n        )\n        return out\n\n    out = torch.empty_like(q)\n    _flash_attn_fwd(\n        q,\n        k,\n        v,\n        seqused_k=seqused_k,\n        page_table=page_table,\n        return_lse=True,\n        score_mod=score_mod,\n        out=out,\n        lse=None,\n        aux_tensors=aux_tensors,\n        pack_gqa=pack_gqa,\n    )\n    return out\n\n\ndef run_flex_varlen_ref(q, k, v, cu_seqlens_q, cu_seqlens_k, score_mod, dtype=None):\n    \"\"\"Run flex_attention per-sequence for varlen reference.\"\"\"\n    if cu_seqlens_q is not None:\n        num_batches = len(cu_seqlens_q) - 1\n    else:\n        num_batches = len(cu_seqlens_k) - 1\n\n    results = []\n    for i in range(num_batches):\n        # Get Q slice\n        if cu_seqlens_q is not None:\n            q_slice = (\n                q[cu_seqlens_q[i] : cu_seqlens_q[i + 1]].unsqueeze(0).transpose(1, 2)\n            )\n        else:\n            q_slice = q[i : i + 1].transpose(1, 2)\n\n        # Get K/V slices\n        if cu_seqlens_k is not None:\n            k_slice = (\n                k[cu_seqlens_k[i] : cu_seqlens_k[i + 1]].unsqueeze(0).transpose(1, 2)\n            )\n            v_slice = (\n                v[cu_seqlens_k[i] : cu_seqlens_k[i + 1]].unsqueeze(0).transpose(1, 2)\n            )\n        else:\n            k_slice = k[i : i + 1].transpose(1, 2)\n            v_slice = v[i : i + 1].transpose(1, 2)\n\n        if dtype is not None:\n            q_slice, k_slice, v_slice = (\n                q_slice.to(dtype),\n                k_slice.to(dtype),\n                v_slice.to(dtype),\n            )\n\n        def wrapped_mod(score, b, h, q_idx, kv_idx):\n            return score_mod(score, i, h, q_idx, kv_idx)\n\n        out = flex_attention(\n            q_slice,\n            k_slice,\n            v_slice,\n            score_mod=wrapped_mod,\n            enable_gqa=q_slice.shape[1] != k_slice.shape[1],\n        )\n        results.append(out.transpose(1, 2).squeeze(0))\n\n    return torch.cat(results, dim=0)\n\n\ndef setup_tensors(seqlens_q, seqlens_k, varlen_q, varlen_k, num_heads, head_dim, dtype):\n    \"\"\"Create Q, K, V tensors and cu_seqlens based on varlen flags.\"\"\"\n    batch_size = len(seqlens_q)\n\n    if varlen_q:\n        total_q = sum(seqlens_q)\n        q = torch.randn(total_q, num_heads, head_dim, device=\"cuda\", dtype=dtype)\n        cu_seqlens_q = torch.tensor(\n            [0] + list(torch.tensor(seqlens_q).cumsum(0).tolist()),\n            device=\"cuda\",\n            dtype=torch.int32,\n        )\n    else:\n        seqlen_q = seqlens_q[0]  # All sequences have the same length for non-varlen\n        q = torch.randn(\n            batch_size, seqlen_q, num_heads, head_dim, device=\"cuda\", dtype=dtype\n        )\n        cu_seqlens_q = None\n\n    if varlen_k:\n        total_k = sum(seqlens_k)\n        k = torch.randn(total_k, num_heads, head_dim, device=\"cuda\", dtype=dtype)\n        v = torch.randn(total_k, num_heads, head_dim, device=\"cuda\", dtype=dtype)\n        cu_seqlens_k = torch.tensor(\n            [0] + list(torch.tensor(seqlens_k).cumsum(0).tolist()),\n            device=\"cuda\",\n            dtype=torch.int32,\n        )\n    else:\n        seqlen_k = seqlens_k[0]  # All sequences have the same length for non-varlen\n        k = torch.randn(\n            batch_size, seqlen_k, num_heads, head_dim, device=\"cuda\", dtype=dtype\n        )\n        v = torch.randn(\n            batch_size, seqlen_k, num_heads, head_dim, device=\"cuda\", dtype=dtype\n        )\n        cu_seqlens_k = None\n\n    return q, k, v, cu_seqlens_q, cu_seqlens_k\n\n\ndef prepare_ref_tensors(\n    q, k, v, cu_seqlens_q, cu_seqlens_k, varlen_q, varlen_k, batch_size, seqlens_q\n):\n    \"\"\"Prepare tensors for flex_attention reference (handle mixed varlen formats).\"\"\"\n    num_heads = q.shape[1] if varlen_q else q.shape[2]\n\n    if not varlen_q and varlen_k:\n        seqlen_q = q.shape[1]\n        q_packed = q.reshape(-1, num_heads, q.shape[-1])\n        ref_cu_seqlens_q = torch.tensor(\n            [seqlen_q * i for i in range(batch_size + 1)],\n            device=\"cuda\",\n            dtype=torch.int32,\n        )\n        return q_packed, k, v, ref_cu_seqlens_q, cu_seqlens_k\n\n    if varlen_q and not varlen_k:\n        return q, k, v, cu_seqlens_q, None\n\n    return q, k, v, cu_seqlens_q, cu_seqlens_k\n\n\ndef check_results(\n    out_cute,\n    out_ref_fp32,\n    out_pt,\n    test_name,\n    rtol=2,\n    extra_atol=1e-4,\n    seqlens_q=None,\n    cu_seqlens_q=None,\n):\n    \"\"\"Compare CuTE output against references.\"\"\"\n    assert not torch.isnan(out_cute).any(), f\"{test_name}: NaN in output\"\n    assert torch.isfinite(out_cute).all(), f\"{test_name}: Inf in output\"\n\n    varlen_q = cu_seqlens_q is not None\n\n    if varlen_q:\n        # Unpack and compare per-sequence\n        assert seqlens_q is not None, \"varlen_q requires use of seqlens_q\"\n        num_seqs = len(seqlens_q)\n        max_cute_error = 0.0\n        max_pt_error = 0.0\n\n        for i in range(num_seqs):\n            # Extract sequences using cu_seqlens (all outputs are in packed format)\n            start_q = cu_seqlens_q[i]\n            end_q = cu_seqlens_q[i + 1]\n            cute_seq = out_cute[start_q:end_q]\n            ref_seq = out_ref_fp32[start_q:end_q]\n            pt_seq = out_pt[start_q:end_q]\n\n            max_cute_error = max(\n                max_cute_error, (cute_seq - ref_seq).abs().max().item()\n            )\n            max_pt_error = max(max_pt_error, (pt_seq - ref_seq).abs().max().item())\n\n        cute_error = max_cute_error\n        pt_error = max_pt_error\n    else:\n        # Direct comparison\n        pt_error = (out_pt - out_ref_fp32).abs().max().item()\n        cute_error = (out_cute - out_ref_fp32).abs().max().item()\n\n    fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item()\n\n    print(f\"\\n{test_name}:\")\n    print(f\"  PyTorch vs FP32 ref: {pt_error:.2e}\")\n    print(f\"  CuTE vs FP32 ref: {cute_error:.2e}\")\n\n    tol = rtol * pt_error + fwd_atol + extra_atol\n    assert cute_error <= tol, (\n        f\"{test_name}: CuTE error {cute_error:.2e} exceeds tolerance {tol:.2e}\"\n    )\n\n\n# =============================================================================\n# Tests\n# =============================================================================\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16])\n@pytest.mark.parametrize(\"varlen_q\", [True, False])\n@pytest.mark.parametrize(\"varlen_k\", [True, False])\n@pytest.mark.parametrize(\"qhead_per_kvhead,num_kv_heads\", [(4, 2)])\n@pytest.mark.parametrize(\"seqlens_q,seqlens_k\", SEQLEN_CONFIGS)\n@pytest.mark.parametrize(\"score_mod_tuple\", TEST_PAIRS_NO_GLOBAL)\ndef test_varlen_with_score_mod(\n    seqlens_q,\n    seqlens_k,\n    varlen_q,\n    varlen_k,\n    qhead_per_kvhead,\n    num_kv_heads,\n    dtype,\n    score_mod_tuple,\n):\n    \"\"\"Test varlen attention with score_mod functions that don't use global indices.\n\n    Covers: both varlen, varlen Q only, varlen K only.\n    Skips: neither varlen\n    \"\"\"\n    if not varlen_q and not varlen_k:\n        pytest.skip(\n            \"At least one of varlen_q or varlen_k must be True for varlen tests\"\n        )\n\n    # For non-varlen dimension, all sequences must have same length\n    if not varlen_q:\n        seqlens_q = [seqlens_q[0]] * len(seqlens_q)\n    if not varlen_k:\n        seqlens_k = [seqlens_k[0]] * len(seqlens_k)\n\n    torch.random.manual_seed(42)\n    cute_score_mod, eager_factory, aux_type = score_mod_tuple\n\n    num_heads = num_kv_heads * qhead_per_kvhead\n    pack_gqa = qhead_per_kvhead > 1\n    head_dim = 128\n    batch_size = len(seqlens_q)\n\n    q, k, v, cu_seqlens_q, cu_seqlens_k = setup_tensors(\n        seqlens_q, seqlens_k, varlen_q, varlen_k, num_heads, head_dim, dtype\n    )\n\n    if pack_gqa:\n        if varlen_k:\n            k = k[:, :num_kv_heads, :].clone()\n            v = v[:, :num_kv_heads, :].clone()\n        else:\n            k = k[:, :, :num_kv_heads, :].clone()\n            v = v[:, :, :num_kv_heads, :].clone()\n\n    aux_tensors = None\n    if aux_type == \"batch\":\n        bias = torch.zeros(batch_size, device=\"cuda\", dtype=dtype) * 0.1\n        aux_tensors = [bias]\n        eager_score_mod = eager_factory(bias)\n    elif aux_type == \"dual_buffer\":\n        seqlen_q = seqlens_q[0] if not varlen_q else max(seqlens_q)\n        head_bias = torch.randn(num_heads, device=\"cuda\", dtype=dtype) * 0.2\n        pos_bias = torch.arange(seqlen_q, device=\"cuda\", dtype=dtype) * 0.01\n        aux_tensors = [head_bias, pos_bias]\n        eager_score_mod = eager_factory(head_bias, pos_bias)\n    else:\n        eager_score_mod = eager_factory\n\n    # Prepare reference tensors\n    q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k = prepare_ref_tensors(\n        q, k, v, cu_seqlens_q, cu_seqlens_k, varlen_q, varlen_k, batch_size, seqlens_q\n    )\n\n    out_ref_fp32 = run_flex_varlen_ref(\n        q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k, eager_score_mod, dtype=torch.float32\n    )\n    out_pt = run_flex_varlen_ref(\n        q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k, eager_score_mod, dtype=dtype\n    )\n    out_cute = run_cute_flash(\n        q,\n        k,\n        v,\n        cute_score_mod,\n        aux_tensors=aux_tensors,\n        pack_gqa=pack_gqa,\n        cu_seqlens_q=cu_seqlens_q,\n        cu_seqlens_k=cu_seqlens_k,\n    )\n\n    if not varlen_q and varlen_k:\n        seqlen_q = q.shape[1]\n        out_ref_fp32 = out_ref_fp32.reshape(batch_size, seqlen_q, num_heads, head_dim)\n        out_pt = out_pt.reshape(batch_size, seqlen_q, num_heads, head_dim)\n\n    assert out_cute.shape == out_ref_fp32.shape, (\n        f\"Shape mismatch: {out_cute.shape} vs {out_ref_fp32.shape}\"\n    )\n\n    test_name = f\"{cute_score_mod.__name__} (varlen_q={varlen_q}, varlen_k={varlen_k})\"\n    extra_atol = 2e-3\n    check_results(\n        out_cute,\n        out_ref_fp32,\n        out_pt,\n        test_name,\n        extra_atol=extra_atol,\n        seqlens_q=seqlens_q if varlen_q else None,\n        cu_seqlens_q=cu_seqlens_q if varlen_q else None,\n    )\n\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16])\n@pytest.mark.parametrize(\"varlen_q\", [True, False])\n@pytest.mark.parametrize(\"varlen_k\", [True, False])\n@pytest.mark.parametrize(\"qhead_per_kvhead,num_kv_heads\", [(4, 2)])\n@pytest.mark.parametrize(\"seqlens_q,seqlens_k\", SEQLEN_CONFIGS)\n@pytest.mark.parametrize(\"score_mod_vec_tuple\", TEST_PAIRS_VECTORIZED_NO_GLOBAL)\ndef test_varlen_with_score_mod_vectorized(\n    seqlens_q,\n    seqlens_k,\n    varlen_q,\n    varlen_k,\n    qhead_per_kvhead,\n    num_kv_heads,\n    dtype,\n    score_mod_vec_tuple,\n):\n    \"\"\"Tests equality between original and vectorized versions of score mods\"\"\"\n    if not varlen_q and not varlen_k:\n        pytest.skip(\n            \"At least one of varlen_q or varlen_k must be True for varlen tests\"\n        )\n\n    # For non-varlen dimension, all sequences must have same length\n    if not varlen_q:\n        seqlens_q = [seqlens_q[0]] * len(seqlens_q)\n    if not varlen_k:\n        seqlens_k = [seqlens_k[0]] * len(seqlens_k)\n    torch.random.manual_seed(42)\n    cute_score_mod, cute_vectorized_score_mod, aux_type = score_mod_vec_tuple\n\n    num_heads = num_kv_heads * qhead_per_kvhead\n    pack_gqa = qhead_per_kvhead > 1\n    head_dim = 128\n    batch_size = len(seqlens_q)\n\n    q, k, v, cu_seqlens_q, cu_seqlens_k = setup_tensors(\n        seqlens_q, seqlens_k, varlen_q, varlen_k, num_heads, head_dim, dtype\n    )\n    aux_tensors = None\n    if aux_type == \"batch\":\n        bias = torch.zeros(batch_size, device=\"cuda\", dtype=dtype) * 0.1\n        aux_tensors = [bias]\n    elif aux_type == \"dual_buffer\":\n        seqlen_q = seqlens_q[0] if not varlen_q else max(seqlens_q)\n        head_bias = torch.randn(num_heads, device=\"cuda\", dtype=dtype) * 0.2\n        pos_bias = torch.arange(seqlen_q, device=\"cuda\", dtype=dtype) * 0.01\n        aux_tensors = [head_bias, pos_bias]\n\n    if pack_gqa:\n        if varlen_k:\n            k = k[:, :num_kv_heads, :].clone()\n            v = v[:, :num_kv_heads, :].clone()\n        else:\n            k = k[:, :, :num_kv_heads, :].clone()\n            v = v[:, :, :num_kv_heads, :].clone()\n\n    out_ref = run_cute_flash(\n        q,\n        k,\n        v,\n        cute_score_mod,\n        aux_tensors=aux_tensors,\n        pack_gqa=pack_gqa,\n        cu_seqlens_q=cu_seqlens_q,\n        cu_seqlens_k=cu_seqlens_k,\n    )\n\n    for vec_size in VEC_SIZES_TO_CHECK_EQUALITY:\n        cute_vectorized_score_mod.__vec_size__ = vec_size\n        out = run_cute_flash(\n            q,\n            k,\n            v,\n            cute_vectorized_score_mod,\n            aux_tensors=aux_tensors,\n            pack_gqa=pack_gqa,\n            cu_seqlens_q=cu_seqlens_q,\n            cu_seqlens_k=cu_seqlens_k,\n        )\n        assert torch.equal(out, out_ref)\n\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16])\n@pytest.mark.parametrize(\"varlen_q\", [True, False])\n@pytest.mark.parametrize(\"varlen_k\", [True, False])\n@pytest.mark.parametrize(\"qhead_per_kvhead,num_kv_heads\", [(1, 1), (4, 2)])\n@pytest.mark.parametrize(\"seqlens_q,seqlens_k\", SEQLEN_CONFIGS)\n@pytest.mark.parametrize(\"score_mod_tuple\", TEST_PAIRS_WITH_GLOBAL)\ndef test_varlen_with_global_idx_score_mod(\n    seqlens_q,\n    seqlens_k,\n    varlen_q,\n    varlen_k,\n    qhead_per_kvhead,\n    num_kv_heads,\n    dtype,\n    score_mod_tuple,\n):\n    \"\"\"Test varlen attention with score_mod functions that use global indices.\n\n    These score_mods compute q_idx_global and/or kv_idx_global from seqlen_info for packed tensor indexing.\n    Skips tests where required global indices aren't available.\n    \"\"\"\n    if not varlen_q and not varlen_k:\n        pytest.skip(\n            \"At least one of varlen_q or varlen_k must be True for varlen tests\"\n        )\n\n    cute_score_mod, eager_factory, aux_type, requires_global = score_mod_tuple\n\n    # Skip if score_mod requires global indices we can't provide\n    if requires_global == \"q\" and not varlen_q:\n        pytest.skip(f\"{cute_score_mod.__name__} requires varlen_q for q_idx_global\")\n    if requires_global == \"kv\" and not varlen_k:\n        pytest.skip(f\"{cute_score_mod.__name__} requires varlen_k for kv_idx_global\")\n    if requires_global == \"both\" and (not varlen_q or not varlen_k):\n        pytest.skip(f\"{cute_score_mod.__name__} requires both varlen_q and varlen_k\")\n\n    # For non-varlen dimension, all sequences must have same length\n    if not varlen_q:\n        seqlens_q = [seqlens_q[0]] * len(seqlens_q)\n    if not varlen_k:\n        seqlens_k = [seqlens_k[0]] * len(seqlens_k)\n\n    torch.random.manual_seed(42)\n\n    num_heads = num_kv_heads * qhead_per_kvhead\n    pack_gqa = qhead_per_kvhead > 1\n    head_dim = 128\n    batch_size = len(seqlens_q)\n    max_rel_pos = 512\n\n    total_q = sum(seqlens_q)\n    total_k = sum(seqlens_k)\n\n    cu_seqlens_q = torch.tensor(\n        [0] + list(torch.tensor(seqlens_q).cumsum(0).tolist()),\n        device=\"cuda\",\n        dtype=torch.int32,\n    )\n    cu_seqlens_k = torch.tensor(\n        [0] + list(torch.tensor(seqlens_k).cumsum(0).tolist()),\n        device=\"cuda\",\n        dtype=torch.int32,\n    )\n\n    if varlen_q:\n        q = torch.randn(total_q, num_heads, head_dim, device=\"cuda\", dtype=dtype)\n    else:\n        seqlen_q = seqlens_q[0]\n        q = torch.randn(\n            batch_size, seqlen_q, num_heads, head_dim, device=\"cuda\", dtype=dtype\n        )\n\n    if varlen_k:\n        k = torch.randn(total_k, num_heads, head_dim, device=\"cuda\", dtype=dtype)\n        v = torch.randn(total_k, num_heads, head_dim, device=\"cuda\", dtype=dtype)\n    else:\n        seqlen_k = seqlens_k[0]\n        k = torch.randn(\n            batch_size, seqlen_k, num_heads, head_dim, device=\"cuda\", dtype=dtype\n        )\n        v = torch.randn(\n            batch_size, seqlen_k, num_heads, head_dim, device=\"cuda\", dtype=dtype\n        )\n\n    if pack_gqa:\n        if varlen_k:\n            k = k[:, :num_kv_heads, :].clone()\n            v = v[:, :num_kv_heads, :].clone()\n        else:\n            k = k[:, :, :num_kv_heads, :].clone()\n            v = v[:, :, :num_kv_heads, :].clone()\n\n    # Setup aux tensors based on indexing type\n    if aux_type == \"kv\":\n        bias = torch.randn(total_k, device=\"cuda\", dtype=dtype) * 0.1\n        aux_tensors = [bias]\n        eager_score_mod = eager_factory(bias, cu_seqlens_k)\n    elif aux_type == \"q\":\n        bias = torch.randn(total_q, device=\"cuda\", dtype=dtype) * 0.1\n        aux_tensors = [bias]\n        eager_score_mod = eager_factory(bias, cu_seqlens_q)\n    elif aux_type == \"q_and_kv\":\n        q_bias = torch.randn(total_q, device=\"cuda\", dtype=dtype) * 0.1\n        kv_bias = torch.randn(total_k, device=\"cuda\", dtype=dtype) * 0.1\n        aux_tensors = [q_bias, kv_bias]\n        eager_score_mod = eager_factory(q_bias, kv_bias, cu_seqlens_q, cu_seqlens_k)\n    elif aux_type == \"q_concat\":\n        bias = torch.randn(total_q, device=\"cuda\", dtype=dtype) * 0.1\n        aux_tensors = [bias]\n        eager_score_mod = eager_factory(bias, cu_seqlens_q)\n    elif aux_type == \"kv_with_cu\":\n        kv_bias = torch.randn(total_k, device=\"cuda\", dtype=dtype) * 0.1\n        aux_tensors = [kv_bias]\n        eager_score_mod = eager_factory(kv_bias, cu_seqlens_q, cu_seqlens_k)\n    elif aux_type == \"multi_buffer\":\n        batch_bias = torch.randn(batch_size, device=\"cuda\", dtype=dtype) * 0.1\n        head_scale = torch.randn(num_heads, device=\"cuda\", dtype=dtype) * 0.1 + 1.0\n        q_pos_bias = torch.randn(total_q, device=\"cuda\", dtype=dtype) * 0.1\n        kv_pos_bias = torch.randn(total_k, device=\"cuda\", dtype=dtype) * 0.1\n        rel_pos_scale = (\n            torch.randn(max_rel_pos * 2 + 1, device=\"cuda\", dtype=dtype) * 0.1\n        )\n        aux_tensors = [batch_bias, head_scale, q_pos_bias, kv_pos_bias, rel_pos_scale]\n        eager_score_mod = eager_factory(\n            batch_bias,\n            head_scale,\n            q_pos_bias,\n            kv_pos_bias,\n            rel_pos_scale,\n            cu_seqlens_q,\n            cu_seqlens_k,\n            max_rel_pos,\n        )\n    else:\n        raise ValueError(f\"Unknown aux_type: {aux_type}\")\n\n    # Prepare reference tensors for flex_attention\n    q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k = prepare_ref_tensors(\n        q, k, v, cu_seqlens_q, cu_seqlens_k, varlen_q, varlen_k, batch_size, seqlens_q\n    )\n\n    out_ref_fp32 = run_flex_varlen_ref(\n        q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k, eager_score_mod, dtype=torch.float32\n    )\n    out_pt = run_flex_varlen_ref(\n        q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k, eager_score_mod, dtype=dtype\n    )\n\n    kernel_cu_seqlens_q = cu_seqlens_q if varlen_q else None\n    kernel_cu_seqlens_k = cu_seqlens_k if varlen_k else None\n    out_cute = run_cute_flash(\n        q,\n        k,\n        v,\n        cute_score_mod,\n        aux_tensors=aux_tensors,\n        pack_gqa=pack_gqa,\n        cu_seqlens_q=kernel_cu_seqlens_q,\n        cu_seqlens_k=kernel_cu_seqlens_k,\n    )\n\n    if varlen_q:\n        out_ref_final = out_ref_fp32\n        out_pt_final = out_pt\n        out_cute_final = out_cute\n    else:\n        seqlen_q = seqlens_q[0]\n        out_ref_final = out_ref_fp32.reshape(batch_size, seqlen_q, num_heads, head_dim)\n        out_pt_final = out_pt.reshape(batch_size, seqlen_q, num_heads, head_dim)\n        out_cute_final = out_cute\n\n    assert out_cute_final.shape == out_ref_final.shape, (\n        f\"Shape mismatch: {out_cute_final.shape} vs {out_ref_final.shape}\"\n    )\n\n    test_name = f\"{cute_score_mod.__name__} (varlen_q={varlen_q}, varlen_k={varlen_k}, {aux_type})\"\n\n    check_results(\n        out_cute_final,\n        out_ref_final,\n        out_pt_final,\n        test_name,\n        extra_atol=1e-3,\n        seqlens_q=seqlens_q if varlen_q else None,\n        cu_seqlens_q=cu_seqlens_q if varlen_q else None,\n    )\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16])\n@pytest.mark.parametrize(\"page_size\", [None, 128])\n@pytest.mark.parametrize(\"varlen_q\", [True, False])\n@pytest.mark.parametrize(\"varlen_k\", [True, False])\n@pytest.mark.parametrize(\"qhead_per_kvhead,num_kv_heads\", [(4, 2)])\n@pytest.mark.parametrize(\"seqlens_q,seqlens_k\", SEQLEN_CONFIGS)\n@pytest.mark.parametrize(\"score_mod_tuple\", TEST_PAIRS_NO_GLOBAL)\ndef test_varlen_score_mod_kvcache(\n    seqlens_q,\n    seqlens_k,\n    varlen_q,\n    varlen_k,\n    qhead_per_kvhead,\n    num_kv_heads,\n    page_size,\n    dtype,\n    score_mod_tuple,\n):\n    \"\"\"Test varlen attention with score_mod and paged KV cache.\"\"\"\n    if IS_SM90 and page_size is not None:\n        pytest.xfail(\"paged KV not supported on SM90\")\n\n    if not varlen_q and not varlen_k:\n        pytest.skip(\n            \"At least one of varlen_q or varlen_k must be True for varlen tests\"\n        )\n\n    if page_size is not None and varlen_k:\n        pytest.skip(\"Paged KV requires batched (non-varlen) K\")\n\n    if not varlen_q:\n        seqlens_q = [seqlens_q[0]] * len(seqlens_q)\n    if not varlen_k:\n        seqlens_k = [seqlens_k[0]] * len(seqlens_k)\n\n    # Skip if page_size doesn't divide seqlens evenly (for simplicity)\n    if page_size is not None and not varlen_k:\n        if seqlens_k[0] % page_size != 0:\n            pytest.skip(\"page_size must divide seqlen_k\")\n\n    torch.random.manual_seed(42)\n    cute_score_mod, eager_factory, aux_type = score_mod_tuple\n\n    num_heads = num_kv_heads * qhead_per_kvhead\n    pack_gqa = qhead_per_kvhead > 1\n    head_dim = 128\n    batch_size = len(seqlens_q)\n    device = \"cuda\"\n\n    # Setup tensors\n    q, k, v, cu_seqlens_q, cu_seqlens_k = setup_tensors(\n        seqlens_q, seqlens_k, varlen_q, varlen_k, num_heads, head_dim, dtype\n    )\n\n    if pack_gqa:\n        if varlen_k:\n            k = k[:, :num_kv_heads, :].clone()\n            v = v[:, :num_kv_heads, :].clone()\n        else:\n            k = k[:, :, :num_kv_heads, :].clone()\n            v = v[:, :, :num_kv_heads, :].clone()\n\n    page_table = None\n    k_cache_paged = None\n    v_cache_paged = None\n    k_cache = k\n    v_cache = v\n\n    if page_size is not None:\n        seqlen_k = seqlens_k[0]\n        (\n            k_cache_bhsd,\n            v_cache_bhsd,\n            page_table,\n            k_cache_paged,\n            v_cache_paged,\n            num_blocks,\n        ) = _generate_block_kvcache(\n            seqlen_k, page_size, batch_size, num_kv_heads, head_dim, device, dtype\n        )\n        k_cache = k_cache_bhsd.transpose(1, 2)  # BHSD -> BSHD\n        v_cache = v_cache_bhsd.transpose(1, 2)\n        seqused_k = torch.tensor(seqlens_k, dtype=torch.int32, device=device)\n    else:\n        seqused_k = None\n\n    # Setup aux tensors and eager score_mod\n    aux_tensors = None\n    if aux_type == \"batch\":\n        bias = torch.zeros(batch_size, device=device, dtype=dtype) * 0.1\n        aux_tensors = [bias]\n        eager_score_mod = eager_factory(bias)\n    elif aux_type == \"dual_buffer\":\n        seqlen_q = seqlens_q[0] if not varlen_q else max(seqlens_q)\n        head_bias = torch.randn(num_heads, device=device, dtype=dtype) * 0.2\n        pos_bias = torch.arange(seqlen_q, device=device, dtype=dtype) * 0.01\n        aux_tensors = [head_bias, pos_bias]\n        eager_score_mod = eager_factory(head_bias, pos_bias)\n    else:\n        eager_score_mod = eager_factory\n\n    # Prepare reference tensors\n    q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k = prepare_ref_tensors(\n        q,\n        k_cache,\n        v_cache,\n        cu_seqlens_q,\n        cu_seqlens_k,\n        varlen_q,\n        varlen_k,\n        batch_size,\n        seqlens_q,\n    )\n\n    out_ref_fp32 = run_flex_varlen_ref(\n        q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k, eager_score_mod, dtype=torch.float32\n    )\n    out_pt = run_flex_varlen_ref(\n        q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k, eager_score_mod, dtype=dtype\n    )\n\n    k_input = k_cache_paged if page_size is not None else k_cache\n    v_input = v_cache_paged if page_size is not None else v_cache\n\n    out_cute = run_cute_flash(\n        q,\n        k_input,\n        v_input,\n        cute_score_mod,\n        aux_tensors=aux_tensors,\n        pack_gqa=pack_gqa,\n        cu_seqlens_q=cu_seqlens_q if varlen_q else None,\n        cu_seqlens_k=cu_seqlens_k if (varlen_k and page_size is None) else None,\n        page_table=page_table if page_size is not None else None,\n        seqused_k=seqused_k if page_size is not None else None,\n    )\n\n    if not varlen_q and varlen_k:\n        seqlen_q = q.shape[1]\n        out_ref_fp32 = out_ref_fp32.reshape(batch_size, seqlen_q, num_heads, head_dim)\n        out_pt = out_pt.reshape(batch_size, seqlen_q, num_heads, head_dim)\n\n    assert out_cute.shape == out_ref_fp32.shape, (\n        f\"Shape mismatch: {out_cute.shape} vs {out_ref_fp32.shape}\"\n    )\n\n    test_name = f\"{cute_score_mod.__name__} (varlen_q={varlen_q}, varlen_k={varlen_k}, paged={page_size is not None})\"\n    extra_atol = 2e-3\n    check_results(\n        out_cute,\n        out_ref_fp32,\n        out_pt,\n        test_name,\n        extra_atol=extra_atol,\n        seqlens_q=seqlens_q if varlen_q else None,\n        cu_seqlens_q=cu_seqlens_q if varlen_q else None,\n    )\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16])\n@pytest.mark.parametrize(\"page_size\", [None, 128])\n@pytest.mark.parametrize(\"varlen_q\", [True, False])\n@pytest.mark.parametrize(\"varlen_k\", [True, False])\n@pytest.mark.parametrize(\"qhead_per_kvhead,num_kv_heads\", [(1, 1), (4, 2)])\n@pytest.mark.parametrize(\"seqlens_q,seqlens_k\", SEQLEN_CONFIGS)\n@pytest.mark.parametrize(\"score_mod_tuple\", TEST_PAIRS_WITH_GLOBAL)\ndef test_varlen_score_mod_with_paged_kvcache_global(\n    seqlens_q,\n    seqlens_k,\n    varlen_q,\n    varlen_k,\n    qhead_per_kvhead,\n    num_kv_heads,\n    page_size,\n    dtype,\n    score_mod_tuple,\n):\n    \"\"\"Test varlen attention with global idx score_mod and paged KV cache.\"\"\"\n    if IS_SM90 and page_size is not None:\n        pytest.xfail(\"paged KV not supported on SM90\")\n\n    if page_size is not None and varlen_k:\n        pytest.skip(\"Paged KV cache requires batched (non-varlen) K\")\n\n    if not varlen_q and not varlen_k:\n        pytest.skip(\n            \"At least one of varlen_q or varlen_k must be True for varlen tests\"\n        )\n\n    if not varlen_q:\n        seqlens_q = [seqlens_q[0]] * len(seqlens_q)\n    if not varlen_k:\n        seqlens_k = [seqlens_k[0]] * len(seqlens_k)\n\n    if page_size is not None and not varlen_k:\n        if seqlens_k[0] % page_size != 0:\n            pytest.skip(\"page_size must divide seqlen_k\")\n\n    cute_score_mod, eager_factory, aux_type, requires_global = score_mod_tuple\n\n    if requires_global == \"q\" and not varlen_q:\n        pytest.skip(f\"{cute_score_mod.__name__} requires varlen_q for q_idx_global\")\n    if requires_global == \"kv\" and not varlen_k:\n        pytest.skip(f\"{cute_score_mod.__name__} requires varlen_k for kv_idx_global\")\n    if requires_global == \"both\" and (not varlen_q or not varlen_k):\n        pytest.skip(f\"{cute_score_mod.__name__} requires both varlen_q and varlen_k\")\n\n    torch.random.manual_seed(42)\n\n    num_heads = num_kv_heads * qhead_per_kvhead\n    pack_gqa = qhead_per_kvhead > 1\n    head_dim = 128\n    batch_size = len(seqlens_q)\n    max_rel_pos = 512\n    device = \"cuda\"\n\n    total_q = sum(seqlens_q)\n    total_k = sum(seqlens_k)\n\n    cu_seqlens_q = torch.tensor(\n        [0] + list(torch.tensor(seqlens_q).cumsum(0).tolist()),\n        device=device,\n        dtype=torch.int32,\n    )\n    cu_seqlens_k = torch.tensor(\n        [0] + list(torch.tensor(seqlens_k).cumsum(0).tolist()),\n        device=device,\n        dtype=torch.int32,\n    )\n    cu_seqlens_k_for_kernel = cu_seqlens_k if varlen_k else None\n\n    q = torch.randn(total_q, num_heads, head_dim, device=device, dtype=dtype)\n    if varlen_k:\n        k = torch.randn(total_k, num_heads, head_dim, device=device, dtype=dtype)\n        v = torch.randn(total_k, num_heads, head_dim, device=device, dtype=dtype)\n    else:\n        seqlen_k = seqlens_k[0]\n        k = torch.randn(\n            batch_size, seqlen_k, num_heads, head_dim, device=device, dtype=dtype\n        )\n        v = torch.randn(\n            batch_size, seqlen_k, num_heads, head_dim, device=device, dtype=dtype\n        )\n\n    if pack_gqa:\n        if varlen_k:\n            k = k[:, :num_kv_heads, :].clone()\n            v = v[:, :num_kv_heads, :].clone()\n        else:\n            k = k[:, :, :num_kv_heads, :].clone()\n            v = v[:, :, :num_kv_heads, :].clone()\n\n    page_table = None\n    k_cache_paged = None\n    v_cache_paged = None\n    k_cache = k\n    v_cache = v\n\n    if page_size is not None:\n        seqlen_k = seqlens_k[0]\n        (\n            k_cache_bhsd,\n            v_cache_bhsd,\n            page_table,\n            k_cache_paged,\n            v_cache_paged,\n            num_blocks,\n        ) = _generate_block_kvcache(\n            seqlen_k, page_size, batch_size, num_kv_heads, head_dim, device, dtype\n        )\n        k_cache = k_cache_bhsd.transpose(1, 2)  # BHSD -> BSHD\n        v_cache = v_cache_bhsd.transpose(1, 2)\n        seqused_k = torch.tensor(seqlens_k, dtype=torch.int32, device=device)\n    else:\n        seqused_k = None\n\n    if aux_type == \"kv\":\n        bias = torch.randn(total_k, device=device, dtype=dtype) * 0.1\n        aux_tensors = [bias]\n        eager_score_mod = eager_factory(bias, cu_seqlens_k)\n    elif aux_type == \"q\":\n        bias = torch.randn(total_q, device=device, dtype=dtype) * 0.1\n        aux_tensors = [bias]\n        eager_score_mod = eager_factory(bias, cu_seqlens_q)\n    elif aux_type == \"q_and_kv\":\n        q_bias = torch.randn(total_q, device=device, dtype=dtype) * 0.1\n        kv_bias = torch.randn(total_k, device=device, dtype=dtype) * 0.1\n        aux_tensors = [q_bias, kv_bias]\n        eager_score_mod = eager_factory(q_bias, kv_bias, cu_seqlens_q, cu_seqlens_k)\n    elif aux_type == \"q_concat\":\n        bias = torch.randn(total_q, device=device, dtype=dtype) * 0.1\n        aux_tensors = [bias]\n        eager_score_mod = eager_factory(bias, cu_seqlens_q)\n    elif aux_type == \"kv_with_cu\":\n        kv_bias = torch.randn(total_k, device=device, dtype=dtype) * 0.1\n        aux_tensors = [kv_bias]\n        eager_score_mod = eager_factory(kv_bias, cu_seqlens_q, cu_seqlens_k)\n    elif aux_type == \"multi_buffer\":\n        batch_bias = torch.randn(batch_size, device=device, dtype=dtype) * 0.1\n        head_scale = torch.randn(num_heads, device=device, dtype=dtype) * 0.1 + 1.0\n        q_pos_bias = torch.randn(total_q, device=device, dtype=dtype) * 0.1\n        kv_pos_bias = torch.randn(total_k, device=device, dtype=dtype) * 0.1\n        rel_pos_scale = (\n            torch.randn(max_rel_pos * 2 + 1, device=device, dtype=dtype) * 0.1\n        )\n        aux_tensors = [batch_bias, head_scale, q_pos_bias, kv_pos_bias, rel_pos_scale]\n        eager_score_mod = eager_factory(\n            batch_bias,\n            head_scale,\n            q_pos_bias,\n            kv_pos_bias,\n            rel_pos_scale,\n            cu_seqlens_q,\n            cu_seqlens_k,\n            max_rel_pos,\n        )\n    else:\n        raise ValueError(f\"Unknown aux_type: {aux_type}\")\n\n    q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k = prepare_ref_tensors(\n        q,\n        k_cache,\n        v_cache,\n        cu_seqlens_q,\n        cu_seqlens_k,\n        True,\n        varlen_k,\n        batch_size,\n        seqlens_q,\n    )\n\n    out_ref_fp32 = run_flex_varlen_ref(\n        q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k, eager_score_mod, dtype=torch.float32\n    )\n    out_pt = run_flex_varlen_ref(\n        q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k, eager_score_mod, dtype=dtype\n    )\n\n    # Run CuTE\n    k_input = k_cache_paged if page_size is not None else k_cache\n    v_input = v_cache_paged if page_size is not None else v_cache\n\n    out_cute = torch.empty_like(q)\n    _flash_attn_fwd(\n        q,\n        k_input,\n        v_input,\n        cu_seqlens_q=cu_seqlens_q,\n        cu_seqlens_k=cu_seqlens_k_for_kernel if page_size is None else None,\n        seqused_k=seqused_k if page_size is not None else None,\n        page_table=page_table,\n        return_lse=True,\n        score_mod=cute_score_mod,\n        out=out_cute,\n        lse=None,\n        aux_tensors=aux_tensors,\n        pack_gqa=pack_gqa,\n    )\n\n    assert out_cute.shape == out_ref_fp32.shape, (\n        f\"Shape mismatch: {out_cute.shape} vs {out_ref_fp32.shape}\"\n    )\n\n    test_name = f\"{cute_score_mod.__name__} (paged={page_size is not None}, {aux_type})\"\n    check_results(\n        out_cute,\n        out_ref_fp32,\n        out_pt,\n        test_name,\n        extra_atol=1e-3,\n        seqlens_q=seqlens_q,\n        cu_seqlens_q=cu_seqlens_q,\n    )\n\n\nif __name__ == \"__main__\":\n    pytest.main([__file__, \"-v\"])\n"
  },
  {
    "path": "tests/cute/test_utils.py",
    "content": "\"\"\"Unit tests for flash_attn.cute.utils module.\"\"\"\n\nimport functools\n\nfrom flash_attn.cute import utils as cute_utils\nfrom flash_attn.cute.utils import hash_callable\n\n\nclass TestHashCallable:\n    \"\"\"Tests for hash_callable function.\"\"\"\n\n    def test_returns_cute_hash_when_set_on_function(self):\n        \"\"\"hash_callable should return __cute_hash__ immediately when set on function.\"\"\"\n\n        def my_func():\n            pass\n\n        my_func.__cute_hash__ = \"precomputed-hash-123\"\n\n        result = hash_callable(my_func)\n        assert result == \"precomputed-hash-123\"\n\n    def test_returns_cute_hash_from_wrapped_function(self):\n        \"\"\"hash_callable should check __wrapped__ for __cute_hash__.\"\"\"\n\n        def inner_func():\n            pass\n\n        inner_func.__cute_hash__ = \"inner-hash-456\"\n\n        # Simulate a decorator that sets __wrapped__\n        @functools.wraps(inner_func)\n        def wrapper_func():\n            return inner_func()\n\n        result = hash_callable(wrapper_func)\n        assert result == \"inner-hash-456\"\n\n    def test_prefers_wrapper_cute_hash_over_wrapped(self):\n        \"\"\"When both wrapper and wrapped have __cute_hash__, prefer wrapper.\"\"\"\n\n        def inner_func():\n            pass\n\n        inner_func.__cute_hash__ = \"inner-hash\"\n\n        @functools.wraps(inner_func)\n        def wrapper_func():\n            return inner_func()\n\n        wrapper_func.__cute_hash__ = \"wrapper-hash\"\n\n        result = hash_callable(wrapper_func)\n        assert result == \"wrapper-hash\"\n\n    def test_fallback_to_source_hashing(self):\n        \"\"\"hash_callable should fall back to source hashing when no __cute_hash__.\"\"\"\n\n        def my_func():\n            return 42\n\n        result = hash_callable(my_func)\n        # Should return a hex string (SHA256 hash)\n        assert isinstance(result, str)\n        assert len(result) == 64  # SHA256 produces 64 hex chars\n\n    def test_same_function_produces_same_hash(self):\n        \"\"\"Same function should produce consistent hash.\"\"\"\n\n        def my_func():\n            return 42\n\n        hash1 = hash_callable(my_func)\n        hash2 = hash_callable(my_func)\n        assert hash1 == hash2\n\n    def test_different_functions_produce_different_hashes(self):\n        \"\"\"Different functions should produce different hashes.\"\"\"\n\n        def func_a():\n            return 1\n\n        def func_b():\n            return 2\n\n        hash_a = hash_callable(func_a)\n        hash_b = hash_callable(func_b)\n        assert hash_a != hash_b\n\n    def test_fast_path_skips_expensive_hashing(self):\n        \"\"\"When __cute_hash__ is set, expensive operations should be skipped.\"\"\"\n\n        def my_func():\n            pass\n\n        my_func.__cute_hash__ = \"fast-hash\"\n\n        # Mock at module level since we loaded it directly\n        original_getsource = cute_utils.inspect.getsource\n        call_tracker = {\"getsource\": 0, \"sha256\": 0}\n\n        def tracking_getsource(*args, **kwargs):\n            call_tracker[\"getsource\"] += 1\n            return original_getsource(*args, **kwargs)\n\n        original_sha256 = cute_utils.hashlib.sha256\n\n        def tracking_sha256(*args, **kwargs):\n            call_tracker[\"sha256\"] += 1\n            return original_sha256(*args, **kwargs)\n\n        cute_utils.inspect.getsource = tracking_getsource\n        cute_utils.hashlib.sha256 = tracking_sha256\n        try:\n            result = hash_callable(my_func)\n        finally:\n            cute_utils.inspect.getsource = original_getsource\n            cute_utils.hashlib.sha256 = original_sha256\n\n        # Neither inspect.getsource nor hashlib.sha256 should be called\n        assert call_tracker[\"getsource\"] == 0, \"getsource should not be called\"\n        assert call_tracker[\"sha256\"] == 0, \"sha256 should not be called\"\n        assert result == \"fast-hash\"\n\n    def test_fast_path_on_wrapped_skips_expensive_hashing(self):\n        \"\"\"When __cute_hash__ is on __wrapped__, expensive operations should be skipped.\"\"\"\n\n        def inner_func():\n            pass\n\n        inner_func.__cute_hash__ = \"wrapped-fast-hash\"\n\n        @functools.wraps(inner_func)\n        def wrapper_func():\n            return inner_func()\n\n        # Mock at module level\n        original_getsource = cute_utils.inspect.getsource\n        call_tracker = {\"getsource\": 0, \"sha256\": 0}\n\n        def tracking_getsource(*args, **kwargs):\n            call_tracker[\"getsource\"] += 1\n            return original_getsource(*args, **kwargs)\n\n        original_sha256 = cute_utils.hashlib.sha256\n\n        def tracking_sha256(*args, **kwargs):\n            call_tracker[\"sha256\"] += 1\n            return original_sha256(*args, **kwargs)\n\n        cute_utils.inspect.getsource = tracking_getsource\n        cute_utils.hashlib.sha256 = tracking_sha256\n        try:\n            result = hash_callable(wrapper_func)\n        finally:\n            cute_utils.inspect.getsource = original_getsource\n            cute_utils.hashlib.sha256 = original_sha256\n\n        assert call_tracker[\"getsource\"] == 0, \"getsource should not be called\"\n        assert call_tracker[\"sha256\"] == 0, \"sha256 should not be called\"\n        assert result == \"wrapped-fast-hash\"\n\n    def test_closure_values_affect_hash(self):\n        \"\"\"Functions with different closure values should have different hashes.\"\"\"\n        value1 = 10\n        value2 = 20\n\n        def make_func(val):\n            def inner():\n                return val\n\n            return inner\n\n        func1 = make_func(value1)\n        func2 = make_func(value2)\n\n        hash1 = hash_callable(func1)\n        hash2 = hash_callable(func2)\n        assert hash1 != hash2\n\n\nclass TestHashCallableIntegration:\n    \"\"\"Integration tests for hash_callable with flash attention.\"\"\"\n\n    def test_repeated_calls_use_cached_hash(self):\n        \"\"\"Repeated calls with same score_mod should use cached/fast hash path.\"\"\"\n\n        def score_mod(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors):\n            return tSrS_ssa\n\n        # Set __cute_hash__ to simulate Inductor-generated code\n        score_mod.__cute_hash__ = \"inductor-generated-hash\"\n\n        original_getsource = cute_utils.inspect.getsource\n        call_count = [0]  # Use list for mutable counter in nested function\n\n        def counting_getsource(*args, **kwargs):\n            call_count[0] += 1\n            return original_getsource(*args, **kwargs)\n\n        cute_utils.inspect.getsource = counting_getsource\n        try:\n            # Call hash_callable multiple times\n            hash1 = hash_callable(score_mod)\n            hash2 = hash_callable(score_mod)\n            hash3 = hash_callable(score_mod)\n        finally:\n            cute_utils.inspect.getsource = original_getsource\n\n        # getsource should never be called because __cute_hash__ is set\n        assert call_count[0] == 0, f\"getsource was called {call_count[0]} times\"\n        assert hash1 == hash2 == hash3 == \"inductor-generated-hash\"\n\n"
  },
  {
    "path": "tests/layers/test_rotary.py",
    "content": "# Copyright (c) 2023, Tri Dao.\n\nimport math\n\nimport pytest\nimport torch\nimport torch.nn.functional as F\nfrom einops import rearrange\nfrom flash_attn.layers.rotary import RotaryEmbedding, apply_rotary_emb_func, apply_rotary_emb_qkv_\nfrom transformers.models.gpt_neox.modeling_gpt_neox import RotaryEmbedding as RotaryEmbeddingNeoX\nfrom transformers.models.gpt_neox.modeling_gpt_neox import (\n    apply_rotary_pos_emb as apply_rotary_pos_emb_neox,\n)\nfrom transformers.models.gptj.modeling_gptj import apply_rotary_pos_emb as apply_rotary_pos_emb_gptj\nfrom transformers.models.gptj.modeling_gptj import fixed_pos_embedding\n\n\n# NeoX-style rotary embedding\n@pytest.mark.parametrize(\"seqlen_offset\", [0, 711])\n@pytest.mark.parametrize(\"rotary_emb_fraction\", [0.5, 1.0])\ndef test_rotary(rotary_emb_fraction, seqlen_offset):\n    device = \"cuda\"\n    dtype = torch.float16\n    rtol, atol = (1e-3, 5e-3)\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 8\n    seqlen_total = 2048\n    seqlen = seqlen_total - seqlen_offset\n    nheads = 16\n    headdim = 128\n    rotary_dim = int(headdim * rotary_emb_fraction)\n    qkv = torch.randn(\n        batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype, requires_grad=True\n    )\n    qkv_og = qkv.clone().detach()  # Our implementation modifies qkv inplace\n    rotary = RotaryEmbedding(rotary_dim, device=device)\n    rotary_neox = RotaryEmbeddingNeoX(rotary_dim, seqlen_total, device=device)\n    # Doesn't matter what tensor we pass in, rotary_neox only uses the device of the tensor\n    cos_neox, sin_neox = rotary_neox(qkv, seq_len=seqlen_total)\n    cos_neox, sin_neox = cos_neox.to(dtype=dtype), sin_neox.to(dtype=dtype)\n    q_pt = (\n        rearrange(qkv[:, :, 0, :, :rotary_dim], \"b s h d -> b h s d\")\n        .detach()\n        .clone()\n        .requires_grad_(True)\n    )\n    k_pt = (\n        rearrange(qkv[:, :, 1, :, :rotary_dim], \"b s h d -> b h s d\")\n        .detach()\n        .clone()\n        .requires_grad_(True)\n    )\n    q_neox, k_neox = apply_rotary_pos_emb_neox(q_pt, k_pt, cos_neox, sin_neox, offset=seqlen_offset)\n    out = rotary(qkv, seqlen_offset=seqlen_offset)\n    assert torch.allclose(\n        rotary._cos_cached, cos_neox[..., : rotary_dim // 2].to(dtype=dtype), rtol=rtol, atol=atol\n    )\n    assert torch.allclose(\n        rotary._sin_cached, sin_neox[..., : rotary_dim // 2].to(dtype=dtype), rtol=rtol, atol=atol\n    )\n    assert torch.allclose(\n        rearrange(q_neox, \"b h s d -> b s h d\"), out[:, :, 0, :, :rotary_dim], rtol=rtol, atol=atol\n    )\n    assert torch.allclose(\n        rearrange(k_neox, \"b h s d -> b s h d\"), out[:, :, 1, :, :rotary_dim], rtol=rtol, atol=atol\n    )\n    assert torch.equal(out[:, :, 0:2, :, rotary_dim:], qkv_og[:, :, 0:2, :, rotary_dim:])\n    assert torch.equal(out[:, :, 2], qkv_og[:, :, 2])\n\n    g = torch.randn_like(out)\n    g_og = g.clone().detach()  # Our implementation modifies g inplace\n    out.backward(g)\n    q_neox.backward(rearrange(g_og[:, :, 0, :, :rotary_dim], \"b s h d -> b h s d\"))\n    k_neox.backward(rearrange(g_og[:, :, 1, :, :rotary_dim], \"b s h d -> b h s d\"))\n    assert torch.allclose(\n        rearrange(q_pt.grad, \"b h s d -> b s h d\"),\n        qkv.grad[:, :, 0, :, :rotary_dim],\n        rtol=rtol,\n        atol=atol,\n    )\n    assert torch.allclose(\n        rearrange(k_pt.grad, \"b h s d -> b s h d\"),\n        qkv.grad[:, :, 1, :, :rotary_dim],\n        rtol=rtol,\n        atol=atol,\n    )\n    assert torch.equal(qkv.grad[:, :, 0:2, :, rotary_dim:], g_og[:, :, 0:2, :, rotary_dim:])\n    assert torch.equal(qkv.grad[:, :, 2], g_og[:, :, 2])\n\n\n# GPT-J-style rotary embedding\n@pytest.mark.parametrize(\"seqlen_offset\", [0, 711])\n@pytest.mark.parametrize(\"rotary_emb_fraction\", [0.5, 1.0])\ndef test_rotary_interleaved(rotary_emb_fraction, seqlen_offset):\n    device = \"cuda\"\n    dtype = torch.float16\n    rtol, atol = (1e-3, 5e-3)\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 8\n    seqlen_total = 2048\n    seqlen = seqlen_total - seqlen_offset\n    nheads = 16\n    headdim = 128\n    rotary_dim = int(headdim * rotary_emb_fraction)\n    qkv = torch.randn(\n        batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype, requires_grad=True\n    )\n    qkv_og = qkv.clone().detach()  # Our implementation modifies qkv inplace\n    rotary = RotaryEmbedding(rotary_dim, interleaved=True, device=device)\n    sincos_gptj = fixed_pos_embedding(qkv[..., :rotary_dim], seq_dim=1, seq_len=seqlen_total)\n    sincos_gptj = tuple(x.to(dtype=dtype) for x in sincos_gptj)\n    q_pt = qkv[:, :, 0, :, :rotary_dim].detach().clone().requires_grad_(True)\n    k_pt = qkv[:, :, 1, :, :rotary_dim].detach().clone().requires_grad_(True)\n    q_gptj = apply_rotary_pos_emb_gptj(q_pt, sincos_gptj, offset=seqlen_offset)\n    k_gptj = apply_rotary_pos_emb_gptj(k_pt, sincos_gptj, offset=seqlen_offset)\n    out = rotary(qkv, seqlen_offset=seqlen_offset)\n    assert torch.allclose(rotary._cos_cached, sincos_gptj[1], rtol=rtol, atol=atol)\n    assert torch.allclose(rotary._sin_cached, sincos_gptj[0], rtol=rtol, atol=atol)\n    assert torch.allclose(q_gptj, out[:, :, 0, :, :rotary_dim], rtol=rtol, atol=atol)\n    assert torch.allclose(k_gptj, out[:, :, 1, :, :rotary_dim], rtol=rtol, atol=atol)\n    assert torch.equal(out[:, :, 0:2, :, rotary_dim:], qkv_og[:, :, 0:2, :, rotary_dim:])\n    assert torch.equal(out[:, :, 2], qkv_og[:, :, 2])\n\n    g = torch.randn_like(out)\n    g_og = g.clone().detach()  # Our implementation modifies g inplace\n    out.backward(g)\n    q_gptj.backward(g_og[:, :, 0, :, :rotary_dim])\n    k_gptj.backward(g_og[:, :, 1, :, :rotary_dim])\n    assert torch.allclose(q_pt.grad, qkv.grad[:, :, 0, :, :rotary_dim], rtol=rtol, atol=atol)\n    assert torch.allclose(k_pt.grad, qkv.grad[:, :, 1, :, :rotary_dim], rtol=rtol, atol=atol)\n    assert torch.equal(qkv.grad[:, :, 0:2, :, rotary_dim:], g_og[:, :, 0:2, :, rotary_dim:])\n    assert torch.equal(qkv.grad[:, :, 2], g_og[:, :, 2])\n"
  },
  {
    "path": "tests/losses/test_cross_entropy.py",
    "content": "# Copyright (c) 2024, Tri Dao.\n\nimport pytest\nimport torch\nimport torch.nn.functional as F\nfrom flash_attn.losses.cross_entropy import CrossEntropyLoss\n\nis_sm8x = torch.cuda.get_device_capability(\"cuda\")[0] >= 8\n\n\n@pytest.mark.parametrize(\n    \"dtype\", [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else [])\n)\n# @pytest.mark.parametrize(\"dtype\", [torch.float16])\n@pytest.mark.parametrize(\"precompute_lse\", [False, True])\n# @pytest.mark.parametrize(\"precompute_lse\", [False])\n@pytest.mark.parametrize(\"inplace_backward\", [False, True])\n# @pytest.mark.parametrize(\"inplace_backward\", [False])\n@pytest.mark.parametrize(\"lse_square_scale\", [0.0, 1e-2])\n@pytest.mark.parametrize(\"return_z_loss\", [False, True])\n# @pytest.mark.parametrize(\"lse_square_scale\", [1e-2])\n@pytest.mark.parametrize(\"logit_scale\", [1.0, 0.7])\n# @pytest.mark.parametrize(\"logit_scale\", [1.0])\n@pytest.mark.parametrize(\"smoothing\", [0.0, 0.9])\n# @pytest.mark.parametrize(\"smoothing\", [0.0])\n@pytest.mark.parametrize(\"vocab_size\", [50257, 128256])  # test vocab larger than 64k for split\n# @pytest.mark.parametrize(\"vocab_size\", [12])\ndef test_cross_entropy_loss(\n    vocab_size,\n    smoothing,\n    logit_scale,\n    lse_square_scale,\n    return_z_loss,\n    inplace_backward,\n    precompute_lse,\n    dtype,\n):\n    if precompute_lse and (logit_scale != 1.0 or smoothing != 0.0):\n        pytest.skip(\"precompute_lse only works with logit_scale=1.0 and smoothing=0.0\")\n    device = \"cuda\"\n    rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 1 if dtype == torch.float32 else 4  # Otherwise OOM\n    seqlen = 4096 if lse_square_scale == 0.0 and logit_scale == 1.0 else 1024  # Otherwise OOM\n    x_pt = torch.randn(\n        batch_size * seqlen, vocab_size, device=device, dtype=dtype, requires_grad=True\n    )\n    x = x_pt.detach().clone().requires_grad_()\n    y = torch.randint(0, vocab_size, (batch_size * seqlen,), dtype=torch.long, device=device)\n    if batch_size * seqlen > 10:\n        y[torch.randperm(batch_size * seqlen)[:10]] = -100\n    model_pt = torch.nn.CrossEntropyLoss(label_smoothing=smoothing)\n    model = CrossEntropyLoss(\n        label_smoothing=smoothing,\n        logit_scale=logit_scale,\n        lse_square_scale=lse_square_scale,\n        return_z_loss=return_z_loss,\n        inplace_backward=inplace_backward,\n    )\n    if precompute_lse:\n        with torch.no_grad():\n            lse = torch.logsumexp(x.float(), dim=-1)\n    else:\n        lse = None\n    if return_z_loss:\n        out, out_z_loss = model(x, y, precomputed_lse=lse)\n    else:\n        out = model(x, y, precomputed_lse=lse)\n    x_pt_scaled = (x_pt.float() * logit_scale) if logit_scale != 1.0 else x_pt.float()\n    out_pt = model_pt(x_pt_scaled, y)\n    if lse_square_scale > 0.0:\n        lse_pt = torch.logsumexp(x_pt_scaled, dim=-1)\n        z_loss_pt = lse_square_scale * (lse_pt[y != -100] ** 2).mean()\n        if return_z_loss:\n            assert torch.allclose(out_z_loss, z_loss_pt, rtol=rtol, atol=atol)\n        out_pt += z_loss_pt\n    assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6)\n\n    g = torch.randn_like(out)\n    out_pt.backward(g)\n    out.backward(g)\n    assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol)\n"
  },
  {
    "path": "tests/losses/test_cross_entropy_parallel.py",
    "content": "# Run test with:\n# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/losses/test_cross_entropy_parallel.py\n\nimport math\n\nimport pytest\nimport torch\nfrom apex.transformer import parallel_state, tensor_parallel\nfrom flash_attn.losses.cross_entropy import CrossEntropyLoss\n\nis_sm8x = torch.cuda.get_device_capability(\"cuda\")[0] >= 8\n\n\n@pytest.mark.parametrize(\n    \"dtype\", [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else [])\n)\n# @pytest.mark.parametrize(\"dtype\", [torch.float16])\n@pytest.mark.parametrize(\"precompute_lse\", [False, True])\n# @pytest.mark.parametrize(\"precompute_lse\", [False])\n@pytest.mark.parametrize(\"inplace_backward\", [False, True])\n# @pytest.mark.parametrize(\"inplace_backward\", [False])\n# @pytest.mark.parametrize(\"lse_square_scale\", [0.0, 1e-2])\n@pytest.mark.parametrize(\"lse_square_scale\", [1e-2])\n@pytest.mark.parametrize(\"logit_scale\", [1.0, 0.7])\n# @pytest.mark.parametrize(\"logit_scale\", [1.0])\n@pytest.mark.parametrize(\"smoothing\", [0.0, 0.9])\n# @pytest.mark.parametrize(\"smoothing\", [0.0])\n@pytest.mark.parametrize(\"vocab_size\", [50264, 256 * 1024])  # test vocab larger than 64k for split\n# @pytest.mark.parametrize(\"vocab_size\", [50264])  # test vocab larger than 64k for split\n# @pytest.mark.parametrize(\"world_size\", [1, 2])\n@pytest.mark.parametrize(\"world_size\", [2])\ndef test_cross_entropy_loss_parallel(\n    vocab_size,\n    world_size,\n    smoothing,\n    logit_scale,\n    lse_square_scale,\n    inplace_backward,\n    precompute_lse,\n    dtype,\n):\n    if precompute_lse and (logit_scale != 1.0 or smoothing != 0.0):\n        pytest.skip(\"precompute_lse only works with logit_scale=1.0 and smoothing=0.0\")\n    assert vocab_size % world_size == 0\n    rtol, atol = (\n        (1e-5, 2e-5)\n        if dtype == torch.float32\n        else ((1e-3, 1e-4) if dtype == torch.float16 else (1e-2, 3e-3))\n    )\n    if not torch.distributed.is_initialized():\n        torch.distributed.init_process_group(backend=\"nccl\", init_method=\"env://\")\n    partition_vocab_size = vocab_size // world_size\n    device = f\"cuda:{torch.distributed.get_rank()}\"\n    assert world_size <= torch.distributed.get_world_size()\n    parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)\n    rank = parallel_state.get_tensor_model_parallel_rank()\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 8\n    seqlen = 128\n    x_pt = (\n        torch.randn(batch_size * seqlen, vocab_size, device=device, dtype=dtype) * 10\n    ).requires_grad_()\n    x = (\n        tensor_parallel.scatter_to_tensor_model_parallel_region(x_pt)\n        .detach()\n        .clone()\n        .requires_grad_()\n    )\n    y = torch.randint(0, vocab_size, (batch_size * seqlen,), dtype=torch.long, device=device)\n    y[torch.randperm(batch_size * seqlen)[:10]] = -100\n    model_pt = torch.nn.CrossEntropyLoss(label_smoothing=smoothing, reduction=\"none\")\n    model = CrossEntropyLoss(\n        label_smoothing=smoothing,\n        logit_scale=logit_scale,\n        reduction=\"none\",\n        lse_square_scale=lse_square_scale,\n        inplace_backward=inplace_backward,\n        process_group=parallel_state.get_tensor_model_parallel_group(),\n    )\n    if precompute_lse:\n        with torch.no_grad():\n            lse = torch.logsumexp(x.float(), dim=-1)\n    else:\n        lse = None\n    out = model(x, y, precomputed_lse=lse)\n    out_pt = model_pt(x_pt.float() * logit_scale, y)\n    if lse_square_scale > 0.0:\n        lse_pt = torch.logsumexp(x_pt.float() * logit_scale, dim=-1)\n        out_pt += lse_square_scale * lse_pt.square()\n        out_pt.masked_fill_(y == -100, 0.0)\n    assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6)\n\n    g = torch.randn_like(out)\n    out_pt.backward(g)\n    out.backward(g)\n    assert torch.allclose(\n        x.grad,\n        x_pt.grad[:, (rank * partition_vocab_size) : (rank + 1) * partition_vocab_size],\n        rtol=rtol,\n        atol=atol,\n    )\n\n    parallel_state.destroy_model_parallel()\n"
  },
  {
    "path": "tests/models/test_baichuan.py",
    "content": "# Copyright (c) 2023, Tri Dao.\nimport os\nimport time\nfrom pathlib import Path\n\nimport torch\nimport pytest\n\nfrom einops import rearrange\n\nfrom transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM\n\nfrom flash_attn.models.gpt import (\n    GPTLMHeadModel,\n    combine_state_dicts_tp,\n    shard_state_dict_tp,\n)\nfrom flash_attn.models.baichuan import (\n    remap_state_dict_hf_baichuan,\n    baichuan_config_to_gpt2_config,\n)\nfrom flash_attn.utils.distributed import all_gather_raw\nfrom flash_attn.utils.pretrained import state_dict_from_pretrained\nfrom flash_attn.utils.generation import update_graph_cache\n\n\n@pytest.mark.parametrize(\n    \"model_name\",\n    [\n        \"baichuan-inc/Baichuan-7B\",\n        \"baichuan-inc/Baichuan-13B-Base\",\n        \"baichuan-inc/Baichuan2-7B-Base\",\n        \"baichuan-inc/Baichuan2-13B-Base\",\n    ],\n)\ndef test_baichuan_state_dict(model_name):\n    config = baichuan_config_to_gpt2_config(\n        AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n    )\n    pretrained_state_dict = remap_state_dict_hf_baichuan(\n        state_dict_from_pretrained(model_name), config\n    )\n    model = GPTLMHeadModel(config, device=\"meta\")  # Without device='meta' init is very slow\n    state_dict = model.state_dict()\n    assert len(state_dict.keys()) == len(pretrained_state_dict.keys())\n    assert state_dict.keys() == pretrained_state_dict.keys()\n    for k in state_dict.keys():\n        assert state_dict[k].shape == pretrained_state_dict[k].shape\n\n\n@pytest.mark.parametrize(\n    \"model_name\",\n    [\n        \"baichuan-inc/Baichuan-7B\",\n        \"baichuan-inc/Baichuan-13B-Base\",\n        \"baichuan-inc/Baichuan2-7B-Base\",\n        \"baichuan-inc/Baichuan2-13B-Base\",\n    ],\n)\ndef test_baichuan_optimized(model_name):\n    \"\"\"Check that our implementation of Baichuan (with all optimizations enabled) matches the\n    HF implementation: the output of our forward pass in fp16 should be around the same as the HF\n    forward pass in fp16, when compared to the HF forward pass in fp32.\n    \"\"\"\n    dtype = torch.float16\n    device = \"cuda\"\n    config = baichuan_config_to_gpt2_config(\n        AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n    )\n    config.use_flash_attn = True\n    config.fused_bias_fc = True\n    config.fused_mlp = False  # We don't have fused GatedMLP yet\n    config.fused_dropout_add_ln = True\n    config.residual_in_fp32 = True\n\n    pretrained_state_dict = remap_state_dict_hf_baichuan(\n        state_dict_from_pretrained(model_name), config\n    )\n    model = GPTLMHeadModel(config, device=device, dtype=dtype)\n    model.load_state_dict(pretrained_state_dict)\n    model.eval()\n\n    torch.manual_seed(0)\n    batch_size = 2\n    max_seqlen = 256\n    seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)\n    input_ids = torch.randint(\n        0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device\n    )\n    with torch.no_grad():\n        out = model.transformer(input_ids)\n        logits = model(input_ids).logits\n    del model\n\n    # Without device_map, the model is loaded on the CPU, which is very slow\n    # Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB\n    model_ref = AutoModelForCausalLM.from_pretrained(\n        model_name, device_map=\"auto\", trust_remote_code=True\n    )\n    model_ref.eval()\n    with torch.no_grad():\n        out_ref = model_ref.model(input_ids).last_hidden_state.to(device=device)\n        logits_ref = model_ref(input_ids).logits.to(device=device)\n    del model_ref\n\n    model_hf = AutoModelForCausalLM.from_pretrained(\n        model_name,\n        torch_dtype=dtype,\n        device_map={\"\": device},\n        trust_remote_code=True,\n    )\n    model_hf.eval()\n    with torch.no_grad():\n        out_hf = model_hf.model(input_ids).last_hidden_state\n        logits_hf = model_hf(input_ids).logits\n    del model_hf\n\n    print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n    print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n    print(f\"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}\")\n    print(f\"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}\")\n    assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()\n\n    print(f\"Logits max diff: {(logits - logits_ref).abs().max().item()}\")\n    print(f\"Logits mean diff: {(logits - logits_ref).abs().mean().item()}\")\n    print(f\"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}\")\n    print(f\"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}\")\n    assert (logits - logits_ref).abs().max().item() < 3 * (\n        logits_hf - logits_ref\n    ).abs().max().item()\n\n\n# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_baichuan.py -k \"test_baichuan_parallel_forward\"\n@pytest.mark.parametrize(\"world_size\", [2])\n@pytest.mark.parametrize(\n    \"model_name\",\n    [\n        \"baichuan-inc/Baichuan-7B\",\n        \"baichuan-inc/Baichuan-13B-Base\",\n        \"baichuan-inc/Baichuan2-7B-Base\",\n        \"baichuan-inc/Baichuan2-13B-Base\",\n    ],\n)\ndef test_baichuan_parallel_forward(model_name, world_size):\n    \"\"\"Check that our implementation of Baichuan (with all optimizations enabled) matches the\n    HF implementation: the output of our forward pass in fp16 should be around the same as the HF\n    forward pass in fp16, when compared to the HF forward pass in fp32.\n    \"\"\"\n    from apex.transformer import parallel_state\n\n    dtype = torch.float16\n    config = baichuan_config_to_gpt2_config(\n        AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n    )\n    config.use_flash_attn = True\n    config.fused_bias_fc = True\n    config.fused_mlp = False  # We don't have fused GatedMLP yet\n    config.fused_dropout_add_ln = True\n    config.residual_in_fp32 = True\n\n    if not torch.distributed.is_initialized():\n        torch.distributed.init_process_group(backend=\"nccl\", init_method=\"env://\")\n    device = f\"cuda:{torch.distributed.get_rank()}\"\n    assert world_size <= torch.distributed.get_world_size()\n    parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)\n    rank = parallel_state.get_tensor_model_parallel_rank()\n    process_group = parallel_state.get_tensor_model_parallel_group()\n\n    pretrained_state_dict = remap_state_dict_hf_baichuan(\n        state_dict_from_pretrained(model_name), config\n    )\n\n    model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)\n    model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))\n    model.eval()\n\n    torch.manual_seed(0)\n    batch_size = 2\n    max_seqlen = 256\n    seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)\n    input_ids = torch.randint(\n        0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device\n    )\n    with torch.no_grad():\n        out = model.transformer(input_ids)\n        out, _ = all_gather_raw(out, process_group=process_group)\n        out = rearrange(out, \"(b s) d -> b s d\", b=batch_size)\n        logits = model(input_ids).logits\n        logits = rearrange(logits, \"(b s) d -> b s d\", b=batch_size)\n        logits, _ = all_gather_raw(logits, process_group)\n        logits = rearrange(logits, \"(n b) ... d -> b ... (n d)\", b=batch_size)\n    del model\n    parallel_state.destroy_model_parallel()\n\n    if rank == 0:\n        # Without device_map, the model is loaded on the CPU, which is very slow\n        model_ref = AutoModelForCausalLM.from_pretrained(\n            model_name, device_map=\"auto\", trust_remote_code=True\n        )\n        model_ref.eval()\n        with torch.no_grad():\n            out_ref = model_ref.model(input_ids).last_hidden_state.to(device=device)\n            logits_ref = model_ref(input_ids).logits.to(device=device)\n        del model_ref\n\n        model_hf = AutoModelForCausalLM.from_pretrained(\n            model_name, torch_dtype=dtype, device_map=\"auto\", trust_remote_code=True\n        )\n        model_hf.eval()\n        with torch.no_grad():\n            out_hf = model_hf.model(input_ids).last_hidden_state.to(device=device)\n            logits_hf = model_hf(input_ids).logits.to(device=device)\n        del model_hf\n\n        print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n        print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n        print(f\"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}\")\n        print(f\"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}\")\n        assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item()\n\n        print(f\"Logits max diff: {(logits - logits_ref).abs().max().item()}\")\n        print(f\"Logits mean diff: {(logits - logits_ref).abs().mean().item()}\")\n        print(f\"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}\")\n        print(f\"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}\")\n        assert (logits - logits_ref).abs().max().item() < 2 * (\n            logits_hf - logits_ref\n        ).abs().max().item()\n\n\n@pytest.mark.parametrize(\n    \"model_name\", [\"baichuan-inc/Baichuan-7B\", \"baichuan-inc/Baichuan-13B-Base\"]\n)\ndef test_baichuan_generation(model_name):\n    dtype = torch.float16\n    device = \"cuda\"\n    config = baichuan_config_to_gpt2_config(\n        AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n    )\n    config.use_flash_attn = True\n    config.fused_bias_fc = True\n    config.fused_mlp = False  # We don't have fused GatedMLP yet\n    config.fused_dropout_add_ln = True\n    config.residual_in_fp32 = True\n\n    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n    eos_token_id = tokenizer.eos_token_id\n\n    torch.manual_seed(0)\n    batch_size = 1\n    seqlen = 2048\n    max_length = 2048 + 150\n    input_ids = torch.randint(\n        0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device\n    )\n\n    model_hf = AutoModelForCausalLM.from_pretrained(\n        model_name, torch_dtype=dtype, device_map={\"\": device}, trust_remote_code=True\n    )\n    model_hf.eval()\n    print(\"HF fp16\")\n    torch.cuda.synchronize()\n    start = time.time()\n    out_hf = model_hf.generate(\n        input_ids=input_ids,\n        max_length=max_length,\n        return_dict_in_generate=True,\n        output_scores=True,\n    )\n    torch.cuda.synchronize()\n    print(f\"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms\")\n    del model_hf\n\n    # Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB\n    model_ref = AutoModelForCausalLM.from_pretrained(\n        model_name, device_map=\"auto\", trust_remote_code=True\n    )\n    model_ref.eval()\n    with torch.no_grad():\n        logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1].to(device=device)\n    del model_ref\n\n    pretrained_state_dict = remap_state_dict_hf_baichuan(\n        state_dict_from_pretrained(model_name), config\n    )\n    model = GPTLMHeadModel(config, device=device, dtype=dtype)\n    model.load_state_dict(pretrained_state_dict)\n    model.eval()\n\n    model(input_ids)  # Warm up\n    print(\"Without CUDA graph\")\n    torch.cuda.synchronize()\n    start = time.time()\n    out = model.generate(\n        input_ids=input_ids,\n        max_length=max_length,\n        eos_token_id=eos_token_id,\n        return_dict_in_generate=True,\n        output_scores=True,\n        enable_timing=True,\n        teacher_outputs=out_hf.sequences,\n    )\n    torch.cuda.synchronize()\n    print(f\"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms\")\n\n    # Capture graph outside the timing loop\n    batch_size, seqlen_og = input_ids.shape\n    model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)\n    print(\"With CUDA graph\")\n    torch.cuda.synchronize()\n    start = time.time()\n    out_cg = model.generate(\n        input_ids=input_ids,\n        max_length=max_length,\n        cg=True,\n        return_dict_in_generate=True,\n        output_scores=True,\n        enable_timing=True,\n        teacher_outputs=out_hf.sequences,\n    )\n    torch.cuda.synchronize()\n    print(f\"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms\")\n\n    with torch.no_grad():\n        logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1) : -1]\n    logits_hf = torch.stack(out_hf.scores, dim=1)\n    logits = torch.stack(out.scores, dim=1)\n    logits_cg = torch.stack(out_cg.scores, dim=1)\n\n    del model\n\n    hf_error = (logits_hf - logits_ref).abs().max().item()\n\n    print(f\"HF fp16 logits max diff: {hf_error}\")\n    print(f\"Logits max diff: {(logits - logits_ref).abs().max().item() }\")\n    print(f\"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }\")\n\n    assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error\n    assert (logits - logits_ref).abs().max().item() < 2 * hf_error\n    assert torch.equal(logits_cg, logits)\n\n\n# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_baichuan.py -k \"baichuan_parallel_generation\"\n@pytest.mark.parametrize(\"world_size\", [2])\n@pytest.mark.parametrize(\"model_name\", [\"baichuan-inc/Baichuan-7B\"])\ndef test_baichuan_parallel_generation(model_name, world_size):\n    \"\"\"Check that our implementation matches the HF implementation:\n    the scores in fp16 should be around the same as the HF scores in fp16, when compared to\n    the HF scores in fp32.\n    \"\"\"\n    from apex.transformer import parallel_state\n\n    dtype = torch.float16\n    config = baichuan_config_to_gpt2_config(\n        AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n    )\n    config.use_flash_attn = True\n    config.fused_bias_fc = True\n    config.fused_mlp = False  # We don't have fused GatedMLP yet\n    config.fused_dropout_add_ln = False\n    config.residual_in_fp32 = True\n    config.pad_vocab_size_multiple = 8 * world_size\n    config.sequence_parallel = False  # Need to set this to False for generation\n\n    os.environ[\"NCCL_ASYNC_ERROR_HANDLING\"] = \"0\"\n    if not torch.distributed.is_initialized():\n        torch.distributed.init_process_group(backend=\"nccl\", init_method=\"env://\")\n    device = f\"cuda:{torch.distributed.get_rank()}\"\n    assert world_size <= torch.distributed.get_world_size()\n    parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)\n    rank = parallel_state.get_tensor_model_parallel_rank()\n    process_group = parallel_state.get_tensor_model_parallel_group()\n\n    torch.manual_seed(0)\n    batch_size = 1\n    seqlen = 100\n    max_length = 150\n    input_ids = torch.randint(\n        0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device\n    )\n\n    # Need this, otherwise when we capture the graph the process for GPU 1 would run on both\n    # GPU0 and GPU1 and things would hang\n    torch.cuda.set_device(device)\n\n    pretrained_state_dict = remap_state_dict_hf_baichuan(\n        state_dict_from_pretrained(model_name), config\n    )\n\n    model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)\n    model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))\n    model.eval()\n\n    print(\"Without CUDA graph\")\n    out = model.generate(\n        input_ids=input_ids,\n        max_length=max_length,\n        tensor_parallel=world_size,\n        vocab_size=config.vocab_size,\n        # teacher_outputs=out_hf.sequences,\n        return_dict_in_generate=True,\n        output_scores=True,\n        enable_timing=True,\n    )\n\n    # Capture graph outside the timing loop\n    batch_size, seqlen_og = input_ids.shape\n    model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)\n    print(\"With CUDA graph\")\n    out_cg = model.generate(\n        input_ids=input_ids,\n        max_length=max_length,\n        tensor_parallel=world_size,\n        vocab_size=config.vocab_size,\n        cg=True,\n        # teacher_outputs=out_hf.sequences,\n        return_dict_in_generate=True,\n        output_scores=True,\n        enable_timing=True,\n    )\n    del model\n    parallel_state.destroy_model_parallel()\n\n    if rank == 0:\n        # Without device_map, the model is loaded on the CPU, which is very slow\n        model_hf = AutoModelForCausalLM.from_pretrained(\n            model_name, torch_dtype=dtype, device_map=\"auto\", trust_remote_code=True\n        )\n        model_hf.eval()\n        print(\"HF fp16\")\n        torch.cuda.synchronize()\n        start = time.time()\n        with torch.inference_mode():\n            out_hf = model_hf.generate(\n                input_ids=input_ids,\n                max_length=max_length,\n                return_dict_in_generate=True,\n                output_scores=True,\n            )\n        torch.cuda.synchronize()\n        print(f\"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms\")\n        del model_hf\n\n        model_ref = AutoModelForCausalLM.from_pretrained(\n            model_name, device_map=\"auto\", trust_remote_code=True\n        )\n        model_ref.eval()\n        with torch.inference_mode():\n            logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1]\n        del model_ref\n        logits_hf = torch.stack(out_hf.scores, dim=1)\n\n        logits = torch.stack(out.scores, dim=1)\n        logits_cg = torch.stack(out_cg.scores, dim=1)\n\n        hf_error = (logits_hf - logits_ref).abs().max().item()\n        print(f\"HF fp16 logits max diff: {hf_error}\")\n        print(f\"Logits max diff: {(logits - logits_ref).abs().max().item() }\")\n        print(f\"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }\")\n        assert (logits - logits_ref).abs().max().item() < 2 * hf_error\n        assert torch.equal(logits_cg, logits)\n"
  },
  {
    "path": "tests/models/test_bert.py",
    "content": "import re\nfrom collections import OrderedDict\n\nimport pytest\nimport torch\nimport torch.nn.functional as F\nfrom einops import rearrange\nfrom transformers import BertConfig\nfrom transformers.models.bert.modeling_bert import BertForPreTraining as BertForPreTrainingHF\nfrom transformers.models.bert.modeling_bert import BertModel as BertModelHF\n\nfrom flash_attn.models.bert import (\n    BertForPreTraining,\n    BertModel,\n    inv_remap_state_dict,\n    remap_state_dict,\n)\nfrom flash_attn.utils.pretrained import state_dict_from_pretrained\n\n\n@pytest.mark.parametrize(\"model_name\", [\"bert-base-uncased\", \"bert-large-uncased\"])\n# @pytest.mark.parametrize('model_name', [\"bert-base-uncased\"])\ndef test_bert_state_dict(model_name):\n    config = BertConfig.from_pretrained(model_name)\n    pretrained_state_dict = remap_state_dict(state_dict_from_pretrained(model_name), config)\n    model = BertForPreTraining(config)\n    state_dict = model.state_dict()\n    assert state_dict.keys() == pretrained_state_dict.keys()\n    for k in state_dict.keys():\n        assert state_dict[k].shape == pretrained_state_dict[k].shape\n\n\ndef get_hf_models(model_name, config, dtype):\n    pretrained_state_dict = state_dict_from_pretrained(model_name)\n\n    def key_mapping_ln_gamma_beta(key):\n        key = re.sub(r\"LayerNorm.gamma$\", \"LayerNorm.weight\", key)\n        key = re.sub(r\"LayerNorm.beta$\", \"LayerNorm.bias\", key)\n        return key\n\n    pretrained_state_dict = OrderedDict(\n        (key_mapping_ln_gamma_beta(k), v) for k, v in pretrained_state_dict.items()\n    )\n    model_hf = BertForPreTrainingHF(config)\n    # Missing key(s) in state_dict: \"bert.embeddings.position_ids\", \"cls.predictions.decoder.bias\"\n    # position_ids is a buffer, and predictions.decoder.bias is tied to predictions.bias.\n    model_hf.load_state_dict(pretrained_state_dict, strict=False)\n    model_hf.cuda().to(dtype=dtype)\n    return model_hf\n\n\n@pytest.mark.parametrize(\"model_name\", [\"bert-base-uncased\"])\ndef test_bert_non_optimized(model_name):\n    \"\"\"Check that our implementation of BERT (without any optimizations enabled) matches the\n    HF implementation: the output of our forward pass in fp16 should be around the same as the HF\n    forward pass in fp16, when compared to the HF forward pass in fp32.\n    \"\"\"\n    dtype = torch.float16\n    config = BertConfig.from_pretrained(model_name)\n\n    model = BertForPreTraining.from_pretrained(model_name, config)\n    model = model.cuda().to(dtype=dtype)\n\n    model_ref = get_hf_models(model_name, config, torch.float32)\n    model_hf = get_hf_models(model_name, config, dtype)\n\n    model.eval()\n    model_ref.eval()\n    model_hf.eval()\n\n    torch.manual_seed(0)\n    batch_size = 4\n    max_seqlen = 512\n    seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=\"cuda\")\n    attention_mask = torch.arange(max_seqlen, device=\"cuda\")[None, :] < seqlens[:, None]\n    input_ids = torch.randint(\n        0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=\"cuda\"\n    )\n    out = model.bert(input_ids, attention_mask=attention_mask)\n    sequence_output, pooled_output = out.last_hidden_state, out.pooler_output\n    out_hf = model_hf.bert(input_ids, attention_mask=attention_mask)\n    sequence_output_hf, pooled_output_hf = out_hf.last_hidden_state, out_hf.pooler_output\n    out_ref = model_ref.bert(input_ids, attention_mask=attention_mask)\n    sequence_output_ref, pooled_output_ref = out_ref.last_hidden_state, out_ref.pooler_output\n\n    print(f\"Output max diff: {(sequence_output - sequence_output_ref).abs().max().item()}\")\n    print(f\"Output mean diff: {(sequence_output - sequence_output_ref).abs().mean().item()}\")\n    print(f\"HF fp16 max diff: {(sequence_output_hf - sequence_output_ref).abs().max().item()}\")\n    print(f\"HF fp16 mean diff: {(sequence_output_hf - sequence_output_ref).abs().mean().item()}\")\n    assert (sequence_output - sequence_output_ref).abs().max().item() < 3 * (\n        sequence_output_hf - sequence_output_ref\n    ).abs().max().item()\n    assert (pooled_output - pooled_output_ref).abs().max().item() < 3 * (\n        pooled_output_hf - pooled_output_ref\n    ).abs().max().item()\n\n\n@pytest.mark.parametrize(\"model_name\", [\"bert-base-uncased\", \"bert-large-uncased\"])\n# @pytest.mark.parametrize('model_name', [\"bert-base-uncased\"])\ndef test_bert_optimized(model_name):\n    \"\"\"Check that our implementation of BERT (with all optimizations enabled) matches the\n    HF implementation: the output of our forward pass in fp16 should be around the same as the HF\n    forward pass in fp16, when compared to the HF forward pass in fp32.\n    \"\"\"\n    dtype = torch.float16\n    config = BertConfig.from_pretrained(model_name)\n    # Our implementation of fused_mlp assumes the activation is\n    # nn.GELU(approximate='tanh'). Huggingface calls it \"gelu_new\", \"gelu_fast\", or \"gelu_pytorch_tanh\".\n    # If you just want \"gelu\", disable fused_mlp.\n    config.hidden_act = \"gelu_new\"\n    config.use_flash_attn = True\n    config.fused_bias_fc = True\n    config.fused_mlp = True\n    config.fused_dropout_add_ln = True\n\n    model = BertForPreTraining.from_pretrained(model_name, config)\n    model = model.cuda().to(dtype=dtype)\n\n    model_ref = get_hf_models(model_name, config, torch.float32)\n    model_hf = get_hf_models(model_name, config, dtype)\n\n    model.eval()\n    model_ref.eval()\n    model_hf.eval()\n\n    torch.manual_seed(0)\n    batch_size = 4\n    max_seqlen = 512\n    seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=\"cuda\")\n    attention_mask = torch.arange(max_seqlen, device=\"cuda\")[None, :] < seqlens[:, None]\n    input_ids = torch.randint(\n        0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=\"cuda\"\n    )\n    out = model.bert(input_ids, attention_mask=attention_mask)\n    sequence_output, pooled_output = out.last_hidden_state, out.pooler_output\n    out_hf = model_hf.bert(input_ids, attention_mask=attention_mask)\n    sequence_output_hf, pooled_output_hf = out_hf.last_hidden_state, out_hf.pooler_output\n    # Need to zero out the padded tokens in the sequence before comparison.\n    sequence_output_hf[~attention_mask, :] = 0.0\n    out_ref = model_ref.bert(input_ids, attention_mask=attention_mask)\n    sequence_output_ref, pooled_output_ref = out_ref.last_hidden_state, out_ref.pooler_output\n    sequence_output_ref[~attention_mask, :] = 0.0\n\n    print(\n        f\"BertModel output max diff: {(sequence_output - sequence_output_ref).abs().max().item()}\"\n    )\n    print(\n        f\"BertModel output mean diff: {(sequence_output - sequence_output_ref).abs().mean().item()}\"\n    )\n    print(\n        f\"HF fp16 BertModel max diff: {(sequence_output_hf - sequence_output_ref).abs().max().item()}\"\n    )\n    print(\n        f\"HF fp16 BertModel mean diff: {(sequence_output_hf - sequence_output_ref).abs().mean().item()}\"\n    )\n    assert (sequence_output - sequence_output_ref).abs().max().item() < 4 * (\n        sequence_output_hf - sequence_output_ref\n    ).abs().max().item()\n    assert (pooled_output - pooled_output_ref).abs().max().item() < 4 * (\n        pooled_output_hf - pooled_output_ref\n    ).abs().max().item()\n\n    out = model(input_ids, attention_mask=attention_mask)\n    prediction_scores, seq_relationship_scores = out.prediction_logits, out.seq_relationship_logits\n    # Need to zero out the padded tokens in the sequence before comparison.\n    prediction_scores = prediction_scores.clone()\n    prediction_scores[~attention_mask, :] = 0.0\n    out_hf = model_hf(input_ids, attention_mask=attention_mask)\n    prediction_scores_hf, seq_relationship_scores_hf = (\n        out_hf.prediction_logits,\n        out_hf.seq_relationship_logits,\n    )\n    prediction_scores_hf[~attention_mask, :] = 0.0\n    out_ref = model_ref(input_ids, attention_mask=attention_mask)\n    prediction_scores_ref, seq_relationship_scores_ref = (\n        out_ref.prediction_logits,\n        out_ref.seq_relationship_logits,\n    )\n    prediction_scores_ref[~attention_mask, :] = 0.0\n\n    print(\n        f\"prediction_scores max diff: {(prediction_scores - prediction_scores_ref).abs().max().item()}\"\n    )\n    print(\n        f\"prediction_scores mean diff: {(prediction_scores - prediction_scores_ref).abs().mean().item()}\"\n    )\n    print(\n        f\"HF fp16 prediction_scoresff: {(prediction_scores_hf - prediction_scores_ref).abs().max().item()}\"\n    )\n    print(\n        f\"HF fp16 prediction_scoresiff: {(prediction_scores_hf - prediction_scores_ref).abs().mean().item()}\"\n    )\n    assert (prediction_scores - prediction_scores_ref).abs().max().item() < 2 * (\n        prediction_scores_hf - prediction_scores_ref\n    ).abs().max().item()\n    assert (seq_relationship_scores - seq_relationship_scores_ref).abs().max().item() < 2 * (\n        seq_relationship_scores_hf - seq_relationship_scores_ref\n    ).abs().max().item()\n\n\n@pytest.mark.parametrize(\"last_layer_subset\", [False, True])\n# @pytest.mark.parametrize('last_layer_subset', [True])\n@pytest.mark.parametrize(\"has_key_padding_mask\", [True, False])\n# @pytest.mark.parametrize('has_key_padding_mask', [True])\n@pytest.mark.parametrize(\"model_name\", [\"bert-base-uncased\", \"bert-large-uncased\"])\n# @pytest.mark.parametrize('model_name', [\"bert-base-uncased\"])\ndef test_bert_dense_seq_output(model_name, has_key_padding_mask, last_layer_subset):\n    \"\"\"Check that our implementation of BERT (with all optimizations enabled) matches the\n    HF implementation: the output of our forward pass in fp16 should be around the same as the HF\n    forward pass in fp16, when compared to the HF forward pass in fp32.\n    \"\"\"\n    dtype = torch.float16\n    config = BertConfig.from_pretrained(model_name)\n    # Our implementation of fused_mlp assumes the activation is\n    # nn.GELU(approximate='tanh'). Huggingface calls it \"gelu_new\", \"gelu_fast\", or \"gelu_pytorch_tanh\".\n    # If you just want \"gelu\", disable fused_mlp.\n    config.hidden_act = \"gelu_new\"\n    config.use_flash_attn = True\n    config.fused_bias_fc = True\n    config.fused_mlp = True\n    config.fused_dropout_add_ln = True\n    config.dense_seq_output = True\n    config.last_layer_subset = last_layer_subset\n    config.use_xentropy = True\n\n    model = BertForPreTraining.from_pretrained(model_name, config)\n    model = model.cuda().to(dtype=dtype)\n\n    model_ref = get_hf_models(model_name, config, torch.float32)\n    model_hf = get_hf_models(model_name, config, dtype)\n\n    model.eval()\n    model_ref.eval()\n    model_hf.eval()\n\n    torch.manual_seed(0)\n    batch_size = 4\n    max_seqlen = 512\n    seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=\"cuda\")\n    if has_key_padding_mask:\n        attention_mask = torch.arange(max_seqlen, device=\"cuda\")[None, :] < seqlens[:, None]\n    else:\n        attention_mask = None\n    input_ids = torch.randint(\n        0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=\"cuda\"\n    )\n    labels = torch.randint(\n        0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=\"cuda\"\n    )\n    if attention_mask is not None:\n        labels[~attention_mask] = 0\n    labels[(torch.rand(batch_size, max_seqlen, device=\"cuda\") > 0.15)] = 0\n    masked_tokens_mask = labels.flatten() > 0\n    next_sequence_label = torch.randint(0, 2, (batch_size,), device=\"cuda\")\n\n    out = model(\n        input_ids,\n        attention_mask=attention_mask,\n        labels=labels,\n        next_sentence_label=next_sequence_label,\n    )\n    prediction_scores, seq_relationship_scores = out.prediction_logits, out.seq_relationship_logits\n    out_hf = model_hf(\n        input_ids,\n        attention_mask=attention_mask,\n        labels=labels,\n        next_sentence_label=next_sequence_label,\n    )\n    prediction_scores_hf, seq_relationship_scores_hf = (\n        out_hf.prediction_logits,\n        out_hf.seq_relationship_logits,\n    )\n    prediction_scores_hf = rearrange(prediction_scores_hf, \"b s d -> (b s) d\")[masked_tokens_mask]\n    out_ref = model_ref(\n        input_ids,\n        attention_mask=attention_mask,\n        labels=labels,\n        next_sentence_label=next_sequence_label,\n    )\n    prediction_scores_ref, seq_relationship_scores_ref = (\n        out_ref.prediction_logits,\n        out_ref.seq_relationship_logits,\n    )\n    prediction_scores_ref = rearrange(prediction_scores_ref, \"b s d -> (b s) d\")[masked_tokens_mask]\n\n    print(\n        f\"prediction_scores max diff: {(prediction_scores - prediction_scores_ref).abs().max().item()}\"\n    )\n    print(\n        f\"prediction_scores mean diff: {(prediction_scores - prediction_scores_ref).abs().mean().item()}\"\n    )\n    print(\n        f\"HF fp16 prediction_scoresff: {(prediction_scores_hf - prediction_scores_ref).abs().max().item()}\"\n    )\n    print(\n        f\"HF fp16 prediction_scoresiff: {(prediction_scores_hf - prediction_scores_ref).abs().mean().item()}\"\n    )\n    assert (prediction_scores - prediction_scores_ref).abs().max().item() < 2 * (\n        prediction_scores_hf - prediction_scores_ref\n    ).abs().max().item()\n    assert (seq_relationship_scores - seq_relationship_scores_ref).abs().max().item() < 2 * (\n        seq_relationship_scores_hf - seq_relationship_scores_ref\n    ).abs().max().item()\n    # The loss calculation from HF is wrong: it doesn't ignore the labels that are 0.\n    # assert (out.loss - out_ref.loss).abs().max().item() < 2 * (out_hf.loss - out_ref.loss).abs().max().item()\n\n\n@pytest.mark.parametrize(\"model_name\", [\"bert-base-uncased\", \"bert-large-uncased\"])\ndef test_inv_remap_state_dict(model_name: str):\n    \"\"\"\n    Verify that we can convert a HF BERT model to flash_attn and back.\n    \"\"\"\n\n    state_dict = state_dict_from_pretrained(model_name)\n    config = BertConfig.from_pretrained(model_name)\n\n    flash_state_dict = remap_state_dict(state_dict, config)\n    recovered_state_dict = inv_remap_state_dict(flash_state_dict, config)\n\n    assert set(state_dict.keys()) == set(recovered_state_dict.keys())\n\n    for k in state_dict.keys():\n        assert state_dict[k].shape == recovered_state_dict[k].shape\n        torch.testing.assert_close(state_dict[k], recovered_state_dict[k], rtol=1e-6, atol=1e-6)\n"
  },
  {
    "path": "tests/models/test_bigcode.py",
    "content": "import time\n\nimport pytest\nimport torch\nfrom transformers import AutoTokenizer, GPTBigCodeConfig\nfrom transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeForCausalLM\n\nfrom flash_attn.models.bigcode import bigcode_config_to_gpt2_config, inv_remap_state_dict_hf_bigcode\nfrom flash_attn.models.gpt import GPTLMHeadModel, remap_state_dict_hf_bigcode\nfrom flash_attn.utils.generation import update_graph_cache\nfrom flash_attn.utils.pretrained import state_dict_from_pretrained\n\n\n@pytest.mark.parametrize(\"model_name\", [\"bigcode/starcoderbase-1b\", \"WizardLM/WizardCoder-1B-V1.0\"])\ndef test_bigcode_state_dict(model_name):\n    config = bigcode_config_to_gpt2_config(GPTBigCodeConfig.from_pretrained(model_name))\n    pretrained_state_dict = remap_state_dict_hf_bigcode(\n        state_dict_from_pretrained(model_name), config\n    )\n    model = GPTLMHeadModel(config, device=\"meta\")\n    state_dict = model.state_dict()\n    assert state_dict.keys() == pretrained_state_dict.keys()\n    for k in state_dict.keys():\n        assert state_dict[k].shape == pretrained_state_dict[k].shape\n\n\n@pytest.mark.parametrize(\"model_name\", [\"bigcode/starcoderbase-1b\", \"WizardLM/WizardCoder-1B-V1.0\"])\ndef test_bigcode_optimized(model_name):\n    \"\"\"Check that our implementation of BigCode (with all optimizations enabled) matches the\n    HF implementation: the output of our forward pass in fp16 should be around the same as the HF\n    forward pass in fp16, when compared to the HF forward pass in fp32.\n    \"\"\"\n    dtype = torch.float16\n    device = \"cuda\"\n    config = bigcode_config_to_gpt2_config(GPTBigCodeConfig.from_pretrained(model_name))\n    config.use_flash_attn = True  # FlashAttention-2 supports headdim 256\n    config.fused_bias_fc = True\n    config.fused_mlp = True\n    config.fused_dropout_add_ln = True\n    config.residual_in_fp32 = True\n\n    model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)\n    model.eval()\n\n    torch.manual_seed(0)\n    batch_size = 2\n    max_seqlen = 256\n    input_ids = torch.randint(\n        0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device\n    )\n    with torch.no_grad():\n        out = model.transformer(input_ids)\n        logits = model(input_ids).logits\n    del model\n\n    # Without device_map, the model is loaded on the CPU, which is very slow\n    model_ref = GPTBigCodeForCausalLM.from_pretrained(model_name, device_map={\"\": device})\n    model_ref.eval()\n    with torch.no_grad():\n        out_ref = model_ref.transformer(input_ids).last_hidden_state\n        logits_ref = model_ref(input_ids).logits\n    del model_ref\n\n    model_hf = GPTBigCodeForCausalLM.from_pretrained(\n        model_name, torch_dtype=dtype, device_map={\"\": device}\n    )\n    model_hf.eval()\n    out_hf = model_hf.transformer(input_ids).last_hidden_state\n    logits_hf = model_hf(input_ids).logits\n    del model_hf\n\n    print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n    print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n    print(f\"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}\")\n    print(f\"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}\")\n    assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()\n\n    print(f\"Logits max diff: {(logits - logits_ref).abs().max().item()}\")\n    print(f\"Logits mean diff: {(logits - logits_ref).abs().mean().item()}\")\n    print(f\"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}\")\n    print(f\"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}\")\n    assert (logits - logits_ref).abs().max().item() < 3 * (\n        logits_hf - logits_ref\n    ).abs().max().item()\n\n\n@pytest.mark.parametrize(\"model_name\", [\"bigcode/starcoderbase-1b\", \"WizardLM/WizardCoder-1B-V1.0\"])\ndef test_bigcode_generation(model_name):\n    \"\"\"Check that our implementation of BigCode (with all optimizations enabled) matches the\n    HF implementation: the output of our forward pass in fp16 should be around the same as the HF\n    forward pass in fp16, when compared to the HF forward pass in fp32.\n    \"\"\"\n    dtype = torch.float16\n    device = \"cuda\"\n    config = bigcode_config_to_gpt2_config(GPTBigCodeConfig.from_pretrained(model_name))\n    config.use_flash_attn = True  # FlashAttention-2 supports headdim 256\n    config.fused_bias_fc = True\n    config.fused_mlp = True\n    config.fused_dropout_add_ln = True\n    # Only prenorm supports residual_in_fp32\n    config.residual_in_fp32 = True\n\n    tokenizer = AutoTokenizer.from_pretrained(model_name)\n    eos_token_id = tokenizer.eos_token_id\n\n    torch.manual_seed(0)\n    batch_size = 1\n    seqlen = 100\n    max_length = 150\n    input_ids = torch.randint(\n        0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device\n    )\n\n    model_hf = GPTBigCodeForCausalLM.from_pretrained(\n        model_name, torch_dtype=dtype, device_map={\"\": device}\n    )\n    model_hf.eval()\n    print(\"HF fp16\")\n    torch.cuda.synchronize()\n    start = time.time()\n    out_hf = model_hf.generate(\n        input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True\n    )\n    torch.cuda.synchronize()\n    print(f\"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms\")\n    del model_hf\n\n    model_ref = GPTBigCodeForCausalLM.from_pretrained(model_name, device_map={\"\": device})\n    model_ref.eval()\n    with torch.no_grad():\n        logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1]\n    del model_ref\n\n    model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)\n    model.eval()\n\n    print(\"Without CUDA graph\")\n    torch.cuda.synchronize()\n    start = time.time()\n    out = model.generate(\n        input_ids=input_ids,\n        max_length=max_length,\n        eos_token_id=eos_token_id,\n        return_dict_in_generate=True,\n        output_scores=True,\n        enable_timing=True,\n        teacher_outputs=out_hf.sequences,\n    )\n    torch.cuda.synchronize()\n    print(f\"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms\")\n\n    # Capture graph outside the timing loop\n    batch_size, seqlen_og = input_ids.shape\n    model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)\n    print(\"With CUDA graph\")\n    torch.cuda.synchronize()\n    start = time.time()\n    out_cg = model.generate(\n        input_ids=input_ids,\n        max_length=max_length,\n        cg=True,\n        return_dict_in_generate=True,\n        output_scores=True,\n        enable_timing=True,\n        teacher_outputs=out_hf.sequences,\n    )\n    torch.cuda.synchronize()\n    print(f\"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms\")\n\n    with torch.no_grad():\n        logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1) : -1]\n    logits_hf = torch.stack(out_hf.scores, dim=1)\n    logits = torch.stack(out.scores, dim=1)\n    logits_cg = torch.stack(out_cg.scores, dim=1)\n\n    del model\n\n    hf_error = (logits_hf - logits_ref).abs().max().item()\n    assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error\n\n    print(f\"HF fp16 logits max diff: {hf_error}\")\n    print(f\"Logits max diff: {(logits - logits_ref).abs().max().item() }\")\n    assert (logits - logits_ref).abs().max().item() < 2 * hf_error\n    print(f\"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }\")\n    assert (logits_cg - logits_ref).abs().max().item() < 2 * hf_error\n\n\n@pytest.mark.parametrize(\"model_name\", [\"bigcode/starcoderbase-1b\", \"WizardLM/WizardCoder-1B-V1.0\"])\ndef test_inv_remap_state_dict(model_name: str):\n    \"\"\"\n    Verify that we can convert a HF BigCode model to flash_attn and back.\n    \"\"\"\n\n    state_dict = state_dict_from_pretrained(model_name)\n    config = GPTBigCodeConfig.from_pretrained(model_name)\n\n    flash_state_dict = remap_state_dict_hf_bigcode(state_dict, config)\n    recovered_state_dict = inv_remap_state_dict_hf_bigcode(flash_state_dict, config)\n\n    assert set(state_dict.keys()) == set(recovered_state_dict.keys())\n\n    for k in state_dict.keys():\n        assert state_dict[k].shape == recovered_state_dict[k].shape\n        torch.testing.assert_close(state_dict[k], recovered_state_dict[k], rtol=1e-6, atol=1e-6)\n"
  },
  {
    "path": "tests/models/test_btlm.py",
    "content": "# Copyright (c) 2023, Tri Dao.\nimport time\n\nimport torch\nimport pytest\n\nfrom transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM\n\nfrom flash_attn.models.gpt import GPTLMHeadModel\nfrom flash_attn.models.btlm import btlm_config_to_gpt2_config, remap_state_dict_hf_btlm\nfrom flash_attn.utils.pretrained import state_dict_from_pretrained\nfrom flash_attn.utils.generation import update_graph_cache\n\n\n@pytest.mark.parametrize(\"model_name\", [\"cerebras/btlm-3b-8k-base\"])\ndef test_btlm_state_dict(model_name):\n    config = btlm_config_to_gpt2_config(\n        AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n    )\n    pretrained_state_dict = remap_state_dict_hf_btlm(state_dict_from_pretrained(model_name), config)\n    model = GPTLMHeadModel(config, device=\"meta\")  # Without device='meta' init is very slow\n    state_dict = model.state_dict()\n    assert len(state_dict.keys()) == len(pretrained_state_dict.keys())\n    assert state_dict.keys() == pretrained_state_dict.keys()\n    for k in state_dict.keys():\n        assert state_dict[k].shape == pretrained_state_dict[k].shape\n\n\n@pytest.mark.parametrize(\"model_name\", [\"cerebras/btlm-3b-8k-base\"])\ndef test_btlm_optimized(model_name):\n    \"\"\"Check that our implementation of Btlm (with all optimizations enabled) matches the\n    HF implementation: the output of our forward pass in fp16 should be around the same as the HF\n    forward pass in fp16, when compared to the HF forward pass in fp32.\n    \"\"\"\n    dtype = torch.float16\n    device = \"cuda\"\n    config = btlm_config_to_gpt2_config(\n        AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n    )\n    config.fused_bias_fc = True\n    config.fused_dropout_add_ln = True\n    config.residual_in_fp32 = True\n\n    pretrained_state_dict = remap_state_dict_hf_btlm(state_dict_from_pretrained(model_name), config)\n    model = GPTLMHeadModel(config, device=device, dtype=dtype)\n    model.load_state_dict(pretrained_state_dict)\n    model.eval()\n\n    torch.manual_seed(0)\n    batch_size = 2\n    max_seqlen = 256\n    seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)\n    input_ids = torch.randint(\n        0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device\n    )\n    with torch.no_grad():\n        out = model.transformer(input_ids)\n        logits = model(input_ids).logits\n    del model\n\n    # Without device_map, the model is loaded on the CPU, which is very slow\n    # Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB\n    model_ref = AutoModelForCausalLM.from_pretrained(\n        model_name, device_map=\"auto\", trust_remote_code=True\n    )\n    model_ref.eval()\n    with torch.no_grad():\n        out_ref = model_ref.transformer(input_ids).last_hidden_state.to(device=device)\n        logits_ref = model_ref(input_ids).logits.to(device=device)\n    del model_ref\n\n    model_hf = AutoModelForCausalLM.from_pretrained(\n        model_name,\n        torch_dtype=dtype,\n        device_map={\"\": device},\n        trust_remote_code=True,\n    )\n    model_hf.eval()\n    with torch.no_grad():\n        out_hf = model_hf.transformer(input_ids).last_hidden_state\n        logits_hf = model_hf(input_ids).logits\n    del model_hf\n\n    print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n    print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n    print(f\"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}\")\n    print(f\"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}\")\n    assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()\n\n    print(f\"Logits max diff: {(logits - logits_ref).abs().max().item()}\")\n    print(f\"Logits mean diff: {(logits - logits_ref).abs().mean().item()}\")\n    print(f\"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}\")\n    print(f\"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}\")\n    assert (logits - logits_ref).abs().max().item() < 3 * (\n        logits_hf - logits_ref\n    ).abs().max().item()\n\n\n@pytest.mark.parametrize(\"model_name\", [\"cerebras/btlm-3b-8k-base\"])\ndef test_btlm_generation(model_name):\n    dtype = torch.float16\n    device = \"cuda\"\n    config = btlm_config_to_gpt2_config(\n        AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n    )\n    config.fused_bias_fc = True\n    config.fused_dropout_add_ln = True\n    config.residual_in_fp32 = True\n\n    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n    eos_token_id = tokenizer.eos_token_id\n\n    torch.manual_seed(0)\n    batch_size = 1\n    seqlen = 2048\n    max_length = 2048 + 150\n    input_ids = torch.randint(\n        0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device\n    )\n\n    model_hf = AutoModelForCausalLM.from_pretrained(\n        model_name, torch_dtype=dtype, device_map={\"\": device}, trust_remote_code=True\n    )\n    model_hf.eval()\n    print(\"HF fp16\")\n    torch.cuda.synchronize()\n    start = time.time()\n    out_hf = model_hf.generate(\n        input_ids=input_ids,\n        max_length=max_length,\n        return_dict_in_generate=True,\n        output_scores=True,\n    )\n    torch.cuda.synchronize()\n    print(f\"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms\")\n    del model_hf\n\n    # Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB\n    model_ref = AutoModelForCausalLM.from_pretrained(\n        model_name, device_map=\"auto\", trust_remote_code=True\n    )\n    model_ref.eval()\n    with torch.no_grad():\n        logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1].to(device=device)\n    del model_ref\n\n    pretrained_state_dict = remap_state_dict_hf_btlm(state_dict_from_pretrained(model_name), config)\n    model = GPTLMHeadModel(config, device=device, dtype=dtype)\n    model.load_state_dict(pretrained_state_dict)\n    model.eval()\n\n    model(input_ids)  # Warm up\n    print(\"Without CUDA graph\")\n    torch.cuda.synchronize()\n    start = time.time()\n    out = model.generate(\n        input_ids=input_ids,\n        max_length=max_length,\n        eos_token_id=eos_token_id,\n        return_dict_in_generate=True,\n        output_scores=True,\n        enable_timing=True,\n        teacher_outputs=out_hf.sequences,\n    )\n    torch.cuda.synchronize()\n    print(f\"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms\")\n\n    # Capture graph outside the timing loop\n    batch_size, seqlen_og = input_ids.shape\n    model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)\n    print(\"With CUDA graph\")\n    torch.cuda.synchronize()\n    start = time.time()\n    out_cg = model.generate(\n        input_ids=input_ids,\n        max_length=max_length,\n        cg=True,\n        return_dict_in_generate=True,\n        output_scores=True,\n        enable_timing=True,\n        teacher_outputs=out_hf.sequences,\n    )\n    torch.cuda.synchronize()\n    print(f\"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms\")\n\n    with torch.no_grad():\n        logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1) : -1]\n    logits_hf = torch.stack(out_hf.scores, dim=1)\n    logits = torch.stack(out.scores, dim=1)\n    logits_cg = torch.stack(out_cg.scores, dim=1)\n\n    del model\n\n    hf_error = (logits_hf - logits_ref).abs().max().item()\n\n    print(f\"HF fp16 logits max diff: {hf_error}\")\n    print(f\"Logits max diff: {(logits - logits_ref).abs().max().item() }\")\n    print(f\"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }\")\n\n    assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error\n    assert (logits - logits_ref).abs().max().item() < 2 * hf_error\n    assert torch.equal(logits_cg, logits)\n\n\n@pytest.mark.parametrize(\"model_name\", [\"cerebras/btlm-3b-8k-base\"])\ndef test_btlm_init(model_name):\n    dtype = torch.float32\n    device = \"cuda\"\n    btlm_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n    config = btlm_config_to_gpt2_config(btlm_config)\n    model = GPTLMHeadModel(config, device=device, dtype=dtype)\n    model_ref = AutoModelForCausalLM.from_config(btlm_config, trust_remote_code=True).to(device)\n\n    assert model.transformer.embeddings.word_embeddings.weight.mean().abs() < 1e-4\n    assert (\n        model.transformer.embeddings.word_embeddings.weight.std()\n        - model_ref.transformer.wte.weight.std()\n    ).abs() < 1e-4\n    assert model.lm_head.weight.mean().abs() < 1e-4\n    assert (model.lm_head.weight.std() - model_ref.lm_head.weight.std()).abs() < 1e-4\n    for l in range(config.n_layer):\n        assert model.transformer.layers[l].mixer.Wqkv.weight.mean().abs() < 1e-4\n        assert (\n            model.transformer.layers[l].mixer.Wqkv.weight.std()\n            - model_ref.transformer.h[l].attn.c_attn.weight.std()\n        ).abs() < 1e-4\n        assert model.transformer.layers[l].mixer.Wqkv.bias.abs().max() == 0.0\n        assert model.transformer.layers[l].mixer.out_proj.weight.mean().abs() < 1e-4\n        assert (\n            model.transformer.layers[l].mixer.out_proj.weight.std()\n            - model_ref.transformer.h[l].attn.c_proj.weight.std()\n        ).abs() < 1e-4\n        assert model.transformer.layers[l].mixer.out_proj.bias.abs().max() == 0.0\n        assert model.transformer.layers[l].mlp.fc1.weight.mean().abs() < 1e-4\n        assert (\n            model.transformer.layers[l].mlp.fc1.weight.std()\n            - model_ref.transformer.h[l].mlp.c_fc.weight.std()\n        ).abs() < 1e-4\n        assert model.transformer.layers[l].mlp.fc1.bias.abs().max() == 0.0\n        assert model.transformer.layers[l].mlp.fc2.weight.mean().abs() < 1e-4\n        assert (\n            model.transformer.layers[l].mlp.fc2.weight.std()\n            - model_ref.transformer.h[l].mlp.c_proj.weight.std()\n        ).abs() < 1e-4\n        assert model.transformer.layers[l].mlp.fc2.bias.abs().max() == 0.0\n"
  },
  {
    "path": "tests/models/test_falcon.py",
    "content": "# Copyright (c) 2023, Tri Dao.\n\nimport os\nimport time\nfrom pathlib import Path\n\ncurrent_dir = Path(__file__).parent.absolute()\n\nimport pytest\nimport torch\nfrom einops import rearrange\nfrom flash_attn.models.falcon import falcon_config_to_gpt2_config, remap_state_dict_hf_falcon\nfrom flash_attn.models.gpt import GPTLMHeadModel, combine_state_dicts_tp, shard_state_dict_tp\nfrom flash_attn.utils.distributed import all_gather_raw\nfrom flash_attn.utils.generation import update_graph_cache\nfrom flash_attn.utils.pretrained import state_dict_from_pretrained\nfrom transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer\n\n\n@pytest.mark.parametrize(\"model_name\", [\"tiiuae/falcon-7b\", \"tiiuae/falcon-40b\"])\ndef test_falcon_state_dict(model_name):\n    config = falcon_config_to_gpt2_config(\n        AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n    )\n    pretrained_state_dict = remap_state_dict_hf_falcon(\n        state_dict_from_pretrained(model_name), config\n    )\n    model = GPTLMHeadModel(config, device=\"meta\")  # Without device='meta' init is very slow\n    state_dict = model.state_dict()\n    assert state_dict.keys() == pretrained_state_dict.keys()\n    for k in state_dict.keys():\n        assert state_dict[k].shape == pretrained_state_dict[k].shape\n\n\n@pytest.mark.parametrize(\"model_name\", [\"tiiuae/falcon-7b\"])\ndef test_falcon_optimized(model_name):\n    \"\"\"Check that our implementation (with all optimizations enabled) matches the\n    HF implementation: the output of our forward pass in fp16 should be around the same as the HF\n    forward pass in fp16, when compared to the HF forward pass in fp32.\n    \"\"\"\n    dtype = torch.float16\n    device = \"cuda\"\n    config = falcon_config_to_gpt2_config(\n        AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n    )\n    config.use_flash_attn = True\n    config.fused_bias_fc = True\n    config.fused_mlp = False  # We don't have fused MLP for \"gelu\" activation\n    config.fused_dropout_add_ln = True\n    config.residual_in_fp32 = True\n\n    model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)\n    model.eval()\n\n    torch.manual_seed(0)\n    batch_size = 2\n    max_seqlen = 256\n    input_ids = torch.randint(\n        0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device\n    )\n    with torch.no_grad():\n        out = model.transformer(input_ids)\n        logits = model(input_ids).logits\n    del model\n\n    # Without device_map, the model is loaded on the CPU, which is very slow\n    model_ref = AutoModelForCausalLM.from_pretrained(\n        model_name, device_map={\"\": device}, trust_remote_code=True\n    )\n    model_ref.eval()\n    with torch.no_grad():\n        out_ref = model_ref.transformer(input_ids).last_hidden_state.to(device=device)\n        logits_ref = model_ref(input_ids).logits.to(device=device)\n    del model_ref\n\n    model_hf = AutoModelForCausalLM.from_pretrained(\n        model_name, torch_dtype=dtype, device_map={\"\": device}, trust_remote_code=True\n    )\n    model_hf.eval()\n    out_hf = model_hf.transformer(input_ids).last_hidden_state\n    logits_hf = model_hf(input_ids).logits\n    del model_hf\n\n    print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n    print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n    print(f\"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}\")\n    print(f\"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}\")\n    assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()\n\n    print(f\"Logits max diff: {(logits - logits_ref).abs().max().item()}\")\n    print(f\"Logits mean diff: {(logits - logits_ref).abs().mean().item()}\")\n    print(f\"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}\")\n    print(f\"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}\")\n    assert (logits - logits_ref).abs().max().item() < 3 * (\n        logits_hf - logits_ref\n    ).abs().max().item()\n\n\n# torchrun --no_python --nproc_per_node=4 pytest -q -s tests/models/test_falcon.py -k \"falcon_parallel_forward\"\n# We want to run this on a machine with 4 x A100 80GB or 8 x A100 40GB so we have enough\n# memory to run the model in fp32.\n@pytest.mark.parametrize(\"world_size\", [4])\n@pytest.mark.parametrize(\"model_name\", [\"tiiuae/falcon-40b\"])\ndef test_falcon_parallel_forward(model_name, world_size):\n    from apex.transformer import parallel_state\n\n    dtype = torch.float16\n    config = falcon_config_to_gpt2_config(\n        AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n    )\n    config.use_flash_attn = False\n    config.fused_bias_fc = True\n    config.fused_mlp = False  # We don't have fused MLP for \"gelu\" activation\n    config.fused_dropout_add_ln = False\n    config.residual_in_fp32 = True\n\n    if not torch.distributed.is_initialized():\n        torch.distributed.init_process_group(backend=\"nccl\", init_method=\"env://\")\n    device = f\"cuda:{torch.distributed.get_rank()}\"\n    assert world_size <= torch.distributed.get_world_size()\n    parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)\n    rank = parallel_state.get_tensor_model_parallel_rank()\n    process_group = parallel_state.get_tensor_model_parallel_group()\n\n    pretrained_state_dict = remap_state_dict_hf_falcon(\n        state_dict_from_pretrained(model_name), config\n    )\n\n    model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)\n    model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))\n    model.eval()\n\n    torch.manual_seed(0)\n    batch_size = 2\n    max_seqlen = 256\n    seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)\n    input_ids = torch.randint(\n        0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device\n    )\n    with torch.no_grad():\n        out = model.transformer(input_ids)\n        out, _ = all_gather_raw(out, process_group=process_group)\n        out = rearrange(out, \"(b s) d -> b s d\", b=batch_size)\n        logits = model(input_ids).logits\n        logits = rearrange(logits, \"(b s) d -> b s d\", b=batch_size)\n        logits, _ = all_gather_raw(logits, process_group)\n        logits = rearrange(logits, \"(n b) ... d -> b ... (n d)\", b=batch_size)\n    del model\n    parallel_state.destroy_model_parallel()\n\n    if rank == 0:\n        model_hf = AutoModelForCausalLM.from_pretrained(\n            model_name, torch_dtype=dtype, device_map=\"auto\", trust_remote_code=True\n        )\n        model_hf.eval()\n        out_hf = model_hf.transformer(input_ids).last_hidden_state.to(device=device)\n        logits_hf = model_hf(input_ids).logits.to(device=device)\n        del model_hf\n\n        # Without device_map, the model is loaded on the CPU, which is very slow\n        model_ref = AutoModelForCausalLM.from_pretrained(\n            model_name, device_map=\"auto\", trust_remote_code=True\n        )\n        model_ref.eval()\n        with torch.no_grad():\n            out_ref = model_ref.transformer(input_ids).last_hidden_state.to(device=device)\n            logits_ref = model_ref(input_ids).logits.to(device=device)\n        del model_ref\n\n        print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n        print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n        print(f\"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}\")\n        print(f\"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}\")\n        assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item()\n\n        print(f\"Logits max diff: {(logits - logits_ref).abs().max().item()}\")\n        print(f\"Logits mean diff: {(logits - logits_ref).abs().mean().item()}\")\n        print(f\"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}\")\n        print(f\"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}\")\n        assert (logits - logits_ref).abs().max().item() < 2 * (\n            logits_hf - logits_ref\n        ).abs().max().item()\n\n\n@pytest.mark.parametrize(\"model_name\", [\"tiiuae/falcon-7b\"])\ndef test_falcon_generation(model_name):\n    \"\"\"Check that our implementation (with all optimizations enabled) matches the\n    HF implementation: the output of our forward pass in fp16 should be around the same as the HF\n    forward pass in fp16, when compared to the HF forward pass in fp32.\n    \"\"\"\n    dtype = torch.float16\n    device = \"cuda\"\n    config = falcon_config_to_gpt2_config(\n        AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n    )\n    config.use_flash_attn = True\n    config.fused_bias_fc = True\n    config.fused_mlp = False  # We don't have fused MLP for \"gelu\" activation\n    config.fused_dropout_add_ln = True\n    config.residual_in_fp32 = True\n\n    tokenizer = AutoTokenizer.from_pretrained(model_name)\n    eos_token_id = tokenizer.eos_token_id\n\n    torch.manual_seed(0)\n    batch_size = 1\n    seqlen = 100\n    max_length = 150\n    input_ids = torch.randint(\n        0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device\n    )\n\n    model_hf = AutoModelForCausalLM.from_pretrained(\n        model_name, torch_dtype=dtype, device_map={\"\": device}, trust_remote_code=True\n    )\n    model_hf.eval()\n    print(\"HF fp16\")\n    torch.cuda.synchronize()\n    start = time.time()\n    out_hf = model_hf.generate(\n        input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True\n    )\n    torch.cuda.synchronize()\n    print(f\"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms\")\n    del model_hf\n\n    model_ref = AutoModelForCausalLM.from_pretrained(\n        model_name, device_map={\"\": device}, trust_remote_code=True\n    )\n    model_ref.eval()\n    with torch.no_grad():\n        logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1]\n    del model_ref\n\n    model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)\n    model.eval()\n\n    print(\"Without CUDA graph\")\n    torch.cuda.synchronize()\n    start = time.time()\n    out = model.generate(\n        input_ids=input_ids,\n        max_length=max_length,\n        eos_token_id=eos_token_id,\n        return_dict_in_generate=True,\n        output_scores=True,\n        enable_timing=True,\n        teacher_outputs=out_hf.sequences,\n    )\n    torch.cuda.synchronize()\n    print(f\"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms\")\n\n    # Capture graph outside the timing loop\n    batch_size, seqlen_og = input_ids.shape\n    model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)\n    print(\"With CUDA graph\")\n    torch.cuda.synchronize()\n    start = time.time()\n    out_cg = model.generate(\n        input_ids=input_ids,\n        max_length=max_length,\n        cg=True,\n        return_dict_in_generate=True,\n        output_scores=True,\n        enable_timing=True,\n        teacher_outputs=out_hf.sequences,\n    )\n    torch.cuda.synchronize()\n    print(f\"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms\")\n\n    with torch.no_grad():\n        logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1) : -1]\n    logits_hf = torch.stack(out_hf.scores, dim=1)\n    logits = torch.stack(out.scores, dim=1)\n    logits_cg = torch.stack(out_cg.scores, dim=1)\n\n    del model\n\n    hf_error = (logits_hf - logits_ref).abs().max().item()\n    assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error\n\n    print(f\"HF fp16 logits max diff: {hf_error}\")\n    print(f\"Logits max diff: {(logits - logits_ref).abs().max().item() }\")\n    assert (logits - logits_ref).abs().max().item() < 2 * hf_error\n    print(f\"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }\")\n    assert torch.equal(logits_cg, logits)\n\n\n# torchrun --no_python --nproc_per_node=4 pytest -q -s tests/models/test_falcon.py -k \"falcon_parallel_generation\"\n# We want to run this on a machine with 4 x A100 80GB or 8 x A100 40GB so we have enough\n# memory to run the model in fp32.\n@pytest.mark.parametrize(\"world_size\", [4])\n@pytest.mark.parametrize(\"model_name\", [\"tiiuae/falcon-40b\"])\ndef test_falcon_parallel_generation(model_name, world_size):\n    \"\"\"Check that our implementation matches the HF implementation:\n    the scores in fp16 should be around the same as the HF scores in fp16, when compared to\n    the HF scores in fp32.\n    \"\"\"\n    from apex.transformer import parallel_state\n\n    dtype = torch.float16\n    config = falcon_config_to_gpt2_config(\n        AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n    )\n    config.use_flash_attn = False\n    config.fused_bias_fc = True\n    config.fused_mlp = False  # We don't have fused MLP for \"gelu\" activation\n    config.fused_dropout_add_ln = False\n    config.residual_in_fp32 = True\n    config.pad_vocab_size_multiple = 8 * world_size\n    config.sequence_parallel = False  # Need to set this to False for generation\n\n    os.environ[\"NCCL_ASYNC_ERROR_HANDLING\"] = \"0\"\n    if not torch.distributed.is_initialized():\n        torch.distributed.init_process_group(backend=\"nccl\", init_method=\"env://\")\n    device = f\"cuda:{torch.distributed.get_rank()}\"\n    assert world_size <= torch.distributed.get_world_size()\n    parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)\n    rank = parallel_state.get_tensor_model_parallel_rank()\n    process_group = parallel_state.get_tensor_model_parallel_group()\n\n    torch.manual_seed(0)\n    batch_size = 1\n    seqlen = 100\n    max_length = 150\n    input_ids = torch.randint(\n        0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device\n    )\n\n    # Need this, otherwise when we capture the graph the process for GPU 1 would run on both\n    # GPU0 and GPU1 and things would hang\n    torch.cuda.set_device(device)\n\n    pretrained_state_dict = remap_state_dict_hf_falcon(\n        state_dict_from_pretrained(model_name), config\n    )\n\n    model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)\n    model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))\n    model.eval()\n\n    print(\"Without CUDA graph\")\n    out = model.generate(\n        input_ids=input_ids,\n        max_length=max_length,\n        tensor_parallel=world_size,\n        vocab_size=config.vocab_size,\n        # teacher_outputs=out_hf.sequences,\n        return_dict_in_generate=True,\n        output_scores=True,\n        enable_timing=True,\n    )\n\n    # Capture graph outside the timing loop\n    batch_size, seqlen_og = input_ids.shape\n    model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)\n    print(\"With CUDA graph\")\n    out_cg = model.generate(\n        input_ids=input_ids,\n        max_length=max_length,\n        tensor_parallel=world_size,\n        vocab_size=config.vocab_size,\n        cg=True,\n        # teacher_outputs=out_hf.sequences,\n        return_dict_in_generate=True,\n        output_scores=True,\n        enable_timing=True,\n    )\n    del model\n    parallel_state.destroy_model_parallel()\n\n    if rank == 0:\n        model_hf = AutoModelForCausalLM.from_pretrained(\n            model_name, torch_dtype=dtype, device_map=\"auto\", trust_remote_code=True\n        )\n        model_hf.eval()\n        print(\"HF fp16\")\n        torch.cuda.synchronize()\n        start = time.time()\n        with torch.inference_mode():\n            out_hf = model_hf.generate(\n                input_ids=input_ids,\n                max_length=max_length,\n                return_dict_in_generate=True,\n                output_scores=True,\n            )\n        torch.cuda.synchronize()\n        print(f\"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms\")\n        del model_hf\n\n        model_ref = AutoModelForCausalLM.from_pretrained(\n            model_name, device_map=\"auto\", trust_remote_code=True\n        )\n        model_ref.eval()\n        with torch.inference_mode():\n            logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1]\n        del model_ref\n        logits_hf = torch.stack(out_hf.scores, dim=1)\n\n        logits = torch.stack(out.scores, dim=1)\n        logits_cg = torch.stack(out_cg.scores, dim=1)\n\n        hf_error = (logits_hf - logits_ref).abs().max().item()\n        print(f\"HF fp16 logits max diff: {hf_error}\")\n        print(f\"Logits max diff: {(logits - logits_ref).abs().max().item() }\")\n        assert (logits - logits_ref).abs().max().item() < 2 * hf_error\n        print(f\"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }\")\n        assert torch.equal(logits_cg, logits)\n"
  },
  {
    "path": "tests/models/test_gpt.py",
    "content": "import re\n\nimport pytest\nimport torch\nfrom einops import rearrange\nfrom flash_attn.models.gpt import (\n    GPTLMHeadModel,\n    remap_state_dict_hf_gpt2,\n    shard_state_dict_tp,\n    combine_state_dicts_tp,\n)\nfrom flash_attn.utils.generation import InferenceParams\nfrom flash_attn.utils.pretrained import state_dict_from_pretrained\nfrom transformers import GPT2Config, GPT2Tokenizer\nfrom transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF\n\n\n@pytest.mark.parametrize(\"model_name\", [\"gpt2\", \"gpt2-medium\"])\n# @pytest.mark.parametrize('model_name', [\"gpt2\"])\ndef test_gpt2_state_dict(model_name):\n    config = GPT2Config.from_pretrained(model_name)\n    pretrained_state_dict = remap_state_dict_hf_gpt2(state_dict_from_pretrained(model_name), config)\n    model = GPTLMHeadModel(config)\n    state_dict = model.state_dict()\n    assert state_dict.keys() == pretrained_state_dict.keys()\n    for k in state_dict.keys():\n        assert state_dict[k].shape == pretrained_state_dict[k].shape\n\n\n@pytest.mark.parametrize(\"model_name\", [\"gpt2\", \"gpt2-medium\"])\n# @pytest.mark.parametrize('model_name', [\"gpt2\"])\ndef test_gpt2_non_optimized(model_name):\n    \"\"\"Check that our implementation of GPT2 (without any optimizations enabled) matches the\n    HF implementation: the output of our forward pass in fp16 should be around the same as the HF\n    forward pass in fp16, when compared to the HF forward pass in fp32.\n    \"\"\"\n    dtype = torch.float16\n    config = GPT2Config.from_pretrained(model_name)\n\n    model = GPTLMHeadModel.from_pretrained(model_name, config)\n    model = model.cuda().to(dtype=dtype)\n\n    model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).cuda()\n    model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).cuda().to(dtype=dtype)\n\n    model.eval()\n    model_ref.eval()\n    model_hf.eval()\n\n    torch.manual_seed(0)\n    batch_size = 4\n    max_seqlen = 512\n    seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=\"cuda\")\n    input_ids = torch.randint(\n        0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=\"cuda\"\n    )\n    out = model.transformer(input_ids)\n    out_hf = model_hf.transformer(input_ids).last_hidden_state\n    out_ref = model_ref.transformer(input_ids).last_hidden_state\n\n    print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n    print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n    print(f\"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}\")\n    print(f\"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}\")\n    assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()\n\n    logits = model(input_ids).logits\n    logits_hf = model_hf(input_ids).logits\n    logits_ref = model_ref(input_ids).logits\n\n    print(f\"Logits max diff: {(logits - logits_ref).abs().max().item()}\")\n    print(f\"Logits mean diff: {(logits - logits_ref).abs().mean().item()}\")\n    print(f\"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}\")\n    print(f\"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}\")\n    assert (logits - logits_ref).abs().max().item() < 3 * (\n        logits_hf - logits_ref\n    ).abs().max().item()\n\n\n@pytest.mark.parametrize(\"model_name\", [\"gpt2\", \"gpt2-medium\"])\n# @pytest.mark.parametrize('model_name', [\"gpt2\"])\ndef test_gpt2_optimized(model_name):\n    \"\"\"Check that our implementation of GPT2 (with all optimizations enabled) matches the\n    HF implementation: the output of our forward pass in fp16 should be around the same as the HF\n    forward pass in fp16, when compared to the HF forward pass in fp32.\n    \"\"\"\n    dtype = torch.float16\n    config = GPT2Config.from_pretrained(model_name)\n    vocab_size_og = config.vocab_size\n    config.use_flash_attn = True\n    config.fused_bias_fc = True\n    config.fused_mlp = True\n    config.fused_dropout_add_ln = True\n    config.residual_in_fp32 = True\n    config.pad_vocab_size_multiple = 8\n\n    model = GPTLMHeadModel.from_pretrained(model_name, config)\n    model = model.cuda().to(dtype=dtype)\n\n    model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).cuda()\n    model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).cuda().to(dtype=dtype)\n\n    model.eval()\n    model_ref.eval()\n    model_hf.eval()\n\n    torch.manual_seed(0)\n    batch_size = 4\n    max_seqlen = 512\n    seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=\"cuda\")\n    input_ids = torch.randint(\n        0, vocab_size_og, (batch_size, max_seqlen), dtype=torch.long, device=\"cuda\"\n    )\n    out = model.transformer(input_ids)\n    out_hf = model_hf.transformer(input_ids).last_hidden_state\n    out_ref = model_ref.transformer(input_ids).last_hidden_state\n\n    print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n    print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n    print(f\"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}\")\n    print(f\"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}\")\n    assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()\n\n    logits = model(input_ids).logits[..., :vocab_size_og]\n    logits_hf = model_hf(input_ids).logits\n    logits_ref = model_ref(input_ids).logits\n\n    print(f\"Logits max diff: {(logits - logits_ref).abs().max().item()}\")\n    print(f\"Logits mean diff: {(logits - logits_ref).abs().mean().item()}\")\n    print(f\"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}\")\n    print(f\"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}\")\n    assert (logits - logits_ref).abs().max().item() < 3 * (\n        logits_hf - logits_ref\n    ).abs().max().item()\n\n\n@pytest.mark.parametrize(\"optimized\", [False, True])\n# @pytest.mark.parametrize('optimized', [True])\n@pytest.mark.parametrize(\"rotary\", [False, True])\n# @pytest.mark.parametrize('rotary', [False])\n@pytest.mark.parametrize(\"model_name\", [\"gpt2\"])\ndef test_gpt2_generation(model_name, rotary, optimized):\n    \"\"\"Check that our implementation of GPT2 generation matches the HF implementation:\n    the scores in fp16 should be around the same as the HF scores in fp16, when compared to\n    the HF scores in fp32.\n    \"\"\"\n    dtype = torch.float16\n    device = \"cuda\"\n    rtol, atol = 3e-3, 3e-1\n    config = GPT2Config.from_pretrained(model_name)\n    if rotary:\n        config.n_positions = 0\n        config.rotary_emb_fraction = 0.5\n        config.rotary_emb_base = 24000\n    config.residual_in_fp32 = True\n    if optimized:\n        config.use_flash_attn = True\n        config.fused_bias_fc = True\n        config.fused_mlp = True\n        config.fused_dropout_add_ln = True\n\n    # if not rotary, we load the weight from HF but ignore the position embeddings.\n    # The model would be nonsense but it doesn't matter for the test.\n    model = GPTLMHeadModel.from_pretrained(\n        model_name, config, strict=not rotary, device=device, dtype=dtype\n    )\n    model.eval()\n\n    if not rotary:\n        model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device)\n        model_hf = GPT2LMHeadModelHF.from_pretrained(model_name, torch_dtype=dtype).to(\n            device=device\n        )\n        model_ref.eval()\n        model_hf.eval()\n\n    torch.manual_seed(0)\n    tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")\n    input_ids = tokenizer(\"Hello, my dog is cute and he\", return_tensors=\"pt\").input_ids.to(\n        device=device\n    )\n    max_length = 25\n    # input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')\n    # max_length = input_ids.shape[1] + 40\n\n    # Slow generation for reference\n    sequences = []\n    scores = []\n    cur_input_ids = input_ids\n    with torch.inference_mode():\n        scores.append(model(cur_input_ids).logits[:, -1])\n        sequences.append(scores[-1].argmax(dim=-1))\n        for _ in range(input_ids.shape[1] + 1, max_length):\n            cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], \"b -> b 1\")], dim=-1)\n            scores.append(model(cur_input_ids).logits[:, -1])\n            sequences.append(scores[-1].argmax(dim=-1))\n    sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1)\n    scores = tuple(scores)\n\n    out = model.generate(\n        input_ids=input_ids,\n        max_length=max_length,\n        return_dict_in_generate=True,\n        output_scores=True,\n        enable_timing=True,\n    )\n    print(out.sequences)\n    print(tokenizer.batch_decode(out.sequences.tolist()))\n    if getattr(config, \"use_flash_attn\", False):\n        out_cg = model.generate(\n            input_ids=input_ids,\n            max_length=max_length,\n            cg=True,\n            return_dict_in_generate=True,\n            output_scores=True,\n            enable_timing=True,\n        )\n        print(out_cg.sequences)\n        assert torch.equal(torch.stack(out.scores, dim=1), torch.stack(out_cg.scores, dim=1))\n\n    if not rotary:\n        out_hf = model_hf.generate(\n            input_ids=input_ids,\n            max_length=max_length,\n            return_dict_in_generate=True,\n            output_scores=True,\n        )\n        out_ref = model_ref.generate(\n            input_ids=input_ids,\n            max_length=max_length,\n            return_dict_in_generate=True,\n            output_scores=True,\n        )\n\n        print(\n            f\"Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}\"\n        )\n        print(\n            f\"Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}\"\n        )\n        print(\n            f\"HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}\"\n        )\n        print(\n            f\"HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}\"\n        )\n        print(tokenizer.batch_decode(out_ref.sequences.tolist()))\n\n    assert torch.all(out.sequences == sequences)\n    assert torch.allclose(\n        torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), rtol=rtol, atol=atol\n    )\n    if not rotary:\n        assert torch.all(out.sequences == out_ref.sequences)\n        assert torch.all(out.sequences == out_hf.sequences)\n\n        assert (\n            torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)\n        ).abs().max().item() < 3 * (\n            torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)\n        ).abs().max().item()\n\n\ndef get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs):\n    out = model.generate(\n        input_ids=input_ids,\n        max_length=max_length,\n        teacher_outputs=teacher_outputs,\n        return_dict_in_generate=True,\n        output_scores=True,\n        enable_timing=True,\n        **kwargs,\n    )\n    return torch.stack(out.scores, dim=1)\n\n\n@pytest.mark.parametrize(\"seqlen,maxlen\", [(10, 20), (30, 150), (3000, 3400), (14000, 15000)])\n# @pytest.mark.parametrize('seqlen,maxlen', [(10, 20)])\n@pytest.mark.parametrize(\"rotary\", [None, \"interleaved\", \"contiguous\"])\n# @pytest.mark.parametrize('rotary', [None])\n@pytest.mark.parametrize(\"model_name\", [\"gpt2\"])\ndef test_gpt2_generation_cg(model_name, rotary, seqlen, maxlen):\n    \"\"\"Check that decoding with CUDA graph is the same as decoding without CUDA graph.\"\"\"\n    dtype = torch.float16\n    device = \"cuda\"\n    rtol, atol = 3e-3, 3e-1\n    config = GPT2Config.from_pretrained(model_name)\n    config.n_positions = 16 * 1024\n    assert seqlen <= maxlen <= config.n_positions\n    if rotary is not None:\n        config.n_positions = 0\n        config.rotary_emb_dim = 32\n        config.rotary_emb_interleaved = rotary == \"interleaved\"\n    config.residual_in_fp32 = True\n    config.use_flash_attn = True\n    config.fused_bias_fc = True\n    config.fused_mlp = True\n    config.fused_dropout_add_ln = True\n\n    model = GPTLMHeadModel(config, device=device, dtype=dtype)\n    model.eval()\n\n    torch.manual_seed(0)\n    batch_size = 1\n    input_ids = torch.randint(\n        0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device\n    )\n    teacher_outputs = torch.randint(\n        0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, device=device\n    )\n\n    logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs)\n    logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True)\n    assert torch.equal(logits, logits_cg)\n\n    # Try increasing batch size and seqlen, then decrease them to see if it's still correct\n    batch_size = 3\n    maxlen += 30\n    input_ids = torch.randint(\n        0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device\n    )\n    teacher_outputs = torch.randint(\n        0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, device=device\n    )\n    logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs)\n    logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True)\n    assert torch.equal(logits, logits_cg)\n\n    batch_size = 2\n    maxlen -= 35\n    input_ids = torch.randint(\n        0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device\n    )\n    teacher_outputs = torch.randint(\n        0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, device=device\n    )\n    logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs)\n    logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True)\n    assert torch.equal(logits, logits_cg)\n\n\n@pytest.mark.parametrize(\"optimized\", [False, True])\n# @pytest.mark.parametrize(\"optimized\", [False])\n@pytest.mark.parametrize(\"model_name\", [\"gpt2\"])\ndef test_gpt2_multiple_token_generation(model_name, optimized):\n    \"\"\"Generation when we pass in multiple tokens at a time, not just one.\"\"\"\n    dtype = torch.float16\n    device = \"cuda\"\n    rtol, atol = 3e-3, 3e-1\n    config = GPT2Config.from_pretrained(model_name)\n    config.residual_in_fp32 = True\n    if optimized:\n        config.use_flash_attn = True\n        config.fused_bias_fc = True\n        config.fused_mlp = True\n        config.fused_dropout_add_ln = True\n\n    model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)\n    model.eval()\n\n    torch.manual_seed(0)\n    input_ids = torch.randint(0, config.vocab_size, (1, 20), dtype=torch.long, device=device)\n    # Reference logits\n    logits_ref = model(input_ids).logits\n\n    # Run 10 tokens, then pass in another 4, then another 6, to see if we get the same logits\n    inference_params = InferenceParams(max_seqlen=20, max_batch_size=1)\n    logits_10 = model(input_ids[:, :10], inference_params=inference_params).logits\n    inference_params.seqlen_offset += 10\n    position_ids = torch.arange(10, 14, dtype=torch.long, device=device)\n    logits_1014 = model(\n        input_ids[:, 10:14], position_ids=position_ids, inference_params=inference_params\n    ).logits\n    inference_params.seqlen_offset += 4\n    position_ids = torch.arange(14, 20, dtype=torch.long, device=device)\n    logits_1420 = model(\n        input_ids[:, 14:20], position_ids=position_ids, inference_params=inference_params\n    ).logits\n    logits = torch.cat([logits_10, logits_1014, logits_1420], dim=1)\n    print(f\"Logits max diff: {(logits - logits_ref).abs().max().item()}\")\n    print(f\"Logits mean diff: {(logits - logits_ref).abs().mean().item()}\")\n    assert torch.allclose(logits, logits_ref, rtol=rtol, atol=atol)\n\n\n@pytest.mark.parametrize(\"cg\", [False, True])\n# @pytest.mark.parametrize(\"cg\", [True])\n@pytest.mark.parametrize(\"optimized\", [False, True])\n# @pytest.mark.parametrize(\"optimized\", [True])\n# @pytest.mark.parametrize(\"model_name\", [\"gpt2-medium\"])\n@pytest.mark.parametrize(\"model_name\", [\"gpt2-xl\"])\ndef test_gpt2_speculative_decoding(model_name, optimized, cg):\n    if cg and not optimized:\n        pytest.skip()  # CG requires use_flash_attn\n    dtype = torch.float16\n    device = \"cuda\"\n    rtol, atol = 3e-3, 3e-1\n    config = GPT2Config.from_pretrained(model_name)\n    config.residual_in_fp32 = True\n    if optimized:\n        config.use_flash_attn = True\n        config.fused_bias_fc = True\n        config.fused_mlp = True\n        config.fused_dropout_add_ln = True\n    config_draft = GPT2Config.from_pretrained(\"gpt2\")\n    config_draft.residual_in_fp32 = True\n    if optimized:\n        config_draft.use_flash_attn = True\n        config_draft.fused_bias_fc = True\n        config_draft.fused_mlp = True\n        config_draft.fused_dropout_add_ln = True\n\n    model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)\n    model.eval()\n    model_draft = GPTLMHeadModel.from_pretrained(\"gpt2\", config_draft, device=device, dtype=dtype)\n    model_draft.eval()\n\n    torch.manual_seed(0)\n    tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")\n    input_ids = tokenizer(\"Hello, my dog is cute and he\", return_tensors=\"pt\").input_ids.to(\n        device=device\n    )\n    max_length = 100\n\n    from flash_attn.utils.generation import decode_speculative\n\n    torch.manual_seed(42)\n    print(f\"Speculative decoding, {optimized = }\")\n    out = decode_speculative(\n        input_ids,\n        model,\n        model_draft,\n        max_length=max_length,\n        top_k=5,\n        cg=cg,\n        speculative_lookahead=4,\n        enable_timing=True,\n        # debug=True,\n    )\n    print(tokenizer.batch_decode(out.sequences))\n    print(f\"Without speculative decoding, {cg = }\")\n    out_og = model.generate(\n        input_ids,\n        max_length=max_length,\n        top_k=5,\n        cg=cg,\n        enable_timing=True,\n        return_dict_in_generate=True,\n    )\n    print(tokenizer.batch_decode(out_og.sequences))\n\n\n@pytest.mark.parametrize(\n    \"n_heads_q_kv\",\n    [\n        (8, 8),  # Regular attention\n        (8, 4),  # GQA\n        (8, 2),  # MQA\n    ],\n)\ndef test_gpt2_shard_unshard(n_heads_q_kv):\n    world_size = 2\n\n    config = GPT2Config.from_pretrained(\"gpt2\")\n    config.vocab_size = 1024\n    config.n_head, config.n_head_kv = n_heads_q_kv\n    model = GPTLMHeadModel(config, device=\"cuda\", dtype=torch.float16)\n    state_dict = model.state_dict()\n    shards = [\n        # NOTE: Shallow copy as `state_dict` is modified in-place\n        shard_state_dict_tp(dict(state_dict), config, world_size, rank)\n        for rank in range(world_size)\n    ]\n    state_dict2 = combine_state_dicts_tp(shards, config)\n    assert state_dict2.keys() == state_dict.keys()\n    for k in state_dict.keys():\n        ref = state_dict[k]\n        new = state_dict[k]\n        assert torch.allclose(ref, new, atol=0.0, rtol=0.0)\n"
  },
  {
    "path": "tests/models/test_gpt_generation_parallel.py",
    "content": "# Run test with:\n# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/models/test_gpt_generation_parallel.py -k \"parallel\"\nimport os\nimport re\n\nimport pytest\nimport torch\nfrom einops import rearrange\nfrom flash_attn.models.gpt import GPTLMHeadModel, remap_state_dict_hf_gpt2\nfrom flash_attn.utils.distributed import all_gather_raw\nfrom flash_attn.utils.pretrained import state_dict_from_pretrained\nfrom transformers import GPT2Config, GPT2Tokenizer\nfrom transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF\n\n\n# @pytest.mark.parametrize('world_size', [1, 2, 4, 8])\n@pytest.mark.parametrize(\"world_size\", [2])\n@pytest.mark.parametrize('rotary', [False, True])\n# @pytest.mark.parametrize(\"rotary\", [False])\n@pytest.mark.parametrize(\"model_name\", [\"gpt2\"])\ndef test_tensor_parallel(model_name, rotary, world_size):\n    \"\"\"Check that our implementation of GPT2 generation matches the HF implementation:\n    the scores in fp16 should be around the same as the HF scores in fp16, when compared to\n    the HF scores in fp32.\n    \"\"\"\n    dtype = torch.float16\n    rtol, atol = 3e-3, 3e-1\n    config = GPT2Config.from_pretrained(model_name)\n    if rotary:\n        config.n_positions = 0\n        config.rotary_emb_dim = 64\n    config.residual_in_fp32 = True\n    config.use_flash_attn = True\n    config.fused_bias_fc = True\n    config.fused_mlp = True\n    config.fused_dropout_add_ln = True\n    config.pad_vocab_size_multiple = 8 * world_size\n    config.sequence_parallel = False  # Need to set this to False for generation\n\n    os.environ[\"NCCL_ASYNC_ERROR_HANDLING\"] = \"0\"\n    if not torch.distributed.is_initialized():\n        torch.distributed.init_process_group(backend=\"nccl\", init_method=\"env://\")\n    device = f\"cuda:{torch.distributed.get_rank()}\"\n    assert world_size <= torch.distributed.get_world_size()\n    # Need this, otherwise when we capture the graph the process for GPU 1 would run on both\n    # GPU0 and GPU1 and things would hang\n    torch.cuda.set_device(device)\n\n    from apex.transformer import parallel_state\n\n    parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)\n    rank = parallel_state.get_tensor_model_parallel_rank()\n    process_group = parallel_state.get_tensor_model_parallel_group()\n\n    # if not rotary, we load the weight from HF but ignore the position embeddings.\n    # The model would be nonsense but it doesn't matter for the test.\n    model = GPTLMHeadModel.from_pretrained(\n        model_name,\n        config,\n        strict=not rotary,\n        device=device,\n        dtype=dtype,\n        process_group=process_group,\n        world_size=world_size,\n        rank=rank,\n    )\n    model.eval()\n\n    if not rotary:\n        model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device)\n        model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device, dtype=dtype)\n        model_ref.eval()\n        model_hf.eval()\n\n    torch.manual_seed(0)\n    tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")\n    input_ids = tokenizer(\"Hello, my dog is cute and \", return_tensors=\"pt\").input_ids.to(\n        device=device\n    )\n    max_length = 30\n    # input_ids = torch.randint(0, 100, (1, 10), dtype=torch.long, device='cuda')\n    # max_length = input_ids.shape[1] + 40\n\n    # Slow generation for reference\n    sequences = []\n    scores = []\n    cur_input_ids = input_ids\n    with torch.inference_mode():\n        logits, _ = all_gather_raw(model(cur_input_ids).logits[:, -1], process_group)\n        logits = rearrange(logits, \"(n b) d -> b (n d)\", b=input_ids.shape[0])[\n            ..., : config.vocab_size\n        ]\n        scores.append(logits)\n        sequences.append(scores[-1].argmax(dim=-1))\n        for _ in range(input_ids.shape[1] + 1, max_length):\n            cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], \"b -> b 1\")], dim=-1)\n            logits, _ = all_gather_raw(model(cur_input_ids).logits[:, -1], process_group)\n            logits = rearrange(logits, \"(n b) d -> b (n d)\", b=input_ids.shape[0])[\n                ..., : config.vocab_size\n            ]\n            scores.append(logits)\n            sequences.append(scores[-1].argmax(dim=-1))\n    sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1)\n    scores = tuple(scores)\n    print(sequences)\n\n    out = model.generate(\n        input_ids=input_ids,\n        max_length=max_length,\n        tensor_parallel=world_size,\n        vocab_size=config.vocab_size,\n        return_dict_in_generate=True,\n        output_scores=True,\n        enable_timing=True,\n    )\n    print(out.sequences)\n    if getattr(config, \"use_flash_attn\", False):\n        out_cg = model.generate(\n            input_ids=input_ids,\n            max_length=max_length,\n            tensor_parallel=world_size,\n            vocab_size=config.vocab_size,\n            cg=True,\n            return_dict_in_generate=True,\n            output_scores=True,\n            enable_timing=True,\n        )\n        print(out_cg.sequences)\n\n    parallel_state.destroy_model_parallel()\n\n    if not rotary:\n        out_hf = model_hf.generate(\n            input_ids=input_ids,\n            max_length=max_length,\n            return_dict_in_generate=True,\n            output_scores=True,\n        )\n        out_ref = model_ref.generate(\n            input_ids=input_ids,\n            max_length=max_length,\n            return_dict_in_generate=True,\n            output_scores=True,\n        )\n\n        print(\n            f\"Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}\"\n        )\n        print(\n            f\"Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}\"\n        )\n        print(\n            f\"HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}\"\n        )\n        print(\n            f\"HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}\"\n        )\n\n    assert torch.all(out.sequences == sequences)\n    assert torch.allclose(\n        torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), rtol=rtol, atol=atol\n    )\n    assert torch.equal(torch.stack(out.scores, dim=1), torch.stack(out_cg.scores, dim=1))\n    if not rotary:\n        assert torch.all(out.sequences == out_ref.sequences)\n        assert torch.all(out.sequences == out_hf.sequences)\n\n        assert (\n            torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)\n        ).abs().max().item() < 3 * (\n            torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)\n        ).abs().max().item()\n"
  },
  {
    "path": "tests/models/test_gpt_neox.py",
    "content": "# Copyright (c) 2023, Tri Dao.\n\nimport time\n\nimport pytest\nimport torch\nfrom flash_attn.models.gpt import GPTLMHeadModel\nfrom flash_attn.models.gpt_neox import gpt_neox_config_to_gpt2_config, remap_state_dict_hf_gpt_neox\nfrom flash_attn.utils.pretrained import state_dict_from_pretrained\nfrom transformers import AutoTokenizer, GPTNeoXConfig\nfrom transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM\n\n\n@pytest.mark.parametrize(\"model_name\", [\"EleutherAI/gpt-neox-20b\"])\ndef test_gptj_state_dict(model_name):\n    config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(model_name))\n    pretrained_state_dict = remap_state_dict_hf_gpt_neox(\n        state_dict_from_pretrained(model_name), config\n    )\n    model = GPTLMHeadModel(config, device=\"meta\")  # Without device='meta' init is very slow\n    state_dict = model.state_dict()\n    assert state_dict.keys() == pretrained_state_dict.keys()\n    for k in state_dict.keys():\n        assert state_dict[k].shape == pretrained_state_dict[k].shape\n\n\n@pytest.mark.parametrize(\n    \"model_name\",\n    [\n        \"EleutherAI/pythia-1b\",\n        \"EleutherAI/pythia-2.8b\",\n        \"EleutherAI/gpt-neox-20b\",\n        \"togethercomputer/RedPajama-INCITE-7B-Base\",\n    ],\n)\ndef test_gpt_neox_optimized(model_name):\n    \"\"\"Check that our implementation of GPT-NeoX (with all optimizations enabled) matches the\n    HF implementation: the output of our forward pass in fp16 should be around the same as the HF\n    forward pass in fp16, when compared to the HF forward pass in fp32.\n    \"\"\"\n    dtype = torch.float16\n    device = \"cuda\"\n    config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(model_name))\n    config.use_flash_attn = True\n    config.fused_bias_fc = True\n    config.fused_mlp = config.activation_function in [\n        \"gelu_fast\",\n        \"gelu_new\",\n        \"gelu_approx\",\n        \"gelu_pytorch_tanh\",\n    ]\n    config.fused_dropout_add_ln = True\n    config.residual_in_fp32 = True\n\n    model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)\n    model.eval()\n\n    torch.manual_seed(0)\n    batch_size = 2\n    max_seqlen = 256\n    seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)\n    input_ids = torch.randint(\n        0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device\n    )\n    with torch.no_grad():\n        out = model.transformer(input_ids)\n        logits = model(input_ids).logits\n    del model\n\n    # Need at least 2 GPUs, otherwise we'll OOM for the 20B model\n    # Without device_map, the model is loaded on the CPU, which is very slow\n    model_ref = GPTNeoXForCausalLM.from_pretrained(model_name, device_map=\"auto\")\n    model_ref.eval()\n    with torch.no_grad():\n        out_ref = model_ref.gpt_neox(input_ids).last_hidden_state.to(device=device)\n        logits_ref = model_ref(input_ids).logits.to(device=device)\n    del model_ref\n\n    model_hf = GPTNeoXForCausalLM.from_pretrained(\n        model_name, torch_dtype=dtype, device_map={\"\": device}\n    )\n    model_hf.eval()\n    with torch.no_grad():\n        out_hf = model_hf.gpt_neox(input_ids).last_hidden_state\n        logits_hf = model_hf(input_ids).logits\n    del model_hf\n\n    print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n    print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n    print(f\"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}\")\n    print(f\"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}\")\n    assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item()\n    assert (out - out_ref).abs().mean().item() < 2 * (out_hf - out_ref).abs().mean().item()\n\n    print(f\"Logits max diff: {(logits - logits_ref).abs().max().item()}\")\n    print(f\"Logits mean diff: {(logits - logits_ref).abs().mean().item()}\")\n    print(f\"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}\")\n    print(f\"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}\")\n    assert (logits - logits_ref).abs().max().item() < 2 * (\n        logits_hf - logits_ref\n    ).abs().max().item()\n    assert (logits - logits_ref).abs().mean().item() < 2 * (\n        logits_hf - logits_ref\n    ).abs().mean().item()\n"
  },
  {
    "path": "tests/models/test_gpt_parallel.py",
    "content": "# Run test with:\n# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/models/test_gpt_parallel.py\n\nimport math\n\nimport pytest\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom apex.transformer import parallel_state\nfrom einops import rearrange\nfrom flash_attn.losses.cross_entropy import CrossEntropyLoss\nfrom flash_attn.models.gpt import GPTLMHeadModel, shard_state_dict_tp\nfrom flash_attn.utils.distributed import allreduce_sequence_parallel_grad\nfrom transformers import GPT2Config\n\nis_sm8x = torch.cuda.get_device_capability(\"cuda\")[0] >= 8\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))\n# @pytest.mark.parametrize('dtype', [torch.bfloat16])\n@pytest.mark.parametrize(\"world_size\", [1, 2, 4, 8])\n# @pytest.mark.parametrize('world_size', [2])\n@pytest.mark.parametrize(\"sequence_parallel\", [True, False])\n# @pytest.mark.parametrize('sequence_parallel', [False])\n@pytest.mark.parametrize(\"has_pos_emb\", [True, False])\n# @pytest.mark.parametrize('has_pos_emb', [True])\n@pytest.mark.parametrize(\"dim\", [1024])\ndef test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):\n    head_dim = 64\n    assert dim % head_dim == 0\n    num_heads = dim // head_dim\n    assert num_heads % world_size == 0\n    vocab_size = 50264\n    assert vocab_size % world_size == 0\n    num_layers = 2\n    rtol, atol = (3e-3, 1e-1) if dtype == torch.bfloat16 else (3e-3, 1e-2)\n    if not torch.distributed.is_initialized():\n        torch.distributed.init_process_group(backend=\"nccl\", init_method=\"env://\")\n    device = f\"cuda:{torch.distributed.get_rank()}\"\n    assert world_size <= torch.distributed.get_world_size()\n    parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)\n    rank = parallel_state.get_tensor_model_parallel_rank()\n    process_group = parallel_state.get_tensor_model_parallel_group()\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 8\n    seqlen = 1024\n    assert (batch_size * seqlen) % world_size == 0\n    input_ids = torch.randint(0, vocab_size, (batch_size, seqlen + 1), device=device)\n\n    # We need to generate g here so that all processes get the same gradient,\n    # as rank 0 will have an extra bias that changes the RNG.\n    g = torch.randn(batch_size * seqlen, device=device)\n\n    config = GPT2Config(\n        n_embd=dim,\n        n_head=num_heads,\n        n_layer=num_layers,\n        n_positions=seqlen if has_pos_emb else 0,\n        vocab_size=50257,\n        resid_pdrop=0.0,\n        embd_pdrop=0.0,\n        attn_pdrop=0.0,\n        scale_attn_by_inverse_layer_idx=True,\n        use_flash_attn=True,\n        fused_mlp=True,\n        fused_bias_fc=True,\n        fused_dropout_add_ln=True,\n        residual_in_fp32=True,\n        rotary_emb_fraction=0.0 if has_pos_emb else 0.5,\n        pad_vocab_size_multiple=8 * world_size,\n        sequence_parallel=sequence_parallel,\n    )\n    config.vocab_size = math.ceil(config.vocab_size / (8 * world_size)) * (8 * world_size)\n    model_pt = GPTLMHeadModel(config, device=device)\n\n    def init_layer_norm(module):\n        if isinstance(module, nn.LayerNorm):\n            nn.init.normal_(module.weight)\n            nn.init.normal_(module.bias)\n\n    model_pt.apply(init_layer_norm)\n\n    model = GPTLMHeadModel(config, process_group=process_group, device=device)\n    total_nparams = sum(p.numel() for p in model_pt.parameters())\n    sharded_nparams = sum(p.numel() for p in model.parameters())\n    sharded_nparams_all = torch.empty(world_size, dtype=torch.long, device=device)\n    torch.distributed.all_gather_into_tensor(\n        sharded_nparams_all, torch.tensor([sharded_nparams], device=device), group=process_group\n    )\n    shared_nparams = sum(\n        p.numel() for p in model.parameters() if getattr(p, \"_shared_params\", False)\n    )\n    shared_nparams_all = torch.empty(world_size, dtype=torch.long, device=device)\n    torch.distributed.all_gather_into_tensor(\n        shared_nparams_all, torch.tensor([shared_nparams], device=device), group=process_group\n    )\n    assert torch.all(shared_nparams_all == shared_nparams)\n    assert total_nparams == (\n        (sharded_nparams_all - shared_nparams_all).sum().item() + shared_nparams\n    )\n\n    # vocab_size has been rounded up here\n    partition_vocab_size = config.vocab_size // world_size\n    partition_dim = dim // world_size\n    partition_hidden_dim = 4 * dim // world_size\n    with torch.no_grad():\n        model.load_state_dict(shard_state_dict_tp(model_pt.state_dict(), config, world_size, rank))\n        model.tie_weights()\n\n    with torch.autocast(device_type=\"cuda\", dtype=dtype):\n        out = model(input_ids[:, :-1]).logits\n        if not sequence_parallel:\n            out = rearrange(out, \"b s d -> (b s) d\")\n        out_pt = rearrange(model_pt(input_ids[:, :-1]).logits, \"b s d -> (b s) d\")\n    partition_batch_dim = batch_size * seqlen // world_size\n    assert torch.allclose(\n        out,\n        out_pt[:, rank * partition_vocab_size : (rank + 1) * partition_vocab_size],\n        rtol=rtol,\n        atol=atol,\n    )\n    loss_fn = CrossEntropyLoss(inplace_backward=True, reduction=\"none\", process_group=process_group)\n    loss_fn_pt = CrossEntropyLoss(inplace_backward=True, reduction=\"none\")\n    loss = loss_fn(out, input_ids[:, 1:].flatten())\n    loss_pt = loss_fn_pt(out_pt, input_ids[:, 1:].flatten())\n    assert torch.allclose(loss, loss_pt, rtol=rtol, atol=atol)\n\n    loss_pt.backward(g)\n    loss.backward(g)\n    allreduce_sequence_parallel_grad(model, process_group)\n    parallel_state.destroy_model_parallel()\n\n    grad_dict = shard_state_dict_tp(\n        {k: v.grad for k, v in model_pt.named_parameters()}, config, world_size, rank\n    )\n\n    assert torch.allclose(\n        model.transformer.embeddings.word_embeddings.weight.grad,\n        grad_dict[\"transformer.embeddings.word_embeddings.weight\"],\n        rtol=rtol,\n        atol=atol * 5,\n    )\n    if has_pos_emb:\n        assert torch.allclose(\n            model.transformer.embeddings.position_embeddings.weight.grad,\n            grad_dict[\"transformer.embeddings.position_embeddings.weight\"],\n            rtol=rtol,\n            atol=atol,\n        )\n    assert torch.allclose(\n        model.transformer.ln_f.weight.grad,\n        grad_dict[\"transformer.ln_f.weight\"],\n        rtol=rtol,\n        atol=atol,\n    )\n    assert torch.allclose(\n        model.transformer.ln_f.bias.grad, grad_dict[\"transformer.ln_f.bias\"], rtol=rtol, atol=atol\n    )\n    for i in range(num_layers):\n        assert torch.allclose(\n            model.transformer.layers[i].mixer.Wqkv.weight.grad,\n            grad_dict[f\"transformer.layers.{i}.mixer.Wqkv.weight\"],\n            rtol=rtol,\n            atol=atol * 10,\n        )\n        assert torch.allclose(\n            model.transformer.layers[i].mixer.Wqkv.bias.grad,\n            grad_dict[f\"transformer.layers.{i}.mixer.Wqkv.bias\"],\n            rtol=rtol,\n            atol=atol * 10,\n        )\n        assert torch.allclose(\n            model.transformer.layers[i].mixer.out_proj.weight.grad,\n            grad_dict[f\"transformer.layers.{i}.mixer.out_proj.weight\"],\n            rtol=rtol,\n            atol=atol * 10,\n        )\n        if rank == 0:\n            assert torch.allclose(\n                model.transformer.layers[i].mixer.out_proj.bias.grad,\n                grad_dict[f\"transformer.layers.{i}.mixer.out_proj.bias\"],\n                rtol=rtol,\n                atol=atol * 5,\n            )\n        assert torch.allclose(\n            model.transformer.layers[i].mlp.fc1.weight.grad,\n            grad_dict[f\"transformer.layers.{i}.mlp.fc1.weight\"],\n            rtol=rtol,\n            atol=atol * 10,\n        )\n        assert torch.allclose(\n            model.transformer.layers[i].mlp.fc1.bias.grad,\n            grad_dict[f\"transformer.layers.{i}.mlp.fc1.bias\"],\n            rtol=rtol,\n            atol=atol * 10,\n        )\n        assert torch.allclose(\n            model.transformer.layers[i].mlp.fc2.weight.grad,\n            grad_dict[f\"transformer.layers.{i}.mlp.fc2.weight\"],\n            rtol=rtol,\n            atol=atol * 10,\n        )\n        if rank == 0:\n            assert torch.allclose(\n                model.transformer.layers[i].mlp.fc2.bias.grad,\n                grad_dict[f\"transformer.layers.{i}.mlp.fc2.bias\"],\n                rtol=rtol,\n                atol=atol * 5,\n            )\n\n        assert torch.allclose(\n            model.transformer.layers[i].norm1.weight.grad,\n            grad_dict[f\"transformer.layers.{i}.norm1.weight\"],\n            rtol=rtol,\n            atol=atol,\n        )\n        assert torch.allclose(\n            model.transformer.layers[i].norm1.bias.grad,\n            grad_dict[f\"transformer.layers.{i}.norm1.bias\"],\n            rtol=rtol,\n            atol=atol,\n        )\n        assert torch.allclose(\n            model.transformer.layers[i].norm2.weight.grad,\n            grad_dict[f\"transformer.layers.{i}.norm2.weight\"],\n            rtol=rtol,\n            atol=atol,\n        )\n        assert torch.allclose(\n            model.transformer.layers[i].norm2.bias.grad,\n            grad_dict[f\"transformer.layers.{i}.norm2.bias\"],\n            rtol=rtol,\n            atol=atol,\n        )\n"
  },
  {
    "path": "tests/models/test_gptj.py",
    "content": "# Copyright (c) 2023, Tri Dao.\n\nimport time\n\nimport pytest\nimport torch\nfrom flash_attn.models.gpt import GPTLMHeadModel\nfrom flash_attn.models.gptj import gptj_config_to_gpt2_config, remap_state_dict_hf_gptj\nfrom flash_attn.utils.generation import update_graph_cache\nfrom flash_attn.utils.pretrained import state_dict_from_pretrained\nfrom transformers import AutoTokenizer, GPTJConfig\nfrom transformers.models.gptj.modeling_gptj import GPTJForCausalLM\n\n\n@pytest.mark.parametrize(\"model_name\", [\"EleutherAI/gpt-j-6B\"])\ndef test_gptj_state_dict(model_name):\n    config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name))\n    pretrained_state_dict = remap_state_dict_hf_gptj(state_dict_from_pretrained(model_name), config)\n    model = GPTLMHeadModel(config, device=\"meta\")  # Without device='meta' init is very slow\n    state_dict = model.state_dict()\n    assert state_dict.keys() == pretrained_state_dict.keys()\n    for k in state_dict.keys():\n        assert state_dict[k].shape == pretrained_state_dict[k].shape\n\n\n@pytest.mark.parametrize(\"model_name\", [\"EleutherAI/gpt-j-6B\", \"togethercomputer/GPT-JT-6B-v1\"])\ndef test_gptj_optimized(model_name):\n    \"\"\"Check that our implementation of GPT-J (with all optimizations enabled) matches the\n    HF implementation: the output of our forward pass in fp16 should be around the same as the HF\n    forward pass in fp16, when compared to the HF forward pass in fp32.\n    \"\"\"\n    dtype = torch.float16\n    device = \"cuda\"\n    config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name))\n    config.use_flash_attn = True  # FlashAttention-2 supports headdim 256\n    config.fused_bias_fc = True\n    config.fused_mlp = True\n    config.fused_dropout_add_ln = True\n    config.residual_in_fp32 = True\n\n    model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)\n    model.eval()\n\n    torch.manual_seed(0)\n    batch_size = 2\n    max_seqlen = 256\n    input_ids = torch.randint(\n        0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device\n    )\n    with torch.no_grad():\n        out = model.transformer(input_ids)\n        logits = model(input_ids).logits\n    del model\n\n    # Without device_map, the model is loaded on the CPU, which is very slow\n    model_ref = GPTJForCausalLM.from_pretrained(model_name, device_map={\"\": device})\n    model_ref.eval()\n    with torch.no_grad():\n        out_ref = model_ref.transformer(input_ids).last_hidden_state\n        logits_ref = model_ref(input_ids).logits\n    del model_ref\n\n    model_hf = GPTJForCausalLM.from_pretrained(\n        model_name, torch_dtype=dtype, device_map={\"\": device}\n    )\n    model_hf.eval()\n    out_hf = model_hf.transformer(input_ids).last_hidden_state\n    logits_hf = model_hf(input_ids).logits\n    del model_hf\n\n    print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n    print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n    print(f\"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}\")\n    print(f\"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}\")\n    assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()\n\n    print(f\"Logits max diff: {(logits - logits_ref).abs().max().item()}\")\n    print(f\"Logits mean diff: {(logits - logits_ref).abs().mean().item()}\")\n    print(f\"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}\")\n    print(f\"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}\")\n    assert (logits - logits_ref).abs().max().item() < 3 * (\n        logits_hf - logits_ref\n    ).abs().max().item()\n\n\n@pytest.mark.parametrize(\"model_name\", [\"EleutherAI/gpt-j-6B\"])\ndef test_gptj_generation(model_name):\n    \"\"\"Check that our implementation of GPT-J (with all optimizations enabled) matches the\n    HF implementation: the output of our forward pass in fp16 should be around the same as the HF\n    forward pass in fp16, when compared to the HF forward pass in fp32.\n    \"\"\"\n    dtype = torch.float16\n    device = \"cuda\"\n    config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name))\n    config.use_flash_attn = True  # FlashAttention-2 supports headdim 256\n    config.fused_bias_fc = True\n    config.fused_mlp = True\n    config.fused_dropout_add_ln = True\n    # Only prenorm supports residual_in_fp32\n    config.residual_in_fp32 = True\n\n    tokenizer = AutoTokenizer.from_pretrained(model_name)\n    eos_token_id = tokenizer.eos_token_id\n\n    torch.manual_seed(0)\n    batch_size = 1\n    seqlen = 100\n    max_length = 150\n    input_ids = torch.randint(\n        0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device\n    )\n\n    model_hf = GPTJForCausalLM.from_pretrained(\n        model_name, torch_dtype=dtype, device_map={\"\": device}\n    )\n    model_hf.eval()\n    print(\"HF fp16\")\n    torch.cuda.synchronize()\n    start = time.time()\n    out_hf = model_hf.generate(\n        input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True\n    )\n    torch.cuda.synchronize()\n    print(f\"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms\")\n    del model_hf\n\n    model_ref = GPTJForCausalLM.from_pretrained(model_name, device_map={\"\": device})\n    model_ref.eval()\n    with torch.no_grad():\n        logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1]\n    del model_ref\n\n    model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)\n    model.eval()\n\n    print(\"Without CUDA graph\")\n    torch.cuda.synchronize()\n    start = time.time()\n    out = model.generate(\n        input_ids=input_ids,\n        max_length=max_length,\n        eos_token_id=eos_token_id,\n        return_dict_in_generate=True,\n        output_scores=True,\n        enable_timing=True,\n        teacher_outputs=out_hf.sequences,\n    )\n    torch.cuda.synchronize()\n    print(f\"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms\")\n\n    # Capture graph outside the timing loop\n    batch_size, seqlen_og = input_ids.shape\n    model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)\n    print(\"With CUDA graph\")\n    torch.cuda.synchronize()\n    start = time.time()\n    out_cg = model.generate(\n        input_ids=input_ids,\n        max_length=max_length,\n        cg=True,\n        return_dict_in_generate=True,\n        output_scores=True,\n        enable_timing=True,\n        teacher_outputs=out_hf.sequences,\n    )\n    torch.cuda.synchronize()\n    print(f\"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms\")\n\n    with torch.no_grad():\n        logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1) : -1]\n    logits_hf = torch.stack(out_hf.scores, dim=1)\n    logits = torch.stack(out.scores, dim=1)\n    logits_cg = torch.stack(out_cg.scores, dim=1)\n\n    del model\n\n    hf_error = (logits_hf - logits_ref).abs().max().item()\n    assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error\n\n    print(f\"HF fp16 logits max diff: {hf_error}\")\n    print(f\"Logits max diff: {(logits - logits_ref).abs().max().item() }\")\n    assert (logits - logits_ref).abs().max().item() < 2 * hf_error\n    print(f\"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }\")\n    assert torch.equal(logits_cg, logits)\n"
  },
  {
    "path": "tests/models/test_llama.py",
    "content": "# Copyright (c) 2023, Tri Dao.\n\n# To run the huggingface implementation of LLaMa (1), we first need to convert the weights:\n# https://github.com/huggingface/transformers/pull/21955\n# python -m transformers.models.llama.convert_llama_weights_to_hf --input_dir $CHECKPOINT_DIR/llama --model_size 7B --output_dir $CHECKPOINT_DIR/llama/7B-hf\n# and repeat for 13B, 30B, 65B\n\nimport os\nimport time\nfrom pathlib import Path\n\ncurrent_dir = Path(__file__).parent.absolute()\n\nimport shutil\n\nimport pytest\nimport torch\nfrom einops import rearrange\nfrom flash_attn.models.gpt import GPTLMHeadModel, combine_state_dicts_tp, shard_state_dict_tp\nfrom flash_attn.models.llama import (\n    config_from_checkpoint,\n    inv_remap_state_dict_hf_llama,\n    llama_config_to_gpt2_config,\n    remap_state_dict_hf_llama,\n    remap_state_dict_meta_llama,\n    state_dicts_from_checkpoint,\n)\nfrom flash_attn.utils.distributed import all_gather_raw\nfrom flash_attn.utils.generation import update_graph_cache\nfrom flash_attn.utils.pretrained import state_dict_from_pretrained\nfrom transformers import LlamaConfig, LlamaTokenizer\nfrom transformers.models.llama.modeling_llama import LlamaForCausalLM\nfrom transformers import AutoConfig\n\n\ndef _pretrained_state_dict_from_checkpoint(checkpoint_path, model_name, config, checkpoint_format):\n    if checkpoint_format == \"meta\":\n        ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name)\n        pretrained_state_dicts = [remap_state_dict_meta_llama(s, config) for s in ckpt_state_dicts]\n        pretrained_state_dict = combine_state_dicts_tp(pretrained_state_dicts, config)\n    else:\n        pretrained_state_dict = state_dict_from_pretrained(\n            Path(checkpoint_path) / f\"{model_name}-hf\"\n        )\n        pretrained_state_dict = remap_state_dict_hf_llama(pretrained_state_dict, config)\n    return pretrained_state_dict\n\n\n@pytest.mark.parametrize(\"model_name\", [\"7B\"])\ndef test_llama_state_dict(model_name):\n    checkpoint_path = (\n        Path(os.environ.get(\"CHECKPOINT_DIR\", current_dir.parent.parent / \"checkpoints\")) / \"llama\"\n    )\n    config = llama_config_to_gpt2_config(config_from_checkpoint(checkpoint_path, model_name))\n    ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name)\n    pretrained_state_dict = remap_state_dict_meta_llama(ckpt_state_dicts[0], config)\n    model = GPTLMHeadModel(config, device=\"meta\")  # Without device='meta' init is very slow\n    state_dict = model.state_dict()\n    assert state_dict.keys() == pretrained_state_dict.keys()\n    for k in state_dict.keys():\n        assert state_dict[k].shape == pretrained_state_dict[k].shape\n\n\n# TinyLlama-1.1B is to test MQA\n@pytest.mark.parametrize(\n    \"model_name\", [\"meta-llama/Llama-2-7b-hf\", \"PY007/TinyLlama-1.1B-step-50K-105b\"]\n)\ndef test_inv_remap_state_dict_hf_llama(model_name):\n    config = llama_config_to_gpt2_config(\n        AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n    )\n    state_dict = state_dict_from_pretrained(model_name)\n    # inv_remap_state_dict_hf_llama should be the inverse of remap_state_dict_hf_llama\n    state_dict = {key: val for key, val in state_dict.items() if \"rotary_emb.inv_freq\" not in key}\n    pretrained_state_dict = remap_state_dict_hf_llama(state_dict, config)\n    state_dict_recover = inv_remap_state_dict_hf_llama(pretrained_state_dict, config)\n    assert set(state_dict_recover.keys()) == set(state_dict.keys())\n    for key in state_dict_recover.keys():\n        torch.testing.assert_close(state_dict_recover[key], state_dict[key])\n\n\n# TinyLlama-1.1B is to test MQA\n@pytest.mark.parametrize(\n    \"model_name\",\n    [\n        \"7B\",  # Llama 1\n        \"13B\",  # Llama 1\n        \"meta-llama/Llama-2-13b-hf\",\n        \"codellama/CodeLlama-7b-hf\",\n        \"codellama/CodeLlama-13b-hf\",\n        \"codellama/CodeLlama-34b-hf\",\n        \"PY007/TinyLlama-1.1B-step-50K-105b\",\n    ],\n)\ndef test_llama_optimized(model_name):\n    \"\"\"Check that our implementation of LLaMa (with all optimizations enabled) matches the\n    HF implementation: the output of our forward pass in fp16 should be around the same as the HF\n    forward pass in fp16, when compared to the HF forward pass in fp32.\n    \"\"\"\n    checkpoint_path = (\n        Path(os.environ.get(\"CHECKPOINT_DIR\", current_dir.parent.parent / \"checkpoints\")) / \"llama\"\n    )\n\n    dtype = torch.float16\n    device = \"cuda\"\n    if \"/\" in model_name:  # Download from HF\n        config = llama_config_to_gpt2_config(\n            AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n        )\n    else:\n        config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format=\"meta\")\n        config = llama_config_to_gpt2_config(config)\n    config.use_flash_attn = True\n    config.fused_bias_fc = True\n    config.fused_mlp = False  # We don't have fused GatedMLP yet\n    config.fused_dropout_add_ln = True\n    config.residual_in_fp32 = True\n\n    if \"/\" in model_name:  # Download from HF\n        pretrained_state_dict = remap_state_dict_hf_llama(\n            state_dict_from_pretrained(model_name), config\n        )\n    else:\n        pretrained_state_dict = _pretrained_state_dict_from_checkpoint(\n            checkpoint_path, model_name, config, checkpoint_format=\"meta\"\n        )\n    model = GPTLMHeadModel(config, device=device, dtype=dtype)\n    model.load_state_dict(pretrained_state_dict)\n    model.eval()\n\n    torch.manual_seed(0)\n    batch_size = 2\n    max_seqlen = 256\n    seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)\n    input_ids = torch.randint(\n        0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device\n    )\n    with torch.no_grad():\n        out = model.transformer(input_ids)\n        logits = model(input_ids).logits\n    del model\n\n    # Without device_map, the model is loaded on the CPU, which is very slow\n    # Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB\n    model_ref = LlamaForCausalLM.from_pretrained(\n        model_name if \"/\" in model_name else Path(checkpoint_path) / f\"{model_name}-hf\",\n        device_map=\"auto\",\n    )\n    model_ref.eval()\n    with torch.no_grad():\n        out_ref = model_ref.model(input_ids).last_hidden_state.to(device=device)\n        logits_ref = model_ref(input_ids).logits.to(device=device)\n    del model_ref\n\n    model_hf = LlamaForCausalLM.from_pretrained(\n        model_name if \"/\" in model_name else Path(checkpoint_path) / f\"{model_name}-hf\",\n        torch_dtype=dtype,\n        device_map={\"\": device},\n    )\n    model_hf.eval()\n    with torch.no_grad():\n        out_hf = model_hf.model(input_ids).last_hidden_state\n        logits_hf = model_hf(input_ids).logits\n    del model_hf\n\n    print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n    print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n    print(f\"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}\")\n    print(f\"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}\")\n    assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()\n\n    print(f\"Logits max diff: {(logits - logits_ref).abs().max().item()}\")\n    print(f\"Logits mean diff: {(logits - logits_ref).abs().mean().item()}\")\n    print(f\"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}\")\n    print(f\"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}\")\n    assert (logits - logits_ref).abs().max().item() < 3 * (\n        logits_hf - logits_ref\n    ).abs().max().item()\n\n\n# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k \"parallel\"\n@pytest.mark.parametrize(\"world_size\", [2])\n@pytest.mark.parametrize(\n    \"model_name\", [\"13B\", \"meta-llama/Llama-2-13b-hf\", \"codellama/CodeLlama-34b-hf\"]\n)\ndef test_llama_parallel(model_name, world_size):\n    \"\"\"Check that our implementation of LLaMa (with all optimizations enabled) matches the\n    HF implementation: the output of our forward pass in fp16 should be around the same as the HF\n    forward pass in fp16, when compared to the HF forward pass in fp32.\n    \"\"\"\n    from apex.transformer import parallel_state\n\n    checkpoint_path = (\n        Path(os.environ.get(\"CHECKPOINT_DIR\", current_dir.parent.parent / \"checkpoints\")) / \"llama\"\n    )\n\n    dtype = torch.float16\n    if \"/\" in model_name:  # Download from HF\n        config = llama_config_to_gpt2_config(\n            AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n        )\n    else:\n        config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format=\"meta\")\n        config = llama_config_to_gpt2_config(config)\n    config.use_flash_attn = True\n    config.fused_bias_fc = True\n    config.fused_mlp = False  # We don't have fused GatedMLP yet\n    config.fused_dropout_add_ln = True\n    config.residual_in_fp32 = True\n\n    if not torch.distributed.is_initialized():\n        torch.distributed.init_process_group(backend=\"nccl\", init_method=\"env://\")\n    device = f\"cuda:{torch.distributed.get_rank()}\"\n    assert world_size <= torch.distributed.get_world_size()\n    parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)\n    rank = parallel_state.get_tensor_model_parallel_rank()\n    process_group = parallel_state.get_tensor_model_parallel_group()\n\n    if \"/\" in model_name:  # Download from HF\n        pretrained_state_dict = remap_state_dict_hf_llama(\n            state_dict_from_pretrained(model_name), config\n        )\n    else:\n        pretrained_state_dict = _pretrained_state_dict_from_checkpoint(\n            checkpoint_path, model_name, config, checkpoint_format=\"meta\"\n        )\n    model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)\n    model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))\n    model.eval()\n\n    torch.manual_seed(0)\n    batch_size = 2\n    max_seqlen = 256\n    seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)\n    input_ids = torch.randint(\n        0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device\n    )\n    with torch.no_grad():\n        out = model.transformer(input_ids)\n        out, _ = all_gather_raw(out, process_group=process_group)\n        out = rearrange(out, \"(b s) d -> b s d\", b=batch_size)\n        logits = model(input_ids).logits\n        logits = rearrange(logits, \"(b s) d -> b s d\", b=batch_size)\n        logits, _ = all_gather_raw(logits, process_group)\n        logits = rearrange(logits, \"(n b) ... d -> b ... (n d)\", b=batch_size)\n    del model\n\n    if rank == 0:\n        # Without device_map, the model is loaded on the CPU, which is very slow\n        model_ref = LlamaForCausalLM.from_pretrained(\n            model_name if \"/\" in model_name else Path(checkpoint_path) / f\"{model_name}-hf\",\n            device_map=\"auto\",\n        )\n        model_ref.eval()\n        with torch.no_grad():\n            out_ref = model_ref.model(input_ids).last_hidden_state.to(device=device)\n            logits_ref = model_ref(input_ids).logits.to(device=device)\n        del model_ref\n\n        model_hf = LlamaForCausalLM.from_pretrained(\n            model_name if \"/\" in model_name else Path(checkpoint_path) / f\"{model_name}-hf\",\n            torch_dtype=dtype,\n            device_map=\"auto\",\n        )\n        model_hf.eval()\n        with torch.no_grad():\n            out_hf = model_hf.model(input_ids).last_hidden_state.to(device=device)\n            logits_hf = model_hf(input_ids).logits.to(device=device)\n        del model_hf\n\n        print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n        print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n        print(f\"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}\")\n        print(f\"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}\")\n        assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item()\n\n        print(f\"Logits max diff: {(logits - logits_ref).abs().max().item()}\")\n        print(f\"Logits mean diff: {(logits - logits_ref).abs().mean().item()}\")\n        print(f\"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}\")\n        print(f\"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}\")\n        assert (logits - logits_ref).abs().max().item() < 2 * (\n            logits_hf - logits_ref\n        ).abs().max().item()\n\n\n# @pytest.mark.parametrize('model_name', [\"7B\", \"13B\"])\n@pytest.mark.parametrize(\"model_name\", [\"7B\"])\n@pytest.mark.parametrize(\"checkpoint_format\", [\"meta\", \"hf\"])\ndef test_llama_generation(model_name, checkpoint_format):\n    checkpoint_path = (\n        Path(os.environ.get(\"CHECKPOINT_DIR\", current_dir.parent.parent / \"checkpoints\")) / \"llama\"\n    )\n\n    dtype = torch.float16\n    device = \"cuda\"\n    config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format)\n    config = llama_config_to_gpt2_config(config)\n    config.use_flash_attn = True\n    config.fused_bias_fc = True\n    config.fused_mlp = False  # We don't have fused GatedMLP yet\n    config.fused_dropout_add_ln = True\n    config.residual_in_fp32 = True\n\n    tokenizer = LlamaTokenizer.from_pretrained(Path(checkpoint_path) / f\"{model_name}-hf\")\n    eos_token_id = tokenizer.eos_token_id\n\n    torch.manual_seed(0)\n    batch_size = 1\n    seqlen = 100\n    max_length = 150\n    input_ids = torch.randint(\n        0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device\n    )\n\n    model_hf = LlamaForCausalLM.from_pretrained(\n        Path(checkpoint_path) / f\"{model_name}-hf\", torch_dtype=dtype, device_map={\"\": device}\n    )\n    model_hf.eval()\n    print(\"HF fp16\")\n    torch.cuda.synchronize()\n    start = time.time()\n    out_hf = model_hf.generate(\n        input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True\n    )\n    torch.cuda.synchronize()\n    print(f\"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms\")\n    del model_hf\n\n    # Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB\n    model_ref = LlamaForCausalLM.from_pretrained(\n        Path(checkpoint_path) / f\"{model_name}-hf\", device_map=\"auto\"\n    )\n    model_ref.eval()\n    with torch.no_grad():\n        logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1].to(device=device)\n    del model_ref\n\n    pretrained_state_dict = _pretrained_state_dict_from_checkpoint(\n        checkpoint_path, model_name, config, checkpoint_format\n    )\n    model = GPTLMHeadModel(config, device=device, dtype=dtype)\n    model.load_state_dict(pretrained_state_dict)\n    model.eval()\n\n    print(\"Without CUDA graph\")\n    torch.cuda.synchronize()\n    start = time.time()\n    out = model.generate(\n        input_ids=input_ids,\n        max_length=max_length,\n        eos_token_id=eos_token_id,\n        return_dict_in_generate=True,\n        output_scores=True,\n        enable_timing=True,\n        teacher_outputs=out_hf.sequences,\n    )\n    torch.cuda.synchronize()\n    print(f\"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms\")\n\n    # Capture graph outside the timing loop\n    batch_size, seqlen_og = input_ids.shape\n    model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)\n    print(\"With CUDA graph\")\n    torch.cuda.synchronize()\n    start = time.time()\n    out_cg = model.generate(\n        input_ids=input_ids,\n        max_length=max_length,\n        cg=True,\n        return_dict_in_generate=True,\n        output_scores=True,\n        enable_timing=True,\n        teacher_outputs=out_hf.sequences,\n    )\n    torch.cuda.synchronize()\n    print(f\"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms\")\n\n    with torch.no_grad():\n        logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1) : -1]\n    logits_hf = torch.stack(out_hf.scores, dim=1)\n    logits = torch.stack(out.scores, dim=1)\n    logits_cg = torch.stack(out_cg.scores, dim=1)\n\n    del model\n\n    hf_error = (logits_hf - logits_ref).abs().max().item()\n\n    print(f\"HF fp16 logits max diff: {hf_error}\")\n    print(f\"Logits max diff: {(logits - logits_ref).abs().max().item()}\")\n    print(f\"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item()}\")\n\n    assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error\n    assert (logits - logits_ref).abs().max().item() < 2 * hf_error\n    assert torch.equal(logits_cg, logits)\n\n\n# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k \"llama_parallel_generation\"\n@pytest.mark.parametrize(\"world_size\", [2])\n@pytest.mark.parametrize(\n    \"model_name\", [\"13B\", \"meta-llama/Llama-2-13b-hf\", \"codellama/CodeLlama-34b-hf\"]\n)\ndef test_llama_parallel_generation(model_name, world_size):\n    \"\"\"Check that our implementation matches the HF implementation:\n    the scores in fp16 should be around the same as the HF scores in fp16, when compared to\n    the HF scores in fp32.\n    \"\"\"\n    from apex.transformer import parallel_state\n\n    checkpoint_path = (\n        Path(os.environ.get(\"CHECKPOINT_DIR\", current_dir.parent.parent / \"checkpoints\")) / \"llama\"\n    )\n\n    dtype = torch.float16\n    if \"/\" in model_name:  # Download from HF\n        config = llama_config_to_gpt2_config(\n            AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n        )\n    else:\n        config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format=\"meta\")\n        config = llama_config_to_gpt2_config(config)\n    config.use_flash_attn = True\n    config.fused_bias_fc = True\n    config.fused_mlp = False  # We don't have fused GatedMLP yet\n    config.fused_dropout_add_ln = True\n    config.residual_in_fp32 = True\n    config.pad_vocab_size_multiple = 8 * world_size\n    config.sequence_parallel = False  # Need to set this to False for generation\n\n    os.environ[\"NCCL_ASYNC_ERROR_HANDLING\"] = \"0\"\n    if not torch.distributed.is_initialized():\n        torch.distributed.init_process_group(backend=\"nccl\", init_method=\"env://\")\n    device = f\"cuda:{torch.distributed.get_rank()}\"\n    assert world_size <= torch.distributed.get_world_size()\n    parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)\n    rank = parallel_state.get_tensor_model_parallel_rank()\n    process_group = parallel_state.get_tensor_model_parallel_group()\n\n    torch.manual_seed(0)\n    batch_size = 1\n    seqlen = 100\n    max_length = 150\n    input_ids = torch.randint(\n        0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device\n    )\n\n    # Need this, otherwise when we capture the graph the process for GPU 1 would run on both\n    # GPU0 and GPU1 and things would hang\n    torch.cuda.set_device(device)\n\n    if \"/\" in model_name:  # Download from HF\n        pretrained_state_dict = remap_state_dict_hf_llama(\n            state_dict_from_pretrained(model_name), config\n        )\n    else:\n        pretrained_state_dict = _pretrained_state_dict_from_checkpoint(\n            checkpoint_path, model_name, config, checkpoint_format=\"meta\"\n        )\n    model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)\n    model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))\n    model.eval()\n\n    print(\"Without CUDA graph\")\n    out = model.generate(\n        input_ids=input_ids,\n        max_length=max_length,\n        tensor_parallel=world_size,\n        vocab_size=config.vocab_size,\n        # teacher_outputs=out_hf.sequences,\n        return_dict_in_generate=True,\n        output_scores=True,\n        enable_timing=True,\n    )\n\n    # Capture graph outside the timing loop\n    batch_size, seqlen_og = input_ids.shape\n    model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)\n    print(\"With CUDA graph\")\n    out_cg = model.generate(\n        input_ids=input_ids,\n        max_length=max_length,\n        tensor_parallel=world_size,\n        vocab_size=config.vocab_size,\n        cg=True,\n        # teacher_outputs=out_hf.sequences,\n        return_dict_in_generate=True,\n        output_scores=True,\n        enable_timing=True,\n    )\n    del model\n    parallel_state.destroy_model_parallel()\n\n    if rank == 0:\n        # Without device_map, the model is loaded on the CPU, which is very slow\n        model_hf = LlamaForCausalLM.from_pretrained(\n            model_name if \"/\" in model_name else Path(checkpoint_path) / f\"{model_name}-hf\",\n            torch_dtype=dtype,\n            device_map=\"auto\",\n        )\n        model_hf.eval()\n        print(\"HF fp16\")\n        torch.cuda.synchronize()\n        start = time.time()\n        with torch.inference_mode():\n            out_hf = model_hf.generate(\n                input_ids=input_ids,\n                max_length=max_length,\n                return_dict_in_generate=True,\n                output_scores=True,\n            )\n        torch.cuda.synchronize()\n        print(f\"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms\")\n        del model_hf\n\n        model_ref = LlamaForCausalLM.from_pretrained(\n            model_name if \"/\" in model_name else Path(checkpoint_path) / f\"{model_name}-hf\",\n            device_map=\"auto\",\n        )\n        model_ref.eval()\n        with torch.inference_mode():\n            logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1]\n        del model_ref\n        logits_hf = torch.stack(out_hf.scores, dim=1)\n\n        logits = torch.stack(out.scores, dim=1)\n        logits_cg = torch.stack(out_cg.scores, dim=1)\n\n        hf_error = (logits_hf - logits_ref).abs().max().item()\n        print(f\"HF fp16 logits max diff: {hf_error}\")\n        print(f\"Logits max diff: {(logits - logits_ref).abs().max().item()}\")\n        assert (logits - logits_ref).abs().max().item() < 2 * hf_error\n        print(f\"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item()}\")\n        assert torch.equal(logits_cg, logits)\n\n\n@torch.no_grad()\n@pytest.mark.parametrize(\"world_size\", [2])\ndef test_llama_parallel_uneven_num_heads(world_size):\n    from apex.transformer import parallel_state\n\n    checkpoint_path = (\n        Path(os.environ.get(\"CHECKPOINT_DIR\", current_dir.parent.parent / \"checkpoints\")) / \"llama\"\n    )\n    num_attention_heads = world_size + 1\n    model_name = f\"teeny-{num_attention_heads}-heads\"\n\n    if not torch.distributed.is_initialized():\n        torch.distributed.init_process_group(backend=\"nccl\", init_method=\"env://\")\n    device = f\"cuda:{torch.distributed.get_rank()}\"\n    assert world_size <= torch.distributed.get_world_size()\n    parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)\n    rank = parallel_state.get_tensor_model_parallel_rank()\n    process_group = parallel_state.get_tensor_model_parallel_group()\n\n    dtype = torch.float16\n    llama_config = LlamaConfig(\n        hidden_size=256\n        * num_attention_heads,  # ParallelGatedMlp hidden_features must be divisible by 256\n        intermediate_size=256 * num_attention_heads * 4,\n        num_hidden_layers=4,\n        num_attention_heads=num_attention_heads,\n        initializer_range=0.5,  # Set crazy init range so we don't have near zero weights implying a vacuous test.\n    )\n    config = llama_config_to_gpt2_config(llama_config)\n    config.use_flash_attn = True\n    config.fused_bias_fc = True\n    config.fused_mlp = False  # We don't have fused GatedMLP yet\n    config.fused_dropout_add_ln = True\n    config.residual_in_fp32 = True\n\n    torch.manual_seed(0)\n    batch_size = 2\n    max_seqlen = 256\n    seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)\n    input_ids = torch.randint(\n        0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device\n    )\n\n    # Create a shared test model.\n    if rank == 0:\n        LlamaForCausalLM(config=llama_config).save_pretrained(checkpoint_path / f\"{model_name}-hf\")\n    torch.distributed.barrier()\n\n    # Run the standard forward pass test.\n    pretrained_state_dict = _pretrained_state_dict_from_checkpoint(\n        checkpoint_path, model_name, config, checkpoint_format=\"hf\"\n    )\n    model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)\n    model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))\n    model.eval()\n\n    # TODO: Avoid duplicate code. Modularize the comparison of two forward pass diffs.\n    out = model.transformer(input_ids)\n    out, _ = all_gather_raw(out, process_group=process_group)\n    out = rearrange(out, \"(b s) d -> b s d\", b=batch_size)\n    logits = model(input_ids).logits\n    logits = rearrange(logits, \"(b s) d -> b s d\", b=batch_size)\n    logits, _ = all_gather_raw(logits, process_group)\n    logits = rearrange(logits, \"(n b) ... d -> b ... (n d)\", b=batch_size)\n\n    if rank == 0:\n        model_ref = LlamaForCausalLM.from_pretrained(\n            Path(checkpoint_path) / f\"{model_name}-hf\", device_map={\"\": device}\n        )\n        model_ref = model_ref.to(device=device)\n        model_ref.eval()\n        out_ref = model_ref.model(input_ids).last_hidden_state\n        logits_ref = model_ref(input_ids).logits\n        del model_ref\n\n        model_hf = LlamaForCausalLM.from_pretrained(\n            Path(checkpoint_path) / f\"{model_name}-hf\", torch_dtype=dtype, device_map={\"\": device}\n        )\n        model_hf.eval()\n        out_hf = model_hf.model(input_ids).last_hidden_state.to(device=device)\n        logits_hf = model_hf(input_ids).logits.to(device=device)\n        del model_hf\n\n        print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n        print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n        print(f\"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}\")\n        print(f\"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}\")\n        assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item()\n\n        print(f\"Logits max diff: {(logits - logits_ref).abs().max().item()}\")\n        print(f\"Logits mean diff: {(logits - logits_ref).abs().mean().item()}\")\n        print(f\"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}\")\n        print(f\"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}\")\n        assert (logits - logits_ref).abs().max().item() < 2 * (\n            logits_hf - logits_ref\n        ).abs().max().item()\n\n        if os.path.exists(checkpoint_path / f\"{model_name}-hf\"):\n            shutil.rmtree(checkpoint_path / f\"{model_name}-hf\")\n"
  },
  {
    "path": "tests/models/test_opt.py",
    "content": "import re\nimport time\n\nimport pytest\nimport torch\nfrom einops import rearrange\nfrom flash_attn.models.gpt import GPTLMHeadModel\nfrom flash_attn.models.opt import opt_config_to_gpt2_config, remap_state_dict_hf_opt\nfrom flash_attn.utils.generation import update_graph_cache\nfrom flash_attn.utils.pretrained import state_dict_from_pretrained\nfrom transformers import AutoTokenizer, OPTConfig\nfrom transformers.models.opt.modeling_opt import OPTForCausalLM\n\n\n@pytest.mark.parametrize(\n    \"model_name\", [\"facebook/opt-125m\", \"facebook/opt-350m\", \"facebook/opt-1.3b\"]\n)\n# @pytest.mark.parametrize('model_name', [\"facebook/opt-350m\"])\ndef test_opt_state_dict(model_name):\n    config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name))\n    pretrained_state_dict = remap_state_dict_hf_opt(state_dict_from_pretrained(model_name), config)\n    model = GPTLMHeadModel(config)\n    state_dict = model.state_dict()\n    assert state_dict.keys() == pretrained_state_dict.keys()\n    for k in state_dict.keys():\n        assert state_dict[k].shape == pretrained_state_dict[k].shape\n\n\n@pytest.mark.parametrize(\n    \"model_name\", [\"facebook/opt-125m\", \"facebook/opt-350m\", \"facebook/opt-1.3b\"]\n)\n# @pytest.mark.parametrize('model_name', [\"facebook/opt-350m\"])\ndef test_opt_optimized(model_name):\n    \"\"\"Check that our implementation of OPT (without all optimizations enabled) matches the\n    HF implementation: the output of our forward pass in fp16 should be around the same as the HF\n    forward pass in fp16, when compared to the HF forward pass in fp32.\n    \"\"\"\n    dtype = torch.float16\n    device = \"cuda\"\n    config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name))\n    config.use_flash_attn = True\n    config.fused_bias_fc = True\n    config.fused_mlp = True\n    config.fused_dropout_add_ln = True\n    # Only prenorm supports residual_in_fp32\n    config.residual_in_fp32 = getattr(config, \"prenorm\", True)\n    config.pad_vocab_size_multiple = 8\n\n    model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)\n\n    model_ref = OPTForCausalLM.from_pretrained(model_name).to(device=device)\n    model_hf = OPTForCausalLM.from_pretrained(model_name, torch_dtype=dtype).to(device=device)\n\n    model.eval()\n    model_ref.eval()\n    model_hf.eval()\n\n    torch.manual_seed(0)\n    batch_size = 2\n    max_seqlen = 256\n    seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=\"cuda\")\n    input_ids = torch.randint(\n        0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=\"cuda\"\n    )\n    if model_name != \"facebook/opt-350m\":  # The OPT-350m projects the embeddings to dimension 512\n        out = model.transformer(input_ids)\n        out_hf = model_hf.model(input_ids).last_hidden_state\n        out_ref = model_ref.model(input_ids).last_hidden_state\n\n        print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n        print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n        print(f\"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}\")\n        print(f\"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}\")\n        assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()\n\n    logits = model(input_ids).logits\n    logits_hf = model_hf(input_ids).logits\n    logits_ref = model_ref(input_ids).logits\n\n    print(f\"Logits max diff: {(logits - logits_ref).abs().max().item()}\")\n    print(f\"Logits mean diff: {(logits - logits_ref).abs().mean().item()}\")\n    print(f\"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}\")\n    print(f\"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}\")\n    assert (logits - logits_ref).abs().max().item() < 3 * (\n        logits_hf - logits_ref\n    ).abs().max().item()\n\n\n@pytest.mark.parametrize(\n    \"model_name\",\n    [\n        \"facebook/opt-125m\",\n        \"facebook/opt-350m\",\n        \"facebook/opt-1.3b\",\n        \"facebook/opt-2.7b\",\n        \"facebook/opt-6.7b\",\n    ],\n)\n# @pytest.mark.parametrize('model_name', [\"facebook/opt-125m\"])\ndef test_opt_generation(model_name):\n    \"\"\"Check that our implementation of OPT generation matches the HF implementation:\n    the scores in fp16 should be around the same as the HF scores in fp16, when compared to\n    the HF scores in fp32.\n    \"\"\"\n    print(f\"\\nMODEL: {model_name}\")\n    verbose = False\n    dtype = torch.float16\n    device = \"cuda\"\n    rtol, atol = 3e-3, 3e-1\n    config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name))\n    # Only prenorm supports residual_in_fp32\n    config.residual_in_fp32 = getattr(config, \"prenorm\", True)\n    config.use_flash_attn = True\n    config.fused_bias_fc = True\n    config.fused_mlp = True\n    config.fused_dropout_add_ln = True\n\n    model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)\n    model.eval()\n\n    torch.manual_seed(0)\n    # OPT tokenizer requires use_fast=False\n    # https://huggingface.co/docs/transformers/model_doc/opt\n    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)\n    eos_token_id = tokenizer.eos_token_id\n\n    input_ids = tokenizer(\"Hello, my dog is cute and he\", return_tensors=\"pt\").input_ids.to(\n        device=device\n    )\n    max_length = 25\n    # input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')\n    # max_length = input_ids.shape[1] + 40\n\n    # Slow generation for reference\n    sequences = []\n    scores = []\n    cur_input_ids = input_ids\n    with torch.inference_mode():\n        scores.append(model(cur_input_ids).logits[:, -1])\n        sequences.append(scores[-1].argmax(dim=-1))\n        for _ in range(input_ids.shape[1] + 1, max_length):\n            cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], \"b -> b 1\")], dim=-1)\n            scores.append(model(cur_input_ids).logits[:, -1])\n            sequences.append(scores[-1].argmax(dim=-1))\n            if eos_token_id is not None and (sequences[-1] == eos_token_id).all():\n                break\n    sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1)\n    scores = tuple(scores)\n\n    print(\"Without CUDA graph\")\n    torch.cuda.synchronize()\n    start = time.time()\n    out = model.generate(\n        input_ids=input_ids,\n        max_length=max_length,\n        eos_token_id=eos_token_id,\n        return_dict_in_generate=True,\n        output_scores=True,\n        enable_timing=True,\n    )\n    torch.cuda.synchronize()\n    print(f\"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms\")\n    if verbose:\n        print(out.sequences)\n    print(tokenizer.batch_decode(out.sequences.tolist()))\n    if getattr(config, \"use_flash_attn\", False):\n        # Capture graph outside the timing loop\n        batch_size, seqlen_og = input_ids.shape\n        model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)\n        print(\"With CUDA graph\")\n        torch.cuda.synchronize()\n        start = time.time()\n        out_cg = model.generate(\n            input_ids=input_ids,\n            max_length=max_length,\n            cg=True,\n            return_dict_in_generate=True,\n            output_scores=True,\n            enable_timing=True,\n        )\n        torch.cuda.synchronize()\n        print(f\"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms\")\n        if verbose:\n            print(out_cg.sequences)\n        print(tokenizer.batch_decode(out_cg.sequences.tolist()))\n\n    del model\n\n    model_hf = OPTForCausalLM.from_pretrained(model_name, torch_dtype=dtype).to(device=device)\n    model_hf.eval()\n    print(\"HF fp16\")\n    torch.cuda.synchronize()\n    start = time.time()\n    out_hf = model_hf.generate(\n        input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True\n    )\n    torch.cuda.synchronize()\n    print(f\"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms\")\n    del model_hf\n\n    model_ref = OPTForCausalLM.from_pretrained(model_name).to(device=device)\n    model_ref.eval()\n    print(\"HF fp32\")\n    torch.cuda.synchronize()\n    start = time.time()\n    out_ref = model_ref.generate(\n        input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True\n    )\n    torch.cuda.synchronize()\n    print(f\"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms\")\n    del model_ref\n    print(tokenizer.batch_decode(out_ref.sequences.tolist()))\n\n    if verbose:\n        print(\n            f\"Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}\"\n        )\n        print(\n            f\"Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}\"\n        )\n        print(\n            f\"HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}\"\n        )\n        print(\n            f\"HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}\"\n        )\n\n    assert torch.all(out.sequences == sequences)\n    assert torch.allclose(\n        torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), rtol=rtol, atol=atol\n    )\n    assert torch.all(out.sequences == out_ref.sequences)\n    assert torch.all(out.sequences == out_hf.sequences)\n\n    assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (\n        torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)\n    ).abs().max().item()\n"
  },
  {
    "path": "tests/models/test_vit.py",
    "content": "import re\n\nimport pytest\nimport torch\nfrom flash_attn.models.vit import vit_base_patch16_224 as flash_vit_base_patch16_224\nfrom timm.models.vision_transformer import vit_base_patch16_224\n\n\n@pytest.mark.parametrize(\"fused_mlp\", [False, True])\n# @pytest.mark.parametrize('fused_mlp', [False])\n@pytest.mark.parametrize(\"optimized\", [False, True])\n# @pytest.mark.parametrize('optimized', [True])\ndef test_vit(optimized, fused_mlp):\n    \"\"\"Check that our implementation of ViT matches the timm's implementation:\n    the output of our forward pass in fp16 should be around the same as\n    timm' forward pass in fp16, when compared to timm's forward pass in fp32.\n    \"\"\"\n    dtype = torch.float16\n    device = \"cuda\"\n\n    kwargs = {}\n    if optimized:\n        kwargs = dict(use_flash_attn=True, fused_bias_fc=True, fused_dropout_add_ln=True)\n    kwargs[\"fused_mlp\"] = fused_mlp\n    model = flash_vit_base_patch16_224(**kwargs).to(device=device, dtype=dtype)\n\n    model_ref = vit_base_patch16_224(pretrained=True).to(device=device)\n    model_timm = vit_base_patch16_224(pretrained=True).to(device=device, dtype=dtype)\n\n    model.load_state_dict(model_ref.state_dict())\n\n    model.eval()\n    model_ref.eval()\n    model_timm.eval()\n\n    torch.manual_seed(0)\n    batch_size = 2\n    x = torch.randn(batch_size, 3, 224, 224, device=device, dtype=dtype)\n    out = model(x)\n    out_timm = model_timm(x)\n    out_ref = model_ref(x.float())\n\n    print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n    print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n    print(f\"timm fp16 max diff: {(out_timm - out_ref).abs().max().item()}\")\n    print(f\"timm fp16 mean diff: {(out_timm - out_ref).abs().mean().item()}\")\n    rtol = 2 if not fused_mlp else 8\n    assert (out - out_ref).abs().max().item() < rtol * (out_timm - out_ref).abs().max().item()\n"
  },
  {
    "path": "tests/modules/test_block_parallel.py",
    "content": "# Run test with:\n# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/modules/test_block_parallel.py\n\nimport math\nfrom functools import partial\n\nimport pytest\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom apex.transformer import parallel_state, tensor_parallel\nfrom einops import rearrange\nfrom flash_attn.modules.block import Block\nfrom flash_attn.modules.mha import MHA, ParallelMHA\nfrom flash_attn.modules.mlp import FusedMLP, ParallelFusedMLP\nfrom flash_attn.utils.distributed import allreduce_sequence_parallel_grad\n\nis_sm8x = torch.cuda.get_device_capability(\"cuda\")[0] >= 8\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))\n# @pytest.mark.parametrize('dtype', [torch.float16])\n@pytest.mark.parametrize(\"world_size\", [1, 2, 4, 8])\n# @pytest.mark.parametrize('world_size', [2])\n@pytest.mark.parametrize(\"sequence_parallel\", [True, False])\n# @pytest.mark.parametrize('sequence_parallel', [True])\n@pytest.mark.parametrize(\"dim\", [1024])\ndef test_block_parallel(dim, sequence_parallel, world_size, dtype):\n    head_dim = 64\n    assert dim % head_dim == 0\n    num_heads = dim // head_dim\n    assert num_heads % world_size == 0\n    rtol, atol = (3e-3, 5e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3)\n    if not torch.distributed.is_initialized():\n        torch.distributed.init_process_group(backend=\"nccl\", init_method=\"env://\")\n    device = f\"cuda:{torch.distributed.get_rank()}\"\n    assert world_size <= torch.distributed.get_world_size()\n    parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)\n    rank = parallel_state.get_tensor_model_parallel_rank()\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 2\n    seqlen = 1024\n    assert (batch_size * seqlen) % world_size == 0\n    x_pt = torch.randn(batch_size * seqlen, dim, device=device, dtype=dtype, requires_grad=True)\n    residual_pt = torch.randn(batch_size * seqlen, dim, device=device, requires_grad=True)\n    # We need to generate g here so that all processes get the same gradient,\n    # as rank 0 will have an extra bias that changes the RNG.\n    # If we don't divide by batch_size, the gradient gets a bit too large.\n    g = torch.randn_like(x_pt) / 32\n    if sequence_parallel:\n        x = (\n            tensor_parallel.scatter_to_sequence_parallel_region(x_pt)\n            .detach()\n            .clone()\n            .requires_grad_()\n        )\n        residual = (\n            tensor_parallel.scatter_to_sequence_parallel_region(residual_pt)\n            .detach()\n            .clone()\n            .requires_grad_()\n        )\n    else:\n        x = x_pt.detach().clone().requires_grad_()\n        residual = residual_pt.detach().clone().requires_grad_()\n\n    mixer_cls_pt = partial(\n        MHA,\n        num_heads=num_heads,\n        rotary_emb_dim=int(head_dim // 2),\n        use_flash_attn=True,\n        device=device,\n        dtype=dtype,\n    )\n    mlp_cls_pt = partial(FusedMLP, hidden_features=4 * dim, device=device, dtype=dtype)\n    norm_cls = partial(nn.LayerNorm, device=device, dtype=dtype)\n    model_pt = Block(dim, mixer_cls_pt, mlp_cls_pt, norm_cls, fused_dropout_add_ln=True)\n    with torch.no_grad():\n        nn.init.normal_(model_pt.norm1.weight)\n        nn.init.normal_(model_pt.norm1.bias)\n        nn.init.normal_(model_pt.norm2.weight)\n        nn.init.normal_(model_pt.norm2.bias)\n\n    mixer_cls = partial(\n        ParallelMHA,\n        num_heads=num_heads,\n        process_group=parallel_state.get_tensor_model_parallel_group(),\n        rotary_emb_dim=int(head_dim // 2),\n        use_flash_attn=True,\n        sequence_parallel=sequence_parallel,\n        device=device,\n        dtype=dtype,\n    )\n    mlp_cls = partial(\n        ParallelFusedMLP,\n        hidden_features=4 * dim,\n        process_group=parallel_state.get_tensor_model_parallel_group(),\n        sequence_parallel=sequence_parallel,\n        device=device,\n        dtype=dtype,\n    )\n    model = Block(\n        dim,\n        mixer_cls,\n        mlp_cls,\n        norm_cls,\n        fused_dropout_add_ln=True,\n        sequence_parallel=sequence_parallel,\n        mark_shared_params=True,\n    )\n\n    partition_dim = dim // world_size\n    partition_hidden_dim = 4 * dim // world_size\n    with torch.no_grad():\n        model.mixer.Wqkv.weight.copy_(\n            rearrange(\n                rearrange(model_pt.mixer.Wqkv.weight, \"(three o) i -> three o i\", three=3)[\n                    :, rank * partition_dim : (rank + 1) * partition_dim\n                ],\n                \"three o i -> (three o) i\",\n            )\n        )\n        model.mixer.Wqkv.bias.copy_(\n            rearrange(\n                rearrange(model_pt.mixer.Wqkv.bias, \"(three o) -> three o\", three=3)[\n                    :, rank * partition_dim : (rank + 1) * partition_dim\n                ],\n                \"three o -> (three o)\",\n            )\n        )\n        model.mixer.out_proj.weight.copy_(\n            model_pt.mixer.out_proj.weight[:, rank * partition_dim : (rank + 1) * partition_dim]\n        )\n        if rank == 0:\n            model.mixer.out_proj.bias.copy_(model_pt.mixer.out_proj.bias)\n        model.mlp.fc1.weight.copy_(\n            model_pt.mlp.fc1.weight[rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim]\n        )\n        model.mlp.fc1.bias.copy_(\n            model_pt.mlp.fc1.bias[rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim]\n        )\n        model.mlp.fc2.weight.copy_(\n            model_pt.mlp.fc2.weight[\n                :, rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim\n            ]\n        )\n        if rank == 0:\n            model.mlp.fc2.bias.copy_(model_pt.mlp.fc2.bias)\n        model.norm1.weight.copy_(model_pt.norm1.weight)\n        model.norm1.bias.copy_(model_pt.norm1.bias)\n        model.norm2.weight.copy_(model_pt.norm2.weight)\n        model.norm2.bias.copy_(model_pt.norm2.bias)\n\n    mixer_kwargs = {\"seqlen\": seqlen}\n    out, out_residual = model(x, residual, mixer_kwargs=mixer_kwargs)\n    out_pt, out_residual_pt = model_pt(\n        rearrange(x_pt, \"(b s) d -> b s d\", s=seqlen),\n        rearrange(residual_pt, \"(b s) d -> b s d\", s=seqlen),\n    )\n    out_pt, out_residual_pt = [rearrange(x, \"b s d -> (b s) d\") for x in [out_pt, out_residual_pt]]\n    partition_batch_dim = batch_size * seqlen // world_size\n    assert torch.allclose(\n        out,\n        out_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]\n        if sequence_parallel\n        else out_pt,\n        rtol=rtol,\n        atol=atol,\n    )\n    assert torch.allclose(\n        out_residual,\n        out_residual_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]\n        if sequence_parallel\n        else out_residual_pt,\n        rtol=rtol,\n        atol=atol,\n    )\n\n    (out_pt + 2 * out_residual_pt).backward(g)\n    (out + 2 * out_residual).backward(\n        g[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else g\n    )\n    allreduce_sequence_parallel_grad(model, parallel_state.get_tensor_model_parallel_group())\n    parallel_state.destroy_model_parallel()\n\n    assert torch.allclose(\n        x.grad,\n        x_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]\n        if sequence_parallel\n        else x_pt.grad,\n        rtol=rtol,\n        atol=atol / 10,  # magnitude of x.grad is quite small\n    )\n    assert torch.allclose(\n        residual.grad,\n        residual_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]\n        if sequence_parallel\n        else residual_pt.grad,\n        rtol=rtol,\n        atol=atol,\n    )\n    # The error for d_weight and d_bias is quite a bit higher\n    assert torch.allclose(\n        model.mixer.Wqkv.weight.grad,\n        rearrange(\n            rearrange(model_pt.mixer.Wqkv.weight.grad, \"(three o) i -> three o i\", three=3)[\n                :, rank * partition_dim : (rank + 1) * partition_dim\n            ],\n            \"three o i -> (three o) i\",\n        ),\n        rtol=rtol,\n        atol=atol * 10,\n    )\n    assert torch.allclose(\n        model.mixer.Wqkv.bias.grad,\n        rearrange(\n            rearrange(model_pt.mixer.Wqkv.bias.grad, \"(three o) -> three o\", three=3)[\n                :, rank * partition_dim : (rank + 1) * partition_dim\n            ],\n            \"three o -> (three o)\",\n        ),\n        rtol=rtol,\n        atol=atol * 5,\n    )\n    assert torch.allclose(\n        model.mixer.out_proj.weight.grad,\n        model_pt.mixer.out_proj.weight.grad[:, rank * partition_dim : (rank + 1) * partition_dim],\n        rtol=rtol,\n        atol=atol * 10,\n    )\n    if rank == 0:\n        assert torch.allclose(\n            model.mixer.out_proj.bias.grad,\n            model_pt.mixer.out_proj.bias.grad,\n            rtol=rtol,\n            atol=atol * 5,\n        )\n    assert torch.allclose(\n        model.mlp.fc1.weight.grad,\n        model_pt.mlp.fc1.weight.grad[\n            rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim\n        ],\n        rtol=rtol,\n        atol=atol * 10,\n    )\n    assert torch.allclose(\n        model.mlp.fc1.bias.grad,\n        model_pt.mlp.fc1.bias.grad[rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim],\n        rtol=rtol,\n        atol=atol * 5,\n    )\n    assert torch.allclose(\n        model.mlp.fc2.weight.grad,\n        model_pt.mlp.fc2.weight.grad[\n            :, rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim\n        ],\n        rtol=rtol,\n        atol=atol * 10,\n    )\n    if rank == 0:\n        assert torch.allclose(\n            model.mlp.fc2.bias.grad, model_pt.mlp.fc2.bias.grad, rtol=rtol, atol=atol * 5\n        )\n\n    assert torch.allclose(\n        model.norm1.weight.grad, model_pt.norm1.weight.grad, rtol=rtol, atol=atol * 5\n    )\n    assert torch.allclose(model.norm1.bias.grad, model_pt.norm1.bias.grad, rtol=rtol, atol=atol * 5)\n    assert torch.allclose(\n        model.norm2.weight.grad, model_pt.norm2.weight.grad, rtol=rtol, atol=atol * 5\n    )\n    assert torch.allclose(model.norm2.bias.grad, model_pt.norm2.bias.grad, rtol=rtol, atol=atol * 5)\n"
  },
  {
    "path": "tests/modules/test_embedding_parallel.py",
    "content": "# Run test with:\n# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/modules/test_embedding_parallel.py\n\nimport pytest\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom apex.transformer import parallel_state\nfrom einops import rearrange\nfrom flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings\n\nis_sm8x = torch.cuda.get_device_capability(\"cuda\")[0] >= 8\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))\n# @pytest.mark.parametrize('dtype', [torch.bfloat16])\n@pytest.mark.parametrize(\"world_size\", [1, 2, 4, 8])\n# @pytest.mark.parametrize('world_size', [2])\n@pytest.mark.parametrize(\"sequence_parallel\", [True, False])\n# @pytest.mark.parametrize('sequence_parallel', [False])\n@pytest.mark.parametrize(\"has_pos_emb\", [True, False])\n# @pytest.mark.parametrize('has_pos_emb', [True])\n@pytest.mark.parametrize(\"dim\", [1024])\ndef test_embedding_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):\n    vocab_size = 50264\n    seqlen = 2048\n    assert vocab_size % world_size == 0\n    assert dim % world_size == 0\n    rtol, atol = (3e-3, 5e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3)\n    if not torch.distributed.is_initialized():\n        torch.distributed.init_process_group(backend=\"nccl\", init_method=\"env://\")\n    device = f\"cuda:{torch.distributed.get_rank()}\"\n    assert world_size <= torch.distributed.get_world_size()\n    parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)\n    rank = parallel_state.get_tensor_model_parallel_rank()\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 8\n    seqlen = 1024\n    assert (batch_size * seqlen) % world_size == 0\n    input_ids_pt = torch.randint(0, vocab_size, (batch_size, seqlen), device=device)\n    input_ids = input_ids_pt.detach().clone()\n\n    model_pt = GPT2Embeddings(\n        dim, vocab_size, seqlen if has_pos_emb else 0, device=device, dtype=dtype\n    )\n    model = ParallelGPT2Embeddings(\n        dim,\n        vocab_size,\n        seqlen if has_pos_emb else 0,\n        parallel_state.get_tensor_model_parallel_group(),\n        sequence_parallel=sequence_parallel,\n        device=device,\n        dtype=dtype,\n    )\n    partition_vocab_size = vocab_size // world_size\n    partition_dim = dim // world_size\n    with torch.no_grad():\n        model.word_embeddings.weight.copy_(\n            model_pt.word_embeddings.weight[\n                rank * partition_vocab_size : (rank + 1) * partition_vocab_size\n            ]\n        )\n        if has_pos_emb:\n            model.position_embeddings.weight.copy_(\n                model_pt.position_embeddings.weight[\n                    :, rank * partition_dim : (rank + 1) * partition_dim\n                ]\n            )\n\n    out = model(input_ids, combine_batch_seqlen_dim=True)\n    out_pt = rearrange(model_pt(input_ids), \"b s d -> (b s) d\")\n    partition_batch_dim = batch_size * seqlen // world_size\n    assert torch.allclose(\n        out,\n        out_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]\n        if sequence_parallel\n        else out_pt,\n        rtol=rtol,\n        atol=atol,\n    )\n\n    g = torch.randn_like(out_pt)\n    out_pt.backward(g)\n    out.backward(\n        g[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else g\n    )\n    parallel_state.destroy_model_parallel()\n\n    assert torch.allclose(\n        model.word_embeddings.weight.grad,\n        model_pt.word_embeddings.weight.grad[\n            rank * partition_vocab_size : (rank + 1) * partition_vocab_size\n        ],\n        rtol=rtol,\n        atol=atol,\n    )\n    if has_pos_emb:\n        assert torch.allclose(\n            model.position_embeddings.weight.grad,\n            model_pt.position_embeddings.weight.grad[\n                :, rank * partition_dim : (rank + 1) * partition_dim\n            ],\n            rtol=rtol,\n            atol=atol,\n        )\n"
  },
  {
    "path": "tests/modules/test_mha_parallel.py",
    "content": "# Run test with:\n# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/modules/test_mha_parallel.py\n\nimport math\n\nimport pytest\nimport torch\nimport torch.nn.functional as F\nfrom apex.transformer import parallel_state, tensor_parallel\nfrom einops import rearrange\nfrom flash_attn.modules.mha import MHA, ParallelMHA\n\nis_sm8x = torch.cuda.get_device_capability(\"cuda\")[0] >= 8\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))\n# @pytest.mark.parametrize('dtype', [torch.float16])\n@pytest.mark.parametrize(\"world_size\", [1, 2, 4, 8])\n# @pytest.mark.parametrize('world_size', [2])\n@pytest.mark.parametrize(\"sequence_parallel\", [True, False])\n# @pytest.mark.parametrize('sequence_parallel', [False])\n@pytest.mark.parametrize(\"head_dim\", [64, 128])\n# @pytest.mark.parametrize('head_dim', [64])\n@pytest.mark.parametrize(\"embed_dim\", [1024, 4096])\n# @pytest.mark.parametrize('embed_dim', [1024])\ndef test_mha_parallel(embed_dim, head_dim, sequence_parallel, world_size, dtype):\n    assert embed_dim % head_dim == 0\n    num_heads = embed_dim // head_dim\n    assert num_heads % world_size == 0\n    rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)\n    if not torch.distributed.is_initialized():\n        torch.distributed.init_process_group(backend=\"nccl\", init_method=\"env://\")\n    device = f\"cuda:{torch.distributed.get_rank()}\"\n    assert world_size <= torch.distributed.get_world_size()\n    parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)\n    rank = parallel_state.get_tensor_model_parallel_rank()\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 2\n    seqlen = 1024\n    assert (batch_size * seqlen) % world_size == 0\n    x_pt = torch.randn(\n        batch_size * seqlen, embed_dim, device=device, dtype=dtype, requires_grad=True\n    )\n    # We need to generate g here so that all processes get the same gradient,\n    # as rank 0 will have an extra bias that changes the RNG.\n    # If we don't divide by batch_size, the gradient gets a bit too large.\n    g = torch.randn_like(x_pt) / 32\n    if sequence_parallel:\n        x = (\n            tensor_parallel.scatter_to_sequence_parallel_region(x_pt)\n            .detach()\n            .clone()\n            .requires_grad_()\n        )\n    else:\n        x = x_pt.detach().clone().requires_grad_()\n\n    model_pt = MHA(\n        embed_dim,\n        num_heads,\n        rotary_emb_dim=int(head_dim // 2),\n        use_flash_attn=True,\n        device=device,\n        dtype=dtype,\n    )\n    partition_dim = embed_dim // world_size\n    model = ParallelMHA(\n        embed_dim,\n        num_heads,\n        parallel_state.get_tensor_model_parallel_group(),\n        rotary_emb_dim=int(head_dim // 2),\n        use_flash_attn=True,\n        sequence_parallel=sequence_parallel,\n        device=device,\n        dtype=dtype,\n    )\n\n    with torch.no_grad():\n        model.Wqkv.weight.copy_(\n            rearrange(\n                rearrange(model_pt.Wqkv.weight, \"(three o) i -> three o i\", three=3)[\n                    :, rank * partition_dim : (rank + 1) * partition_dim\n                ],\n                \"three o i -> (three o) i\",\n            )\n        )\n        model.Wqkv.bias.copy_(\n            rearrange(\n                rearrange(model_pt.Wqkv.bias, \"(three o) -> three o\", three=3)[\n                    :, rank * partition_dim : (rank + 1) * partition_dim\n                ],\n                \"three o -> (three o)\",\n            )\n        )\n        model.out_proj.weight.copy_(\n            model_pt.out_proj.weight[:, rank * partition_dim : (rank + 1) * partition_dim]\n        )\n        if rank == 0:\n            model.out_proj.bias.copy_(model_pt.out_proj.bias)\n\n    out = model(x, seqlen=seqlen)\n    out_pt = rearrange(model_pt(rearrange(x_pt, \"(b s) d -> b s d\", s=seqlen)), \"b s d -> (b s) d\")\n    partition_batch_dim = batch_size * seqlen // world_size\n    assert torch.allclose(\n        out,\n        out_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]\n        if sequence_parallel\n        else out_pt,\n        rtol=rtol,\n        atol=atol,\n    )\n\n    out_pt.backward(g)\n    out.backward(\n        g[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else g\n    )\n    parallel_state.destroy_model_parallel()\n\n    assert torch.allclose(\n        x.grad,\n        x_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]\n        if sequence_parallel\n        else x_pt.grad,\n        rtol=rtol,\n        atol=atol / 100,  # magnitude of x.grad is quite small\n    )\n    # The error for d_weight and d_bias is quite a bit higher\n    assert torch.allclose(\n        model.Wqkv.weight.grad,\n        rearrange(\n            rearrange(model_pt.Wqkv.weight.grad, \"(three o) i -> three o i\", three=3)[\n                :, rank * partition_dim : (rank + 1) * partition_dim\n            ],\n            \"three o i -> (three o) i\",\n        ),\n        rtol=rtol,\n        atol=atol * 10,\n    )\n    assert torch.allclose(\n        model.Wqkv.bias.grad,\n        rearrange(\n            rearrange(model_pt.Wqkv.bias.grad, \"(three o) -> three o\", three=3)[\n                :, rank * partition_dim : (rank + 1) * partition_dim\n            ],\n            \"three o -> (three o)\",\n        ),\n        rtol=rtol,\n        atol=atol * 5,\n    )\n    assert torch.allclose(\n        model.out_proj.weight.grad,\n        model_pt.out_proj.weight.grad[:, rank * partition_dim : (rank + 1) * partition_dim],\n        rtol=rtol,\n        atol=atol * 10,\n    )\n    if rank == 0:\n        assert torch.allclose(\n            model.out_proj.bias.grad, model_pt.out_proj.bias.grad, rtol=rtol, atol=atol * 5\n        )\n"
  },
  {
    "path": "tests/modules/test_mlp_parallel.py",
    "content": "# Run test with:\n# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/modules/test_mlp_parallel.py\n\nimport pytest\nimport torch\nimport torch.nn.functional as F\nfrom apex.transformer import parallel_state, tensor_parallel\nfrom einops import rearrange\nfrom flash_attn.modules.mlp import GatedMlp, ParallelGatedMlp\n\nis_sm8x = torch.cuda.get_device_capability(\"cuda\")[0] >= 8\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))\n# @pytest.mark.parametrize('dtype', [torch.float16])\n@pytest.mark.parametrize(\"world_size\", [1, 2, 4, 8])\n# @pytest.mark.parametrize('world_size', [2])\n@pytest.mark.parametrize(\"sequence_parallel\", [True, False])\n# @pytest.mark.parametrize('sequence_parallel', [False])\n@pytest.mark.parametrize(\"activation\", [F.silu, F.sigmoid])\n# @pytest.mark.parametrize('activation', [F.silu])\n@pytest.mark.parametrize(\"dim\", [1024, 4096])\n# @pytest.mark.parametrize('dim', [1024])\ndef test_mlp_parallel(dim, activation, sequence_parallel, world_size, dtype):\n    rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3)\n\n    if not torch.distributed.is_initialized():\n        torch.distributed.init_process_group(backend=\"nccl\", init_method=\"env://\")\n    device = f\"cuda:{torch.distributed.get_rank()}\"\n    assert world_size <= torch.distributed.get_world_size()\n    parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)\n    rank = parallel_state.get_tensor_model_parallel_rank()\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 2\n    seqlen = 1024\n    assert (batch_size * seqlen) % world_size == 0\n    x_pt = torch.randn(batch_size * seqlen, dim, device=device, dtype=dtype, requires_grad=True)\n    # We need to generate g here so that all processes get the same gradient,\n    # as rank 0 will have an extra bias that changes the RNG.\n    # If we don't divide by batch_size, the gradient gets a bit too large.\n    g = torch.randn_like(x_pt) / 32\n    if sequence_parallel:\n        x = (\n            tensor_parallel.scatter_to_sequence_parallel_region(x_pt)\n            .detach()\n            .clone()\n            .requires_grad_()\n        )\n    else:\n        x = x_pt.detach().clone().requires_grad_()\n\n    model_pt = GatedMlp(dim, activation=activation, device=device, dtype=dtype)\n    partition_dim = model_pt.fc1.weight.shape[0] // 2 // world_size\n    model = ParallelGatedMlp(\n        dim,\n        parallel_state.get_tensor_model_parallel_group(),\n        activation=activation,\n        sequence_parallel=sequence_parallel,\n        device=device,\n        dtype=dtype,\n    )\n\n    with torch.no_grad():\n        model.fc1.weight.copy_(\n            rearrange(\n                rearrange(model_pt.fc1.weight, \"(two o) i -> two o i\", two=2)[\n                    :, rank * partition_dim : (rank + 1) * partition_dim\n                ],\n                \"two o i -> (two o) i\",\n            )\n        )\n        model.fc1.bias.copy_(\n            rearrange(\n                rearrange(model_pt.fc1.bias, \"(two o) -> two o\", two=2)[\n                    :, rank * partition_dim : (rank + 1) * partition_dim\n                ],\n                \"two o -> (two o)\",\n            )\n        )\n        model.fc2.weight.copy_(\n            model_pt.fc2.weight[:, rank * partition_dim : (rank + 1) * partition_dim]\n        )\n        if rank == 0:\n            model.fc2.bias.copy_(model_pt.fc2.bias)\n\n    out = model(x)\n    out_pt = model_pt(x_pt)\n    partition_batch_dim = batch_size * seqlen // world_size\n    assert torch.allclose(\n        out,\n        out_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]\n        if sequence_parallel\n        else out_pt,\n        rtol=rtol,\n        atol=atol,\n    )\n\n    out_pt.backward(g)\n    out.backward(\n        g[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else g\n    )\n    parallel_state.destroy_model_parallel()\n\n    assert torch.allclose(\n        x.grad,\n        x_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]\n        if sequence_parallel\n        else x_pt.grad,\n        rtol=rtol,\n        atol=atol,\n    )\n\n    assert torch.allclose(\n        model.fc1.weight.grad,\n        rearrange(\n            rearrange(model_pt.fc1.weight.grad, \"(two o) i -> two o i\", two=2)[\n                :, rank * partition_dim : (rank + 1) * partition_dim\n            ],\n            \"two o i -> (two o) i\",\n        ),\n        rtol=rtol,\n        atol=atol,\n    )\n    assert torch.allclose(\n        model.fc1.bias.grad,\n        rearrange(\n            rearrange(model_pt.fc1.bias.grad, \"(two o) -> two o\", two=2)[\n                :, rank * partition_dim : (rank + 1) * partition_dim\n            ],\n            \"two o -> (two o)\",\n        ),\n        rtol=rtol,\n        atol=atol,\n    )\n    assert torch.allclose(\n        model.fc2.weight.grad,\n        model_pt.fc2.weight.grad[:, rank * partition_dim : (rank + 1) * partition_dim],\n        rtol=rtol,\n        atol=atol,\n    )\n    if rank == 0:\n        assert torch.allclose(model.fc2.bias.grad, model_pt.fc2.bias.grad, rtol=rtol, atol=atol)\n"
  },
  {
    "path": "tests/ops/test_dropout_layer_norm.py",
    "content": "import math\n\nimport pytest\nimport torch\nimport torch.nn.functional as F\nfrom einops import rearrange, repeat\nfrom flash_attn.ops.layer_norm import (\n    DropoutAddLayerNorm,\n    dropout_add_layer_norm,\n    dropout_add_layer_norm_parallel_residual,\n    dropout_add_layer_norm_subset,\n)\nfrom flash_attn.ops.rms_norm import (\n    DropoutAddRMSNorm,\n    dropout_add_rms_norm,\n    dropout_add_rms_norm_parallel_residual,\n    dropout_add_rms_norm_subset,\n)\n\ntry:\n    from apex.normalization import FusedRMSNorm\n    from apex.normalization.fused_layer_norm import fused_rms_norm_affine\nexcept:\n    FusedRMSNorm, fused_rms_norm_affine = None, None\n\n\nis_sm8x = torch.cuda.get_device_capability(\"cuda\")[0] >= 8\n\n\n@pytest.mark.parametrize(\"is_rms_norm\", [False, True])\n@pytest.mark.parametrize(\"has_colscale\", [True, False])\n# @pytest.mark.parametrize('has_colscale', [False])\n@pytest.mark.parametrize(\"has_rowscale\", [True, False])\n# @pytest.mark.parametrize('has_rowscale', [True])\n@pytest.mark.parametrize(\"has_residual\", [True, False])\n# @pytest.mark.parametrize('has_residual', [False])\n@pytest.mark.parametrize(\"dropout_p\", [0.37, 0.0])\n# @pytest.mark.parametrize('dropout_p', [0.0])\n@pytest.mark.parametrize(\"weight_dtype\", [torch.float32, torch.float16])\n# @pytest.mark.parametrize('weight_dtype', [torch.float32])\n@pytest.mark.parametrize(\n    \"input_dtype,residual_dtype\",\n    [(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]\n    + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),\n)\n# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])\n@pytest.mark.parametrize(\n    \"hidden_size\",\n    [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144],\n)\n# @pytest.mark.parametrize('hidden_size', [256])\ndef test_dropout_layer_norm_training(\n    hidden_size,\n    input_dtype,\n    residual_dtype,\n    weight_dtype,\n    dropout_p,\n    has_residual,\n    has_rowscale,\n    has_colscale,\n    is_rms_norm,\n):\n    if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:\n        pytest.skip()  # Not supported\n    if is_rms_norm and FusedRMSNorm is None:\n        pytest.skip()  # We need Apex's FusedRMSNorm to test\n    layer_norm_cls = torch.nn.LayerNorm if not is_rms_norm else FusedRMSNorm\n    our_layer_norm_cls = DropoutAddLayerNorm if not is_rms_norm else DropoutAddRMSNorm\n    our_layer_norm_func = dropout_add_layer_norm if not is_rms_norm else dropout_add_rms_norm\n    device = \"cuda\"\n    # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)\n    rtol, atol = (1e-3, 1e-4)\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 8\n    seqlen = 512\n    x0_pt = torch.randn(\n        batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True\n    )\n    x0 = x0_pt.detach().clone().requires_grad_()\n    x0_ref = x0_pt.detach().clone().float().requires_grad_()\n    if has_colscale:\n        colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)\n        colscale_pt = colscale.detach().clone().requires_grad_()\n        colscale_ref = colscale.detach().clone().float().requires_grad_()\n    else:\n        colscale = None\n    if has_residual:\n        res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)\n        res = res_pt.detach().clone().requires_grad_()\n        res_ref = res_pt.detach().clone().float().requires_grad_()\n    else:\n        res = None\n    if has_rowscale:\n        rowscale = torch.empty(batch_size, seqlen, device=device, dtype=input_dtype)\n        survival_rate = 0.87\n        rowscale = rowscale.bernoulli_(survival_rate) / survival_rate\n        x0_scaled_pt = x0_pt * rearrange(rowscale, \"... -> ... 1\")\n        x0_scaled_ref = x0_ref * rearrange(rowscale, \"... -> ... 1\")\n    else:\n        rowscale = None\n        x0_scaled_pt = x0_pt\n        x0_scaled_ref = x0_ref\n    if has_colscale:\n        x0_scaled_pt = x0_scaled_pt * colscale_pt\n        x0_scaled_ref = x0_scaled_ref * colscale_ref\n    model_pt = layer_norm_cls(hidden_size).to(device=device, dtype=weight_dtype)\n    torch.nn.init.normal_(model_pt.weight)\n    if not is_rms_norm:\n        torch.nn.init.normal_(model_pt.bias)\n    model_ref = layer_norm_cls(hidden_size).to(device=device, dtype=torch.float32)\n    model = our_layer_norm_cls(hidden_size, p=dropout_p, device=device, dtype=weight_dtype)\n    with torch.no_grad():\n        model.weight.copy_(model_pt.weight)\n        model_ref.weight.copy_(model_pt.weight)\n        if not is_rms_norm:\n            model.bias.copy_(model_pt.bias)\n            model_ref.bias.copy_(model_pt.bias)\n    residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32\n    out, dmask = our_layer_norm_func(\n        x0,\n        res,\n        model.weight,\n        model.bias,\n        model.p,\n        model.eps,\n        rowscale=rowscale,\n        layerscale=colscale,\n        residual_in_fp32=residual_in_fp32,\n        return_dropout_mask=True,\n    )\n    assert out.dtype == input_dtype\n    print(f\"Actual dropout fraction: {1 - dmask.float().mean().item()}\")\n    if has_residual:\n        residual_pt = (\n            (x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p) + res_pt.float()\n        ).to(dtype=residual_dtype)\n        residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) + res_ref\n    else:\n        residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p)).to(\n            dtype=residual_dtype\n        )\n        residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p)\n    out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)\n    out_ref = model_ref(residual_ref)\n    assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4\n\n    g = torch.randn_like(out) / batch_size\n    out_pt.backward(g)\n    out.backward(g)\n    out_ref.backward(g)\n    assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4\n    if has_residual:\n        assert (res.grad - res_ref.grad).abs().max() <= 4 * (\n            res_pt.grad - res_ref.grad\n        ).abs().max() + 1e-4\n    assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 3 * (\n        model_pt.weight.grad - model_ref.weight.grad\n    ).abs().max() + 3e-5\n    if not is_rms_norm:\n        assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (\n            model_pt.bias.grad - model_ref.bias.grad\n        ).abs().max() + 3e-5\n    if has_colscale:\n        assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (\n            colscale_pt.grad - colscale_ref.grad\n        ).abs().max() + 2e-4\n\n\n@pytest.mark.parametrize(\"weight_dtype\", [torch.float32, torch.float16])\n@pytest.mark.parametrize(\n    \"input_dtype,residual_dtype\",\n    [(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]\n    + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),\n)\n@pytest.mark.parametrize(\"hidden_size\", [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120])\ndef test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weight_dtype):\n    if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:\n        pytest.skip()  # Not supported\n    device = \"cuda\"\n    # rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)\n    rtol, atol = (1e-3, 1e-4)\n    dropout_p = 0.37\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 32\n    seqlen = 512\n    x0_pt = torch.randn(\n        batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True\n    )\n    x0 = x0_pt.detach().clone().requires_grad_()\n    x0_ref = x0_pt.detach().clone().float().requires_grad_()\n    res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)\n    res = res_pt.detach().clone().requires_grad_()\n    res_ref = res_pt.detach().clone().float().requires_grad_()\n    model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)\n    torch.nn.init.normal_(model_pt.weight)\n    torch.nn.init.normal_(model_pt.bias)\n    model = DropoutAddLayerNorm(hidden_size, p=dropout_p, device=device, dtype=weight_dtype)\n    model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)\n    with torch.no_grad():\n        model.weight.copy_(model_pt.weight)\n        model.bias.copy_(model_pt.bias)\n        model_ref.weight.copy_(model_pt.weight)\n        model_ref.bias.copy_(model_pt.bias)\n    model_pt.eval()\n    model.eval()\n    model_ref.eval()\n    out = model(x0, res)\n    residual_pt = (x0_pt.float() + res_pt.float()).to(dtype=residual_dtype)\n    residual_ref = x0_ref + res_ref\n    out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(input_dtype)\n    out_ref = model_ref(residual_ref)\n    assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4\n\n\n@pytest.mark.parametrize(\"is_rms_norm\", [False, True])\n@pytest.mark.parametrize(\"has_colscale\", [True, False])\n@pytest.mark.parametrize(\"has_rowscale\", [True, False])\n@pytest.mark.parametrize(\"has_residual\", [True, False])\n@pytest.mark.parametrize(\"dropout_p\", [0.37, 0.0])\n@pytest.mark.parametrize(\"weight_dtype\", [torch.float32, torch.float16])\n@pytest.mark.parametrize(\n    \"input_dtype,residual_dtype\",\n    [(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]\n    + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),\n)\n# @pytest.mark.parametrize('has_colscale', [True])\n# @pytest.mark.parametrize('has_rowscale', [False])\n# @pytest.mark.parametrize('has_residual', [True])\n# @pytest.mark.parametrize('dropout_p', [0.0])\n# @pytest.mark.parametrize('weight_dtype', [torch.float32])\n# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])\n@pytest.mark.parametrize(\n    \"hidden_size\",\n    [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144],\n)\n# @pytest.mark.parametrize('hidden_size', [256])\ndef test_dropout_layer_norm_prenorm_training(\n    hidden_size,\n    input_dtype,\n    residual_dtype,\n    weight_dtype,\n    dropout_p,\n    has_residual,\n    has_rowscale,\n    has_colscale,\n    is_rms_norm,\n):\n    if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:\n        pytest.skip()  # Not supported\n    if is_rms_norm and FusedRMSNorm is None:\n        pytest.skip()  # We need Apex's FusedRMSNorm to test\n    layer_norm_cls = torch.nn.LayerNorm if not is_rms_norm else FusedRMSNorm\n    our_layer_norm_cls = DropoutAddLayerNorm if not is_rms_norm else DropoutAddRMSNorm\n    our_layer_norm_func = dropout_add_layer_norm if not is_rms_norm else dropout_add_rms_norm\n    device = \"cuda\"\n    # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)\n    rtol, atol = (1e-3, 2e-4)\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 8\n    seqlen = 512\n    x0_pt = torch.randn(\n        batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True\n    )\n    x0 = x0_pt.detach().clone().requires_grad_()\n    x0_ref = x0_pt.detach().clone().float().requires_grad_()\n    if has_colscale:\n        colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)\n        colscale_pt = colscale.detach().clone().requires_grad_()\n        colscale_ref = colscale.detach().clone().float().requires_grad_()\n    else:\n        colscale = None\n    if has_residual:\n        res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)\n        res = res_pt.detach().clone().requires_grad_()\n        res_ref = res_pt.detach().clone().float().requires_grad_()\n    else:\n        res = None\n    if has_rowscale:\n        rowscale = torch.empty(batch_size, seqlen, device=device, dtype=input_dtype)\n        survival_rate = 0.87\n        rowscale = rowscale.bernoulli_(survival_rate) / survival_rate\n        x0_scaled_pt = x0_pt * rearrange(rowscale, \"... -> ... 1\")\n        x0_scaled_ref = x0_ref * rearrange(rowscale, \"... -> ... 1\")\n    else:\n        rowscale = None\n        x0_scaled_pt = x0_pt\n        x0_scaled_ref = x0_ref\n    if has_colscale:\n        x0_scaled_pt = x0_scaled_pt * colscale_pt\n        x0_scaled_ref = x0_scaled_ref * colscale_ref\n    model_pt = layer_norm_cls(hidden_size).to(device=device, dtype=weight_dtype)\n    torch.nn.init.normal_(model_pt.weight)\n    if not is_rms_norm:\n        torch.nn.init.normal_(model_pt.bias)\n    model_ref = layer_norm_cls(hidden_size).to(device=device, dtype=torch.float32)\n    model = our_layer_norm_cls(\n        hidden_size, prenorm=True, p=dropout_p, device=device, dtype=weight_dtype\n    )\n    with torch.no_grad():\n        model.weight.copy_(model_pt.weight)\n        model_ref.weight.copy_(model_pt.weight)\n        if not is_rms_norm:\n            model.bias.copy_(model_pt.bias)\n            model_ref.bias.copy_(model_pt.bias)\n    residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32\n    out, residual, dmask = our_layer_norm_func(\n        x0,\n        res,\n        model.weight,\n        model.bias,\n        model.p,\n        model.eps,\n        rowscale=rowscale,\n        layerscale=colscale,\n        prenorm=True,\n        residual_in_fp32=residual_in_fp32,\n        return_dropout_mask=True,\n    )\n    print(f\"Actual dropout fraction: {1 - dmask.float().mean().item()}\")\n    if has_residual:\n        residual_pt = (\n            (x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p) + res_pt.float()\n        ).to(dtype=residual_dtype)\n        residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) + res_ref\n    else:\n        residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p)).to(\n            dtype=residual_dtype\n        )\n        residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p)\n    out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)\n    out_ref = model_ref(residual_ref)\n    assert out.dtype == input_dtype\n    assert residual.dtype == residual_dtype\n    assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4\n    assert (residual - residual_ref).abs().max() <= 4 * (\n        residual_pt - residual_ref\n    ).abs().max() + 1e-4\n\n    g = torch.randn_like(out) / batch_size\n    (out_pt * F.sigmoid(residual_pt)).backward(g)\n    (out * F.sigmoid(residual)).backward(g)\n    (out_ref * F.sigmoid(residual_ref.to(dtype=residual_dtype))).backward(g)\n    assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4\n    if has_residual:\n        assert (res.grad - res_ref.grad).abs().max() <= 4 * (\n            res_pt.grad - res_ref.grad\n        ).abs().max() + 1e-4\n    assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (\n        model_pt.weight.grad - model_ref.weight.grad\n    ).abs().max() + 2e-4\n    if not is_rms_norm:\n        assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (\n            model_pt.bias.grad - model_ref.bias.grad\n        ).abs().max() + 2e-4\n    if has_colscale:\n        assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (\n            colscale_pt.grad - colscale_ref.grad\n        ).abs().max() + 2e-4\n\n\n@pytest.mark.parametrize(\"weight_dtype\", [torch.float32, torch.float16])\n@pytest.mark.parametrize(\n    \"input_dtype,residual_dtype\",\n    [(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]\n    + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),\n)\n@pytest.mark.parametrize(\"hidden_size\", [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120])\ndef test_dropout_layer_norm_prenorm_eval(hidden_size, input_dtype, residual_dtype, weight_dtype):\n    if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:\n        pytest.skip()  # Not supported\n    device = \"cuda\"\n    # rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)\n    rtol, atol = (1e-3, 1e-4)\n    dropout_p = 0.37\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 32\n    seqlen = 512\n    x0_pt = torch.randn(\n        batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True\n    )\n    x0 = x0_pt.detach().clone().requires_grad_()\n    x0_ref = x0_pt.detach().clone().float().requires_grad_()\n    res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)\n    res = res_pt.detach().clone().requires_grad_()\n    res_ref = res_pt.detach().clone().float().requires_grad_()\n    model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)\n    torch.nn.init.normal_(model_pt.weight)\n    torch.nn.init.normal_(model_pt.bias)\n    model = DropoutAddLayerNorm(\n        hidden_size, prenorm=True, p=dropout_p, device=device, dtype=weight_dtype\n    )\n    model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)\n    with torch.no_grad():\n        model.weight.copy_(model_pt.weight)\n        model.bias.copy_(model_pt.bias)\n        model_ref.weight.copy_(model_pt.weight)\n        model_ref.bias.copy_(model_pt.bias)\n    model_pt.eval()\n    model.eval()\n    model_ref.eval()\n    out, residual = model(x0, res)\n    residual_pt = (x0_pt.float() + res_pt.float()).to(dtype=residual_dtype)\n    residual_ref = x0_ref + res_ref\n    out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(input_dtype)\n    out_ref = model_ref(residual_ref)\n    assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4\n    assert (residual - residual_ref).abs().max() <= 4 * (\n        residual_pt - residual_ref\n    ).abs().max() + 1e-4\n\n\n@pytest.mark.parametrize(\"has_colscale\", [True, False])\n@pytest.mark.parametrize(\"has_residual\", [True, False])\n@pytest.mark.parametrize(\"dropout_p\", [0.37, 0.0])\n@pytest.mark.parametrize(\"weight_dtype\", [torch.float32, torch.float16])\n@pytest.mark.parametrize(\n    \"input_dtype,residual_dtype\",\n    [(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]\n    + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),\n)\n# @pytest.mark.parametrize('has_colscale', [True])\n# @pytest.mark.parametrize('has_residual', [True])\n# @pytest.mark.parametrize('dropout_p', [0.0])\n# @pytest.mark.parametrize('weight_dtype', [torch.float32])\n# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])\n@pytest.mark.parametrize(\n    \"hidden_size\",\n    [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144],\n)\n# @pytest.mark.parametrize('hidden_size', [256])\ndef test_dropout_layer_norm_subset_training(\n    hidden_size, input_dtype, residual_dtype, weight_dtype, dropout_p, has_residual, has_colscale\n):\n    if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:\n        pytest.skip()  # Not supported\n    device = \"cuda\"\n    # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)\n    rtol, atol = (1e-3, 2e-4)\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 8\n    seqlen = 512\n    drop_path_rate = 0.4\n    drop_path_scale = 1 / (1 - drop_path_rate)\n\n    def generate_droppath_masks(batch_size, seqlen, drop_path_rate, device):\n        # Do it on CPU so we can get the numrows (with .item()) without GPU-CPU sync\n        mask_batch = torch.rand(batch_size) < 1 - drop_path_rate\n        numrows = (mask_batch).sum().item() * seqlen\n        mask_batch = mask_batch.to(device=device, non_blocking=True)\n        mask_batch_seqlen = repeat(mask_batch, \"b -> (b s)\", s=seqlen)\n        subset = torch.cumsum(mask_batch_seqlen, dim=0, dtype=torch.int32).masked_fill_(\n            ~mask_batch_seqlen, 0\n        )\n        return mask_batch, numrows, rearrange(subset, \"(b s) -> b s\", b=batch_size)\n\n    x0_mask_batch, x0_numrows, x0_subset = generate_droppath_masks(\n        batch_size, seqlen, drop_path_rate, device\n    )\n    out_mask_batch, out_numrows, out_subset = generate_droppath_masks(\n        batch_size, seqlen, drop_path_rate, device\n    )\n\n    x0_pt = torch.randn(\n        batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True\n    )\n    x0 = x0_pt.detach().clone()[x0_mask_batch].requires_grad_()\n    x0_ref = x0_pt.detach().clone().float().requires_grad_()\n    if has_colscale:\n        colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)\n        colscale_pt = colscale.detach().clone().requires_grad_()\n        colscale_ref = colscale.detach().clone().float().requires_grad_()\n    else:\n        colscale = None\n    if has_residual:\n        res_pt = torch.randn_like(x0_pt, dtype=residual_dtype, requires_grad=True)\n        res = res_pt.detach().clone().requires_grad_()\n        res_ref = res_pt.detach().clone().float().requires_grad_()\n    else:\n        res = None\n\n    if has_colscale:\n        x0_scaled_pt = x0_pt * colscale_pt\n        x0_scaled_ref = x0_ref * colscale_ref\n    else:\n        x0_scaled_pt = x0_pt\n        x0_scaled_ref = x0_ref\n\n    model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)\n    torch.nn.init.normal_(model_pt.weight)\n    torch.nn.init.normal_(model_pt.bias)\n    model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)\n    model = DropoutAddLayerNorm(\n        hidden_size, prenorm=False, p=dropout_p, device=device, dtype=weight_dtype\n    )\n    with torch.no_grad():\n        model.weight.copy_(model_pt.weight)\n        model.bias.copy_(model_pt.bias)\n        model_ref.weight.copy_(model_pt.weight)\n        model_ref.bias.copy_(model_pt.bias)\n\n    residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32\n    out, dmask = dropout_add_layer_norm_subset(\n        x0,\n        res,\n        model.weight,\n        model.bias,\n        model.p,\n        model.eps,\n        layerscale=colscale,\n        x0_subset=x0_subset,\n        out_subset=out_subset,\n        rowscale_const=drop_path_scale,\n        out_numrows=out_numrows,\n        prenorm=False,\n        residual_in_fp32=residual_in_fp32,\n        return_dropout_mask=True,\n    )\n    print(f\"Actual dropout fraction: {1 - dmask.float().mean().item()}\")\n\n    x0_scaled_pt = (\n        x0_scaled_pt.masked_fill(repeat(~x0_mask_batch, \"b -> b s d\", s=seqlen, d=hidden_size), 0)\n        * drop_path_scale\n    )\n    x0_scaled_ref = (\n        x0_scaled_ref.masked_fill(repeat(~x0_mask_batch, \"b -> b s d\", s=seqlen, d=hidden_size), 0)\n        * drop_path_scale\n    )\n    dmask_expanded = torch.zeros_like(x0_pt, dtype=torch.uint8)\n    dmask_expanded[x0_mask_batch] = dmask\n    if has_residual:\n        residual_pt = (\n            (x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p) + res_pt.float()\n        ).to(dtype=residual_dtype)\n        residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p) + res_ref\n    else:\n        residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p)).to(\n            dtype=residual_dtype\n        )\n        residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p)\n    out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)[out_mask_batch]\n    out_ref = model_ref(residual_ref)[out_mask_batch]\n    assert out.dtype == input_dtype\n    assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4\n\n    g = torch.randn_like(out) / batch_size\n    out_pt.backward(g)\n    out.backward(g)\n    out_ref.backward(g)\n    assert (x0.grad - x0_ref.grad[x0_mask_batch]).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad)[\n        x0_mask_batch\n    ].abs().max() + 1e-4\n    if has_residual:\n        assert (res.grad - res_ref.grad).abs().max() <= 4 * (\n            res_pt.grad - res_ref.grad\n        ).abs().max() + 1e-4\n    assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (\n        model_pt.weight.grad - model_ref.weight.grad\n    ).abs().max() + 2e-4\n    assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (\n        model_pt.bias.grad - model_ref.bias.grad\n    ).abs().max() + 2e-4\n    if has_colscale:\n        assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (\n            colscale_pt.grad - colscale_ref.grad\n        ).abs().max() + 2e-4\n\n\n@pytest.mark.parametrize(\"has_colscale\", [True, False])\n@pytest.mark.parametrize(\"has_residual\", [True, False])\n@pytest.mark.parametrize(\"dropout_p\", [0.37, 0.0])\n@pytest.mark.parametrize(\"weight_dtype\", [torch.float32, torch.float16])\n@pytest.mark.parametrize(\n    \"input_dtype,residual_dtype\",\n    [(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]\n    + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),\n)\n# @pytest.mark.parametrize('has_colscale', [True])\n# @pytest.mark.parametrize('has_residual', [True])\n# @pytest.mark.parametrize('dropout_p', [0.0])\n# @pytest.mark.parametrize('weight_dtype', [torch.float32])\n# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])\n@pytest.mark.parametrize(\n    \"hidden_size\",\n    [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144],\n)\n# @pytest.mark.parametrize('hidden_size', [256])\ndef test_dropout_layer_norm_subset_prenorm_training(\n    hidden_size, input_dtype, residual_dtype, weight_dtype, dropout_p, has_residual, has_colscale\n):\n    if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:\n        pytest.skip()  # Not supported\n    device = \"cuda\"\n    # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)\n    rtol, atol = (1e-3, 2e-4)\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 8\n    seqlen = 512\n    drop_path_rate = 0.4\n    drop_path_scale = 1 / (1 - drop_path_rate)\n\n    def generate_droppath_masks(batch_size, seqlen, drop_path_rate, device):\n        # Do it on CPU so we can get the numrows (with .item()) without GPU-CPU sync\n        mask_batch = torch.rand(batch_size) < 1 - drop_path_rate\n        numrows = (mask_batch).sum().item() * seqlen\n        mask_batch = mask_batch.to(device=device, non_blocking=True)\n        mask_batch_seqlen = repeat(mask_batch, \"b -> (b s)\", s=seqlen)\n        subset = torch.cumsum(mask_batch_seqlen, dim=0, dtype=torch.int32).masked_fill_(\n            ~mask_batch_seqlen, 0\n        )\n        return mask_batch, numrows, rearrange(subset, \"(b s) -> b s\", b=batch_size)\n\n    x0_mask_batch, x0_numrows, x0_subset = generate_droppath_masks(\n        batch_size, seqlen, drop_path_rate, device\n    )\n    out_mask_batch, out_numrows, out_subset = generate_droppath_masks(\n        batch_size, seqlen, drop_path_rate, device\n    )\n\n    x0_pt = torch.randn(\n        batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True\n    )\n    x0 = x0_pt.detach().clone()[x0_mask_batch].requires_grad_()\n    x0_ref = x0_pt.detach().clone().float().requires_grad_()\n    if has_colscale:\n        colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)\n        colscale_pt = colscale.detach().clone().requires_grad_()\n        colscale_ref = colscale.detach().clone().float().requires_grad_()\n    else:\n        colscale = None\n    if has_residual:\n        res_pt = torch.randn_like(x0_pt, dtype=residual_dtype, requires_grad=True)\n        res = res_pt.detach().clone().requires_grad_()\n        res_ref = res_pt.detach().clone().float().requires_grad_()\n    else:\n        res = None\n\n    if has_colscale:\n        x0_scaled_pt = x0_pt * colscale_pt\n        x0_scaled_ref = x0_ref * colscale_ref\n    else:\n        x0_scaled_pt = x0_pt\n        x0_scaled_ref = x0_ref\n\n    model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)\n    torch.nn.init.normal_(model_pt.weight)\n    torch.nn.init.normal_(model_pt.bias)\n    model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)\n    model = DropoutAddLayerNorm(\n        hidden_size, prenorm=True, p=dropout_p, device=device, dtype=weight_dtype\n    )\n    with torch.no_grad():\n        model.weight.copy_(model_pt.weight)\n        model.bias.copy_(model_pt.bias)\n        model_ref.weight.copy_(model_pt.weight)\n        model_ref.bias.copy_(model_pt.bias)\n\n    residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32\n    out, residual, dmask = dropout_add_layer_norm_subset(\n        x0,\n        res,\n        model.weight,\n        model.bias,\n        model.p,\n        model.eps,\n        layerscale=colscale,\n        x0_subset=x0_subset,\n        out_subset=out_subset,\n        rowscale_const=drop_path_scale,\n        out_numrows=out_numrows,\n        prenorm=True,\n        residual_in_fp32=residual_in_fp32,\n        return_dropout_mask=True,\n    )\n    print(f\"Actual dropout fraction: {1 - dmask.float().mean().item()}\")\n\n    x0_scaled_pt = (\n        x0_scaled_pt.masked_fill(repeat(~x0_mask_batch, \"b -> b s d\", s=seqlen, d=hidden_size), 0)\n        * drop_path_scale\n    )\n    x0_scaled_ref = (\n        x0_scaled_ref.masked_fill(repeat(~x0_mask_batch, \"b -> b s d\", s=seqlen, d=hidden_size), 0)\n        * drop_path_scale\n    )\n    dmask_expanded = torch.zeros_like(x0_pt, dtype=torch.uint8)\n    dmask_expanded[x0_mask_batch] = dmask\n    if has_residual:\n        residual_pt = (\n            (x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p) + res_pt.float()\n        ).to(dtype=residual_dtype)\n        residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p) + res_ref\n    else:\n        residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p)).to(\n            dtype=residual_dtype\n        )\n        residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p)\n    out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)[out_mask_batch]\n    out_ref = model_ref(residual_ref)[out_mask_batch]\n    assert out.dtype == input_dtype\n    assert residual.dtype == residual_dtype\n    assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4\n    assert (residual - residual_ref).abs().max() <= 4 * (\n        residual_pt - residual_ref\n    ).abs().max() + 1e-4\n\n    g = torch.randn_like(out) / batch_size\n    (out_pt * F.sigmoid(residual_pt[out_mask_batch]) + residual_pt.mean(0, keepdim=True)).backward(\n        g\n    )\n    (out * F.sigmoid(residual[out_mask_batch]) + residual.mean(0, keepdim=True)).backward(g)\n    (\n        out_ref * F.sigmoid(residual_ref[out_mask_batch].to(dtype=residual_dtype))\n        + residual_ref.mean(0, keepdim=True)\n    ).backward(g)\n    assert (x0.grad - x0_ref.grad[x0_mask_batch]).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad)[\n        x0_mask_batch\n    ].abs().max() + 1e-4\n    if has_residual:\n        assert (res.grad - res_ref.grad).abs().max() <= 4 * (\n            res_pt.grad - res_ref.grad\n        ).abs().max() + 1e-4\n    assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (\n        model_pt.weight.grad - model_ref.weight.grad\n    ).abs().max() + 2e-4\n    assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (\n        model_pt.bias.grad - model_ref.bias.grad\n    ).abs().max() + 2e-4\n    if has_colscale:\n        assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * (\n            colscale_pt.grad - colscale_ref.grad\n        ).abs().max() + 2e-4\n\n\n@pytest.mark.parametrize(\"is_rms_norm\", [False, True])\n# @pytest.mark.parametrize('is_rms_norm', [False])\n@pytest.mark.parametrize(\"tied_norm\", [False, True])\n# @pytest.mark.parametrize('tied_norm', [False])\n@pytest.mark.parametrize(\"has_residual\", [True, False])\n# @pytest.mark.parametrize('has_residual', [False])\n@pytest.mark.parametrize(\"has_x1\", [True, False])\n# @pytest.mark.parametrize('has_x1', [True])\n@pytest.mark.parametrize(\"dropout_p\", [0.37, 0.0])\n# @pytest.mark.parametrize('dropout_p', [0.0])\n@pytest.mark.parametrize(\"weight_dtype\", [torch.float32, torch.float16])\n# @pytest.mark.parametrize('weight_dtype', [torch.float16])\n@pytest.mark.parametrize(\n    \"input_dtype,residual_dtype\",\n    [(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]\n    + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),\n)\n# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])\n@pytest.mark.parametrize(\n    \"hidden_size\",\n    [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144],\n)\n# @pytest.mark.parametrize('hidden_size', [256])\ndef test_dropout_layer_norm_parallel_residual_training(\n    hidden_size,\n    input_dtype,\n    residual_dtype,\n    weight_dtype,\n    dropout_p,\n    has_x1,\n    has_residual,\n    tied_norm,\n    is_rms_norm,\n):\n    if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:\n        pytest.skip()  # Not supported\n    if is_rms_norm and fused_rms_norm_affine is None:\n        pytest.skip()  # We need Apex's FusedRMSNorm to test\n    our_layer_norm_func = (\n        dropout_add_layer_norm_parallel_residual\n        if not is_rms_norm\n        else dropout_add_rms_norm_parallel_residual\n    )\n    device = \"cuda\"\n    # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)\n    rtol, atol = (1e-3, 1e-4)\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 8\n    seqlen = 512\n    x0_pt = torch.randn(\n        batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True\n    )\n    x0 = x0_pt.detach().clone().requires_grad_()\n    x0_ref = x0_pt.detach().clone().float().requires_grad_()\n    if has_x1:\n        x1_pt = torch.randn(\n            batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True\n        )\n        x1 = x1_pt.detach().clone().requires_grad_()\n        x1_ref = x1_pt.detach().clone().float().requires_grad_()\n    else:\n        x1 = None\n    if has_residual:\n        res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)\n        res = res_pt.detach().clone().requires_grad_()\n        res_ref = res_pt.detach().clone().float().requires_grad_()\n    else:\n        res = None\n    weight0 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)\n    bias0 = (\n        torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)\n        if not is_rms_norm\n        else None\n    )\n    weight0_pt = weight0.detach().clone().requires_grad_()\n    weight0_ref = weight0.detach().clone().float().requires_grad_()\n    bias0_pt = bias0.detach().clone().requires_grad_() if bias0 is not None else None\n    bias0_ref = bias0.detach().clone().float().requires_grad_() if bias0 is not None else None\n    if not tied_norm:\n        weight1 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)\n        bias1 = (\n            torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)\n            if not is_rms_norm\n            else None\n        )\n        weight1_pt = weight1.detach().clone().requires_grad_()\n        weight1_ref = weight1.detach().clone().float().requires_grad_()\n        bias1_pt = bias1.detach().clone().requires_grad_() if bias1 is not None else None\n        bias1_ref = bias1.detach().clone().float().requires_grad_() if bias1 is not None else None\n    else:\n        weight1, bias1 = None, None\n    epsilon = 1e-5\n    residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32\n\n    out0, out1, dmask0, dmask1 = our_layer_norm_func(\n        x0,\n        x1,\n        res,\n        weight0,\n        bias0,\n        weight1,\n        bias1,\n        dropout_p,\n        epsilon,\n        residual_in_fp32=residual_in_fp32,\n        return_dropout_mask=True,\n    )\n    assert out0.dtype == input_dtype\n    if not tied_norm:\n        assert out1.dtype == input_dtype\n    print(f\"Actual dropout fraction: {1 - dmask0.float().mean().item()}\")\n    if has_residual:\n        if has_x1:\n            residual_pt = (\n                (x0_pt.float() * dmask0.float()) / (1 - dropout_p)\n                + (x1_pt.float() * dmask1.float()) / (1 - dropout_p)\n                + res_pt.float()\n            ).to(dtype=residual_dtype)\n            residual_ref = (\n                (x0_ref * dmask0.float()) / (1 - dropout_p)\n                + (x1_ref * dmask1.float()) / (1 - dropout_p)\n            ) + res_ref\n        else:\n            residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p) + res_pt.float()).to(\n                dtype=residual_dtype\n            )\n            residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p) + res_ref\n    else:\n        if has_x1:\n            residual_pt = (\n                (x0_pt.float() * dmask0.float()) / (1 - dropout_p)\n                + (x1_pt.float() * dmask1.float()) / (1 - dropout_p)\n            ).to(dtype=residual_dtype)\n            residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p) + (\n                x1_ref * dmask1.float()\n            ) / (1 - dropout_p)\n        else:\n            residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)).to(\n                dtype=residual_dtype\n            )\n            residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p)\n    if not is_rms_norm:\n        out0_pt = F.layer_norm(\n            residual_pt.to(dtype=weight_dtype), (hidden_size,), weight0_pt, bias0_pt, eps=epsilon\n        ).to(dtype=input_dtype)\n        out0_ref = F.layer_norm(residual_ref, (hidden_size,), weight0_ref, bias0_ref, eps=epsilon)\n        if not tied_norm:\n            out1_pt = F.layer_norm(\n                residual_pt.to(dtype=weight_dtype),\n                (hidden_size,),\n                weight1_pt,\n                bias1_pt,\n                eps=epsilon,\n            ).to(dtype=input_dtype)\n            out1_ref = F.layer_norm(\n                residual_ref, (hidden_size,), weight1_ref, bias1_ref, eps=epsilon\n            )\n    else:\n        out0_pt = fused_rms_norm_affine(\n            residual_pt.to(dtype=weight_dtype), weight0_pt, (hidden_size,), eps=epsilon\n        ).to(dtype=input_dtype)\n        out0_ref = fused_rms_norm_affine(residual_ref, weight0_ref, (hidden_size,), eps=epsilon)\n        if not tied_norm:\n            out1_pt = fused_rms_norm_affine(\n                residual_pt.to(dtype=weight_dtype), weight1_pt, (hidden_size,), eps=epsilon\n            ).to(dtype=input_dtype)\n            out1_ref = fused_rms_norm_affine(residual_ref, weight1_ref, (hidden_size,), eps=epsilon)\n\n    assert (out0 - out0_ref).abs().max() <= 4 * (out0_pt - out0_ref).abs().max() + 1e-4\n    if not tied_norm:\n        assert (out1 - out1_ref).abs().max() <= 4 * (out1_pt - out1_ref).abs().max() + 1e-4\n\n    g0 = torch.randn_like(out0) / batch_size\n    if tied_norm:\n        out0.backward(g0)\n        out0_pt.backward(g0)\n        out0_ref.backward(g0)\n    else:\n        g1 = torch.randn_like(out1) / batch_size\n        (out0 * g0 + out1 * g1).sum().backward()\n        (out0_pt * g0 + out1_pt * g1).sum().backward()\n        (out0_ref * g0 + out1_ref * g1).sum().backward()\n    assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4\n    if has_x1:\n        assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (\n            x1_pt.grad - x1_ref.grad\n        ).abs().max() + 1e-4\n    if has_residual:\n        assert (res.grad - res_ref.grad).abs().max() <= 4 * (\n            res_pt.grad - res_ref.grad\n        ).abs().max() + 1e-4\n    assert (weight0.grad - weight0_ref.grad).abs().max() <= 3 * (\n        weight0_pt.grad - weight0_ref.grad\n    ).abs().max() + 3e-5\n    if not is_rms_norm:\n        assert (bias0.grad - bias0_ref.grad).abs().max() <= 2 * (\n            bias0_pt.grad - bias0_ref.grad\n        ).abs().max() + 3e-5\n    if not tied_norm:\n        assert (weight1.grad - weight1_ref.grad).abs().max() <= 3 * (\n            weight1_pt.grad - weight1_ref.grad\n        ).abs().max() + 3e-5\n        if not is_rms_norm:\n            assert (bias1.grad - bias1_ref.grad).abs().max() <= 2 * (\n                bias1_pt.grad - bias1_ref.grad\n            ).abs().max() + 3e-5\n\n\n@pytest.mark.parametrize(\"is_rms_norm\", [False, True])\n# @pytest.mark.parametrize('is_rms_norm', [False])\n@pytest.mark.parametrize(\"tied_norm\", [False, True])\n# @pytest.mark.parametrize('tied_norm', [False])\n@pytest.mark.parametrize(\"has_residual\", [True, False])\n# @pytest.mark.parametrize('has_residual', [False])\n@pytest.mark.parametrize(\"has_x1\", [True, False])\n# @pytest.mark.parametrize('has_x1', [True])\n@pytest.mark.parametrize(\"dropout_p\", [0.37, 0.0])\n# @pytest.mark.parametrize('dropout_p', [0.0])\n@pytest.mark.parametrize(\"weight_dtype\", [torch.float32, torch.float16])\n# @pytest.mark.parametrize('weight_dtype', [torch.float16])\n@pytest.mark.parametrize(\n    \"input_dtype,residual_dtype\",\n    [(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]\n    + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),\n)\n# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])\n@pytest.mark.parametrize(\n    \"hidden_size\",\n    [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144],\n)\n# @pytest.mark.parametrize('hidden_size', [256])\ndef test_dropout_layer_norm_parallel_residual_prenorm_training(\n    hidden_size,\n    input_dtype,\n    residual_dtype,\n    weight_dtype,\n    dropout_p,\n    has_x1,\n    has_residual,\n    tied_norm,\n    is_rms_norm,\n):\n    if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:\n        pytest.skip()  # Not supported\n    if is_rms_norm and fused_rms_norm_affine is None:\n        pytest.skip()  # We need Apex's FusedRMSNorm to test\n    our_layer_norm_func = (\n        dropout_add_layer_norm_parallel_residual\n        if not is_rms_norm\n        else dropout_add_rms_norm_parallel_residual\n    )\n    device = \"cuda\"\n    # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)\n    rtol, atol = (1e-3, 1e-4)\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 8\n    seqlen = 512\n    x0_pt = torch.randn(\n        batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True\n    )\n    x0 = x0_pt.detach().clone().requires_grad_()\n    x0_ref = x0_pt.detach().clone().float().requires_grad_()\n    if has_x1:\n        x1_pt = torch.randn(\n            batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True\n        )\n        x1 = x1_pt.detach().clone().requires_grad_()\n        x1_ref = x1_pt.detach().clone().float().requires_grad_()\n    else:\n        x1 = None\n    if has_residual:\n        res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)\n        res = res_pt.detach().clone().requires_grad_()\n        res_ref = res_pt.detach().clone().float().requires_grad_()\n    else:\n        res = None\n    weight0 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)\n    bias0 = (\n        torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)\n        if not is_rms_norm\n        else None\n    )\n    weight0_pt = weight0.detach().clone().requires_grad_()\n    weight0_ref = weight0.detach().clone().float().requires_grad_()\n    bias0_pt = bias0.detach().clone().requires_grad_() if bias0 is not None else None\n    bias0_ref = bias0.detach().clone().float().requires_grad_() if bias0 is not None else None\n    if not tied_norm:\n        weight1 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)\n        bias1 = (\n            torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)\n            if not is_rms_norm\n            else None\n        )\n        weight1_pt = weight1.detach().clone().requires_grad_()\n        weight1_ref = weight1.detach().clone().float().requires_grad_()\n        bias1_pt = bias1.detach().clone().requires_grad_() if bias1 is not None else None\n        bias1_ref = bias1.detach().clone().float().requires_grad_() if bias1 is not None else None\n    else:\n        weight1, bias1 = None, None\n    epsilon = 1e-5\n    residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32\n\n    out0, out1, residual, dmask0, dmask1 = our_layer_norm_func(\n        x0,\n        x1,\n        res,\n        weight0,\n        bias0,\n        weight1,\n        bias1,\n        dropout_p,\n        epsilon,\n        prenorm=True,\n        residual_in_fp32=residual_in_fp32,\n        return_dropout_mask=True,\n    )\n    assert out0.dtype == input_dtype\n    if not tied_norm:\n        assert out1.dtype == input_dtype\n    print(f\"Actual dropout fraction: {1 - dmask0.float().mean().item()}\")\n    if has_residual:\n        if has_x1:\n            residual_pt = (\n                (x0_pt.float() * dmask0.float()) / (1 - dropout_p)\n                + (x1_pt.float() * dmask1.float()) / (1 - dropout_p)\n                + res_pt.float()\n            ).to(dtype=residual_dtype)\n            residual_ref = (\n                (x0_ref * dmask0.float()) / (1 - dropout_p)\n                + (x1_ref * dmask1.float()) / (1 - dropout_p)\n            ) + res_ref\n        else:\n            residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p) + res_pt.float()).to(\n                dtype=residual_dtype\n            )\n            residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p) + res_ref\n    else:\n        if has_x1:\n            residual_pt = (\n                (x0_pt.float() * dmask0.float()) / (1 - dropout_p)\n                + (x1_pt.float() * dmask1.float()) / (1 - dropout_p)\n            ).to(dtype=residual_dtype)\n            residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p) + (\n                x1_ref * dmask1.float()\n            ) / (1 - dropout_p)\n        else:\n            residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)).to(\n                dtype=residual_dtype\n            )\n            residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p)\n    if not is_rms_norm:\n        out0_pt = F.layer_norm(\n            residual_pt.to(dtype=weight_dtype), (hidden_size,), weight0_pt, bias0_pt, eps=epsilon\n        ).to(dtype=input_dtype)\n        out0_ref = F.layer_norm(residual_ref, (hidden_size,), weight0_ref, bias0_ref, eps=epsilon)\n        if not tied_norm:\n            out1_pt = F.layer_norm(\n                residual_pt.to(dtype=weight_dtype),\n                (hidden_size,),\n                weight1_pt,\n                bias1_pt,\n                eps=epsilon,\n            ).to(dtype=input_dtype)\n            out1_ref = F.layer_norm(\n                residual_ref, (hidden_size,), weight1_ref, bias1_ref, eps=epsilon\n            )\n    else:\n        out0_pt = fused_rms_norm_affine(\n            residual_pt.to(dtype=weight_dtype), weight0_pt, (hidden_size,), eps=epsilon\n        ).to(dtype=input_dtype)\n        out0_ref = fused_rms_norm_affine(residual_ref, weight0_ref, (hidden_size,), eps=epsilon)\n        if not tied_norm:\n            out1_pt = fused_rms_norm_affine(\n                residual_pt.to(dtype=weight_dtype), weight1_pt, (hidden_size,), eps=epsilon\n            ).to(dtype=input_dtype)\n            out1_ref = fused_rms_norm_affine(residual_ref, weight1_ref, (hidden_size,), eps=epsilon)\n\n    assert (out0 - out0_ref).abs().max() <= 4 * (out0_pt - out0_ref).abs().max() + 1e-4\n    if not tied_norm:\n        assert (out1 - out1_ref).abs().max() <= 4 * (out1_pt - out1_ref).abs().max() + 1e-4\n    assert (residual - residual_ref).abs().max() <= 4 * (\n        residual_pt - residual_ref\n    ).abs().max() + 1e-4\n\n    g0 = torch.randn_like(out0) / batch_size\n    if tied_norm:\n        (out0 * F.sigmoid(residual)).backward(g0)\n        (out0_pt * F.sigmoid(residual_pt)).backward(g0)\n        (out0_ref * F.sigmoid(residual_ref)).backward(g0)\n    else:\n        g1 = torch.randn_like(out1) / batch_size\n        (out0 * F.sigmoid(residual) * g0 + out1 * g1).sum().backward()\n        (out0_pt * F.sigmoid(residual_pt) * g0 + out1_pt * g1).sum().backward()\n        (out0_ref * F.sigmoid(residual_ref) * g0 + out1_ref * g1).sum().backward()\n    assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4\n    if has_x1:\n        assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (\n            x1_pt.grad - x1_ref.grad\n        ).abs().max() + 1e-4\n    if has_residual:\n        assert (res.grad - res_ref.grad).abs().max() <= 4 * (\n            res_pt.grad - res_ref.grad\n        ).abs().max() + 1e-4\n    assert (weight0.grad - weight0_ref.grad).abs().max() <= 3 * (\n        weight0_pt.grad - weight0_ref.grad\n    ).abs().max() + 3e-5\n    if not is_rms_norm:\n        assert (bias0.grad - bias0_ref.grad).abs().max() <= 2 * (\n            bias0_pt.grad - bias0_ref.grad\n        ).abs().max() + 3e-5\n    if not tied_norm:\n        assert (weight1.grad - weight1_ref.grad).abs().max() <= 3 * (\n            weight1_pt.grad - weight1_ref.grad\n        ).abs().max() + 3e-5\n        if not is_rms_norm:\n            assert (bias1.grad - bias1_ref.grad).abs().max() <= 2 * (\n                bias1_pt.grad - bias1_ref.grad\n            ).abs().max() + 3e-5\n\n\ndef test_dropout_layer_norm_randomness():\n    hidden_size = 256\n    dtype = torch.float32\n    dropout_p = 0.1\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 8\n    seqlen = 512\n    x0 = torch.randn(\n        batch_size, seqlen, hidden_size, device=device, dtype=dtype, requires_grad=True\n    )\n    res = torch.randn_like(x0, dtype=dtype, requires_grad=True)\n    model = DropoutAddLayerNorm(hidden_size, p=dropout_p, device=device, dtype=dtype)\n    torch.random.manual_seed(42)\n    _, dmask0 = dropout_add_layer_norm(\n        x0, res, model.weight, model.bias, model.p, model.eps, return_dropout_mask=True\n    )\n    # Subsequent call should have a different dropout mask\n    _, dmask1 = dropout_add_layer_norm(\n        x0, res, model.weight, model.bias, model.p, model.eps, return_dropout_mask=True\n    )\n    torch.random.manual_seed(42)\n    # Resetting the seed, should get the same dropout mask\n    _, dmask2 = dropout_add_layer_norm(\n        x0, res, model.weight, model.bias, model.p, model.eps, return_dropout_mask=True\n    )\n    assert not torch.equal(dmask0, dmask1)\n    assert torch.equal(dmask0, dmask2)\n"
  },
  {
    "path": "tests/ops/test_fused_dense.py",
    "content": "import math\nfrom functools import partial\n\nimport pytest\nimport torch\nimport torch.nn.functional as F\nfrom einops import rearrange\nfrom flash_attn.ops.fused_dense import FusedDense, FusedMLP\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16])\n@pytest.mark.parametrize(\"return_residual\", [False, True])\n@pytest.mark.parametrize(\"has_bias\", [True, False])\n@pytest.mark.parametrize(\"out_features\", [1024, 4096])\n@pytest.mark.parametrize(\"in_features\", [1024, 4096])\ndef test_fused_linear_bias(in_features, out_features, has_bias, return_residual, dtype):\n    device = \"cuda\"\n    rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 8\n    seqlen = 512\n    x_pt = torch.randn(\n        batch_size, seqlen, in_features, device=device, dtype=dtype, requires_grad=True\n    )\n    x = x_pt.detach().clone().requires_grad_()\n    model_pt = torch.nn.Linear(in_features, out_features, bias=has_bias, device=device, dtype=dtype)\n    model = FusedDense(\n        in_features,\n        out_features,\n        bias=has_bias,\n        return_residual=return_residual,\n        device=device,\n        dtype=dtype,\n    )\n    with torch.no_grad():\n        model.weight.copy_(model_pt.weight)\n        if has_bias:\n            model.bias.copy_(model_pt.bias)\n    out_pt = model_pt(x_pt)\n    if not return_residual:\n        out = model(x)\n    else:\n        out, x_copy = model(x)\n        x_copy = (\n            x_copy[..., :out_features]\n            if out_features < in_features\n            else F.pad(x_copy, (0, out_features - in_features))\n        )\n        x_pt_copy = (\n            x_pt[..., :out_features]\n            if out_features < in_features\n            else F.pad(x_pt, (0, out_features - in_features))\n        )\n        # Just add some random function of the residual\n        out_pt = out_pt + F.gelu(x_pt_copy)\n        out = out + F.gelu(x_copy)\n\n    # with torch.no_grad():\n    #     out_fl = F.linear(x_pt.float(), model.weight.float(), model.bias.float()).half()\n    assert torch.allclose(out, out_pt, rtol=rtol, atol=atol)\n\n    # If we don't divide by batch_size, the gradient gets a bit too large.\n    g = torch.randn_like(out) / 32\n    out_pt.backward(g)\n    out.backward(g)\n    assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol)\n    # The error for d_weight and d_bias is quite a bit higher\n    assert torch.allclose(model.weight.grad, model_pt.weight.grad, rtol=rtol, atol=atol * 10)\n    if has_bias:\n        assert torch.allclose(model.bias.grad, model_pt.bias.grad, rtol=rtol, atol=atol * 5)\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16])\n# @pytest.mark.parametrize('dtype', [torch.float16])\n@pytest.mark.parametrize(\"heuristic\", [\"auto\", -1])\n# @pytest.mark.parametrize('heuristic', ['auto'])\n@pytest.mark.parametrize(\"checkpoint_lvl\", [0, 1, 2])\n# @pytest.mark.parametrize('checkpoint_lvl', [1])\n@pytest.mark.parametrize(\"return_residual\", [False, True])\n# @pytest.mark.parametrize('return_residual', [False])\n@pytest.mark.parametrize(\"has_bias2\", [True, False])\n@pytest.mark.parametrize(\"has_bias1\", [True, False])\n# @pytest.mark.parametrize('has_bias2', [True])\n# @pytest.mark.parametrize('has_bias1', [True])\n@pytest.mark.parametrize(\"activation\", [\"gelu_approx\", \"relu\"])\n# @pytest.mark.parametrize('activation', ['relu'])\n@pytest.mark.parametrize(\"out_features\", [1024, 4096])\n@pytest.mark.parametrize(\"in_features\", [1024, 4096])\n# @pytest.mark.parametrize('out_features', [4096])\n# @pytest.mark.parametrize('in_features', [1024])\ndef test_fused_mlp(\n    in_features,\n    out_features,\n    activation,\n    has_bias1,\n    has_bias2,\n    return_residual,\n    checkpoint_lvl,\n    heuristic,\n    dtype,\n):\n    device = \"cuda\"\n    rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 8\n    seqlen = 512\n    x_pt = torch.randn(\n        batch_size, seqlen, in_features, device=device, dtype=dtype, requires_grad=True\n    )\n    x = x_pt.detach().clone().requires_grad_()\n    model_pt_fc1 = torch.nn.Linear(\n        in_features, out_features, bias=has_bias1, device=device, dtype=dtype\n    )\n    model_pt_fc2 = torch.nn.Linear(\n        out_features, in_features, bias=has_bias2, device=device, dtype=dtype\n    )\n    model = FusedMLP(\n        in_features,\n        out_features,\n        in_features,\n        activation=activation,\n        bias1=has_bias1,\n        bias2=has_bias2,\n        return_residual=return_residual,\n        checkpoint_lvl=checkpoint_lvl,\n        heuristic=heuristic,\n        device=device,\n        dtype=dtype,\n    )\n    with torch.no_grad():\n        model.fc1.weight.copy_(model_pt_fc1.weight)\n        if has_bias1:\n            model.fc1.bias.copy_(model_pt_fc1.bias)\n        model.fc2.weight.copy_(model_pt_fc2.weight)\n        if has_bias2:\n            model.fc2.bias.copy_(model_pt_fc2.bias)\n    activation_fn = (\n        partial(F.gelu, approximate=\"tanh\")\n        if activation == \"gelu_approx\"\n        else partial(F.relu, inplace=True)\n    )\n    out_pt = model_pt_fc2(activation_fn(model_pt_fc1(x_pt)))\n    if not return_residual:\n        out = model(x)\n    else:\n        out, x_copy = model(x)\n        # Just add some random function of the residual\n        out_pt = out_pt + F.gelu(x_pt)\n        out = out + F.gelu(x_copy)\n    assert torch.allclose(out, out_pt, rtol=rtol, atol=atol)\n\n    # If we don't divide by batch_size, the gradient gets a bit too large.\n    g = torch.randn_like(out) / 32\n    out_pt.backward(g)\n    out.backward(g)\n    # The error for relu is higher still\n    if activation == \"relu\":\n        atol = 1e-1 if dtype == torch.bfloat16 else 5e-2\n    assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol)\n    # The error for d_weight and d_bias is quite a bit higher\n    assert torch.allclose(\n        model.fc1.weight.grad, model_pt_fc1.weight.grad, rtol=rtol, atol=atol * 10\n    )\n    if has_bias1:\n        assert torch.allclose(model.fc1.bias.grad, model_pt_fc1.bias.grad, rtol=rtol, atol=atol * 5)\n    assert torch.allclose(\n        model.fc2.weight.grad, model_pt_fc2.weight.grad, rtol=rtol, atol=atol * 10\n    )\n    if has_bias2:\n        assert torch.allclose(model.fc2.bias.grad, model_pt_fc2.bias.grad, rtol=rtol, atol=atol * 5)\n"
  },
  {
    "path": "tests/ops/test_fused_dense_parallel.py",
    "content": "# Run test with:\n# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/ops/test_fused_dense_parallel.py\n\nimport math\n\nimport pytest\nimport torch\nimport torch.nn.functional as F\nfrom apex.transformer import parallel_state, tensor_parallel\nfrom flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, FusedMLP, ParallelFusedMLP\n\nis_sm8x = torch.cuda.get_device_capability(\"cuda\")[0] >= 8\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))\n# @pytest.mark.parametrize('dtype', [torch.bfloat16])\n@pytest.mark.parametrize(\"world_size\", [1, 2, 4, 8])\n# @pytest.mark.parametrize('world_size', [2])\n@pytest.mark.parametrize(\"sequence_parallel\", [True, False])\n# @pytest.mark.parametrize('sequence_parallel', [False])\n@pytest.mark.parametrize(\"has_bias\", [True, False])\n# @pytest.mark.parametrize('has_bias', [False])\n@pytest.mark.parametrize(\"out_features\", [1024])\n@pytest.mark.parametrize(\"in_features\", [4096])\ndef test_fused_linear_bias(\n    in_features, out_features, has_bias, sequence_parallel, world_size, dtype\n):\n    assert out_features % world_size == 0\n    rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3)\n    if not torch.distributed.is_initialized():\n        torch.distributed.init_process_group(backend=\"nccl\", init_method=\"env://\")\n    device = f\"cuda:{torch.distributed.get_rank()}\"\n    assert world_size <= torch.distributed.get_world_size()\n    parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)\n    rank = parallel_state.get_tensor_model_parallel_rank()\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 2\n    seqlen = 512\n    assert batch_size * seqlen % world_size == 0\n    x_pt = torch.randn(\n        batch_size * seqlen, in_features, device=device, dtype=dtype, requires_grad=True\n    )\n    if sequence_parallel:\n        x = (\n            tensor_parallel.scatter_to_sequence_parallel_region(x_pt)\n            .detach()\n            .clone()\n            .requires_grad_()\n        )\n    else:\n        x = x_pt.detach().clone().requires_grad_()\n\n    model_pt = torch.nn.Linear(in_features, out_features, bias=has_bias, device=device, dtype=dtype)\n    partition_out_features = out_features // world_size\n    model = ColumnParallelLinear(\n        in_features,\n        out_features,\n        parallel_state.get_tensor_model_parallel_group(),\n        bias=has_bias,\n        sequence_parallel=sequence_parallel,\n        device=device,\n        dtype=dtype,\n    )\n    with torch.no_grad():\n        model.weight.copy_(\n            model_pt.weight[rank * partition_out_features : (rank + 1) * partition_out_features]\n        )\n        if has_bias:\n            model.bias.copy_(\n                model_pt.bias[rank * partition_out_features : (rank + 1) * partition_out_features]\n            )\n\n    out = model(x)\n    out_pt = model_pt(x_pt)\n    assert torch.allclose(\n        out,\n        out_pt[:, rank * partition_out_features : (rank + 1) * partition_out_features],\n        rtol=rtol,\n        atol=atol,\n    )\n\n    # If we don't divide by batch_size, the gradient gets a bit too large.\n    g = torch.randn_like(out_pt) / 32\n    out_pt.backward(g)\n    out.backward(g[:, rank * partition_out_features : (rank + 1) * partition_out_features])\n    parallel_state.destroy_model_parallel()\n\n    partition_batch_dim = batch_size * seqlen // world_size\n    assert torch.allclose(\n        x.grad,\n        x_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]\n        if sequence_parallel\n        else x_pt.grad,\n        rtol=rtol,\n        atol=atol,\n    )\n    # The error for d_weight and d_bias is quite a bit higher\n    assert torch.allclose(\n        model.weight.grad,\n        model_pt.weight.grad[rank * partition_out_features : (rank + 1) * partition_out_features],\n        rtol=rtol,\n        atol=atol * 10,\n    )\n    if has_bias:\n        assert torch.allclose(\n            model.bias.grad,\n            model_pt.bias.grad[rank * partition_out_features : (rank + 1) * partition_out_features],\n            rtol=rtol,\n            atol=atol * 5,\n        )\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))\n# @pytest.mark.parametrize('dtype', [torch.bfloat16])\n@pytest.mark.parametrize(\"world_size\", [1, 2, 4, 8])\n# @pytest.mark.parametrize('world_size', [2])\n@pytest.mark.parametrize(\"sequence_parallel\", [True, False])\n# @pytest.mark.parametrize('sequence_parallel', [False])\n@pytest.mark.parametrize(\"has_bias2\", [True, False])\n# @pytest.mark.parametrize('has_bias2', [True])\n@pytest.mark.parametrize(\"out_features\", [4096])\n@pytest.mark.parametrize(\"in_features\", [1024])\ndef test_fused_mlp(in_features, out_features, has_bias2, sequence_parallel, world_size, dtype):\n    assert out_features % world_size == 0\n    rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3)\n    if not torch.distributed.is_initialized():\n        torch.distributed.init_process_group(backend=\"nccl\", init_method=\"env://\")\n    device = f\"cuda:{torch.distributed.get_rank()}\"\n    assert world_size <= torch.distributed.get_world_size()\n    parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)\n    rank = parallel_state.get_tensor_model_parallel_rank()\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 2\n    seqlen = 512\n    assert batch_size * seqlen % world_size == 0\n    x_pt = torch.randn(\n        batch_size * seqlen, in_features, device=device, dtype=dtype, requires_grad=True\n    )\n    # We need to generate g here so that all processes get the same gradient,\n    # as rank 0 will have an extra bias that changes the RNG.\n    # If we don't divide by batch_size, the gradient gets a bit too large.\n    g = torch.randn_like(x_pt) / 32\n    if sequence_parallel:\n        x = (\n            tensor_parallel.scatter_to_sequence_parallel_region(x_pt)\n            .detach()\n            .clone()\n            .requires_grad_()\n        )\n    else:\n        x = x_pt.detach().clone().requires_grad_()\n\n    model_pt_fc1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)\n    model_pt_fc2 = torch.nn.Linear(\n        out_features, in_features, bias=has_bias2, device=device, dtype=dtype\n    )\n    partition_out_features = out_features // world_size\n    partition_in_features = in_features // world_size\n    model = ParallelFusedMLP(\n        in_features,\n        out_features,\n        in_features,\n        process_group=parallel_state.get_tensor_model_parallel_group(),\n        bias2=has_bias2 and rank == 0,\n        sequence_parallel=sequence_parallel,\n        device=device,\n        dtype=dtype,\n    )\n\n    with torch.no_grad():\n        model.fc1.weight.copy_(\n            model_pt_fc1.weight[rank * partition_out_features : (rank + 1) * partition_out_features]\n        )\n        model.fc1.bias.copy_(\n            model_pt_fc1.bias[rank * partition_out_features : (rank + 1) * partition_out_features]\n        )\n        model.fc2.weight.copy_(\n            model_pt_fc2.weight[\n                :, rank * partition_out_features : (rank + 1) * partition_out_features\n            ]\n        )\n        if has_bias2 and rank == 0:\n            model.fc2.bias.copy_(model_pt_fc2.bias)\n\n    out = model(x)\n    out_pt = model_pt_fc2(F.gelu(model_pt_fc1(x_pt), approximate=\"tanh\"))\n    partition_batch_dim = batch_size * seqlen // world_size\n    assert torch.allclose(\n        out,\n        out_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]\n        if sequence_parallel\n        else out_pt,\n        rtol=rtol,\n        atol=atol,\n    )\n\n    out_pt.backward(g)\n    out.backward(\n        g[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else g\n    )\n    parallel_state.destroy_model_parallel()\n\n    assert torch.allclose(\n        x.grad,\n        x_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim]\n        if sequence_parallel\n        else x_pt.grad,\n        rtol=rtol,\n        atol=atol,\n    )\n    # The error for d_weight and d_bias is quite a bit higher\n    assert torch.allclose(\n        model.fc1.weight.grad,\n        model_pt_fc1.weight.grad[\n            rank * partition_out_features : (rank + 1) * partition_out_features\n        ],\n        rtol=rtol,\n        atol=atol * 10,\n    )\n    assert torch.allclose(\n        model.fc1.bias.grad,\n        model_pt_fc1.bias.grad[rank * partition_out_features : (rank + 1) * partition_out_features],\n        rtol=rtol,\n        atol=atol * 5,\n    )\n    assert torch.allclose(\n        model.fc2.weight.grad,\n        model_pt_fc2.weight.grad[\n            :, rank * partition_out_features : (rank + 1) * partition_out_features\n        ],\n        rtol=rtol,\n        atol=atol * 10,\n    )\n    if has_bias2 and rank == 0:\n        assert torch.allclose(model.fc2.bias.grad, model_pt_fc2.bias.grad, rtol=rtol, atol=atol * 5)\n"
  },
  {
    "path": "tests/ops/triton/test_layer_norm.py",
    "content": "# Copyright (c) 2024, Tri Dao.\n\nimport pytest\nimport torch\nimport torch.nn.functional as F\nfrom einops import rearrange, repeat\n\nfrom flash_attn.ops.triton.layer_norm import (\n    layer_norm_fn,\n    layer_norm_ref,\n    rms_norm_ref,\n    layer_norm_linear_fn,\n)\n\n\nis_sm8x = torch.cuda.get_device_capability(\"cuda\")[0] >= 8\n\n\n# @pytest.mark.parametrize(\"zero_centered_weight\", [False, True])\n@pytest.mark.parametrize(\"zero_centered_weight\", [False])\n@pytest.mark.parametrize(\"has_weight1\", [False, True])\n# @pytest.mark.parametrize(\"has_weight1\", [False])\n@pytest.mark.parametrize(\"has_x1\", [False, True])\n# @pytest.mark.parametrize(\"has_x1\", [False])\n@pytest.mark.parametrize(\"has_rowscale\", [False, True])\n# @pytest.mark.parametrize(\"has_rowscale\", [False])\n@pytest.mark.parametrize(\"dropout_p\", [0.0, 0.27])\n# @pytest.mark.parametrize(\"dropout_p\", [0.0])\n@pytest.mark.parametrize(\"prenorm\", [True, False])\n# @pytest.mark.parametrize(\"prenorm\", [True])\n@pytest.mark.parametrize(\"is_rms_norm\", [False, True])\n# @pytest.mark.parametrize(\"is_rms_norm\", [True])\n@pytest.mark.parametrize(\"has_residual\", [True, False])\n# @pytest.mark.parametrize(\"has_residual\", [True])\n@pytest.mark.parametrize(\n    \"weight_dtype\", [torch.float32, torch.float16] + ([torch.bfloat16] if is_sm8x else [])\n)\n# @pytest.mark.parametrize(\"weight_dtype\", [torch.float32])\n@pytest.mark.parametrize(\n    \"input_dtype,residual_dtype\",\n    [(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]\n    + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),\n)\n# @pytest.mark.parametrize(\"input_dtype,residual_dtype\", [(torch.float16, torch.float16)])\n@pytest.mark.parametrize(\"hidden_size\", [192, 2048, 2560, 3000, 4096])\n# @pytest.mark.parametrize(\"hidden_size\", [1024])\ndef test_layer_norm(\n    hidden_size,\n    input_dtype,\n    residual_dtype,\n    weight_dtype,\n    has_residual,\n    is_rms_norm,\n    prenorm,\n    dropout_p,\n    has_rowscale,\n    has_x1,\n    has_weight1,\n    zero_centered_weight,\n):\n    if has_rowscale and has_x1:\n        pytest.skip(\"Not supported\")\n    device = \"cuda\"\n    if any(x == torch.bfloat16 for x in [input_dtype, residual_dtype, weight_dtype]):\n        atol = 5e-2\n    elif any(x == torch.float16 for x in [input_dtype, residual_dtype, weight_dtype]):\n        atol = 1e-2\n    else:\n        atol = 1e-4\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 8\n    seqlen = 512\n    layer_norm_ref_fn = layer_norm_ref if not is_rms_norm else rms_norm_ref\n    allclose = (\n        # Sometimes x0_pt.grad is NaN\n        lambda x, x_pt, x_ref, atol=atol: (x - x_ref).abs().max()\n        <= 2 * (x_pt[~x_pt.isnan()] - x_ref[~x_pt.isnan()]).abs().max() + atol\n        or (\n            # Sometimes x_pt and x_ref are the same (e.g. bfloat16) so we want to perturb is a bit\n            # by multiply and divide by 0.3\n            (x_pt[~x_pt.isnan()] - x_ref[~x_pt.isnan()]).abs().max() == 0.0\n            and (x - x_ref).abs().max()\n            <= 2 * (x_pt[~x_pt.isnan()] * 0.3 / 0.3 - x_ref[~x_pt.isnan()]).abs().max() + atol\n        )\n    )\n    x0 = torch.randn(\n        batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True\n    )\n    x0_pt = x0.detach().clone().requires_grad_()\n    x0_ref = x0.detach().clone().requires_grad_()\n    if has_residual:\n        res = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)\n        res_pt = res.detach().clone().requires_grad_()\n        res_ref = res.detach().clone().requires_grad_()\n    else:\n        res, res_pt, res_ref = None, None, None\n    weight = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)\n    if not is_rms_norm:\n        bias = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)\n    else:\n        bias = None\n    weight_pt = weight.detach().clone().requires_grad_()\n    weight_ref = weight.detach().clone().requires_grad_()\n    bias_pt = bias.detach().clone().requires_grad_() if bias is not None else None\n    bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None\n    if has_x1:\n        x1 = torch.randn_like(x0, dtype=input_dtype, requires_grad=True)\n        x1_pt = x1.detach().clone().requires_grad_()\n        x1_ref = x1.detach().clone().requires_grad_()\n    else:\n        x1, x1_pt, x1_ref = None, None, None\n    if has_weight1:\n        weight1 = torch.randn(\n            hidden_size, device=device, dtype=weight_dtype, requires_grad=True\n        )\n        weight1_pt = weight1.detach().clone().requires_grad_()\n        weight1_ref = weight1.detach().clone().requires_grad_()\n        if not is_rms_norm:\n            bias1 = torch.randn(\n                hidden_size, device=device, dtype=weight_dtype, requires_grad=True\n            )\n        else:\n            bias1 = None\n        bias1_pt = bias1.detach().clone().requires_grad_() if bias1 is not None else None\n        bias1_ref = bias1.detach().clone().requires_grad_() if bias1 is not None else None\n    else:\n        weight1, weight1_pt, weight1_ref = None, None, None\n        bias1, bias1_pt, bias1_ref = None, None, None\n\n    rowscale = (\n        torch.randn(batch_size, seqlen, dtype=input_dtype, device=device)\n        if has_rowscale\n        else None\n    )\n\n    residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32\n    out, *rest = layer_norm_fn(\n        x0,\n        weight,\n        bias,\n        residual=res,\n        x1=x1,\n        weight1=weight1,\n        bias1=bias1,\n        eps=1e-6,\n        dropout_p=dropout_p,\n        rowscale=rowscale,\n        prenorm=prenorm,\n        residual_in_fp32=residual_in_fp32,\n        zero_centered_weight=zero_centered_weight,\n        is_rms_norm=is_rms_norm,\n        return_dropout_mask=True,\n    )\n    dropout_mask = rest[-2] if dropout_p > 0.0 else None\n    dropout_mask1 = rest[-1] if dropout_p > 0.0 and x1 is not None else None\n    out_pt = layer_norm_ref_fn(\n        x0_pt,\n        weight_pt,\n        bias_pt,\n        residual=res_pt,\n        x1=x1_pt,\n        weight1=weight1_pt,\n        bias1=bias1_pt,\n        eps=1e-6,\n        dropout_p=dropout_p,\n        rowscale=rowscale,\n        prenorm=prenorm,\n        zero_centered_weight=zero_centered_weight,\n        dropout_mask=dropout_mask,\n        dropout_mask1=dropout_mask1,\n    )\n    out_ref = layer_norm_ref_fn(\n        x0_ref,\n        weight_ref,\n        bias_ref,\n        residual=res_ref,\n        x1=x1_ref,\n        weight1=weight1_ref,\n        bias1=bias1_ref,\n        eps=1e-6,\n        dropout_p=dropout_p,\n        rowscale=rowscale,\n        prenorm=prenorm,\n        zero_centered_weight=zero_centered_weight,\n        dropout_mask=dropout_mask,\n        dropout_mask1=dropout_mask1,\n        upcast=True,\n    )\n    if not has_weight1:\n        if prenorm:\n            residual = rest[0]\n            out_pt, residual_pt = out_pt\n            out_ref, residual_ref = out_ref\n        out1, out1_pt, out1_ref = None, None, None\n    else:\n        out1 = rest.pop(0)\n        if prenorm:\n            residual = rest[0]\n            out_pt, out1_pt, residual_pt = out_pt\n            out_ref, out1_ref, residual_ref = out_ref\n        else:\n            out_pt, out1_pt = out_pt\n            out_ref, out1_ref = out_ref\n    assert out.dtype == input_dtype\n    if prenorm:\n        assert residual.dtype == residual_dtype\n        assert allclose(residual, residual_pt, residual_ref)\n    assert allclose(out, out_pt, out_ref)\n    if out1 is not None:\n        assert out1.dtype == input_dtype\n        assert allclose(out1, out1_pt, out1_ref)\n    if dropout_mask is not None:\n        dropout_fraction = 1.0 - dropout_mask.float().mean()\n        assert abs(dropout_fraction - dropout_p) < 0.01\n    if dropout_mask1 is not None:\n        dropout_fraction = 1.0 - dropout_mask1.float().mean()\n        assert abs(dropout_fraction - dropout_p) < 0.01\n        assert not torch.equal(dropout_mask, dropout_mask1)\n\n    g = torch.randn_like(out) / batch_size\n    if has_weight1:\n        out = out * F.gelu(out1)\n        out_pt = out_pt * F.gelu(out1_pt)\n        out_ref = out_ref * F.gelu(out1_ref)\n    if not prenorm:\n        out.backward(g)\n        out_pt.backward(g)\n        out_ref.backward(g)\n    else:\n        (out * F.sigmoid(residual)).backward(g)\n        (out_pt * F.sigmoid(residual_pt)).backward(g)\n        (out_ref * F.sigmoid(residual_ref.to(dtype=residual_dtype))).backward(g)\n    assert allclose(x0.grad, x0_pt.grad, x0_ref.grad)\n    if has_residual:\n        assert allclose(res.grad, res_pt.grad, res_ref.grad)\n    if has_x1:\n        assert allclose(x1.grad, x1_pt.grad, x1_ref.grad)\n    assert allclose(weight.grad, weight_pt.grad, weight_ref.grad)\n    if bias is not None:\n        assert allclose(bias.grad, bias_pt.grad, bias_ref.grad)\n    if has_weight1:\n        assert allclose(weight1.grad, weight1_pt.grad, weight1_ref.grad)\n        if bias1 is not None:\n            assert allclose(bias1.grad, bias1_pt.grad, bias1_ref.grad)\n\n\n@pytest.mark.parametrize(\"prenorm\", [True, False])\n# @pytest.mark.parametrize(\"prenorm\", [True])\n@pytest.mark.parametrize(\"is_rms_norm\", [False, True])\n# @pytest.mark.parametrize(\"is_rms_norm\", [True])\n@pytest.mark.parametrize(\"has_residual\", [True, False])\n# @pytest.mark.parametrize(\"has_residual\", [False])\n@pytest.mark.parametrize(\"weight_dtype\", [torch.float32])\n@pytest.mark.parametrize(\n    \"input_dtype,residual_dtype\",\n    [(torch.float16, torch.float16), (torch.float16, torch.float32)]\n    + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),\n)\n# @pytest.mark.parametrize(\"input_dtype,residual_dtype\", [(torch.bfloat16, torch.float32)])\n@pytest.mark.parametrize(\"hidden_size\", [192, 2048, 2560, 3000])\n# @pytest.mark.parametrize(\"hidden_size\", [256])\ndef test_layer_norm_linear(\n    hidden_size, input_dtype, residual_dtype, weight_dtype, has_residual, is_rms_norm, prenorm\n):\n    device = \"cuda\"\n    if any(x == torch.bfloat16 for x in [input_dtype, residual_dtype, weight_dtype]):\n        atol = 5e-2\n    elif any(x == torch.float16 for x in [input_dtype, residual_dtype, weight_dtype]):\n        atol = 1e-2\n    else:\n        atol = 1e-4\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 4\n    seqlen = 512\n    # batch_size = 1\n    # seqlen = 1\n    layer_norm_ref_fn = layer_norm_ref if not is_rms_norm else rms_norm_ref\n    allclose = (\n        lambda x, x_pt, x_ref, atol=atol: (x - x_ref).abs().max()\n        <= 2 * (x_pt - x_ref).abs().max() + atol\n    )\n    x0 = torch.randn(\n        batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True\n    )\n    x0_pt = x0.detach().clone().requires_grad_()\n    x0_ref = x0.detach().clone().requires_grad_()\n    if has_residual:\n        res = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)\n        res_pt = res.detach().clone().requires_grad_()\n        res_ref = res.detach().clone().requires_grad_()\n    else:\n        res, res_pt, res_ref = None, None, None\n    norm_weight = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)\n    if not is_rms_norm:\n        norm_bias = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)\n    else:\n        norm_bias = None\n    norm_weight_pt = norm_weight.detach().clone().requires_grad_()\n    norm_weight_ref = norm_weight.detach().clone().requires_grad_()\n    norm_bias_pt = norm_bias.detach().clone().requires_grad_() if norm_bias is not None else None\n    norm_bias_ref = norm_bias.detach().clone().requires_grad_() if norm_bias is not None else None\n    linear_weight = torch.empty(\n        2 * hidden_size, hidden_size, device=device, dtype=weight_dtype, requires_grad=True\n    )\n    torch.nn.init.xavier_uniform_(linear_weight)\n    if not is_rms_norm:\n        linear_bias = torch.randn(\n            2 * hidden_size, device=device, dtype=weight_dtype, requires_grad=True\n        )\n    else:\n        linear_bias = None\n    linear_weight_pt = linear_weight.detach().clone().requires_grad_()\n    linear_weight_ref = linear_weight.detach().clone().requires_grad_()\n    linear_bias_pt = (\n        linear_bias.detach().clone().requires_grad_() if linear_bias is not None else None\n    )\n    linear_bias_ref = (\n        linear_bias.detach().clone().requires_grad_() if linear_bias is not None else None\n    )\n\n    residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32\n    with torch.autocast(device_type=\"cuda\", dtype=input_dtype):\n        out, *rest = layer_norm_linear_fn(\n            x0,\n            norm_weight,\n            norm_bias,\n            linear_weight,\n            linear_bias,\n            residual=res,\n            eps=1e-6,\n            prenorm=prenorm,\n            residual_in_fp32=residual_in_fp32,\n            is_rms_norm=is_rms_norm,\n        )\n    out_pt, *rest_pt = layer_norm_ref_fn(\n        x0_pt, norm_weight_pt, norm_bias_pt, residual=res_pt, eps=1e-6, prenorm=prenorm\n    )\n    with torch.autocast(device_type=\"cuda\", dtype=input_dtype):\n        out_pt = F.linear(out_pt, linear_weight_pt, linear_bias_pt)\n    out_ref, *rest_ref = layer_norm_ref_fn(\n        x0_ref,\n        norm_weight_ref,\n        norm_bias_ref,\n        residual=res_ref,\n        eps=1e-6,\n        prenorm=prenorm,\n        upcast=True,\n    )\n    out_ref = F.linear(out_ref.to(linear_weight_ref.dtype), linear_weight_ref, linear_bias_ref)\n    if prenorm:\n        residual = rest[0]\n        residual_pt = rest_pt[0]\n        residual_ref = rest_ref[0]\n    assert out.dtype == input_dtype\n    if prenorm:\n        assert residual.dtype == residual_dtype\n        assert allclose(residual, residual_pt, residual_ref)\n    assert allclose(out, out_pt, out_ref)\n\n    g = torch.randn_like(out) / batch_size\n    out.backward(g)\n    out_pt.backward(g)\n    out_ref.backward(g)\n    assert allclose(x0.grad, x0_pt.grad, x0_ref.grad)\n    if has_residual:\n        assert allclose(res.grad, res_pt.grad, res_ref.grad)\n    assert allclose(norm_weight.grad, norm_weight_pt.grad, norm_weight_ref.grad)\n    if norm_bias is not None:\n        assert allclose(norm_bias.grad, norm_bias_pt.grad, norm_bias_ref.grad)\n    assert allclose(linear_weight.grad, linear_weight_pt.grad, linear_weight_ref.grad)\n    if linear_bias is not None:\n        assert allclose(linear_bias.grad, linear_bias_pt.grad, linear_bias_ref.grad)\n"
  },
  {
    "path": "tests/pyproject.toml",
    "content": "[tool.black]\nline-length = 100\ntarget-version = ['py38']"
  },
  {
    "path": "tests/test_flash_attn.py",
    "content": "import math\n\nimport pytest\nimport torch\nimport torch.nn.functional as F\nfrom einops import rearrange, repeat\nfrom flash_attn import (\n    flash_attn_func,\n    flash_attn_kvpacked_func,\n    flash_attn_qkvpacked_func,\n    flash_attn_varlen_func,\n    flash_attn_varlen_kvpacked_func,\n    flash_attn_varlen_qkvpacked_func,\n    flash_attn_with_kvcache,\n)\nfrom flash_attn.bert_padding import pad_input, unpad_input\nfrom flash_attn.flash_attn_interface import _get_block_size_n\nfrom flash_attn.layers.rotary import apply_rotary_emb\n\nMAX_HEADDIM_SM8x = 192\n\n\nis_sm75 = torch.cuda.get_device_capability(\"cuda\") == (7, 5)\nis_sm8x = torch.cuda.get_device_capability(\"cuda\")[0] == 8\nis_sm80 = torch.cuda.get_device_capability(\"cuda\") == (8, 0)\nis_sm90 = torch.cuda.get_device_capability(\"cuda\") == (9, 0)\n\n\ndef attn_bias_from_alibi_slopes(\n    slopes, seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, causal=False, key_leftpad=None\n):\n    batch, nheads = slopes.shape\n    device = slopes.device\n    slopes = rearrange(slopes, \"b h -> b h 1 1\")\n    if causal:\n        return torch.arange(-seqlen_k + 1, 1, device=device, dtype=torch.float32) * slopes\n    else:\n        row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), \"s -> s 1\")\n        col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)\n        if key_leftpad is not None:\n            key_leftpad = rearrange(key_leftpad, \"b -> b 1 1 1\")\n            col_idx = repeat(col_idx, \"s -> b 1 1 s\", b=key_leftpad.shape[0])\n            col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)\n        sk = (\n            seqlen_k\n            if key_padding_mask is None\n            else rearrange(key_padding_mask.sum(-1), \"b -> b 1 1 1\")\n        )\n        sq = (\n            seqlen_q\n            if query_padding_mask is None\n            else rearrange(query_padding_mask.sum(-1), \"b -> b 1 1 1\")\n        )\n        relative_pos = torch.abs(row_idx + sk - sq - col_idx)\n        return -slopes * relative_pos.to(dtype=slopes.dtype)\n\n\ndef generate_random_padding_mask(max_seqlen, batch_size, device, mode=\"random\"):\n    assert mode in [\"full\", \"random\", \"third\"]\n    if mode == \"full\":\n        lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32)\n    elif mode == \"random\":\n        lengths = torch.randint(\n            max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device\n        )\n    elif mode == \"third\":\n        lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device)\n    padding_mask = (\n        repeat(torch.arange(max_seqlen, device=device), \"s -> b s\", b=batch_size) < lengths\n    )\n    return padding_mask\n\n\ndef generate_qkv(\n    q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False\n):\n    \"\"\"\n    Arguments:\n        q: (batch_size, seqlen_q, nheads, d)\n        k: (batch_size, seqlen_k, nheads_k, d)\n        v: (batch_size, seqlen_k, nheads_k, d)\n        query_padding_mask: (batch_size, seqlen), bool\n        key_padding_mask: (batch_size, seqlen), bool\n    \"\"\"\n    assert not (kvpacked and qkvpacked)\n    batch_size, seqlen_q, nheads, d = q.shape\n    _, seqlen_k, nheads_k, _ = k.shape\n    assert k.shape == (batch_size, seqlen_k, nheads_k, d)\n    assert v.shape == (batch_size, seqlen_k, nheads_k, d)\n\n    if query_padding_mask is not None:\n        q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, _ = unpad_input(q, query_padding_mask)\n        output_pad_fn = lambda output_unpad: pad_input(\n            output_unpad, indices_q, batch_size, seqlen_q\n        )\n    else:\n        q_unpad = rearrange(q, \"b s h d -> (b s) h d\")\n        cu_seqlens_q = torch.arange(\n            0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device\n        )\n        max_seqlen_q = seqlen_q\n        output_pad_fn = lambda output_unpad: rearrange(\n            output_unpad, \"(b s) h d -> b s h d\", b=batch_size\n        )\n\n    if key_padding_mask is not None:\n        k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, _ = unpad_input(k, key_padding_mask)\n        v_unpad, _, _, _, _ = unpad_input(v, key_padding_mask)\n    else:\n        k_unpad = rearrange(k, \"b s h d -> (b s) h d\")\n        v_unpad = rearrange(v, \"b s h d -> (b s) h d\")\n        cu_seqlens_k = torch.arange(\n            0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device\n        )\n        max_seqlen_k = seqlen_k\n\n    if qkvpacked:\n        assert (query_padding_mask == key_padding_mask).all()\n        assert nheads == nheads_k\n        qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)\n        qkv = torch.stack([q, k, v], dim=2)\n        if query_padding_mask is not None:\n            dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q)\n        else:\n            dqkv_pad_fn = lambda dqkv_unpad: rearrange(\n                dqkv_unpad, \"(b s) t h d -> b s t h d\", b=batch_size\n            )\n        return (\n            qkv_unpad.detach().requires_grad_(),\n            cu_seqlens_q,\n            max_seqlen_q,\n            qkv.detach().requires_grad_(),\n            output_pad_fn,\n            dqkv_pad_fn,\n        )\n    elif kvpacked:\n        kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)\n        kv = torch.stack([k, v], dim=2)\n        dq_pad_fn = output_pad_fn\n        if key_padding_mask is not None:\n            dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k)\n        else:\n            dkv_pad_fn = lambda dkv_unpad: rearrange(\n                dkv_unpad, \"(b s) t h d -> b s t h d\", b=batch_size\n            )\n        return (\n            q_unpad.detach().requires_grad_(),\n            kv_unpad.detach().requires_grad_(),\n            cu_seqlens_q,\n            cu_seqlens_k,\n            max_seqlen_q,\n            max_seqlen_k,\n            q.detach().requires_grad_(),\n            kv.detach().requires_grad_(),\n            output_pad_fn,\n            dq_pad_fn,\n            dkv_pad_fn,\n        )\n    else:\n        dq_pad_fn = output_pad_fn\n        if key_padding_mask is not None:\n            dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k)\n        else:\n            dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, \"(b s) h d -> b s h d\", b=batch_size)\n        return (\n            q_unpad.detach().requires_grad_(),\n            k_unpad.detach().requires_grad_(),\n            v_unpad.detach().requires_grad_(),\n            cu_seqlens_q,\n            cu_seqlens_k,\n            max_seqlen_q,\n            max_seqlen_k,\n            q.detach().requires_grad_(),\n            k.detach().requires_grad_(),\n            v.detach().requires_grad_(),\n            output_pad_fn,\n            dq_pad_fn,\n            dk_pad_fn,\n        )\n\n\ndef construct_local_mask(\n    seqlen_q,\n    seqlen_k,\n    window_size=(-1, -1),  # -1 means infinite window size\n    query_padding_mask=None,\n    key_padding_mask=None,\n    device=None,\n    key_leftpad=None,\n):\n    row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), \"s -> s 1\")\n    col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)\n    if key_leftpad is not None:\n        key_leftpad = rearrange(key_leftpad, \"b -> b 1 1 1\")\n        col_idx = repeat(col_idx, \"s -> b 1 1 s\", b=key_leftpad.shape[0])\n        col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)\n    sk = (\n        seqlen_k\n        if key_padding_mask is None\n        else rearrange(key_padding_mask.sum(-1), \"b -> b 1 1 1\")\n    )\n    sq = (\n        seqlen_q\n        if query_padding_mask is None\n        else rearrange(query_padding_mask.sum(-1), \"b -> b 1 1 1\")\n    )\n    if window_size[0] < 0:\n        return col_idx > row_idx + sk - sq + window_size[1]\n    else:\n        sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk\n        return torch.logical_or(\n            col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),\n            col_idx < row_idx + sk - sq - window_size[0],\n        )\n\n\ndef attention_ref(\n    q,\n    k,\n    v,\n    query_padding_mask=None,\n    key_padding_mask=None,\n    attn_bias=None,\n    dropout_p=0.0,\n    dropout_mask=None,\n    causal=False,\n    window_size=(-1, -1),  # -1 means infinite window size\n    softcap=0.0,\n    upcast=True,\n    reorder_ops=False,\n    key_leftpad=None,\n):\n    \"\"\"\n    Arguments:\n        q: (batch_size, seqlen_q, nheads, head_dim)\n        k: (batch_size, seqlen_k, nheads_k, head_dim)\n        v: (batch_size, seqlen_k, nheads_k, head_dim)\n        query_padding_mask: (batch_size, seqlen_q)\n        key_padding_mask: (batch_size, seqlen_k)\n        attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)\n        dropout_p: float\n        dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)\n        causal: whether to apply causal masking\n        window_size: (int, int), left and right window size\n        upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast\n            output back to fp16/bf16.\n        reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.)\n            without changing the math. This is to estimate the numerical error from operation\n            reordering.\n    Output:\n        output: (batch_size, seqlen_q, nheads, head_dim)\n        attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout\n    \"\"\"\n    if causal:\n        window_size = (window_size[0], 0)\n    dtype_og = q.dtype\n    if upcast:\n        q, k, v = q.float(), k.float(), v.float()\n    seqlen_q, seqlen_k = q.shape[1], k.shape[1]\n    k = repeat(k, \"b s h d -> b s (h g) d\", g=q.shape[2] // k.shape[2])\n    v = repeat(v, \"b s h d -> b s (h g) d\", g=q.shape[2] // v.shape[2])\n    d = q.shape[-1]\n    if not reorder_ops:\n        scores = torch.einsum(\"bthd,bshd->bhts\", q / math.sqrt(d), k)\n    else:\n        scores = torch.einsum(\"bthd,bshd->bhts\", q, k / math.sqrt(d))\n    if softcap > 0:\n        scores = scores / softcap\n        scores = scores.tanh()\n        scores = scores * softcap\n    if key_padding_mask is not None:\n        scores.masked_fill_(rearrange(~key_padding_mask, \"b s -> b 1 1 s\"), float(\"-inf\"))\n    if window_size[0] >= 0 or window_size[1] >= 0:\n        local_mask = construct_local_mask(\n            seqlen_q,\n            seqlen_k,\n            window_size,\n            query_padding_mask,\n            key_padding_mask,\n            q.device,\n            key_leftpad=key_leftpad,\n        )\n        scores.masked_fill_(local_mask, float(\"-inf\"))\n    if attn_bias is not None:\n        scores = scores + attn_bias\n    attention = torch.softmax(scores, dim=-1).to(v.dtype)\n    # Some rows might be completely masked out so we fill them with zero instead of NaN\n    if window_size[0] >= 0 or window_size[1] >= 0:\n        attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0)\n    # We want to mask here so that the attention matrix doesn't have any NaNs\n    # Otherwise we'll get NaN in dV\n    if query_padding_mask is not None:\n        attention = attention.masked_fill(rearrange(~query_padding_mask, \"b s -> b 1 s 1\"), 0.0)\n    dropout_scaling = 1.0 / (1 - dropout_p)\n    # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling\n    # output = torch.einsum('bhts,bshd->bthd', attention_drop , v)\n    if dropout_mask is not None:\n        attention_drop = attention.masked_fill(~dropout_mask, 0.0)\n    else:\n        attention_drop = attention\n    output = torch.einsum(\"bhts,bshd->bthd\", attention_drop, v * dropout_scaling)\n    if query_padding_mask is not None:\n        output.masked_fill_(rearrange(~query_padding_mask, \"b s -> b s 1 1\"), 0.0)\n    return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)\n\n\ndef attention_kvpacked_ref(\n    q,\n    kv,\n    query_padding_mask=None,\n    key_padding_mask=None,\n    attn_bias=None,\n    dropout_p=0.0,\n    dropout_mask=None,\n    causal=False,\n    window_size=(-1, -1),  # -1 means infinite window size\n    softcap=0.0,\n    upcast=True,\n    reorder_ops=False,\n    key_leftpad=None,\n):\n    return attention_ref(\n        q,\n        kv[:, :, 0],\n        kv[:, :, 1],\n        query_padding_mask,\n        key_padding_mask,\n        attn_bias,\n        dropout_p,\n        dropout_mask,\n        upcast=upcast,\n        causal=causal,\n        window_size=window_size,\n        softcap=softcap,\n        reorder_ops=reorder_ops,\n        key_leftpad=key_leftpad,\n    )\n\n\ndef attention_qkvpacked_ref(\n    qkv,\n    key_padding_mask=None,\n    attn_bias=None,\n    dropout_p=0.0,\n    dropout_mask=None,\n    causal=False,\n    window_size=(-1, -1),  # -1 means infinite window size\n    softcap=0.0,\n    upcast=True,\n    reorder_ops=False,\n):\n    return attention_ref(\n        qkv[:, :, 0],\n        qkv[:, :, 1],\n        qkv[:, :, 2],\n        key_padding_mask,\n        key_padding_mask,\n        attn_bias,\n        dropout_p,\n        dropout_mask,\n        upcast=upcast,\n        causal=causal,\n        window_size=window_size,\n        softcap=softcap,\n        reorder_ops=reorder_ops,\n    )\n\n\ndef generate_sparsity_mask(seqlen, sparsity=0.3):\n    repeats = seqlen // 16 // 2\n    # mask = torch.stack([torch.tensor([1, 0] * repeats, dtype=torch.bool, device='cuda'),\n    #                     torch.tensor([0, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)\n    # mask = torch.stack([torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda'),\n    #                     torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)\n    # mask = torch.stack([torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)\n    # mask = torch.stack([torch.tensor([1, 0] * repeats, dtype=torch.bool, device='cuda')], dim=-1)\n    nrow, ncol = seqlen // 16, seqlen // 256\n    mask = torch.rand(nrow, ncol, device=\"cuda\") < sparsity\n    return mask\n\n\ndef attention_blocksparse_ref(qkv, blockmask, attn_mask, dropout_p, dropout_mask):\n    \"\"\"\n    Arguments:\n        qkv: (batch_size, seqlen, 3, nheads, head_dim)\n        blockmask: (seqlen / 16, seqlen / 256)\n        attn_mask: (batch_size, seqlen)\n        dropout_p: float\n        dropout_mask: (batch_size, nheads, seqlen, seqlen)\n    Output:\n        output: (batch_size, seqlen, nheads, head_dim)\n        attention: softmax after dropout\n    \"\"\"\n    q, k, v = qkv.float().unbind(dim=2)\n    d = qkv.shape[-1]\n    seqlen = qkv.shape[1]\n    scores = torch.einsum(\"bthd,bshd->bhts\", q / math.sqrt(d), k)\n    scores.masked_fill_(rearrange(~attn_mask, \"b s -> b 1 1 s\"), float(\"-inf\"))\n    blockmask = repeat(blockmask, \"s_16 s_256 -> (s_16 16) (s_256 256)\")\n    blockmask = blockmask[:seqlen, :seqlen]\n    scores.masked_fill_(rearrange(~blockmask, \"t s -> 1 1 t s\"), float(\"-inf\"))\n    attention = torch.softmax(scores, dim=-1)\n    attention = attention.masked_fill(rearrange(~attn_mask, \"b s -> b 1 s 1\"), 0.0)\n    attention = attention.masked_fill_(rearrange(~blockmask, \"t s -> 1 1 t s\"), 0.0)\n    attention_drop = attention.masked_fill(~dropout_mask, 0.0) / (1 - dropout_p)\n    output = torch.einsum(\"bhts,bshd->bthd\", attention_drop, v)\n    output.masked_fill_(rearrange(~attn_mask, \"b s -> b s 1 1\"), 0)\n    return output.to(dtype=qkv.dtype), attention.to(dtype=qkv.dtype)\n\n\ndef convert_flash_attn_S_to_softmax(\n    S,\n    seqlen_q,\n    seqlen_k,\n    query_padding_mask,\n    key_padding_mask,\n    head_dim,\n    is_dropout,\n    causal=False,\n    window_size=(-1, -1),  # -1 means infinite window size\n):\n    \"\"\"FlashAttention stores the S matrix in a different way.\n    Arguments:\n        S: (batch_size, nheads, seqlen_q_rounded, seqlen_k_rounded)\n        query_padding_mask: (batch_size, seqlen_q_rounded)\n        key_padding_mask: (batch_size, seqlen_k_rounded)\n    \"\"\"\n    if causal:\n        window_size = (window_size[0], 0)\n    seqlen_q_rounded, seqlen_k_rounded = S.shape[-2:]\n    S_converted = S\n    if window_size[0] >= 0 or window_size[1] >= 0:\n        local_mask = construct_local_mask(\n            seqlen_q,\n            seqlen_k,\n            window_size,\n            query_padding_mask,\n            key_padding_mask,\n            S.device,\n        )\n        local_mask = F.pad(\n            local_mask,\n            (0, seqlen_k_rounded - seqlen_k, 0, seqlen_q_rounded - seqlen_q),\n            value=True,\n        )\n        S_converted = S_converted.masked_fill(local_mask, 0.0)\n\n    # Need to zero out things not in attention_mask in case S was initialized with random values\n    # and some of those values aren't overwritten.\n    seqlen_q_og = (\n        query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q_rounded\n    )\n    if query_padding_mask is not None:\n        query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q_rounded - seqlen_q_og))\n        S_converted = S_converted.masked_fill(rearrange(~query_padding_mask, \"b s -> b 1 s 1\"), 0.0)\n    seqlen_k_og = key_padding_mask.shape[-1] if key_padding_mask is not None else seqlen_k\n    if key_padding_mask is not None:\n        key_padding_mask = F.pad(key_padding_mask, (0, seqlen_k_rounded - seqlen_k_og))\n        S_converted = S_converted.masked_fill(rearrange(~key_padding_mask, \"b s -> b 1 1 s\"), 0.0)\n    S_converted = F.pad(S_converted, (0, 0, 0, seqlen_q_og - seqlen_q_rounded))\n    S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k_rounded))\n    return S_converted[:, :, :seqlen_q, :seqlen_k]\n\n\ndef normalize_flash_attn_S(\n    attn_unnorm,\n    q,\n    k,\n    v,\n    query_padding_mask=None,\n    key_padding_mask=None,\n    attn_bias=None,\n    is_dropout=False,\n    causal=False,\n    window_size=(-1, -1),  # -1 means infinite window size\n):\n    \"\"\"\n    Arguments:\n        q: (batch_size, seqlen_q, nheads, head_dim)\n        k, v: (batch_size, seqlen_k, nheads, head_dim)\n        key_padding_mask: (batch_size, seqlen_q)\n        attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)\n    Output:\n        softmax_lse: (batch_size, nheads, seqlen_q)\n        softmax_max: (batch_size, nheads, seqlen_q)\n    \"\"\"\n    if causal:\n        window_size = (window_size[0], 0)\n    q, k, v = q.float(), k.float(), v.float()\n    _, seqlen_q, _, head_dim = q.shape\n    seqlen_k = k.shape[1]\n    scores = torch.einsum(\"bthd,bshd->bhts\", q / math.sqrt(head_dim), k)\n    if key_padding_mask is not None:\n        scores.masked_fill_(rearrange(~key_padding_mask, \"b s -> b 1 1 s\"), float(\"-inf\"))\n    if window_size[0] >= 0 or window_size[1] >= 0:\n        local_mask = construct_local_mask(\n            seqlen_q,\n            seqlen_k,\n            window_size,\n            query_padding_mask,\n            key_padding_mask,\n            q.device,\n        )\n        scores.masked_fill_(local_mask, float(\"-inf\"))\n    if attn_bias is not None:\n        scores = scores + attn_bias.to(dtype=scores.dtype)\n    block_size_n = _get_block_size_n(scores.device, head_dim, is_dropout, causal)\n    scores_block = scores.split(block_size_n, dim=-1)\n    lse_block = torch.stack([torch.logsumexp(s, dim=-1) for s in scores_block], dim=-1)\n    lse = torch.logsumexp(lse_block, dim=-1)\n    # lse could be -inf (i.e. all values in scores are -inf), and we want to set those to inf\n    # so that when we do torch.exp(m - lse), we get 0.0 instead of NaN.\n    lse[lse == float(\"-inf\")] = float(\"inf\")\n    scores_max_block = torch.stack([torch.amax(s, dim=-1) for s in scores_block], dim=-1)\n    cummax_block = torch.cummax(scores_max_block.flip(-1), dim=-1).values.flip(-1).unbind(dim=-1)\n    attn_unnorm_block = attn_unnorm.split(block_size_n, dim=-1)\n    attn_norm = torch.cat(\n        [\n            a * rearrange(torch.exp(m - lse), \"b h s -> b h s 1\")\n            for a, m in zip(attn_unnorm_block, cummax_block)\n        ],\n        dim=-1,\n    )\n    if query_padding_mask is not None:\n        attn_norm.masked_fill_(rearrange(~query_padding_mask, \"b s -> b 1 s 1\"), 0.0)\n    return attn_norm.to(dtype=attn_unnorm.dtype)\n\n\ndef get_dropout_fraction(\n    dropout_mask,\n    query_padding_mask=None,\n    key_padding_mask=None,\n    causal=False,\n    window_size=(-1, -1),  # -1 means infinite window size\n):\n    \"\"\"\n    dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k), bool. True means keep, False means drop.\n    query_padding_mask: (batch_size, seqlen_q)\n    key_padding_mask: (batch_size, seqlen_k)\n    \"\"\"\n    if causal:\n        window_size = (window_size[0], 0)\n    batch_size, nheads, seqlen_q, seqlen_k = dropout_mask.shape\n    dropped = ~dropout_mask\n    valid = torch.ones_like(dropout_mask)\n    if query_padding_mask is not None:\n        dropped.masked_fill_(rearrange(~query_padding_mask, \"b s -> b 1 s 1\"), False)\n        valid.masked_fill_(rearrange(~query_padding_mask, \"b s -> b 1 s 1\"), False)\n    if key_padding_mask is not None:\n        dropped.masked_fill_(rearrange(~key_padding_mask, \"b s -> b 1 1 s\"), False)\n        valid.masked_fill_(rearrange(~key_padding_mask, \"b s -> b 1 1 s\"), False)\n    if window_size[0] >= 0 or window_size[1] >= 0:\n        local_mask = construct_local_mask(\n            seqlen_q,\n            seqlen_k,\n            window_size,\n            query_padding_mask,\n            key_padding_mask,\n            dropout_mask.device,\n        )\n        dropped.masked_fill_(local_mask, False)\n        valid.masked_fill_(local_mask, False)\n    dropped_total = dropped.sum()\n    return dropped.sum() / valid.sum()\n\n\n@pytest.mark.parametrize(\"dtype\", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))\n# @pytest.mark.parametrize(\"dtype\", [torch.float16])\n@pytest.mark.parametrize(\"deterministic\", [False, True])\n# @pytest.mark.parametrize(\"deterministic\", [False])\n@pytest.mark.parametrize(\"alibi\", [False, True])\n# @pytest.mark.parametrize(\"alibi\", [False])\n@pytest.mark.parametrize(\"local\", [False, True])\n# @pytest.mark.parametrize(\"local\", [False])\n@pytest.mark.parametrize(\"causal\", [False, True])\n# @pytest.mark.parametrize(\"causal\", [False])\n@pytest.mark.parametrize(\"d\", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize(\"d\", [32, 64, 96, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize('d', [32, 64, 96, 128])\n# @pytest.mark.parametrize(\"d\", [64])\n# @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048])\n@pytest.mark.parametrize(\"seqlen\", [97, 128, 200, 384, 768, 1024, 1025, 2048])\n# @pytest.mark.parametrize(\"seqlen\", [512])\n@pytest.mark.parametrize(\"dropout_p\", [0.0, 0.17])\n# @pytest.mark.parametrize(\"dropout_p\", [0.0])\ndef test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype):\n    if seqlen >= 2048 and torch.cuda.get_device_properties(\"cuda\").total_memory <= 16 * 2**30:\n        pytest.skip()  # Reference implementation OOM\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 4\n    nheads = 9\n    window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,))\n    qkv = torch.randn(\n        batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True\n    )\n    if alibi:\n        alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3\n        attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen, seqlen, causal=causal)\n    else:\n        alibi_slopes, attn_bias = None, None\n    out, lse, S_dmask = flash_attn_qkvpacked_func(\n        qkv,\n        dropout_p,\n        causal=causal,\n        window_size=window_size,\n        alibi_slopes=alibi_slopes,\n        deterministic=deterministic,\n        return_attn_probs=True,\n    )\n    if dropout_p > 0.0:\n        S_dmask_converted = convert_flash_attn_S_to_softmax(\n            S_dmask,\n            seqlen,\n            seqlen,\n            None,\n            None,\n            d,\n            dropout_p > 0.0,\n            causal=causal,\n            window_size=window_size,\n        )\n        dropout_mask = S_dmask_converted >= 0\n        attn_unnorm = S_dmask_converted.abs()\n        attn = normalize_flash_attn_S(\n            attn_unnorm,\n            qkv[:, :, 0],\n            qkv[:, :, 1],\n            qkv[:, :, 2],\n            None,\n            None,\n            attn_bias,\n            dropout_p > 0.0,\n            causal=causal,\n            window_size=window_size,\n        )\n        dropout_fraction = get_dropout_fraction(\n            dropout_mask, None, None, causal=causal, window_size=window_size\n        ).item()\n        print(f\"Actual dropout fraction: {dropout_fraction}\")\n    else:\n        dropout_mask = None\n\n    out_ref, attn_ref = attention_qkvpacked_ref(\n        qkv, None, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size\n    )\n    out_pt, attn_pt = attention_qkvpacked_ref(\n        qkv,\n        None,\n        attn_bias,\n        dropout_p,\n        dropout_mask,\n        causal=causal,\n        window_size=window_size,\n        upcast=False,\n        reorder_ops=True,\n    )\n    # v = qkv[:, :, 2].float()\n    # qk = torch.einsum('bshd,bthd->bhst', qkv[:, :, 0], qkv[:, :, 1]).float()\n    # if causal:\n    #     causal_mask = torch.triu(torch.ones(seqlen, seqlen, dtype=torch.bool, device=qkv.device), 1)\n    #     qk.masked_fill_(causal_mask, float('-inf'))\n    # m = qk.amax(-1, keepdim=True)\n    # s_tmp = torch.exp((qk - m) / math.sqrt(d))\n    # p_tmp = torch.softmax(qk / math.sqrt(d), -1)\n    # p_dropped = p_tmp if dropout_mask is None else p_tmp.masked_fill(~dropout_mask, 0)\n    # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)\n    # qk_max1 = torch.max(qk[:, :, 128:, 192:], -1, keepdim=True).values\n    # qk_max2 = torch.max(qk[:, :, 128:, 128:], -1, keepdim=True).values\n    # qk_max3 = torch.max(qk[:, :, 128:, 64:], -1, keepdim=True).values\n    # qk_max4 = torch.max(qk[:, :, 128:, :], -1, keepdim=True).values\n    # o1 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 192:] - qk_max1) / math.sqrt(d)), v[:, 192:])\n    # o2 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 128:] - qk_max2) / math.sqrt(d)), v[:, 128:])\n    # o3 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 64:] - qk_max3) / math.sqrt(d)), v[:, 64:])\n    # o4 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, :] - qk_max4) / math.sqrt(d)), v[:, :])\n    print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n    print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n    print(f\"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}\")\n    print(f\"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}\")\n    if dropout_p > 0.0:\n        print(f\"Attention max diff: {(attn - attn_ref).abs().max().item()}\")\n        print(f\"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}\")\n\n    g = torch.randn_like(out)\n    # do_o = (g.float() * out.float()).sum(-1)\n    # dv_tmp = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, :64], g[:, :64])\n    # dv_tmp1 = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, 64:], g[:, 64:])\n    if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):\n        (dqkv,) = torch.autograd.grad(out, qkv, g)\n        (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g)\n        (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g)\n        print(f\"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}\")\n        print(f\"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}\")\n        print(f\"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}\")\n        print(f\"dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}\")\n        print(f\"dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}\")\n        print(f\"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}\")\n        print(f\"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}\")\n        print(f\"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}\")\n\n    # Check that FlashAttention's numerical error is at most twice the numerical error\n    # of a Pytorch implementation.\n    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()\n\n    if dropout_p > 0.0:\n        assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()\n        # With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate\n        if not alibi:\n            assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)\n\n    if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):\n        assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()\n\n\n@pytest.mark.parametrize(\"dtype\", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))\n# @pytest.mark.parametrize('dtype', [torch.float16])\n@pytest.mark.parametrize(\"deterministic\", [False, True])\n# @pytest.mark.parametrize(\"deterministic\", [True])\n@pytest.mark.parametrize(\"alibi\", [False, True])\n# @pytest.mark.parametrize(\"alibi\", [True])\n@pytest.mark.parametrize(\"local\", [False, True])\n# @pytest.mark.parametrize(\"local\", [True])\n@pytest.mark.parametrize(\"causal\", [False, True])\n# @pytest.mark.parametrize('causal', [False])\n@pytest.mark.parametrize(\"d\", [32, 59, 64, 80, 96, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize(\"d\", [32, 64, 96, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize('d', [64])\n@pytest.mark.parametrize(\"seqlen\", [97, 128, 200, 257, 384, 512, 768, 1025, 2048])\n# @pytest.mark.parametrize('seqlen', [128])\n@pytest.mark.parametrize(\"dropout_p\", [0.0, 0.17])\n# @pytest.mark.parametrize('dropout_p', [0.0])\ndef test_flash_attn_varlen_qkvpacked(\n    seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype\n):\n    if seqlen >= 2048 and torch.cuda.get_device_properties(\"cuda\").total_memory <= 16 * 2**30:\n        pytest.skip()  # Reference implementation OOM\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 5\n    nheads = 6\n    window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,))\n    qkv = torch.randn(\n        batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True\n    )\n\n    key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode=\"random\")\n    # key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full')\n    if alibi:\n        alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3\n        attn_bias = attn_bias_from_alibi_slopes(\n            alibi_slopes, seqlen, seqlen, key_padding_mask, key_padding_mask, causal=causal\n        )\n    else:\n        alibi_slopes, attn_bias = None, None\n\n    qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv(\n        *qkv.unbind(dim=2), key_padding_mask, key_padding_mask, qkvpacked=True\n    )\n\n    out_unpad, sm_lse, S_dmask = flash_attn_varlen_qkvpacked_func(\n        qkv_unpad,\n        cu_seqlens,\n        max_seqlen,\n        dropout_p,\n        causal=causal,\n        window_size=window_size,\n        alibi_slopes=alibi_slopes,\n        deterministic=deterministic,\n        return_attn_probs=True,\n    )\n    out = output_pad_fn(out_unpad)\n    if dropout_p > 0.0:\n        S_dmask_converted = convert_flash_attn_S_to_softmax(\n            S_dmask,\n            seqlen,\n            seqlen,\n            key_padding_mask,\n            key_padding_mask,\n            d,\n            dropout_p > 0.0,\n            causal=causal,\n            window_size=window_size,\n        )\n        dropout_mask = S_dmask_converted >= 0\n        attn_unnorm = S_dmask_converted.abs()\n        attn = normalize_flash_attn_S(\n            attn_unnorm,\n            qkv[:, :, 0],\n            qkv[:, :, 1],\n            qkv[:, :, 2],\n            key_padding_mask,\n            key_padding_mask,\n            attn_bias,\n            dropout_p > 0.0,\n            causal=causal,\n            window_size=window_size,\n        )\n        dropout_fraction = get_dropout_fraction(\n            dropout_mask, key_padding_mask, key_padding_mask, causal=causal, window_size=window_size\n        ).item()\n        print(f\"Actual dropout fraction: {dropout_fraction}\")\n    else:\n        dropout_mask = None\n\n    out_ref, attn_ref = attention_qkvpacked_ref(\n        qkv,\n        key_padding_mask,\n        attn_bias,\n        dropout_p,\n        dropout_mask,\n        causal=causal,\n        window_size=window_size,\n    )\n    out_pt, attn_pt = attention_qkvpacked_ref(\n        qkv,\n        key_padding_mask,\n        attn_bias,\n        dropout_p,\n        dropout_mask,\n        causal=causal,\n        window_size=window_size,\n        upcast=False,\n        reorder_ops=True,\n    )\n    print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n    print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n    print(f\"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}\")\n    print(f\"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}\")\n    if dropout_p > 0.0:\n        print(f\"Attention max diff: {(attn - attn_ref).abs().max().item()}\")\n        print(f\"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}\")\n\n    g = torch.randn_like(out)\n    if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):\n        (dqkv_unpad,) = torch.autograd.grad(out, qkv_unpad, g)\n        dqkv = dqkv_pad_fn(dqkv_unpad)\n        (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g)\n        (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g)\n        print(f\"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}\")\n        print(f\"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}\")\n        print(f\"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}\")\n        print(f\"dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}\")\n        print(f\"dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}\")\n        print(f\"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}\")\n        print(f\"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}\")\n        print(f\"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}\")\n\n    # Check that FlashAttention's numerical error is at most twice the numerical error\n    # of a Pytorch implementation.\n    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()\n\n    if dropout_p > 0.0:\n        assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()\n        # With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate\n        if not alibi:\n            assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)\n\n    if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):\n        assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()\n\n\n@pytest.mark.parametrize(\"kvpacked\", [True, False])\n# @pytest.mark.parametrize(\"kvpacked\", [False])\n@pytest.mark.parametrize(\"dtype\", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))\n# @pytest.mark.parametrize(\"dtype\", [torch.bfloat16])\n@pytest.mark.parametrize(\"mha_type\", [\"mha\", \"mqa\", \"gqa\"])\n# @pytest.mark.parametrize(\"mha_type\", [\"mha\"])\n@pytest.mark.parametrize(\"deterministic\", [False, True])\n# @pytest.mark.parametrize(\"deterministic\", [True])\n@pytest.mark.parametrize(\"alibi\", [False, True])\n# @pytest.mark.parametrize(\"alibi\", [False])\n@pytest.mark.parametrize(\"local\", [False, True])\n# @pytest.mark.parametrize(\"local\", [False])\n@pytest.mark.parametrize(\"causal\", [False, True])\n# @pytest.mark.parametrize(\"causal\", [True])\n@pytest.mark.parametrize(\"d\", [32, 40, 59, 64, 96, 111, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize(\"d\", [32, 64, 96, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])\n# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])\n# @pytest.mark.parametrize('d', [56, 80])\n# @pytest.mark.parametrize(\"d\", [64])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_k\",\n    [\n        (113, 203),\n        (128, 217),\n        (113, 211),\n        (108, 256),\n        (256, 512),\n        (512, 256),\n        (1024, 1024),\n        (1023, 1024),\n        (1024, 1023),\n        (2048, 2048),\n    ],\n)\n# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])\n@pytest.mark.parametrize(\"dropout_p\", [0.0, 0.17])\n# @pytest.mark.parametrize(\"dropout_p\", [0.0])\n@pytest.mark.parametrize(\"softcap\", [0.0, 50.0])\ndef test_flash_attn_output(\n    seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap\n):\n    if (\n        max(seqlen_q, seqlen_k) >= 2048\n        and torch.cuda.get_device_properties(\"cuda\").total_memory <= 16 * 2**30\n    ):\n        pytest.skip()  # Reference implementation OOM\n    if softcap > 0.0 and dropout_p > 0.0:\n        pytest.skip(\"Softcap and dropout not supported together\")\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 4\n    nheads = 6 if softcap == 0.0 else 4  # softcap reference impl takes more memory\n    nheads_k = nheads if mha_type == \"mha\" else (1 if mha_type == \"mqa\" else 2)\n    assert nheads % nheads_k == 0\n    window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))\n    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    if softcap > 0:\n        # Ensure the values of qk are at least within softcap range.\n        q = q * softcap\n    if kvpacked:\n        kv = torch.randn(\n            batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True\n        )\n    else:\n        k = torch.randn(\n            batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True\n        )\n        v = torch.randn(\n            batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True\n        )\n    if alibi:\n        alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3\n        attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen_q, seqlen_k, causal=causal)\n    else:\n        alibi_slopes, attn_bias = None, None\n\n    if kvpacked:\n        out, lse, S_dmask = flash_attn_kvpacked_func(\n            q,\n            kv,\n            dropout_p,\n            causal=causal,\n            window_size=window_size,\n            softcap=softcap,\n            alibi_slopes=alibi_slopes,\n            deterministic=deterministic,\n            return_attn_probs=True,\n        )\n    else:\n        out, lse, S_dmask = flash_attn_func(\n            q,\n            k,\n            v,\n            dropout_p,\n            causal=causal,\n            window_size=window_size,\n            softcap=softcap,\n            alibi_slopes=alibi_slopes,\n            deterministic=deterministic,\n            return_attn_probs=True,\n        )\n    if dropout_p > 0.0:\n        S_dmask_converted = convert_flash_attn_S_to_softmax(\n            S_dmask,\n            seqlen_q,\n            seqlen_k,\n            None,\n            None,\n            d,\n            dropout_p > 0.0,\n            causal=causal,\n            window_size=window_size,\n        )\n        dropout_mask = S_dmask_converted >= 0\n        attn_unnorm = S_dmask_converted.abs()\n        if kvpacked:\n            kv_rep = repeat(kv, \"b s two h d -> b s two (h g) d\", g=nheads // nheads_k)\n            k_rep, v_rep = kv_rep.unbind(dim=2)\n        else:\n            k_rep = repeat(k, \"b s h d -> b s (h g) d\", g=nheads // nheads_k)\n            v_rep = repeat(v, \"b s h d -> b s (h g) d\", g=nheads // nheads_k)\n        attn = normalize_flash_attn_S(\n            attn_unnorm,\n            q,\n            k_rep,\n            v_rep,\n            None,\n            None,\n            attn_bias,\n            dropout_p > 0.0,\n            causal=causal,\n            window_size=window_size,\n        )\n        dropout_fraction = get_dropout_fraction(\n            dropout_mask, None, None, causal=causal, window_size=window_size\n        ).item()\n        print(f\"Actual dropout fraction: {dropout_fraction}\")\n    else:\n        dropout_mask = None\n\n    if kvpacked:\n        out_ref, attn_ref = attention_kvpacked_ref(\n            q,\n            kv,\n            None,\n            None,\n            attn_bias,\n            dropout_p,\n            dropout_mask,\n            causal=causal,\n            window_size=window_size,\n            softcap=softcap,\n        )\n        out_pt, attn_pt = attention_kvpacked_ref(\n            q,\n            kv,\n            None,\n            None,\n            attn_bias,\n            dropout_p,\n            dropout_mask,\n            causal=causal,\n            window_size=window_size,\n            softcap=softcap,\n            upcast=False,\n            reorder_ops=True,\n        )\n    else:\n        out_ref, attn_ref = attention_ref(\n            q,\n            k,\n            v,\n            None,\n            None,\n            attn_bias,\n            dropout_p,\n            dropout_mask,\n            causal=causal,\n            window_size=window_size,\n            softcap=softcap,\n        )\n        out_pt, attn_pt = attention_ref(\n            q,\n            k,\n            v,\n            None,\n            None,\n            attn_bias,\n            dropout_p,\n            dropout_mask,\n            causal=causal,\n            window_size=window_size,\n            softcap=softcap,\n            upcast=False,\n            reorder_ops=True,\n        )\n\n    print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n    print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n    print(f\"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}\")\n    print(f\"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}\")\n    if dropout_p > 0.0:\n        print(f\"Attention max diff: {(attn - attn_ref).abs().max().item()}\")\n        print(f\"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}\")\n\n    g = torch.randn_like(out)\n    do_o = (g.float() * out.float()).sum(-1)\n    if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):\n        if kvpacked:\n            (\n                dq,\n                dkv,\n            ) = torch.autograd.grad(out, (q, kv), g)\n            dk, dv = dkv.unbind(2)\n            (\n                dq_ref,\n                dkv_ref,\n            ) = torch.autograd.grad(out_ref, (q, kv), g)\n            dk_ref, dv_ref = dkv_ref.unbind(2)\n            (\n                dq_pt,\n                dkv_pt,\n            ) = torch.autograd.grad(out_pt, (q, kv), g)\n            dk_pt, dv_pt = dkv_pt.unbind(2)\n        else:\n            (\n                dq,\n                dk,\n                dv,\n            ) = torch.autograd.grad(out, (q, k, v), g)\n            (\n                dq_ref,\n                dk_ref,\n                dv_ref,\n            ) = torch.autograd.grad(out_ref, (q, k, v), g)\n            (\n                dq_pt,\n                dk_pt,\n                dv_pt,\n            ) = torch.autograd.grad(out_pt, (q, k, v), g)\n        print(f\"dQ max diff: {(dq - dq_ref).abs().max().item()}\")\n        print(f\"dK max diff: {(dk - dk_ref).abs().max().item()}\")\n        print(f\"dV max diff: {(dv - dv_ref).abs().max().item()}\")\n        print(f\"dQ mean diff: {(dq - dq_ref).abs().mean().item()}\")\n        print(f\"dK mean diff: {(dk - dk_ref).abs().mean().item()}\")\n        print(f\"dV mean diff: {(dv - dv_ref).abs().mean().item()}\")\n        print(f\"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}\")\n        print(f\"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}\")\n        print(f\"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}\")\n        print(f\"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}\")\n        print(f\"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}\")\n        print(f\"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}\")\n\n    # Check that FlashAttention's numerical error is at most twice the numerical error\n    # of a Pytorch implementation.\n    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()\n\n    if dropout_p > 0.0:\n        assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()\n        # With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate\n        if not alibi:\n            assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)\n\n    if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):\n        assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item()\n        assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item()\n        assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item()\n\n\n@pytest.mark.parametrize(\"kvpacked\", [True, False])\n# @pytest.mark.parametrize('kvpacked', [False])\n@pytest.mark.parametrize(\"dtype\", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))\n# @pytest.mark.parametrize('dtype', [torch.float16])\n@pytest.mark.parametrize(\"mha_type\", [\"mha\", \"mqa\", \"gqa\"])\n# @pytest.mark.parametrize('mha_type', [\"mqa\"])\n@pytest.mark.parametrize(\"deterministic\", [False, True])\n# @pytest.mark.parametrize(\"deterministic\", [True])\n@pytest.mark.parametrize(\"alibi\", [False, True])\n# @pytest.mark.parametrize(\"alibi\", [True])\n@pytest.mark.parametrize(\"local\", [False, True])\n# @pytest.mark.parametrize(\"local\", [True])\n@pytest.mark.parametrize(\"causal\", [False, True])\n# @pytest.mark.parametrize('causal', [True])\n@pytest.mark.parametrize(\"d\", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize(\"d\", [32, 64, 96, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize('d', [64])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_k\",\n    [\n        (1, 147),\n        (113, 203),\n        (128, 217),\n        (113, 211),\n        (108, 256),\n        (256, 512),\n        (512, 256),\n        (1024, 1024),\n        (1023, 1024),\n        (1024, 1023),\n        (2048, 2048),\n    ],\n)\n# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])\n@pytest.mark.parametrize(\"dropout_p\", [0.0, 0.17])\n@pytest.mark.parametrize(\"softcap\", [0.0, 50.0])\n# @pytest.mark.parametrize('dropout_p', [0.0])\ndef test_flash_attn_varlen_output(\n    seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap\n):\n    if (\n        max(seqlen_q, seqlen_k) >= 2048\n        and torch.cuda.get_device_properties(\"cuda\").total_memory <= 16 * 2**30\n    ):\n        pytest.skip()  # Reference implementation OOM\n    if softcap > 0.0 and dropout_p > 0.0:\n        pytest.skip(\"Softcap and dropout not supported together\")\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 4\n    nheads = 6 if softcap == 0.0 else 4  # softcap reference impl takes more memory\n    nheads_k = nheads if mha_type == \"mha\" else (1 if mha_type == \"mqa\" else 2)\n    assert nheads % nheads_k == 0\n    window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))\n    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    if softcap > 0:\n        # Ensure the values of qk are at least within softcap range.\n        q = q * softcap\n\n    if kvpacked:\n        kv = torch.randn(\n            batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True\n        )\n    else:\n        k = torch.randn(\n            batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True\n        )\n        v = torch.randn(\n            batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True\n        )\n\n    query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode=\"random\")\n    key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode=\"random\")\n    # key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full')\n    if alibi:\n        alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3\n        attn_bias = attn_bias_from_alibi_slopes(\n            alibi_slopes, seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, causal=causal\n        )\n    else:\n        alibi_slopes, attn_bias = None, None\n\n    if kvpacked:\n        (\n            q_unpad,\n            kv_unpad,\n            cu_seqlens_q,\n            cu_seqlens_k,\n            max_seqlen_q,\n            max_seqlen_k,\n            q,\n            kv,\n            output_pad_fn,\n            dq_pad_fn,\n            dkv_pad_fn,\n        ) = generate_qkv(q, *kv.unbind(dim=2), query_padding_mask, key_padding_mask, kvpacked=True)\n        out_unpad, sm_lse, S_dmask = flash_attn_varlen_kvpacked_func(\n            q_unpad,\n            kv_unpad,\n            cu_seqlens_q,\n            cu_seqlens_k,\n            max_seqlen_q,\n            max_seqlen_k,\n            dropout_p,\n            causal=causal,\n            window_size=window_size,\n            softcap=softcap,\n            alibi_slopes=alibi_slopes,\n            deterministic=deterministic,\n            return_attn_probs=True,\n        )\n    else:\n        (\n            q_unpad,\n            k_unpad,\n            v_unpad,\n            cu_seqlens_q,\n            cu_seqlens_k,\n            max_seqlen_q,\n            max_seqlen_k,\n            q,\n            k,\n            v,\n            output_pad_fn,\n            dq_pad_fn,\n            dk_pad_fn,\n        ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)\n        out_unpad, sm_lse, S_dmask = flash_attn_varlen_func(\n            q_unpad,\n            k_unpad,\n            v_unpad,\n            cu_seqlens_q,\n            cu_seqlens_k,\n            max_seqlen_q,\n            max_seqlen_k,\n            dropout_p,\n            causal=causal,\n            window_size=window_size,\n            softcap=softcap,\n            alibi_slopes=alibi_slopes,\n            deterministic=deterministic,\n            return_attn_probs=True,\n        )\n    out = output_pad_fn(out_unpad)\n    if dropout_p > 0.0:\n        S_dmask_converted = convert_flash_attn_S_to_softmax(\n            S_dmask,\n            seqlen_q,\n            seqlen_k,\n            query_padding_mask,\n            key_padding_mask,\n            d,\n            dropout_p > 0.0,\n            causal=causal,\n            window_size=window_size,\n        )\n        dropout_mask = S_dmask_converted >= 0\n        attn_unnorm = S_dmask_converted.abs()\n        if kvpacked:\n            kv_rep = repeat(kv, \"b s two h d -> b s two (h g) d\", g=nheads // nheads_k)\n            k_rep, v_rep = kv_rep.unbind(dim=2)\n        else:\n            k_rep = repeat(k, \"b s h d -> b s (h g) d\", g=nheads // nheads_k)\n            v_rep = repeat(v, \"b s h d -> b s (h g) d\", g=nheads // nheads_k)\n        attn = normalize_flash_attn_S(\n            attn_unnorm,\n            q,\n            k_rep,\n            v_rep,\n            query_padding_mask,\n            key_padding_mask,\n            attn_bias,\n            dropout_p > 0.0,\n            causal=causal,\n            window_size=window_size,\n        )\n        dropout_fraction = get_dropout_fraction(\n            dropout_mask,\n            query_padding_mask,\n            key_padding_mask,\n            causal=causal,\n            window_size=window_size,\n        ).item()\n        print(f\"Actual dropout fraction: {dropout_fraction}\")\n    else:\n        dropout_mask = None\n\n    if kvpacked:\n        out_ref, attn_ref = attention_kvpacked_ref(\n            q,\n            kv,\n            query_padding_mask,\n            key_padding_mask,\n            attn_bias,\n            dropout_p,\n            dropout_mask,\n            causal=causal,\n            window_size=window_size,\n            softcap=softcap,\n        )\n        out_pt, attn_pt = attention_kvpacked_ref(\n            q,\n            kv,\n            query_padding_mask,\n            key_padding_mask,\n            attn_bias,\n            dropout_p,\n            dropout_mask,\n            causal=causal,\n            window_size=window_size,\n            softcap=softcap,\n            upcast=False,\n            reorder_ops=True,\n        )\n    else:\n        out_ref, attn_ref = attention_ref(\n            q,\n            k,\n            v,\n            query_padding_mask,\n            key_padding_mask,\n            attn_bias,\n            dropout_p,\n            dropout_mask,\n            causal=causal,\n            window_size=window_size,\n            softcap=softcap,\n        )\n        out_pt, attn_pt = attention_ref(\n            q,\n            k,\n            v,\n            query_padding_mask,\n            key_padding_mask,\n            attn_bias,\n            dropout_p,\n            dropout_mask,\n            causal=causal,\n            window_size=window_size,\n            softcap=softcap,\n            upcast=False,\n            reorder_ops=True,\n        )\n\n    print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n    print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n    print(f\"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}\")\n    print(f\"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}\")\n    if dropout_p > 0.0:\n        print(f\"Attention max diff: {(attn - attn_ref).abs().max().item()}\")\n        print(f\"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}\")\n\n    g = torch.randn_like(out)\n    if ((d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90)):\n        if kvpacked:\n            (\n                dq_unpad,\n                dkv_unpad,\n            ) = torch.autograd.grad(out, (q_unpad, kv_unpad), g)\n            dk, dv = dkv_pad_fn(dkv_unpad).unbind(2)\n            (\n                dq_ref,\n                dkv_ref,\n            ) = torch.autograd.grad(out_ref, (q, kv), g)\n            dk_ref, dv_ref = dkv_ref.unbind(2)\n            (\n                dq_pt,\n                dkv_pt,\n            ) = torch.autograd.grad(out_pt, (q, kv), g)\n            dk_pt, dv_pt = dkv_pt.unbind(2)\n        else:\n            (\n                dq_unpad,\n                dk_unpad,\n                dv_unpad,\n            ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g)\n            dk = dk_pad_fn(dk_unpad)\n            dv = dk_pad_fn(dv_unpad)\n            (\n                dq_ref,\n                dk_ref,\n                dv_ref,\n            ) = torch.autograd.grad(out_ref, (q, k, v), g)\n            (\n                dq_pt,\n                dk_pt,\n                dv_pt,\n            ) = torch.autograd.grad(out_pt, (q, k, v), g)\n        dq = dq_pad_fn(dq_unpad)\n        print(f\"dQ max diff: {(dq - dq_ref).abs().max().item()}\")\n        print(f\"dK max diff: {(dk - dk_ref).abs().max().item()}\")\n        print(f\"dV max diff: {(dv - dv_ref).abs().max().item()}\")\n        print(f\"dQ mean diff: {(dq - dq_ref).abs().mean().item()}\")\n        print(f\"dK mean diff: {(dk - dk_ref).abs().mean().item()}\")\n        print(f\"dV mean diff: {(dv - dv_ref).abs().mean().item()}\")\n        print(f\"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}\")\n        print(f\"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}\")\n        print(f\"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}\")\n        print(f\"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}\")\n        print(f\"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}\")\n        print(f\"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}\")\n\n    # Check that FlashAttention's numerical error is at most twice the numerical error\n    # of a Pytorch implementation.\n    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()\n\n    if dropout_p > 0.0:\n        assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()\n        # With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate\n        if not alibi:\n            assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.04)\n\n    if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):\n        assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item()\n        assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item()\n        assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item()\n\n\n@pytest.mark.parametrize(\"dtype\", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))\n# @pytest.mark.parametrize(\"dtype\", [torch.bfloat16])\n@pytest.mark.parametrize(\"local\", [False, True])\n# @pytest.mark.parametrize(\"local\", [True])\n@pytest.mark.parametrize(\"d\", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize(\"d\", [32, 64, 96, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])\n# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])\n# @pytest.mark.parametrize('d', [56, 80])\n# @pytest.mark.parametrize(\"d\", [64, 128])\n@pytest.mark.parametrize(\"swap_sq_sk\", [False, True])\n# @pytest.mark.parametrize(\"swap_sq_sk\", [True])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_k\",\n    [\n        (1, 239),\n        (3, 799),\n        (127, 512),\n        (127, 513),\n        (113, 203),\n        (128, 217),\n        (113, 211),\n        (108, 256),\n        (256, 512),\n        (1023, 1024),\n    ],\n)\n# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])\ndef test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):\n    if (\n        max(seqlen_q, seqlen_k) >= 2048\n        and torch.cuda.get_device_properties(\"cuda\").total_memory <= 16 * 2**30\n    ):\n        pytest.skip()  # Reference implementation OOM\n    if swap_sq_sk:\n        seqlen_q, seqlen_k = seqlen_k, seqlen_q\n    device = \"cuda\"\n    causal = True\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 8\n    nheads = 9\n    window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))\n    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size)\n    out_ref, attn_ref = attention_ref(\n        q, k, v, None, None, None, 0.0, None, causal=causal, window_size=window_size\n    )\n    out_pt, attn_pt = attention_ref(\n        q,\n        k,\n        v,\n        None,\n        None,\n        None,\n        0.0,\n        None,\n        causal=causal,\n        window_size=window_size,\n        upcast=False,\n        reorder_ops=True,\n    )\n\n    print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n    print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n    print(f\"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}\")\n    print(f\"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}\")\n\n    g = torch.randn_like(out)\n    do_o = (g.float() * out.float()).sum(-1)\n    (\n        dq,\n        dk,\n        dv,\n    ) = torch.autograd.grad(out, (q, k, v), g)\n    (\n        dq_ref,\n        dk_ref,\n        dv_ref,\n    ) = torch.autograd.grad(out_ref, (q, k, v), g)\n    (\n        dq_pt,\n        dk_pt,\n        dv_pt,\n    ) = torch.autograd.grad(out_pt, (q, k, v), g)\n    print(f\"dQ max diff: {(dq - dq_ref).abs().max().item()}\")\n    print(f\"dK max diff: {(dk - dk_ref).abs().max().item()}\")\n    print(f\"dV max diff: {(dv - dv_ref).abs().max().item()}\")\n    print(f\"dQ mean diff: {(dq - dq_ref).abs().mean().item()}\")\n    print(f\"dK mean diff: {(dk - dk_ref).abs().mean().item()}\")\n    print(f\"dV mean diff: {(dv - dv_ref).abs().mean().item()}\")\n    print(f\"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}\")\n    print(f\"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}\")\n    print(f\"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}\")\n    print(f\"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}\")\n    print(f\"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}\")\n    print(f\"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}\")\n\n    # Check that FlashAttention's numerical error is at most twice the numerical error\n    # of a Pytorch implementation.\n    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5\n\n    assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5\n    assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5\n    assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5\n\n\n@pytest.mark.parametrize(\"dtype\", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))\n# @pytest.mark.parametrize(\"dtype\", [torch.bfloat16])\n@pytest.mark.parametrize(\"local\", [False, True])\n# @pytest.mark.parametrize(\"local\", [True])\n@pytest.mark.parametrize(\"d\", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize(\"d\", [32, 64, 96, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])\n# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])\n# @pytest.mark.parametrize('d', [56, 80])\n# @pytest.mark.parametrize(\"d\", [64])\n@pytest.mark.parametrize(\"swap_sq_sk\", [False, True])\n# @pytest.mark.parametrize(\"swap_sq_sk\", [True])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_k\",\n    [\n        (1, 239),\n        (3, 799),\n        (127, 512),\n        (127, 513),\n        (113, 203),\n        (128, 217),\n        (113, 211),\n        (108, 256),\n        (256, 512),\n        (1023, 1024),\n    ],\n)\n# TODO: add smaller page sizes when https://github.com/Dao-AILab/flash-attention/pull/824 is merged\n@pytest.mark.parametrize(\"paged_kv_block_size\", [None, 256, 512])\n# @pytest.mark.parametrize(\"seqlen_q,seqlen_k\", [(256, 128)])\ndef test_flash_attn_varlen_causal(\n    seqlen_q, seqlen_k, swap_sq_sk, d, local, paged_kv_block_size, dtype\n):\n    if (\n        max(seqlen_q, seqlen_k) >= 2048\n        and torch.cuda.get_device_properties(\"cuda\").total_memory <= 16 * 2**30\n    ):\n        pytest.skip()  # Reference implementation OOM\n    if swap_sq_sk:\n        seqlen_q, seqlen_k = seqlen_k, seqlen_q\n    device = \"cuda\"\n    causal = True\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 8\n    nheads = 9\n    window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))\n    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)\n\n    if paged_kv_block_size is None:\n        k = torch.randn(\n            batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True\n        )\n        v = torch.randn(\n            batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True\n        )\n        block_table = None\n    else:\n        k, v, block_table, k_cache_paged, v_cache_paged, num_blocks = _generate_block_kvcache(\n            seqlen_k, paged_kv_block_size, batch_size, nheads, d, device, dtype\n        )\n    query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode=\"random\")\n    key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode=\"random\")\n    (\n        q_unpad,\n        k_unpad,\n        v_unpad,\n        cu_seqlens_q,\n        cu_seqlens_k,\n        max_seqlen_q,\n        max_seqlen_k,\n        q,\n        k,\n        v,\n        output_pad_fn,\n        dq_pad_fn,\n        dk_pad_fn,\n    ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)\n    out_unpad = flash_attn_varlen_func(\n        q_unpad,\n        k_unpad if paged_kv_block_size is None else k_cache_paged,\n        v_unpad if paged_kv_block_size is None else v_cache_paged,\n        cu_seqlens_q,\n        cu_seqlens_k,\n        max_seqlen_q,\n        max_seqlen_k,\n        0.0,\n        causal=causal,\n        window_size=window_size,\n        block_table=block_table,\n    )\n    out = output_pad_fn(out_unpad)\n    out_ref, attn_ref = attention_ref(\n        q,\n        k,\n        v,\n        query_padding_mask,\n        key_padding_mask,\n        None,\n        0.0,\n        None,\n        causal=causal,\n        window_size=window_size,\n    )\n    out_pt, attn_pt = attention_ref(\n        q,\n        k,\n        v,\n        query_padding_mask,\n        key_padding_mask,\n        None,\n        0.0,\n        None,\n        causal=causal,\n        window_size=window_size,\n        upcast=False,\n        reorder_ops=True,\n    )\n\n    print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n    print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n    print(f\"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}\")\n    print(f\"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}\")\n\n    g = torch.randn_like(out)\n    do_o = (g.float() * out.float()).sum(-1)\n    test_backward = block_table is None\n    if test_backward:\n        (\n            dq_unpad,\n            dk_unpad,\n            dv_unpad,\n        ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g)\n        dq = dq_pad_fn(dq_unpad)\n        dk = dk_pad_fn(dk_unpad)\n        dv = dk_pad_fn(dv_unpad)\n        (\n            dq_ref,\n            dk_ref,\n            dv_ref,\n        ) = torch.autograd.grad(out_ref, (q, k, v), g)\n        (\n            dq_pt,\n            dk_pt,\n            dv_pt,\n        ) = torch.autograd.grad(out_pt, (q, k, v), g)\n        print(f\"dQ max diff: {(dq - dq_ref).abs().max().item()}\")\n        print(f\"dK max diff: {(dk - dk_ref).abs().max().item()}\")\n        print(f\"dV max diff: {(dv - dv_ref).abs().max().item()}\")\n        print(f\"dQ mean diff: {(dq - dq_ref).abs().mean().item()}\")\n        print(f\"dK mean diff: {(dk - dk_ref).abs().mean().item()}\")\n        print(f\"dV mean diff: {(dv - dv_ref).abs().mean().item()}\")\n        print(f\"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}\")\n        print(f\"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}\")\n        print(f\"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}\")\n        print(f\"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}\")\n        print(f\"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}\")\n        print(f\"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}\")\n\n    # Check that FlashAttention's numerical error is at most twice the numerical error\n    # of a Pytorch implementation.\n    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5\n\n    if test_backward:\n        assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5\n        assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5\n        assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5\n\n\n@pytest.mark.parametrize(\"dtype\", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))\n# @pytest.mark.parametrize(\"dtype\", [torch.float16])\n@pytest.mark.parametrize(\"deterministic\", [False, True])\n# @pytest.mark.parametrize(\"deterministic\", [True])\n@pytest.mark.parametrize(\"alibi\", [False, True])\n# @pytest.mark.parametrize(\"alibi\", [True])\n@pytest.mark.parametrize(\"local\", [False, True])\n# @pytest.mark.parametrize(\"local\", [False])\n@pytest.mark.parametrize(\"causal\", [False, True])\n# @pytest.mark.parametrize(\"causal\", [True])\n@pytest.mark.parametrize(\"d\", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])\n# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])\n# @pytest.mark.parametrize('d', [56, 80])\n# @pytest.mark.parametrize(\"d\", [64])\n@pytest.mark.parametrize(\"swap_sq_sk\", [False, True])\n# @pytest.mark.parametrize(\"swap_sq_sk\", [False])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_k\",\n    [\n        (3, 1024),\n        (1, 339),\n        (64, 800),\n        (3, 799),\n        (64, 2048),\n        (16, 20000),\n        (16, 100000),\n        (128, 128),\n        (256, 256),\n    ],\n)\n# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])\ndef test_flash_attn_splitkv(\n    seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, alibi, deterministic, dtype\n):\n    if swap_sq_sk:\n        seqlen_q, seqlen_k = seqlen_k, seqlen_q\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 1\n    nheads = 12\n    window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))\n    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    if alibi:\n        alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3\n        attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen_q, seqlen_k, causal=causal)\n    else:\n        alibi_slopes, attn_bias = None, None\n    out, lse, _ = flash_attn_func(\n        q,\n        k,\n        v,\n        0.0,\n        causal=causal,\n        window_size=window_size,\n        alibi_slopes=alibi_slopes,\n        deterministic=deterministic,\n        return_attn_probs=True,\n    )\n    out_ref, attn_ref = attention_ref(\n        q, k, v, None, None, attn_bias, 0.0, None, causal=causal, window_size=window_size\n    )\n    out_pt, attn_pt = attention_ref(\n        q,\n        k,\n        v,\n        None,\n        None,\n        attn_bias,\n        0.0,\n        None,\n        causal=causal,\n        window_size=window_size,\n        upcast=False,\n        reorder_ops=True,\n    )\n\n    print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n    print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n    print(f\"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}\")\n    print(f\"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}\")\n\n    g = torch.randn_like(out)\n    do_o = (g.float() * out.float()).sum(-1)\n    (\n        dq,\n        dk,\n        dv,\n    ) = torch.autograd.grad(out, (q, k, v), g)\n    (\n        dq_ref,\n        dk_ref,\n        dv_ref,\n    ) = torch.autograd.grad(out_ref, (q, k, v), g)\n    (\n        dq_pt,\n        dk_pt,\n        dv_pt,\n    ) = torch.autograd.grad(out_pt, (q, k, v), g)\n    print(f\"dQ max diff: {(dq - dq_ref).abs().max().item()}\")\n    print(f\"dK max diff: {(dk - dk_ref).abs().max().item()}\")\n    print(f\"dV max diff: {(dv - dv_ref).abs().max().item()}\")\n    print(f\"dQ mean diff: {(dq - dq_ref).abs().mean().item()}\")\n    print(f\"dK mean diff: {(dk - dk_ref).abs().mean().item()}\")\n    print(f\"dV mean diff: {(dv - dv_ref).abs().mean().item()}\")\n    print(f\"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}\")\n    print(f\"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}\")\n    print(f\"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}\")\n    print(f\"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}\")\n    print(f\"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}\")\n    print(f\"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}\")\n\n    # Check that FlashAttention's numerical error is at most twice the numerical error\n    # of a Pytorch implementation.\n    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5\n\n    mult = 2 if not alibi else 8\n    assert (dq - dq_ref).abs().max().item() <= mult * (dq_pt - dq_ref).abs().max().item() + 2e-4\n    assert (dk - dk_ref).abs().max().item() <= mult * (dk_pt - dk_ref).abs().max().item() + 2e-4\n    assert (dv - dv_ref).abs().max().item() <= mult * (dv_pt - dv_ref).abs().max().item() + 2e-4\n\n\n# @pytest.mark.parametrize(\"dtype\", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))\n@pytest.mark.parametrize(\"dtype\", [torch.float16])\n@pytest.mark.parametrize(\"num_splits\", [1, 0])\n# @pytest.mark.parametrize(\"num_splits\", [1])\n@pytest.mark.parametrize(\"mha_type\", [\"mha\", \"mqa\", \"gqa\"])\n# @pytest.mark.parametrize(\"mha_type\", [\"mha\"])\n@pytest.mark.parametrize(\"new_kv\", [False, True])\n# @pytest.mark.parametrize(\"new_kv\", [False])\n@pytest.mark.parametrize(\"alibi\", [False, True])\n# @pytest.mark.parametrize(\"alibi\", [False])\n@pytest.mark.parametrize(\"local\", [False, True])\n# @pytest.mark.parametrize(\"local\", [False])\n@pytest.mark.parametrize(\"causal\", [False, True])\n# @pytest.mark.parametrize(\"causal\", [False])\n@pytest.mark.parametrize(\"seqlen_new_eq_seqlen_q\", [True, False])\n# @pytest.mark.parametrize(\"seqlen_new_eq_seqlen_q\", [True])\n@pytest.mark.parametrize(\"rotary_interleaved\", [False, True])\n# @pytest.mark.parametrize(\"rotary_interleaved\", [False])\n@pytest.mark.parametrize(\"rotary_fraction\", [0.0, 0.5, 1.0])\n# @pytest.mark.parametrize(\"rotary_fraction\", [0.0])\n@pytest.mark.parametrize(\"paged_kv_block_size\", [None, 256])\n# @pytest.mark.parametrize(\"paged_kv_block_size\", [256, 512])\n# @pytest.mark.parametrize(\"paged_kv_block_size\", [None])\n@pytest.mark.parametrize(\"has_leftpad\", [False, True])\n# @pytest.mark.parametrize(\"has_leftpad\", [True])\n# @pytest.mark.parametrize(\"has_batch_idx\", [False, True])\n@pytest.mark.parametrize(\"has_batch_idx\", [False])\n@pytest.mark.parametrize(\"d\", [32, 59, 64, 80, 128, 256])\n# @pytest.mark.parametrize(\"d\", [32, 64, 96, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])\n# @pytest.mark.parametrize('d', [56, 80])\n# @pytest.mark.parametrize(\"d\", [128])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_k\",\n    [\n        (1, 128),\n        (1, 339),\n        (3, 1024),\n        (64, 800),\n        (64, 256),\n        (3, 799),\n        (64, 2048),\n        (16, 20000),\n        (1, 128 * 1024),\n        (16, 128 * 1024),\n        (128, 128),\n    ],\n)\n# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])\ndef test_flash_attn_kvcache(\n    seqlen_q,\n    seqlen_k,\n    d,\n    has_batch_idx,\n    has_leftpad,\n    paged_kv_block_size,\n    rotary_fraction,\n    rotary_interleaved,\n    seqlen_new_eq_seqlen_q,\n    causal,\n    local,\n    alibi,\n    new_kv,\n    mha_type,\n    num_splits,\n    dtype,\n):\n    if seqlen_q > seqlen_k and new_kv:\n        pytest.skip()\n    if not new_kv and rotary_fraction > 0.0:\n        pytest.skip()\n    if has_batch_idx and paged_kv_block_size is not None:\n        pytest.skip()\n    if has_leftpad and paged_kv_block_size is not None:\n        pytest.skip()\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 2\n    batch_size_cache = batch_size if not has_batch_idx else batch_size * 2\n    nheads = 6\n    # rotary_dim must be a multiple of 16, and must be <= d\n    rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16\n    nheads_k = nheads if mha_type == \"mha\" else (1 if mha_type == \"mqa\" else 3)\n    assert nheads % nheads_k == 0\n    window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))\n    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype)\n    seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item()\n    if new_kv:\n        k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype)\n        v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype)\n    else:\n        k, v = None, None\n    if paged_kv_block_size is None:\n        k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)\n        v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)\n        block_table = None\n    else:\n        (\n            k_cache,\n            v_cache,\n            block_table,\n            k_cache_paged,\n            v_cache_paged,\n            num_blocks,\n        ) = _generate_block_kvcache(\n            seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype\n        )\n    cache_seqlens = torch.randint(\n        0 if new_kv else 1,\n        # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough\n        (\n            (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1)\n            if new_kv\n            else (seqlen_k + 1)\n        ),\n        (batch_size,),\n        dtype=torch.int32,\n        device=device,\n    )\n    if has_leftpad:\n        cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device)\n                                   if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device)\n                                   for i in range(batch_size)])\n    else:\n        cache_leftpad = None\n    arange = rearrange(torch.arange(seqlen_k, device=device), \"s -> 1 s\")\n    cache_seqlens_expanded = rearrange(cache_seqlens, \"b -> b 1\")\n    key_padding_mask = arange < cache_seqlens_expanded + (seqlen_new if new_kv else 0)\n    if has_leftpad:\n        key_padding_mask = torch.logical_and(\n            key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k)\n        )\n    if has_batch_idx:\n        cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[\n            :batch_size\n        ]\n    else:\n        cache_batch_idx = None\n    if alibi:\n        alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3\n        attn_bias = attn_bias_from_alibi_slopes(\n            alibi_slopes, seqlen_q, seqlen_k, None, key_padding_mask, causal=causal, key_leftpad=cache_leftpad\n        )\n    else:\n        alibi_slopes, attn_bias = None, None\n    # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)\n    if rotary_dim > 0:\n        angle = (\n            torch.rand(\n                seqlen_k if paged_kv_block_size is None else num_blocks * paged_kv_block_size,\n                rotary_dim // 2,\n                device=device,\n            )\n            * 2\n            * math.pi\n        )\n        cos = torch.cos(angle).to(dtype=dtype)\n        sin = torch.sin(angle).to(dtype=dtype)\n        if causal or local:\n            q_ro = apply_rotary_emb(\n                q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved\n            )\n        else:\n            q_ro = rearrange(\n                apply_rotary_emb(\n                    rearrange(q, \"b s h d -> b 1 (s h) d\"),\n                    cos,\n                    sin,\n                    seqlen_offsets=cache_seqlens,\n                    interleaved=rotary_interleaved,\n                ),\n                \"b 1 (s h) d -> b s h d\",\n                s=seqlen_q,\n            )\n        # q_ro = q\n        k_ro = apply_rotary_emb(\n            k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved\n        )\n    else:\n        cos, sin = None, None\n        q_ro, k_ro = q, k\n    # k_cache[:, 64:] = -1\n    k_cache_ref = (\n        k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)]\n    ).clone()\n    v_cache_ref = (\n        v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)]\n    ).clone()\n    if new_kv:\n        update_mask = torch.logical_and(\n            cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_new\n        )\n        k_cache_ref[update_mask] = rearrange(k_ro, \"b s ... -> (b s) ...\")\n        v_cache_ref[update_mask] = rearrange(v, \"b s ... -> (b s) ...\")\n    k_cache_rep = repeat(k_cache_ref, \"b s h d -> b s (h g) d\", g=nheads // nheads_k)\n    v_cache_rep = repeat(v_cache_ref, \"b s h d -> b s (h g) d\", g=nheads // nheads_k)\n    out = flash_attn_with_kvcache(\n        q,\n        k_cache if paged_kv_block_size is None else k_cache_paged,\n        v_cache if paged_kv_block_size is None else v_cache_paged,\n        k,\n        v,\n        rotary_cos=cos,\n        rotary_sin=sin,\n        cache_seqlens=cache_seqlens,\n        cache_batch_idx=cache_batch_idx,\n        cache_leftpad=cache_leftpad,\n        block_table=block_table,\n        causal=causal,\n        window_size=window_size,\n        rotary_interleaved=rotary_interleaved,\n        alibi_slopes=alibi_slopes,\n        num_splits=num_splits,\n    )\n    # out = flash_attn_with_kvcache(\n    #     q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size\n    # )\n    # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size)\n    # qk = torch.einsum(\"bqhd,bkhd->bhqk\", q, k_cache_ref)\n    # m = qk.amax(-1, keepdim=True)\n    # s_tmp = torch.exp((qk - m) / math.sqrt(d))\n    # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref)\n    # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)\n    # probs = torch.softmax(qk, dim=-1)\n    out_ref, _ = attention_ref(\n        q_ro,\n        k_cache_rep,\n        v_cache_rep,\n        None,\n        key_padding_mask,\n        attn_bias,\n        0.0,\n        None,\n        causal=causal,\n        window_size=window_size,\n        key_leftpad=cache_leftpad,\n    )\n    out_pt, _ = attention_ref(\n        q_ro,\n        k_cache_rep,\n        v_cache_rep,\n        None,\n        key_padding_mask,\n        attn_bias,\n        0.0,\n        None,\n        causal=causal,\n        window_size=window_size,\n        upcast=False,\n        reorder_ops=True,\n        key_leftpad=cache_leftpad,\n    )\n    print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n    print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n    print(f\"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}\")\n    print(f\"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}\")\n\n    # Check that FlashAttention's numerical error is at most twice the numerical error\n    # of a Pytorch implementation.\n    if new_kv:\n        if paged_kv_block_size is None:\n            k_cache_select = (\n                k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)]\n            )\n            v_cache_select = (\n                v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)]\n            )\n        else:\n            k_cache_select = rearrange(\n                k_cache_paged[block_table.to(dtype=torch.long).flatten()],\n                \"(b nblocks) block_size ... -> b (nblocks block_size) ...\",\n                b=batch_size,\n            )[:, :seqlen_k]\n            v_cache_select = rearrange(\n                v_cache_paged[block_table.to(dtype=torch.long).flatten()],\n                \"(b nblocks) block_size ... -> b (nblocks block_size) ...\",\n                b=batch_size,\n            )[:, :seqlen_k]\n        assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3)\n        assert torch.equal(v_cache_select, v_cache_ref)\n    mult = 3 if not alibi else 5\n    assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5\n\n\ndef _generate_block_kvcache(seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype):\n    num_blocks = math.ceil(seqlen_k / paged_kv_block_size) * batch_size * 3\n    k_cache_paged = torch.randn(\n        num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype\n    )\n    v_cache_paged = torch.randn(\n        num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype\n    )\n    block_table = rearrange(\n        torch.randperm(num_blocks, dtype=torch.int32, device=device),\n        \"(b nblocks) -> b nblocks\",\n        b=batch_size,\n    )\n    k_cache = rearrange(\n        # pytorch 1.12 doesn't have indexing with int32\n        k_cache_paged[block_table.to(dtype=torch.long).flatten()],\n        \"(b nblocks) block_size ... -> b (nblocks block_size) ...\",\n        b=batch_size,\n    )[:, :seqlen_k]\n    v_cache = rearrange(\n        v_cache_paged[block_table.to(dtype=torch.long).flatten()],\n        \"(b nblocks) block_size ... -> b (nblocks block_size) ...\",\n        b=batch_size,\n    )[:, :seqlen_k]\n    return k_cache, v_cache, block_table, k_cache_paged, v_cache_paged, num_blocks\n\n\n# @pytest.mark.parametrize(\"dtype\", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))\n@pytest.mark.parametrize(\"dtype\", [torch.float16])\n@pytest.mark.parametrize(\"causal\", [False, True])\n# @pytest.mark.parametrize('causal', [True])\n@pytest.mark.parametrize(\"d\", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize('d', [32, 56, 64, 80, 96, 128])\n# @pytest.mark.parametrize(\"d\", [32, 64, 96, 128, 160, 192])\n# @pytest.mark.parametrize('d', [128])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_k\",\n    [\n        (1, 239),\n        (239, 1),\n        (3, 799),\n        (799, 3),\n        (1024, 128),\n        (97, 97),\n        (128, 128),\n        (200, 200),\n        (256, 256),\n        (257, 257),\n        (384, 384),\n        (512, 512),\n        (768, 768),\n        (1024, 1024),\n    ],\n)\n@pytest.mark.parametrize(\"dropout_p\", [0.0, 0.17])\n# @pytest.mark.parametrize(\"dropout_p\", [0.0])\ndef test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dtype):\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 60  # Sometimes we need large batch size for the race conditions to trigger\n    nheads = 4\n    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    torch.random.manual_seed(42)\n    out0, lse0, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True)\n    g = torch.randn_like(out0)\n    if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):\n        (\n            dq0,\n            dk0,\n            dv0,\n        ) = torch.autograd.grad(out0, (q, k, v), g)\n        # Numerical error if we just do any arithmetic on dq\n        dq_atol = 2 * ((dq0 + 0.3 - 0.3) - dq0).abs().max().item()\n\n    for i in range(250):\n        torch.random.manual_seed(42)\n        out, lse, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True)\n        assert torch.equal(out, out0)\n        assert torch.equal(lse, lse0)\n\n        if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):\n            (\n                dq,\n                dk,\n                dv,\n            ) = torch.autograd.grad(out, (q, k, v), g)\n            dq_equal = torch.allclose(dq, dq0, atol=dq_atol)\n            if not dq_equal:\n                print(f\"Iter {i}, {dq_atol = }, dQ max diff: {(dq - dq0).abs().max().item()}\")\n            assert torch.equal(dv, dv0)\n            assert torch.equal(dk, dk0)\n            assert dq_equal\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float16])\n@pytest.mark.parametrize(\"causal\", [False, True])\n# @pytest.mark.parametrize('causal', [False])\n@pytest.mark.parametrize(\"d\", [16, 32, 64])\n# @pytest.mark.parametrize('d', [16])\n@pytest.mark.parametrize(\"seqlen\", [1, 2, 5, 17, 128])\n# @pytest.mark.parametrize('seqlen', [2])\ndef test_flash_attn_bwd_overflow(seqlen, d, causal, dtype):\n    \"\"\"We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ,\n    in the case where seqlen % 128 != 0.\n    \"\"\"\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 2\n    nheads = 5\n    q = torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device=\"cuda\") * 5\n    k, v = [\n        torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device=\"cuda\") * 3\n        for _ in range(2)\n    ]\n    q.requires_grad_(True)\n    k.requires_grad_(True)\n    v.requires_grad_(True)\n    out = flash_attn_func(q, k, v, causal=causal)\n    g = torch.randn_like(out)\n    out.backward(g)\n    q_pt = q.detach().clone().requires_grad_(True)\n    k_pt = k.detach().clone().requires_grad_(True)\n    v_pt = v.detach().clone().requires_grad_(True)\n    out_pt, _ = attention_ref(q_pt, k_pt, v_pt, causal=causal, upcast=False, reorder_ops=True)\n    out_pt.backward(g)\n    q_ref = q.detach().clone().requires_grad_(True)\n    k_ref = k.detach().clone().requires_grad_(True)\n    v_ref = v.detach().clone().requires_grad_(True)\n    out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal)\n    out_ref.backward(g)\n    print(f\"dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}\")\n    print(f\"dK max diff: {(k.grad - k_ref.grad).abs().max().item()}\")\n    print(f\"dV max diff: {(v.grad - v_ref.grad).abs().max().item()}\")\n    print(f\"dQ Pytorch max diff: {(q_pt.grad - q_ref.grad).abs().max().item()}\")\n    print(f\"dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}\")\n    print(f\"dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}\")\n    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()\n    assert (q.grad - q_ref.grad).abs().max().item() <= 5 * (\n        q_pt.grad - q_ref.grad\n    ).abs().max().item() + 1e-3\n    assert (k.grad - k_ref.grad).abs().max().item() <= 5 * (\n        k_pt.grad - k_ref.grad\n    ).abs().max().item() + 1e-3\n    assert (v.grad - v_ref.grad).abs().max().item() <= 5 * (\n        v_pt.grad - v_ref.grad\n    ).abs().max().item() + 1e-3\n\n\n@pytest.mark.parametrize(\"dtype\", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))\n# @pytest.mark.parametrize('dtype', [torch.bfloat16])\n@pytest.mark.parametrize(\"causal\", [False, True])\n# @pytest.mark.parametrize('causal', [False])\n@pytest.mark.parametrize(\"d\", [64, 128])\n# @pytest.mark.parametrize('d', [64])\n@pytest.mark.parametrize(\"seqlen\", [97, 128, 200, 256])\n# @pytest.mark.parametrize('seqlen', [128])\ndef test_flash_attn_bwd_transpose(seqlen, d, causal, dtype):\n    \"\"\"We previously had a bug where we were using the wrong strides of dout, which shows up\n    when dout is not contiguous.\n    \"\"\"\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 5\n    nheads = 2\n    q, k, v = [\n        torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device=\"cuda\", requires_grad=True)\n        for _ in range(3)\n    ]\n    out = rearrange(flash_attn_func(q, k, v, causal=causal), \"b s ... -> s b ...\")\n    # So g is not contiguous\n    g = torch.randn(seqlen, 2 * batch_size, nheads, d, dtype=dtype, device=\"cuda\")[:, ::2]\n    out.backward(g)\n    q_pt = q.detach().clone().requires_grad_(True)\n    k_pt = k.detach().clone().requires_grad_(True)\n    v_pt = v.detach().clone().requires_grad_(True)\n    out_pt, attn_pt = attention_ref(q_pt, k_pt, v_pt, causal=causal, upcast=False, reorder_ops=True)\n    out_pt = rearrange(out_pt, \"b s ... -> s b ...\")\n    out_pt.backward(g)\n    q_ref = q.detach().clone().requires_grad_(True)\n    k_ref = k.detach().clone().requires_grad_(True)\n    v_ref = v.detach().clone().requires_grad_(True)\n    out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal)\n    out_ref = rearrange(out_ref, \"b s ... -> s b ...\")\n    out_ref.backward(g)\n    print(f\"dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}\")\n    print(f\"dK max diff: {(k.grad - k_ref.grad).abs().max().item()}\")\n    print(f\"dV max diff: {(v.grad - v_ref.grad).abs().max().item()}\")\n    print(f\"dQ Pytorch max diff: {(q_pt.grad - q_ref.grad).abs().max().item()}\")\n    print(f\"dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}\")\n    print(f\"dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}\")\n    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()\n    assert (q.grad - q_ref.grad).abs().max().item() <= 2 * (\n        q_pt.grad - q_ref.grad\n    ).abs().max().item()\n    assert (k.grad - k_ref.grad).abs().max().item() <= 2 * (\n        k_pt.grad - k_ref.grad\n    ).abs().max().item()\n    assert (v.grad - v_ref.grad).abs().max().item() <= 2 * (\n        v_pt.grad - v_ref.grad\n    ).abs().max().item()\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float16])\n@pytest.mark.parametrize(\"causal\", [False, True])\n# @pytest.mark.parametrize('causal', [False])\n@pytest.mark.parametrize(\"d\", [16, 32, 64])\n# @pytest.mark.parametrize('d', [16])\ndef test_flash_attn_bwd_varlen_overflow(d, causal, dtype):\n    \"\"\"We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ,\n    in the case where seqlen % 128 != 0 or varlen.\n    \"\"\"\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(0)\n    nheads = 5\n    q_cuseqlen = torch.tensor([0, 76, 110, 256], device=device, dtype=torch.int32)\n    k_cuseqlen = torch.tensor([0, 1, 2, 3], device=device, dtype=torch.int32)\n    Mq = 256\n    Mk = 3\n\n    q = torch.randn([Mq, nheads, d], dtype=dtype, device=device) * 3\n    k, v = [torch.randn([Mk, nheads, d], dtype=dtype, device=device) * 3 for _ in range(2)]\n    q.requires_grad_(True)\n    k.requires_grad_(True)\n    v.requires_grad_(True)\n\n    out = flash_attn_varlen_func(q, k, v, q_cuseqlen, k_cuseqlen, Mq, Mk, causal=causal)\n    g = torch.randn_like(out)\n    out.backward(g)\n\n    assert not q.grad.isnan().any()\n    assert not k.grad.isnan().any()\n    assert not v.grad.isnan().any()\n\n\n@pytest.mark.parametrize(\"dtype\", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))\n# @pytest.mark.parametrize(\"dtype\", [torch.bfloat16])\n@pytest.mark.parametrize(\"local\", [False, True])\n# @pytest.mark.parametrize(\"local\", [True])\n@pytest.mark.parametrize(\"causal\", [False, True])\n# @pytest.mark.parametrize(\"causal\", [True])\n@pytest.mark.parametrize(\"d\", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize(\"d\", [32, 64, 96, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])\n# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])\n# @pytest.mark.parametrize('d', [56, 80])\n# @pytest.mark.parametrize(\"d\", [64])\n@pytest.mark.parametrize(\"swap_sq_sk\", [False, True])\n# @pytest.mark.parametrize(\"swap_sq_sk\", [False])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_k\",\n    [\n        (1, 239),\n        (3, 799),\n        (127, 512),\n        (127, 513),\n        (113, 203),\n        (128, 217),\n        (113, 211),\n        (108, 256),\n        (256, 512),\n        (1023, 1024),\n    ],\n)\n# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])\ndef test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype):\n    if (\n        max(seqlen_q, seqlen_k) >= 2048\n        and torch.cuda.get_device_properties(\"cuda\").total_memory <= 16 * 2**30\n    ):\n        pytest.skip()  # Reference implementation OOM\n    if swap_sq_sk:\n        seqlen_q, seqlen_k = seqlen_k, seqlen_q\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 4\n    nheads = 9\n    window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))\n    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size, deterministic=True)\n\n    g = torch.randn_like(out)\n    dq0, dk0, dv0 = torch.autograd.grad(out, (q, k, v), g, retain_graph=True)\n    for _ in range(50):\n        dq, dk, dv = torch.autograd.grad(out, (q, k, v), g, retain_graph=True)\n        assert torch.equal(dv, dv0)\n        assert torch.equal(dk, dk0)\n        assert torch.equal(dq, dq0)\n\n\n@pytest.mark.parametrize(\"dtype\", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))\n# @pytest.mark.parametrize(\"dtype\", [torch.bfloat16])\n@pytest.mark.parametrize(\"local\", [False, True])\n# @pytest.mark.parametrize(\"local\", [True])\n@pytest.mark.parametrize(\"causal\", [False, True])\n# @pytest.mark.parametrize(\"causal\", [True])\n@pytest.mark.parametrize(\"d\", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize(\"d\", [32, 64, 96, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])\n# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])\n# @pytest.mark.parametrize('d', [56, 80])\n# @pytest.mark.parametrize(\"d\", [64])\n@pytest.mark.parametrize(\"swap_sq_sk\", [False, True])\n# @pytest.mark.parametrize(\"swap_sq_sk\", [True])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_k\",\n    [\n        (1, 239),\n        (3, 799),\n        (127, 512),\n        (127, 513),\n        (113, 203),\n        (128, 217),\n        (113, 211),\n        (108, 256),\n        (256, 512),\n        (1023, 1024),\n    ],\n)\n# @pytest.mark.parametrize(\"seqlen_q,seqlen_k\", [(256, 128)])\ndef test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype):\n    if (\n        max(seqlen_q, seqlen_k) >= 2048\n        and torch.cuda.get_device_properties(\"cuda\").total_memory <= 16 * 2**30\n    ):\n        pytest.skip()  # Reference implementation OOM\n    if swap_sq_sk:\n        seqlen_q, seqlen_k = seqlen_k, seqlen_q\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 2\n    nheads = 9\n    window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))\n    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode=\"random\")\n    key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode=\"random\")\n    (\n        q_unpad,\n        k_unpad,\n        v_unpad,\n        cu_seqlens_q,\n        cu_seqlens_k,\n        max_seqlen_q,\n        max_seqlen_k,\n        q,\n        k,\n        v,\n        output_pad_fn,\n        dq_pad_fn,\n        dk_pad_fn,\n    ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)\n    out = flash_attn_varlen_func(\n        q_unpad,\n        k_unpad,\n        v_unpad,\n        cu_seqlens_q,\n        cu_seqlens_k,\n        max_seqlen_q,\n        max_seqlen_k,\n        0.0,\n        causal=causal,\n        window_size=window_size,\n        deterministic=True,\n    )\n\n    g = torch.randn_like(out)\n    dq0, dk0, dv0 = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)\n    for _ in range(50):\n        dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)\n        assert torch.equal(dv, dv0)\n        assert torch.equal(dk, dk0)\n        assert torch.equal(dq, dq0)\n"
  },
  {
    "path": "tests/test_flash_attn_ck.py",
    "content": "import math\n\nimport pytest\nimport torch\nimport torch.nn.functional as F\nfrom einops import rearrange, repeat\nfrom flash_attn import (\n    flash_attn_func,\n    flash_attn_kvpacked_func,\n    flash_attn_qkvpacked_func,\n    flash_attn_varlen_func,\n    flash_attn_varlen_kvpacked_func,\n    flash_attn_varlen_qkvpacked_func,\n    flash_attn_with_kvcache,\n)\n\nfrom test_flash_attn import (\n    attn_bias_from_alibi_slopes,\n    convert_flash_attn_S_to_softmax,\n    generate_qkv,\n    generate_random_padding_mask,\n    _generate_block_kvcache,\n    attention_ref,\n    attention_kvpacked_ref,\n    attention_qkvpacked_ref,\n)\n\nfrom flash_attn.layers.rotary import apply_rotary_emb\n\ndef is_bwd_hdim_supported(d):\n    return d <= 256\n\n\ndef ck_randval_to_dropout_mask(randval, p):\n    # If p = 0.3, randval in 255 * (0.7, 1.0] will be dropout\n    # randval in 255 * [0, 0.7] will be kept\n    # If return dropout_mask >=0, value will be kept\n    return math.floor(255.0 * (1 - p)) - randval.to(torch.float32)\n\n\ndef pad_rearrange_dropout_mask_hts_to_bhss(S_dmask, cu_seqlens_q, seqlen_q_rounded, seqlen_k_rounded):\n    \"\"\" pad + rearrange [nheads, total_q, max_seqlen_k] into [b, nheads, seqlen_q_rounded, seqlen_k_rounded]\n    Arguments:\n        S_dmask: (nheads, total_q, max_seqlen_k)\n        cu_seqlens_q: (b + 1)\n    Output:\n        S_dmask: (b, nheads, seqlen_q_rounded, seqlen_k_rounded)\n    \"\"\"\n    batch_size = cu_seqlens_q.numel() - 1\n    seqlens_q = torch.roll(cu_seqlens_q, shifts = -1) - cu_seqlens_q\n    seqlens_q = seqlens_q[0:batch_size].tolist()\n    S_dmask = torch.split(S_dmask, seqlens_q, dim=1)\n    # [(nheads, seqlen_q0, max_seqlen_k), (nheads, seqlen_q1, max_seqlen_k), ..., (nheads, seqlen_qb, max_seqlen_k)]\n    masks = ()\n    for mask in S_dmask:\n        # (nheads, seqlen_qi, max_seqlen_k) -> (nheads, seqlen_q_rounded, seqlen_k_rounded)\n        mask = F.pad(mask, (0, seqlen_k_rounded - mask.shape[2], 0, seqlen_q_rounded - mask.shape[1], 0, 0)).unsqueeze(1)\n        masks = masks + (mask, )\n    S_dmask = torch.cat(masks, dim=1)\n\n    S_dmask = S_dmask.transpose(0, 1)\n    return S_dmask\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16])\n@pytest.mark.parametrize(\"deterministic\", [False, True])\n@pytest.mark.parametrize(\"alibi\", [False, True])\n@pytest.mark.parametrize(\"local\", [False, True])\n@pytest.mark.parametrize(\"causal\", [False, True])\n@pytest.mark.parametrize(\"d\", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])\n@pytest.mark.parametrize(\"seqlen\", [97, 128, 200, 384, 768, 1024, 1025, 2048])\n@pytest.mark.parametrize(\"dropout_p\", [0.0, 0.17])\ndef test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype):\n    if d > 256:\n        pytest.skip()\n\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 4\n    nheads = 9\n    window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,))\n\n    qkv = torch.randn(\n        batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True\n    )\n\n    if alibi:\n        alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3\n        attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen, seqlen, causal=causal)\n    else:\n        alibi_slopes, attn_bias = None, None\n    out, lse, S_dmask = flash_attn_qkvpacked_func(\n        qkv,\n        dropout_p,\n        causal=causal,\n        window_size=window_size,\n        alibi_slopes=alibi_slopes,\n        deterministic=deterministic,\n        return_attn_probs=True,\n    )\n    if dropout_p > 0.0:\n        # TODO - move to c++ mha_varlen_fwd()\n        S_dmask = ck_randval_to_dropout_mask(S_dmask, dropout_p)\n        S_dmask_converted = convert_flash_attn_S_to_softmax(\n            S_dmask,\n            seqlen,\n            seqlen,\n            None,\n            None,\n            d,\n            dropout_p > 0.0,\n            causal=causal,\n            window_size=window_size,\n        )\n        dropout_mask = S_dmask_converted >= 0\n        # CK does not return P. Hence, we don't test the attn here.\n    else:\n        dropout_mask = None\n\n    out_ref, attn_ref = attention_qkvpacked_ref(\n        qkv, None, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size\n    )\n    out_pt, attn_pt = attention_qkvpacked_ref(\n        qkv,\n        None,\n        attn_bias,\n        dropout_p,\n        dropout_mask,\n        causal=causal,\n        window_size=window_size,\n        upcast=False,\n        reorder_ops=True,\n    )\n\n    print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n    print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n    print(f\"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}\")\n    print(f\"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}\")\n\n    # Check that FlashAttention's numerical error is at most twice the numerical error\n    # of a Pytorch implementation.\n    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()\n\n    g = torch.randn_like(out)\n    if is_bwd_hdim_supported(d):\n        (dqkv,) = torch.autograd.grad(out, qkv, g)\n        (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g)\n        (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g)\n        print(f\"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}\")\n        print(f\"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}\")\n        print(f\"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}\")\n        print(f\"dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}\")\n        print(f\"dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}\")\n        print(f\"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}\")\n        print(f\"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}\")\n        print(f\"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}\")\n\n        # TODO - use 10 times to check, wait for ck to fix bwd precision issue\n        assert (dqkv - dqkv_ref).abs().max().item() <= 10 * (dqkv_pt - dqkv_ref).abs().max().item()\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16])\n@pytest.mark.parametrize(\"deterministic\", [False, True])\n@pytest.mark.parametrize(\"alibi\", [False, True])\n@pytest.mark.parametrize(\"local\", [False, True])\n@pytest.mark.parametrize(\"causal\", [False, True])\n@pytest.mark.parametrize(\"d\", [32, 59, 64, 80, 96, 128, 160, 192, 224, 256])\n@pytest.mark.parametrize(\"seqlen\", [97, 128, 200, 257, 384, 512, 768, 1025, 2048])\n@pytest.mark.parametrize(\"dropout_p\", [0, 0.17])\ndef test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype):\n    if d > 256:\n        pytest.skip()\n\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 5\n    nheads = 6\n    window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,))\n    qkv = torch.randn(\n        batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True\n    )\n\n    key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode=\"random\")\n    # key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full')\n    if alibi:\n        alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3\n        attn_bias = attn_bias_from_alibi_slopes(\n            alibi_slopes, seqlen, seqlen, key_padding_mask, key_padding_mask, causal=causal\n        )\n    else:\n        alibi_slopes, attn_bias = None, None\n\n    qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv(\n        *qkv.unbind(dim=2), key_padding_mask, key_padding_mask, qkvpacked=True\n    )\n\n    out_unpad, sm_lse, S_dmask = flash_attn_varlen_qkvpacked_func(\n        qkv_unpad,\n        cu_seqlens,\n        max_seqlen,\n        dropout_p,\n        causal=causal,\n        window_size=window_size,\n        alibi_slopes=alibi_slopes,\n        deterministic=deterministic,\n        return_attn_probs=True,\n    )\n    out = output_pad_fn(out_unpad)\n    if dropout_p > 0.0:\n        # TODO - move to c++ mha_varlen_fwd()\n        S_dmask = ck_randval_to_dropout_mask(S_dmask, dropout_p)\n        S_dmask = pad_rearrange_dropout_mask_hts_to_bhss(S_dmask, cu_seqlens, seqlen, seqlen)\n\n        S_dmask_converted = convert_flash_attn_S_to_softmax(\n            S_dmask,\n            seqlen,\n            seqlen,\n            key_padding_mask,\n            key_padding_mask,\n            d,\n            dropout_p > 0.0,\n            causal=causal,\n            window_size=window_size,\n        )\n\n        dropout_mask = S_dmask_converted >= 0\n        # CK does not return P. Hence, we don't test the attn here.\n    else:\n        dropout_mask = None\n\n    out_ref, attn_ref = attention_qkvpacked_ref(\n        qkv,\n        key_padding_mask,\n        attn_bias,\n        dropout_p,\n        dropout_mask,\n        causal=causal,\n        window_size=window_size,\n    )\n    out_pt, attn_pt = attention_qkvpacked_ref(\n        qkv,\n        key_padding_mask,\n        attn_bias,\n        dropout_p,\n        dropout_mask,\n        causal=causal,\n        window_size=window_size,\n        upcast=False,\n        reorder_ops=True,\n    )\n    print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n    print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n    print(f\"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}\")\n    print(f\"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}\")\n\n    # Check that FlashAttention's numerical error is at most twice the numerical error\n    # of a Pytorch implementation.\n    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()\n\n    g = torch.randn_like(out)\n    if is_bwd_hdim_supported(d):\n        (dqkv_unpad,) = torch.autograd.grad(out, qkv_unpad, g)\n        dqkv = dqkv_pad_fn(dqkv_unpad)\n        (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g)\n        (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g)\n        print(f\"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}\")\n        print(f\"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}\")\n        print(f\"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}\")\n        print(f\"dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}\")\n        print(f\"dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}\")\n        print(f\"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}\")\n        print(f\"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}\")\n        print(f\"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}\")\n\n        # TODO - use 10 times to check, wait for ck to fix bwd precision issue\n        assert (dqkv - dqkv_ref).abs().max().item() <= 10 * (dqkv_pt - dqkv_ref).abs().max().item()\n\n\n@pytest.mark.parametrize(\"kvpacked\", [True, False])\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16])\n@pytest.mark.parametrize(\"mha_type\", [\"mha\", \"mqa\", \"gqa\"])\n@pytest.mark.parametrize(\"deterministic\", [False, True])\n@pytest.mark.parametrize(\"alibi\", [False, True])\n@pytest.mark.parametrize(\"local\", [False, True])\n@pytest.mark.parametrize(\"causal\", [False, True])\n@pytest.mark.parametrize(\"d\", [32, 40, 59, 64, 96, 111, 128, 160, 192, 224, 256])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_k\",\n    [\n        (113, 203),\n        (128, 217),\n        (113, 211),\n        (108, 256),\n        (256, 512),\n        (512, 256),\n        (1024, 1024),\n        (1023, 1024),\n        (1024, 1023),\n        (2048, 2048),\n    ],\n)\n@pytest.mark.parametrize(\"dropout_p\", [0.0, 0.17])\ndef test_flash_attn_output(\n    seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked\n):\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 4\n    nheads = 9\n    nheads_k = nheads if mha_type == \"mha\" else (1 if mha_type == \"mqa\" else 3)\n    assert nheads % nheads_k == 0\n    window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))\n    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    if kvpacked:\n        kv = torch.randn(\n            batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True\n        )\n    else:\n        k = torch.randn(\n            batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True\n        )\n        v = torch.randn(\n            batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True\n        )\n    if alibi:\n        alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3\n        attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen_q, seqlen_k, causal=causal)\n    else:\n        alibi_slopes, attn_bias = None, None\n\n    if kvpacked:\n        out, lse, S_dmask = flash_attn_kvpacked_func(\n            q,\n            kv,\n            dropout_p,\n            causal=causal,\n            window_size=window_size,\n            alibi_slopes=alibi_slopes,\n            deterministic=deterministic,\n            return_attn_probs=True,\n        )\n    else:\n        out, lse, S_dmask = flash_attn_func(\n            q,\n            k,\n            v,\n            dropout_p,\n            causal=causal,\n            window_size=window_size,\n            alibi_slopes=alibi_slopes,\n            deterministic=deterministic,\n            return_attn_probs=True,\n        )\n    if dropout_p > 0.0:\n        # TODO - move to c++ mha_varlen_fwd()\n        S_dmask = ck_randval_to_dropout_mask(S_dmask, dropout_p)\n        S_dmask_converted = convert_flash_attn_S_to_softmax(\n            S_dmask,\n            seqlen_q,\n            seqlen_k,\n            None,\n            None,\n            d,\n            dropout_p > 0.0,\n            causal=causal,\n            window_size=window_size,\n        )\n        dropout_mask = S_dmask_converted >= 0\n        if kvpacked:\n            kv_rep = repeat(kv, \"b s two h d -> b s two (h g) d\", g=nheads // nheads_k)\n            k_rep, v_rep = kv_rep.unbind(dim=2)\n        else:\n            k_rep = repeat(k, \"b s h d -> b s (h g) d\", g=nheads // nheads_k)\n            v_rep = repeat(v, \"b s h d -> b s (h g) d\", g=nheads // nheads_k)\n        # CK does not return P. Hence, we don't test the attn here.\n    else:\n        dropout_mask = None\n\n    if kvpacked:\n        out_ref, attn_ref = attention_kvpacked_ref(\n            q,\n            kv,\n            None,\n            None,\n            attn_bias,\n            dropout_p,\n            dropout_mask,\n            causal=causal,\n            window_size=window_size,\n        )\n        out_pt, attn_pt = attention_kvpacked_ref(\n            q,\n            kv,\n            None,\n            None,\n            attn_bias,\n            dropout_p,\n            dropout_mask,\n            causal=causal,\n            window_size=window_size,\n            upcast=False,\n            reorder_ops=True,\n        )\n    else:\n        out_ref, attn_ref = attention_ref(\n            q,\n            k,\n            v,\n            None,\n            None,\n            attn_bias,\n            dropout_p,\n            dropout_mask,\n            causal=causal,\n            window_size=window_size,\n        )\n        out_pt, attn_pt = attention_ref(\n            q,\n            k,\n            v,\n            None,\n            None,\n            attn_bias,\n            dropout_p,\n            dropout_mask,\n            causal=causal,\n            window_size=window_size,\n            upcast=False,\n            reorder_ops=True,\n        )\n\n    print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n    print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n    print(f\"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}\")\n    print(f\"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}\")\n\n    # Check that FlashAttention's numerical error is at most twice the numerical error\n    # of a Pytorch implementation.\n    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()\n\n    g = torch.randn_like(out)\n    if is_bwd_hdim_supported(d):\n        if kvpacked:\n            (\n                dq,\n                dkv,\n            ) = torch.autograd.grad(out, (q, kv), g)\n            dk, dv = dkv.unbind(2)\n            (\n                dq_ref,\n                dkv_ref,\n            ) = torch.autograd.grad(out_ref, (q, kv), g)\n            dk_ref, dv_ref = dkv_ref.unbind(2)\n            (\n                dq_pt,\n                dkv_pt,\n            ) = torch.autograd.grad(out_pt, (q, kv), g)\n            dk_pt, dv_pt = dkv_pt.unbind(2)\n        else:\n            (\n                dq,\n                dk,\n                dv,\n            ) = torch.autograd.grad(out, (q, k, v), g)\n            (\n                dq_ref,\n                dk_ref,\n                dv_ref,\n            ) = torch.autograd.grad(out_ref, (q, k, v), g)\n            (\n                dq_pt,\n                dk_pt,\n                dv_pt,\n            ) = torch.autograd.grad(out_pt, (q, k, v), g)\n        print(f\"dQ max diff: {(dq - dq_ref).abs().max().item()}\")\n        print(f\"dK max diff: {(dk - dk_ref).abs().max().item()}\")\n        print(f\"dV max diff: {(dv - dv_ref).abs().max().item()}\")\n        print(f\"dQ mean diff: {(dq - dq_ref).abs().mean().item()}\")\n        print(f\"dK mean diff: {(dk - dk_ref).abs().mean().item()}\")\n        print(f\"dV mean diff: {(dv - dv_ref).abs().mean().item()}\")\n        print(f\"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}\")\n        print(f\"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}\")\n        print(f\"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}\")\n        print(f\"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}\")\n        print(f\"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}\")\n        print(f\"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}\")\n\n        # TODO - use 10 times to check, wait for ck to fix bwd precision issue\n        assert (dq - dq_ref).abs().max().item() <= 10 * (dq_pt - dq_ref).abs().max().item()\n        assert (dk - dk_ref).abs().max().item() <= 10 * (dk_pt - dk_ref).abs().max().item()\n        assert (dv - dv_ref).abs().max().item() <= 10 * (dv_pt - dv_ref).abs().max().item()\n\n\n@pytest.mark.parametrize(\"kvpacked\", [True, False])\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16])\n@pytest.mark.parametrize(\"mha_type\", [\"mha\", \"mqa\", \"gqa\"])\n@pytest.mark.parametrize(\"deterministic\", [False, True])\n@pytest.mark.parametrize(\"alibi\", [False, True])\n@pytest.mark.parametrize(\"local\", [False, True])\n@pytest.mark.parametrize(\"causal\", [False, True])\n@pytest.mark.parametrize(\"d\", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_k\",\n    [\n        (1, 147),\n        (113, 203),\n        (128, 217),\n        (113, 211),\n        (108, 256),\n        (256, 512),\n        (512, 256),\n        (1024, 1024),\n        (1023, 1024),\n        (1024, 1023),\n        (2048, 2048),\n    ],\n)\n@pytest.mark.parametrize(\"dropout_p\", [0.0, 0.17])\ndef test_flash_attn_varlen_output(\n    seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked\n):\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 4\n    nheads = 9\n    nheads_k = nheads if mha_type == \"mha\" else (1 if mha_type == \"mqa\" else 3)\n    assert nheads % nheads_k == 0\n    window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))\n    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    if kvpacked:\n        kv = torch.randn(\n            batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True\n        )\n    else:\n        k = torch.randn(\n            batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True\n        )\n        v = torch.randn(\n            batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True\n        )\n\n    query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode=\"random\")\n    key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode=\"random\")\n    # key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full')\n    if alibi:\n        alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3\n        attn_bias = attn_bias_from_alibi_slopes(\n            alibi_slopes, seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, causal=causal\n        )\n    else:\n        alibi_slopes, attn_bias = None, None\n\n    if kvpacked:\n        (\n            q_unpad,\n            kv_unpad,\n            cu_seqlens_q,\n            cu_seqlens_k,\n            max_seqlen_q,\n            max_seqlen_k,\n            q,\n            kv,\n            output_pad_fn,\n            dq_pad_fn,\n            dkv_pad_fn,\n        ) = generate_qkv(q, *kv.unbind(dim=2), query_padding_mask, key_padding_mask, kvpacked=True)\n        out_unpad, sm_lse, S_dmask = flash_attn_varlen_kvpacked_func(\n            q_unpad,\n            kv_unpad,\n            cu_seqlens_q,\n            cu_seqlens_k,\n            max_seqlen_q,\n            max_seqlen_k,\n            dropout_p,\n            causal=causal,\n            window_size=window_size,\n            alibi_slopes=alibi_slopes,\n            deterministic=deterministic,\n            return_attn_probs=True,\n        )\n    else:\n        (\n            q_unpad,\n            k_unpad,\n            v_unpad,\n            cu_seqlens_q,\n            cu_seqlens_k,\n            max_seqlen_q,\n            max_seqlen_k,\n            q,\n            k,\n            v,\n            output_pad_fn,\n            dq_pad_fn,\n            dk_pad_fn,\n        ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)\n        out_unpad, sm_lse, S_dmask = flash_attn_varlen_func(\n            q_unpad,\n            k_unpad,\n            v_unpad,\n            cu_seqlens_q,\n            cu_seqlens_k,\n            max_seqlen_q,\n            max_seqlen_k,\n            dropout_p,\n            causal=causal,\n            window_size=window_size,\n            alibi_slopes=alibi_slopes,\n            deterministic=deterministic,\n            return_attn_probs=True,\n        )\n    out = output_pad_fn(out_unpad)\n    if dropout_p > 0.0:\n        # TODO - move to c++ mha_varlen_fwd()\n        S_dmask = ck_randval_to_dropout_mask(S_dmask, dropout_p)\n        S_dmask = pad_rearrange_dropout_mask_hts_to_bhss(S_dmask, cu_seqlens_q, seqlen_q, seqlen_k)\n        S_dmask_converted = convert_flash_attn_S_to_softmax(\n            S_dmask,\n            seqlen_q,\n            seqlen_k,\n            query_padding_mask,\n            key_padding_mask,\n            d,\n            dropout_p > 0.0,\n            causal=causal,\n            window_size=window_size,\n        )\n        dropout_mask = S_dmask_converted >= 0\n        if kvpacked:\n            kv_rep = repeat(kv, \"b s two h d -> b s two (h g) d\", g=nheads // nheads_k)\n            k_rep, v_rep = kv_rep.unbind(dim=2)\n        else:\n            k_rep = repeat(k, \"b s h d -> b s (h g) d\", g=nheads // nheads_k)\n            v_rep = repeat(v, \"b s h d -> b s (h g) d\", g=nheads // nheads_k)\n        # CK does not return P. Hence, we don't test the attn here.\n    else:\n        dropout_mask = None\n\n    if kvpacked:\n        out_ref, attn_ref = attention_kvpacked_ref(\n            q,\n            kv,\n            query_padding_mask,\n            key_padding_mask,\n            attn_bias,\n            dropout_p,\n            dropout_mask,\n            causal=causal,\n            window_size=window_size,\n        )\n        out_pt, attn_pt = attention_kvpacked_ref(\n            q,\n            kv,\n            query_padding_mask,\n            key_padding_mask,\n            attn_bias,\n            dropout_p,\n            dropout_mask,\n            causal=causal,\n            window_size=window_size,\n            upcast=False,\n            reorder_ops=True,\n        )\n    else:\n        out_ref, attn_ref = attention_ref(\n            q,\n            k,\n            v,\n            query_padding_mask,\n            key_padding_mask,\n            attn_bias,\n            dropout_p,\n            dropout_mask,\n            causal=causal,\n            window_size=window_size,\n        )\n        out_pt, attn_pt = attention_ref(\n            q,\n            k,\n            v,\n            query_padding_mask,\n            key_padding_mask,\n            attn_bias,\n            dropout_p,\n            dropout_mask,\n            causal=causal,\n            window_size=window_size,\n            upcast=False,\n            reorder_ops=True,\n        )\n\n    print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n    print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n    print(f\"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}\")\n    print(f\"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}\")\n\n    # Check that FlashAttention's numerical error is at most 4 times the numerical error\n    # of a Pytorch implementation.\n    assert (out - out_ref).abs().max().item() <= 4 * (out_pt - out_ref).abs().max().item()\n\n    g = torch.randn_like(out)\n    if is_bwd_hdim_supported(d):\n        if kvpacked:\n            (\n                dq_unpad,\n                dkv_unpad,\n            ) = torch.autograd.grad(out, (q_unpad, kv_unpad), g)\n            dk, dv = dkv_pad_fn(dkv_unpad).unbind(2)\n            (\n                dq_ref,\n                dkv_ref,\n            ) = torch.autograd.grad(out_ref, (q, kv), g)\n            dk_ref, dv_ref = dkv_ref.unbind(2)\n            (\n                dq_pt,\n                dkv_pt,\n            ) = torch.autograd.grad(out_pt, (q, kv), g)\n            dk_pt, dv_pt = dkv_pt.unbind(2)\n        else:\n            (\n                dq_unpad,\n                dk_unpad,\n                dv_unpad,\n            ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g)\n            dk = dk_pad_fn(dk_unpad)\n            dv = dk_pad_fn(dv_unpad)\n            (\n                dq_ref,\n                dk_ref,\n                dv_ref,\n            ) = torch.autograd.grad(out_ref, (q, k, v), g)\n            (\n                dq_pt,\n                dk_pt,\n                dv_pt,\n            ) = torch.autograd.grad(out_pt, (q, k, v), g)\n        dq = dq_pad_fn(dq_unpad)\n        print(f\"dQ max diff: {(dq - dq_ref).abs().max().item()}\")\n        print(f\"dK max diff: {(dk - dk_ref).abs().max().item()}\")\n        print(f\"dV max diff: {(dv - dv_ref).abs().max().item()}\")\n        print(f\"dQ mean diff: {(dq - dq_ref).abs().mean().item()}\")\n        print(f\"dK mean diff: {(dk - dk_ref).abs().mean().item()}\")\n        print(f\"dV mean diff: {(dv - dv_ref).abs().mean().item()}\")\n        print(f\"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}\")\n        print(f\"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}\")\n        print(f\"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}\")\n        print(f\"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}\")\n        print(f\"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}\")\n        print(f\"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}\")\n\n        # TODO - use 10 times to check, wait for ck to fix bwd precision issue\n        assert (dq - dq_ref).abs().max().item() <= 10 * (dq_pt - dq_ref).abs().max().item()\n        assert (dk - dk_ref).abs().max().item() <= 10 * (dk_pt - dk_ref).abs().max().item()\n        assert (dv - dv_ref).abs().max().item() <= 10 * (dv_pt - dv_ref).abs().max().item()\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16])\n@pytest.mark.parametrize(\"local\", [False, True])\n@pytest.mark.parametrize(\"d\", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])\n@pytest.mark.parametrize(\"swap_sq_sk\", [False, True])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_k\",\n    [\n        # (1, 239),\n        (3, 799),\n        (127, 512),\n        (127, 513),\n        (113, 203),\n        (128, 217),\n        (113, 211),\n        (108, 256),\n        (256, 512),\n        (1023, 1024),\n    ],\n)\ndef test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):\n    if max(seqlen_q, seqlen_k) >= 2048:\n        pytest.skip()\n    if swap_sq_sk:\n        seqlen_q, seqlen_k = seqlen_k, seqlen_q\n    device = \"cuda\"\n    causal = True\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 8\n    nheads = 9\n    window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))\n    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size)\n    out_ref, attn_ref = attention_ref(\n        q, k, v, None, None, None, 0.0, None, causal=causal, window_size=window_size\n    )\n    out_pt, attn_pt = attention_ref(\n        q,\n        k,\n        v,\n        None,\n        None,\n        None,\n        0.0,\n        None,\n        causal=causal,\n        window_size=window_size,\n        upcast=False,\n        reorder_ops=True,\n    )\n\n    print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n    print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n    print(f\"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}\")\n    print(f\"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}\")\n\n    # Check that FlashAttention's numerical error is at most 4 times the numerical error\n    # of a Pytorch implementation.\n    assert (out - out_ref).abs().max().item() <= 4 * (out_pt - out_ref).abs().max().item() + 1e-5\n\n    g = torch.randn_like(out)\n    if is_bwd_hdim_supported(d):\n        do_o = (g.float() * out.float()).sum(-1)\n        (\n            dq,\n            dk,\n            dv,\n        ) = torch.autograd.grad(out, (q, k, v), g)\n        (\n            dq_ref,\n            dk_ref,\n            dv_ref,\n        ) = torch.autograd.grad(out_ref, (q, k, v), g)\n        (\n            dq_pt,\n            dk_pt,\n            dv_pt,\n        ) = torch.autograd.grad(out_pt, (q, k, v), g)\n        print(f\"dQ max diff: {(dq - dq_ref).abs().max().item()}\")\n        print(f\"dK max diff: {(dk - dk_ref).abs().max().item()}\")\n        print(f\"dV max diff: {(dv - dv_ref).abs().max().item()}\")\n        print(f\"dQ mean diff: {(dq - dq_ref).abs().mean().item()}\")\n        print(f\"dK mean diff: {(dk - dk_ref).abs().mean().item()}\")\n        print(f\"dV mean diff: {(dv - dv_ref).abs().mean().item()}\")\n        print(f\"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}\")\n        print(f\"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}\")\n        print(f\"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}\")\n        print(f\"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}\")\n        print(f\"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}\")\n        print(f\"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}\")\n\n    # TODO - use 10 times to check, wait for ck to fix bwd precision issue\n    assert (dq - dq_ref).abs().max().item() <= 10 * (dq_pt - dq_ref).abs().max().item() + 1e-4\n    assert (dk - dk_ref).abs().max().item() <= 10 * (dk_pt - dk_ref).abs().max().item() + 1e-4\n    assert (dv - dv_ref).abs().max().item() <= 10 * (dv_pt - dv_ref).abs().max().item() + 1e-4\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16])\n@pytest.mark.parametrize(\"local\", [False, True])\n@pytest.mark.parametrize(\"d\", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])\n@pytest.mark.parametrize(\"swap_sq_sk\", [False, True])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_k\",\n    [\n        # (1, 239),\n        (3, 799),\n        (127, 512),\n        (127, 513),\n        (113, 203),\n        (128, 217),\n        (113, 211),\n        (108, 256),\n        (256, 512),\n        (1023, 1024),\n    ],\n)\n@pytest.mark.parametrize(\"paged_kv_block_size\", [None, 256, 512])\ndef test_flash_attn_varlen_causal(\n    seqlen_q, seqlen_k, swap_sq_sk, d, local, paged_kv_block_size, dtype\n):\n    if max(seqlen_q, seqlen_k) >= 2048:\n        pytest.skip()\n    if swap_sq_sk:\n        seqlen_q, seqlen_k = seqlen_k, seqlen_q\n    device = \"cuda\"\n    causal = True\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 8\n    nheads = 9\n    window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))\n    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)\n\n    if paged_kv_block_size is None:\n        k = torch.randn(\n            batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True\n        )\n        v = torch.randn(\n            batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True\n        )\n        block_table = None\n    else:\n        k, v, block_table, k_cache_paged, v_cache_paged, num_blocks = _generate_block_kvcache(\n            seqlen_k, paged_kv_block_size, batch_size, nheads, d, device, dtype\n        )\n    query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode=\"random\")\n    key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode=\"random\")\n    (\n        q_unpad,\n        k_unpad,\n        v_unpad,\n        cu_seqlens_q,\n        cu_seqlens_k,\n        max_seqlen_q,\n        max_seqlen_k,\n        q,\n        k,\n        v,\n        output_pad_fn,\n        dq_pad_fn,\n        dk_pad_fn,\n    ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)\n    out_unpad = flash_attn_varlen_func(\n        q_unpad,\n        k_unpad if paged_kv_block_size is None else k_cache_paged,\n        v_unpad if paged_kv_block_size is None else v_cache_paged,\n        cu_seqlens_q,\n        cu_seqlens_k,\n        max_seqlen_q,\n        max_seqlen_k,\n        0.0,\n        causal=causal,\n        window_size=window_size,\n        block_table=block_table,\n    )\n    out = output_pad_fn(out_unpad)\n    out_ref, attn_ref = attention_ref(\n        q,\n        k,\n        v,\n        query_padding_mask,\n        key_padding_mask,\n        None,\n        0.0,\n        None,\n        causal=causal,\n        window_size=window_size,\n    )\n    out_pt, attn_pt = attention_ref(\n        q,\n        k,\n        v,\n        query_padding_mask,\n        key_padding_mask,\n        None,\n        0.0,\n        None,\n        causal=causal,\n        window_size=window_size,\n        upcast=False,\n        reorder_ops=True,\n    )\n\n    print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n    print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n    print(f\"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}\")\n    print(f\"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}\")\n\n    # Check that FlashAttention's numerical error is at most twice the numerical error\n    # of a Pytorch implementation.\n    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5\n\n    g = torch.randn_like(out)\n    if is_bwd_hdim_supported(d):\n        do_o = (g.float() * out.float()).sum(-1)\n        test_backward = block_table is None\n        if test_backward:\n            (\n                dq_unpad,\n                dk_unpad,\n                dv_unpad,\n            ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g)\n            dq = dq_pad_fn(dq_unpad)\n            dk = dk_pad_fn(dk_unpad)\n            dv = dk_pad_fn(dv_unpad)\n            (\n                dq_ref,\n                dk_ref,\n                dv_ref,\n            ) = torch.autograd.grad(out_ref, (q, k, v), g)\n            (\n                dq_pt,\n                dk_pt,\n                dv_pt,\n            ) = torch.autograd.grad(out_pt, (q, k, v), g)\n            print(f\"dQ max diff: {(dq - dq_ref).abs().max().item()}\")\n            print(f\"dK max diff: {(dk - dk_ref).abs().max().item()}\")\n            print(f\"dV max diff: {(dv - dv_ref).abs().max().item()}\")\n            print(f\"dQ mean diff: {(dq - dq_ref).abs().mean().item()}\")\n            print(f\"dK mean diff: {(dk - dk_ref).abs().mean().item()}\")\n            print(f\"dV mean diff: {(dv - dv_ref).abs().mean().item()}\")\n            print(f\"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}\")\n            print(f\"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}\")\n            print(f\"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}\")\n            print(f\"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}\")\n            print(f\"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}\")\n            print(f\"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}\")\n\n        if test_backward:\n            # TODO - use 10 times to check, wait for ck to fix bwd precision issue\n            assert (dq - dq_ref).abs().max().item() <= 10 * (dq_pt - dq_ref).abs().max().item() + 1e-5\n            assert (dk - dk_ref).abs().max().item() <= 10 * (dk_pt - dk_ref).abs().max().item() + 1e-5\n            assert (dv - dv_ref).abs().max().item() <= 10 * (dv_pt - dv_ref).abs().max().item() + 1e-5\n\n\n# TODO - support splitkv\n# def test_flash_attn_splitkv\n\n\n# TODO - Support has_leftpad\n@pytest.mark.parametrize(\"dtype\", [torch.float16])\n@pytest.mark.parametrize(\"num_splits\", [1, 0])\n@pytest.mark.parametrize(\"mha_type\", [\"mha\", \"mqa\", \"gqa\"])\n@pytest.mark.parametrize(\"new_kv\", [False, True])\n@pytest.mark.parametrize(\"alibi\", [False, True])\n@pytest.mark.parametrize(\"local\", [False, True])\n@pytest.mark.parametrize(\"causal\", [False, True])\n@pytest.mark.parametrize(\"seqlen_new_eq_seqlen_q\", [True, False])\n@pytest.mark.parametrize(\"rotary_interleaved\", [False, True])\n@pytest.mark.parametrize(\"rotary_fraction\", [0.0, 0.5, 1.0])\n@pytest.mark.parametrize(\"paged_kv_block_size\", [None, 256])\n@pytest.mark.parametrize(\"has_leftpad\", [False])\n@pytest.mark.parametrize(\"has_batch_idx\", [False, True])\n@pytest.mark.parametrize(\"d\", [32, 59, 64, 80, 128, 256])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_k\",\n    [\n        (1, 128),\n        (1, 339),\n        (3, 1024),\n        (64, 800),\n        (64, 256),\n        (3, 799),\n        (64, 2048),\n        (16, 20000),\n        (1, 128 * 1024),\n        (16, 128 * 1024),\n        (128, 128),\n    ],\n)\ndef test_flash_attn_kvcache(\n    seqlen_q,\n    seqlen_k,\n    d,\n    has_batch_idx,\n    has_leftpad,\n    paged_kv_block_size,\n    rotary_fraction,\n    rotary_interleaved,\n    seqlen_new_eq_seqlen_q,\n    causal,\n    local,\n    alibi,\n    new_kv,\n    mha_type,\n    num_splits,\n    dtype,\n):\n    if seqlen_q > seqlen_k and new_kv:\n        pytest.skip()\n    if not new_kv and rotary_fraction > 0.0:\n        pytest.skip()\n    if has_batch_idx and paged_kv_block_size is not None:\n        pytest.skip()\n    if has_leftpad and paged_kv_block_size is not None:\n        pytest.skip()\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 1\n    batch_size_cache = batch_size if not has_batch_idx else batch_size * 2\n    nheads = 6\n    # rotary_dim must be a multiple of 16, and must be <= d\n    rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16\n    nheads_k = nheads if mha_type == \"mha\" else (1 if mha_type == \"mqa\" else 3)\n    assert nheads % nheads_k == 0\n    window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))\n    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype)\n    seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item()\n    if new_kv:\n        k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype)\n        v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype)\n    else:\n        k, v = None, None\n    if paged_kv_block_size is None:\n        k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)\n        v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)\n        block_table = None\n    else:\n        (\n            k_cache,\n            v_cache,\n            block_table,\n            k_cache_paged,\n            v_cache_paged,\n            num_blocks,\n        ) = _generate_block_kvcache(\n            seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype\n        )\n    cache_seqlens = torch.randint(\n        0 if new_kv else 1,\n        # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough\n        (\n            (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1)\n            if new_kv\n            else (seqlen_k + 1)\n        ),\n        (batch_size,),\n        dtype=torch.int32,\n        device=device,\n    )\n    if has_leftpad:\n        cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device)\n                                   if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device)\n                                   for i in range(batch_size)])\n    else:\n        cache_leftpad = None\n    arange = rearrange(torch.arange(seqlen_k, device=device), \"s -> 1 s\")\n    cache_seqlens_expanded = rearrange(cache_seqlens, \"b -> b 1\")\n    key_padding_mask = arange < cache_seqlens_expanded + (seqlen_new if new_kv else 0)\n    if has_leftpad:\n        key_padding_mask = torch.logical_and(\n            key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k)\n        )\n    if has_batch_idx:\n        cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[\n            :batch_size\n        ]\n    else:\n        cache_batch_idx = None\n    if alibi:\n        alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3\n        attn_bias = attn_bias_from_alibi_slopes(\n            alibi_slopes, seqlen_q, seqlen_k, None, key_padding_mask, causal=causal, key_leftpad=cache_leftpad\n        )\n    else:\n        alibi_slopes, attn_bias = None, None\n    # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)\n    if rotary_dim > 0:\n        angle = (\n            torch.rand(\n                seqlen_k if paged_kv_block_size is None else num_blocks * paged_kv_block_size,\n                rotary_dim // 2,\n                device=device,\n            )\n            * 2\n            * math.pi\n        )\n        cos = torch.cos(angle).to(dtype=dtype)\n        sin = torch.sin(angle).to(dtype=dtype)\n        if causal or local:\n            q_ro = apply_rotary_emb(\n                q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved\n            )\n        else:\n            q_ro = rearrange(\n                apply_rotary_emb(\n                    rearrange(q, \"b s h d -> b 1 (s h) d\"),\n                    cos,\n                    sin,\n                    seqlen_offsets=cache_seqlens,\n                    interleaved=rotary_interleaved,\n                ),\n                \"b 1 (s h) d -> b s h d\",\n                s=seqlen_q,\n            )\n        # q_ro = q\n        k_ro = apply_rotary_emb(\n            k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved\n        )\n    else:\n        cos, sin = None, None\n        q_ro, k_ro = q, k\n    # k_cache[:, 64:] = -1\n    k_cache_ref = (\n        k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)]\n    ).clone()\n    v_cache_ref = (\n        v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)]\n    ).clone()\n    if new_kv:\n        update_mask = torch.logical_and(\n            cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_new\n        )\n        k_cache_ref[update_mask] = rearrange(k_ro, \"b s ... -> (b s) ...\")\n        v_cache_ref[update_mask] = rearrange(v, \"b s ... -> (b s) ...\")\n    k_cache_rep = repeat(k_cache_ref, \"b s h d -> b s (h g) d\", g=nheads // nheads_k)\n    v_cache_rep = repeat(v_cache_ref, \"b s h d -> b s (h g) d\", g=nheads // nheads_k)\n    out = flash_attn_with_kvcache(\n        q,\n        k_cache if paged_kv_block_size is None else k_cache_paged,\n        v_cache if paged_kv_block_size is None else v_cache_paged,\n        k,\n        v,\n        rotary_cos=cos,\n        rotary_sin=sin,\n        cache_seqlens=cache_seqlens,\n        cache_batch_idx=cache_batch_idx,\n        cache_leftpad=cache_leftpad,\n        block_table=block_table,\n        causal=causal,\n        window_size=window_size,\n        rotary_interleaved=rotary_interleaved,\n        alibi_slopes=alibi_slopes,\n        num_splits=num_splits,\n    )\n    # out = flash_attn_with_kvcache(\n    #     q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size\n    # )\n    # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size)\n    # qk = torch.einsum(\"bqhd,bkhd->bhqk\", q, k_cache_ref)\n    # m = qk.amax(-1, keepdim=True)\n    # s_tmp = torch.exp((qk - m) / math.sqrt(d))\n    # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref)\n    # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)\n    # probs = torch.softmax(qk, dim=-1)\n    out_ref, _ = attention_ref(\n        q_ro,\n        k_cache_rep,\n        v_cache_rep,\n        None,\n        key_padding_mask,\n        attn_bias,\n        0.0,\n        None,\n        causal=causal,\n        window_size=window_size,\n        key_leftpad=cache_leftpad,\n    )\n    out_pt, _ = attention_ref(\n        q_ro,\n        k_cache_rep,\n        v_cache_rep,\n        None,\n        key_padding_mask,\n        attn_bias,\n        0.0,\n        None,\n        causal=causal,\n        window_size=window_size,\n        upcast=False,\n        reorder_ops=True,\n        key_leftpad=cache_leftpad,\n    )\n    print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n    print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n    print(f\"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}\")\n    print(f\"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}\")\n\n    # Check that FlashAttention's numerical error is at most twice the numerical error\n    # of a Pytorch implementation.\n    if new_kv:\n        if paged_kv_block_size is None:\n            k_cache_select = (\n                k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)]\n            )\n            v_cache_select = (\n                v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)]\n            )\n        else:\n            k_cache_select = rearrange(\n                k_cache_paged[block_table.to(dtype=torch.long).flatten()],\n                \"(b nblocks) block_size ... -> b (nblocks block_size) ...\",\n                b=batch_size,\n            )[:, :seqlen_k]\n            v_cache_select = rearrange(\n                v_cache_paged[block_table.to(dtype=torch.long).flatten()],\n                \"(b nblocks) block_size ... -> b (nblocks block_size) ...\",\n                b=batch_size,\n            )[:, :seqlen_k]\n        assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3)\n        assert torch.equal(v_cache_select, v_cache_ref)\n    # mult = 3 if f16, bf16 need 4\n    mult = 4 if not alibi else 5\n    assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5\n\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float16])\n@pytest.mark.parametrize(\"causal\", [False, True])\n@pytest.mark.parametrize(\"d\", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_k\",\n    [\n        (1, 239),\n        (239, 1),\n        (3, 799),\n        (799, 3),\n        (1024, 128),\n        (97, 97),\n        (128, 128),\n        (200, 200),\n        (256, 256),\n        (257, 257),\n        (384, 384),\n        (512, 512),\n        (768, 768),\n        # (1024, 1024),\n    ],\n)\n@pytest.mark.parametrize(\"dropout_p\", [0.0, 0.17])\ndef test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dtype):\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 60  # Sometimes we need large batch size for the race conditions to trigger\n    nheads = 4\n    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    torch.random.manual_seed(42)\n    out0, lse0, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True)\n    g = torch.randn_like(out0)\n    if dropout_p == 0 and is_bwd_hdim_supported(d):\n        (\n            dq0,\n            dk0,\n            dv0,\n        ) = torch.autograd.grad(out0, (q, k, v), g)\n        # Numerical error if we just do any arithmetic on dq\n        dq_atol = 2 * ((dq0 + 0.3 - 0.3) - dq0).abs().max().item()\n\n    for i in range(250):\n        torch.random.manual_seed(42)\n        out, lse, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True)\n        assert torch.equal(out, out0)\n        assert torch.equal(lse, lse0)\n\n        if dropout_p == 0:\n            (\n                dq,\n                dk,\n                dv,\n            ) = torch.autograd.grad(out, (q, k, v), g)\n            dq_equal = torch.allclose(dq, dq0, atol=dq_atol)\n            if not dq_equal:\n                print(f\"Iter {i}, {dq_atol = }, dQ max diff: {(dq - dq0).abs().max().item()}\")\n\n            assert torch.equal(dv, dv0)\n            assert torch.equal(dk, dk0)\n            assert dq_equal\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float16])\n@pytest.mark.parametrize(\"causal\", [False, True])\n@pytest.mark.parametrize(\"d\", [16, 32, 64])\n@pytest.mark.parametrize(\"seqlen\", [1, 2, 5, 17, 128])\ndef test_flash_attn_bwd_overflow(seqlen, d, causal, dtype):\n    \"\"\"We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ,\n    in the case where seqlen % 128 != 0.\n    \"\"\"\n\n    # TODO - 1 or 2 might fail, need to check\n    if seqlen == 1 or seqlen == 2:\n        pytest.skip()\n\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 2\n    nheads = 5\n    q = torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device=\"cuda\") * 5\n    k, v = [\n        torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device=\"cuda\") * 3\n        for _ in range(2)\n    ]\n    q.requires_grad_(True)\n    k.requires_grad_(True)\n    v.requires_grad_(True)\n    out = flash_attn_func(q, k, v, causal=causal)\n    g = torch.randn_like(out)\n    out.backward(g)\n    q_pt = q.detach().clone().requires_grad_(True)\n    k_pt = k.detach().clone().requires_grad_(True)\n    v_pt = v.detach().clone().requires_grad_(True)\n    out_pt, _ = attention_ref(q_pt, k_pt, v_pt, causal=causal, upcast=False, reorder_ops=True)\n    out_pt.backward(g)\n    q_ref = q.detach().clone().requires_grad_(True)\n    k_ref = k.detach().clone().requires_grad_(True)\n    v_ref = v.detach().clone().requires_grad_(True)\n    out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal)\n    out_ref.backward(g)\n    print(f\"dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}\")\n    print(f\"dK max diff: {(k.grad - k_ref.grad).abs().max().item()}\")\n    print(f\"dV max diff: {(v.grad - v_ref.grad).abs().max().item()}\")\n    print(f\"dQ Pytorch max diff: {(q_pt.grad - q_ref.grad).abs().max().item()}\")\n    print(f\"dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}\")\n    print(f\"dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}\")\n    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()\n    assert (q.grad - q_ref.grad).abs().max().item() <= 7 * (\n        q_pt.grad - q_ref.grad\n    ).abs().max().item() + 1e-3\n    assert (k.grad - k_ref.grad).abs().max().item() <= 5 * (\n        k_pt.grad - k_ref.grad\n    ).abs().max().item() + 1e-3\n    assert (v.grad - v_ref.grad).abs().max().item() <= 5 * (\n        v_pt.grad - v_ref.grad\n    ).abs().max().item() + 1e-3\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16])\n@pytest.mark.parametrize(\"causal\", [False, True])\n@pytest.mark.parametrize(\"d\", [64, 128])\n@pytest.mark.parametrize(\"seqlen\", [97, 128, 200, 256])\ndef test_flash_attn_bwd_transpose(seqlen, d, causal, dtype):\n    \"\"\"We previously had a bug where we were using the wrong strides of dout, which shows up\n    when dout is not contiguous.\n    \"\"\"\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 5\n    nheads = 2\n    q, k, v = [\n        torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device=\"cuda\", requires_grad=True)\n        for _ in range(3)\n    ]\n    out = rearrange(flash_attn_func(q, k, v, causal=causal), \"b s ... -> s b ...\")\n    # So g is not contiguous\n    g = torch.randn(seqlen, 2 * batch_size, nheads, d, dtype=dtype, device=\"cuda\")[:, ::2]\n    out.backward(g)\n    q_pt = q.detach().clone().requires_grad_(True)\n    k_pt = k.detach().clone().requires_grad_(True)\n    v_pt = v.detach().clone().requires_grad_(True)\n    out_pt, attn_pt = attention_ref(q_pt, k_pt, v_pt, causal=causal, upcast=False, reorder_ops=True)\n    out_pt = rearrange(out_pt, \"b s ... -> s b ...\")\n    out_pt.backward(g)\n    q_ref = q.detach().clone().requires_grad_(True)\n    k_ref = k.detach().clone().requires_grad_(True)\n    v_ref = v.detach().clone().requires_grad_(True)\n    out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal)\n    out_ref = rearrange(out_ref, \"b s ... -> s b ...\")\n    out_ref.backward(g)\n    print(f\"dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}\")\n    print(f\"dK max diff: {(k.grad - k_ref.grad).abs().max().item()}\")\n    print(f\"dV max diff: {(v.grad - v_ref.grad).abs().max().item()}\")\n    print(f\"dQ Pytorch max diff: {(q_pt.grad - q_ref.grad).abs().max().item()}\")\n    print(f\"dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}\")\n    print(f\"dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}\")\n    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()\n    assert (q.grad - q_ref.grad).abs().max().item() <= 2 * (\n        q_pt.grad - q_ref.grad\n    ).abs().max().item()\n    assert (k.grad - k_ref.grad).abs().max().item() <= 2 * (\n        k_pt.grad - k_ref.grad\n    ).abs().max().item()\n    assert (v.grad - v_ref.grad).abs().max().item() <= 2 * (\n        v_pt.grad - v_ref.grad\n    ).abs().max().item()\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float16])\n@pytest.mark.parametrize(\"causal\", [False, True])\n@pytest.mark.parametrize(\"d\", [16, 32, 64])\ndef test_flash_attn_bwd_varlen_overflow(d, causal, dtype):\n    \"\"\"We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ,\n    in the case where seqlen % 128 != 0 or varlen.\n    \"\"\"\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(0)\n    nheads = 5\n    q_cuseqlen = torch.tensor([0, 76, 110, 256], device=device, dtype=torch.int32)\n    k_cuseqlen = torch.tensor([0, 1, 2, 3], device=device, dtype=torch.int32)\n    Mq = 256\n    Mk = 3\n\n    q = torch.randn([Mq, nheads, d], dtype=dtype, device=device) * 3\n    k, v = [torch.randn([Mk, nheads, d], dtype=dtype, device=device) * 3 for _ in range(2)]\n    q.requires_grad_(True)\n    k.requires_grad_(True)\n    v.requires_grad_(True)\n\n    out = flash_attn_varlen_func(q, k, v, q_cuseqlen, k_cuseqlen, Mq, Mk, causal=causal)\n    g = torch.randn_like(out)\n    out.backward(g)\n\n    assert not q.grad.isnan().any()\n    assert not k.grad.isnan().any()\n    assert not v.grad.isnan().any()\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16])\n@pytest.mark.parametrize(\"local\", [False, True])\n@pytest.mark.parametrize(\"causal\", [False, True])\n@pytest.mark.parametrize(\"d\", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])\n@pytest.mark.parametrize(\"swap_sq_sk\", [False, True])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_k\",\n    [\n        (1, 239),\n        (3, 799),\n        (127, 512),\n        (127, 513),\n        (113, 203),\n        (128, 217),\n        (113, 211),\n        (108, 256),\n        (256, 512),\n        (1023, 1024),\n    ],\n)\ndef test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype):\n    if (\n        max(seqlen_q, seqlen_k) >= 2048\n        and torch.cuda.get_device_properties(\"cuda\").total_memory <= 16 * 2**30\n    ):\n        pytest.skip()  # Reference implementation OOM\n    if swap_sq_sk:\n        seqlen_q, seqlen_k = seqlen_k, seqlen_q\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 4\n    nheads = 9\n    window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))\n    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size, deterministic=True)\n\n    g = torch.randn_like(out)\n    dq0, dk0, dv0 = torch.autograd.grad(out, (q, k, v), g, retain_graph=True)\n    for _ in range(50):\n        dq, dk, dv = torch.autograd.grad(out, (q, k, v), g, retain_graph=True)\n        assert torch.equal(dv, dv0)\n        assert torch.equal(dk, dk0)\n        assert torch.equal(dq, dq0)\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16])\n@pytest.mark.parametrize(\"local\", [False, True])\n@pytest.mark.parametrize(\"causal\", [False, True])\n@pytest.mark.parametrize(\"d\", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])\n@pytest.mark.parametrize(\"swap_sq_sk\", [False, True])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_k\",\n    [\n        (1, 239),\n        (3, 799),\n        (127, 512),\n        (127, 513),\n        (113, 203),\n        (128, 217),\n        (113, 211),\n        (108, 256),\n        (256, 512),\n        (1023, 1024),\n    ],\n)\ndef test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype):\n    if (\n        max(seqlen_q, seqlen_k) >= 2048\n        and torch.cuda.get_device_properties(\"cuda\").total_memory <= 16 * 2**30\n    ):\n        pytest.skip()  # Reference implementation OOM\n    if swap_sq_sk:\n        seqlen_q, seqlen_k = seqlen_k, seqlen_q\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 2\n    nheads = 9\n    window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))\n    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode=\"random\")\n    key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode=\"random\")\n    (\n        q_unpad,\n        k_unpad,\n        v_unpad,\n        cu_seqlens_q,\n        cu_seqlens_k,\n        max_seqlen_q,\n        max_seqlen_k,\n        q,\n        k,\n        v,\n        output_pad_fn,\n        dq_pad_fn,\n        dk_pad_fn,\n    ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)\n    out = flash_attn_varlen_func(\n        q_unpad,\n        k_unpad,\n        v_unpad,\n        cu_seqlens_q,\n        cu_seqlens_k,\n        max_seqlen_q,\n        max_seqlen_k,\n        0.0,\n        causal=causal,\n        window_size=window_size,\n        deterministic=True,\n    )\n\n    g = torch.randn_like(out)\n    dq0, dk0, dv0 = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)\n    for _ in range(50):\n        dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)\n        assert torch.equal(dv, dv0)\n        assert torch.equal(dk, dk0)\n        assert torch.equal(dq, dq0)\n\n"
  },
  {
    "path": "tests/test_flash_attn_triton_amd.py",
    "content": "import math\n\nimport pytest\nimport torch\nimport torch.nn.functional as F\nfrom einops import rearrange, repeat\nfrom flash_attn import (\n    flash_attn_func,\n    flash_attn_kvpacked_func,\n    flash_attn_qkvpacked_func,\n    flash_attn_varlen_func,\n    flash_attn_varlen_kvpacked_func,\n    flash_attn_varlen_qkvpacked_func,\n    flash_attn_with_kvcache,\n)\nfrom flash_attn.bert_padding import pad_input, unpad_input\nfrom flash_attn.flash_attn_interface import _get_block_size_n\nfrom flash_attn.layers.rotary import apply_rotary_emb\nfrom aiter.ops.triton._triton_kernels.flash_attn_triton_amd.utils import USE_TRITON_ROCM, is_hip, get_arch\n\n\ndef _get_block_size_n_triton(device, head_dim, is_dropout, is_causal):\n    \"\"\"Get block size for Triton AMD kernel.\"\"\"\n    arch = get_arch()\n    if arch.is_rdna:\n        return 32\n    elif arch.is_cdna:\n        return 64\n    # Fall back to CUDA kernel block sizes\n    return _get_block_size_n(device, head_dim, is_dropout, is_causal)\n\n\nMAX_HEADDIM_SM8x = 192\n\n\nis_sm75 = torch.cuda.get_device_capability(\"cuda\") == (7, 5)\nis_sm8x = torch.cuda.get_device_capability(\"cuda\")[0] == 8\nis_sm80 = torch.cuda.get_device_capability(\"cuda\") == (8, 0)\nis_sm90 = torch.cuda.get_device_capability(\"cuda\") == (9, 0)\n\nskip_bfloat16 = True if is_sm75 or is_hip() else False\n\n\ndef attn_bias_from_alibi_slopes(\n    slopes, seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, causal=False, key_leftpad=None\n):\n    batch, nheads = slopes.shape\n    device = slopes.device\n    slopes = rearrange(slopes, \"b h -> b h 1 1\")\n    if causal:\n        return torch.arange(-seqlen_k + 1, 1, device=device, dtype=torch.float32) * slopes\n    else:\n        row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), \"s -> s 1\")\n        col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)\n        if key_leftpad is not None:\n            key_leftpad = rearrange(key_leftpad, \"b -> b 1 1 1\")\n            col_idx = repeat(col_idx, \"s -> b 1 1 s\", b=key_leftpad.shape[0])\n            col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)\n        sk = (\n            seqlen_k\n            if key_padding_mask is None\n            else rearrange(key_padding_mask.sum(-1), \"b -> b 1 1 1\")\n        )\n        sq = (\n            seqlen_q\n            if query_padding_mask is None\n            else rearrange(query_padding_mask.sum(-1), \"b -> b 1 1 1\")\n        )\n        relative_pos = torch.abs(row_idx + sk - sq - col_idx)\n        return -slopes * relative_pos.to(dtype=slopes.dtype)\n\n\ndef generate_random_padding_mask(max_seqlen, batch_size, device, mode=\"random\"):\n    assert mode in [\"full\", \"random\", \"third\"]\n    if mode == \"full\":\n        lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32)\n    elif mode == \"random\":\n        lengths = torch.randint(\n            max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device\n        )\n    elif mode == \"third\":\n        lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device)\n    padding_mask = (\n        repeat(torch.arange(max_seqlen, device=device), \"s -> b s\", b=batch_size) < lengths\n    )\n    return padding_mask\n\n\ndef generate_qkv(\n    q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False\n):\n    \"\"\"\n    Arguments:\n        q: (batch_size, seqlen_q, nheads, d)\n        k: (batch_size, seqlen_k, nheads_k, d)\n        v: (batch_size, seqlen_k, nheads_k, d)\n        query_padding_mask: (batch_size, seqlen), bool\n        key_padding_mask: (batch_size, seqlen), bool\n    \"\"\"\n    assert not (kvpacked and qkvpacked)\n    batch_size, seqlen_q, nheads, d = q.shape\n    _, seqlen_k, nheads_k, _ = k.shape\n    assert k.shape == (batch_size, seqlen_k, nheads_k, d)\n    assert v.shape == (batch_size, seqlen_k, nheads_k, d)\n\n    if query_padding_mask is not None:\n        q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, _ = unpad_input(q, query_padding_mask)\n        output_pad_fn = lambda output_unpad: pad_input(\n            output_unpad, indices_q, batch_size, seqlen_q\n        )\n    else:\n        q_unpad = rearrange(q, \"b s h d -> (b s) h d\")\n        cu_seqlens_q = torch.arange(\n            0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device\n        )\n        max_seqlen_q = seqlen_q\n        output_pad_fn = lambda output_unpad: rearrange(\n            output_unpad, \"(b s) h d -> b s h d\", b=batch_size\n        )\n\n    if key_padding_mask is not None:\n        k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, _ = unpad_input(k, key_padding_mask)\n        v_unpad, _, _, _, _ = unpad_input(v, key_padding_mask)\n    else:\n        k_unpad = rearrange(k, \"b s h d -> (b s) h d\")\n        v_unpad = rearrange(v, \"b s h d -> (b s) h d\")\n        cu_seqlens_k = torch.arange(\n            0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device\n        )\n        max_seqlen_k = seqlen_k\n\n    if qkvpacked:\n        assert (query_padding_mask == key_padding_mask).all()\n        assert nheads == nheads_k\n        qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)\n        qkv = torch.stack([q, k, v], dim=2)\n        if query_padding_mask is not None:\n            dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q)\n        else:\n            dqkv_pad_fn = lambda dqkv_unpad: rearrange(\n                dqkv_unpad, \"(b s) t h d -> b s t h d\", b=batch_size\n            )\n        return (\n            qkv_unpad.detach().requires_grad_(),\n            cu_seqlens_q,\n            max_seqlen_q,\n            qkv.detach().requires_grad_(),\n            output_pad_fn,\n            dqkv_pad_fn,\n        )\n    elif kvpacked:\n        kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)\n        kv = torch.stack([k, v], dim=2)\n        dq_pad_fn = output_pad_fn\n        if key_padding_mask is not None:\n            dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k)\n        else:\n            dkv_pad_fn = lambda dkv_unpad: rearrange(\n                dkv_unpad, \"(b s) t h d -> b s t h d\", b=batch_size\n            )\n        return (\n            q_unpad.detach().requires_grad_(),\n            kv_unpad.detach().requires_grad_(),\n            cu_seqlens_q,\n            cu_seqlens_k,\n            max_seqlen_q,\n            max_seqlen_k,\n            q.detach().requires_grad_(),\n            kv.detach().requires_grad_(),\n            output_pad_fn,\n            dq_pad_fn,\n            dkv_pad_fn,\n        )\n    else:\n        dq_pad_fn = output_pad_fn\n        if key_padding_mask is not None:\n            dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k)\n        else:\n            dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, \"(b s) h d -> b s h d\", b=batch_size)\n        return (\n            q_unpad.detach().requires_grad_(),\n            k_unpad.detach().requires_grad_(),\n            v_unpad.detach().requires_grad_(),\n            cu_seqlens_q,\n            cu_seqlens_k,\n            max_seqlen_q,\n            max_seqlen_k,\n            q.detach().requires_grad_(),\n            k.detach().requires_grad_(),\n            v.detach().requires_grad_(),\n            output_pad_fn,\n            dq_pad_fn,\n            dk_pad_fn,\n        )\n\n\ndef construct_local_mask(\n    seqlen_q,\n    seqlen_k,\n    window_size=(-1, -1),  # -1 means infinite window size\n    query_padding_mask=None,\n    key_padding_mask=None,\n    device=None,\n    key_leftpad=None,\n):\n    row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), \"s -> s 1\")\n    col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)\n    if key_leftpad is not None:\n        key_leftpad = rearrange(key_leftpad, \"b -> b 1 1 1\")\n        col_idx = repeat(col_idx, \"s -> b 1 1 s\", b=key_leftpad.shape[0])\n        col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)\n    sk = (\n        seqlen_k\n        if key_padding_mask is None\n        else rearrange(key_padding_mask.sum(-1), \"b -> b 1 1 1\")\n    )\n    sq = (\n        seqlen_q\n        if query_padding_mask is None\n        else rearrange(query_padding_mask.sum(-1), \"b -> b 1 1 1\")\n    )\n    if window_size[0] < 0:\n        return col_idx > row_idx + sk - sq + window_size[1]\n    else:\n        sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk\n        return torch.logical_or(\n            col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),\n            col_idx < row_idx + sk - sq - window_size[0],\n        )\n\n\ndef attention_ref(\n    q,\n    k,\n    v,\n    query_padding_mask=None,\n    key_padding_mask=None,\n    attn_bias=None,\n    dropout_p=0.0,\n    dropout_mask=None,\n    causal=False,\n    window_size=(-1, -1),  # -1 means infinite window size\n    softcap=0.0,\n    upcast=True,\n    reorder_ops=False,\n    key_leftpad=None,\n):\n    \"\"\"\n    Arguments:\n        q: (batch_size, seqlen_q, nheads, head_dim)\n        k: (batch_size, seqlen_k, nheads_k, head_dim)\n        v: (batch_size, seqlen_k, nheads_k, head_dim)\n        query_padding_mask: (batch_size, seqlen_q)\n        key_padding_mask: (batch_size, seqlen_k)\n        attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)\n        dropout_p: float\n        dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)\n        causal: whether to apply causal masking\n        window_size: (int, int), left and right window size\n        upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast\n            output back to fp16/bf16.\n        reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.)\n            without changing the math. This is to estimate the numerical error from operation\n            reordering.\n    Output:\n        output: (batch_size, seqlen_q, nheads, head_dim)\n        attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout\n    \"\"\"\n    if causal:\n        window_size = (window_size[0], 0)\n    dtype_og = q.dtype\n    if upcast:\n        q, k, v = q.float(), k.float(), v.float()\n    seqlen_q, seqlen_k = q.shape[1], k.shape[1]\n    k = repeat(k, \"b s h d -> b s (h g) d\", g=q.shape[2] // k.shape[2])\n    v = repeat(v, \"b s h d -> b s (h g) d\", g=q.shape[2] // v.shape[2])\n    d = q.shape[-1]\n    if not reorder_ops:\n        scores = torch.einsum(\"bthd,bshd->bhts\", q / math.sqrt(d), k)\n    else:\n        scores = torch.einsum(\"bthd,bshd->bhts\", q, k / math.sqrt(d))\n    if softcap > 0:\n        scores = scores / softcap\n        scores = scores.tanh()\n        scores = scores * softcap\n    if key_padding_mask is not None:\n        scores.masked_fill_(rearrange(~key_padding_mask, \"b s -> b 1 1 s\"), float(\"-inf\"))\n    if window_size[0] >= 0 or window_size[1] >= 0:\n        local_mask = construct_local_mask(\n            seqlen_q,\n            seqlen_k,\n            window_size,\n            query_padding_mask,\n            key_padding_mask,\n            q.device,\n            key_leftpad=key_leftpad,\n        )\n        scores.masked_fill_(local_mask, float(\"-inf\"))\n    if attn_bias is not None:\n        scores = scores + attn_bias\n    attention = torch.softmax(scores, dim=-1).to(v.dtype)\n    # Some rows might be completely masked out so we fill them with zero instead of NaN\n    if window_size[0] >= 0 or window_size[1] >= 0:\n        attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0)\n    # We want to mask here so that the attention matrix doesn't have any NaNs\n    # Otherwise we'll get NaN in dV\n    if query_padding_mask is not None:\n        attention = attention.masked_fill(rearrange(~query_padding_mask, \"b s -> b 1 s 1\"), 0.0)\n    dropout_scaling = 1.0 / (1 - dropout_p)\n    # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling\n    # output = torch.einsum('bhts,bshd->bthd', attention_drop , v)\n    if dropout_mask is not None:\n        attention_drop = attention.masked_fill(~dropout_mask, 0.0)\n    else:\n        attention_drop = attention\n    output = torch.einsum(\"bhts,bshd->bthd\", attention_drop, v * dropout_scaling)\n    if query_padding_mask is not None:\n        output.masked_fill_(rearrange(~query_padding_mask, \"b s -> b s 1 1\"), 0.0)\n    return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)\n\n\ndef attention_kvpacked_ref(\n    q,\n    kv,\n    query_padding_mask=None,\n    key_padding_mask=None,\n    attn_bias=None,\n    dropout_p=0.0,\n    dropout_mask=None,\n    causal=False,\n    window_size=(-1, -1),  # -1 means infinite window size\n    softcap=0.0,\n    upcast=True,\n    reorder_ops=False,\n    key_leftpad=None,\n):\n    return attention_ref(\n        q,\n        kv[:, :, 0],\n        kv[:, :, 1],\n        query_padding_mask,\n        key_padding_mask,\n        attn_bias,\n        dropout_p,\n        dropout_mask,\n        upcast=upcast,\n        causal=causal,\n        window_size=window_size,\n        softcap=softcap,\n        reorder_ops=reorder_ops,\n        key_leftpad=key_leftpad,\n    )\n\n\ndef attention_qkvpacked_ref(\n    qkv,\n    key_padding_mask=None,\n    attn_bias=None,\n    dropout_p=0.0,\n    dropout_mask=None,\n    causal=False,\n    window_size=(-1, -1),  # -1 means infinite window size\n    softcap=0.0,\n    upcast=True,\n    reorder_ops=False,\n):\n    return attention_ref(\n        qkv[:, :, 0],\n        qkv[:, :, 1],\n        qkv[:, :, 2],\n        key_padding_mask,\n        key_padding_mask,\n        attn_bias,\n        dropout_p,\n        dropout_mask,\n        upcast=upcast,\n        causal=causal,\n        window_size=window_size,\n        softcap=softcap,\n        reorder_ops=reorder_ops,\n    )\n\n\ndef generate_sparsity_mask(seqlen, sparsity=0.3):\n    repeats = seqlen // 16 // 2\n    # mask = torch.stack([torch.tensor([1, 0] * repeats, dtype=torch.bool, device='cuda'),\n    #                     torch.tensor([0, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)\n    # mask = torch.stack([torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda'),\n    #                     torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)\n    # mask = torch.stack([torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)\n    # mask = torch.stack([torch.tensor([1, 0] * repeats, dtype=torch.bool, device='cuda')], dim=-1)\n    nrow, ncol = seqlen // 16, seqlen // 256\n    mask = torch.rand(nrow, ncol, device=\"cuda\") < sparsity\n    return mask\n\n\ndef attention_blocksparse_ref(qkv, blockmask, attn_mask, dropout_p, dropout_mask):\n    \"\"\"\n    Arguments:\n        qkv: (batch_size, seqlen, 3, nheads, head_dim)\n        blockmask: (seqlen / 16, seqlen / 256)\n        attn_mask: (batch_size, seqlen)\n        dropout_p: float\n        dropout_mask: (batch_size, nheads, seqlen, seqlen)\n    Output:\n        output: (batch_size, seqlen, nheads, head_dim)\n        attention: softmax after dropout\n    \"\"\"\n    q, k, v = qkv.float().unbind(dim=2)\n    d = qkv.shape[-1]\n    seqlen = qkv.shape[1]\n    scores = torch.einsum(\"bthd,bshd->bhts\", q / math.sqrt(d), k)\n    scores.masked_fill_(rearrange(~attn_mask, \"b s -> b 1 1 s\"), float(\"-inf\"))\n    blockmask = repeat(blockmask, \"s_16 s_256 -> (s_16 16) (s_256 256)\")\n    blockmask = blockmask[:seqlen, :seqlen]\n    scores.masked_fill_(rearrange(~blockmask, \"t s -> 1 1 t s\"), float(\"-inf\"))\n    attention = torch.softmax(scores, dim=-1)\n    attention = attention.masked_fill(rearrange(~attn_mask, \"b s -> b 1 s 1\"), 0.0)\n    attention = attention.masked_fill_(rearrange(~blockmask, \"t s -> 1 1 t s\"), 0.0)\n    attention_drop = attention.masked_fill(~dropout_mask, 0.0) / (1 - dropout_p)\n    output = torch.einsum(\"bhts,bshd->bthd\", attention_drop, v)\n    output.masked_fill_(rearrange(~attn_mask, \"b s -> b s 1 1\"), 0)\n    return output.to(dtype=qkv.dtype), attention.to(dtype=qkv.dtype)\n\n\ndef convert_flash_attn_S_to_softmax(\n    S,\n    seqlen_q,\n    seqlen_k,\n    query_padding_mask,\n    key_padding_mask,\n    head_dim,\n    is_dropout,\n    causal=False,\n    window_size=(-1, -1),  # -1 means infinite window size\n):\n    \"\"\"FlashAttention stores the S matrix in a different way.\n    Arguments:\n        S: (batch_size, nheads, seqlen_q_rounded, seqlen_k_rounded)\n        query_padding_mask: (batch_size, seqlen_q_rounded)\n        key_padding_mask: (batch_size, seqlen_k_rounded)\n    \"\"\"\n    if causal:\n        window_size = (window_size[0], 0)\n    seqlen_q_rounded, seqlen_k_rounded = S.shape[-2:]\n    S_converted = S\n    if window_size[0] >= 0 or window_size[1] >= 0:\n        local_mask = construct_local_mask(\n            seqlen_q,\n            seqlen_k,\n            window_size,\n            query_padding_mask,\n            key_padding_mask,\n            S.device,\n        )\n        local_mask = F.pad(\n            local_mask,\n            (0, seqlen_k_rounded - seqlen_k, 0, seqlen_q_rounded - seqlen_q),\n            value=True,\n        )\n        S_converted = S_converted.masked_fill(local_mask, 0.0)\n\n    # Need to zero out things not in attention_mask in case S was initialized with random values\n    # and some of those values aren't overwritten.\n    seqlen_q_og = (\n        query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q_rounded\n    )\n    if query_padding_mask is not None:\n        query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q_rounded - seqlen_q_og))\n        S_converted = S_converted.masked_fill(rearrange(~query_padding_mask, \"b s -> b 1 s 1\"), 0.0)\n    seqlen_k_og = key_padding_mask.shape[-1] if key_padding_mask is not None else seqlen_k\n    if key_padding_mask is not None:\n        key_padding_mask = F.pad(key_padding_mask, (0, seqlen_k_rounded - seqlen_k_og))\n        S_converted = S_converted.masked_fill(rearrange(~key_padding_mask, \"b s -> b 1 1 s\"), 0.0)\n    S_converted = F.pad(S_converted, (0, 0, 0, seqlen_q_og - seqlen_q_rounded))\n    S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k_rounded))\n    return S_converted[:, :, :seqlen_q, :seqlen_k]\n\n\ndef normalize_flash_attn_S(\n    attn_unnorm,\n    q,\n    k,\n    v,\n    query_padding_mask=None,\n    key_padding_mask=None,\n    attn_bias=None,\n    is_dropout=False,\n    causal=False,\n    window_size=(-1, -1),  # -1 means infinite window size\n):\n    \"\"\"\n    Arguments:\n        q: (batch_size, seqlen_q, nheads, head_dim)\n        k, v: (batch_size, seqlen_k, nheads, head_dim)\n        key_padding_mask: (batch_size, seqlen_q)\n        attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)\n    Output:\n        softmax_lse: (batch_size, nheads, seqlen_q)\n        softmax_max: (batch_size, nheads, seqlen_q)\n    \"\"\"\n    if causal:\n        window_size = (window_size[0], 0)\n    q, k, v = q.float(), k.float(), v.float()\n    _, seqlen_q, _, head_dim = q.shape\n    seqlen_k = k.shape[1]\n    scores = torch.einsum(\"bthd,bshd->bhts\", q / math.sqrt(head_dim), k)\n    if key_padding_mask is not None:\n        scores.masked_fill_(rearrange(~key_padding_mask, \"b s -> b 1 1 s\"), float(\"-inf\"))\n    if window_size[0] >= 0 or window_size[1] >= 0:\n        local_mask = construct_local_mask(\n            seqlen_q,\n            seqlen_k,\n            window_size,\n            query_padding_mask,\n            key_padding_mask,\n            q.device,\n        )\n        scores.masked_fill_(local_mask, float(\"-inf\"))\n    if attn_bias is not None:\n        scores = scores + attn_bias.to(dtype=scores.dtype)\n    block_size_n = _get_block_size_n_triton(scores.device, head_dim, is_dropout, causal)\n    scores_block = scores.split(block_size_n, dim=-1)\n    lse_block = torch.stack([torch.logsumexp(s, dim=-1) for s in scores_block], dim=-1)\n    lse = torch.logsumexp(lse_block, dim=-1)\n    # lse could be -inf (i.e. all values in scores are -inf), and we want to set those to inf\n    # so that when we do torch.exp(m - lse), we get 0.0 instead of NaN.\n    lse[lse == float(\"-inf\")] = float(\"inf\")\n    scores_max_block = torch.stack([torch.amax(s, dim=-1) for s in scores_block], dim=-1)\n    cummax_block = torch.cummax(scores_max_block.flip(-1), dim=-1).values.flip(-1).unbind(dim=-1)\n    attn_unnorm_block = attn_unnorm.split(block_size_n, dim=-1)\n    attn_norm = torch.cat(\n        [\n            a * rearrange(torch.exp(m - lse), \"b h s -> b h s 1\")\n            for a, m in zip(attn_unnorm_block, cummax_block)\n        ],\n        dim=-1,\n    )\n    if query_padding_mask is not None:\n        attn_norm.masked_fill_(rearrange(~query_padding_mask, \"b s -> b 1 s 1\"), 0.0)\n    return attn_norm.to(dtype=attn_unnorm.dtype)\n\n\ndef get_dropout_fraction(\n    dropout_mask,\n    query_padding_mask=None,\n    key_padding_mask=None,\n    causal=False,\n    window_size=(-1, -1),  # -1 means infinite window size\n):\n    \"\"\"\n    dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k), bool. True means keep, False means drop.\n    query_padding_mask: (batch_size, seqlen_q)\n    key_padding_mask: (batch_size, seqlen_k)\n    \"\"\"\n    if causal:\n        window_size = (window_size[0], 0)\n    batch_size, nheads, seqlen_q, seqlen_k = dropout_mask.shape\n    dropped = ~dropout_mask\n    valid = torch.ones_like(dropout_mask)\n    if query_padding_mask is not None:\n        dropped.masked_fill_(rearrange(~query_padding_mask, \"b s -> b 1 s 1\"), False)\n        valid.masked_fill_(rearrange(~query_padding_mask, \"b s -> b 1 s 1\"), False)\n    if key_padding_mask is not None:\n        dropped.masked_fill_(rearrange(~key_padding_mask, \"b s -> b 1 1 s\"), False)\n        valid.masked_fill_(rearrange(~key_padding_mask, \"b s -> b 1 1 s\"), False)\n    if window_size[0] >= 0 or window_size[1] >= 0:\n        local_mask = construct_local_mask(\n            seqlen_q,\n            seqlen_k,\n            window_size,\n            query_padding_mask,\n            key_padding_mask,\n            dropout_mask.device,\n        )\n        dropped.masked_fill_(local_mask, False)\n        valid.masked_fill_(local_mask, False)\n    dropped_total = dropped.sum()\n    return dropped.sum() / valid.sum()\n\n\n@pytest.mark.parametrize(\"dtype\", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16]))\n# @pytest.mark.parametrize(\"dtype\", [torch.float16])\n@pytest.mark.parametrize(\"deterministic\", [False])\n# @pytest.mark.parametrize(\"deterministic\", [False])\n@pytest.mark.parametrize(\"alibi\", [False, True])\n# @pytest.mark.parametrize(\"alibi\", [False])\n@pytest.mark.parametrize(\"local\", [False])\n# @pytest.mark.parametrize(\"local\", [False])\n@pytest.mark.parametrize(\"causal\", [False, True])\n# @pytest.mark.parametrize(\"causal\", [False])\n@pytest.mark.parametrize(\"d\", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize(\"d\", [32, 64, 96, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize('d', [32, 64, 96, 128])\n# @pytest.mark.parametrize(\"d\", [64])\n# @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048])\n@pytest.mark.parametrize(\"seqlen\", [97, 128, 200, 384, 768, 1024, 1025, 2048])\n# @pytest.mark.parametrize(\"seqlen\", [512])\n@pytest.mark.parametrize(\"dropout_p\", [0.0, 0.17])\n# @pytest.mark.parametrize(\"dropout_p\", [0.0])\ndef test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype):\n    if seqlen >= 2048 and torch.cuda.get_device_properties(\"cuda\").total_memory <= 16 * 2**30:\n        pytest.skip()  # Reference implementation OOM\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 4\n    nheads = 9\n    window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,))\n    qkv = torch.randn(\n        batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True\n    )\n    if alibi:\n        alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3\n        attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen, seqlen, causal=causal)\n    else:\n        alibi_slopes, attn_bias = None, None\n    out, lse, S_dmask = flash_attn_qkvpacked_func(\n        qkv,\n        dropout_p,\n        causal=causal,\n        window_size=window_size,\n        alibi_slopes=alibi_slopes,\n        deterministic=deterministic,\n        return_attn_probs=True,\n    )\n    if dropout_p > 0.0:\n        S_dmask_converted = convert_flash_attn_S_to_softmax(\n            S_dmask,\n            seqlen,\n            seqlen,\n            None,\n            None,\n            d,\n            dropout_p > 0.0,\n            causal=causal,\n            window_size=window_size,\n        )\n        dropout_mask = S_dmask_converted >= 0\n        attn_unnorm = S_dmask_converted.abs()\n        attn = normalize_flash_attn_S(\n            attn_unnorm,\n            qkv[:, :, 0],\n            qkv[:, :, 1],\n            qkv[:, :, 2],\n            None,\n            None,\n            attn_bias,\n            dropout_p > 0.0,\n            causal=causal,\n            window_size=window_size,\n        )\n        dropout_fraction = get_dropout_fraction(\n            dropout_mask, None, None, causal=causal, window_size=window_size\n        ).item()\n        print(f\"Actual dropout fraction: {dropout_fraction}\")\n    else:\n        dropout_mask = None\n\n    out_ref, attn_ref = attention_qkvpacked_ref(\n        qkv, None, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size\n    )\n    out_pt, attn_pt = attention_qkvpacked_ref(\n        qkv,\n        None,\n        attn_bias,\n        dropout_p,\n        dropout_mask,\n        causal=causal,\n        window_size=window_size,\n        upcast=False,\n        reorder_ops=True,\n    )\n    # v = qkv[:, :, 2].float()\n    # qk = torch.einsum('bshd,bthd->bhst', qkv[:, :, 0], qkv[:, :, 1]).float()\n    # if causal:\n    #     causal_mask = torch.triu(torch.ones(seqlen, seqlen, dtype=torch.bool, device=qkv.device), 1)\n    #     qk.masked_fill_(causal_mask, float('-inf'))\n    # m = qk.amax(-1, keepdim=True)\n    # s_tmp = torch.exp((qk - m) / math.sqrt(d))\n    # p_tmp = torch.softmax(qk / math.sqrt(d), -1)\n    # p_dropped = p_tmp if dropout_mask is None else p_tmp.masked_fill(~dropout_mask, 0)\n    # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)\n    # qk_max1 = torch.max(qk[:, :, 128:, 192:], -1, keepdim=True).values\n    # qk_max2 = torch.max(qk[:, :, 128:, 128:], -1, keepdim=True).values\n    # qk_max3 = torch.max(qk[:, :, 128:, 64:], -1, keepdim=True).values\n    # qk_max4 = torch.max(qk[:, :, 128:, :], -1, keepdim=True).values\n    # o1 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 192:] - qk_max1) / math.sqrt(d)), v[:, 192:])\n    # o2 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 128:] - qk_max2) / math.sqrt(d)), v[:, 128:])\n    # o3 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 64:] - qk_max3) / math.sqrt(d)), v[:, 64:])\n    # o4 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, :] - qk_max4) / math.sqrt(d)), v[:, :])\n    print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n    print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n    print(f\"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}\")\n    print(f\"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}\")\n    if dropout_p > 0.0:\n        print(f\"Attention max diff: {(attn - attn_ref).abs().max().item()}\")\n        print(f\"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}\")\n\n    g = torch.randn_like(out)\n    # do_o = (g.float() * out.float()).sum(-1)\n    # dv_tmp = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, :64], g[:, :64])\n    # dv_tmp1 = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, 64:], g[:, 64:])\n    if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):\n        (dqkv,) = torch.autograd.grad(out, qkv, g)\n        (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g)\n        (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g)\n        print(f\"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}\")\n        print(f\"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}\")\n        print(f\"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}\")\n        print(f\"dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}\")\n        print(f\"dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}\")\n        print(f\"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}\")\n        print(f\"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}\")\n        print(f\"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}\")\n\n    # Check that FlashAttention's numerical error is at most twice the numerical error\n    # of a Pytorch implementation.\n    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()\n\n    if dropout_p > 0.0:\n        # assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()\n        # With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate\n        if not alibi:\n            assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)\n\n    if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):\n        assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()\n\n\n@pytest.mark.parametrize(\"dtype\", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16]))\n# @pytest.mark.parametrize('dtype', [torch.float16])\n@pytest.mark.parametrize(\"deterministic\", [False])\n# @pytest.mark.parametrize(\"deterministic\", [True])\n@pytest.mark.parametrize(\"alibi\", [False, True])\n# @pytest.mark.parametrize(\"alibi\", [True])\n@pytest.mark.parametrize(\"local\", [False])\n# @pytest.mark.parametrize(\"local\", [True])\n@pytest.mark.parametrize(\"causal\", [False, True])\n# @pytest.mark.parametrize('causal', [False])\n@pytest.mark.parametrize(\"d\", [32, 59, 64, 80, 96, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize(\"d\", [32, 64, 96, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize('d', [64])\n@pytest.mark.parametrize(\"seqlen\", [97, 128, 200, 257, 384, 512, 768, 1025, 2048])\n# @pytest.mark.parametrize('seqlen', [128])\n@pytest.mark.parametrize(\"dropout_p\", [0.0, 0.17])\n# @pytest.mark.parametrize('dropout_p', [0.0])\ndef test_flash_attn_varlen_qkvpacked(\n    seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype\n):\n    if seqlen >= 2048 and torch.cuda.get_device_properties(\"cuda\").total_memory <= 16 * 2**30:\n        pytest.skip()  # Reference implementation OOM\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 5\n    nheads = 6\n    window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,))\n    qkv = torch.randn(\n        batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True\n    )\n\n    key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode=\"random\")\n    # key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full')\n    if alibi:\n        alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3\n        attn_bias = attn_bias_from_alibi_slopes(\n            alibi_slopes, seqlen, seqlen, key_padding_mask, key_padding_mask, causal=causal\n        )\n    else:\n        alibi_slopes, attn_bias = None, None\n\n    qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv(\n        *qkv.unbind(dim=2), key_padding_mask, key_padding_mask, qkvpacked=True\n    )\n\n    out_unpad, sm_lse, S_dmask = flash_attn_varlen_qkvpacked_func(\n        qkv_unpad,\n        cu_seqlens,\n        max_seqlen,\n        dropout_p,\n        causal=causal,\n        window_size=window_size,\n        alibi_slopes=alibi_slopes,\n        deterministic=deterministic,\n        return_attn_probs=True,\n    )\n    out = output_pad_fn(out_unpad)\n    if dropout_p > 0.0:\n        S_dmask_converted = convert_flash_attn_S_to_softmax(\n            S_dmask,\n            seqlen,\n            seqlen,\n            key_padding_mask,\n            key_padding_mask,\n            d,\n            dropout_p > 0.0,\n            causal=causal,\n            window_size=window_size,\n        )\n        dropout_mask = S_dmask_converted >= 0\n        attn_unnorm = S_dmask_converted.abs()\n        attn = normalize_flash_attn_S(\n            attn_unnorm,\n            qkv[:, :, 0],\n            qkv[:, :, 1],\n            qkv[:, :, 2],\n            key_padding_mask,\n            key_padding_mask,\n            attn_bias,\n            dropout_p > 0.0,\n            causal=causal,\n            window_size=window_size,\n        )\n        dropout_fraction = get_dropout_fraction(\n            dropout_mask, key_padding_mask, key_padding_mask, causal=causal, window_size=window_size\n        ).item()\n        print(f\"Actual dropout fraction: {dropout_fraction}\")\n    else:\n        dropout_mask = None\n\n    out_ref, attn_ref = attention_qkvpacked_ref(\n        qkv,\n        key_padding_mask,\n        attn_bias,\n        dropout_p,\n        dropout_mask,\n        causal=causal,\n        window_size=window_size,\n    )\n    out_pt, attn_pt = attention_qkvpacked_ref(\n        qkv,\n        key_padding_mask,\n        attn_bias,\n        dropout_p,\n        dropout_mask,\n        causal=causal,\n        window_size=window_size,\n        upcast=False,\n        reorder_ops=True,\n    )\n    print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n    print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n    print(f\"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}\")\n    print(f\"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}\")\n    if dropout_p > 0.0:\n        print(f\"Attention max diff: {(attn - attn_ref).abs().max().item()}\")\n        print(f\"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}\")\n\n    g = torch.randn_like(out)\n    if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):\n        (dqkv_unpad,) = torch.autograd.grad(out, qkv_unpad, g)\n        dqkv = dqkv_pad_fn(dqkv_unpad)\n        (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g)\n        (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g)\n        print(f\"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}\")\n        print(f\"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}\")\n        print(f\"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}\")\n        print(f\"dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}\")\n        print(f\"dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}\")\n        print(f\"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}\")\n        print(f\"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}\")\n        print(f\"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}\")\n\n    # Check that FlashAttention's numerical error is at most twice the numerical error\n    # of a Pytorch implementation.\n    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()\n\n    if dropout_p > 0.0:\n        # assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()\n        # With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate\n        if not alibi:\n            assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)\n\n    if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):\n        assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()\n\n\n@pytest.mark.parametrize(\"kvpacked\", [True, False])\n# @pytest.mark.parametrize(\"kvpacked\", [False])\n@pytest.mark.parametrize(\"dtype\", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16]))\n# @pytest.mark.parametrize(\"dtype\", [torch.bfloat16])\n@pytest.mark.parametrize(\"mha_type\", [\"mha\", \"mqa\", \"gqa\"])\n# @pytest.mark.parametrize(\"mha_type\", [\"mha\"])\n@pytest.mark.parametrize(\"deterministic\", [False])\n# @pytest.mark.parametrize(\"deterministic\", [True])\n@pytest.mark.parametrize(\"alibi\", [False, True])\n# @pytest.mark.parametrize(\"alibi\", [False])\n@pytest.mark.parametrize(\"local\", [False])\n# @pytest.mark.parametrize(\"local\", [False])\n@pytest.mark.parametrize(\"causal\", [False, True])\n# @pytest.mark.parametrize(\"causal\", [True])\n@pytest.mark.parametrize(\"d\", [32, 40, 59, 64, 96, 111, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize(\"d\", [32, 64, 96, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])\n# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])\n# @pytest.mark.parametrize('d', [56, 80])\n# @pytest.mark.parametrize(\"d\", [64])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_k\",\n    [\n        (113, 203),\n        (128, 217),\n        (113, 211),\n        (108, 256),\n        (256, 512),\n        (512, 256),\n        (1024, 1024),\n        (1023, 1024),\n        (1024, 1023),\n        (2048, 2048),\n    ],\n)\n# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])\n@pytest.mark.parametrize(\"dropout_p\", [0.0, 0.17])\n# @pytest.mark.parametrize(\"dropout_p\", [0.0])\n@pytest.mark.parametrize(\"softcap\", [0.0])\ndef test_flash_attn_output(\n    seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap\n):\n    if USE_TRITON_ROCM:\n        if causal:\n            if seqlen_q ==1024 and seqlen_k==1024 and d==160:\n                pytest.skip(\"This test with causal=True is flakey\")\n    if (\n        max(seqlen_q, seqlen_k) >= 2048\n        and torch.cuda.get_device_properties(\"cuda\").total_memory <= 16 * 2**30\n    ):\n        pytest.skip()  # Reference implementation OOM\n    if softcap > 0.0 and dropout_p > 0.0:\n        pytest.skip(\"Softcap and dropout not supported together\")\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 4\n    nheads = 6 if softcap == 0.0 else 4  # softcap reference impl takes more memory\n    nheads_k = nheads if mha_type == \"mha\" else (1 if mha_type == \"mqa\" else 2)\n    assert nheads % nheads_k == 0\n    window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))\n    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    if softcap > 0:\n        # Ensure the values of qk are at least within softcap range.\n        q = q * softcap\n    if kvpacked:\n        kv = torch.randn(\n            batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True\n        )\n    else:\n        k = torch.randn(\n            batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True\n        )\n        v = torch.randn(\n            batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True\n        )\n    if alibi:\n        alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3\n        attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen_q, seqlen_k, causal=causal)\n    else:\n        alibi_slopes, attn_bias = None, None\n\n    if kvpacked:\n        out, lse, S_dmask = flash_attn_kvpacked_func(\n            q,\n            kv,\n            dropout_p,\n            causal=causal,\n            window_size=window_size,\n            softcap=softcap,\n            alibi_slopes=alibi_slopes,\n            deterministic=deterministic,\n            return_attn_probs=True,\n        )\n    else:\n        out, lse, S_dmask = flash_attn_func(\n            q,\n            k,\n            v,\n            dropout_p,\n            causal=causal,\n            window_size=window_size,\n            softcap=softcap,\n            alibi_slopes=alibi_slopes,\n            deterministic=deterministic,\n            return_attn_probs=True,\n        )\n    if dropout_p > 0.0:\n        S_dmask_converted = convert_flash_attn_S_to_softmax(\n            S_dmask,\n            seqlen_q,\n            seqlen_k,\n            None,\n            None,\n            d,\n            dropout_p > 0.0,\n            causal=causal,\n            window_size=window_size,\n        )\n        dropout_mask = S_dmask_converted >= 0\n        attn_unnorm = S_dmask_converted.abs()\n        if kvpacked:\n            kv_rep = repeat(kv, \"b s two h d -> b s two (h g) d\", g=nheads // nheads_k)\n            k_rep, v_rep = kv_rep.unbind(dim=2)\n        else:\n            k_rep = repeat(k, \"b s h d -> b s (h g) d\", g=nheads // nheads_k)\n            v_rep = repeat(v, \"b s h d -> b s (h g) d\", g=nheads // nheads_k)\n        attn = normalize_flash_attn_S(\n            attn_unnorm,\n            q,\n            k_rep,\n            v_rep,\n            None,\n            None,\n            attn_bias,\n            dropout_p > 0.0,\n            causal=causal,\n            window_size=window_size,\n        )\n        dropout_fraction = get_dropout_fraction(\n            dropout_mask, None, None, causal=causal, window_size=window_size\n        ).item()\n        print(f\"Actual dropout fraction: {dropout_fraction}\")\n    else:\n        dropout_mask = None\n\n    if kvpacked:\n        out_ref, attn_ref = attention_kvpacked_ref(\n            q,\n            kv,\n            None,\n            None,\n            attn_bias,\n            dropout_p,\n            dropout_mask,\n            causal=causal,\n            window_size=window_size,\n            softcap=softcap,\n        )\n        out_pt, attn_pt = attention_kvpacked_ref(\n            q,\n            kv,\n            None,\n            None,\n            attn_bias,\n            dropout_p,\n            dropout_mask,\n            causal=causal,\n            window_size=window_size,\n            softcap=softcap,\n            upcast=False,\n            reorder_ops=True,\n        )\n    else:\n        out_ref, attn_ref = attention_ref(\n            q,\n            k,\n            v,\n            None,\n            None,\n            attn_bias,\n            dropout_p,\n            dropout_mask,\n            causal=causal,\n            window_size=window_size,\n            softcap=softcap,\n        )\n        out_pt, attn_pt = attention_ref(\n            q,\n            k,\n            v,\n            None,\n            None,\n            attn_bias,\n            dropout_p,\n            dropout_mask,\n            causal=causal,\n            window_size=window_size,\n            softcap=softcap,\n            upcast=False,\n            reorder_ops=True,\n        )\n\n    print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n    print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n    print(f\"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}\")\n    print(f\"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}\")\n    if dropout_p > 0.0:\n        print(f\"Attention max diff: {(attn - attn_ref).abs().max().item()}\")\n        print(f\"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}\")\n\n    g = torch.randn_like(out)\n    do_o = (g.float() * out.float()).sum(-1)\n    if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):\n        if kvpacked:\n            (\n                dq,\n                dkv,\n            ) = torch.autograd.grad(out, (q, kv), g)\n            dk, dv = dkv.unbind(2)\n            (\n                dq_ref,\n                dkv_ref,\n            ) = torch.autograd.grad(out_ref, (q, kv), g)\n            dk_ref, dv_ref = dkv_ref.unbind(2)\n            (\n                dq_pt,\n                dkv_pt,\n            ) = torch.autograd.grad(out_pt, (q, kv), g)\n            dk_pt, dv_pt = dkv_pt.unbind(2)\n        else:\n            (\n                dq,\n                dk,\n                dv,\n            ) = torch.autograd.grad(out, (q, k, v), g)\n            (\n                dq_ref,\n                dk_ref,\n                dv_ref,\n            ) = torch.autograd.grad(out_ref, (q, k, v), g)\n            (\n                dq_pt,\n                dk_pt,\n                dv_pt,\n            ) = torch.autograd.grad(out_pt, (q, k, v), g)\n        print(f\"dQ max diff: {(dq - dq_ref).abs().max().item()}\")\n        print(f\"dK max diff: {(dk - dk_ref).abs().max().item()}\")\n        print(f\"dV max diff: {(dv - dv_ref).abs().max().item()}\")\n        print(f\"dQ mean diff: {(dq - dq_ref).abs().mean().item()}\")\n        print(f\"dK mean diff: {(dk - dk_ref).abs().mean().item()}\")\n        print(f\"dV mean diff: {(dv - dv_ref).abs().mean().item()}\")\n        print(f\"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}\")\n        print(f\"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}\")\n        print(f\"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}\")\n        print(f\"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}\")\n        print(f\"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}\")\n        print(f\"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}\")\n\n    # Check that FlashAttention's numerical error is at most twice the numerical error\n    # of a Pytorch implementation.\n    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()\n\n    if dropout_p > 0.0:\n        # assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()\n        # With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate\n        if not alibi:\n            assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)\n\n    if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):\n        assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item()\n        assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item()\n        assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item()\n\n\n@pytest.mark.parametrize(\"kvpacked\", [False])\n# @pytest.mark.parametrize('kvpacked', [False])\n@pytest.mark.parametrize(\"dtype\", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16]))\n# @pytest.mark.parametrize('dtype', [torch.float16])\n@pytest.mark.parametrize(\"mha_type\", [\"mha\", \"mqa\", \"gqa\"])\n# @pytest.mark.parametrize('mha_type', [\"mqa\"])\n@pytest.mark.parametrize(\"deterministic\", [False])\n# @pytest.mark.parametrize(\"deterministic\", [True])\n@pytest.mark.parametrize(\"alibi\", [False, True])\n# @pytest.mark.parametrize(\"alibi\", [True])\n@pytest.mark.parametrize(\"local\", [False])\n# @pytest.mark.parametrize(\"local\", [True])\n@pytest.mark.parametrize(\"causal\", [False, True])\n# @pytest.mark.parametrize('causal', [True])\n@pytest.mark.parametrize(\"d\", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize(\"d\", [32, 64, 96, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize('d', [64])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_k\",\n    [\n        (1, 147),\n        (113, 203),\n        (128, 217),\n        (113, 211),\n        (108, 256),\n        (256, 512),\n        (512, 256),\n        (1024, 1024),\n        (1023, 1024),\n        (1024, 1023),\n        (2048, 2048),\n    ],\n)\n# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])\n@pytest.mark.parametrize(\"dropout_p\", [0.0, 0.17])\n@pytest.mark.parametrize(\"softcap\", [0.0])\n# @pytest.mark.parametrize('dropout_p', [0.0])\ndef test_flash_attn_varlen_output(\n    seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap\n):\n    if USE_TRITON_ROCM:\n        if seqlen_q == 1 and seqlen_k == 147 and kvpacked == True and dropout_p != 0.0:\n            pytest.skip(\"This config with dropout is flaky on AMD.\")\n    if (\n        max(seqlen_q, seqlen_k) >= 2048\n        and torch.cuda.get_device_properties(\"cuda\").total_memory <= 16 * 2**30\n    ):\n        pytest.skip()  # Reference implementation OOM\n    if softcap > 0.0 and dropout_p > 0.0:\n        pytest.skip(\"Softcap and dropout not supported together\")\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 4\n    nheads = 6 if softcap == 0.0 else 4  # softcap reference impl takes more memory\n    nheads_k = nheads if mha_type == \"mha\" else (1 if mha_type == \"mqa\" else 2)\n    assert nheads % nheads_k == 0\n    window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))\n    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    if softcap > 0:\n        # Ensure the values of qk are at least within softcap range.\n        q = q * softcap\n\n    if kvpacked:\n        kv = torch.randn(\n            batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True\n        )\n    else:\n        k = torch.randn(\n            batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True\n        )\n        v = torch.randn(\n            batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True\n        )\n\n    query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode=\"random\")\n    key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode=\"random\")\n    # key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full')\n    if alibi:\n        alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3\n        attn_bias = attn_bias_from_alibi_slopes(\n            alibi_slopes, seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, causal=causal\n        )\n    else:\n        alibi_slopes, attn_bias = None, None\n\n    if kvpacked:\n        (\n            q_unpad,\n            kv_unpad,\n            cu_seqlens_q,\n            cu_seqlens_k,\n            max_seqlen_q,\n            max_seqlen_k,\n            q,\n            kv,\n            output_pad_fn,\n            dq_pad_fn,\n            dkv_pad_fn,\n        ) = generate_qkv(q, *kv.unbind(dim=2), query_padding_mask, key_padding_mask, kvpacked=True)\n        out_unpad, sm_lse, S_dmask = flash_attn_varlen_kvpacked_func(\n            q_unpad,\n            kv_unpad,\n            cu_seqlens_q,\n            cu_seqlens_k,\n            max_seqlen_q,\n            max_seqlen_k,\n            dropout_p,\n            causal=causal,\n            window_size=window_size,\n            softcap=softcap,\n            alibi_slopes=alibi_slopes,\n            deterministic=deterministic,\n            return_attn_probs=True,\n        )\n    else:\n        (\n            q_unpad,\n            k_unpad,\n            v_unpad,\n            cu_seqlens_q,\n            cu_seqlens_k,\n            max_seqlen_q,\n            max_seqlen_k,\n            q,\n            k,\n            v,\n            output_pad_fn,\n            dq_pad_fn,\n            dk_pad_fn,\n        ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)\n        out_unpad, sm_lse, S_dmask = flash_attn_varlen_func(\n            q_unpad,\n            k_unpad,\n            v_unpad,\n            cu_seqlens_q,\n            cu_seqlens_k,\n            max_seqlen_q,\n            max_seqlen_k,\n            dropout_p,\n            causal=causal,\n            window_size=window_size,\n            softcap=softcap,\n            alibi_slopes=alibi_slopes,\n            deterministic=deterministic,\n            return_attn_probs=True,\n        )\n    out = output_pad_fn(out_unpad)\n    if dropout_p > 0.0:\n        S_dmask_converted = convert_flash_attn_S_to_softmax(\n            S_dmask,\n            seqlen_q,\n            seqlen_k,\n            query_padding_mask,\n            key_padding_mask,\n            d,\n            dropout_p > 0.0,\n            causal=causal,\n            window_size=window_size,\n        )\n        dropout_mask = S_dmask_converted >= 0\n        attn_unnorm = S_dmask_converted.abs()\n        if kvpacked:\n            kv_rep = repeat(kv, \"b s two h d -> b s two (h g) d\", g=nheads // nheads_k)\n            k_rep, v_rep = kv_rep.unbind(dim=2)\n        else:\n            k_rep = repeat(k, \"b s h d -> b s (h g) d\", g=nheads // nheads_k)\n            v_rep = repeat(v, \"b s h d -> b s (h g) d\", g=nheads // nheads_k)\n        attn = normalize_flash_attn_S(\n            attn_unnorm,\n            q,\n            k_rep,\n            v_rep,\n            query_padding_mask,\n            key_padding_mask,\n            attn_bias,\n            dropout_p > 0.0,\n            causal=causal,\n            window_size=window_size,\n        )\n        dropout_fraction = get_dropout_fraction(\n            dropout_mask,\n            query_padding_mask,\n            key_padding_mask,\n            causal=causal,\n            window_size=window_size,\n        ).item()\n        print(f\"Actual dropout fraction: {dropout_fraction}\")\n    else:\n        dropout_mask = None\n\n    if kvpacked:\n        out_ref, attn_ref = attention_kvpacked_ref(\n            q,\n            kv,\n            query_padding_mask,\n            key_padding_mask,\n            attn_bias,\n            dropout_p,\n            dropout_mask,\n            causal=causal,\n            window_size=window_size,\n            softcap=softcap,\n        )\n        out_pt, attn_pt = attention_kvpacked_ref(\n            q,\n            kv,\n            query_padding_mask,\n            key_padding_mask,\n            attn_bias,\n            dropout_p,\n            dropout_mask,\n            causal=causal,\n            window_size=window_size,\n            softcap=softcap,\n            upcast=False,\n            reorder_ops=True,\n        )\n    else:\n        out_ref, attn_ref = attention_ref(\n            q,\n            k,\n            v,\n            query_padding_mask,\n            key_padding_mask,\n            attn_bias,\n            dropout_p,\n            dropout_mask,\n            causal=causal,\n            window_size=window_size,\n            softcap=softcap,\n        )\n        out_pt, attn_pt = attention_ref(\n            q,\n            k,\n            v,\n            query_padding_mask,\n            key_padding_mask,\n            attn_bias,\n            dropout_p,\n            dropout_mask,\n            causal=causal,\n            window_size=window_size,\n            softcap=softcap,\n            upcast=False,\n            reorder_ops=True,\n        )\n\n    print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n    print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n    print(f\"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}\")\n    print(f\"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}\")\n    if dropout_p > 0.0:\n        print(f\"Attention max diff: {(attn - attn_ref).abs().max().item()}\")\n        print(f\"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}\")\n\n    g = torch.randn_like(out)\n    if ((d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90)):\n        if kvpacked:\n            (\n                dq_unpad,\n                dkv_unpad,\n            ) = torch.autograd.grad(out, (q_unpad, kv_unpad), g)\n            dk, dv = dkv_pad_fn(dkv_unpad).unbind(2)\n            (\n                dq_ref,\n                dkv_ref,\n            ) = torch.autograd.grad(out_ref, (q, kv), g)\n            dk_ref, dv_ref = dkv_ref.unbind(2)\n            (\n                dq_pt,\n                dkv_pt,\n            ) = torch.autograd.grad(out_pt, (q, kv), g)\n            dk_pt, dv_pt = dkv_pt.unbind(2)\n        else:\n            (\n                dq_unpad,\n                dk_unpad,\n                dv_unpad,\n            ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g)\n            dk = dk_pad_fn(dk_unpad)\n            dv = dk_pad_fn(dv_unpad)\n            (\n                dq_ref,\n                dk_ref,\n                dv_ref,\n            ) = torch.autograd.grad(out_ref, (q, k, v), g)\n            (\n                dq_pt,\n                dk_pt,\n                dv_pt,\n            ) = torch.autograd.grad(out_pt, (q, k, v), g)\n        dq = dq_pad_fn(dq_unpad)\n        print(f\"dQ max diff: {(dq - dq_ref).abs().max().item()}\")\n        print(f\"dK max diff: {(dk - dk_ref).abs().max().item()}\")\n        print(f\"dV max diff: {(dv - dv_ref).abs().max().item()}\")\n        print(f\"dQ mean diff: {(dq - dq_ref).abs().mean().item()}\")\n        print(f\"dK mean diff: {(dk - dk_ref).abs().mean().item()}\")\n        print(f\"dV mean diff: {(dv - dv_ref).abs().mean().item()}\")\n        print(f\"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}\")\n        print(f\"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}\")\n        print(f\"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}\")\n        print(f\"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}\")\n        print(f\"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}\")\n        print(f\"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}\")\n\n    # Check that FlashAttention's numerical error is at most twice the numerical error\n    # of a Pytorch implementation.\n    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()\n\n    if dropout_p > 0.0:\n        # assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()\n        # With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate\n        if not alibi:\n            assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.04)\n\n    if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):\n        assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item()\n        assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item()\n        assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item()\n\n\n@pytest.mark.parametrize(\"dtype\", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16]))\n# @pytest.mark.parametrize(\"dtype\", [torch.bfloat16])\n@pytest.mark.parametrize(\"local\", [False])\n# @pytest.mark.parametrize(\"local\", [True])\n@pytest.mark.parametrize(\"d\", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize(\"d\", [32, 64, 96, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])\n# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])\n# @pytest.mark.parametrize('d', [56, 80])\n# @pytest.mark.parametrize(\"d\", [64, 128])\n@pytest.mark.parametrize(\"swap_sq_sk\", [False])\n# @pytest.mark.parametrize(\"swap_sq_sk\", [True])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_k\",\n    [\n        (1, 239),\n        (3, 799),\n        (127, 512),\n        (127, 513),\n        (113, 203),\n        (128, 217),\n        (113, 211),\n        (108, 256),\n        (256, 512),\n        (1023, 1024),\n    ],\n)\n# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])\ndef test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):\n    if USE_TRITON_ROCM:\n        if get_arch().is_rdna:\n            if seqlen_q == 1 and seqlen_k == 239 and d == 256:\n                pytest.skip(\"This config doesnot work on RDNA Devices.\")\n    if (\n        max(seqlen_q, seqlen_k) >= 2048\n        and torch.cuda.get_device_properties(\"cuda\").total_memory <= 16 * 2**30\n    ):\n        pytest.skip()  # Reference implementation OOM\n    if swap_sq_sk:\n        seqlen_q, seqlen_k = seqlen_k, seqlen_q\n    device = \"cuda\"\n    causal = True\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 8\n    nheads = 9\n    window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))\n    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size)\n    out_ref, attn_ref = attention_ref(\n        q, k, v, None, None, None, 0.0, None, causal=causal, window_size=window_size\n    )\n    out_pt, attn_pt = attention_ref(\n        q,\n        k,\n        v,\n        None,\n        None,\n        None,\n        0.0,\n        None,\n        causal=causal,\n        window_size=window_size,\n        upcast=False,\n        reorder_ops=True,\n    )\n\n    print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n    print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n    print(f\"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}\")\n    print(f\"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}\")\n\n    g = torch.randn_like(out)\n    do_o = (g.float() * out.float()).sum(-1)\n    (\n        dq,\n        dk,\n        dv,\n    ) = torch.autograd.grad(out, (q, k, v), g)\n    (\n        dq_ref,\n        dk_ref,\n        dv_ref,\n    ) = torch.autograd.grad(out_ref, (q, k, v), g)\n    (\n        dq_pt,\n        dk_pt,\n        dv_pt,\n    ) = torch.autograd.grad(out_pt, (q, k, v), g)\n    print(f\"dQ max diff: {(dq - dq_ref).abs().max().item()}\")\n    print(f\"dK max diff: {(dk - dk_ref).abs().max().item()}\")\n    print(f\"dV max diff: {(dv - dv_ref).abs().max().item()}\")\n    print(f\"dQ mean diff: {(dq - dq_ref).abs().mean().item()}\")\n    print(f\"dK mean diff: {(dk - dk_ref).abs().mean().item()}\")\n    print(f\"dV mean diff: {(dv - dv_ref).abs().mean().item()}\")\n    print(f\"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}\")\n    print(f\"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}\")\n    print(f\"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}\")\n    print(f\"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}\")\n    print(f\"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}\")\n    print(f\"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}\")\n\n    # Check that FlashAttention's numerical error is at most twice the numerical error\n    # of a Pytorch implementation.\n    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5\n\n    assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5\n    assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5\n    assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5\n\n\n@pytest.mark.parametrize(\"dtype\", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16]))\n# @pytest.mark.parametrize(\"dtype\", [torch.bfloat16])\n@pytest.mark.parametrize(\"local\", [False])\n# @pytest.mark.parametrize(\"local\", [True])\n@pytest.mark.parametrize(\"d\", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize(\"d\", [32, 64, 96, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])\n# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])\n# @pytest.mark.parametrize('d', [56, 80])\n# @pytest.mark.parametrize(\"d\", [64])\n@pytest.mark.parametrize(\"swap_sq_sk\", [False])\n# @pytest.mark.parametrize(\"swap_sq_sk\", [True])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_k\",\n    [\n        (1, 239),\n        (3, 799),\n        (127, 512),\n        (127, 513),\n        (113, 203),\n        (128, 217),\n        (113, 211),\n        (108, 256),\n        (256, 512),\n        (1023, 1024),\n    ],\n)\n# TODO: add smaller page sizes when https://github.com/Dao-AILab/flash-attention/pull/824 is merged\n@pytest.mark.parametrize(\"paged_kv_block_size\", [None])\n# @pytest.mark.parametrize(\"seqlen_q,seqlen_k\", [(256, 128)])\ndef test_flash_attn_varlen_causal(\n    seqlen_q, seqlen_k, swap_sq_sk, d, local, paged_kv_block_size, dtype\n):\n    if (\n        max(seqlen_q, seqlen_k) >= 2048\n        and torch.cuda.get_device_properties(\"cuda\").total_memory <= 16 * 2**30\n    ):\n        pytest.skip()  # Reference implementation OOM\n    if swap_sq_sk:\n        seqlen_q, seqlen_k = seqlen_k, seqlen_q\n    device = \"cuda\"\n    causal = True\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 8\n    nheads = 9\n    window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))\n    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)\n\n    if paged_kv_block_size is None:\n        k = torch.randn(\n            batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True\n        )\n        v = torch.randn(\n            batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True\n        )\n        block_table = None\n    else:\n        k, v, block_table, k_cache_paged, v_cache_paged, num_blocks = _generate_block_kvcache(\n            seqlen_k, paged_kv_block_size, batch_size, nheads, d, device, dtype\n        )\n    query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode=\"random\")\n    key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode=\"random\")\n    (\n        q_unpad,\n        k_unpad,\n        v_unpad,\n        cu_seqlens_q,\n        cu_seqlens_k,\n        max_seqlen_q,\n        max_seqlen_k,\n        q,\n        k,\n        v,\n        output_pad_fn,\n        dq_pad_fn,\n        dk_pad_fn,\n    ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)\n    out_unpad = flash_attn_varlen_func(\n        q_unpad,\n        k_unpad if paged_kv_block_size is None else k_cache_paged,\n        v_unpad if paged_kv_block_size is None else v_cache_paged,\n        cu_seqlens_q,\n        cu_seqlens_k,\n        max_seqlen_q,\n        max_seqlen_k,\n        0.0,\n        causal=causal,\n        window_size=window_size,\n        block_table=block_table,\n    )\n    out = output_pad_fn(out_unpad)\n    out_ref, attn_ref = attention_ref(\n        q,\n        k,\n        v,\n        query_padding_mask,\n        key_padding_mask,\n        None,\n        0.0,\n        None,\n        causal=causal,\n        window_size=window_size,\n    )\n    out_pt, attn_pt = attention_ref(\n        q,\n        k,\n        v,\n        query_padding_mask,\n        key_padding_mask,\n        None,\n        0.0,\n        None,\n        causal=causal,\n        window_size=window_size,\n        upcast=False,\n        reorder_ops=True,\n    )\n\n    print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n    print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n    print(f\"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}\")\n    print(f\"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}\")\n\n    g = torch.randn_like(out)\n    do_o = (g.float() * out.float()).sum(-1)\n    test_backward = block_table is None\n    if test_backward:\n        (\n            dq_unpad,\n            dk_unpad,\n            dv_unpad,\n        ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g)\n        dq = dq_pad_fn(dq_unpad)\n        dk = dk_pad_fn(dk_unpad)\n        dv = dk_pad_fn(dv_unpad)\n        (\n            dq_ref,\n            dk_ref,\n            dv_ref,\n        ) = torch.autograd.grad(out_ref, (q, k, v), g)\n        (\n            dq_pt,\n            dk_pt,\n            dv_pt,\n        ) = torch.autograd.grad(out_pt, (q, k, v), g)\n        print(f\"dQ max diff: {(dq - dq_ref).abs().max().item()}\")\n        print(f\"dK max diff: {(dk - dk_ref).abs().max().item()}\")\n        print(f\"dV max diff: {(dv - dv_ref).abs().max().item()}\")\n        print(f\"dQ mean diff: {(dq - dq_ref).abs().mean().item()}\")\n        print(f\"dK mean diff: {(dk - dk_ref).abs().mean().item()}\")\n        print(f\"dV mean diff: {(dv - dv_ref).abs().mean().item()}\")\n        print(f\"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}\")\n        print(f\"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}\")\n        print(f\"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}\")\n        print(f\"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}\")\n        print(f\"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}\")\n        print(f\"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}\")\n\n    # Check that FlashAttention's numerical error is at most twice the numerical error\n    # of a Pytorch implementation.\n    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5\n\n    if test_backward:\n        assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5\n        assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5\n        assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5\n\n\n@pytest.mark.parametrize(\"dtype\", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16]))\n# @pytest.mark.parametrize(\"dtype\", [torch.float16])\n@pytest.mark.parametrize(\"deterministic\", [False])\n# @pytest.mark.parametrize(\"deterministic\", [True])\n@pytest.mark.parametrize(\"alibi\", [False, True])\n# @pytest.mark.parametrize(\"alibi\", [True])\n@pytest.mark.parametrize(\"local\", [False])\n# @pytest.mark.parametrize(\"local\", [False])\n@pytest.mark.parametrize(\"causal\", [False, True])\n# @pytest.mark.parametrize(\"causal\", [True])\n@pytest.mark.parametrize(\"d\", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])\n# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])\n# @pytest.mark.parametrize('d', [56, 80])\n# @pytest.mark.parametrize(\"d\", [64])\n@pytest.mark.parametrize(\"swap_sq_sk\", [False])\n# @pytest.mark.parametrize(\"swap_sq_sk\", [False])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_k\",\n    [\n        (3, 1024),\n        (1, 339),\n        (64, 800),\n        (3, 799),\n        (64, 2048),\n        (16, 20000),\n        (16, 100000),\n        (128, 128),\n        (256, 256),\n    ],\n)\n# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])\n@pytest.mark.skip()\ndef test_flash_attn_splitkv(\n    seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, alibi, deterministic, dtype\n):\n    if USE_TRITON_ROCM:\n        if seqlen_q == 1 and seqlen_k == 339 and swap_sq_sk == True:\n            pytest.skip(\"This config with is flaky on AMD.\")\n    if swap_sq_sk:\n        seqlen_q, seqlen_k = seqlen_k, seqlen_q\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 1\n    nheads = 12\n    window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))\n    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    if alibi:\n        alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3\n        attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen_q, seqlen_k, causal=causal)\n    else:\n        alibi_slopes, attn_bias = None, None\n    out, lse, _ = flash_attn_func(\n        q,\n        k,\n        v,\n        0.0,\n        causal=causal,\n        window_size=window_size,\n        alibi_slopes=alibi_slopes,\n        deterministic=deterministic,\n        return_attn_probs=True,\n    )\n    out_ref, attn_ref = attention_ref(\n        q, k, v, None, None, attn_bias, 0.0, None, causal=causal, window_size=window_size\n    )\n    out_pt, attn_pt = attention_ref(\n        q,\n        k,\n        v,\n        None,\n        None,\n        attn_bias,\n        0.0,\n        None,\n        causal=causal,\n        window_size=window_size,\n        upcast=False,\n        reorder_ops=True,\n    )\n\n    print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n    print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n    print(f\"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}\")\n    print(f\"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}\")\n\n    g = torch.randn_like(out)\n    do_o = (g.float() * out.float()).sum(-1)\n    (\n        dq,\n        dk,\n        dv,\n    ) = torch.autograd.grad(out, (q, k, v), g)\n    (\n        dq_ref,\n        dk_ref,\n        dv_ref,\n    ) = torch.autograd.grad(out_ref, (q, k, v), g)\n    (\n        dq_pt,\n        dk_pt,\n        dv_pt,\n    ) = torch.autograd.grad(out_pt, (q, k, v), g)\n    print(f\"dQ max diff: {(dq - dq_ref).abs().max().item()}\")\n    print(f\"dK max diff: {(dk - dk_ref).abs().max().item()}\")\n    print(f\"dV max diff: {(dv - dv_ref).abs().max().item()}\")\n    print(f\"dQ mean diff: {(dq - dq_ref).abs().mean().item()}\")\n    print(f\"dK mean diff: {(dk - dk_ref).abs().mean().item()}\")\n    print(f\"dV mean diff: {(dv - dv_ref).abs().mean().item()}\")\n    print(f\"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}\")\n    print(f\"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}\")\n    print(f\"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}\")\n    print(f\"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}\")\n    print(f\"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}\")\n    print(f\"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}\")\n\n    # Check that FlashAttention's numerical error is at most twice the numerical error\n    # of a Pytorch implementation.\n    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5\n\n    mult = 2 if not alibi else 8\n    assert (dq - dq_ref).abs().max().item() <= mult * (dq_pt - dq_ref).abs().max().item() + 2e-4\n    assert (dk - dk_ref).abs().max().item() <= mult * (dk_pt - dk_ref).abs().max().item() + 2e-4\n    assert (dv - dv_ref).abs().max().item() <= mult * (dv_pt - dv_ref).abs().max().item() + 2e-4\n\n\n# @pytest.mark.parametrize(\"dtype\", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16]))\n@pytest.mark.parametrize(\"dtype\", [torch.float16])\n@pytest.mark.parametrize(\"num_splits\", [1, 0])\n# @pytest.mark.parametrize(\"num_splits\", [1])\n@pytest.mark.parametrize(\"mha_type\", [\"mha\", \"mqa\", \"gqa\"])\n# @pytest.mark.parametrize(\"mha_type\", [\"mha\"])\n@pytest.mark.parametrize(\"new_kv\", [False, True])\n# @pytest.mark.parametrize(\"new_kv\", [False])\n@pytest.mark.parametrize(\"alibi\", [False, True])\n# @pytest.mark.parametrize(\"alibi\", [False])\n@pytest.mark.parametrize(\"local\", [False])\n# @pytest.mark.parametrize(\"local\", [False])\n@pytest.mark.parametrize(\"causal\", [False, True])\n# @pytest.mark.parametrize(\"causal\", [False])\n@pytest.mark.parametrize(\"seqlen_new_eq_seqlen_q\", [True, False])\n# @pytest.mark.parametrize(\"seqlen_new_eq_seqlen_q\", [True])\n@pytest.mark.parametrize(\"rotary_interleaved\", [False, True])\n# @pytest.mark.parametrize(\"rotary_interleaved\", [False])\n@pytest.mark.parametrize(\"rotary_fraction\", [0.0, 0.5, 1.0])\n# @pytest.mark.parametrize(\"rotary_fraction\", [0.0])\n@pytest.mark.parametrize(\"paged_kv_block_size\", [None, 256])\n# @pytest.mark.parametrize(\"paged_kv_block_size\", [256, 512])\n# @pytest.mark.parametrize(\"paged_kv_block_size\", [None])\n@pytest.mark.parametrize(\"has_leftpad\", [False])\n# @pytest.mark.parametrize(\"has_leftpad\", [True])\n# @pytest.mark.parametrize(\"has_batch_idx\", [False, True])\n@pytest.mark.parametrize(\"has_batch_idx\", [False])\n@pytest.mark.parametrize(\"d\", [32, 59, 64, 80, 128, 256])\n# @pytest.mark.parametrize(\"d\", [32, 64, 96, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])\n# @pytest.mark.parametrize('d', [56, 80])\n# @pytest.mark.parametrize(\"d\", [128])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_k\",\n    [\n        (1, 128),\n        (1, 339),\n        (3, 1024),\n        (64, 800),\n        (64, 256),\n        (3, 799),\n        (64, 2048),\n        (16, 20000),\n        (1, 128 * 1024),\n        (16, 128 * 1024),\n        (128, 128),\n    ],\n)\n# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])\ndef test_flash_attn_kvcache(\n    seqlen_q,\n    seqlen_k,\n    d,\n    has_batch_idx,\n    has_leftpad,\n    paged_kv_block_size,\n    rotary_fraction,\n    rotary_interleaved,\n    seqlen_new_eq_seqlen_q,\n    causal,\n    local,\n    alibi,\n    new_kv,\n    mha_type,\n    num_splits,\n    dtype,\n):\n    if seqlen_q > seqlen_k and new_kv:\n        pytest.skip()\n    if not new_kv and rotary_fraction > 0.0:\n        pytest.skip()\n    if has_batch_idx and paged_kv_block_size is not None:\n        pytest.skip()\n    if has_leftpad and paged_kv_block_size is not None:\n        pytest.skip()\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 2\n    batch_size_cache = batch_size if not has_batch_idx else batch_size * 2\n    nheads = 6\n    # rotary_dim must be a multiple of 16, and must be <= d\n    rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16\n    nheads_k = nheads if mha_type == \"mha\" else (1 if mha_type == \"mqa\" else 3)\n    assert nheads % nheads_k == 0\n    window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))\n    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype)\n    seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item()\n    if new_kv:\n        k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype)\n        v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype)\n    else:\n        k, v = None, None\n    if paged_kv_block_size is None:\n        k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)\n        v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)\n        block_table = None\n    else:\n        (\n            k_cache,\n            v_cache,\n            block_table,\n            k_cache_paged,\n            v_cache_paged,\n            num_blocks,\n        ) = _generate_block_kvcache(\n            seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype\n        )\n    cache_seqlens = torch.randint(\n        0 if new_kv else 1,\n        # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough\n        (\n            (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1)\n            if new_kv\n            else (seqlen_k + 1)\n        ),\n        (batch_size,),\n        dtype=torch.int32,\n        device=device,\n    )\n    if has_leftpad:\n        cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device)\n                                   if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device)\n                                   for i in range(batch_size)])\n    else:\n        cache_leftpad = None\n    arange = rearrange(torch.arange(seqlen_k, device=device), \"s -> 1 s\")\n    cache_seqlens_expanded = rearrange(cache_seqlens, \"b -> b 1\")\n    key_padding_mask = arange < cache_seqlens_expanded + (seqlen_new if new_kv else 0)\n    if has_leftpad:\n        key_padding_mask = torch.logical_and(\n            key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k)\n        )\n    if has_batch_idx:\n        cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[\n            :batch_size\n        ]\n    else:\n        cache_batch_idx = None\n    if alibi:\n        alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3\n        attn_bias = attn_bias_from_alibi_slopes(\n            alibi_slopes, seqlen_q, seqlen_k, None, key_padding_mask, causal=causal, key_leftpad=cache_leftpad\n        )\n    else:\n        alibi_slopes, attn_bias = None, None\n    # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)\n    if rotary_dim > 0:\n        angle = (\n            torch.rand(\n                seqlen_k if paged_kv_block_size is None else num_blocks * paged_kv_block_size,\n                rotary_dim // 2,\n                device=device,\n            )\n            * 2\n            * math.pi\n        )\n        cos = torch.cos(angle).to(dtype=dtype)\n        sin = torch.sin(angle).to(dtype=dtype)\n        if causal or local:\n            q_ro = apply_rotary_emb(\n                q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved\n            )\n        else:\n            q_ro = rearrange(\n                apply_rotary_emb(\n                    rearrange(q, \"b s h d -> b 1 (s h) d\"),\n                    cos,\n                    sin,\n                    seqlen_offsets=cache_seqlens,\n                    interleaved=rotary_interleaved,\n                ),\n                \"b 1 (s h) d -> b s h d\",\n                s=seqlen_q,\n            )\n        # q_ro = q\n        k_ro = apply_rotary_emb(\n            k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved\n        )\n    else:\n        cos, sin = None, None\n        q_ro, k_ro = q, k\n    # k_cache[:, 64:] = -1\n    k_cache_ref = (\n        k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)]\n    ).clone()\n    v_cache_ref = (\n        v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)]\n    ).clone()\n    if new_kv:\n        update_mask = torch.logical_and(\n            cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_new\n        )\n        k_cache_ref[update_mask] = rearrange(k_ro, \"b s ... -> (b s) ...\")\n        v_cache_ref[update_mask] = rearrange(v, \"b s ... -> (b s) ...\")\n    k_cache_rep = repeat(k_cache_ref, \"b s h d -> b s (h g) d\", g=nheads // nheads_k)\n    v_cache_rep = repeat(v_cache_ref, \"b s h d -> b s (h g) d\", g=nheads // nheads_k)\n    out = flash_attn_with_kvcache(\n        q,\n        k_cache if paged_kv_block_size is None else k_cache_paged,\n        v_cache if paged_kv_block_size is None else v_cache_paged,\n        k,\n        v,\n        rotary_cos=cos,\n        rotary_sin=sin,\n        cache_seqlens=cache_seqlens,\n        cache_batch_idx=cache_batch_idx,\n        cache_leftpad=cache_leftpad,\n        block_table=block_table,\n        causal=causal,\n        window_size=window_size,\n        rotary_interleaved=rotary_interleaved,\n        alibi_slopes=alibi_slopes,\n        num_splits=num_splits,\n    )\n    # out = flash_attn_with_kvcache(\n    #     q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size\n    # )\n    # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size)\n    # qk = torch.einsum(\"bqhd,bkhd->bhqk\", q, k_cache_ref)\n    # m = qk.amax(-1, keepdim=True)\n    # s_tmp = torch.exp((qk - m) / math.sqrt(d))\n    # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref)\n    # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)\n    # probs = torch.softmax(qk, dim=-1)\n    out_ref, _ = attention_ref(\n        q_ro,\n        k_cache_rep,\n        v_cache_rep,\n        None,\n        key_padding_mask,\n        attn_bias,\n        0.0,\n        None,\n        causal=causal,\n        window_size=window_size,\n        key_leftpad=cache_leftpad,\n    )\n    out_pt, _ = attention_ref(\n        q_ro,\n        k_cache_rep,\n        v_cache_rep,\n        None,\n        key_padding_mask,\n        attn_bias,\n        0.0,\n        None,\n        causal=causal,\n        window_size=window_size,\n        upcast=False,\n        reorder_ops=True,\n        key_leftpad=cache_leftpad,\n    )\n    print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n    print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n    print(f\"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}\")\n    print(f\"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}\")\n\n    # Check that FlashAttention's numerical error is at most twice the numerical error\n    # of a Pytorch implementation.\n    if new_kv:\n        if paged_kv_block_size is None:\n            k_cache_select = (\n                k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)]\n            )\n            v_cache_select = (\n                v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)]\n            )\n        else:\n            k_cache_select = rearrange(\n                k_cache_paged[block_table.to(dtype=torch.long).flatten()],\n                \"(b nblocks) block_size ... -> b (nblocks block_size) ...\",\n                b=batch_size,\n            )[:, :seqlen_k]\n            v_cache_select = rearrange(\n                v_cache_paged[block_table.to(dtype=torch.long).flatten()],\n                \"(b nblocks) block_size ... -> b (nblocks block_size) ...\",\n                b=batch_size,\n            )[:, :seqlen_k]\n        assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3)\n        assert torch.equal(v_cache_select, v_cache_ref)\n    mult = 3 if not alibi else 5\n    assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5\n\n\ndef _generate_block_kvcache(seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype):\n    num_blocks = math.ceil(seqlen_k / paged_kv_block_size) * batch_size * 3\n    k_cache_paged = torch.randn(\n        num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype\n    )\n    v_cache_paged = torch.randn(\n        num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype\n    )\n    block_table = rearrange(\n        torch.randperm(num_blocks, dtype=torch.int32, device=device),\n        \"(b nblocks) -> b nblocks\",\n        b=batch_size,\n    )\n    k_cache = rearrange(\n        # pytorch 1.12 doesn't have indexing with int32\n        k_cache_paged[block_table.to(dtype=torch.long).flatten()],\n        \"(b nblocks) block_size ... -> b (nblocks block_size) ...\",\n        b=batch_size,\n    )[:, :seqlen_k]\n    v_cache = rearrange(\n        v_cache_paged[block_table.to(dtype=torch.long).flatten()],\n        \"(b nblocks) block_size ... -> b (nblocks block_size) ...\",\n        b=batch_size,\n    )[:, :seqlen_k]\n    return k_cache, v_cache, block_table, k_cache_paged, v_cache_paged, num_blocks\n\n\n# @pytest.mark.parametrize(\"dtype\", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16]))\n@pytest.mark.parametrize(\"dtype\", [torch.float16])\n@pytest.mark.parametrize(\"causal\", [False, True])\n# @pytest.mark.parametrize('causal', [True])\n@pytest.mark.parametrize(\"d\", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize('d', [32, 56, 64, 80, 96, 128])\n# @pytest.mark.parametrize(\"d\", [32, 64, 96, 128, 160, 192])\n# @pytest.mark.parametrize('d', [128])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_k\",\n    [\n        (1, 239),\n        (239, 1),\n        (3, 799),\n        (799, 3),\n        (1024, 128),\n        (97, 97),\n        (128, 128),\n        (200, 200),\n        (256, 256),\n        (257, 257),\n        (384, 384),\n        (512, 512),\n        (768, 768),\n        (1024, 1024),\n    ],\n)\n@pytest.mark.parametrize(\"dropout_p\", [0.0, 0.17])\n# @pytest.mark.parametrize(\"dropout_p\", [0.0])\n@pytest.mark.skip()\ndef test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dtype):\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 60  # Sometimes we need large batch size for the race conditions to trigger\n    nheads = 4\n    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    torch.random.manual_seed(42)\n    out0, lse0, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True)\n    g = torch.randn_like(out0)\n    if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):\n        (\n            dq0,\n            dk0,\n            dv0,\n        ) = torch.autograd.grad(out0, (q, k, v), g)\n        # Numerical error if we just do any arithmetic on dq\n        dq_atol = 2 * ((dq0 + 0.3 - 0.3) - dq0).abs().max().item()\n\n    for i in range(250):\n        torch.random.manual_seed(42)\n        out, lse, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True)\n        assert torch.equal(out, out0)\n        assert torch.equal(lse, lse0)\n\n        if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):\n            (\n                dq,\n                dk,\n                dv,\n            ) = torch.autograd.grad(out, (q, k, v), g)\n            dq_equal = torch.allclose(dq, dq0, atol=dq_atol)\n            if not dq_equal:\n                print(f\"Iter {i}, {dq_atol = }, dQ max diff: {(dq - dq0).abs().max().item()}\")\n            assert torch.equal(dv, dv0)\n            assert torch.equal(dk, dk0)\n            assert dq_equal\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float16])\n@pytest.mark.parametrize(\"causal\", [False, True])\n# @pytest.mark.parametrize('causal', [False])\n@pytest.mark.parametrize(\"d\", [16, 32, 64])\n# @pytest.mark.parametrize('d', [16])\n@pytest.mark.parametrize(\"seqlen\", [1, 2, 5, 17, 128])\n# @pytest.mark.parametrize('seqlen', [2])\n@pytest.mark.skip()\ndef test_flash_attn_bwd_overflow(seqlen, d, causal, dtype):\n    \"\"\"We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ,\n    in the case where seqlen % 128 != 0.\n    \"\"\"\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 2\n    nheads = 5\n    q = torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device=\"cuda\") * 5\n    k, v = [\n        torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device=\"cuda\") * 3\n        for _ in range(2)\n    ]\n    q.requires_grad_(True)\n    k.requires_grad_(True)\n    v.requires_grad_(True)\n    out = flash_attn_func(q, k, v, causal=causal)\n    g = torch.randn_like(out)\n    out.backward(g)\n    q_pt = q.detach().clone().requires_grad_(True)\n    k_pt = k.detach().clone().requires_grad_(True)\n    v_pt = v.detach().clone().requires_grad_(True)\n    out_pt, _ = attention_ref(q_pt, k_pt, v_pt, causal=causal, upcast=False, reorder_ops=True)\n    out_pt.backward(g)\n    q_ref = q.detach().clone().requires_grad_(True)\n    k_ref = k.detach().clone().requires_grad_(True)\n    v_ref = v.detach().clone().requires_grad_(True)\n    out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal)\n    out_ref.backward(g)\n    print(f\"dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}\")\n    print(f\"dK max diff: {(k.grad - k_ref.grad).abs().max().item()}\")\n    print(f\"dV max diff: {(v.grad - v_ref.grad).abs().max().item()}\")\n    print(f\"dQ Pytorch max diff: {(q_pt.grad - q_ref.grad).abs().max().item()}\")\n    print(f\"dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}\")\n    print(f\"dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}\")\n    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()\n    assert (q.grad - q_ref.grad).abs().max().item() <= 5 * (\n        q_pt.grad - q_ref.grad\n    ).abs().max().item() + 1e-3\n    assert (k.grad - k_ref.grad).abs().max().item() <= 5 * (\n        k_pt.grad - k_ref.grad\n    ).abs().max().item() + 1e-3\n    assert (v.grad - v_ref.grad).abs().max().item() <= 5 * (\n        v_pt.grad - v_ref.grad\n    ).abs().max().item() + 1e-3\n\n\n@pytest.mark.parametrize(\"dtype\", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16]))\n# @pytest.mark.parametrize('dtype', [torch.bfloat16])\n@pytest.mark.parametrize(\"causal\", [False, True])\n# @pytest.mark.parametrize('causal', [False])\n@pytest.mark.parametrize(\"d\", [64, 128])\n# @pytest.mark.parametrize('d', [64])\n@pytest.mark.parametrize(\"seqlen\", [97, 128, 200, 256])\n# @pytest.mark.parametrize('seqlen', [128])\n@pytest.mark.skip()\ndef test_flash_attn_bwd_transpose(seqlen, d, causal, dtype):\n    \"\"\"We previously had a bug where we were using the wrong strides of dout, which shows up\n    when dout is not contiguous.\n    \"\"\"\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 5\n    nheads = 2\n    q, k, v = [\n        torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device=\"cuda\", requires_grad=True)\n        for _ in range(3)\n    ]\n    out = rearrange(flash_attn_func(q, k, v, causal=causal), \"b s ... -> s b ...\")\n    # So g is not contiguous\n    g = torch.randn(seqlen, 2 * batch_size, nheads, d, dtype=dtype, device=\"cuda\")[:, ::2]\n    out.backward(g)\n    q_pt = q.detach().clone().requires_grad_(True)\n    k_pt = k.detach().clone().requires_grad_(True)\n    v_pt = v.detach().clone().requires_grad_(True)\n    out_pt, attn_pt = attention_ref(q_pt, k_pt, v_pt, causal=causal, upcast=False, reorder_ops=True)\n    out_pt = rearrange(out_pt, \"b s ... -> s b ...\")\n    out_pt.backward(g)\n    q_ref = q.detach().clone().requires_grad_(True)\n    k_ref = k.detach().clone().requires_grad_(True)\n    v_ref = v.detach().clone().requires_grad_(True)\n    out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal)\n    out_ref = rearrange(out_ref, \"b s ... -> s b ...\")\n    out_ref.backward(g)\n    print(f\"dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}\")\n    print(f\"dK max diff: {(k.grad - k_ref.grad).abs().max().item()}\")\n    print(f\"dV max diff: {(v.grad - v_ref.grad).abs().max().item()}\")\n    print(f\"dQ Pytorch max diff: {(q_pt.grad - q_ref.grad).abs().max().item()}\")\n    print(f\"dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}\")\n    print(f\"dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}\")\n    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()\n    assert (q.grad - q_ref.grad).abs().max().item() <= 2 * (\n        q_pt.grad - q_ref.grad\n    ).abs().max().item()\n    assert (k.grad - k_ref.grad).abs().max().item() <= 2 * (\n        k_pt.grad - k_ref.grad\n    ).abs().max().item()\n    assert (v.grad - v_ref.grad).abs().max().item() <= 2 * (\n        v_pt.grad - v_ref.grad\n    ).abs().max().item()\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float16])\n@pytest.mark.parametrize(\"causal\", [False, True])\n# @pytest.mark.parametrize('causal', [False])\n@pytest.mark.parametrize(\"d\", [16, 32, 64])\n# @pytest.mark.parametrize('d', [16])\n@pytest.mark.skip()\ndef test_flash_attn_bwd_varlen_overflow(d, causal, dtype):\n    \"\"\"We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ,\n    in the case where seqlen % 128 != 0 or varlen.\n    \"\"\"\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(0)\n    nheads = 5\n    q_cuseqlen = torch.tensor([0, 76, 110, 256], device=device, dtype=torch.int32)\n    k_cuseqlen = torch.tensor([0, 1, 2, 3], device=device, dtype=torch.int32)\n    Mq = 256\n    Mk = 3\n\n    q = torch.randn([Mq, nheads, d], dtype=dtype, device=device) * 3\n    k, v = [torch.randn([Mk, nheads, d], dtype=dtype, device=device) * 3 for _ in range(2)]\n    q.requires_grad_(True)\n    k.requires_grad_(True)\n    v.requires_grad_(True)\n\n    out = flash_attn_varlen_func(q, k, v, q_cuseqlen, k_cuseqlen, Mq, Mk, causal=causal)\n    g = torch.randn_like(out)\n    out.backward(g)\n\n    assert not q.grad.isnan().any()\n    assert not k.grad.isnan().any()\n    assert not v.grad.isnan().any()\n\n\n@pytest.mark.parametrize(\"dtype\", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16]))\n# @pytest.mark.parametrize(\"dtype\", [torch.bfloat16])\n@pytest.mark.parametrize(\"local\", [False])\n# @pytest.mark.parametrize(\"local\", [True])\n@pytest.mark.parametrize(\"causal\", [False, True])\n# @pytest.mark.parametrize(\"causal\", [True])\n@pytest.mark.parametrize(\"d\", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize(\"d\", [32, 64, 96, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])\n# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])\n# @pytest.mark.parametrize('d', [56, 80])\n# @pytest.mark.parametrize(\"d\", [64])\n@pytest.mark.parametrize(\"swap_sq_sk\", [False])\n# @pytest.mark.parametrize(\"swap_sq_sk\", [False])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_k\",\n    [\n        (1, 239),\n        (3, 799),\n        (127, 512),\n        (127, 513),\n        (113, 203),\n        (128, 217),\n        (113, 211),\n        (108, 256),\n        (256, 512),\n        (1023, 1024),\n    ],\n)\n# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])\n@pytest.mark.skip()\ndef test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype):\n    if (\n        max(seqlen_q, seqlen_k) >= 2048\n        and torch.cuda.get_device_properties(\"cuda\").total_memory <= 16 * 2**30\n    ):\n        pytest.skip()  # Reference implementation OOM\n    if swap_sq_sk:\n        seqlen_q, seqlen_k = seqlen_k, seqlen_q\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 4\n    nheads = 9\n    window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))\n    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size, deterministic=True)\n\n    g = torch.randn_like(out)\n    dq0, dk0, dv0 = torch.autograd.grad(out, (q, k, v), g, retain_graph=True)\n    for _ in range(50):\n        dq, dk, dv = torch.autograd.grad(out, (q, k, v), g, retain_graph=True)\n        assert torch.equal(dv, dv0)\n        assert torch.equal(dk, dk0)\n        assert torch.equal(dq, dq0)\n\n\n@pytest.mark.parametrize(\"dtype\", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16]))\n# @pytest.mark.parametrize(\"dtype\", [torch.bfloat16])\n@pytest.mark.parametrize(\"local\", [False])\n# @pytest.mark.parametrize(\"local\", [True])\n@pytest.mark.parametrize(\"causal\", [False, True])\n# @pytest.mark.parametrize(\"causal\", [True])\n@pytest.mark.parametrize(\"d\", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize(\"d\", [32, 64, 96, 128, 160, 192, 224, 256])\n# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])\n# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])\n# @pytest.mark.parametrize('d', [56, 80])\n# @pytest.mark.parametrize(\"d\", [64])\n@pytest.mark.parametrize(\"swap_sq_sk\", [False])\n# @pytest.mark.parametrize(\"swap_sq_sk\", [True])\n@pytest.mark.parametrize(\n    \"seqlen_q,seqlen_k\",\n    [\n        (1, 239),\n        (3, 799),\n        (127, 512),\n        (127, 513),\n        (113, 203),\n        (128, 217),\n        (113, 211),\n        (108, 256),\n        (256, 512),\n        (1023, 1024),\n    ],\n)\n# @pytest.mark.parametrize(\"seqlen_q,seqlen_k\", [(256, 128)])\n@pytest.mark.skip()\ndef test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype):\n    if (\n        max(seqlen_q, seqlen_k) >= 2048\n        and torch.cuda.get_device_properties(\"cuda\").total_memory <= 16 * 2**30\n    ):\n        pytest.skip()  # Reference implementation OOM\n    if swap_sq_sk:\n        seqlen_q, seqlen_k = seqlen_k, seqlen_q\n    device = \"cuda\"\n    # set seed\n    torch.random.manual_seed(0)\n    batch_size = 2\n    nheads = 9\n    window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))\n    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)\n    query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode=\"random\")\n    key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode=\"random\")\n    (\n        q_unpad,\n        k_unpad,\n        v_unpad,\n        cu_seqlens_q,\n        cu_seqlens_k,\n        max_seqlen_q,\n        max_seqlen_k,\n        q,\n        k,\n        v,\n        output_pad_fn,\n        dq_pad_fn,\n        dk_pad_fn,\n    ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)\n    out = flash_attn_varlen_func(\n        q_unpad,\n        k_unpad,\n        v_unpad,\n        cu_seqlens_q,\n        cu_seqlens_k,\n        max_seqlen_q,\n        max_seqlen_k,\n        0.0,\n        causal=causal,\n        window_size=window_size,\n        deterministic=True,\n    )\n\n    g = torch.randn_like(out)\n    dq0, dk0, dv0 = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)\n    for _ in range(50):\n        dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)\n        assert torch.equal(dv, dv0)\n        assert torch.equal(dk, dk0)\n        assert torch.equal(dq, dq0)\n"
  },
  {
    "path": "tests/test_rotary.py",
    "content": "import math\nimport random\n\nimport pytest\nimport torch\nimport torch.nn.functional as F\nfrom einops import rearrange\n\nimport triton\n\nfrom flash_attn.layers.rotary import apply_rotary_emb, apply_rotary_emb_torch\nfrom flash_attn.layers.rotary import apply_rotary_emb_qkv_, apply_rotary_emb_kv_\nfrom flash_attn.bert_padding import pad_input, unpad_input\n\nis_sm8x = torch.cuda.get_device_capability(\"cuda\") >= (8, 0)\n\n\ndef generate_cos_sin(seqlen, rotary_dim, device, dtype):\n    assert rotary_dim % 2 == 0\n    angle = torch.rand(seqlen * 2, rotary_dim // 2, device=device) * 2 * math.pi\n    cos = torch.cos(angle).to(dtype=dtype)\n    sin = torch.sin(angle).to(dtype=dtype)\n    return cos, sin\n\n\ndef generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device):\n    if seqlen_offsets_type == 0:\n        return 0\n    elif seqlen_offsets_type is int:\n        return torch.randint(0, seqlen + 1, (1,)).item()\n    elif seqlen_offsets_type is torch.Tensor:\n        return torch.randint(0, seqlen + 1, (batch_size,), dtype=torch.int32, device=device)\n\n\ndef index_cos_sin(cos, sin, seqlen_offsets, seqlen):\n    if isinstance(seqlen_offsets, torch.Tensor):\n        batch_size = seqlen_offsets.shape[0]\n        arange = rearrange(torch.arange(seqlen, device=cos.device), \"s -> 1 s\")\n        idx = rearrange(seqlen_offsets, \"b -> b 1\") + arange\n        cos_pt = rearrange(cos[idx.flatten()], \"(b s) d -> b s d\", b=batch_size)\n        sin_pt = rearrange(sin[idx.flatten()], \"(b s) d -> b s d\", b=batch_size)\n    else:\n        cos_pt = cos[seqlen_offsets : seqlen_offsets + seqlen]\n        sin_pt = sin[seqlen_offsets : seqlen_offsets + seqlen]\n    return cos_pt, sin_pt\n\n\n@pytest.mark.parametrize(\n    \"dtype\", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])\n)\n# @pytest.mark.parametrize('dtype', ([torch.bfloat16]))\n@pytest.mark.parametrize(\"seqlen_offsets_type\", [0, int, torch.Tensor])\n# @pytest.mark.parametrize(\"seqlen_offsets_type\", [0])\n@pytest.mark.parametrize(\"rotary_fraction\", [1.0, 0.5])\n# @pytest.mark.parametrize('rotary_fraction', [1.0])\n@pytest.mark.parametrize(\"interleaved\", [False, True])\n# @pytest.mark.parametrize('interleaved', [True])\n@pytest.mark.parametrize(\"inplace\", [False, True])\n# @pytest.mark.parametrize('inplace', [False])\ndef test_rotary_emb_func(inplace, interleaved, rotary_fraction, seqlen_offsets_type, dtype):\n    rtol = 1e-3\n    batch_size = 32\n    nheads = 4\n    seqlen = 217\n    headdim = 128\n    device = \"cuda\"\n    rotary_dim = int(rotary_fraction * headdim)\n    torch.manual_seed(42)\n    x = torch.randn(\n        batch_size, seqlen, nheads, headdim, dtype=dtype, device=device, requires_grad=True\n    )\n    x_pt = x.detach().clone().requires_grad_()\n    cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype)\n    seqlen_offsets = generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device)\n    out = apply_rotary_emb(\n        x, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=inplace\n    )\n    cos_pt, sin_pt = index_cos_sin(cos, sin, seqlen_offsets, seqlen)\n    out_pt = apply_rotary_emb_torch(\n        x_pt.float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved\n    ).to(dtype=dtype)\n    print(f\"Output max diff: {(out - out_pt).abs().max().item()}\")\n\n    g = torch.randn_like(out)\n    g_pt = g.clone()  # If inplace=True, we might modify the gradient inplace\n    out.backward(g)\n    out_pt.backward(g_pt)\n    print(f\"Grad max diff: {(x.grad - x_pt.grad).abs().max().item()}\")\n\n    if not inplace:\n        assert torch.equal(x, x_pt)\n    # Numerical error if we just do any arithmetic\n    atol = ((out_pt + 0.3 - 0.3) - out_pt).abs().max().item()\n    assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)\n    atol = ((x_pt.grad + 0.3 - 0.3) - x_pt.grad).abs().max().item()\n    assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=2 * atol)\n\n\n@pytest.mark.parametrize(\n    \"dtype\", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])\n)\n# @pytest.mark.parametrize('dtype', ([torch.float16]))\n@pytest.mark.parametrize(\"compiled\", [False, True])\n# @pytest.mark.parametrize(\"compiled\", [True])\n@pytest.mark.parametrize(\"gqa\", [False, True])\n# @pytest.mark.parametrize(\"gqa\", [False])\n@pytest.mark.parametrize(\"seqlen_offsets_type\", [0, int, torch.Tensor])\n# @pytest.mark.parametrize(\"seqlen_offsets_type\", [0])\n@pytest.mark.parametrize(\"rotary_fraction\", [1.0, 0.5])\n# @pytest.mark.parametrize('rotary_fraction', [1.0])\n@pytest.mark.parametrize(\"interleaved\", [False, True])\n# @pytest.mark.parametrize('interleaved', [False])\ndef test_rotary_emb_qkv(interleaved, rotary_fraction, seqlen_offsets_type, gqa, compiled, dtype):\n    if compiled:  # Don't fall back to eager just bc of recompilation\n        torch._dynamo.config.recompile_limit = 2 ** 31\n    rtol = 1e-3\n    batch_size = 32\n    nheads = 4\n    seqlen = 512\n    headdim = 128\n    device = \"cuda\"\n    rotary_dim = int(rotary_fraction * headdim)\n    torch.manual_seed(42)\n    if not gqa:\n        qkv = torch.randn(\n            batch_size, seqlen, 3, nheads, headdim, dtype=dtype, device=device, requires_grad=True\n        )\n    else:\n        nheads_k = nheads // 2\n        qkv = torch.randn(\n            batch_size, seqlen, nheads + nheads_k * 2, headdim, dtype=dtype, device=device, requires_grad=True\n        )\n    qkv_pt = qkv.detach().clone().requires_grad_()\n    cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype)\n    seqlen_offsets = generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device)\n    fn = apply_rotary_emb_qkv_ if not compiled else torch.compile(apply_rotary_emb_qkv_)\n    out = fn(\n        qkv, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved,\n        num_heads_q=None if not gqa else nheads\n    )\n    cos_pt, sin_pt = index_cos_sin(cos, sin, seqlen_offsets, seqlen)\n    if not gqa:\n        q_pt, k_pt, v_pt = qkv_pt.unbind(2)\n    else:\n        q_pt, k_pt, v_pt = qkv_pt.split([nheads, nheads_k, nheads_k], dim=2)\n    q_pt = apply_rotary_emb_torch(\n        q_pt.float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved\n    ).to(dtype=dtype)\n    k_pt = apply_rotary_emb_torch(\n        k_pt.float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved\n    ).to(dtype=dtype)\n    if not gqa:\n        out_pt = torch.stack([q_pt, k_pt, v_pt], dim=2)\n    else:\n        out_pt = torch.cat([q_pt, k_pt, v_pt], dim=2)\n    print(f\"Output max diff: {(out - out_pt).abs().max().item()}\")\n\n    g = torch.randn_like(out)\n    g_pt = g.clone()  # Since inplace=True, we modify the gradient inplace\n    out.backward(g)\n    out_pt.backward(g_pt)\n    print(f\"Grad max diff: {(qkv.grad - qkv_pt.grad).abs().max().item()}\")\n\n    # Numerical error if we just do any arithmetic\n    atol = ((out_pt + 0.3 - 0.3) - out_pt).abs().max().item()\n    assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)\n    atol = ((qkv_pt.grad + 0.3 - 0.3) - qkv_pt.grad).abs().max().item()\n    assert torch.allclose(qkv.grad, qkv_pt.grad, rtol=rtol, atol=2 * atol)\n\n\n@pytest.mark.parametrize(\n    \"dtype\", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])\n)\n# @pytest.mark.parametrize('dtype', ([torch.float16]))\n@pytest.mark.parametrize(\"seqlen_offsets_type\", [0, int, torch.Tensor])\n# @pytest.mark.parametrize(\"seqlen_offsets_type\", [0])\n@pytest.mark.parametrize(\"rotary_fraction\", [1.0, 0.5])\n# @pytest.mark.parametrize('rotary_fraction', [1.0])\n@pytest.mark.parametrize(\"interleaved\", [False, True])\n# @pytest.mark.parametrize('interleaved', [False])\ndef test_rotary_emb_kv(interleaved, rotary_fraction, seqlen_offsets_type, dtype):\n    rtol = 1e-3\n    batch_size = 32\n    nheads = 4\n    seqlen = 781\n    headdim = 64\n    device = \"cuda\"\n    rotary_dim = int(rotary_fraction * headdim)\n    torch.manual_seed(42)\n    kv = torch.randn(\n        batch_size, seqlen, 2, nheads, headdim, dtype=dtype, device=device, requires_grad=True\n    )\n    kv_pt = kv.detach().clone().requires_grad_()\n    cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype)\n    seqlen_offsets = generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device)\n    out = apply_rotary_emb_kv_(kv, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved)\n    cos_pt, sin_pt = index_cos_sin(cos, sin, seqlen_offsets, seqlen)\n    k_pt = apply_rotary_emb_torch(\n        kv_pt[:, :, 0].float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved\n    ).to(dtype=dtype)\n    out_pt = torch.stack([k_pt, kv_pt[:, :, 1]], dim=2)\n    print(f\"Output max diff: {(out - out_pt).abs().max().item()}\")\n\n    g = torch.randn_like(out)\n    g_pt = g.clone()  # Since inplace=True, we modify the gradient inplace\n    out.backward(g)\n    out_pt.backward(g_pt)\n    print(f\"Grad max diff: {(kv.grad - kv_pt.grad).abs().max().item()}\")\n\n    # Numerical error if we just do any arithmetic\n    atol = ((out_pt + 0.3 - 0.3) - out_pt).abs().max().item()\n    assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)\n    atol = ((kv_pt.grad + 0.3 - 0.3) - kv_pt.grad).abs().max().item()\n    assert torch.allclose(kv.grad, kv_pt.grad, rtol=rtol, atol=2 * atol)\n\n\n@pytest.mark.parametrize(\n    \"dtype\", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])\n)\n# @pytest.mark.parametrize(\"dtype\", ([torch.float16]))\n@pytest.mark.parametrize(\"seqlen_offsets_type\", [0, int, torch.Tensor])\n# @pytest.mark.parametrize(\"seqlen_offsets_type\", [0])\n@pytest.mark.parametrize(\"rotary_fraction\", [1.0, 0.5])\n# @pytest.mark.parametrize(\"rotary_fraction\", [1.0])\n@pytest.mark.parametrize(\"interleaved\", [False, True])\n# @pytest.mark.parametrize(\"interleaved\", [True])\n@pytest.mark.parametrize(\"inplace\", [False, True])\n# @pytest.mark.parametrize(\"inplace\", [False])\ndef test_rotary_emb_varlen_func(inplace, interleaved, rotary_fraction, seqlen_offsets_type, dtype):\n    rtol = 1e-3\n    batch_size = 32\n    nheads = 4\n    seqlen = 217\n    headdim = 128\n    device = \"cuda\"\n    rotary_dim = int(rotary_fraction * headdim)\n    torch.manual_seed(42)\n    x = torch.randn(batch_size, seqlen, nheads, headdim, dtype=dtype, device=device)\n    x_pt = x.detach().clone().requires_grad_()\n    lengths = torch.randint(max(1, seqlen - 20), seqlen + 1, (batch_size, 1), device=device)\n    padding_mask = rearrange(torch.arange(seqlen, device=device), \"s -> 1 s\") < lengths\n    x_unpad, indices, cu_seqlens, max_seqlen, _ = unpad_input(x, padding_mask)\n    x_unpad_clone = x_unpad.clone()\n    x_unpad = x_unpad.requires_grad_()\n    cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype)\n    seqlen_offsets = generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device)\n    out_unpad = apply_rotary_emb(\n        x_unpad,\n        cos,\n        sin,\n        seqlen_offsets=seqlen_offsets,\n        interleaved=interleaved,\n        inplace=inplace,\n        cu_seqlens=cu_seqlens,\n        max_seqlen=max_seqlen,\n    )\n    out = pad_input(out_unpad, indices, batch_size, seqlen)\n    cos_pt, sin_pt = index_cos_sin(cos, sin, seqlen_offsets, seqlen)\n    out_pt = apply_rotary_emb_torch(\n        x_pt.float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved\n    ).to(dtype=dtype)\n    out_pt = out_pt.masked_fill(rearrange(~padding_mask, \"b s -> b s 1 1\"), 0.0)\n    print(f\"Output max diff: {(out - out_pt).abs().max().item()}\")\n\n    g = torch.randn_like(out)\n    g_pt = g.clone()  # If inplace=True, we might modify the gradient inplace\n    out.backward(g)\n    out_pt.backward(g_pt)\n    x_grad = pad_input(x_unpad.grad, indices, batch_size, seqlen)\n    print(f\"Grad max diff: {(x_grad - x_pt.grad).abs().max().item()}\")\n\n    if not inplace:\n        assert torch.equal(x_unpad, x_unpad_clone)\n    # Numerical error if we just do any arithmetic\n    atol = ((out_pt + 0.3 - 0.3) - out_pt).abs().max().item()\n    assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)\n    atol = ((x_pt.grad + 0.3 - 0.3) - x_pt.grad).abs().max().item()\n    assert torch.allclose(x_grad, x_pt.grad, rtol=rtol, atol=2 * atol)\n\n\ndef test_compilation_count():\n    nheads = 4\n    headdim = 128\n    device = \"cuda\"\n    dtype = torch.float16\n    torch.manual_seed(42)\n\n    from triton.runtime.jit import JITFunction\n    from flash_attn.ops.triton.rotary import rotary_kernel\n    compilation_count = 0\n\n    def count_compilations(*args, **kwargs):\n        nonlocal compilation_count\n        compilation_count += 1\n\n    old_cache_func = JITFunction.cache_hook\n\n    try:\n        if hasattr(rotary_kernel, \"cache\"):\n            rotary_kernel.cache.clear()\n        else:  # Triton 3.3 replaces cache with per-device device_caches\n            device = triton.runtime.driver.active.get_current_device()\n            # device_caches[device] returns a 4-tuple: (kernel_cache, target, backend, binder)\n            rotary_kernel.device_caches[device][0].clear()\n\n        JITFunction.cache_hook = count_compilations\n\n        for seqlen in (128, 256):\n            for batch_size in (4, 32):\n                x = torch.randn(batch_size, seqlen, nheads, headdim, dtype=dtype, device=device)\n                x.requires_grad_()\n                cos, sin = generate_cos_sin(seqlen, headdim, device, dtype)\n                out = apply_rotary_emb(x, cos, sin)\n                out.backward(torch.randn_like(out))\n\n        # Only two kernels are expected to be compiled:\n        #   * for the forward pass (conjugate=False)\n        #   * for the backward pass (conjugate=True)\n        assert compilation_count == 2\n    finally:\n        JITFunction.cache_hook = old_cache_func\n"
  },
  {
    "path": "tests/test_util.py",
    "content": "import math\n\nimport torch\nfrom einops import rearrange, repeat\nfrom flash_attn.bert_padding import pad_input, unpad_input\n\n\ndef generate_random_padding_mask(max_seqlen, batch_size, device, mode=\"random\", zero_lengths=False):\n    assert mode in [\"full\", \"random\", \"third\"]\n    if mode == \"full\":\n        lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32)\n    elif mode == \"random\":\n        lengths = torch.randint(\n            max(0 if zero_lengths else 1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device\n        )\n    elif mode == \"third\":\n        lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device)\n\n    if zero_lengths:\n        # Generate zero-lengths every 5 batches and the last batch.\n        for i in range(batch_size):\n            if i % 5 == 0:\n                lengths[i] = 0\n        lengths[-1] = 0\n    padding_mask = (\n        repeat(torch.arange(max_seqlen, device=device), \"s -> b s\", b=batch_size) < lengths\n    )\n    return padding_mask\n\n\ndef generate_qkv(\n    q, k, v, query_padding_mask=None, key_padding_mask=None, \n    kvpacked=False, qkvpacked=False, add_unused_qkv=False,\n    query_unused_mask=None, key_unused_mask=None,\n):\n    \"\"\"\n    Arguments:\n        q: (batch_size, seqlen_q, nheads, d)\n        k: (batch_size, seqlen_k, nheads_k, d)\n        v: (batch_size, seqlen_k, nheads_k, d)\n        query_padding_mask: (batch_size, seqlen), bool\n        key_padding_mask: (batch_size, seqlen), bool\n    \"\"\"\n    assert not (kvpacked and qkvpacked)\n    batch_size, seqlen_q, nheads, d = q.shape\n    _, seqlen_k, nheads_k, _ = k.shape\n    assert k.shape == (batch_size, seqlen_k, nheads_k, d)\n    assert v.shape == (batch_size, seqlen_k, nheads_k, d)\n    if query_unused_mask is not None or key_unused_mask is not None:\n        assert not kvpacked\n        assert not qkvpacked\n\n    if query_padding_mask is not None:\n        q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q = unpad_input(\n            q, query_padding_mask, query_unused_mask,\n        )\n        output_pad_fn = lambda output_unpad: pad_input(\n            output_unpad, indices_q, batch_size, seqlen_q\n        )\n    else:\n        q_unpad = rearrange(q, \"b s h d -> (b s) h d\")\n        cu_seqlens_q = torch.arange(\n            0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device\n        )\n        seqused_q = None\n        max_seqlen_q = seqlen_q\n        output_pad_fn = lambda output_unpad: rearrange(\n            output_unpad, \"(b s) h d -> b s h d\", b=batch_size\n        )\n\n    if key_padding_mask is not None:\n        k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k = unpad_input(k, key_padding_mask, key_unused_mask)\n        v_unpad, _, _, _, _ = unpad_input(v, key_padding_mask, key_unused_mask)\n    else:\n        k_unpad = rearrange(k, \"b s h d -> (b s) h d\")\n        v_unpad = rearrange(v, \"b s h d -> (b s) h d\")\n        cu_seqlens_k = torch.arange(\n            0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device\n        )\n        seqused_k = None\n        max_seqlen_k = seqlen_k\n\n    if qkvpacked:\n        assert (query_padding_mask == key_padding_mask).all()\n        assert nheads == nheads_k\n        qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)\n        qkv = torch.stack([q, k, v], dim=2)\n        if query_padding_mask is not None:\n            dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q)\n        else:\n            dqkv_pad_fn = lambda dqkv_unpad: rearrange(\n                dqkv_unpad, \"(b s) t h d -> b s t h d\", b=batch_size\n            )\n        return (\n            qkv_unpad.detach().requires_grad_(),\n            cu_seqlens_q,\n            max_seqlen_q,\n            qkv.detach().requires_grad_(),\n            output_pad_fn,\n            dqkv_pad_fn,\n        )\n    elif kvpacked:\n        kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)\n        kv = torch.stack([k, v], dim=2)\n        dq_pad_fn = output_pad_fn\n        if key_padding_mask is not None:\n            dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k)\n        else:\n            dkv_pad_fn = lambda dkv_unpad: rearrange(\n                dkv_unpad, \"(b s) t h d -> b s t h d\", b=batch_size\n            )\n        return (\n            q_unpad.detach().requires_grad_(),\n            kv_unpad.detach().requires_grad_(),\n            cu_seqlens_q,\n            cu_seqlens_k,\n            max_seqlen_q,\n            max_seqlen_k,\n            q.detach().requires_grad_(),\n            kv.detach().requires_grad_(),\n            output_pad_fn,\n            dq_pad_fn,\n            dkv_pad_fn,\n        )\n    else:\n        dq_pad_fn = output_pad_fn\n        if key_padding_mask is not None:\n            dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k)\n        else:\n            dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, \"(b s) h d -> b s h d\", b=batch_size)\n        return (\n            q_unpad.detach().requires_grad_(),\n            k_unpad.detach().requires_grad_(),\n            v_unpad.detach().requires_grad_(),\n            cu_seqlens_q,\n            cu_seqlens_k,\n            seqused_q,\n            seqused_k,\n            max_seqlen_q,\n            max_seqlen_k,\n            q.detach().requires_grad_(),\n            k.detach().requires_grad_(),\n            v.detach().requires_grad_(),\n            output_pad_fn,\n            dq_pad_fn,\n            dk_pad_fn,\n        )\n\n\ndef construct_local_mask(\n    seqlen_q,\n    seqlen_k,\n    window_size=(-1, -1),  # -1 means infinite window size\n    query_padding_mask=None,\n    key_padding_mask=None,\n    device=None,\n    key_leftpad=None,\n):\n    row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), \"s -> s 1\")\n    col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)\n    if key_leftpad is not None:\n        key_leftpad = rearrange(key_leftpad, \"b -> b 1 1 1\")\n        col_idx = repeat(col_idx, \"s -> b 1 1 s\", b=key_leftpad.shape[0])\n        col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)\n    sk = (\n        seqlen_k\n        if key_padding_mask is None\n        else rearrange(key_padding_mask.sum(-1), \"b -> b 1 1 1\")\n    )\n    sq = (\n        seqlen_q\n        if query_padding_mask is None\n        else rearrange(query_padding_mask.sum(-1), \"b -> b 1 1 1\")\n    )\n    if window_size[0] < 0:\n        return col_idx > row_idx + sk - sq + window_size[1]\n    else:\n        sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk\n        return torch.logical_or(\n            col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),\n            col_idx < row_idx + sk - sq - window_size[0],\n        )\n\n\ndef attention_ref(\n    q,\n    k,\n    v,\n    query_padding_mask=None,\n    key_padding_mask=None,\n    attn_bias=None,\n    dropout_p=0.0,\n    dropout_mask=None,\n    causal=False,\n    window_size=(-1, -1),  # -1 means infinite window size\n    softcap=0.0,\n    upcast=True,\n    reorder_ops=False,\n    key_leftpad=None,\n):\n    \"\"\"\n    Arguments:\n        q: (batch_size, seqlen_q, nheads, head_dim)\n        k: (batch_size, seqlen_k, nheads_k, head_dim)\n        v: (batch_size, seqlen_k, nheads_k, head_dim)\n        query_padding_mask: (batch_size, seqlen_q)\n        key_padding_mask: (batch_size, seqlen_k)\n        attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)\n        dropout_p: float\n        dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)\n        causal: whether to apply causal masking\n        window_size: (int, int), left and right window size\n        upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast\n            output back to fp16/bf16.\n        reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.)\n            without changing the math. This is to estimate the numerical error from operation\n            reordering.\n    Output:\n        output: (batch_size, seqlen_q, nheads, head_dim)\n        attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout\n    \"\"\"\n    if causal:\n        window_size = (window_size[0], 0)\n    dtype_og = q.dtype\n    if upcast:\n        q, k, v = q.float(), k.float(), v.float()\n    seqlen_q, seqlen_k = q.shape[1], k.shape[1]\n    k = repeat(k, \"b s h d -> b s (h g) d\", g=q.shape[2] // k.shape[2])\n    v = repeat(v, \"b s h d -> b s (h g) d\", g=q.shape[2] // v.shape[2])\n    d = q.shape[-1]\n    if not reorder_ops:\n        scores = torch.einsum(\"bthd,bshd->bhts\", q / math.sqrt(d), k)\n    else:\n        scores = torch.einsum(\"bthd,bshd->bhts\", q, k / math.sqrt(d))\n    if softcap > 0:\n        scores /= softcap\n        scores = scores.tanh()\n        scores *= softcap\n    if key_padding_mask is not None:\n        scores.masked_fill_(rearrange(~key_padding_mask, \"b s -> b 1 1 s\"), float(\"-inf\"))\n    if window_size[0] >= 0 or window_size[1] >= 0:\n        local_mask = construct_local_mask(\n            seqlen_q,\n            seqlen_k,\n            window_size,\n            query_padding_mask,\n            key_padding_mask,\n            q.device,\n            key_leftpad=key_leftpad,\n        )\n        scores.masked_fill_(local_mask, float(\"-inf\"))\n    if attn_bias is not None:\n        scores = scores + attn_bias\n    attention = torch.softmax(scores, dim=-1).to(v.dtype)\n    # Some rows might be completely masked out so we fill them with zero instead of NaN\n    if window_size[0] >= 0 or window_size[1] >= 0:\n        attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0)\n    # We want to mask here so that the attention matrix doesn't have any NaNs\n    # Otherwise we'll get NaN in dV\n    if query_padding_mask is not None:\n        attention = attention.masked_fill(rearrange(~query_padding_mask, \"b s -> b 1 s 1\"), 0.0)\n    dropout_scaling = 1.0 / (1 - dropout_p)\n    # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling\n    # output = torch.einsum('bhts,bshd->bthd', attention_drop , v)\n    if dropout_mask is not None:\n        attention_drop = attention.masked_fill(~dropout_mask, 0.0)\n    else:\n        attention_drop = attention\n    output = torch.einsum(\"bhts,bshd->bthd\", attention_drop, v * dropout_scaling)\n    if query_padding_mask is not None:\n        output.masked_fill_(rearrange(~query_padding_mask, \"b s -> b s 1 1\"), 0.0)\n    if key_padding_mask is not None:\n        output.masked_fill_(rearrange(torch.logical_not(torch.any(key_padding_mask, 1)), \"b -> b 1 1 1\"), 0.0)\n    return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)\n"
  },
  {
    "path": "tools/sass_diff.py",
    "content": "#!/usr/bin/env python3\n\"\"\"Compare two SASS files, ignoring register assignments and addresses.\n\nNormalizes registers per-instruction so that two instructions doing the\nsame operation with different register allocations compare as equal.\nE.g. \"UIADD3 UR30, UP1, UR30, 0x70, URZ\" and\n     \"UIADD3 UR14, UP1, UR38, 0x70, URZ\" both normalize to\n     \"UIADD3 UR_0, UP_0, UR_1, 0x70, URZ\"\n\nUsage:\n    python scripts/sass_diff.py file_a.sass file_b.sass\n    python scripts/sass_diff.py file_a.sass file_b.sass --context 5\n    python scripts/sass_diff.py file_a.sass file_b.sass --all       # include metadata\n    python scripts/sass_diff.py file_a.sass file_b.sass --summary-only\n\"\"\"\n\nimport argparse\nimport re\nimport sys\nfrom dataclasses import dataclass, field\nfrom difflib import SequenceMatcher\n\n\n# ── Parsing ──────────────────────────────────────────────────────────────────\n\nADDR_LINE_RE = re.compile(r\"^\\s+/\\*([0-9a-f]+)\\*/\\s+(.*?)\\s*;?\\s*$\")\nLABEL_RE = re.compile(r\"^(\\.L_x_\\d+):\\s*$\")\nMETADATA_PREFIXES = (\".byte\", \".word\", \".short\", \".dword\", \".string\", \".align\")\n\n# Register pattern: match UR before R, UP before P\nREG_RE = re.compile(r\"\\b(UP|UR|P|R)(\\d+)\\b\")\n\n\n@dataclass\nclass Line:\n    \"\"\"One parsed SASS line.\"\"\"\n    addr: str           # hex address or \"\" for labels\n    raw: str            # original text (no addr prefix)\n    normalized: str     # register-normalized for comparison\n    lineno: int         # 1-based line number in file\n    is_code: bool       # True for instructions/labels\n\n\ndef _normalize_instr(text: str) -> str:\n    \"\"\"Normalize one instruction by replacing registers with positional IDs.\n\n    Each register class (R, UR, P, UP) gets its own counter, reset per\n    instruction. Constants RZ, URZ, PT, UPT are preserved.\n    \"\"\"\n    counters: dict[str, int] = {}\n    mapping: dict[str, str] = {}\n\n    def repl(m: re.Match) -> str:\n        name = m.group(0)\n        if name in (\"RZ\", \"URZ\", \"PT\", \"UPT\"):\n            return name\n        if name in mapping:\n            return mapping[name]\n        prefix = m.group(1)\n        idx = counters.get(prefix, 0)\n        counters[prefix] = idx + 1\n        mapping[name] = f\"{prefix}_{idx}\"\n        return mapping[name]\n\n    return REG_RE.sub(repl, text)\n\n\ndef parse_sass(path: str) -> list[Line]:\n    \"\"\"Extract instruction, label, and metadata lines from a SASS file.\"\"\"\n    lines: list[Line] = []\n\n    with open(path) as f:\n        for lineno, raw in enumerate(f, 1):\n            raw = raw.rstrip()\n\n            m = LABEL_RE.match(raw)\n            if m:\n                label = m.group(1)\n                lines.append(Line(\"\", label, label, lineno, True))\n                continue\n\n            m = ADDR_LINE_RE.match(raw)\n            if m:\n                addr, text = m.group(1), m.group(2).strip()\n                is_meta = any(text.startswith(p) for p in METADATA_PREFIXES)\n                normalized = text if is_meta else _normalize_instr(text)\n                lines.append(Line(addr, text, normalized, lineno, not is_meta))\n\n    return lines\n\n\n# ── Diffing ──────────────────────────────────────────────────────────────────\n\n@dataclass\nclass DiffBlock:\n    tag: str  # \"equal\", \"replace\", \"insert\", \"delete\"\n    a_lines: list[Line] = field(default_factory=list)\n    b_lines: list[Line] = field(default_factory=list)\n\n\ndef diff_sass(a_lines: list[Line], b_lines: list[Line]) -> list[DiffBlock]:\n    a_norm = [l.normalized for l in a_lines]\n    b_norm = [l.normalized for l in b_lines]\n    sm = SequenceMatcher(None, a_norm, b_norm, autojunk=False)\n    blocks: list[DiffBlock] = []\n    for tag, i1, i2, j1, j2 in sm.get_opcodes():\n        blocks.append(DiffBlock(tag, a_lines[i1:i2], b_lines[j1:j2]))\n    return blocks\n\n\n# ── Display ──────────────────────────────────────────────────────────────────\n\nRED = \"\\033[31m\"\nGREEN = \"\\033[32m\"\nCYAN = \"\\033[36m\"\nDIM = \"\\033[2m\"\nRESET = \"\\033[0m\"\n\n\ndef _fmt(line: Line, prefix: str, color: str, use_color: bool, show_norm: bool) -> str:\n    addr = f\"[{line.addr}]\" if line.addr else \"       \"\n    text = line.normalized if show_norm else line.raw\n    if use_color:\n        return f\"{color}{prefix} {addr:>8s}  {text}{RESET}\"\n    return f\"{prefix} {addr:>8s}  {text}\"\n\n\ndef print_diff(blocks: list[DiffBlock], context: int = 3,\n               use_color: bool = True, show_norm: bool = False):\n    \"\"\"Unified-diff-style output with context.\"\"\"\n    groups: list[list[str]] = []\n    cur: list[str] = []\n    last_changed = False\n\n    for block in blocks:\n        if block.tag == \"equal\":\n            lines = block.a_lines\n            if last_changed:\n                for l in lines[:context]:\n                    cur.append(_fmt(l, \" \", DIM, use_color, show_norm))\n                if len(lines) > 2 * context:\n                    if cur:\n                        groups.append(cur)\n                    cur = []\n                    for l in lines[-context:]:\n                        cur.append(_fmt(l, \" \", DIM, use_color, show_norm))\n                elif len(lines) > context:\n                    for l in lines[context:]:\n                        cur.append(_fmt(l, \" \", DIM, use_color, show_norm))\n            else:\n                for l in lines[-context:]:\n                    cur.append(_fmt(l, \" \", DIM, use_color, show_norm))\n            last_changed = False\n        else:\n            last_changed = True\n            if block.tag in (\"replace\", \"delete\"):\n                for l in block.a_lines:\n                    cur.append(_fmt(l, \"-\", RED, use_color, show_norm))\n            if block.tag in (\"replace\", \"insert\"):\n                for l in block.b_lines:\n                    cur.append(_fmt(l, \"+\", GREEN, use_color, show_norm))\n\n    if cur:\n        groups.append(cur)\n\n    sep = f\"{CYAN}{'─' * 72}{RESET}\" if use_color else \"─\" * 72\n    for i, g in enumerate(groups):\n        if i > 0:\n            print(sep)\n        for line in g:\n            print(line)\n\n\ndef _get_opcode(raw: str) -> str | None:\n    \"\"\"Extract opcode from instruction, skipping predicates and labels.\"\"\"\n    for p in raw.split():\n        if p.startswith(\"@\") or p.startswith(\".L_\"):\n            continue\n        return p\n    return None\n\n\ndef print_summary(a_all: list[Line], b_all: list[Line], blocks: list[DiffBlock]):\n    a_code = [l for l in a_all if l.is_code]\n    b_code = [l for l in b_all if l.is_code]\n\n    n_equal = sum(len(b.a_lines) for b in blocks if b.tag == \"equal\")\n    n_delete = sum(len(b.a_lines) for b in blocks if b.tag in (\"replace\", \"delete\"))\n    n_insert = sum(len(b.b_lines) for b in blocks if b.tag in (\"replace\", \"insert\"))\n    n_changed = sum(1 for b in blocks if b.tag != \"equal\")\n\n    print(f\"  File A: {len(a_code)} instructions\")\n    print(f\"  File B: {len(b_code)} instructions\")\n    print(f\"  Identical (normalized): {n_equal}\")\n    print(f\"  Changed regions: {n_changed}\")\n    print(f\"  Removed: {n_delete}, Added: {n_insert}\")\n\n    def opcode_counts(lines):\n        counts: dict[str, int] = {}\n        for l in lines:\n            op = _get_opcode(l.raw)\n            if op:\n                counts[op] = counts.get(op, 0) + 1\n        return counts\n\n    a_ops, b_ops = opcode_counts(a_code), opcode_counts(b_code)\n    all_ops = sorted(set(a_ops) | set(b_ops))\n    diffs = {op: b_ops.get(op, 0) - a_ops.get(op, 0) for op in all_ops}\n    diffs = {op: d for op, d in diffs.items() if d != 0}\n    if diffs:\n        print(\"\\n  Opcode count changes (B - A):\")\n        for op, d in sorted(diffs.items(), key=lambda x: -abs(x[1])):\n            sign = \"+\" if d > 0 else \"\"\n            print(f\"    {op:30s} {sign}{d}\")\n    else:\n        print(\"\\n  Opcode counts: identical\")\n\n\n# ── Main ─────────────────────────────────────────────────────────────────────\n\ndef main():\n    p = argparse.ArgumentParser(description=\"Compare SASS files ignoring register assignments\")\n    p.add_argument(\"file_a\", help=\"First SASS file\")\n    p.add_argument(\"file_b\", help=\"Second SASS file\")\n    p.add_argument(\"-C\", \"--context\", type=int, default=3, help=\"Context lines (default: 3)\")\n    p.add_argument(\"--no-color\", action=\"store_true\", help=\"Disable color output\")\n    p.add_argument(\"--summary-only\", action=\"store_true\", help=\"Only print summary\")\n    p.add_argument(\"--all\", action=\"store_true\", help=\"Include metadata in diff\")\n    p.add_argument(\"--show-normalized\", action=\"store_true\",\n                   help=\"Show normalized form instead of raw instructions\")\n    args = p.parse_args()\n\n    a_all = parse_sass(args.file_a)\n    b_all = parse_sass(args.file_b)\n\n    if args.all:\n        a_lines, b_lines = a_all, b_all\n    else:\n        a_lines = [l for l in a_all if l.is_code]\n        b_lines = [l for l in b_all if l.is_code]\n\n    blocks = diff_sass(a_lines, b_lines)\n    use_color = not args.no_color and sys.stdout.isatty()\n\n    print(\"=== Summary ===\")\n    print_summary(a_all, b_all, blocks)\n\n    if not args.summary_only:\n        print(\"\\n=== Diff (registers normalized) ===\\n\")\n        print_diff(blocks, args.context, use_color, args.show_normalized)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "training/Dockerfile",
    "content": "# Inspired by https://github.com/anibali/docker-pytorch/blob/master/dockerfiles/1.10.0-cuda11.3-ubuntu20.04/Dockerfile\n# ARG COMPAT=0\nARG PERSONAL=0\n# FROM nvidia/cuda:11.3.1-devel-ubuntu20.04 as base-0\nFROM nvcr.io/nvidia/pytorch:22.12-py3 as base\n\nENV HOST docker\nENV LANG=C.UTF-8 LC_ALL=C.UTF-8\n# https://serverfault.com/questions/683605/docker-container-time-timezone-will-not-reflect-changes\nENV TZ America/Los_Angeles\nRUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone\n\n# git for installing dependencies\n# tzdata to set time zone\n# wget and unzip to download data\n# [2021-09-09] TD: zsh, stow, subversion, fasd are for setting up my personal environment.\n# [2021-12-07] TD: openmpi-bin for MPI (multi-node training)\nRUN apt-get update && apt-get install -y --no-install-recommends \\\n    build-essential \\\n    cmake \\\n    curl \\\n    ca-certificates \\\n    sudo \\\n    less \\\n    htop \\\n    git \\\n    tzdata \\\n    wget \\\n    tmux \\\n    zip \\\n    unzip \\\n    zsh stow subversion fasd \\\n    && rm -rf /var/lib/apt/lists/*\n    # openmpi-bin \\\n\n# Allow running runmpi as root\n# ENV OMPI_ALLOW_RUN_AS_ROOT=1 OMPI_ALLOW_RUN_AS_ROOT_CONFIRM=1\n\n# # Create a non-root user and switch to it\n# RUN adduser --disabled-password --gecos '' --shell /bin/bash user \\\n#     && echo \"user ALL=(ALL) NOPASSWD:ALL\" > /etc/sudoers.d/90-user\n# USER user\n\n# All users can use /home/user as their home directory\nENV HOME=/home/user\nRUN mkdir -p /home/user && chmod 777 /home/user\nWORKDIR /home/user\n\n# Set up personal environment\n# FROM base-${COMPAT} as env-0\nFROM base as env-0\nFROM env-0 as env-1\n# Use ONBUILD so that the dotfiles dir doesn't need to exist unless we're building a personal image\n# https://stackoverflow.com/questions/31528384/conditional-copy-add-in-dockerfile\nONBUILD COPY dotfiles ./dotfiles\nONBUILD RUN cd ~/dotfiles && stow bash zsh tmux && sudo chsh -s /usr/bin/zsh $(whoami)\n# nvcr pytorch image sets SHELL=/bin/bash\nONBUILD ENV SHELL=/bin/zsh\n\nFROM env-${PERSONAL} as packages\n\n# Disable pip cache: https://stackoverflow.com/questions/45594707/what-is-pips-no-cache-dir-good-for\nENV PIP_NO_CACHE_DIR=1\n\n# # apex and pytorch-fast-transformers take a while to compile so we install them first\n# TD [2022-04-28] apex is already installed. In case we need a newer commit:\n# RUN pip install --upgrade --force-reinstall --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" --global-option=\"--fast_multihead_attn\" --global-option=\"--fmha\" --global-option=\"--fast_layer_norm\" --global-option=\"--xentropy\" git+https://github.com/NVIDIA/apex.git#egg=apex\n\n# xgboost conflicts with deepspeed\nRUN pip uninstall -y xgboost && DS_BUILD_UTILS=1 DS_BUILD_FUSED_LAMB=1 pip install deepspeed==0.7.7\n\n# General packages that we don't care about the version\n# zstandard to extract the_pile dataset\n# psutil to get the number of cpu physical cores\n# twine to upload package to PyPI\nRUN pip install pytest matplotlib jupyter ipython ipdb gpustat scikit-learn spacy munch einops opt_einsum fvcore gsutil cmake pykeops zstandard psutil h5py twine gdown \\\n    && python -m spacy download en_core_web_sm\n# hydra\nRUN pip install hydra-core==1.3.1 hydra-colorlog==1.2.0 hydra-optuna-sweeper==1.2.0 pyrootutils rich\n# Core packages\nRUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 triton==2.0.0.dev20221202 wandb==0.13.7 timm==0.6.12 torchmetrics==0.10.3\n# torchmetrics 0.11.0 broke hydra's instantiate\n\n# For MLPerf\nRUN pip install git+https://github.com/mlcommons/logging.git@2.1.0\n\n# Install FlashAttention\nRUN pip install flash-attn==2.6.3\n\n# Install CUDA extensions for fused dense\nRUN pip install git+https://github.com/Dao-AILab/flash-attention@v2.6.3#subdirectory=csrc/fused_dense_lib\n"
  },
  {
    "path": "training/README.md",
    "content": "# Optimized Transformer implementation\nThis repo contains examples of how FlashAttention can be integrated into a model\n(e.g., GPT, ViT) and trained end-to-end. We also provide optimized\nimplementations of other layers (e.g., MLP, LayerNorm, cross-entropy loss,\nrotary embedding). Overall this speeds up training by 3-5x compared to the\nbaseline implementation from Huggingface, reaching up to 189 TFLOPs/sec per A100,\nequivalent to 60.6\\% model FLOPs utilization (we don't need any activation\ncheckpointing). All without changing the model architecture (i.e., no\napproximation).\n\nGoals:\n- Performance: we optimize for model speed and memory, especially on 1-node\n  (e.g., with 8 A100s).\n- Flexibility: we provide optimized building blocks (MLP, attention, LayerNorm),\n  and the model code illustrates how these components can be put together.\n  The training code also aims to be model- & task-agnostic.\n\nNon-goals (and other resources):\n- Support as many models as possible: Huggingface's\n  [transformers](https://github.com/huggingface/transformers) and\n  [timm](https://github.com/rwightman/pytorch-image-models/) are great for this.\n- Large-scale distributed training: our codebase has been used for multi-GPU and multi-node\n  training for models up to 2.7B parameters. However, if you're looking for large-scale distributed\n  training techniques (e.g., pipeline parallelism, tensor parallelism),\n  check out [Megatron-LM](https://github.com/NVIDIA/Megatron-LM/) and\n  [DeepSpeed](https://github.com/microsoft/deepspeed).\n- Inference: we currently focus on training (this might change in the future).\n  If you want fast inference, take a look at\n  [FasterTransformer](https://github.com/NVIDIA/FasterTransformer).\n- Production: this codebase was written during several research projects to validate ideas\n  on speeding up ML models.\n\n## Model Components\n\nThe GPT model is implemented\n[here](https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/models/gpt.py).\nAnd here's an example to construct the GPT3-1.3B model with rotary embedding:\n```python\nfrom transformers.models.gpt2.configuration_gpt2 import GPT2Config\nfrom flash_attn.models.gpt import GPTLMHeadModel\n\nseqlen = 2048\nhidden_dim = 2048\nnheads = 16\nn_layer = 24\nrotary_emb_fraction = 0.5\nconfig = GPT2Config(vocab_size=50257, n_positions=seqlen, n_embd=hidden_dim,\n                    n_layer=n_layer, n_head=nheads, \n                    scale_attn_by_inverse_layer_idx=True, \n                    rotary_emb_fraction=rotary_emb_fraction,\n                    use_flash_attn=True, fused_mlp=True,\n                    fused_bias_fc=True, fused_dropout_add_ln=True, \n                    pad_vocab_size_multiple=8)\nmodel = GPTLMHeadModel(config)\n```\n\nWe provide the following optimized components:\n\n1. FlashAttention: fast and memory-efficient exact attention. This makes\nattention much faster and saves a lot of activation memory. As a result we don't need\nto use any activation checkpointing.\n```sh\npip install flash-attn\n```\n\n2. Fused matmul + bias (forward and backward), and fused matmul + bias + gelu\n(forward and backward), adapted from Apex's\n[FusedDense](https://github.com/NVIDIA/apex/tree/master/apex/fused_dense). We\nmake it work for bfloat16. For best performance, you should use CUDA >= 11.8. CuBLAS versions before\nthis doesn't have the best matmul + bias + gelu performance for bfloat16.\n```sh\ncd ../csrc/fused_dense_lib && pip install .\n```\n3. Optimized cross-entropy loss, adapted from Apex's\n[Xentropy](https://github.com/NVIDIA/apex/tree/master/apex/contrib/xentropy). We make it work for bfloat16 and support in-place backward to save memory.\n```sh\ncd ../csrc/xentropy && pip install .\n```\n4. Fused rotary embedding:\n```sh\ncd ../csrc/rotary && pip install .\n```\n5. Fused dropout + residual + LayerNorm, adapted from Apex's\n[FastLayerNorm](https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm). We add dropout and residual, and make it work for both pre-norm and post-norm architecture.\nThis supports dimensions divisible by 8, up to 6144.\n```sh\ncd ../csrc/layer_norm && pip install .\n```\n\n## Training\n\nWe also provide here training scripts to train GPT2 on Openwebtext and GPT3 on\nThe Pile as examples. Feel free to use the model in your own training setup as\nwell.\n\nWe use [Hydra](https://hydra.cc/) for configuration,\n[Pytorch-Lightning](https://github.com/Lightning-AI/lightning) for training, and\n[Wandb](https://wandb.ai/) for logging.\n\nWe use the template from `https://github.com/ashleve/lightning-hydra-template`.\nPlease read the instructions there to understand the repo structure.\n\n### Requirements\n\nPython 3.8+, Pytorch 1.12+, torchvision, einops, timm, hydra-core,\nhydra-colorlog, python-dotenv, rich, pytorch-lightning, triton, flash-attn.\nWe recommend CUDA 11.8 (e.g., using the Nvidia's Pytorch Docker image from https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch)\n\nWe provide a Dockerfile that lists all the required packages.\n\n### Dataset preparation\n\nRunning the training command would automatically download the datasets\n(Openwebtext, Pile), tokenize with the GPT2 tokenizer, concatenate all the\ntokens, then save this cache to disk. Alternatively, you can also prepare the\ndatasets as a separate step.\n\nThe cached datasets are saved to `${DATA_DIR}/openwebtext` and\n`${DATA_DIR}/the_pile`. If `${DATA_DIR}` is not set, they will be saved to\n`./data/{openwebtext,the_pile}`. \n\n- Openwebtext:\n```sh\nexport PYTHONPATH=$PWD:$PYTHONPATH\npytest -q -s tests/datamodules/test_language_modeling_hf.py -k \"openwebtext\"\n```\nThis takes around 1h on a 64-core CPU. The processed dataset has size 17GB.\n\n- The Pile:\n```sh\nexport PYTHONPATH=$PWD:$PYTHONPATH\npytest -q -s tests/datamodules/test_language_modeling_hf.py -k \"pile\"\n```\nThis takes around 20h on a 64-core CPU. The processed dataset has size 699GB.\n\n### GPT2 training on Openwebtext\nTo train GPT2 on Openwebtext with 8 GPUs:\n```sh\npython run.py experiment=owt/gpt2s-flash trainer.devices=8  # 125M\npython run.py experiment=owt/gpt2m-flash trainer.devices=8  # 355M\npython run.py experiment=owt/gpt2l-flash trainer.devices=8  # 760M\npython run.py experiment=owt/gpt2xl-flash trainer.devices=8  # 1.6B\n```\nThe default parameters are set for 8 x A100 80GB.\n\nTo train with bf16 instead of fp16, add `trainer.precision=bf16`.\n\n### GPT3 training on The Pile\nTo train GPT3 on The Pile with 8 GPUs:\n```sh\npython run.py experiment=pile/gpt3s-flash trainer.devices=8  # 125M\npython run.py experiment=pile/gpt3m-flash trainer.devices=8  # 355M\npython run.py experiment=pile/gpt3l-flash trainer.devices=8  # 760M\npython run.py experiment=pile/gpt3xl-flash trainer.devices=8  # 1.3B\npython run.py experiment=pile/gpt3-2.7B-flash-hdim128 trainer.devices=8  # 2.7B\n```\nThe default parameters are set for 8 x A100 80GB. We train with bf16 by default.\n\nTo train with rotary embedding, run the experiments `pile/gpt3{s,m,l,xl}-flash-rotary`.\n\n### Training options\n\n**Gradient accumulation**: to adjust device batch size to fit into GPU memory\n(the global batch size stays the same, and gradient accumulation is calculated\nautomatically), set `datamodule.batch_size=blah`.\n\n**Multi-node**: to train on multiple nodes, add `trainer.num_nodes=blah`.\n\n**Speed benchmarking**: to print out iteration time, add `+callbacks.speed_monitor.verbose=True`.\n\n**Resumable training**: set a name to the run, and then set `resume=True` when\nyou resume. Training will restart at exactly the same batch.\n```sh\npython run.py experiment=pile/gpt3s-flash trainer.devices=8 name=pile-gpt3s-flash resume=True\n```\n\n## Training speed\n\nWe measure the wallclock training speed on one node with 8 x A100 80GB SXM4 80GB (400W) with NVLink.\n\nFLOPs are calculated using the formula from the [Megatron-LM\npaper](https://arxiv.org/abs/2104.04473) (Section 5.1), except we scale by 3/4\nto get the model FLOPs (instead of hardware FLOPs with activation\ncheckpointing).\n\n\n### GPT2 (sequence length 1024)\n\n![GPT2 speedup](../assets/gpt2_training_efficiency.jpg)\n\nThe implementation in this repo (FlashAttention) is 3-4x faster than the\nbaseline implementation from Huggingface.\n\n### GPT3 (sequence length 2048)\n\n![GPT3 speedup](../assets/gpt3_training_efficiency.jpg)\n\nThe implementation in this repo (FlashAttention) is 3-5x faster than the\nbaseline implementation from Huggingface.\n\nFor the GPT3-2.7B model, we set head dimension to 128 (instead of 80) for better efficiency.\n\nWe include here more details on the training speed with FlashAttention on 8 x\nA100 80GB.\n\n| Model     | Batch size (tokens) | Through put (tokens/sec) | Hours / 1B tokens |\n| --------- | ------------------- | ------------------------ | ----------------- |\n| GPT3-125M | 0.5M                | 1310k                    |              0.21 |\n| GPT3-355M | 0.5M                | 503k                     |              0.55 |\n| GPT3-760M | 0.5M                | 245k                     |              1.13 |\n| GPT3-1.3B | 1M                  | 169k                     |              1.64 |\n| GPT3-2.7B | 1M                  | 85k                      |              3.27 |\n\nAs an example, this means that one can train a GPT3-1.3B model on 26B tokens\n(compute-optimal according to Chinchilla scaling) in about 43 hours on 8 x A100.\n\n## Training quality\n\nWe include here the loss curve for GPT2 on Openwebtext, trained for 200B tokens.\nFor GPT2, the runs with FlashAttention yield the same loss curve as the runs\nwith the baseline implementation from Huggingface for 125M and 355M models. For\nlarger models the baseline implementation just takes too long.\n\n![GPT2 training curve](../assets/gpt2_training_curve.jpg)\n\nWe include here the loss curve for GPT3 on The Pile, trained for 400B tokens.\nThe 125M, 355M, 760M models have batch size 512k tokens so this translates to\n800k training steps, while the 1.3B and 2.7B models have batch size 1M tokens,\nwhich translates to 400k training steps.\n\n![GPT3 training curve](../assets/gpt3_training_curve.jpg)\n"
  },
  {
    "path": "training/configs/callbacks/causality-monitor.yaml",
    "content": "causality-monitor:\n  _target_: src.callbacks.causality_monitor.CausalityMonitor"
  },
  {
    "path": "training/configs/callbacks/default.yaml",
    "content": "# rich_progress_bar:\n#   _target_: pytorch_lightning.callbacks.RichProgressBar\n\nrich_model_summary:\n  _target_: pytorch_lightning.callbacks.RichModelSummary\n\nmodel_checkpoint:\n  _target_: pytorch_lightning.callbacks.ModelCheckpoint\n  monitor: \"val/acc\" # name of the logged metric which determines when model is improving\n  mode: \"max\" # can be \"max\" or \"min\"\n  save_top_k: 1 # save k best models (determined by above metric)\n  save_last: True # additionally always save model from last epoch\n  verbose: False\n  dirpath: ${oc.env:CHECKPOINT_DIR,checkpoints}/${oc.select:name,''}\n  filename: \"epoch_{epoch:03d}\"\n  auto_insert_metric_name: False\n\nearly_stopping:\n  _target_: pytorch_lightning.callbacks.EarlyStopping\n  monitor: \"val/acc\" # name of the logged metric which determines when model is improving\n  mode: \"max\" # can be \"max\" or \"min\"\n  patience: 100 # how many epochs of not improving until training stops\n  min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement\n\nlearning_rate_monitor:\n  _target_: pytorch_lightning.callbacks.LearningRateMonitor\n  logging_interval: step\n\nspeed_monitor:\n  _target_: src.callbacks.speed_monitor.SpeedMonitor\n  intra_step_time: True\n  inter_step_time: True\n  epoch_time: True\n\nloss_scale_monitor:\n  _target_: src.callbacks.loss_scale_monitor.LossScaleMonitor\n\nparams_log:\n  _target_: src.callbacks.params_log.ParamsLog\n  total_params_log: True\n  trainable_params_log: True\n  non_trainable_params_log: True\n\ngpu_affinity:\n  _target_: src.callbacks.gpu_affinity.GpuAffinity\n"
  },
  {
    "path": "training/configs/callbacks/ema.yaml",
    "content": "ema:\n  _target_: src.callbacks.ema.EMACallback\n  decay: ???\n  use_num_updates: False\n"
  },
  {
    "path": "training/configs/callbacks/flop-count.yaml",
    "content": "flop_count:\n  _target_: src.callbacks.flop_count.FlopCount\n  profilers: ['fvcore']\n  input_size: [3, 224, 224]\n  device: null\n"
  },
  {
    "path": "training/configs/callbacks/gpu-monitor.yaml",
    "content": "defaults:\n  - default.yaml\n\ngpu_stats_monitor:\n  _target_: pytorch_lightning.callbacks.GPUStatsMonitor\n  # [2021-08-13] TD: I just want the intra_step_size but it'll error if I\n  # don't have memory_utilization and gpu_utilization.\n  # Maybe I should write a callback with just the intra_step_size.\n  memory_utilization: True\n  gpu_utilization: True\n  intra_step_time: True\n"
  },
  {
    "path": "training/configs/callbacks/model-summary.yaml",
    "content": "model_summary:\n  _target_: pytorch_lightning.callbacks.RichModelSummary\n"
  },
  {
    "path": "training/configs/callbacks/none.yaml",
    "content": ""
  },
  {
    "path": "training/configs/callbacks/norm-monitor.yaml",
    "content": "norm_monitor:\n  _target_: src.callbacks.norm_monitor.NormMonitor\n"
  },
  {
    "path": "training/configs/callbacks/params-log.yaml",
    "content": "params_log:\n  _target_: src.callbacks.params_log.ParamsLog\n  total_params_log: True\n  trainable_params_log: True\n  non_trainable_params_log: True\n"
  },
  {
    "path": "training/configs/callbacks/wandb.yaml",
    "content": "defaults:\n  - default.yaml\n\nwatch_model:\n  _target_: src.callbacks.wandb_callbacks.WatchModel\n  log: \"all\"\n  log_freq: 100\n\nupload_code_as_artifact:\n  _target_: src.callbacks.wandb_callbacks.UploadCodeAsArtifact\n  code_dir: ${work_dir}/src\n\nupload_ckpts_as_artifact:\n  _target_: src.callbacks.wandb_callbacks.UploadCheckpointsAsArtifact\n  ckpt_dir: \"checkpoints/\"\n  upload_best_only: True\n\nlog_f1_precision_recall_heatmap:\n  _target_: src.callbacks.wandb_callbacks.LogF1PrecRecHeatmap\n\nlog_confusion_matrix:\n  _target_: src.callbacks.wandb_callbacks.LogConfusionMatrix\n\nlog_image_predictions:\n  _target_: src.callbacks.wandb_callbacks.LogImagePredictions\n  num_samples: 8\n"
  },
  {
    "path": "training/configs/config.yaml",
    "content": "# @package _global_\n\n# specify here default training configuration\ndefaults:\n  - _self_\n  - trainer: default\n  - optimizer: adamw\n  - scheduler: null\n  - task: sequence-model\n  - model: null\n  - datamodule: null\n  - callbacks: default # set this to null if you don't want to use callbacks\n  - metrics: null\n  - logger: null # set logger here or use command line (e.g. `python run.py logger=wandb`)\n\n  - mode: default\n\n  - experiment: null\n  - hparams_search: null\n\n  # enable color logging\n  - override hydra/hydra_logging: colorlog\n  - override hydra/job_logging: colorlog\n\n# path to original working directory\n# hydra hijacks working directory by changing it to the current log directory,\n# so it's useful to have this path as a special variable\n# https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory\nwork_dir: ${hydra:runtime.cwd}\n\n# path to folder with data\ndata_dir: ${work_dir}/data/\n\n# pretty print config at the start of the run using Rich library\nprint_config: True\n\n# disable python warnings if they annoy you\nignore_warnings: True\n\n# check performance on test set, using the best model achieved during training\n# lightning chooses best model based on metric specified in checkpoint callback\ntest_after_training: True\n\nresume: False\n\n# seed for random number generators in pytorch, numpy and python.random\nseed: null\n\n# name of the run, accessed by loggers\nname: null\n"
  },
  {
    "path": "training/configs/datamodule/openwebtext.yaml",
    "content": "_target_: src.datamodules.language_modeling_hf.LMDataModule\ndataset_name: openwebtext\ndataset_config_name: null\ntokenizer_name: gpt2\ncache_dir: ${oc.env:DATA_DIR,${data_dir}}/openwebtext/cache\nmax_length: 1024\nval_ratio: 0.0005\nval_split_seed: 2357\nadd_eos: True\nbatch_size: 8  # per GPU\nbatch_size_eval: ${eval:${.batch_size} * 2}\nnum_workers: 32  # For preprocessing only\nshuffle: True\npin_memory: True\n__train_len: ${div_up:9035582198, ${.max_length}}\n"
  },
  {
    "path": "training/configs/datamodule/thepile.yaml",
    "content": "_target_: src.datamodules.language_modeling_hf.LMDataModule\ndataset_name: the_pile\ndataset_config_name: null\ntokenizer_name: gpt2\ncache_dir: ${oc.env:DATA_DIR,${data_dir}}/the_pile/cache\nmax_length: 2048\nadd_eos: True\nbatch_size: 4  # per GPU\nbatch_size_eval: ${eval:${.batch_size} * 2}\nnum_workers: 64  # For preprocessing only\nuse_shmem: False\nshuffle: True\npin_memory: True\n__train_len: ${div_up:374337375694, ${.max_length}}\n"
  },
  {
    "path": "training/configs/experiment/owt/base.yaml",
    "content": "# @package _global_\ndefaults:\n  - override /trainer: default # choose trainer from 'configs/trainer/'\n  - override /model: null\n  - override /datamodule: openwebtext\n  # FusedAdam from apex speeds up the optimizer step a bit, for GPT2-small time\n  # per global step (i.e. batch size 512) on 8 A100s goes from 376ms to 368ms.\n  # For GPT2-medium time per global goes from 997ms to 972ms.\n  - override /optimizer: adamw-apex\n  - override /scheduler: linear-warmup\n  - override /callbacks: [default, norm-monitor]\n  - override /metrics: [perplexity, num-tokens]\n  - override /logger: wandb\n\n# all parameters below will be merged with parameters from default configurations set above\n# this allows you to overwrite only specified parameters\n\ntask:\n  _target_: src.tasks.seq.SequenceLMModel\n\nseed: 1111\n\ntrainer:\n  accelerator: gpu\n  devices: 8\n  num_nodes: 1\n  accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${datamodule.batch_size} * ${trainer.num_nodes}}}\n  max_steps: 400000\n  val_check_interval: ${eval:1000 * ${.accumulate_grad_batches}}\n  check_val_every_n_epoch: null  # We don't care about epoch boundary\n  precision: 16\n  gradient_clip_val: 1.0\n  strategy: null\n\ndatamodule:\n  batch_size: 16  # Per GPU\n  batch_size_eval: ${.batch_size}  # Fused dense only support batch size at most 64k\n  max_length: 1024\n  fault_tolerant: True\n  ddp: ${eval:\"${trainer.devices} > 1\"}\n\ntrain:\n  gpu_mem: ${eval:\"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)\"}\n  global_batch_size: 512\n  optimizer:\n    lr: 6e-4\n    weight_decay: 0.1\n  optimizer_param_grouping:\n    bias_weight_decay: False\n    normalization_weight_decay: False\n  scheduler:\n    num_warmup_steps: ${eval:0.01 * ${trainer.max_steps}}\n    num_training_steps: ${trainer.max_steps}\n  loss_fn:\n    # This is faster and uses less memory than torch.nn.CrossEntropyLoss.\n    # It's also more numerically stable if we're using DeepSpeed 16 bits.\n    _target_: flash_attn.losses.cross_entropy.CrossEntropyLoss\n    inplace_backward: True  # to save memory\n\neval:\n  log_on_step: True  # 1 training epoch takes too long, we want to see metrics per train step\n\ncallbacks:\n  model_checkpoint:\n    monitor: val/loss\n    mode: min\n    save_top_k: 3\n    save_last: True\n    every_n_train_steps: 1000\n    dirpath: ${work_dir}/checkpoints/${oc.select:name,''}\n    filename: step_{step}\n    auto_insert_metric_name: False\n  model_checkpoint_progress:\n    _target_: src.callbacks.model_checkpoint.ModelCheckpointMine\n    fault_tolerant: True\n    every_n_train_steps: 50000\n    save_last: False\n    save_top_k: -1  # Save all the checkpoints\n    dirpath: ${..model_checkpoint.dirpath}\n    filename: progress_step_{step}\n    auto_insert_metric_name: False\n  early_stopping: null\n"
  },
  {
    "path": "training/configs/experiment/owt/gpt2l-flash.yaml",
    "content": "# @package _global_\ndefaults:\n  - /experiment/owt/gpt2m-flash.yaml\n  - override /model/gpt2model: gpt2-large\n  # TD [2022-08-03] Surprisingly it's faster to use the ZeRO optimizer than just AdamW.\n  # Still, fairscale is even faster and uses less memory.\n  # I think it's because Pytorch is using ZeRO stage 1 and fairscale is using ZeRO stage 2?\n  # However, fairscale has issues with saving checkpoint (either OOM or very\n  # slow since it goes through the CPU?). Fairscale says Pytorch ZeRO is the\n  # upstream version of OSS\n  # https://github.com/facebookresearch/fairscale/issues/937\n  # Pytorch ZeRO as also very slow for saving checkpoints due to\n  # consolidate_state_dict(), but I've fixed it to save separate checkpoint per GPU.\n  - override /optimizer: adamw-zero\n\n  # FusedAdam doesn't seem to speed things up here, time per global step\n  # (i.e. batch size 512) on 8 A100s is around 2056ms for both AdamW and FusedAdam.\n  # This could be because each GPU is only doing the optimizer step for 1 /\n  # world_size of the parameters.\n  # Maybe the bottleneck here is the NCCL call to exchange parameters (ZeRO).\n  # - override /optimizer: adamw-apex-zero\n\n# Can enable mlp_chekcpoint_lvl to fit batch_size 16 on A100 40GB\n# model:\n#   config:\n#     # mlp_checkpoint_lvl: ${eval:\"[1] * 18 + [2] * 18\"}\n#     mlp_checkpoint_lvl: 1\n\ndatamodule:\n  # batch_size: 16\n  batch_size: ${eval:\"4 if ${train.gpu_mem} < 24 else (8 if ${train.gpu_mem} < 40 else (16 if ${train.gpu_mem} < 80 else 32))\"}\n\ntrainer:\n  # strategy: null\n  # strategy: ${eval:\"None if ${trainer.devices} == 1 else 'ddp_sharded'\"}\n  strategy:\n    _target_: src.utils.ddp_zero1.DDPStrategyZero1\n    find_unused_parameters: False\n    gradient_as_bucket_view: True\n  # TD [2022-08-03] Deepspeed makes the ppl curve go wild\n  # strategy: deepspeed_stage_1\n"
  },
  {
    "path": "training/configs/experiment/owt/gpt2l-hf.yaml",
    "content": "# @package _global_\ndefaults:\n  - /experiment/owt/gpt2m-hf.yaml\n  - override /model/gpt2model: gpt2-large\n  - override /optimizer: adamw-zero\n\ndatamodule:\n  batch_size: 2\n\ntrainer:\n  strategy:\n    _target_: src.utils.ddp_zero1.DDPStrategyZero1\n    find_unused_parameters: False\n    gradient_as_bucket_view: True\n"
  },
  {
    "path": "training/configs/experiment/owt/gpt2l.yaml",
    "content": "# @package _global_\ndefaults:\n  - /experiment/owt/gpt2m.yaml\n  - override /model/gpt2model: gpt2-large\n  - override /optimizer: adamw-zero\n\ndatamodule:\n  batch_size: 4  # Per GPU\n\ntrainer:\n  strategy:\n    _target_: src.utils.ddp_zero1.DDPStrategyZero1\n    find_unused_parameters: False\n    gradient_as_bucket_view: True\n"
  },
  {
    "path": "training/configs/experiment/owt/gpt2m-flash.yaml",
    "content": "# @package _global_\ndefaults:\n  - /experiment/owt/gpt2s-flash.yaml\n  - override /model/gpt2model: gpt2-medium\n\n# Can enable mlp_checkpoint_lvl to fit batch_size 32 to A100 40GB\n# model:\n#   config:\n#     mlp_checkpoint_lvl: 1\n\ndatamodule:\n  # batch_size: 32\n  batch_size: ${eval:\"8 if ${train.gpu_mem} < 24 else (16 if ${train.gpu_mem} < 40 else (32 if ${train.gpu_mem} < 80 else 64))\"}\n\ntrain:\n  optimizer:\n    lr: 1.5e-4\n"
  },
  {
    "path": "training/configs/experiment/owt/gpt2m-hf.yaml",
    "content": "# @package _global_\ndefaults:\n  - /experiment/owt/gpt2s-hf.yaml\n  - override /model/gpt2model: gpt2-medium\n\ndatamodule:\n  batch_size: 4\n\ntrain:\n  optimizer:\n    lr: 1.5e-4\n"
  },
  {
    "path": "training/configs/experiment/owt/gpt2m.yaml",
    "content": "# @package _global_\ndefaults:\n  - /experiment/owt/gpt2s.yaml\n  - override /model/gpt2model: gpt2-medium\n\ndatamodule:\n  batch_size: 8  # Per GPU\n\ntrain:\n  optimizer:\n    lr: 1.5e-4\n"
  },
  {
    "path": "training/configs/experiment/owt/gpt2s-flash.yaml",
    "content": "# @package _global_\ndefaults:\n  - /experiment/owt/base.yaml\n  - override /model: gpt2\n  - override /model/gpt2model: gpt2-small\n\nmodel:\n  config:\n    # n_positions is already set to ${datamodule.max_length}\n    residual_in_fp32: True\n    use_flash_attn: True\n    fused_bias_fc: True\n    fused_mlp: True\n    fused_dropout_add_ln: True\n    pad_vocab_size_multiple: 8\n\ndatamodule:\n  # batch_size: 64\n  batch_size: ${eval:\"16 if ${train.gpu_mem} < 24 else (32 if ${train.gpu_mem} < 40 else 64)\"}\n"
  },
  {
    "path": "training/configs/experiment/owt/gpt2s-hf.yaml",
    "content": "# @package _global_\ndefaults:\n  - /experiment/owt/base.yaml\n  - override /model: gpt2-hf\n  - override /model/gpt2model: gpt2-small\n  - override /callbacks: [default, norm-monitor, flop-count]\n\ndatamodule:\n  batch_size: 8\n\ntrain:\n  # Use the standard torch.nn.CrossEntropyLoss\n  loss_fn: null\n\ncallbacks:\n  flop_count:\n    input_size:\n      - ${datamodule.max_length}\n    input_dtype:\n      # It's surprisingly hard to get hydra to return torch.long since it's not a callable\n      _target_: torch.__getattribute__\n      _args_:\n        - long\n"
  },
  {
    "path": "training/configs/experiment/owt/gpt2s.yaml",
    "content": "# @package _global_\ndefaults:\n  - /experiment/owt/base.yaml\n  - override /model: gpt2\n  - override /model/gpt2model: gpt2-small\n\ndatamodule:\n  batch_size: ${eval:\"4 if ${train.gpu_mem} < 24 else (8 if ${train.gpu_mem} < 40 else 16)\"}\n"
  },
  {
    "path": "training/configs/experiment/owt/gpt2xl-flash.yaml",
    "content": "# @package _global_\ndefaults:\n  - /experiment/owt/gpt2l-flash.yaml\n  - override /model/gpt2model: gpt2-xlarge\n\n# Can enable mlp_checkpoint_lvl to fit to A100 40GB\n# model:\n#   config:\n#     # mlp_checkpoint_lvl: ${eval:\"[1] * 18 + [2] * 18\"}\n#     mlp_checkpoint_lvl: 1\n\ndatamodule:\n  batch_size: ${eval:\"2 if ${train.gpu_mem} < 24 else (4 if ${train.gpu_mem} < 40 else (8 if ${train.gpu_mem} < 80 else 16))\"}\n  # With adamw-zero optimizer, on A100 40GB:\n  # checkpoint_lvl=1, batch size = 4: mem 37GB, 4650ms / batch of 512 (285ms * 15 + 375ms * 1)\n  # checkpoint_lvl=1, batch size = 8: mem 46GB, 4330ms / batch of 512 (530ms * 7 + 620ms * 1)\n  # checkpoint_lvl=2, batch size = 8: mem 41GB, 4570ms / batch of 512 (560ms * 7 + 650ms * 1)\n  # With adamw-apex-distributed optimizer:\n  # checkpoint_lvl=1, batch size = 8: mem 41.5GB, 4500ms / batch of 512 (550ms * 7 + 650ms * 1)\n  # checkpoint_lvl=1 for 24 layers and checkpoint_lvl=2 for 24 layers,\n  # batch size = 8: mem 39GB, 4640ms / batch of 512 (565ms * 7 + 675ms * 1)\n"
  },
  {
    "path": "training/configs/experiment/owt/gpt2xl-hf.yaml",
    "content": "# @package _global_\ndefaults:\n  - /experiment/owt/gpt2l-hf.yaml\n  - override /model/gpt2model: gpt2-xlarge\n\ndatamodule:\n  batch_size: 1\n"
  },
  {
    "path": "training/configs/experiment/owt/gpt2xl.yaml",
    "content": "# @package _global_\ndefaults:\n  - /experiment/owt/gpt2m.yaml\n  - override /model/gpt2model: gpt2-xlarge\n  - override /optimizer: adamw-zero\n\ndatamodule:\n  batch_size: 2  # Per GPU\n\ntrainer:\n  strategy:\n    _target_: src.utils.ddp_zero1.DDPStrategyZero1\n    find_unused_parameters: False\n    gradient_as_bucket_view: True\n"
  },
  {
    "path": "training/configs/experiment/pile/base.yaml",
    "content": "# @package _global_\ndefaults:\n  - override /trainer: default # choose trainer from 'configs/trainer/'\n  - override /model: null\n  - override /datamodule: thepile\n  - override /optimizer: adamw-apex  # slight speedup (1-2%) over Pytorch AdamW\n  - override /scheduler: cosine-warmup-timm\n  - override /callbacks: [default, norm-monitor]\n  - override /metrics: [perplexity, num-tokens]\n  - override /logger: wandb\n\n# all parameters below will be merged with parameters from default configurations set above\n# this allows you to overwrite only specified parameters\n\ntask:\n  _target_: src.tasks.seq.SequenceLMModel\n\nseed: 1111\n\ntrainer:\n  accelerator: gpu\n  devices: 8\n  num_nodes: 1\n  accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${datamodule.batch_size} * ${trainer.num_nodes}}}\n  max_steps: 800000\n  val_check_interval: ${eval:2000 * ${.accumulate_grad_batches}}\n  check_val_every_n_epoch: null  # We don't care about epoch boundary\n  precision: bf16\n  gradient_clip_val: 1.0\n  strategy: null\n\ndatamodule:\n  batch_size: 16  # Per GPU\n  batch_size_eval: ${.batch_size}  # Fused dense only support batch size at most 64k\n  max_length: 2048\n  fault_tolerant: True\n  ddp: ${eval:\"${trainer.devices} > 1\"}\n\ntrain:\n  gpu_mem: ${eval:\"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)\"}\n  global_batch_size: 256\n  optimizer:\n    lr: 6e-4\n    weight_decay: 0.1\n  optimizer_param_grouping:\n    bias_weight_decay: False\n    normalization_weight_decay: False\n  scheduler:\n    t_in_epochs: False\n    t_initial: 600000\n    warmup_lr_init: 1e-6\n    warmup_t: ${eval:0.01 * ${trainer.max_steps}}\n    lr_min: ${eval:0.1 * ${train.optimizer.lr}}\n  loss_fn:\n    # This is faster and uses less memory than torch.nn.CrossEntropyLoss.\n    # It's also more numerically stable if we're using DeepSpeed 16 bits.\n    _target_: flash_attn.losses.cross_entropy.CrossEntropyLoss\n    inplace_backward: True  # to save memory\n\neval:\n  log_on_step: True  # 1 training epoch takes too long, we want to see metrics per train step\n\ncallbacks:\n  model_checkpoint:\n    monitor: val/loss\n    mode: min\n    save_top_k: 3\n    save_last: True\n    every_n_train_steps: 1000\n    dirpath: ${work_dir}/checkpoints/${oc.select:name,''}\n    filename: step_{step}\n    auto_insert_metric_name: False\n  model_checkpoint_progress:\n    _target_: src.callbacks.model_checkpoint.ModelCheckpointMine\n    # fault_tolerant: True  # The .pl_auto_save.ckpt doesn't get saved by all workers\n    every_n_train_steps: 50000\n    save_last: False\n    save_top_k: -1  # Save all the checkpoints\n    dirpath: ${..model_checkpoint.dirpath}\n    filename: progress_step_{step}\n    auto_insert_metric_name: False\n  early_stopping: null\n\n"
  },
  {
    "path": "training/configs/experiment/pile/gpt3-2.7B-flash-8k.yaml",
    "content": "# @package _global_\ndefaults:\n  - /experiment/pile/gpt3xl-flash-8k.yaml\n\nmodel:\n  config:\n    n_embd: 2560\n    n_head: 32\n    n_layer: 32\n    initializer_range: ${eval:\"(2 / (${.n_embd} * 5)) ** 0.5\"}\n    mlp_checkpoint_lvl: 0\n\ndatamodule:\n  batch_size: ${eval:\"1 if ${train.gpu_mem} < 40 else (2 if ${train.gpu_mem} < 80 else 4)\"}\n\ntrain:\n  optimizer:\n    lr: 1.6e-4\n"
  },
  {
    "path": "training/configs/experiment/pile/gpt3-2.7B-flash-hdim128-rotary-8k.yaml",
    "content": "# @package _global_\ndefaults:\n  - /experiment/pile/gpt3xl-flash-rotary-8k.yaml\n\nmodel:\n  config:\n    n_embd: 2560\n    n_head: 20\n    n_layer: 32\n    initializer_range: ${eval:\"(2 / (${.n_embd} * 5)) ** 0.5\"}\n    mlp_checkpoint_lvl: 0\n\ndatamodule:\n  batch_size: ${eval:\"1 if ${train.gpu_mem} < 24 else (2 if ${train.gpu_mem} < 40 else (4 if ${train.gpu_mem} < 80 else 8))\"}\n\ntrain:\n  optimizer:\n    lr: 1.6e-4\n"
  },
  {
    "path": "training/configs/experiment/pile/gpt3-2.7B-flash-hdim128-rotary.yaml",
    "content": "# @package _global_\ndefaults:\n  - /experiment/pile/gpt3xl-flash-rotary.yaml\n\nmodel:\n  config:\n    n_embd: 2560\n    n_head: 20\n    n_layer: 32\n    initializer_range: ${eval:\"(2 / (${.n_embd} * 5)) ** 0.5\"}\n    mlp_checkpoint_lvl: 0\n\ndatamodule:\n  batch_size: ${eval:\"4 if ${train.gpu_mem} < 24 else (8 if ${train.gpu_mem} < 40 else (16 if ${train.gpu_mem} < 80 else 32))\"}\n\ntrain:\n  optimizer:\n    lr: 1.6e-4\n"
  },
  {
    "path": "training/configs/experiment/pile/gpt3-2.7B-flash-hdim128.yaml",
    "content": "# @package _global_\ndefaults:\n  - /experiment/pile/gpt3xl-flash.yaml\n\nmodel:\n  config:\n    n_embd: 2560\n    n_head: 20  # Headdim 128 is faster than headdim 80\n    n_layer: 32\n    initializer_range: ${eval:\"(2 / (${.n_embd} * 5)) ** 0.5\"}\n    mlp_checkpoint_lvl: 0\n\ndatamodule:\n  batch_size: ${eval:\"1 if ${train.gpu_mem} < 40 else (2 if ${train.gpu_mem} < 80 else 4)\"}\n\ntrain:\n  optimizer:\n    lr: 1.6e-4\n"
  },
  {
    "path": "training/configs/experiment/pile/gpt3-2.7B-flash-rotary-8k.yaml",
    "content": "# @package _global_\ndefaults:\n  - /experiment/pile/gpt3xl-flash-rotary-8k.yaml\n\nmodel:\n  config:\n    n_embd: 2560\n    n_head: 32\n    n_layer: 32\n    initializer_range: ${eval:\"(2 / (${.n_embd} * 5)) ** 0.5\"}\n    mlp_checkpoint_lvl: 0\n\ndatamodule:\n  batch_size: ${eval:\"1 if ${train.gpu_mem} < 24 else (2 if ${train.gpu_mem} < 40 else (4 if ${train.gpu_mem} < 80 else 8))\"}\n\ntrain:\n  optimizer:\n    lr: 1.6e-4\n"
  },
  {
    "path": "training/configs/experiment/pile/gpt3-2.7B-flash-rotary.yaml",
    "content": "# @package _global_\ndefaults:\n  - /experiment/pile/gpt3xl-flash-rotary.yaml\n\nmodel:\n  config:\n    n_embd: 2560\n    n_head: 32\n    n_layer: 32\n    initializer_range: ${eval:\"(2 / (${.n_embd} * 5)) ** 0.5\"}\n    mlp_checkpoint_lvl: 0\n\ndatamodule:\n  batch_size: ${eval:\"4 if ${train.gpu_mem} < 24 else (8 if ${train.gpu_mem} < 40 else (16 if ${train.gpu_mem} < 80 else 32))\"}\n\ntrain:\n  optimizer:\n    lr: 1.6e-4\n"
  },
  {
    "path": "training/configs/experiment/pile/gpt3-2.7B-flash.yaml",
    "content": "# @package _global_\ndefaults:\n  - /experiment/pile/gpt3xl-flash.yaml\n\nmodel:\n  config:\n    n_embd: 2560\n    n_head: 32\n    n_layer: 32\n    initializer_range: ${eval:\"(2 / (${.n_embd} * 5)) ** 0.5\"}\n    mlp_checkpoint_lvl: 0\n\ndatamodule:\n  batch_size: ${eval:\"1 if ${train.gpu_mem} < 40 else (2 if ${train.gpu_mem} < 80 else 4)\"}\n\ntrain:\n  optimizer:\n    lr: 1.6e-4\n"
  },
  {
    "path": "training/configs/experiment/pile/gpt3-2.7B-hf-hdim128.yaml",
    "content": "# @package _global_\ndefaults:\n  - /experiment/pile/gpt3xl-hf.yaml\n\nmodel:\n  config:\n    n_embd: 2560\n    n_head: 128\n    n_layer: 32\n\n# OOM on A100 80GB even with batch_size = 1\ndatamodule:\n  batch_size: 1\n\ntrain:\n  optimizer:\n    lr: 1.6e-4\n"
  },
  {
    "path": "training/configs/experiment/pile/gpt3-2.7B-hf.yaml",
    "content": "# @package _global_\ndefaults:\n  - /experiment/pile/gpt3xl-hf.yaml\n\nmodel:\n  config:\n    n_embd: 2560\n    n_head: 32\n    n_layer: 32\n\ndatamodule:\n  batch_size: 1\n\ntrain:\n  optimizer:\n    lr: 1.6e-4\n"
  },
  {
    "path": "training/configs/experiment/pile/gpt3l-flash-8k.yaml",
    "content": "# @package _global_\ndefaults:\n  - /experiment/pile/gpt3l-flash.yaml\n\ndatamodule:\n  max_length: 8192\n  batch_size: ${eval:\"1 if ${train.gpu_mem} < 40 else (2 if ${train.gpu_mem} < 80 else 4)\"}\n\ntrain:\n  global_batch_size: 64\n"
  },
  {
    "path": "training/configs/experiment/pile/gpt3l-flash-rotary-30B.yaml",
    "content": "# @package _global_\ndefaults:\n  - /experiment/pile/gpt3l-flash-rotary.yaml\n\ntrainer:\n  max_steps: 60000\n\ntrain:\n  scheduler:\n    t_initial: ${trainer.max_steps}\n"
  },
  {
    "path": "training/configs/experiment/pile/gpt3l-flash-rotary-8k.yaml",
    "content": "# @package _global_\ndefaults:\n  - /experiment/pile/gpt3l-flash-8k.yaml\n\nmodel:\n  config:\n    max_position_embeddings: 0  # Disable absolute position embedding\n    rotary_emb_fraction: 0.5\n"
  },
  {
    "path": "training/configs/experiment/pile/gpt3l-flash-rotary.yaml",
    "content": "# @package _global_\ndefaults:\n  - /experiment/pile/gpt3l-flash.yaml\n\nmodel:\n  config:\n    max_position_embeddings: 0  # Disable absolute position embedding\n    rotary_emb_fraction: 0.5\n"
  },
  {
    "path": "training/configs/experiment/pile/gpt3l-flash.yaml",
    "content": "# @package _global_\ndefaults:\n  - /experiment/pile/gpt3s-flash.yaml\n  - override /optimizer: adamw-zero\n\nmodel:\n  config:\n    n_embd: 1536\n    n_head: 16\n    n_layer: 24\n    # mlp_checkpoint_lvl: 1  # To fit batch_size 8\n\ndatamodule:\n  batch_size: ${eval:\"2 if ${train.gpu_mem} < 24 else (4 if ${train.gpu_mem} < 40 else (8 if ${train.gpu_mem} < 80 else 16))\"}\n\ntrain:\n  optimizer:\n    lr: 2.5e-4\n\ntrainer:\n  strategy:\n    _target_: src.utils.ddp_zero1.DDPStrategyZero1\n    find_unused_parameters: False\n    gradient_as_bucket_view: True\n"
  },
  {
    "path": "training/configs/experiment/pile/gpt3l-hf.yaml",
    "content": "# @package _global_\ndefaults:\n  - /experiment/pile/gpt3s-hf.yaml\n\nmodel:\n  config:\n    n_embd: 1536\n    n_head: 16\n    n_layer: 24\n\ndatamodule:\n  batch_size: 2\n\ntrain:\n  optimizer:\n    lr: 2.5e-4\n"
  },
  {
    "path": "training/configs/experiment/pile/gpt3m-flash-8k.yaml",
    "content": "# @package _global_\ndefaults:\n  - /experiment/pile/gpt3m-flash.yaml\n\ndatamodule:\n  max_length: 8192\n  batch_size: ${eval:\"2 if ${train.gpu_mem} < 24 else (4 if ${train.gpu_mem} < 40 else 8)\"}\n\ntrain:\n  global_batch_size: 64\n"
  },
  {
    "path": "training/configs/experiment/pile/gpt3m-flash-rotary-30B.yaml",
    "content": "# @package _global_\ndefaults:\n  - /experiment/pile/gpt3m-flash-rotary.yaml\n\ntrainer:\n  max_steps: 60000\n\ntrain:\n  scheduler:\n    t_initial: ${trainer.max_steps}\n"
  },
  {
    "path": "training/configs/experiment/pile/gpt3m-flash-rotary-8k.yaml",
    "content": "# @package _global_\ndefaults:\n  - /experiment/pile/gpt3m-flash-8k.yaml\n\nmodel:\n  config:\n    max_position_embeddings: 0  # Disable absolute position embedding\n    rotary_emb_fraction: 0.5\n"
  },
  {
    "path": "training/configs/experiment/pile/gpt3m-flash-rotary.yaml",
    "content": "# @package _global_\ndefaults:\n  - /experiment/pile/gpt3m-flash.yaml\n\nmodel:\n  config:\n    max_position_embeddings: 0  # Disable absolute position embedding\n    rotary_emb_fraction: 0.5\n"
  },
  {
    "path": "training/configs/experiment/pile/gpt3m-flash.yaml",
    "content": "# @package _global_\ndefaults:\n  - /experiment/pile/gpt3s-flash.yaml\n  - override /model/gpt2model: gpt2-medium\n\n# Can enable mlp_checkpoint_lvl to fit batch_size 16 to A100 40GB\n# model:\n#   config:\n#     mlp_checkpoint_lvl: 1\n\ndatamodule:\n  batch_size: ${eval:\"4 if ${train.gpu_mem} < 24 else (8 if ${train.gpu_mem} < 40 else (16 if ${train.gpu_mem} < 80 else 32))\"}\n\ntrain:\n  optimizer:\n    lr: 3.0e-4\n"
  },
  {
    "path": "training/configs/experiment/pile/gpt3m-hf.yaml",
    "content": "# @package _global_\ndefaults:\n  - /experiment/pile/gpt3s-hf.yaml\n  - override /model/gpt2model: gpt2-medium\n\ndatamodule:\n  batch_size: 4\n\ntrain:\n  optimizer:\n    lr: 3.0e-4\n"
  },
  {
    "path": "training/configs/experiment/pile/gpt3s-flash-8k.yaml",
    "content": "# @package _global_\ndefaults:\n  - /experiment/pile/gpt3s-flash.yaml\n\ndatamodule:\n  max_length: 8192\n  batch_size: ${eval:\"2 if ${train.gpu_mem} < 24 else (4 if ${train.gpu_mem} < 40 else 8)\"}\n\ntrain:\n  global_batch_size: 64\n"
  },
  {
    "path": "training/configs/experiment/pile/gpt3s-flash-rotary-30B.yaml",
    "content": "# @package _global_\ndefaults:\n  - /experiment/pile/gpt3s-flash-rotary.yaml\n\ntrainer:\n  max_steps: 60000\n\ntrain:\n  scheduler:\n    t_initial: ${trainer.max_steps}\n"
  },
  {
    "path": "training/configs/experiment/pile/gpt3s-flash-rotary-8k.yaml",
    "content": "# @package _global_\ndefaults:\n  - /experiment/pile/gpt3s-flash-8k.yaml\n\nmodel:\n  config:\n    max_position_embeddings: 0  # Disable absolute position embedding\n    rotary_emb_fraction: 0.5\n"
  },
  {
    "path": "training/configs/experiment/pile/gpt3s-flash-rotary.yaml",
    "content": "# @package _global_\ndefaults:\n  - /experiment/pile/gpt3s-flash.yaml\n\nmodel:\n  config:\n    max_position_embeddings: 0  # Disable absolute position embedding\n    rotary_emb_fraction: 0.5\n"
  },
  {
    "path": "training/configs/experiment/pile/gpt3s-flash.yaml",
    "content": "# @package _global_\ndefaults:\n  - /experiment/pile/base.yaml\n  - override /model: gpt2\n  - override /model/gpt2model: gpt2-small\n\nmodel:\n  config:\n    # n_positions is already set to ${datamodule.max_length}\n    residual_in_fp32: True\n    use_flash_attn: True\n    fused_dropout_add_ln: True\n    fused_mlp: True\n    fused_bias_fc: True\n    pad_vocab_size_multiple: 8\n\ndatamodule:\n  batch_size: ${eval:\"8 if ${train.gpu_mem} < 24 else (16 if ${train.gpu_mem} < 40 else 32)\"}\n"
  },
  {
    "path": "training/configs/experiment/pile/gpt3s-hf.yaml",
    "content": "# @package _global_\ndefaults:\n  - /experiment/pile/base.yaml\n  - override /model: gpt2-hf\n  - override /model/gpt2model: gpt2-small\n\ndatamodule:\n  batch_size: 8\n\ntrain:\n  # Use the standard torch.nn.CrossEntropyLoss\n  loss_fn: null\n"
  },
  {
    "path": "training/configs/experiment/pile/gpt3xl-flash-8k.yaml",
    "content": "# @package _global_\ndefaults:\n  - /experiment/pile/gpt3xl-flash.yaml\n\ndatamodule:\n  max_length: 8192\n  batch_size: ${eval:\"1 if ${train.gpu_mem} < 40 else (2 if ${train.gpu_mem} < 80 else 4)\"}\n\ntrain:\n  global_batch_size: 128\n"
  },
  {
    "path": "training/configs/experiment/pile/gpt3xl-flash-rotary-60B.yaml",
    "content": "# @package _global_\ndefaults:\n  - /experiment/pile/gpt3xl-flash-rotary.yaml\n\ntrainer:\n  max_steps: 60000\n\ntrain:\n  scheduler:\n    t_initial: ${trainer.max_steps}\n"
  },
  {
    "path": "training/configs/experiment/pile/gpt3xl-flash-rotary-8k.yaml",
    "content": "# @package _global_\ndefaults:\n  - /experiment/pile/gpt3xl-flash-8k.yaml\n\nmodel:\n  config:\n    max_position_embeddings: 0  # Disable absolute position embedding\n    rotary_emb_fraction: 0.5\n"
  },
  {
    "path": "training/configs/experiment/pile/gpt3xl-flash-rotary.yaml",
    "content": "# @package _global_\ndefaults:\n  - /experiment/pile/gpt3xl-flash.yaml\n\nmodel:\n  config:\n    max_position_embeddings: 0  # Disable absolute position embedding\n    rotary_emb_fraction: 0.5\n"
  },
  {
    "path": "training/configs/experiment/pile/gpt3xl-flash.yaml",
    "content": "# @package _global_\ndefaults:\n  - /experiment/pile/gpt3s-flash.yaml\n  - override /optimizer: adamw-zero\n\nmodel:\n  config:\n    n_embd: 2048\n    n_head: 16\n    n_layer: 24\n\ndatamodule:\n  batch_size: ${eval:\"1 if ${train.gpu_mem} < 24 else (2 if ${train.gpu_mem} < 40 else (4 if ${train.gpu_mem} < 80 else 8))\"}\n\ntrain:\n  global_batch_size: 512\n  optimizer:\n    lr: 2.0e-4\n  scheduler:\n    t_initial: 300000\n\ntrainer:\n  strategy:\n    _target_: src.utils.ddp_zero1.DDPStrategyZero1\n    find_unused_parameters: False\n    gradient_as_bucket_view: True\n  max_steps: 400000\n  val_check_interval: ${eval:1000 * ${.accumulate_grad_batches}}\n\ncallbacks:\n  model_checkpoint:\n    every_n_train_steps: 1000\n  model_checkpoint_progress:\n    every_n_train_steps: 12500\n    fault_tolerant: False  # Saving takes too long\n"
  },
  {
    "path": "training/configs/experiment/pile/gpt3xl-hf.yaml",
    "content": "# @package _global_\ndefaults:\n  - /experiment/pile/gpt3s-hf.yaml\n  - override /optimizer: adamw-zero\n\nmodel:\n  config:\n    n_embd: 2048\n    n_head: 16\n    n_layer: 24\n\ndatamodule:\n  batch_size: 2\n\ntrain:\n  global_batch_size: 512\n  optimizer:\n    lr: 2.0e-4\n  scheduler:\n    t_initial: 300000\n\ntrainer:\n  strategy:\n    _target_: src.utils.ddp_zero1.DDPStrategyZero1\n    find_unused_parameters: False\n    gradient_as_bucket_view: True\n  max_steps: 400000\n  val_check_interval: ${eval:1000 * ${.accumulate_grad_batches}}\n\ncallbacks:\n  model_checkpoint:\n    every_n_train_steps: 1000\n  model_checkpoint_progress:\n    every_n_train_steps: 12500\n    fault_tolerant: False  # Saving takes too long\n"
  },
  {
    "path": "training/configs/logger/comet.yaml",
    "content": "# https://www.comet.ml\n\ncomet:\n  _target_: pytorch_lightning.loggers.comet.CometLogger\n  api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable\n  project_name: \"template-tests\"\n  experiment_name: ${name}\n"
  },
  {
    "path": "training/configs/logger/csv.yaml",
    "content": "# csv logger built in lightning\n\ncsv:\n  _target_: pytorch_lightning.loggers.csv_logs.CSVLogger\n  save_dir: \".\"\n  name: \"csv/\"\n  version: ${name}\n  prefix: \"\"\n"
  },
  {
    "path": "training/configs/logger/many_loggers.yaml",
    "content": "# train with many loggers at once\n\ndefaults:\n  # - comet.yaml\n  - csv.yaml\n  # - mlflow.yaml\n  # - neptune.yaml\n  # - tensorboard.yaml\n  - wandb.yaml\n"
  },
  {
    "path": "training/configs/logger/mlflow.yaml",
    "content": "# https://mlflow.org\n\nmlflow:\n  _target_: pytorch_lightning.loggers.mlflow.MLFlowLogger\n  experiment_name: ${name}\n  tracking_uri: null\n  tags: null\n  save_dir: ./mlruns\n  prefix: \"\"\n  artifact_location: null\n"
  },
  {
    "path": "training/configs/logger/neptune.yaml",
    "content": "# https://neptune.ai\n\nneptune:\n  _target_: pytorch_lightning.loggers.neptune.NeptuneLogger\n  api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable\n  project_name: your_name/template-tests\n  close_after_fit: True\n  offline_mode: False\n  experiment_name: ${name}\n  experiment_id: null\n  prefix: \"\"\n"
  },
  {
    "path": "training/configs/logger/tensorboard.yaml",
    "content": "# https://www.tensorflow.org/tensorboard/\n\ntensorboard:\n  _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger\n  save_dir: \"tensorboard/\"\n  name: \"default\"\n  version: ${name}\n  log_graph: False\n  default_hp_metric: True\n  prefix: \"\"\n"
  },
  {
    "path": "training/configs/logger/wandb.yaml",
    "content": "# https://wandb.ai\n\nwandb:\n  _target_: pytorch_lightning.loggers.wandb.WandbLogger\n  project: attention\n  name: ${name}\n  save_dir: \".\"\n  mode: online # set offline to store all logs only locally\n  id: ${oc.select:name} # pass correct id to resume experiment!\n  # entity: \"\"  # set to name of your wandb team or just remove it\n  log_model: False\n  prefix: \"\"\n  job_type: \"train\"\n  group: \"\"\n  tags: []\n"
  },
  {
    "path": "training/configs/metrics/acc.yaml",
    "content": "# @package eval.metrics\nacc:\n  _target_: src.metrics.accuracy.AccuracyMine\n"
  },
  {
    "path": "training/configs/metrics/acc_ignore_index.yaml",
    "content": "# @package eval.metrics\nacc:\n  _target_: torchmetrics.Accuracy\n  ignore_index: -100\n"
  },
  {
    "path": "training/configs/metrics/acctop5.yaml",
    "content": "# @package eval.metrics\nacctop5:\n  _target_: src.metrics.accuracy.AccuracyMine\n  top_k: 5\n"
  },
  {
    "path": "training/configs/metrics/mse.yaml",
    "content": "# @package eval.metrics\nmse:\n  _target_: torchmetrics.MeanSquaredError\n"
  },
  {
    "path": "training/configs/metrics/num-tokens.yaml",
    "content": "# @package eval.metrics\nnum-tokens:\n  _target_: src.metrics.num_tokens.NumTokens\n"
  },
  {
    "path": "training/configs/metrics/perplexity.yaml",
    "content": "# @package eval.metrics\nppl:\n  _target_: src.metrics.perplexity.Perplexity\n"
  },
  {
    "path": "training/configs/mode/debug.yaml",
    "content": "# @package _global_\n\n# run in debug mode with:\n# `python run.py mode=debug`\n\ndefaults:\n  - override /trainer: debug.yaml\n\ndebug_mode: True\n\nhydra:\n  # sets level of all command line loggers to 'DEBUG'\n  verbose: True\n\n  # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/\n  # sets level of only chosen command line loggers to 'DEBUG'\n  # verbose: [src.train, src.utils.utils]\n\n  # sets output paths for all file logs to 'logs/debug/'\n  run:\n    dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/debug/${now:%Y-%m-%d}/${now:%H-%M-%S}\n  sweep:\n    dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/debug/multirun_${now:%Y-%m-%d_%H-%M-%S}\n    subdir: ${hydra.job.num}\n\n# disable rich config printing, since it will be already printed by hydra when `verbose: True`\nprint_config: False\n"
  },
  {
    "path": "training/configs/mode/default.yaml",
    "content": "# @package _global_\n\n# default running mode\n\ndefault_mode: True\n\nhydra:\n  # default output paths for all file logs\n  run:\n    dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/runs/${now:%Y-%m-%d}/${now:%H-%M-%S}\n  sweep:\n    dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/multiruns/${now:%Y-%m-%d_%H-%M-%S}\n    subdir: ${hydra.job.num}\n"
  },
  {
    "path": "training/configs/mode/exp.yaml",
    "content": "# @package _global_\n\n# run in experiment mode with:\n# `python run.py mode=exp name=experiment_name`\n\nexperiment_mode: True\n\n# allows for custom naming of the experiment\nname: ???\n\nhydra:\n  # sets output paths for all file logs to `logs/experiment/name'\n  run:\n    dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/experiments/${name}\n  sweep:\n    dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/experiments/${name}\n    subdir: ${hydra.job.num}\n"
  },
  {
    "path": "training/configs/mode/profile.yaml",
    "content": "# @package _global_\n# Run the Pytorch profiler\n\ntrainer:\n  profiler:\n    _target_: pytorch_lightning.profilers.PyTorchProfiler\n    dirpath: ${hydra.run.dir}\n    schedule:\n      _target_: torch.profiler.schedule\n      wait: 5\n      warmup: 5\n      active: 5\n    use_cuda: True\n  max_steps: 20\n\nlogger:\n  wandb:\n    mode: disabled\n\ncallbacks:\n  model_checkpoint: null\n  model_checkpoint_progress: null\n  early_stopping: null\n\nhydra:\n  # sets output paths for all file logs to 'logs/profile/'\n  run:\n    dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/profile/${now:%Y-%m-%d}/${now:%H-%M-%S}\n  sweep:\n    dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/profile/multirun_${now:%Y-%m-%d_%H-%M-%S}\n    subdir: ${hydra.job.num}\n"
  },
  {
    "path": "training/configs/mode/smoke.yaml",
    "content": "# @package _global_\n# Smoke test: disable logging and model checkpointing\n\nlogger:\n  wandb:\n    mode: disabled\n\ncallbacks:\n  model_checkpoint: null\n  model_checkpoint_progress: null\n\nhydra:\n  # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/\n  # sets level of only chosen command line loggers to 'DEBUG'\n  # verbose: [src.train, src.utils.utils]\n\n  # sets output paths for all file logs to 'logs/debug/'\n  run:\n    dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/debug/${now:%Y-%m-%d}/${now:%H-%M-%S}\n  sweep:\n    dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/debug/multirun_${now:%Y-%m-%d_%H-%M-%S}\n    subdir: ${hydra.job.num}\n"
  },
  {
    "path": "training/configs/model/gpt2-hf.yaml",
    "content": "defaults:\n  - _self_\n  - gpt2model: gpt2-small\n\n_target_: transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel\n_recursive_: True\nconfig:\n  _target_: transformers.GPT2Config\n  # Mistral's config: https://github.com/stanford-crfm/mistral/blob/main/conf/models/gpt2-small.yaml\n  # However, reorder_and_upcast_attn slows things down\n  reorder_and_upcast_attn: false\n  scale_attn_by_inverse_layer_idx: true\n  n_positions: ${datamodule.max_length}\n"
  },
  {
    "path": "training/configs/model/gpt2.yaml",
    "content": "defaults:\n  - _self_\n  - gpt2model: gpt2-small\n\n_target_: flash_attn.models.gpt.GPTLMHeadModel\n_recursive_: True\nconfig:\n  _target_: transformers.GPT2Config\n  # Mistral's config: # https://github.com/stanford-crfm/mistral/blob/main/conf/models/mistral-small.yaml\n  # However, reorder_and_upcast_attn slows things down\n  reorder_and_upcast_attn: false\n  scale_attn_by_inverse_layer_idx: true\n  n_positions: ${datamodule.max_length}\n"
  },
  {
    "path": "training/configs/model/gpt2model/gpt2-large.yaml",
    "content": "# @package _global_\nmodel:\n  config:\n    n_embd: 1280\n    n_head: 20\n    n_layer: 36\n"
  },
  {
    "path": "training/configs/model/gpt2model/gpt2-medium.yaml",
    "content": "# @package _global_\nmodel:\n  config:\n    n_embd: 1024\n    n_head: 16\n    n_layer: 24\n"
  },
  {
    "path": "training/configs/model/gpt2model/gpt2-small.yaml",
    "content": "# @package _global_\nmodel:\n  config:\n    n_embd: 768\n    n_head: 12\n    n_layer: 12\n"
  },
  {
    "path": "training/configs/model/gpt2model/gpt2-xlarge.yaml",
    "content": "# @package _global_\nmodel:\n  config:\n    n_embd: 1600\n    n_head: 25\n    n_layer: 48\n"
  },
  {
    "path": "training/configs/optimizer/adam.yaml",
    "content": "# @package train.optimizer\n_target_: torch.optim.Adam\n"
  },
  {
    "path": "training/configs/optimizer/adamw-apex-distributed.yaml",
    "content": "# @package train.optimizer\n_target_: apex.contrib.optimizers.distributed_fused_adam.DistributedFusedAdam\nadam_w_mode: True\n"
  },
  {
    "path": "training/configs/optimizer/adamw-apex-zero.yaml",
    "content": "# @package train.optimizer\n_target_: torch.distributed.optim.ZeroRedundancyOptimizer\n_recursive_: True\noptimizer_class:\n  _target_: apex.optimizers.FusedAdam\n  _partial_: True\n  adam_w_mode: True\n"
  },
  {
    "path": "training/configs/optimizer/adamw-apex.yaml",
    "content": "# @package train.optimizer\n_target_: apex.optimizers.FusedAdam\nadam_w_mode: True\n"
  },
  {
    "path": "training/configs/optimizer/adamw-zero.yaml",
    "content": "# @package train.optimizer\n_target_: torch.distributed.optim.ZeroRedundancyOptimizer\n_recursive_: True\noptimizer_class:\n  _target_: torch.optim.__getattribute__\n  _args_:\n    - \"AdamW\"\n"
  },
  {
    "path": "training/configs/optimizer/adamw.yaml",
    "content": "# @package train.optimizer\n_target_: torch.optim.AdamW\n"
  },
  {
    "path": "training/configs/optimizer/fusedlamb-ds.yaml",
    "content": "# @package train.optimizer\n_target_: deepspeed.ops.lamb.FusedLamb\n"
  },
  {
    "path": "training/configs/optimizer/fusedlamb.yaml",
    "content": "# @package train.optimizer\n_target_: apex.optimizers.FusedLAMB\n"
  },
  {
    "path": "training/configs/optimizer/sgd.yaml",
    "content": "# @package train.optimizer\n_target_: torch.optim.SGD\n"
  },
  {
    "path": "training/configs/scheduler/cosine-warmup-timm.yaml",
    "content": "# @package train.scheduler\n_target_: src.optim.timm_lr_scheduler.TimmCosineLRScheduler\n"
  },
  {
    "path": "training/configs/scheduler/cosine-warmup.yaml",
    "content": "# @package train.scheduler\n_target_: transformers.get_cosine_schedule_with_warmup\n"
  },
  {
    "path": "training/configs/scheduler/invsqrt.yaml",
    "content": "# @package train.scheduler\n_target_: src.optim.lr_scheduler.InvSqrt\nnum_warmup_steps: ???\n"
  },
  {
    "path": "training/configs/scheduler/linear-warmup.yaml",
    "content": "# @package train.scheduler\n_target_: transformers.get_linear_schedule_with_warmup\n"
  },
  {
    "path": "training/configs/scheduler/multi-step.yaml",
    "content": "# @package train.scheduler\n_target_: torch.optim.lr_scheduler.MultiStepLR\n"
  },
  {
    "path": "training/configs/scheduler/plateau.yaml",
    "content": "# @package _global_\ntrain:\n  scheduler_interval: epoch\n  scheduler_monitor: ???\n  scheduler:\n    _target_: torch.optim.lr_scheduler.ReduceLROnPlateau\n    factor: 0.2  # Decay factor when ReduceLROnPlateau is used\n    patience: 20\n    min_lr: 0.0  # Minimum learning rate during annealing\n"
  },
  {
    "path": "training/configs/scheduler/poly-warmup.yaml",
    "content": "# @package train.scheduler\n_target_: transformers.get_polynomial_decay_schedule_with_warmup\n"
  },
  {
    "path": "training/configs/scheduler/step.yaml",
    "content": "# @package train.scheduler\n_target_: torch.optim.lr_scheduler.StepLR\nstep_size: ???\n"
  },
  {
    "path": "training/configs/task/sequence-model.yaml",
    "content": "_target_: src.tasks.seq.SequenceModel\n"
  },
  {
    "path": "training/configs/trainer/all_params.yaml",
    "content": "_target_: pytorch_lightning.Trainer\n\n# default values for all trainer parameters\ncheckpoint_callback: True\ndefault_root_dir: null\ngradient_clip_val: 0.0\nprocess_position: 0\nnum_nodes: 1\nnum_processes: 1\ngpus: null\nauto_select_gpus: False\ntpu_cores: null\nlog_gpu_memory: null\noverfit_batches: 0.0\ntrack_grad_norm: -1\ncheck_val_every_n_epoch: 1\nfast_dev_run: False\naccumulate_grad_batches: 1\nmax_epochs: 1\nmin_epochs: 1\nmax_steps: null\nmin_steps: null\nlimit_train_batches: 1.0\nlimit_val_batches: 1.0\nlimit_test_batches: 1.0\nval_check_interval: 1.0\nflush_logs_every_n_steps: 100\nlog_every_n_steps: 50\naccelerator: null\nsync_batchnorm: False\nprecision: 32\nweights_summary: \"top\"\nweights_save_path: null\nnum_sanity_val_steps: 2\ntruncated_bptt_steps: null\nresume_from_checkpoint: null\nprofiler: null\nbenchmark: False\ndeterministic: False\nreload_dataloaders_every_epoch: False\nauto_lr_find: False\nreplace_sampler_ddp: True\nterminate_on_nan: False\nauto_scale_batch_size: False\nprepare_data_per_node: True\nplugins: null\namp_backend: \"native\"\namp_level: \"O2\"\nmove_metrics_to_cpu: False\n"
  },
  {
    "path": "training/configs/trainer/ddp.yaml",
    "content": "defaults:\n  - default.yaml\n\naccelerator: gpu\ndevices: 4\nstrategy: ddp\n"
  },
  {
    "path": "training/configs/trainer/debug.yaml",
    "content": "defaults:\n  - default.yaml\n\ngpus: 0\n\nmin_epochs: 1\nmax_epochs: 2\n\n# prints\nweights_summary: \"full\"\nprofiler: null\n\n# debugs\nfast_dev_run: true\nnum_sanity_val_steps: 2\noverfit_batches: 0\nlimit_train_batches: 1.0\nlimit_val_batches: 1.0\nlimit_test_batches: 1.0\ntrack_grad_norm: -1\nterminate_on_nan: true\n"
  },
  {
    "path": "training/configs/trainer/default.yaml",
    "content": "_target_: pytorch_lightning.Trainer\n\n# set `gpu` to train on GPU, null to train on CPU only\naccelerator: null\n\nmin_epochs: 1\nmax_epochs: 1000\n"
  },
  {
    "path": "training/run.py",
    "content": "from typing import Callable\n\nimport dotenv\nimport hydra\nfrom omegaconf import OmegaConf, DictConfig\n\n# load environment variables from `.env` file if it exists\n# recursively searches for `.env` in all folders starting from work dir\ndotenv.load_dotenv(override=True)\n\nOmegaConf.register_new_resolver('eval', eval)\nOmegaConf.register_new_resolver('div_up', lambda x, y: (x + y - 1) // y)\n# Delay the evaluation until we have the datamodule\n# So we want the resolver to yield the same string.\nOmegaConf.register_new_resolver('datamodule', lambda attr: '${datamodule:' + str(attr) + '}')\n\n# Turn on TensorFloat32\nimport torch.backends\ntorch.backends.cuda.matmul.allow_tf32 = True\ntorch.backends.cudnn.allow_tf32 = True\n\n\ndef dictconfig_filter_key(d: DictConfig, fn: Callable) -> DictConfig:\n    \"\"\"Only keep keys where fn(key) is True. Support nested DictConfig.\n    \"\"\"\n    # Using d.items_ex(resolve=False) instead of d.items() since we want to keep the\n    # ${datamodule:foo} unresolved for now.\n    return DictConfig({k: dictconfig_filter_key(v, fn) if isinstance(v, DictConfig) else v\n                       # for k, v in d.items_ex(resolve=False) if fn(k)})\n                       for k, v in d.items() if fn(k)})\n\n\n@hydra.main(config_path=\"configs/\", config_name=\"config.yaml\")\ndef main(config: DictConfig):\n\n    # Remove config keys that start with '__'. These are meant to be used only in computing\n    # other entries in the config.\n    config = dictconfig_filter_key(config, lambda k: not k.startswith('__'))\n\n    # Imports should be nested inside @hydra.main to optimize tab completion\n    # Read more here: https://github.com/facebookresearch/hydra/issues/934\n    from src.train import train\n    from src.eval import evaluate\n    from src.utils import utils\n\n    # A couple of optional utilities:\n    # - disabling python warnings\n    # - forcing debug-friendly configuration\n    # - verifying experiment name is set when running in experiment mode\n    # You can safely get rid of this line if you don't want those\n    utils.extras(config)\n\n    # Pretty print config using Rich library\n    if config.get(\"print_config\"):\n        utils.print_config(config, resolve=True)\n\n    # Train model\n    mode = config.get('mode', 'train')\n    if mode not in ['train', 'eval']:\n        raise NotImplementedError(f'mode {mode} not supported')\n    if mode == 'train':\n        return train(config)\n    elif mode == 'eval':\n        return evaluate(config)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "training/src/callbacks/__init__.py",
    "content": ""
  },
  {
    "path": "training/src/callbacks/causality_monitor.py",
    "content": "\nimport pytorch_lightning as pl\nfrom pytorch_lightning import Callback\nfrom pytorch_lightning.utilities import rank_zero_only\n\nimport torch\nfrom torch.autograd import grad\n\nclass CausalityMonitor(Callback):\n    r\"\"\"Monitor causality of a model by tracking gradient leakage  forward in time.\n    In a fully causal model, dy[k]du[s] ~= 0 for all k < s.\n\n    Args:\n        seq_len (int): Length of the sequence to monitor.\n        input_dim (int): Dimension of the input to monitor. If 0, the callback assumes\n            the task to be language modeling, and skips the embedding layer. If > 0,\n            input_dim is interpreted as the input channel dimension, i.e. D with\n            dummy input of dimension [B, L, D].\n    \n    Notes:\n        This callback assumes that `pl_module.model` has a `net` or `s4seq` attribute,\n        indicating the primary model to monitor. For LMs, `net` or `s4seq` should \n        be after the embedding layer.\n    \"\"\"\n\n    def __init__(self, seq_len: int  = 10, input_dim: int = 0):\n        super().__init__()\n        self.seq_len = seq_len\n        self.input_dim = input_dim\n\n    @rank_zero_only\n    def on_train_epoch_end(self, trainer: \"pl.Trainer\", pl_module: \"pl.LightningModule\") -> None:\n        model = pl_module.model\n\n        with torch.enable_grad():\n            if self.input_dim == 0: \n                # [MP] LongTensors cannot have gradients - we start from post\n                # embedding in the LM case\n                input_dim = model.d_model\n                x = torch.randn((2, self.seq_len, input_dim), \\\n                    requires_grad=True).to(pl_module.device)\n                # [DF] HACK: we need to get the layer that comes after the embedding\n                if hasattr(model, 'net'):\n                    y = model.net(x)\n                else:\n                    y = model.s4seq(x)\n            else:\n                x = torch.randn(1, self.seq_len, self.input_dim, \\\n                    requires_grad=True).to(pl_module.device)\n                y =  model(x)\n\n            stats = {}\n            for i in range(self.seq_len):\n                # total gradients flowing from y_i to x \n                g =  grad(y[0,0,i].mean(), x, retain_graph=True, allow_unused=True)[0]\n                g = g[0,i+1:,:].abs().mean()\n                stats[f'stats/causality_{i}'] = g.item()\n\n        if trainer.loggers is not None:\n            for logger in trainer.loggers:\n                logger.log_metrics(stats, step=trainer.global_step)\n"
  },
  {
    "path": "training/src/callbacks/ema.py",
    "content": "# Inspired by https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/callbacks/stochastic_weight_avg.py\n# https://github.com/PyTorchLightning/Lightning-Bolts/blob/master/pl_bolts/callbacks/byol_updates.py\n# https://forums.pytorchlightning.ai/t/adopting-exponential-moving-average-ema-for-pl-pipeline/488/2\n# https://github.com/PyTorchLightning/pytorch-lightning/issues/8100\n\nfrom typing import Dict, Any\n\nfrom pytorch_lightning import Callback, Trainer\nfrom pytorch_lightning.utilities import rank_zero_only\nfrom pytorch_lightning.utilities.parsing import AttributeDict\nfrom pytorch_lightning.utilities.types import STEP_OUTPUT\n\nfrom src.utils.ema import ExponentialMovingAverage\n\n\nclass EMACallback(Callback):\n    \"\"\"TD [2021-08-31]: saving and loading from checkpoint should work.\n    \"\"\"\n    def __init__(self, decay: float, use_num_updates: bool = True):\n        \"\"\"\n        decay: The exponential decay.\n        use_num_updates: Whether to use number of updates when computing\n            averages.\n        \"\"\"\n        super().__init__()\n        self.decay = decay\n        self.use_num_updates = use_num_updates\n        self.ema = None\n\n    def on_train_start(self, trainer: \"pl.Trainer\", pl_module: \"pl.LightningModule\"):\n        # It's possible that we already loaded EMA from the checkpoint\n        if self.ema is None:\n          self.ema = ExponentialMovingAverage([p for p in pl_module.parameters() if p.requires_grad],\n                                              decay=self.decay, use_num_updates=self.use_num_updates)\n\n    # Ideally we want on_after_optimizer_step but pytorch-lightning doesn't have it\n    # We only want to update when parameters are changing.\n    # Because of gradient accumulation, this doesn't happen every training step.\n    # https://github.com/PyTorchLightning/pytorch-lightning/issues/11688\n    def on_train_batch_end(\n        self,\n        trainer: \"pl.Trainer\",\n        pl_module: \"pl.LightningModule\",\n        outputs: STEP_OUTPUT,\n        batch: Any,\n        batch_idx: int,\n    ) -> None:\n        if (batch_idx + 1) % trainer.accumulate_grad_batches == 0:\n          self.ema.update()\n\n    def on_validation_start(self, trainer: \"pl.Trainer\", pl_module: \"pl.LightningModule\") -> None:\n        # During the initial validation we don't have self.ema yet\n        if self.ema is not None:\n            self.ema.store()\n            self.ema.copy_to()\n\n    def on_validation_end(self, trainer: \"pl.Trainer\", pl_module: \"pl.LightningModule\") -> None:\n        if self.ema is not None:\n            self.ema.restore()\n\n    def on_test_start(self, trainer: \"pl.Trainer\", pl_module: \"pl.LightningModule\") -> None:\n        if self.ema is not None:\n            self.ema.store()\n            self.ema.copy_to()\n\n    def on_test_end(self, trainer: \"pl.Trainer\", pl_module: \"pl.LightningModule\") -> None:\n        if self.ema is not None:\n            self.ema.restore()\n\n    def on_save_checkpoint(\n        self, trainer: \"pl.Trainer\", pl_module: \"pl.LightningModule\", checkpoint: Dict[str, Any]\n    ) -> Dict[str, Any]:\n        return self.ema.state_dict()\n\n    def on_load_checkpoint(\n        self, trainer: \"pl.Trainer\", pl_module: \"pl.LightningModule\",\n        checkpoint: Dict[str, Any]\n    ) -> None:\n        if self.ema is None:\n            self.ema = ExponentialMovingAverage([p for p in pl_module.parameters() if p.requires_grad],\n                                                decay=self.decay, use_num_updates=self.use_num_updates)\n        self.ema.load_state_dict(checkpoint)\n"
  },
  {
    "path": "training/src/callbacks/flop_count.py",
    "content": "# Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/benchmark.py\nfrom typing import Any, List, Sequence\n\nimport torch\n\nfrom pytorch_lightning import Callback, Trainer, LightningModule\nfrom pytorch_lightning.utilities import rank_zero_only\nfrom pytorch_lightning.utilities.parsing import AttributeDict\n\nfrom src.utils.flops import has_deepspeed_profiling, has_fvcore_profiling\nfrom src.utils.flops import profile_deepspeed, profile_fvcore\n\n\nclass FlopCount(Callback):\n    \"\"\"Counter the number of FLOPs used by the model\n    \"\"\"\n    def __init__(self, profilers: List[str] = ['fvcore', 'deepspeed'],\n                 input_size: tuple = (3, 224, 224), input_dtype=torch.float32, device=None):\n        if not isinstance(profilers, Sequence):\n            profilers = [profilers]\n        if any(p not in ['fvcore', 'deepspeed'] for p in profilers):\n            raise NotImplementedError('Only support fvcore and deepspeed profilers')\n        if 'fvcore' in profilers and not has_fvcore_profiling:\n            raise ImportError('fvcore is not installed. Install it by running `pip install fvcore`')\n        elif 'deepspeed' in profilers and not has_deepspeed_profiling:\n            raise ImportError('deepspeed is not installed')\n        super().__init__()\n        self.profilers = profilers\n        self.input_size = tuple(input_size)\n        self.input_dtype = input_dtype\n        self.device = device\n\n    @rank_zero_only\n    def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None:\n        if 'fvcore' in self.profilers:\n            _, macs, _, acts = profile_fvcore(pl_module.to(self.device), input_size=self.input_size,\n                                              input_dtype=self.input_dtype, detailed=True)\n            trainer.logger.log_hyperparams({'GMACs': macs * 1e-9, 'MActs': acts * 1e-6})\n        if 'deepspeed' in self.profilers:\n            macs, _= profile_deepspeed(pl_module.to(self.device), input_size=self.input_size,\n                                       input_dtype=self.input_dtype, detailed=True)\n            if 'fvcore' not in self.profilers:  # fvcore's MACs seem more accurate\n                trainer.logger.log_hyperparams({'GMACs': macs * 1e-9})\n"
  },
  {
    "path": "training/src/callbacks/gpu_affinity.py",
    "content": "import torch\n\nfrom pytorch_lightning import Callback, Trainer, LightningModule\n\nimport logging\n\nlog = logging.getLogger(__name__)  # We want a logger for each process, not just the rank 0\n\n\ndef l2_promote():\n    import ctypes\n    _libcudart = ctypes.CDLL('libcudart.so')\n    # Set device limit on the current device\n    # cudaLimitMaxL2FetchGranularity = 0x05\n    pValue = ctypes.cast((ctypes.c_int*1)(), ctypes.POINTER(ctypes.c_int))\n    _libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128))\n    _libcudart.cudaDeviceGetLimit(pValue, ctypes.c_int(0x05))\n    assert pValue.contents.value == 128\n\n\ndef set_affinity(trainer):\n    try:\n        from src.utils.gpu_affinity import set_affinity\n        nproc_per_node = torch.cuda.device_count()\n        affinity = set_affinity(trainer.local_rank, nproc_per_node, 'socket_unique_continuous')\n        log.info(f'{trainer.local_rank}: thread affinity: {affinity}')\n        # TD [2022-05-07] Somehow calling this causes GPU 0 to allocate extra ~800MB of memory per\n        # number of GPUs (e.g., 6.4GB of extra memory in a 8-GPU setup). H/t Dan.\n        # l2_promote()\n    except:\n        pass\n\n\nclass GpuAffinity(Callback):\n    \"\"\"Set GPU affinity and increase the L2 fetch granularity.\n    Adapted from https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/LanguageModeling/Transformer-XL\n    \"\"\"\n\n    def setup(self, trainer: Trainer, pl_module: LightningModule, stage=None) -> None:\n        set_affinity(trainer)\n"
  },
  {
    "path": "training/src/callbacks/loss_scale_monitor.py",
    "content": "# Adapted from https://github.com/Lightning-AI/lightning/blob/master/src/pytorch_lightning/callbacks/lr_monitor.py.\nfrom typing import Any\n\nfrom pytorch_lightning import Callback, Trainer\nfrom pytorch_lightning.utilities import rank_zero_only\nfrom pytorch_lightning.strategies import DeepSpeedStrategy\n\n\nclass LossScaleMonitor(Callback):\n    \"\"\"Monitor the loss scale for AMP (fp16).\n    \"\"\"\n\n    # Use on_before_optimizer_step instead of on_train_batch_start since there might be\n    # gradient accumulation and we only care about the loss scale when it could change (i.e.,\n    # optimizer.step).\n    @rank_zero_only\n    def on_before_optimizer_step(self, trainer: Trainer, *args: Any, **kwargs: Any) -> None:\n        if not trainer._logger_connector.should_update_logs:\n            return\n        stats = {}\n        if isinstance(trainer.strategy, DeepSpeedStrategy):\n            stats = {'scalar/scale': trainer.model.optimizer.loss_scale}\n        if hasattr(trainer, 'precision_plugin') and hasattr(trainer.precision_plugin, 'scaler'):\n            scaler = trainer.precision_plugin.scaler\n            if scaler is not None:\n                stats = {\n                    'scaler/scale': scaler.get_scale(),\n                    'scaler/growth_tracker': scaler._get_growth_tracker(),\n                }\n        if stats and trainer.loggers is not None:\n            for logger in trainer.loggers:\n                logger.log_metrics(stats, step=trainer.fit_loop.epoch_loop._batches_that_stepped)\n"
  },
  {
    "path": "training/src/callbacks/model_checkpoint.py",
    "content": "# Adapted from https://github.com/Lightning-AI/lightning/blob/master/src/pytorch_lightning/callbacks/fault_tolerance.py\nfrom typing import Any\nfrom pathlib import Path\n\nimport pytorch_lightning as pl\n\n\nclass ModelCheckpointMine(pl.callbacks.model_checkpoint.ModelCheckpoint):\n\n    def __init__(self, *args, fault_tolerant=False, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.fault_tolerant = fault_tolerant\n\n    def on_exception(self, trainer: \"pl.Trainer\", *_: Any, **__: Any) -> None:\n        if self.fault_tolerant:\n            # overwrite if necessary\n            trainer.save_checkpoint(str(Path(self.dirpath) / '.pl_auto_save.ckpt'))\n\n    # def teardown(self, trainer: \"pl.Trainer\", *_: Any, **__: Any) -> None:\n    #     if self.fault_tolerant:\n    #         trainer.strategy.remove_checkpoint(str(Path(self.dirpath) / '.pl_auto_save.ckpt'))\n\n\n# TD [2022-07-17] I was trying to make resuming from standard checkpoint fault-tolerant.\n# However, when it resumes it's off by 1 iteration. My attempt to fix it in seq.py (below) didn't work.\n# So I decided to just copy _FaultToleranceCheckpoint and just save on_exception.\n\n    # def on_save_checkpoint(self, checkpoint):\n    #     # TD [2022-07-12] The \"completed\" counter is off by 1 so when it resumes\n    #     # it's off by 1 iteration. However, the data is still off by 1 iteration, probably\n    #     # because the dataloader_state_dict['counter'] is off by @batch_size, and idk how\n    #     # to fix it cleanly.\n    #     checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['total']['completed'] += 1\n    #     checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['completed'] += 1\n    #     checkpoint['loops']['fit_loop']['epoch_loop.state_dict']['_batches_that_stepped'] += 1\n    #     checkpoint['loops']['fit_loop']['epoch_loop.state_dict']['dataloader_state_dict'][0]['state'][0]['num_batches_fetched'] += 1\n"
  },
  {
    "path": "training/src/callbacks/norm_monitor.py",
    "content": "# Inspired by https://github.com/Lightning-AI/lightning/blob/master/src/pytorch_lightning/utilities/grads.py\n# However, they compute grad at every iteration (I think), and the .item() calls incur a lot of overhead\n# (6-7% slow down on GPT-2 small). Instead we only compute for iterations where we need to log, and don't\n# call .item() explicitly.\n\nfrom typing import Any\nfrom collections import OrderedDict\n\nfrom pytorch_lightning import Callback, Trainer\nfrom pytorch_lightning.utilities import rank_zero_only\nfrom pytorch_lightning.strategies import DeepSpeedStrategy\n\nimport torch\nimport torch.nn as nn\n\ntry:\n    from apex.contrib.layer_norm import FastLayerNorm\nexcept ImportError:\n    FastLayerNorm = None\n\n\nclass NormMonitor(Callback):\n    \"\"\"Monitor the scales of weights and gradients.\n    \"\"\"\n\n    def __init__(self, layer_norm_only: bool = False):\n        super().__init__()\n        self.layer_norm_only = layer_norm_only\n\n    # Use on_before_optimizer_step instead of on_train_batch_start since there might be\n    # gradient accumulation and we only care about  scale when it could change (i.e., optimizer.step).\n    @rank_zero_only\n    def on_before_optimizer_step(self, trainer: Trainer, pl_module, *args: Any, **kwargs: Any) -> None:\n        if not trainer._logger_connector.should_update_logs:\n            return\n        model = pl_module.model\n        named_parameters = {}\n        if self.layer_norm_only:\n            ln_modules = (nn.LayerNorm, nn.Embedding)\n            if FastLayerNorm is not None:\n                ln_modules += (FastLayerNorm,)\n            for mn, m in model.named_modules():\n                if isinstance(m, ln_modules):\n                    for pn, p in m.named_parameters():\n                        fpn = '%s.%s' % (mn, pn) if mn else pn # full param name\n                        named_parameters[fpn] = p\n        else:\n            named_parameters = dict(model.named_parameters())\n\n        if isinstance(trainer.strategy, DeepSpeedStrategy):\n            loss_scale = trainer.model.optimizer.loss_scale\n        else:\n            loss_scale = 1.0\n\n        stats = {}\n        param_l1_norm, grad_l1_norm = [], []\n        for param_name, param in named_parameters.items():\n            param_abs = param.abs()\n            param_abs_mean = param_abs.mean(dtype=torch.float32)\n            stats[f'stats/{param_name}_max'] = param_abs.max()\n            stats[f'stats/{param_name}_mean'] = param_abs_mean\n            param_l1_norm.append(param_abs_mean * param.numel())\n            if param.grad is not None:\n                # If using AMP, gradient is already unscaled by the AMP loss scaler at this point\n                # https://github.com/Lightning-AI/lightning/pull/9606\n                # However, if using DeepSpeed, we need to scale it ourselves\n                param_grad_abs = param.grad.abs()\n                param_grad_abs_mean = param_grad_abs.mean(dtype=torch.float32) / loss_scale\n                stats[f'stats/{param_name}_grad_max'] = param_grad_abs.max() / loss_scale\n                stats[f'stats/{param_name}_grad_mean'] = param_grad_abs_mean\n                grad_l1_norm.append(param_grad_abs_mean * param.grad.numel())\n        stats['total_param_l1_norm'] = torch.stack(param_l1_norm).sum()\n        if grad_l1_norm:\n            stats['total_grad_l1_norm'] = torch.stack(grad_l1_norm).sum()\n        # Sort by params name\n        stats = OrderedDict(sorted(stats.items()))\n        if trainer.loggers is not None:\n            for logger in trainer.loggers:\n                logger.log_metrics(stats, step=trainer.fit_loop.epoch_loop._batches_that_stepped)\n"
  },
  {
    "path": "training/src/callbacks/params_log.py",
    "content": "from typing import Any\n\nfrom pytorch_lightning import Callback, Trainer, LightningModule\nfrom pytorch_lightning.utilities import rank_zero_only\nfrom pytorch_lightning.utilities.parsing import AttributeDict\n\n\nclass ParamsLog(Callback):\n    \"\"\"Log the number of parameters of the model\n    \"\"\"\n    def __init__(self, total_params_log: bool = True, trainable_params_log: bool = True,\n                 non_trainable_params_log: bool = True):\n        super().__init__()\n        self._log_stats = AttributeDict(\n            {\n                'total_params_log': total_params_log,\n                'trainable_params_log': trainable_params_log,\n                'non_trainable_params_log': non_trainable_params_log,\n            }\n        )\n\n    @rank_zero_only\n    def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None:\n        logs = {}\n        if self._log_stats.total_params_log:\n            logs[\"model/params_total\"] = sum(p.numel() for p in pl_module.parameters())\n        if self._log_stats.trainable_params_log:\n            logs[\"model/params_trainable\"] = sum(p.numel() for p in pl_module.parameters()\n                                             if p.requires_grad)\n        if self._log_stats.non_trainable_params_log:\n            logs[\"model/params_not_trainable\"] = sum(p.numel() for p in pl_module.parameters()\n                                                     if not p.requires_grad)\n        if trainer.logger is not None:\n            trainer.logger.log_hyperparams(logs)\n"
  },
  {
    "path": "training/src/callbacks/speed_monitor.py",
    "content": "# Adapted from https://pytorch-lightning.readthedocs.io/en/latest/_modules/pytorch_lightning/callbacks/gpu_stats_monitor.html#GPUStatsMonitor\n# We only need the speed monitoring, not the GPU monitoring\nimport time\nfrom typing import Any\n\nfrom pytorch_lightning import Callback, Trainer\nfrom pytorch_lightning.utilities import rank_zero_only\nfrom pytorch_lightning.utilities.parsing import AttributeDict\nfrom pytorch_lightning.utilities.types import STEP_OUTPUT\n\n\nclass SpeedMonitor(Callback):\n    \"\"\"Monitor the speed of each step and each epoch.\n    \"\"\"\n    def __init__(self, intra_step_time: bool = True, inter_step_time: bool = True,\n                 epoch_time: bool = True, verbose=False):\n        super().__init__()\n        self._log_stats = AttributeDict(\n            {\n                'intra_step_time': intra_step_time,\n                'inter_step_time': inter_step_time,\n                'epoch_time': epoch_time,\n            }\n        )\n        self.verbose = verbose\n\n    def on_train_start(self, trainer: \"pl.Trainer\", pl_module: \"pl.LightningModule\") -> None:\n        self._snap_epoch_time = None\n\n    def on_train_epoch_start(self, trainer: \"pl.Trainer\", pl_module: \"pl.LightningModule\") -> None:\n        self._snap_intra_step_time = None\n        self._snap_inter_step_time = None\n        self._snap_epoch_time = time.time()\n\n    def on_validation_epoch_start(self, trainer: \"pl.Trainer\", pl_module: \"pl.LightningModule\") -> None:\n        self._snap_inter_step_time = None\n\n    def on_test_epoch_start(self, trainer: \"pl.Trainer\", pl_module: \"pl.LightningModule\") -> None:\n        self._snap_inter_step_time = None\n\n    @rank_zero_only\n    def on_train_batch_start(\n        self,\n        trainer: \"pl.Trainer\",\n        pl_module: \"pl.LightningModule\",\n        batch: Any,\n        batch_idx: int,\n    ) -> None:\n        if self._log_stats.intra_step_time:\n            self._snap_intra_step_time = time.time()\n\n        if not trainer._logger_connector.should_update_logs:\n            return\n\n        logs = {}\n        if self._log_stats.inter_step_time and self._snap_inter_step_time:\n            # First log at beginning of second step\n            logs[\"time/inter_step (ms)\"] = (time.time() - self._snap_inter_step_time) * 1000\n\n        if trainer.logger is not None:\n            trainer.logger.log_metrics(logs, step=trainer.global_step)\n\n    @rank_zero_only\n    def on_train_batch_end(\n        self,\n        trainer: \"pl.Trainer\",\n        pl_module: \"pl.LightningModule\",\n        outputs: STEP_OUTPUT,\n        batch: Any,\n        batch_idx: int,\n    ) -> None:\n        if self._log_stats.inter_step_time:\n            self._snap_inter_step_time = time.time()\n\n        if self.verbose and self._log_stats.intra_step_time and self._snap_intra_step_time:\n            pl_module.print(f\"time/intra_step (ms): {(time.time() - self._snap_intra_step_time) * 1000}\")\n\n        if not trainer._logger_connector.should_update_logs:\n            return\n\n        logs = {}\n        if self._log_stats.intra_step_time and self._snap_intra_step_time:\n            logs[\"time/intra_step (ms)\"] = (time.time() - self._snap_intra_step_time) * 1000\n\n        if trainer.logger is not None:\n            trainer.logger.log_metrics(logs, step=trainer.global_step)\n\n    @rank_zero_only\n    def on_train_epoch_end(self, trainer: \"pl.Trainer\", pl_module: \"pl.LightningModule\",) -> None:\n        logs = {}\n        if self._log_stats.epoch_time and self._snap_epoch_time:\n            logs[\"time/epoch (s)\"] = time.time() - self._snap_epoch_time\n        if trainer.logger is not None:\n            trainer.logger.log_metrics(logs, step=trainer.global_step)\n\n"
  },
  {
    "path": "training/src/callbacks/wandb_callbacks.py",
    "content": "import subprocess\nfrom pathlib import Path\nfrom typing import List\n\nimport matplotlib.pyplot as plt\nimport seaborn as sn\nimport torch\nimport wandb\nfrom pytorch_lightning import Callback, Trainer\nfrom pytorch_lightning.loggers import LoggerCollection, WandbLogger\nfrom pytorch_lightning.utilities import rank_zero_only\nfrom sklearn import metrics\nfrom sklearn.metrics import f1_score, precision_score, recall_score\n\n\ndef get_wandb_logger(trainer: Trainer) -> WandbLogger:\n    \"\"\"Safely get Weights&Biases logger from Trainer.\"\"\"\n\n    if trainer.fast_dev_run:\n        raise Exception(\n            \"Cannot use wandb callbacks since pytorch lightning disables loggers in `fast_dev_run=true` mode.\"\n        )\n\n    if isinstance(trainer.logger, WandbLogger):\n        return trainer.logger\n\n    if isinstance(trainer.logger, LoggerCollection):\n        for logger in trainer.logger:\n            if isinstance(logger, WandbLogger):\n                return logger\n\n    raise Exception(\n        \"You are using wandb related callback, but WandbLogger was not found for some reason...\"\n    )\n\n\nclass WatchModel(Callback):\n    \"\"\"Make wandb watch model at the beginning of the run.\"\"\"\n\n    def __init__(self, log: str = \"gradients\", log_freq: int = 100):\n        self.log = log\n        self.log_freq = log_freq\n\n    @rank_zero_only\n    def on_train_start(self, trainer, pl_module):\n        logger = get_wandb_logger(trainer=trainer)\n        logger.watch(model=trainer.model, log=self.log, log_freq=self.log_freq)\n\n\nclass UploadCodeAsArtifact(Callback):\n    \"\"\"Upload all code files to wandb as an artifact, at the beginning of the run.\"\"\"\n\n    def __init__(self, code_dir: str, use_git: bool = True):\n        \"\"\"\n\n        Args:\n            code_dir: the code directory\n            use_git: if using git, then upload all files that are not ignored by git.\n            if not using git, then upload all '*.py' file\n        \"\"\"\n        self.code_dir = code_dir\n        self.use_git = use_git\n\n    @rank_zero_only\n    def on_train_start(self, trainer, pl_module):\n        logger = get_wandb_logger(trainer=trainer)\n        experiment = logger.experiment\n\n        code = wandb.Artifact(\"project-source\", type=\"code\")\n\n        if self.use_git:\n            # get .git folder\n            # https://alexwlchan.net/2020/11/a-python-function-to-ignore-a-path-with-git-info-exclude/\n            git_dir_path = Path(\n                subprocess.check_output([\"git\", \"rev-parse\", \"--git-dir\"]).strip().decode(\"utf8\")\n            ).resolve()\n\n            for path in Path(self.code_dir).resolve().rglob(\"*\"):\n                if (\n                    path.is_file()\n                    # ignore files in .git\n                    and not str(path).startswith(str(git_dir_path))  # noqa: W503\n                    # ignore files ignored by git\n                    and (  # noqa: W503\n                        subprocess.run([\"git\", \"check-ignore\", \"-q\", str(path)]).returncode == 1\n                    )\n                ):\n                    code.add_file(str(path), name=str(path.relative_to(self.code_dir)))\n\n        else:\n            for path in Path(self.code_dir).resolve().rglob(\"*.py\"):\n                code.add_file(str(path), name=str(path.relative_to(self.code_dir)))\n\n        experiment.log_artifact(code)\n\n\nclass UploadCheckpointsAsArtifact(Callback):\n    \"\"\"Upload checkpoints to wandb as an artifact, at the end of run.\"\"\"\n\n    def __init__(self, ckpt_dir: str = \"checkpoints/\", upload_best_only: bool = False):\n        self.ckpt_dir = ckpt_dir\n        self.upload_best_only = upload_best_only\n\n    @rank_zero_only\n    def on_keyboard_interrupt(self, trainer, pl_module):\n        self.on_train_end(trainer, pl_module)\n\n    @rank_zero_only\n    def on_train_end(self, trainer, pl_module):\n        logger = get_wandb_logger(trainer=trainer)\n        experiment = logger.experiment\n\n        ckpts = wandb.Artifact(\"experiment-ckpts\", type=\"checkpoints\")\n\n        if self.upload_best_only:\n            ckpts.add_file(trainer.checkpoint_callback.best_model_path)\n        else:\n            for path in Path(self.ckpt_dir).rglob(\"*.ckpt\"):\n                ckpts.add_file(str(path))\n\n        experiment.log_artifact(ckpts)\n\n\nclass LogConfusionMatrix(Callback):\n    \"\"\"Generate confusion matrix every epoch and send it to wandb.\n    Expects validation step to return predictions and targets.\n    \"\"\"\n\n    def __init__(self):\n        self.preds = []\n        self.targets = []\n        self.ready = True\n\n    def on_sanity_check_start(self, trainer, pl_module) -> None:\n        self.ready = False\n\n    def on_sanity_check_end(self, trainer, pl_module):\n        \"\"\"Start executing this callback only after all validation sanity checks end.\"\"\"\n        self.ready = True\n\n    def on_validation_batch_end(\n        self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx\n    ):\n        \"\"\"Gather data from single batch.\"\"\"\n        if self.ready:\n            self.preds.append(outputs[\"preds\"])\n            self.targets.append(outputs[\"targets\"])\n\n    def on_validation_epoch_end(self, trainer, pl_module):\n        \"\"\"Generate confusion matrix.\"\"\"\n        if self.ready:\n            logger = get_wandb_logger(trainer)\n            experiment = logger.experiment\n\n            preds = torch.cat(self.preds).cpu().numpy()\n            targets = torch.cat(self.targets).cpu().numpy()\n\n            confusion_matrix = metrics.confusion_matrix(y_true=targets, y_pred=preds)\n\n            # set figure size\n            plt.figure(figsize=(14, 8))\n\n            # set labels size\n            sn.set(font_scale=1.4)\n\n            # set font size\n            sn.heatmap(confusion_matrix, annot=True, annot_kws={\"size\": 8}, fmt=\"g\")\n\n            # names should be uniqe or else charts from different experiments in wandb will overlap\n            experiment.log({f\"confusion_matrix/{experiment.name}\": wandb.Image(plt)}, commit=False)\n\n            # according to wandb docs this should also work but it crashes\n            # experiment.log(f{\"confusion_matrix/{experiment.name}\": plt})\n\n            # reset plot\n            plt.clf()\n\n            self.preds.clear()\n            self.targets.clear()\n\n\nclass LogF1PrecRecHeatmap(Callback):\n    \"\"\"Generate f1, precision, recall heatmap every epoch and send it to wandb.\n    Expects validation step to return predictions and targets.\n    \"\"\"\n\n    def __init__(self, class_names: List[str] = None):\n        self.preds = []\n        self.targets = []\n        self.ready = True\n\n    def on_sanity_check_start(self, trainer, pl_module):\n        self.ready = False\n\n    def on_sanity_check_end(self, trainer, pl_module):\n        \"\"\"Start executing this callback only after all validation sanity checks end.\"\"\"\n        self.ready = True\n\n    def on_validation_batch_end(\n        self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx\n    ):\n        \"\"\"Gather data from single batch.\"\"\"\n        if self.ready:\n            self.preds.append(outputs[\"preds\"])\n            self.targets.append(outputs[\"targets\"])\n\n    def on_validation_epoch_end(self, trainer, pl_module):\n        \"\"\"Generate f1, precision and recall heatmap.\"\"\"\n        if self.ready:\n            logger = get_wandb_logger(trainer=trainer)\n            experiment = logger.experiment\n\n            preds = torch.cat(self.preds).cpu().numpy()\n            targets = torch.cat(self.targets).cpu().numpy()\n            f1 = f1_score(targets, preds, average=None)\n            r = recall_score(targets, preds, average=None)\n            p = precision_score(targets, preds, average=None)\n            data = [f1, p, r]\n\n            # set figure size\n            plt.figure(figsize=(14, 3))\n\n            # set labels size\n            sn.set(font_scale=1.2)\n\n            # set font size\n            sn.heatmap(\n                data,\n                annot=True,\n                annot_kws={\"size\": 10},\n                fmt=\".3f\",\n                yticklabels=[\"F1\", \"Precision\", \"Recall\"],\n            )\n\n            # names should be uniqe or else charts from different experiments in wandb will overlap\n            experiment.log({f\"f1_p_r_heatmap/{experiment.name}\": wandb.Image(plt)}, commit=False)\n\n            # reset plot\n            plt.clf()\n\n            self.preds.clear()\n            self.targets.clear()\n\n\nclass LogImagePredictions(Callback):\n    \"\"\"Logs a validation batch and their predictions to wandb.\n    Example adapted from:\n        https://wandb.ai/wandb/wandb-lightning/reports/Image-Classification-using-PyTorch-Lightning--VmlldzoyODk1NzY\n    \"\"\"\n\n    def __init__(self, num_samples: int = 8):\n        super().__init__()\n        self.num_samples = num_samples\n        self.ready = True\n\n    def on_sanity_check_start(self, trainer, pl_module):\n        self.ready = False\n\n    def on_sanity_check_end(self, trainer, pl_module):\n        \"\"\"Start executing this callback only after all validation sanity checks end.\"\"\"\n        self.ready = True\n\n    def on_validation_epoch_end(self, trainer, pl_module):\n        if self.ready:\n            logger = get_wandb_logger(trainer=trainer)\n            experiment = logger.experiment\n\n            # get a validation batch from the validation dat loader\n            val_samples = next(iter(trainer.datamodule.val_dataloader()))\n            val_imgs, val_labels = val_samples\n\n            # run the batch through the network\n            val_imgs = val_imgs.to(device=pl_module.device)\n            logits = pl_module(val_imgs)\n            preds = torch.argmax(logits, dim=-1)\n\n            # log the images as wandb Image\n            experiment.log(\n                {\n                    f\"Images/{experiment.name}\": [\n                        wandb.Image(x, caption=f\"Pred:{pred}, Label:{y}\")\n                        for x, pred, y in zip(\n                            val_imgs[: self.num_samples],\n                            preds[: self.num_samples],\n                            val_labels[: self.num_samples],\n                        )\n                    ]\n                }\n            )\n"
  },
  {
    "path": "training/src/datamodules/datasets/detokenizer.py",
    "content": "# Copied from https://github.com/stanford-crfm/mistral/blob/main/src/corpora/detokenization.py\n# Which was originally from https://github.com/NVIDIA/Megatron-LM/blob/aed2f75e209e525c842aec7c044af7acae2a4614/tasks/zeroshot_gpt/detokenizer.py\n\n\"\"\"\nHandle detokenization for different dataset for zero-shot LM evaluation.\n\"\"\"\nimport re\n\n\ndef wikitext_detokenize(string: str) -> str:\n    \"\"\"\n    Wikitext is whitespace tokenized and we remove these whitespaces.\n    Taken from https://github.com/NVIDIA/Megatron-LM/blob/main/tasks/zeroshot_gpt2/detokenizer.py\n    \"\"\"\n    # Contractions\n    string = string.replace(\"s '\", \"s'\")\n    string = re.sub(r\"/' [0-9]/\", r\"/'[0-9]/\", string)\n\n    # Number Separators\n    string = string.replace(\" @-@ \", \"-\")\n    string = string.replace(\" @,@ \", \",\")\n    string = string.replace(\" @.@ \", \".\")\n\n    # Punctuation\n    string = string.replace(\" : \", \": \")\n    string = string.replace(\" ; \", \"; \")\n    string = string.replace(\" . \", \". \")\n    string = string.replace(\" ! \", \"! \")\n    string = string.replace(\" ? \", \"? \")\n    string = string.replace(\" , \", \", \")\n\n    # Double Brackets\n    string = re.sub(r\"\\(\\s*([^\\)]*?)\\s*\\)\", r\"(\\1)\", string)\n    string = re.sub(r\"\\[\\s*([^\\]]*?)\\s*\\]\", r\"[\\1]\", string)\n    string = re.sub(r\"{\\s*([^}]*?)\\s*}\", r\"{\\1}\", string)\n    string = re.sub(r\"\\\"\\s*([^\\\"]*?)\\s*\\\"\", r'\"\\1\"', string)\n    string = re.sub(r\"'\\s*([^']*?)\\s*'\", r\"'\\1'\", string)\n\n    # Miscellaneous\n    string = string.replace(\"= = = =\", \"====\")\n    string = string.replace(\"= = =\", \"===\")\n    string = string.replace(\"= =\", \"==\")\n    string = string.replace(\" \" + chr(176) + \" \", chr(176))\n    string = string.replace(\" \\n\", \"\\n\")\n    string = string.replace(\"\\n \", \"\\n\")\n    string = string.replace(\" N \", \" 1 \")\n    string = string.replace(\" 's\", \"'s\")\n\n    return string\n\n\n# Set Registry for Various Datasets\nDATASET_TOKENIZATION_REGISTRY = {\"wikitext\": wikitext_detokenize}\n"
  },
  {
    "path": "training/src/datamodules/datasets/lm_dataset.py",
    "content": "# Inspired by https://github.com/NVIDIA/Megatron-LM/blob/main/tasks/zeroshot_gpt/datasets.py\n# Except we don't pad the last block and don't use overlapping eval\n# And we return both the input and the target\nimport math\nimport numpy as np\n\nimport torch\n\n\nclass LMDataset(torch.utils.data.Dataset):\n\n    def __init__(self, tokens, seq_len, drop_last=True):\n        \"\"\"tokens should be a numpy array\n        \"\"\"\n        self.seq_len = seq_len\n        ntokens = len(tokens)\n        if drop_last:\n            ntokens = ((ntokens - 1) // seq_len) * seq_len + 1\n        self.ntokens = ntokens\n        # We're careful not to slice tokens, since it could be a memmap'ed array or H5 dataset,\n        # and slicing would load it to memory.\n        self.tokens = tokens\n        self.total_sequences = math.ceil((self.ntokens - 1) / self.seq_len)\n\n    def __len__(self):\n        return self.total_sequences\n\n    def __getitem__(self, idx):\n        start_idx = idx * self.seq_len\n        seq_len = min(self.seq_len, self.ntokens - 1 - start_idx)\n        data = torch.as_tensor(self.tokens[start_idx:(start_idx + seq_len + 1)].astype(np.int64))\n        return data[:-1], data[1:].clone()\n"
  },
  {
    "path": "training/src/datamodules/fault_tolerant_sampler.py",
    "content": "# Adapted from https://github.com/Lightning-AI/lightning/blob/2845e7565dbe6b765ae32870e7d2bc456529c30a/tests/tests_pytorch/utilities/test_auto_restart.py#L1397\nfrom typing import Iterator\nimport math\n\nimport torch\nfrom torch.utils.data import RandomSampler, DistributedSampler\n\n\nclass RandomFaultTolerantSampler(RandomSampler):\n\n    def __init__(self, *args, generator=None, **kwargs):\n        # generator = torch.Generator().manual_seed(seed)\n        # super().__init__(*args, generator=generator, **kwargs)\n        # TD [2022-07-17]: We don't force the seed to be zero. We generate random seed,\n        # which should be reproducible if pl.seed_everything was called before hand.\n        # This means that changing the seed of the experiment will also change the\n        # sampling order.\n        if generator is None:\n            seed = int(torch.empty((), dtype=torch.int64).random_().item())\n            generator = torch.Generator().manual_seed(seed)\n        super().__init__(*args, generator=generator, **kwargs)\n        self.counter = 0\n        # self.start_counter = 0\n        self.restarting = False\n\n    def state_dict(self):\n        return {\"random_state\": self.state, \"counter\": self.counter}\n\n    def load_state_dict(self, state_dict):\n        self.generator.set_state(state_dict.get(\"random_state\"))\n        self.counter = state_dict[\"counter\"]\n        # self.start_counter = self.counter\n        self.restarting = True\n\n    # TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per\n    # epoch, and subsequent epoch will have very few batches.\n    # def __len__(self):\n    #     # We need a separate self.start_counter because PL seems to call len repeatedly.\n    #     # If we use len(self.data_source) - self.counter then PL will think the epoch ends\n    #     # when we're only half way through.\n    #     return len(self.data_source) - self.start_counter\n\n    def __iter__(self) -> Iterator[int]:\n        n = len(self.data_source)\n\n        self.state = self.generator.get_state()\n        indices = torch.randperm(n, generator=self.generator).tolist()\n\n        if not self.restarting:\n            self.counter = 0\n        else:\n            indices = indices[self.counter:]\n            self.restarting = False\n        # self.start_counter = self.counter\n\n        for index in indices:\n            self.counter += 1\n            yield index\n\n        self.counter = 0\n        # self.start_counter = self.counter\n\n\nclass FaultTolerantDistributedSampler(DistributedSampler):\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.counter = 0\n        # self.start_counter = 0\n        self.restarting = False\n\n    def state_dict(self):\n        return {\"epoch\": self.epoch, \"counter\": self.counter}\n\n    def load_state_dict(self, state_dict):\n        self.epoch = state_dict[\"epoch\"]\n        self.counter = state_dict[\"counter\"]\n        # self.start_counter = self.counter\n        self.restarting = True\n\n    # TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per\n    # epoch, and subsequent epoch will have very few batches.\n    # def __len__(self) -> int:\n        # return self.num_samples - self.start_counter\n\n    def __iter__(self):\n        if self.shuffle:\n            # deterministically shuffle based on epoch and seed\n            g = torch.Generator()\n            g.manual_seed(self.seed + self.epoch)\n            indices = torch.randperm(len(self.dataset), generator=g).tolist()  # type: ignore[arg-type]\n        else:\n            indices = list(range(len(self.dataset)))  # type: ignore[arg-type]\n\n        if not self.drop_last:\n            # add extra samples to make it evenly divisible\n            padding_size = self.total_size - len(indices)\n            if padding_size <= len(indices):\n                indices += indices[:padding_size]\n            else:\n                indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]\n        else:\n            # remove tail of data to make it evenly divisible.\n            indices = indices[:self.total_size]\n        assert len(indices) == self.total_size\n\n        # subsample\n        indices = indices[self.rank:self.total_size:self.num_replicas]\n        assert len(indices) == self.num_samples\n\n        if not self.restarting:\n            self.counter = 0\n        else:\n            indices = indices[self.counter:]\n            self.restarting = False\n        # self.start_counter = self.counter\n\n        for index in indices:\n            self.counter += 1\n            yield index\n\n        self.counter = 0\n        # self.start_counter = self.counter\n"
  },
  {
    "path": "training/src/datamodules/imagenet.py",
    "content": "# Adapted from https://github.com/PyTorchLightning/lightning-bolts/blob/master/pl_bolts/datamodules/imagenet_datamodule.py\nimport os\nfrom pathlib import Path\nfrom typing import Any, List, Union, Callable, Optional\n\nimport torch\nfrom torch.utils.data import Dataset, DataLoader, SequentialSampler\nfrom torch.utils.data.dataloader import default_collate\nfrom torch.utils.data.distributed import DistributedSampler\n\nfrom pytorch_lightning import LightningDataModule\n\nfrom torchvision import transforms\nfrom torchvision.datasets import ImageFolder\n\n\nclass DictDataset(Dataset):\n\n    def __init__(self, dataset_dict, length=None):\n        \"\"\"dataset_dict: dictionary mapping from index to batch\n        length is used in the case of DistributedSampler: e.g. the dataset could have size 1k, but\n        with 8 GPUs the dataset_dict would only have 125 items.\n        \"\"\"\n        super().__init__()\n        self.dataset_dict = dataset_dict\n        self.length = length or len(self.dataset_dict)\n\n    def __getitem__(self, index):\n        return self.dataset_dict[index]\n\n    def __len__(self):\n        return self.length\n\n\n# From https://github.com/PyTorchLightning/lightning-bolts/blob/2415b49a2b405693cd499e09162c89f807abbdc4/pl_bolts/transforms/dataset_normalizations.py#L10\ndef imagenet_normalization():\n    return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n\n\nclass ImagenetDataModule(LightningDataModule):\n    \"\"\"\n    .. figure:: https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/wp-content/uploads/2017/08/\n        Sample-of-Images-from-the-ImageNet-Dataset-used-in-the-ILSVRC-Challenge.png\n        :width: 400\n        :alt: Imagenet\n    Specs:\n        - 1000 classes\n        - Each image is (3 x varies x varies) (here we default to 3 x 224 x 224)\n    Imagenet train, val and test dataloaders.\n    The train set is the imagenet train.\n    The val set is taken from the train set with `num_imgs_per_val_class` images per class.\n    For example if `num_imgs_per_val_class=2` then there will be 2,000 images in the validation set.\n    The test set is the official imagenet validation set.\n     Example::\n        from pl_bolts.datamodules import ImagenetDataModule\n        dm = ImagenetDataModule(IMAGENET_PATH)\n        model = LitModel()\n        Trainer().fit(model, datamodule=dm)\n    \"\"\"\n\n    name = \"imagenet\"\n\n    def __init__(\n        self,\n        data_dir: str,\n        image_size: int = 224,\n        train_transforms=None,\n        val_transforms=None,\n        test_transforms=None,\n        img_dtype='float32',  # Using str since OmegaConf doesn't support non-primitive type\n        cache_val_dataset=False,\n        mixup: Optional[Callable] = None,\n        num_aug_repeats: int = 0,\n        num_workers: int = 0,\n        batch_size: int = 32,\n        batch_size_eval: Optional[int] = None,\n        shuffle: bool = True,\n        pin_memory: bool = True,\n        drop_last: bool = False,\n        *args: Any,\n        **kwargs: Any,\n    ) -> None:\n        \"\"\"\n        Args:\n            data_dir: path to the imagenet dataset file\n            num_imgs_per_val_class: how many images per class for the validation set\n            image_size: final image size\n            num_workers: how many data workers\n            batch_size: batch_size\n            shuffle: If true shuffles the data every epoch\n            pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before\n                        returning them\n            drop_last: If true drops the last incomplete batch\n        \"\"\"\n        super().__init__(*args, **kwargs)\n\n        self.image_size = image_size\n        self.train_transforms = train_transforms\n        self.val_transforms = val_transforms\n        self.test_transforms = test_transforms\n        assert img_dtype in ['float32', 'float16', 'bfloat16']\n        self.img_dtype = torch.__getattribute__(img_dtype)\n        self.cache_val_dataset = cache_val_dataset\n        self.mixup = mixup\n        self.num_aug_repeats = num_aug_repeats\n        self.dims = (3, self.image_size, self.image_size)\n        self.data_dir = Path(data_dir).expanduser()\n        self.num_workers = num_workers\n        self.batch_size = batch_size\n        self.batch_size_eval = batch_size_eval if batch_size_eval is not None else self.batch_size\n        self.shuffle = shuffle\n        self.pin_memory = pin_memory\n        self.drop_last = drop_last\n\n    @property\n    def num_classes(self) -> int:\n        \"\"\"\n        Return:\n            1000\n        \"\"\"\n        return 1000\n\n    def _verify_splits(self, data_dir: str, split: str) -> None:\n        dirs = os.listdir(data_dir)\n\n        if split not in dirs:\n            raise FileNotFoundError(\n                f\"a {split} Imagenet split was not found in {data_dir},\"\n                f\" make sure the folder contains a subfolder named {split}\"\n            )\n\n    def prepare_data(self) -> None:\n        \"\"\"This method already assumes you have imagenet2012 downloaded. It validates the data using the meta.bin.\n        .. warning:: Please download imagenet on your own first.\n        \"\"\"\n        self._verify_splits(self.data_dir, \"train\")\n        self._verify_splits(self.data_dir, \"val\")\n\n    def setup(self, stage: Optional[str] = None) -> None:\n        \"\"\"Creates train, val, and test dataset.\"\"\"\n        if stage == \"fit\" or stage is None:\n            train_transforms = (self.train_transform() if self.train_transforms is None\n                                else self.train_transforms)\n            val_transforms = (self.val_transform() if self.val_transforms is None\n                              else self.val_transforms)\n            if self.img_dtype is not torch.float32:\n                assert isinstance(train_transforms, transforms.Compose)\n                assert isinstance(val_transforms, transforms.Compose)\n                convert_dtype = transforms.Lambda(lambda x: x.to(dtype=self.img_dtype))\n                train_transforms.transforms.append(convert_dtype)\n                val_transforms.transforms.append(convert_dtype)\n            self.dataset_train = ImageFolder(self.data_dir / 'train', transform=train_transforms)\n            self.dataset_val = ImageFolder(self.data_dir / 'val', transform=val_transforms)\n\n        if stage == \"test\" or stage is None:\n            test_transforms = (self.val_transform() if self.test_transforms is None\n                               else self.test_transforms)\n            if self.img_dtype is not torch.float32:\n                assert isinstance(test_transforms, transforms.Compose)\n                convert_dtype = transforms.Lambda(lambda x: x.to(dtype=self.img_dtype))\n                test_transforms.transforms.append(convert_dtype)\n            self.dataset_test = ImageFolder(self.data_dir / 'val', transform=test_transforms)\n\n    def train_transform(self) -> Callable:\n        \"\"\"The standard imagenet transforms.\n        .. code-block:: python\n            transforms.Compose([\n                transforms.RandomResizedCrop(self.image_size),\n                transforms.RandomHorizontalFlip(),\n                transforms.ToTensor(),\n                transforms.Normalize(\n                    mean=[0.485, 0.456, 0.406],\n                    std=[0.229, 0.224, 0.225]\n                ),\n            ])\n        \"\"\"\n        preprocessing = transforms.Compose(\n            [\n                transforms.RandomResizedCrop(self.image_size),\n                transforms.RandomHorizontalFlip(),\n                transforms.ToTensor(),\n                imagenet_normalization(),\n            ]\n        )\n\n        return preprocessing\n\n    def val_transform(self) -> Callable:\n        \"\"\"The standard imagenet transforms for validation.\n        .. code-block:: python\n            transforms.Compose([\n                transforms.Resize(self.image_size + 32),\n                transforms.CenterCrop(self.image_size),\n                transforms.ToTensor(),\n                transforms.Normalize(\n                    mean=[0.485, 0.456, 0.406],\n                    std=[0.229, 0.224, 0.225]\n                ),\n            ])\n        \"\"\"\n\n        preprocessing = transforms.Compose(\n            [\n                transforms.Resize(self.image_size + 32),\n                transforms.CenterCrop(self.image_size),\n                transforms.ToTensor(),\n                imagenet_normalization(),\n            ]\n        )\n        return preprocessing\n\n    def train_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader:\n        \"\"\" The train dataloader \"\"\"\n        if self.num_aug_repeats == 0:\n            shuffle = self.shuffle\n            sampler = None\n        else:\n            shuffle = False\n            from timm.data.distributed_sampler import RepeatAugSampler\n            sampler = RepeatAugSampler(self.dataset_train, num_repeats=self.num_aug_repeats)\n        return self._data_loader(self.dataset_train, batch_size=self.batch_size,\n                                 shuffle=shuffle, mixup=self.mixup, sampler=sampler)\n\n    def val_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]:\n        \"\"\" The val dataloader \"\"\"\n        # If using RepeatAugment, we set trainer.replace_sampler_ddp=False, so we have to\n        # construct the DistributedSampler ourselves.\n        if not self.cache_val_dataset:\n            sampler = (DistributedSampler(self.dataset_val, shuffle=False, drop_last=self.drop_last)\n                       if self.num_aug_repeats != 0 else None)\n            return self._data_loader(self.dataset_val, batch_size=self.batch_size_eval,\n                                     sampler=sampler)\n        else:\n            print('Caching val dataset')\n            sampler = (SequentialSampler(self.dataset_val) if self.trainer.world_size <= 1\n                       else DistributedSampler(self.dataset_val, shuffle=False,\n                                               drop_last=self.drop_last))\n            indices = list(iter(sampler))\n            loader = DataLoader(self.dataset_val, batch_size=None, shuffle=False, sampler=sampler,\n                                num_workers=self.num_workers, drop_last=self.drop_last)\n            batches = list(loader)\n            assert len(batches) == len(indices)\n            self.dataset_val = DictDataset(dict(zip(indices, batches)),\n                                           length=len(self.dataset_val))\n            sampler = (DistributedSampler(self.dataset_val, shuffle=False, drop_last=self.drop_last)\n                       if self.num_aug_repeats != 0 else None)\n            return self._data_loader(self.dataset_val, batch_size=self.batch_size_eval,\n                                     sampler=sampler)\n\n    def test_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]:\n        \"\"\" The test dataloader \"\"\"\n        sampler = (DistributedSampler(self.dataset_test, shuffle=False, drop_last=self.drop_last)\n                   if self.num_aug_repeats != 0 else None)\n        return self._data_loader(self.dataset_test, batch_size=self.batch_size_eval, sampler=sampler)\n\n    def _data_loader(self, dataset: Dataset, batch_size: int, shuffle: bool = False,\n                     mixup: Optional[Callable] = None, sampler=None) -> DataLoader:\n        collate_fn = ((lambda batch: mixup(*default_collate(batch))) if mixup is not None\n                      else default_collate)\n        return DataLoader(\n            dataset,\n            collate_fn=collate_fn,\n            batch_size=batch_size,\n            shuffle=shuffle,\n            sampler=sampler,\n            num_workers=self.num_workers,\n            drop_last=self.drop_last,\n            pin_memory=self.pin_memory,\n            persistent_workers=True\n        )\n\n\nclass Imagenet21kPDataModule(ImagenetDataModule):\n    \"\"\"ImageNet-21k (winter 21) processed with https://github.com/Alibaba-MIIL/ImageNet21K\n    \"\"\"\n\n    @property\n    def num_classes(self) -> int:\n        \"\"\"\n        Return:\n            10450\n        \"\"\"\n        return 10450\n"
  },
  {
    "path": "training/src/datamodules/language_modeling_hf.py",
    "content": "# Adapted from https://github.com/huggingface/transformers/blob/master/examples/pytorch/language-modeling/run_clm.py\nfrom itertools import chain\nfrom pathlib import Path\nimport pickle\nfrom typing import Any, List, Union\nimport subprocess\nimport mmap\n\nfrom multiprocessing.shared_memory import SharedMemory\n\nimport numpy as np\n\nimport torch\nfrom torch.utils.data.dataloader import DataLoader, Dataset\nfrom transformers import AutoTokenizer\nfrom datasets import load_dataset\n\nfrom pytorch_lightning import LightningDataModule\n\nfrom src.datamodules.datasets.lm_dataset import LMDataset\nfrom src.datamodules.fault_tolerant_sampler import RandomFaultTolerantSampler\nfrom src.datamodules.fault_tolerant_sampler import FaultTolerantDistributedSampler\nfrom src.datamodules.datasets.detokenizer import DATASET_TOKENIZATION_REGISTRY\nfrom src.utils.utils import get_logger\nlogger = get_logger()\n\n\n# https://github.com/numpy/numpy/issues/18294\nclass SHMArray(np.ndarray): #copied from https://numpy.org/doc/stable/user/basics.subclassing.html#slightly-more-realistic-example-attribute-added-to-existing-array\n\n    def __new__(cls, input_array, shm=None):\n        obj = np.asarray(input_array).view(cls)\n        obj.shm = shm\n        return obj\n\n    def __array_finalize__(self, obj):\n        if obj is None: return\n        self.shm = getattr(obj, 'shm', None)\n\n\nclass LMDataModule(LightningDataModule):\n    def __init__(self, dataset_name, tokenizer_name, dataset_config_name=None, max_length=1024,\n                 cache_dir=None, val_ratio=0.0005, val_split_seed=2357, add_eos=True,\n                 detokenize=False, val_only=False, batch_size=32, batch_size_eval=None, num_workers=1,\n                 shuffle=False, pin_memory=False, drop_last=False, fault_tolerant=False, ddp=False,\n                 fast_forward_epochs=None, fast_forward_batches=None,\n                 use_shmem=True):\n        super().__init__()\n        self.dataset_name = dataset_name\n        self.dataset_config_name = dataset_config_name\n        self.tokenizer_name = tokenizer_name\n        self.cache_dir = None if cache_dir is None else Path(cache_dir).expanduser()\n        self.max_length = max_length\n        self.val_ratio = val_ratio\n        self.val_split_seed = val_split_seed\n        self.val_only = val_only\n        self.add_eos = add_eos\n        self.detokenize = detokenize\n        self.batch_size = batch_size\n        self.batch_size_eval = batch_size_eval if batch_size_eval is not None else self.batch_size\n        self.num_workers = num_workers\n        self.shuffle = shuffle\n        self.pin_memory = pin_memory\n        self.drop_last = drop_last\n        if fault_tolerant:\n            assert self.shuffle\n        self.fault_tolerant = fault_tolerant\n        if ddp:\n            assert fault_tolerant\n        self.ddp = ddp\n        self.fast_forward_epochs = fast_forward_epochs\n        self.fast_forward_batches = fast_forward_batches\n        if self.fast_forward_epochs is not None or self.fast_forward_batches is not None:\n            assert ddp and fault_tolerant\n\n        self.use_shmem = use_shmem\n        if self.use_shmem:\n            assert cache_dir is not None\n\n    def prepare_data(self):\n        if self.cache_dir is None:  # Just download the dataset\n            load_dataset(self.dataset_name, self.dataset_config_name)\n        else:  # Process the dataset and save it\n            self.process_dataset()\n\n    def setup(self, stage=None):\n        if stage == 'test' and hasattr(self, 'dataset_test'):\n            return\n        concat_ids, self.tokenizer = self.process_dataset()\n        self.vocab_size = len(self.tokenizer)\n        # Create all splits\n        self.dataset_train, self.dataset_val, self.dataset_test = [\n            LMDataset(concat_ids[split], seq_len=self.max_length)\n            for split in ['train', 'validation', 'test']\n        ]\n\n    def process_dataset(self):\n        cache_dir = None if self.cache_dir is None else self.cache_dir / self._cache_dir_name\n        if cache_dir is not None:\n            if cache_dir.is_dir():\n                return self._load_from_cache(cache_dir)\n\n        raw_datasets = load_dataset(self.dataset_name, self.dataset_config_name)\n        # https://github.com/stanford-crfm/mistral/blob/main/src/corpora/auto.py\n        if 'validation' not in raw_datasets:\n            assert \"train\" in raw_datasets, \"You must have train in raw_datasets to make a validation raw_datasets\"\n            raw_datasets = raw_datasets[\"train\"].train_test_split(\n                test_size=self.val_ratio, seed=self.val_split_seed,\n                shuffle=True  # Otherwise test will be at the end of the dataset\n            )\n            raw_datasets['validation'] = raw_datasets['test']\n\n        if self.val_only:  # Should only be used for evaluation, not for training\n            raw_datasets['train'] = raw_datasets['validation']\n\n        # [2021-12-25] TD: Running the detokenizer on wikitext-103 makes ppl worse\n        # (GPT2-small val ppl after 10 epochs ~22 -> ~25)\n        # However, it's useful for zero-shot transfer from Openwebtext,\n        # as after detokenization it's closer to Openwebtext's format.\n        # https://github.com/stanford-crfm/mistral/issues/12\n        if self.detokenize:\n            if self.dataset_name in DATASET_TOKENIZATION_REGISTRY:\n                detokenizer = DATASET_TOKENIZATION_REGISTRY[self.dataset_name]\n                raw_datasets = raw_datasets.map(\n                    lambda example: {'text': detokenizer(example['text'])},\n                    num_proc=max(self.num_workers, 1),\n                    desc='Running detokenizer on dataset'\n                )\n\n        tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, use_fast=True)\n        # Preprocessing the datasets.\n        # First we tokenize all the texts.\n        column_names = raw_datasets[\"train\"].column_names\n        text_column_name = \"text\" if \"text\" in column_names else column_names[0]\n        # [2021-12-25] TD: For wikitext, don't need to add the EOS since each example already ends\n        # with '\\n', and there are no other '\\n' in the examples.\n        # assert all([t.count('\\n') == 1 for t in raw_datasets['train']['text'] if t])\n        # Add EOS token to the end of the text if the text is not empty\n        # https://github.com/stanford-crfm/mistral/issues/91\n        # https://github.com/stanford-crfm/mistral/pull/98\n        if self.add_eos:\n            add_eos = lambda seq: (seq + tokenizer.eos_token) if seq else seq\n            add_eos_batched = lambda seqs: [add_eos(seq) for seq in seqs]\n            tokenize = lambda example: tokenizer(add_eos_batched(example[text_column_name]))\n        else:\n            tokenize = lambda example: tokenizer(example[text_column_name])\n        # tokenized_datasets = raw_datasets.map(\n        #     tokenize,\n        #     batched=True,\n        #     num_proc=max(self.num_workers, 1),\n        #     remove_columns=column_names,\n        #     desc=\"Running tokenizer on dataset\",\n        # )\n        dtype = np.uint16 if tokenizer.vocab_size < 64 * 1024 else np.int32\n        def tokenize_concat(examples):\n            # We just need 'input_ids', not 'attention_mask' (since it's all 1)\n            input_ids = np.fromiter(chain(*tokenize(examples)['input_ids']), dtype=dtype)\n            # Need to return a list since we're doing batched processing\n            return {'input_ids': [input_ids], 'len': [len(input_ids)]}\n        tokenized_datasets = raw_datasets.map(\n            tokenize_concat,\n            batched=True,\n            num_proc=max(self.num_workers, 1),\n            remove_columns=column_names,\n            desc=\"Running tokenizer on dataset\",\n        )\n\n        if self.use_shmem:\n            # Concatenate all input_ids into an array in shared memory\n            def write_ids_to_shm(example, shm_name, array_len):\n                shm = SharedMemory(name=shm_name)\n                shm_arr = np.ndarray((array_len,), dtype=dtype, buffer=shm.buf)\n                start_idx = example['len_offset'] - len(example['input_ids'])\n                shm_arr[start_idx:example['len_offset']] = example['input_ids']\n                shm.close()\n            concat_ids = {}\n            for name, ds in tokenized_datasets.items():\n                tokenized_datasets[name] = ds.add_column('len_offset', np.cumsum(ds['len']))\n                array_len = tokenized_datasets[name][-1]['len_offset']\n                shm = SharedMemory(create=True, size=array_len * np.dtype(dtype).itemsize)\n                shm_name = shm.name\n                tokenized_datasets[name].map(\n                    write_ids_to_shm,\n                    fn_kwargs={'shm_name': shm_name, 'array_len': array_len},\n                    batched=False,\n                    num_proc=max(self.num_workers, 1),\n                    desc=\"Concatenating examples\",\n                )\n                shm_arr = np.ndarray((array_len,), dtype=dtype, buffer=shm.buf)\n                # We need to keep a reference to the shared memory, otherwise it gets garbage-collected\n                # when it goes out of scope, and that memory is gone.\n                # https://github.com/numpy/numpy/issues/18294\n                concat_ids[name] = SHMArray(shm_arr, shm=shm)\n        else:\n            # Use disk\n            concat_ids = {}\n            assert cache_dir is not None\n            cache_dir.mkdir(parents=True, exist_ok=True)\n            def write_ids_to_disk(example, filename):\n                with open(filename, 'r+b') as f:\n                    mm = mmap.mmap(f.fileno(), 0)\n                    start_idx = example['len_offset'] - len(example['input_ids'])\n                    array_len = len(example['input_ids'])\n                    arr = np.ndarray((array_len,), dtype=dtype, buffer=mm,\n                                     offset=np.dtype(dtype).itemsize * start_idx)\n                    arr[:] = example['input_ids']\n                    mm.flush()\n            for name, ds in tokenized_datasets.items():\n                tokenized_datasets[name] = ds.add_column('len_offset', np.cumsum(ds['len']))\n                array_len = tokenized_datasets[name][-1]['len_offset']\n                filename = cache_dir / f'{name}.bin'\n                # Need to create the file with this specific size first\n                # https://ostechnix.com/create-files-certain-size-linux/\n                subprocess.run(['truncate', '-s', str(array_len * np.dtype(dtype).itemsize),\n                                str(filename)], check=True)\n                tokenized_datasets[name].map(\n                    write_ids_to_disk,\n                    fn_kwargs={'filename': filename},\n                    batched=False,\n                    num_proc=max(self.num_workers, 1),\n                    desc=\"Concatenating examples\",\n                )\n                concat_ids[name] = np.memmap(filename, dtype=dtype, mode='r', shape=(array_len,))\n\n        if cache_dir is not None:\n            self._save_to_cache(concat_ids, tokenizer, cache_dir)\n            if not self.use_shmem:\n                for name in concat_ids:\n                    Path(cache_dir / f'{name}.bin').unlink()\n        return concat_ids, tokenizer\n\n    def _save_to_cache(self, concat_ids, tokenizer, cache_dir):\n        cache_dir.mkdir(parents=True, exist_ok=True)\n        logger.info(f'Saving to cache at {str(cache_dir)}')\n        for k, v in concat_ids.items():\n            np.save(cache_dir / f'{k}.npy', v)\n        with open(cache_dir / 'tokenizer.pkl', 'wb') as f:\n            pickle.dump(tokenizer, f)\n\n    def _load_from_cache(self, cache_dir):\n        assert cache_dir.is_dir()\n        logger.info(f'Load from cache at {str(cache_dir)}')\n        concat_ids = {split: np.load(cache_dir / f'{split}.npy', mmap_mode='r')\n                      for split in ['train', 'validation', 'test']}\n        with open(cache_dir / 'tokenizer.pkl', 'rb') as f:\n            tokenizer = pickle.load(f)\n        return concat_ids, tokenizer\n\n    @property\n    def _cache_dir_name(self):\n        return f'tokenizer_name-{self.tokenizer_name}-val_ratio-{self.val_ratio}-val_split_seed-{self.val_split_seed}-add_eos-{self.add_eos}-detokenize-{self.detokenize}'\n\n    def train_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader:\n        \"\"\" The train dataloader \"\"\"\n        if self.shuffle and self.fault_tolerant:\n            shuffle = False\n            sampler = (FaultTolerantDistributedSampler(self.dataset_train) if self.ddp\n                       else RandomFaultTolerantSampler(self.dataset_train))\n            # TD [2022-08-06]: Only the DDP sampler supports fast-forwarding for now\n            # We assume that it's being resumed with the same number of GPUs\n            if self.ddp and self.fast_forward_epochs is not None and self.fast_forward_batches is not None:\n                sampler.load_state_dict({\n                    'epoch': self.fast_forward_epochs,\n                    'counter': self.fast_forward_batches * self.batch_size\n                })\n        else:\n            shuffle = self.shuffle\n            sampler = None\n        return self._data_loader(self.dataset_train, batch_size=self.batch_size,\n                                 shuffle=shuffle, sampler=sampler)\n\n    def val_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]:\n        \"\"\" The val dataloader \"\"\"\n        return self._data_loader(self.dataset_val, batch_size=self.batch_size_eval)\n\n    def test_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]:\n        \"\"\" The test dataloader \"\"\"\n        return self._data_loader(self.dataset_test, batch_size=self.batch_size_eval)\n\n    def _data_loader(self, dataset: Dataset, batch_size: int, shuffle: bool = False,\n                     sampler=None) -> DataLoader:\n        return DataLoader(\n            dataset,\n            batch_size=batch_size,\n            num_workers=1,  # Data is already in memory, we don't need many workers\n            shuffle=shuffle,\n            sampler=sampler,\n            drop_last=self.drop_last,\n            pin_memory=self.pin_memory,\n            # persistent_workers=True\n        )\n\n    def load_state_dict(self, checkpoint):\n        if self.fault_tolerant:\n            self.fast_forward_epochs = checkpoint['loops']['fit_loop']['epoch_progress']['current']['completed']\n            # TD [2022-08-07] ['epoch_loop.batch_progress']['total']['completed'] is 1 iteration\n            # behind, so we're using the optimizer's progress. This is set correctly in seq.py.\n            self.fast_forward_batches = checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['completed']\n        # At this point the train loader hasn't been constructed yet\n"
  },
  {
    "path": "training/src/datamodules/timm_mixup.py",
    "content": "import torch\n\nfrom timm.data import Mixup\nfrom timm.data.mixup import mixup_target\n\n\nclass TimmMixup(Mixup):\n    \"\"\" Wrap timm.data.Mixup that avoids the assert that batch size must be even.\n    \"\"\"\n    def __call__(self, x, target):\n        if self.mode == 'elem':\n            lam = self._mix_elem(x)\n        elif self.mode == 'pair':\n            # We move the assert from the beginning of the function to here\n            assert len(x) % 2 == 0, 'Batch size should be even when using this'\n            lam = self._mix_pair(x)\n        else:\n            lam = self._mix_batch(x)\n        target = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device)\n        return x, target\n"
  },
  {
    "path": "training/src/distributed/ddp_comm_hooks.py",
    "content": "# Adapted from https://pytorch.org/docs/stable/_modules/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.html\n# We divide by world_size first before converting to fp16, so it's safer.\nfrom typing import Any, Callable\n\nimport torch\nimport torch.distributed as dist\n\n\ndef fp16_compress_hook(\n    process_group: dist.ProcessGroup, bucket: dist.GradBucket\n) -> torch.futures.Future[torch.Tensor]:\n    \"\"\"\n    This DDP communication hook implements a simple gradient compression\n    approach that casts ``GradBucket`` tensor to half-precision floating-point format (``torch.float16``)\n    and then divides it by the process group size.\n    It allreduces those ``float16`` gradient tensors. Once compressed gradient\n    tensors are allreduced, the chained callback ``decompress`` casts it back to the input data type (such as ``float32``).\n\n    Example::\n        >>> ddp_model.register_comm_hook(process_group, fp16_compress_hook)\n    \"\"\"\n    group_to_use = process_group if process_group is not None else dist.group.WORLD\n    world_size = group_to_use.size()\n\n    # Divide first before converting to fp16\n    # Use out argument to fuse the division and the conversion.\n    compressed_tensor = torch.div(bucket.buffer(), world_size,\n                                  out=torch.empty_like(bucket.buffer(), dtype=torch.float16))\n\n    fut = dist.all_reduce(\n        compressed_tensor, group=group_to_use, async_op=True\n    ).get_future()\n\n    def decompress(fut):\n        decompressed_tensor = bucket.buffer()\n        # Decompress in place to reduce the peak memory.\n        # See: https://github.com/pytorch/pytorch/issues/45968\n        decompressed_tensor.copy_(fut.value()[0])\n        return decompressed_tensor\n\n    # TODO: maybe have a backoff strategy: check if the buffer has inf / NaN, in that case\n    # resend with fp32?\n    return fut.then(decompress)\n"
  },
  {
    "path": "training/src/eval.py",
    "content": "from typing import List, Optional\nfrom pathlib import Path\n\nimport torch\n\nimport hydra\nfrom omegaconf import OmegaConf, DictConfig\nfrom pytorch_lightning import (\n    Callback,\n    LightningDataModule,\n    LightningModule,\n    Trainer,\n    seed_everything,\n)\nfrom pytorch_lightning.loggers import LightningLoggerBase\n\nfrom src.utils import utils\n\nlog = utils.get_logger(__name__)\n\n\ndef remove_prefix(text: str, prefix: str):\n    if text.startswith(prefix):\n        return text[len(prefix) :]\n    return text  # or whatever\n\n\ndef load_checkpoint(path, device='cpu'):\n    path = Path(path).expanduser()\n    if path.is_dir():\n        path /= 'last.ckpt'\n    # dst = f'cuda:{torch.cuda.current_device()}'\n    log.info(f'Loading checkpoint from {str(path)}')\n    state_dict = torch.load(path, map_location=device)\n    # T2T-ViT checkpoint is nested in the key 'state_dict_ema'\n    if state_dict.keys() == {'state_dict_ema'}:\n        state_dict = state_dict['state_dict_ema']\n    # Swin checkpoint is nested in the key 'model'\n    if state_dict.keys() == {'model'}:\n        state_dict = state_dict['model']\n    # Lightning checkpoint contains extra stuff, we only want the model state dict\n    if 'pytorch-lightning_version' in state_dict:\n        state_dict = {remove_prefix(k, 'model.'): v for k, v in state_dict['state_dict'].items()}\n    return state_dict\n\n\ndef evaluate(config: DictConfig) -> None:\n    \"\"\"Example of inference with trained model.\n    It loads trained image classification model from checkpoint.\n    Then it loads example image and predicts its label.\n    \"\"\"\n\n    # load model from checkpoint\n    # model __init__ parameters will be loaded from ckpt automatically\n    # you can also pass some parameter explicitly to override it\n\n    # We want to add fields to config so need to call OmegaConf.set_struct\n    OmegaConf.set_struct(config, False)\n\n    # load model\n    checkpoint_type = config.eval.get('checkpoint_type', 'pytorch')\n    if checkpoint_type not in ['lightning', 'pytorch']:\n        raise NotImplementedError(f'checkpoint_type ${checkpoint_type} not supported')\n\n    if checkpoint_type == 'lightning':\n        cls = hydra.utils.get_class(config.task._target_)\n        model = cls.load_from_checkpoint(checkpoint_path=config.eval.ckpt)\n    elif checkpoint_type == 'pytorch':\n        model_cfg = config.model_pretrained if 'model_pretrained' in config else None\n        trained_model: LightningModule = hydra.utils.instantiate(config.task, cfg=config,\n                                                                 model_cfg=model_cfg,\n                                                                 _recursive_=False)\n        if 'ckpt' in config.eval:\n            load_return = trained_model.model.load_state_dict(\n                load_checkpoint(config.eval.ckpt, device=trained_model.device), strict=False\n            )\n            log.info(load_return)\n        if 'model_pretrained' in config:\n            ...\n        else:\n            model = trained_model\n\n    datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule)\n    # datamodule: LightningDataModule = model._datamodule\n    datamodule.prepare_data()\n    datamodule.setup()\n\n    # print model hyperparameters\n    log.info(f'Model hyperparameters: {model.hparams}')\n\n    # Init Lightning callbacks\n    callbacks: List[Callback] = []\n    if \"callbacks\" in config:\n        for _, cb_conf in config[\"callbacks\"].items():\n            if cb_conf is not None and \"_target_\" in cb_conf:\n                log.info(f\"Instantiating callback <{cb_conf._target_}>\")\n                callbacks.append(hydra.utils.instantiate(cb_conf))\n\n    # Init Lightning loggers\n    logger: List[LightningLoggerBase] = []\n    if \"logger\" in config:\n        for _, lg_conf in config[\"logger\"].items():\n            if lg_conf is not None and \"_target_\" in lg_conf:\n                log.info(f\"Instantiating logger <{lg_conf._target_}>\")\n                logger.append(hydra.utils.instantiate(lg_conf))\n\n    # Init Lightning trainer\n    log.info(f\"Instantiating trainer <{config.trainer._target_}>\")\n    trainer: Trainer = hydra.utils.instantiate(\n        config.trainer, callbacks=callbacks, logger=logger,  _convert_=\"partial\"\n    )\n\n    # Evaluate the model\n    log.info(\"Starting evaluation!\")\n    if config.eval.get('run_val', True):\n        trainer.validate(model=model, datamodule=datamodule)\n    if config.eval.get('run_test', True):\n        trainer.test(model=model, datamodule=datamodule)\n\n    # Make sure everything closed properly\n    log.info(\"Finalizing!\")\n    utils.finish(\n        config=config,\n        model=model,\n        datamodule=datamodule,\n        trainer=trainer,\n        callbacks=callbacks,\n        logger=logger,\n    )\n"
  },
  {
    "path": "training/src/metrics/accuracy.py",
    "content": "import torch\nfrom torch import Tensor\n\nfrom torchmetrics import Metric, Accuracy\n\n\nclass AccuracyMine(Accuracy):\n    \"\"\"Wrap torchmetrics.Accuracy to take argmax of y in case of Mixup.\n    \"\"\"\n    def update(self, preds: Tensor, target: Tensor) -> None:  # type: ignore\n        super().update(preds, target.argmax(dim=-1) if target.is_floating_point() else target)\n"
  },
  {
    "path": "training/src/metrics/num_tokens.py",
    "content": "from typing import Any, Dict, Optional\n\nimport torch\nfrom torch import Tensor\n\nfrom torchmetrics import Metric\n\n\nclass NumTokens(Metric):\n    \"\"\"Keep track of how many tokens we've seen.\n    \"\"\"\n    # TODO: how do we prevent the reset between the epochs? The reset happens on the 1st batch\n    # of the next epoch.\n    # Right now the hack is that we override reset(), which would mess up the forward method.\n    # We then override forward to do the right thing.\n\n    is_differentiable = False\n    higher_is_better = False\n    full_state_update = False\n    count: Tensor\n\n    def __init__(self, **kwargs: Dict[str, Any]):\n        super().__init__(**kwargs)\n        self.add_state(\"count\", default=torch.tensor(0, dtype=torch.int64), dist_reduce_fx=\"sum\",\n                       persistent=True)  # We want the count to be saved to state-dict\n\n    def update(self, preds: Tensor, target: Tensor, loss: Optional[Tensor] = None) -> None:  # type: ignore\n        self.count += target.numel()\n\n    def compute(self) -> Tensor:\n        return self.count\n\n    def reset(self):\n        count = self.count\n        super().reset()\n        self.count = count\n\n    # Adapted from https://github.com/Lightning-AI/metrics/blob/master/src/torchmetrics/metric.py\n    def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any:\n        \"\"\"forward computation using single call to `update` to calculate the metric value on the current batch and\n        accumulate global state.\n        This can be done when the global metric state is a sinple reduction of batch states.\n        \"\"\"\n        self.update(*args, **kwargs)\n        return self.compute()\n"
  },
  {
    "path": "training/src/metrics/perplexity.py",
    "content": "# Inspired by https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/metrics/perplexity.py\n# But we compute the perplexity correctly: exp(average(nll)), not average(exp(nll))\n# Also adapted from https://github.com/Lightning-AI/metrics/blob/master/src/torchmetrics/text/perplexity.py\n# But we pass in the loss to avoid recomputation\n\nfrom typing import Any, Dict, Optional\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom torchmetrics import Metric\n\ntry:\n    from flash_attn.losses.cross_entropy import CrossEntropyLoss\nexcept ImportError:\n    CrossEntropyLoss = torch.nn.CrossEntropyLoss\n\n__all__ = ['Perplexity']\n\n\nclass Perplexity(Metric):\n    r\"\"\"\n    Perplexity measures how well a language model predicts a text sample. It's calculated as the average number of bits\n    per word a model needs to represent the sample.\n    Args:\n        kwargs:\n            Additional keyword arguments, see :ref:`Metric kwargs` for more info.\n    Examples:\n        >>> import torch\n        >>> preds = torch.rand(2, 8, 5, generator=torch.manual_seed(22))\n        >>> target = torch.randint(5, (2, 8), generator=torch.manual_seed(22))\n        >>> target[0, 6:] = -100\n        >>> metric = Perplexity(ignore_index=-100)\n        >>> metric(preds, target)\n        tensor(5.2545)\n    \"\"\"\n    is_differentiable = True\n    higher_is_better = False\n    full_state_update = False\n    total_log_probs: Tensor\n    count: Tensor\n\n    def __init__(self, **kwargs: Dict[str, Any]):\n        super().__init__(**kwargs)\n        self.add_state(\"total_log_probs\", default=torch.tensor(0.0, dtype=torch.float64),\n                       dist_reduce_fx=\"sum\")\n        self.add_state(\"count\", default=torch.tensor(0, dtype=torch.int64), dist_reduce_fx=\"sum\")\n\n        self.loss_fn = CrossEntropyLoss()\n\n    def update(self, preds: Tensor, target: Tensor, loss: Optional[Tensor] = None) -> None:  # type: ignore\n        \"\"\"Compute and store intermediate statistics for Perplexity.\n        Args:\n            preds:\n                Probabilities assigned to each token in a sequence with shape [batch_size, seq_len, vocab_size].\n            target:\n                Ground truth values with a shape [batch_size, seq_len].\n        \"\"\"\n        count = target.numel()\n        if loss is None:\n            loss = self.loss_fn(preds, target)\n        self.total_log_probs += loss.double() * count\n        self.count += count\n\n    def compute(self) -> Tensor:\n        \"\"\"Compute the Perplexity.\n        Returns:\n           Perplexity\n        \"\"\"\n        return torch.exp(self.total_log_probs / self.count)\n"
  },
  {
    "path": "training/src/models/modules/seq_common.py",
    "content": "import math\nfrom functools import partial\nfrom collections import namedtuple\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn.modules.utils import _pair\n\nimport hydra\n\nfrom einops import reduce, rearrange\n\n\ndef pooling(x, pooling_mode='CLS', key_padding_mask=None, batch_first=True):\n    if pooling_mode not in ['MEAN', 'SUM', 'CLS', 'LAST', 'FLATTEN']:\n        raise NotImplementedError(f'pooling_mode must be MEAN, SUM, CLS, LAST, FLATTEN')\n    if pooling_mode in ['MEAN', 'SUM']:\n        if key_padding_mask is not None:\n            mask = rearrange(~key_padding_mask.bool_matrix,\n                             'b s -> b s 1' if batch_first else 'b s -> s b 1')\n            x = x.masked_fill(mask, 0)\n        s = reduce(x, 'b s ... -> b ...' if batch_first else 's b ... -> b ...', 'sum')\n        if pooling_mode == 'SUM':\n            return s\n        else:\n            if key_padding_mask is None:\n                return s / x.shape[1 if batch_first else 0]\n            else:\n                lengths = rearrange(key_padding_mask._lengths, 'b -> b 1')\n                return s / lengths\n    elif pooling_mode == 'CLS':\n        return x[:, 0] if batch_first else x[0]\n    elif pooling_mode == 'LAST':\n        if key_padding_mask is None:\n            return x[:, -1] if batch_first else x[-1]\n        else:\n            lengths = key_padding_mask._lengths\n            if batch_first:\n                batch_size = x.shape[0]\n                return x[torch.arange(batch_size, device=x.device), lengths - 1]\n            else:\n                batch_size = x.shape[1]\n                return x[lengths - 1, torch.arange(batch_size, device=x.device)]\n    elif pooling_mode == 'FLATTEN':\n        return rearrange(x, 'b ... -> b (...)' if batch_first else 's b ... -> b (s ...)')\n\n\nclass ClassificationHeadLinear(nn.Module):\n    \"\"\"Head for sentence-level classification tasks.\"\"\"\n\n    def __init__(self, d_model, num_classes, pooling_mode='MEAN',\n                 batch_first=False, **kwargs):\n        super().__init__()\n        assert pooling_mode in ['MEAN', 'SUM', 'CLS', 'LAST', 'FLATTEN'], 'pooling_mode not supported'\n        self.pooling_mode = pooling_mode\n        self.batch_first = batch_first\n        self.out_proj = nn.Linear(d_model, num_classes)\n\n    def forward(self, hidden_states, key_padding_mask=None, **kwargs):\n        \"\"\"\n            hidden_states: (B, S, D) if batch_first else (S, B, D)\n        \"\"\"\n        hidden_states = pooling(hidden_states, pooling_mode=self.pooling_mode,\n                                key_padding_mask=key_padding_mask, batch_first=self.batch_first)\n        hidden_states = self.out_proj(hidden_states)\n        return hidden_states\n\n\n# Adapted from https://github.com/huggingface/transformers/blob/master/src/transformers/models/reformer/modeling_reformer.py\nclass ClassificationHead(nn.Module):\n    \"\"\"Head for sentence-level classification tasks.\"\"\"\n\n    def __init__(self, d_model, d_inner, num_classes, dropout=0.0, pooling_mode='MEAN',\n                 batch_first=False):\n        super().__init__()\n        assert pooling_mode in ['MEAN', 'SUM', 'CLS', 'LAST', 'FLATTEN'], 'pooling_mode not supported'\n        self.pooling_mode = pooling_mode\n        self.batch_first = batch_first\n        self.dense = nn.Linear(d_model, d_inner)\n        self.dropout = nn.Dropout(dropout)\n        self.out_proj = nn.Linear(d_inner, num_classes)\n\n    def forward(self, hidden_states, key_padding_mask=None, **kwargs):\n        \"\"\"\n            hidden_states: (B, S, D) if batch_first else (S, B, D)\n        \"\"\"\n        hidden_states = pooling(hidden_states, pooling_mode=self.pooling_mode,\n                                key_padding_mask=key_padding_mask, batch_first=self.batch_first)\n        hidden_states = self.dropout(hidden_states)\n        hidden_states = self.dense(hidden_states)\n        # Huggingface uses tanh instead of relu\n        hidden_states = torch.relu(hidden_states)\n        hidden_states = self.dropout(hidden_states)\n        hidden_states = self.out_proj(hidden_states)\n        return hidden_states\n\n\nclass ClassificationHeadDual(nn.Module):\n    \"\"\"Head for sentence-level classification tasks.\"\"\"\n\n    def __init__(self, d_model, d_inner, num_classes, dropout=0.0, pooling_mode='MEAN',\n                 batch_first=False, interaction='NLI'):\n        super().__init__()\n        assert pooling_mode in ['MEAN', 'SUM', 'CLS'], 'pooling_mode not supported'\n        assert interaction in [None, 'NLI'], 'interaction not supported'\n        self.pooling_mode = pooling_mode\n        self.batch_first = batch_first\n        self.interaction = interaction\n        self.dense = nn.Linear(d_model * (4 if self.interaction == 'NLI' else 2), d_inner)\n        self.dropout = nn.Dropout(dropout)\n        self.out_proj = nn.Linear(d_inner, num_classes)\n\n    def forward(self, hidden_states1, hidden_states2,\n                key_padding_mask1=None, key_padding_mask2=None, **kwargs):\n        \"\"\"\n            hidden_states: (B, S, D) if batch_first else (S, B, D)\n        \"\"\"\n        x1 = pooling(hidden_states1, pooling_mode=self.pooling_mode,\n                     key_padding_mask=key_padding_mask1, batch_first=self.batch_first)\n        x2 = pooling(hidden_states2, pooling_mode=self.pooling_mode,\n                     key_padding_mask=key_padding_mask2, batch_first=self.batch_first)\n        hidden_states = (torch.cat([x1, x2, x1 * x2, x1 - x2], dim=-1) if self.interaction == 'NLI'\n                         else torch.cat([x1, x2], dim=-1))\n        hidden_states = self.dropout(hidden_states)\n        hidden_states = self.dense(hidden_states)\n        # Huggingface uses tanh instead of relu\n        hidden_states = torch.relu(hidden_states)\n        hidden_states = self.dropout(hidden_states)\n        hidden_states = self.out_proj(hidden_states)\n        return hidden_states\n\n\nclass LMHead(nn.Module):\n\n    def __init__(self, d_model, num_classes, batch_first=True, bias=True):\n        super().__init__()\n        self.lm_head = nn.Linear(d_model, num_classes, bias=bias)\n\n    def forward(self, hidden_states, **kwargs):\n        \"\"\"\n            hidden_states: (B, S, D) if batch_first else (S, B, D)\n        \"\"\"\n        CausalLMOutput = namedtuple('CausalLMOutput', ['logits'])\n        return CausalLMOutput(self.lm_head(hidden_states))\n\n\ndef sinusoidal_init_(tensor):\n    \"\"\"\n        tensor: (max_len, d_model)\n    \"\"\"\n    max_len, d_model = tensor.shape\n    position = rearrange(torch.arange(0.0, max_len), 's -> s 1')\n    div_term = torch.exp(-math.log(10000.0) * torch.arange(0.0, d_model, 2.0) / d_model)\n    tensor[:, 0::2] = torch.sin(position * div_term)\n    tensor[:, 1::2] = torch.cos(position * div_term)\n    return tensor\n\n\n# Adapted from https://github.com/pytorch/examples/blob/master/word_language_model/model.py\nclass PositionalEncoding(nn.Module):\n    r\"\"\"Inject some information about the relative or absolute position of the tokens\n        in the sequence. The positional encodings have the same dimension as\n        the embeddings, so that the two can be summed. Here, we use sine and cosine\n        functions of different frequencies.\n    .. math::\n        \\text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))\n        \\text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))\n        \\text{where pos is the word position and i is the embed idx)\n    Args:\n        d_model: the embed dim (required).\n        dropout: the dropout value (default=0.1).\n        max_len: the max. length of the incoming sequence (default=5000).\n    Examples:\n        >>> pos_encoder = PositionalEncoding(d_model)\n    \"\"\"\n\n    def __init__(self, d_model, dropout=0.1, max_len=5000, batch_first=False, initializer=None):\n        super().__init__()\n        self.batch_first = batch_first\n        self.dropout = nn.Dropout(p=dropout)\n        pe = torch.empty(max_len, d_model)\n        if initializer is None:\n            sinusoidal_init_(pe)\n            pe = rearrange(pe, 's d -> 1 s d' if self.batch_first else 's d -> s 1 d')\n            self.register_buffer('pe', pe)\n        else:\n            hydra.utils.call(initializer, pe)\n            pe = rearrange(pe, 's d -> 1 s d' if self.batch_first else 's d -> s 1 d')\n            self.pe = nn.Parameter(pe)\n\n    def forward(self, x):\n        r\"\"\"Inputs of forward function\n        Args:\n            x: the sequence fed to the positional encoder model (required).\n        Shape:\n            x: [sequence length, batch size, embed dim] if not batch_first else [B, S, D]\n            output: [sequence length, batch size, embed dim] if not batch_first else [B, S, D]\n        Examples:\n            >>> output = pos_encoder(x)\n        \"\"\"\n        x = x + (self.pe[:, :x.size(1)] if self.batch_first else self.pe[:x.size(0)])\n        return self.dropout(x)\n\n\n# Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/mlp.py\nclass Mlp(nn.Module):\n    \"\"\" MLP as used in Vision Transformer, MLP-Mixer and related networks\n    \"\"\"\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,\n                 act_fn=None, drop=0., device=None, dtype=None):\n        \"\"\"TD [2021-10-27] act_fn takes precedence over act_layer if set.\n        This is to support Pytorch 1.10 Transformer interface that construct the activation\n        *function*, not the activation *layer*.\n        \"\"\"\n        factory_kwargs = {'device': device, 'dtype': dtype}\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        drop_probs = _pair(drop)\n        self.fc1 = nn.Linear(in_features, hidden_features, **factory_kwargs)\n        self.act = act_layer() if act_fn is None else act_fn\n        self.drop1 = nn.Dropout(drop_probs[0])\n        self.fc2 = nn.Linear(hidden_features, out_features, **factory_kwargs)\n        self.drop2 = nn.Dropout(drop_probs[1])\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop1(x)\n        x = self.fc2(x)\n        x = self.drop2(x)\n        return x\n\n\nclass MlpBig(nn.Module):\n    \"\"\" MLP as used in Vision Transformer, MLP-Mixer and related networks\n    \"\"\"\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,\n                 act_fn=None, drop=0., device=None, dtype=None):\n        \"\"\"Copied from Mlp above. If num_layers > 2, add more Mlp layers, doubling each time.\n        \"\"\"\n        factory_kwargs = {'device': device, 'dtype': dtype}\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        cur_hidden_features = hidden_features\n        layers = []\n        for _ in range(4):\n            layers.append(nn.Linear(in_features, cur_hidden_features, **factory_kwargs))\n            layers.append(act_layer())\n            layers.append(nn.Dropout(drop))\n            in_features = cur_hidden_features\n            cur_hidden_features *= 2\n        layers.append(nn.Linear(in_features, out_features, **factory_kwargs))\n        layers.append(nn.Dropout(drop))\n        self.fwd = nn.Sequential(*layers)\n\n    def forward(self, x):\n        return self.fwd(x)\n\nclass GluMlp(nn.Module):\n    \"\"\" MLP w/ GLU style gating\n    See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202\n    \"\"\"\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.Sigmoid, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        assert hidden_features % 2 == 0\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features // 2, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def init_weights(self):\n        # override init of fc1 w/ gate portion set to weight near zero, bias=1\n        fc1_mid = self.fc1.bias.shape[0] // 2\n        nn.init.ones_(self.fc1.bias[fc1_mid:])\n        nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x, gates = x.chunk(2, dim=-1)\n        x = x * self.act(gates)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\nclass GatedMlp(nn.Module):\n    \"\"\" MLP as used in gMLP\n    \"\"\"\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,\n                 gate_layer=None, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        if gate_layer is not None:\n            assert hidden_features % 2 == 0\n            self.gate = gate_layer(hidden_features)\n            hidden_features = hidden_features // 2  # FIXME base reduction on gate property?\n        else:\n            self.gate = nn.Identity()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.gate(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\nclass ConvMlp(nn.Module):\n    \"\"\" MLP using 1x1 convs that keeps spatial dims\n    \"\"\"\n    def __init__(\n            self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, norm_layer=None, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=True)\n        self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity()\n        self.act = act_layer()\n        self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=True)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.norm(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        return x\n\n"
  },
  {
    "path": "training/src/optim/param_grouping.py",
    "content": "import inspect\n\nimport torch.nn as nn\n\nimport hydra\n\ntry:\n    from apex.contrib.layer_norm import FastLayerNorm\nexcept ImportError:\n    FastLayerNorm = None\n\nfrom src.models.modules.seq_common import PositionalEncoding\n\n\ndef group_parameters_for_optimizer(model, optimizer_cfg, bias_weight_decay=False,\n                                   normalization_weight_decay=False):\n    \"\"\"Set weight_decay=0.0 for parameters in model.no_weight_decay, for parameters with\n    attribute _no_weight_decay==True, for bias parameters if bias_weight_decay==False, for\n    normalization parameters if normalization_weight_decay==False\n    \"\"\"\n    # Get the weight decay from the config, or from the default value of the optimizer constructor\n    # if it's not specified in the config.\n    if 'weight_decay' in optimizer_cfg:\n        weight_decay = optimizer_cfg.weight_decay\n    else:\n        # https://stackoverflow.com/questions/12627118/get-a-function-arguments-default-value\n        signature = inspect.signature(hydra.utils.get_class(optimizer_cfg._target_))\n        if 'weight_decay' in signature.parameters:\n            weight_decay = signature.parameters['weight_decay'].default\n            if weight_decay is inspect.Parameter.empty:\n                weight_decay = 0.0\n        else:\n            weight_decay = 0.0\n\n    # If none of the parameters have weight decay anyway, and there are no parameters with special\n    # optimization params\n    if weight_decay == 0.0 and not any(hasattr(p, '_optim') for p in model.parameters()):\n        return model.parameters()\n\n    skip = model.no_weight_decay() if hasattr(model, 'no_weight_decay') else set()\n    skip_keywords = (model.no_weight_decay_keywords() if hasattr(model, 'no_weight_decay_keywords')\n                     else set())\n\n    # Adapted from https://github.com/karpathy/minGPT/blob/master/mingpt/model.py#L134\n    \"\"\"\n    This long function is unfortunately doing something very simple and is being very defensive:\n    We are separating out all parameters of the model into two buckets: those that will experience\n    weight decay for regularization and those that won't (biases, and layernorm/embedding weights).\n    We are then returning the PyTorch optimizer object.\n    \"\"\"\n\n    # separate out all parameters to those that will and won't experience regularizing weight decay\n    decay = set()\n    no_decay = set()\n    special = set()\n    whitelist_weight_modules = (nn.Linear, )\n    blacklist_weight_modules = (nn.Embedding, PositionalEncoding)\n    if not normalization_weight_decay:\n        blacklist_weight_modules += (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,\n                                     nn.LazyBatchNorm1d, nn.LazyBatchNorm2d, nn.LazyBatchNorm3d,\n                                     nn.GroupNorm, nn.SyncBatchNorm,\n                                     nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d,\n                                     nn.LayerNorm, nn.LocalResponseNorm)\n    if FastLayerNorm is not None:\n        blacklist_weight_modules += (FastLayerNorm,)\n\n    param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad}\n    for mn, m in model.named_modules():\n        for pn, p in m.named_parameters():\n            fpn = '%s.%s' % (mn, pn) if mn else pn # full param name\n            # In case of parameter sharing, some parameters show up here but are not in\n            # param_dict.keys()\n            if not p.requires_grad or fpn not in param_dict:\n                continue  # frozen weights\n            if hasattr(p, '_optim'):\n                special.add(fpn)\n            elif fpn in skip or any(skip_keyword in fpn for skip_keyword in skip_keywords):\n                no_decay.add(fpn)\n            elif getattr(p, '_no_weight_decay', False):\n                no_decay.add(fpn)\n            elif not bias_weight_decay and pn.endswith('bias'):\n                no_decay.add(fpn)\n            elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):\n                # weights of whitelist modules will be weight decayed\n                decay.add(fpn)\n            elif isinstance(m, blacklist_weight_modules):\n                # weights of blacklist modules will NOT be weight decayed\n                no_decay.add(fpn)\n\n    decay |= (param_dict.keys() - no_decay - special)\n    # validate that we considered every parameter\n    inter_params = decay & no_decay\n    union_params = decay | no_decay\n    assert len(inter_params) == 0, f\"Parameters {str(inter_params)} made it into both decay/no_decay sets!\"\n    assert len(param_dict.keys() - special - union_params) == 0, f\"parameters {str(param_dict.keys() - union_params)}  were not separated into either decay/no_decay set!\"\n\n    if weight_decay == 0.0 or not no_decay:\n        param_groups = [{\"params\": [param_dict[pn] for pn in sorted(list(no_decay | decay))],\n                         \"weight_decay\": weight_decay}]\n    else:\n        # We need sorted(list()) so that the order is deterministic. Otherwise when we resume\n        # the order could change and resume will fail. [H/t Albert]\n        param_groups = [\n            {\"params\": [param_dict[pn] for pn in sorted(list(decay))], \"weight_decay\": weight_decay},\n            {\"params\": [param_dict[pn] for pn in sorted(list(no_decay))], \"weight_decay\": 0.0},\n        ]\n    # Add parameters with special hyperparameters\n    # Unique dicts\n    hps = [dict(s) for s in set(frozenset(param_dict[pn]._optim.items()) for pn in special)]\n    for hp in hps:\n        params = [param_dict[pn] for pn in sorted(list(special)) if param_dict[pn]._optim == hp]\n        param_groups.append({\"params\": params, **hp})\n\n    return param_groups\n"
  },
  {
    "path": "training/src/optim/timm_lr_scheduler.py",
    "content": "import torch\nfrom torch.optim import Optimizer\n\nfrom timm.scheduler import CosineLRScheduler\n\n\n# We need to subclass torch.optim.lr_scheduler._LRScheduler, or Pytorch-lightning will complain\nclass TimmCosineLRScheduler(CosineLRScheduler, torch.optim.lr_scheduler._LRScheduler):\n    \"\"\" Wrap timm.scheduler.CosineLRScheduler so we can call scheduler.step() without passing in epoch.\n    It supports resuming as well.\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self._last_epoch = -1\n        self.step(epoch=0)\n\n    def step(self, epoch=None):\n        if epoch is None:\n            self._last_epoch += 1\n        else:\n            self._last_epoch = epoch\n        # We call either step or step_update, depending on whether we're using the scheduler every\n        # epoch or every step.\n        # Otherwise, lightning will always call step (i.e., meant for each epoch), and if we set\n        # scheduler interval to \"step\", then the learning rate update will be wrong.\n        if self.t_in_epochs:\n            super().step(epoch=self._last_epoch)\n        else:\n            super().step_update(num_updates=self._last_epoch)\n"
  },
  {
    "path": "training/src/tasks/seq.py",
    "content": "from typing import Any, List\nimport inspect\n\nimport torch\nimport hydra\nfrom pytorch_lightning import LightningModule, LightningDataModule\nfrom torchmetrics import MetricCollection\n\nfrom einops import rearrange\n\nfrom omegaconf import OmegaConf\n\nfrom src.utils.utils import get_logger\nfrom src.optim.param_grouping import group_parameters_for_optimizer\nfrom src.utils.checkpoint import load_checkpoint\n\nlogger = get_logger(__name__)\n\n\nclass SequenceModel(LightningModule):\n\n    def __init__(self, cfg, model_cfg=None):\n        \"\"\"If model_cfg is passed, it will take precedence over cfg.model\n        \"\"\"\n        super().__init__()\n        # this line ensures params passed to LightningModule will be saved to ckpt\n        # it also allows to access params with 'self.hparams' attribute\n        self.save_hyperparameters(cfg)\n        self.cfg = cfg\n        self.model_cfg = model_cfg or self.cfg.model\n\n        self.instantiate_datamodule()\n        self.instantiate_model()\n        self.warmstart()\n        self.instantiate_loss()\n        self.instantiate_metrics()\n\n    def instantiate_datamodule(self):\n        logger.info(f\"Instantiating datamodule <{self.cfg.datamodule._target_}>\")\n        # Calling this self.datamodule will mess with PL since it also assigns self.datamodule\n        self._datamodule: LightningDataModule = hydra.utils.instantiate(self.cfg.datamodule)\n        self._datamodule.prepare_data()\n        self._datamodule.setup()\n        OmegaConf.clear_resolver('datamodule')\n        OmegaConf.register_new_resolver('datamodule', lambda attr: getattr(self._datamodule, attr))\n\n    def instantiate_model(self):\n        # if hasattr(self._datamodule, 'num_classes'):\n        #     self.model_cfg.num_classes = self._datamodule.num_classes\n        # if (hasattr(self._datamodule, 'vocab_size')\n        #     and self.model_cfg.get('embedding_cfg', None) is not None\n        #     and self.model_cfg.embedding_cfg._target_ == \"torch.nn.Embedding\"):\n        #     self.model_cfg.embedding_cfg.num_embeddings = self._datamodule.vocab_size\n        logger.info(f\"Instantiating model <{self.model_cfg._target_}>\")\n        recursive = getattr(self.model_cfg, '_recursive_', False)\n        self.model = hydra.utils.instantiate(self.model_cfg, _recursive_=recursive)\n\n    def instantiate_loss(self):\n        loss_fn_cfg = self.cfg.train.get('loss_fn')\n        if loss_fn_cfg is None:\n            loss_fn_cfg = {'_target_': 'torch.nn.CrossEntropyLoss'}\n        self.loss_fn = hydra.utils.instantiate(loss_fn_cfg)\n        loss_fn_val_cfg = self.cfg.train.get('loss_fn_val', loss_fn_cfg)\n        self.loss_fn_val = hydra.utils.instantiate(loss_fn_val_cfg)\n\n    def instantiate_metrics(self):\n        # use separate metric instance for train, val and test step\n        # to ensure a proper reduction over the epoch\n        if 'eval' in self.cfg and 'metrics' in self.cfg.eval:\n            metrics_cfg = self.cfg.eval.metrics\n        else:\n            metrics_cfg = {'acc': {'_target_': 'torchmetrics.Accuracy'}}\n        metrics = MetricCollection({name: hydra.utils.instantiate(cfg)\n                                    for name, cfg in metrics_cfg.items()})\n        self.train_metrics = metrics.clone(prefix='train/')\n        self.val_metrics = metrics.clone(prefix='val/')\n        self.test_metrics = metrics.clone(prefix='test/')\n\n    def warmstart(self):\n        if self.cfg.train.get('warmstart', None) is not None:\n            logger.info(f\"Warm-starting with weights from {self.cfg.train.warmstart.path}\")\n            strict = self.cfg.train.warmstart.get('strict', True)\n            state_dict = load_checkpoint(self.cfg.train.warmstart.path)\n            if self.cfg.train.warmstart.get('post_process', None) is not None:\n                state_dict = hydra.utils.instantiate(self.cfg.train.warmstart.post_process,\n                                                     state_dict)\n            load_return = self.model.load_state_dict(state_dict, strict=False)\n            logger.info(load_return)\n\n    def forward(self, *args, **kwargs):\n        return self.model(*args, **kwargs)\n\n    def step(self, batch: Any, is_train=True):\n        try:\n            x, y, lengths = batch\n        except ValueError:\n            x, y = batch\n            lengths = None\n        output = self.forward(x) if lengths is None else self.forward(x, lengths=lengths)\n        loss = self.loss_fn(output, y) if is_train else self.loss_fn_val(output, y)\n        return loss, output, y\n\n    def shared_step(self, batch: Any, batch_idx: int, phase='train'):\n        loss, output, targets = self.step(batch, is_train=(phase == 'train'))\n        metrics = getattr(self, f'{phase}_metrics')\n        metrics(output, targets)\n        log_on_step = 'eval' in self.cfg and self.cfg.eval.get('log_on_step', False) and phase == 'train'\n        self.log(f\"{phase}/loss\", loss, on_step=log_on_step, on_epoch=True,\n                 prog_bar=False, sync_dist=True)\n        # https://pytorch-lightning.readthedocs.io/en/stable/visualize/logging_advanced.html#enable-metrics-for-distributed-training\n        # We need to log the Metrics object, not the metric result, since otherwise\n        # pytorch-lightning will use torch.mean to reduce it.\n        # This would be wrong for perplexity, for example.\n        self.log_dict(metrics, on_step=log_on_step, on_epoch=True, prog_bar=True, sync_dist=True)\n        return {\"loss\": loss, \"output\": output, \"targets\": targets}\n\n    def training_step(self, batch: Any, batch_idx: int):\n        return self.shared_step(batch, batch_idx, phase='train')\n\n    def validation_step(self, batch: Any, batch_idx: int):\n        return self.shared_step(batch, batch_idx, phase='val')\n\n    def test_step(self, batch: Any, batch_idx: int):\n        return self.shared_step(batch, batch_idx, phase='test')\n\n    def configure_optimizers(self):\n        if 'optimizer_param_grouping' in self.cfg.train:  # Set zero weight decay for some params\n            parameters = group_parameters_for_optimizer(self.model, self.cfg.train.optimizer,\n                                                        **self.cfg.train.optimizer_param_grouping)\n        else:\n            # parameters = self.model.parameters()\n            parameters = self.parameters() # [21-09-08] AG: this will train task specific parameters such as Retrieval head for AAN\n        optimizer = hydra.utils.instantiate(self.cfg.train.optimizer, parameters)\n\n        # Log optimizer info\n        for i, g in enumerate(optimizer.param_groups):\n            ntensors = len(g['params'])\n            nparams = sum(p.numel() for p in g['params'])\n            hparams = {k: v for k, v in g.items() if k != 'params'}\n            logger.info(f'Optimizer group {i}: {ntensors} tensors, {nparams} parameters, {hparams}')\n\n        if 'scheduler' not in self.cfg.train:\n            return optimizer\n        else:\n            # lr_scheduler should be called either every step (default) or every epoch\n            lr_scheduler = hydra.utils.instantiate(self.cfg.train.scheduler, optimizer)\n            return [optimizer], {'scheduler': lr_scheduler,\n                                 'interval': self.cfg.train.get('scheduler_interval', 'step'),\n                                 'monitor': self.cfg.train.get('scheduler_monitor', 'val/loss')}\n\n    def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx):\n        # https://pytorch-lightning.readthedocs.io/en/latest/guides/speed.html#set-grads-to-none\n        # TD [2022-04-30]: DeepSpeed optimizer uses the kwarg set_grad_to_none instead of set_to_none\n        if 'set_to_none' in inspect.signature(optimizer.zero_grad).parameters:\n            optimizer.zero_grad(set_to_none=True)\n        else:\n            optimizer.zero_grad()\n\n    def on_save_checkpoint(self, checkpoint):\n        # TD [2022-08-07] ['epoch_loop.batch_progress']['total']['completed'] is 1 iteration\n        # behind, so we're using the optimizer's progress.\n        checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['total']['completed'] = checkpoint['loops']['fit_loop']['epoch_loop.batch_loop.optimizer_loop.optim_progress']['optimizer']['step']['total']['completed'] * self.trainer.accumulate_grad_batches\n        checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['completed'] = checkpoint['loops']['fit_loop']['epoch_loop.batch_loop.optimizer_loop.optim_progress']['optimizer']['step']['current']['completed'] * self.trainer.accumulate_grad_batches\n        # _batches_that_stepped tracks the number of global steps, not the number\n        # of local steps, so we don't multiply with self.trainer.accumulate_grad_batches here.\n        checkpoint['loops']['fit_loop']['epoch_loop.state_dict']['_batches_that_stepped'] = checkpoint['loops']['fit_loop']['epoch_loop.batch_loop.optimizer_loop.optim_progress']['optimizer']['step']['total']['completed']\n\n\nclass SequenceLMModel(SequenceModel):\n\n    def step(self, batch: Any, is_train=True):\n        x, y = batch\n        output = self.forward(x).logits\n        output = rearrange(output, '... C -> (...) C')\n        y = rearrange(y, '... -> (...)')\n        loss = self.loss_fn(output, y) if is_train else self.loss_fn_val(output, y)\n        return loss, output, y\n\n    def shared_step(self, batch: Any, batch_idx: int, phase='train'):\n        loss, output, targets = self.step(batch, is_train=(phase == 'train'))\n        # Passing the loss to the perplexity metrics to avoid recomputation\n        metrics = getattr(self, f'{phase}_metrics')\n        metrics(output, targets, loss=loss)\n        log_on_step = 'eval' in self.cfg and self.cfg.eval.get('log_on_step', False) and phase == 'train'\n        self.log(f\"{phase}/loss\", loss, on_step=log_on_step, on_epoch=True,\n                 prog_bar=False, sync_dist=True)\n        # https://pytorch-lightning.readthedocs.io/en/stable/visualize/logging_advanced.html#enable-metrics-for-distributed-training\n        # We need to log the Metrics object, not the metric result, since otherwise\n        # pytorch-lightning will use torch.mean to reduce it.\n        # This would be wrong for perplexity, for example.\n        self.log_dict(metrics, on_step=log_on_step, on_epoch=True, prog_bar=True, sync_dist=True)\n        return {\"loss\": loss, \"output\": output, \"targets\": targets}\n"
  },
  {
    "path": "training/src/train.py",
    "content": "from typing import List, Optional, Sequence\nfrom pathlib import Path\n\nimport hydra\nfrom omegaconf import OmegaConf, DictConfig\nfrom pytorch_lightning import (\n    Callback,\n    LightningDataModule,\n    LightningModule,\n    Trainer,\n    seed_everything,\n)\nfrom pytorch_lightning.loggers import LightningLoggerBase\n\nfrom src.utils import utils\n\nlog = utils.get_logger(__name__)\n\n\ndef last_modification_time(path):\n    \"\"\"Including files / directory 1-level below the path\n    \"\"\"\n    path = Path(path)\n    if path.is_file():\n        return path.stat().st_mtime\n    elif path.is_dir():\n        return max(child.stat().st_mtime for child in path.iterdir())\n    else:\n        return None\n\n\ndef train(config: DictConfig) -> Optional[float]:\n    \"\"\"Contains training pipeline.\n    Instantiates all PyTorch Lightning objects from config.\n\n    Args:\n        config (DictConfig): Configuration composed by Hydra.\n\n    Returns:\n        Optional[float]: Metric score for hyperparameter optimization.\n    \"\"\"\n\n    # Set seed for random number generators in pytorch, numpy and python.random\n    if config.get(\"seed\"):\n        seed_everything(config.seed, workers=True)\n\n    # We want to add fields to config so need to call OmegaConf.set_struct\n    OmegaConf.set_struct(config, False)\n    # Init lightning model\n    model: LightningModule = hydra.utils.instantiate(config.task, cfg=config, _recursive_=False)\n    datamodule: LightningDataModule = model._datamodule\n\n    # Init lightning callbacks\n    callbacks: List[Callback] = []\n    if \"callbacks\" in config:\n        for _, cb_conf in config.callbacks.items():\n            if cb_conf is not None and \"_target_\" in cb_conf:\n                log.info(f\"Instantiating callback <{cb_conf._target_}>\")\n                callbacks.append(hydra.utils.instantiate(cb_conf))\n\n    # Init lightning loggers\n    logger: List[LightningLoggerBase] = []\n    if \"logger\" in config:\n        for _, lg_conf in config.logger.items():\n            if lg_conf is not None and \"_target_\" in lg_conf:\n                log.info(f\"Instantiating logger <{lg_conf._target_}>\")\n                logger.append(hydra.utils.instantiate(lg_conf))\n\n    ckpt_cfg = {}\n    if config.get('resume'):\n        try:\n            checkpoint_path = Path(config.callbacks.model_checkpoint.dirpath)\n            if checkpoint_path.is_dir():\n                last_ckpt = checkpoint_path / 'last.ckpt'\n                autosave_ckpt = checkpoint_path / '.pl_auto_save.ckpt'\n                if not (last_ckpt.exists() or autosave_ckpt.exists()):\n                    raise FileNotFoundError(\"Resume requires either last.ckpt or .pl_autosave.ckpt\")\n                if ((not last_ckpt.exists())\n                    or (autosave_ckpt.exists()\n                       and last_modification_time(autosave_ckpt) > last_modification_time(last_ckpt))):\n                    # autosave_ckpt = autosave_ckpt.replace(autosave_ckpt.with_name('.pl_auto_save_loaded.ckpt'))\n                    checkpoint_path = autosave_ckpt\n                else:\n                    checkpoint_path = last_ckpt\n            # DeepSpeed's checkpoint is a directory, not a file\n            if checkpoint_path.is_file() or checkpoint_path.is_dir():\n                ckpt_cfg = {'ckpt_path': str(checkpoint_path)}\n            else:\n                log.info(f'Checkpoint file {str(checkpoint_path)} not found. Will start training from scratch')\n        except (KeyError, FileNotFoundError):\n            pass\n\n    # Configure ddp automatically\n    n_devices = config.trainer.get('devices', 1)\n    if isinstance(n_devices, Sequence):  # trainer.devices could be [1, 3] for example\n        n_devices = len(n_devices)\n    if n_devices > 1 and config.trainer.get('strategy', None) is None:\n        config.trainer.strategy = dict(\n            _target_='pytorch_lightning.strategies.DDPStrategy',\n            find_unused_parameters=False,\n            gradient_as_bucket_view=True,  # https://pytorch-lightning.readthedocs.io/en/stable/advanced/advanced_gpu.html#ddp-optimizations\n        )\n\n    # Init lightning trainer\n    log.info(f\"Instantiating trainer <{config.trainer._target_}>\")\n    trainer: Trainer = hydra.utils.instantiate(\n        config.trainer, callbacks=callbacks, logger=logger)\n\n    # Train the model\n    log.info(\"Starting training!\")\n    trainer.fit(model=model, datamodule=datamodule, **ckpt_cfg)\n\n    # Evaluate model on test set, using the best model achieved during training\n    if config.get(\"test_after_training\") and not config.trainer.get(\"fast_dev_run\"):\n        log.info(\"Starting testing!\")\n        trainer.test(model=model, datamodule=datamodule)\n\n    # Make sure everything closed properly\n    log.info(\"Finalizing!\")\n    utils.finish(\n        config=config,\n        model=model,\n        datamodule=datamodule,\n        trainer=trainer,\n        callbacks=callbacks,\n        logger=logger,\n    )\n\n    # Print path to best checkpoint\n    if not config.trainer.get(\"fast_dev_run\"):\n        log.info(f\"Best model ckpt: {trainer.checkpoint_callback.best_model_path}\")\n\n    # Return metric score for hyperparameter optimization\n    optimized_metric = config.get(\"optimized_metric\")\n    if optimized_metric:\n        return trainer.callback_metrics[optimized_metric]\n"
  },
  {
    "path": "training/src/utils/checkpoint.py",
    "content": "import re\nfrom pathlib import Path\n\nimport torch\nimport math\nfrom einops import rearrange\n\ndef load_checkpoint(path, device='cpu'):\n    path = Path(path).expanduser()\n    is_deepspeed = False\n    if path.is_dir():  # DeepSpeed checkpoint\n        is_deepspeed = True\n        latest_path = path / 'latest'\n        if latest_path.is_file():\n            with open(latest_path, 'r') as fd:\n                tag = fd.read().strip()\n        else:\n            raise ValueError(f\"Unable to find 'latest' file at {latest_path}\")\n        path /= f'{tag}/mp_rank_00_model_states.pt'\n    state_dict = torch.load(path, map_location=device)\n    if is_deepspeed:\n        state_dict = state_dict['module']\n\n        # Replace the names of some of the submodules\n        def key_mapping(key):\n            return re.sub(r'^module.model.', '', key)\n\n        state_dict = {key_mapping(k): v for k, v in state_dict.items()}\n    return state_dict\n\n\ndef blockdiag_to_dense_mlp_bert(state_dict):\n    from src.ops.blockdiag_multiply import blockdiag_weight_to_dense_weight\n    names = {name for name in state_dict\n             if re.match('bert.encoder.layer.(\\d+).(mlp.fc(1|2)|(intermediate|output).dense).weight',\n                         name)}\n    for name in names:\n        state_dict[name] = blockdiag_weight_to_dense_weight(state_dict[name])\n    return state_dict\n\ndef interpolate_pos_embedding(state_dict, out_seqlen, pos_embedding_name='model.pos_encoder.pe', interleave=False):\n    orig_emb = state_dict['state_dict'][pos_embedding_name]\n    assert (out_seqlen % orig_emb.shape[1]) == 0, 'out_seqlen must be a multiple of the original sequence length'\n    reps = [1 for i in orig_emb.shape]\n    reps[1] = out_seqlen // orig_emb.shape[1]\n    \n    if interleave:\n        assert math.isqrt(orig_emb.shape[1]) ** 2 == orig_emb.shape[1], 'interleave only works for square lengths'\n        assert math.isqrt(out_seqlen) ** 2 == out_seqlen, 'interleave only works for square lengths'\n        assert math.isqrt(reps[1]) ** 2 == reps[1], 'out_seqlen / seqlen must be a perfect square'\n\n        emb_square = rearrange(orig_emb, 'b (h w) d -> b h w d', h = math.isqrt(orig_emb.shape[1]))\n        emb_square_expanded = emb_square.repeat_interleave(math.isqrt(reps[1]), axis=1).repeat_interleave(math.isqrt(reps[1]), axis=2)\n        new_emb = rearrange(emb_square_expanded, 'b h w d -> b (h w) d')\n        state_dict['state_dict'][pos_embedding_name] = new_emb\n    else:\n        state_dict['state_dict'][pos_embedding_name] = orig_emb.repeat(*reps)\n\n    ret = remove_model_prefix(state_dict)\n    # # HACK: this is a hack for block-sparse flash attention\n    ret = {\n        k: v\n        for k, v in ret.items()\n        if not k.endswith('inner_attn.layout')\n    }\n    return ret\n\ndef remove_model_prefix(state_dict):\n    # HACK: this is a hack to get the model to load properly, get rid of 'model.' prefix\n    for key in list(state_dict['state_dict'].keys()):\n        if key.startswith('model.'):\n            new_key = key[len('model.'):]\n            state_dict['state_dict'][new_key] = state_dict['state_dict'].pop(key)\n\n    # HACK: something is wrong with the state dict being loaded...\n    return state_dict['state_dict']\n"
  },
  {
    "path": "training/src/utils/ddp_zero1.py",
    "content": "# Meant to work with Pytorch's ZeroRedundancyOptimizer\n\nfrom typing import Any, Callable, Dict, List, Optional, Union\nfrom pathlib import Path\n\nimport torch\nfrom torch.optim.optimizer import Optimizer\nfrom torch.distributed.optim import ZeroRedundancyOptimizer\n\nfrom pytorch_lightning.strategies.ddp import DDPStrategy\nfrom pytorch_lightning.core.optimizer import LightningOptimizer\ntry:  # pytorch_lightning <= 1.7\n    from pytorch_lightning.utilities.types import _PATH\nexcept ImportError:  # pytorch_lightning >= 1.8\n    try:\n        from lightning_lite.utilities.types import _PATH\n    except ImportError:  # pytorch_lightning >= 1.9\n        from lightning_fabric.utilities.types import _PATH\n\n\n# Copied from Pytorch's ZeroRedundancyOptimizer's state_dict method, but we only get\n# the local state dict to avoid synchronization across GPUs.\n# https://github.com/pytorch/pytorch/blob/0c7ca2d97ba5980a2af7dcd6b8106dc915e591cd/torch/distributed/optim/zero_redundancy_optimizer.py#L1131\ndef get_zero_optimizer_state_dict_local(optimizer, global_rank):\n    optimizer._check_overlap_initialized()\n\n    # Sync the exposed `param_groups` attributes to the local optimizer in\n    # case they have been updated\n    optimizer._sync_param_groups(optimizer.param_groups, optimizer.optim.param_groups)\n\n    local_state_dict = optimizer.optim.state_dict()\n    state_dict = super(ZeroRedundancyOptimizer, optimizer).state_dict()\n\n    # Update the global optimizer state with local state information,\n    # factoring in the translation from local to global indexing\n    rank = global_rank\n    # TODO: recursive copy to device\n    local_param_groups = local_state_dict[\"param_groups\"]\n    global_param_groups = optimizer._partition_parameters()[rank]\n    assert len(local_param_groups) == len(global_param_groups), \\\n        \"Mismatch between number of local and global parameter groups\"\n\n    for local_param_group, global_param_group in zip(local_param_groups, global_param_groups):\n        # `local_param_group` stores local indices, while\n        # `global_param_group` stores the tensors directly\n        local_param_indices = local_param_group[\"params\"]\n        global_params = global_param_group[\"params\"]\n\n        assert len(local_param_indices) == len(global_params), \\\n            \"Mismatch between number of local and global parameters in parameter group\"\n        for local_param_index, global_param in zip(local_param_indices, global_params):\n            # Update the global parameter state, if any\n            if local_param_index in local_state_dict[\"state\"]:\n                global_param_index = optimizer._param_to_index[global_param]\n                state_dict[\"state\"][global_param_index] = local_state_dict[\"state\"][local_param_index]\n\n    # Sort the parameters in the state\n    state_dict[\"state\"] = dict(sorted(state_dict[\"state\"].items()))\n    return state_dict\n\n\nclass DDPStrategyZero1(DDPStrategy):\n    \"\"\"To use ZeroRedundancyOptimizer, we need to shard the optimizer states when\n    saving/loading checkpoints.\n    \"\"\"\n\n    strategy_name = \"ddp_zero1\"\n\n    def optimizer_state(self, optimizer: Optimizer) -> Optional[dict]:\n        if isinstance(optimizer, LightningOptimizer):\n            optimizer = optimizer._optimizer\n        if isinstance(optimizer, ZeroRedundancyOptimizer):\n            return get_zero_optimizer_state_dict_local(optimizer, self.global_rank)\n        else:\n            return optimizer.state_dict()\n\n    def save_checkpoint(\n        self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None\n    ) -> None:\n        \"\"\"Save model/training states as a checkpoint file through state-dump and file-write.\n        Args:\n            checkpoint: dict containing model and trainer state\n            filepath: write-target file's path\n            storage_options: parameter for how to save to storage, passed to ``CheckpointIO`` plugin\n        \"\"\"\n        filepath = Path(filepath)\n        filepath.mkdir(parents=True, exist_ok=True)\n        local_optimizer_states = checkpoint.pop('optimizer_states')\n        if self.is_global_zero:\n            self.checkpoint_io.save_checkpoint(checkpoint, filepath / 'model_states.pt',\n                                               storage_options=storage_options)\n        self.checkpoint_io.save_checkpoint(local_optimizer_states,\n                                           filepath / f'{self.global_rank:03d}_optim_states.pt',\n                                           storage_options=storage_options)\n\n    def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:\n        torch.cuda.empty_cache()\n        checkpoint_path = Path(checkpoint_path)\n        if checkpoint_path.is_file():\n            return super().load_checkpoint(self, str(checkpoint_path))\n        else:\n            assert checkpoint_path.is_dir()\n            global_states = self.checkpoint_io.load_checkpoint(checkpoint_path / 'model_states.pt')\n            local_optimizer_states = self.checkpoint_io.load_checkpoint(checkpoint_path / f'{self.global_rank:03d}_optim_states.pt')\n            global_states['optimizer_states'] = local_optimizer_states\n            return global_states\n"
  },
  {
    "path": "training/src/utils/ddp_zero2.py",
    "content": "# Meant to work with Apex's DistributeFusedAdam\n\nfrom typing import Any, Callable, Dict, List, Optional, Union\nfrom pathlib import Path\nimport types\n\nimport torch\nfrom torch.optim.optimizer import Optimizer\nfrom torch.optim import LBFGS\n\nfrom apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam\n\nfrom pytorch_lightning.strategies.ddp import DDPStrategy\nfrom pytorch_lightning.plugins.precision import PrecisionPlugin, NativeMixedPrecisionPlugin\nfrom pytorch_lightning.core.optimizer import LightningOptimizer\nfrom pytorch_lightning.utilities.exceptions import MisconfigurationException\ntry:  # pytorch_lightning <= 1.7\n    from pytorch_lightning.utilities.types import _PATH\nexcept ImportError:  # pytorch_lightning >= 1.8\n    try:\n        from lightning_lite.utilities.types import _PATH\n    except ImportError:  # pytorch_lightning >= 1.9\n        from lightning_fabric.utilities.types import _PATH\n\n\nclass DistAdamNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin):\n\n    def optimizer_step(  # type: ignore[override]\n        self,\n        model: \"pl.LightningModule\",\n        optimizer,\n        optimizer_idx: int,\n        closure: Callable[[], Any],\n        **kwargs: Any,\n    ) -> Any:\n        if self.scaler is None:\n            # skip scaler logic, as bfloat16 does not require scaler\n            return NativeMixedPrecisionPlugin.optimizer_step(\n                self, optimizer, model=model, optimizer_idx=optimizer_idx, closure=closure, **kwargs\n            )\n        if isinstance(optimizer, LBFGS):\n            raise MisconfigurationException(\n                f\"Native AMP and the LBFGS optimizer are not compatible (optimizer {optimizer_idx}).\"\n            )\n        closure_result = closure()\n        # HACK: we don't call self.scaler.unscale_ here. This is because DistributedFusedAdam\n        # optimizer internally takes the scale into account.\n        # If we call unscale_ here, it would be equivalent to unscaling the gradients twice.\n        # Not unscaling has the side-effect that the NormMonitor callback will report the\n        # gradient norm to be much larger than reality.\n        # # `unscale` after the closure is executed but before the `on_before_optimizer_step` hook.\n        # self.scaler.unscale_(optimizer)\n        # This will call gradient clipping\n        self._after_closure(model, optimizer, optimizer_idx)\n        skipped_backward = closure_result is None\n        # in manual optimization, the closure does not return a value\n        if not model.automatic_optimization or not skipped_backward:\n            # note: the scaler will skip the `optimizer.step` if nonfinite gradients are found\n            step_output = self.scaler.step(optimizer, **kwargs)\n            self.scaler.update()\n            return step_output\n        return closure_result\n\n    def clip_grad_by_norm(self, optimizer: DistributedFusedAdam, clip_val: Union[int, float]) -> None:\n        \"\"\"Clip gradients by norm.\"\"\"\n        # DistributedFusedAdam wants list, not generator\n        # Gradients have not be scaled, so we need to scale up the clip_val\n        if self.scaler is not None:\n            clip_val *= self.scaler.get_scale()\n        return optimizer.clip_grad_norm(clip_val)\n\n\nclass DDPStrategyZero2(DDPStrategy):\n    \"\"\"To use Apex's DistributedFusedAdam, we need to shard the optimizer states when\n    saving/loading checkpoints.\n    \"\"\"\n\n    strategy_name = \"ddp_zero2\"\n\n    def __init__(\n        self,\n        *args,\n        precision_plugin: Optional[PrecisionPlugin] = DistAdamNativeMixedPrecisionPlugin,\n        # precision_plugin: Optional[PrecisionPlugin] = None,\n        **kwargs: Union[Any, Dict[str, Any]],\n    ) -> None:\n        super().__init__(\n            *args, precision_plugin=precision_plugin, **kwargs\n        )\n\n    @property\n    def precision_plugin(self) -> PrecisionPlugin:\n        return self._precision_plugin if self._precision_plugin is not None else PrecisionPlugin()\n\n    @precision_plugin.setter\n    def precision_plugin(self, precision_plugin: Optional[PrecisionPlugin]) -> None:\n        self._precision_plugin = precision_plugin\n        # https://stackoverflow.com/questions/972/adding-a-method-to-an-existing-object-instance\n        self._precision_plugin.optimizer_step = types.MethodType(\n            DistAdamNativeMixedPrecisionPlugin.optimizer_step, self._precision_plugin\n        )\n        self._precision_plugin.clip_grad_by_norm = types.MethodType(\n            DistAdamNativeMixedPrecisionPlugin.clip_grad_by_norm, self._precision_plugin\n        )\n\n    def optimizer_state(self, optimizer: Optimizer) -> Optional[dict]:\n        if isinstance(optimizer, LightningOptimizer):\n            optimizer = optimizer._optimizer\n        if isinstance(optimizer, DistributedFusedAdam):\n            return optimizer.state_dict(gather_on_root=False)\n        else:\n            return optimizer.state_dict()\n\n    def save_checkpoint(\n        self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None\n    ) -> None:\n        \"\"\"Save model/training states as a checkpoint file through state-dump and file-write.\n        Args:\n            checkpoint: dict containing model and trainer state\n            filepath: write-target file's path\n            storage_options: parameter for how to save to storage, passed to ``CheckpointIO`` plugin\n        \"\"\"\n        filepath = Path(filepath)\n        filepath.mkdir(parents=True, exist_ok=True)\n        local_optimizer_states = checkpoint.pop('optimizer_states')\n        if self.is_global_zero:\n            self.checkpoint_io.save_checkpoint(checkpoint, filepath / 'model_states.pt',\n                                               storage_options=storage_options)\n        self.checkpoint_io.save_checkpoint(local_optimizer_states,\n                                           filepath / f'{self.global_rank:03d}_optim_states.pt',\n                                           storage_options=storage_options)\n\n    def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:\n        torch.cuda.empty_cache()\n        checkpoint_path = Path(checkpoint_path)\n        if checkpoint_path.is_file():\n            return super().load_checkpoint(self, str(checkpoint_path))\n        else:\n            assert checkpoint_path.is_dir()\n            global_states = self.checkpoint_io.load_checkpoint(checkpoint_path / 'model_states.pt')\n            local_optimizer_states = self.checkpoint_io.load_checkpoint(\n                checkpoint_path / f'{self.global_rank:03d}_optim_states.pt',\n                map_location='cuda'\n            )\n            global_states['optimizer_states'] = local_optimizer_states\n            return global_states\n"
  },
  {
    "path": "training/src/utils/distributed.py",
    "content": "# Copied from https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/utils/distributed.py\n\n# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#       http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nfrom contextlib import contextmanager\n\nimport torch\n\n\ndef init_distributed(cuda):\n    \"\"\"\n    Initializes distributed backend.\n    :param cuda: (bool) if True initializes nccl backend, if False initializes\n        gloo backend\n    \"\"\"\n    world_size = int(os.environ.get('WORLD_SIZE', 1))\n    distributed = (world_size > 1)\n    if distributed:\n        backend = 'nccl' if cuda else 'gloo'\n        torch.distributed.init_process_group(backend=backend,\n                                             init_method='env://')\n        assert torch.distributed.is_initialized()\n    return distributed\n\n\ndef barrier():\n    \"\"\"\n    Call torch.distributed.barrier() if distritubed is in use\n    \"\"\"\n    if torch.distributed.is_available() and torch.distributed.is_initialized():\n        torch.distributed.barrier()\n\n\ndef get_rank():\n    \"\"\"\n    Gets distributed rank or returns zero if distributed is not initialized.\n    \"\"\"\n    if torch.distributed.is_available() and torch.distributed.is_initialized():\n        rank = torch.distributed.get_rank()\n    else:\n        rank = 0\n    return rank\n\n\ndef get_world_size():\n    \"\"\"\n    Gets total number of distributed workers or returns one if distributed is\n    not initialized.\n    \"\"\"\n    if torch.distributed.is_available() and torch.distributed.is_initialized():\n        world_size = torch.distributed.get_world_size()\n    else:\n        world_size = 1\n    return world_size\n\n\ndef all_reduce_item(value, op='sum'):\n    \"\"\"\n    All-reduces single scalar value if distributed is in use\n    \"\"\"\n    if torch.distributed.is_available() and torch.distributed.is_initialized():\n        if op == 'sum' or op == 'mean':\n            dop = torch.distributed.ReduceOp.SUM\n        elif op == 'min':\n            dop = torch.distributed.ReduceOp.MIN\n        elif op == 'max':\n            dop = torch.distributed.ReduceOp.MAX\n        elif op == 'product':\n            dop = torch.distributed.ReduceOp.PRODUCT\n        else:\n            raise RuntimeError('Unsupported reduce op')\n\n        backend = torch.distributed.get_backend()\n        if backend == torch.distributed.Backend.NCCL:\n            device = torch.device('cuda')\n        elif backend == torch.distributed.Backend.GLOO:\n            device = torch.device('cpu')\n        else:\n            raise RuntimeError('Unsupported distributed backend')\n\n        tensor = torch.tensor(value, device=device)\n        torch.distributed.all_reduce(tensor, dop)\n        if op == 'mean':\n            tensor /= get_world_size()\n        ret = tensor.item()\n    else:\n        ret = value\n    return ret\n\n\n@contextmanager\ndef sync_workers():\n    \"\"\"\n    Yields distributed rank and synchronizes all workers on exit.\n    \"\"\"\n    rank = get_rank()\n    yield rank\n    barrier()\n"
  },
  {
    "path": "training/src/utils/ema.py",
    "content": "# Copied from https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py\nfrom __future__ import division\nfrom __future__ import unicode_literals\n\nfrom typing import Iterable, Optional\nimport weakref\nimport copy\nimport contextlib\n\nimport torch\n\n\ndef to_float_maybe(x):\n    return x.float() if x.dtype in [torch.float16, torch.bfloat16] else x\n\n\n# Partially based on:\n# https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/moving_averages.py\nclass ExponentialMovingAverage:\n    \"\"\"\n    Maintains (exponential) moving average of a set of parameters.\n    Args:\n        parameters: Iterable of `torch.nn.Parameter` (typically from\n            `model.parameters()`).\n        decay: The exponential decay.\n        use_num_updates: Whether to use number of updates when computing\n            averages.\n    \"\"\"\n    def __init__(\n        self,\n        parameters: Iterable[torch.nn.Parameter],\n        decay: float,\n        use_num_updates: bool = True\n    ):\n        if decay < 0.0 or decay > 1.0:\n            raise ValueError('Decay must be between 0 and 1')\n        self.decay = decay\n        self.num_updates = 0 if use_num_updates else None\n        parameters = list(parameters)\n        self.shadow_params = [to_float_maybe(p.clone().detach())\n                              for p in parameters if p.requires_grad]\n        self.collected_params = None\n        # By maintaining only a weakref to each parameter,\n        # we maintain the old GC behaviour of ExponentialMovingAverage:\n        # if the model goes out of scope but the ExponentialMovingAverage\n        # is kept, no references to the model or its parameters will be\n        # maintained, and the model will be cleaned up.\n        self._params_refs = [weakref.ref(p) for p in parameters]\n\n    def _get_parameters(\n        self,\n        parameters: Optional[Iterable[torch.nn.Parameter]]\n    ) -> Iterable[torch.nn.Parameter]:\n        if parameters is None:\n            parameters = [p() for p in self._params_refs]\n            if any(p is None for p in parameters):\n                raise ValueError(\n                    \"(One of) the parameters with which this \"\n                    \"ExponentialMovingAverage \"\n                    \"was initialized no longer exists (was garbage collected);\"\n                    \" please either provide `parameters` explicitly or keep \"\n                    \"the model to which they belong from being garbage \"\n                    \"collected.\"\n                )\n            return parameters\n        else:\n            parameters = list(parameters)\n            if len(parameters) != len(self.shadow_params):\n                raise ValueError(\n                    \"Number of parameters passed as argument is different \"\n                    \"from number of shadow parameters maintained by this \"\n                    \"ExponentialMovingAverage\"\n                )\n            return parameters\n\n    def update(\n        self,\n        parameters: Optional[Iterable[torch.nn.Parameter]] = None\n    ) -> None:\n        \"\"\"\n        Update currently maintained parameters.\n        Call this every time the parameters are updated, such as the result of\n        the `optimizer.step()` call.\n        Args:\n            parameters: Iterable of `torch.nn.Parameter`; usually the same set of\n                parameters used to initialize this object. If `None`, the\n                parameters with which this `ExponentialMovingAverage` was\n                initialized will be used.\n        \"\"\"\n        parameters = self._get_parameters(parameters)\n        decay = self.decay\n        if self.num_updates is not None:\n            self.num_updates += 1\n            decay = min(\n                decay,\n                (1 + self.num_updates) / (10 + self.num_updates)\n            )\n        one_minus_decay = 1.0 - decay\n        if parameters[0].device != self.shadow_params[0].device:\n            self.to(device=parameters[0].device)\n        with torch.no_grad():\n            parameters = [p for p in parameters if p.requires_grad]\n            for s_param, param in zip(self.shadow_params, parameters):\n                torch.lerp(s_param, param.to(dtype=s_param.dtype), one_minus_decay, out=s_param)\n\n    def copy_to(\n        self,\n        parameters: Optional[Iterable[torch.nn.Parameter]] = None\n    ) -> None:\n        \"\"\"\n        Copy current averaged parameters into given collection of parameters.\n        Args:\n            parameters: Iterable of `torch.nn.Parameter`; the parameters to be\n                updated with the stored moving averages. If `None`, the\n                parameters with which this `ExponentialMovingAverage` was\n                initialized will be used.\n        \"\"\"\n        parameters = self._get_parameters(parameters)\n        for s_param, param in zip(self.shadow_params, parameters):\n            if param.requires_grad:\n                param.data.copy_(s_param.data)\n\n    def store(\n        self,\n        parameters: Optional[Iterable[torch.nn.Parameter]] = None\n    ) -> None:\n        \"\"\"\n        Save the current parameters for restoring later.\n        Args:\n            parameters: Iterable of `torch.nn.Parameter`; the parameters to be\n                temporarily stored. If `None`, the parameters of with which this\n                `ExponentialMovingAverage` was initialized will be used.\n        \"\"\"\n        parameters = self._get_parameters(parameters)\n        self.collected_params = [\n            param.clone()\n            for param in parameters\n            if param.requires_grad\n        ]\n\n    def restore(\n        self,\n        parameters: Optional[Iterable[torch.nn.Parameter]] = None\n    ) -> None:\n        \"\"\"\n        Restore the parameters stored with the `store` method.\n        Useful to validate the model with EMA parameters without affecting the\n        original optimization process. Store the parameters before the\n        `copy_to` method. After validation (or model saving), use this to\n        restore the former parameters.\n        Args:\n            parameters: Iterable of `torch.nn.Parameter`; the parameters to be\n                updated with the stored parameters. If `None`, the\n                parameters with which this `ExponentialMovingAverage` was\n                initialized will be used.\n        \"\"\"\n        if self.collected_params is None:\n            raise RuntimeError(\n                \"This ExponentialMovingAverage has no `store()`ed weights \"\n                \"to `restore()`\"\n            )\n        parameters = self._get_parameters(parameters)\n        for c_param, param in zip(self.collected_params, parameters):\n            if param.requires_grad:\n                param.data.copy_(c_param.data)\n\n    @contextlib.contextmanager\n    def average_parameters(\n        self,\n        parameters: Optional[Iterable[torch.nn.Parameter]] = None\n    ):\n        r\"\"\"\n        Context manager for validation/inference with averaged parameters.\n        Equivalent to:\n            ema.store()\n            ema.copy_to()\n            try:\n                ...\n            finally:\n                ema.restore()\n        Args:\n            parameters: Iterable of `torch.nn.Parameter`; the parameters to be\n                updated with the stored parameters. If `None`, the\n                parameters with which this `ExponentialMovingAverage` was\n                initialized will be used.\n        \"\"\"\n        parameters = self._get_parameters(parameters)\n        self.store(parameters)\n        self.copy_to(parameters)\n        try:\n            yield\n        finally:\n            self.restore(parameters)\n\n    def to(self, device=None, dtype=None) -> None:\n        r\"\"\"Move internal buffers of the ExponentialMovingAverage to `device`.\n        Args:\n            device: like `device` argument to `torch.Tensor.to`\n        \"\"\"\n        # .to() on the tensors handles None correctly\n        self.shadow_params = [\n            p.to(device=device, dtype=dtype)\n            if p.is_floating_point()\n            else p.to(device=device)\n            for p in self.shadow_params\n        ]\n        if self.collected_params is not None:\n            self.collected_params = [\n                p.to(device=device, dtype=dtype)\n                if p.is_floating_point()\n                else p.to(device=device)\n                for p in self.collected_params\n            ]\n        return\n\n    def state_dict(self) -> dict:\n        r\"\"\"Returns the state of the ExponentialMovingAverage as a dict.\"\"\"\n        # Following PyTorch conventions, references to tensors are returned:\n        # \"returns a reference to the state and not its copy!\" -\n        # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict\n        return {\n            \"decay\": self.decay,\n            \"num_updates\": self.num_updates,\n            \"shadow_params\": self.shadow_params,\n            \"collected_params\": self.collected_params\n        }\n\n    def load_state_dict(self, state_dict: dict) -> None:\n        r\"\"\"Loads the ExponentialMovingAverage state.\n        Args:\n            state_dict (dict): EMA state. Should be an object returned\n                from a call to :meth:`state_dict`.\n        \"\"\"\n        # deepcopy, to be consistent with module API\n        state_dict = copy.deepcopy(state_dict)\n        self.decay = state_dict[\"decay\"]\n        if self.decay < 0.0 or self.decay > 1.0:\n            raise ValueError('Decay must be between 0 and 1')\n        self.num_updates = state_dict[\"num_updates\"]\n        assert self.num_updates is None or isinstance(self.num_updates, int), \\\n            \"Invalid num_updates\"\n\n        self.shadow_params = state_dict[\"shadow_params\"]\n        assert isinstance(self.shadow_params, list), \\\n            \"shadow_params must be a list\"\n        assert all(\n            isinstance(p, torch.Tensor) for p in self.shadow_params\n        ), \"shadow_params must all be Tensors\"\n\n        self.collected_params = state_dict[\"collected_params\"]\n        if self.collected_params is not None:\n            assert isinstance(self.collected_params, list), \\\n                \"collected_params must be a list\"\n            assert all(\n                isinstance(p, torch.Tensor) for p in self.collected_params\n            ), \"collected_params must all be Tensors\"\n            assert len(self.collected_params) == len(self.shadow_params), \\\n                \"collected_params and shadow_params had different lengths\"\n\n        if len(self.shadow_params) == len(self._params_refs):\n            # Consistent with torch.optim.Optimizer, cast things to consistent\n            # device and dtype with the parameters\n            params = [p() for p in self._params_refs]\n            # If parameters have been garbage collected, just load the state\n            # we were given without change.\n            if not any(p is None for p in params):\n                # ^ parameter references are still good\n                for i, p in enumerate(params):\n                    self.shadow_params[i] = to_float_maybe(self.shadow_params[i].to(\n                        device=p.device, dtype=p.dtype\n                    ))\n                    if self.collected_params is not None:\n                        self.collected_params[i] = self.collected_params[i].to(\n                            device=p.device, dtype=p.dtype\n                        )\n        else:\n            raise ValueError(\n                \"Tried to `load_state_dict()` with the wrong number of \"\n                \"parameters in the saved state.\"\n            )\n"
  },
  {
    "path": "training/src/utils/flops.py",
    "content": "# Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/benchmark.py\nimport torch\n\ntry:\n    from deepspeed.profiling.flops_profiler import get_model_profile\n    has_deepspeed_profiling = True\nexcept ImportError as e:\n    has_deepspeed_profiling = False\n\ntry:\n    from fvcore.nn import FlopCountAnalysis, flop_count_str, flop_count_table\n    from fvcore.nn import ActivationCountAnalysis\n    has_fvcore_profiling = True\nexcept ImportError as e:\n    FlopCountAnalysis = None\n    ActivationCountAnalysis = None\n    has_fvcore_profiling = False\n\n\ndef profile_deepspeed(model, input_size=(3, 224, 224), input_dtype=torch.float32,\n                      batch_size=1, detailed=False):\n    device, dtype = next(model.parameters()).device, next(model.parameters()).dtype\n    flops, macs, params = get_model_profile(\n        model=model,\n        args=torch.zeros((batch_size,) + input_size, device=device, dtype=input_dtype),\n        print_profile=detailed,  # prints the model graph with the measured profile attached to each module\n        detailed=detailed,  # print the detailed profile\n        warm_up=10,  # the number of warm-ups before measuring the time of each module\n        as_string=False,  # print raw numbers (e.g. 1000) or as human-readable strings (e.g. 1k)\n        output_file=None,  # path to the output file. If None, the profiler prints to stdout.\n        ignore_modules=None)  # the list of modules to ignore in the profiling\n    return macs, 0  # no activation count in DS\n\n\ndef profile_fvcore(model, input_size=(3, 224, 224), input_dtype=torch.float32, max_depth=4,\n                   batch_size=1, detailed=False, force_cpu=False):\n    if force_cpu:\n        model = model.to('cpu')\n    device, dtype = next(model.parameters()).device, next(model.parameters()).dtype\n    example_input = torch.zeros((batch_size,) + input_size, device=device, dtype=input_dtype)\n    fca = FlopCountAnalysis(model, example_input)\n    aca = ActivationCountAnalysis(model, example_input)\n    if detailed:\n        print(flop_count_table(fca, max_depth=max_depth))\n    return fca, fca.total(), aca, aca.total()\n"
  },
  {
    "path": "training/src/utils/gpu_affinity.py",
    "content": "import collections\nimport math\nimport os\nimport pathlib\nimport re\n\nimport pynvml\n\npynvml.nvmlInit()\n\n\ndef systemGetDriverVersion():\n    return pynvml.nvmlSystemGetDriverVersion()\n\n\ndef deviceGetCount():\n    return pynvml.nvmlDeviceGetCount()\n\n\nclass device:\n    # assume nvml returns list of 64 bit ints\n    _nvml_affinity_elements = math.ceil(os.cpu_count() / 64)\n\n    def __init__(self, device_idx):\n        super().__init__()\n        self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx)\n\n    def getName(self):\n        return pynvml.nvmlDeviceGetName(self.handle)\n\n    def getCpuAffinity(self):\n        affinity_string = ''\n        for j in pynvml.nvmlDeviceGetCpuAffinity(\n            self.handle, device._nvml_affinity_elements\n        ):\n            # assume nvml returns list of 64 bit ints\n            affinity_string = '{:064b}'.format(j) + affinity_string\n        affinity_list = [int(x) for x in affinity_string]\n        affinity_list.reverse()  # so core 0 is in 0th element of list\n\n        ret = [i for i, e in enumerate(affinity_list) if e != 0]\n        return ret\n\n\ndef set_socket_affinity(gpu_id):\n    dev = device(gpu_id)\n    affinity = dev.getCpuAffinity()\n    os.sched_setaffinity(0, affinity)\n\n\ndef set_single_affinity(gpu_id):\n    dev = device(gpu_id)\n    affinity = dev.getCpuAffinity()\n    os.sched_setaffinity(0, affinity[:1])\n\n\ndef set_single_unique_affinity(gpu_id, nproc_per_node):\n    devices = [device(i) for i in range(nproc_per_node)]\n    socket_affinities = [dev.getCpuAffinity() for dev in devices]\n\n    siblings_list = get_thread_siblings_list()\n    siblings_dict = dict(siblings_list)\n\n    # remove siblings\n    for idx, socket_affinity in enumerate(socket_affinities):\n        socket_affinities[idx] = list(set(socket_affinity) - set(siblings_dict.values()))\n\n    affinities = []\n    assigned = []\n\n    for socket_affinity in socket_affinities:\n        for core in socket_affinity:\n            if core not in assigned:\n                affinities.append([core])\n                assigned.append(core)\n                break\n    os.sched_setaffinity(0, affinities[gpu_id])\n\n\ndef set_socket_unique_affinity(gpu_id, nproc_per_node, mode):\n    device_ids = [device(i) for i in range(nproc_per_node)]\n    socket_affinities = [dev.getCpuAffinity() for dev in device_ids]\n\n    siblings_list = get_thread_siblings_list()\n    siblings_dict = dict(siblings_list)\n\n    # remove siblings\n    for idx, socket_affinity in enumerate(socket_affinities):\n        socket_affinities[idx] = list(set(socket_affinity) - set(siblings_dict.values()))\n\n    socket_affinities_to_device_ids = collections.defaultdict(list)\n\n    for idx, socket_affinity in enumerate(socket_affinities):\n        socket_affinities_to_device_ids[tuple(socket_affinity)].append(idx)\n\n    for socket_affinity, device_ids in socket_affinities_to_device_ids.items():\n        devices_per_group = len(device_ids)\n        cores_per_device = len(socket_affinity) // devices_per_group\n        for group_id, device_id in enumerate(device_ids):\n            if device_id == gpu_id:\n                if mode == 'interleaved':\n                    affinity = list(socket_affinity[group_id::devices_per_group])\n                elif mode == 'continuous':\n                    affinity = list(socket_affinity[group_id*cores_per_device:(group_id+1)*cores_per_device])\n                else:\n                    raise RuntimeError('Unknown set_socket_unique_affinity mode')\n\n                # reintroduce siblings\n                affinity += [siblings_dict[aff] for aff in affinity if aff in siblings_dict]\n                os.sched_setaffinity(0, affinity)\n\n\ndef get_thread_siblings_list():\n    path = '/sys/devices/system/cpu/cpu*/topology/thread_siblings_list'\n    thread_siblings_list = []\n    pattern = re.compile(r'(\\d+)\\D(\\d+)')\n    for fname in pathlib.Path(path[0]).glob(path[1:]):\n        with open(fname) as f:\n            content = f.read().strip()\n            res = pattern.findall(content)\n            if res:\n                pair = tuple(map(int, res[0]))\n                thread_siblings_list.append(pair)\n    return thread_siblings_list\n\n\ndef set_affinity(gpu_id, nproc_per_node, mode='socket'):\n    if mode == 'socket':\n        set_socket_affinity(gpu_id)\n    elif mode == 'single':\n        set_single_affinity(gpu_id)\n    elif mode == 'single_unique':\n        set_single_unique_affinity(gpu_id, nproc_per_node)\n    elif mode == 'socket_unique_interleaved':\n        set_socket_unique_affinity(gpu_id, nproc_per_node, 'interleaved')\n    elif mode == 'socket_unique_continuous':\n        set_socket_unique_affinity(gpu_id, nproc_per_node, 'continuous')\n    else:\n        raise RuntimeError('Unknown affinity mode')\n\n    affinity = os.sched_getaffinity(0)\n    return affinity\n"
  },
  {
    "path": "training/src/utils/utils.py",
    "content": "import logging\nimport warnings\nfrom typing import List, Sequence\n\nimport pytorch_lightning as pl\nimport rich.syntax\nimport rich.tree\nfrom omegaconf import DictConfig, OmegaConf\nfrom pytorch_lightning.utilities import rank_zero_only\n\n\n# Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging\nclass LoggingContext:\n    def __init__(self, logger, level=None, handler=None, close=True):\n        self.logger = logger\n        self.level = level\n        self.handler = handler\n        self.close = close\n\n    def __enter__(self):\n        if self.level is not None:\n            self.old_level = self.logger.level\n            self.logger.setLevel(self.level)\n        if self.handler:\n            self.logger.addHandler(self.handler)\n\n    def __exit__(self, et, ev, tb):\n        if self.level is not None:\n            self.logger.setLevel(self.old_level)\n        if self.handler:\n            self.logger.removeHandler(self.handler)\n        if self.handler and self.close:\n            self.handler.close()\n        # implicit return of None => don't swallow exceptions\n\n\ndef get_logger(name=__name__) -> logging.Logger:\n    \"\"\"Initializes multi-GPU-friendly python logger.\"\"\"\n\n    logger = logging.getLogger(name)\n\n    # this ensures all logging levels get marked with the rank zero decorator\n    # otherwise logs would get multiplied for each GPU process in multi-GPU setup\n    for level in (\"debug\", \"info\", \"warning\", \"error\", \"exception\", \"fatal\", \"critical\"):\n        setattr(logger, level, rank_zero_only(getattr(logger, level)))\n\n    return logger\n\n\ndef extras(config: DictConfig) -> None:\n    \"\"\"A couple of optional utilities, controlled by main config file:\n    - disabling warnings\n    - forcing debug friendly configuration\n    - verifying experiment name is set when running in experiment mode\n    Modifies DictConfig in place.\n    Args:\n        config (DictConfig): Configuration composed by Hydra.\n    \"\"\"\n\n    log = get_logger(__name__)\n\n    # disable python warnings if <config.ignore_warnings=True>\n    if config.get(\"ignore_warnings\"):\n        log.info(\"Disabling python warnings! <config.ignore_warnings=True>\")\n        warnings.filterwarnings(\"ignore\")\n\n    # verify experiment name is set when running in experiment mode\n    if config.get(\"experiment_mode\") and not config.get(\"name\"):\n        log.info(\n            \"Running in experiment mode without the experiment name specified! \"\n            \"Use `python run.py mode=exp name=experiment_name`\"\n        )\n        log.info(\"Exiting...\")\n        exit()\n\n    # force debugger friendly configuration if <config.trainer.fast_dev_run=True>\n    # debuggers don't like GPUs and multiprocessing\n    if config.trainer.get(\"fast_dev_run\"):\n        log.info(\"Forcing debugger friendly configuration! <config.trainer.fast_dev_run=True>\")\n        if config.trainer.get(\"gpus\"):\n            config.trainer.gpus = 0\n        if config.datamodule.get(\"pin_memory\"):\n            config.datamodule.pin_memory = False\n        if config.datamodule.get(\"num_workers\"):\n            config.datamodule.num_workers = 0\n\n\n@rank_zero_only\ndef print_config(\n    config: DictConfig,\n    fields: Sequence[str] = (\n        \"trainer\",\n        \"model\",\n        \"datamodule\",\n        \"train\",\n        \"eval\",\n        \"callbacks\",\n        \"logger\",\n        \"seed\",\n        \"name\",\n    ),\n    resolve: bool = True,\n) -> None:\n    \"\"\"Prints content of DictConfig using Rich library and its tree structure.\n    Args:\n        config (DictConfig): Configuration composed by Hydra.\n        fields (Sequence[str], optional): Determines which main fields from config will\n        be printed and in what order.\n        resolve (bool, optional): Whether to resolve reference fields of DictConfig.\n    \"\"\"\n\n    style = \"dim\"\n    tree = rich.tree.Tree(\"CONFIG\", style=style, guide_style=style)\n\n    for field in fields:\n        branch = tree.add(field, style=style, guide_style=style)\n\n        config_section = config.get(field)\n        branch_content = str(config_section)\n        if isinstance(config_section, DictConfig):\n            branch_content = OmegaConf.to_yaml(config_section, resolve=resolve)\n\n        branch.add(rich.syntax.Syntax(branch_content, \"yaml\"))\n\n    rich.print(tree)\n\n    with open(\"config_tree.txt\", \"w\") as fp:\n        rich.print(tree, file=fp)\n\n\ndef finish(\n    config: DictConfig,\n    model: pl.LightningModule,\n    datamodule: pl.LightningDataModule,\n    trainer: pl.Trainer,\n    callbacks: List[pl.Callback],\n    logger: List[pl.loggers.LightningLoggerBase],\n) -> None:\n    \"\"\"Makes sure everything closed properly.\"\"\"\n\n    # without this sweeps with wandb logger might crash!\n    for lg in logger:\n        if isinstance(lg, pl.loggers.wandb.WandbLogger):\n            import wandb\n\n            wandb.finish()\n"
  },
  {
    "path": "training/tests/datamodules/test_language_modeling_hf.py",
    "content": "import os\nfrom pathlib import Path\ncurrent_dir = Path(__file__).parent.absolute()\n\n\nimport pytest\n\nimport torch\n\nimport dotenv\n\nfrom src.datamodules.language_modeling_hf import LMDataModule\n\n# load environment variables from `.env` file if it exists\n# recursively searches for `.env` in all folders starting from work dir\ndotenv.load_dotenv(override=True)\n\n\ndef div_up(x: int, y: int) -> int:\n    return (x + y - 1) // y\n\n\n# https://stackoverflow.com/questions/1006289/how-to-find-out-the-number-of-cpus-using-python/55423170#55423170\ndef num_cpu_cores():\n    try:\n        import psutil\n        return psutil.cpu_count(logical=False)\n    except ImportError:\n        return len(os.sched_getaffinity(0))\n\n\nclass TestLMDataModule:\n\n    def test_wikitext2(self):\n        batch_size = 7\n        dataset_name = 'wikitext'\n        dataset_config_name = 'wikitext-2-raw-v1'\n        data_dir = Path(os.getenv('DATA_DIR', current_dir.parent.parent / 'data'))\n        cache_dir = data_dir / 'wikitext-2' / 'cache'\n        max_length = 1024\n        datamodule = LMDataModule(dataset_name, tokenizer_name='gpt2',\n                                  dataset_config_name=dataset_config_name,\n                                  max_length=max_length, cache_dir=cache_dir,\n                                  add_eos=False, batch_size=batch_size, num_workers=4)\n        datamodule.prepare_data()\n        datamodule.setup(stage='fit')\n        train_loader = datamodule.train_dataloader()\n        val_loader = datamodule.val_dataloader()\n        datamodule.setup(stage='test')\n        test_loader = datamodule.test_dataloader()\n        train_len = 2391884\n        val_len = 247289\n        test_len = 283287\n        assert len(train_loader) == div_up((train_len - 1) // max_length, batch_size)\n        assert len(val_loader) == div_up((val_len - 1) // max_length, batch_size)\n        assert len(test_loader) == div_up((test_len - 1) // max_length, batch_size)\n        for loader in [train_loader, val_loader, test_loader]:\n            x, y = next(iter(loader))\n            assert x.dim() == 2\n            assert x.shape == (batch_size, max_length)\n            assert x.dtype == torch.long\n            assert torch.allclose(x[:, 1:], y[:, :-1])\n\n    def test_wikitext103(self):\n        batch_size = 7\n        dataset_name = 'wikitext'\n        dataset_config_name = 'wikitext-103-raw-v1'\n        data_dir = Path(os.getenv('DATA_DIR', current_dir.parent.parent / 'data'))\n        cache_dir = data_dir / 'wikitext-103' / 'cache'\n        max_length = 1024\n        datamodule = LMDataModule(dataset_name, tokenizer_name='gpt2',\n                                  dataset_config_name=dataset_config_name,\n                                  max_length=max_length, cache_dir=cache_dir,\n                                  add_eos=False, batch_size=batch_size, num_workers=4)\n        datamodule.prepare_data()\n        datamodule.setup(stage='fit')\n        train_loader = datamodule.train_dataloader()\n        val_loader = datamodule.val_dataloader()\n        datamodule.setup(stage='test')\n        test_loader = datamodule.test_dataloader()\n        train_len = 117920140\n        val_len = 247289\n        test_len = 283287\n        assert len(train_loader) == div_up((train_len - 1) // max_length, batch_size)\n        assert len(val_loader) == div_up((val_len - 1) // max_length, batch_size)\n        assert len(test_loader) == div_up((test_len - 1) // max_length, batch_size)\n        for loader in [train_loader, val_loader, test_loader]:\n            x, y = next(iter(loader))\n            assert x.dim() == 2\n            assert x.shape == (batch_size, max_length)\n            assert x.dtype == torch.long\n            assert torch.allclose(x[:, 1:], y[:, :-1])\n\n    def test_openwebtext(self):\n        batch_size = 8\n        dataset_name = 'openwebtext'\n        dataset_config_name = None\n        data_dir = Path(os.getenv('DATA_DIR', current_dir.parent.parent / 'data'))\n        cache_dir = data_dir / 'openwebtext' / 'cache'\n        max_length = 1024\n        datamodule = LMDataModule(dataset_name, tokenizer_name='gpt2',\n                                  dataset_config_name=dataset_config_name,\n                                  max_length=max_length, cache_dir=cache_dir,\n                                  add_eos=True, batch_size=batch_size,\n                                  num_workers=num_cpu_cores() // 2)\n        datamodule.prepare_data()\n        datamodule.setup(stage='fit')\n        train_loader = datamodule.train_dataloader()\n        val_loader = datamodule.val_dataloader()\n        datamodule.setup(stage='test')\n        test_loader = datamodule.test_dataloader()\n        train_len = 9035582198\n        val_len = 4434897\n        test_len = 4434897\n        assert len(train_loader) == div_up((train_len - 1) // max_length, batch_size)\n        assert len(val_loader) == div_up((val_len - 1) // max_length, batch_size)\n        assert len(test_loader) == div_up((test_len - 1) // max_length, batch_size)\n        for loader in [train_loader, val_loader, test_loader]:\n            x, y = next(iter(loader))\n            assert x.dim() == 2\n            assert x.shape == (batch_size, max_length)\n            assert x.dtype == torch.long\n            assert torch.allclose(x[:, 1:], y[:, :-1])\n\n    def test_lambada(self):\n        batch_size = 8\n        dataset_name = 'lambada'\n        dataset_config_name = None\n        data_dir = Path(os.getenv('DATA_DIR', current_dir.parent.parent / 'data'))\n        cache_dir = data_dir / 'lambada' / 'cache'\n        max_length = 1024\n        datamodule = LMDataModule(dataset_name, tokenizer_name='gpt2',\n                                  dataset_config_name=dataset_config_name,\n                                  max_length=max_length, cache_dir=cache_dir,\n                                  add_eos=True, batch_size=batch_size,\n                                  num_workers=64)\n        datamodule.prepare_data()\n        datamodule.setup(stage='fit')\n        train_loader = datamodule.train_dataloader()\n        val_loader = datamodule.val_dataloader()\n        datamodule.setup(stage='test')\n        test_loader = datamodule.test_dataloader()\n        train_len = 9035582198\n        val_len = 4434897\n        test_len = 4434897\n        assert len(train_loader) == div_up((train_len - 1) // max_length, batch_size)\n        assert len(val_loader) == div_up((val_len - 1) // max_length, batch_size)\n        assert len(test_loader) == div_up((test_len - 1) // max_length, batch_size)\n        for loader in [train_loader, val_loader, test_loader]:\n            x, y = next(iter(loader))\n            assert x.dim() == 2\n            assert x.shape == (batch_size, max_length)\n            assert x.dtype == torch.long\n            assert torch.allclose(x[:, 1:], y[:, :-1])\n\n    def test_the_pile(self):\n        batch_size = 8\n        dataset_name = 'the_pile'\n        dataset_config_name = None\n        data_dir = Path(os.getenv('DATA_DIR', current_dir.parent.parent / 'data'))\n        cache_dir = data_dir / 'the_pile' / 'cache'\n        max_length = 2048\n        # Dataset is too large to fit into memory, need to use disk for concatenation\n        datamodule = LMDataModule(dataset_name, tokenizer_name='gpt2',\n                                  dataset_config_name=dataset_config_name,\n                                  max_length=max_length, cache_dir=cache_dir,\n                                  add_eos=True, batch_size=batch_size,\n                                  num_workers=num_cpu_cores() // 2, use_shmem=False)\n        datamodule.prepare_data()\n        datamodule.setup(stage='fit')\n        train_loader = datamodule.train_dataloader()\n        val_loader = datamodule.val_dataloader()\n        datamodule.setup(stage='test')\n        test_loader = datamodule.test_dataloader()\n        train_len = 374337375694\n        val_len = 383326395\n        test_len = 373297018\n        assert len(train_loader) == div_up((train_len - 1) // max_length, batch_size)\n        assert len(val_loader) == div_up((val_len - 1) // max_length, batch_size)\n        assert len(test_loader) == div_up((test_len - 1) // max_length, batch_size)\n        for loader in [train_loader, val_loader, test_loader]:\n            x, y = next(iter(loader))\n            assert x.dim() == 2\n            assert x.shape == (batch_size, max_length)\n            assert x.dtype == torch.long\n            assert torch.allclose(x[:, 1:], y[:, :-1])\n\n    def test_pg19(self):\n        batch_size = 8\n        dataset_name = 'pg19'\n        dataset_config_name = None\n        data_dir = Path(os.getenv('DATA_DIR', current_dir.parent.parent / 'data'))\n        cache_dir = data_dir / 'pg19' / 'cache'\n        max_length = 2048\n        # Dataset is too large to fit into memory, need to use disk for concatenation\n        datamodule = LMDataModule(dataset_name, tokenizer_name='gpt2',\n                                  dataset_config_name=dataset_config_name,\n                                  max_length=max_length, cache_dir=cache_dir,\n                                  add_eos=True, batch_size=batch_size,\n                                  num_workers=num_cpu_cores() // 2)\n        datamodule.prepare_data()\n        datamodule.setup(stage='fit')\n        train_loader = datamodule.train_dataloader()\n        val_loader = datamodule.val_dataloader()\n        datamodule.setup(stage='test')\n        test_loader = datamodule.test_dataloader()\n        train_len = 3066544128\n        val_len = 4653056\n        test_len = 10584064\n        assert len(train_loader) == div_up((train_len - 1) // max_length, batch_size)\n        assert len(val_loader) == div_up((val_len - 1) // max_length, batch_size)\n        assert len(test_loader) == div_up((test_len - 1) // max_length, batch_size)\n        for loader in [train_loader, val_loader, test_loader]:\n            x, y = next(iter(loader))\n            assert x.dim() == 2\n            assert x.shape == (batch_size, max_length)\n            assert x.dtype == torch.long\n            assert torch.allclose(x[:, 1:], y[:, :-1])\n"
  },
  {
    "path": "usage.md",
    "content": "# FlashAttention adoption\n\nWe've been very happy to see FlashAttention being adopted by many organizations\nand research labs to speed up their training / inference.\nThis page contains a partial list of places where FlashAttention is being used.\nIf you'd like to add links to your organization / product / codebase, please open a\nPR or email us. We'd very much like to hear from you!\n\n## Integrated into machine learning frameworks\n\n- Pytorch: [integrated](https://github.com/pytorch/pytorch/pull/81434) into core Pytorch in nn.Transformer.\n\n- Huggingface's [transformers](https://github.com/huggingface/transformers) library.\n  [On-going](https://github.com/huggingface/transformers/pull/18439), blogpost\n  coming soon.\n\n- Microsoft's [DeepSpeed](https://github.com/microsoft/DeepSpeed):\n  FlashAttention is [integrated](https://github.com/microsoft/DeepSpeed/blob/ec13da6ba7cabc44bb4745a64a208b8580792954/deepspeed/ops/transformer/inference/triton_ops.py) into DeepSpeed's inference engine.\n\n- Nvidia's [Megatron-LM](https://github.com/NVIDIA/Megatron-LM/pull/267). This\n  library is a popular framework on training large transformer language models at scale.\n\n- MosaicML [Composer](https://github.com/mosaicml/composer)\n  [library](https://www.mosaicml.com/blog/gpt-3-quality-for-500k). Composer is a\n  library for efficient neural network training.\n  \n- EleutherAI's [GPT-NeoX](https://github.com/EleutherAI/gpt-neox/pull/725). This is a research library for training large language transformer models at scale based on NVIDIA's Megatron-LM and Microsoft's DeepSpeed.\n\n- PaddlePaddle: integrated into the framework with [API](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/nn/functional/flash_attention.py) `paddle.nn.functional.flash_attention`.\n\n## MLPerf benchmarks\n\n[MLPerf](https://mlcommons.org/en/) is a competitive machine learning performance benchmark. FlashAttention\nyields the fastest BERT training on cloud instances in MLPerf training 2.0 (June\n2022) and MLPerf training 2.1 (November 2022).\n\n- MLPerf 2.0: [IEEE Spectrum](https://spectrum.ieee.org/mlperf-rankings-2022) and [Forbes](ttps://www.forbes.com/sites/moorinsights/2022/07/12/google-dethrones-nvidia-in-latest-artificial-intelligence-benchmarking-tests/) articles about our submission to the MLPerf 2.0 benchmark using FlashAttention.\n\n- MLPerf 2.1 -\n  collaboration\n  between [Azure and Hazy Research](https://techcommunity.microsoft.com/t5/azure-high-performance-computing/azure-collaborates-with-hazy-research-and-nvidia-to-achieve/ba-p/3667511): for the first time, we can train MLPerf BERT\n  in under 2 minutes on 16 nodes.\n\n- MLPerf 2.1 -\n  [Nvidia](https://developer.nvidia.com/blog/leading-mlperf-training-2-1-with-full-stack-optimizations-for-ai/):\n  Nvidia uses techniques from FlashAttention to make their (already extremely optimized) BERT\n  implementation go even faster.\n\n- MLPerf 2.1 - [MosaicML](https://www.mosaicml.com/blog/mlperf-nlp-nov2022): FlashAttention\n  helps train BERT 2.7x faster in the open division.\n\n## Language model training & inference\n\n- [PubMedGPT 2.7B](https://crfm.stanford.edu/2022/12/15/pubmedgpt.html), a\n  domain-specific LLM for biomedicine, by Stanford CRFM, trained on\n  [MosaicML](https://www.mosaicml.com/blog/introducing-pubmed-gpt) Cloud. Just\n  using FlashAttention nearly halves the total training time.\n\n- Meta's\n  [AITemplate](https://ai.facebook.com/blog/gpu-inference-engine-nvidia-amd-open-source/)\n  uses FlashAttention as part of their approach to speed up Transformer\n  inference (up to 5.3x on BERT).\n\n- Nvidia's [FasterTransformer](https://github.com/NVIDIA/FasterTransformer) is a\n  state-of-the-art Transformer inference library. As of version\n  [5.2](https://github.com/NVIDIA/FasterTransformer/commit/b672f49e256ba7a2d4fc9691d270b60b7fc1a2ff),\n  FlashAttention is used as a component of FasterTransformer to speed up GPT inference.\n\n- [Kernl](https://github.com/ELS-RD/kernl) is a library for fast Transformer\n  inference. They use FlashAttention as part of their\n  [approach](https://twitter.com/pommedeterre33/status/1585284221014245377) to\n  speed up Transformers by up to 12x.\n\n## Diffusion model training and inference\n\n- Huggingface's [diffusers](https://github.com/huggingface/diffusers) library\n  for diffusion models. FlashAttention is integrated into [diffusers\n  v0.7.0](https://github.com/huggingface/diffusers/releases/tag/v0.7.0).\n  Up to 2x faster inference and lower memory usage.\n\n- Colossal-AI's\n  [implementation](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion)\n  of Stable Diffusion: with FlashAttention as one of its components, it speeds up\n  pretraining by up to 6.5x, and reduces the hardware cost of fine-tuning by 7x.\n\n- Meta's\n  [AITemplate](https://ai.facebook.com/blog/gpu-inference-engine-nvidia-amd-open-source/)\n  with FlashAttention one of the components, is currently the [fastest](https://twitter.com/bing_xu_/status/1590447334055632897) Stable\n  Diffusion inference engine that we know of.\n\n- Stable Diffusion inference from\n  [Labml.ai](https://twitter.com/labmlai/status/1573634095732490240): 50% speedup.\n\n- Our own Stable Diffusion [fork](https://twitter.com/realDanFu/status/1580641495991754752) uses FlashAttention to get 3-4x speedup compared\n  to the original version.\n\n## Other models\n\n- [Uni-Fold](https://github.com/dptech-corp/Uni-Fold): Uni-Fold is an\n  open-source platform for developing protein models beyond AlphaFold. With\n  FlashAttention, Uni-Fold is 2.6x\n  [faster](https://twitter.com/guolin_ke/status/1580532071901995008) than AlphaFold.\n\n- [OpenFold](https://github.com/aqlaboratory/openfold): a trainable,\n  memory-efficient, and GPU-friendly PyTorch reproduction of AlphaFold 2. With\n  FlashAttention as one of its\n  [components](https://twitter.com/gahdritz/status/1595420944880779266), it is\n  up to 3x faster than AlphaFold2 to run inference on short sequences, and can\n  predict 2x longer structures.\n\n## Different implementations\n\n- [Triton](https://github.com/openai/triton): an [implementation](https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py) of\n  FlashAttention in Triton by Phil Tillet from OpenAI. Triton is a Python-based\n  language and compiler for parallel programming.\n\n- [xformers](https://github.com/facebookresearch/xformers): The xformers team\n  has implemented [memory-efficient\n  attention](https://twitter.com/fvsmassa/status/1580229170629849089) in a\n  similar spirit to FlashAttention.\n  xformers dynamically dispatches to whichever implementation is available / faster.\n\n- [Jax](https://github.com/google/jax): an [implementation](https://github.com/lucidrains/flash-attention-jax)\n  in Jax by [lucidrains](https://github.com/lucidrains/).\n\n- [Metal](https://developer.apple.com/metal): an [implementation](https://github.com/philipturner/metal-flash-attention) in Metal by Philip Turner. This ports FlashAttention to mobile GPU architectures such as Apple silicon.\n"
  }
]