Repository: XunhaoLai/native-sparse-attention-triton
Branch: main
Commit: 9bea856c911e
Files: 44
Total size: 359.0 KB
Directory structure:
gitextract_5vrtmrhk/
├── .gitignore
├── LICENSE
├── README.md
├── install_dependency.sh
├── native_sparse_attention/
│ ├── __init__.py
│ ├── infer/
│ │ ├── __init__.py
│ │ ├── inference_func.py
│ │ └── nsa_inference.py
│ ├── model/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── toy_llama.py
│ │ └── toy_nsa_llama.py
│ ├── module/
│ │ ├── __init__.py
│ │ ├── kv_cache.py
│ │ ├── native_sparse_attention.py
│ │ ├── rope.py
│ │ └── self_attention.py
│ └── ops/
│ ├── README.md
│ ├── __init__.py
│ ├── torch/
│ │ ├── __init__.py
│ │ ├── compress_key_value.py
│ │ ├── compressed_attention.py
│ │ ├── compressed_attention_decode.py
│ │ └── topk_sparse_attention.py
│ └── triton/
│ ├── __init__.py
│ ├── compressed_attention.py
│ ├── flash_attention.py
│ ├── flash_attention_decode.py
│ ├── linear_compress.py
│ ├── topk_sparse_attention.py
│ ├── topk_sparse_attention_decode.py
│ ├── utils.py
│ └── weighted_pool.py
├── setup.py
└── test/
├── test_compress_key_value.py
├── test_compressed_attention.py
├── test_flash_attention.py
├── test_kv_cache.py
├── test_linear_compress.py
├── test_nsa_infer.py
├── test_nsa_model.py
├── test_nsa_module.py
├── test_rope.py
└── test_topk_sparse_attention.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# UV
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
#uv.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# PyPI configuration file
.pypirc
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# Native Sparse Attention Triton
This repository implements the sparse attention mechanism introduced in the paper [Native Sparse Attention: Hardware-Aligned and Natively Trainable Sparse Attention](https://arxiv.org/abs/2502.11089) and provides an efficient training implementation based on [Triton](https://github.com/triton-lang/triton).
🎉 We now support both training and inference for Native Sparse Attention (variable-length version, including prefilling, decoding, and KV cache management). We have provided a toy model at `model.ToyNSALlama`, which supports `forward` function for training and `generate` function for inference. Welcome to try it out!
## Requirements
Ensure the following dependencies are installed:
- PyTorch >= 2.1.0
- triton >= 3.0.0
- einops >= 0.7.0
- flash_attn >= 2.6.3
## Usage
### Notes
1. PyTorch implementations (`ops.torch`) are intended for debugging only.
2. For production use, prefer Triton operators (`ops.triton`).
3. All implementations are based on the varlen approach similiar to flash_attn_func_varlen. Please concatenate the inputs of a batch before use.
4. Only support attention head dimension less than 128 for now.
### Install
You can install `native_sparse_attention` using pip:
```shell
pip install git+https://github.com/XunhaoLai/native-sparse-attention-triton.git
```
### Functions
The `ops` module has implemented several functions required for native sparse attention. For detailed usage instructions, please see [this link](https://github.com/XunhaoLai/native-sparse-attention-triton/tree/main/native_sparse_attention/ops#readme).
You can import those functions from the `ops` module:
```python
import torch
from native_sparse_attention.ops import linear_compress, compressed_attention, topk_sparse_attention
# input example
num_q_heads = 64
num_kv_heads = 4
head_dim = 128
kernel_size = 32
kernel_stride = 16
block_size = 64
topk = 16
cu_seqlens = torch.Tensor([0, 1024, 8192, 16384]).to(torch.int32).cuda()
query = torch.randn(16384, num_q_heads, head_dim).to(torch.bfloat16).cuda()
key = torch.randn(16384, num_kv_heads, head_dim).to(torch.bfloat16).cuda()
value = torch.randn(16384, num_kv_heads, head_dim).to(torch.bfloat16).cuda()
# weight example
w = (
torch.randn(num_kv_heads, kernel_size * head_dim, head_dim)
.to(torch.bfloat16)
.cuda()
)
pe = torch.randn(num_kv_heads, kernel_size, head_dim).to(torch.bfloat16).cuda()
# 1. key value compression
compressed_key, compressed_cu_seqlens = linear_compress(
key, w, cu_seqlens, kernel_size, kernel_stride, pe
)
compressed_value, _ = linear_compress(
value, w, cu_seqlens, kernel_size, kernel_stride, None
)
# 2. attention between query and compressed key value
compressed_attn_output, topk_idx = compressed_attention(
query,
compressed_key,
compressed_value,
kernel_size,
kernel_stride,
block_size,
topk,
cu_seqlens,
compressed_cu_seqlens,
init_blocks=1,
local_blocks=2,
)
# 3. topk sparse attention
sparse_attn_output = topk_sparse_attention(
query,
key,
value,
topk_idx,
block_size,
cu_seqlens,
)
```
### Module
The `modules` directory also provides implementations based on `torch.nn.module` for easy integration into models.
```python
from native_sparse_attention.modules import NativeSparseAttention, RopeConfig
NSA_Layer = NativeSparseAttention(
compress_type="linear",
hidden_size=4096,
num_q_heads=64,
num_kv_heads=4,
head_dim=128,
kernel_size=32,
kernel_stride=16,
block_size=64,
topk=8,
init_blocks=1,
local_blocks=2,
window_size=512,
rope_config=RopeConfig(
max_position_embeddings=32768,
head_dim=128,
rope_theta=500000,
rope_scaling={
"factor": 4.0,
"high_freq_factor": 4.0,
"low_freq_factor": 1.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3",
},
),
)
```
### Model
We offer two simplified LLaMA models in the `model` directory, featuring self-attention and native sparse attention. For more details on how to use these models, please refer to [this link](https://github.com/XunhaoLai/native-sparse-attention-triton/tree/main/native_sparse_attention/model#readme).
```python
from native_sparse_attention.model import ToyNSALlamaConfig, InferenceConfig, ToyNSALlama
config = ToyNSALlamaConfig(
hidden_size=4096,
intermediate_size=14336,
num_hidden_layers=8,
num_attention_heads=32,
num_key_value_heads=2,
head_dim=128,
rope_theta=500000.0,
rope_scaling={
"factor": 8.0,
"high_freq_factor": 4.0,
"low_freq_factor": 1.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3",
},
compress_type="weightedpool",
kernel_size=32,
kernel_stride=16,
block_size=64,
topk=8,
init_blocks=1,
local_blocks=2,
window_size=512,
)
inference_config = InferenceConfig(
max_batch_size=4,
max_length=8192,
max_new_tokens=128,
)
model = ToyNSALlama(config, inference_config).cuda().bfloat16()
```
## Testing
Some test scripts are available in the `test` folder and can be run directly for unit testing. For example:
```bash
python test/test_topk_sparse_attention.py
python test/test_nsa_module.py
python test/test_nsa_model.py
```
### Benchmarks
Here are the speed benchmarks conducted on a single NVIDIA A100 GPU or H100 GPU for the `topk_sparse_attention` function:
A100 GPU speed benchmarks:
```sh
** forward with block size 64 **:
N Flash Triton-Flash Triton-Top8 Triton-Top16
0 2048.0 0.414144 0.635648 0.633440 1.009184
1 4096.0 1.400304 2.267552 1.179808 1.916736
2 8192.0 5.223776 8.528160 2.266816 3.723168
3 16384.0 20.225697 32.745537 4.468128 7.359168
4 32768.0 79.587715 128.951065 8.517440 14.142848
5 65536.0 321.240479 511.652100 17.249599 30.991360
6 131072.0 1349.810425 2063.245605 36.400482 67.884544
** backward with block size 64 **:
N Flash Triton-Flash Triton-Top8 Triton-Top16
0 2048.0 1.315440 2.348560 1.941568 2.691040
1 4096.0 4.271584 8.553184 3.647744 5.032160
2 8192.0 15.323984 32.665440 5.650144 9.066112
3 16384.0 58.753281 127.675964 11.160832 17.113279
4 32768.0 227.770462 504.572693 21.723392 34.715614
5 65536.0 899.181274 2059.718506 44.517181 76.309441
6 131072.0 3587.918701 8530.726562 105.344734 182.970169
```
H100 GPU benchmarks:
```sh
** forward with block size 64 **:
N Flash Triton-Flash Triton-Top8 Triton-Top16
0 2048.0 0.259552 0.293888 0.584544 0.917664
1 4096.0 0.846848 1.029904 1.094976 1.745136
2 8192.0 3.043744 3.843392 2.128256 3.396880
3 16384.0 11.743568 14.791360 4.190528 6.704192
4 32768.0 45.968513 57.532478 7.614496 12.417440
5 65536.0 187.234375 228.093948 14.840048 24.511856
6 131072.0 810.890381 914.693970 29.470400 48.990192
** backward with block size 64 **:
N Flash Triton-Flash Triton-Top8 Triton-Top16
0 2048.0 0.798976 1.096096 1.117312 1.380016
1 4096.0 2.545680 3.826336 1.669760 2.214880
2 8192.0 9.029760 14.411633 2.772096 3.947456
3 16384.0 34.144016 58.945698 5.201344 7.538912
4 32768.0 135.718369 233.369247 9.968864 15.154192
5 65536.0 541.053894 929.337646 21.089870 33.818878
6 131072.0 2139.974854 3785.540527 54.918144 93.750717
```
Here comes another speed benchmark result for testing `compressed_attention` function on a single NVIDIA A100 GPU or H100 GPU:
A100 GPU speed benchmarks:
```sh
** forward with kernel 32 and stride 16 **:
N Flash Triton-Flash Compressed Compressed-wo-Score
0 2048.0 0.413664 0.635488 0.655024 0.170816
1 4096.0 1.396416 2.247648 1.132304 0.377152
2 8192.0 5.234656 8.526400 2.879200 0.977952
3 16384.0 19.988865 32.755199 9.426448 2.943024
4 32768.0 79.419907 128.955170 30.284096 9.901120
5 65536.0 321.590210 511.615509 112.260544 36.001602
6 131072.0 1346.996338 2069.837891 423.099518 136.820038
** backward with kernel 32 and stride 16 **:
N Flash Triton-Flash Compressed
0 2048.0 1.322560 2.352000 0.486784
1 4096.0 4.270832 8.552608 0.971392
2 8192.0 15.515680 32.671329 2.603744
3 16384.0 59.345055 128.377472 8.499456
4 32768.0 230.626144 506.581238 30.064833
5 65536.0 919.260498 2068.642578 113.466560
6 131072.0 3646.603760 8498.374023 439.623444
```
H100 GPU speed benchmarks:
```sh
** forward with kernel 32 and stride 16 **:
N Flash Triton-Flash Compressed Compressed-wo-Score
0 2048.0 0.259488 0.297152 0.485920 0.103232
1 4096.0 0.847376 1.030400 0.710208 0.217760
2 8192.0 3.044016 3.875840 1.607360 0.516016
3 16384.0 11.823104 14.829360 4.970272 1.440288
4 32768.0 46.204750 57.527809 15.004992 4.584736
5 65536.0 187.324249 227.909958 53.009087 16.134224
6 131072.0 810.707214 910.106873 191.245728 60.154270
** backward with kernel 32 and stride 16 **:
N Flash Triton-Flash Compressed
0 2048.0 0.797728 1.090640 0.283104
1 4096.0 2.547088 3.834592 0.550464
2 8192.0 9.021520 14.421088 1.249184
3 16384.0 34.159508 58.793377 3.743440
4 32768.0 136.844070 233.447708 12.640032
5 65536.0 537.559814 929.360229 46.054817
6 131072.0 2135.629883 3782.351562 175.587296
```
All the speed benchmarks above were tested with 64 query heads, 4 key/value heads, and a head dimension of 128.
## Contributing
Contributions are welcome! Please open an issue to discuss major changes.
## Contact
For any questions or feedback, please feel free to contact laixunhao@pku.edu.cn.
## Citations
```bibtex
@inproceedings{Yuan2025NativeSA,
title = {Native Sparse Attention: Hardware-Aligned and Natively Trainable Sparse Attention},
author = {Jingyang Yuan and Huazuo Gao and Damai Dai and Junyu Luo and Liang Zhao and Zhengyan Zhang and Zhenda Xie and Y. X. Wei and Lean Wang and Zhiping Xiao and Yuqing Wang and Chong Ruan and Ming Zhang and Wenfeng Liang and Wangding Zeng},
year = {2025},
url = {https://api.semanticscholar.org/CorpusID:276408911}
}
```
================================================
FILE: install_dependency.sh
================================================
pip3 install packaging -i https://pypi.org/simple
pip3 install numpy==1.26.4 -i https://pypi.org/simple
pip3 install torch==2.4.0 -i https://pypi.org/simple
pip3 install triton==3.0.0 -i https://pypi.org/simple
pip3 install transformers==4.44.0 -i https://pypi.org/simple
pip3 install flash_attn==2.6.3 -i https://pypi.org/simple
pip3 install matplotlib==3.9.4 -i https://pypi.org/simple
pip3 install pandas==2.2.3 -i https://pypi.org/simple
================================================
FILE: native_sparse_attention/__init__.py
================================================
================================================
FILE: native_sparse_attention/infer/__init__.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from native_sparse_attention.infer.nsa_inference import nsa_infer
__all__ = [
"nsa_infer",
]
================================================
FILE: native_sparse_attention/infer/inference_func.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from typing import Tuple, Callable, Optional
from flash_attn import flash_attn_varlen_func
from native_sparse_attention.ops import (
flash_attention_decode,
compressed_attention,
compressed_attention_decode,
topk_sparse_attention,
topk_sparse_attention_decode,
)
from native_sparse_attention.ops.triton.utils import get_compressed_seqlens
def compress_infer(
cu_seqlens: torch.Tensor,
step: int,
key: torch.Tensor,
value: torch.Tensor,
cache,
weight: Tuple[torch.Tensor, torch.Tensor],
compress_func: Tuple[Callable, Callable],
intra_block_pe: Optional[torch.Tensor],
kernel_size: int,
kernel_stride: int,
):
if step == 0:
key, compress_cu_seqlens = compress_func[0](
key,
weight[0],
cu_seqlens,
kernel_size,
kernel_stride,
intra_block_pe,
)
value, _ = compress_func[1](
value,
weight[1],
cu_seqlens,
kernel_size,
kernel_stride,
)
else:
batch_size = cu_seqlens.shape[0] - 1
aux_cu_seqlens = (
torch.arange(batch_size + 1, dtype=torch.int32).to(cu_seqlens.device)
* kernel_size
)
key, _ = compress_func[0](
cache.before_compress_kv_cache[0, :batch_size].view(
batch_size * kernel_size, cache.num_kv_heads, cache.head_dim
),
weight[0],
aux_cu_seqlens,
kernel_size,
kernel_stride,
intra_block_pe,
)
value, _ = compress_func[1](
cache.before_compress_kv_cache[1, :batch_size].view(
batch_size * kernel_size, cache.num_kv_heads, cache.head_dim
),
weight[1],
aux_cu_seqlens,
kernel_size,
kernel_stride,
)
# return actual compress_cu_seqlens before this token
compress_cu_seqlens = torch.zeros(
batch_size + 1, dtype=torch.int32, device=key.device
)
compress_cu_seqlens[1:] = torch.cumsum(
cache.compress_kv_len[:batch_size], dim=0
)
return key, value, compress_cu_seqlens
def compressed_attention_infer(
cu_seqlens,
step,
query,
key,
value,
cache,
kernel_size,
kernel_stride,
topk,
block_size,
init_blocks,
local_blocks,
):
if step == 0:
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
compress_seqlens, compress_cu_seqlens = get_compressed_seqlens(
cu_seqlens, kernel_size, kernel_stride
)
attn_output, topk_idx = compressed_attention(
query,
key,
value,
kernel_size,
kernel_stride,
block_size,
topk,
cu_seqlens,
compress_cu_seqlens,
seqlens.max().item(),
compress_seqlens.max().item(),
None,
init_blocks,
local_blocks,
)
else:
batch_size = cu_seqlens.shape[0] - 1
seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + step
attn_output, topk_idx = compressed_attention_decode(
query,
cache.compress_kv_cache[
0, :batch_size, : cache.compress_kv_len[:batch_size].max()
],
cache.compress_kv_cache[
1, :batch_size, : cache.compress_kv_len[:batch_size].max()
],
seqlens,
cache.compress_kv_len[:batch_size],
kernel_size,
kernel_stride,
block_size,
topk,
init_blocks,
local_blocks,
)
return attn_output, topk_idx
def topk_sparse_attention_infer(
cu_seqlens,
step,
query,
key,
value,
cache,
topk_idx,
block_size,
):
if step == 0:
attn_output = topk_sparse_attention(
query, key, value, topk_idx, block_size, cu_seqlens
)
else:
batch_size = cu_seqlens.shape[0] - 1
attn_output = topk_sparse_attention_decode(
query,
cache.sparse_kv_cache[0, :batch_size],
cache.sparse_kv_cache[1, :batch_size],
topk_idx,
block_size,
cache.sparse_kv_len[:batch_size],
)
return attn_output
def sliding_window_attention_infer(
cu_seqlens, step, query, key, value, cache, window_size
):
if step == 0:
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
attn_output = flash_attn_varlen_func(
query,
key,
value,
cu_seqlens,
cu_seqlens,
seqlens.max().item(),
seqlens.max().item(),
causal=True,
window_size=(window_size, -1),
)
else:
batch_size = cu_seqlens.shape[0] - 1
attn_output = flash_attention_decode(
query,
cache.sliding_kv_cache[0, :batch_size],
cache.sliding_kv_cache[1, :batch_size],
torch.minimum(
cache.sliding_kv_len,
torch.zeros_like(cache.sliding_kv_len) + window_size,
)[:batch_size],
)
return attn_output
================================================
FILE: native_sparse_attention/infer/nsa_inference.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from typing import Tuple, Callable, Optional
from native_sparse_attention.infer.inference_func import (
compress_infer,
compressed_attention_infer,
topk_sparse_attention_infer,
sliding_window_attention_infer,
)
def nsa_infer(
cu_seqlens: torch.Tensor,
step: int,
# qkv for three parts
query: torch.Tensor,
key: torch.Tensor, # prefill: [total_len, num_heads, head_dim], decode: [batch_size, num_heads, head_dim]
value: torch.Tensor,
gate_value: torch.Tensor, # prefill: [total_len, num_heads, 3], decode: [batch_size, num_heads, 3]
# rope and kv cache
rope,
cache,
# weight for nsa compress
compress_weight: Tuple[
torch.Tensor, torch.Tensor
], # compress weight for key and value
compress_func: Tuple[Callable, Callable], # compress function for key and value
intra_block_pe: Optional[torch.Tensor],
# nsa parameters
kernel_size: int,
kernel_stride: int,
block_size: int,
topk: int,
init_blocks: int,
local_blocks: int,
window_size: int,
) -> torch.Tensor:
"""Inference function for native sparse attention. Support prefill and decode with kv cache.
Args:
cu_seqlens (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.
step (int): current inference step, step == 0 means prefill, step > 0 means decode step.
query (torch.Tensor): for prefill, shape [total_len, num_q_heads, head_dim]; for decode, shape [batch_size, num_q_heads, head_dim]
key (torch.Tensor): for prefill, shape [total_len, num_kv_heads, head_dim]; for decode, shape [batch_size, num_kv_heads, head_dim]
value (torch.Tensor): for prefill, shape [total_len, num_kv_heads, head_dim]; for decode, shape [batch_size, num_kv_heads, head_dim]
gate_value (torch.Tensor): for prefill, shape [total_len, num_heads, 3]; for decode, shape [batch_size, num_heads, 3]
rope (RotaryEmbedding): rope module, see native_sparse_attention.module.rope.RotaryEmbedding for details
cache (NSACache): kv cache, seed native_sparse_attention.module.kv_cache.NSACache for details
compress_weight (Tuple[torch.Tensor, torch.Tensor]): compress weight of key and value respectively
compress_func (Tuple[Callable, Callable]): compress functions for key and value respectively
intra_block_pe (Optional[torch.Tensor]): intra-block positonal embedding for compression, set to None if don't use it
kernel_size (int): kernel size of compression
kernel_stride (int): kernel stride ofr compression
block_size (int): block size of sparse attention
topk (int): topk of sparse attention
init_blocks (int): number of blocks at the begining of the sequence, these blocks are force to be computed in sparse attention
local_blocks (int): number of blocks at the local window of each query, these blocks are force to be computed in sparse attention
window_size (int): window size for sliding window attention
Returns:
torch.Tensor: native sparse attention output, same shape as input query
"""
# reset kv cache at the begining of prefilling
if step == 0:
cache.reset()
# prepare for compress
cache.prepare_compress(cu_seqlens, step, key, value)
# compressed key and value before rope
compress_key, compress_value, compress_cu_seqlens = compress_infer(
cu_seqlens,
step,
key,
value,
cache,
compress_weight,
compress_func,
intra_block_pe,
kernel_size,
kernel_stride,
)
# do rope
query = rope(query, cu_seqlens, step)
if step == 0:
compress_key = rope(
compress_key, compress_cu_seqlens, step, stride=cache.kernel_stride
)
else:
compress_key = rope(
compress_key, compress_cu_seqlens, 1, stride=cache.kernel_stride
)
key = rope(key, cu_seqlens, step)
# update kv cache
cache.update_kv(
cu_seqlens,
step,
compress_key,
compress_value,
key,
value,
key,
value,
)
# compressed attention
compress_attn_output, topk_idx = compressed_attention_infer(
cu_seqlens,
step,
query,
compress_key,
compress_value,
cache,
kernel_size,
kernel_stride,
topk,
block_size,
init_blocks,
local_blocks,
)
# topk sparse attention
sparse_attn_output = topk_sparse_attention_infer(
cu_seqlens,
step,
query,
key,
value,
cache,
topk_idx,
block_size,
)
# sliding window attention
sliding_attn_output = sliding_window_attention_infer(
cu_seqlens, step, query, key, value, cache, window_size
)
# combine 3 attn output
attn_output = (
gate_value[..., 0, None] * compress_attn_output
+ gate_value[..., 1, None] * sparse_attn_output
+ gate_value[..., 2, None] * sliding_attn_output
)
return attn_output
================================================
FILE: native_sparse_attention/model/README.md
================================================
# Guide for the ToyNSALlama Model
The `ToyNSALlama` model is a custom implementation of a Llama-like transformer architecture featuring a Native Sparse Attention (NSA) module. This guide explains how to integrate the NSA module into your own model.
## Overview
The `ToyNSALlama` model consists of:
- **Configuration**: Defined by `ToyNSALlamaConfig` (model structure parameters) and `InferenceConfig` (inference-specific parameters).
- **Components**: An embedding layer, multiple NativeSparseAttention modules, Feed-Forward Network (FFN) modules, normalization layers, and a language model head.
## Step-by-Step Instructions
### 1. Import Necessary Modules
```python
import torch
import torch.nn as nn
from native_sparse_attention.model import ToyNSALlama, ToyNSALlamaConfig, InferenceConfig
```
### 2. Define Configurations
Create instances of `ToyNSALlamaConfig` and `InferenceConfig` to set model and inference parameters.
#### Model Configuration
The model configuration aligns with the Transformers Llama model configuration. Adjust the following parameters to control the NSA module’s sparsity:
- `compress_type`: Compression method for keys/values. Supported options: `avgpool`, `weightedpool`, `linear`.
- `kernel_size` & `kernel_stride`: `kernel_size` determines how many tokens are compressed into one; `kernel_stride` sets the sliding window stride (must be divisible by `kernel_size`).
- `block_size`: Block size for sparse attention (recommended: 64 or 128).
- `topk`, `init_blocks`, `local_blocks`: `topk` specifies the number of blocks selected in sparse attention; `init_blocks` and `local_blocks` define the number of initial and local blocks that must be selected.
- `window_size`: Size of the sliding window for attention.
Example:
```python
config = ToyNSALlamaConfig(
hidden_size=4096,
intermediate_size=14336,
num_hidden_layers=8,
num_attention_heads=32,
num_key_value_heads=2,
head_dim=128,
vocab_size=128288,
max_position_embeddings=131072,
rope_theta=500000.0,
rope_scaling={
"factor": 8.0,
"high_freq_factor": 4.0,
"low_freq_factor": 1.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3",
},
compress_type="weightedpool",
kernel_size=32,
kernel_stride=16,
block_size=64,
topk=8,
init_blocks=1,
local_blocks=2,
window_size=512,
)
```
#### Inference Configuration
This configuration applies during inference, initializing the Key-Value (KV) Cache based on these settings. The full KV cache size is calculated as `max_batch_size × max_length × num_kv_heads × num_layers × 2 × 2` bytes. Currently, only greedy decoding is supported as an example.
Example:
```python
inference_config = InferenceConfig(
max_batch_size=4,
max_length=8192,
max_new_tokens=128,
)
```
### 3. Initialize the Model
Instantiate the model and move it to the GPU with the appropriate data type (currently, only `bfloat16` is supported).
```python
model = ToyNSALlama(config, inference_config).cuda().to(torch.bfloat16)
```
### 4. Forward & Generate
The model supports two methods:
- **`forward`**: Accepts `input_ids` and `cu_seqlens`, returning final logits after the language model head. Use this for training or evaluation.
- **`generate`**: Accepts `input_ids` and `cu_seqlens`, generating output tokens via greedy sampling. This demonstrates KV cache usage for token generation (pre-filling and decoding).
Example:
```python
# Example input
batch_size = 4
seqlens = torch.randint(0, 4096, (batch_size,), dtype=torch.int32, device="cuda")
cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
input_ids = torch.randint(0, 128288, (cu_seqlens[-1],), dtype=torch.int64, device="cuda")
print(f"\nEXAMPLE INPUT:\ncu_seqlens: {cu_seqlens}\ninput_ids: {input_ids.shape}\n")
# Example forward
logits = model(input_ids, cu_seqlens)
print(f"\nEXAMPLE OUTPUT:\nlogits: {logits.shape}\n")
# Example generate
output_tokens = model.generate(input_ids, cu_seqlens)
print(f"\nEXAMPLE GENERATE:\noutput_tokens: {output_tokens}\n")
```
## Toy Llama Model with Self-Attention
A simpler toy model with the Llama structure is available in `native_sparse_attention/model/toy_llama.py`. Compare `ToyLlama` and `ToyNSALlama` to see how to adapt a self-attention model into an NSA-based model.
The primary difference lies in replacing the `SelfAttention` module with the `NativeSparseAttention` module, along with updates to the KV cache and inference function. These changes are straightforward and easy to implement.
================================================
FILE: native_sparse_attention/model/__init__.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from native_sparse_attention.model.toy_llama import (
ToyLlamaConfig,
InferenceConfig,
ToyLlama,
)
from native_sparse_attention.model.toy_nsa_llama import (
ToyNSALlamaConfig,
InferenceConfig,
ToyNSALlama,
)
__all__ = [
"ToyLlamaConfig",
"ToyNSALlamaConfig",
"InferenceConfig",
"ToyLlama",
"ToyNSALlama",
]
================================================
FILE: native_sparse_attention/model/toy_llama.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
import torch.nn as nn
from dataclasses import dataclass, field
from native_sparse_attention.module import SelfAttention, RopeConfig, KVCache
@dataclass
class ToyLlamaConfig:
# embedding config
vocab_size: int = 128288
max_position_embeddings: int = 131072
# model config
hidden_size: int = 4096
intermediate_size: int = 14336
num_hidden_layers: int = 32
num_attention_heads: int = 32
num_key_value_heads: int = 2
head_dim: int = 128
# rope config
rope_theta: float = 500000.0
rope_scaling: dict = field(
default_factory=lambda: {
"factor": 8.0,
"high_freq_factor": 4.0,
"low_freq_factor": 1.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3",
}
)
@dataclass
class InferenceConfig:
max_batch_size: int = 32
max_length: int = 8192
max_new_tokens: int = 128
class RMSNorm(nn.Module):
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states: torch.Tensor):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class FFN(nn.Module):
def __init__(self, hidden_size: int, intermediate_size: int):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = nn.SiLU()
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
class ToyLlamaLayer(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
num_q_heads: int,
num_kv_heads: int,
head_dim: int,
rope_config: RopeConfig,
):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_q_heads = num_q_heads
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.rope_config = rope_config
self.attn_norm = RMSNorm(self.hidden_size)
self.self_attn = SelfAttention(
hidden_size=self.hidden_size,
num_q_heads=self.num_q_heads,
num_kv_heads=self.num_kv_heads,
head_dim=self.head_dim,
rope_config=rope_config,
)
self.ffn_norm = RMSNorm(self.hidden_size)
self.ffn = FFN(
hidden_size=self.hidden_size, intermediate_size=self.intermediate_size
)
def forward(self, x, cu_seqlens):
x = x + self.self_attn(self.attn_norm(x), cu_seqlens)
x = x + self.ffn(self.ffn_norm(x))
return x
@torch.no_grad()
def inference(self, x, cu_seqlens, step, kv_cache):
x = x + self.self_attn.inference(self.attn_norm(x), cu_seqlens, step, kv_cache)
x = x + self.ffn(self.ffn_norm(x))
return x
class ToyLlama(nn.Module):
def __init__(
self, config: ToyLlamaConfig, inference_config: Optional[InferenceConfig] = None
):
super().__init__()
self.config = config
self.embedding = nn.Embedding(self.config.vocab_size, self.config.hidden_size)
self.rope_config = RopeConfig(
head_dim=self.config.head_dim,
rope_theta=self.config.rope_theta,
rope_scaling=self.config.rope_scaling,
)
self.layers = nn.ModuleList(
[
ToyLlamaLayer(
hidden_size=self.config.hidden_size,
intermediate_size=self.config.intermediate_size,
num_q_heads=self.config.num_attention_heads,
num_kv_heads=self.config.num_key_value_heads,
head_dim=self.config.head_dim,
rope_config=RopeConfig(
self.config.max_position_embeddings,
self.config.head_dim,
self.config.rope_theta,
self.config.rope_scaling,
),
)
for _ in range(self.config.num_hidden_layers)
]
)
self.norm = RMSNorm(self.config.hidden_size)
self.lm_head = nn.Linear(
self.config.hidden_size, self.config.vocab_size, bias=False
)
# inference config and kv cache
self.inference_config = inference_config
self.kv_cache = None
def forward(
self,
input_ids: torch.LongTensor, # shape: [total_length, ]
cu_seqlens: torch.LongTensor, # shape: [batch_size + 1, ]
):
# embedding
x = self.embedding(input_ids).to(torch.bfloat16)
# layers
for layer in self.layers:
x = layer(x, cu_seqlens)
# final norm
x = self.norm(x)
# lanugauge head
x = self.lm_head(x).to(torch.float32) # [total_len, vocab_size]
return x
@torch.no_grad()
def inference(
self,
input_ids: torch.LongTensor, # prefill shape: [total_length, ]; decode shape: [batch_size, ]
cu_seqlens: torch.LongTensor, # shape: [batch_size + 1, ]
step: int,
):
# set kv cache if self.kv_cache is None
if self.kv_cache is None:
self.kv_cache = [
KVCache(
max_batch_size=self.inference_config.max_batch_size,
max_length=self.inference_config.max_length,
num_kv_heads=self.config.num_key_value_heads,
head_dim=self.config.head_dim,
dtype=torch.bfloat16,
device="cuda",
)
for _ in range(self.config.num_hidden_layers)
]
# embedding
x = self.embedding(input_ids).to(torch.bfloat16)
# layers
for i, layer in enumerate(self.layers):
x = layer.inference(x, cu_seqlens, step, self.kv_cache[i])
# final norm
x = self.norm(x)
# lanugauge head
if step == 0:
x = x[cu_seqlens[1:] - 1, :]
x = self.lm_head(x).to(torch.float32) # [total_len, vocab_size]
return x
def generate(
self,
input_ids: torch.LongTensor,
cu_seqlens: torch.LongTensor,
max_new_tokens: int = -1,
):
output_tokens = []
if max_new_tokens <= 0:
max_new_tokens = self.inference_config.max_new_tokens
for step in range(max_new_tokens):
logits = self.inference(
input_ids, cu_seqlens, step
) # shape: [batch_size, vocab_size]
next_token = torch.argmax(logits, dim=-1) # shape: [batch_size, ]
input_ids = next_token
output_tokens.append(next_token)
output_tokens = torch.stack(
output_tokens, dim=1
) # shape: [batch_size, max_new_tokens]
return output_tokens
if __name__ == "__main__":
torch.manual_seed(42)
# initialize model
config = ToyLlamaConfig(
hidden_size=4096,
intermediate_size=14336,
num_hidden_layers=8,
num_attention_heads=32,
num_key_value_heads=2,
head_dim=128,
rope_theta=500000.0,
rope_scaling={
"factor": 8.0,
"high_freq_factor": 4.0,
"low_freq_factor": 1.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3",
},
)
inference_config = InferenceConfig(
max_batch_size=4,
max_length=8192,
max_new_tokens=128,
)
model = ToyLlama(config, inference_config).cuda().bfloat16()
print(f"\nMODEL CONFIG:\n{config}\n")
print(f"\nINFERENCE CONFIG:\n{inference_config}\n")
print(f"\nMODEL:\n{model}\n")
# example input
batch_size = 4
seqlens = torch.randint(0, 4096, (batch_size,), dtype=torch.int32, device="cuda")
cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
input_ids = torch.randint(
0, 128288, (cu_seqlens[-1],), dtype=torch.int64, device="cuda"
)
print(f"\nEXAMPLE INPUT:\ncu_seqlens: {cu_seqlens}\ninput_ids: {input_ids.shape}\n")
# example output
logits = model(input_ids, cu_seqlens)
print(f"\nEXAMPLE OUTPUT:\nlogits: {logits.shape}\n")
# example generate
output_tokens = model.generate(input_ids, cu_seqlens, 64)
print(f"\nEXAMPLE GENERATE:\noutput_tokens: {output_tokens}\n")
================================================
FILE: native_sparse_attention/model/toy_nsa_llama.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
import torch.nn as nn
from dataclasses import dataclass, field
from native_sparse_attention.module import NativeSparseAttention, RopeConfig, NSACache
@dataclass
class ToyNSALlamaConfig:
# embedding config
vocab_size: int = 128288
max_position_embeddings: int = 131072
# model config
hidden_size: int = 4096
intermediate_size: int = 14336
num_hidden_layers: int = 32
num_attention_heads: int = 32
num_key_value_heads: int = 2
head_dim: int = 128
# rope config
rope_theta: float = 500000.0
rope_scaling: dict = field(
default_factory=lambda: {
"factor": 8.0,
"high_freq_factor": 4.0,
"low_freq_factor": 1.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3",
}
)
# nsa config
compress_type: str = "weightedpool"
kernel_size: int = 32
kernel_stride: int = 16
block_size: int = 64
topk: int = 16
init_blocks: int = 1
local_blocks: int = 2
window_size: int = 512
@dataclass
class InferenceConfig:
max_batch_size: int = 32
max_length: int = 8192
max_new_tokens: int = 128
class RMSNorm(nn.Module):
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states: torch.Tensor):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class FFN(nn.Module):
def __init__(self, hidden_size: int, intermediate_size: int):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = nn.SiLU()
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
class ToyNSALlamaLayer(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
num_q_heads: int,
num_kv_heads: int,
head_dim: int,
compress_type: str,
kernel_size: int,
kernel_stride: int,
block_size: int,
topk: int,
init_blocks: int,
local_blocks: int,
window_size: int,
rope_config: RopeConfig,
):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_q_heads = num_q_heads
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.compress_type = compress_type
self.kernel_size = kernel_size
self.kernel_stride = kernel_stride
self.block_size = block_size
self.topk = topk
self.init_blocks = init_blocks
self.local_blocks = local_blocks
self.window_size = window_size
self.rope_config = rope_config
self.attn_norm = RMSNorm(self.hidden_size)
self.nsa = NativeSparseAttention(
hidden_size=self.hidden_size,
num_q_heads=self.num_q_heads,
num_kv_heads=self.num_kv_heads,
head_dim=self.head_dim,
compress_type=self.compress_type,
kernel_size=self.kernel_size,
kernel_stride=self.kernel_stride,
block_size=self.block_size,
topk=self.topk,
init_blocks=self.init_blocks,
local_blocks=self.local_blocks,
window_size=self.window_size,
rope_config=rope_config,
)
self.ffn_norm = RMSNorm(self.hidden_size)
self.ffn = FFN(
hidden_size=self.hidden_size, intermediate_size=self.intermediate_size
)
def forward(self, x, cu_seqlens):
x = x + self.nsa(self.attn_norm(x), cu_seqlens)
x = x + self.ffn(self.ffn_norm(x))
return x
@torch.no_grad()
def inference(self, x, cu_seqlens, step, kv_cache):
x = x + self.nsa.inference(self.attn_norm(x), cu_seqlens, step, kv_cache)
x = x + self.ffn(self.ffn_norm(x))
return x
class ToyNSALlama(nn.Module):
def __init__(
self,
config: ToyNSALlamaConfig,
inference_config: Optional[InferenceConfig] = None,
):
super().__init__()
self.config = config
self.embedding = nn.Embedding(self.config.vocab_size, self.config.hidden_size)
self.rope_config = RopeConfig(
head_dim=self.config.head_dim,
rope_theta=self.config.rope_theta,
rope_scaling=self.config.rope_scaling,
)
self.layers = nn.ModuleList(
[
ToyNSALlamaLayer(
hidden_size=self.config.hidden_size,
intermediate_size=self.config.intermediate_size,
num_q_heads=self.config.num_attention_heads,
num_kv_heads=self.config.num_key_value_heads,
head_dim=self.config.head_dim,
compress_type=self.config.compress_type,
kernel_size=self.config.kernel_size,
kernel_stride=self.config.kernel_stride,
block_size=self.config.block_size,
topk=self.config.topk,
init_blocks=self.config.init_blocks,
local_blocks=self.config.local_blocks,
window_size=self.config.window_size,
rope_config=RopeConfig(
self.config.max_position_embeddings,
self.config.head_dim,
self.config.rope_theta,
self.config.rope_scaling,
),
)
for _ in range(self.config.num_hidden_layers)
]
)
self.norm = RMSNorm(self.config.hidden_size)
self.lm_head = nn.Linear(
self.config.hidden_size, self.config.vocab_size, bias=False
)
# inference config and kv cache
self.inference_config = inference_config
self.kv_cache = None
def forward(
self,
input_ids: torch.LongTensor, # shape: [batch_size, max_length]
cu_seqlens: torch.LongTensor, # shape: [batch_size + 1, ]
):
# embedding
x = self.embedding(input_ids).to(torch.bfloat16)
# layers
for layer in self.layers:
x = layer(x, cu_seqlens)
# final norm
x = self.norm(x)
# lanugauge head
x = self.lm_head(x).to(torch.float32) # [total_len, vocab_size]
return x
@torch.no_grad()
def inference(
self,
input_ids: torch.LongTensor, # prefill shape: [total_length, ]; decode shape: [batch_size, ]
cu_seqlens: torch.LongTensor, # shape: [batch_size + 1, ]
step: int,
):
# set kv cache if self.kv_cache is None
if self.kv_cache is None:
self.kv_cache = [
NSACache(
max_batch_size=self.inference_config.max_batch_size,
max_length=self.inference_config.max_length,
num_kv_heads=self.config.num_key_value_heads,
head_dim=self.config.head_dim,
kernel_size=self.config.kernel_size,
kernel_stride=self.config.kernel_stride,
window_size=self.config.window_size,
dtype=torch.bfloat16,
device="cuda",
)
for _ in range(self.config.num_hidden_layers)
]
# embedding
x = self.embedding(input_ids).to(torch.bfloat16)
# layers
for i, layer in enumerate(self.layers):
x = layer.inference(x, cu_seqlens, step, self.kv_cache[i])
# final norm
x = self.norm(x)
# lanugauge head
if step == 0:
x = x[cu_seqlens[1:] - 1, :]
x = self.lm_head(x).to(torch.float32) # [total_len, vocab_size]
return x
def generate(
self,
input_ids: torch.LongTensor,
cu_seqlens: torch.LongTensor,
max_new_tokens: int = -1,
):
output_tokens = []
if max_new_tokens <= 0:
max_new_tokens = self.inference_config.max_new_tokens
for step in range(max_new_tokens):
logits = self.inference(
input_ids, cu_seqlens, step
) # shape: [batch_size, vocab_size]
next_token = torch.argmax(logits, dim=-1) # shape: [batch_size, ]
input_ids = next_token
output_tokens.append(next_token)
output_tokens = torch.stack(
output_tokens, dim=1
) # shape: [batch_size, max_new_tokens]
return output_tokens
if __name__ == "__main__":
torch.manual_seed(42)
# initialize model
config = ToyNSALlamaConfig(
hidden_size=4096,
intermediate_size=14336,
num_hidden_layers=8,
num_attention_heads=32,
num_key_value_heads=2,
head_dim=128,
rope_theta=500000.0,
rope_scaling={
"factor": 8.0,
"high_freq_factor": 4.0,
"low_freq_factor": 1.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3",
},
compress_type="weightedpool",
kernel_size=32,
kernel_stride=16,
block_size=64,
topk=8,
init_blocks=1,
local_blocks=2,
window_size=512,
)
inference_config = InferenceConfig(
max_batch_size=4,
max_length=8192,
max_new_tokens=128,
)
model = ToyNSALlama(config, inference_config).cuda().bfloat16()
print(f"\nMODEL CONFIG:\n{config}\n")
print(f"\nINFERENCE CONFIG:\n{inference_config}\n")
print(f"\nMODEL:\n{model}\n")
# example input
batch_size = 4
seqlens = torch.randint(0, 4096, (batch_size,), dtype=torch.int32, device="cuda")
cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
input_ids = torch.randint(
0, 128288, (cu_seqlens[-1],), dtype=torch.int64, device="cuda"
)
print(f"\nEXAMPLE INPUT:\ncu_seqlens: {cu_seqlens}\ninput_ids: {input_ids.shape}\n")
# example output
logits = model(input_ids, cu_seqlens)
print(f"\nEXAMPLE OUTPUT:\nlogits: {logits.shape}\n")
# example generate
output_tokens = model.generate(input_ids, cu_seqlens, 64)
print(f"\nEXAMPLE GENERATE:\noutput_tokens: {output_tokens}\n")
================================================
FILE: native_sparse_attention/module/__init__.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from native_sparse_attention.module.native_sparse_attention import NativeSparseAttention
from native_sparse_attention.module.self_attention import SelfAttention
from native_sparse_attention.module.rope import RotaryEmbedding, RopeConfig
from native_sparse_attention.module.kv_cache import NSACache, KVCache
__all__ = [
"SelfAttention",
"NativeSparseAttention",
"RotaryEmbedding",
"RopeConfig",
"NSACache",
]
================================================
FILE: native_sparse_attention/module/kv_cache.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import triton
import triton.language as tl
from typing import Union
from native_sparse_attention.ops.triton.utils import get_compressed_seqlens
class KVCache:
def __init__(
self,
max_batch_size: int,
max_length: int,
num_kv_heads: int,
head_dim: int,
dtype: torch.dtype,
device: Union[str, torch.device],
):
self.max_batch_size = max_batch_size
self.max_length = max_length
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.dtype = dtype
self.device = device
# alloc kv cache tensor for topk sparse attention
self.kv_cache = torch.zeros(
2,
self.max_batch_size,
self.max_length,
self.num_kv_heads,
self.head_dim,
dtype=self.dtype,
device=self.device,
)
self.kv_len = torch.zeros(
self.max_batch_size, dtype=torch.int32, device=self.device
)
def reset(self):
self.kv_cache.zero_()
self.kv_len.zero_()
def update_kv(
self,
cu_seqlens: torch.Tensor,
step: int,
key: torch.Tensor,
value: torch.Tensor,
):
if step == 0:
self._update_kv_prefill(
cu_seqlens,
step,
key,
value,
)
else:
self._update_kv_decode(
cu_seqlens,
step,
key,
value,
)
def _update_kv_prefill(
self,
cu_seqlens: torch.Tensor,
step: int,
key: torch.Tensor,
value: torch.Tensor,
):
assert step == 0
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
batch_size = seqlens.shape[0]
# sparse part kv, shape check
total_len, num_heads, head_dim = key.shape
assert key.shape == value.shape
assert num_heads == self.num_kv_heads and head_dim == self.head_dim
assert total_len == cu_seqlens[-1].item()
# fill sparse part kv cache
seq_start, seq_end = cu_seqlens[:-1], cu_seqlens[1:]
_fill_kv_cache(
self.kv_cache,
key,
value,
seq_start,
seq_end,
)
self.kv_len[:batch_size] = seqlens
def _update_kv_decode(
self,
cu_seqlens: torch.Tensor,
step: int,
key: torch.Tensor,
value: torch.Tensor,
):
assert step > 0
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
# sparse part kv, shape check
batch_size, num_heads, head_dim = key.shape
assert batch_size == seqlens.shape[0]
assert key.shape == value.shape
assert num_heads == self.num_kv_heads and head_dim == self.head_dim
# fill sparse part kv cache
brange = torch.arange(batch_size, dtype=torch.int32, device=key.device)
self.kv_cache[0, :batch_size][brange, self.kv_len[:batch_size]] = key
self.kv_cache[1, :batch_size][brange, self.kv_len[:batch_size]] = value
self.kv_len[:batch_size] += 1
class NSACache:
"""KV cache manager for native sparse attention.
Args:
max_batch_size (int): max batch size
max_length (int): max length, including prompt len and reponse len
num_kv_heads (int): number of key/value heads
head_dim (int): head dim
kernel_size (int): kernel size of compression
kernel_stride (int): kernel stride ofr compression
window_size (int): window size for sliding window attention
dtype (torch.dtype): data type for kv cache, should be same as model weight dtype
device (Union[str, torch.device]): default to 'cuda'
Methods:
reset: reset kv cache, should be called before prefilling
prepare_compress: store keys/values for compression, should be called before key/value compression at both prefilling and decoding
update_kv: update key/value cache, should be called after rope
"""
def __init__(
self,
max_batch_size: int,
max_length: int,
num_kv_heads: int,
head_dim: int,
kernel_size: int,
kernel_stride: int,
window_size: int,
dtype: torch.dtype,
device: Union[str, torch.device] = "cuda",
):
self.max_batch_size = max_batch_size
self.max_length = max_length
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.kernel_size = kernel_size
self.kernel_stride = kernel_stride
self.window_size = window_size
self.dtype = dtype
self.device = device
# alloc kv cache tensor for topk sparse attention
self.sparse_kv_cache = torch.zeros(
2,
self.max_batch_size,
self.max_length,
self.num_kv_heads,
self.head_dim,
dtype=self.dtype,
device=self.device,
)
self.sparse_kv_len = torch.zeros(
self.max_batch_size, dtype=torch.int32, device=self.device
)
# alloc kv cache tensor for compressed attention
self.max_comp_length = (
self.max_length - self.kernel_size
) // self.kernel_stride + 1
self.compress_kv_cache = torch.zeros(
2,
self.max_batch_size,
self.max_comp_length,
self.num_kv_heads,
self.head_dim,
dtype=self.dtype,
device=self.device,
)
self.compress_kv_len = torch.zeros(
self.max_batch_size, dtype=torch.int32, device=self.device
)
self.before_compress_kv_cache = torch.zeros(
2,
self.max_batch_size,
self.kernel_size,
self.num_kv_heads,
self.head_dim,
dtype=self.dtype,
device=self.device,
)
self.before_compress_kv_len = torch.zeros(
self.max_batch_size, dtype=torch.int32, device=self.device
)
# alloc kv cache for sliding window attention
self.sliding_kv_cache = torch.zeros(
2,
self.max_batch_size,
self.window_size,
self.num_kv_heads,
self.head_dim,
dtype=self.dtype,
device=self.device,
)
self.sliding_kv_len = torch.zeros(
self.max_batch_size, dtype=torch.int32, device=self.device
)
def reset(self):
self.compress_kv_cache.zero_()
self.compress_kv_len.zero_()
self.before_compress_kv_cache.zero_()
self.before_compress_kv_len.zero_()
self.sparse_kv_cache.zero_()
self.sparse_kv_len.zero_()
self.sliding_kv_cache.zero_()
self.sliding_kv_len.zero_()
def prepare_compress(
self,
cu_seqlens: torch.Tensor,
step: int,
key: torch.Tensor,
value: torch.Tensor,
):
if step == 0:
self._prepare_compress_prefill(cu_seqlens, step, key, value)
else:
self._prepare_compress_decode(cu_seqlens, step, key, value)
def _prepare_compress_prefill(
self,
cu_seqlens: torch.Tensor,
step: int,
key: torch.Tensor,
value: torch.Tensor,
):
assert step == 0
# compress part kv, shape check
batch_size = cu_seqlens.shape[0] - 1
total_len, num_heads, head_dim = key.shape
assert key.shape == value.shape
assert num_heads == self.num_kv_heads and head_dim == self.head_dim
comp_seqlens, comp_cu_seqlens = get_compressed_seqlens(
cu_seqlens, self.kernel_size, self.kernel_stride
)
assert total_len == cu_seqlens[-1].item()
# fill tmp cache
seq_start = cu_seqlens[:-1] + comp_seqlens * self.kernel_stride
seq_end = cu_seqlens[1:]
_fill_kv_cache(
self.before_compress_kv_cache,
key,
value,
seq_start,
seq_end,
)
self.before_compress_kv_len[:batch_size] = seq_end - seq_start
def _prepare_compress_decode(
self,
cu_seqlens: torch.Tensor,
step: int,
key: torch.Tensor,
value: torch.Tensor,
):
assert step > 0
# compress part kv, shape check
batch_size, num_heads, head_dim = key.shape
assert key.shape == value.shape
assert num_heads == self.num_kv_heads and head_dim == self.head_dim
assert batch_size == cu_seqlens.shape[0] - 1
# sequence with full before_compress_kv
idx = torch.where(self.before_compress_kv_len == self.kernel_size)[0].squeeze()
self.before_compress_kv_cache[
:, idx, : self.kernel_size - self.kernel_stride, :, :
] = self.before_compress_kv_cache[:, idx, self.kernel_stride :, :, :]
self.before_compress_kv_len[idx] -= self.kernel_stride
# fill new kv
brange = torch.arange(batch_size, dtype=torch.int32, device=key.device)
self.before_compress_kv_cache[0, :batch_size][
brange, self.before_compress_kv_len[:batch_size]
] = key
self.before_compress_kv_cache[1, :batch_size][
brange, self.before_compress_kv_len[:batch_size]
] = value
# update kv len
self.before_compress_kv_len[:batch_size] += 1
def update_kv(
self,
cu_seqlens: torch.Tensor,
step: int,
compress_key: torch.Tensor,
compress_value: torch.Tensor,
sparse_key: torch.Tensor,
sparse_value: torch.Tensor,
sliding_key: torch.Tensor,
sliding_value: torch.Tensor,
):
if step == 0:
self._update_kv_prefill(
cu_seqlens,
step,
compress_key,
compress_value,
sparse_key,
sparse_value,
sliding_key,
sliding_value,
)
else:
self._update_kv_decode(
cu_seqlens,
step,
compress_key,
compress_value,
sparse_key,
sparse_value,
sliding_key,
sliding_value,
)
def _update_kv_prefill(
self,
cu_seqlens: torch.Tensor,
step: int,
compress_key: torch.Tensor,
compress_value: torch.Tensor,
sparse_key: torch.Tensor,
sparse_value: torch.Tensor,
sliding_key: torch.Tensor,
sliding_value: torch.Tensor,
):
assert step == 0
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
batch_size = seqlens.shape[0]
# sparse part kv, shape check
total_len, num_heads, head_dim = sparse_key.shape
assert sparse_key.shape == sparse_value.shape
assert num_heads == self.num_kv_heads and head_dim == self.head_dim
assert total_len == cu_seqlens[-1].item()
# compress part kv, shape check
total_comp_len, num_heads, head_dim = compress_key.shape
assert compress_key.shape == compress_value.shape
assert num_heads == self.num_kv_heads and head_dim == self.head_dim
comp_seqlens, comp_cu_seqlens = get_compressed_seqlens(
cu_seqlens, self.kernel_size, self.kernel_stride
)
assert total_comp_len == comp_cu_seqlens[-1].item()
# sliding window part kv, shape check
total_len, num_heads, head_dim = sliding_key.shape
assert sliding_key.shape == sliding_value.shape
# fill compress part kv cache
seq_start, seq_end = comp_cu_seqlens[:-1], comp_cu_seqlens[1:]
_fill_kv_cache(
self.compress_kv_cache,
compress_key,
compress_value,
seq_start,
seq_end,
)
self.compress_kv_len[:batch_size] = comp_seqlens
# fill sparse part kv cache
seq_start, seq_end = cu_seqlens[:-1], cu_seqlens[1:]
_fill_kv_cache(
self.sparse_kv_cache,
sparse_key,
sparse_value,
seq_start,
seq_end,
)
self.sparse_kv_len[:batch_size] = seqlens
# fill sliding part kv cache
seq_start = torch.maximum(cu_seqlens[1:] - self.window_size, cu_seqlens[:-1])
seq_end = cu_seqlens[1:]
_fill_kv_cache(
self.sliding_kv_cache,
sliding_key,
sliding_value,
seq_start,
seq_end,
)
self.sliding_kv_len[:batch_size] = seq_end - seq_start
def _update_kv_decode(
self,
cu_seqlens: torch.Tensor,
step: int,
compress_key: torch.Tensor,
compress_value: torch.Tensor,
sparse_key: torch.Tensor,
sparse_value: torch.Tensor,
sliding_key: torch.Tensor,
sliding_value: torch.Tensor,
):
assert step > 0
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
# sparse part kv, shape check
batch_size, num_heads, head_dim = sparse_key.shape
assert batch_size == seqlens.shape[0]
assert sparse_key.shape == sparse_value.shape
assert num_heads == self.num_kv_heads and head_dim == self.head_dim
# compress part kv, shape check
batch_size, num_heads, head_dim = compress_key.shape
assert batch_size == seqlens.shape[0]
assert compress_key.shape == compress_value.shape
assert num_heads == self.num_kv_heads and head_dim == self.head_dim
# sliding window part kv, shape check
total_len, num_heads, head_dim = sliding_key.shape
assert sliding_key.shape == sliding_value.shape
# fill compress part kv cache, only full block need to be compress
idx = torch.where(self.before_compress_kv_len == self.kernel_size)[0].squeeze()
self.compress_kv_cache[0][idx, self.compress_kv_len[idx]] = compress_key[idx]
self.compress_kv_cache[1][idx, self.compress_kv_len[idx]] = compress_value[idx]
self.compress_kv_len[idx] += 1
# fill sparse part kv cache
brange = torch.arange(batch_size, dtype=torch.int32, device=sparse_key.device)
self.sparse_kv_cache[0, :batch_size][
brange, self.sparse_kv_len[:batch_size]
] = sparse_key
self.sparse_kv_cache[1, :batch_size][
brange, self.sparse_kv_len[:batch_size]
] = sparse_value
self.sparse_kv_len[:batch_size] += 1
# fill sliding window kv cache
self.sliding_kv_cache[0, :batch_size][
brange, self.sliding_kv_len[:batch_size] % self.window_size
] = sliding_key
self.sliding_kv_cache[1, :batch_size][
brange, self.sliding_kv_len[:batch_size] % self.window_size
] = sliding_value
self.sliding_kv_len[:batch_size] += 1
@triton.jit
def _fill_kv_cache_kernel(
cache_ptr,
k_ptr,
v_ptr,
seq_start,
seq_end,
head_dim,
stride_c2,
stride_cb,
stride_cn,
stride_ch,
stride_cd,
stride_kn,
stride_kh,
stride_kd,
stride_vn,
stride_vh,
stride_vd,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_D: tl.constexpr,
):
# get batch id and head id
pid_2b = tl.program_id(0)
pid_2 = pid_2b % 2
pid_b = pid_2b // 2
pid_h = tl.program_id(1)
pid_n = tl.program_id(2)
# get kv start and len after rmpad
kv_start = tl.load(seq_start + pid_b)
kv_end = tl.load(seq_end + pid_b)
kv_len = kv_end - kv_start
if pid_n * BLOCK_SIZE_N >= kv_len:
return
# load key or value
if pid_2 == 0:
kv_ptrs = tl.make_block_ptr(
base=k_ptr + kv_start * stride_kn + pid_h * stride_kh,
shape=(kv_len, head_dim),
strides=(stride_kn, stride_kd),
offsets=(pid_n * BLOCK_SIZE_N, 0),
block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_D),
order=(1, 0),
)
kv = tl.load(kv_ptrs, boundary_check=(0, 1))
else:
kv_ptrs = tl.make_block_ptr(
base=v_ptr + kv_start * stride_vn + pid_h * stride_vh,
shape=(kv_len, head_dim),
strides=(stride_vn, stride_vd),
offsets=(pid_n * BLOCK_SIZE_N, 0),
block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_D),
order=(1, 0),
)
kv = tl.load(kv_ptrs, boundary_check=(0, 1))
# store to cache
cache_ptrs = tl.make_block_ptr(
base=cache_ptr + pid_2 * stride_c2 + pid_b * stride_cb + pid_h * stride_ch,
shape=(kv_len, head_dim),
strides=(stride_cn, stride_cd),
offsets=(pid_n * BLOCK_SIZE_N, 0),
block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_D),
order=(1, 0),
)
tl.store(cache_ptrs, kv.to(cache_ptr.dtype.element_ty), boundary_check=(0, 1))
def _fill_kv_cache(
kv_cache: torch.Tensor, # shape: [2, b, n, h, d]
key: torch.Tensor, # shape: [l, h, d]
value: torch.Tensor, # shape: [l, h, d]
seq_start: torch.Tensor, # shape: [b]
seq_end: torch.Tensor, # shape: [b]
):
total_len, num_heads, head_dim = key.shape
batch_size = seq_start.shape[0]
max_kv_len = (seq_end - seq_start).max().item()
# no kv cache to fill
if max_kv_len == 0:
return
BLOCK_SIZE_N = min(1024, triton.next_power_of_2(max_kv_len))
BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
grid = (2 * batch_size, num_heads, triton.cdiv(max_kv_len, BLOCK_SIZE_N))
_fill_kv_cache_kernel[grid](
kv_cache,
key,
value,
seq_start,
seq_end,
head_dim,
kv_cache.stride(0),
kv_cache.stride(1),
kv_cache.stride(2),
kv_cache.stride(3),
kv_cache.stride(4),
key.stride(0),
key.stride(1),
key.stride(2),
value.stride(0),
value.stride(1),
value.stride(2),
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_D=BLOCK_SIZE_D,
)
================================================
FILE: native_sparse_attention/module/native_sparse_attention.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from flash_attn import flash_attn_varlen_func
from native_sparse_attention.ops import (
compressed_attention,
topk_sparse_attention,
avgpool_compress,
weightedpool_compress,
linear_compress,
)
from einops import rearrange
from native_sparse_attention.module.rope import RopeConfig, RotaryEmbedding
from native_sparse_attention.infer import nsa_infer
from native_sparse_attention.module.kv_cache import NSACache
COMPRESS_TYPE_TO_FUNC = {
"avgpool": avgpool_compress,
"weightedpool": weightedpool_compress,
"linear": linear_compress,
}
COMPRESS_TYPE_TO_WEIGHT = {
"avgpool": lambda num_heads, head_dim, kernel_size: None,
"weightedpool": lambda num_heads, head_dim, kernel_size: torch.nn.Parameter(
torch.zeros(num_heads, kernel_size)
),
"linear": lambda num_heads, head_dim, kernel_size: torch.nn.Parameter(
torch.zeros(num_heads, head_dim * kernel_size, head_dim)
),
}
class NativeSparseAttention(torch.nn.Module):
"""Native sparse attention module, support training and inference
Args:
compress_type (str): key value compression type, currently support ['linear', 'avgpool', 'weightedpool']
hidden_size (int): hidden dimension
num_q_heads (int): number of query heads
num_kv_heads (int): number of key/value heads, must be divisible by num_q_heads
head_dim (int): head dim
kernel_size (int): kernel size of compression
kernel_stride (int): kernel stride ofr compression
block_size (int): block size of sparse attention
topk (int): topk of sparse attention
init_blocks (int): number of blocks at the begining of the sequence, these blocks are force to be computed in sparse attention
local_blocks (int): number of blocks at the local window of each query, these blocks are force to be computed in sparse attention
window_size (int): window size for sliding window attention
rope_config (RopeConfig): config for rotary embedding, see native_sparse_attention.module.rope.RopeConfig for details
rope_device (str): device used to store rope freqs
"""
def __init__(
self,
compress_type: str,
hidden_size: int,
num_q_heads: int,
num_kv_heads: int,
head_dim: int,
kernel_size: int,
kernel_stride: int,
block_size: int,
topk: int,
init_blocks: int,
local_blocks: int,
window_size: int,
rope_config: RopeConfig,
rope_device: str = "cuda",
):
super().__init__()
# configs
self.compress_type = compress_type
self.hidden_size = hidden_size
self.num_q_heads = num_q_heads
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.kernel_size = kernel_size
self.kernel_stride = kernel_stride
self.block_size = block_size
self.topk = topk
self.init_blocks = init_blocks
self.local_blocks = local_blocks
self.window_size = window_size
self.rope_config = rope_config
assert self.head_dim == self.rope_config.head_dim
# qkv proj and o proj
self.proj_q = torch.nn.Linear(
self.hidden_size, self.num_q_heads * self.head_dim, bias=False
)
self.proj_k = torch.nn.Linear(
self.hidden_size, self.num_kv_heads * self.head_dim, bias=False
)
self.proj_v = torch.nn.Linear(
self.hidden_size, self.num_kv_heads * self.head_dim, bias=False
)
self.proj_o = torch.nn.Linear(
self.num_q_heads * self.head_dim, self.hidden_size, bias=False
)
# nsa compress func
self.compress_func = COMPRESS_TYPE_TO_FUNC[self.compress_type]
# nsa parameteres
self.compress_key = COMPRESS_TYPE_TO_WEIGHT[self.compress_type](
num_kv_heads, head_dim, kernel_size
)
self.compress_value = COMPRESS_TYPE_TO_WEIGHT[self.compress_type](
num_kv_heads, head_dim, kernel_size
)
self.intra_block_pe = torch.nn.Parameter(
torch.zeros(self.num_kv_heads, self.kernel_size, self.head_dim)
)
# gate function
self.gate = torch.nn.Sequential(
torch.nn.Linear(self.hidden_size, self.num_q_heads * 3, bias=False),
torch.nn.Sigmoid(),
)
# rope
self.rope = RotaryEmbedding(self.rope_config, device=rope_device)
# init parameters
self.init_params()
def init_params(self):
for p in self.parameters():
if len(p.shape) > 1:
torch.nn.init.xavier_uniform_(p)
def forward(
self,
x: torch.Tensor, # shape: [total_len, hidden_size]
cu_seqlens: torch.Tensor, # shape: [batch_size + 1]
):
# dtype and shape check
assert x.dtype == torch.bfloat16 or x.dtype == torch.float16
assert x.shape[-1] == self.hidden_size
cu_seqlens = cu_seqlens.to(torch.int32)
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
# qkv proj
q = self.proj_q(x).view(-1, self.num_q_heads, self.head_dim)
k = self.proj_k(x).view(-1, self.num_kv_heads, self.head_dim)
v = self.proj_v(x).view(-1, self.num_kv_heads, self.head_dim)
# compressed key and value before rope
compressed_k, compressed_cu_seqlens = self.compress_func(
k,
self.compress_key,
cu_seqlens,
self.kernel_size,
self.kernel_stride,
self.intra_block_pe,
)
compressed_v, _ = self.compress_func(
v,
self.compress_value,
cu_seqlens,
self.kernel_size,
self.kernel_stride,
None,
)
# do rope for query and compressed key
q = self.rope(q, cu_seqlens)
compressed_k = self.rope(
compressed_k, compressed_cu_seqlens, stride=self.kernel_stride
)
# attention between query and compressed key value
compressed_seqlens = compressed_cu_seqlens[1:] - compressed_cu_seqlens[:-1]
compressed_attn_output, topk_idx = compressed_attention(
q,
compressed_k,
compressed_v,
self.kernel_size,
self.kernel_stride,
self.block_size,
self.topk,
cu_seqlens,
compressed_cu_seqlens,
seqlens.max().item(),
compressed_seqlens.max().item(),
None,
self.init_blocks,
self.local_blocks,
)
# do rope for original key
k = self.rope(k, cu_seqlens)
# topk sparse attention
sparse_attn_output = topk_sparse_attention(
q, k, v, topk_idx, self.block_size, cu_seqlens, None
)
# sliding window attention
sliding_attn_output = flash_attn_varlen_func(
q,
k,
v,
cu_seqlens,
cu_seqlens,
seqlens.max().item(),
seqlens.max().item(),
causal=True,
window_size=(self.window_size, -1),
)
# gate average
gate = self.gate(x)
gate = rearrange(gate, "n (h g) -> n h g", g=3)
attn_output = (
gate[..., 0:1] * compressed_attn_output
+ gate[..., 1:2] * sparse_attn_output
+ gate[..., 2:3] * sliding_attn_output
)
# rearrange and output proj
attn_output = rearrange(attn_output, "n h d -> n (h d)")
attn_output = self.proj_o(attn_output)
return attn_output
@torch.no_grad()
def inference(
self,
x: torch.Tensor, # shape: [total_len, hidden_size]
cu_seqlens: torch.Tensor, # shape: [batch_size + 1]
step: int,
cache: NSACache,
):
# dtype and shape check
assert x.dtype == torch.bfloat16 or x.dtype == torch.float16
assert x.shape[-1] == self.hidden_size
cu_seqlens = cu_seqlens.to(torch.int32)
assert step >= 0
if step == 0:
assert x.shape[0] == cu_seqlens[-1]
else:
assert x.shape[0] == cu_seqlens.shape[0] - 1
# qkv proj
q = self.proj_q(x).view(-1, self.num_q_heads, self.head_dim)
k = self.proj_k(x).view(-1, self.num_kv_heads, self.head_dim)
v = self.proj_v(x).view(-1, self.num_kv_heads, self.head_dim)
# gate proj
gate = self.gate(x)
gate = rearrange(gate, "n (h g) -> n h g", g=3)
# nsa infer
output = nsa_infer(
cu_seqlens,
step,
q,
k,
v,
gate,
self.rope,
cache,
[self.compress_key, self.compress_value],
[self.compress_func, self.compress_func],
self.intra_block_pe,
self.kernel_size,
self.kernel_stride,
self.block_size,
self.topk,
self.init_blocks,
self.local_blocks,
self.window_size,
)
# output proj
output = rearrange(output, "n h d -> n (h d)")
output = self.proj_o(output)
return output
================================================
FILE: native_sparse_attention/module/rope.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from dataclasses import dataclass, field
from torch import nn
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
# default to llama3.1 rope config
@dataclass
class RopeConfig:
"""Config for RotaryEmbedding, similar to transformers llama."""
max_position_embeddings: int = 131072
head_dim: int = 128
rope_theta: float = 500000
rope_scaling: dict = field(
default_factory=lambda: {
"factor": 8.0,
"high_freq_factor": 4.0,
"low_freq_factor": 1.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3",
}
)
# useless, just for compatibility, please use head_dim instead
hidden_size: int = 1
num_attention_heads: int = 1
def __post_init__(self):
self.num_attention_heads = 1
self.hidden_size = self.head_dim
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
# copy and modify from modify from hugigngface transformers
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
class RotaryEmbedding(nn.Module):
"""Rotary embedding
Args:
config (RopeConfig): config for rotary embedding, see native_sparse_attention.module.rope.RopeConfig for details
device (str): default to 'cuda'
"""
cos = None
sin = None
def __init__(
self, config: RopeConfig, device=torch.device(torch.cuda.current_device())
):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get(
"rope_type", config.rope_scaling.get("type")
)
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
def _dynamic_frequency_update(self, position_ids, device):
"""
dynamic RoPE layers should recompute `inv_freq` in the following situations:
1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len
)
self.register_buffer(
"inv_freq", inv_freq, persistent=False
) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len
if (
seq_len < self.original_max_seq_len
and self.max_seq_len_cached > self.original_max_seq_len
): # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len
@torch.no_grad()
def generate_cos_sin(self, x: torch.Tensor, position_ids):
if "dynamic" in self.rope_type:
self._dynamic_frequency_update(position_ids, device=x.device)
# Core RoPE block
inv_freq_expanded = (
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = (
device_type
if isinstance(device_type, str) and device_type != "mps"
else "cpu"
)
with torch.autocast(device_type=device_type, enabled=False):
freqs = (
inv_freq_expanded.float() @ position_ids_expanded.float()
).transpose(1, 2)
# # donot use this if use flash_attn
# emb = torch.cat((freqs, freqs), dim=-1)
emb = freqs
cos = emb.cos()
sin = emb.sin()
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
cos = (cos * self.attention_scaling).to(dtype=x.dtype).squeeze(0)
sin = (sin * self.attention_scaling).to(dtype=x.dtype).squeeze(0)
# save cos sin
RotaryEmbedding.cos = torch.cat([cos, cos], dim=-1)
RotaryEmbedding.sin = torch.cat([sin, sin], dim=-1)
return RotaryEmbedding.cos, RotaryEmbedding.sin
@torch.no_grad()
def generate_pos_embs(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
seqlens: torch.Tensor,
step: int = 0,
stride: int = 1,
):
if (
RotaryEmbedding.cos is None
or seqlens.max() + step > RotaryEmbedding.cos.shape[0]
):
self.generate_cos_sin(
x, torch.arange(seqlens.max() + step).to(x.device)[None, :]
)
cos_embs = []
sin_embs = []
bsz = len(cu_seqlens) - 1
for i in range(bsz):
if step == 0: # prefilling
r = cu_seqlens[i + 1] - cu_seqlens[i]
cos_emb, sin_emb = (
RotaryEmbedding.cos[: r * stride : stride],
RotaryEmbedding.sin[: r * stride : stride],
)
elif step > 0: # decoding
r = cu_seqlens[i + 1] - cu_seqlens[i] + step - 1
cos_emb, sin_emb = (
RotaryEmbedding.cos[r * stride : r * stride + 1],
RotaryEmbedding.sin[r * stride : r * stride + 1],
)
cos_embs.append(cos_emb)
sin_embs.append(sin_emb)
cos_embs = torch.cat(cos_embs, dim=0)
sin_embs = torch.cat(sin_embs, dim=0)
return cos_embs, sin_embs
def forward(self, x, cu_seqlens, step=0, stride=1):
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
cos_embs, sin_embs = self.generate_pos_embs(
x,
cu_seqlens,
seqlens,
step=step,
stride=stride,
)
N, H, D = x.shape[0], x.shape[-2], x.shape[-1] # H: number of heads
x = x * cos_embs.view(N, 1, D) + rotate_half(x) * sin_embs.view(N, 1, D)
return x
================================================
FILE: native_sparse_attention/module/self_attention.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from flash_attn import flash_attn_varlen_func
from einops import rearrange
from native_sparse_attention.module.rope import RopeConfig, RotaryEmbedding
from native_sparse_attention.module.kv_cache import KVCache
from native_sparse_attention.ops import flash_attention_decode
class SelfAttention(torch.nn.Module):
"""self attention module
Args:
hidden_size (int): hidden dimension
num_q_heads (int): number of query heads
num_kv_heads (int): number of key/value heads, must be divisible by num_q_heads
head_dim (int): head dim
rope_config (RopeConfig): config for rotary embedding, see native_sparse_attention.module.rope.RopeConfig for details
"""
def __init__(
self,
hidden_size: int,
num_q_heads: int,
num_kv_heads: int,
head_dim: int,
rope_config: RopeConfig,
rope_device: str = "cuda",
):
super().__init__()
# configs
self.hidden_size = hidden_size
self.num_q_heads = num_q_heads
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.rope_config = rope_config
assert self.head_dim == self.rope_config.head_dim
# qkv proj and o proj
self.proj_q = torch.nn.Linear(
self.hidden_size, self.num_q_heads * self.head_dim, bias=False
)
self.proj_k = torch.nn.Linear(
self.hidden_size, self.num_kv_heads * self.head_dim, bias=False
)
self.proj_v = torch.nn.Linear(
self.hidden_size, self.num_kv_heads * self.head_dim, bias=False
)
self.proj_o = torch.nn.Linear(
self.num_q_heads * self.head_dim, self.hidden_size, bias=False
)
# rope
self.rope = RotaryEmbedding(self.rope_config, device=rope_device)
# init parameters
self.init_params()
def init_params(self):
for p in self.parameters():
torch.nn.init.xavier_uniform_(p)
def forward(
self,
x: torch.Tensor, # shape: [total_len, hidden_size]
cu_seqlens: torch.Tensor, # shape: [batch_size + 1]
):
# dtype and shape check
assert x.dtype == torch.bfloat16 or x.dtype == torch.float16
assert x.shape[-1] == self.hidden_size
cu_seqlens = cu_seqlens.to(torch.int32)
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
# qkv proj
q = self.proj_q(x).view(-1, self.num_q_heads, self.head_dim)
k = self.proj_k(x).view(-1, self.num_kv_heads, self.head_dim)
v = self.proj_v(x).view(-1, self.num_kv_heads, self.head_dim)
# do rope for query and compressed key
q = self.rope(q, cu_seqlens)
k = self.rope(k, cu_seqlens)
# self attention
attn_output = flash_attn_varlen_func(
q,
k,
v,
cu_seqlens,
cu_seqlens,
seqlens.max().item(),
seqlens.max().item(),
causal=True,
)
# rearrange and output proj
attn_output = rearrange(attn_output, "n h d -> n (h d)")
attn_output = self.proj_o(attn_output)
return attn_output
@torch.no_grad()
def inference(
self,
x: torch.Tensor, # shape: [total_len, hidden_size]
cu_seqlens: torch.Tensor, # shape: [batch_size + 1]
step: int,
cache: KVCache,
):
# dtype and shape check
assert x.dtype == torch.bfloat16 or x.dtype == torch.float16
assert x.shape[-1] == self.hidden_size
cu_seqlens = cu_seqlens.to(torch.int32)
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
assert step >= 0
if step == 0:
assert x.shape[0] == cu_seqlens[-1]
else:
assert x.shape[0] == cu_seqlens.shape[0] - 1
batch_size = cu_seqlens.shape[0] - 1
# qkv proj
q = self.proj_q(x).view(-1, self.num_q_heads, self.head_dim)
k = self.proj_k(x).view(-1, self.num_kv_heads, self.head_dim)
v = self.proj_v(x).view(-1, self.num_kv_heads, self.head_dim)
# do rope for query and compressed key
q = self.rope(q, cu_seqlens, step)
k = self.rope(k, cu_seqlens, step)
# reset and update kv cache
if step == 0:
cache.reset()
cache.update_kv(cu_seqlens, step, k, v)
# self attention
if step == 0:
cu_seqlens_q = cu_seqlens_k = cu_seqlens
max_seqlen_in_batch_q = max_seqlen_in_batch_k = seqlens.max().item()
output = flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
causal=True,
)
else:
output = flash_attention_decode(
q,
cache.kv_cache[0, :batch_size],
cache.kv_cache[1, :batch_size],
cache.kv_len[:batch_size],
)
# rearrange and output proj
output = rearrange(output, "n h d -> n (h d)")
output = self.proj_o(output)
return output
================================================
FILE: native_sparse_attention/ops/README.md
================================================
# Triton Functions for Native Sparse Attention
This folder provides efficient Triton-based implementations of components for Native Sparse Attention. This README introduces the available functions, explains how to set them up, and offers guidance on their usage.
---
## Overview of Functions
The functions are organized into two main categories:
1. **Compression Methods**: Techniques for compressing key and value tensors.
2. **Attention Mechanisms**: Methods for computing attention between queries and compressed key/value tensors, including top-k sparse attention.
---
## Function Descriptions
### Compression Methods
These functions compress key and value tensors using a sliding window approach. Within each window, `kernel_size` tokens are compressed into a single token, with a stride of `kernel_stride`. All compression functions share similar input parameters and return formats.
**Parameters:**
- `x`: Input tensor (`total_len, num_heads, head_dim`)
- `w`: Weight tensor (shape varies by compression method)
- `cu_seqlens`: Cumulative sequence lengths (`batch_size + 1`)
- `kernel_size`: Size of the compression window
- `kernel_stride`: Stride of the compression window
- `pe`: Optional positional embedding (`num_heads, kernel_size, head_dim`)
**Returns:**
- Compressed tensor (`total_compress_len, num_heads, head_dim`)
- Cumulative sequence lengths (`com_cu_seqlens`) for the compressed tensor
#### `weightedpool_compress`
Compresses the input tensor using weighted pooling, applying a weighted sum over each block:
$\hat{k} = w_1 k_1 + \dots + w_m k_m$
- **Weight shape**: `(num_heads, kernel_size)`
#### `avgpool_compress`
Compresses the input tensor using average pooling:
$\hat{k} = (k_1 + \dots + k_m) / m$
- **Weight**: Must be `None`
#### `linear_compress`
Compresses the input tensor via linear projection, mapping each block to a single vector using learned weights:
$\hat{k} = \text{cat}(k_1, \dots, k_m) W$
- **Weight shape**: `(num_heads, kernel_size * head_dim, head_dim)`
---
### Attention Mechanisms
These functions compute attention using either full or sparse mechanisms.
#### `flash_attention_varlen`
A variable-length implementation of flash attention, similar to `flash_attn_varlen_func` from the `flash_attn` package.
**Parameters:**
- `q`, `k`, `v`: Query, key, and value tensors (`total_len, num_heads, head_dim`)
- `cu_seqlens_q`, `cu_seqlens_k`: Cumulative sequence lengths for queries and keys
- `max_seqlen_q`, `max_seqlen_k`: Maximum sequence lengths in the batch
- `causal`: Apply causal masking (default: `False`)
- `sm_scale`: Softmax scale (default: `1 / sqrt(head_dim)`)
**Returns:**
- Attention output tensor (`total_q_len, num_q_heads, head_dim`)
#### `compressed_attention`
Computes attention between a query and compressed key/value tensors, identifying the top-k blocks for sparse attention.
**Parameters:**
- `q`: Query tensor (`total_len, num_heads, head_dim`)
- `k`, `v`: Compressed key and value tensors (`total_compress_len, num_heads, head_dim`)
- `kernel_size`, `kernel_stride`: Compression parameters
- `block_size`: Size of blocks for sparse attention
- `topk`: Number of top blocks to select
- `cu_seqlens_q`, `cu_seqlens_k`: Cumulative sequence lengths for query and compressed key/value
- `max_seqlen_q`, `max_seqlen_k`: Maximum sequence lengths for query and compressed key/value
- `sm_scale`: Softmax scale (default: `1 / sqrt(head_dim)`)
- `init_blocks`: Number of initial blocks forced to be selected (default: `1`)
- `local_blocks`: Number of local blocks forced to be selected (default: `2`)
**Returns:**
- Tuple containing:
- Attention output tensor
- Top-k block indices
#### `topk_sparse_attention`
Performs sparse attention using precomputed top-k block indices. If a query attends to fewer than `topk` key/value blocks, the `topk_idx` should be padded with `-1` on the right.
**Parameters:**
- `q`, `k`, `v`: Query, key, and value tensors (`total_len, num_heads, head_dim`)
- `topk_idx`: Precomputed top-k indices (`num_kv_heads, total_len, topk`)
- `block_size`: Block size for sparse attention (recommended: `64` or `128`)
- `cu_seqlens`: Cumulative sequence lengths
- `softmax_scale`: Softmax scale (default: `1 / sqrt(head_dim)`)
**Returns:**
- Attention output tensor (`total_len, num_q_heads, head_dim`)
---
## Usage Example
Below is a typical workflow demonstrating how to combine these sparse attention functions:
```python
import torch
from native_sparse_attention.ops import linear_compress, compressed_attention, topk_sparse_attention
# Example input setup
num_q_heads = 64
num_kv_heads = 4
head_dim = 128
cu_seqlens = torch.tensor([0, 1024, 8192, 16384], dtype=torch.int32).cuda()
# Query, key, and value tensors
query = torch.randn(16384, num_q_heads, head_dim, dtype=torch.bfloat16).cuda()
key = torch.randn(16384, num_kv_heads, head_dim, dtype=torch.bfloat16).cuda()
value = torch.randn(16384, num_kv_heads, head_dim, dtype=torch.bfloat16).cuda()
# Compression weights and positional embeddings
kernel_size = 32
kernel_stride = 16
wk = torch.randn(num_kv_heads, kernel_size * head_dim, head_dim, dtype=torch.bfloat16).cuda()
wv = torch.randn_like(wk)
pe = torch.randn(num_kv_heads, kernel_size, head_dim, dtype=torch.bfloat16).cuda()
# Parameters for top-k sparse attention
block_size = 64
topk = 16
# 1. Compress key and value tensors
compressed_key, compressed_cu_seqlens = linear_compress(
key, wk, cu_seqlens, kernel_size, kernel_stride, pe
)
compressed_value, _ = linear_compress(
value, wv, cu_seqlens, kernel_size, kernel_stride, None
)
# 2. Compute attention with compressed key/value and get top-k indices
compressed_attn_output, topk_idx = compressed_attention(
query,
compressed_key,
compressed_value,
kernel_size,
kernel_stride,
block_size,
topk,
cu_seqlens,
compressed_cu_seqlens,
init_blocks=1,
local_blocks=2,
)
# 3. Perform top-k sparse attention
sparse_attn_output = topk_sparse_attention(
query,
key,
value,
topk_idx,
block_size,
cu_seqlens,
)
# 4. Combine attention outputs (e.g., average)
attn_output = (compressed_attn_output + sparse_attn_output) / 2
```
For a complete implementation of the Native Sparse Attention module, see `native_sparse_attention/module/native_sparse_attention.py`.
================================================
FILE: native_sparse_attention/ops/__init__.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# compress method
from native_sparse_attention.ops.triton.weighted_pool import (
weightedpool_compress,
avgpool_compress,
)
from native_sparse_attention.ops.triton.linear_compress import linear_compress
# prefill attention
from native_sparse_attention.ops.triton.flash_attention import flash_attention_varlen
from native_sparse_attention.ops.triton.compressed_attention import compressed_attention
from native_sparse_attention.ops.triton.topk_sparse_attention import (
topk_sparse_attention,
)
# decode attention
from native_sparse_attention.ops.triton.flash_attention_decode import (
flash_attention_decode,
)
from native_sparse_attention.ops.torch.compressed_attention_decode import (
compressed_attention_decode,
)
from native_sparse_attention.ops.triton.topk_sparse_attention_decode import (
topk_sparse_attention_decode,
)
__all__ = [
# compress method
"avgpool_compress",
"weightedpool_compress",
"linear_compress",
# prefill attention, trainable
"flash_attention_varlen",
"compressed_attention",
"topk_sparse_attention",
# decode attention, no grad
"flash_attention_decode",
"compressed_attention_decode",
"topk_sparse_attention_decode",
]
================================================
FILE: native_sparse_attention/ops/torch/__init__.py
================================================
================================================
FILE: native_sparse_attention/ops/torch/compress_key_value.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from typing import Optional
from einops import rearrange, einsum
def avgpool_compress_torch(
x: torch.Tensor,
w: torch.Tensor,
cu_seqlens,
kernel_size: int,
kernel_stride: int,
pe: Optional[torch.Tensor] = None,
):
"""Compress key and value tensor with kernel_size and kernel_stride.
Args:
x (torch.Tensor): key_states or value_states, shape (total_len, num_heads, head_dim)
w (torch.Tensor): no weight for avgpool, must be None.
cu_seqlens (_type_): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.
kernel_size (int): kernel_size, each (kernel_size, head_dim) blocks will be compressed to (1, head_dim)
kernel_stride (int): stride for each compress kernel
pe (Optional[torch.Tensor], optional): intra-block positional embedding with shape (num_heads, kernel_size, head_dim). Defaults to None.
Returns:
Tuple[torch.Tensor, torch.Tensor]: compressed states and corresponding cu_seqlens.
"""
# dtype check
assert x.dtype == torch.float16 or x.dtype == torch.bfloat16
assert cu_seqlens.dtype == torch.int32
assert x.dtype == pe.dtype if pe is not None else True
# shape check
total_len, num_heads, head_dim = x.shape
batch_size = cu_seqlens.shape[0] - 1
assert w is None, "don't need additional weight for avgpool"
assert kernel_size % kernel_stride == 0
assert kernel_size in {16, 32, 64, 128}
# compute seqlens after compression
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
y_seqlens = torch.floor((seqlens - kernel_size) / kernel_stride).to(torch.int32) + 1
# corner case, if sequence_length < kernel_size, no compression for this sequence
y_seqlens[seqlens < kernel_size] = 0
y_cu_seqlens = torch.cat(
[
torch.zeros(1, dtype=torch.int32, device="cuda"),
torch.cumsum(y_seqlens, dim=0),
],
dim=0,
).to(torch.int32)
# pad and rearrange x
x = rearrange(x, "n h d -> n (h d)")
splited_x = torch.split(x, seqlens.tolist(), 0)
x = torch.nn.utils.rnn.pad_sequence(splited_x, batch_first=True)
x = rearrange(x, "b n d -> b d n")
# avgpool
y = torch.nn.functional.avg_pool1d(x, kernel_size=kernel_size, stride=kernel_stride)
y = rearrange(y, "b (h d) n -> b n h d", h=num_heads)
# only keep useful part
y = torch.cat([y[i, : y_seqlens[i]] for i in range(batch_size)], dim=0)
# position embedding as a bias
if pe is not None:
bias = torch.mean(pe, dim=1)
y = y + bias.unsqueeze(0)
return y, y_cu_seqlens
def weightedpool_compress_torch(
x: torch.Tensor,
w: torch.Tensor, # [num_heads, kernel_size]
cu_seqlens,
kernel_size: int,
kernel_stride: int,
pe: Optional[torch.Tensor] = None,
):
"""Compress key and value tensor with kernel_size and kernel_stride.
Args:
x (torch.Tensor): key_states or value_states, shape (total_len, num_heads, head_dim)
w (torch.Tensor): weight for each head, shape (num_heads, kernel_size)
cu_seqlens (_type_): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.
kernel_size (int): kernel_size, each (kernel_size, head_dim) blocks will be compressed to (1, head_dim)
kernel_stride (int): stride for each compress kernel
pe (Optional[torch.Tensor], optional): intra-block positional embedding with shape (num_heads, kernel_size, head_dim). Defaults to None.
Returns:
Tuple[torch.Tensor, torch.Tensor]: compressed states and corresponding cu_seqlens.
"""
# dtype check
assert x.dtype == torch.float16 or x.dtype == torch.bfloat16
assert x.dtype == w.dtype
assert x.dtype == pe.dtype if pe is not None else True
assert cu_seqlens.dtype == torch.int32
# shape check
total_len, num_heads, head_dim = x.shape
batch_size = cu_seqlens.shape[0] - 1
assert w.shape[0] == num_heads
assert w.shape[1] == kernel_size
assert kernel_size % kernel_stride == 0
assert kernel_size in {16, 32, 64, 128}
# compute seqlens after compression
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
y_seqlens = torch.floor((seqlens - kernel_size) / kernel_stride).to(torch.int32) + 1
# corner case, if sequence_length < kernel_size, no compression for this sequence
y_seqlens[seqlens < kernel_size] = 0
y_cu_seqlens = torch.cat(
[
torch.zeros(1, dtype=torch.int32, device="cuda"),
torch.cumsum(y_seqlens, dim=0),
],
dim=0,
).to(torch.int32)
# pad and rearrange x
x = rearrange(x, "n h d -> n (h d)")
splited_x = torch.split(x, seqlens.tolist(), 0)
x = torch.nn.utils.rnn.pad_sequence(splited_x, batch_first=True)
x = rearrange(x, "b n (h d) -> b h n d", h=num_heads)
x = x.as_strided(
size=(batch_size, num_heads, y_seqlens.max().item(), kernel_size, head_dim),
stride=(
x.stride(0),
x.stride(1),
kernel_stride * x.stride(2),
x.stride(2),
x.stride(3),
),
)
y = einsum(x, w, "b h n k d, h k -> b n h d")
# only keep useful part
y = torch.cat([y[i, : y_seqlens[i]] for i in range(batch_size)], dim=0)
# position embedding as a bias
if pe is not None:
bias = einsum(pe, w, "h k d, h k -> h d")
y = y + bias.unsqueeze(0)
return y, y_cu_seqlens
def linear_compress_torch(
x: torch.Tensor,
w: torch.Tensor, # [num_heads, kernel_size * head_dim, head_dim]
cu_seqlens,
kernel_size: int,
kernel_stride: int,
pe: Optional[torch.Tensor] = None,
):
"""Compress key and value tensor with kernel_size and kernel_stride. Similar to conv_compress.
Args:
x (torch.Tensor): key_states or value_states, shape (total_len, num_heads, head_dim)
w (torch.Tensor): weight for each head, shape (num_heads, kernel_size * head_dim, head_dim)
cu_seqlens (_type_): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.
kernel_size (int): kernel_size, each (kernel_size, head_dim) blocks will be compressed to (1, head_dim)
kernel_stride (int): stride for each compress kernel
pe (Optional[torch.Tensor], optional): intra-block positional embedding with shape (num_heads, kernel_size, head_dim). Defaults to None.
Returns:
Tuple[torch.Tensor, torch.Tensor]: compressed states and corresponding cu_seqlens.
"""
# dtype check
assert x.dtype == torch.float16 or x.dtype == torch.bfloat16
assert x.dtype == w.dtype
assert x.dtype == pe.dtype if pe is not None else True
assert cu_seqlens.dtype == torch.int32
# shape check
total_len, num_heads, head_dim = x.shape
batch_size = cu_seqlens.shape[0] - 1
assert w.shape[0] == num_heads
assert w.shape[1] == kernel_size * head_dim
assert w.shape[2] == head_dim
assert kernel_size % kernel_stride == 0
assert kernel_size in {16, 32, 64, 128}
# compute seqlens after compression
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
y_seqlens = torch.floor((seqlens - kernel_size) / kernel_stride).to(torch.int32) + 1
# corner case, if sequence_length < kernel_size, no compression for this sequence
y_seqlens[seqlens < kernel_size] = 0
y_cu_seqlens = torch.cat(
[
torch.zeros(1, dtype=torch.int32, device="cuda"),
torch.cumsum(y_seqlens, dim=0),
],
dim=0,
).to(torch.int32)
# pad and rearrange x
x = rearrange(x, "n h d -> n (h d)")
splited_x = torch.split(x, seqlens.tolist(), 0)
x = torch.nn.utils.rnn.pad_sequence(splited_x, batch_first=True)
x = rearrange(x, "b n (h d) -> b h n d", h=num_heads)
x = x.as_strided(
size=(batch_size, num_heads, y_seqlens.max().item(), kernel_size, head_dim),
stride=(
x.stride(0),
x.stride(1),
kernel_stride * x.stride(2),
x.stride(2),
x.stride(3),
),
)
y = einsum(
x,
rearrange(w, "h (k d) D -> h k d D", k=kernel_size),
"b h n k d, h k d D -> b n h D",
)
# only keep useful part
y = torch.cat([y[i, : y_seqlens[i]] for i in range(batch_size)], dim=0)
# position embedding as a bias
if pe is not None:
pe = rearrange(pe, "h k d -> h (k d)")
bias = einsum(pe, w, "h D, h D d -> h d")
y = y + bias.unsqueeze(0)
return y, y_cu_seqlens
================================================
FILE: native_sparse_attention/ops/torch/compressed_attention.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import math
from typing import Tuple
from collections import Counter
from einops import rearrange
def transform_score(
score: torch.Tensor,
kernel_size: int,
kernel_stride: int,
block_size: int,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
init_blocks: int = 1,
local_blocks: int = 2,
) -> torch.Tensor:
num_k_heads, total_query_len, _ = score.shape
pad_len = kernel_size // kernel_stride - 1
score = torch.nn.functional.pad(score, (pad_len, pad_len), value=0)
max_blocks = math.ceil(max_seqlen_q / block_size)
full_blocks = max_seqlen_q // block_size
block_score = torch.zeros(
num_k_heads,
total_query_len,
max_blocks,
dtype=torch.float32,
device=score.device,
)
offs = (
torch.arange(kernel_size // kernel_stride)[:, None]
+ torch.arange(block_size // kernel_stride)[None, :]
).view(-1)
offs = dict(Counter(offs.tolist()))
for k, v in offs.items():
block_score[..., :full_blocks] += (
v * score[..., k :: block_size // kernel_stride][..., :full_blocks]
)
# set init block and local block score
batch_size = cu_seqlens_q.shape[0] - 1
q_idx = torch.cat(
[
torch.arange(cu_seqlens_q[i + 1] - cu_seqlens_q[i], device=score.device)
for i in range(batch_size)
],
dim=0,
)
q_idx = q_idx // block_size
b_idx = torch.arange(max_blocks, device=score.device)
block_score[..., :init_blocks] = torch.inf
local_mask = (q_idx[:, None] >= b_idx[None, :]) & (
q_idx[:, None] < b_idx[None, :] + local_blocks
)
local_mask = local_mask.unsqueeze(0).expand(num_k_heads, -1, -1)
block_score[local_mask] = torch.inf
return block_score
def compressed_attention_torch(
q: torch.Tensor, # [total_query_len, num_q_heads, head_dim]
k: torch.Tensor, # [total_key_len, num_k_heads, head_dim]
v: torch.Tensor, # [total_key_len, num_k_heads, head_dim]
kernel_size: int,
kernel_stride: int,
block_size: int,
topk: int,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
sm_scale: float = None,
init_blocks: int = 1,
local_blocks: int = 2,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Attention between query and compressed key and value. Implemented with torch, only for debug.
Args:
q (torch.Tensor): shape [total_q_len, num_q_heads, head_dim]
k (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim]
v (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim]
kernel_size (int): kernel size in compress_key_value
kernel_stride (int): stride of compress_key_value
block_size (int): key value block size for topk sparse attention.
topk (int): number of blocks for each query.
cu_seqlens_q (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.
cu_seqlens_k (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_k in flash_attn_func_varlen.
max_seqlen_q (int): max q len of the batch.
max_seqlen_k (int): max k len of the batch.
sm_scale (float, optional): softmax scale. Defaults to None, means 1/sqrt(head_dim).
init_blocks (int, optional): Number of init blocks for each query. Defaults to 1.
local_blocks (int, optional): Number of local blocks for each query. Defaults to 2.
Returns:
Tuple[torch.Tensor, torch.Tensor]: attention output and topk_idx used in topk_sparse_attention
"""
assert block_size % kernel_size == 0 and kernel_size % kernel_stride == 0
total_query_len, num_q_heads, head_dim = q.shape
total_key_len, num_k_heads, _ = k.shape
num_share_q_heads = num_q_heads // num_k_heads
batch_size = cu_seqlens_q.shape[0] - 1
if sm_scale is None:
sm_scale = 1.0 / math.sqrt(head_dim)
# get mask
mask = torch.zeros(
(total_query_len, total_key_len), dtype=torch.bool, device=q.device
)
for b in range(batch_size):
q_len, k_len = (
cu_seqlens_q[b + 1] - cu_seqlens_q[b],
cu_seqlens_k[b + 1] - cu_seqlens_k[b],
)
k_max_ids = (
torch.arange(k_len, device=q.device) * kernel_stride + kernel_size - 1
)
q_ids = torch.arange(q_len, device=q.device)
mask[
cu_seqlens_q[b] : cu_seqlens_q[b + 1], cu_seqlens_k[b] : cu_seqlens_k[b + 1]
] = (q_ids[:, None] >= k_max_ids[None, :])
# attention
qk = (
torch.einsum("qhd,khd->hqk", q, k.repeat_interleave(num_share_q_heads, 1))
* sm_scale
)
qk = qk.masked_fill_(~mask[None, ...], -torch.inf)
# query from beginning of the sequence can't attend to any compressed key
qk = qk.softmax(dim=-1, dtype=torch.float32)
qk = qk.nan_to_num(0)
attn_output = torch.einsum(
"hqk,khd->qhd", qk.to(v.dtype), v.repeat_interleave(num_share_q_heads, 1)
)
with torch.no_grad():
# get avg score over gqa heads
# qk shape [num_k_heads, total_q_len, total_k_len]
score = torch.zeros(
num_k_heads,
cu_seqlens_q[-1],
max_seqlen_k,
dtype=torch.float32,
device=q.device,
)
qk = rearrange(qk, "(h g) q k -> h g q k", h=num_k_heads).sum(1)
for b in range(batch_size):
score[
:,
cu_seqlens_q[b] : cu_seqlens_q[b + 1],
: cu_seqlens_k[b + 1] - cu_seqlens_k[b],
] = qk[
:,
cu_seqlens_q[b] : cu_seqlens_q[b + 1],
cu_seqlens_k[b] : cu_seqlens_k[b + 1],
]
# transform score to block-wise score
score = transform_score(
score,
kernel_size,
kernel_stride,
block_size,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
init_blocks,
local_blocks,
)
# get topk
batch_size = cu_seqlens_q.shape[0] - 1
q_idx = torch.cat(
[
torch.arange(cu_seqlens_q[i + 1] - cu_seqlens_q[i], device=q.device)
for i in range(batch_size)
],
dim=0,
)
q_idx = q_idx // block_size
topk = min(topk, score.shape[-1])
topk_idx = score.topk(topk, dim=-1).indices.sort(-1).values
topk_idx[topk_idx > q_idx[None, :, None]] = -1
topk_idx = topk_idx.to(torch.int32)
return attn_output, topk_idx
================================================
FILE: native_sparse_attention/ops/torch/compressed_attention_decode.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import math
from typing import Tuple, Optional
from collections import Counter
from einops import rearrange
def transform_score(
score: torch.Tensor,
seqlens: torch.Tensor,
kernel_size: int,
kernel_stride: int,
block_size: int,
init_blocks: int = 1,
local_blocks: int = 2,
) -> torch.Tensor:
num_k_heads, batch_size, kv_len = score.shape
pad_len = kernel_size // kernel_stride - 1
score = torch.nn.functional.pad(score, (pad_len, pad_len), value=0)
max_seqlen = seqlens.max().item()
max_blocks = math.ceil(max_seqlen / block_size)
full_blocks = max_seqlen // block_size
block_score = torch.zeros(
num_k_heads,
batch_size,
max_blocks,
dtype=torch.float32,
device=score.device,
)
offs = (
torch.arange(kernel_size // kernel_stride)[:, None]
+ torch.arange(block_size // kernel_stride)[None, :]
).view(-1)
offs = dict(Counter(offs.tolist()))
for k, v in offs.items():
block_score[..., :full_blocks] += (
v * score[..., k :: block_size // kernel_stride][..., :full_blocks]
)
# set init block and local block score
q_idx = (seqlens - 1) // block_size
b_idx = torch.arange(max_blocks, device=score.device)
block_score[..., :init_blocks] = torch.inf
local_mask = (q_idx[:, None] >= b_idx[None, :]) & (
q_idx[:, None] < b_idx[None, :] + local_blocks
)
local_mask = local_mask.unsqueeze(0).expand(num_k_heads, -1, -1)
block_score[local_mask] = torch.inf
block_score = block_score.nan_to_num(0, torch.inf, -torch.inf)
return block_score
def compressed_attention_decode(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seqlens: torch.Tensor,
compress_seqlens: torch.Tensor,
kernel_size: int,
kernel_stride: int,
block_size: int,
topk: int,
init_blocks: int = 1,
local_blocks: int = 2,
sm_scale: Optional[float] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""_summary_
Args:
q (torch.Tensor): shape [batch_size, num_q_heads, head_dim]
k (torch.Tensor): shape [batch_size, kv_len, num_kv_heads, head_dim]
v (torch.Tensor): shape [batch_size, kv_len, num_kv_heads, head_dim]
seqlens (torch.Tensor): original kv length for each sequence
compress_seqlens (torch.Tensor): kv length for each sequence after compression
kernel_size (int): kernel size in compress_key_value
kernel_stride (int): stride of compress_key_value
block_size (int): key value block size for topk sparse attention.
topk (int): number of blocks for each query.
init_blocks (int, optional): Number of init blocks for each query. Defaults to 1.
local_blocks (int, optional): Number of local blocks for each query. Defaults to 2.
sm_scale (float, optional): softmax scale. Defaults to None, means 1/sqrt(head_dim).
Returns:
Tuple[torch.Tensor, torch.Tensor]: attention output and topk_idx used in topk_sparse_attention_decode
"""
assert block_size % kernel_size == 0 and kernel_size % kernel_stride == 0
batch_size, num_q_heads, head_dim = q.shape
batch_size, kv_len, num_k_heads, _ = k.shape
num_share_q_heads = num_q_heads // num_k_heads
if sm_scale is None:
sm_scale = 1.0 / math.sqrt(head_dim)
# input is too short to have a valid block
if kv_len == 0:
return torch.zeros_like(q), torch.zeros(
num_k_heads, batch_size, 1, device=q.device, dtype=torch.int32
)
# get mask
mask = (
compress_seqlens[:, None]
> torch.arange(
kv_len, device=compress_seqlens.device, dtype=compress_seqlens.dtype
)[None, :]
)
# attention
qk = (
torch.einsum(
"bihgd, bjhgd -> bhgij",
rearrange(q, "b (h g) d -> b 1 h g d", g=num_share_q_heads),
rearrange(k, "b j h d -> b j h 1 d"),
)
* sm_scale
)
qk = qk.masked_fill_(~mask[:, None, None, None, :], -torch.inf)
qk = qk.softmax(dim=-1, dtype=torch.float32)
qk = qk.nan_to_num_(0) # qk is nan when seqlen == 0
attn_output = torch.einsum(
"bhgij, bjhgd -> bihgd",
qk.to(v.dtype),
rearrange(v, "b k h d -> b k h 1 d"),
)
attn_output = rearrange(attn_output, "b 1 h g d -> b (h g) d")
# get score
score = rearrange(qk.sum(2).squeeze(2), "b h j -> h b j")
# transform score to block-wise score
score = transform_score(
score,
seqlens,
kernel_size,
kernel_stride,
block_size,
init_blocks,
local_blocks,
)
# get topk
q_idx = (seqlens - 1) // block_size
topk = min(topk, score.shape[-1])
topk_idx = score.topk(topk, dim=-1).indices.sort(-1).values
topk_idx[topk_idx > q_idx[None, :, None]] = -1
topk_idx = topk_idx.to(torch.int32)
return attn_output, topk_idx
================================================
FILE: native_sparse_attention/ops/torch/topk_sparse_attention.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import math
from typing import Optional
def topk_sparse_attention_torch(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
topk_idx: torch.Tensor,
block_size_k: int,
cu_seqlens: torch.Tensor,
softmax_scale: Optional[float] = None,
block_size_q: int = 1,
) -> torch.Tensor:
"""Simple topk sparse attention varlen version implemented in torch. Extremly slow, only for debugging.
Args:
q (torch.Tensor): shape [total_len, num_q_heads, head_dim]
k (torch.Tensor): shape [total_len, num_kv_heads, head_dim]
v (torch.Tensor): shape [total_len, num_kv_heads, head_dim]
topk_idx (torch.Tensor): topk block idx for each query, shape [num_kv_heads, total_len, topk]. -1 means padding.
block_size_q (int): query block size.
block_size_k (int): key value block size.
cu_seqlens (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens in flash_attn_func_varlen.
softmax_scale (Optional[float], optional): Defaults to None, means 1/sqrt(head_dim).
Returns:
torch.Tensor: attention output, shape [total_len, num_q_heads, head_dim]
"""
total_seqlen, num_q_heads, head_dim = q.shape
total_seqlen, num_kv_heads, head_dim = k.shape
num_share_q_heads = num_q_heads // num_kv_heads
batch_size = cu_seqlens.shape[0] - 1
topk = topk_idx.shape[-1]
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
seqblocks_q = torch.ceil(seqlens / block_size_q).to(torch.int32)
cu_seqblocks_q = torch.nn.functional.pad(seqblocks_q.cumsum(0), (1, 0), value=0)
if softmax_scale is None:
softmax_scale = 1.0 / math.sqrt(head_dim)
# get mask
mask = torch.zeros(
(num_kv_heads, total_seqlen, total_seqlen), dtype=torch.bool, device=q.device
)
for i in range(batch_size):
num_q_blocks = math.ceil(seqlens[i] / block_size_q)
num_kv_blocks = math.ceil(seqlens[i] / block_size_k)
for h in range(num_kv_heads):
temp_mask = torch.zeros(
num_q_blocks, num_kv_blocks, dtype=torch.bool, device=q.device
)
temp_idx = topk_idx[h, cu_seqblocks_q[i] : cu_seqblocks_q[i + 1]].clone()
temp_idx[temp_idx < 0] = 0
temp_mask[torch.arange(num_q_blocks).to(q.device)[:, None], temp_idx] = True
temp_mask = torch.repeat_interleave(temp_mask, block_size_q, dim=0)
temp_mask = torch.repeat_interleave(temp_mask, block_size_k, dim=1)
temp_mask = temp_mask[: seqlens[i], : seqlens[i]]
mask[
h, cu_seqlens[i] : cu_seqlens[i + 1], cu_seqlens[i] : cu_seqlens[i + 1]
] = temp_mask
mask = torch.tril(mask).repeat_interleave(num_share_q_heads, 0)
# qk attn
qk = (
torch.einsum("qhd,khd->hqk", q, k.repeat_interleave(num_share_q_heads, 1))
* softmax_scale
)
qk = torch.masked_fill(qk, ~mask, -torch.inf)
qk = torch.softmax(qk, dim=-1, dtype=torch.float32).to(q.dtype)
o = torch.einsum("hqk,khd->qhd", qk, v.repeat_interleave(num_share_q_heads, 1))
return o
================================================
FILE: native_sparse_attention/ops/triton/__init__.py
================================================
================================================
FILE: native_sparse_attention/ops/triton/compressed_attention.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Any, Tuple, Union
from collections import Counter
import torch
import triton
import triton.language as tl
import warnings
from native_sparse_attention.ops.triton.utils import get_num_warps_stages, is_hopper_gpu
IS_HOPPER_GPU = is_hopper_gpu()
@triton.jit
def forward_kernel(
q_ptr, # Q: n x h x d
k_ptr, # K: n x h x d
v_ptr, # V: n x h x d
o_ptr, # O: n x h x d
lse_ptr, # LSE: h x n
# size and stride at compresstion
kernel_size,
kernel_stride,
# seqlens
cu_seqlens_q,
cu_seqlens_k,
# shape
NUM_KV_HEADS,
NUM_SHARE_Q_HEADS,
HEAD_DIM,
# sm_scale
sm_scale,
# stride
stride_qn,
stride_qh,
stride_qd,
stride_kn,
stride_kh,
stride_kd,
stride_vn,
stride_vh,
stride_vd,
stride_on,
stride_oh,
stride_od,
stride_lh,
stride_ln,
# META parameters
BLOCK_SIZE_Q: tl.constexpr, # q block size
BLOCK_SIZE_K: tl.constexpr, # k block size
BLOCK_SIZE_D: tl.constexpr,
):
qk_scale = sm_scale * 1.44269504
# get batch id and head id
pid_b = tl.program_id(0)
pid_h = tl.program_id(1)
pid_q = tl.program_id(2)
pid_kh = pid_h // NUM_SHARE_Q_HEADS
# get q k start and len after rmpad
q_start = tl.load(cu_seqlens_q + pid_b)
q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
k_start = tl.load(cu_seqlens_k + pid_b)
k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
# skip first kernel_size query block, because they do no attend to any keys
q_start_in_seq = pid_q * BLOCK_SIZE_Q + kernel_size - 1
if q_start_in_seq >= q_len:
return
# init qkv pointer
q_ptrs = tl.make_block_ptr(
base=q_ptr + q_start * stride_qn + pid_h * stride_qh,
shape=(q_len, HEAD_DIM),
strides=(stride_qn, stride_qd),
offsets=(q_start_in_seq, 0),
block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
order=(1, 0),
)
k_ptrs = tl.make_block_ptr(
base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
shape=(HEAD_DIM, k_len),
strides=(stride_kd, stride_kn),
offsets=(0, 0),
block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K),
order=(0, 1),
)
v_ptrs = tl.make_block_ptr(
base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,
shape=(k_len, HEAD_DIM),
strides=(stride_vn, stride_vd),
offsets=(0, 0),
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
order=(1, 0),
)
# load q
q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero")
# init statistics
off_q = tl.arange(0, BLOCK_SIZE_Q) + q_start_in_seq
off_k = tl.arange(0, BLOCK_SIZE_K) * kernel_stride + kernel_size - 1
m_i = tl.full((BLOCK_SIZE_Q,), float("-inf"), dtype=tl.float32)
lse_i = tl.full((BLOCK_SIZE_Q,), float("-inf"), dtype=tl.float32)
acc_o = tl.full((BLOCK_SIZE_Q, BLOCK_SIZE_D), 0, dtype=tl.float32)
# attention
lo = 0
hi = min(k_len, (q_start_in_seq + BLOCK_SIZE_Q - kernel_size) // kernel_stride + 1)
for i in range(lo, hi, BLOCK_SIZE_K):
i = tl.multiple_of(i, BLOCK_SIZE_K)
# load k
k = tl.load(k_ptrs, boundary_check=(1, 0), padding_option="zero")
# compute qk
qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
qk += tl.where(
off_q[:, None] >= (i * kernel_stride + off_k)[None, :], 0, float("-inf")
)
qk += tl.dot(q, k) * qk_scale
# compute m_ij and l_ij
m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
p = tl.exp2(qk - m_ij[:, None])
l_ij = tl.sum(p, axis=1)
# scale acc_o
acc_o_scale = tl.exp2(m_i - m_ij)
acc_o = acc_o * acc_o_scale[:, None]
# load v and update acc_o
v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero")
p = p.to(v.dtype)
acc_o += tl.dot(p, v)
# update statistics
m_i = m_ij
lse_i = m_ij + tl.math.log2(tl.exp2(lse_i - m_ij) + l_ij)
# update ptrs
k_ptrs = tl.advance(k_ptrs, (0, BLOCK_SIZE_K))
v_ptrs = tl.advance(v_ptrs, (BLOCK_SIZE_K, 0))
# final scale
acc_o = acc_o * tl.exp2(m_i - lse_i)[:, None]
# save output
o_ptrs = tl.make_block_ptr(
base=o_ptr + q_start * stride_on + pid_h * stride_oh,
shape=(q_len, HEAD_DIM),
strides=(stride_on, stride_od),
offsets=(q_start_in_seq, 0),
block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
order=(1, 0),
)
tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1))
# save lse
l_ptrs = lse_ptr + q_start * stride_ln + pid_h * stride_lh + off_q * stride_ln
tl.store(l_ptrs, lse_i, mask=off_q < q_len)
@triton.jit
def backward_sum_o_do(
o_ptr, # O: n x h x d
do_ptr, # dO: n x h x d
delta_ptr, # D: h x n
o_len,
HEAD_DIM,
stride_on,
stride_oh,
stride_od,
stride_don,
stride_doh,
stride_dod,
stride_dh,
stride_dn,
BLOCK_SIZE_O: tl.constexpr,
BLOCK_SIZE_D: tl.constexpr,
):
pid_n = tl.program_id(0)
pid_h = tl.program_id(1)
off_n = pid_n * BLOCK_SIZE_O + tl.arange(0, BLOCK_SIZE_O)
off_d = tl.arange(0, BLOCK_SIZE_D)
o = tl.load(
o_ptr
+ off_n[:, None] * stride_on
+ pid_h * stride_oh
+ off_d[None, :] * stride_od,
mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM),
other=0,
).to(tl.float32)
do = tl.load(
do_ptr
+ off_n[:, None] * stride_don
+ pid_h * stride_doh
+ off_d[None, :] * stride_dod,
mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM),
other=0,
).to(tl.float32)
delta = tl.sum(o * do, axis=1)
tl.store(
delta_ptr + pid_h * stride_dh + off_n * stride_dn, delta, mask=off_n < o_len
)
@triton.jit
def backward_dkdv(
q_ptr, # Q: n x qh x d
k_ptr, # K: n x kh x d
v_ptr, # V: n x kh x d
lse_ptr, # LSE: qh x n
d_ptr, # Delta: qh x n
do_ptr,
dk_ptr, # DK: sh x n x kh x d
dv_ptr, # DV: sh x n x kh x d
kernel_size,
kernel_stride,
# seqlens
cu_seqlens_q,
cu_seqlens_k,
# shape
NUM_KV_HEADS,
NUM_SHARE_Q_HEADS,
HEAD_DIM,
# sm_scale
sm_scale,
# stride
stride_qn,
stride_qh,
stride_qd,
stride_kn,
stride_kh,
stride_kd,
stride_vn,
stride_vh,
stride_vd,
stride_lh,
stride_ln,
stride_dh,
stride_dn,
stride_don,
stride_doh,
stride_dod,
stride_dks,
stride_dkn,
stride_dkh,
stride_dkd,
stride_dvs,
stride_dvn,
stride_dvh,
stride_dvd,
# META parameters
BLOCK_SIZE_Q: tl.constexpr, # q block size
BLOCK_SIZE_K: tl.constexpr, # k block size
BLOCK_SIZE_D: tl.constexpr,
):
qk_scale = sm_scale * 1.44269504
# get batch id and head id
pid_b = tl.program_id(0)
pid_h = tl.program_id(1)
pid_kh = pid_h // NUM_SHARE_Q_HEADS
pid_sh = pid_h % NUM_SHARE_Q_HEADS
pid_k = tl.program_id(2)
# get q k start and len after rmpad
q_start = tl.load(cu_seqlens_q + pid_b)
q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
k_start = tl.load(cu_seqlens_k + pid_b)
k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
if BLOCK_SIZE_K * pid_k >= k_len:
return
# init pointers
k_ptrs = tl.make_block_ptr(
base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
shape=(k_len, HEAD_DIM),
strides=(stride_kn, stride_kd),
offsets=(pid_k * BLOCK_SIZE_K, 0),
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
order=(1, 0),
)
dk_ptrs = tl.make_block_ptr(
base=dk_ptr + k_start * stride_dkn + pid_kh * stride_dkh + pid_sh * stride_dks,
shape=(k_len, HEAD_DIM),
strides=(stride_dkn, stride_dkd),
offsets=(pid_k * BLOCK_SIZE_K, 0),
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
order=(1, 0),
)
v_ptrs = tl.make_block_ptr(
base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,
shape=(k_len, HEAD_DIM),
strides=(stride_vn, stride_vd),
offsets=(pid_k * BLOCK_SIZE_K, 0),
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
order=(1, 0),
)
dv_ptrs = tl.make_block_ptr(
base=dv_ptr + k_start * stride_dvn + pid_kh * stride_dvh + pid_sh * stride_dvs,
shape=(k_len, HEAD_DIM),
strides=(stride_dvn, stride_dvd),
offsets=(pid_k * BLOCK_SIZE_K, 0),
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
order=(1, 0),
)
# offsets
off_q = tl.arange(0, BLOCK_SIZE_Q)
off_k = (
pid_k * BLOCK_SIZE_K * kernel_stride
+ tl.arange(0, BLOCK_SIZE_K) * kernel_stride
+ kernel_size
- 1
)
# load k v and keep in SRAM
k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero")
v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero")
# init dk dv
dk = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32)
dv = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32)
q_lo = pid_k * BLOCK_SIZE_K * kernel_stride + kernel_size - 1
q_ptrs = tl.make_block_ptr(
base=q_ptr + q_start * stride_qn + pid_h * stride_qh,
shape=(HEAD_DIM, q_len),
strides=(stride_qd, stride_qn),
offsets=(0, q_lo),
block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_Q),
order=(0, 1),
)
do_ptrs = tl.make_block_ptr(
base=do_ptr + q_start * stride_don + pid_h * stride_doh,
shape=(HEAD_DIM, q_len),
strides=(stride_dod, stride_don),
offsets=(0, q_lo),
block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_Q),
order=(0, 1),
)
d_ptrs = tl.make_block_ptr(
base=d_ptr + q_start * stride_dn + pid_h * stride_dh,
shape=(1, q_len),
strides=(0, stride_dn),
offsets=(0, q_lo),
block_shape=(1, BLOCK_SIZE_Q),
order=(1, 0),
)
lse_ptrs = tl.make_block_ptr(
base=lse_ptr + q_start * stride_ln + pid_h * stride_lh,
shape=(1, q_len),
strides=(0, stride_ln),
offsets=(0, q_lo),
block_shape=(1, BLOCK_SIZE_Q),
order=(0, 1),
)
# loop for q blocks
for i in range(q_lo, q_len, BLOCK_SIZE_Q):
# load
q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero")
do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero")
lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero")
d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero")
# compute qk
# [BLOCK_SIZE_K, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q]
qk = tl.where(off_k[:, None] <= (off_q + i)[None, :], float(0.0), float("-inf"))
qk += tl.dot(k, q) * qk_scale
# compute p, ds
# [BLOCK_SIZE_K, BLOCK_SIE_Q] - [1, BLOCK_SIZE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q]
p = tl.exp2(qk - lse)
# [BLOCK_SIZE_K, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q]
dp = tl.dot(v, do)
ds = sm_scale * p * (dp - d)
# cast dtype
p = p.to(do.dtype)
ds = ds.to(q.dtype)
# update dk and dv
# [BLOCK_SIZE_K, BLOCK_SIE_Q] @ [BLOCK_SIE_Q, HEAD_DIM] -> [BLOCK_SIZE_K, HEAD_DIM]
dk += tl.dot(ds, tl.trans(q))
dv += tl.dot(p, tl.trans(do))
# increment pointers
q_ptrs = tl.advance(q_ptrs, (0, BLOCK_SIZE_Q))
do_ptrs = tl.advance(do_ptrs, (0, BLOCK_SIZE_Q))
lse_ptrs = tl.advance(lse_ptrs, (0, BLOCK_SIZE_Q))
d_ptrs = tl.advance(d_ptrs, (0, BLOCK_SIZE_Q))
# save dk dv
tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), boundary_check=(0, 1))
tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), boundary_check=(0, 1))
@triton.jit
def backward_dq(
q_ptr, # Q: n x qh x d
k_ptr, # K: n x kh x d
v_ptr, # V: n x kh x d
lse_ptr, # LSE: qh x n
d_ptr, # Delta: qh x n
do_ptr,
dq_ptr,
kernel_size,
kernel_stride,
# seqlens
cu_seqlens_q,
cu_seqlens_k,
# shape
NUM_KV_HEADS,
NUM_SHARE_Q_HEADS,
HEAD_DIM,
# sm_scale
sm_scale,
# stride
stride_qn,
stride_qh,
stride_qd,
stride_kn,
stride_kh,
stride_kd,
stride_vn,
stride_vh,
stride_vd,
stride_lh,
stride_ln,
stride_dh,
stride_dn,
stride_don,
stride_doh,
stride_dod,
stride_dqn,
stride_dqh,
stride_dqd,
# META parameters
BLOCK_SIZE_Q: tl.constexpr, # q block size
BLOCK_SIZE_K: tl.constexpr, # k block size
BLOCK_SIZE_D: tl.constexpr,
):
qk_scale = sm_scale * 1.44269504
# get batch id and head id
pid_b = tl.program_id(0)
pid_h = tl.program_id(1)
pid_q = tl.program_id(2)
pid_kh = pid_h // NUM_SHARE_Q_HEADS
# get q k start and len after rmpad
q_start = tl.load(cu_seqlens_q + pid_b)
q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
k_start = tl.load(cu_seqlens_k + pid_b)
k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
# skip first kernel_size query block, because they do no attend to any keys
q_start_in_seq = pid_q * BLOCK_SIZE_Q + kernel_size - 1
if q_start_in_seq >= q_len:
return
# init pointers
q_ptrs = tl.make_block_ptr(
base=q_ptr + q_start * stride_qn + pid_h * stride_qh,
shape=(q_len, HEAD_DIM),
strides=(stride_qn, stride_qd),
offsets=(q_start_in_seq, 0),
block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
order=(1, 0),
)
dq_ptrs = tl.make_block_ptr(
base=dq_ptr + q_start * stride_dqn + pid_h * stride_dqh,
shape=(q_len, HEAD_DIM),
strides=(stride_dqn, stride_dqd),
offsets=(q_start_in_seq, 0),
block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
order=(1, 0),
)
k_ptrs = tl.make_block_ptr(
base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
shape=(k_len, HEAD_DIM),
strides=(stride_kn, stride_kd),
offsets=(0, 0),
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
order=(1, 0),
)
v_ptrs = tl.make_block_ptr(
base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,
shape=(HEAD_DIM, k_len),
strides=(stride_vd, stride_vn),
offsets=(0, 0),
block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K),
order=(0, 1),
)
do_ptrs = tl.make_block_ptr(
base=do_ptr + q_start * stride_don + pid_h * stride_doh,
shape=(q_len, HEAD_DIM),
strides=(stride_don, stride_dod),
offsets=(q_start_in_seq, 0),
block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
order=(1, 0),
)
d_ptrs = tl.make_block_ptr(
base=d_ptr + q_start * stride_dn + pid_h * stride_dh,
shape=(q_len, 1),
strides=(stride_dn, stride_dh),
offsets=(q_start_in_seq, 0),
block_shape=(BLOCK_SIZE_Q, 1),
order=(0, 1),
)
lse_ptrs = tl.make_block_ptr(
base=lse_ptr + q_start * stride_ln + pid_h * stride_lh,
shape=(q_len, 1),
strides=(stride_ln, stride_lh),
offsets=(q_start_in_seq, 0),
block_shape=(BLOCK_SIZE_Q, 1),
order=(0, 1),
)
# offsets
off_q = tl.arange(0, BLOCK_SIZE_Q) + q_start_in_seq
off_k = tl.arange(0, BLOCK_SIZE_K) * kernel_stride + kernel_size - 1
# load q, do, lse, delta, and keep in SRAM
q = tl.load(q_ptrs, boundary_check=(1, 0), padding_option="zero")
do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero")
lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero")
d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero")
# init dq
dq = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_D), dtype=tl.float32)
lo = 0
hi = min(k_len, (q_start_in_seq + BLOCK_SIZE_Q - kernel_size) // kernel_stride + 1)
for i in range(lo, hi, BLOCK_SIZE_K):
# load
k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero")
v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero")
# compute qk
qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
qk += tl.where(
off_q[:, None] >= (i * kernel_stride + off_k)[None, :], 0, float("-inf")
)
qk += tl.dot(q, tl.trans(k)) * qk_scale
# compute p, ds
p = tl.exp2(qk - lse)
dp = tl.dot(do, v)
ds = sm_scale * p * (dp - d)
# cast dtype
ds = ds.to(q.dtype)
# update dq
dq += tl.dot(ds, k)
# increment pointers
k_ptrs = tl.advance(k_ptrs, (BLOCK_SIZE_K, 0))
v_ptrs = tl.advance(v_ptrs, (0, BLOCK_SIZE_K))
# save dq
tl.store(dq_ptrs, dq.to(dq_ptr.dtype.element_ty), boundary_check=(0, 1))
def _compressed_attention_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
kernel_size: int,
kernel_stride: int,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
sm_scale: float,
):
# dtype check
assert k.dtype == q.dtype and v.dtype == q.dtype
assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32
# shape
q_len, num_q_heads, head_dim = q.shape
k_len, num_k_heads, head_dim = k.shape
v_len, num_v_heads, head_dim = v.shape
batch_size = cu_seqlens_q.shape[0] - 1
assert k_len == v_len and q_len > k_len
# gqa
assert num_k_heads == num_v_heads
assert num_q_heads % num_k_heads == 0
num_share_q_heads = num_q_heads // num_k_heads
# output tensor
o = torch.zeros_like(q)
lse = torch.full(
(num_q_heads, q_len),
fill_value=-torch.inf,
dtype=torch.float32,
device=q.device,
)
# launch kernel
grid = lambda META: (
batch_size,
num_q_heads,
triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]),
)
BLOCK_SIZE_Q = 128
BLOCK_SIZE_K = 128
BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU)
forward_kernel[grid](
q,
k,
v,
o,
lse,
kernel_size,
kernel_stride,
cu_seqlens_q,
cu_seqlens_k,
num_k_heads,
num_share_q_heads,
head_dim,
sm_scale,
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
lse.stride(0),
lse.stride(1),
BLOCK_SIZE_Q=BLOCK_SIZE_Q,
BLOCK_SIZE_K=BLOCK_SIZE_K,
BLOCK_SIZE_D=BLOCK_SIZE_D,
num_warps=num_warps,
num_stages=num_stages,
)
return o, lse
def _compressed_attention_bwd(
o: torch.Tensor,
do: torch.Tensor,
lse: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
kernel_size: int,
kernel_stride: int,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
sm_scale: float,
):
q_len, num_q_heads, head_dim = q.shape
k_len, num_k_heads, head_dim = k.shape
v_len, num_v_heads, head_dim = v.shape
o_len, num_o_heads, head_dim = o.shape
num_share_q_heads = num_q_heads // num_k_heads
# compute D
delta = torch.zeros([num_o_heads, o_len], device=o.device, dtype=torch.float32)
grid = lambda META: (triton.cdiv(o_len, META["BLOCK_SIZE_O"]), num_o_heads)
BLOCK_SIZE_O = 256
BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_O, IS_HOPPER_GPU)
backward_sum_o_do[grid](
o,
do,
delta,
o_len,
head_dim,
o.stride(0),
o.stride(1),
o.stride(2),
do.stride(0),
do.stride(1),
do.stride(2),
delta.stride(0),
delta.stride(1),
BLOCK_SIZE_O=BLOCK_SIZE_O,
BLOCK_SIZE_D=BLOCK_SIZE_D,
num_warps=num_warps,
num_stages=num_stages,
)
# compute dk dv
dk = torch.zeros(
num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype
)
dv = torch.zeros(
num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype
)
batch_size = cu_seqlens_q.shape[0] - 1
grid = lambda META: (
batch_size,
num_q_heads,
triton.cdiv(max_seqlen_k, META["BLOCK_SIZE_K"]),
)
BLOCK_SIZE_Q = 64
BLOCK_SIZE_K = 128
BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_K, IS_HOPPER_GPU)
backward_dkdv[grid](
q,
k,
v,
lse,
delta,
do,
dk,
dv,
kernel_size,
kernel_stride,
cu_seqlens_q,
cu_seqlens_k,
num_k_heads,
num_share_q_heads,
head_dim,
sm_scale,
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
lse.stride(0),
lse.stride(1),
delta.stride(0),
delta.stride(1),
do.stride(0),
do.stride(1),
do.stride(2),
dk.stride(0),
dk.stride(1),
dk.stride(2),
dk.stride(3),
dv.stride(0),
dv.stride(1),
dv.stride(2),
dv.stride(3),
BLOCK_SIZE_Q=BLOCK_SIZE_Q,
BLOCK_SIZE_K=BLOCK_SIZE_K,
BLOCK_SIZE_D=BLOCK_SIZE_D,
num_warps=num_warps,
num_stages=num_stages,
)
dk = dk.sum(0)
dv = dv.sum(0)
# compute dq
dq = torch.zeros_like(q)
grid = lambda META: (
batch_size,
num_q_heads,
triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]),
)
BLOCK_SIZE_Q = 128
BLOCK_SIZE_K = 64
num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU)
backward_dq[grid](
q,
k,
v,
lse,
delta,
do,
dq,
kernel_size,
kernel_stride,
cu_seqlens_q,
cu_seqlens_k,
num_k_heads,
num_share_q_heads,
head_dim,
sm_scale,
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
lse.stride(0),
lse.stride(1),
delta.stride(0),
delta.stride(1),
do.stride(0),
do.stride(1),
do.stride(2),
dq.stride(0),
dq.stride(1),
dq.stride(2),
BLOCK_SIZE_Q=BLOCK_SIZE_Q,
BLOCK_SIZE_K=BLOCK_SIZE_K,
BLOCK_SIZE_D=BLOCK_SIZE_D,
num_warps=num_warps,
num_stages=num_stages,
)
return dq, dk, dv
class CompressedAttention(torch.autograd.Function):
@staticmethod
def forward(
ctx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
kernel_size: int,
kernel_stride: int,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
sm_scale=None,
):
# dtype check
assert q.dtype == torch.bfloat16 or q.dtype == torch.float16
assert q.dtype == k.dtype and k.dtype == v.dtype
assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32
# softmax scale
if sm_scale is None:
sm_scale = 1 / math.sqrt(q.shape[-1])
o, lse = _compressed_attention_fwd(
q,
k,
v,
kernel_size,
kernel_stride,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
sm_scale,
)
ctx.save_for_backward(q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k)
ctx.sm_scale = sm_scale
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_k = max_seqlen_k
ctx.kernel_size = kernel_size
ctx.kernel_stride = kernel_stride
return o, lse
@staticmethod
def backward(ctx, do: torch.Tensor, *args) -> Any:
q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors
max_seqlen_q = ctx.max_seqlen_q
max_seqlen_k = ctx.max_seqlen_k
sm_scale = ctx.sm_scale
kernel_size = ctx.kernel_size
kernel_stride = ctx.kernel_stride
dq, dk, dv = _compressed_attention_bwd(
o,
do,
lse,
q,
k,
v,
kernel_size,
kernel_stride,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
sm_scale,
)
return dq, dk, dv, None, None, None, None, None, None, None
@triton.jit
def score_kernel(
q_ptr,
k_ptr,
lse_ptr,
s_ptr,
kernel_size,
kernel_stride,
# seqlens
cu_seqlens_q,
cu_seqlens_k,
# shape
NUM_KV_HEADS,
NUM_SHARE_Q_HEADS,
HEAD_DIM,
# sm_scale
sm_scale,
# stride
stride_qn,
stride_qh,
stride_qd,
stride_kn,
stride_kh,
stride_kd,
stride_lh,
stride_ln,
stride_sh,
stride_sq,
stride_sk,
# META parameters
BLOCK_SIZE_Q: tl.constexpr, # q block size
BLOCK_SIZE_K: tl.constexpr, # k block size
BLOCK_SIZE_D: tl.constexpr,
):
qk_scale = sm_scale * 1.44269504
# get batch id and head id
pid_bkh = tl.program_id(0)
pid_b = pid_bkh // NUM_KV_HEADS
pid_kh = pid_bkh % NUM_KV_HEADS
pid_q = tl.program_id(1)
pid_k = tl.program_id(2)
# get q k start and len after rmpad
q_start = tl.load(cu_seqlens_q + pid_b)
q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
k_start = tl.load(cu_seqlens_k + pid_b)
k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
if pid_q * BLOCK_SIZE_Q >= q_len or pid_k * BLOCK_SIZE_K >= k_len:
return
# init k pointer and load k
k_ptrs = tl.make_block_ptr(
base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
shape=(HEAD_DIM, k_len),
strides=(stride_kd, stride_kn),
offsets=(0, pid_k * BLOCK_SIZE_K),
block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K),
order=(0, 1),
)
k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero")
# offsets
off_q = tl.arange(0, BLOCK_SIZE_Q) + pid_q * BLOCK_SIZE_Q
off_k = tl.arange(0, BLOCK_SIZE_K) + pid_k * BLOCK_SIZE_K
causal_mask = off_q[:, None] >= (off_k * kernel_stride + kernel_size - 1)[None, :]
# init score
s = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
# loop over gqa heads
for h in range(NUM_SHARE_Q_HEADS):
pid_h = pid_kh * NUM_SHARE_Q_HEADS + h
q_ptrs = tl.make_block_ptr(
base=q_ptr + q_start * stride_qn + pid_h * stride_qh,
shape=(q_len, HEAD_DIM),
strides=(stride_qn, stride_qd),
offsets=(pid_q * BLOCK_SIZE_Q, 0),
block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
order=(1, 0),
)
lse_ptrs = tl.make_block_ptr(
base=lse_ptr + q_start * stride_ln + pid_h * stride_lh,
shape=(q_len, 1),
strides=(stride_ln, stride_lh),
offsets=(pid_q * BLOCK_SIZE_Q, 0),
block_shape=(BLOCK_SIZE_Q, 1),
order=(0, 1),
)
# load q and lse
q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero")
lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero")
# compute qk
qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
qk += tl.dot(q, k) * qk_scale
# compute score
s += tl.where(causal_mask, tl.exp2(qk - lse), 0)
# save output
s_ptrs = tl.make_block_ptr(
base=s_ptr + pid_kh * stride_sh + q_start * stride_sq,
shape=(q_len, k_len),
strides=(stride_sq, stride_sk),
offsets=(pid_q * BLOCK_SIZE_Q, pid_k * BLOCK_SIZE_K),
block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_K),
order=(1, 0),
)
tl.store(s_ptrs, s.to(s_ptr.dtype.element_ty), boundary_check=(0, 1))
def _get_attention_score(
q: torch.Tensor, # [total_query_len, num_q_heads, head_dim]
k: torch.Tensor, # [total_key_len, num_k_heads, head_dim]
lse: torch.Tensor, # [num_q_heads, total_query_len]
kernel_size: int,
kernel_stride: int,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
sm_scale: float,
) -> torch.Tensor:
# dtype check
assert q.dtype == torch.bfloat16 or q.dtype == torch.float16
assert q.dtype == k.dtype
assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32
assert (
lse.dtype == torch.float32
) # lse here is log2(sum(exp(qk*scale))), not log(sum(exp(qk*scale)))
# shape
q_len, num_q_heads, head_dim = q.shape
k_len, num_k_heads, head_dim = k.shape
batch_size = cu_seqlens_q.shape[0] - 1
assert q_len > k_len
if sm_scale is None:
sm_scale = 1 / math.sqrt(head_dim)
# gqa
assert num_q_heads % num_k_heads == 0
num_share_q_heads = num_q_heads // num_k_heads
# init score
score = torch.zeros(
num_k_heads, q_len, max_seqlen_k, dtype=torch.float32, device=q.device
)
# launch kernel
grid = lambda META: (
batch_size * num_k_heads,
triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]),
triton.cdiv(max_seqlen_k, META["BLOCK_SIZE_K"]),
)
BLOCK_SIZE_Q = 128
BLOCK_SIZE_K = 128
BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
score_kernel[grid](
q,
k,
lse,
score,
kernel_size,
kernel_stride,
cu_seqlens_q,
cu_seqlens_k,
num_k_heads,
num_share_q_heads,
head_dim,
sm_scale,
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
lse.stride(0),
lse.stride(1),
score.stride(0),
score.stride(1),
score.stride(2),
BLOCK_SIZE_Q=BLOCK_SIZE_Q,
BLOCK_SIZE_K=BLOCK_SIZE_K,
BLOCK_SIZE_D=BLOCK_SIZE_D,
num_warps=8,
num_stages=3,
)
return score
@triton.jit
def _transform_score_kernel(
s_ptr, # score, shape: [num_heads, q_len, k_len]
bs_ptr, # block wise score: [num_heads, q_len, num_k_block]
offs,
cu_seqlens_q,
# shape
num_heads,
num_offs,
max_k_len,
max_blocks,
pad_len,
# kernel & block size
block_size,
block_stride, # block_size // kernel_stride
init_blocks,
local_blocks,
# stride
stride_sh,
stride_sq,
stride_sk,
stride_bsh,
stride_bsq,
stride_bsk,
BLOCK_SIZE_Q: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
BLOCK_SIZE_O: tl.constexpr,
):
pid_bh = tl.program_id(0)
pid_b = pid_bh // num_heads
pid_h = pid_bh % num_heads
pid_q = tl.program_id(1)
pid_k = tl.program_id(2)
q_start = tl.load(cu_seqlens_q + pid_b)
q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
k_start = pid_k * BLOCK_SIZE_K
if pid_q * BLOCK_SIZE_Q >= q_len:
return
# load weight
off_o = tl.arange(0, BLOCK_SIZE_O)
w = tl.load(offs + off_o, mask=off_o < num_offs, other=0)
# load score
off_q = pid_q * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q)
off_k = (k_start + tl.arange(0, BLOCK_SIZE_K)) * block_stride - pad_len
off_k = off_k[None, :] + off_o[:, None]
s_ptrs = (
s_ptr
+ q_start * stride_sq
+ pid_h * stride_sh
+ off_q[:, None, None] * stride_sq
+ off_k[None, :, :] * stride_sk
)
# weighted sum, [BQ, BO, BK] * [1, BO, 1] -> [BQ, BO, BK] -> [BQ, BK]
s = tl.load(
s_ptrs,
mask=(off_q < q_len)[:, None, None] & (off_k >= 0) & (off_k < max_k_len),
other=0,
)
s = s * w[None, :, None]
s = tl.sum(s, axis=1)
# init mask and local mask
off_bq = off_q // block_size
off_bk = k_start + tl.arange(0, BLOCK_SIZE_K)
s = tl.where(
(
(off_bq[:, None] >= off_bk[None, :]) # causal mask
& (off_bq[:, None] < off_bk[None, :] + local_blocks) # local window
)
| (off_bk[None, :] < init_blocks), # init window
float("inf"),
s,
)
# store block wise score
bs_ptrs = (
bs_ptr
+ q_start * stride_bsq
+ pid_h * stride_bsh
+ off_q[:, None] * stride_bsq
+ off_bk[None, :] * stride_bsk
)
tl.store(
bs_ptrs,
s,
mask=(off_q < q_len)[:, None] & (off_bk < max_blocks)[None, :],
)
def transform_score(
score: torch.Tensor,
kernel_size: int,
kernel_stride: int,
block_size: int,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
init_blocks: int = 1,
local_blocks: int = 2,
) -> torch.Tensor:
num_k_heads, total_query_len, max_key_len = score.shape
batch_size = cu_seqlens_q.shape[0] - 1
pad_len = kernel_size // kernel_stride - 1
max_blocks = math.ceil(max_seqlen_q / block_size)
block_score = torch.zeros(
num_k_heads,
total_query_len,
max_blocks,
dtype=torch.float32,
device=score.device,
)
offs = (
torch.arange(kernel_size // kernel_stride, device=score.device)[:, None]
+ torch.arange(block_size // kernel_stride, device=score.device)[None, :]
).view(-1)
offs = torch.histc(offs, bins=offs.max() + 1, min=0, max=offs.max())
num_offs = int(offs.shape[0])
BLOCK_SIZE_K = min(128, triton.next_power_of_2(max_blocks))
BLOCK_SIZE_O = triton.next_power_of_2(num_offs)
BLOCK_SIZE_Q = 8
grid = (
num_k_heads * batch_size,
triton.cdiv(total_query_len, BLOCK_SIZE_Q),
triton.cdiv(max_blocks, BLOCK_SIZE_K),
)
_transform_score_kernel[grid](
score,
block_score,
offs,
cu_seqlens_q,
num_k_heads,
offs.shape[0],
max_key_len,
max_blocks,
pad_len,
block_size,
block_size // kernel_stride,
init_blocks,
local_blocks,
score.stride(0),
score.stride(1),
score.stride(2),
block_score.stride(0),
block_score.stride(1),
block_score.stride(2),
BLOCK_SIZE_Q=BLOCK_SIZE_Q,
BLOCK_SIZE_K=BLOCK_SIZE_K,
BLOCK_SIZE_O=BLOCK_SIZE_O,
num_warps=8,
num_stages=3,
)
return block_score
def compressed_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
kernel_size: int,
kernel_stride: int,
block_size: int,
topk: int,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int = None,
max_seqlen_k: int = None,
sm_scale: float = None,
init_blocks: int = 1,
local_blocks: int = 2,
parallel_topk_compute: Union[str, bool] = "auto",
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Attention between query and compressed key and value. Compute attention output and topk block idx used in topk_sparse_attention.
Args:
q (torch.Tensor): shape [total_q_len, num_q_heads, head_dim]
k (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim]
v (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim]
kernel_size (int): kernel size in compress_key_value
kernel_stride (int): stride of compress_key_value
block_size (int): key value block size for topk sparse attention.
topk (int): number of blocks for each query.
cu_seqlens_q (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.
cu_seqlens_k (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_k in flash_attn_func_varlen.
max_seqlen_q (int): max q len of the batch. Defaults to None, means (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max().item().
max_seqlen_k (int): max k len of the batch. Defaults to None, means (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item().
sm_scale (float, optional): softmax scale. Defaults to None, means 1/sqrt(head_dim).
init_blocks (int, optional): Number of init blocks for each query. Defaults to 1.
local_blocks (int, optional): Number of local blocks for each query. Defaults to 2.
parallel_topk_compute (str, optional): Only set it to False when the sequence length is too long. This can avoid a current bug.
We'll fix this issue later. Defaults to auto, it will be set to False when the sequence length is greater than 32k and True otherwise.
Returns:
Tuple[torch.Tensor, torch.Tensor]: attention output and topk_idx used in topk_sparse_attention
"""
if max_seqlen_q is None:
max_seqlen_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max().item()
if max_seqlen_k is None:
max_seqlen_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item()
attn_output, lse = CompressedAttention.apply(
q,
k,
v,
kernel_size,
kernel_stride,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
sm_scale,
)
# do not select topk index
if topk <= 0:
warnings.warn("topk <= 0, returned topk_idx will be None")
return attn_output, None
assert topk >= init_blocks + local_blocks
with torch.no_grad():
num_k_heads, num_q_heads = k.shape[1], q.shape[1]
num_shared_q_heads = num_q_heads // num_k_heads
batch_size = cu_seqlens_q.shape[0] - 1
q_idx = torch.cat(
[
torch.arange(cu_seqlens_q[i + 1] - cu_seqlens_q[i], device=q.device)
for i in range(batch_size)
],
dim=0,
)
q_idx = q_idx // block_size
# whether to use parallel version
if parallel_topk_compute == "auto":
parallel_topk_compute = cu_seqlens_q[-1] <= 32768
# parallel version
if parallel_topk_compute:
# recompute score
score = _get_attention_score(
q,
k,
lse,
kernel_size,
kernel_stride,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
sm_scale,
)
# transform score to block-wise score
score = transform_score(
score,
kernel_size,
kernel_stride,
block_size,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
init_blocks,
local_blocks,
)
# get topk
topk = min(topk, score.shape[-1])
topk_idx = score.topk(topk, dim=-1).indices.sort(-1).values
topk_idx[topk_idx > q_idx[None, :, None]] = -1
topk_idx = topk_idx.to(torch.int32)
# non parallel version, avoid some current bugs when sequence length is too long
# FIXME: need to fix later
else:
topk_idx_list = []
for h in range(num_k_heads):
# recompute score
score = _get_attention_score(
q[:, h * num_shared_q_heads : (h + 1) * num_shared_q_heads],
k[:, h : h + 1],
lse[h * num_shared_q_heads : (h + 1) * num_shared_q_heads],
kernel_size,
kernel_stride,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
sm_scale,
)
# transform score to block-wise score
score = transform_score(
score,
kernel_size,
kernel_stride,
block_size,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
init_blocks,
local_blocks,
)
# get topk
topk = min(topk, score.shape[-1])
topk_idx = score.topk(topk, dim=-1).indices.sort(-1).values
topk_idx[topk_idx > q_idx[None, :, None]] = -1
topk_idx = topk_idx.to(torch.int32)
topk_idx_list.append(topk_idx)
topk_idx = torch.cat(topk_idx_list, dim=0)
return attn_output, topk_idx
================================================
FILE: native_sparse_attention/ops/triton/flash_attention.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Any, Optional
import torch
import triton
import triton.language as tl
from native_sparse_attention.ops.triton.utils import get_num_warps_stages, is_hopper_gpu
IS_HOPPER_GPU = is_hopper_gpu()
@triton.jit
def forward_kernel(
q_ptr, # Q: n x h x d
k_ptr, # K: n x h x d
v_ptr, # V: n x h x d
o_ptr, # O: n x h x d
lse_ptr, # LSE: h x n
# seqlens
cu_seqlens_q,
cu_seqlens_k,
# shape
NUM_KV_HEADS,
NUM_SHARE_Q_HEADS,
qk_head_dim,
v_head_dim,
# sm_scale
sm_scale,
# causal
causal,
# gqa
gqa_interleave,
# stride
stride_qn,
stride_qh,
stride_qd,
stride_kn,
stride_kh,
stride_kd,
stride_vn,
stride_vh,
stride_vd,
stride_on,
stride_oh,
stride_od,
stride_lh,
stride_ln,
# META parameters
BLOCK_SIZE_Q: tl.constexpr, # q block size
BLOCK_SIZE_K: tl.constexpr, # k block size
BLOCK_SIZE_KD: tl.constexpr,
BLOCK_SIZE_VD: tl.constexpr,
):
qk_scale = sm_scale * 1.44269504
# get batch id and head id
pid_b = tl.program_id(0)
pid_h = tl.program_id(1)
pid_q = tl.program_id(2)
if gqa_interleave:
pid_kh = pid_h % NUM_KV_HEADS
else:
pid_kh = pid_h // NUM_SHARE_Q_HEADS
# get q k start and len after rmpad
q_start = tl.load(cu_seqlens_q + pid_b)
q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
k_start = tl.load(cu_seqlens_k + pid_b)
k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
if BLOCK_SIZE_Q * pid_q >= q_len:
return
# init qkv pointer
q_ptrs = tl.make_block_ptr(
base=q_ptr + q_start * stride_qn + pid_h * stride_qh,
shape=(q_len, qk_head_dim),
strides=(stride_qn, stride_qd),
offsets=(pid_q * BLOCK_SIZE_Q, 0),
block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_KD),
order=(1, 0),
)
k_ptrs = tl.make_block_ptr(
base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
shape=(qk_head_dim, k_len),
strides=(stride_kd, stride_kn),
offsets=(0, 0),
block_shape=(BLOCK_SIZE_KD, BLOCK_SIZE_K),
order=(0, 1),
)
v_ptrs = tl.make_block_ptr(
base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,
shape=(k_len, v_head_dim),
strides=(stride_vn, stride_vd),
offsets=(0, 0),
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_VD),
order=(1, 0),
)
# load q
q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero")
# init statistics
off_q = tl.arange(0, BLOCK_SIZE_Q) + pid_q * BLOCK_SIZE_Q
off_k = tl.arange(0, BLOCK_SIZE_K)
m_i = tl.full((BLOCK_SIZE_Q,), float("-inf"), dtype=tl.float32)
lse_i = tl.full((BLOCK_SIZE_Q,), float("-inf"), dtype=tl.float32)
acc_o = tl.full((BLOCK_SIZE_Q, BLOCK_SIZE_VD), 0, dtype=tl.float32)
# full attention or causal attention
lo = 0
if causal:
hi = min(k_len, (pid_q + 1) * BLOCK_SIZE_Q)
else:
hi = k_len
for i in range(lo, hi, BLOCK_SIZE_K):
i = tl.multiple_of(i, BLOCK_SIZE_K)
# load k
k = tl.load(k_ptrs, boundary_check=(1, 0), padding_option="zero")
# compute qk
qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
if causal:
qk += tl.where(off_q[:, None] >= (i + off_k)[None, :], 0, float("-inf"))
else:
qk += tl.where((off_k < k_len - i)[None, :], 0, float("-inf"))
qk += tl.dot(q, k) * qk_scale
# compute m_ij and l_ij
m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
p = tl.math.exp2(qk - m_ij[:, None])
l_ij = tl.sum(p, axis=1)
# scale acc_o
acc_o_scale = tl.math.exp2(m_i - m_ij)
acc_o = acc_o * acc_o_scale[:, None]
# load v and update acc_o
v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero")
p = p.to(v.dtype)
acc_o += tl.dot(p, v)
# update statistics
m_i = m_ij
lse_i = m_ij + tl.math.log2(tl.math.exp2(lse_i - m_ij) + l_ij)
# update ptrs
k_ptrs = tl.advance(k_ptrs, (0, BLOCK_SIZE_K))
v_ptrs = tl.advance(v_ptrs, (BLOCK_SIZE_K, 0))
# final scale
acc_o = acc_o * tl.math.exp2(m_i - lse_i)[:, None]
# save output
o_ptrs = tl.make_block_ptr(
base=o_ptr + q_start * stride_on + pid_h * stride_oh,
shape=(q_len, v_head_dim),
strides=(stride_on, stride_od),
offsets=(pid_q * BLOCK_SIZE_Q, 0),
block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_VD),
order=(1, 0),
)
tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1))
# save lse
l_ptrs = lse_ptr + q_start * stride_ln + pid_h * stride_lh + off_q * stride_ln
tl.store(l_ptrs, lse_i, mask=off_q < q_len)
@triton.jit
def backward_sum_o_do(
o_ptr, # O: n x h x d
do_ptr, # dO: n x h x d
delta_ptr, # D: h x n
o_len,
HEAD_DIM,
stride_on,
stride_oh,
stride_od,
stride_don,
stride_doh,
stride_dod,
stride_dh,
stride_dn,
BLOCK_SIZE_O: tl.constexpr,
BLOCK_SIZE_D: tl.constexpr,
):
pid_n = tl.program_id(0)
pid_h = tl.program_id(1)
off_n = pid_n * BLOCK_SIZE_O + tl.arange(0, BLOCK_SIZE_O)
off_d = tl.arange(0, BLOCK_SIZE_D)
o = tl.load(
o_ptr
+ off_n[:, None] * stride_on
+ pid_h * stride_oh
+ off_d[None, :] * stride_od,
mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM),
other=0,
).to(tl.float32)
do = tl.load(
do_ptr
+ off_n[:, None] * stride_don
+ pid_h * stride_doh
+ off_d[None, :] * stride_dod,
mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM),
other=0,
).to(tl.float32)
delta = tl.sum(o * do, axis=1)
tl.store(
delta_ptr + pid_h * stride_dh + off_n * stride_dn, delta, mask=off_n < o_len
)
@triton.jit
def backward_dkdv(
q_ptr, # Q: n x qh x d
k_ptr, # K: n x kh x d
v_ptr, # V: n x kh x d
lse_ptr, # LSE: qh x n
d_ptr, # Delta: qh x n
do_ptr,
dk_ptr, # DK: sh x n x kh x d
dv_ptr, # DV: sh x n x kh x d
# seqlens
cu_seqlens_q,
cu_seqlens_k,
# shape
NUM_KV_HEADS,
NUM_SHARE_Q_HEADS,
qk_head_dim,
v_head_dim,
# sm_scale
sm_scale,
# causal
causal,
# gqa
gqa_interleave,
# stride
stride_qn,
stride_qh,
stride_qd,
stride_kn,
stride_kh,
stride_kd,
stride_vn,
stride_vh,
stride_vd,
stride_lh,
stride_ln,
stride_dh,
stride_dn,
stride_don,
stride_doh,
stride_dod,
stride_dks,
stride_dkn,
stride_dkh,
stride_dkd,
stride_dvs,
stride_dvn,
stride_dvh,
stride_dvd,
# META parameters
BLOCK_SIZE_Q: tl.constexpr, # q block size
BLOCK_SIZE_K: tl.constexpr, # k block size
BLOCK_SIZE_KD: tl.constexpr,
BLOCK_SIZE_VD: tl.constexpr,
):
qk_scale = sm_scale * 1.44269504
# get batch id and head id
pid_b = tl.program_id(0)
pid_h = tl.program_id(1)
if gqa_interleave:
pid_kh = pid_h % NUM_SHARE_Q_HEADS
pid_sh = pid_h // NUM_SHARE_Q_HEADS
else:
pid_kh = pid_h // NUM_SHARE_Q_HEADS
pid_sh = pid_h % NUM_SHARE_Q_HEADS
pid_k = tl.program_id(2)
# get q k start and len after rmpad
q_start = tl.load(cu_seqlens_q + pid_b)
q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
k_start = tl.load(cu_seqlens_k + pid_b)
k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
if BLOCK_SIZE_K * pid_k >= k_len:
return
# init pointers
k_ptrs = tl.make_block_ptr(
base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
shape=(k_len, qk_head_dim),
strides=(stride_kn, stride_kd),
offsets=(pid_k * BLOCK_SIZE_K, 0),
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_KD),
order=(1, 0),
)
dk_ptrs = tl.make_block_ptr(
base=dk_ptr + k_start * stride_dkn + pid_kh * stride_dkh + pid_sh * stride_dks,
shape=(k_len, qk_head_dim),
strides=(stride_dkn, stride_dkd),
offsets=(pid_k * BLOCK_SIZE_K, 0),
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_KD),
order=(1, 0),
)
v_ptrs = tl.make_block_ptr(
base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,
shape=(k_len, v_head_dim),
strides=(stride_vn, stride_vd),
offsets=(pid_k * BLOCK_SIZE_K, 0),
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_VD),
order=(1, 0),
)
dv_ptrs = tl.make_block_ptr(
base=dv_ptr + k_start * stride_dvn + pid_kh * stride_dvh + pid_sh * stride_dvs,
shape=(k_len, v_head_dim),
strides=(stride_dvn, stride_dvd),
offsets=(pid_k * BLOCK_SIZE_K, 0),
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_VD),
order=(1, 0),
)
# offsets
off_q = tl.arange(0, BLOCK_SIZE_Q)
off_k = tl.arange(0, BLOCK_SIZE_K) + pid_k * BLOCK_SIZE_K
# load k v and keep in SRAM
k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero")
v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero")
# init dk dv
dk = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_KD), dtype=tl.float32)
dv = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_VD), dtype=tl.float32)
# causal
if causal:
q_lo = pid_k * BLOCK_SIZE_K
else:
q_lo = 0
q_ptrs = tl.make_block_ptr(
base=q_ptr + q_start * stride_qn + pid_h * stride_qh,
shape=(q_len, qk_head_dim),
strides=(stride_qn, stride_qd),
offsets=(q_lo, 0),
block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_KD),
order=(1, 0),
)
do_ptrs = tl.make_block_ptr(
base=do_ptr + q_start * stride_don + pid_h * stride_doh,
shape=(q_len, v_head_dim),
strides=(stride_don, stride_dod),
offsets=(q_lo, 0),
block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_VD),
order=(1, 0),
)
d_ptrs = tl.make_block_ptr(
base=d_ptr + q_start * stride_dn + pid_h * stride_dh,
shape=(q_len, 1),
strides=(stride_dn, stride_dh),
offsets=(q_lo, 0),
block_shape=(BLOCK_SIZE_Q, 1),
order=(0, 1),
)
lse_ptrs = tl.make_block_ptr(
base=lse_ptr + q_start * stride_ln + pid_h * stride_lh,
shape=(q_len, 1),
strides=(stride_ln, stride_lh),
offsets=(q_lo, 0),
block_shape=(BLOCK_SIZE_Q, 1),
order=(0, 1),
)
# loop for q blocks
for i in range(q_lo, q_len, BLOCK_SIZE_Q):
# load
q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero")
do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero")
lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero")
d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero")
# compute qk
if causal:
qk = tl.where(
(off_q + i)[:, None] >= off_k[None, :], float(0.0), float("-inf")
)
else:
qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
qk += tl.dot(q, k.T) * qk_scale
# compute p, ds
p = tl.math.exp2(qk - lse)
dp = tl.dot(do, v.T)
ds = sm_scale * p * (dp - d)
# cast dtype
p = p.to(do.dtype)
ds = ds.to(q.dtype)
# update dk and dv
dk += tl.dot(ds.T, q)
dv += tl.dot(p.T, do)
# increment pointers
q_ptrs = tl.advance(q_ptrs, (BLOCK_SIZE_Q, 0))
do_ptrs = tl.advance(do_ptrs, (BLOCK_SIZE_Q, 0))
lse_ptrs = tl.advance(lse_ptrs, (BLOCK_SIZE_Q, 0))
d_ptrs = tl.advance(d_ptrs, (BLOCK_SIZE_Q, 0))
# save dk dv
tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), boundary_check=(0, 1))
tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), boundary_check=(0, 1))
@triton.jit
def backward_dq(
q_ptr, # Q: n x qh x d
k_ptr, # K: n x kh x d
v_ptr, # V: n x kh x d
lse_ptr, # LSE: qh x n
d_ptr, # Delta: qh x n
do_ptr,
dq_ptr,
# seqlens
cu_seqlens_q,
cu_seqlens_k,
# shape
NUM_KV_HEADS,
NUM_SHARE_Q_HEADS,
qk_head_dim,
v_head_dim,
# sm_scale
sm_scale,
# causal
causal,
# gqa
gqa_interleave,
# stride
stride_qn,
stride_qh,
stride_qd,
stride_kn,
stride_kh,
stride_kd,
stride_vn,
stride_vh,
stride_vd,
stride_lh,
stride_ln,
stride_dh,
stride_dn,
stride_don,
stride_doh,
stride_dod,
stride_dqn,
stride_dqh,
stride_dqd,
# META parameters
BLOCK_SIZE_Q: tl.constexpr, # q block size
BLOCK_SIZE_K: tl.constexpr, # k block size
BLOCK_SIZE_KD: tl.constexpr,
BLOCK_SIZE_VD: tl.constexpr,
):
qk_scale = sm_scale * 1.44269504
# get batch id and head id
pid_b = tl.program_id(0)
pid_h = tl.program_id(1)
pid_q = tl.program_id(2)
if gqa_interleave:
pid_kh = pid_h % NUM_KV_HEADS
else:
pid_kh = pid_h // NUM_SHARE_Q_HEADS
# get q k start and len after rmpad
q_start = tl.load(cu_seqlens_q + pid_b)
q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
k_start = tl.load(cu_seqlens_k + pid_b)
k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
if BLOCK_SIZE_Q * pid_q >= q_len:
return
# init pointers
q_ptrs = tl.make_block_ptr(
base=q_ptr + q_start * stride_qn + pid_h * stride_qh,
shape=(q_len, qk_head_dim),
strides=(stride_qn, stride_qd),
offsets=(pid_q * BLOCK_SIZE_Q, 0),
block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_KD),
order=(1, 0),
)
dq_ptrs = tl.make_block_ptr(
base=dq_ptr + q_start * stride_dqn + pid_h * stride_dqh,
shape=(q_len, qk_head_dim),
strides=(stride_dqn, stride_dqd),
offsets=(pid_q * BLOCK_SIZE_Q, 0),
block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_KD),
order=(1, 0),
)
k_ptrs = tl.make_block_ptr(
base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
shape=(k_len, qk_head_dim),
strides=(stride_kn, stride_kd),
offsets=(0, 0),
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_KD),
order=(1, 0),
)
v_ptrs = tl.make_block_ptr(
base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,
shape=(k_len, qk_head_dim),
strides=(stride_vn, stride_vd),
offsets=(0, 0),
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_VD),
order=(1, 0),
)
do_ptrs = tl.make_block_ptr(
base=do_ptr + q_start * stride_don + pid_h * stride_doh,
shape=(q_len, qk_head_dim),
strides=(stride_don, stride_dod),
offsets=(pid_q * BLOCK_SIZE_Q, 0),
block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_VD),
order=(1, 0),
)
d_ptrs = tl.make_block_ptr(
base=d_ptr + q_start * stride_dn + pid_h * stride_dh,
shape=(q_len, 1),
strides=(stride_dn, stride_dh),
offsets=(pid_q * BLOCK_SIZE_Q, 0),
block_shape=(BLOCK_SIZE_Q, 1),
order=(0, 1),
)
lse_ptrs = tl.make_block_ptr(
base=lse_ptr + q_start * stride_ln + pid_h * stride_lh,
shape=(q_len, 1),
strides=(stride_ln, stride_lh),
offsets=(pid_q * BLOCK_SIZE_Q, 0),
block_shape=(BLOCK_SIZE_Q, 1),
order=(0, 1),
)
# offsets
off_q = tl.arange(0, BLOCK_SIZE_Q) + pid_q * BLOCK_SIZE_Q
off_k = tl.arange(0, BLOCK_SIZE_K)
# load q, do, lse, delta, and keep in SRAM
q = tl.load(q_ptrs, boundary_check=(1, 0), padding_option="zero")
do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero")
lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero")
d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero")
# init dq
dq = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_KD), dtype=tl.float32)
# causal
if causal:
k_hi = (pid_q + 1) * BLOCK_SIZE_Q
else:
k_hi = k_len
for j in range(0, k_hi, BLOCK_SIZE_K):
# load
k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero")
v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero")
# compute qk
if causal:
qk = tl.where(
off_q[:, None] >= (off_k + j)[None, :], float(0.0), float("-inf")
)
else:
qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
qk += tl.dot(q, k.T) * qk_scale
# compute p, ds
p = tl.math.exp2(qk - lse)
dp = tl.dot(do, v.T)
ds = sm_scale * p * (dp - d)
# cast dtype
ds = ds.to(q.dtype)
# update dq
dq += tl.dot(ds, k)
# increment pointers
k_ptrs = tl.advance(k_ptrs, (BLOCK_SIZE_K, 0))
v_ptrs = tl.advance(v_ptrs, (BLOCK_SIZE_K, 0))
# save dq
tl.store(dq_ptrs, dq.to(dq_ptr.dtype.element_ty), boundary_check=(0, 1))
def _flash_attention_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
causal: bool,
sm_scale: float,
gqa_interleave: bool = False,
):
# dtype check
assert q.dtype == torch.bfloat16 or q.dtype == torch.float16
assert k.dtype == q.dtype and v.dtype == q.dtype
assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32
# shape
q_len, num_q_heads, qk_head_dim = q.shape
k_len, num_k_heads, qk_head_dim = k.shape
v_len, num_v_heads, v_head_dim = v.shape
batch_size = cu_seqlens_q.shape[0] - 1
assert qk_head_dim <= 256 and v_head_dim <= 256, "head_dim must be less than 256"
assert q_len == k_len and k_len == v_len
# gqa
assert num_k_heads == num_v_heads
assert num_q_heads % num_k_heads == 0
num_share_q_heads = num_q_heads // num_k_heads
# output tensor
o = torch.empty(q.shape[0], q.shape[1], v.shape[-1], dtype=q.dtype, device=q.device)
lse = torch.empty(num_q_heads, q_len, dtype=torch.float32, device=q.device)
# launch kernel
grid = lambda META: (
batch_size,
num_q_heads,
triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]),
)
BLOCK_SIZE_Q = 128
BLOCK_SIZE_K = 64
BLOCK_SIZE_KD = triton.next_power_of_2(qk_head_dim)
BLOCK_SIZE_VD = triton.next_power_of_2(v_head_dim)
num_warps, num_stages = get_num_warps_stages(
max(qk_head_dim, v_head_dim), BLOCK_SIZE_Q, IS_HOPPER_GPU
)
forward_kernel[grid](
q,
k,
v,
o,
lse,
cu_seqlens_q,
cu_seqlens_k,
num_k_heads,
num_share_q_heads,
qk_head_dim,
v_head_dim,
sm_scale,
causal,
gqa_interleave,
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
lse.stride(0),
lse.stride(1),
BLOCK_SIZE_Q=BLOCK_SIZE_Q,
BLOCK_SIZE_K=BLOCK_SIZE_K,
BLOCK_SIZE_KD=BLOCK_SIZE_KD,
BLOCK_SIZE_VD=BLOCK_SIZE_VD,
num_warps=num_warps,
num_stages=num_stages,
)
return o, lse
def _flash_attention_bwd(
o: torch.Tensor,
do: torch.Tensor,
lse: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
causal: bool,
sm_scale: float,
gqa_interleave: bool = False,
):
q_len, num_q_heads, qk_head_dim = q.shape
k_len, num_k_heads, qk_head_dim = k.shape
v_len, num_v_heads, v_head_dim = v.shape
o_len, num_o_heads, v_head_dim = o.shape
num_share_q_heads = num_q_heads // num_k_heads
# compute D
delta = torch.empty([num_o_heads, o_len], device=o.device, dtype=torch.float32)
grid = lambda META: (triton.cdiv(o_len, META["BLOCK_SIZE_O"]), num_o_heads)
BLOCK_SIZE_O = 256
BLOCK_SIZE_VD = triton.next_power_of_2(v_head_dim)
num_warps, num_stages = get_num_warps_stages(
max(qk_head_dim, v_head_dim), BLOCK_SIZE_O, IS_HOPPER_GPU
)
backward_sum_o_do[grid](
o,
do,
delta,
o_len,
v_head_dim,
o.stride(0),
o.stride(1),
o.stride(2),
do.stride(0),
do.stride(1),
do.stride(2),
delta.stride(0),
delta.stride(1),
BLOCK_SIZE_O=BLOCK_SIZE_O,
BLOCK_SIZE_D=BLOCK_SIZE_VD,
num_warps=num_warps,
num_stages=num_stages,
)
# compute dk dv
dk = torch.empty(
num_share_q_heads,
k_len,
num_k_heads,
qk_head_dim,
device=k.device,
dtype=k.dtype,
)
dv = torch.empty(
num_share_q_heads,
k_len,
num_k_heads,
v_head_dim,
device=k.device,
dtype=k.dtype,
)
batch_size = cu_seqlens_q.shape[0] - 1
grid = lambda META: (
batch_size,
num_q_heads,
triton.cdiv(max_seqlen_k, META["BLOCK_SIZE_K"]),
)
BLOCK_SIZE_Q = 64
BLOCK_SIZE_K = 64
BLOCK_SIZE_KD = triton.next_power_of_2(qk_head_dim)
BLOCK_SIZE_VD = triton.next_power_of_2(v_head_dim)
num_warps, num_stages = get_num_warps_stages(
max(qk_head_dim, v_head_dim), BLOCK_SIZE_K, IS_HOPPER_GPU
)
backward_dkdv[grid](
q,
k,
v,
lse,
delta,
do,
dk,
dv,
cu_seqlens_q,
cu_seqlens_k,
num_k_heads,
num_share_q_heads,
qk_head_dim,
v_head_dim,
sm_scale,
causal,
gqa_interleave,
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
lse.stride(0),
lse.stride(1),
delta.stride(0),
delta.stride(1),
do.stride(0),
do.stride(1),
do.stride(2),
dk.stride(0),
dk.stride(1),
dk.stride(2),
dk.stride(3),
dv.stride(0),
dv.stride(1),
dv.stride(2),
dv.stride(3),
BLOCK_SIZE_Q=BLOCK_SIZE_Q,
BLOCK_SIZE_K=BLOCK_SIZE_K,
BLOCK_SIZE_KD=BLOCK_SIZE_KD,
BLOCK_SIZE_VD=BLOCK_SIZE_VD,
num_warps=num_warps,
num_stages=num_stages,
)
dk = dk.sum(0)
dv = dv.sum(0)
# compute dq
dq = torch.empty_like(q)
grid = lambda META: (
batch_size,
num_q_heads,
triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]),
)
BLOCK_SIZE_Q = 64 if max(qk_head_dim, v_head_dim) > 128 else 128
BLOCK_SIZE_K = 64
num_warps, num_stages = get_num_warps_stages(
max(qk_head_dim, v_head_dim), BLOCK_SIZE_Q, IS_HOPPER_GPU
)
backward_dq[grid](
q,
k,
v,
lse,
delta,
do,
dq,
cu_seqlens_q,
cu_seqlens_k,
num_k_heads,
num_share_q_heads,
qk_head_dim,
v_head_dim,
sm_scale,
causal,
gqa_interleave,
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
lse.stride(0),
lse.stride(1),
delta.stride(0),
delta.stride(1),
do.stride(0),
do.stride(1),
do.stride(2),
dq.stride(0),
dq.stride(1),
dq.stride(2),
BLOCK_SIZE_Q=BLOCK_SIZE_Q,
BLOCK_SIZE_K=BLOCK_SIZE_K,
BLOCK_SIZE_KD=BLOCK_SIZE_KD,
BLOCK_SIZE_VD=BLOCK_SIZE_VD,
num_warps=num_warps,
num_stages=num_stages,
)
return dq, dk, dv
class FlashAttention(torch.autograd.Function):
@staticmethod
def forward(
ctx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
causal=True,
sm_scale=None,
gqa_interleave=False,
):
# softmax scale
if sm_scale is None:
sm_scale = 1 / math.sqrt(q.shape[-1])
o, lse = _flash_attention_fwd(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
causal,
sm_scale,
gqa_interleave,
)
ctx.save_for_backward(q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k)
ctx.sm_scale = sm_scale
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_k = max_seqlen_k
ctx.causal = causal
ctx.gqa_interleave = gqa_interleave
return o
@staticmethod
def backward(ctx, do: torch.Tensor, *args) -> Any:
q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors
max_seqlen_q = ctx.max_seqlen_q
max_seqlen_k = ctx.max_seqlen_k
sm_scale = ctx.sm_scale
causal = ctx.causal
gqa_interleave = ctx.gqa_interleave
dq, dk, dv = _flash_attention_bwd(
o,
do,
lse,
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
causal,
sm_scale,
gqa_interleave,
)
return dq, dk, dv, None, None, None, None, None, None, None
def flash_attention_varlen(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
causal: bool = False,
sm_scale: Optional[float] = None,
gqa_interleave: bool = False,
) -> torch.Tensor:
"""Flash attention with variable length based on triton.
Args:
q (torch.Tensor): shape [total_q_len, num_q_heads, head_dim]
k (torch.Tensor): shape [total_kv_len, num_q_heads, head_dim]
v (torch.Tensor): shape [total_kv_len, num_q_heads, head_dim]
cu_seqlens_q (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.
cu_seqlens_k (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_k in flash_attn_func_varlen.
max_seqlen_q (int): max q len of the batch.
max_seqlen_k (int): max k len of the batch.
causal (bool, optional): Causal mask. Defaults to False.
sm_scale (float, optional): softmax scale. Defaults to None, means 1/sqrt(head_dim).
gqa_interleave (bool, optional): GQA pattern. Defaults to False, use Llama style GQA.
Returns:
torch.Tensor: attention output with shape [total_q_len, num_q_heads, head_dim]
"""
return FlashAttention.apply(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
causal,
sm_scale,
gqa_interleave,
)
================================================
FILE: native_sparse_attention/ops/triton/flash_attention_decode.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
import triton
import triton.language as tl
from typing import Optional
@triton.jit
def decode_kernel(
q_ptr, # Q: b x h x d
k_ptr, # K: b x n x h x d
v_ptr, # V: b x n x h x d
o_ptr, # O: b x h x d
seqlens,
# shape
BATCH_SIZE,
NUM_SHARE_Q_HEADS,
HEAD_DIM,
# sm_scale
sm_scale,
# stride
stride_qb,
stride_qh,
stride_qd,
stride_kb,
stride_kn,
stride_kh,
stride_kd,
stride_vb,
stride_vn,
stride_vh,
stride_vd,
stride_ob,
stride_oh,
stride_od,
# META parameters
BLOCK_SIZE_B: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
BLOCK_SIZE_D: tl.constexpr,
):
qk_scale = sm_scale * 1.44269504
# get batch id and head id
pid_h = tl.program_id(0)
pid_b = tl.program_id(1)
pid_kh = pid_h // NUM_SHARE_Q_HEADS
# get q k start and len after rmpad
off_b = tl.arange(0, BLOCK_SIZE_B)
kv_len = tl.load(
seqlens + pid_b * BLOCK_SIZE_B + off_b,
mask=pid_b * BLOCK_SIZE_B + off_b < BATCH_SIZE,
other=0,
)
max_kv_len = tl.max(kv_len)
# init qkv pointer
q_ptrs = tl.make_block_ptr(
base=q_ptr + pid_h * stride_qh,
shape=(BATCH_SIZE, HEAD_DIM),
strides=(stride_qb, stride_qd),
offsets=(pid_b * BLOCK_SIZE_B, 0),
block_shape=(BLOCK_SIZE_B, BLOCK_SIZE_D),
order=(1, 0),
)
k_ptrs = tl.make_block_ptr(
base=k_ptr + pid_kh * stride_kh,
shape=(BATCH_SIZE, max_kv_len, HEAD_DIM),
strides=(stride_kb, stride_kn, stride_kd),
offsets=(pid_b * BLOCK_SIZE_B, 0, 0),
block_shape=(BLOCK_SIZE_B, BLOCK_SIZE_K, BLOCK_SIZE_D),
order=(2, 1, 0),
)
v_ptrs = tl.make_block_ptr(
base=v_ptr + pid_kh * stride_vh,
shape=(BATCH_SIZE, max_kv_len, HEAD_DIM),
strides=(stride_vb, stride_vn, stride_vd),
offsets=(pid_b * BLOCK_SIZE_B, 0, 0),
block_shape=(BLOCK_SIZE_B, BLOCK_SIZE_K, BLOCK_SIZE_D),
order=(2, 1, 0),
)
# load q
q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero")
# init statistics
off_k = tl.arange(0, BLOCK_SIZE_K)
m_i = tl.full((BLOCK_SIZE_B,), float("-inf"), dtype=tl.float32)
lse_i = tl.full((BLOCK_SIZE_B,), float("-inf"), dtype=tl.float32)
acc_o = tl.full((BLOCK_SIZE_B, BLOCK_SIZE_D), 0, dtype=tl.float32)
# full attention or causal attention
for i in range(0, max_kv_len, BLOCK_SIZE_K):
i = tl.multiple_of(i, BLOCK_SIZE_K)
# load k
k = tl.load(k_ptrs, boundary_check=(0, 1, 2), padding_option="zero")
# compute qk
qk = tl.zeros((BLOCK_SIZE_B, BLOCK_SIZE_K), dtype=tl.float32)
qk += tl.where(off_k[None, :] + i < kv_len[:, None], 0, float("-inf"))
# [B, D], [B, K, D] -> [B, K]
qk += tl.sum(q[:, None, :] * k, axis=2) * qk_scale
# compute m_ij and l_ij
m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
p = tl.math.exp2(qk - m_ij[:, None])
l_ij = tl.sum(p, axis=1)
# scale acc_o
acc_o_scale = tl.math.exp2(m_i - m_ij)
acc_o = acc_o * acc_o_scale[:, None]
# load v and update acc_o
v = tl.load(v_ptrs, boundary_check=(0, 1, 2), padding_option="zero")
p = p.to(v.dtype)
# [B, K], [B, K, D] -> [B, D]
acc_o += tl.sum(p[:, :, None] * v, axis=1)
# update statistics
m_i = m_ij
lse_i = m_ij + tl.math.log2(tl.math.exp2(lse_i - m_ij) + l_ij)
# update ptrs
k_ptrs = tl.advance(k_ptrs, (0, BLOCK_SIZE_K, 0))
v_ptrs = tl.advance(v_ptrs, (0, BLOCK_SIZE_K, 0))
# final scale
acc_o = acc_o * tl.math.exp2(m_i - lse_i)[:, None]
# save output
o_ptrs = tl.make_block_ptr(
base=o_ptr + pid_h * stride_oh,
shape=(BATCH_SIZE, HEAD_DIM),
strides=(stride_ob, stride_od),
offsets=(pid_b * BLOCK_SIZE_B, 0),
block_shape=(BLOCK_SIZE_B, BLOCK_SIZE_D),
order=(1, 0),
)
tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1))
def flash_attention_decode(
q: torch.Tensor, # [batch_size, num_heads, head_dim]
k: torch.Tensor, # [batch_size, max_len, num_heads, head_dim]
v: torch.Tensor,
seqlens: torch.Tensor, # [batch_size, ]
sm_scale: Optional[float] = None,
) -> torch.Tensor:
"""flash attention for decode.
Args:
q (torch.Tensor): query, shape [batch_size, num_q_heads, head_dim]
k (torch.Tensor): key, shape [batch_size, kv_len, num_kv_heads, head_dim]
v (torch.Tensor): value, shape [batch_size, kv_len, num_kv_heads, head_dim]
seqlens (torch.Tensor): kv length for each sequence
sm_scale (Optional[float]): softmax scale, default to 1/sqrt(head_dim)
Returns:
torch.Tensor: attention output
"""
# dtype check
assert q.dtype == torch.bfloat16 or q.dtype == torch.float16
assert k.dtype == q.dtype and v.dtype == q.dtype
assert seqlens.dtype == torch.int32
# shape
batch_size, num_q_heads, head_dim = q.shape
_, k_len, num_k_heads, head_dim = k.shape
_, v_len, num_v_heads, head_dim = v.shape
assert k_len == v_len and batch_size == seqlens.shape[0]
# gqa
assert num_k_heads == num_v_heads
assert num_q_heads % num_k_heads == 0
num_share_q_heads = num_q_heads // num_k_heads
# sm scale
if sm_scale is None:
sm_scale = 1 / math.sqrt(head_dim)
# output tensor
o = torch.zeros_like(q)
# launch kernel
num_warps = 4 if head_dim <= 64 else 8
num_stages = 3
# there is a bug for triton 3.0.0 if BLOCK_SIZE_B > 16
BLOCK_SIZE_B = min(16, triton.next_power_of_2(batch_size))
BLOCK_SIZE_K = 128
BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
grid = (num_q_heads, triton.cdiv(batch_size, BLOCK_SIZE_B))
decode_kernel[grid](
q,
k,
v,
o,
seqlens,
batch_size,
num_share_q_heads,
head_dim,
sm_scale,
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
k.stride(3),
v.stride(0),
v.stride(1),
v.stride(2),
v.stride(3),
o.stride(0),
o.stride(1),
o.stride(2),
BLOCK_SIZE_B=BLOCK_SIZE_B,
BLOCK_SIZE_K=BLOCK_SIZE_K,
BLOCK_SIZE_D=BLOCK_SIZE_D,
num_warps=num_warps,
num_stages=num_stages,
)
return o
def torch_attention_decode(
q: torch.Tensor, # [batch_size, num_heads, head_dim]
k: torch.Tensor, # [batch_size, max_len, num_heads, head_dim]
v: torch.Tensor,
seqlens: torch.Tensor, # [batch_size, ]
sm_scale: Optional[float] = None,
):
# dtype check
assert q.dtype == torch.bfloat16 or q.dtype == torch.float16
assert k.dtype == q.dtype and v.dtype == q.dtype
assert seqlens.dtype == torch.int32
# shape
batch_size, num_q_heads, head_dim = q.shape
_, k_len, num_k_heads, head_dim = k.shape
_, v_len, num_v_heads, head_dim = v.shape
assert k_len == v_len and batch_size == seqlens.shape[0]
# gqa
assert num_k_heads == num_v_heads
assert num_q_heads % num_k_heads == 0
num_share_q_heads = num_q_heads // num_k_heads
# sm scale
if sm_scale is None:
sm_scale = 1 / math.sqrt(head_dim)
# attention
attn = (
torch.einsum(
"bqhd,bkhd->bhqk",
q.unsqueeze(1),
k.repeat_interleave(num_share_q_heads, dim=2),
)
* sm_scale
)
mask = torch.arange(k_len, device=q.device)[None, :] < seqlens[:, None]
attn = attn.masked_fill(~mask[:, None, None, :], -torch.inf)
attn = torch.softmax(attn, dim=-1, dtype=torch.float32).to(q.dtype)
out = torch.einsum(
"bhqk,bkhd->bqhd", attn, v.repeat_interleave(num_share_q_heads, dim=2)
).squeeze(1)
return out
if __name__ == "__main__":
torch.manual_seed(42)
batch_size = 76
max_length = 8192
seqlens = torch.arange(batch_size, dtype=torch.int32).cuda() * 128 + 1
seqlens[seqlens > max_length] = max_length
seqlens = seqlens[torch.randn_like(seqlens, dtype=torch.float32).argsort(-1)]
q = (
torch.empty(batch_size, 32, 128, device="cuda")
.uniform_(-1, 1)
.to(torch.bfloat16)
)
k = (
torch.empty(batch_size, max_length, 4, 128, device="cuda")
.uniform_(-1, 1)
.to(torch.bfloat16)
)
v = (
torch.empty(batch_size, max_length, 4, 128, device="cuda")
.uniform_(-1, 1)
.to(torch.bfloat16)
)
o1 = torch_attention_decode(q, k, v, seqlens)
o2 = flash_attention_decode(q, k, v, seqlens)
print(torch.allclose(o1, o2, atol=1e-2, rtol=1e-2))
================================================
FILE: native_sparse_attention/ops/triton/linear_compress.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from typing import Optional, Tuple, Any
import triton
import torch
import triton.language as tl
from einops import rearrange, einsum
from native_sparse_attention.ops.triton.utils import is_hopper_gpu
IS_HOPPER_GPU = is_hopper_gpu()
@triton.jit
def linear_compress_fwd_kernel(
X, # input pointer [total_len, num_heads, head_dim]
Y, # output pointer [total_compressed_len, num_heads, head_dim]
W, # weight matrix pointer [num_heads, kernel_size, head_dim, head_dim]
cu_seqlens_x, # cumulative sequence lengths before compression
cu_seqlens_y, # cumulative sequence lengths after compression
stride_xn, # stride for X's sequence dimension
stride_xh, # stride for X's num head dimension
stride_xd, # stride for X's head_dim dimension
stride_wh, # stride for W's num head dimension
stride_wk, # stride for W's kernel size dimension
stride_wd, # stride for W's initial head dim dimension
stride_wD, # stride for W's final head dim dimension
stride_yn, # stride for Y's sequence dimension
stride_yh, # stride for Y's num head dimension
stride_yd, # stride for Y's head_dim dimension
NUM_HEADS: tl.constexpr, # total num heads
KERNEL_SIZE: tl.constexpr, # kernel size when calculate the output
KERNEL_STRIDE: tl.constexpr, # kernel stride when calculate the output
HEADd_DIM: tl.constexpr, # initial head dimension size
HEADD_DIM: tl.constexpr, # final head dimension size
BLOCK_OUTPUT_SEQ_SIZE: tl.constexpr, # Loaded output len
BLOCK_KERNEL_SIZE: tl.constexpr, # Loaded kernel size when calculate the output
BLOCK_HEADd_DIM: tl.constexpr, # Loaded orignal head dimension size
BLOCK_HEADD_DIM: tl.constexpr, # loaded final head dimension size
):
pid_bh = tl.program_id(0)
pid_b = pid_bh // NUM_HEADS
pid_h = pid_bh % NUM_HEADS
pid_k = tl.program_id(1)
pid_D = tl.program_id(2)
x_start = tl.load(cu_seqlens_x + pid_b)
x_end = tl.load(cu_seqlens_x + pid_b + 1)
x_len = x_end - x_start
y_start = tl.load(cu_seqlens_y + pid_b)
y_end = tl.load(cu_seqlens_y + pid_b + 1)
y_len = y_end - y_start
if pid_k * BLOCK_OUTPUT_SEQ_SIZE >= y_len:
return
off_kernel_size = tl.arange(0, BLOCK_KERNEL_SIZE)
off_d = tl.arange(0, BLOCK_HEADd_DIM)
off_output_seq_size = tl.arange(0, BLOCK_OUTPUT_SEQ_SIZE)
x_base_ptrs = (
X
+ pid_h * stride_xh
+ x_start * stride_xn
+ (
(
pid_k * BLOCK_OUTPUT_SEQ_SIZE * KERNEL_STRIDE
+ off_output_seq_size * KERNEL_STRIDE
)[:, None]
+ off_kernel_size[None, :]
)[:, :, None]
* stride_xn
+ off_d[None, None, :] * stride_xd
)
x_base_mask = (
(
(
pid_k * BLOCK_OUTPUT_SEQ_SIZE * KERNEL_STRIDE
+ off_output_seq_size * KERNEL_STRIDE
)[:, None]
+ off_kernel_size[None, :]
)
< x_len
)[:, :, None]
w_ptrs = tl.make_block_ptr(
base=W + pid_h * stride_wh,
shape=(KERNEL_SIZE, HEADd_DIM, HEADD_DIM),
strides=(stride_wk, stride_wd, stride_wD),
offsets=(0, 0, pid_D * BLOCK_HEADD_DIM),
block_shape=(BLOCK_KERNEL_SIZE, BLOCK_HEADd_DIM, BLOCK_HEADD_DIM),
order=(2, 1, 0),
)
y_ptrs = tl.make_block_ptr(
base=Y + y_start * stride_yn + pid_h * stride_yh,
shape=(y_len, HEADD_DIM),
strides=(stride_yn, stride_yd),
offsets=(pid_k * BLOCK_OUTPUT_SEQ_SIZE, pid_D * BLOCK_HEADD_DIM),
block_shape=(BLOCK_OUTPUT_SEQ_SIZE, BLOCK_HEADD_DIM),
order=(1, 0),
)
y_d = tl.full((BLOCK_OUTPUT_SEQ_SIZE, BLOCK_HEADD_DIM), 0, dtype=tl.float32)
for i in range(0, HEADd_DIM, BLOCK_HEADd_DIM):
x_ptrs = x_base_ptrs + i * stride_xd
x_mask = x_base_mask & ((i + off_d) < HEADd_DIM)[None, None, :]
x = tl.load(x_ptrs, mask=x_mask, other=0)
x = tl.reshape(x, (BLOCK_OUTPUT_SEQ_SIZE, BLOCK_KERNEL_SIZE * BLOCK_HEADd_DIM))
# x : [n, k * bd]
w = tl.load(w_ptrs, boundary_check=(0, 1, 2), padding_option="zero")
w = tl.reshape(w, (BLOCK_KERNEL_SIZE * BLOCK_HEADd_DIM, BLOCK_HEADD_DIM))
# w: [k * bd, D]
y_d += tl.dot(x, w)
# y_d : [n, D]
w_ptrs = tl.advance(w_ptrs, (0, BLOCK_HEADd_DIM, 0))
tl.store(y_ptrs, y_d.to(y_ptrs.dtype.element_ty), boundary_check=(0, 1))
@triton.jit
def linear_compress_bwd_kernel(
DX, # X's gradient pointer [total_len, num_heads, head_dim]
DY, # Y's gradient pointer [total_compressed_len, num_heads, head_dim]
DW, # weight's gradient pointer [num_heads, kernel_size, head_dim, head_dim]
X, # x pointer [total_len, num_heads, head_dim]
W, # weight matrix pointer [num_heads, kernel_size, head_dim, head_dim]
cu_seqlens_x, # cumulative sequence lengths before compression
cu_seqlens_y, # cumulative sequence lengths after compression
stride_xn, # stride for X's sequence dimension
stride_xh, # stride for X's num head dimension
stride_xd, # stride for X's head_dim dimension
stride_wh, # stride for W's num head dimension
stride_wk, # stride for W's kernel size dimension
stride_wd, # stride for W's initial head dim dimension
stride_wD, # stride for W's final head dim dimension
stride_dxn, # stride for DX's sequence dimension
stride_dxh, # stride for DX's num head dimension
stride_dxd, # stride for DX's head_dim dimension
stride_dwh, # stride for DW's num head dimension
stride_dwk, # stride for DW's kernel size dimension
stride_dwd, # stride for DW's initial head dim dimension
stride_dwD, # stride for DW's final head dim dimension
stride_dyn, # stride for DY's sequence dimension
stride_dyh, # stride for DY's num head dimension
stride_dyd, # stride for DY's head_dim dimension
NUM_HEADS: tl.constexpr, # total num heads
NUM_PARALLEL_HEADd_NUM: tl.constexpr, # total parallel process among head dim d
KERNEL_SIZE: tl.constexpr, # kernel size when calculate the output
KERNEL_STRIDE: tl.constexpr, # kernel stride when calculate the output
HEADd_DIM: tl.constexpr, # initial head dimension size
HEADD_DIM: tl.constexpr, # final head dimension size
BLOCK_KERNEL_SIZE: tl.constexpr, # Loaded kernel size when calculate the output
BLOCK_HEADd_DIM: tl.constexpr, # Loaded orignal head dimension size
BLOCK_HEADD_DIM: tl.constexpr, # loaded final head dimension size
BLOCK_OUTPUT_SEQ_SIZE: tl.constexpr, # loaded final output seq parallel
):
pid_bh = tl.program_id(0)
pid_b = pid_bh // NUM_HEADS
pid_h = pid_bh % NUM_HEADS
pid_k = tl.program_id(1)
pid_Dd = tl.program_id(2)
pid_D = pid_Dd // NUM_PARALLEL_HEADd_NUM
pid_d = pid_Dd % NUM_PARALLEL_HEADd_NUM
x_start = tl.load(cu_seqlens_x + pid_b)
x_end = tl.load(cu_seqlens_x + pid_b + 1)
x_len = x_end - x_start
y_start = tl.load(cu_seqlens_y + pid_b)
y_end = tl.load(cu_seqlens_y + pid_b + 1)
y_len = y_end - y_start
if pid_k * BLOCK_OUTPUT_SEQ_SIZE >= y_len:
return
# pdb.set_trace()
off_kernel_size = tl.arange(0, BLOCK_KERNEL_SIZE)
off_d = tl.arange(0, BLOCK_HEADd_DIM)
off_D = tl.arange(0, BLOCK_HEADD_DIM)
off_output_seq_size = tl.arange(0, BLOCK_OUTPUT_SEQ_SIZE)
x_ptrs = (
X
+ pid_h * stride_xh
+ x_start * stride_xn
+ (
(
pid_k * BLOCK_OUTPUT_SEQ_SIZE * KERNEL_STRIDE
+ off_output_seq_size * KERNEL_STRIDE
)[:, None]
+ off_kernel_size[None, :]
)[:, :, None]
* stride_xn
+ (pid_d * BLOCK_HEADd_DIM + off_d)[None, None, :] * stride_xd
)
x_mask = (
(
(
pid_k * BLOCK_OUTPUT_SEQ_SIZE * KERNEL_STRIDE
+ off_output_seq_size * KERNEL_STRIDE
)[:, None]
+ off_kernel_size[None, :]
)
< x_len
)[:, :, None] & ((pid_d * BLOCK_HEADd_DIM + off_d) < HEADd_DIM)[None, None, :]
dx_ptrs = (
DX
+ pid_h * stride_dxh
+ x_start * stride_dxn
+ (
(
pid_k * BLOCK_OUTPUT_SEQ_SIZE * KERNEL_STRIDE
+ off_output_seq_size * KERNEL_STRIDE
)[:, None]
+ off_kernel_size[None, :]
)[:, :, None]
* stride_dxn
+ (pid_d * BLOCK_HEADd_DIM + off_d)[None, None, :] * stride_dxd
)
w_ptrs = tl.make_block_ptr(
base=W + pid_h * stride_wh,
shape=(KERNEL_SIZE, HEADd_DIM, HEADD_DIM),
strides=(stride_wk, stride_wd, stride_wD),
offsets=(0, pid_d * BLOCK_HEADd_DIM, pid_D * BLOCK_HEADD_DIM),
block_shape=(BLOCK_KERNEL_SIZE, BLOCK_HEADd_DIM, BLOCK_HEADD_DIM),
order=(2, 1, 0),
)
dw_ptrs = (
DW
+ pid_h * stride_dwh
+ off_kernel_size[:, None, None] * stride_dwk
+ (pid_d * BLOCK_HEADd_DIM + off_d)[None, :, None] * stride_dwd
+ (pid_D * BLOCK_HEADD_DIM + off_D)[None, None, :] * stride_dwD
)
dw_mask = (
(off_kernel_size < KERNEL_SIZE)[:, None, None]
& ((pid_d * BLOCK_HEADd_DIM + off_d) < HEADd_DIM)[None, :, None]
& ((pid_D * BLOCK_HEADD_DIM + off_D) < HEADD_DIM)[None, None, :]
)
dy_ptrs = tl.make_block_ptr(
base=DY + y_start * stride_dyn + pid_h * stride_dyh,
shape=(y_len, HEADD_DIM),
strides=(stride_dyn, stride_dyd),
offsets=(pid_k * BLOCK_OUTPUT_SEQ_SIZE, pid_D * BLOCK_HEADD_DIM),
block_shape=(BLOCK_OUTPUT_SEQ_SIZE, BLOCK_HEADD_DIM),
order=(1, 0),
)
dy = tl.load(dy_ptrs, boundary_check=(0, 1), padding_option="zero")
# dy : [by, D]
# cal dx, start
w = tl.load(w_ptrs, boundary_check=(0, 1, 2), padding_option="zero")
# w: [k, bd, D]
w = tl.reshape(w, (BLOCK_KERNEL_SIZE * BLOCK_HEADd_DIM, BLOCK_HEADD_DIM))
# w: [k * bd, D]
dx = tl.dot(dy, tl.trans(w))
# dx: [by, k * bd]
dx = tl.reshape(dx, (BLOCK_OUTPUT_SEQ_SIZE, BLOCK_KERNEL_SIZE, BLOCK_HEADd_DIM))
# dx: [by, k, bd]
tl.atomic_add(
dx_ptrs,
dx.to(dx_ptrs.dtype.element_ty),
mask=x_mask,
)
# cal dx, end
# cal dw, start
x = tl.load(x_ptrs, mask=x_mask, other=0)
x = tl.reshape(x, (BLOCK_OUTPUT_SEQ_SIZE, BLOCK_KERNEL_SIZE * BLOCK_HEADd_DIM))
# x : [by, k * bd]
dw = tl.dot(tl.trans(x), dy)
# dw: [k * bd, D]
dw = tl.reshape(dw, (BLOCK_KERNEL_SIZE, BLOCK_HEADd_DIM, BLOCK_HEADD_DIM))
# dw: [k, bd, D]
tl.atomic_add(dw_ptrs, dw.to(dw_ptrs.dtype.element_ty), mask=dw_mask)
class LinearCompress(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x: torch.Tensor,
w: torch.Tensor,
cu_seqlens: torch.Tensor,
kernel_size: int,
kernel_stride: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compress key and value tensor with kernel_size and kernel_stride. Similar to conv_compress.
Args:
x (torch.Tensor): key_states or value_states, shape (total_len, num_heads, head_dim)
w (torch.Tensor): weight for each head, shape (num_heads, kernel_size * head_dim, head_dim)
cu_seqlens (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen
kernel_size (int): kernel_size, each (kernel_size, head_dim) blocks will be compressed to (1, head_dim)
kernel_stride (int): stride for each compress kernel
Returns:
Tuple[torch.Tensor, torch.Tensor]: compressed states and corresponding cu_seqlens.
"""
# dtype check
assert x.dtype == torch.float16 or x.dtype == torch.bfloat16
assert x.dtype == w.dtype
assert cu_seqlens.dtype == torch.int32
# shape check
total_len, num_heads, head_dim = x.shape
batch_size = cu_seqlens.shape[0] - 1
assert w.shape[0] == num_heads
assert w.shape[1] == kernel_size * head_dim
assert w.shape[2] == head_dim
assert kernel_size % kernel_stride == 0
assert kernel_size in {16, 32, 64, 128}
assert head_dim % 8 == 0
torch.cuda.set_device(x.device)
# compute seqlens after compression
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
y_seqlens = (
torch.floor((seqlens - kernel_size) / kernel_stride).to(torch.int32) + 1
)
# corner case: if sequence_length < kernel_size, no compression for this sequence
y_seqlens[seqlens < kernel_size] = 0
y_cu_seqlens = torch.cat(
[
torch.zeros(1, dtype=torch.int32, device=x.device),
torch.cumsum(y_seqlens.to(x.device), dim=0),
],
dim=0,
).to(torch.int32)
y = torch.zeros(
y_cu_seqlens[-1], num_heads, head_dim, dtype=x.dtype, device=x.device
)
block_kernel_size = max(16, triton.next_power_of_2(kernel_size))
block_head_dim = 8 if IS_HOPPER_GPU else 4
block_headD_dim = 32
block_output_seq_size = 64
w = w.reshape(num_heads, kernel_size, head_dim, head_dim).contiguous()
grid = lambda META: (
batch_size * num_heads,
triton.cdiv(y_seqlens.max(0)[0].item(), META["BLOCK_OUTPUT_SEQ_SIZE"]),
triton.cdiv(head_dim, META["BLOCK_HEADD_DIM"]),
)
linear_compress_fwd_kernel[grid](
x,
y,
w,
cu_seqlens,
y_cu_seqlens,
x.stride(0),
x.stride(1),
x.stride(2),
w.stride(0),
w.stride(1),
w.stride(2),
w.stride(3),
y.stride(0),
y.stride(1),
y.stride(2),
num_heads,
kernel_size,
kernel_stride,
head_dim,
head_dim,
block_output_seq_size,
block_kernel_size,
block_head_dim,
block_headD_dim,
# num_warps=8,
# num_stages=3,
)
# save for backward
ctx.save_for_backward(x, w, cu_seqlens, y_seqlens, y_cu_seqlens)
# save value
ctx.kernel_size = kernel_size
ctx.kernel_stride = kernel_stride
ctx.block_kernel_size = block_kernel_size
ctx.block_headd_dim = block_head_dim
ctx.block_headD_dim = block_headD_dim
ctx.block_output_seq_size = block_output_seq_size
return y, y_cu_seqlens
@staticmethod
def backward(ctx, dy: torch.Tensor, *args) -> Any:
x, w, cu_seqlens, y_seqlens, y_cu_seqlens = ctx.saved_tensors
kernel_size = ctx.kernel_size
kernel_stride = ctx.kernel_stride
block_kernel_size = ctx.block_kernel_size
block_head_dim = ctx.block_headd_dim
block_headD_dim = ctx.block_headD_dim
block_output_seq_size = ctx.block_output_seq_size
total_len, num_heads, head_dim = x.shape
batch_size = cu_seqlens.shape[0] - 1
dx = torch.zeros(
cu_seqlens[-1], num_heads, head_dim, dtype=torch.float32, device=x.device
)
dw = torch.zeros(
num_heads,
kernel_size,
head_dim,
head_dim,
dtype=torch.float32,
device=x.device,
)
grid = lambda META: (
batch_size * num_heads,
triton.cdiv(y_seqlens.max(0)[0].item(), META["BLOCK_OUTPUT_SEQ_SIZE"]),
triton.cdiv(head_dim, META["BLOCK_HEADD_DIM"])
* triton.cdiv(head_dim, META["BLOCK_HEADd_DIM"]),
)
linear_compress_bwd_kernel[grid](
dx,
dy,
dw,
x,
w,
cu_seqlens,
y_cu_seqlens,
x.stride(0),
x.stride(1),
x.stride(2),
w.stride(0),
w.stride(1),
w.stride(2),
w.stride(3),
dx.stride(0),
dx.stride(1),
dx.stride(2),
dw.stride(0),
dw.stride(1),
dw.stride(2),
dw.stride(3),
dy.stride(0),
dy.stride(1),
dy.stride(2),
num_heads,
head_dim // block_head_dim,
kernel_size,
kernel_stride,
head_dim,
head_dim,
block_kernel_size,
block_head_dim,
block_headD_dim,
block_output_seq_size,
)
return (
dx.to(x.dtype),
rearrange(dw.to(x.dtype), "n k d D -> n (k d) D"),
None,
None,
None,
)
def linear_compress(
x: torch.Tensor,
w: torch.Tensor,
cu_seqlens: torch.Tensor,
kernel_size: int,
kernel_stride: int,
pe: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compress key and value tensor with kernel_size and kernel_stride with linear projection.
Args:
x (torch.Tensor): key_states or value_states, shape (total_len, num_heads, head_dim)
w (torch.Tensor): weight for each head, shape (num_heads, kernel_size * head_dim, head_dim)
cu_seqlens (_type_): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.
kernel_size (int): kernel_size, each (kernel_size, head_dim) blocks will be compressed to (1, head_dim)
kernel_stride (int): stride for each compress kernel
pe (Optional[torch.Tensor], optional): intra-block positional embedding with shape (num_heads, kernel_size, head_dim). Defaults to None.
Returns:
Tuple[torch.Tensor, torch.Tensor]: compressed states and corresponding cu_seqlens.
"""
y, y_cu_seqlens = LinearCompress.apply(x, w, cu_seqlens, kernel_size, kernel_stride)
# position embedding as a bias
if pe is not None:
assert pe.dtype == x.dtype and pe.device == x.device
pe = rearrange(pe, "h k d -> h (k d)")
bias = einsum(pe, w, "h D, h D d -> h d")
y = y + bias.unsqueeze(0)
return y, y_cu_seqlens
================================================
FILE: native_sparse_attention/ops/triton/topk_sparse_attention.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Any, Optional
import torch
import triton
import triton.language as tl
from native_sparse_attention.ops.triton.utils import get_num_warps_stages, is_hopper_gpu
IS_HOPPER_GPU = is_hopper_gpu()
@triton.jit
def forward_kernel(
q_ptr, # Q: n x h x d
k_ptr, # K: n x kh x d
v_ptr, # V: n x kh x d
t_ptr, # topk_idx: kh x n x k
o_ptr, # O: n x h x d
lse_ptr, # LSE: h x n
# seqlens
cu_seqlens_q,
cu_seqlens_k,
# shape
NUM_KV_HEADS,
NUM_SHARE_Q_HEADS,
HEAD_DIM,
TOPK,
# q loop num
num_q_loop,
# sm_scale
sm_scale,
# stride
stride_qn,
stride_qh,
stride_qd,
stride_kn,
stride_kh,
stride_kd,
stride_vn,
stride_vh,
stride_vd,
stride_th,
stride_tn,
stride_tk,
stride_on,
stride_oh,
stride_od,
stride_lh,
stride_ln,
# META parameters
BLOCK_SIZE_K: tl.constexpr, # k block size
BLOCK_SIZE_D: tl.constexpr,
BLOCK_SIZE_H: tl.constexpr,
BLOCK_SIZE_T: tl.constexpr,
):
qk_scale = sm_scale * 1.44269504
# get batch id and head id
pid_b = tl.program_id(0)
pid_kh = tl.program_id(1)
pid_h = pid_kh * NUM_SHARE_Q_HEADS
pid_q = tl.program_id(2)
# get q k start and len after rmpad
q_start = tl.load(cu_seqlens_q + pid_b)
q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
k_start = tl.load(cu_seqlens_k + pid_b)
k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
if pid_q * num_q_loop >= q_len:
return
real_q_loop = min(num_q_loop, q_len - pid_q * num_q_loop)
for j in range(real_q_loop):
pid_q_j = pid_q * num_q_loop + j
# init topk idx pointer
off_t = tl.arange(0, BLOCK_SIZE_T)
t_ptr_j = t_ptr + (q_start + pid_q_j) * stride_tn + pid_kh * stride_th
topk_idx = tl.load(t_ptr_j + off_t * stride_tk, mask=off_t < TOPK, other=-1)
real_topk = tl.sum(
tl.where((topk_idx >= 0) & (topk_idx <= pid_q_j // BLOCK_SIZE_K), 1, 0),
axis=0,
)
# init qkv pointer
q_ptrs = tl.make_block_ptr(
base=q_ptr + (q_start + pid_q_j) * stride_qn + pid_h * stride_qh,
shape=(NUM_SHARE_Q_HEADS, HEAD_DIM),
strides=(stride_qh, stride_qd),
offsets=(0, 0),
block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D),
order=(1, 0),
)
k_ptrs = tl.make_block_ptr(
base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
shape=(HEAD_DIM, k_len),
strides=(stride_kd, stride_kn),
offsets=(0, 0),
block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K),
order=(0, 1),
)
v_ptrs = tl.make_block_ptr(
base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,
shape=(k_len, HEAD_DIM),
strides=(stride_vn, stride_vd),
offsets=(0, 0),
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
order=(1, 0),
)
# load q
q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero")
# init statistics
off_h = tl.arange(0, BLOCK_SIZE_H)
off_k = tl.arange(0, BLOCK_SIZE_K)
m_i = tl.full((BLOCK_SIZE_H,), float("-inf"), dtype=tl.float32)
lse_i = tl.full((BLOCK_SIZE_H,), float("-inf"), dtype=tl.float32)
acc_o = tl.full((BLOCK_SIZE_H, BLOCK_SIZE_D), 0, dtype=tl.float32)
# sparse attention
for i in range(real_topk):
# get current block start index
c = tl.load(t_ptr_j).to(tl.int32) * BLOCK_SIZE_K
t_ptr_j = t_ptr_j + stride_tk
# load k
k = tl.load(
tl.advance(k_ptrs, (0, c)), boundary_check=(1, 0), padding_option="zero"
)
# compute qk
qk = tl.zeros((BLOCK_SIZE_H, BLOCK_SIZE_K), dtype=tl.float32)
qk += tl.where((pid_q_j >= c + off_k)[None, :], 0, float("-inf"))
# [BLOCK_SIZE_H, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIZE_K] -> [BLOCK_SIZE_H, BLOCK_SIZE_K]
qk += tl.dot(q, k) * qk_scale
# compute m_ij and l_ij
m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
p = tl.exp2(qk - m_ij[:, None])
l_ij = tl.sum(p, axis=1)
# scale acc_o
acc_o_scale = tl.exp2(m_i - m_ij)
acc_o = acc_o * acc_o_scale[:, None]
# load v and update acc_o
v = tl.load(
tl.advance(v_ptrs, (c, 0)), boundary_check=(0, 1), padding_option="zero"
)
p = p.to(v.dtype)
acc_o += tl.dot(p, v)
# update statistics
m_i = m_ij
lse_i = m_ij + tl.math.log2(tl.exp2(lse_i - m_ij) + l_ij)
# final scale
acc_o = acc_o * tl.exp2(m_i - lse_i)[:, None]
# save output
o_ptrs = tl.make_block_ptr(
base=o_ptr + (q_start + pid_q_j) * stride_on + pid_h * stride_oh,
shape=(NUM_SHARE_Q_HEADS, HEAD_DIM),
strides=(stride_oh, stride_od),
offsets=(0, 0),
block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D),
order=(1, 0),
)
tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1))
# save lse
lse_ptrs = (
lse_ptr + (q_start + pid_q_j) * stride_ln + (pid_h + off_h) * stride_lh
)
tl.store(lse_ptrs, lse_i, mask=off_h < NUM_SHARE_Q_HEADS)
@triton.jit
def backward_sum_o_do(
o_ptr, # O: n x h x d
do_ptr, # dO: n x h x d
delta_ptr, # D: h x n
o_len,
HEAD_DIM,
stride_on,
stride_oh,
stride_od,
stride_don,
stride_doh,
stride_dod,
stride_dh,
stride_dn,
BLOCK_SIZE_O: tl.constexpr,
BLOCK_SIZE_D: tl.constexpr,
):
pid_n = tl.program_id(0)
pid_h = tl.program_id(1)
off_o = pid_n * BLOCK_SIZE_O + tl.arange(0, BLOCK_SIZE_O)
off_d = tl.arange(0, BLOCK_SIZE_D)
o = tl.load(
o_ptr
+ off_o[:, None] * stride_on
+ pid_h * stride_oh
+ off_d[None, :] * stride_od,
mask=(off_o[:, None] < o_len) & (off_d[None, :] < HEAD_DIM),
other=0,
).to(tl.float32)
do = tl.load(
do_ptr
+ off_o[:, None] * stride_don
+ pid_h * stride_doh
+ off_d[None, :] * stride_dod,
mask=(off_o[:, None] < o_len) & (off_d[None, :] < HEAD_DIM),
other=0,
).to(tl.float32)
delta = tl.sum(o * do, axis=1)
tl.store(
delta_ptr + pid_h * stride_dh + off_o * stride_dn, delta, mask=off_o < o_len
)
@triton.jit
def count_kernel(
x_ptr, # [num_kv_heads, total_len, topk]
y_ptr, # [num_kv_heads, total_blocks]
cu_seqlens, # [batch_size + 1]
cu_seqblocks, # [batch_size + 1]
topk,
stride_xh,
stride_xn,
stride_xk,
stride_yh,
stride_yn,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
BLOCK_SIZE_R: tl.constexpr,
):
pid_h = tl.program_id(0)
pid_b = tl.program_id(1)
# get start and len after rmpad
seq_start = tl.load(cu_seqlens + pid_b)
seq_len = tl.load(cu_seqlens + pid_b + 1) - seq_start
blocks_start = tl.load(cu_seqblocks + pid_b)
num_blocks = tl.load(cu_seqblocks + pid_b + 1) - blocks_start
# load x
off_k = tl.arange(0, BLOCK_SIZE_K)
off_n = tl.arange(0, BLOCK_SIZE_N)
x_ptr = x_ptr + pid_h * stride_xh + seq_start * stride_xn
x_ptrs = x_ptr + off_n[:, None] * stride_xn + off_k[None, :] * stride_xk
# init y
y = tl.zeros((BLOCK_SIZE_R,), dtype=tl.int32)
# loop
for i in range(0, seq_len, BLOCK_SIZE_N):
x = tl.load(
x_ptrs,
mask=(off_n < seq_len - i)[:, None] & (off_k < topk)[None, :],
other=-1,
)
x = tl.ravel(x)
y += tl.histogram(x, BLOCK_SIZE_R)
x_ptrs += BLOCK_SIZE_N * stride_xn
# store result
off_r = tl.arange(0, BLOCK_SIZE_R)
y_ptr = y_ptr + pid_h * stride_yh + blocks_start * stride_yn
y_ptrs = y_ptr + off_r * stride_yn
tl.store(y_ptrs, y.to(y_ptr.dtype.element_ty), mask=off_r < num_blocks)
def count_query(
topk_idx: torch.Tensor,
cu_seqlens: torch.Tensor,
cu_seqblocks: torch.Tensor,
block_size: int,
):
num_kv_heads, total_len, topk = topk_idx.shape
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
seqblocks = cu_seqblocks[1:] - cu_seqblocks[:-1]
batch_size = seqlens.shape[0]
BLOCK_SIZE_K = triton.next_power_of_2(topk)
BLOCK_SIZE_N = triton.next_power_of_2(4096 // BLOCK_SIZE_K)
BLOCK_SIZE_R = triton.next_power_of_2(seqblocks.max().item() + 2)
active_query_count = torch.zeros(
num_kv_heads, cu_seqblocks[-1], dtype=torch.int32, device=topk_idx.device
)
grid = (num_kv_heads, batch_size)
count_kernel[grid](
topk_idx,
active_query_count,
cu_seqlens,
cu_seqblocks,
topk,
topk_idx.stride(0),
topk_idx.stride(1),
topk_idx.stride(2),
active_query_count.stride(0),
active_query_count.stride(1),
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K,
BLOCK_SIZE_R=BLOCK_SIZE_R,
num_warps=4,
num_stages=3,
)
return active_query_count
@triton.jit
def pad_topk_idx_kernel(
t_ptr,
p_ptr,
cu_seqlens,
topk,
stride_th,
stride_tn,
stride_tk,
stride_pb,
stride_ph,
stride_pn,
stride_pk,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_T: tl.constexpr,
):
pid_b = tl.program_id(0)
pid_h = tl.program_id(1)
pid_n = tl.program_id(2)
# get q start and len after rmpad
q_start = tl.load(cu_seqlens + pid_b)
q_len = tl.load(cu_seqlens + pid_b + 1) - q_start
if BLOCK_SIZE_N * pid_n >= q_len:
return
# init prts
t_ptrs = tl.make_block_ptr(
base=t_ptr + pid_h * stride_th + q_start * stride_tn,
shape=(q_len, topk),
strides=(stride_tn, stride_tk),
offsets=(pid_n * BLOCK_SIZE_N, 0),
block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_T),
order=(1, 0),
)
p_ptrs = tl.make_block_ptr(
base=p_ptr + pid_b * stride_pb + pid_h * stride_ph,
shape=(q_len, topk),
strides=(stride_pn, stride_pk),
offsets=(pid_n * BLOCK_SIZE_N, 0),
block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_T),
order=(1, 0),
)
# load and save
idxs = tl.load(t_ptrs, boundary_check=(0, 1))
tl.store(p_ptrs, idxs, boundary_check=(0, 1))
@triton.jit
def save_topk_idx_kernel(
p_ptr,
t_ptr,
cu_seqblocks,
cu_topk_q_count,
n_len,
stride_pb,
stride_ph,
stride_pn,
stride_th,
stride_tn,
stride_ch,
stride_cn,
BLOCK_SIZE_N: tl.constexpr,
):
pid_b = tl.program_id(0)
pid_h = tl.program_id(1)
pid_n = tl.program_id(2)
# get q start and len after rmpad
q_block_start = tl.load(cu_seqblocks + pid_b)
q_block_end = tl.load(cu_seqblocks + pid_b + 1)
c_start = tl.load(cu_topk_q_count + pid_h * stride_ch + q_block_start * stride_cn)
c_end = tl.load(cu_topk_q_count + pid_h * stride_ch + q_block_end * stride_cn)
c_len = c_end - c_start
if c_len <= 0:
return
if pid_n * BLOCK_SIZE_N >= c_len:
return
# init ptrs
p_ptrs = tl.make_block_ptr(
base=p_ptr
+ pid_b * stride_pb
+ pid_h * stride_ph
+ (n_len - c_len) * stride_pn,
shape=(c_len,),
strides=(stride_pn,),
offsets=(pid_n * BLOCK_SIZE_N,),
block_shape=(BLOCK_SIZE_N,),
order=(0,),
)
t_ptrs = tl.make_block_ptr(
base=t_ptr + pid_h * stride_th + c_start * stride_tn,
shape=(c_len,),
strides=(stride_tn,),
offsets=(pid_n * BLOCK_SIZE_N,),
block_shape=(BLOCK_SIZE_N,),
order=(0,),
)
# load and save
idxs = tl.load(p_ptrs, boundary_check=(0,))
tl.store(t_ptrs, idxs, boundary_check=(0,))
def reorder_topk_idx(
topk_idx: torch.Tensor,
cu_topk_q_count: torch.Tensor,
cu_seqlens: torch.Tensor,
cu_seqblocks: torch.Tensor,
block_size: int,
):
num_kv_heads, total_len, topk = topk_idx.shape
batch_size = cu_seqlens.shape[0] - 1
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
max_seqlen = seq_lens.max().item()
# pad shape [num_kv_heads, total_seqlen, topk] to [batch_size, num_kv_heads, max_seqlen, topk]
pad_topk_idx = torch.full(
(batch_size, num_kv_heads, max_seqlen, topk),
fill_value=-1,
device=topk_idx.device,
dtype=torch.int32,
)
BLOCK_SIZE_T = triton.next_power_of_2(topk)
BLOCK_SIZE_N = min(
triton.next_power_of_2(max_seqlen), triton.next_power_of_2(8192 // BLOCK_SIZE_T)
)
grid = (batch_size, num_kv_heads, triton.cdiv(max_seqlen, BLOCK_SIZE_N))
pad_topk_idx_kernel[grid](
topk_idx,
pad_topk_idx,
cu_seqlens,
topk,
topk_idx.stride(0),
topk_idx.stride(1),
topk_idx.stride(2),
pad_topk_idx.stride(0),
pad_topk_idx.stride(1),
pad_topk_idx.stride(2),
pad_topk_idx.stride(3),
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_T=BLOCK_SIZE_T,
)
# argsort
pad_topk_q_idx = pad_topk_idx.view(batch_size, num_kv_heads, -1).argsort(-1) // topk
pad_topk_q_idx = pad_topk_q_idx.to(torch.int32)
# save as remove pad version
topk_q_idx = torch.full(
(num_kv_heads, cu_topk_q_count[:, -1].max().item()),
fill_value=-1,
device=topk_idx.device,
dtype=torch.int32,
)
max_len = (
(
cu_topk_q_count[:, cu_seqblocks][:, 1:]
- cu_topk_q_count[:, cu_seqblocks][:, :-1]
)
.max()
.item()
)
BLOCK_SIZE_N = min(triton.next_power_of_2(max_len), 8192)
grid = (batch_size, num_kv_heads, triton.cdiv(max_len, BLOCK_SIZE_N))
save_topk_idx_kernel[grid](
pad_topk_q_idx,
topk_q_idx,
cu_seqblocks,
cu_topk_q_count,
pad_topk_q_idx.shape[-1],
pad_topk_q_idx.stride(0),
pad_topk_q_idx.stride(1),
pad_topk_q_idx.stride(2),
topk_q_idx.stride(0),
topk_q_idx.stride(1),
cu_topk_q_count.stride(0),
cu_topk_q_count.stride(1),
BLOCK_SIZE_N=BLOCK_SIZE_N,
)
return topk_q_idx
@triton.jit
def backward_dkdv(
q_ptr, # Q: n x qh x d
k_ptr, # K: n x kh x d
v_ptr, # V: n x kh x d
tq_ptr, # topk_q_idx: kh x N
lse_ptr, # LSE: qh x n
d_ptr, # Delta: qh x n
do_ptr,
dk_ptr, # DK: sh x n x kh x d
dv_ptr, # DK: sh x n x kh x d
# seqlens
cu_seqlens_q, # [batch_size + 1]
cu_seqlens_k, # [batch_size + 1]
cu_seqblocks, # [batch_size + 1]
cu_topk_q_count, # [kh, total_blocks]
# shape
NUM_KV_HEADS,
NUM_SHARE_Q_HEADS,
HEAD_DIM,
TOPK,
# sm_scale
sm_scale,
# stride
stride_qn,
stride_qh,
stride_qd,
stride_kn,
stride_kh,
stride_kd,
stride_vn,
stride_vh,
stride_vd,
stride_tqh,
stride_tqn,
stride_ctqh,
stride_ctqn,
stride_lh,
stride_ln,
stride_dh,
stride_dn,
stride_don,
stride_doh,
stride_dod,
stride_dks,
stride_dkn,
stride_dkh,
stride_dkd,
stride_dvs,
stride_dvn,
stride_dvh,
stride_dvd,
# META parameters
BLOCK_SIZE_Q: tl.constexpr, # q block size
BLOCK_SIZE_K: tl.constexpr, # k block size
BLOCK_SIZE_D: tl.constexpr,
):
qk_scale = sm_scale * 1.44269504
# get batch id and head id
pid_b = tl.program_id(0)
pid_h = tl.program_id(1)
pid_kh = pid_h // NUM_SHARE_Q_HEADS
pid_sh = pid_h % NUM_SHARE_Q_HEADS
pid_k = tl.program_id(2)
# get q k start and len after rmpad
q_start = tl.load(cu_seqlens_q + pid_b)
q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
k_start = tl.load(cu_seqlens_k + pid_b)
k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
if BLOCK_SIZE_K * pid_k >= k_len:
return
# get topk_q_idx
b_start = tl.load(cu_seqblocks + pid_b) # how many blocks before current sequence
act_q_start = tl.load(
cu_topk_q_count + pid_kh * stride_ctqh + (b_start + pid_k) * stride_ctqn
)
act_q_end = tl.load(
cu_topk_q_count + pid_kh * stride_ctqh + (b_start + pid_k + 1) * stride_ctqn
)
act_q_len = act_q_end - act_q_start
tq_ptr = tq_ptr + pid_kh * stride_tqh + act_q_start * stride_tqn
# init pointers
k_ptrs = tl.make_block_ptr(
base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
shape=(k_len, HEAD_DIM),
strides=(stride_kn, stride_kd),
offsets=(pid_k * BLOCK_SIZE_K, 0),
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
order=(1, 0),
)
dk_ptrs = tl.make_block_ptr(
base=dk_ptr + k_start * stride_dkn + pid_kh * stride_dkh + pid_sh * stride_dks,
shape=(k_len, HEAD_DIM),
strides=(stride_dkn, stride_dkd),
offsets=(pid_k * BLOCK_SIZE_K, 0),
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
order=(1, 0),
)
v_ptrs = tl.make_block_ptr(
base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,
shape=(k_len, HEAD_DIM),
strides=(stride_vn, stride_vd),
offsets=(pid_k * BLOCK_SIZE_K, 0),
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
order=(1, 0),
)
dv_ptrs = tl.make_block_ptr(
base=dv_ptr + k_start * stride_dvn + pid_kh * stride_dvh + pid_sh * stride_dvs,
shape=(k_len, HEAD_DIM),
strides=(stride_dvn, stride_dvd),
offsets=(pid_k * BLOCK_SIZE_K, 0),
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
order=(1, 0),
)
# offsets
off_q = tl.arange(0, BLOCK_SIZE_Q)
off_k = tl.arange(0, BLOCK_SIZE_K) + pid_k * BLOCK_SIZE_K
off_d = tl.arange(0, BLOCK_SIZE_D)
# load k v and keep in SRAM
k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero")
v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero")
# init dk dv
dk = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32)
dv = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32)
# init ptrs
q_ptrs = (
q_ptr + q_start * stride_qn + pid_h * stride_qh + off_d[None, :] * stride_qd
)
do_ptrs = (
do_ptr + q_start * stride_don + pid_h * stride_doh + off_d[None, :] * stride_dod
)
d_ptrs = d_ptr + q_start * stride_dn + pid_h * stride_dh
lse_ptrs = lse_ptr + q_start * stride_ln + pid_h * stride_lh
# loop for q blocks
for i in range(0, act_q_len, BLOCK_SIZE_Q):
# load
idx_q = tl.load(tq_ptr + i + off_q, mask=off_q < act_q_len - i, other=0).to(
tl.int32
)
q = tl.load(
q_ptrs + idx_q[:, None] * stride_qn,
mask=(off_q < act_q_len - i)[:, None] & (off_d < HEAD_DIM)[None, :],
other=0,
)
do = tl.load(
do_ptrs + idx_q[:, None] * stride_don,
mask=(off_q < act_q_len - i)[:, None] & (off_d < HEAD_DIM)[None, :],
other=0,
)
lse = tl.load(
lse_ptrs + idx_q[:, None] * stride_ln,
mask=(off_q < act_q_len - i)[:, None],
other=0,
)
d = tl.load(
d_ptrs + idx_q[:, None] * stride_dn,
mask=(off_q < act_q_len - i)[:, None],
other=0,
)
# compute qk
qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
qk += tl.where(idx_q[:, None] >= off_k[None, :], float(0.0), float("-inf"))
qk += tl.dot(q, k.T) * qk_scale
# compute p, ds
p = tl.exp2(qk - lse)
dp = tl.dot(do, v.T)
ds = sm_scale * p * (dp - d)
# cast dtype
p = p.to(do.dtype)
ds = ds.to(q.dtype)
# update dk and dv
dk += tl.dot(ds.T, q)
dv += tl.dot(p.T, do)
# save dk dv
tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), boundary_check=(0, 1))
tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), boundary_check=(0, 1))
@triton.jit
def backward_dq(
q_ptr, # Q: n x qh x d
k_ptr, # K: n x kh x d
v_ptr, # V: n x kh x d
t_ptr, # topk_idx: kh x n x k
lse_ptr, # LSE: qh x n
d_ptr, # Delta: qh x n
do_ptr,
dq_ptr,
# seqlens
cu_seqlens_q,
cu_seqlens_k,
# shape
NUM_KV_HEADS,
NUM_SHARE_Q_HEADS,
HEAD_DIM,
TOPK,
# q loop num
num_q_loop,
# sm_scale
sm_scale,
# stride
stride_qn,
stride_qh,
stride_qd,
stride_kn,
stride_kh,
stride_kd,
stride_vn,
stride_vh,
stride_vd,
stride_th,
stride_tn,
stride_tk,
stride_lh,
stride_ln,
stride_dh,
stride_dn,
stride_don,
stride_doh,
stride_dod,
stride_dqn,
stride_dqh,
stride_dqd,
# META parameters
BLOCK_SIZE_K: tl.constexpr, # k block size
BLOCK_SIZE_D: tl.constexpr,
BLOCK_SIZE_H: tl.constexpr,
BLOCK_SIZE_T: tl.constexpr,
):
qk_scale = sm_scale * 1.44269504
# get batch id and head id
pid_b = tl.program_id(0)
pid_kh = tl.program_id(1)
pid_q = tl.program_id(2)
pid_h = pid_kh * NUM_SHARE_Q_HEADS
# get q k start and len after rmpad
q_start = tl.load(cu_seqlens_q + pid_b)
q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
k_start = tl.load(cu_seqlens_k + pid_b)
k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
if pid_q * num_q_loop >= q_len:
return
real_q_loop = min(num_q_loop, q_len - pid_q * num_q_loop)
for j in range(real_q_loop):
pid_q_j = pid_q * num_q_loop + j
# init topk idx pointer
off_t = tl.arange(0, BLOCK_SIZE_T)
t_ptr_j = t_ptr + (q_start + pid_q_j) * stride_tn + pid_kh * stride_th
topk_idx = tl.load(t_ptr_j + off_t * stride_tk, mask=off_t < TOPK, other=-1)
real_topk = tl.sum(
tl.where((topk_idx >= 0) & (topk_idx <= pid_q_j // BLOCK_SIZE_K), 1, 0),
axis=0,
)
# init pointers
q_ptrs = tl.make_block_ptr(
base=q_ptr + (q_start + pid_q_j) * stride_qn + pid_h * stride_qh,
shape=(NUM_SHARE_Q_HEADS, HEAD_DIM),
strides=(stride_qh, stride_qd),
offsets=(0, 0),
block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D),
order=(1, 0),
)
dq_ptrs = tl.make_block_ptr(
base=dq_ptr + (q_start + pid_q_j) * stride_dqn + pid_h * stride_dqh,
shape=(NUM_SHARE_Q_HEADS, HEAD_DIM),
strides=(stride_dqh, stride_dqd),
offsets=(0, 0),
block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D),
order=(1, 0),
)
k_ptrs = tl.make_block_ptr(
base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
shape=(k_len, HEAD_DIM),
strides=(stride_kn, stride_kd),
offsets=(0, 0),
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
order=(1, 0),
)
v_ptrs = tl.make_block_ptr(
base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,
shape=(HEAD_DIM, k_len),
strides=(stride_vd, stride_vn),
offsets=(0, 0),
block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K),
order=(0, 1),
)
do_ptrs = tl.make_block_ptr(
base=do_ptr + (q_start + pid_q_j) * stride_don + pid_h * stride_doh,
shape=(NUM_SHARE_Q_HEADS, HEAD_DIM),
strides=(stride_doh, stride_dod),
offsets=(0, 0),
block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D),
order=(1, 0),
)
d_ptrs = tl.make_block_ptr(
base=d_ptr + (q_start + pid_q_j) * stride_dn + pid_h * stride_dh,
shape=(NUM_SHARE_Q_HEADS, 1),
strides=(stride_dh, stride_dn),
offsets=(0, 0),
block_shape=(BLOCK_SIZE_H, 1),
order=(1, 0),
)
lse_ptrs = tl.make_block_ptr(
base=lse_ptr + (q_start + pid_q_j) * stride_ln + pid_h * stride_lh,
shape=(NUM_SHARE_Q_HEADS, 1),
strides=(stride_lh, stride_ln),
offsets=(0, 0),
block_shape=(BLOCK_SIZE_H, 1),
order=(1, 0),
)
# offsets
off_k = tl.arange(0, BLOCK_SIZE_K)
# load q, do, lse, delta, and keep in SRAM
q = tl.load(q_ptrs, boundary_check=(1, 0), padding_option="zero")
do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero")
lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero")
d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero")
# init dq
dq = tl.zeros((BLOCK_SIZE_H, BLOCK_SIZE_D), dtype=tl.float32)
# sparse
for i in range(real_topk):
# get current block start index
c = tl.load(t_ptr_j).to(tl.int32) * BLOCK_SIZE_K
t_ptr_j = t_ptr_j + stride_tk
# load
k = tl.load(
tl.advance(k_ptrs, (c, 0)), boundary_check=(1, 0), padding_option="zero"
)
v = tl.load(
tl.advance(v_ptrs, (0, c)), boundary_check=(0, 1), padding_option="zero"
)
# compute qk
qk = tl.zeros((BLOCK_SIZE_H, BLOCK_SIZE_K), dtype=tl.float32)
qk += tl.where((pid_q_j >= c + off_k)[None, :], 0, float("-inf"))
# [BLOCK_SIZE_H, HEAD_DIM] @ [BLOCK_SIZE_K, HEAD_DIM].T -> [BLOCK_SIZE_H, BLOCK_SIZE_K]
qk += tl.dot(q, tl.trans(k)) * qk_scale
# compute p, ds
p = tl.exp2(qk - lse)
dp = tl.dot(do, v)
ds = sm_scale * p * (dp - d)
# cast dtype
ds = ds.to(q.dtype)
# update dq
dq += tl.dot(ds, k)
# save dq
tl.store(dq_ptrs, dq.to(dq_ptr.dtype.element_ty), boundary_check=(0, 1))
def _topk_sparse_attention_fwd(
q: torch.Tensor, # [total_len, num_q_heads, head_dim]
k: torch.Tensor, # [total_len, num_k_heads, head_dim]
v: torch.Tensor, # [total_len, num_k_heads, head_dim]
topk_idx: torch.Tensor, # [num_kv_heads, total_len, topk]
block_size: int,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
sm_scale: float,
):
# dtype check
assert k.dtype == q.dtype and v.dtype == q.dtype
assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32
assert block_size in {32, 64, 128, 256}
# shape
q_len, num_q_heads, head_dim = q.shape
k_len, num_k_heads, head_dim = k.shape
v_len, num_v_heads, head_dim = v.shape
batch_size = cu_seqlens_q.shape[0] - 1
assert q_len == k_len and k_len == v_len
topk = topk_idx.shape[-1]
assert topk_idx.shape[0] == num_k_heads
assert topk_idx.shape[1] == q_len
# gqa
assert num_k_heads == num_v_heads
assert num_q_heads % num_k_heads == 0
num_share_q_heads = num_q_heads // num_k_heads
# output tensor
o = torch.zeros_like(q)
lse = torch.zeros(num_q_heads, q_len, dtype=torch.float32, device=q.device)
# launch kernel
num_q_loop = (
max_seqlen_q // 32768 + 1
) # calculate multiple querys in one kernel if seqlence length is too long
grid = (batch_size, num_k_heads, triton.cdiv(max_seqlen_q, num_q_loop))
BLOCK_SIZE_K = triton.next_power_of_2(block_size)
BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
BLOCK_SIZE_H = max(16, triton.next_power_of_2(num_share_q_heads))
BLOCK_SIZE_T = triton.next_power_of_2(topk)
num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_K, IS_HOPPER_GPU)
forward_kernel[grid](
q,
k,
v,
topk_idx,
o,
lse,
cu_seqlens_q,
cu_seqlens_k,
num_k_heads,
num_share_q_heads,
head_dim,
topk,
num_q_loop,
sm_scale,
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
topk_idx.stride(0),
topk_idx.stride(1),
topk_idx.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
lse.stride(0),
lse.stride(1),
BLOCK_SIZE_K=BLOCK_SIZE_K,
BLOCK_SIZE_D=BLOCK_SIZE_D,
BLOCK_SIZE_H=BLOCK_SIZE_H,
BLOCK_SIZE_T=BLOCK_SIZE_T,
num_warps=num_warps,
num_stages=num_stages,
)
return o, lse
def _topk_sparse_attention_bwd(
o: torch.Tensor,
do: torch.Tensor,
lse: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
topk_idx: torch.Tensor,
block_size: int,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
sm_scale: float,
):
assert block_size in {32, 64, 128, 256}
q_len, num_q_heads, head_dim = q.shape
k_len, num_k_heads, head_dim = k.shape
v_len, num_v_heads, head_dim = v.shape
o_len, num_o_heads, head_dim = o.shape
num_share_q_heads = num_q_heads // num_k_heads
topk = topk_idx.shape[-1]
# compute D
delta = torch.zeros([num_o_heads, o_len], device=o.device, dtype=torch.float32)
BLOCK_SIZE_O = 256
BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_O, IS_HOPPER_GPU)
grid = (triton.cdiv(o_len, BLOCK_SIZE_O), num_o_heads)
backward_sum_o_do[grid](
o,
do,
delta,
o_len,
head_dim,
o.stride(0),
o.stride(1),
o.stride(2),
do.stride(0),
do.stride(1),
do.stride(2),
delta.stride(0),
delta.stride(1),
BLOCK_SIZE_O=BLOCK_SIZE_O,
BLOCK_SIZE_D=BLOCK_SIZE_D,
num_warps=num_warps,
num_stages=num_stages,
)
# count active querys for each key block, shape: (num_k_heads, total_k_blocks)
seqlens = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
seqblocks = torch.ceil(seqlens / block_size).to(torch.int32)
cu_seqblocks = torch.cat(
[
torch.zeros(1, dtype=torch.int32, device=topk_idx.device),
torch.cumsum(seqblocks, dim=0),
]
).to(torch.int32)
topk_q_count = count_query(topk_idx, cu_seqlens_q, cu_seqblocks, block_size)
cu_topk_q_count = torch.cat(
[
torch.zeros(
topk_q_count.shape[0], 1, dtype=torch.int32, device=topk_idx.device
),
torch.cumsum(topk_q_count, dim=-1),
],
dim=-1,
).to(torch.int32)
# active query idx for each key block
# how to get active query idx for sequence b, head h, kv block i?
# topk_q_idx[h, cu_topk_q_count[h, cu_seqblocks[b] + i]:cu_topk_q_count[h, cu_seqblocks[b] + i + 1]]
topk_q_idx = reorder_topk_idx(
topk_idx, cu_topk_q_count, cu_seqlens_q, cu_seqblocks, block_size
)
# compute dk dv
dk = torch.zeros(
num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype
)
dv = torch.zeros(
num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype
)
batch_size = cu_seqlens_q.shape[0] - 1
BLOCK_SIZE_K = triton.next_power_of_2(block_size)
BLOCK_SIZE_Q = 64
BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU)
grid = (batch_size, num_q_heads, triton.cdiv(max_seqlen_k, BLOCK_SIZE_K))
backward_dkdv[grid](
q,
k,
v,
topk_q_idx,
lse,
delta,
do,
dk,
dv,
cu_seqlens_q,
cu_seqlens_k,
cu_seqblocks,
cu_topk_q_count,
num_k_heads,
num_share_q_heads,
head_dim,
topk,
sm_scale,
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
topk_q_idx.stride(0),
topk_q_idx.stride(1),
cu_topk_q_count.stride(0),
cu_topk_q_count.stride(1),
lse.stride(0),
lse.stride(1),
delta.stride(0),
delta.stride(1),
do.stride(0),
do.stride(1),
do.stride(2),
dk.stride(0),
dk.stride(1),
dk.stride(2),
dk.stride(3),
dv.stride(0),
dv.stride(1),
dv.stride(2),
dv.stride(3),
BLOCK_SIZE_Q=BLOCK_SIZE_Q,
BLOCK_SIZE_K=BLOCK_SIZE_K,
BLOCK_SIZE_D=BLOCK_SIZE_D,
num_warps=num_warps,
num_stages=num_stages,
)
dk = dk.sum(0)
dv = dv.sum(0)
# compute dq
dq = torch.zeros_like(q)
num_q_loop = (
max_seqlen_q // 32768 + 1
) # calculate multiple querys in one kernel if seqlence length is too long
grid = (batch_size, num_k_heads, triton.cdiv(max_seqlen_q, num_q_loop))
BLOCK_SIZE_K = block_size
BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
BLOCK_SIZE_H = max(16, triton.next_power_of_2(num_share_q_heads))
BLOCK_SIZE_T = triton.next_power_of_2(topk)
num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_K, IS_HOPPER_GPU)
backward_dq[grid](
q,
k,
v,
topk_idx,
lse,
delta,
do,
dq,
cu_seqlens_q,
cu_seqlens_k,
num_k_heads,
num_share_q_heads,
head_dim,
topk,
num_q_loop,
sm_scale,
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
topk_idx.stride(0),
topk_idx.stride(1),
topk_idx.stride(2),
lse.stride(0),
lse.stride(1),
delta.stride(0),
delta.stride(1),
do.stride(0),
do.stride(1),
do.stride(2),
dq.stride(0),
dq.stride(1),
dq.stride(2),
BLOCK_SIZE_K=BLOCK_SIZE_K,
BLOCK_SIZE_D=BLOCK_SIZE_D,
BLOCK_SIZE_H=BLOCK_SIZE_H,
BLOCK_SIZE_T=BLOCK_SIZE_T,
num_warps=num_warps,
num_stages=num_stages,
)
return dq, dk, dv
class TopkSparseAttention(torch.autograd.Function):
@staticmethod
def forward(
ctx,
q: torch.Tensor, # [total_len, num_q_heads, head_dim]
k: torch.Tensor, # [total_len, num_k_heads, head_dim]
v: torch.Tensor, # [total_len, num_k_heads, head_dim]
topk_idx: torch.Tensor, # [num_kv_heads, total_len, topk]
block_size: int,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
sm_scale=None,
):
# dtype check
assert q.dtype == torch.bfloat16 or q.dtype == torch.float16
assert q.dtype == k.dtype and k.dtype == v.dtype
assert topk_idx.dtype == torch.int32
assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32
# softmax scale
if sm_scale is None:
sm_scale = 1 / math.sqrt(q.shape[-1])
o, lse = _topk_sparse_attention_fwd(
q,
k,
v,
topk_idx,
block_size,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
sm_scale,
)
ctx.save_for_backward(q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k, topk_idx)
ctx.sm_scale = sm_scale
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_k = max_seqlen_k
ctx.block_size = block_size
# return
return o
@staticmethod
def backward(ctx, do: torch.Tensor, *args) -> Any:
q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k, topk_idx = ctx.saved_tensors
max_seqlen_q = ctx.max_seqlen_q
max_seqlen_k = ctx.max_seqlen_k
sm_scale = ctx.sm_scale
block_size = ctx.block_size
assert block_size in {32, 64, 128, 256}
dq, dk, dv = _topk_sparse_attention_bwd(
o,
do,
lse,
q,
k,
v,
topk_idx,
block_size,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
sm_scale,
)
return dq, dk, dv, None, None, None, None, None, None, None, None
def topk_sparse_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
topk_idx: torch.Tensor,
block_size: int,
cu_seqlens: torch.Tensor,
softmax_scale: Optional[float] = None,
) -> torch.Tensor:
"""Topk sparse attention varlen version implemented in triton.
Args:
q (torch.Tensor): shape [total_len, num_q_heads, head_dim]
k (torch.Tensor): shape [total_len, num_kv_heads, head_dim]
v (torch.Tensor): shape [total_len, num_kv_heads, head_dim]
topk_idx (torch.Tensor): topk block idx for each query, shape [num_kv_heads, total_len, topk]. -1 means padding.
block_size (int): key value block size.
cu_seqlens (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens in flash_attn_func_varlen.
softmax_scale (Optional[float], optional): Defaults to None, means 1/sqrt(head_dim).
Returns:
torch.Tensor: attention output, shape [total_len, num_q_heads, head_dim]
"""
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
return TopkSparseAttention.apply(
q,
k,
v,
topk_idx,
block_size,
cu_seqlens,
cu_seqlens,
max_seqlen,
max_seqlen,
softmax_scale,
)
================================================
FILE: native_sparse_attention/ops/triton/topk_sparse_attention_decode.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
import triton
import triton.language as tl
from typing import Optional
@triton.jit
def forward_kernel(
q_ptr, # Q: b x h x d
k_ptr, # K: b x n x kh x d
v_ptr, # V: b x n x kh x d
t_ptr, # topk_idx: kh x b x k
o_ptr, # O: b x h x d
# seqlens
seqlens,
# shape
NUM_SHARE_Q_HEADS,
HEAD_DIM,
TOPK,
# sm_scale
sm_scale,
# stride
stride_qb,
stride_qh,
stride_qd,
stride_kb,
stride_kn,
stride_kh,
stride_kd,
stride_vb,
stride_vn,
stride_vh,
stride_vd,
stride_th,
stride_tb,
stride_tk,
stride_ob,
stride_oh,
stride_od,
# META parameters
BLOCK_SIZE_K: tl.constexpr, # k block size
BLOCK_SIZE_D: tl.constexpr,
BLOCK_SIZE_H: tl.constexpr,
BLOCK_SIZE_T: tl.constexpr,
):
qk_scale = sm_scale * 1.44269504
# get batch id and head id
pid_b = tl.program_id(0)
pid_kh = tl.program_id(1)
# get kv_len
kv_len = tl.load(seqlens + pid_b)
# init topk idx pointer
off_t = tl.arange(0, BLOCK_SIZE_T)
t_ptr = t_ptr + pid_b * stride_tb + pid_kh * stride_th
topk_idx = tl.load(t_ptr + off_t * stride_tk, mask=off_t < TOPK, other=-1)
real_topk = tl.sum(
tl.where((topk_idx >= 0) & (topk_idx <= (kv_len - 1) // BLOCK_SIZE_K), 1, 0),
axis=0,
)
# init qkv pointer
q_ptrs = tl.make_block_ptr(
base=q_ptr + pid_b * stride_qb + pid_kh * NUM_SHARE_Q_HEADS * stride_qh,
shape=(NUM_SHARE_Q_HEADS, HEAD_DIM),
strides=(stride_qh, stride_qd),
offsets=(0, 0),
block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D),
order=(1, 0),
)
k_ptrs = tl.make_block_ptr(
base=k_ptr + pid_b * stride_kb + pid_kh * stride_kh,
shape=(HEAD_DIM, kv_len),
strides=(stride_kd, stride_kn),
offsets=(0, 0),
block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K),
order=(0, 1),
)
v_ptrs = tl.make_block_ptr(
base=v_ptr + pid_b * stride_vb + pid_kh * stride_vh,
shape=(kv_len, HEAD_DIM),
strides=(stride_vn, stride_vd),
offsets=(0, 0),
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
order=(1, 0),
)
# load q
q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero")
# init statistics
off_k = tl.arange(0, BLOCK_SIZE_K)
m_i = tl.full((BLOCK_SIZE_H,), float("-inf"), dtype=tl.float32)
lse_i = tl.full((BLOCK_SIZE_H,), float("-inf"), dtype=tl.float32)
acc_o = tl.full((BLOCK_SIZE_H, BLOCK_SIZE_D), 0, dtype=tl.float32)
# sparse attention
for i in range(real_topk):
# get current block start index
c = tl.load(t_ptr).to(tl.int32) * BLOCK_SIZE_K
t_ptr = t_ptr + stride_tk
# load k
k = tl.load(
tl.advance(k_ptrs, (0, c)), boundary_check=(1, 0), padding_option="zero"
)
# compute qk
qk = tl.zeros((BLOCK_SIZE_H, BLOCK_SIZE_K), dtype=tl.float32)
qk += tl.where((kv_len > c + off_k)[None, :], 0, float("-inf"))
# [BLOCK_SIZE_H, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIZE_K] -> [BLOCK_SIZE_H, BLOCK_SIZE_K]
qk += tl.dot(q, k) * qk_scale
# compute m_ij and l_ij
m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
p = tl.exp2(qk - m_ij[:, None])
l_ij = tl.sum(p, axis=1)
# scale acc_o
acc_o_scale = tl.exp2(m_i - m_ij)
acc_o = acc_o * acc_o_scale[:, None]
# load v and update acc_o
v = tl.load(
tl.advance(v_ptrs, (c, 0)), boundary_check=(0, 1), padding_option="zero"
)
p = p.to(v.dtype)
acc_o += tl.dot(p, v)
# update statistics
m_i = m_ij
lse_i = m_ij + tl.math.log2(tl.exp2(lse_i - m_ij) + l_ij)
# final scale
acc_o = acc_o * tl.exp2(m_i - lse_i)[:, None]
# save output
o_ptrs = tl.make_block_ptr(
base=o_ptr + pid_b * stride_ob + pid_kh * NUM_SHARE_Q_HEADS * stride_oh,
shape=(NUM_SHARE_Q_HEADS, HEAD_DIM),
strides=(stride_oh, stride_od),
offsets=(0, 0),
block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D),
order=(1, 0),
)
tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1))
def topk_sparse_attention_decode(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
topk_idx: torch.Tensor,
block_size: int,
seqlens: torch.Tensor,
sm_scale: Optional[float] = None,
) -> torch.Tensor:
"""_summary_
Args:
q (torch.Tensor): shape [batch_size, num_q_heads, head_dim]
k (torch.Tensor): shape [batch_size, kv_len, num_kv_heads, head_dim]
v (torch.Tensor): shape [batch_size, kv_len, num_kv_heads, head_dim]
topk_idx (torch.Tensor): topk block idx for each query, shape [num_kv_heads, batch_size, topk]. -1 means padding.
block_size (int): key value block size.
seqlens (torch.Tensor): max kv length for each sequence
softmax_scale (Optional[float], optional): Defaults to None, means 1/sqrt(head_dim).
Returns:
torch.Tensor: sparse attention output
"""
# dtype check
assert q.dtype == torch.bfloat16 or q.dtype == torch.float16
assert k.dtype == q.dtype and v.dtype == q.dtype
assert seqlens.dtype == torch.int32
# shape
batch_size, num_q_heads, head_dim = q.shape
_, k_len, num_k_heads, head_dim = k.shape
_, v_len, num_v_heads, head_dim = v.shape
assert k_len == v_len and batch_size == seqlens.shape[0]
assert num_k_heads == topk_idx.shape[0] and batch_size == topk_idx.shape[1]
topk = topk_idx.shape[-1]
# gqa
assert num_k_heads == num_v_heads
assert num_q_heads % num_k_heads == 0
num_share_q_heads = num_q_heads // num_k_heads
# sm scale
if sm_scale is None:
sm_scale = 1 / math.sqrt(head_dim)
# output tensor
o = torch.zeros_like(q)
# launch kernel
grid = (batch_size, num_k_heads)
num_warps = 4 if head_dim <= 64 else 8
num_stages = 3
BLOCK_SIZE_K = triton.next_power_of_2(block_size)
BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
BLOCK_SIZE_H = max(16, triton.next_power_of_2(num_share_q_heads))
BLOCK_SIZE_T = triton.next_power_of_2(topk)
forward_kernel[grid](
q,
k,
v,
topk_idx,
o,
seqlens,
num_share_q_heads,
head_dim,
topk,
sm_scale,
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
k.stride(3),
v.stride(0),
v.stride(1),
v.stride(2),
v.stride(3),
topk_idx.stride(0),
topk_idx.stride(1),
topk_idx.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
BLOCK_SIZE_K=BLOCK_SIZE_K,
BLOCK_SIZE_D=BLOCK_SIZE_D,
BLOCK_SIZE_H=BLOCK_SIZE_H,
BLOCK_SIZE_T=BLOCK_SIZE_T,
num_warps=num_warps,
num_stages=num_stages,
)
return o
def torch_topk_sparse_attention_decode(
q: torch.Tensor, # [batch_size, num_q_heads, head_dim]
k: torch.Tensor, # [batch_size, kv_len, num_k_heads, head_dim]
v: torch.Tensor, # [batch_size, kv_len, num_k_heads, head_dim]
topk_idx: torch.Tensor, # [num_k_heads, batch_size, topk]
block_size: int,
seqlens: torch.Tensor, # [batch_size, ]
sm_scale: Optional[float] = None,
):
# dtype check
assert q.dtype == torch.bfloat16 or q.dtype == torch.float16
assert k.dtype == q.dtype and v.dtype == q.dtype
assert seqlens.dtype == torch.int32
# shape
batch_size, num_q_heads, head_dim = q.shape
_, k_len, num_k_heads, head_dim = k.shape
_, v_len, num_v_heads, head_dim = v.shape
assert k_len == v_len and batch_size == seqlens.shape[0]
assert num_k_heads == topk_idx.shape[0] and batch_size == topk_idx.shape[1]
topk = topk_idx.shape[-1]
# gqa
assert num_k_heads == num_v_heads
assert num_q_heads % num_k_heads == 0
num_share_q_heads = num_q_heads // num_k_heads
# sm scale
if sm_scale is None:
sm_scale = 1 / math.sqrt(head_dim)
# mask
mask = torch.zeros(
(batch_size, num_k_heads, k_len), dtype=torch.bool, device=q.device
)
for b in range(batch_size):
for h in range(num_k_heads):
for t in range(topk):
if topk_idx[h, b, t] != -1:
mask[
b,
h,
topk_idx[h, b, t]
* block_size : (topk_idx[h, b, t] + 1)
* block_size,
] = True
mask = mask & (
(seqlens - 1)[:, None, None] >= torch.arange(k_len).cuda()[None, None, :]
)
mask = mask.repeat_interleave(num_share_q_heads, 1)
# attention
attn = (
torch.einsum(
"bqhd,bkhd->bhqk", q.unsqueeze(1), k.repeat_interleave(num_share_q_heads, 2)
)
* sm_scale
)
attn = attn.masked_fill(~mask.unsqueeze(2), -torch.inf)
attn = torch.softmax(attn, dim=-1, dtype=torch.float32).to(q.dtype)
out = torch.einsum(
"bhqk,bkhd->bqhd", attn, v.repeat_interleave(num_share_q_heads, 2)
).squeeze(1)
return out
def generate_topk_idx_example(
seqlens: torch.Tensor, block_size: int, topk: int, num_heads: int
) -> torch.Tensor:
batch_size = seqlens.shape[0]
num_blocks = torch.ceil(seqlens / block_size).to(torch.int32)
topk_idx_all_heads = []
for _ in range(num_heads):
topk_idx = [
torch.randn(1, num_blocks[i], device="cuda")
.topk(min(topk, num_blocks[i]), dim=-1)
.indices.to(torch.int32)
for i in range(batch_size)
]
topk_idx = [
torch.nn.functional.pad(
topk_idx[i], (0, topk - topk_idx[i].shape[-1]), value=topk
)
for i in range(batch_size)
]
topk_idx = torch.cat(topk_idx, dim=0)
topk_idx = torch.sort(topk_idx, dim=1).values
topk_idx[:, 0] = 0
q_idx = seqlens - 1
topk_idx[topk_idx > (q_idx // block_size)[:, None]] = -1 # -1 means padding
topk_idx_all_heads.append(topk_idx)
topk_idx = torch.stack(topk_idx_all_heads, dim=0)
return topk_idx
if __name__ == "__main__":
torch.manual_seed(42)
topk = 16
block_size = 64
batch_size = 76
max_length = 8192
seqlens = torch.arange(batch_size, dtype=torch.int32).cuda() * 128 + 1
seqlens[seqlens > max_length] = max_length
seqlens = seqlens[torch.randn_like(seqlens, dtype=torch.float32).argsort(-1)]
q = (
torch.empty(batch_size, 32, 128, device="cuda")
.uniform_(-1, 1)
.to(torch.float16)
)
k = (
torch.empty(batch_size, max_length, 4, 128, device="cuda")
.uniform_(-1, 1)
.to(torch.float16)
)
v = (
torch.empty(batch_size, max_length, 4, 128, device="cuda")
.uniform_(-1, 1)
.to(torch.float16)
)
topk_idx = generate_topk_idx_example(seqlens, block_size, topk, 4)
o1 = torch_topk_sparse_attention_decode(q, k, v, topk_idx, block_size, seqlens)
o2 = topk_sparse_attention_decode(q, k, v, topk_idx, block_size, seqlens)
print(torch.allclose(o1, o2, atol=1e-3, rtol=1e-3))
print((o1 - o2).abs().max())
================================================
FILE: native_sparse_attention/ops/triton/utils.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
def is_hopper_gpu():
if torch.cuda.is_available():
device_capability = torch.cuda.get_device_capability()
major, minor = device_capability
return major == 9
return False
def get_compressed_seqlens(
cu_seqlens: torch.Tensor, kernel_size: int, kernel_stride: int
):
# compute seqlens after compression
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
y_seqlens = torch.floor((seqlens - kernel_size) / kernel_stride).to(torch.int32) + 1
# corner case, if sequence_length < kernel_size, no compression for this sequence
y_seqlens[seqlens < kernel_size] = 0
y_cu_seqlens = torch.zeros(
y_seqlens.shape[0] + 1, dtype=torch.int32, device=cu_seqlens.device
)
y_cu_seqlens[1:] = torch.cumsum(y_seqlens, dim=0)
return y_seqlens, y_cu_seqlens
def get_num_warps_stages(head_dim, block_size, is_hopper_gpu):
"""
Returns recommended num_warps and num_stages for a Sparse Attention kernel in Triton.
Args:
head_dim (int): Size of the head dimension.
block_size (int): Size of the block in the attention matrix.
is_hopper_gpu (bool): True if Hopper GPU, False if Ampere GPU.
Returns:
tuple: (num_warps, num_stages) recommended values.
"""
# Determine if head_dim and block_size exceed 64
head_large = head_dim > 64
block_large = block_size > 64
if is_hopper_gpu:
# Hopper GPU recommendations
if head_large and block_large:
num_warps = 8
num_stages = 3
elif head_large or block_large:
num_warps = 4
num_stages = 3
else:
num_warps = 2
num_stages = 2
else:
# Ampere GPU recommendations
if head_large and block_large:
num_warps = 8
num_stages = 3
elif head_large or block_large:
num_warps = 8
num_stages = 3
else:
num_warps = 2
num_stages = 2
if head_dim > 128:
num_stages = 2
return num_warps, num_stages
================================================
FILE: native_sparse_attention/ops/triton/weighted_pool.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from einops import einsum
import torch
import triton
import triton.language as tl
from native_sparse_attention.ops.triton.utils import get_compressed_seqlens
@triton.jit
def sliding_pool_fwd_kernel(
x_ptr,
y_ptr,
w_ptr,
cu_seqlens,
y_cu_seqlens,
head_dim,
kernel_size,
kernel_stride,
stride_xn,
stride_xh,
stride_xd,
stride_yn,
stride_yh,
stride_yd,
stride_wh,
stride_wk,
BLOCK_SIZE_K: tl.constexpr,
BLOCK_SIZE_D: tl.constexpr,
):
pid_b = tl.program_id(0)
pid_h = tl.program_id(1)
pid_k = tl.program_id(2)
# get start and len after rmpad
x_start = tl.load(cu_seqlens + pid_b)
x_len = tl.load(cu_seqlens + pid_b + 1) - x_start
y_start = tl.load(y_cu_seqlens + pid_b)
y_len = tl.load(y_cu_seqlens + pid_b + 1) - y_start
if pid_k >= y_len:
return
if w_ptr is not None:
# load w
w_ptrs = tl.make_block_ptr(
base=w_ptr + pid_h * stride_wh,
shape=(kernel_size, 1),
strides=(stride_wk, 0),
offsets=(0, 0),
block_shape=(BLOCK_SIZE_K, 1),
order=(0, 1),
)
w = tl.load(w_ptrs, boundary_check=(0, 1), padding_option="zero")
# load x
x_ptrs = tl.make_block_ptr(
base=x_ptr + x_start * stride_xn + pid_h * stride_xh,
shape=(x_len, head_dim),
strides=(stride_xn, stride_xd),
offsets=(pid_k * kernel_stride, 0),
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
order=(1, 0),
)
x = tl.load(x_ptrs, boundary_check=(0, 1), padding_option="zero")
# compute y
if w_ptr is not None:
y = tl.sum(x * w, axis=0)
else:
y = tl.sum(x, axis=0) / kernel_size
off_d = tl.arange(0, BLOCK_SIZE_D)
tl.store(
y_ptr + (y_start + pid_k) * stride_yn + pid_h * stride_yh + off_d * stride_yd,
y.to(y_ptr.dtype.element_ty),
mask=off_d < head_dim,
)
@triton.jit
def sliding_pool_dxdw_kernel(
x_ptr,
dx_ptr,
dy_ptr,
w_ptr,
dw_ptr,
cu_seqlens,
y_cu_seqlens,
head_dim,
kernel_size,
kernel_stride,
stride_xn,
stride_xh,
stride_xd,
stride_dxn,
stride_dxh,
stride_dxd,
stride_dyn,
stride_dyh,
stride_dyd,
stride_wh,
stride_wk,
stride_dwh,
stride_dwn,
stride_dwk,
BLOCK_SIZE_K: tl.constexpr,
BLOCK_SIZE_D: tl.constexpr,
):
pid_b = tl.program_id(0)
pid_h = tl.program_id(1)
pid_k = tl.program_id(2)
# get start and len after rmpad
x_start = tl.load(cu_seqlens + pid_b)
x_len = tl.load(cu_seqlens + pid_b + 1) - x_start
y_start = tl.load(y_cu_seqlens + pid_b)
y_len = tl.load(y_cu_seqlens + pid_b + 1) - y_start
if pid_k >= y_len:
return
# offsets
off_d = tl.arange(0, BLOCK_SIZE_D)
off_k = tl.arange(0, BLOCK_SIZE_K)
if w_ptr is not None:
# load w
w_ptrs = w_ptr + pid_h * stride_wh + off_k * stride_wk
w = tl.load(w_ptrs, mask=off_k < kernel_size, other=0)
# load x
x_ptrs = tl.make_block_ptr(
base=x_ptr + x_start * stride_xn + pid_h * stride_xh,
shape=(head_dim, x_len),
strides=(stride_xd, stride_xn),
offsets=(0, pid_k * kernel_stride),
block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K),
order=(0, 1),
)
x = tl.load(x_ptrs, boundary_check=(0, 1), padding_option="zero")
# load dy
dy_ptrs = (
dy_ptr
+ pid_h * stride_dyh
+ (y_start + pid_k) * stride_dyn
+ off_d * stride_dyd
)
dy = tl.load(dy_ptrs, mask=off_d < head_dim, other=0)
if w_ptr is not None:
# compute dx, [1, D] x [K, 1] -> [K, D]
dx = dy[None, :] * w[:, None]
# compute dw, [D, 1] x [D, K] -> [D, K] -> [K]
dw = tl.sum(dy[:, None] * x, axis=0)
# store dw
dw_ptrs = (
dw_ptr
+ pid_h * stride_dwh
+ (y_start + pid_k) * stride_dwn
+ off_k * stride_dwk
)
tl.store(dw_ptrs, dw.to(dw_ptr.dtype.element_ty), mask=off_k < kernel_size)
else:
dx = dy[None, :] / kernel_size
# store dx
dx_ptrs = (
dx_ptr
+ pid_h * stride_dxh
+ (x_start + pid_k * kernel_stride + off_k[:, None]) * stride_dxn
+ off_d[None, :] * stride_dxd
)
tl.atomic_add(
dx_ptrs,
dx.to(dx_ptr.dtype.element_ty),
mask=(off_k < x_len - pid_k * kernel_stride)[:, None]
& (off_d < head_dim)[None, :],
)
class SlidingWindowWeightedPool(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x: torch.Tensor, # [total_len, num_heads, head_dim]
w: torch.Tensor, # [num_heads, kernel_size]
cu_seqlens: torch.Tensor,
kernel_size: int,
kernel_stride: int,
):
# dtype check
assert x.dtype == torch.float16 or x.dtype == torch.bfloat16
if w is not None:
assert x.dtype == w.dtype
assert cu_seqlens.dtype == torch.int32
# shape check
total_len, num_heads, head_dim = x.shape
batch_size = cu_seqlens.shape[0] - 1
if w is not None:
assert w.shape[0] == num_heads
assert w.shape[1] == kernel_size
assert kernel_size % kernel_stride == 0
assert kernel_size in {16, 32, 64, 128}
# compute seqlens after compression
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
y_seqlens, y_cu_seqlens = get_compressed_seqlens(
cu_seqlens, kernel_size, kernel_stride
)
# output buffer
y = torch.zeros(
y_cu_seqlens[-1], num_heads, head_dim, dtype=x.dtype, device=x.device
)
# launch kernel
BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
BLOCK_SIZE_K = triton.next_power_of_2(kernel_size)
grid = (batch_size, num_heads, y_seqlens.max().item())
sliding_pool_fwd_kernel[grid](
x,
y,
w,
cu_seqlens,
y_cu_seqlens,
head_dim,
kernel_size,
kernel_stride,
x.stride(0),
x.stride(1),
x.stride(2),
y.stride(0),
y.stride(1),
y.stride(2),
w.stride(0) if w is not None else None,
w.stride(1) if w is not None else None,
BLOCK_SIZE_K=BLOCK_SIZE_K,
BLOCK_SIZE_D=BLOCK_SIZE_D,
)
ctx.save_for_backward(x, w, seqlens, cu_seqlens, y_seqlens, y_cu_seqlens)
ctx.kernel_size = kernel_size
ctx.kernel_stride = kernel_stride
ctx.head_dim = head_dim
return y, y_cu_seqlens
@staticmethod
def backward(ctx, dy, _):
x, w, seqlens, cu_seqlens, y_seqlens, y_cu_seqlens = ctx.saved_tensors
kernel_size = ctx.kernel_size
kernel_stride = ctx.kernel_stride
head_dim = ctx.head_dim
batch_size = cu_seqlens.shape[0] - 1
num_heads = x.shape[1]
# compute dx
dx = torch.zeros_like(x, dtype=torch.float32)
if w is not None:
dw = torch.zeros(
num_heads,
y_cu_seqlens[-1],
kernel_size,
dtype=torch.float32,
device=w.device,
)
BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
BLOCK_SIZE_K = triton.next_power_of_2(kernel_size)
grid = (batch_size, num_heads, y_seqlens.max().item())
sliding_pool_dxdw_kernel[grid](
x,
dx,
dy,
w,
dw if w is not None else None,
cu_seqlens,
y_cu_seqlens,
head_dim,
kernel_size,
kernel_stride,
x.stride(0),
x.stride(1),
x.stride(2),
dx.stride(0),
dx.stride(1),
dx.stride(2),
dy.stride(0),
dy.stride(1),
dy.stride(2),
w.stride(0) if w is not None else None,
w.stride(1) if w is not None else None,
dw.stride(0) if w is not None else None,
dw.stride(1) if w is not None else None,
dw.stride(2) if w is not None else None,
BLOCK_SIZE_K=BLOCK_SIZE_K,
BLOCK_SIZE_D=BLOCK_SIZE_D,
)
dx = dx.to(x.dtype)
if w is None:
dw = None
else:
dw = dw.sum(1).to(w.dtype)
return dx, dw, None, None, None
def weightedpool_compress(
x: torch.Tensor, # [total_len, num_heads, head_dim]
w: torch.Tensor, # [num_heads, kernel_size]
cu_seqlens: torch.Tensor,
kernel_size: int,
kernel_stride: int,
pe: Optional[torch.Tensor] = None,
):
"""Compress key and value tensor with kernel_size and kernel_stride with weighted pooling
Args:
x (torch.Tensor): key_states or value_states, shape (total_len, num_heads, head_dim)
w (torch.Tensor): weight for each head, shape (num_heads, kernel_size)
cu_seqlens (_type_): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.
kernel_size (int): kernel_size, each (kernel_size, head_dim) blocks will be compressed to (1, head_dim)
kernel_stride (int): stride for each compress kernel
pe (Optional[torch.Tensor], optional): intra-block positional embedding with shape (num_heads, kernel_size, head_dim). Defaults to None.
Returns:
Tuple[torch.Tensor, torch.Tensor]: compressed states and corresponding cu_seqlens.
"""
y, y_cu_seqlens = SlidingWindowWeightedPool.apply(
x, w, cu_seqlens, kernel_size, kernel_stride
)
# position embedding as a bias
if pe is not None:
assert pe.dtype == x.dtype and pe.device == x.device
bias = einsum(pe, w, "h k d, h k -> h d")
y = y + bias.unsqueeze(0)
return y, y_cu_seqlens
def avgpool_compress(
x: torch.Tensor, # [total_len, num_heads, head_dim]
w: torch.Tensor, # don't need weight
cu_seqlens: torch.Tensor,
kernel_size: int,
kernel_stride: int,
pe: Optional[torch.Tensor] = None,
):
"""Compress key and value tensor with kernel_size and kernel_stride with average pooling.
Args:
x (torch.Tensor): key_states or value_states, shape (total_len, num_heads, head_dim)
w (torch.Tensor): weight for each head, shape (num_heads, kernel_size)
cu_seqlens (_type_): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.
kernel_size (int): kernel_size, each (kernel_size, head_dim) blocks will be compressed to (1, head_dim)
kernel_stride (int): stride for each compress kernel
pe (Optional[torch.Tensor], optional): intra-block positional embedding with shape (num_heads, kernel_size, head_dim). Defaults to None.
Returns:
Tuple[torch.Tensor, torch.Tensor]: compressed states and corresponding cu_seqlens.
"""
assert w is None, "don't need additional weight for avgpool"
y, y_cu_seqlens = SlidingWindowWeightedPool.apply(
x, w, cu_seqlens, kernel_size, kernel_stride
)
# position embedding as a bias
if pe is not None:
assert pe.dtype == x.dtype and pe.device == x.device
bias = torch.mean(pe, dim=1)
y = y + bias.unsqueeze(0)
return y, y_cu_seqlens
if __name__ == "__main__":
from native_sparse_attention.ops.torch.compress_key_value import (
weightedpool_compress_torch,
)
torch.manual_seed(42)
num_heads = 4
head_dim = 128
kernel_size = 32
kernel_stride = 16
seqlens = torch.LongTensor([12, 1000, 2000, 4096]).int().cuda()
cu_seqlens = torch.cat(
[
torch.zeros(1, dtype=torch.int32, device="cuda"),
torch.cumsum(seqlens, dim=0),
],
dim=0,
).to(torch.int32)
x = (
torch.zeros(cu_seqlens[-1], num_heads, head_dim)
.uniform_(-1, 1)
.to(torch.bfloat16)
.cuda()
.requires_grad_()
)
w = (
torch.zeros(num_heads, kernel_size)
.uniform_(-1 / 32, 1 / 32)
.cuda()
.to(torch.bfloat16)
.requires_grad_()
)
y, y_cu_seqlens = weightedpool_compress_torch(
x, w, cu_seqlens, kernel_size, kernel_stride, None
)
x1 = x.clone().detach().requires_grad_()
w1 = w.clone().detach().requires_grad_()
y1, y1_cu_seqlens = weightedpool_compress(
x1, w1, cu_seqlens, kernel_size, kernel_stride
)
print(torch.allclose(y, y1, rtol=1e-2, atol=1e-2))
print(torch.abs(y - y1).max().item())
randn = torch.randn_like(y)
randn1 = randn.clone().detach()
loss = (y * randn).sum()
loss1 = (y1 * randn1).sum()
loss.backward()
loss1.backward()
print((x.grad - x1.grad).abs().max())
print(((w.grad - w1.grad).abs() / w.grad.abs()).max())
================================================
FILE: setup.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from setuptools import setup, find_packages
# Read the README.md file for the long description
with open("README.md", "r", encoding="utf-8") as fh:
long_description = fh.read()
# Define the setup configuration
setup(
name="native-sparse-attention-triton",
version="0.1.0",
description="An efficient implementation of Native Sparse Attention using Triton",
long_description=long_description,
long_description_content_type="text/markdown",
author="XunhaoLai",
author_email="laixunhao@pku.edu.cn", # Replace with your actual email
url="https://github.com/XunhaoLai/native-sparse-attention-triton",
packages=find_packages(),
install_requires=[
"torch>=2.1.0",
"triton>=3.0.0",
"einops>=0.7.0",
"flash-attn>=2.6.3",
"transformers>=4.44.0",
],
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
],
python_requires=">=3.9",
)
================================================
FILE: test/test_compress_key_value.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific
import torch
import triton
from native_sparse_attention.ops import linear_compress
if __name__ == "__main__":
torch.manual_seed(42)
num_heads = 4
head_dim = 192
kernel_size = 32
kernel_stride = 16
seqlens = torch.LongTensor([1000, 2000, 4096]).int().cuda()
cu_seqlens = torch.cat(
[
torch.zeros(1, dtype=torch.int32, device="cuda"),
torch.cumsum(seqlens, dim=0),
],
dim=0,
).to(torch.int32)
x = (
torch.zeros(cu_seqlens[-1], num_heads, head_dim)
.uniform_(-1, 1)
.cuda()
.bfloat16()
.requires_grad_()
)
w = (
torch.zeros(num_heads, kernel_size * head_dim, head_dim)
.uniform_(-1, 1)
.cuda()
.bfloat16()
.requires_grad_()
)
pe = (
torch.zeros(num_heads, kernel_size, head_dim)
.uniform_(-1, 1)
.cuda()
.bfloat16()
.requires_grad_()
)
y, y_cu_seqlens = linear_compress(x, w, cu_seqlens, kernel_size, kernel_stride, pe)
loss = (y * torch.randn_like(y)).mean()
loss.backward()
print(y.shape, y_cu_seqlens)
print(y.norm(), x.grad.norm())
print(
w.grad.norm() if w.grad is not None else None,
pe.grad.norm() if pe.grad is not None else None,
)
# benchmark
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["N"],
x_vals=[1024 * 2**i for i in range(1, 6)],
line_arg="provider",
line_vals=["batch1", "batch8", "batch32"],
line_names=["batch1", "batch8", "batch32"],
styles=[("green", "-"), ("blue", "-"), ("blue", "--")],
ylabel="ms",
plot_name="** forward **",
args={"H": 4, "D": 128},
)
)
def benchmark(N, H, D, provider):
K, S = 32, 16
x = torch.zeros(N, H, D, device="cuda", dtype=torch.bfloat16).uniform_(-1, 1)
w = torch.zeros(H, K * D, D, device="cuda", dtype=torch.bfloat16).uniform_(
-1, 1
)
pe = torch.zeros(H, K, D, device="cuda", dtype=torch.bfloat16).uniform_(-1, 1)
cu_seqlens_b1 = torch.LongTensor([0, N]).int().cuda()
cu_seqlens_b8 = (
torch.LongTensor([N // 8 if i > 0 else 0 for i in range(9)]).int().cuda()
)
cu_seqlens_b32 = (
torch.LongTensor([N // 32 if i > 0 else 0 for i in range(33)]).int().cuda()
)
cu_seqlens_b1 = cu_seqlens_b1.cumsum(0).to(torch.int32)
cu_seqlens_b8 = cu_seqlens_b8.cumsum(0).to(torch.int32)
cu_seqlens_b32 = cu_seqlens_b32.cumsum(0).to(torch.int32)
quantiles = [0.5, 0.2, 0.8]
if provider == "batch1":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: linear_compress(x, w, cu_seqlens_b1, K, S, pe),
quantiles=quantiles,
)
if provider == "batch8":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: linear_compress(x, w, cu_seqlens_b8, K, S, pe),
quantiles=quantiles,
)
if provider == "batch32":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: linear_compress(x, w, cu_seqlens_b32, K, S, pe),
quantiles=quantiles,
)
return ms, min_ms, max_ms
benchmark.run(show_plots=True, print_data=True)
================================================
FILE: test/test_compressed_attention.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific
import torch
import triton
import math
from native_sparse_attention.ops.torch.compressed_attention import (
compressed_attention_torch,
)
from native_sparse_attention.ops.triton.compressed_attention import (
compressed_attention,
_compressed_attention_bwd,
)
from native_sparse_attention.ops import avgpool_compress
from native_sparse_attention.ops.triton.flash_attention import (
flash_attention_varlen,
_flash_attention_bwd,
)
from flash_attn import flash_attn_varlen_func
from flash_attn.flash_attn_interface import _flash_attn_varlen_backward
if __name__ == "__main__":
torch.manual_seed(42)
num_heads = 32
head_dim = 96
kernel_size = 32
kernel_stride = 16
block_size = 64
topk = 16
seqlens = torch.LongTensor([1000, 4000, 8192]).int().cuda()
cu_seqlens = torch.cat(
[
torch.zeros(1, dtype=torch.int32, device="cuda"),
torch.cumsum(seqlens, dim=0),
],
dim=0,
).to(torch.int32)
max_seqlen = seqlens.max().item()
q = (
torch.empty(cu_seqlens[-1], num_heads, head_dim, device="cuda")
.uniform_(-1, 1)
.to(torch.float16)
)
k = (
torch.empty(cu_seqlens[-1], num_heads // 4, head_dim, device="cuda")
.uniform_(-1, 1)
.to(torch.float16)
)
v = (
torch.empty(cu_seqlens[-1], num_heads // 4, head_dim, device="cuda")
.uniform_(-1, 1)
.to(torch.float16)
)
q.requires_grad = True
k.requires_grad = True
v.requires_grad = True
ck, ck_cu_seqlens = avgpool_compress(
k, None, cu_seqlens, kernel_size, kernel_stride
)
ck = torch.empty_like(ck).uniform_(-1, 1)
cv = torch.empty_like(ck).uniform_(-1, 1)
ck.requires_grad = True
cv.requires_grad = True
ck_seqlens = ck_cu_seqlens[1:] - ck_cu_seqlens[:-1]
ck_max_seqlen = ck_seqlens.max().item()
o, topk_idx = compressed_attention_torch(
q,
ck,
cv,
kernel_size,
kernel_stride,
block_size,
topk,
cu_seqlens,
ck_cu_seqlens,
max_seqlen,
ck_max_seqlen,
)
randn = torch.randn_like(o)
loss = (o * randn).sum()
loss.backward()
torch.manual_seed(42)
q1 = q.detach().clone().requires_grad_()
ck1 = ck.detach().clone().requires_grad_()
cv1 = cv.detach().clone().requires_grad_()
o1, topk_idx1 = compressed_attention(
q1,
ck1,
cv1,
kernel_size,
kernel_stride,
block_size,
topk,
cu_seqlens,
ck_cu_seqlens,
max_seqlen,
ck_max_seqlen,
)
randn1 = randn.clone().detach()
loss1 = (o1 * randn1).sum()
loss1.backward()
print("Same Output:", torch.allclose(o, o1, atol=0.01, rtol=0.01))
print("Max Error:", (o - o1).abs().max().item())
print()
print("Same Query Gradient:", torch.allclose(q.grad, q1.grad, atol=0.01, rtol=0.01))
print("Max Query Gradient Error:", (q.grad - q1.grad).abs().max().item())
print()
print("Same Key Gradient:", torch.allclose(ck.grad, ck1.grad, atol=0.01, rtol=0.01))
print("Max Key Gradient Error:", (ck.grad - ck1.grad).abs().max().item())
print()
print(
"Same Value Gradient:", torch.allclose(cv.grad, cv1.grad, atol=0.01, rtol=0.01)
)
print("Max Value Gradient Error:", (cv.grad - cv1.grad).abs().max().item())
print()
# There are some discrepancies in the topk indices (about 3%). These might be due to bugs and will be addressed later.
all_num = 0
err_num = 0
for h in range(topk_idx.shape[0]):
for i in range(topk_idx.shape[1]):
s = set(topk_idx[h, i][topk_idx[h, i] >= 0].tolist())
s1 = set(topk_idx1[h, i][topk_idx1[h, i] >= 0].tolist())
all_num += len(s)
err_num += len(s) - len(s1 & s)
print("Topk Idx Error Rate:", err_num / all_num)
# benchmark
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["N"],
x_vals=[1024 * 2**i for i in range(1, 8)],
line_arg="provider",
line_vals=[
"flash",
"triton-flash",
"triton-compressed",
"triton-compressed-wo-score",
],
line_names=[
"Flash",
"Triton-Flash",
"Compressed",
"Compressed-wo-Score",
],
styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")],
ylabel="ms",
plot_name="** forward speed for compressed attention (kernel 32 stride 16) **",
args={"H": 64, "D": 128},
)
)
def benchmark(N, H, D, provider):
q = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16)
k = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16)
v = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16)
cu_seqlens = torch.tensor([0, N], device="cuda", dtype=torch.int32)
sm_scale = 1 / math.sqrt(D)
com_k, com_cu_seqlens = avgpool_compress(k, None, cu_seqlens, 32, 16, None)
com_v, com_cu_seqlens = avgpool_compress(v, None, cu_seqlens, 32, 16, None)
M = (com_cu_seqlens[1:] - com_cu_seqlens[:-1]).max().item()
quantiles = [0.5, 0.2, 0.8]
if provider == "flash":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: flash_attn_varlen_func(
q,
k,
v,
cu_seqlens,
cu_seqlens,
N,
N,
dropout_p=0.0,
causal=True,
softmax_scale=sm_scale,
),
quantiles=quantiles,
)
if provider == "triton-flash":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: flash_attention_varlen(
q, k, v, cu_seqlens, cu_seqlens, N, N, True, sm_scale
),
quantiles=quantiles,
)
if provider == "triton-compressed":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: compressed_attention(
q,
com_k,
com_v,
32,
16,
64,
16,
cu_seqlens,
com_cu_seqlens,
N,
M,
sm_scale,
),
quantiles=quantiles,
)
if provider == "triton-compressed-wo-score":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: compressed_attention(
q,
com_k,
com_v,
32,
16,
64,
-1,
cu_seqlens,
com_cu_seqlens,
N,
M,
sm_scale,
),
quantiles=quantiles,
)
return ms, min_ms, max_ms
benchmark.run(show_plots=True, print_data=True)
# benchmark
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["N"],
x_vals=[1024 * 2**i for i in range(1, 8)],
line_arg="provider",
line_vals=[
"flash",
"triton-flash",
"triton-compressed",
],
line_names=[
"Flash",
"Triton-Flash",
"Compressed",
],
styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")],
ylabel="ms",
plot_name="** backward speed for compressed attention (kernel 32 stride 16) **",
args={"H": 64, "D": 128},
)
)
def benchmark(N, H, D, provider):
q = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16)
k = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16)
v = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16)
o = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16)
do = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16)
lse = torch.randn((H, N), device="cuda", dtype=torch.float32)
sm_scale = 1 / math.sqrt(D)
cu_seqlens = torch.tensor([0, N], device="cuda", dtype=torch.int32)
dq = torch.zeros_like(q)
dk = torch.zeros_like(k)
dv = torch.zeros_like(v)
com_k, com_cu_seqlens = avgpool_compress(k, None, cu_seqlens, 32, 16, None)
com_v, com_cu_seqlens = avgpool_compress(v, None, cu_seqlens, 32, 16, None)
M = (com_cu_seqlens[1:] - com_cu_seqlens[:-1]).max().item()
quantiles = [0.5, 0.2, 0.8]
if provider == "flash":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: _flash_attn_varlen_backward(
do,
q,
k,
v,
o,
lse.transpose(0, 1),
dq,
dk,
dv,
cu_seqlens,
cu_seqlens,
N,
N,
dropout_p=0.0,
causal=True,
softmax_scale=sm_scale,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=False,
),
quantiles=quantiles,
)
if provider == "triton-flash":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: _flash_attention_bwd(
o, do, lse, q, k, v, cu_seqlens, cu_seqlens, N, N, True, sm_scale
),
quantiles=quantiles,
)
if provider == "triton-compressed":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: _compressed_attention_bwd(
o,
do,
lse,
q,
com_k,
com_v,
32,
16,
cu_seqlens,
com_cu_seqlens,
N,
M,
sm_scale,
),
quantiles=quantiles,
)
return ms, min_ms, max_ms
benchmark.run(show_plots=True, print_data=True)
================================================
FILE: test/test_flash_attention.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific
import torch
import triton
import math
from native_sparse_attention.ops.triton.flash_attention import (
flash_attention_varlen,
_flash_attention_fwd,
_flash_attention_bwd,
)
from flash_attn import flash_attn_varlen_func
from flash_attn.flash_attn_interface import (
_flash_attn_varlen_forward,
_flash_attn_varlen_backward,
)
if __name__ == "__main__":
for causal in [False, True]:
# triton flash attention
torch.manual_seed(42)
q = torch.randn(
1000, 32, 128, dtype=torch.float16, device="cuda", requires_grad=True
)
k = torch.randn(
1000, 16, 128, dtype=torch.float16, device="cuda", requires_grad=True
)
v = torch.randn(
1000, 16, 128, dtype=torch.float16, device="cuda", requires_grad=True
)
cu_seqlens_q = torch.Tensor([0, 100, 384, 1000]).cuda().to(torch.int32)
cu_seqlens_k = torch.Tensor([0, 100, 384, 1000]).cuda().to(torch.int32)
max_seqlen_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max()
max_seqlen_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max()
o = flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
causal=causal,
)
randn = torch.randn_like(o)
loss = (o * randn).sum()
loss.backward()
# flash attention
torch.manual_seed(42)
q1 = q.clone().detach().requires_grad_()
k1 = k.clone().detach().requires_grad_()
v1 = v.clone().detach().requires_grad_()
cu_seqlens_q1 = cu_seqlens_q.clone().detach()
cu_seqlens_k1 = cu_seqlens_k.clone().detach()
max_seqlen_q1 = (cu_seqlens_q1[1:] - cu_seqlens_q1[:-1]).max()
max_seqlen_k1 = (cu_seqlens_k1[1:] - cu_seqlens_k1[:-1]).max()
o1 = flash_attention_varlen(
q1,
k1,
v1,
cu_seqlens_q1,
cu_seqlens_k1,
max_seqlen_q1,
max_seqlen_k1,
causal=causal,
)
randn2 = randn.clone().detach()
loss2 = (o1 * randn2).sum()
loss2.backward()
# diff
print(
f"=== Flash Attention Backward Test ({'causal' if causal else 'full'}) ==="
)
print("Same Output:", torch.allclose(o, o1, atol=0.01, rtol=0.01))
print("Max Error:", (o - o1).abs().max().item())
print()
print(
"Same Query Gradient:",
torch.allclose(q.grad, q1.grad, atol=0.01, rtol=0.01),
)
print("Max Query Gradient Error:", (q.grad - q1.grad).abs().max().item())
print()
print(
"Same Key Gradient:", torch.allclose(k.grad, k1.grad, atol=0.01, rtol=0.01)
)
print("Max Key Gradient Error:", (k.grad - k1.grad).abs().max().item())
print()
print(
"Same Value Gradient:",
torch.allclose(v.grad, v1.grad, atol=0.01, rtol=0.01),
)
print("Max Value Gradient Error:", (v.grad - v1.grad).abs().max().item())
print()
# benchmark
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["N"],
x_vals=[1024 * 2**i for i in range(1, 6)],
line_arg="provider",
line_vals=["flash", "triton-flash"],
line_names=[
"Flash",
"Triton-Flash",
],
styles=[("green", "-"), ("green", "--")],
ylabel="ms",
plot_name="** forward **",
args={"H": 64, "D": 128},
)
)
def benchmark(N, H, D, provider):
q = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16)
k = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16)
v = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16)
cu_seqlens = torch.tensor([0, N], device="cuda", dtype=torch.int32)
sm_scale = 1 / math.sqrt(D)
quantiles = [0.5, 0.2, 0.8]
if provider == "flash":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: _flash_attn_varlen_forward(
q,
k,
v,
cu_seqlens,
cu_seqlens,
N,
N,
dropout_p=0.0,
causal=True,
softmax_scale=sm_scale,
),
quantiles=quantiles,
)
if provider == "triton-flash":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: _flash_attention_fwd(
q, k, v, cu_seqlens, cu_seqlens, N, N, True, sm_scale
),
quantiles=quantiles,
)
return ms, min_ms, max_ms
benchmark.run(show_plots=True, print_data=True)
# benchmark
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["N"],
x_vals=[1024 * 2**i for i in range(1, 6)],
line_arg="provider",
line_vals=["flash", "triton-flash"],
line_names=[
"Flash",
"Triton-Flash",
],
styles=[("green", "-"), ("green", "--")],
ylabel="ms",
plot_name="** backward **",
args={"H": 64, "D": 128},
)
)
def benchmark(N, H, D, provider):
q = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16)
k = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16)
v = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16)
o = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16)
do = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16)
lse = torch.randn((H, N), device="cuda", dtype=torch.float32)
sm_scale = 1 / math.sqrt(D)
cu_seqlens = torch.tensor([0, N], device="cuda", dtype=torch.int32)
dq = torch.zeros_like(q)
dk = torch.zeros_like(k)
dv = torch.zeros_like(v)
quantiles = [0.5, 0.2, 0.8]
if provider == "flash":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: _flash_attn_varlen_backward(
do,
q,
k,
v,
o,
lse.transpose(0, 1),
dq,
dk,
dv,
cu_seqlens,
cu_seqlens,
N,
N,
dropout_p=0.0,
causal=True,
softmax_scale=sm_scale,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=False,
),
quantiles=quantiles,
)
if provider == "triton-flash":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: _flash_attention_bwd(
o, do, lse, q, k, v, cu_seqlens, cu_seqlens, N, N, True, sm_scale
),
quantiles=quantiles,
)
return ms, min_ms, max_ms
benchmark.run(show_plots=True, print_data=True)
================================================
FILE: test/test_kv_cache.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from native_sparse_attention.module.kv_cache import NSACache
if __name__ == "__main__":
from native_sparse_attention.ops import avgpool_compress
torch.manual_seed(42)
num_heads = 4
head_dim = 128
seqlens = torch.tensor([12, 576, 12000]).to(torch.int32).cuda()
batch_size = seqlens.shape[0]
cu_seqlens = torch.zeros(seqlens.shape[0] + 1, dtype=torch.int32, device="cuda")
cu_seqlens[1:] = seqlens.cumsum(0)
# init cache
cache = NSACache(4, 16384, num_heads, head_dim, 32, 16, 512, torch.bfloat16, "cuda")
# test prefill
step = 0
k = torch.randn(cu_seqlens[-1], num_heads, head_dim).cuda().bfloat16()
v = torch.randn_like(k)
ck, _ = avgpool_compress(k, None, cu_seqlens, 32, 16, None)
cv, _ = avgpool_compress(v, None, cu_seqlens, 32, 16, None)
cache.prepare_compress(cu_seqlens, step, k, v)
cache.update_kv(cu_seqlens, step, ck, cv, k, v, k, v)
# test decode
step = 1
k = torch.randn(batch_size, num_heads, head_dim).cuda().bfloat16()
v = torch.randn_like(k)
ck = torch.randn_like(k)
cv = torch.randn_like(v)
cache.prepare_compress(cu_seqlens, step, k, v)
cache.update_kv(cu_seqlens, step, ck, cv, k, v, k, v)
================================================
FILE: test/test_linear_compress.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import triton
from native_sparse_attention.ops.torch.compress_key_value import linear_compress_torch
from native_sparse_attention.ops.triton.linear_compress import linear_compress
def test_linear_compress(
batch_size: int = 1,
num_heads: int = 1,
head_dim: int = 32,
max_seqlen: int = 32,
kernel_sizes: list = [16, 32],
kernel_strides: list = [8, 16],
use_pe: bool = True,
dtype: torch.dtype = torch.float32,
device: str = "cuda",
):
"""
Test both PyTorch and Triton implementations of linear_compress for equivalence,
including forward and backward passes.
Args:
batch_size: Number of sequences in the batch
num_heads: Number of attention heads
head_dim: Dimension of each attention head
max_seqlen: Maximum sequence length
kernel_sizes: List of kernel sizes to test
kernel_strides: List of kernel strides to test
use_pe: Whether to test with positional encoding
dtype: Data type for tensors
device: Device to run the test on
"""
torch.manual_seed(42)
# Generate random sequence lengths for each batch
seqlens = torch.randint(
low=kernel_sizes[0], # minimum length should be at least kernel_size
high=max_seqlen + 1,
size=(batch_size,),
device=device,
)
# seqlens[:] = max_seqlen
cu_seqlens = torch.cat(
[
torch.tensor([0], device=device, dtype=torch.int32),
torch.cumsum(seqlens, dim=0).to(torch.int32),
]
)
total_len = cu_seqlens[-1].item()
for kernel_size, kernel_stride in zip(kernel_sizes, kernel_strides):
print(f"\nTesting kernel_size={kernel_size}, kernel_stride={kernel_stride}")
# Create input tensors with requires_grad=True
x_torch = torch.zeros(
(total_len, num_heads, head_dim),
dtype=dtype,
device=device,
).uniform_(-1, 1)
x_torch.requires_grad_(True)
x_triton = x_torch.clone().detach().requires_grad_(True)
w_torch = (
torch.ones(
(num_heads, kernel_size * head_dim, head_dim),
dtype=dtype,
device=device,
)
/ kernel_size
)
w_torch.requires_grad_(True)
w_triton = w_torch.clone().detach().requires_grad_(True)
pe_torch = None
pe_triton = None
if use_pe:
pe_torch = torch.randn(
(num_heads, kernel_size, head_dim),
dtype=dtype,
device=device,
requires_grad=True,
)
pe_triton = pe_torch.clone().detach().requires_grad_(True)
# Run forward passes
y_torch, y_cu_seqlens_torch = linear_compress_torch(
x=x_torch,
w=w_torch,
cu_seqlens=cu_seqlens,
kernel_size=kernel_size,
kernel_stride=kernel_stride,
pe=pe_torch,
)
y_triton, y_cu_seqlens_triton = linear_compress(
x=x_triton,
w=w_triton,
cu_seqlens=cu_seqlens,
kernel_size=kernel_size,
kernel_stride=kernel_stride,
pe=pe_triton,
)
# Check forward pass numerical equivalence
atol, rtol = 1e-2, 1e-2
values_match = torch.allclose(y_torch, y_triton, atol=atol, rtol=rtol)
print(
f"Forward pass - Output values match (atol={atol}, rtol={rtol}): {values_match}"
)
if not values_match:
max_diff = (y_torch - y_triton).abs().max().item()
print(f"Forward pass - Maximum difference: {max_diff}")
print("\nSample values (first batch, first head):")
print("Torch:", y_torch[0, 0, :5])
print("Triton:", y_triton[0, 0, :5])
# Create random output gradients for backward pass
grad_output = torch.randn_like(y_torch)
# Run backward passes
y_torch.backward(grad_output)
y_triton.backward(grad_output)
# Check gradient equivalence
print("\nTesting backward pass:")
# Check x gradients
x_grads_match = torch.allclose(
x_torch.grad, x_triton.grad, atol=atol, rtol=rtol
)
print(f"x gradients match (atol={atol}, rtol={rtol}): {x_grads_match}")
if not x_grads_match:
max_diff = (x_torch.grad - x_triton.grad).abs().max().item()
print(f"x gradients - Maximum difference: {max_diff}")
print("\nSample x gradients (first batch, first head):")
print("Torch:", x_torch.grad[0, 0, :5])
print("Triton:", x_triton.grad[0, 0, :5])
# Check w gradients
w_grads_match = torch.allclose(
w_torch.grad, w_triton.grad, atol=atol, rtol=rtol
)
print(f"w gradients match (atol={atol}, rtol={rtol}): {w_grads_match}")
if not w_grads_match:
max_diff = (w_torch.grad - w_triton.grad).abs().max().item()
print(f"w gradients - Maximum difference: {max_diff}")
print("\nSample w gradients (first head):")
print("Torch:", w_torch.grad[0, :5, 0])
print("Triton:", w_triton.grad[0, :5, 0])
# Check pe gradients if used
if use_pe:
pe_grads_match = torch.allclose(
pe_torch.grad, pe_triton.grad, atol=atol, rtol=rtol
)
print(f"pe gradients match (atol={atol}, rtol={rtol}): {pe_grads_match}")
if not pe_grads_match:
max_diff = (pe_torch.grad - pe_triton.grad).abs().max().item()
print(f"pe gradients - Maximum difference: {max_diff}")
print("\nSample pe gradients (first head):")
print("Torch:", pe_torch.grad[0, :5, 0])
print("Triton:", pe_triton.grad[0, :5, 0])
# Clean up gradients for next iteration
x_torch.grad = None
x_triton.grad = None
w_torch.grad = None
w_triton.grad = None
if use_pe:
pe_torch.grad = None
pe_triton.grad = None
if __name__ == "__main__":
# Run tests
test_linear_compress(
batch_size=16,
num_heads=8,
head_dim=128,
max_seqlen=2048,
kernel_sizes=[32],
kernel_strides=[16],
use_pe=False,
dtype=torch.float16,
device="cuda",
)
# benchmark
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["N"],
x_vals=[1024 * 2**i for i in range(1, 8)],
line_arg="provider",
line_vals=["torch", "triton"],
line_names=["torch", "triton"],
styles=[("green", "-"), ("blue", "-")],
ylabel="ms",
plot_name="** forward + backward **",
args={"H": 4, "D": 64},
)
)
def benchmark_fwdbwd(N, H, D, provider):
K, S = 32, 16
# Input tensors
x = torch.zeros(N, H, D, device="cuda", dtype=torch.bfloat16).uniform_(-1, 1)
x.requires_grad = True
w = torch.zeros(H, K * D, D, device="cuda", dtype=torch.bfloat16).uniform_(
-1, 1
)
w.requires_grad = True
pe = torch.zeros(H, K, D, device="cuda", dtype=torch.bfloat16).uniform_(-1, 1)
cu_seqlens_b32 = (
torch.LongTensor(
[0 if i == 0 else 32 if i > 1 else N - 32 * 31 for i in range(33)]
)
.int()
.cuda()
)
cu_seqlens_b32 = cu_seqlens_b32.cumsum(0).to(torch.int32)
quantiles = [0.5, 0.2, 0.8]
def fwd_bwd():
if provider == "torch":
out, _ = linear_compress_torch(x, w, cu_seqlens_b32, K, S, pe)
else:
out, _ = linear_compress(x, w, cu_seqlens_b32, K, S, pe)
out.backward(out) # Using output as gradient for simplicity
return out
ms, min_ms, max_ms = triton.testing.do_bench(fwd_bwd, quantiles=quantiles)
return ms, min_ms, max_ms
benchmark_fwdbwd.run(show_plots=True, print_data=True)
================================================
FILE: test/test_nsa_infer.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from native_sparse_attention.ops import linear_compress, weightedpool_compress
from native_sparse_attention.module import NSACache, RotaryEmbedding, RopeConfig
from native_sparse_attention.infer import nsa_infer
if __name__ == "__main__":
torch.manual_seed(42)
num_heads = 4
head_dim = 128
kernel_size = 32
kernel_stride = 16
block_size = 64
window_size = 512
topk = 16
init_blocks = 1
local_blocks = 2
# init seqlens
seqlens = torch.tensor([12, 576, 12000]).to(torch.int32).cuda()
batch_size = seqlens.shape[0]
cu_seqlens = torch.zeros(seqlens.shape[0] + 1, dtype=torch.int32, device="cuda")
cu_seqlens[1:] = seqlens.cumsum(0)
step = 0
# init cache and weight and rope
cache = NSACache(4, 16384, num_heads, head_dim, 32, 16, 512, torch.bfloat16, "cuda")
compress_weight = [
torch.ones(num_heads, kernel_size * head_dim, head_dim).cuda().bfloat16()
/ (kernel_size * head_dim),
torch.ones(num_heads, kernel_size).cuda().bfloat16() / kernel_size,
]
compress_func = [linear_compress, weightedpool_compress]
rope = RotaryEmbedding(
RopeConfig(
max_position_embeddings=131072,
head_dim=128,
rope_theta=500000,
rope_scaling={
"factor": 8.0,
"high_freq_factor": 4.0,
"low_freq_factor": 1.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3",
},
)
)
# test prefill
q = torch.randn(cu_seqlens[-1], num_heads * 16, head_dim).cuda().bfloat16()
k = torch.randn(cu_seqlens[-1], num_heads, head_dim).cuda().bfloat16()
v = torch.randn_like(k)
g = torch.rand(cu_seqlens[-1], num_heads * 16, 3).cuda().bfloat16()
o = nsa_infer(
cu_seqlens,
step,
q,
k,
v,
g,
rope,
cache,
compress_weight,
compress_func,
None,
kernel_size,
kernel_stride,
block_size,
topk,
init_blocks,
local_blocks,
window_size,
)
print(o.shape, o.norm())
# test decode
q = torch.randn(cu_seqlens.shape[0] - 1, num_heads * 16, head_dim).cuda().bfloat16()
k = torch.randn(cu_seqlens.shape[0] - 1, num_heads, head_dim).cuda().bfloat16()
v = torch.randn_like(k)
g = torch.rand(cu_seqlens.shape[0] - 1, num_heads * 16, 3).cuda().bfloat16()
step = 1
o = nsa_infer(
cu_seqlens,
step,
q,
k,
v,
g,
rope,
cache,
compress_weight,
compress_func,
None,
kernel_size,
kernel_stride,
block_size,
topk,
init_blocks,
local_blocks,
window_size,
)
print(o.shape, o.norm())
================================================
FILE: test/test_nsa_model.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from native_sparse_attention.model import (
ToyNSALlamaConfig,
InferenceConfig,
ToyNSALlama,
)
if __name__ == "__main__":
torch.manual_seed(42)
# initialize model
config = ToyNSALlamaConfig(
hidden_size=4096,
intermediate_size=14336,
num_hidden_layers=8,
num_attention_heads=32,
num_key_value_heads=2,
head_dim=128,
rope_theta=500000.0,
rope_scaling={
"factor": 8.0,
"high_freq_factor": 4.0,
"low_freq_factor": 1.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3",
},
compress_type="weightedpool",
kernel_size=32,
kernel_stride=16,
block_size=64,
topk=8,
init_blocks=1,
local_blocks=2,
window_size=512,
)
inference_config = InferenceConfig(
max_batch_size=4,
max_length=8192,
max_new_tokens=128,
)
model = ToyNSALlama(config, inference_config).cuda().bfloat16()
print(f"\nMODEL CONFIG:\n{config}\n")
print(f"\nINFERENCE CONFIG:\n{inference_config}\n")
print(f"\nMODEL:\n{model}\n")
# example input
batch_size = 4
seqlens = torch.randint(0, 4096, (batch_size,), dtype=torch.int32, device="cuda")
cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
input_ids = torch.randint(
0, 128288, (cu_seqlens[-1],), dtype=torch.int64, device="cuda"
)
print(f"\nEXAMPLE INPUT:\ncu_seqlens: {cu_seqlens}\ninput_ids: {input_ids.shape}\n")
# example output
logits = model(input_ids, cu_seqlens)
print(f"\nEXAMPLE OUTPUT:\nlogits: {logits.shape}\n")
# example generate
output_tokens = model.generate(input_ids, cu_seqlens, 64)
print(f"\nEXAMPLE GENERATE:\noutput_tokens: {output_tokens}\n")
================================================
FILE: test/test_nsa_module.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific
import torch
import triton
from native_sparse_attention.module import (
SelfAttention,
NativeSparseAttention,
RopeConfig,
)
if __name__ == "__main__":
torch.manual_seed(42)
NSA = (
NativeSparseAttention(
compress_type="avgpool",
hidden_size=8192,
num_q_heads=64,
num_kv_heads=4,
head_dim=128,
kernel_size=32,
kernel_stride=16,
block_size=64,
topk=16,
init_blocks=1,
local_blocks=2,
window_size=512,
rope_config=RopeConfig(
max_position_embeddings=131072,
head_dim=128,
rope_theta=500000,
rope_scaling={
"factor": 8.0,
"high_freq_factor": 4.0,
"low_freq_factor": 1.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3",
},
),
)
.cuda()
.to(torch.bfloat16)
)
print("======= Init Moduel: Native Sparse Attention =======\n")
for name, param in NSA.named_parameters():
print(f"NSA Parameters, {name}, shape: {param.shape}\n")
# random input
seqlens = torch.LongTensor([4000, 8192, 16384]).int().cuda()
cu_seqlens = torch.cat(
[
torch.zeros(1, dtype=torch.int32, device="cuda"),
torch.cumsum(seqlens, dim=0),
],
dim=0,
).to(torch.int32)
x = torch.zeros(cu_seqlens[-1], 8192, device="cuda", dtype=torch.bfloat16).uniform_(
-1, 1
)
# forward test
print("======= NSA Forward & Backward Test =======\n")
y = NSA(x, cu_seqlens)
print(f"Forward, output shape: {y.shape}, output norm: {y.norm()}\n")
# backward test
loss = (y * torch.randn_like(y)).sum(-1).mean()
loss.backward()
for name, param in NSA.named_parameters():
print(
f"Backward, {name}, grad shape: {param.grad.shape}, grad norm: {param.grad.norm()}\n"
)
# speed benchmark
SelfAttn = (
SelfAttention(
hidden_size=8192,
num_q_heads=64,
num_kv_heads=4,
head_dim=128,
rope_config=RopeConfig(
max_position_embeddings=131072,
head_dim=128,
rope_theta=500000,
rope_scaling={
"factor": 8.0,
"high_freq_factor": 4.0,
"low_freq_factor": 1.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3",
},
),
)
.cuda()
.to(torch.bfloat16)
)
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["N"],
x_vals=[1024 * 2**i for i in range(1, 8)],
line_arg="provider",
line_vals=["Self-Attention", "Native-Sparse-Attention"],
line_names=["Self-Attention", "Native-Sparse-Attention"],
styles=[("green", "-"), ("blue", "-")],
ylabel="ms",
plot_name="** NSA forward speed benchmark **",
args={},
)
)
def benchmark(N, provider):
x = torch.randn(N, 8192, device="cuda", dtype=torch.bfloat16)
cu_seqlens = torch.tensor([0, N], device="cuda", dtype=torch.int32)
quantiles = [0.5, 0.2, 0.8]
with torch.no_grad():
if provider == "Self-Attention":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: SelfAttn(x, cu_seqlens),
quantiles=quantiles,
)
if provider == "Native-Sparse-Attention":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: NSA(x, cu_seqlens),
quantiles=quantiles,
)
return ms, min_ms, max_ms
benchmark.run(show_plots=True, print_data=True)
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["N"],
x_vals=[1024 * 2**i for i in range(1, 8)],
line_arg="provider",
line_vals=["Self-Attention", "Native-Sparse-Attention"],
line_names=["Self-Attention", "Native-Sparse-Attention"],
styles=[("green", "-"), ("blue", "-")],
ylabel="ms",
plot_name="** NSA backward speed benchmark **",
args={},
)
)
def benchmark(N, provider):
x = torch.randn(N, 8192, device="cuda", dtype=torch.bfloat16)
cu_seqlens = torch.tensor([0, N], device="cuda", dtype=torch.int32)
quantiles = [0.5, 0.2, 0.8]
if provider == "Self-Attention":
loss = SelfAttn(x.clone().detach().requires_grad_(), cu_seqlens).mean()
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: loss.backward(retain_graph=True),
quantiles=quantiles,
)
elif provider == "Native-Sparse-Attention":
loss = NSA(x.clone().detach().requires_grad_(), cu_seqlens).mean()
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: loss.backward(retain_graph=True),
quantiles=quantiles,
)
return ms, min_ms, max_ms
benchmark.run(show_plots=True, print_data=True)
================================================
FILE: test/test_rope.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from native_sparse_attention.module import RopeConfig, RotaryEmbedding
if __name__ == "__main__":
rope_config = RopeConfig(
max_position_embeddings=131072,
head_dim=128,
rope_theta=500000,
rope_scaling={
"factor": 8.0,
"high_freq_factor": 4.0,
"low_freq_factor": 1.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3",
},
)
rope = RotaryEmbedding(rope_config, "cuda")
# random input
torch.manual_seed(42)
seqlens = torch.LongTensor([1000, 2000, 4096]).int().cuda()
cu_seqlens = torch.cat(
[
torch.zeros(1, dtype=torch.int32, device="cuda"),
torch.cumsum(seqlens, dim=0),
],
dim=0,
).to(torch.int32)
x = torch.zeros(
cu_seqlens[-1], 32, 128, device="cuda", dtype=torch.bfloat16
).uniform_(-1, 1)
y = rope(x, cu_seqlens)
================================================
FILE: test/test_topk_sparse_attention.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific
import torch
import triton
import math
from native_sparse_attention.ops.torch.topk_sparse_attention import (
topk_sparse_attention_torch,
)
from native_sparse_attention.ops.triton.topk_sparse_attention import (
topk_sparse_attention,
_topk_sparse_attention_fwd,
_topk_sparse_attention_bwd,
)
from native_sparse_attention.ops.triton.flash_attention import (
_flash_attention_fwd,
_flash_attention_bwd,
)
from flash_attn.flash_attn_interface import (
_flash_attn_varlen_forward,
_flash_attn_varlen_backward,
)
def generate_topk_idx_example(
seqlens: torch.Tensor,
block_size_k: int,
topk: int,
num_heads: int,
block_size_q: int = 1,
) -> torch.Tensor:
"""Generate topk idx example for test.
Args:
seqlens (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens in flash_attn_func_varlen.
block_size_q (int): query block size
block_size_k (int): key value block size
topk (int): selected topk
num_heads (int): number of key value heads
Returns:
torch.Tensor: shape [num_heads, total_seqlen, topk], topk key value block idx for each query. -1 means padding.
"""
batch_size = seqlens.shape[0]
num_blocks = torch.ceil(seqlens / block_size_k).to(torch.int32)
topk_idx_all_heads = []
cu_seqlens = torch.nn.functional.pad(seqlens.cumsum(0), pad=(1, 0), value=0)
for _ in range(num_heads):
topk_idx = [
torch.randn(seqlens[i], num_blocks[i], device="cuda")
.topk(min(topk, num_blocks[i]), dim=-1)
.indices.to(torch.int32)
for i in range(batch_size)
]
topk_idx = [
torch.nn.functional.pad(
topk_idx[i], (0, topk - topk_idx[i].shape[-1]), value=topk
)
for i in range(batch_size)
]
topk_idx = torch.cat(topk_idx, dim=0)
topk_idx = torch.sort(topk_idx, dim=1).values
topk_idx[:, 0] = 0
q_idx = torch.cat(
[torch.arange(seqlens[i], device="cuda") for i in range(batch_size)], dim=0
)
topk_idx[topk_idx > (q_idx // block_size_k)[:, None]] = -1 # -1 means padding
topk_idx = torch.cat(
[
topk_idx[cu_seqlens[i] : cu_seqlens[i + 1]][0::block_size_q]
for i in range(batch_size)
],
dim=0,
)
topk_idx_all_heads.append(topk_idx)
topk_idx = torch.stack(topk_idx_all_heads, dim=0)
return topk_idx
if __name__ == "__main__":
torch.manual_seed(42)
batch_size = 3
seqlens = torch.LongTensor([1000, 2000, 4096]).int().cuda()
cu_seqlens = torch.cat(
[
torch.zeros(1, dtype=torch.int32, device="cuda"),
torch.cumsum(seqlens, dim=0),
],
dim=0,
).to(torch.int32)
max_seqlen = seqlens.max().item()
q = (
torch.empty(cu_seqlens[-1], 64, 96, device="cuda")
.uniform_(-1, 1)
.to(torch.float16)
)
k = (
torch.empty(cu_seqlens[-1], 8, 96, device="cuda")
.uniform_(-1, 1)
.to(torch.float16)
)
v = (
torch.empty(cu_seqlens[-1], 8, 96, device="cuda")
.uniform_(-1, 1)
.to(torch.float16)
)
q.requires_grad = True
k.requires_grad = True
v.requires_grad = True
block_size = 64
topk = 5
topk_idx = generate_topk_idx_example(seqlens, block_size, topk, 8)
o = topk_sparse_attention_torch(q, k, v, topk_idx, block_size, cu_seqlens)
randn = torch.randn_like(o)
loss = (o * randn).sum()
loss.backward()
torch.manual_seed(42)
q1 = q.clone().detach().requires_grad_()
k1 = k.clone().detach().requires_grad_()
v1 = v.clone().detach().requires_grad_()
topk_idx1 = topk_idx.clone().detach()
cu_seqlens1 = cu_seqlens.clone().detach()
o1 = topk_sparse_attention(q1, k1, v1, topk_idx, block_size, cu_seqlens)
randn2 = randn.clone().detach()
loss2 = (o1 * randn2).sum()
loss2.backward()
print("Same Output:", torch.allclose(o, o1, atol=0.01, rtol=0.01))
print("Max Error:", (o - o1).abs().max().item())
print()
print("Same Query Gradient:", torch.allclose(q.grad, q1.grad, atol=0.01, rtol=0.01))
print("Max Query Gradient Error:", (q.grad - q1.grad).abs().max().item())
print()
print("Same Key Gradient:", torch.allclose(k.grad, k1.grad, atol=0.01, rtol=0.01))
print("Max Key Gradient Error:", (k.grad - k1.grad).abs().max().item())
print()
print("Same Value Gradient:", torch.allclose(v.grad, v1.grad, atol=0.01, rtol=0.01))
print("Max Value Gradient Error:", (v.grad - v1.grad).abs().max().item())
print()
# benchmark
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["N"],
x_vals=[1024 * 2**i for i in range(1, 8)],
line_arg="provider",
line_vals=["flash", "triton-flash", "triton-top8", "triton-top16"],
line_names=[
"Flash",
"Triton-Flash",
"Triton-Top8",
"Triton-Top16",
],
styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")],
ylabel="ms",
plot_name="** forward with block size 64 **",
args={"H": 64, "D": 128, "K": 64},
)
)
def benchmark(N, H, D, K, provider):
q = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16)
k = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16)
v = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16)
cu_seqlens = torch.tensor([0, N], device="cuda", dtype=torch.int32)
sm_scale = 1 / math.sqrt(D)
top8_idx = generate_topk_idx_example(cu_seqlens[1:], K, 8, H // 16)
top16_idx = generate_topk_idx_example(cu_seqlens[1:], K, 16, H // 16)
quantiles = [0.5, 0.2, 0.8]
if provider == "flash":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: _flash_attn_varlen_forward(
q,
k,
v,
cu_seqlens,
cu_seqlens,
N,
N,
dropout_p=0.0,
causal=True,
softmax_scale=sm_scale,
),
quantiles=quantiles,
)
if provider == "triton-flash":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: _flash_attention_fwd(
q, k, v, cu_seqlens, cu_seqlens, N, N, True, sm_scale
),
quantiles=quantiles,
)
if provider == "triton-top8":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: _topk_sparse_attention_fwd(
q, k, v, top8_idx, K, cu_seqlens, cu_seqlens, N, N, sm_scale
),
quantiles=quantiles,
)
if provider == "triton-top16":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: _topk_sparse_attention_fwd(
q, k, v, top16_idx, K, cu_seqlens, cu_seqlens, N, N, sm_scale
),
quantiles=quantiles,
)
return ms, min_ms, max_ms
benchmark.run(show_plots=True, print_data=True)
# benchmark
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["N"],
x_vals=[1024 * 2**i for i in range(1, 8)],
line_arg="provider",
line_vals=["flash", "triton-flash", "triton-top8", "triton-top16"],
line_names=[
"Flash",
"Triton-Flash",
"Triton-Top8",
"Triton-Top16",
],
styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")],
ylabel="ms",
plot_name="** backward with block size 64 **",
args={"H": 64, "D": 128, "K": 64},
)
)
def benchmark(N, H, D, K, provider):
q = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16)
k = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16)
v = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16)
o = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16)
do = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16)
lse = torch.randn((H, N), device="cuda", dtype=torch.float32)
sm_scale = 1 / math.sqrt(D)
cu_seqlens = torch.tensor([0, N], device="cuda", dtype=torch.int32)
top8_idx = generate_topk_idx_example(cu_seqlens[1:], K, 8, H // 16)
top16_idx = generate_topk_idx_example(cu_seqlens[1:], K, 16, H // 16)
dq = torch.zeros_like(q)
dk = torch.zeros_like(k)
dv = torch.zeros_like(v)
quantiles = [0.5, 0.2, 0.8]
if provider == "flash":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: _flash_attn_varlen_backward(
do,
q,
k,
v,
o,
lse.transpose(0, 1),
dq,
dk,
dv,
cu_seqlens,
cu_seqlens,
N,
N,
dropout_p=0.0,
causal=True,
softmax_scale=sm_scale,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=False,
),
quantiles=quantiles,
)
if provider == "triton-flash":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: _flash_attention_bwd(
o, do, lse, q, k, v, cu_seqlens, cu_seqlens, N, N, True, sm_scale
),
quantiles=quantiles,
)
if provider == "triton-top8":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: _topk_sparse_attention_bwd(
o,
do,
lse,
q,
k,
v,
top8_idx,
K,
cu_seqlens,
cu_seqlens,
N,
N,
sm_scale,
),
quantiles=quantiles,
)
if provider == "triton-top16":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: _topk_sparse_attention_bwd(
o,
do,
lse,
q,
k,
v,
top16_idx,
K,
cu_seqlens,
cu_seqlens,
N,
N,
sm_scale,
),
quantiles=quantiles,
)
return ms, min_ms, max_ms
benchmark.run(show_plots=True, print_data=True)