Showing preview only (376K chars total). Download the full file or copy to clipboard to get everything.
Repository: XunhaoLai/native-sparse-attention-triton
Branch: main
Commit: 9bea856c911e
Files: 44
Total size: 359.0 KB
Directory structure:
gitextract_5vrtmrhk/
├── .gitignore
├── LICENSE
├── README.md
├── install_dependency.sh
├── native_sparse_attention/
│ ├── __init__.py
│ ├── infer/
│ │ ├── __init__.py
│ │ ├── inference_func.py
│ │ └── nsa_inference.py
│ ├── model/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── toy_llama.py
│ │ └── toy_nsa_llama.py
│ ├── module/
│ │ ├── __init__.py
│ │ ├── kv_cache.py
│ │ ├── native_sparse_attention.py
│ │ ├── rope.py
│ │ └── self_attention.py
│ └── ops/
│ ├── README.md
│ ├── __init__.py
│ ├── torch/
│ │ ├── __init__.py
│ │ ├── compress_key_value.py
│ │ ├── compressed_attention.py
│ │ ├── compressed_attention_decode.py
│ │ └── topk_sparse_attention.py
│ └── triton/
│ ├── __init__.py
│ ├── compressed_attention.py
│ ├── flash_attention.py
│ ├── flash_attention_decode.py
│ ├── linear_compress.py
│ ├── topk_sparse_attention.py
│ ├── topk_sparse_attention_decode.py
│ ├── utils.py
│ └── weighted_pool.py
├── setup.py
└── test/
├── test_compress_key_value.py
├── test_compressed_attention.py
├── test_flash_attention.py
├── test_kv_cache.py
├── test_linear_compress.py
├── test_nsa_infer.py
├── test_nsa_model.py
├── test_nsa_module.py
├── test_rope.py
└── test_topk_sparse_attention.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# UV
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
#uv.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# PyPI configuration file
.pypirc
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
<div align="center">
# Native Sparse Attention Triton
</div>
This repository implements the sparse attention mechanism introduced in the paper [Native Sparse Attention: Hardware-Aligned and Natively Trainable Sparse Attention](https://arxiv.org/abs/2502.11089) and provides an efficient training implementation based on [Triton](https://github.com/triton-lang/triton).
🎉 We now support both training and inference for Native Sparse Attention (variable-length version, including prefilling, decoding, and KV cache management). We have provided a toy model at `model.ToyNSALlama`, which supports `forward` function for training and `generate` function for inference. Welcome to try it out!
## Requirements
Ensure the following dependencies are installed:
- PyTorch >= 2.1.0
- triton >= 3.0.0
- einops >= 0.7.0
- flash_attn >= 2.6.3
## Usage
### Notes
1. PyTorch implementations (`ops.torch`) are intended for debugging only.
2. For production use, prefer Triton operators (`ops.triton`).
3. All implementations are based on the varlen approach similiar to flash_attn_func_varlen. Please concatenate the inputs of a batch before use.
4. Only support attention head dimension less than 128 for now.
### Install
You can install `native_sparse_attention` using pip:
```shell
pip install git+https://github.com/XunhaoLai/native-sparse-attention-triton.git
```
### Functions
The `ops` module has implemented several functions required for native sparse attention. For detailed usage instructions, please see [this link](https://github.com/XunhaoLai/native-sparse-attention-triton/tree/main/native_sparse_attention/ops#readme).
You can import those functions from the `ops` module:
```python
import torch
from native_sparse_attention.ops import linear_compress, compressed_attention, topk_sparse_attention
# input example
num_q_heads = 64
num_kv_heads = 4
head_dim = 128
kernel_size = 32
kernel_stride = 16
block_size = 64
topk = 16
cu_seqlens = torch.Tensor([0, 1024, 8192, 16384]).to(torch.int32).cuda()
query = torch.randn(16384, num_q_heads, head_dim).to(torch.bfloat16).cuda()
key = torch.randn(16384, num_kv_heads, head_dim).to(torch.bfloat16).cuda()
value = torch.randn(16384, num_kv_heads, head_dim).to(torch.bfloat16).cuda()
# weight example
w = (
torch.randn(num_kv_heads, kernel_size * head_dim, head_dim)
.to(torch.bfloat16)
.cuda()
)
pe = torch.randn(num_kv_heads, kernel_size, head_dim).to(torch.bfloat16).cuda()
# 1. key value compression
compressed_key, compressed_cu_seqlens = linear_compress(
key, w, cu_seqlens, kernel_size, kernel_stride, pe
)
compressed_value, _ = linear_compress(
value, w, cu_seqlens, kernel_size, kernel_stride, None
)
# 2. attention between query and compressed key value
compressed_attn_output, topk_idx = compressed_attention(
query,
compressed_key,
compressed_value,
kernel_size,
kernel_stride,
block_size,
topk,
cu_seqlens,
compressed_cu_seqlens,
init_blocks=1,
local_blocks=2,
)
# 3. topk sparse attention
sparse_attn_output = topk_sparse_attention(
query,
key,
value,
topk_idx,
block_size,
cu_seqlens,
)
```
### Module
The `modules` directory also provides implementations based on `torch.nn.module` for easy integration into models.
```python
from native_sparse_attention.modules import NativeSparseAttention, RopeConfig
NSA_Layer = NativeSparseAttention(
compress_type="linear",
hidden_size=4096,
num_q_heads=64,
num_kv_heads=4,
head_dim=128,
kernel_size=32,
kernel_stride=16,
block_size=64,
topk=8,
init_blocks=1,
local_blocks=2,
window_size=512,
rope_config=RopeConfig(
max_position_embeddings=32768,
head_dim=128,
rope_theta=500000,
rope_scaling={
"factor": 4.0,
"high_freq_factor": 4.0,
"low_freq_factor": 1.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3",
},
),
)
```
### Model
We offer two simplified LLaMA models in the `model` directory, featuring self-attention and native sparse attention. For more details on how to use these models, please refer to [this link](https://github.com/XunhaoLai/native-sparse-attention-triton/tree/main/native_sparse_attention/model#readme).
```python
from native_sparse_attention.model import ToyNSALlamaConfig, InferenceConfig, ToyNSALlama
config = ToyNSALlamaConfig(
hidden_size=4096,
intermediate_size=14336,
num_hidden_layers=8,
num_attention_heads=32,
num_key_value_heads=2,
head_dim=128,
rope_theta=500000.0,
rope_scaling={
"factor": 8.0,
"high_freq_factor": 4.0,
"low_freq_factor": 1.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3",
},
compress_type="weightedpool",
kernel_size=32,
kernel_stride=16,
block_size=64,
topk=8,
init_blocks=1,
local_blocks=2,
window_size=512,
)
inference_config = InferenceConfig(
max_batch_size=4,
max_length=8192,
max_new_tokens=128,
)
model = ToyNSALlama(config, inference_config).cuda().bfloat16()
```
## Testing
Some test scripts are available in the `test` folder and can be run directly for unit testing. For example:
```bash
python test/test_topk_sparse_attention.py
python test/test_nsa_module.py
python test/test_nsa_model.py
```
### Benchmarks
Here are the speed benchmarks conducted on a single NVIDIA A100 GPU or H100 GPU for the `topk_sparse_attention` function:
A100 GPU speed benchmarks:
```sh
** forward with block size 64 **:
N Flash Triton-Flash Triton-Top8 Triton-Top16
0 2048.0 0.414144 0.635648 0.633440 1.009184
1 4096.0 1.400304 2.267552 1.179808 1.916736
2 8192.0 5.223776 8.528160 2.266816 3.723168
3 16384.0 20.225697 32.745537 4.468128 7.359168
4 32768.0 79.587715 128.951065 8.517440 14.142848
5 65536.0 321.240479 511.652100 17.249599 30.991360
6 131072.0 1349.810425 2063.245605 36.400482 67.884544
** backward with block size 64 **:
N Flash Triton-Flash Triton-Top8 Triton-Top16
0 2048.0 1.315440 2.348560 1.941568 2.691040
1 4096.0 4.271584 8.553184 3.647744 5.032160
2 8192.0 15.323984 32.665440 5.650144 9.066112
3 16384.0 58.753281 127.675964 11.160832 17.113279
4 32768.0 227.770462 504.572693 21.723392 34.715614
5 65536.0 899.181274 2059.718506 44.517181 76.309441
6 131072.0 3587.918701 8530.726562 105.344734 182.970169
```
H100 GPU benchmarks:
```sh
** forward with block size 64 **:
N Flash Triton-Flash Triton-Top8 Triton-Top16
0 2048.0 0.259552 0.293888 0.584544 0.917664
1 4096.0 0.846848 1.029904 1.094976 1.745136
2 8192.0 3.043744 3.843392 2.128256 3.396880
3 16384.0 11.743568 14.791360 4.190528 6.704192
4 32768.0 45.968513 57.532478 7.614496 12.417440
5 65536.0 187.234375 228.093948 14.840048 24.511856
6 131072.0 810.890381 914.693970 29.470400 48.990192
** backward with block size 64 **:
N Flash Triton-Flash Triton-Top8 Triton-Top16
0 2048.0 0.798976 1.096096 1.117312 1.380016
1 4096.0 2.545680 3.826336 1.669760 2.214880
2 8192.0 9.029760 14.411633 2.772096 3.947456
3 16384.0 34.144016 58.945698 5.201344 7.538912
4 32768.0 135.718369 233.369247 9.968864 15.154192
5 65536.0 541.053894 929.337646 21.089870 33.818878
6 131072.0 2139.974854 3785.540527 54.918144 93.750717
```
Here comes another speed benchmark result for testing `compressed_attention` function on a single NVIDIA A100 GPU or H100 GPU:
A100 GPU speed benchmarks:
```sh
** forward with kernel 32 and stride 16 **:
N Flash Triton-Flash Compressed Compressed-wo-Score
0 2048.0 0.413664 0.635488 0.655024 0.170816
1 4096.0 1.396416 2.247648 1.132304 0.377152
2 8192.0 5.234656 8.526400 2.879200 0.977952
3 16384.0 19.988865 32.755199 9.426448 2.943024
4 32768.0 79.419907 128.955170 30.284096 9.901120
5 65536.0 321.590210 511.615509 112.260544 36.001602
6 131072.0 1346.996338 2069.837891 423.099518 136.820038
** backward with kernel 32 and stride 16 **:
N Flash Triton-Flash Compressed
0 2048.0 1.322560 2.352000 0.486784
1 4096.0 4.270832 8.552608 0.971392
2 8192.0 15.515680 32.671329 2.603744
3 16384.0 59.345055 128.377472 8.499456
4 32768.0 230.626144 506.581238 30.064833
5 65536.0 919.260498 2068.642578 113.466560
6 131072.0 3646.603760 8498.374023 439.623444
```
H100 GPU speed benchmarks:
```sh
** forward with kernel 32 and stride 16 **:
N Flash Triton-Flash Compressed Compressed-wo-Score
0 2048.0 0.259488 0.297152 0.485920 0.103232
1 4096.0 0.847376 1.030400 0.710208 0.217760
2 8192.0 3.044016 3.875840 1.607360 0.516016
3 16384.0 11.823104 14.829360 4.970272 1.440288
4 32768.0 46.204750 57.527809 15.004992 4.584736
5 65536.0 187.324249 227.909958 53.009087 16.134224
6 131072.0 810.707214 910.106873 191.245728 60.154270
** backward with kernel 32 and stride 16 **:
N Flash Triton-Flash Compressed
0 2048.0 0.797728 1.090640 0.283104
1 4096.0 2.547088 3.834592 0.550464
2 8192.0 9.021520 14.421088 1.249184
3 16384.0 34.159508 58.793377 3.743440
4 32768.0 136.844070 233.447708 12.640032
5 65536.0 537.559814 929.360229 46.054817
6 131072.0 2135.629883 3782.351562 175.587296
```
All the speed benchmarks above were tested with 64 query heads, 4 key/value heads, and a head dimension of 128.
## Contributing
Contributions are welcome! Please open an issue to discuss major changes.
## Contact
For any questions or feedback, please feel free to contact laixunhao@pku.edu.cn.
## Citations
```bibtex
@inproceedings{Yuan2025NativeSA,
title = {Native Sparse Attention: Hardware-Aligned and Natively Trainable Sparse Attention},
author = {Jingyang Yuan and Huazuo Gao and Damai Dai and Junyu Luo and Liang Zhao and Zhengyan Zhang and Zhenda Xie and Y. X. Wei and Lean Wang and Zhiping Xiao and Yuqing Wang and Chong Ruan and Ming Zhang and Wenfeng Liang and Wangding Zeng},
year = {2025},
url = {https://api.semanticscholar.org/CorpusID:276408911}
}
```
================================================
FILE: install_dependency.sh
================================================
pip3 install packaging -i https://pypi.org/simple
pip3 install numpy==1.26.4 -i https://pypi.org/simple
pip3 install torch==2.4.0 -i https://pypi.org/simple
pip3 install triton==3.0.0 -i https://pypi.org/simple
pip3 install transformers==4.44.0 -i https://pypi.org/simple
pip3 install flash_attn==2.6.3 -i https://pypi.org/simple
pip3 install matplotlib==3.9.4 -i https://pypi.org/simple
pip3 install pandas==2.2.3 -i https://pypi.org/simple
================================================
FILE: native_sparse_attention/__init__.py
================================================
================================================
FILE: native_sparse_attention/infer/__init__.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from native_sparse_attention.infer.nsa_inference import nsa_infer
__all__ = [
"nsa_infer",
]
================================================
FILE: native_sparse_attention/infer/inference_func.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from typing import Tuple, Callable, Optional
from flash_attn import flash_attn_varlen_func
from native_sparse_attention.ops import (
flash_attention_decode,
compressed_attention,
compressed_attention_decode,
topk_sparse_attention,
topk_sparse_attention_decode,
)
from native_sparse_attention.ops.triton.utils import get_compressed_seqlens
def compress_infer(
cu_seqlens: torch.Tensor,
step: int,
key: torch.Tensor,
value: torch.Tensor,
cache,
weight: Tuple[torch.Tensor, torch.Tensor],
compress_func: Tuple[Callable, Callable],
intra_block_pe: Optional[torch.Tensor],
kernel_size: int,
kernel_stride: int,
):
if step == 0:
key, compress_cu_seqlens = compress_func[0](
key,
weight[0],
cu_seqlens,
kernel_size,
kernel_stride,
intra_block_pe,
)
value, _ = compress_func[1](
value,
weight[1],
cu_seqlens,
kernel_size,
kernel_stride,
)
else:
batch_size = cu_seqlens.shape[0] - 1
aux_cu_seqlens = (
torch.arange(batch_size + 1, dtype=torch.int32).to(cu_seqlens.device)
* kernel_size
)
key, _ = compress_func[0](
cache.before_compress_kv_cache[0, :batch_size].view(
batch_size * kernel_size, cache.num_kv_heads, cache.head_dim
),
weight[0],
aux_cu_seqlens,
kernel_size,
kernel_stride,
intra_block_pe,
)
value, _ = compress_func[1](
cache.before_compress_kv_cache[1, :batch_size].view(
batch_size * kernel_size, cache.num_kv_heads, cache.head_dim
),
weight[1],
aux_cu_seqlens,
kernel_size,
kernel_stride,
)
# return actual compress_cu_seqlens before this token
compress_cu_seqlens = torch.zeros(
batch_size + 1, dtype=torch.int32, device=key.device
)
compress_cu_seqlens[1:] = torch.cumsum(
cache.compress_kv_len[:batch_size], dim=0
)
return key, value, compress_cu_seqlens
def compressed_attention_infer(
cu_seqlens,
step,
query,
key,
value,
cache,
kernel_size,
kernel_stride,
topk,
block_size,
init_blocks,
local_blocks,
):
if step == 0:
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
compress_seqlens, compress_cu_seqlens = get_compressed_seqlens(
cu_seqlens, kernel_size, kernel_stride
)
attn_output, topk_idx = compressed_attention(
query,
key,
value,
kernel_size,
kernel_stride,
block_size,
topk,
cu_seqlens,
compress_cu_seqlens,
seqlens.max().item(),
compress_seqlens.max().item(),
None,
init_blocks,
local_blocks,
)
else:
batch_size = cu_seqlens.shape[0] - 1
seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + step
attn_output, topk_idx = compressed_attention_decode(
query,
cache.compress_kv_cache[
0, :batch_size, : cache.compress_kv_len[:batch_size].max()
],
cache.compress_kv_cache[
1, :batch_size, : cache.compress_kv_len[:batch_size].max()
],
seqlens,
cache.compress_kv_len[:batch_size],
kernel_size,
kernel_stride,
block_size,
topk,
init_blocks,
local_blocks,
)
return attn_output, topk_idx
def topk_sparse_attention_infer(
cu_seqlens,
step,
query,
key,
value,
cache,
topk_idx,
block_size,
):
if step == 0:
attn_output = topk_sparse_attention(
query, key, value, topk_idx, block_size, cu_seqlens
)
else:
batch_size = cu_seqlens.shape[0] - 1
attn_output = topk_sparse_attention_decode(
query,
cache.sparse_kv_cache[0, :batch_size],
cache.sparse_kv_cache[1, :batch_size],
topk_idx,
block_size,
cache.sparse_kv_len[:batch_size],
)
return attn_output
def sliding_window_attention_infer(
cu_seqlens, step, query, key, value, cache, window_size
):
if step == 0:
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
attn_output = flash_attn_varlen_func(
query,
key,
value,
cu_seqlens,
cu_seqlens,
seqlens.max().item(),
seqlens.max().item(),
causal=True,
window_size=(window_size, -1),
)
else:
batch_size = cu_seqlens.shape[0] - 1
attn_output = flash_attention_decode(
query,
cache.sliding_kv_cache[0, :batch_size],
cache.sliding_kv_cache[1, :batch_size],
torch.minimum(
cache.sliding_kv_len,
torch.zeros_like(cache.sliding_kv_len) + window_size,
)[:batch_size],
)
return attn_output
================================================
FILE: native_sparse_attention/infer/nsa_inference.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from typing import Tuple, Callable, Optional
from native_sparse_attention.infer.inference_func import (
compress_infer,
compressed_attention_infer,
topk_sparse_attention_infer,
sliding_window_attention_infer,
)
def nsa_infer(
cu_seqlens: torch.Tensor,
step: int,
# qkv for three parts
query: torch.Tensor,
key: torch.Tensor, # prefill: [total_len, num_heads, head_dim], decode: [batch_size, num_heads, head_dim]
value: torch.Tensor,
gate_value: torch.Tensor, # prefill: [total_len, num_heads, 3], decode: [batch_size, num_heads, 3]
# rope and kv cache
rope,
cache,
# weight for nsa compress
compress_weight: Tuple[
torch.Tensor, torch.Tensor
], # compress weight for key and value
compress_func: Tuple[Callable, Callable], # compress function for key and value
intra_block_pe: Optional[torch.Tensor],
# nsa parameters
kernel_size: int,
kernel_stride: int,
block_size: int,
topk: int,
init_blocks: int,
local_blocks: int,
window_size: int,
) -> torch.Tensor:
"""Inference function for native sparse attention. Support prefill and decode with kv cache.
Args:
cu_seqlens (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.
step (int): current inference step, step == 0 means prefill, step > 0 means decode step.
query (torch.Tensor): for prefill, shape [total_len, num_q_heads, head_dim]; for decode, shape [batch_size, num_q_heads, head_dim]
key (torch.Tensor): for prefill, shape [total_len, num_kv_heads, head_dim]; for decode, shape [batch_size, num_kv_heads, head_dim]
value (torch.Tensor): for prefill, shape [total_len, num_kv_heads, head_dim]; for decode, shape [batch_size, num_kv_heads, head_dim]
gate_value (torch.Tensor): for prefill, shape [total_len, num_heads, 3]; for decode, shape [batch_size, num_heads, 3]
rope (RotaryEmbedding): rope module, see native_sparse_attention.module.rope.RotaryEmbedding for details
cache (NSACache): kv cache, seed native_sparse_attention.module.kv_cache.NSACache for details
compress_weight (Tuple[torch.Tensor, torch.Tensor]): compress weight of key and value respectively
compress_func (Tuple[Callable, Callable]): compress functions for key and value respectively
intra_block_pe (Optional[torch.Tensor]): intra-block positonal embedding for compression, set to None if don't use it
kernel_size (int): kernel size of compression
kernel_stride (int): kernel stride ofr compression
block_size (int): block size of sparse attention
topk (int): topk of sparse attention
init_blocks (int): number of blocks at the begining of the sequence, these blocks are force to be computed in sparse attention
local_blocks (int): number of blocks at the local window of each query, these blocks are force to be computed in sparse attention
window_size (int): window size for sliding window attention
Returns:
torch.Tensor: native sparse attention output, same shape as input query
"""
# reset kv cache at the begining of prefilling
if step == 0:
cache.reset()
# prepare for compress
cache.prepare_compress(cu_seqlens, step, key, value)
# compressed key and value before rope
compress_key, compress_value, compress_cu_seqlens = compress_infer(
cu_seqlens,
step,
key,
value,
cache,
compress_weight,
compress_func,
intra_block_pe,
kernel_size,
kernel_stride,
)
# do rope
query = rope(query, cu_seqlens, step)
if step == 0:
compress_key = rope(
compress_key, compress_cu_seqlens, step, stride=cache.kernel_stride
)
else:
compress_key = rope(
compress_key, compress_cu_seqlens, 1, stride=cache.kernel_stride
)
key = rope(key, cu_seqlens, step)
# update kv cache
cache.update_kv(
cu_seqlens,
step,
compress_key,
compress_value,
key,
value,
key,
value,
)
# compressed attention
compress_attn_output, topk_idx = compressed_attention_infer(
cu_seqlens,
step,
query,
compress_key,
compress_value,
cache,
kernel_size,
kernel_stride,
topk,
block_size,
init_blocks,
local_blocks,
)
# topk sparse attention
sparse_attn_output = topk_sparse_attention_infer(
cu_seqlens,
step,
query,
key,
value,
cache,
topk_idx,
block_size,
)
# sliding window attention
sliding_attn_output = sliding_window_attention_infer(
cu_seqlens, step, query, key, value, cache, window_size
)
# combine 3 attn output
attn_output = (
gate_value[..., 0, None] * compress_attn_output
+ gate_value[..., 1, None] * sparse_attn_output
+ gate_value[..., 2, None] * sliding_attn_output
)
return attn_output
================================================
FILE: native_sparse_attention/model/README.md
================================================
# Guide for the ToyNSALlama Model
The `ToyNSALlama` model is a custom implementation of a Llama-like transformer architecture featuring a Native Sparse Attention (NSA) module. This guide explains how to integrate the NSA module into your own model.
## Overview
The `ToyNSALlama` model consists of:
- **Configuration**: Defined by `ToyNSALlamaConfig` (model structure parameters) and `InferenceConfig` (inference-specific parameters).
- **Components**: An embedding layer, multiple NativeSparseAttention modules, Feed-Forward Network (FFN) modules, normalization layers, and a language model head.
## Step-by-Step Instructions
### 1. Import Necessary Modules
```python
import torch
import torch.nn as nn
from native_sparse_attention.model import ToyNSALlama, ToyNSALlamaConfig, InferenceConfig
```
### 2. Define Configurations
Create instances of `ToyNSALlamaConfig` and `InferenceConfig` to set model and inference parameters.
#### Model Configuration
The model configuration aligns with the Transformers Llama model configuration. Adjust the following parameters to control the NSA module’s sparsity:
- `compress_type`: Compression method for keys/values. Supported options: `avgpool`, `weightedpool`, `linear`.
- `kernel_size` & `kernel_stride`: `kernel_size` determines how many tokens are compressed into one; `kernel_stride` sets the sliding window stride (must be divisible by `kernel_size`).
- `block_size`: Block size for sparse attention (recommended: 64 or 128).
- `topk`, `init_blocks`, `local_blocks`: `topk` specifies the number of blocks selected in sparse attention; `init_blocks` and `local_blocks` define the number of initial and local blocks that must be selected.
- `window_size`: Size of the sliding window for attention.
Example:
```python
config = ToyNSALlamaConfig(
hidden_size=4096,
intermediate_size=14336,
num_hidden_layers=8,
num_attention_heads=32,
num_key_value_heads=2,
head_dim=128,
vocab_size=128288,
max_position_embeddings=131072,
rope_theta=500000.0,
rope_scaling={
"factor": 8.0,
"high_freq_factor": 4.0,
"low_freq_factor": 1.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3",
},
compress_type="weightedpool",
kernel_size=32,
kernel_stride=16,
block_size=64,
topk=8,
init_blocks=1,
local_blocks=2,
window_size=512,
)
```
#### Inference Configuration
This configuration applies during inference, initializing the Key-Value (KV) Cache based on these settings. The full KV cache size is calculated as `max_batch_size × max_length × num_kv_heads × num_layers × 2 × 2` bytes. Currently, only greedy decoding is supported as an example.
Example:
```python
inference_config = InferenceConfig(
max_batch_size=4,
max_length=8192,
max_new_tokens=128,
)
```
### 3. Initialize the Model
Instantiate the model and move it to the GPU with the appropriate data type (currently, only `bfloat16` is supported).
```python
model = ToyNSALlama(config, inference_config).cuda().to(torch.bfloat16)
```
### 4. Forward & Generate
The model supports two methods:
- **`forward`**: Accepts `input_ids` and `cu_seqlens`, returning final logits after the language model head. Use this for training or evaluation.
- **`generate`**: Accepts `input_ids` and `cu_seqlens`, generating output tokens via greedy sampling. This demonstrates KV cache usage for token generation (pre-filling and decoding).
Example:
```python
# Example input
batch_size = 4
seqlens = torch.randint(0, 4096, (batch_size,), dtype=torch.int32, device="cuda")
cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
input_ids = torch.randint(0, 128288, (cu_seqlens[-1],), dtype=torch.int64, device="cuda")
print(f"\nEXAMPLE INPUT:\ncu_seqlens: {cu_seqlens}\ninput_ids: {input_ids.shape}\n")
# Example forward
logits = model(input_ids, cu_seqlens)
print(f"\nEXAMPLE OUTPUT:\nlogits: {logits.shape}\n")
# Example generate
output_tokens = model.generate(input_ids, cu_seqlens)
print(f"\nEXAMPLE GENERATE:\noutput_tokens: {output_tokens}\n")
```
## Toy Llama Model with Self-Attention
A simpler toy model with the Llama structure is available in `native_sparse_attention/model/toy_llama.py`. Compare `ToyLlama` and `ToyNSALlama` to see how to adapt a self-attention model into an NSA-based model.
The primary difference lies in replacing the `SelfAttention` module with the `NativeSparseAttention` module, along with updates to the KV cache and inference function. These changes are straightforward and easy to implement.
================================================
FILE: native_sparse_attention/model/__init__.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from native_sparse_attention.model.toy_llama import (
ToyLlamaConfig,
InferenceConfig,
ToyLlama,
)
from native_sparse_attention.model.toy_nsa_llama import (
ToyNSALlamaConfig,
InferenceConfig,
ToyNSALlama,
)
__all__ = [
"ToyLlamaConfig",
"ToyNSALlamaConfig",
"InferenceConfig",
"ToyLlama",
"ToyNSALlama",
]
================================================
FILE: native_sparse_attention/model/toy_llama.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
import torch.nn as nn
from dataclasses import dataclass, field
from native_sparse_attention.module import SelfAttention, RopeConfig, KVCache
@dataclass
class ToyLlamaConfig:
# embedding config
vocab_size: int = 128288
max_position_embeddings: int = 131072
# model config
hidden_size: int = 4096
intermediate_size: int = 14336
num_hidden_layers: int = 32
num_attention_heads: int = 32
num_key_value_heads: int = 2
head_dim: int = 128
# rope config
rope_theta: float = 500000.0
rope_scaling: dict = field(
default_factory=lambda: {
"factor": 8.0,
"high_freq_factor": 4.0,
"low_freq_factor": 1.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3",
}
)
@dataclass
class InferenceConfig:
max_batch_size: int = 32
max_length: int = 8192
max_new_tokens: int = 128
class RMSNorm(nn.Module):
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states: torch.Tensor):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class FFN(nn.Module):
def __init__(self, hidden_size: int, intermediate_size: int):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = nn.SiLU()
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
class ToyLlamaLayer(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
num_q_heads: int,
num_kv_heads: int,
head_dim: int,
rope_config: RopeConfig,
):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_q_heads = num_q_heads
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.rope_config = rope_config
self.attn_norm = RMSNorm(self.hidden_size)
self.self_attn = SelfAttention(
hidden_size=self.hidden_size,
num_q_heads=self.num_q_heads,
num_kv_heads=self.num_kv_heads,
head_dim=self.head_dim,
rope_config=rope_config,
)
self.ffn_norm = RMSNorm(self.hidden_size)
self.ffn = FFN(
hidden_size=self.hidden_size, intermediate_size=self.intermediate_size
)
def forward(self, x, cu_seqlens):
x = x + self.self_attn(self.attn_norm(x), cu_seqlens)
x = x + self.ffn(self.ffn_norm(x))
return x
@torch.no_grad()
def inference(self, x, cu_seqlens, step, kv_cache):
x = x + self.self_attn.inference(self.attn_norm(x), cu_seqlens, step, kv_cache)
x = x + self.ffn(self.ffn_norm(x))
return x
class ToyLlama(nn.Module):
def __init__(
self, config: ToyLlamaConfig, inference_config: Optional[InferenceConfig] = None
):
super().__init__()
self.config = config
self.embedding = nn.Embedding(self.config.vocab_size, self.config.hidden_size)
self.rope_config = RopeConfig(
head_dim=self.config.head_dim,
rope_theta=self.config.rope_theta,
rope_scaling=self.config.rope_scaling,
)
self.layers = nn.ModuleList(
[
ToyLlamaLayer(
hidden_size=self.config.hidden_size,
intermediate_size=self.config.intermediate_size,
num_q_heads=self.config.num_attention_heads,
num_kv_heads=self.config.num_key_value_heads,
head_dim=self.config.head_dim,
rope_config=RopeConfig(
self.config.max_position_embeddings,
self.config.head_dim,
self.config.rope_theta,
self.config.rope_scaling,
),
)
for _ in range(self.config.num_hidden_layers)
]
)
self.norm = RMSNorm(self.config.hidden_size)
self.lm_head = nn.Linear(
self.config.hidden_size, self.config.vocab_size, bias=False
)
# inference config and kv cache
self.inference_config = inference_config
self.kv_cache = None
def forward(
self,
input_ids: torch.LongTensor, # shape: [total_length, ]
cu_seqlens: torch.LongTensor, # shape: [batch_size + 1, ]
):
# embedding
x = self.embedding(input_ids).to(torch.bfloat16)
# layers
for layer in self.layers:
x = layer(x, cu_seqlens)
# final norm
x = self.norm(x)
# lanugauge head
x = self.lm_head(x).to(torch.float32) # [total_len, vocab_size]
return x
@torch.no_grad()
def inference(
self,
input_ids: torch.LongTensor, # prefill shape: [total_length, ]; decode shape: [batch_size, ]
cu_seqlens: torch.LongTensor, # shape: [batch_size + 1, ]
step: int,
):
# set kv cache if self.kv_cache is None
if self.kv_cache is None:
self.kv_cache = [
KVCache(
max_batch_size=self.inference_config.max_batch_size,
max_length=self.inference_config.max_length,
num_kv_heads=self.config.num_key_value_heads,
head_dim=self.config.head_dim,
dtype=torch.bfloat16,
device="cuda",
)
for _ in range(self.config.num_hidden_layers)
]
# embedding
x = self.embedding(input_ids).to(torch.bfloat16)
# layers
for i, layer in enumerate(self.layers):
x = layer.inference(x, cu_seqlens, step, self.kv_cache[i])
# final norm
x = self.norm(x)
# lanugauge head
if step == 0:
x = x[cu_seqlens[1:] - 1, :]
x = self.lm_head(x).to(torch.float32) # [total_len, vocab_size]
return x
def generate(
self,
input_ids: torch.LongTensor,
cu_seqlens: torch.LongTensor,
max_new_tokens: int = -1,
):
output_tokens = []
if max_new_tokens <= 0:
max_new_tokens = self.inference_config.max_new_tokens
for step in range(max_new_tokens):
logits = self.inference(
input_ids, cu_seqlens, step
) # shape: [batch_size, vocab_size]
next_token = torch.argmax(logits, dim=-1) # shape: [batch_size, ]
input_ids = next_token
output_tokens.append(next_token)
output_tokens = torch.stack(
output_tokens, dim=1
) # shape: [batch_size, max_new_tokens]
return output_tokens
if __name__ == "__main__":
torch.manual_seed(42)
# initialize model
config = ToyLlamaConfig(
hidden_size=4096,
intermediate_size=14336,
num_hidden_layers=8,
num_attention_heads=32,
num_key_value_heads=2,
head_dim=128,
rope_theta=500000.0,
rope_scaling={
"factor": 8.0,
"high_freq_factor": 4.0,
"low_freq_factor": 1.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3",
},
)
inference_config = InferenceConfig(
max_batch_size=4,
max_length=8192,
max_new_tokens=128,
)
model = ToyLlama(config, inference_config).cuda().bfloat16()
print(f"\nMODEL CONFIG:\n{config}\n")
print(f"\nINFERENCE CONFIG:\n{inference_config}\n")
print(f"\nMODEL:\n{model}\n")
# example input
batch_size = 4
seqlens = torch.randint(0, 4096, (batch_size,), dtype=torch.int32, device="cuda")
cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
input_ids = torch.randint(
0, 128288, (cu_seqlens[-1],), dtype=torch.int64, device="cuda"
)
print(f"\nEXAMPLE INPUT:\ncu_seqlens: {cu_seqlens}\ninput_ids: {input_ids.shape}\n")
# example output
logits = model(input_ids, cu_seqlens)
print(f"\nEXAMPLE OUTPUT:\nlogits: {logits.shape}\n")
# example generate
output_tokens = model.generate(input_ids, cu_seqlens, 64)
print(f"\nEXAMPLE GENERATE:\noutput_tokens: {output_tokens}\n")
================================================
FILE: native_sparse_attention/model/toy_nsa_llama.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
import torch.nn as nn
from dataclasses import dataclass, field
from native_sparse_attention.module import NativeSparseAttention, RopeConfig, NSACache
@dataclass
class ToyNSALlamaConfig:
# embedding config
vocab_size: int = 128288
max_position_embeddings: int = 131072
# model config
hidden_size: int = 4096
intermediate_size: int = 14336
num_hidden_layers: int = 32
num_attention_heads: int = 32
num_key_value_heads: int = 2
head_dim: int = 128
# rope config
rope_theta: float = 500000.0
rope_scaling: dict = field(
default_factory=lambda: {
"factor": 8.0,
"high_freq_factor": 4.0,
"low_freq_factor": 1.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3",
}
)
# nsa config
compress_type: str = "weightedpool"
kernel_size: int = 32
kernel_stride: int = 16
block_size: int = 64
topk: int = 16
init_blocks: int = 1
local_blocks: int = 2
window_size: int = 512
@dataclass
class InferenceConfig:
max_batch_size: int = 32
max_length: int = 8192
max_new_tokens: int = 128
class RMSNorm(nn.Module):
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states: torch.Tensor):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class FFN(nn.Module):
def __init__(self, hidden_size: int, intermediate_size: int):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = nn.SiLU()
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
class ToyNSALlamaLayer(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
num_q_heads: int,
num_kv_heads: int,
head_dim: int,
compress_type: str,
kernel_size: int,
kernel_stride: int,
block_size: int,
topk: int,
init_blocks: int,
local_blocks: int,
window_size: int,
rope_config: RopeConfig,
):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_q_heads = num_q_heads
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.compress_type = compress_type
self.kernel_size = kernel_size
self.kernel_stride = kernel_stride
self.block_size = block_size
self.topk = topk
self.init_blocks = init_blocks
self.local_blocks = local_blocks
self.window_size = window_size
self.rope_config = rope_config
self.attn_norm = RMSNorm(self.hidden_size)
self.nsa = NativeSparseAttention(
hidden_size=self.hidden_size,
num_q_heads=self.num_q_heads,
num_kv_heads=self.num_kv_heads,
head_dim=self.head_dim,
compress_type=self.compress_type,
kernel_size=self.kernel_size,
kernel_stride=self.kernel_stride,
block_size=self.block_size,
topk=self.topk,
init_blocks=self.init_blocks,
local_blocks=self.local_blocks,
window_size=self.window_size,
rope_config=rope_config,
)
self.ffn_norm = RMSNorm(self.hidden_size)
self.ffn = FFN(
hidden_size=self.hidden_size, intermediate_size=self.intermediate_size
)
def forward(self, x, cu_seqlens):
x = x + self.nsa(self.attn_norm(x), cu_seqlens)
x = x + self.ffn(self.ffn_norm(x))
return x
@torch.no_grad()
def inference(self, x, cu_seqlens, step, kv_cache):
x = x + self.nsa.inference(self.attn_norm(x), cu_seqlens, step, kv_cache)
x = x + self.ffn(self.ffn_norm(x))
return x
class ToyNSALlama(nn.Module):
def __init__(
self,
config: ToyNSALlamaConfig,
inference_config: Optional[InferenceConfig] = None,
):
super().__init__()
self.config = config
self.embedding = nn.Embedding(self.config.vocab_size, self.config.hidden_size)
self.rope_config = RopeConfig(
head_dim=self.config.head_dim,
rope_theta=self.config.rope_theta,
rope_scaling=self.config.rope_scaling,
)
self.layers = nn.ModuleList(
[
ToyNSALlamaLayer(
hidden_size=self.config.hidden_size,
intermediate_size=self.config.intermediate_size,
num_q_heads=self.config.num_attention_heads,
num_kv_heads=self.config.num_key_value_heads,
head_dim=self.config.head_dim,
compress_type=self.config.compress_type,
kernel_size=self.config.kernel_size,
kernel_stride=self.config.kernel_stride,
block_size=self.config.block_size,
topk=self.config.topk,
init_blocks=self.config.init_blocks,
local_blocks=self.config.local_blocks,
window_size=self.config.window_size,
rope_config=RopeConfig(
self.config.max_position_embeddings,
self.config.head_dim,
self.config.rope_theta,
self.config.rope_scaling,
),
)
for _ in range(self.config.num_hidden_layers)
]
)
self.norm = RMSNorm(self.config.hidden_size)
self.lm_head = nn.Linear(
self.config.hidden_size, self.config.vocab_size, bias=False
)
# inference config and kv cache
self.inference_config = inference_config
self.kv_cache = None
def forward(
self,
input_ids: torch.LongTensor, # shape: [batch_size, max_length]
cu_seqlens: torch.LongTensor, # shape: [batch_size + 1, ]
):
# embedding
x = self.embedding(input_ids).to(torch.bfloat16)
# layers
for layer in self.layers:
x = layer(x, cu_seqlens)
# final norm
x = self.norm(x)
# lanugauge head
x = self.lm_head(x).to(torch.float32) # [total_len, vocab_size]
return x
@torch.no_grad()
def inference(
self,
input_ids: torch.LongTensor, # prefill shape: [total_length, ]; decode shape: [batch_size, ]
cu_seqlens: torch.LongTensor, # shape: [batch_size + 1, ]
step: int,
):
# set kv cache if self.kv_cache is None
if self.kv_cache is None:
self.kv_cache = [
NSACache(
max_batch_size=self.inference_config.max_batch_size,
max_length=self.inference_config.max_length,
num_kv_heads=self.config.num_key_value_heads,
head_dim=self.config.head_dim,
kernel_size=self.config.kernel_size,
kernel_stride=self.config.kernel_stride,
window_size=self.config.window_size,
dtype=torch.bfloat16,
device="cuda",
)
for _ in range(self.config.num_hidden_layers)
]
# embedding
x = self.embedding(input_ids).to(torch.bfloat16)
# layers
for i, layer in enumerate(self.layers):
x = layer.inference(x, cu_seqlens, step, self.kv_cache[i])
# final norm
x = self.norm(x)
# lanugauge head
if step == 0:
x = x[cu_seqlens[1:] - 1, :]
x = self.lm_head(x).to(torch.float32) # [total_len, vocab_size]
return x
def generate(
self,
input_ids: torch.LongTensor,
cu_seqlens: torch.LongTensor,
max_new_tokens: int = -1,
):
output_tokens = []
if max_new_tokens <= 0:
max_new_tokens = self.inference_config.max_new_tokens
for step in range(max_new_tokens):
logits = self.inference(
input_ids, cu_seqlens, step
) # shape: [batch_size, vocab_size]
next_token = torch.argmax(logits, dim=-1) # shape: [batch_size, ]
input_ids = next_token
output_tokens.append(next_token)
output_tokens = torch.stack(
output_tokens, dim=1
) # shape: [batch_size, max_new_tokens]
return output_tokens
if __name__ == "__main__":
torch.manual_seed(42)
# initialize model
config = ToyNSALlamaConfig(
hidden_size=4096,
intermediate_size=14336,
num_hidden_layers=8,
num_attention_heads=32,
num_key_value_heads=2,
head_dim=128,
rope_theta=500000.0,
rope_scaling={
"factor": 8.0,
"high_freq_factor": 4.0,
"low_freq_factor": 1.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3",
},
compress_type="weightedpool",
kernel_size=32,
kernel_stride=16,
block_size=64,
topk=8,
init_blocks=1,
local_blocks=2,
window_size=512,
)
inference_config = InferenceConfig(
max_batch_size=4,
max_length=8192,
max_new_tokens=128,
)
model = ToyNSALlama(config, inference_config).cuda().bfloat16()
print(f"\nMODEL CONFIG:\n{config}\n")
print(f"\nINFERENCE CONFIG:\n{inference_config}\n")
print(f"\nMODEL:\n{model}\n")
# example input
batch_size = 4
seqlens = torch.randint(0, 4096, (batch_size,), dtype=torch.int32, device="cuda")
cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
input_ids = torch.randint(
0, 128288, (cu_seqlens[-1],), dtype=torch.int64, device="cuda"
)
print(f"\nEXAMPLE INPUT:\ncu_seqlens: {cu_seqlens}\ninput_ids: {input_ids.shape}\n")
# example output
logits = model(input_ids, cu_seqlens)
print(f"\nEXAMPLE OUTPUT:\nlogits: {logits.shape}\n")
# example generate
output_tokens = model.generate(input_ids, cu_seqlens, 64)
print(f"\nEXAMPLE GENERATE:\noutput_tokens: {output_tokens}\n")
================================================
FILE: native_sparse_attention/module/__init__.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from native_sparse_attention.module.native_sparse_attention import NativeSparseAttention
from native_sparse_attention.module.self_attention import SelfAttention
from native_sparse_attention.module.rope import RotaryEmbedding, RopeConfig
from native_sparse_attention.module.kv_cache import NSACache, KVCache
__all__ = [
"SelfAttention",
"NativeSparseAttention",
"RotaryEmbedding",
"RopeConfig",
"NSACache",
]
================================================
FILE: native_sparse_attention/module/kv_cache.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import triton
import triton.language as tl
from typing import Union
from native_sparse_attention.ops.triton.utils import get_compressed_seqlens
class KVCache:
def __init__(
self,
max_batch_size: int,
max_length: int,
num_kv_heads: int,
head_dim: int,
dtype: torch.dtype,
device: Union[str, torch.device],
):
self.max_batch_size = max_batch_size
self.max_length = max_length
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.dtype = dtype
self.device = device
# alloc kv cache tensor for topk sparse attention
self.kv_cache = torch.zeros(
2,
self.max_batch_size,
self.max_length,
self.num_kv_heads,
self.head_dim,
dtype=self.dtype,
device=self.device,
)
self.kv_len = torch.zeros(
self.max_batch_size, dtype=torch.int32, device=self.device
)
def reset(self):
self.kv_cache.zero_()
self.kv_len.zero_()
def update_kv(
self,
cu_seqlens: torch.Tensor,
step: int,
key: torch.Tensor,
value: torch.Tensor,
):
if step == 0:
self._update_kv_prefill(
cu_seqlens,
step,
key,
value,
)
else:
self._update_kv_decode(
cu_seqlens,
step,
key,
value,
)
def _update_kv_prefill(
self,
cu_seqlens: torch.Tensor,
step: int,
key: torch.Tensor,
value: torch.Tensor,
):
assert step == 0
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
batch_size = seqlens.shape[0]
# sparse part kv, shape check
total_len, num_heads, head_dim = key.shape
assert key.shape == value.shape
assert num_heads == self.num_kv_heads and head_dim == self.head_dim
assert total_len == cu_seqlens[-1].item()
# fill sparse part kv cache
seq_start, seq_end = cu_seqlens[:-1], cu_seqlens[1:]
_fill_kv_cache(
self.kv_cache,
key,
value,
seq_start,
seq_end,
)
self.kv_len[:batch_size] = seqlens
def _update_kv_decode(
self,
cu_seqlens: torch.Tensor,
step: int,
key: torch.Tensor,
value: torch.Tensor,
):
assert step > 0
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
# sparse part kv, shape check
batch_size, num_heads, head_dim = key.shape
assert batch_size == seqlens.shape[0]
assert key.shape == value.shape
assert num_heads == self.num_kv_heads and head_dim == self.head_dim
# fill sparse part kv cache
brange = torch.arange(batch_size, dtype=torch.int32, device=key.device)
self.kv_cache[0, :batch_size][brange, self.kv_len[:batch_size]] = key
self.kv_cache[1, :batch_size][brange, self.kv_len[:batch_size]] = value
self.kv_len[:batch_size] += 1
class NSACache:
"""KV cache manager for native sparse attention.
Args:
max_batch_size (int): max batch size
max_length (int): max length, including prompt len and reponse len
num_kv_heads (int): number of key/value heads
head_dim (int): head dim
kernel_size (int): kernel size of compression
kernel_stride (int): kernel stride ofr compression
window_size (int): window size for sliding window attention
dtype (torch.dtype): data type for kv cache, should be same as model weight dtype
device (Union[str, torch.device]): default to 'cuda'
Methods:
reset: reset kv cache, should be called before prefilling
prepare_compress: store keys/values for compression, should be called before key/value compression at both prefilling and decoding
update_kv: update key/value cache, should be called after rope
"""
def __init__(
self,
max_batch_size: int,
max_length: int,
num_kv_heads: int,
head_dim: int,
kernel_size: int,
kernel_stride: int,
window_size: int,
dtype: torch.dtype,
device: Union[str, torch.device] = "cuda",
):
self.max_batch_size = max_batch_size
self.max_length = max_length
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.kernel_size = kernel_size
self.kernel_stride = kernel_stride
self.window_size = window_size
self.dtype = dtype
self.device = device
# alloc kv cache tensor for topk sparse attention
self.sparse_kv_cache = torch.zeros(
2,
self.max_batch_size,
self.max_length,
self.num_kv_heads,
self.head_dim,
dtype=self.dtype,
device=self.device,
)
self.sparse_kv_len = torch.zeros(
self.max_batch_size, dtype=torch.int32, device=self.device
)
# alloc kv cache tensor for compressed attention
self.max_comp_length = (
self.max_length - self.kernel_size
) // self.kernel_stride + 1
self.compress_kv_cache = torch.zeros(
2,
self.max_batch_size,
self.max_comp_length,
self.num_kv_heads,
self.head_dim,
dtype=self.dtype,
device=self.device,
)
self.compress_kv_len = torch.zeros(
self.max_batch_size, dtype=torch.int32, device=self.device
)
self.before_compress_kv_cache = torch.zeros(
2,
self.max_batch_size,
self.kernel_size,
self.num_kv_heads,
self.head_dim,
dtype=self.dtype,
device=self.device,
)
self.before_compress_kv_len = torch.zeros(
self.max_batch_size, dtype=torch.int32, device=self.device
)
# alloc kv cache for sliding window attention
self.sliding_kv_cache = torch.zeros(
2,
self.max_batch_size,
self.window_size,
self.num_kv_heads,
self.head_dim,
dtype=self.dtype,
device=self.device,
)
self.sliding_kv_len = torch.zeros(
self.max_batch_size, dtype=torch.int32, device=self.device
)
def reset(self):
self.compress_kv_cache.zero_()
self.compress_kv_len.zero_()
self.before_compress_kv_cache.zero_()
self.before_compress_kv_len.zero_()
self.sparse_kv_cache.zero_()
self.sparse_kv_len.zero_()
self.sliding_kv_cache.zero_()
self.sliding_kv_len.zero_()
def prepare_compress(
self,
cu_seqlens: torch.Tensor,
step: int,
key: torch.Tensor,
value: torch.Tensor,
):
if step == 0:
self._prepare_compress_prefill(cu_seqlens, step, key, value)
else:
self._prepare_compress_decode(cu_seqlens, step, key, value)
def _prepare_compress_prefill(
self,
cu_seqlens: torch.Tensor,
step: int,
key: torch.Tensor,
value: torch.Tensor,
):
assert step == 0
# compress part kv, shape check
batch_size = cu_seqlens.shape[0] - 1
total_len, num_heads, head_dim = key.shape
assert key.shape == value.shape
assert num_heads == self.num_kv_heads and head_dim == self.head_dim
comp_seqlens, comp_cu_seqlens = get_compressed_seqlens(
cu_seqlens, self.kernel_size, self.kernel_stride
)
assert total_len == cu_seqlens[-1].item()
# fill tmp cache
seq_start = cu_seqlens[:-1] + comp_seqlens * self.kernel_stride
seq_end = cu_seqlens[1:]
_fill_kv_cache(
self.before_compress_kv_cache,
key,
value,
seq_start,
seq_end,
)
self.before_compress_kv_len[:batch_size] = seq_end - seq_start
def _prepare_compress_decode(
self,
cu_seqlens: torch.Tensor,
step: int,
key: torch.Tensor,
value: torch.Tensor,
):
assert step > 0
# compress part kv, shape check
batch_size, num_heads, head_dim = key.shape
assert key.shape == value.shape
assert num_heads == self.num_kv_heads and head_dim == self.head_dim
assert batch_size == cu_seqlens.shape[0] - 1
# sequence with full before_compress_kv
idx = torch.where(self.before_compress_kv_len == self.kernel_size)[0].squeeze()
self.before_compress_kv_cache[
:, idx, : self.kernel_size - self.kernel_stride, :, :
] = self.before_compress_kv_cache[:, idx, self.kernel_stride :, :, :]
self.before_compress_kv_len[idx] -= self.kernel_stride
# fill new kv
brange = torch.arange(batch_size, dtype=torch.int32, device=key.device)
self.before_compress_kv_cache[0, :batch_size][
brange, self.before_compress_kv_len[:batch_size]
] = key
self.before_compress_kv_cache[1, :batch_size][
brange, self.before_compress_kv_len[:batch_size]
] = value
# update kv len
self.before_compress_kv_len[:batch_size] += 1
def update_kv(
self,
cu_seqlens: torch.Tensor,
step: int,
compress_key: torch.Tensor,
compress_value: torch.Tensor,
sparse_key: torch.Tensor,
sparse_value: torch.Tensor,
sliding_key: torch.Tensor,
sliding_value: torch.Tensor,
):
if step == 0:
self._update_kv_prefill(
cu_seqlens,
step,
compress_key,
compress_value,
sparse_key,
sparse_value,
sliding_key,
sliding_value,
)
else:
self._update_kv_decode(
cu_seqlens,
step,
compress_key,
compress_value,
sparse_key,
sparse_value,
sliding_key,
sliding_value,
)
def _update_kv_prefill(
self,
cu_seqlens: torch.Tensor,
step: int,
compress_key: torch.Tensor,
compress_value: torch.Tensor,
sparse_key: torch.Tensor,
sparse_value: torch.Tensor,
sliding_key: torch.Tensor,
sliding_value: torch.Tensor,
):
assert step == 0
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
batch_size = seqlens.shape[0]
# sparse part kv, shape check
total_len, num_heads, head_dim = sparse_key.shape
assert sparse_key.shape == sparse_value.shape
assert num_heads == self.num_kv_heads and head_dim == self.head_dim
assert total_len == cu_seqlens[-1].item()
# compress part kv, shape check
total_comp_len, num_heads, head_dim = compress_key.shape
assert compress_key.shape == compress_value.shape
assert num_heads == self.num_kv_heads and head_dim == self.head_dim
comp_seqlens, comp_cu_seqlens = get_compressed_seqlens(
cu_seqlens, self.kernel_size, self.kernel_stride
)
assert total_comp_len == comp_cu_seqlens[-1].item()
# sliding window part kv, shape check
total_len, num_heads, head_dim = sliding_key.shape
assert sliding_key.shape == sliding_value.shape
# fill compress part kv cache
seq_start, seq_end = comp_cu_seqlens[:-1], comp_cu_seqlens[1:]
_fill_kv_cache(
self.compress_kv_cache,
compress_key,
compress_value,
seq_start,
seq_end,
)
self.compress_kv_len[:batch_size] = comp_seqlens
# fill sparse part kv cache
seq_start, seq_end = cu_seqlens[:-1], cu_seqlens[1:]
_fill_kv_cache(
self.sparse_kv_cache,
sparse_key,
sparse_value,
seq_start,
seq_end,
)
self.sparse_kv_len[:batch_size] = seqlens
# fill sliding part kv cache
seq_start = torch.maximum(cu_seqlens[1:] - self.window_size, cu_seqlens[:-1])
seq_end = cu_seqlens[1:]
_fill_kv_cache(
self.sliding_kv_cache,
sliding_key,
sliding_value,
seq_start,
seq_end,
)
self.sliding_kv_len[:batch_size] = seq_end - seq_start
def _update_kv_decode(
self,
cu_seqlens: torch.Tensor,
step: int,
compress_key: torch.Tensor,
compress_value: torch.Tensor,
sparse_key: torch.Tensor,
sparse_value: torch.Tensor,
sliding_key: torch.Tensor,
sliding_value: torch.Tensor,
):
assert step > 0
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
# sparse part kv, shape check
batch_size, num_heads, head_dim = sparse_key.shape
assert batch_size == seqlens.shape[0]
assert sparse_key.shape == sparse_value.shape
assert num_heads == self.num_kv_heads and head_dim == self.head_dim
# compress part kv, shape check
batch_size, num_heads, head_dim = compress_key.shape
assert batch_size == seqlens.shape[0]
assert compress_key.shape == compress_value.shape
assert num_heads == self.num_kv_heads and head_dim == self.head_dim
# sliding window part kv, shape check
total_len, num_heads, head_dim = sliding_key.shape
assert sliding_key.shape == sliding_value.shape
# fill compress part kv cache, only full block need to be compress
idx = torch.where(self.before_compress_kv_len == self.kernel_size)[0].squeeze()
self.compress_kv_cache[0][idx, self.compress_kv_len[idx]] = compress_key[idx]
self.compress_kv_cache[1][idx, self.compress_kv_len[idx]] = compress_value[idx]
self.compress_kv_len[idx] += 1
# fill sparse part kv cache
brange = torch.arange(batch_size, dtype=torch.int32, device=sparse_key.device)
self.sparse_kv_cache[0, :batch_size][
brange, self.sparse_kv_len[:batch_size]
] = sparse_key
self.sparse_kv_cache[1, :batch_size][
brange, self.sparse_kv_len[:batch_size]
] = sparse_value
self.sparse_kv_len[:batch_size] += 1
# fill sliding window kv cache
self.sliding_kv_cache[0, :batch_size][
brange, self.sliding_kv_len[:batch_size] % self.window_size
] = sliding_key
self.sliding_kv_cache[1, :batch_size][
brange, self.sliding_kv_len[:batch_size] % self.window_size
] = sliding_value
self.sliding_kv_len[:batch_size] += 1
@triton.jit
def _fill_kv_cache_kernel(
cache_ptr,
k_ptr,
v_ptr,
seq_start,
seq_end,
head_dim,
stride_c2,
stride_cb,
stride_cn,
stride_ch,
stride_cd,
stride_kn,
stride_kh,
stride_kd,
stride_vn,
stride_vh,
stride_vd,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_D: tl.constexpr,
):
# get batch id and head id
pid_2b = tl.program_id(0)
pid_2 = pid_2b % 2
pid_b = pid_2b // 2
pid_h = tl.program_id(1)
pid_n = tl.program_id(2)
# get kv start and len after rmpad
kv_start = tl.load(seq_start + pid_b)
kv_end = tl.load(seq_end + pid_b)
kv_len = kv_end - kv_start
if pid_n * BLOCK_SIZE_N >= kv_len:
return
# load key or value
if pid_2 == 0:
kv_ptrs = tl.make_block_ptr(
base=k_ptr + kv_start * stride_kn + pid_h * stride_kh,
shape=(kv_len, head_dim),
strides=(stride_kn, stride_kd),
offsets=(pid_n * BLOCK_SIZE_N, 0),
block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_D),
order=(1, 0),
)
kv = tl.load(kv_ptrs, boundary_check=(0, 1))
else:
kv_ptrs = tl.make_block_ptr(
base=v_ptr + kv_start * stride_vn + pid_h * stride_vh,
shape=(kv_len, head_dim),
strides=(stride_vn, stride_vd),
offsets=(pid_n * BLOCK_SIZE_N, 0),
block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_D),
order=(1, 0),
)
kv = tl.load(kv_ptrs, boundary_check=(0, 1))
# store to cache
cache_ptrs = tl.make_block_ptr(
base=cache_ptr + pid_2 * stride_c2 + pid_b * stride_cb + pid_h * stride_ch,
shape=(kv_len, head_dim),
strides=(stride_cn, stride_cd),
offsets=(pid_n * BLOCK_SIZE_N, 0),
block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_D),
order=(1, 0),
)
tl.store(cache_ptrs, kv.to(cache_ptr.dtype.element_ty), boundary_check=(0, 1))
def _fill_kv_cache(
kv_cache: torch.Tensor, # shape: [2, b, n, h, d]
key: torch.Tensor, # shape: [l, h, d]
value: torch.Tensor, # shape: [l, h, d]
seq_start: torch.Tensor, # shape: [b]
seq_end: torch.Tensor, # shape: [b]
):
total_len, num_heads, head_dim = key.shape
batch_size = seq_start.shape[0]
max_kv_len = (seq_end - seq_start).max().item()
# no kv cache to fill
if max_kv_len == 0:
return
BLOCK_SIZE_N = min(1024, triton.next_power_of_2(max_kv_len))
BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
grid = (2 * batch_size, num_heads, triton.cdiv(max_kv_len, BLOCK_SIZE_N))
_fill_kv_cache_kernel[grid](
kv_cache,
key,
value,
seq_start,
seq_end,
head_dim,
kv_cache.stride(0),
kv_cache.stride(1),
kv_cache.stride(2),
kv_cache.stride(3),
kv_cache.stride(4),
key.stride(0),
key.stride(1),
key.stride(2),
value.stride(0),
value.stride(1),
value.stride(2),
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_D=BLOCK_SIZE_D,
)
================================================
FILE: native_sparse_attention/module/native_sparse_attention.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from flash_attn import flash_attn_varlen_func
from native_sparse_attention.ops import (
compressed_attention,
topk_sparse_attention,
avgpool_compress,
weightedpool_compress,
linear_compress,
)
from einops import rearrange
from native_sparse_attention.module.rope import RopeConfig, RotaryEmbedding
from native_sparse_attention.infer import nsa_infer
from native_sparse_attention.module.kv_cache import NSACache
COMPRESS_TYPE_TO_FUNC = {
"avgpool": avgpool_compress,
"weightedpool": weightedpool_compress,
"linear": linear_compress,
}
COMPRESS_TYPE_TO_WEIGHT = {
"avgpool": lambda num_heads, head_dim, kernel_size: None,
"weightedpool": lambda num_heads, head_dim, kernel_size: torch.nn.Parameter(
torch.zeros(num_heads, kernel_size)
),
"linear": lambda num_heads, head_dim, kernel_size: torch.nn.Parameter(
torch.zeros(num_heads, head_dim * kernel_size, head_dim)
),
}
class NativeSparseAttention(torch.nn.Module):
"""Native sparse attention module, support training and inference
Args:
compress_type (str): key value compression type, currently support ['linear', 'avgpool', 'weightedpool']
hidden_size (int): hidden dimension
num_q_heads (int): number of query heads
num_kv_heads (int): number of key/value heads, must be divisible by num_q_heads
head_dim (int): head dim
kernel_size (int): kernel size of compression
kernel_stride (int): kernel stride ofr compression
block_size (int): block size of sparse attention
topk (int): topk of sparse attention
init_blocks (int): number of blocks at the begining of the sequence, these blocks are force to be computed in sparse attention
local_blocks (int): number of blocks at the local window of each query, these blocks are force to be computed in sparse attention
window_size (int): window size for sliding window attention
rope_config (RopeConfig): config for rotary embedding, see native_sparse_attention.module.rope.RopeConfig for details
rope_device (str): device used to store rope freqs
"""
def __init__(
self,
compress_type: str,
hidden_size: int,
num_q_heads: int,
num_kv_heads: int,
head_dim: int,
kernel_size: int,
kernel_stride: int,
block_size: int,
topk: int,
init_blocks: int,
local_blocks: int,
window_size: int,
rope_config: RopeConfig,
rope_device: str = "cuda",
):
super().__init__()
# configs
self.compress_type = compress_type
self.hidden_size = hidden_size
self.num_q_heads = num_q_heads
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.kernel_size = kernel_size
self.kernel_stride = kernel_stride
self.block_size = block_size
self.topk = topk
self.init_blocks = init_blocks
self.local_blocks = local_blocks
self.window_size = window_size
self.rope_config = rope_config
assert self.head_dim == self.rope_config.head_dim
# qkv proj and o proj
self.proj_q = torch.nn.Linear(
self.hidden_size, self.num_q_heads * self.head_dim, bias=False
)
self.proj_k = torch.nn.Linear(
self.hidden_size, self.num_kv_heads * self.head_dim, bias=False
)
self.proj_v = torch.nn.Linear(
self.hidden_size, self.num_kv_heads * self.head_dim, bias=False
)
self.proj_o = torch.nn.Linear(
self.num_q_heads * self.head_dim, self.hidden_size, bias=False
)
# nsa compress func
self.compress_func = COMPRESS_TYPE_TO_FUNC[self.compress_type]
# nsa parameteres
self.compress_key = COMPRESS_TYPE_TO_WEIGHT[self.compress_type](
num_kv_heads, head_dim, kernel_size
)
self.compress_value = COMPRESS_TYPE_TO_WEIGHT[self.compress_type](
num_kv_heads, head_dim, kernel_size
)
self.intra_block_pe = torch.nn.Parameter(
torch.zeros(self.num_kv_heads, self.kernel_size, self.head_dim)
)
# gate function
self.gate = torch.nn.Sequential(
torch.nn.Linear(self.hidden_size, self.num_q_heads * 3, bias=False),
torch.nn.Sigmoid(),
)
# rope
self.rope = RotaryEmbedding(self.rope_config, device=rope_device)
# init parameters
self.init_params()
def init_params(self):
for p in self.parameters():
if len(p.shape) > 1:
torch.nn.init.xavier_uniform_(p)
def forward(
self,
x: torch.Tensor, # shape: [total_len, hidden_size]
cu_seqlens: torch.Tensor, # shape: [batch_size + 1]
):
# dtype and shape check
assert x.dtype == torch.bfloat16 or x.dtype == torch.float16
assert x.shape[-1] == self.hidden_size
cu_seqlens = cu_seqlens.to(torch.int32)
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
# qkv proj
q = self.proj_q(x).view(-1, self.num_q_heads, self.head_dim)
k = self.proj_k(x).view(-1, self.num_kv_heads, self.head_dim)
v = self.proj_v(x).view(-1, self.num_kv_heads, self.head_dim)
# compressed key and value before rope
compressed_k, compressed_cu_seqlens = self.compress_func(
k,
self.compress_key,
cu_seqlens,
self.kernel_size,
self.kernel_stride,
self.intra_block_pe,
)
compressed_v, _ = self.compress_func(
v,
self.compress_value,
cu_seqlens,
self.kernel_size,
self.kernel_stride,
None,
)
# do rope for query and compressed key
q = self.rope(q, cu_seqlens)
compressed_k = self.rope(
compressed_k, compressed_cu_seqlens, stride=self.kernel_stride
)
# attention between query and compressed key value
compressed_seqlens = compressed_cu_seqlens[1:] - compressed_cu_seqlens[:-1]
compressed_attn_output, topk_idx = compressed_attention(
q,
compressed_k,
compressed_v,
self.kernel_size,
self.kernel_stride,
self.block_size,
self.topk,
cu_seqlens,
compressed_cu_seqlens,
seqlens.max().item(),
compressed_seqlens.max().item(),
None,
self.init_blocks,
self.local_blocks,
)
# do rope for original key
k = self.rope(k, cu_seqlens)
# topk sparse attention
sparse_attn_output = topk_sparse_attention(
q, k, v, topk_idx, self.block_size, cu_seqlens, None
)
# sliding window attention
sliding_attn_output = flash_attn_varlen_func(
q,
k,
v,
cu_seqlens,
cu_seqlens,
seqlens.max().item(),
seqlens.max().item(),
causal=True,
window_size=(self.window_size, -1),
)
# gate average
gate = self.gate(x)
gate = rearrange(gate, "n (h g) -> n h g", g=3)
attn_output = (
gate[..., 0:1] * compressed_attn_output
+ gate[..., 1:2] * sparse_attn_output
+ gate[..., 2:3] * sliding_attn_output
)
# rearrange and output proj
attn_output = rearrange(attn_output, "n h d -> n (h d)")
attn_output = self.proj_o(attn_output)
return attn_output
@torch.no_grad()
def inference(
self,
x: torch.Tensor, # shape: [total_len, hidden_size]
cu_seqlens: torch.Tensor, # shape: [batch_size + 1]
step: int,
cache: NSACache,
):
# dtype and shape check
assert x.dtype == torch.bfloat16 or x.dtype == torch.float16
assert x.shape[-1] == self.hidden_size
cu_seqlens = cu_seqlens.to(torch.int32)
assert step >= 0
if step == 0:
assert x.shape[0] == cu_seqlens[-1]
else:
assert x.shape[0] == cu_seqlens.shape[0] - 1
# qkv proj
q = self.proj_q(x).view(-1, self.num_q_heads, self.head_dim)
k = self.proj_k(x).view(-1, self.num_kv_heads, self.head_dim)
v = self.proj_v(x).view(-1, self.num_kv_heads, self.head_dim)
# gate proj
gate = self.gate(x)
gate = rearrange(gate, "n (h g) -> n h g", g=3)
# nsa infer
output = nsa_infer(
cu_seqlens,
step,
q,
k,
v,
gate,
self.rope,
cache,
[self.compress_key, self.compress_value],
[self.compress_func, self.compress_func],
self.intra_block_pe,
self.kernel_size,
self.kernel_stride,
self.block_size,
self.topk,
self.init_blocks,
self.local_blocks,
self.window_size,
)
# output proj
output = rearrange(output, "n h d -> n (h d)")
output = self.proj_o(output)
return output
================================================
FILE: native_sparse_attention/module/rope.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from dataclasses import dataclass, field
from torch import nn
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
# default to llama3.1 rope config
@dataclass
class RopeConfig:
"""Config for RotaryEmbedding, similar to transformers llama."""
max_position_embeddings: int = 131072
head_dim: int = 128
rope_theta: float = 500000
rope_scaling: dict = field(
default_factory=lambda: {
"factor": 8.0,
"high_freq_factor": 4.0,
"low_freq_factor": 1.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3",
}
)
# useless, just for compatibility, please use head_dim instead
hidden_size: int = 1
num_attention_heads: int = 1
def __post_init__(self):
self.num_attention_heads = 1
self.hidden_size = self.head_dim
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
# copy and modify from modify from hugigngface transformers
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
class RotaryEmbedding(nn.Module):
"""Rotary embedding
Args:
config (RopeConfig): config for rotary embedding, see native_sparse_attention.module.rope.RopeConfig for details
device (str): default to 'cuda'
"""
cos = None
sin = None
def __init__(
self, config: RopeConfig, device=torch.device(torch.cuda.current_device())
):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get(
"rope_type", config.rope_scaling.get("type")
)
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
def _dynamic_frequency_update(self, position_ids, device):
"""
dynamic RoPE layers should recompute `inv_freq` in the following situations:
1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len
)
self.register_buffer(
"inv_freq", inv_freq, persistent=False
) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len
if (
seq_len < self.original_max_seq_len
and self.max_seq_len_cached > self.original_max_seq_len
): # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len
@torch.no_grad()
def generate_cos_sin(self, x: torch.Tensor, position_ids):
if "dynamic" in self.rope_type:
self._dynamic_frequency_update(position_ids, device=x.device)
# Core RoPE block
inv_freq_expanded = (
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type
device_type = (
device_type
if isinstance(device_type, str) and device_type != "mps"
else "cpu"
)
with torch.autocast(device_type=device_type, enabled=False):
freqs = (
inv_freq_expanded.float() @ position_ids_expanded.float()
).transpose(1, 2)
# # donot use this if use flash_attn
# emb = torch.cat((freqs, freqs), dim=-1)
emb = freqs
cos = emb.cos()
sin = emb.sin()
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
cos = (cos * self.attention_scaling).to(dtype=x.dtype).squeeze(0)
sin = (sin * self.attention_scaling).to(dtype=x.dtype).squeeze(0)
# save cos sin
RotaryEmbedding.cos = torch.cat([cos, cos], dim=-1)
RotaryEmbedding.sin = torch.cat([sin, sin], dim=-1)
return RotaryEmbedding.cos, RotaryEmbedding.sin
@torch.no_grad()
def generate_pos_embs(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
seqlens: torch.Tensor,
step: int = 0,
stride: int = 1,
):
if (
RotaryEmbedding.cos is None
or seqlens.max() + step > RotaryEmbedding.cos.shape[0]
):
self.generate_cos_sin(
x, torch.arange(seqlens.max() + step).to(x.device)[None, :]
)
cos_embs = []
sin_embs = []
bsz = len(cu_seqlens) - 1
for i in range(bsz):
if step == 0: # prefilling
r = cu_seqlens[i + 1] - cu_seqlens[i]
cos_emb, sin_emb = (
RotaryEmbedding.cos[: r * stride : stride],
RotaryEmbedding.sin[: r * stride : stride],
)
elif step > 0: # decoding
r = cu_seqlens[i + 1] - cu_seqlens[i] + step - 1
cos_emb, sin_emb = (
RotaryEmbedding.cos[r * stride : r * stride + 1],
RotaryEmbedding.sin[r * stride : r * stride + 1],
)
cos_embs.append(cos_emb)
sin_embs.append(sin_emb)
cos_embs = torch.cat(cos_embs, dim=0)
sin_embs = torch.cat(sin_embs, dim=0)
return cos_embs, sin_embs
def forward(self, x, cu_seqlens, step=0, stride=1):
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
cos_embs, sin_embs = self.generate_pos_embs(
x,
cu_seqlens,
seqlens,
step=step,
stride=stride,
)
N, H, D = x.shape[0], x.shape[-2], x.shape[-1] # H: number of heads
x = x * cos_embs.view(N, 1, D) + rotate_half(x) * sin_embs.view(N, 1, D)
return x
================================================
FILE: native_sparse_attention/module/self_attention.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from flash_attn import flash_attn_varlen_func
from einops import rearrange
from native_sparse_attention.module.rope import RopeConfig, RotaryEmbedding
from native_sparse_attention.module.kv_cache import KVCache
from native_sparse_attention.ops import flash_attention_decode
class SelfAttention(torch.nn.Module):
"""self attention module
Args:
hidden_size (int): hidden dimension
num_q_heads (int): number of query heads
num_kv_heads (int): number of key/value heads, must be divisible by num_q_heads
head_dim (int): head dim
rope_config (RopeConfig): config for rotary embedding, see native_sparse_attention.module.rope.RopeConfig for details
"""
def __init__(
self,
hidden_size: int,
num_q_heads: int,
num_kv_heads: int,
head_dim: int,
rope_config: RopeConfig,
rope_device: str = "cuda",
):
super().__init__()
# configs
self.hidden_size = hidden_size
self.num_q_heads = num_q_heads
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.rope_config = rope_config
assert self.head_dim == self.rope_config.head_dim
# qkv proj and o proj
self.proj_q = torch.nn.Linear(
self.hidden_size, self.num_q_heads * self.head_dim, bias=False
)
self.proj_k = torch.nn.Linear(
self.hidden_size, self.num_kv_heads * self.head_dim, bias=False
)
self.proj_v = torch.nn.Linear(
self.hidden_size, self.num_kv_heads * self.head_dim, bias=False
)
self.proj_o = torch.nn.Linear(
self.num_q_heads * self.head_dim, self.hidden_size, bias=False
)
# rope
self.rope = RotaryEmbedding(self.rope_config, device=rope_device)
# init parameters
self.init_params()
def init_params(self):
for p in self.parameters():
torch.nn.init.xavier_uniform_(p)
def forward(
self,
x: torch.Tensor, # shape: [total_len, hidden_size]
cu_seqlens: torch.Tensor, # shape: [batch_size + 1]
):
# dtype and shape check
assert x.dtype == torch.bfloat16 or x.dtype == torch.float16
assert x.shape[-1] == self.hidden_size
cu_seqlens = cu_seqlens.to(torch.int32)
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
# qkv proj
q = self.proj_q(x).view(-1, self.num_q_heads, self.head_dim)
k = self.proj_k(x).view(-1, self.num_kv_heads, self.head_dim)
v = self.proj_v(x).view(-1, self.num_kv_heads, self.head_dim)
# do rope for query and compressed key
q = self.rope(q, cu_seqlens)
k = self.rope(k, cu_seqlens)
# self attention
attn_output = flash_attn_varlen_func(
q,
k,
v,
cu_seqlens,
cu_seqlens,
seqlens.max().item(),
seqlens.max().item(),
causal=True,
)
# rearrange and output proj
attn_output = rearrange(attn_output, "n h d -> n (h d)")
attn_output = self.proj_o(attn_output)
return attn_output
@torch.no_grad()
def inference(
self,
x: torch.Tensor, # shape: [total_len, hidden_size]
cu_seqlens: torch.Tensor, # shape: [batch_size + 1]
step: int,
cache: KVCache,
):
# dtype and shape check
assert x.dtype == torch.bfloat16 or x.dtype == torch.float16
assert x.shape[-1] == self.hidden_size
cu_seqlens = cu_seqlens.to(torch.int32)
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
assert step >= 0
if step == 0:
assert x.shape[0] == cu_seqlens[-1]
else:
assert x.shape[0] == cu_seqlens.shape[0] - 1
batch_size = cu_seqlens.shape[0] - 1
# qkv proj
q = self.proj_q(x).view(-1, self.num_q_heads, self.head_dim)
k = self.proj_k(x).view(-1, self.num_kv_heads, self.head_dim)
v = self.proj_v(x).view(-1, self.num_kv_heads, self.head_dim)
# do rope for query and compressed key
q = self.rope(q, cu_seqlens, step)
k = self.rope(k, cu_seqlens, step)
# reset and update kv cache
if step == 0:
cache.reset()
cache.update_kv(cu_seqlens, step, k, v)
# self attention
if step == 0:
cu_seqlens_q = cu_seqlens_k = cu_seqlens
max_seqlen_in_batch_q = max_seqlen_in_batch_k = seqlens.max().item()
output = flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
causal=True,
)
else:
output = flash_attention_decode(
q,
cache.kv_cache[0, :batch_size],
cache.kv_cache[1, :batch_size],
cache.kv_len[:batch_size],
)
# rearrange and output proj
output = rearrange(output, "n h d -> n (h d)")
output = self.proj_o(output)
return output
================================================
FILE: native_sparse_attention/ops/README.md
================================================
# Triton Functions for Native Sparse Attention
This folder provides efficient Triton-based implementations of components for Native Sparse Attention. This README introduces the available functions, explains how to set them up, and offers guidance on their usage.
---
## Overview of Functions
The functions are organized into two main categories:
1. **Compression Methods**: Techniques for compressing key and value tensors.
2. **Attention Mechanisms**: Methods for computing attention between queries and compressed key/value tensors, including top-k sparse attention.
---
## Function Descriptions
### Compression Methods
These functions compress key and value tensors using a sliding window approach. Within each window, `kernel_size` tokens are compressed into a single token, with a stride of `kernel_stride`. All compression functions share similar input parameters and return formats.
**Parameters:**
- `x`: Input tensor (`total_len, num_heads, head_dim`)
- `w`: Weight tensor (shape varies by compression method)
- `cu_seqlens`: Cumulative sequence lengths (`batch_size + 1`)
- `kernel_size`: Size of the compression window
- `kernel_stride`: Stride of the compression window
- `pe`: Optional positional embedding (`num_heads, kernel_size, head_dim`)
**Returns:**
- Compressed tensor (`total_compress_len, num_heads, head_dim`)
- Cumulative sequence lengths (`com_cu_seqlens`) for the compressed tensor
#### `weightedpool_compress`
Compresses the input tensor using weighted pooling, applying a weighted sum over each block:
$\hat{k} = w_1 k_1 + \dots + w_m k_m$
- **Weight shape**: `(num_heads, kernel_size)`
#### `avgpool_compress`
Compresses the input tensor using average pooling:
$\hat{k} = (k_1 + \dots + k_m) / m$
- **Weight**: Must be `None`
#### `linear_compress`
Compresses the input tensor via linear projection, mapping each block to a single vector using learned weights:
$\hat{k} = \text{cat}(k_1, \dots, k_m) W$
- **Weight shape**: `(num_heads, kernel_size * head_dim, head_dim)`
---
### Attention Mechanisms
These functions compute attention using either full or sparse mechanisms.
#### `flash_attention_varlen`
A variable-length implementation of flash attention, similar to `flash_attn_varlen_func` from the `flash_attn` package.
**Parameters:**
- `q`, `k`, `v`: Query, key, and value tensors (`total_len, num_heads, head_dim`)
- `cu_seqlens_q`, `cu_seqlens_k`: Cumulative sequence lengths for queries and keys
- `max_seqlen_q`, `max_seqlen_k`: Maximum sequence lengths in the batch
- `causal`: Apply causal masking (default: `False`)
- `sm_scale`: Softmax scale (default: `1 / sqrt(head_dim)`)
**Returns:**
- Attention output tensor (`total_q_len, num_q_heads, head_dim`)
#### `compressed_attention`
Computes attention between a query and compressed key/value tensors, identifying the top-k blocks for sparse attention.
**Parameters:**
- `q`: Query tensor (`total_len, num_heads, head_dim`)
- `k`, `v`: Compressed key and value tensors (`total_compress_len, num_heads, head_dim`)
- `kernel_size`, `kernel_stride`: Compression parameters
- `block_size`: Size of blocks for sparse attention
- `topk`: Number of top blocks to select
- `cu_seqlens_q`, `cu_seqlens_k`: Cumulative sequence lengths for query and compressed key/value
- `max_seqlen_q`, `max_seqlen_k`: Maximum sequence lengths for query and compressed key/value
- `sm_scale`: Softmax scale (default: `1 / sqrt(head_dim)`)
- `init_blocks`: Number of initial blocks forced to be selected (default: `1`)
- `local_blocks`: Number of local blocks forced to be selected (default: `2`)
**Returns:**
- Tuple containing:
- Attention output tensor
- Top-k block indices
#### `topk_sparse_attention`
Performs sparse attention using precomputed top-k block indices. If a query attends to fewer than `topk` key/value blocks, the `topk_idx` should be padded with `-1` on the right.
**Parameters:**
- `q`, `k`, `v`: Query, key, and value tensors (`total_len, num_heads, head_dim`)
- `topk_idx`: Precomputed top-k indices (`num_kv_heads, total_len, topk`)
- `block_size`: Block size for sparse attention (recommended: `64` or `128`)
- `cu_seqlens`: Cumulative sequence lengths
- `softmax_scale`: Softmax scale (default: `1 / sqrt(head_dim)`)
**Returns:**
- Attention output tensor (`total_len, num_q_heads, head_dim`)
---
## Usage Example
Below is a typical workflow demonstrating how to combine these sparse attention functions:
```python
import torch
from native_sparse_attention.ops import linear_compress, compressed_attention, topk_sparse_attention
# Example input setup
num_q_heads = 64
num_kv_heads = 4
head_dim = 128
cu_seqlens = torch.tensor([0, 1024, 8192, 16384], dtype=torch.int32).cuda()
# Query, key, and value tensors
query = torch.randn(16384, num_q_heads, head_dim, dtype=torch.bfloat16).cuda()
key = torch.randn(16384, num_kv_heads, head_dim, dtype=torch.bfloat16).cuda()
value = torch.randn(16384, num_kv_heads, head_dim, dtype=torch.bfloat16).cuda()
# Compression weights and positional embeddings
kernel_size = 32
kernel_stride = 16
wk = torch.randn(num_kv_heads, kernel_size * head_dim, head_dim, dtype=torch.bfloat16).cuda()
wv = torch.randn_like(wk)
pe = torch.randn(num_kv_heads, kernel_size, head_dim, dtype=torch.bfloat16).cuda()
# Parameters for top-k sparse attention
block_size = 64
topk = 16
# 1. Compress key and value tensors
compressed_key, compressed_cu_seqlens = linear_compress(
key, wk, cu_seqlens, kernel_size, kernel_stride, pe
)
compressed_value, _ = linear_compress(
value, wv, cu_seqlens, kernel_size, kernel_stride, None
)
# 2. Compute attention with compressed key/value and get top-k indices
compressed_attn_output, topk_idx = compressed_attention(
query,
compressed_key,
compressed_value,
kernel_size,
kernel_stride,
block_size,
topk,
cu_seqlens,
compressed_cu_seqlens,
init_blocks=1,
local_blocks=2,
)
# 3. Perform top-k sparse attention
sparse_attn_output = topk_sparse_attention(
query,
key,
value,
topk_idx,
block_size,
cu_seqlens,
)
# 4. Combine attention outputs (e.g., average)
attn_output = (compressed_attn_output + sparse_attn_output) / 2
```
For a complete implementation of the Native Sparse Attention module, see `native_sparse_attention/module/native_sparse_attention.py`.
================================================
FILE: native_sparse_attention/ops/__init__.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# compress method
from native_sparse_attention.ops.triton.weighted_pool import (
weightedpool_compress,
avgpool_compress,
)
from native_sparse_attention.ops.triton.linear_compress import linear_compress
# prefill attention
from native_sparse_attention.ops.triton.flash_attention import flash_attention_varlen
from native_sparse_attention.ops.triton.compressed_attention import compressed_attention
from native_sparse_attention.ops.triton.topk_sparse_attention import (
topk_sparse_attention,
)
# decode attention
from native_sparse_attention.ops.triton.flash_attention_decode import (
flash_attention_decode,
)
from native_sparse_attention.ops.torch.compressed_attention_decode import (
compressed_attention_decode,
)
from native_sparse_attention.ops.triton.topk_sparse_attention_decode import (
topk_sparse_attention_decode,
)
__all__ = [
# compress method
"avgpool_compress",
"weightedpool_compress",
"linear_compress",
# prefill attention, trainable
"flash_attention_varlen",
"compressed_attention",
"topk_sparse_attention",
# decode attention, no grad
"flash_attention_decode",
"compressed_attention_decode",
"topk_sparse_attention_decode",
]
================================================
FILE: native_sparse_attention/ops/torch/__init__.py
================================================
================================================
FILE: native_sparse_attention/ops/torch/compress_key_value.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from typing import Optional
from einops import rearrange, einsum
def avgpool_compress_torch(
x: torch.Tensor,
w: torch.Tensor,
cu_seqlens,
kernel_size: int,
kernel_stride: int,
pe: Optional[torch.Tensor] = None,
):
"""Compress key and value tensor with kernel_size and kernel_stride.
Args:
x (torch.Tensor): key_states or value_states, shape (total_len, num_heads, head_dim)
w (torch.Tensor): no weight for avgpool, must be None.
cu_seqlens (_type_): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.
kernel_size (int): kernel_size, each (kernel_size, head_dim) blocks will be compressed to (1, head_dim)
kernel_stride (int): stride for each compress kernel
pe (Optional[torch.Tensor], optional): intra-block positional embedding with shape (num_heads, kernel_size, head_dim). Defaults to None.
Returns:
Tuple[torch.Tensor, torch.Tensor]: compressed states and corresponding cu_seqlens.
"""
# dtype check
assert x.dtype == torch.float16 or x.dtype == torch.bfloat16
assert cu_seqlens.dtype == torch.int32
assert x.dtype == pe.dtype if pe is not None else True
# shape check
total_len, num_heads, head_dim = x.shape
batch_size = cu_seqlens.shape[0] - 1
assert w is None, "don't need additional weight for avgpool"
assert kernel_size % kernel_stride == 0
assert kernel_size in {16, 32, 64, 128}
# compute seqlens after compression
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
y_seqlens = torch.floor((seqlens - kernel_size) / kernel_stride).to(torch.int32) + 1
# corner case, if sequence_length < kernel_size, no compression for this sequence
y_seqlens[seqlens < kernel_size] = 0
y_cu_seqlens = torch.cat(
[
torch.zeros(1, dtype=torch.int32, device="cuda"),
torch.cumsum(y_seqlens, dim=0),
],
dim=0,
).to(torch.int32)
# pad and rearrange x
x = rearrange(x, "n h d -> n (h d)")
splited_x = torch.split(x, seqlens.tolist(), 0)
x = torch.nn.utils.rnn.pad_sequence(splited_x, batch_first=True)
x = rearrange(x, "b n d -> b d n")
# avgpool
y = torch.nn.functional.avg_pool1d(x, kernel_size=kernel_size, stride=kernel_stride)
y = rearrange(y, "b (h d) n -> b n h d", h=num_heads)
# only keep useful part
y = torch.cat([y[i, : y_seqlens[i]] for i in range(batch_size)], dim=0)
# position embedding as a bias
if pe is not None:
bias = torch.mean(pe, dim=1)
y = y + bias.unsqueeze(0)
return y, y_cu_seqlens
def weightedpool_compress_torch(
x: torch.Tensor,
w: torch.Tensor, # [num_heads, kernel_size]
cu_seqlens,
kernel_size: int,
kernel_stride: int,
pe: Optional[torch.Tensor] = None,
):
"""Compress key and value tensor with kernel_size and kernel_stride.
Args:
x (torch.Tensor): key_states or value_states, shape (total_len, num_heads, head_dim)
w (torch.Tensor): weight for each head, shape (num_heads, kernel_size)
cu_seqlens (_type_): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.
kernel_size (int): kernel_size, each (kernel_size, head_dim) blocks will be compressed to (1, head_dim)
kernel_stride (int): stride for each compress kernel
pe (Optional[torch.Tensor], optional): intra-block positional embedding with shape (num_heads, kernel_size, head_dim). Defaults to None.
Returns:
Tuple[torch.Tensor, torch.Tensor]: compressed states and corresponding cu_seqlens.
"""
# dtype check
assert x.dtype == torch.float16 or x.dtype == torch.bfloat16
assert x.dtype == w.dtype
assert x.dtype == pe.dtype if pe is not None else True
assert cu_seqlens.dtype == torch.int32
# shape check
total_len, num_heads, head_dim = x.shape
batch_size = cu_seqlens.shape[0] - 1
assert w.shape[0] == num_heads
assert w.shape[1] == kernel_size
assert kernel_size % kernel_stride == 0
assert kernel_size in {16, 32, 64, 128}
# compute seqlens after compression
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
y_seqlens = torch.floor((seqlens - kernel_size) / kernel_stride).to(torch.int32) + 1
# corner case, if sequence_length < kernel_size, no compression for this sequence
y_seqlens[seqlens < kernel_size] = 0
y_cu_seqlens = torch.cat(
[
torch.zeros(1, dtype=torch.int32, device="cuda"),
torch.cumsum(y_seqlens, dim=0),
],
dim=0,
).to(torch.int32)
# pad and rearrange x
x = rearrange(x, "n h d -> n (h d)")
splited_x = torch.split(x, seqlens.tolist(), 0)
x = torch.nn.utils.rnn.pad_sequence(splited_x, batch_first=True)
x = rearrange(x, "b n (h d) -> b h n d", h=num_heads)
x = x.as_strided(
size=(batch_size, num_heads, y_seqlens.max().item(), kernel_size, head_dim),
stride=(
x.stride(0),
x.stride(1),
kernel_stride * x.stride(2),
x.stride(2),
x.stride(3),
),
)
y = einsum(x, w, "b h n k d, h k -> b n h d")
# only keep useful part
y = torch.cat([y[i, : y_seqlens[i]] for i in range(batch_size)], dim=0)
# position embedding as a bias
if pe is not None:
bias = einsum(pe, w, "h k d, h k -> h d")
y = y + bias.unsqueeze(0)
return y, y_cu_seqlens
def linear_compress_torch(
x: torch.Tensor,
w: torch.Tensor, # [num_heads, kernel_size * head_dim, head_dim]
cu_seqlens,
kernel_size: int,
kernel_stride: int,
pe: Optional[torch.Tensor] = None,
):
"""Compress key and value tensor with kernel_size and kernel_stride. Similar to conv_compress.
Args:
x (torch.Tensor): key_states or value_states, shape (total_len, num_heads, head_dim)
w (torch.Tensor): weight for each head, shape (num_heads, kernel_size * head_dim, head_dim)
cu_seqlens (_type_): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.
kernel_size (int): kernel_size, each (kernel_size, head_dim) blocks will be compressed to (1, head_dim)
kernel_stride (int): stride for each compress kernel
pe (Optional[torch.Tensor], optional): intra-block positional embedding with shape (num_heads, kernel_size, head_dim). Defaults to None.
Returns:
Tuple[torch.Tensor, torch.Tensor]: compressed states and corresponding cu_seqlens.
"""
# dtype check
assert x.dtype == torch.float16 or x.dtype == torch.bfloat16
assert x.dtype == w.dtype
assert x.dtype == pe.dtype if pe is not None else True
assert cu_seqlens.dtype == torch.int32
# shape check
total_len, num_heads, head_dim = x.shape
batch_size = cu_seqlens.shape[0] - 1
assert w.shape[0] == num_heads
assert w.shape[1] == kernel_size * head_dim
assert w.shape[2] == head_dim
assert kernel_size % kernel_stride == 0
assert kernel_size in {16, 32, 64, 128}
# compute seqlens after compression
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
y_seqlens = torch.floor((seqlens - kernel_size) / kernel_stride).to(torch.int32) + 1
# corner case, if sequence_length < kernel_size, no compression for this sequence
y_seqlens[seqlens < kernel_size] = 0
y_cu_seqlens = torch.cat(
[
torch.zeros(1, dtype=torch.int32, device="cuda"),
torch.cumsum(y_seqlens, dim=0),
],
dim=0,
).to(torch.int32)
# pad and rearrange x
x = rearrange(x, "n h d -> n (h d)")
splited_x = torch.split(x, seqlens.tolist(), 0)
x = torch.nn.utils.rnn.pad_sequence(splited_x, batch_first=True)
x = rearrange(x, "b n (h d) -> b h n d", h=num_heads)
x = x.as_strided(
size=(batch_size, num_heads, y_seqlens.max().item(), kernel_size, head_dim),
stride=(
x.stride(0),
x.stride(1),
kernel_stride * x.stride(2),
x.stride(2),
x.stride(3),
),
)
y = einsum(
x,
rearrange(w, "h (k d) D -> h k d D", k=kernel_size),
"b h n k d, h k d D -> b n h D",
)
# only keep useful part
y = torch.cat([y[i, : y_seqlens[i]] for i in range(batch_size)], dim=0)
# position embedding as a bias
if pe is not None:
pe = rearrange(pe, "h k d -> h (k d)")
bias = einsum(pe, w, "h D, h D d -> h d")
y = y + bias.unsqueeze(0)
return y, y_cu_seqlens
================================================
FILE: native_sparse_attention/ops/torch/compressed_attention.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import math
from typing import Tuple
from collections import Counter
from einops import rearrange
def transform_score(
score: torch.Tensor,
kernel_size: int,
kernel_stride: int,
block_size: int,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
init_blocks: int = 1,
local_blocks: int = 2,
) -> torch.Tensor:
num_k_heads, total_query_len, _ = score.shape
pad_len = kernel_size // kernel_stride - 1
score = torch.nn.functional.pad(score, (pad_len, pad_len), value=0)
max_blocks = math.ceil(max_seqlen_q / block_size)
full_blocks = max_seqlen_q // block_size
block_score = torch.zeros(
num_k_heads,
total_query_len,
max_blocks,
dtype=torch.float32,
device=score.device,
)
offs = (
torch.arange(kernel_size // kernel_stride)[:, None]
+ torch.arange(block_size // kernel_stride)[None, :]
).view(-1)
offs = dict(Counter(offs.tolist()))
for k, v in offs.items():
block_score[..., :full_blocks] += (
v * score[..., k :: block_size // kernel_stride][..., :full_blocks]
)
# set init block and local block score
batch_size = cu_seqlens_q.shape[0] - 1
q_idx = torch.cat(
[
torch.arange(cu_seqlens_q[i + 1] - cu_seqlens_q[i], device=score.device)
for i in range(batch_size)
],
dim=0,
)
q_idx = q_idx // block_size
b_idx = torch.arange(max_blocks, device=score.device)
block_score[..., :init_blocks] = torch.inf
local_mask = (q_idx[:, None] >= b_idx[None, :]) & (
q_idx[:, None] < b_idx[None, :] + local_blocks
)
local_mask = local_mask.unsqueeze(0).expand(num_k_heads, -1, -1)
block_score[local_mask] = torch.inf
return block_score
def compressed_attention_torch(
q: torch.Tensor, # [total_query_len, num_q_heads, head_dim]
k: torch.Tensor, # [total_key_len, num_k_heads, head_dim]
v: torch.Tensor, # [total_key_len, num_k_heads, head_dim]
kernel_size: int,
kernel_stride: int,
block_size: int,
topk: int,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
sm_scale: float = None,
init_blocks: int = 1,
local_blocks: int = 2,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Attention between query and compressed key and value. Implemented with torch, only for debug.
Args:
q (torch.Tensor): shape [total_q_len, num_q_heads, head_dim]
k (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim]
v (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim]
kernel_size (int): kernel size in compress_key_value
kernel_stride (int): stride of compress_key_value
block_size (int): key value block size for topk sparse attention.
topk (int): number of blocks for each query.
cu_seqlens_q (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.
cu_seqlens_k (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_k in flash_attn_func_varlen.
max_seqlen_q (int): max q len of the batch.
max_seqlen_k (int): max k len of the batch.
sm_scale (float, optional): softmax scale. Defaults to None, means 1/sqrt(head_dim).
init_blocks (int, optional): Number of init blocks for each query. Defaults to 1.
local_blocks (int, optional): Number of local blocks for each query. Defaults to 2.
Returns:
Tuple[torch.Tensor, torch.Tensor]: attention output and topk_idx used in topk_sparse_attention
"""
assert block_size % kernel_size == 0 and kernel_size % kernel_stride == 0
total_query_len, num_q_heads, head_dim = q.shape
total_key_len, num_k_heads, _ = k.shape
num_share_q_heads = num_q_heads // num_k_heads
batch_size = cu_seqlens_q.shape[0] - 1
if sm_scale is None:
sm_scale = 1.0 / math.sqrt(head_dim)
# get mask
mask = torch.zeros(
(total_query_len, total_key_len), dtype=torch.bool, device=q.device
)
for b in range(batch_size):
q_len, k_len = (
cu_seqlens_q[b + 1] - cu_seqlens_q[b],
cu_seqlens_k[b + 1] - cu_seqlens_k[b],
)
k_max_ids = (
torch.arange(k_len, device=q.device) * kernel_stride + kernel_size - 1
)
q_ids = torch.arange(q_len, device=q.device)
mask[
cu_seqlens_q[b] : cu_seqlens_q[b + 1], cu_seqlens_k[b] : cu_seqlens_k[b + 1]
] = (q_ids[:, None] >= k_max_ids[None, :])
# attention
qk = (
torch.einsum("qhd,khd->hqk", q, k.repeat_interleave(num_share_q_heads, 1))
* sm_scale
)
qk = qk.masked_fill_(~mask[None, ...], -torch.inf)
# query from beginning of the sequence can't attend to any compressed key
qk = qk.softmax(dim=-1, dtype=torch.float32)
qk = qk.nan_to_num(0)
attn_output = torch.einsum(
"hqk,khd->qhd", qk.to(v.dtype), v.repeat_interleave(num_share_q_heads, 1)
)
with torch.no_grad():
# get avg score over gqa heads
# qk shape [num_k_heads, total_q_len, total_k_len]
score = torch.zeros(
num_k_heads,
cu_seqlens_q[-1],
max_seqlen_k,
dtype=torch.float32,
device=q.device,
)
qk = rearrange(qk, "(h g) q k -> h g q k", h=num_k_heads).sum(1)
for b in range(batch_size):
score[
:,
cu_seqlens_q[b] : cu_seqlens_q[b + 1],
: cu_seqlens_k[b + 1] - cu_seqlens_k[b],
] = qk[
:,
cu_seqlens_q[b] : cu_seqlens_q[b + 1],
cu_seqlens_k[b] : cu_seqlens_k[b + 1],
]
# transform score to block-wise score
score = transform_score(
score,
kernel_size,
kernel_stride,
block_size,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
init_blocks,
local_blocks,
)
# get topk
batch_size = cu_seqlens_q.shape[0] - 1
q_idx = torch.cat(
[
torch.arange(cu_seqlens_q[i + 1] - cu_seqlens_q[i], device=q.device)
for i in range(batch_size)
],
dim=0,
)
q_idx = q_idx // block_size
topk = min(topk, score.shape[-1])
topk_idx = score.topk(topk, dim=-1).indices.sort(-1).values
topk_idx[topk_idx > q_idx[None, :, None]] = -1
topk_idx = topk_idx.to(torch.int32)
return attn_output, topk_idx
================================================
FILE: native_sparse_attention/ops/torch/compressed_attention_decode.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import math
from typing import Tuple, Optional
from collections import Counter
from einops import rearrange
def transform_score(
score: torch.Tensor,
seqlens: torch.Tensor,
kernel_size: int,
kernel_stride: int,
block_size: int,
init_blocks: int = 1,
local_blocks: int = 2,
) -> torch.Tensor:
num_k_heads, batch_size, kv_len = score.shape
pad_len = kernel_size // kernel_stride - 1
score = torch.nn.functional.pad(score, (pad_len, pad_len), value=0)
max_seqlen = seqlens.max().item()
max_blocks = math.ceil(max_seqlen / block_size)
full_blocks = max_seqlen // block_size
block_score = torch.zeros(
num_k_heads,
batch_size,
max_blocks,
dtype=torch.float32,
device=score.device,
)
offs = (
torch.arange(kernel_size // kernel_stride)[:, None]
+ torch.arange(block_size // kernel_stride)[None, :]
).view(-1)
offs = dict(Counter(offs.tolist()))
for k, v in offs.items():
block_score[..., :full_blocks] += (
v * score[..., k :: block_size // kernel_stride][..., :full_blocks]
)
# set init block and local block score
q_idx = (seqlens - 1) // block_size
b_idx = torch.arange(max_blocks, device=score.device)
block_score[..., :init_blocks] = torch.inf
local_mask = (q_idx[:, None] >= b_idx[None, :]) & (
q_idx[:, None] < b_idx[None, :] + local_blocks
)
local_mask = local_mask.unsqueeze(0).expand(num_k_heads, -1, -1)
block_score[local_mask] = torch.inf
block_score = block_score.nan_to_num(0, torch.inf, -torch.inf)
return block_score
def compressed_attention_decode(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seqlens: torch.Tensor,
compress_seqlens: torch.Tensor,
kernel_size: int,
kernel_stride: int,
block_size: int,
topk: int,
init_blocks: int = 1,
local_blocks: int = 2,
sm_scale: Optional[float] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""_summary_
Args:
q (torch.Tensor): shape [batch_size, num_q_heads, head_dim]
k (torch.Tensor): shape [batch_size, kv_len, num_kv_heads, head_dim]
v (torch.Tensor): shape [batch_size, kv_len, num_kv_heads, head_dim]
seqlens (torch.Tensor): original kv length for each sequence
compress_seqlens (torch.Tensor): kv length for each sequence after compression
kernel_size (int): kernel size in compress_key_value
kernel_stride (int): stride of compress_key_value
block_size (int): key value block size for topk sparse attention.
topk (int): number of blocks for each query.
init_blocks (int, optional): Number of init blocks for each query. Defaults to 1.
local_blocks (int, optional): Number of local blocks for each query. Defaults to 2.
sm_scale (float, optional): softmax scale. Defaults to None, means 1/sqrt(head_dim).
Returns:
Tuple[torch.Tensor, torch.Tensor]: attention output and topk_idx used in topk_sparse_attention_decode
"""
assert block_size % kernel_size == 0 and kernel_size % kernel_stride == 0
batch_size, num_q_heads, head_dim = q.shape
batch_size, kv_len, num_k_heads, _ = k.shape
num_share_q_heads = num_q_heads // num_k_heads
if sm_scale is None:
sm_scale = 1.0 / math.sqrt(head_dim)
# input is too short to have a valid block
if kv_len == 0:
return torch.zeros_like(q), torch.zeros(
num_k_heads, batch_size, 1, device=q.device, dtype=torch.int32
)
# get mask
mask = (
compress_seqlens[:, None]
> torch.arange(
kv_len, device=compress_seqlens.device, dtype=compress_seqlens.dtype
)[None, :]
)
# attention
qk = (
torch.einsum(
"bihgd, bjhgd -> bhgij",
rearrange(q, "b (h g) d -> b 1 h g d", g=num_share_q_heads),
rearrange(k, "b j h d -> b j h 1 d"),
)
* sm_scale
)
qk = qk.masked_fill_(~mask[:, None, None, None, :], -torch.inf)
qk = qk.softmax(dim=-1, dtype=torch.float32)
qk = qk.nan_to_num_(0) # qk is nan when seqlen == 0
attn_output = torch.einsum(
"bhgij, bjhgd -> bihgd",
qk.to(v.dtype),
rearrange(v, "b k h d -> b k h 1 d"),
)
attn_output = rearrange(attn_output, "b 1 h g d -> b (h g) d")
# get score
score = rearrange(qk.sum(2).squeeze(2), "b h j -> h b j")
# transform score to block-wise score
score = transform_score(
score,
seqlens,
kernel_size,
kernel_stride,
block_size,
init_blocks,
local_blocks,
)
# get topk
q_idx = (seqlens - 1) // block_size
topk = min(topk, score.shape[-1])
topk_idx = score.topk(topk, dim=-1).indices.sort(-1).values
topk_idx[topk_idx > q_idx[None, :, None]] = -1
topk_idx = topk_idx.to(torch.int32)
return attn_output, topk_idx
================================================
FILE: native_sparse_attention/ops/torch/topk_sparse_attention.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import math
from typing import Optional
def topk_sparse_attention_torch(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
topk_idx: torch.Tensor,
block_size_k: int,
cu_seqlens: torch.Tensor,
softmax_scale: Optional[float] = None,
block_size_q: int = 1,
) -> torch.Tensor:
"""Simple topk sparse attention varlen version implemented in torch. Extremly slow, only for debugging.
Args:
q (torch.Tensor): shape [total_len, num_q_heads, head_dim]
k (torch.Tensor): shape [total_len, num_kv_heads, head_dim]
v (torch.Tensor): shape [total_len, num_kv_heads, head_dim]
topk_idx (torch.Tensor): topk block idx for each query, shape [num_kv_heads, total_len, topk]. -1 means padding.
block_size_q (int): query block size.
block_size_k (int): key value block size.
cu_seqlens (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens in flash_attn_func_varlen.
softmax_scale (Optional[float], optional): Defaults to None, means 1/sqrt(head_dim).
Returns:
torch.Tensor: attention output, shape [total_len, num_q_heads, head_dim]
"""
total_seqlen, num_q_heads, head_dim = q.shape
total_seqlen, num_kv_heads, head_dim = k.shape
num_share_q_heads = num_q_heads // num_kv_heads
batch_size = cu_seqlens.shape[0] - 1
topk = topk_idx.shape[-1]
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
seqblocks_q = torch.ceil(seqlens / block_size_q).to(torch.int32)
cu_seqblocks_q = torch.nn.functional.pad(seqblocks_q.cumsum(0), (1, 0), value=0)
if softmax_scale is None:
softmax_scale = 1.0 / math.sqrt(head_dim)
# get mask
mask = torch.zeros(
(num_kv_heads, total_seqlen, total_seqlen), dtype=torch.bool, device=q.device
)
for i in range(batch_size):
num_q_blocks = math.ceil(seqlens[i] / block_size_q)
num_kv_blocks = math.ceil(seqlens[i] / block_size_k)
for h in range(num_kv_heads):
temp_mask = torch.zeros(
num_q_blocks, num_kv_blocks, dtype=torch.bool, device=q.device
)
temp_idx = topk_idx[h, cu_seqblocks_q[i] : cu_seqblocks_q[i + 1]].clone()
temp_idx[temp_idx < 0] = 0
temp_mask[torch.arange(num_q_blocks).to(q.device)[:, None], temp_idx] = True
temp_mask = torch.repeat_interleave(temp_mask, block_size_q, dim=0)
temp_mask = torch.repeat_interleave(temp_mask, block_size_k, dim=1)
temp_mask = temp_mask[: seqlens[i], : seqlens[i]]
mask[
h, cu_seqlens[i] : cu_seqlens[i + 1], cu_seqlens[i] : cu_seqlens[i + 1]
] = temp_mask
mask = torch.tril(mask).repeat_interleave(num_share_q_heads, 0)
# qk attn
qk = (
torch.einsum("qhd,khd->hqk", q, k.repeat_interleave(num_share_q_heads, 1))
* softmax_scale
)
qk = torch.masked_fill(qk, ~mask, -torch.inf)
qk = torch.softmax(qk, dim=-1, dtype=torch.float32).to(q.dtype)
o = torch.einsum("hqk,khd->qhd", qk, v.repeat_interleave(num_share_q_heads, 1))
return o
================================================
FILE: native_sparse_attention/ops/triton/__init__.py
================================================
================================================
FILE: native_sparse_attention/ops/triton/compressed_attention.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Any, Tuple, Union
from collections import Counter
import torch
import triton
import triton.language as tl
import warnings
from native_sparse_attention.ops.triton.utils import get_num_warps_stages, is_hopper_gpu
IS_HOPPER_GPU = is_hopper_gpu()
@triton.jit
def forward_kernel(
q_ptr, # Q: n x h x d
k_ptr, # K: n x h x d
v_ptr, # V: n x h x d
o_ptr, # O: n x h x d
lse_ptr, # LSE: h x n
# size and stride at compresstion
kernel_size,
kernel_stride,
# seqlens
cu_seqlens_q,
cu_seqlens_k,
# shape
NUM_KV_HEADS,
NUM_SHARE_Q_HEADS,
HEAD_DIM,
# sm_scale
sm_scale,
# stride
stride_qn,
stride_qh,
stride_qd,
stride_kn,
stride_kh,
stride_kd,
stride_vn,
stride_vh,
stride_vd,
stride_on,
stride_oh,
stride_od,
stride_lh,
stride_ln,
# META parameters
BLOCK_SIZE_Q: tl.constexpr, # q block size
BLOCK_SIZE_K: tl.constexpr, # k block size
BLOCK_SIZE_D: tl.constexpr,
):
qk_scale = sm_scale * 1.44269504
# get batch id and head id
pid_b = tl.program_id(0)
pid_h = tl.program_id(1)
pid_q = tl.program_id(2)
pid_kh = pid_h // NUM_SHARE_Q_HEADS
# get q k start and len after rmpad
q_start = tl.load(cu_seqlens_q + pid_b)
q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
k_start = tl.load(cu_seqlens_k + pid_b)
k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
# skip first kernel_size query block, because they do no attend to any keys
q_start_in_seq = pid_q * BLOCK_SIZE_Q + kernel_size - 1
if q_start_in_seq >= q_len:
return
# init qkv pointer
q_ptrs = tl.make_block_ptr(
base=q_ptr + q_start * stride_qn + pid_h * stride_qh,
shape=(q_len, HEAD_DIM),
strides=(stride_qn, stride_qd),
offsets=(q_start_in_seq, 0),
block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
order=(1, 0),
)
k_ptrs = tl.make_block_ptr(
base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
shape=(HEAD_DIM, k_len),
strides=(stride_kd, stride_kn),
offsets=(0, 0),
block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K),
order=(0, 1),
)
v_ptrs = tl.make_block_ptr(
base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,
shape=(k_len, HEAD_DIM),
strides=(stride_vn, stride_vd),
offsets=(0, 0),
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
order=(1, 0),
)
# load q
q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero")
# init statistics
off_q = tl.arange(0, BLOCK_SIZE_Q) + q_start_in_seq
off_k = tl.arange(0, BLOCK_SIZE_K) * kernel_stride + kernel_size - 1
m_i = tl.full((BLOCK_SIZE_Q,), float("-inf"), dtype=tl.float32)
lse_i = tl.full((BLOCK_SIZE_Q,), float("-inf"), dtype=tl.float32)
acc_o = tl.full((BLOCK_SIZE_Q, BLOCK_SIZE_D), 0, dtype=tl.float32)
# attention
lo = 0
hi = min(k_len, (q_start_in_seq + BLOCK_SIZE_Q - kernel_size) // kernel_stride + 1)
for i in range(lo, hi, BLOCK_SIZE_K):
i = tl.multiple_of(i, BLOCK_SIZE_K)
# load k
k = tl.load(k_ptrs, boundary_check=(1, 0), padding_option="zero")
# compute qk
qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
qk += tl.where(
off_q[:, None] >= (i * kernel_stride + off_k)[None, :], 0, float("-inf")
)
qk += tl.dot(q, k) * qk_scale
# compute m_ij and l_ij
m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
p = tl.exp2(qk - m_ij[:, None])
l_ij = tl.sum(p, axis=1)
# scale acc_o
acc_o_scale = tl.exp2(m_i - m_ij)
acc_o = acc_o * acc_o_scale[:, None]
# load v and update acc_o
v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero")
p = p.to(v.dtype)
acc_o += tl.dot(p, v)
# update statistics
m_i = m_ij
lse_i = m_ij + tl.math.log2(tl.exp2(lse_i - m_ij) + l_ij)
# update ptrs
k_ptrs = tl.advance(k_ptrs, (0, BLOCK_SIZE_K))
v_ptrs = tl.advance(v_ptrs, (BLOCK_SIZE_K, 0))
# final scale
acc_o = acc_o * tl.exp2(m_i - lse_i)[:, None]
# save output
o_ptrs = tl.make_block_ptr(
base=o_ptr + q_start * stride_on + pid_h * stride_oh,
shape=(q_len, HEAD_DIM),
strides=(stride_on, stride_od),
offsets=(q_start_in_seq, 0),
block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
order=(1, 0),
)
tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1))
# save lse
l_ptrs = lse_ptr + q_start * stride_ln + pid_h * stride_lh + off_q * stride_ln
tl.store(l_ptrs, lse_i, mask=off_q < q_len)
@triton.jit
def backward_sum_o_do(
o_ptr, # O: n x h x d
do_ptr, # dO: n x h x d
delta_ptr, # D: h x n
o_len,
HEAD_DIM,
stride_on,
stride_oh,
stride_od,
stride_don,
stride_doh,
stride_dod,
stride_dh,
stride_dn,
BLOCK_SIZE_O: tl.constexpr,
BLOCK_SIZE_D: tl.constexpr,
):
pid_n = tl.program_id(0)
pid_h = tl.program_id(1)
off_n = pid_n * BLOCK_SIZE_O + tl.arange(0, BLOCK_SIZE_O)
off_d = tl.arange(0, BLOCK_SIZE_D)
o = tl.load(
o_ptr
+ off_n[:, None] * stride_on
+ pid_h * stride_oh
+ off_d[None, :] * stride_od,
mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM),
other=0,
).to(tl.float32)
do = tl.load(
do_ptr
+ off_n[:, None] * stride_don
+ pid_h * stride_doh
+ off_d[None, :] * stride_dod,
mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM),
other=0,
).to(tl.float32)
delta = tl.sum(o * do, axis=1)
tl.store(
delta_ptr + pid_h * stride_dh + off_n * stride_dn, delta, mask=off_n < o_len
)
@triton.jit
def backward_dkdv(
q_ptr, # Q: n x qh x d
k_ptr, # K: n x kh x d
v_ptr, # V: n x kh x d
lse_ptr, # LSE: qh x n
d_ptr, # Delta: qh x n
do_ptr,
dk_ptr, # DK: sh x n x kh x d
dv_ptr, # DV: sh x n x kh x d
kernel_size,
kernel_stride,
# seqlens
cu_seqlens_q,
cu_seqlens_k,
# shape
NUM_KV_HEADS,
NUM_SHARE_Q_HEADS,
HEAD_DIM,
# sm_scale
sm_scale,
# stride
stride_qn,
stride_qh,
stride_qd,
stride_kn,
stride_kh,
stride_kd,
stride_vn,
stride_vh,
stride_vd,
stride_lh,
stride_ln,
stride_dh,
stride_dn,
stride_don,
stride_doh,
stride_dod,
stride_dks,
stride_dkn,
stride_dkh,
stride_dkd,
stride_dvs,
stride_dvn,
stride_dvh,
stride_dvd,
# META parameters
BLOCK_SIZE_Q: tl.constexpr, # q block size
BLOCK_SIZE_K: tl.constexpr, # k block size
BLOCK_SIZE_D: tl.constexpr,
):
qk_scale = sm_scale * 1.44269504
# get batch id and head id
pid_b = tl.program_id(0)
pid_h = tl.program_id(1)
pid_kh = pid_h // NUM_SHARE_Q_HEADS
pid_sh = pid_h % NUM_SHARE_Q_HEADS
pid_k = tl.program_id(2)
# get q k start and len after rmpad
q_start = tl.load(cu_seqlens_q + pid_b)
q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
k_start = tl.load(cu_seqlens_k + pid_b)
k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
if BLOCK_SIZE_K * pid_k >= k_len:
return
# init pointers
k_ptrs = tl.make_block_ptr(
base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
shape=(k_len, HEAD_DIM),
strides=(stride_kn, stride_kd),
offsets=(pid_k * BLOCK_SIZE_K, 0),
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
order=(1, 0),
)
dk_ptrs = tl.make_block_ptr(
base=dk_ptr + k_start * stride_dkn + pid_kh * stride_dkh + pid_sh * stride_dks,
shape=(k_len, HEAD_DIM),
strides=(stride_dkn, stride_dkd),
offsets=(pid_k * BLOCK_SIZE_K, 0),
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
order=(1, 0),
)
v_ptrs = tl.make_block_ptr(
base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,
shape=(k_len, HEAD_DIM),
strides=(stride_vn, stride_vd),
offsets=(pid_k * BLOCK_SIZE_K, 0),
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
order=(1, 0),
)
dv_ptrs = tl.make_block_ptr(
base=dv_ptr + k_start * stride_dvn + pid_kh * stride_dvh + pid_sh * stride_dvs,
shape=(k_len, HEAD_DIM),
strides=(stride_dvn, stride_dvd),
offsets=(pid_k * BLOCK_SIZE_K, 0),
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
order=(1, 0),
)
# offsets
off_q = tl.arange(0, BLOCK_SIZE_Q)
off_k = (
pid_k * BLOCK_SIZE_K * kernel_stride
+ tl.arange(0, BLOCK_SIZE_K) * kernel_stride
+ kernel_size
- 1
)
# load k v and keep in SRAM
k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero")
v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero")
# init dk dv
dk = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32)
dv = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32)
q_lo = pid_k * BLOCK_SIZE_K * kernel_stride + kernel_size - 1
q_ptrs = tl.make_block_ptr(
base=q_ptr + q_start * stride_qn + pid_h * stride_qh,
shape=(HEAD_DIM, q_len),
strides=(stride_qd, stride_qn),
offsets=(0, q_lo),
block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_Q),
order=(0, 1),
)
do_ptrs = tl.make_block_ptr(
base=do_ptr + q_start * stride_don + pid_h * stride_doh,
shape=(HEAD_DIM, q_len),
strides=(stride_dod, stride_don),
offsets=(0, q_lo),
block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_Q),
order=(0, 1),
)
d_ptrs = tl.make_block_ptr(
base=d_ptr + q_start * stride_dn + pid_h * stride_dh,
shape=(1, q_len),
strides=(0, stride_dn),
offsets=(0, q_lo),
block_shape=(1, BLOCK_SIZE_Q),
order=(1, 0),
)
lse_ptrs = tl.make_block_ptr(
base=lse_ptr + q_start * stride_ln + pid_h * stride_lh,
shape=(1, q_len),
strides=(0, stride_ln),
offsets=(0, q_lo),
block_shape=(1, BLOCK_SIZE_Q),
order=(0, 1),
)
# loop for q blocks
for i in range(q_lo, q_len, BLOCK_SIZE_Q):
# load
q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero")
do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero")
lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero")
d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero")
# compute qk
# [BLOCK_SIZE_K, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q]
qk = tl.where(off_k[:, None] <= (off_q + i)[None, :], float(0.0), float("-inf"))
qk += tl.dot(k, q) * qk_scale
# compute p, ds
# [BLOCK_SIZE_K, BLOCK_SIE_Q] - [1, BLOCK_SIZE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q]
p = tl.exp2(qk - lse)
# [BLOCK_SIZE_K, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q]
dp = tl.dot(v, do)
ds = sm_scale * p * (dp - d)
# cast dtype
p = p.to(do.dtype)
ds = ds.to(q.dtype)
# update dk and dv
# [BLOCK_SIZE_K, BLOCK_SIE_Q] @ [BLOCK_SIE_Q, HEAD_DIM] -> [BLOCK_SIZE_K, HEAD_DIM]
dk += tl.dot(ds, tl.trans(q))
dv += tl.dot(p, tl.trans(do))
# increment pointers
q_ptrs = tl.advance(q_ptrs, (0, BLOCK_SIZE_Q))
do_ptrs = tl.advance(do_ptrs, (0, BLOCK_SIZE_Q))
lse_ptrs = tl.advance(lse_ptrs, (0, BLOCK_SIZE_Q))
d_ptrs = tl.advance(d_ptrs, (0, BLOCK_SIZE_Q))
# save dk dv
tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), boundary_check=(0, 1))
tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), boundary_check=(0, 1))
@triton.jit
def backward_dq(
q_ptr, # Q: n x qh x d
k_ptr, # K: n x kh x d
v_ptr, # V: n x kh x d
lse_ptr, # LSE: qh x n
d_ptr, # Delta: qh x n
do_ptr,
dq_ptr,
kernel_size,
kernel_stride,
# seqlens
cu_seqlens_q,
cu_seqlens_k,
# shape
NUM_KV_HEADS,
NUM_SHARE_Q_HEADS,
HEAD_DIM,
# sm_scale
sm_scale,
# stride
stride_qn,
stride_qh,
stride_qd,
stride_kn,
stride_kh,
stride_kd,
stride_vn,
stride_vh,
stride_vd,
stride_lh,
stride_ln,
stride_dh,
stride_dn,
stride_don,
stride_doh,
stride_dod,
stride_dqn,
stride_dqh,
stride_dqd,
# META parameters
BLOCK_SIZE_Q: tl.constexpr, # q block size
BLOCK_SIZE_K: tl.constexpr, # k block size
BLOCK_SIZE_D: tl.constexpr,
):
qk_scale = sm_scale * 1.44269504
# get batch id and head id
pid_b = tl.program_id(0)
pid_h = tl.program_id(1)
pid_q = tl.program_id(2)
pid_kh = pid_h // NUM_SHARE_Q_HEADS
# get q k start and len after rmpad
q_start = tl.load(cu_seqlens_q + pid_b)
q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
k_start = tl.load(cu_seqlens_k + pid_b)
k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
# skip first kernel_size query block, because they do no attend to any keys
q_start_in_seq = pid_q * BLOCK_SIZE_Q + kernel_size - 1
if q_start_in_seq >= q_len:
return
# init pointers
q_ptrs = tl.make_block_ptr(
base=q_ptr + q_start * stride_qn + pid_h * stride_qh,
shape=(q_len, HEAD_DIM),
strides=(stride_qn, stride_qd),
offsets=(q_start_in_seq, 0),
block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
order=(1, 0),
)
dq_ptrs = tl.make_block_ptr(
base=dq_ptr + q_start * stride_dqn + pid_h * stride_dqh,
shape=(q_len, HEAD_DIM),
strides=(stride_dqn, stride_dqd),
offsets=(q_start_in_seq, 0),
block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
order=(1, 0),
)
k_ptrs = tl.make_block_ptr(
base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
shape=(k_len, HEAD_DIM),
strides=(stride_kn, stride_kd),
offsets=(0, 0),
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),
order=(1, 0),
)
v_ptrs = tl.make_block_ptr(
base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,
shape=(HEAD_DIM, k_len),
strides=(stride_vd, stride_vn),
offsets=(0, 0),
block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K),
order=(0, 1),
)
do_ptrs = tl.make_block_ptr(
base=do_ptr + q_start * stride_don + pid_h * stride_doh,
shape=(q_len, HEAD_DIM),
strides=(stride_don, stride_dod),
offsets=(q_start_in_seq, 0),
block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
order=(1, 0),
)
d_ptrs = tl.make_block_ptr(
base=d_ptr + q_start * stride_dn + pid_h * stride_dh,
shape=(q_len, 1),
strides=(stride_dn, stride_dh),
offsets=(q_start_in_seq, 0),
block_shape=(BLOCK_SIZE_Q, 1),
order=(0, 1),
)
lse_ptrs = tl.make_block_ptr(
base=lse_ptr + q_start * stride_ln + pid_h * stride_lh,
shape=(q_len, 1),
strides=(stride_ln, stride_lh),
offsets=(q_start_in_seq, 0),
block_shape=(BLOCK_SIZE_Q, 1),
order=(0, 1),
)
# offsets
off_q = tl.arange(0, BLOCK_SIZE_Q) + q_start_in_seq
off_k = tl.arange(0, BLOCK_SIZE_K) * kernel_stride + kernel_size - 1
# load q, do, lse, delta, and keep in SRAM
q = tl.load(q_ptrs, boundary_check=(1, 0), padding_option="zero")
do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero")
lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero")
d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero")
# init dq
dq = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_D), dtype=tl.float32)
lo = 0
hi = min(k_len, (q_start_in_seq + BLOCK_SIZE_Q - kernel_size) // kernel_stride + 1)
for i in range(lo, hi, BLOCK_SIZE_K):
# load
k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero")
v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero")
# compute qk
qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
qk += tl.where(
off_q[:, None] >= (i * kernel_stride + off_k)[None, :], 0, float("-inf")
)
qk += tl.dot(q, tl.trans(k)) * qk_scale
# compute p, ds
p = tl.exp2(qk - lse)
dp = tl.dot(do, v)
ds = sm_scale * p * (dp - d)
# cast dtype
ds = ds.to(q.dtype)
# update dq
dq += tl.dot(ds, k)
# increment pointers
k_ptrs = tl.advance(k_ptrs, (BLOCK_SIZE_K, 0))
v_ptrs = tl.advance(v_ptrs, (0, BLOCK_SIZE_K))
# save dq
tl.store(dq_ptrs, dq.to(dq_ptr.dtype.element_ty), boundary_check=(0, 1))
def _compressed_attention_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
kernel_size: int,
kernel_stride: int,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
sm_scale: float,
):
# dtype check
assert k.dtype == q.dtype and v.dtype == q.dtype
assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32
# shape
q_len, num_q_heads, head_dim = q.shape
k_len, num_k_heads, head_dim = k.shape
v_len, num_v_heads, head_dim = v.shape
batch_size = cu_seqlens_q.shape[0] - 1
assert k_len == v_len and q_len > k_len
# gqa
assert num_k_heads == num_v_heads
assert num_q_heads % num_k_heads == 0
num_share_q_heads = num_q_heads // num_k_heads
# output tensor
o = torch.zeros_like(q)
lse = torch.full(
(num_q_heads, q_len),
fill_value=-torch.inf,
dtype=torch.float32,
device=q.device,
)
# launch kernel
grid = lambda META: (
batch_size,
num_q_heads,
triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]),
)
BLOCK_SIZE_Q = 128
BLOCK_SIZE_K = 128
BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU)
forward_kernel[grid](
q,
k,
v,
o,
lse,
kernel_size,
kernel_stride,
cu_seqlens_q,
cu_seqlens_k,
num_k_heads,
num_share_q_heads,
head_dim,
sm_scale,
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
lse.stride(0),
lse.stride(1),
BLOCK_SIZE_Q=BLOCK_SIZE_Q,
BLOCK_SIZE_K=BLOCK_SIZE_K,
BLOCK_SIZE_D=BLOCK_SIZE_D,
num_warps=num_warps,
num_stages=num_stages,
)
return o, lse
def _compressed_attention_bwd(
o: torch.Tensor,
do: torch.Tensor,
lse: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
kernel_size: int,
kernel_stride: int,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
sm_scale: float,
):
q_len, num_q_heads, head_dim = q.shape
k_len, num_k_heads, head_dim = k.shape
v_len, num_v_heads, head_dim = v.shape
o_len, num_o_heads, head_dim = o.shape
num_share_q_heads = num_q_heads // num_k_heads
# compute D
delta = torch.zeros([num_o_heads, o_len], device=o.device, dtype=torch.float32)
grid = lambda META: (triton.cdiv(o_len, META["BLOCK_SIZE_O"]), num_o_heads)
BLOCK_SIZE_O = 256
BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_O, IS_HOPPER_GPU)
backward_sum_o_do[grid](
o,
do,
delta,
o_len,
head_dim,
o.stride(0),
o.stride(1),
o.stride(2),
do.stride(0),
do.stride(1),
do.stride(2),
delta.stride(0),
delta.stride(1),
BLOCK_SIZE_O=BLOCK_SIZE_O,
BLOCK_SIZE_D=BLOCK_SIZE_D,
num_warps=num_warps,
num_stages=num_stages,
)
# compute dk dv
dk = torch.zeros(
num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype
)
dv = torch.zeros(
num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype
)
batch_size = cu_seqlens_q.shape[0] - 1
grid = lambda META: (
batch_size,
num_q_heads,
triton.cdiv(max_seqlen_k, META["BLOCK_SIZE_K"]),
)
BLOCK_SIZE_Q = 64
BLOCK_SIZE_K = 128
BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_K, IS_HOPPER_GPU)
backward_dkdv[grid](
q,
k,
v,
lse,
delta,
do,
dk,
dv,
kernel_size,
kernel_stride,
cu_seqlens_q,
cu_seqlens_k,
num_k_heads,
num_share_q_heads,
head_dim,
sm_scale,
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
lse.stride(0),
lse.stride(1),
delta.stride(0),
delta.stride(1),
do.stride(0),
do.stride(1),
do.stride(2),
dk.stride(0),
dk.stride(1),
dk.stride(2),
dk.stride(3),
dv.stride(0),
dv.stride(1),
dv.stride(2),
dv.stride(3),
BLOCK_SIZE_Q=BLOCK_SIZE_Q,
BLOCK_SIZE_K=BLOCK_SIZE_K,
BLOCK_SIZE_D=BLOCK_SIZE_D,
num_warps=num_warps,
num_stages=num_stages,
)
dk = dk.sum(0)
dv = dv.sum(0)
# compute dq
dq = torch.zeros_like(q)
grid = lambda META: (
batch_size,
num_q_heads,
triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]),
)
BLOCK_SIZE_Q = 128
BLOCK_SIZE_K = 64
num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU)
backward_dq[grid](
q,
k,
v,
lse,
delta,
do,
dq,
kernel_size,
kernel_stride,
cu_seqlens_q,
cu_seqlens_k,
num_k_heads,
num_share_q_heads,
head_dim,
sm_scale,
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
lse.stride(0),
lse.stride(1),
delta.stride(0),
delta.stride(1),
do.stride(0),
do.stride(1),
do.stride(2),
dq.stride(0),
dq.stride(1),
dq.stride(2),
BLOCK_SIZE_Q=BLOCK_SIZE_Q,
BLOCK_SIZE_K=BLOCK_SIZE_K,
BLOCK_SIZE_D=BLOCK_SIZE_D,
num_warps=num_warps,
num_stages=num_stages,
)
return dq, dk, dv
class CompressedAttention(torch.autograd.Function):
@staticmethod
def forward(
ctx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
kernel_size: int,
kernel_stride: int,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
sm_scale=None,
):
# dtype check
assert q.dtype == torch.bfloat16 or q.dtype == torch.float16
assert q.dtype == k.dtype and k.dtype == v.dtype
assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32
# softmax scale
if sm_scale is None:
sm_scale = 1 / math.sqrt(q.shape[-1])
o, lse = _compressed_attention_fwd(
q,
k,
v,
kernel_size,
kernel_stride,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
sm_scale,
)
ctx.save_for_backward(q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k)
ctx.sm_scale = sm_scale
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_k = max_seqlen_k
ctx.kernel_size = kernel_size
ctx.kernel_stride = kernel_stride
return o, lse
@staticmethod
def backward(ctx, do: torch.Tensor, *args) -> Any:
q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors
max_seqlen_q = ctx.max_seqlen_q
max_seqlen_k = ctx.max_seqlen_k
sm_scale = ctx.sm_scale
kernel_size = ctx.kernel_size
kernel_stride = ctx.kernel_stride
dq, dk, dv = _compressed_attention_bwd(
o,
do,
lse,
q,
k,
v,
kernel_size,
kernel_stride,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
sm_scale,
)
return dq, dk, dv, None, None, None, None, None, None, None
@triton.jit
def score_kernel(
q_ptr,
k_ptr,
lse_ptr,
s_ptr,
kernel_size,
kernel_stride,
# seqlens
cu_seqlens_q,
cu_seqlens_k,
# shape
NUM_KV_HEADS,
NUM_SHARE_Q_HEADS,
HEAD_DIM,
# sm_scale
sm_scale,
# stride
stride_qn,
stride_qh,
stride_qd,
stride_kn,
stride_kh,
stride_kd,
stride_lh,
stride_ln,
stride_sh,
stride_sq,
stride_sk,
# META parameters
BLOCK_SIZE_Q: tl.constexpr, # q block size
BLOCK_SIZE_K: tl.constexpr, # k block size
BLOCK_SIZE_D: tl.constexpr,
):
qk_scale = sm_scale * 1.44269504
# get batch id and head id
pid_bkh = tl.program_id(0)
pid_b = pid_bkh // NUM_KV_HEADS
pid_kh = pid_bkh % NUM_KV_HEADS
pid_q = tl.program_id(1)
pid_k = tl.program_id(2)
# get q k start and len after rmpad
q_start = tl.load(cu_seqlens_q + pid_b)
q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
k_start = tl.load(cu_seqlens_k + pid_b)
k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
if pid_q * BLOCK_SIZE_Q >= q_len or pid_k * BLOCK_SIZE_K >= k_len:
return
# init k pointer and load k
k_ptrs = tl.make_block_ptr(
base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
shape=(HEAD_DIM, k_len),
strides=(stride_kd, stride_kn),
offsets=(0, pid_k * BLOCK_SIZE_K),
block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K),
order=(0, 1),
)
k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero")
# offsets
off_q = tl.arange(0, BLOCK_SIZE_Q) + pid_q * BLOCK_SIZE_Q
off_k = tl.arange(0, BLOCK_SIZE_K) + pid_k * BLOCK_SIZE_K
causal_mask = off_q[:, None] >= (off_k * kernel_stride + kernel_size - 1)[None, :]
# init score
s = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
# loop over gqa heads
for h in range(NUM_SHARE_Q_HEADS):
pid_h = pid_kh * NUM_SHARE_Q_HEADS + h
q_ptrs = tl.make_block_ptr(
base=q_ptr + q_start * stride_qn + pid_h * stride_qh,
shape=(q_len, HEAD_DIM),
strides=(stride_qn, stride_qd),
offsets=(pid_q * BLOCK_SIZE_Q, 0),
block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),
order=(1, 0),
)
lse_ptrs = tl.make_block_ptr(
base=lse_ptr + q_start * stride_ln + pid_h * stride_lh,
shape=(q_len, 1),
strides=(stride_ln, stride_lh),
offsets=(pid_q * BLOCK_SIZE_Q, 0),
block_shape=(BLOCK_SIZE_Q, 1),
order=(0, 1),
)
# load q and lse
q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero")
lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero")
# compute qk
qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
qk += tl.dot(q, k) * qk_scale
# compute score
s += tl.where(causal_mask, tl.exp2(qk - lse), 0)
# save output
s_ptrs = tl.make_block_ptr(
base=s_ptr + pid_kh * stride_sh + q_start * stride_sq,
shape=(q_len, k_len),
strides=(stride_sq, stride_sk),
offsets=(pid_q * BLOCK_SIZE_Q, pid_k * BLOCK_SIZE_K),
block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_K),
order=(1, 0),
)
tl.store(s_ptrs, s.to(s_ptr.dtype.element_ty), boundary_check=(0, 1))
def _get_attention_score(
q: torch.Tensor, # [total_query_len, num_q_heads, head_dim]
k: torch.Tensor, # [total_key_len, num_k_heads, head_dim]
lse: torch.Tensor, # [num_q_heads, total_query_len]
kernel_size: int,
kernel_stride: int,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
sm_scale: float,
) -> torch.Tensor:
# dtype check
assert q.dtype == torch.bfloat16 or q.dtype == torch.float16
assert q.dtype == k.dtype
assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32
assert (
lse.dtype == torch.float32
) # lse here is log2(sum(exp(qk*scale))), not log(sum(exp(qk*scale)))
# shape
q_len, num_q_heads, head_dim = q.shape
k_len, num_k_heads, head_dim = k.shape
batch_size = cu_seqlens_q.shape[0] - 1
assert q_len > k_len
if sm_scale is None:
sm_scale = 1 / math.sqrt(head_dim)
# gqa
assert num_q_heads % num_k_heads == 0
num_share_q_heads = num_q_heads // num_k_heads
# init score
score = torch.zeros(
num_k_heads, q_len, max_seqlen_k, dtype=torch.float32, device=q.device
)
# launch kernel
grid = lambda META: (
batch_size * num_k_heads,
triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]),
triton.cdiv(max_seqlen_k, META["BLOCK_SIZE_K"]),
)
BLOCK_SIZE_Q = 128
BLOCK_SIZE_K = 128
BLOCK_SIZE_D = triton.next_power_of_2(head_dim)
score_kernel[grid](
q,
k,
lse,
score,
kernel_size,
kernel_stride,
cu_seqlens_q,
cu_seqlens_k,
num_k_heads,
num_share_q_heads,
head_dim,
sm_scale,
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
lse.stride(0),
lse.stride(1),
score.stride(0),
score.stride(1),
score.stride(2),
BLOCK_SIZE_Q=BLOCK_SIZE_Q,
BLOCK_SIZE_K=BLOCK_SIZE_K,
BLOCK_SIZE_D=BLOCK_SIZE_D,
num_warps=8,
num_stages=3,
)
return score
@triton.jit
def _transform_score_kernel(
s_ptr, # score, shape: [num_heads, q_len, k_len]
bs_ptr, # block wise score: [num_heads, q_len, num_k_block]
offs,
cu_seqlens_q,
# shape
num_heads,
num_offs,
max_k_len,
max_blocks,
pad_len,
# kernel & block size
block_size,
block_stride, # block_size // kernel_stride
init_blocks,
local_blocks,
# stride
stride_sh,
stride_sq,
stride_sk,
stride_bsh,
stride_bsq,
stride_bsk,
BLOCK_SIZE_Q: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
BLOCK_SIZE_O: tl.constexpr,
):
pid_bh = tl.program_id(0)
pid_b = pid_bh // num_heads
pid_h = pid_bh % num_heads
pid_q = tl.program_id(1)
pid_k = tl.program_id(2)
q_start = tl.load(cu_seqlens_q + pid_b)
q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
k_start = pid_k * BLOCK_SIZE_K
if pid_q * BLOCK_SIZE_Q >= q_len:
return
# load weight
off_o = tl.arange(0, BLOCK_SIZE_O)
w = tl.load(offs + off_o, mask=off_o < num_offs, other=0)
# load score
off_q = pid_q * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q)
off_k = (k_start + tl.arange(0, BLOCK_SIZE_K)) * block_stride - pad_len
off_k = off_k[None, :] + off_o[:, None]
s_ptrs = (
s_ptr
+ q_start * stride_sq
+ pid_h * stride_sh
+ off_q[:, None, None] * stride_sq
+ off_k[None, :, :] * stride_sk
)
# weighted sum, [BQ, BO, BK] * [1, BO, 1] -> [BQ, BO, BK] -> [BQ, BK]
s = tl.load(
s_ptrs,
mask=(off_q < q_len)[:, None, None] & (off_k >= 0) & (off_k < max_k_len),
other=0,
)
s = s * w[None, :, None]
s = tl.sum(s, axis=1)
# init mask and local mask
off_bq = off_q // block_size
off_bk = k_start + tl.arange(0, BLOCK_SIZE_K)
s = tl.where(
(
(off_bq[:, None] >= off_bk[None, :]) # causal mask
& (off_bq[:, None] < off_bk[None, :] + local_blocks) # local window
)
| (off_bk[None, :] < init_blocks), # init window
float("inf"),
s,
)
# store block wise score
bs_ptrs = (
bs_ptr
+ q_start * stride_bsq
+ pid_h * stride_bsh
+ off_q[:, None] * stride_bsq
+ off_bk[None, :] * stride_bsk
)
tl.store(
bs_ptrs,
s,
mask=(off_q < q_len)[:, None] & (off_bk < max_blocks)[None, :],
)
def transform_score(
score: torch.Tensor,
kernel_size: int,
kernel_stride: int,
block_size: int,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
init_blocks: int = 1,
local_blocks: int = 2,
) -> torch.Tensor:
num_k_heads, total_query_len, max_key_len = score.shape
batch_size = cu_seqlens_q.shape[0] - 1
pad_len = kernel_size // kernel_stride - 1
max_blocks = math.ceil(max_seqlen_q / block_size)
block_score = torch.zeros(
num_k_heads,
total_query_len,
max_blocks,
dtype=torch.float32,
device=score.device,
)
offs = (
torch.arange(kernel_size // kernel_stride, device=score.device)[:, None]
+ torch.arange(block_size // kernel_stride, device=score.device)[None, :]
).view(-1)
offs = torch.histc(offs, bins=offs.max() + 1, min=0, max=offs.max())
num_offs = int(offs.shape[0])
BLOCK_SIZE_K = min(128, triton.next_power_of_2(max_blocks))
BLOCK_SIZE_O = triton.next_power_of_2(num_offs)
BLOCK_SIZE_Q = 8
grid = (
num_k_heads * batch_size,
triton.cdiv(total_query_len, BLOCK_SIZE_Q),
triton.cdiv(max_blocks, BLOCK_SIZE_K),
)
_transform_score_kernel[grid](
score,
block_score,
offs,
cu_seqlens_q,
num_k_heads,
offs.shape[0],
max_key_len,
max_blocks,
pad_len,
block_size,
block_size // kernel_stride,
init_blocks,
local_blocks,
score.stride(0),
score.stride(1),
score.stride(2),
block_score.stride(0),
block_score.stride(1),
block_score.stride(2),
BLOCK_SIZE_Q=BLOCK_SIZE_Q,
BLOCK_SIZE_K=BLOCK_SIZE_K,
BLOCK_SIZE_O=BLOCK_SIZE_O,
num_warps=8,
num_stages=3,
)
return block_score
def compressed_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
kernel_size: int,
kernel_stride: int,
block_size: int,
topk: int,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int = None,
max_seqlen_k: int = None,
sm_scale: float = None,
init_blocks: int = 1,
local_blocks: int = 2,
parallel_topk_compute: Union[str, bool] = "auto",
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Attention between query and compressed key and value. Compute attention output and topk block idx used in topk_sparse_attention.
Args:
q (torch.Tensor): shape [total_q_len, num_q_heads, head_dim]
k (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim]
v (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim]
kernel_size (int): kernel size in compress_key_value
kernel_stride (int): stride of compress_key_value
block_size (int): key value block size for topk sparse attention.
topk (int): number of blocks for each query.
cu_seqlens_q (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.
cu_seqlens_k (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_k in flash_attn_func_varlen.
max_seqlen_q (int): max q len of the batch. Defaults to None, means (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max().item().
max_seqlen_k (int): max k len of the batch. Defaults to None, means (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item().
sm_scale (float, optional): softmax scale. Defaults to None, means 1/sqrt(head_dim).
init_blocks (int, optional): Number of init blocks for each query. Defaults to 1.
local_blocks (int, optional): Number of local blocks for each query. Defaults to 2.
parallel_topk_compute (str, optional): Only set it to False when the sequence length is too long. This can avoid a current bug.
We'll fix this issue later. Defaults to auto, it will be set to False when the sequence length is greater than 32k and True otherwise.
Returns:
Tuple[torch.Tensor, torch.Tensor]: attention output and topk_idx used in topk_sparse_attention
"""
if max_seqlen_q is None:
max_seqlen_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max().item()
if max_seqlen_k is None:
max_seqlen_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item()
attn_output, lse = CompressedAttention.apply(
q,
k,
v,
kernel_size,
kernel_stride,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
sm_scale,
)
# do not select topk index
if topk <= 0:
warnings.warn("topk <= 0, returned topk_idx will be None")
return attn_output, None
assert topk >= init_blocks + local_blocks
with torch.no_grad():
num_k_heads, num_q_heads = k.shape[1], q.shape[1]
num_shared_q_heads = num_q_heads // num_k_heads
batch_size = cu_seqlens_q.shape[0] - 1
q_idx = torch.cat(
[
torch.arange(cu_seqlens_q[i + 1] - cu_seqlens_q[i], device=q.device)
for i in range(batch_size)
],
dim=0,
)
q_idx = q_idx // block_size
# whether to use parallel version
if parallel_topk_compute == "auto":
parallel_topk_compute = cu_seqlens_q[-1] <= 32768
# parallel version
if parallel_topk_compute:
# recompute score
score = _get_attention_score(
q,
k,
lse,
kernel_size,
kernel_stride,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
sm_scale,
)
# transform score to block-wise score
score = transform_score(
score,
kernel_size,
kernel_stride,
block_size,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
init_blocks,
local_blocks,
)
# get topk
topk = min(topk, score.shape[-1])
topk_idx = score.topk(topk, dim=-1).indices.sort(-1).values
topk_idx[topk_idx > q_idx[None, :, None]] = -1
topk_idx = topk_idx.to(torch.int32)
# non parallel version, avoid some current bugs when sequence length is too long
# FIXME: need to fix later
else:
topk_idx_list = []
for h in range(num_k_heads):
# recompute score
score = _get_attention_score(
q[:, h * num_shared_q_heads : (h + 1) * num_shared_q_heads],
k[:, h : h + 1],
lse[h * num_shared_q_heads : (h + 1) * num_shared_q_heads],
kernel_size,
kernel_stride,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
sm_scale,
)
# transform score to block-wise score
score = transform_score(
score,
kernel_size,
kernel_stride,
block_size,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
init_blocks,
local_blocks,
)
# get topk
topk = min(topk, score.shape[-1])
topk_idx = score.topk(topk, dim=-1).indices.sort(-1).values
topk_idx[topk_idx > q_idx[None, :, None]] = -1
topk_idx = topk_idx.to(torch.int32)
topk_idx_list.append(topk_idx)
topk_idx = torch.cat(topk_idx_list, dim=0)
return attn_output, topk_idx
================================================
FILE: native_sparse_attention/ops/triton/flash_attention.py
================================================
# Copyright 2025 Xunhao Lai & Jianqiao Lu.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Any, Optional
import torch
import triton
import triton.language as tl
from native_sparse_attention.ops.triton.utils import get_num_warps_stages, is_hopper_gpu
IS_HOPPER_GPU = is_hopper_gpu()
@triton.jit
def forward_kernel(
q_ptr, # Q: n x h x d
k_ptr, # K: n x h x d
v_ptr, # V: n x h x d
o_ptr, # O: n x h x d
lse_ptr, # LSE: h x n
# seqlens
cu_seqlens_q,
cu_seqlens_k,
# shape
NUM_KV_HEADS,
NUM_SHARE_Q_HEADS,
qk_head_dim,
v_head_dim,
# sm_scale
sm_scale,
# causal
causal,
# gqa
gqa_interleave,
# stride
stride_qn,
stride_qh,
stride_qd,
stride_kn,
stride_kh,
stride_kd,
stride_vn,
stride_vh,
stride_vd,
stride_on,
stride_oh,
stride_od,
stride_lh,
stride_ln,
# META parameters
BLOCK_SIZE_Q: tl.constexpr, # q block size
BLOCK_SIZE_K: tl.constexpr, # k block size
BLOCK_SIZE_KD: tl.constexpr,
BLOCK_SIZE_VD: tl.constexpr,
):
qk_scale = sm_scale * 1.44269504
# get batch id and head id
pid_b = tl.program_id(0)
pid_h = tl.program_id(1)
pid_q = tl.program_id(2)
if gqa_interleave:
pid_kh = pid_h % NUM_KV_HEADS
else:
pid_kh = pid_h // NUM_SHARE_Q_HEADS
# get q k start and len after rmpad
q_start = tl.load(cu_seqlens_q + pid_b)
q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
k_start = tl.load(cu_seqlens_k + pid_b)
k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
if BLOCK_SIZE_Q * pid_q >= q_len:
return
# init qkv pointer
q_ptrs = tl.make_block_ptr(
base=q_ptr + q_start * stride_qn + pid_h * stride_qh,
shape=(q_len, qk_head_dim),
strides=(stride_qn, stride_qd),
offsets=(pid_q * BLOCK_SIZE_Q, 0),
block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_KD),
order=(1, 0),
)
k_ptrs = tl.make_block_ptr(
base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
shape=(qk_head_dim, k_len),
strides=(stride_kd, stride_kn),
offsets=(0, 0),
block_shape=(BLOCK_SIZE_KD, BLOCK_SIZE_K),
order=(0, 1),
)
v_ptrs = tl.make_block_ptr(
base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,
shape=(k_len, v_head_dim),
strides=(stride_vn, stride_vd),
offsets=(0, 0),
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_VD),
order=(1, 0),
)
# load q
q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero")
# init statistics
off_q = tl.arange(0, BLOCK_SIZE_Q) + pid_q * BLOCK_SIZE_Q
off_k = tl.arange(0, BLOCK_SIZE_K)
m_i = tl.full((BLOCK_SIZE_Q,), float("-inf"), dtype=tl.float32)
lse_i = tl.full((BLOCK_SIZE_Q,), float("-inf"), dtype=tl.float32)
acc_o = tl.full((BLOCK_SIZE_Q, BLOCK_SIZE_VD), 0, dtype=tl.float32)
# full attention or causal attention
lo = 0
if causal:
hi = min(k_len, (pid_q + 1) * BLOCK_SIZE_Q)
else:
hi = k_len
for i in range(lo, hi, BLOCK_SIZE_K):
i = tl.multiple_of(i, BLOCK_SIZE_K)
# load k
k = tl.load(k_ptrs, boundary_check=(1, 0), padding_option="zero")
# compute qk
qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)
if causal:
qk += tl.where(off_q[:, None] >= (i + off_k)[None, :], 0, float("-inf"))
else:
qk += tl.where((off_k < k_len - i)[None, :], 0, float("-inf"))
qk += tl.dot(q, k) * qk_scale
# compute m_ij and l_ij
m_ij = tl.maximum(m_i, tl.max(qk, axis=1))
p = tl.math.exp2(qk - m_ij[:, None])
l_ij = tl.sum(p, axis=1)
# scale acc_o
acc_o_scale = tl.math.exp2(m_i - m_ij)
acc_o = acc_o * acc_o_scale[:, None]
# load v and update acc_o
v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero")
p = p.to(v.dtype)
acc_o += tl.dot(p, v)
# update statistics
m_i = m_ij
lse_i = m_ij + tl.math.log2(tl.math.exp2(lse_i - m_ij) + l_ij)
# update ptrs
k_ptrs = tl.advance(k_ptrs, (0, BLOCK_SIZE_K))
v_ptrs = tl.advance(v_ptrs, (BLOCK_SIZE_K, 0))
# final scale
acc_o = acc_o * tl.math.exp2(m_i - lse_i)[:, None]
# save output
o_ptrs = tl.make_block_ptr(
base=o_ptr + q_start * stride_on + pid_h * stride_oh,
shape=(q_len, v_head_dim),
strides=(stride_on, stride_od),
offsets=(pid_q * BLOCK_SIZE_Q, 0),
block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_VD),
order=(1, 0),
)
tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1))
# save lse
l_ptrs = lse_ptr + q_start * stride_ln + pid_h * stride_lh + off_q * stride_ln
tl.store(l_ptrs, lse_i, mask=off_q < q_len)
@triton.jit
def backward_sum_o_do(
o_ptr, # O: n x h x d
do_ptr, # dO: n x h x d
delta_ptr, # D: h x n
o_len,
HEAD_DIM,
stride_on,
stride_oh,
stride_od,
stride_don,
stride_doh,
stride_dod,
stride_dh,
stride_dn,
BLOCK_SIZE_O: tl.constexpr,
BLOCK_SIZE_D: tl.constexpr,
):
pid_n = tl.program_id(0)
pid_h = tl.program_id(1)
off_n = pid_n * BLOCK_SIZE_O + tl.arange(0, BLOCK_SIZE_O)
off_d = tl.arange(0, BLOCK_SIZE_D)
o = tl.load(
o_ptr
+ off_n[:, None] * stride_on
+ pid_h * stride_oh
+ off_d[None, :] * stride_od,
mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM),
other=0,
).to(tl.float32)
do = tl.load(
do_ptr
+ off_n[:, None] * stride_don
+ pid_h * stride_doh
+ off_d[None, :] * stride_dod,
mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM),
other=0,
).to(tl.float32)
delta = tl.sum(o * do, axis=1)
tl.store(
delta_ptr + pid_h * stride_dh + off_n * stride_dn, delta, mask=off_n < o_len
)
@triton.jit
def backward_dkdv(
q_ptr, # Q: n x qh x d
k_ptr, # K: n x kh x d
v_ptr, # V: n x kh x d
lse_ptr, # LSE: qh x n
d_ptr, # Delta: qh x n
do_ptr,
dk_ptr, # DK: sh x n x kh x d
dv_ptr, # DV: sh x n x kh x d
# seqlens
cu_seqlens_q,
cu_seqlens_k,
# shape
NUM_KV_HEADS,
NUM_SHARE_Q_HEADS,
qk_head_dim,
v_head_dim,
# sm_scale
sm_scale,
# causal
causal,
# gqa
gqa_interleave,
# stride
stride_qn,
stride_qh,
stride_qd,
stride_kn,
stride_kh,
stride_kd,
stride_vn,
stride_vh,
stride_vd,
stride_lh,
stride_ln,
stride_dh,
stride_dn,
stride_don,
stride_doh,
stride_dod,
stride_dks,
stride_dkn,
stride_dkh,
stride_dkd,
stride_dvs,
stride_dvn,
stride_dvh,
stride_dvd,
# META parameters
BLOCK_SIZE_Q: tl.constexpr, # q block size
BLOCK_SIZE_K: tl.constexpr, # k block size
BLOCK_SIZE_KD: tl.constexpr,
BLOCK_SIZE_VD: tl.constexpr,
):
qk_scale = sm_scale * 1.44269504
# get batch id and head id
pid_b = tl.program_id(0)
pid_h = tl.program_id(1)
if gqa_interleave:
pid_kh = pid_h % NUM_SHARE_Q_HEADS
pid_sh = pid_h // NUM_SHARE_Q_HEADS
else:
pid_kh = pid_h // NUM_SHARE_Q_HEADS
pid_sh = pid_h % NUM_SHARE_Q_HEADS
pid_k = tl.program_id(2)
# get q k start and len after rmpad
q_start = tl.load(cu_seqlens_q + pid_b)
q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start
k_start = tl.load(cu_seqlens_k + pid_b)
k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start
if BLOCK_SIZE_K * pid_k >= k_len:
return
# init pointers
k_ptrs = tl.make_block_ptr(
base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,
shape=(k_len, qk_head_dim),
strides=(stride_kn, stride_kd),
offsets=(pid_k * BLOCK_SIZE_K, 0),
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_KD),
order=(1, 0),
)
dk_ptrs = tl.make_block_ptr(
base=dk_ptr + k_start * stride_dkn + pid_kh * stride_dkh + pid_sh * stride_dks,
shape=(k_len, qk_head_dim),
strides=(stride_dkn, stride_dkd),
offsets=(pid_k * BLOCK_SIZE_K, 0),
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_KD),
order=(1, 0),
)
v_ptrs = tl.make_block_ptr(
base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,
shape=(k_len, v_head_dim),
strides=(stride_vn, stride_vd),
offsets=(pid_k * BLOCK_SIZE_K, 0),
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_VD),
order=(1, 0),
)
dv_ptrs = tl.make_block_ptr(
base=dv_ptr + k_start * stride_dvn + pid_kh * stride_dvh + pid_sh * stride_dvs,
shape=(k_len, v_head_dim),
strides=(stride_dvn, stride_dvd),
offsets=(pid_k * BLOCK_SIZE_K, 0),
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_VD),
order=(1, 0),
)
# offsets
off_q = tl.arange(0, BLOCK_SIZE_Q)
off_k = tl.arange(0, BLOCK_SIZE_K) + pid_k * BLOCK_SIZE_K
# load k v and keep in SRAM
k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero")
v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero")
# init dk dv
dk = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_KD), dtype=tl.float32)
dv = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_VD), dtype=tl.float32)
# causal
if causal:
q_lo = pid_k * BLOCK_SIZE_K
else:
q_lo = 0
q_ptrs = tl.make_block_ptr(
base=q_ptr + q_start * stride_qn + pid_h * stride_qh,
shape=(q_len, qk_head_dim),
strides=(stride_qn, stride_qd),
offsets=(q_lo, 0),
block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_KD),
order=(1, 0),
)
do_ptrs = tl.make_block_ptr(
base=do_ptr + q_start * stride_don + pid_h * stride_doh,
shape=(q_len, v_head_dim),
strides=(stride_don, stride_dod),
offsets=(q_lo, 0),
block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_VD),
order=(1, 0),
)
d_ptrs = tl.make_block_ptr(
base=d_ptr + q_start * stride_dn + pid_h * stride_dh,
shape=(q_len, 1),
strides=(stride_dn, stride_dh),
offsets=(q_lo, 0),
block_shape=(BLOCK_SIZE_Q, 1),
order=(0, 1),
)
lse_ptrs = tl.make_block_ptr(
base=lse_ptr + q_start * stride_ln + pid_h * stride_lh,
shape=(q_len, 1),
strides=(stride_ln, stride_lh),
offsets=(q_lo, 0),
block_shape=(BLOCK_SIZE_Q, 1),
order=(0, 1),
)
# loop for q blocks
for i in range(q_lo, q_len, BLOCK_SIZE_Q):
# load
q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero")
do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero")
lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero")
d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero")
# compute qk
if causal:
qk = tl.where(
gitextract_5vrtmrhk/
├── .gitignore
├── LICENSE
├── README.md
├── install_dependency.sh
├── native_sparse_attention/
│ ├── __init__.py
│ ├── infer/
│ │ ├── __init__.py
│ │ ├── inference_func.py
│ │ └── nsa_inference.py
│ ├── model/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── toy_llama.py
│ │ └── toy_nsa_llama.py
│ ├── module/
│ │ ├── __init__.py
│ │ ├── kv_cache.py
│ │ ├── native_sparse_attention.py
│ │ ├── rope.py
│ │ └── self_attention.py
│ └── ops/
│ ├── README.md
│ ├── __init__.py
│ ├── torch/
│ │ ├── __init__.py
│ │ ├── compress_key_value.py
│ │ ├── compressed_attention.py
│ │ ├── compressed_attention_decode.py
│ │ └── topk_sparse_attention.py
│ └── triton/
│ ├── __init__.py
│ ├── compressed_attention.py
│ ├── flash_attention.py
│ ├── flash_attention_decode.py
│ ├── linear_compress.py
│ ├── topk_sparse_attention.py
│ ├── topk_sparse_attention_decode.py
│ ├── utils.py
│ └── weighted_pool.py
├── setup.py
└── test/
├── test_compress_key_value.py
├── test_compressed_attention.py
├── test_flash_attention.py
├── test_kv_cache.py
├── test_linear_compress.py
├── test_nsa_infer.py
├── test_nsa_model.py
├── test_nsa_module.py
├── test_rope.py
└── test_topk_sparse_attention.py
SYMBOL INDEX (157 symbols across 26 files)
FILE: native_sparse_attention/infer/inference_func.py
function compress_infer (line 27) | def compress_infer(
function compressed_attention_infer (line 90) | def compressed_attention_infer(
function topk_sparse_attention_infer (line 148) | def topk_sparse_attention_infer(
function sliding_window_attention_infer (line 175) | def sliding_window_attention_infer(
FILE: native_sparse_attention/infer/nsa_inference.py
function nsa_infer (line 24) | def nsa_infer(
FILE: native_sparse_attention/model/toy_llama.py
class ToyLlamaConfig (line 22) | class ToyLlamaConfig:
class InferenceConfig (line 47) | class InferenceConfig:
class RMSNorm (line 53) | class RMSNorm(nn.Module):
method __init__ (line 54) | def __init__(self, hidden_size: int, eps: float = 1e-6):
method forward (line 59) | def forward(self, hidden_states: torch.Tensor):
class FFN (line 67) | class FFN(nn.Module):
method __init__ (line 68) | def __init__(self, hidden_size: int, intermediate_size: int):
method forward (line 77) | def forward(self, x):
class ToyLlamaLayer (line 82) | class ToyLlamaLayer(nn.Module):
method __init__ (line 83) | def __init__(
method forward (line 112) | def forward(self, x, cu_seqlens):
method inference (line 118) | def inference(self, x, cu_seqlens, step, kv_cache):
class ToyLlama (line 124) | class ToyLlama(nn.Module):
method __init__ (line 125) | def __init__(
method forward (line 163) | def forward(
method inference (line 180) | def inference(
method generate (line 212) | def generate(
FILE: native_sparse_attention/model/toy_nsa_llama.py
class ToyNSALlamaConfig (line 22) | class ToyNSALlamaConfig:
class InferenceConfig (line 56) | class InferenceConfig:
class RMSNorm (line 62) | class RMSNorm(nn.Module):
method __init__ (line 63) | def __init__(self, hidden_size: int, eps: float = 1e-6):
method forward (line 68) | def forward(self, hidden_states: torch.Tensor):
class FFN (line 76) | class FFN(nn.Module):
method __init__ (line 77) | def __init__(self, hidden_size: int, intermediate_size: int):
method forward (line 86) | def forward(self, x):
class ToyNSALlamaLayer (line 91) | class ToyNSALlamaLayer(nn.Module):
method __init__ (line 92) | def __init__(
method forward (line 145) | def forward(self, x, cu_seqlens):
method inference (line 151) | def inference(self, x, cu_seqlens, step, kv_cache):
class ToyNSALlama (line 157) | class ToyNSALlama(nn.Module):
method __init__ (line 158) | def __init__(
method forward (line 206) | def forward(
method inference (line 223) | def inference(
method generate (line 258) | def generate(
FILE: native_sparse_attention/module/kv_cache.py
class KVCache (line 21) | class KVCache:
method __init__ (line 22) | def __init__(
method reset (line 52) | def reset(self):
method update_kv (line 56) | def update_kv(
method _update_kv_prefill (line 78) | def _update_kv_prefill(
method _update_kv_decode (line 104) | def _update_kv_decode(
class NSACache (line 125) | class NSACache:
method __init__ (line 144) | def __init__(
method reset (line 223) | def reset(self):
method prepare_compress (line 233) | def prepare_compress(
method _prepare_compress_prefill (line 245) | def _prepare_compress_prefill(
method _prepare_compress_decode (line 274) | def _prepare_compress_decode(
method update_kv (line 304) | def update_kv(
method _update_kv_prefill (line 338) | def _update_kv_prefill(
method _update_kv_decode (line 401) | def _update_kv_decode(
function _fill_kv_cache_kernel (line 453) | def _fill_kv_cache_kernel(
function _fill_kv_cache (line 519) | def _fill_kv_cache(
FILE: native_sparse_attention/module/native_sparse_attention.py
class NativeSparseAttention (line 45) | class NativeSparseAttention(torch.nn.Module):
method __init__ (line 65) | def __init__(
method init_params (line 140) | def init_params(self):
method forward (line 145) | def forward(
method inference (line 241) | def inference(
FILE: native_sparse_attention/module/rope.py
class RopeConfig (line 28) | class RopeConfig:
method __post_init__ (line 47) | def __post_init__(self):
function rotate_half (line 53) | def rotate_half(x):
class RotaryEmbedding (line 62) | class RotaryEmbedding(nn.Module):
method __init__ (line 73) | def __init__(
method _dynamic_frequency_update (line 94) | def _dynamic_frequency_update(self, position_ids, device):
method generate_cos_sin (line 121) | def generate_cos_sin(self, x: torch.Tensor, position_ids):
method generate_pos_embs (line 158) | def generate_pos_embs(
method forward (line 198) | def forward(self, x, cu_seqlens, step=0, stride=1):
FILE: native_sparse_attention/module/self_attention.py
class SelfAttention (line 22) | class SelfAttention(torch.nn.Module):
method __init__ (line 33) | def __init__(
method init_params (line 70) | def init_params(self):
method forward (line 74) | def forward(
method inference (line 113) | def inference(
FILE: native_sparse_attention/ops/torch/compress_key_value.py
function avgpool_compress_torch (line 19) | def avgpool_compress_torch(
function weightedpool_compress_torch (line 84) | def weightedpool_compress_torch(
function linear_compress_torch (line 156) | def linear_compress_torch(
FILE: native_sparse_attention/ops/torch/compressed_attention.py
function transform_score (line 21) | def transform_score(
function compressed_attention_torch (line 74) | def compressed_attention_torch(
FILE: native_sparse_attention/ops/torch/compressed_attention_decode.py
function transform_score (line 21) | def transform_score(
function compressed_attention_decode (line 65) | def compressed_attention_decode(
FILE: native_sparse_attention/ops/torch/topk_sparse_attention.py
function topk_sparse_attention_torch (line 19) | def topk_sparse_attention_torch(
FILE: native_sparse_attention/ops/triton/compressed_attention.py
function forward_kernel (line 28) | def forward_kernel(
function backward_sum_o_do (line 162) | def backward_sum_o_do(
function backward_dkdv (line 206) | def backward_dkdv(
function backward_dq (line 385) | def backward_dq(
function _compressed_attention_fwd (line 538) | def _compressed_attention_fwd(
function _compressed_attention_bwd (line 618) | def _compressed_attention_bwd(
class CompressedAttention (line 783) | class CompressedAttention(torch.autograd.Function):
method forward (line 785) | def forward(
method backward (line 826) | def backward(ctx, do: torch.Tensor, *args) -> Any:
function score_kernel (line 852) | def score_kernel(
function _get_attention_score (line 954) | def _get_attention_score(
function _transform_score_kernel (line 1030) | def _transform_score_kernel(
function transform_score (line 1116) | def transform_score(
function compressed_attention (line 1182) | def compressed_attention(
FILE: native_sparse_attention/ops/triton/flash_attention.py
function forward_kernel (line 27) | def forward_kernel(
function backward_sum_o_do (line 169) | def backward_sum_o_do(
function backward_dkdv (line 213) | def backward_dkdv(
function backward_dq (line 400) | def backward_dq(
function _flash_attention_fwd (line 563) | def _flash_attention_fwd(
function _flash_attention_bwd (line 645) | def _flash_attention_bwd(
class FlashAttention (line 831) | class FlashAttention(torch.autograd.Function):
method forward (line 833) | def forward(
method backward (line 870) | def backward(ctx, do: torch.Tensor, *args) -> Any:
function flash_attention_varlen (line 895) | def flash_attention_varlen(
FILE: native_sparse_attention/ops/triton/flash_attention_decode.py
function decode_kernel (line 23) | def decode_kernel(
function flash_attention_decode (line 142) | def flash_attention_decode(
function torch_attention_decode (line 220) | def torch_attention_decode(
FILE: native_sparse_attention/ops/triton/linear_compress.py
function linear_compress_fwd_kernel (line 27) | def linear_compress_fwd_kernel(
function linear_compress_bwd_kernel (line 140) | def linear_compress_bwd_kernel(
class LinearCompress (line 309) | class LinearCompress(torch.autograd.Function):
method forward (line 311) | def forward(
method backward (line 420) | def backward(ctx, dy: torch.Tensor, *args) -> Any:
function linear_compress (line 497) | def linear_compress(
FILE: native_sparse_attention/ops/triton/topk_sparse_attention.py
function forward_kernel (line 27) | def forward_kernel(
function backward_sum_o_do (line 177) | def backward_sum_o_do(
function count_kernel (line 221) | def count_kernel(
function count_query (line 267) | def count_query(
function pad_topk_idx_kernel (line 305) | def pad_topk_idx_kernel(
function save_topk_idx_kernel (line 351) | def save_topk_idx_kernel(
function reorder_topk_idx (line 404) | def reorder_topk_idx(
function backward_dkdv (line 481) | def backward_dkdv(
function backward_dq (line 659) | def backward_dq(
function _topk_sparse_attention_fwd (line 828) | def _topk_sparse_attention_fwd(
function _topk_sparse_attention_bwd (line 912) | def _topk_sparse_attention_bwd(
class TopkSparseAttention (line 1112) | class TopkSparseAttention(torch.autograd.Function):
method forward (line 1114) | def forward(
method backward (line 1156) | def backward(ctx, do: torch.Tensor, *args) -> Any:
function topk_sparse_attention (line 1182) | def topk_sparse_attention(
FILE: native_sparse_attention/ops/triton/topk_sparse_attention_decode.py
function forward_kernel (line 23) | def forward_kernel(
function topk_sparse_attention_decode (line 151) | def topk_sparse_attention_decode(
function torch_topk_sparse_attention_decode (line 240) | def torch_topk_sparse_attention_decode(
function generate_topk_idx_example (line 301) | def generate_topk_idx_example(
FILE: native_sparse_attention/ops/triton/utils.py
function is_hopper_gpu (line 17) | def is_hopper_gpu():
function get_compressed_seqlens (line 25) | def get_compressed_seqlens(
function get_num_warps_stages (line 40) | def get_num_warps_stages(head_dim, block_size, is_hopper_gpu):
FILE: native_sparse_attention/ops/triton/weighted_pool.py
function sliding_pool_fwd_kernel (line 24) | def sliding_pool_fwd_kernel(
function sliding_pool_dxdw_kernel (line 89) | def sliding_pool_dxdw_kernel(
class SlidingWindowWeightedPool (line 182) | class SlidingWindowWeightedPool(torch.autograd.Function):
method forward (line 184) | def forward(
method backward (line 245) | def backward(ctx, dy, _):
function weightedpool_compress (line 301) | def weightedpool_compress(
function avgpool_compress (line 333) | def avgpool_compress(
FILE: test/test_compress_key_value.py
function benchmark (line 81) | def benchmark(N, H, D, provider):
FILE: test/test_compressed_attention.py
function benchmark (line 171) | def benchmark(N, H, D, provider):
function benchmark (line 267) | def benchmark(N, H, D, provider):
FILE: test/test_flash_attention.py
function benchmark (line 124) | def benchmark(N, H, D, provider):
function benchmark (line 176) | def benchmark(N, H, D, provider):
FILE: test/test_linear_compress.py
function test_linear_compress (line 21) | def test_linear_compress(
function benchmark_fwdbwd (line 220) | def benchmark_fwdbwd(N, H, D, provider):
FILE: test/test_nsa_module.py
function benchmark (line 121) | def benchmark(N, provider):
function benchmark (line 153) | def benchmark(N, provider):
FILE: test/test_topk_sparse_attention.py
function generate_topk_idx_example (line 34) | def generate_topk_idx_example(
function benchmark (line 174) | def benchmark(N, H, D, K, provider):
function benchmark (line 245) | def benchmark(N, H, D, K, provider):
Condensed preview — 44 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (383K chars).
[
{
"path": ".gitignore",
"chars": 3415,
"preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
},
{
"path": "LICENSE",
"chars": 11357,
"preview": " Apache License\n Version 2.0, January 2004\n "
},
{
"path": "README.md",
"chars": 10986,
"preview": "<div align=\"center\">\n\n# Native Sparse Attention Triton\n\n</div>\n\nThis repository implements the sparse attention mechanis"
},
{
"path": "install_dependency.sh",
"chars": 441,
"preview": "pip3 install packaging -i https://pypi.org/simple\npip3 install numpy==1.26.4 -i https://pypi.org/simple\npip3 install tor"
},
{
"path": "native_sparse_attention/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "native_sparse_attention/infer/__init__.py",
"chars": 688,
"preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
},
{
"path": "native_sparse_attention/infer/inference_func.py",
"chars": 5873,
"preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
},
{
"path": "native_sparse_attention/infer/nsa_inference.py",
"chars": 5732,
"preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
},
{
"path": "native_sparse_attention/model/README.md",
"chars": 4623,
"preview": "# Guide for the ToyNSALlama Model\n\nThe `ToyNSALlama` model is a custom implementation of a Llama-like transformer archit"
},
{
"path": "native_sparse_attention/model/__init__.py",
"chars": 941,
"preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
},
{
"path": "native_sparse_attention/model/toy_llama.py",
"chars": 9744,
"preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
},
{
"path": "native_sparse_attention/model/toy_nsa_llama.py",
"chars": 11688,
"preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
},
{
"path": "native_sparse_attention/module/__init__.py",
"chars": 1018,
"preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
},
{
"path": "native_sparse_attention/module/kv_cache.py",
"chars": 18667,
"preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
},
{
"path": "native_sparse_attention/module/native_sparse_attention.py",
"chars": 9898,
"preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
},
{
"path": "native_sparse_attention/module/rope.py",
"chars": 8066,
"preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserve"
},
{
"path": "native_sparse_attention/module/self_attention.py",
"chars": 5866,
"preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
},
{
"path": "native_sparse_attention/ops/README.md",
"chars": 6333,
"preview": "# Triton Functions for Native Sparse Attention\n\nThis folder provides efficient Triton-based implementations of component"
},
{
"path": "native_sparse_attention/ops/__init__.py",
"chars": 1812,
"preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
},
{
"path": "native_sparse_attention/ops/torch/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "native_sparse_attention/ops/torch/compress_key_value.py",
"chars": 9107,
"preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
},
{
"path": "native_sparse_attention/ops/torch/compressed_attention.py",
"chars": 7313,
"preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
},
{
"path": "native_sparse_attention/ops/torch/compressed_attention_decode.py",
"chars": 5563,
"preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
},
{
"path": "native_sparse_attention/ops/torch/topk_sparse_attention.py",
"chars": 3699,
"preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
},
{
"path": "native_sparse_attention/ops/triton/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "native_sparse_attention/ops/triton/compressed_attention.py",
"chars": 41639,
"preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
},
{
"path": "native_sparse_attention/ops/triton/flash_attention.py",
"chars": 27720,
"preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
},
{
"path": "native_sparse_attention/ops/triton/flash_attention_decode.py",
"chars": 9380,
"preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
},
{
"path": "native_sparse_attention/ops/triton/linear_compress.py",
"chars": 18875,
"preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
},
{
"path": "native_sparse_attention/ops/triton/topk_sparse_attention.py",
"chars": 38430,
"preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
},
{
"path": "native_sparse_attention/ops/triton/topk_sparse_attention_decode.py",
"chars": 11940,
"preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
},
{
"path": "native_sparse_attention/ops/triton/utils.py",
"chars": 2667,
"preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
},
{
"path": "native_sparse_attention/ops/triton/weighted_pool.py",
"chars": 13525,
"preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
},
{
"path": "setup.py",
"chars": 1630,
"preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
},
{
"path": "test/test_compress_key_value.py",
"chars": 3944,
"preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
},
{
"path": "test/test_compressed_attention.py",
"chars": 11305,
"preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
},
{
"path": "test/test_flash_attention.py",
"chars": 7937,
"preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
},
{
"path": "test/test_kv_cache.py",
"chars": 1822,
"preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
},
{
"path": "test/test_linear_compress.py",
"chars": 8767,
"preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
},
{
"path": "test/test_nsa_infer.py",
"chars": 3454,
"preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
},
{
"path": "test/test_nsa_model.py",
"chars": 2502,
"preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
},
{
"path": "test/test_nsa_module.py",
"chars": 5937,
"preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
},
{
"path": "test/test_rope.py",
"chars": 1537,
"preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
},
{
"path": "test/test_topk_sparse_attention.py",
"chars": 11801,
"preview": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may"
}
]
About this extraction
This page contains the full source code of the XunhaoLai/native-sparse-attention-triton GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 44 files (359.0 KB), approximately 101.9k tokens, and a symbol index with 157 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.