Repository: XunhaoLai/native-sparse-attention-triton Branch: main Commit: 9bea856c911e Files: 44 Total size: 359.0 KB Directory structure: gitextract_5vrtmrhk/ ├── .gitignore ├── LICENSE ├── README.md ├── install_dependency.sh ├── native_sparse_attention/ │ ├── __init__.py │ ├── infer/ │ │ ├── __init__.py │ │ ├── inference_func.py │ │ └── nsa_inference.py │ ├── model/ │ │ ├── README.md │ │ ├── __init__.py │ │ ├── toy_llama.py │ │ └── toy_nsa_llama.py │ ├── module/ │ │ ├── __init__.py │ │ ├── kv_cache.py │ │ ├── native_sparse_attention.py │ │ ├── rope.py │ │ └── self_attention.py │ └── ops/ │ ├── README.md │ ├── __init__.py │ ├── torch/ │ │ ├── __init__.py │ │ ├── compress_key_value.py │ │ ├── compressed_attention.py │ │ ├── compressed_attention_decode.py │ │ └── topk_sparse_attention.py │ └── triton/ │ ├── __init__.py │ ├── compressed_attention.py │ ├── flash_attention.py │ ├── flash_attention_decode.py │ ├── linear_compress.py │ ├── topk_sparse_attention.py │ ├── topk_sparse_attention_decode.py │ ├── utils.py │ └── weighted_pool.py ├── setup.py └── test/ ├── test_compress_key_value.py ├── test_compressed_attention.py ├── test_flash_attention.py ├── test_kv_cache.py ├── test_linear_compress.py ├── test_nsa_infer.py ├── test_nsa_model.py ├── test_nsa_module.py ├── test_rope.py └── test_topk_sparse_attention.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ cover/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder .pybuilder/ target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: # .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # UV # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. # This is especially recommended for binary packages to ensure reproducibility, and is more # commonly ignored for libraries. #uv.lock # poetry # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. # This is especially recommended for binary packages to ensure reproducibility, and is more # commonly ignored for libraries. # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control #poetry.lock # pdm # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. #pdm.lock # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it # in version control. # https://pdm.fming.dev/latest/usage/project/#working-with-version-control .pdm.toml .pdm-python .pdm-build/ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # pytype static type analyzer .pytype/ # Cython debug symbols cython_debug/ # PyCharm # JetBrains specific template is maintained in a separate JetBrains.gitignore that can # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ # PyPI configuration file .pypirc ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================
# Native Sparse Attention Triton
This repository implements the sparse attention mechanism introduced in the paper [Native Sparse Attention: Hardware-Aligned and Natively Trainable Sparse Attention](https://arxiv.org/abs/2502.11089) and provides an efficient training implementation based on [Triton](https://github.com/triton-lang/triton). 🎉 We now support both training and inference for Native Sparse Attention (variable-length version, including prefilling, decoding, and KV cache management). We have provided a toy model at `model.ToyNSALlama`, which supports `forward` function for training and `generate` function for inference. Welcome to try it out! ## Requirements Ensure the following dependencies are installed: - PyTorch >= 2.1.0 - triton >= 3.0.0 - einops >= 0.7.0 - flash_attn >= 2.6.3 ## Usage ### Notes 1. PyTorch implementations (`ops.torch`) are intended for debugging only. 2. For production use, prefer Triton operators (`ops.triton`). 3. All implementations are based on the varlen approach similiar to flash_attn_func_varlen. Please concatenate the inputs of a batch before use. 4. Only support attention head dimension less than 128 for now. ### Install You can install `native_sparse_attention` using pip: ```shell pip install git+https://github.com/XunhaoLai/native-sparse-attention-triton.git ``` ### Functions The `ops` module has implemented several functions required for native sparse attention. For detailed usage instructions, please see [this link](https://github.com/XunhaoLai/native-sparse-attention-triton/tree/main/native_sparse_attention/ops#readme). You can import those functions from the `ops` module: ```python import torch from native_sparse_attention.ops import linear_compress, compressed_attention, topk_sparse_attention # input example num_q_heads = 64 num_kv_heads = 4 head_dim = 128 kernel_size = 32 kernel_stride = 16 block_size = 64 topk = 16 cu_seqlens = torch.Tensor([0, 1024, 8192, 16384]).to(torch.int32).cuda() query = torch.randn(16384, num_q_heads, head_dim).to(torch.bfloat16).cuda() key = torch.randn(16384, num_kv_heads, head_dim).to(torch.bfloat16).cuda() value = torch.randn(16384, num_kv_heads, head_dim).to(torch.bfloat16).cuda() # weight example w = ( torch.randn(num_kv_heads, kernel_size * head_dim, head_dim) .to(torch.bfloat16) .cuda() ) pe = torch.randn(num_kv_heads, kernel_size, head_dim).to(torch.bfloat16).cuda() # 1. key value compression compressed_key, compressed_cu_seqlens = linear_compress( key, w, cu_seqlens, kernel_size, kernel_stride, pe ) compressed_value, _ = linear_compress( value, w, cu_seqlens, kernel_size, kernel_stride, None ) # 2. attention between query and compressed key value compressed_attn_output, topk_idx = compressed_attention( query, compressed_key, compressed_value, kernel_size, kernel_stride, block_size, topk, cu_seqlens, compressed_cu_seqlens, init_blocks=1, local_blocks=2, ) # 3. topk sparse attention sparse_attn_output = topk_sparse_attention( query, key, value, topk_idx, block_size, cu_seqlens, ) ``` ### Module The `modules` directory also provides implementations based on `torch.nn.module` for easy integration into models. ```python from native_sparse_attention.modules import NativeSparseAttention, RopeConfig NSA_Layer = NativeSparseAttention( compress_type="linear", hidden_size=4096, num_q_heads=64, num_kv_heads=4, head_dim=128, kernel_size=32, kernel_stride=16, block_size=64, topk=8, init_blocks=1, local_blocks=2, window_size=512, rope_config=RopeConfig( max_position_embeddings=32768, head_dim=128, rope_theta=500000, rope_scaling={ "factor": 4.0, "high_freq_factor": 4.0, "low_freq_factor": 1.0, "original_max_position_embeddings": 8192, "rope_type": "llama3", }, ), ) ``` ### Model We offer two simplified LLaMA models in the `model` directory, featuring self-attention and native sparse attention. For more details on how to use these models, please refer to [this link](https://github.com/XunhaoLai/native-sparse-attention-triton/tree/main/native_sparse_attention/model#readme). ```python from native_sparse_attention.model import ToyNSALlamaConfig, InferenceConfig, ToyNSALlama config = ToyNSALlamaConfig( hidden_size=4096, intermediate_size=14336, num_hidden_layers=8, num_attention_heads=32, num_key_value_heads=2, head_dim=128, rope_theta=500000.0, rope_scaling={ "factor": 8.0, "high_freq_factor": 4.0, "low_freq_factor": 1.0, "original_max_position_embeddings": 8192, "rope_type": "llama3", }, compress_type="weightedpool", kernel_size=32, kernel_stride=16, block_size=64, topk=8, init_blocks=1, local_blocks=2, window_size=512, ) inference_config = InferenceConfig( max_batch_size=4, max_length=8192, max_new_tokens=128, ) model = ToyNSALlama(config, inference_config).cuda().bfloat16() ``` ## Testing Some test scripts are available in the `test` folder and can be run directly for unit testing. For example: ```bash python test/test_topk_sparse_attention.py python test/test_nsa_module.py python test/test_nsa_model.py ``` ### Benchmarks Here are the speed benchmarks conducted on a single NVIDIA A100 GPU or H100 GPU for the `topk_sparse_attention` function: A100 GPU speed benchmarks: ```sh ** forward with block size 64 **: N Flash Triton-Flash Triton-Top8 Triton-Top16 0 2048.0 0.414144 0.635648 0.633440 1.009184 1 4096.0 1.400304 2.267552 1.179808 1.916736 2 8192.0 5.223776 8.528160 2.266816 3.723168 3 16384.0 20.225697 32.745537 4.468128 7.359168 4 32768.0 79.587715 128.951065 8.517440 14.142848 5 65536.0 321.240479 511.652100 17.249599 30.991360 6 131072.0 1349.810425 2063.245605 36.400482 67.884544 ** backward with block size 64 **: N Flash Triton-Flash Triton-Top8 Triton-Top16 0 2048.0 1.315440 2.348560 1.941568 2.691040 1 4096.0 4.271584 8.553184 3.647744 5.032160 2 8192.0 15.323984 32.665440 5.650144 9.066112 3 16384.0 58.753281 127.675964 11.160832 17.113279 4 32768.0 227.770462 504.572693 21.723392 34.715614 5 65536.0 899.181274 2059.718506 44.517181 76.309441 6 131072.0 3587.918701 8530.726562 105.344734 182.970169 ``` H100 GPU benchmarks: ```sh ** forward with block size 64 **: N Flash Triton-Flash Triton-Top8 Triton-Top16 0 2048.0 0.259552 0.293888 0.584544 0.917664 1 4096.0 0.846848 1.029904 1.094976 1.745136 2 8192.0 3.043744 3.843392 2.128256 3.396880 3 16384.0 11.743568 14.791360 4.190528 6.704192 4 32768.0 45.968513 57.532478 7.614496 12.417440 5 65536.0 187.234375 228.093948 14.840048 24.511856 6 131072.0 810.890381 914.693970 29.470400 48.990192 ** backward with block size 64 **: N Flash Triton-Flash Triton-Top8 Triton-Top16 0 2048.0 0.798976 1.096096 1.117312 1.380016 1 4096.0 2.545680 3.826336 1.669760 2.214880 2 8192.0 9.029760 14.411633 2.772096 3.947456 3 16384.0 34.144016 58.945698 5.201344 7.538912 4 32768.0 135.718369 233.369247 9.968864 15.154192 5 65536.0 541.053894 929.337646 21.089870 33.818878 6 131072.0 2139.974854 3785.540527 54.918144 93.750717 ``` Here comes another speed benchmark result for testing `compressed_attention` function on a single NVIDIA A100 GPU or H100 GPU: A100 GPU speed benchmarks: ```sh ** forward with kernel 32 and stride 16 **: N Flash Triton-Flash Compressed Compressed-wo-Score 0 2048.0 0.413664 0.635488 0.655024 0.170816 1 4096.0 1.396416 2.247648 1.132304 0.377152 2 8192.0 5.234656 8.526400 2.879200 0.977952 3 16384.0 19.988865 32.755199 9.426448 2.943024 4 32768.0 79.419907 128.955170 30.284096 9.901120 5 65536.0 321.590210 511.615509 112.260544 36.001602 6 131072.0 1346.996338 2069.837891 423.099518 136.820038 ** backward with kernel 32 and stride 16 **: N Flash Triton-Flash Compressed 0 2048.0 1.322560 2.352000 0.486784 1 4096.0 4.270832 8.552608 0.971392 2 8192.0 15.515680 32.671329 2.603744 3 16384.0 59.345055 128.377472 8.499456 4 32768.0 230.626144 506.581238 30.064833 5 65536.0 919.260498 2068.642578 113.466560 6 131072.0 3646.603760 8498.374023 439.623444 ``` H100 GPU speed benchmarks: ```sh ** forward with kernel 32 and stride 16 **: N Flash Triton-Flash Compressed Compressed-wo-Score 0 2048.0 0.259488 0.297152 0.485920 0.103232 1 4096.0 0.847376 1.030400 0.710208 0.217760 2 8192.0 3.044016 3.875840 1.607360 0.516016 3 16384.0 11.823104 14.829360 4.970272 1.440288 4 32768.0 46.204750 57.527809 15.004992 4.584736 5 65536.0 187.324249 227.909958 53.009087 16.134224 6 131072.0 810.707214 910.106873 191.245728 60.154270 ** backward with kernel 32 and stride 16 **: N Flash Triton-Flash Compressed 0 2048.0 0.797728 1.090640 0.283104 1 4096.0 2.547088 3.834592 0.550464 2 8192.0 9.021520 14.421088 1.249184 3 16384.0 34.159508 58.793377 3.743440 4 32768.0 136.844070 233.447708 12.640032 5 65536.0 537.559814 929.360229 46.054817 6 131072.0 2135.629883 3782.351562 175.587296 ``` All the speed benchmarks above were tested with 64 query heads, 4 key/value heads, and a head dimension of 128. ## Contributing Contributions are welcome! Please open an issue to discuss major changes. ## Contact For any questions or feedback, please feel free to contact laixunhao@pku.edu.cn. ## Citations ```bibtex @inproceedings{Yuan2025NativeSA, title = {Native Sparse Attention: Hardware-Aligned and Natively Trainable Sparse Attention}, author = {Jingyang Yuan and Huazuo Gao and Damai Dai and Junyu Luo and Liang Zhao and Zhengyan Zhang and Zhenda Xie and Y. X. Wei and Lean Wang and Zhiping Xiao and Yuqing Wang and Chong Ruan and Ming Zhang and Wenfeng Liang and Wangding Zeng}, year = {2025}, url = {https://api.semanticscholar.org/CorpusID:276408911} } ``` ================================================ FILE: install_dependency.sh ================================================ pip3 install packaging -i https://pypi.org/simple pip3 install numpy==1.26.4 -i https://pypi.org/simple pip3 install torch==2.4.0 -i https://pypi.org/simple pip3 install triton==3.0.0 -i https://pypi.org/simple pip3 install transformers==4.44.0 -i https://pypi.org/simple pip3 install flash_attn==2.6.3 -i https://pypi.org/simple pip3 install matplotlib==3.9.4 -i https://pypi.org/simple pip3 install pandas==2.2.3 -i https://pypi.org/simple ================================================ FILE: native_sparse_attention/__init__.py ================================================ ================================================ FILE: native_sparse_attention/infer/__init__.py ================================================ # Copyright 2025 Xunhao Lai & Jianqiao Lu. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from native_sparse_attention.infer.nsa_inference import nsa_infer __all__ = [ "nsa_infer", ] ================================================ FILE: native_sparse_attention/infer/inference_func.py ================================================ # Copyright 2025 Xunhao Lai & Jianqiao Lu. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from typing import Tuple, Callable, Optional from flash_attn import flash_attn_varlen_func from native_sparse_attention.ops import ( flash_attention_decode, compressed_attention, compressed_attention_decode, topk_sparse_attention, topk_sparse_attention_decode, ) from native_sparse_attention.ops.triton.utils import get_compressed_seqlens def compress_infer( cu_seqlens: torch.Tensor, step: int, key: torch.Tensor, value: torch.Tensor, cache, weight: Tuple[torch.Tensor, torch.Tensor], compress_func: Tuple[Callable, Callable], intra_block_pe: Optional[torch.Tensor], kernel_size: int, kernel_stride: int, ): if step == 0: key, compress_cu_seqlens = compress_func[0]( key, weight[0], cu_seqlens, kernel_size, kernel_stride, intra_block_pe, ) value, _ = compress_func[1]( value, weight[1], cu_seqlens, kernel_size, kernel_stride, ) else: batch_size = cu_seqlens.shape[0] - 1 aux_cu_seqlens = ( torch.arange(batch_size + 1, dtype=torch.int32).to(cu_seqlens.device) * kernel_size ) key, _ = compress_func[0]( cache.before_compress_kv_cache[0, :batch_size].view( batch_size * kernel_size, cache.num_kv_heads, cache.head_dim ), weight[0], aux_cu_seqlens, kernel_size, kernel_stride, intra_block_pe, ) value, _ = compress_func[1]( cache.before_compress_kv_cache[1, :batch_size].view( batch_size * kernel_size, cache.num_kv_heads, cache.head_dim ), weight[1], aux_cu_seqlens, kernel_size, kernel_stride, ) # return actual compress_cu_seqlens before this token compress_cu_seqlens = torch.zeros( batch_size + 1, dtype=torch.int32, device=key.device ) compress_cu_seqlens[1:] = torch.cumsum( cache.compress_kv_len[:batch_size], dim=0 ) return key, value, compress_cu_seqlens def compressed_attention_infer( cu_seqlens, step, query, key, value, cache, kernel_size, kernel_stride, topk, block_size, init_blocks, local_blocks, ): if step == 0: seqlens = cu_seqlens[1:] - cu_seqlens[:-1] compress_seqlens, compress_cu_seqlens = get_compressed_seqlens( cu_seqlens, kernel_size, kernel_stride ) attn_output, topk_idx = compressed_attention( query, key, value, kernel_size, kernel_stride, block_size, topk, cu_seqlens, compress_cu_seqlens, seqlens.max().item(), compress_seqlens.max().item(), None, init_blocks, local_blocks, ) else: batch_size = cu_seqlens.shape[0] - 1 seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + step attn_output, topk_idx = compressed_attention_decode( query, cache.compress_kv_cache[ 0, :batch_size, : cache.compress_kv_len[:batch_size].max() ], cache.compress_kv_cache[ 1, :batch_size, : cache.compress_kv_len[:batch_size].max() ], seqlens, cache.compress_kv_len[:batch_size], kernel_size, kernel_stride, block_size, topk, init_blocks, local_blocks, ) return attn_output, topk_idx def topk_sparse_attention_infer( cu_seqlens, step, query, key, value, cache, topk_idx, block_size, ): if step == 0: attn_output = topk_sparse_attention( query, key, value, topk_idx, block_size, cu_seqlens ) else: batch_size = cu_seqlens.shape[0] - 1 attn_output = topk_sparse_attention_decode( query, cache.sparse_kv_cache[0, :batch_size], cache.sparse_kv_cache[1, :batch_size], topk_idx, block_size, cache.sparse_kv_len[:batch_size], ) return attn_output def sliding_window_attention_infer( cu_seqlens, step, query, key, value, cache, window_size ): if step == 0: seqlens = cu_seqlens[1:] - cu_seqlens[:-1] attn_output = flash_attn_varlen_func( query, key, value, cu_seqlens, cu_seqlens, seqlens.max().item(), seqlens.max().item(), causal=True, window_size=(window_size, -1), ) else: batch_size = cu_seqlens.shape[0] - 1 attn_output = flash_attention_decode( query, cache.sliding_kv_cache[0, :batch_size], cache.sliding_kv_cache[1, :batch_size], torch.minimum( cache.sliding_kv_len, torch.zeros_like(cache.sliding_kv_len) + window_size, )[:batch_size], ) return attn_output ================================================ FILE: native_sparse_attention/infer/nsa_inference.py ================================================ # Copyright 2025 Xunhao Lai & Jianqiao Lu. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from typing import Tuple, Callable, Optional from native_sparse_attention.infer.inference_func import ( compress_infer, compressed_attention_infer, topk_sparse_attention_infer, sliding_window_attention_infer, ) def nsa_infer( cu_seqlens: torch.Tensor, step: int, # qkv for three parts query: torch.Tensor, key: torch.Tensor, # prefill: [total_len, num_heads, head_dim], decode: [batch_size, num_heads, head_dim] value: torch.Tensor, gate_value: torch.Tensor, # prefill: [total_len, num_heads, 3], decode: [batch_size, num_heads, 3] # rope and kv cache rope, cache, # weight for nsa compress compress_weight: Tuple[ torch.Tensor, torch.Tensor ], # compress weight for key and value compress_func: Tuple[Callable, Callable], # compress function for key and value intra_block_pe: Optional[torch.Tensor], # nsa parameters kernel_size: int, kernel_stride: int, block_size: int, topk: int, init_blocks: int, local_blocks: int, window_size: int, ) -> torch.Tensor: """Inference function for native sparse attention. Support prefill and decode with kv cache. Args: cu_seqlens (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen. step (int): current inference step, step == 0 means prefill, step > 0 means decode step. query (torch.Tensor): for prefill, shape [total_len, num_q_heads, head_dim]; for decode, shape [batch_size, num_q_heads, head_dim] key (torch.Tensor): for prefill, shape [total_len, num_kv_heads, head_dim]; for decode, shape [batch_size, num_kv_heads, head_dim] value (torch.Tensor): for prefill, shape [total_len, num_kv_heads, head_dim]; for decode, shape [batch_size, num_kv_heads, head_dim] gate_value (torch.Tensor): for prefill, shape [total_len, num_heads, 3]; for decode, shape [batch_size, num_heads, 3] rope (RotaryEmbedding): rope module, see native_sparse_attention.module.rope.RotaryEmbedding for details cache (NSACache): kv cache, seed native_sparse_attention.module.kv_cache.NSACache for details compress_weight (Tuple[torch.Tensor, torch.Tensor]): compress weight of key and value respectively compress_func (Tuple[Callable, Callable]): compress functions for key and value respectively intra_block_pe (Optional[torch.Tensor]): intra-block positonal embedding for compression, set to None if don't use it kernel_size (int): kernel size of compression kernel_stride (int): kernel stride ofr compression block_size (int): block size of sparse attention topk (int): topk of sparse attention init_blocks (int): number of blocks at the begining of the sequence, these blocks are force to be computed in sparse attention local_blocks (int): number of blocks at the local window of each query, these blocks are force to be computed in sparse attention window_size (int): window size for sliding window attention Returns: torch.Tensor: native sparse attention output, same shape as input query """ # reset kv cache at the begining of prefilling if step == 0: cache.reset() # prepare for compress cache.prepare_compress(cu_seqlens, step, key, value) # compressed key and value before rope compress_key, compress_value, compress_cu_seqlens = compress_infer( cu_seqlens, step, key, value, cache, compress_weight, compress_func, intra_block_pe, kernel_size, kernel_stride, ) # do rope query = rope(query, cu_seqlens, step) if step == 0: compress_key = rope( compress_key, compress_cu_seqlens, step, stride=cache.kernel_stride ) else: compress_key = rope( compress_key, compress_cu_seqlens, 1, stride=cache.kernel_stride ) key = rope(key, cu_seqlens, step) # update kv cache cache.update_kv( cu_seqlens, step, compress_key, compress_value, key, value, key, value, ) # compressed attention compress_attn_output, topk_idx = compressed_attention_infer( cu_seqlens, step, query, compress_key, compress_value, cache, kernel_size, kernel_stride, topk, block_size, init_blocks, local_blocks, ) # topk sparse attention sparse_attn_output = topk_sparse_attention_infer( cu_seqlens, step, query, key, value, cache, topk_idx, block_size, ) # sliding window attention sliding_attn_output = sliding_window_attention_infer( cu_seqlens, step, query, key, value, cache, window_size ) # combine 3 attn output attn_output = ( gate_value[..., 0, None] * compress_attn_output + gate_value[..., 1, None] * sparse_attn_output + gate_value[..., 2, None] * sliding_attn_output ) return attn_output ================================================ FILE: native_sparse_attention/model/README.md ================================================ # Guide for the ToyNSALlama Model The `ToyNSALlama` model is a custom implementation of a Llama-like transformer architecture featuring a Native Sparse Attention (NSA) module. This guide explains how to integrate the NSA module into your own model. ## Overview The `ToyNSALlama` model consists of: - **Configuration**: Defined by `ToyNSALlamaConfig` (model structure parameters) and `InferenceConfig` (inference-specific parameters). - **Components**: An embedding layer, multiple NativeSparseAttention modules, Feed-Forward Network (FFN) modules, normalization layers, and a language model head. ## Step-by-Step Instructions ### 1. Import Necessary Modules ```python import torch import torch.nn as nn from native_sparse_attention.model import ToyNSALlama, ToyNSALlamaConfig, InferenceConfig ``` ### 2. Define Configurations Create instances of `ToyNSALlamaConfig` and `InferenceConfig` to set model and inference parameters. #### Model Configuration The model configuration aligns with the Transformers Llama model configuration. Adjust the following parameters to control the NSA module’s sparsity: - `compress_type`: Compression method for keys/values. Supported options: `avgpool`, `weightedpool`, `linear`. - `kernel_size` & `kernel_stride`: `kernel_size` determines how many tokens are compressed into one; `kernel_stride` sets the sliding window stride (must be divisible by `kernel_size`). - `block_size`: Block size for sparse attention (recommended: 64 or 128). - `topk`, `init_blocks`, `local_blocks`: `topk` specifies the number of blocks selected in sparse attention; `init_blocks` and `local_blocks` define the number of initial and local blocks that must be selected. - `window_size`: Size of the sliding window for attention. Example: ```python config = ToyNSALlamaConfig( hidden_size=4096, intermediate_size=14336, num_hidden_layers=8, num_attention_heads=32, num_key_value_heads=2, head_dim=128, vocab_size=128288, max_position_embeddings=131072, rope_theta=500000.0, rope_scaling={ "factor": 8.0, "high_freq_factor": 4.0, "low_freq_factor": 1.0, "original_max_position_embeddings": 8192, "rope_type": "llama3", }, compress_type="weightedpool", kernel_size=32, kernel_stride=16, block_size=64, topk=8, init_blocks=1, local_blocks=2, window_size=512, ) ``` #### Inference Configuration This configuration applies during inference, initializing the Key-Value (KV) Cache based on these settings. The full KV cache size is calculated as `max_batch_size × max_length × num_kv_heads × num_layers × 2 × 2` bytes. Currently, only greedy decoding is supported as an example. Example: ```python inference_config = InferenceConfig( max_batch_size=4, max_length=8192, max_new_tokens=128, ) ``` ### 3. Initialize the Model Instantiate the model and move it to the GPU with the appropriate data type (currently, only `bfloat16` is supported). ```python model = ToyNSALlama(config, inference_config).cuda().to(torch.bfloat16) ``` ### 4. Forward & Generate The model supports two methods: - **`forward`**: Accepts `input_ids` and `cu_seqlens`, returning final logits after the language model head. Use this for training or evaluation. - **`generate`**: Accepts `input_ids` and `cu_seqlens`, generating output tokens via greedy sampling. This demonstrates KV cache usage for token generation (pre-filling and decoding). Example: ```python # Example input batch_size = 4 seqlens = torch.randint(0, 4096, (batch_size,), dtype=torch.int32, device="cuda") cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda") cu_seqlens[1:] = torch.cumsum(seqlens, dim=0) input_ids = torch.randint(0, 128288, (cu_seqlens[-1],), dtype=torch.int64, device="cuda") print(f"\nEXAMPLE INPUT:\ncu_seqlens: {cu_seqlens}\ninput_ids: {input_ids.shape}\n") # Example forward logits = model(input_ids, cu_seqlens) print(f"\nEXAMPLE OUTPUT:\nlogits: {logits.shape}\n") # Example generate output_tokens = model.generate(input_ids, cu_seqlens) print(f"\nEXAMPLE GENERATE:\noutput_tokens: {output_tokens}\n") ``` ## Toy Llama Model with Self-Attention A simpler toy model with the Llama structure is available in `native_sparse_attention/model/toy_llama.py`. Compare `ToyLlama` and `ToyNSALlama` to see how to adapt a self-attention model into an NSA-based model. The primary difference lies in replacing the `SelfAttention` module with the `NativeSparseAttention` module, along with updates to the KV cache and inference function. These changes are straightforward and easy to implement. ================================================ FILE: native_sparse_attention/model/__init__.py ================================================ # Copyright 2025 Xunhao Lai & Jianqiao Lu. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from native_sparse_attention.model.toy_llama import ( ToyLlamaConfig, InferenceConfig, ToyLlama, ) from native_sparse_attention.model.toy_nsa_llama import ( ToyNSALlamaConfig, InferenceConfig, ToyNSALlama, ) __all__ = [ "ToyLlamaConfig", "ToyNSALlamaConfig", "InferenceConfig", "ToyLlama", "ToyNSALlama", ] ================================================ FILE: native_sparse_attention/model/toy_llama.py ================================================ # Copyright 2025 Xunhao Lai & Jianqiao Lu. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional import torch import torch.nn as nn from dataclasses import dataclass, field from native_sparse_attention.module import SelfAttention, RopeConfig, KVCache @dataclass class ToyLlamaConfig: # embedding config vocab_size: int = 128288 max_position_embeddings: int = 131072 # model config hidden_size: int = 4096 intermediate_size: int = 14336 num_hidden_layers: int = 32 num_attention_heads: int = 32 num_key_value_heads: int = 2 head_dim: int = 128 # rope config rope_theta: float = 500000.0 rope_scaling: dict = field( default_factory=lambda: { "factor": 8.0, "high_freq_factor": 4.0, "low_freq_factor": 1.0, "original_max_position_embeddings": 8192, "rope_type": "llama3", } ) @dataclass class InferenceConfig: max_batch_size: int = 32 max_length: int = 8192 max_new_tokens: int = 128 class RMSNorm(nn.Module): def __init__(self, hidden_size: int, eps: float = 1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states: torch.Tensor): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) class FFN(nn.Module): def __init__(self, hidden_size: int, intermediate_size: int): super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = nn.SiLU() def forward(self, x): down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj class ToyLlamaLayer(nn.Module): def __init__( self, hidden_size: int, intermediate_size: int, num_q_heads: int, num_kv_heads: int, head_dim: int, rope_config: RopeConfig, ): super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_q_heads = num_q_heads self.num_kv_heads = num_kv_heads self.head_dim = head_dim self.rope_config = rope_config self.attn_norm = RMSNorm(self.hidden_size) self.self_attn = SelfAttention( hidden_size=self.hidden_size, num_q_heads=self.num_q_heads, num_kv_heads=self.num_kv_heads, head_dim=self.head_dim, rope_config=rope_config, ) self.ffn_norm = RMSNorm(self.hidden_size) self.ffn = FFN( hidden_size=self.hidden_size, intermediate_size=self.intermediate_size ) def forward(self, x, cu_seqlens): x = x + self.self_attn(self.attn_norm(x), cu_seqlens) x = x + self.ffn(self.ffn_norm(x)) return x @torch.no_grad() def inference(self, x, cu_seqlens, step, kv_cache): x = x + self.self_attn.inference(self.attn_norm(x), cu_seqlens, step, kv_cache) x = x + self.ffn(self.ffn_norm(x)) return x class ToyLlama(nn.Module): def __init__( self, config: ToyLlamaConfig, inference_config: Optional[InferenceConfig] = None ): super().__init__() self.config = config self.embedding = nn.Embedding(self.config.vocab_size, self.config.hidden_size) self.rope_config = RopeConfig( head_dim=self.config.head_dim, rope_theta=self.config.rope_theta, rope_scaling=self.config.rope_scaling, ) self.layers = nn.ModuleList( [ ToyLlamaLayer( hidden_size=self.config.hidden_size, intermediate_size=self.config.intermediate_size, num_q_heads=self.config.num_attention_heads, num_kv_heads=self.config.num_key_value_heads, head_dim=self.config.head_dim, rope_config=RopeConfig( self.config.max_position_embeddings, self.config.head_dim, self.config.rope_theta, self.config.rope_scaling, ), ) for _ in range(self.config.num_hidden_layers) ] ) self.norm = RMSNorm(self.config.hidden_size) self.lm_head = nn.Linear( self.config.hidden_size, self.config.vocab_size, bias=False ) # inference config and kv cache self.inference_config = inference_config self.kv_cache = None def forward( self, input_ids: torch.LongTensor, # shape: [total_length, ] cu_seqlens: torch.LongTensor, # shape: [batch_size + 1, ] ): # embedding x = self.embedding(input_ids).to(torch.bfloat16) # layers for layer in self.layers: x = layer(x, cu_seqlens) # final norm x = self.norm(x) # lanugauge head x = self.lm_head(x).to(torch.float32) # [total_len, vocab_size] return x @torch.no_grad() def inference( self, input_ids: torch.LongTensor, # prefill shape: [total_length, ]; decode shape: [batch_size, ] cu_seqlens: torch.LongTensor, # shape: [batch_size + 1, ] step: int, ): # set kv cache if self.kv_cache is None if self.kv_cache is None: self.kv_cache = [ KVCache( max_batch_size=self.inference_config.max_batch_size, max_length=self.inference_config.max_length, num_kv_heads=self.config.num_key_value_heads, head_dim=self.config.head_dim, dtype=torch.bfloat16, device="cuda", ) for _ in range(self.config.num_hidden_layers) ] # embedding x = self.embedding(input_ids).to(torch.bfloat16) # layers for i, layer in enumerate(self.layers): x = layer.inference(x, cu_seqlens, step, self.kv_cache[i]) # final norm x = self.norm(x) # lanugauge head if step == 0: x = x[cu_seqlens[1:] - 1, :] x = self.lm_head(x).to(torch.float32) # [total_len, vocab_size] return x def generate( self, input_ids: torch.LongTensor, cu_seqlens: torch.LongTensor, max_new_tokens: int = -1, ): output_tokens = [] if max_new_tokens <= 0: max_new_tokens = self.inference_config.max_new_tokens for step in range(max_new_tokens): logits = self.inference( input_ids, cu_seqlens, step ) # shape: [batch_size, vocab_size] next_token = torch.argmax(logits, dim=-1) # shape: [batch_size, ] input_ids = next_token output_tokens.append(next_token) output_tokens = torch.stack( output_tokens, dim=1 ) # shape: [batch_size, max_new_tokens] return output_tokens if __name__ == "__main__": torch.manual_seed(42) # initialize model config = ToyLlamaConfig( hidden_size=4096, intermediate_size=14336, num_hidden_layers=8, num_attention_heads=32, num_key_value_heads=2, head_dim=128, rope_theta=500000.0, rope_scaling={ "factor": 8.0, "high_freq_factor": 4.0, "low_freq_factor": 1.0, "original_max_position_embeddings": 8192, "rope_type": "llama3", }, ) inference_config = InferenceConfig( max_batch_size=4, max_length=8192, max_new_tokens=128, ) model = ToyLlama(config, inference_config).cuda().bfloat16() print(f"\nMODEL CONFIG:\n{config}\n") print(f"\nINFERENCE CONFIG:\n{inference_config}\n") print(f"\nMODEL:\n{model}\n") # example input batch_size = 4 seqlens = torch.randint(0, 4096, (batch_size,), dtype=torch.int32, device="cuda") cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda") cu_seqlens[1:] = torch.cumsum(seqlens, dim=0) input_ids = torch.randint( 0, 128288, (cu_seqlens[-1],), dtype=torch.int64, device="cuda" ) print(f"\nEXAMPLE INPUT:\ncu_seqlens: {cu_seqlens}\ninput_ids: {input_ids.shape}\n") # example output logits = model(input_ids, cu_seqlens) print(f"\nEXAMPLE OUTPUT:\nlogits: {logits.shape}\n") # example generate output_tokens = model.generate(input_ids, cu_seqlens, 64) print(f"\nEXAMPLE GENERATE:\noutput_tokens: {output_tokens}\n") ================================================ FILE: native_sparse_attention/model/toy_nsa_llama.py ================================================ # Copyright 2025 Xunhao Lai & Jianqiao Lu. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional import torch import torch.nn as nn from dataclasses import dataclass, field from native_sparse_attention.module import NativeSparseAttention, RopeConfig, NSACache @dataclass class ToyNSALlamaConfig: # embedding config vocab_size: int = 128288 max_position_embeddings: int = 131072 # model config hidden_size: int = 4096 intermediate_size: int = 14336 num_hidden_layers: int = 32 num_attention_heads: int = 32 num_key_value_heads: int = 2 head_dim: int = 128 # rope config rope_theta: float = 500000.0 rope_scaling: dict = field( default_factory=lambda: { "factor": 8.0, "high_freq_factor": 4.0, "low_freq_factor": 1.0, "original_max_position_embeddings": 8192, "rope_type": "llama3", } ) # nsa config compress_type: str = "weightedpool" kernel_size: int = 32 kernel_stride: int = 16 block_size: int = 64 topk: int = 16 init_blocks: int = 1 local_blocks: int = 2 window_size: int = 512 @dataclass class InferenceConfig: max_batch_size: int = 32 max_length: int = 8192 max_new_tokens: int = 128 class RMSNorm(nn.Module): def __init__(self, hidden_size: int, eps: float = 1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states: torch.Tensor): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) class FFN(nn.Module): def __init__(self, hidden_size: int, intermediate_size: int): super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = nn.SiLU() def forward(self, x): down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj class ToyNSALlamaLayer(nn.Module): def __init__( self, hidden_size: int, intermediate_size: int, num_q_heads: int, num_kv_heads: int, head_dim: int, compress_type: str, kernel_size: int, kernel_stride: int, block_size: int, topk: int, init_blocks: int, local_blocks: int, window_size: int, rope_config: RopeConfig, ): super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_q_heads = num_q_heads self.num_kv_heads = num_kv_heads self.head_dim = head_dim self.compress_type = compress_type self.kernel_size = kernel_size self.kernel_stride = kernel_stride self.block_size = block_size self.topk = topk self.init_blocks = init_blocks self.local_blocks = local_blocks self.window_size = window_size self.rope_config = rope_config self.attn_norm = RMSNorm(self.hidden_size) self.nsa = NativeSparseAttention( hidden_size=self.hidden_size, num_q_heads=self.num_q_heads, num_kv_heads=self.num_kv_heads, head_dim=self.head_dim, compress_type=self.compress_type, kernel_size=self.kernel_size, kernel_stride=self.kernel_stride, block_size=self.block_size, topk=self.topk, init_blocks=self.init_blocks, local_blocks=self.local_blocks, window_size=self.window_size, rope_config=rope_config, ) self.ffn_norm = RMSNorm(self.hidden_size) self.ffn = FFN( hidden_size=self.hidden_size, intermediate_size=self.intermediate_size ) def forward(self, x, cu_seqlens): x = x + self.nsa(self.attn_norm(x), cu_seqlens) x = x + self.ffn(self.ffn_norm(x)) return x @torch.no_grad() def inference(self, x, cu_seqlens, step, kv_cache): x = x + self.nsa.inference(self.attn_norm(x), cu_seqlens, step, kv_cache) x = x + self.ffn(self.ffn_norm(x)) return x class ToyNSALlama(nn.Module): def __init__( self, config: ToyNSALlamaConfig, inference_config: Optional[InferenceConfig] = None, ): super().__init__() self.config = config self.embedding = nn.Embedding(self.config.vocab_size, self.config.hidden_size) self.rope_config = RopeConfig( head_dim=self.config.head_dim, rope_theta=self.config.rope_theta, rope_scaling=self.config.rope_scaling, ) self.layers = nn.ModuleList( [ ToyNSALlamaLayer( hidden_size=self.config.hidden_size, intermediate_size=self.config.intermediate_size, num_q_heads=self.config.num_attention_heads, num_kv_heads=self.config.num_key_value_heads, head_dim=self.config.head_dim, compress_type=self.config.compress_type, kernel_size=self.config.kernel_size, kernel_stride=self.config.kernel_stride, block_size=self.config.block_size, topk=self.config.topk, init_blocks=self.config.init_blocks, local_blocks=self.config.local_blocks, window_size=self.config.window_size, rope_config=RopeConfig( self.config.max_position_embeddings, self.config.head_dim, self.config.rope_theta, self.config.rope_scaling, ), ) for _ in range(self.config.num_hidden_layers) ] ) self.norm = RMSNorm(self.config.hidden_size) self.lm_head = nn.Linear( self.config.hidden_size, self.config.vocab_size, bias=False ) # inference config and kv cache self.inference_config = inference_config self.kv_cache = None def forward( self, input_ids: torch.LongTensor, # shape: [batch_size, max_length] cu_seqlens: torch.LongTensor, # shape: [batch_size + 1, ] ): # embedding x = self.embedding(input_ids).to(torch.bfloat16) # layers for layer in self.layers: x = layer(x, cu_seqlens) # final norm x = self.norm(x) # lanugauge head x = self.lm_head(x).to(torch.float32) # [total_len, vocab_size] return x @torch.no_grad() def inference( self, input_ids: torch.LongTensor, # prefill shape: [total_length, ]; decode shape: [batch_size, ] cu_seqlens: torch.LongTensor, # shape: [batch_size + 1, ] step: int, ): # set kv cache if self.kv_cache is None if self.kv_cache is None: self.kv_cache = [ NSACache( max_batch_size=self.inference_config.max_batch_size, max_length=self.inference_config.max_length, num_kv_heads=self.config.num_key_value_heads, head_dim=self.config.head_dim, kernel_size=self.config.kernel_size, kernel_stride=self.config.kernel_stride, window_size=self.config.window_size, dtype=torch.bfloat16, device="cuda", ) for _ in range(self.config.num_hidden_layers) ] # embedding x = self.embedding(input_ids).to(torch.bfloat16) # layers for i, layer in enumerate(self.layers): x = layer.inference(x, cu_seqlens, step, self.kv_cache[i]) # final norm x = self.norm(x) # lanugauge head if step == 0: x = x[cu_seqlens[1:] - 1, :] x = self.lm_head(x).to(torch.float32) # [total_len, vocab_size] return x def generate( self, input_ids: torch.LongTensor, cu_seqlens: torch.LongTensor, max_new_tokens: int = -1, ): output_tokens = [] if max_new_tokens <= 0: max_new_tokens = self.inference_config.max_new_tokens for step in range(max_new_tokens): logits = self.inference( input_ids, cu_seqlens, step ) # shape: [batch_size, vocab_size] next_token = torch.argmax(logits, dim=-1) # shape: [batch_size, ] input_ids = next_token output_tokens.append(next_token) output_tokens = torch.stack( output_tokens, dim=1 ) # shape: [batch_size, max_new_tokens] return output_tokens if __name__ == "__main__": torch.manual_seed(42) # initialize model config = ToyNSALlamaConfig( hidden_size=4096, intermediate_size=14336, num_hidden_layers=8, num_attention_heads=32, num_key_value_heads=2, head_dim=128, rope_theta=500000.0, rope_scaling={ "factor": 8.0, "high_freq_factor": 4.0, "low_freq_factor": 1.0, "original_max_position_embeddings": 8192, "rope_type": "llama3", }, compress_type="weightedpool", kernel_size=32, kernel_stride=16, block_size=64, topk=8, init_blocks=1, local_blocks=2, window_size=512, ) inference_config = InferenceConfig( max_batch_size=4, max_length=8192, max_new_tokens=128, ) model = ToyNSALlama(config, inference_config).cuda().bfloat16() print(f"\nMODEL CONFIG:\n{config}\n") print(f"\nINFERENCE CONFIG:\n{inference_config}\n") print(f"\nMODEL:\n{model}\n") # example input batch_size = 4 seqlens = torch.randint(0, 4096, (batch_size,), dtype=torch.int32, device="cuda") cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda") cu_seqlens[1:] = torch.cumsum(seqlens, dim=0) input_ids = torch.randint( 0, 128288, (cu_seqlens[-1],), dtype=torch.int64, device="cuda" ) print(f"\nEXAMPLE INPUT:\ncu_seqlens: {cu_seqlens}\ninput_ids: {input_ids.shape}\n") # example output logits = model(input_ids, cu_seqlens) print(f"\nEXAMPLE OUTPUT:\nlogits: {logits.shape}\n") # example generate output_tokens = model.generate(input_ids, cu_seqlens, 64) print(f"\nEXAMPLE GENERATE:\noutput_tokens: {output_tokens}\n") ================================================ FILE: native_sparse_attention/module/__init__.py ================================================ # Copyright 2025 Xunhao Lai & Jianqiao Lu. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from native_sparse_attention.module.native_sparse_attention import NativeSparseAttention from native_sparse_attention.module.self_attention import SelfAttention from native_sparse_attention.module.rope import RotaryEmbedding, RopeConfig from native_sparse_attention.module.kv_cache import NSACache, KVCache __all__ = [ "SelfAttention", "NativeSparseAttention", "RotaryEmbedding", "RopeConfig", "NSACache", ] ================================================ FILE: native_sparse_attention/module/kv_cache.py ================================================ # Copyright 2025 Xunhao Lai & Jianqiao Lu. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import triton import triton.language as tl from typing import Union from native_sparse_attention.ops.triton.utils import get_compressed_seqlens class KVCache: def __init__( self, max_batch_size: int, max_length: int, num_kv_heads: int, head_dim: int, dtype: torch.dtype, device: Union[str, torch.device], ): self.max_batch_size = max_batch_size self.max_length = max_length self.num_kv_heads = num_kv_heads self.head_dim = head_dim self.dtype = dtype self.device = device # alloc kv cache tensor for topk sparse attention self.kv_cache = torch.zeros( 2, self.max_batch_size, self.max_length, self.num_kv_heads, self.head_dim, dtype=self.dtype, device=self.device, ) self.kv_len = torch.zeros( self.max_batch_size, dtype=torch.int32, device=self.device ) def reset(self): self.kv_cache.zero_() self.kv_len.zero_() def update_kv( self, cu_seqlens: torch.Tensor, step: int, key: torch.Tensor, value: torch.Tensor, ): if step == 0: self._update_kv_prefill( cu_seqlens, step, key, value, ) else: self._update_kv_decode( cu_seqlens, step, key, value, ) def _update_kv_prefill( self, cu_seqlens: torch.Tensor, step: int, key: torch.Tensor, value: torch.Tensor, ): assert step == 0 seqlens = cu_seqlens[1:] - cu_seqlens[:-1] batch_size = seqlens.shape[0] # sparse part kv, shape check total_len, num_heads, head_dim = key.shape assert key.shape == value.shape assert num_heads == self.num_kv_heads and head_dim == self.head_dim assert total_len == cu_seqlens[-1].item() # fill sparse part kv cache seq_start, seq_end = cu_seqlens[:-1], cu_seqlens[1:] _fill_kv_cache( self.kv_cache, key, value, seq_start, seq_end, ) self.kv_len[:batch_size] = seqlens def _update_kv_decode( self, cu_seqlens: torch.Tensor, step: int, key: torch.Tensor, value: torch.Tensor, ): assert step > 0 seqlens = cu_seqlens[1:] - cu_seqlens[:-1] # sparse part kv, shape check batch_size, num_heads, head_dim = key.shape assert batch_size == seqlens.shape[0] assert key.shape == value.shape assert num_heads == self.num_kv_heads and head_dim == self.head_dim # fill sparse part kv cache brange = torch.arange(batch_size, dtype=torch.int32, device=key.device) self.kv_cache[0, :batch_size][brange, self.kv_len[:batch_size]] = key self.kv_cache[1, :batch_size][brange, self.kv_len[:batch_size]] = value self.kv_len[:batch_size] += 1 class NSACache: """KV cache manager for native sparse attention. Args: max_batch_size (int): max batch size max_length (int): max length, including prompt len and reponse len num_kv_heads (int): number of key/value heads head_dim (int): head dim kernel_size (int): kernel size of compression kernel_stride (int): kernel stride ofr compression window_size (int): window size for sliding window attention dtype (torch.dtype): data type for kv cache, should be same as model weight dtype device (Union[str, torch.device]): default to 'cuda' Methods: reset: reset kv cache, should be called before prefilling prepare_compress: store keys/values for compression, should be called before key/value compression at both prefilling and decoding update_kv: update key/value cache, should be called after rope """ def __init__( self, max_batch_size: int, max_length: int, num_kv_heads: int, head_dim: int, kernel_size: int, kernel_stride: int, window_size: int, dtype: torch.dtype, device: Union[str, torch.device] = "cuda", ): self.max_batch_size = max_batch_size self.max_length = max_length self.num_kv_heads = num_kv_heads self.head_dim = head_dim self.kernel_size = kernel_size self.kernel_stride = kernel_stride self.window_size = window_size self.dtype = dtype self.device = device # alloc kv cache tensor for topk sparse attention self.sparse_kv_cache = torch.zeros( 2, self.max_batch_size, self.max_length, self.num_kv_heads, self.head_dim, dtype=self.dtype, device=self.device, ) self.sparse_kv_len = torch.zeros( self.max_batch_size, dtype=torch.int32, device=self.device ) # alloc kv cache tensor for compressed attention self.max_comp_length = ( self.max_length - self.kernel_size ) // self.kernel_stride + 1 self.compress_kv_cache = torch.zeros( 2, self.max_batch_size, self.max_comp_length, self.num_kv_heads, self.head_dim, dtype=self.dtype, device=self.device, ) self.compress_kv_len = torch.zeros( self.max_batch_size, dtype=torch.int32, device=self.device ) self.before_compress_kv_cache = torch.zeros( 2, self.max_batch_size, self.kernel_size, self.num_kv_heads, self.head_dim, dtype=self.dtype, device=self.device, ) self.before_compress_kv_len = torch.zeros( self.max_batch_size, dtype=torch.int32, device=self.device ) # alloc kv cache for sliding window attention self.sliding_kv_cache = torch.zeros( 2, self.max_batch_size, self.window_size, self.num_kv_heads, self.head_dim, dtype=self.dtype, device=self.device, ) self.sliding_kv_len = torch.zeros( self.max_batch_size, dtype=torch.int32, device=self.device ) def reset(self): self.compress_kv_cache.zero_() self.compress_kv_len.zero_() self.before_compress_kv_cache.zero_() self.before_compress_kv_len.zero_() self.sparse_kv_cache.zero_() self.sparse_kv_len.zero_() self.sliding_kv_cache.zero_() self.sliding_kv_len.zero_() def prepare_compress( self, cu_seqlens: torch.Tensor, step: int, key: torch.Tensor, value: torch.Tensor, ): if step == 0: self._prepare_compress_prefill(cu_seqlens, step, key, value) else: self._prepare_compress_decode(cu_seqlens, step, key, value) def _prepare_compress_prefill( self, cu_seqlens: torch.Tensor, step: int, key: torch.Tensor, value: torch.Tensor, ): assert step == 0 # compress part kv, shape check batch_size = cu_seqlens.shape[0] - 1 total_len, num_heads, head_dim = key.shape assert key.shape == value.shape assert num_heads == self.num_kv_heads and head_dim == self.head_dim comp_seqlens, comp_cu_seqlens = get_compressed_seqlens( cu_seqlens, self.kernel_size, self.kernel_stride ) assert total_len == cu_seqlens[-1].item() # fill tmp cache seq_start = cu_seqlens[:-1] + comp_seqlens * self.kernel_stride seq_end = cu_seqlens[1:] _fill_kv_cache( self.before_compress_kv_cache, key, value, seq_start, seq_end, ) self.before_compress_kv_len[:batch_size] = seq_end - seq_start def _prepare_compress_decode( self, cu_seqlens: torch.Tensor, step: int, key: torch.Tensor, value: torch.Tensor, ): assert step > 0 # compress part kv, shape check batch_size, num_heads, head_dim = key.shape assert key.shape == value.shape assert num_heads == self.num_kv_heads and head_dim == self.head_dim assert batch_size == cu_seqlens.shape[0] - 1 # sequence with full before_compress_kv idx = torch.where(self.before_compress_kv_len == self.kernel_size)[0].squeeze() self.before_compress_kv_cache[ :, idx, : self.kernel_size - self.kernel_stride, :, : ] = self.before_compress_kv_cache[:, idx, self.kernel_stride :, :, :] self.before_compress_kv_len[idx] -= self.kernel_stride # fill new kv brange = torch.arange(batch_size, dtype=torch.int32, device=key.device) self.before_compress_kv_cache[0, :batch_size][ brange, self.before_compress_kv_len[:batch_size] ] = key self.before_compress_kv_cache[1, :batch_size][ brange, self.before_compress_kv_len[:batch_size] ] = value # update kv len self.before_compress_kv_len[:batch_size] += 1 def update_kv( self, cu_seqlens: torch.Tensor, step: int, compress_key: torch.Tensor, compress_value: torch.Tensor, sparse_key: torch.Tensor, sparse_value: torch.Tensor, sliding_key: torch.Tensor, sliding_value: torch.Tensor, ): if step == 0: self._update_kv_prefill( cu_seqlens, step, compress_key, compress_value, sparse_key, sparse_value, sliding_key, sliding_value, ) else: self._update_kv_decode( cu_seqlens, step, compress_key, compress_value, sparse_key, sparse_value, sliding_key, sliding_value, ) def _update_kv_prefill( self, cu_seqlens: torch.Tensor, step: int, compress_key: torch.Tensor, compress_value: torch.Tensor, sparse_key: torch.Tensor, sparse_value: torch.Tensor, sliding_key: torch.Tensor, sliding_value: torch.Tensor, ): assert step == 0 seqlens = cu_seqlens[1:] - cu_seqlens[:-1] batch_size = seqlens.shape[0] # sparse part kv, shape check total_len, num_heads, head_dim = sparse_key.shape assert sparse_key.shape == sparse_value.shape assert num_heads == self.num_kv_heads and head_dim == self.head_dim assert total_len == cu_seqlens[-1].item() # compress part kv, shape check total_comp_len, num_heads, head_dim = compress_key.shape assert compress_key.shape == compress_value.shape assert num_heads == self.num_kv_heads and head_dim == self.head_dim comp_seqlens, comp_cu_seqlens = get_compressed_seqlens( cu_seqlens, self.kernel_size, self.kernel_stride ) assert total_comp_len == comp_cu_seqlens[-1].item() # sliding window part kv, shape check total_len, num_heads, head_dim = sliding_key.shape assert sliding_key.shape == sliding_value.shape # fill compress part kv cache seq_start, seq_end = comp_cu_seqlens[:-1], comp_cu_seqlens[1:] _fill_kv_cache( self.compress_kv_cache, compress_key, compress_value, seq_start, seq_end, ) self.compress_kv_len[:batch_size] = comp_seqlens # fill sparse part kv cache seq_start, seq_end = cu_seqlens[:-1], cu_seqlens[1:] _fill_kv_cache( self.sparse_kv_cache, sparse_key, sparse_value, seq_start, seq_end, ) self.sparse_kv_len[:batch_size] = seqlens # fill sliding part kv cache seq_start = torch.maximum(cu_seqlens[1:] - self.window_size, cu_seqlens[:-1]) seq_end = cu_seqlens[1:] _fill_kv_cache( self.sliding_kv_cache, sliding_key, sliding_value, seq_start, seq_end, ) self.sliding_kv_len[:batch_size] = seq_end - seq_start def _update_kv_decode( self, cu_seqlens: torch.Tensor, step: int, compress_key: torch.Tensor, compress_value: torch.Tensor, sparse_key: torch.Tensor, sparse_value: torch.Tensor, sliding_key: torch.Tensor, sliding_value: torch.Tensor, ): assert step > 0 seqlens = cu_seqlens[1:] - cu_seqlens[:-1] # sparse part kv, shape check batch_size, num_heads, head_dim = sparse_key.shape assert batch_size == seqlens.shape[0] assert sparse_key.shape == sparse_value.shape assert num_heads == self.num_kv_heads and head_dim == self.head_dim # compress part kv, shape check batch_size, num_heads, head_dim = compress_key.shape assert batch_size == seqlens.shape[0] assert compress_key.shape == compress_value.shape assert num_heads == self.num_kv_heads and head_dim == self.head_dim # sliding window part kv, shape check total_len, num_heads, head_dim = sliding_key.shape assert sliding_key.shape == sliding_value.shape # fill compress part kv cache, only full block need to be compress idx = torch.where(self.before_compress_kv_len == self.kernel_size)[0].squeeze() self.compress_kv_cache[0][idx, self.compress_kv_len[idx]] = compress_key[idx] self.compress_kv_cache[1][idx, self.compress_kv_len[idx]] = compress_value[idx] self.compress_kv_len[idx] += 1 # fill sparse part kv cache brange = torch.arange(batch_size, dtype=torch.int32, device=sparse_key.device) self.sparse_kv_cache[0, :batch_size][ brange, self.sparse_kv_len[:batch_size] ] = sparse_key self.sparse_kv_cache[1, :batch_size][ brange, self.sparse_kv_len[:batch_size] ] = sparse_value self.sparse_kv_len[:batch_size] += 1 # fill sliding window kv cache self.sliding_kv_cache[0, :batch_size][ brange, self.sliding_kv_len[:batch_size] % self.window_size ] = sliding_key self.sliding_kv_cache[1, :batch_size][ brange, self.sliding_kv_len[:batch_size] % self.window_size ] = sliding_value self.sliding_kv_len[:batch_size] += 1 @triton.jit def _fill_kv_cache_kernel( cache_ptr, k_ptr, v_ptr, seq_start, seq_end, head_dim, stride_c2, stride_cb, stride_cn, stride_ch, stride_cd, stride_kn, stride_kh, stride_kd, stride_vn, stride_vh, stride_vd, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_D: tl.constexpr, ): # get batch id and head id pid_2b = tl.program_id(0) pid_2 = pid_2b % 2 pid_b = pid_2b // 2 pid_h = tl.program_id(1) pid_n = tl.program_id(2) # get kv start and len after rmpad kv_start = tl.load(seq_start + pid_b) kv_end = tl.load(seq_end + pid_b) kv_len = kv_end - kv_start if pid_n * BLOCK_SIZE_N >= kv_len: return # load key or value if pid_2 == 0: kv_ptrs = tl.make_block_ptr( base=k_ptr + kv_start * stride_kn + pid_h * stride_kh, shape=(kv_len, head_dim), strides=(stride_kn, stride_kd), offsets=(pid_n * BLOCK_SIZE_N, 0), block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_D), order=(1, 0), ) kv = tl.load(kv_ptrs, boundary_check=(0, 1)) else: kv_ptrs = tl.make_block_ptr( base=v_ptr + kv_start * stride_vn + pid_h * stride_vh, shape=(kv_len, head_dim), strides=(stride_vn, stride_vd), offsets=(pid_n * BLOCK_SIZE_N, 0), block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_D), order=(1, 0), ) kv = tl.load(kv_ptrs, boundary_check=(0, 1)) # store to cache cache_ptrs = tl.make_block_ptr( base=cache_ptr + pid_2 * stride_c2 + pid_b * stride_cb + pid_h * stride_ch, shape=(kv_len, head_dim), strides=(stride_cn, stride_cd), offsets=(pid_n * BLOCK_SIZE_N, 0), block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_D), order=(1, 0), ) tl.store(cache_ptrs, kv.to(cache_ptr.dtype.element_ty), boundary_check=(0, 1)) def _fill_kv_cache( kv_cache: torch.Tensor, # shape: [2, b, n, h, d] key: torch.Tensor, # shape: [l, h, d] value: torch.Tensor, # shape: [l, h, d] seq_start: torch.Tensor, # shape: [b] seq_end: torch.Tensor, # shape: [b] ): total_len, num_heads, head_dim = key.shape batch_size = seq_start.shape[0] max_kv_len = (seq_end - seq_start).max().item() # no kv cache to fill if max_kv_len == 0: return BLOCK_SIZE_N = min(1024, triton.next_power_of_2(max_kv_len)) BLOCK_SIZE_D = triton.next_power_of_2(head_dim) grid = (2 * batch_size, num_heads, triton.cdiv(max_kv_len, BLOCK_SIZE_N)) _fill_kv_cache_kernel[grid]( kv_cache, key, value, seq_start, seq_end, head_dim, kv_cache.stride(0), kv_cache.stride(1), kv_cache.stride(2), kv_cache.stride(3), kv_cache.stride(4), key.stride(0), key.stride(1), key.stride(2), value.stride(0), value.stride(1), value.stride(2), BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_D=BLOCK_SIZE_D, ) ================================================ FILE: native_sparse_attention/module/native_sparse_attention.py ================================================ # Copyright 2025 Xunhao Lai & Jianqiao Lu. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from flash_attn import flash_attn_varlen_func from native_sparse_attention.ops import ( compressed_attention, topk_sparse_attention, avgpool_compress, weightedpool_compress, linear_compress, ) from einops import rearrange from native_sparse_attention.module.rope import RopeConfig, RotaryEmbedding from native_sparse_attention.infer import nsa_infer from native_sparse_attention.module.kv_cache import NSACache COMPRESS_TYPE_TO_FUNC = { "avgpool": avgpool_compress, "weightedpool": weightedpool_compress, "linear": linear_compress, } COMPRESS_TYPE_TO_WEIGHT = { "avgpool": lambda num_heads, head_dim, kernel_size: None, "weightedpool": lambda num_heads, head_dim, kernel_size: torch.nn.Parameter( torch.zeros(num_heads, kernel_size) ), "linear": lambda num_heads, head_dim, kernel_size: torch.nn.Parameter( torch.zeros(num_heads, head_dim * kernel_size, head_dim) ), } class NativeSparseAttention(torch.nn.Module): """Native sparse attention module, support training and inference Args: compress_type (str): key value compression type, currently support ['linear', 'avgpool', 'weightedpool'] hidden_size (int): hidden dimension num_q_heads (int): number of query heads num_kv_heads (int): number of key/value heads, must be divisible by num_q_heads head_dim (int): head dim kernel_size (int): kernel size of compression kernel_stride (int): kernel stride ofr compression block_size (int): block size of sparse attention topk (int): topk of sparse attention init_blocks (int): number of blocks at the begining of the sequence, these blocks are force to be computed in sparse attention local_blocks (int): number of blocks at the local window of each query, these blocks are force to be computed in sparse attention window_size (int): window size for sliding window attention rope_config (RopeConfig): config for rotary embedding, see native_sparse_attention.module.rope.RopeConfig for details rope_device (str): device used to store rope freqs """ def __init__( self, compress_type: str, hidden_size: int, num_q_heads: int, num_kv_heads: int, head_dim: int, kernel_size: int, kernel_stride: int, block_size: int, topk: int, init_blocks: int, local_blocks: int, window_size: int, rope_config: RopeConfig, rope_device: str = "cuda", ): super().__init__() # configs self.compress_type = compress_type self.hidden_size = hidden_size self.num_q_heads = num_q_heads self.num_kv_heads = num_kv_heads self.head_dim = head_dim self.kernel_size = kernel_size self.kernel_stride = kernel_stride self.block_size = block_size self.topk = topk self.init_blocks = init_blocks self.local_blocks = local_blocks self.window_size = window_size self.rope_config = rope_config assert self.head_dim == self.rope_config.head_dim # qkv proj and o proj self.proj_q = torch.nn.Linear( self.hidden_size, self.num_q_heads * self.head_dim, bias=False ) self.proj_k = torch.nn.Linear( self.hidden_size, self.num_kv_heads * self.head_dim, bias=False ) self.proj_v = torch.nn.Linear( self.hidden_size, self.num_kv_heads * self.head_dim, bias=False ) self.proj_o = torch.nn.Linear( self.num_q_heads * self.head_dim, self.hidden_size, bias=False ) # nsa compress func self.compress_func = COMPRESS_TYPE_TO_FUNC[self.compress_type] # nsa parameteres self.compress_key = COMPRESS_TYPE_TO_WEIGHT[self.compress_type]( num_kv_heads, head_dim, kernel_size ) self.compress_value = COMPRESS_TYPE_TO_WEIGHT[self.compress_type]( num_kv_heads, head_dim, kernel_size ) self.intra_block_pe = torch.nn.Parameter( torch.zeros(self.num_kv_heads, self.kernel_size, self.head_dim) ) # gate function self.gate = torch.nn.Sequential( torch.nn.Linear(self.hidden_size, self.num_q_heads * 3, bias=False), torch.nn.Sigmoid(), ) # rope self.rope = RotaryEmbedding(self.rope_config, device=rope_device) # init parameters self.init_params() def init_params(self): for p in self.parameters(): if len(p.shape) > 1: torch.nn.init.xavier_uniform_(p) def forward( self, x: torch.Tensor, # shape: [total_len, hidden_size] cu_seqlens: torch.Tensor, # shape: [batch_size + 1] ): # dtype and shape check assert x.dtype == torch.bfloat16 or x.dtype == torch.float16 assert x.shape[-1] == self.hidden_size cu_seqlens = cu_seqlens.to(torch.int32) seqlens = cu_seqlens[1:] - cu_seqlens[:-1] # qkv proj q = self.proj_q(x).view(-1, self.num_q_heads, self.head_dim) k = self.proj_k(x).view(-1, self.num_kv_heads, self.head_dim) v = self.proj_v(x).view(-1, self.num_kv_heads, self.head_dim) # compressed key and value before rope compressed_k, compressed_cu_seqlens = self.compress_func( k, self.compress_key, cu_seqlens, self.kernel_size, self.kernel_stride, self.intra_block_pe, ) compressed_v, _ = self.compress_func( v, self.compress_value, cu_seqlens, self.kernel_size, self.kernel_stride, None, ) # do rope for query and compressed key q = self.rope(q, cu_seqlens) compressed_k = self.rope( compressed_k, compressed_cu_seqlens, stride=self.kernel_stride ) # attention between query and compressed key value compressed_seqlens = compressed_cu_seqlens[1:] - compressed_cu_seqlens[:-1] compressed_attn_output, topk_idx = compressed_attention( q, compressed_k, compressed_v, self.kernel_size, self.kernel_stride, self.block_size, self.topk, cu_seqlens, compressed_cu_seqlens, seqlens.max().item(), compressed_seqlens.max().item(), None, self.init_blocks, self.local_blocks, ) # do rope for original key k = self.rope(k, cu_seqlens) # topk sparse attention sparse_attn_output = topk_sparse_attention( q, k, v, topk_idx, self.block_size, cu_seqlens, None ) # sliding window attention sliding_attn_output = flash_attn_varlen_func( q, k, v, cu_seqlens, cu_seqlens, seqlens.max().item(), seqlens.max().item(), causal=True, window_size=(self.window_size, -1), ) # gate average gate = self.gate(x) gate = rearrange(gate, "n (h g) -> n h g", g=3) attn_output = ( gate[..., 0:1] * compressed_attn_output + gate[..., 1:2] * sparse_attn_output + gate[..., 2:3] * sliding_attn_output ) # rearrange and output proj attn_output = rearrange(attn_output, "n h d -> n (h d)") attn_output = self.proj_o(attn_output) return attn_output @torch.no_grad() def inference( self, x: torch.Tensor, # shape: [total_len, hidden_size] cu_seqlens: torch.Tensor, # shape: [batch_size + 1] step: int, cache: NSACache, ): # dtype and shape check assert x.dtype == torch.bfloat16 or x.dtype == torch.float16 assert x.shape[-1] == self.hidden_size cu_seqlens = cu_seqlens.to(torch.int32) assert step >= 0 if step == 0: assert x.shape[0] == cu_seqlens[-1] else: assert x.shape[0] == cu_seqlens.shape[0] - 1 # qkv proj q = self.proj_q(x).view(-1, self.num_q_heads, self.head_dim) k = self.proj_k(x).view(-1, self.num_kv_heads, self.head_dim) v = self.proj_v(x).view(-1, self.num_kv_heads, self.head_dim) # gate proj gate = self.gate(x) gate = rearrange(gate, "n (h g) -> n h g", g=3) # nsa infer output = nsa_infer( cu_seqlens, step, q, k, v, gate, self.rope, cache, [self.compress_key, self.compress_value], [self.compress_func, self.compress_func], self.intra_block_pe, self.kernel_size, self.kernel_stride, self.block_size, self.topk, self.init_blocks, self.local_blocks, self.window_size, ) # output proj output = rearrange(output, "n h d -> n (h d)") output = self.proj_o(output) return output ================================================ FILE: native_sparse_attention/module/rope.py ================================================ # Copyright 2025 Xunhao Lai & Jianqiao Lu. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its # original forms to accommodate minor architectural differences compared # to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from dataclasses import dataclass, field from torch import nn from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS # default to llama3.1 rope config @dataclass class RopeConfig: """Config for RotaryEmbedding, similar to transformers llama.""" max_position_embeddings: int = 131072 head_dim: int = 128 rope_theta: float = 500000 rope_scaling: dict = field( default_factory=lambda: { "factor": 8.0, "high_freq_factor": 4.0, "low_freq_factor": 1.0, "original_max_position_embeddings": 8192, "rope_type": "llama3", } ) # useless, just for compatibility, please use head_dim instead hidden_size: int = 1 num_attention_heads: int = 1 def __post_init__(self): self.num_attention_heads = 1 self.hidden_size = self.head_dim # Copied from transformers.models.llama.modeling_llama.rotate_half def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) # copy and modify from modify from hugigngface transformers # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py class RotaryEmbedding(nn.Module): """Rotary embedding Args: config (RopeConfig): config for rotary embedding, see native_sparse_attention.module.rope.RopeConfig for details device (str): default to 'cuda' """ cos = None sin = None def __init__( self, config: RopeConfig, device=torch.device(torch.cuda.current_device()) ): super().__init__() # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get( "rope_type", config.rope_scaling.get("type") ) else: self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq def _dynamic_frequency_update(self, position_ids, device): """ dynamic RoPE layers should recompute `inv_freq` in the following situations: 1 - growing beyond the cached sequence length (allow scaling) 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth inv_freq, self.attention_scaling = self.rope_init_fn( self.config, device, seq_len=seq_len ) self.register_buffer( "inv_freq", inv_freq, persistent=False ) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len if ( seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len ): # reset # This .to() is needed if the model has been moved to a device after being initialized (because # the buffer is automatically moved, but not the original copy) self.original_inv_freq = self.original_inv_freq.to(device) self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) self.max_seq_len_cached = self.original_max_seq_len @torch.no_grad() def generate_cos_sin(self, x: torch.Tensor, position_ids): if "dynamic" in self.rope_type: self._dynamic_frequency_update(position_ids, device=x.device) # Core RoPE block inv_freq_expanded = ( self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) ) position_ids_expanded = position_ids[:, None, :].float() # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = ( device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" ) with torch.autocast(device_type=device_type, enabled=False): freqs = ( inv_freq_expanded.float() @ position_ids_expanded.float() ).transpose(1, 2) # # donot use this if use flash_attn # emb = torch.cat((freqs, freqs), dim=-1) emb = freqs cos = emb.cos() sin = emb.sin() # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention cos = (cos * self.attention_scaling).to(dtype=x.dtype).squeeze(0) sin = (sin * self.attention_scaling).to(dtype=x.dtype).squeeze(0) # save cos sin RotaryEmbedding.cos = torch.cat([cos, cos], dim=-1) RotaryEmbedding.sin = torch.cat([sin, sin], dim=-1) return RotaryEmbedding.cos, RotaryEmbedding.sin @torch.no_grad() def generate_pos_embs( self, x: torch.Tensor, cu_seqlens: torch.Tensor, seqlens: torch.Tensor, step: int = 0, stride: int = 1, ): if ( RotaryEmbedding.cos is None or seqlens.max() + step > RotaryEmbedding.cos.shape[0] ): self.generate_cos_sin( x, torch.arange(seqlens.max() + step).to(x.device)[None, :] ) cos_embs = [] sin_embs = [] bsz = len(cu_seqlens) - 1 for i in range(bsz): if step == 0: # prefilling r = cu_seqlens[i + 1] - cu_seqlens[i] cos_emb, sin_emb = ( RotaryEmbedding.cos[: r * stride : stride], RotaryEmbedding.sin[: r * stride : stride], ) elif step > 0: # decoding r = cu_seqlens[i + 1] - cu_seqlens[i] + step - 1 cos_emb, sin_emb = ( RotaryEmbedding.cos[r * stride : r * stride + 1], RotaryEmbedding.sin[r * stride : r * stride + 1], ) cos_embs.append(cos_emb) sin_embs.append(sin_emb) cos_embs = torch.cat(cos_embs, dim=0) sin_embs = torch.cat(sin_embs, dim=0) return cos_embs, sin_embs def forward(self, x, cu_seqlens, step=0, stride=1): seqlens = cu_seqlens[1:] - cu_seqlens[:-1] cos_embs, sin_embs = self.generate_pos_embs( x, cu_seqlens, seqlens, step=step, stride=stride, ) N, H, D = x.shape[0], x.shape[-2], x.shape[-1] # H: number of heads x = x * cos_embs.view(N, 1, D) + rotate_half(x) * sin_embs.view(N, 1, D) return x ================================================ FILE: native_sparse_attention/module/self_attention.py ================================================ # Copyright 2025 Xunhao Lai & Jianqiao Lu. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from flash_attn import flash_attn_varlen_func from einops import rearrange from native_sparse_attention.module.rope import RopeConfig, RotaryEmbedding from native_sparse_attention.module.kv_cache import KVCache from native_sparse_attention.ops import flash_attention_decode class SelfAttention(torch.nn.Module): """self attention module Args: hidden_size (int): hidden dimension num_q_heads (int): number of query heads num_kv_heads (int): number of key/value heads, must be divisible by num_q_heads head_dim (int): head dim rope_config (RopeConfig): config for rotary embedding, see native_sparse_attention.module.rope.RopeConfig for details """ def __init__( self, hidden_size: int, num_q_heads: int, num_kv_heads: int, head_dim: int, rope_config: RopeConfig, rope_device: str = "cuda", ): super().__init__() # configs self.hidden_size = hidden_size self.num_q_heads = num_q_heads self.num_kv_heads = num_kv_heads self.head_dim = head_dim self.rope_config = rope_config assert self.head_dim == self.rope_config.head_dim # qkv proj and o proj self.proj_q = torch.nn.Linear( self.hidden_size, self.num_q_heads * self.head_dim, bias=False ) self.proj_k = torch.nn.Linear( self.hidden_size, self.num_kv_heads * self.head_dim, bias=False ) self.proj_v = torch.nn.Linear( self.hidden_size, self.num_kv_heads * self.head_dim, bias=False ) self.proj_o = torch.nn.Linear( self.num_q_heads * self.head_dim, self.hidden_size, bias=False ) # rope self.rope = RotaryEmbedding(self.rope_config, device=rope_device) # init parameters self.init_params() def init_params(self): for p in self.parameters(): torch.nn.init.xavier_uniform_(p) def forward( self, x: torch.Tensor, # shape: [total_len, hidden_size] cu_seqlens: torch.Tensor, # shape: [batch_size + 1] ): # dtype and shape check assert x.dtype == torch.bfloat16 or x.dtype == torch.float16 assert x.shape[-1] == self.hidden_size cu_seqlens = cu_seqlens.to(torch.int32) seqlens = cu_seqlens[1:] - cu_seqlens[:-1] # qkv proj q = self.proj_q(x).view(-1, self.num_q_heads, self.head_dim) k = self.proj_k(x).view(-1, self.num_kv_heads, self.head_dim) v = self.proj_v(x).view(-1, self.num_kv_heads, self.head_dim) # do rope for query and compressed key q = self.rope(q, cu_seqlens) k = self.rope(k, cu_seqlens) # self attention attn_output = flash_attn_varlen_func( q, k, v, cu_seqlens, cu_seqlens, seqlens.max().item(), seqlens.max().item(), causal=True, ) # rearrange and output proj attn_output = rearrange(attn_output, "n h d -> n (h d)") attn_output = self.proj_o(attn_output) return attn_output @torch.no_grad() def inference( self, x: torch.Tensor, # shape: [total_len, hidden_size] cu_seqlens: torch.Tensor, # shape: [batch_size + 1] step: int, cache: KVCache, ): # dtype and shape check assert x.dtype == torch.bfloat16 or x.dtype == torch.float16 assert x.shape[-1] == self.hidden_size cu_seqlens = cu_seqlens.to(torch.int32) seqlens = cu_seqlens[1:] - cu_seqlens[:-1] assert step >= 0 if step == 0: assert x.shape[0] == cu_seqlens[-1] else: assert x.shape[0] == cu_seqlens.shape[0] - 1 batch_size = cu_seqlens.shape[0] - 1 # qkv proj q = self.proj_q(x).view(-1, self.num_q_heads, self.head_dim) k = self.proj_k(x).view(-1, self.num_kv_heads, self.head_dim) v = self.proj_v(x).view(-1, self.num_kv_heads, self.head_dim) # do rope for query and compressed key q = self.rope(q, cu_seqlens, step) k = self.rope(k, cu_seqlens, step) # reset and update kv cache if step == 0: cache.reset() cache.update_kv(cu_seqlens, step, k, v) # self attention if step == 0: cu_seqlens_q = cu_seqlens_k = cu_seqlens max_seqlen_in_batch_q = max_seqlen_in_batch_k = seqlens.max().item() output = flash_attn_varlen_func( q, k, v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, causal=True, ) else: output = flash_attention_decode( q, cache.kv_cache[0, :batch_size], cache.kv_cache[1, :batch_size], cache.kv_len[:batch_size], ) # rearrange and output proj output = rearrange(output, "n h d -> n (h d)") output = self.proj_o(output) return output ================================================ FILE: native_sparse_attention/ops/README.md ================================================ # Triton Functions for Native Sparse Attention This folder provides efficient Triton-based implementations of components for Native Sparse Attention. This README introduces the available functions, explains how to set them up, and offers guidance on their usage. --- ## Overview of Functions The functions are organized into two main categories: 1. **Compression Methods**: Techniques for compressing key and value tensors. 2. **Attention Mechanisms**: Methods for computing attention between queries and compressed key/value tensors, including top-k sparse attention. --- ## Function Descriptions ### Compression Methods These functions compress key and value tensors using a sliding window approach. Within each window, `kernel_size` tokens are compressed into a single token, with a stride of `kernel_stride`. All compression functions share similar input parameters and return formats. **Parameters:** - `x`: Input tensor (`total_len, num_heads, head_dim`) - `w`: Weight tensor (shape varies by compression method) - `cu_seqlens`: Cumulative sequence lengths (`batch_size + 1`) - `kernel_size`: Size of the compression window - `kernel_stride`: Stride of the compression window - `pe`: Optional positional embedding (`num_heads, kernel_size, head_dim`) **Returns:** - Compressed tensor (`total_compress_len, num_heads, head_dim`) - Cumulative sequence lengths (`com_cu_seqlens`) for the compressed tensor #### `weightedpool_compress` Compresses the input tensor using weighted pooling, applying a weighted sum over each block: $\hat{k} = w_1 k_1 + \dots + w_m k_m$ - **Weight shape**: `(num_heads, kernel_size)` #### `avgpool_compress` Compresses the input tensor using average pooling: $\hat{k} = (k_1 + \dots + k_m) / m$ - **Weight**: Must be `None` #### `linear_compress` Compresses the input tensor via linear projection, mapping each block to a single vector using learned weights: $\hat{k} = \text{cat}(k_1, \dots, k_m) W$ - **Weight shape**: `(num_heads, kernel_size * head_dim, head_dim)` --- ### Attention Mechanisms These functions compute attention using either full or sparse mechanisms. #### `flash_attention_varlen` A variable-length implementation of flash attention, similar to `flash_attn_varlen_func` from the `flash_attn` package. **Parameters:** - `q`, `k`, `v`: Query, key, and value tensors (`total_len, num_heads, head_dim`) - `cu_seqlens_q`, `cu_seqlens_k`: Cumulative sequence lengths for queries and keys - `max_seqlen_q`, `max_seqlen_k`: Maximum sequence lengths in the batch - `causal`: Apply causal masking (default: `False`) - `sm_scale`: Softmax scale (default: `1 / sqrt(head_dim)`) **Returns:** - Attention output tensor (`total_q_len, num_q_heads, head_dim`) #### `compressed_attention` Computes attention between a query and compressed key/value tensors, identifying the top-k blocks for sparse attention. **Parameters:** - `q`: Query tensor (`total_len, num_heads, head_dim`) - `k`, `v`: Compressed key and value tensors (`total_compress_len, num_heads, head_dim`) - `kernel_size`, `kernel_stride`: Compression parameters - `block_size`: Size of blocks for sparse attention - `topk`: Number of top blocks to select - `cu_seqlens_q`, `cu_seqlens_k`: Cumulative sequence lengths for query and compressed key/value - `max_seqlen_q`, `max_seqlen_k`: Maximum sequence lengths for query and compressed key/value - `sm_scale`: Softmax scale (default: `1 / sqrt(head_dim)`) - `init_blocks`: Number of initial blocks forced to be selected (default: `1`) - `local_blocks`: Number of local blocks forced to be selected (default: `2`) **Returns:** - Tuple containing: - Attention output tensor - Top-k block indices #### `topk_sparse_attention` Performs sparse attention using precomputed top-k block indices. If a query attends to fewer than `topk` key/value blocks, the `topk_idx` should be padded with `-1` on the right. **Parameters:** - `q`, `k`, `v`: Query, key, and value tensors (`total_len, num_heads, head_dim`) - `topk_idx`: Precomputed top-k indices (`num_kv_heads, total_len, topk`) - `block_size`: Block size for sparse attention (recommended: `64` or `128`) - `cu_seqlens`: Cumulative sequence lengths - `softmax_scale`: Softmax scale (default: `1 / sqrt(head_dim)`) **Returns:** - Attention output tensor (`total_len, num_q_heads, head_dim`) --- ## Usage Example Below is a typical workflow demonstrating how to combine these sparse attention functions: ```python import torch from native_sparse_attention.ops import linear_compress, compressed_attention, topk_sparse_attention # Example input setup num_q_heads = 64 num_kv_heads = 4 head_dim = 128 cu_seqlens = torch.tensor([0, 1024, 8192, 16384], dtype=torch.int32).cuda() # Query, key, and value tensors query = torch.randn(16384, num_q_heads, head_dim, dtype=torch.bfloat16).cuda() key = torch.randn(16384, num_kv_heads, head_dim, dtype=torch.bfloat16).cuda() value = torch.randn(16384, num_kv_heads, head_dim, dtype=torch.bfloat16).cuda() # Compression weights and positional embeddings kernel_size = 32 kernel_stride = 16 wk = torch.randn(num_kv_heads, kernel_size * head_dim, head_dim, dtype=torch.bfloat16).cuda() wv = torch.randn_like(wk) pe = torch.randn(num_kv_heads, kernel_size, head_dim, dtype=torch.bfloat16).cuda() # Parameters for top-k sparse attention block_size = 64 topk = 16 # 1. Compress key and value tensors compressed_key, compressed_cu_seqlens = linear_compress( key, wk, cu_seqlens, kernel_size, kernel_stride, pe ) compressed_value, _ = linear_compress( value, wv, cu_seqlens, kernel_size, kernel_stride, None ) # 2. Compute attention with compressed key/value and get top-k indices compressed_attn_output, topk_idx = compressed_attention( query, compressed_key, compressed_value, kernel_size, kernel_stride, block_size, topk, cu_seqlens, compressed_cu_seqlens, init_blocks=1, local_blocks=2, ) # 3. Perform top-k sparse attention sparse_attn_output = topk_sparse_attention( query, key, value, topk_idx, block_size, cu_seqlens, ) # 4. Combine attention outputs (e.g., average) attn_output = (compressed_attn_output + sparse_attn_output) / 2 ``` For a complete implementation of the Native Sparse Attention module, see `native_sparse_attention/module/native_sparse_attention.py`. ================================================ FILE: native_sparse_attention/ops/__init__.py ================================================ # Copyright 2025 Xunhao Lai & Jianqiao Lu. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # compress method from native_sparse_attention.ops.triton.weighted_pool import ( weightedpool_compress, avgpool_compress, ) from native_sparse_attention.ops.triton.linear_compress import linear_compress # prefill attention from native_sparse_attention.ops.triton.flash_attention import flash_attention_varlen from native_sparse_attention.ops.triton.compressed_attention import compressed_attention from native_sparse_attention.ops.triton.topk_sparse_attention import ( topk_sparse_attention, ) # decode attention from native_sparse_attention.ops.triton.flash_attention_decode import ( flash_attention_decode, ) from native_sparse_attention.ops.torch.compressed_attention_decode import ( compressed_attention_decode, ) from native_sparse_attention.ops.triton.topk_sparse_attention_decode import ( topk_sparse_attention_decode, ) __all__ = [ # compress method "avgpool_compress", "weightedpool_compress", "linear_compress", # prefill attention, trainable "flash_attention_varlen", "compressed_attention", "topk_sparse_attention", # decode attention, no grad "flash_attention_decode", "compressed_attention_decode", "topk_sparse_attention_decode", ] ================================================ FILE: native_sparse_attention/ops/torch/__init__.py ================================================ ================================================ FILE: native_sparse_attention/ops/torch/compress_key_value.py ================================================ # Copyright 2025 Xunhao Lai & Jianqiao Lu. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from typing import Optional from einops import rearrange, einsum def avgpool_compress_torch( x: torch.Tensor, w: torch.Tensor, cu_seqlens, kernel_size: int, kernel_stride: int, pe: Optional[torch.Tensor] = None, ): """Compress key and value tensor with kernel_size and kernel_stride. Args: x (torch.Tensor): key_states or value_states, shape (total_len, num_heads, head_dim) w (torch.Tensor): no weight for avgpool, must be None. cu_seqlens (_type_): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen. kernel_size (int): kernel_size, each (kernel_size, head_dim) blocks will be compressed to (1, head_dim) kernel_stride (int): stride for each compress kernel pe (Optional[torch.Tensor], optional): intra-block positional embedding with shape (num_heads, kernel_size, head_dim). Defaults to None. Returns: Tuple[torch.Tensor, torch.Tensor]: compressed states and corresponding cu_seqlens. """ # dtype check assert x.dtype == torch.float16 or x.dtype == torch.bfloat16 assert cu_seqlens.dtype == torch.int32 assert x.dtype == pe.dtype if pe is not None else True # shape check total_len, num_heads, head_dim = x.shape batch_size = cu_seqlens.shape[0] - 1 assert w is None, "don't need additional weight for avgpool" assert kernel_size % kernel_stride == 0 assert kernel_size in {16, 32, 64, 128} # compute seqlens after compression seqlens = cu_seqlens[1:] - cu_seqlens[:-1] y_seqlens = torch.floor((seqlens - kernel_size) / kernel_stride).to(torch.int32) + 1 # corner case, if sequence_length < kernel_size, no compression for this sequence y_seqlens[seqlens < kernel_size] = 0 y_cu_seqlens = torch.cat( [ torch.zeros(1, dtype=torch.int32, device="cuda"), torch.cumsum(y_seqlens, dim=0), ], dim=0, ).to(torch.int32) # pad and rearrange x x = rearrange(x, "n h d -> n (h d)") splited_x = torch.split(x, seqlens.tolist(), 0) x = torch.nn.utils.rnn.pad_sequence(splited_x, batch_first=True) x = rearrange(x, "b n d -> b d n") # avgpool y = torch.nn.functional.avg_pool1d(x, kernel_size=kernel_size, stride=kernel_stride) y = rearrange(y, "b (h d) n -> b n h d", h=num_heads) # only keep useful part y = torch.cat([y[i, : y_seqlens[i]] for i in range(batch_size)], dim=0) # position embedding as a bias if pe is not None: bias = torch.mean(pe, dim=1) y = y + bias.unsqueeze(0) return y, y_cu_seqlens def weightedpool_compress_torch( x: torch.Tensor, w: torch.Tensor, # [num_heads, kernel_size] cu_seqlens, kernel_size: int, kernel_stride: int, pe: Optional[torch.Tensor] = None, ): """Compress key and value tensor with kernel_size and kernel_stride. Args: x (torch.Tensor): key_states or value_states, shape (total_len, num_heads, head_dim) w (torch.Tensor): weight for each head, shape (num_heads, kernel_size) cu_seqlens (_type_): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen. kernel_size (int): kernel_size, each (kernel_size, head_dim) blocks will be compressed to (1, head_dim) kernel_stride (int): stride for each compress kernel pe (Optional[torch.Tensor], optional): intra-block positional embedding with shape (num_heads, kernel_size, head_dim). Defaults to None. Returns: Tuple[torch.Tensor, torch.Tensor]: compressed states and corresponding cu_seqlens. """ # dtype check assert x.dtype == torch.float16 or x.dtype == torch.bfloat16 assert x.dtype == w.dtype assert x.dtype == pe.dtype if pe is not None else True assert cu_seqlens.dtype == torch.int32 # shape check total_len, num_heads, head_dim = x.shape batch_size = cu_seqlens.shape[0] - 1 assert w.shape[0] == num_heads assert w.shape[1] == kernel_size assert kernel_size % kernel_stride == 0 assert kernel_size in {16, 32, 64, 128} # compute seqlens after compression seqlens = cu_seqlens[1:] - cu_seqlens[:-1] y_seqlens = torch.floor((seqlens - kernel_size) / kernel_stride).to(torch.int32) + 1 # corner case, if sequence_length < kernel_size, no compression for this sequence y_seqlens[seqlens < kernel_size] = 0 y_cu_seqlens = torch.cat( [ torch.zeros(1, dtype=torch.int32, device="cuda"), torch.cumsum(y_seqlens, dim=0), ], dim=0, ).to(torch.int32) # pad and rearrange x x = rearrange(x, "n h d -> n (h d)") splited_x = torch.split(x, seqlens.tolist(), 0) x = torch.nn.utils.rnn.pad_sequence(splited_x, batch_first=True) x = rearrange(x, "b n (h d) -> b h n d", h=num_heads) x = x.as_strided( size=(batch_size, num_heads, y_seqlens.max().item(), kernel_size, head_dim), stride=( x.stride(0), x.stride(1), kernel_stride * x.stride(2), x.stride(2), x.stride(3), ), ) y = einsum(x, w, "b h n k d, h k -> b n h d") # only keep useful part y = torch.cat([y[i, : y_seqlens[i]] for i in range(batch_size)], dim=0) # position embedding as a bias if pe is not None: bias = einsum(pe, w, "h k d, h k -> h d") y = y + bias.unsqueeze(0) return y, y_cu_seqlens def linear_compress_torch( x: torch.Tensor, w: torch.Tensor, # [num_heads, kernel_size * head_dim, head_dim] cu_seqlens, kernel_size: int, kernel_stride: int, pe: Optional[torch.Tensor] = None, ): """Compress key and value tensor with kernel_size and kernel_stride. Similar to conv_compress. Args: x (torch.Tensor): key_states or value_states, shape (total_len, num_heads, head_dim) w (torch.Tensor): weight for each head, shape (num_heads, kernel_size * head_dim, head_dim) cu_seqlens (_type_): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen. kernel_size (int): kernel_size, each (kernel_size, head_dim) blocks will be compressed to (1, head_dim) kernel_stride (int): stride for each compress kernel pe (Optional[torch.Tensor], optional): intra-block positional embedding with shape (num_heads, kernel_size, head_dim). Defaults to None. Returns: Tuple[torch.Tensor, torch.Tensor]: compressed states and corresponding cu_seqlens. """ # dtype check assert x.dtype == torch.float16 or x.dtype == torch.bfloat16 assert x.dtype == w.dtype assert x.dtype == pe.dtype if pe is not None else True assert cu_seqlens.dtype == torch.int32 # shape check total_len, num_heads, head_dim = x.shape batch_size = cu_seqlens.shape[0] - 1 assert w.shape[0] == num_heads assert w.shape[1] == kernel_size * head_dim assert w.shape[2] == head_dim assert kernel_size % kernel_stride == 0 assert kernel_size in {16, 32, 64, 128} # compute seqlens after compression seqlens = cu_seqlens[1:] - cu_seqlens[:-1] y_seqlens = torch.floor((seqlens - kernel_size) / kernel_stride).to(torch.int32) + 1 # corner case, if sequence_length < kernel_size, no compression for this sequence y_seqlens[seqlens < kernel_size] = 0 y_cu_seqlens = torch.cat( [ torch.zeros(1, dtype=torch.int32, device="cuda"), torch.cumsum(y_seqlens, dim=0), ], dim=0, ).to(torch.int32) # pad and rearrange x x = rearrange(x, "n h d -> n (h d)") splited_x = torch.split(x, seqlens.tolist(), 0) x = torch.nn.utils.rnn.pad_sequence(splited_x, batch_first=True) x = rearrange(x, "b n (h d) -> b h n d", h=num_heads) x = x.as_strided( size=(batch_size, num_heads, y_seqlens.max().item(), kernel_size, head_dim), stride=( x.stride(0), x.stride(1), kernel_stride * x.stride(2), x.stride(2), x.stride(3), ), ) y = einsum( x, rearrange(w, "h (k d) D -> h k d D", k=kernel_size), "b h n k d, h k d D -> b n h D", ) # only keep useful part y = torch.cat([y[i, : y_seqlens[i]] for i in range(batch_size)], dim=0) # position embedding as a bias if pe is not None: pe = rearrange(pe, "h k d -> h (k d)") bias = einsum(pe, w, "h D, h D d -> h d") y = y + bias.unsqueeze(0) return y, y_cu_seqlens ================================================ FILE: native_sparse_attention/ops/torch/compressed_attention.py ================================================ # Copyright 2025 Xunhao Lai & Jianqiao Lu. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import math from typing import Tuple from collections import Counter from einops import rearrange def transform_score( score: torch.Tensor, kernel_size: int, kernel_stride: int, block_size: int, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, init_blocks: int = 1, local_blocks: int = 2, ) -> torch.Tensor: num_k_heads, total_query_len, _ = score.shape pad_len = kernel_size // kernel_stride - 1 score = torch.nn.functional.pad(score, (pad_len, pad_len), value=0) max_blocks = math.ceil(max_seqlen_q / block_size) full_blocks = max_seqlen_q // block_size block_score = torch.zeros( num_k_heads, total_query_len, max_blocks, dtype=torch.float32, device=score.device, ) offs = ( torch.arange(kernel_size // kernel_stride)[:, None] + torch.arange(block_size // kernel_stride)[None, :] ).view(-1) offs = dict(Counter(offs.tolist())) for k, v in offs.items(): block_score[..., :full_blocks] += ( v * score[..., k :: block_size // kernel_stride][..., :full_blocks] ) # set init block and local block score batch_size = cu_seqlens_q.shape[0] - 1 q_idx = torch.cat( [ torch.arange(cu_seqlens_q[i + 1] - cu_seqlens_q[i], device=score.device) for i in range(batch_size) ], dim=0, ) q_idx = q_idx // block_size b_idx = torch.arange(max_blocks, device=score.device) block_score[..., :init_blocks] = torch.inf local_mask = (q_idx[:, None] >= b_idx[None, :]) & ( q_idx[:, None] < b_idx[None, :] + local_blocks ) local_mask = local_mask.unsqueeze(0).expand(num_k_heads, -1, -1) block_score[local_mask] = torch.inf return block_score def compressed_attention_torch( q: torch.Tensor, # [total_query_len, num_q_heads, head_dim] k: torch.Tensor, # [total_key_len, num_k_heads, head_dim] v: torch.Tensor, # [total_key_len, num_k_heads, head_dim] kernel_size: int, kernel_stride: int, block_size: int, topk: int, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, sm_scale: float = None, init_blocks: int = 1, local_blocks: int = 2, ) -> Tuple[torch.Tensor, torch.Tensor]: """Attention between query and compressed key and value. Implemented with torch, only for debug. Args: q (torch.Tensor): shape [total_q_len, num_q_heads, head_dim] k (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim] v (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim] kernel_size (int): kernel size in compress_key_value kernel_stride (int): stride of compress_key_value block_size (int): key value block size for topk sparse attention. topk (int): number of blocks for each query. cu_seqlens_q (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen. cu_seqlens_k (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_k in flash_attn_func_varlen. max_seqlen_q (int): max q len of the batch. max_seqlen_k (int): max k len of the batch. sm_scale (float, optional): softmax scale. Defaults to None, means 1/sqrt(head_dim). init_blocks (int, optional): Number of init blocks for each query. Defaults to 1. local_blocks (int, optional): Number of local blocks for each query. Defaults to 2. Returns: Tuple[torch.Tensor, torch.Tensor]: attention output and topk_idx used in topk_sparse_attention """ assert block_size % kernel_size == 0 and kernel_size % kernel_stride == 0 total_query_len, num_q_heads, head_dim = q.shape total_key_len, num_k_heads, _ = k.shape num_share_q_heads = num_q_heads // num_k_heads batch_size = cu_seqlens_q.shape[0] - 1 if sm_scale is None: sm_scale = 1.0 / math.sqrt(head_dim) # get mask mask = torch.zeros( (total_query_len, total_key_len), dtype=torch.bool, device=q.device ) for b in range(batch_size): q_len, k_len = ( cu_seqlens_q[b + 1] - cu_seqlens_q[b], cu_seqlens_k[b + 1] - cu_seqlens_k[b], ) k_max_ids = ( torch.arange(k_len, device=q.device) * kernel_stride + kernel_size - 1 ) q_ids = torch.arange(q_len, device=q.device) mask[ cu_seqlens_q[b] : cu_seqlens_q[b + 1], cu_seqlens_k[b] : cu_seqlens_k[b + 1] ] = (q_ids[:, None] >= k_max_ids[None, :]) # attention qk = ( torch.einsum("qhd,khd->hqk", q, k.repeat_interleave(num_share_q_heads, 1)) * sm_scale ) qk = qk.masked_fill_(~mask[None, ...], -torch.inf) # query from beginning of the sequence can't attend to any compressed key qk = qk.softmax(dim=-1, dtype=torch.float32) qk = qk.nan_to_num(0) attn_output = torch.einsum( "hqk,khd->qhd", qk.to(v.dtype), v.repeat_interleave(num_share_q_heads, 1) ) with torch.no_grad(): # get avg score over gqa heads # qk shape [num_k_heads, total_q_len, total_k_len] score = torch.zeros( num_k_heads, cu_seqlens_q[-1], max_seqlen_k, dtype=torch.float32, device=q.device, ) qk = rearrange(qk, "(h g) q k -> h g q k", h=num_k_heads).sum(1) for b in range(batch_size): score[ :, cu_seqlens_q[b] : cu_seqlens_q[b + 1], : cu_seqlens_k[b + 1] - cu_seqlens_k[b], ] = qk[ :, cu_seqlens_q[b] : cu_seqlens_q[b + 1], cu_seqlens_k[b] : cu_seqlens_k[b + 1], ] # transform score to block-wise score score = transform_score( score, kernel_size, kernel_stride, block_size, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, init_blocks, local_blocks, ) # get topk batch_size = cu_seqlens_q.shape[0] - 1 q_idx = torch.cat( [ torch.arange(cu_seqlens_q[i + 1] - cu_seqlens_q[i], device=q.device) for i in range(batch_size) ], dim=0, ) q_idx = q_idx // block_size topk = min(topk, score.shape[-1]) topk_idx = score.topk(topk, dim=-1).indices.sort(-1).values topk_idx[topk_idx > q_idx[None, :, None]] = -1 topk_idx = topk_idx.to(torch.int32) return attn_output, topk_idx ================================================ FILE: native_sparse_attention/ops/torch/compressed_attention_decode.py ================================================ # Copyright 2025 Xunhao Lai & Jianqiao Lu. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import math from typing import Tuple, Optional from collections import Counter from einops import rearrange def transform_score( score: torch.Tensor, seqlens: torch.Tensor, kernel_size: int, kernel_stride: int, block_size: int, init_blocks: int = 1, local_blocks: int = 2, ) -> torch.Tensor: num_k_heads, batch_size, kv_len = score.shape pad_len = kernel_size // kernel_stride - 1 score = torch.nn.functional.pad(score, (pad_len, pad_len), value=0) max_seqlen = seqlens.max().item() max_blocks = math.ceil(max_seqlen / block_size) full_blocks = max_seqlen // block_size block_score = torch.zeros( num_k_heads, batch_size, max_blocks, dtype=torch.float32, device=score.device, ) offs = ( torch.arange(kernel_size // kernel_stride)[:, None] + torch.arange(block_size // kernel_stride)[None, :] ).view(-1) offs = dict(Counter(offs.tolist())) for k, v in offs.items(): block_score[..., :full_blocks] += ( v * score[..., k :: block_size // kernel_stride][..., :full_blocks] ) # set init block and local block score q_idx = (seqlens - 1) // block_size b_idx = torch.arange(max_blocks, device=score.device) block_score[..., :init_blocks] = torch.inf local_mask = (q_idx[:, None] >= b_idx[None, :]) & ( q_idx[:, None] < b_idx[None, :] + local_blocks ) local_mask = local_mask.unsqueeze(0).expand(num_k_heads, -1, -1) block_score[local_mask] = torch.inf block_score = block_score.nan_to_num(0, torch.inf, -torch.inf) return block_score def compressed_attention_decode( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor, compress_seqlens: torch.Tensor, kernel_size: int, kernel_stride: int, block_size: int, topk: int, init_blocks: int = 1, local_blocks: int = 2, sm_scale: Optional[float] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """_summary_ Args: q (torch.Tensor): shape [batch_size, num_q_heads, head_dim] k (torch.Tensor): shape [batch_size, kv_len, num_kv_heads, head_dim] v (torch.Tensor): shape [batch_size, kv_len, num_kv_heads, head_dim] seqlens (torch.Tensor): original kv length for each sequence compress_seqlens (torch.Tensor): kv length for each sequence after compression kernel_size (int): kernel size in compress_key_value kernel_stride (int): stride of compress_key_value block_size (int): key value block size for topk sparse attention. topk (int): number of blocks for each query. init_blocks (int, optional): Number of init blocks for each query. Defaults to 1. local_blocks (int, optional): Number of local blocks for each query. Defaults to 2. sm_scale (float, optional): softmax scale. Defaults to None, means 1/sqrt(head_dim). Returns: Tuple[torch.Tensor, torch.Tensor]: attention output and topk_idx used in topk_sparse_attention_decode """ assert block_size % kernel_size == 0 and kernel_size % kernel_stride == 0 batch_size, num_q_heads, head_dim = q.shape batch_size, kv_len, num_k_heads, _ = k.shape num_share_q_heads = num_q_heads // num_k_heads if sm_scale is None: sm_scale = 1.0 / math.sqrt(head_dim) # input is too short to have a valid block if kv_len == 0: return torch.zeros_like(q), torch.zeros( num_k_heads, batch_size, 1, device=q.device, dtype=torch.int32 ) # get mask mask = ( compress_seqlens[:, None] > torch.arange( kv_len, device=compress_seqlens.device, dtype=compress_seqlens.dtype )[None, :] ) # attention qk = ( torch.einsum( "bihgd, bjhgd -> bhgij", rearrange(q, "b (h g) d -> b 1 h g d", g=num_share_q_heads), rearrange(k, "b j h d -> b j h 1 d"), ) * sm_scale ) qk = qk.masked_fill_(~mask[:, None, None, None, :], -torch.inf) qk = qk.softmax(dim=-1, dtype=torch.float32) qk = qk.nan_to_num_(0) # qk is nan when seqlen == 0 attn_output = torch.einsum( "bhgij, bjhgd -> bihgd", qk.to(v.dtype), rearrange(v, "b k h d -> b k h 1 d"), ) attn_output = rearrange(attn_output, "b 1 h g d -> b (h g) d") # get score score = rearrange(qk.sum(2).squeeze(2), "b h j -> h b j") # transform score to block-wise score score = transform_score( score, seqlens, kernel_size, kernel_stride, block_size, init_blocks, local_blocks, ) # get topk q_idx = (seqlens - 1) // block_size topk = min(topk, score.shape[-1]) topk_idx = score.topk(topk, dim=-1).indices.sort(-1).values topk_idx[topk_idx > q_idx[None, :, None]] = -1 topk_idx = topk_idx.to(torch.int32) return attn_output, topk_idx ================================================ FILE: native_sparse_attention/ops/torch/topk_sparse_attention.py ================================================ # Copyright 2025 Xunhao Lai & Jianqiao Lu. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import math from typing import Optional def topk_sparse_attention_torch( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, topk_idx: torch.Tensor, block_size_k: int, cu_seqlens: torch.Tensor, softmax_scale: Optional[float] = None, block_size_q: int = 1, ) -> torch.Tensor: """Simple topk sparse attention varlen version implemented in torch. Extremly slow, only for debugging. Args: q (torch.Tensor): shape [total_len, num_q_heads, head_dim] k (torch.Tensor): shape [total_len, num_kv_heads, head_dim] v (torch.Tensor): shape [total_len, num_kv_heads, head_dim] topk_idx (torch.Tensor): topk block idx for each query, shape [num_kv_heads, total_len, topk]. -1 means padding. block_size_q (int): query block size. block_size_k (int): key value block size. cu_seqlens (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens in flash_attn_func_varlen. softmax_scale (Optional[float], optional): Defaults to None, means 1/sqrt(head_dim). Returns: torch.Tensor: attention output, shape [total_len, num_q_heads, head_dim] """ total_seqlen, num_q_heads, head_dim = q.shape total_seqlen, num_kv_heads, head_dim = k.shape num_share_q_heads = num_q_heads // num_kv_heads batch_size = cu_seqlens.shape[0] - 1 topk = topk_idx.shape[-1] seqlens = cu_seqlens[1:] - cu_seqlens[:-1] seqblocks_q = torch.ceil(seqlens / block_size_q).to(torch.int32) cu_seqblocks_q = torch.nn.functional.pad(seqblocks_q.cumsum(0), (1, 0), value=0) if softmax_scale is None: softmax_scale = 1.0 / math.sqrt(head_dim) # get mask mask = torch.zeros( (num_kv_heads, total_seqlen, total_seqlen), dtype=torch.bool, device=q.device ) for i in range(batch_size): num_q_blocks = math.ceil(seqlens[i] / block_size_q) num_kv_blocks = math.ceil(seqlens[i] / block_size_k) for h in range(num_kv_heads): temp_mask = torch.zeros( num_q_blocks, num_kv_blocks, dtype=torch.bool, device=q.device ) temp_idx = topk_idx[h, cu_seqblocks_q[i] : cu_seqblocks_q[i + 1]].clone() temp_idx[temp_idx < 0] = 0 temp_mask[torch.arange(num_q_blocks).to(q.device)[:, None], temp_idx] = True temp_mask = torch.repeat_interleave(temp_mask, block_size_q, dim=0) temp_mask = torch.repeat_interleave(temp_mask, block_size_k, dim=1) temp_mask = temp_mask[: seqlens[i], : seqlens[i]] mask[ h, cu_seqlens[i] : cu_seqlens[i + 1], cu_seqlens[i] : cu_seqlens[i + 1] ] = temp_mask mask = torch.tril(mask).repeat_interleave(num_share_q_heads, 0) # qk attn qk = ( torch.einsum("qhd,khd->hqk", q, k.repeat_interleave(num_share_q_heads, 1)) * softmax_scale ) qk = torch.masked_fill(qk, ~mask, -torch.inf) qk = torch.softmax(qk, dim=-1, dtype=torch.float32).to(q.dtype) o = torch.einsum("hqk,khd->qhd", qk, v.repeat_interleave(num_share_q_heads, 1)) return o ================================================ FILE: native_sparse_attention/ops/triton/__init__.py ================================================ ================================================ FILE: native_sparse_attention/ops/triton/compressed_attention.py ================================================ # Copyright 2025 Xunhao Lai & Jianqiao Lu. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import Any, Tuple, Union from collections import Counter import torch import triton import triton.language as tl import warnings from native_sparse_attention.ops.triton.utils import get_num_warps_stages, is_hopper_gpu IS_HOPPER_GPU = is_hopper_gpu() @triton.jit def forward_kernel( q_ptr, # Q: n x h x d k_ptr, # K: n x h x d v_ptr, # V: n x h x d o_ptr, # O: n x h x d lse_ptr, # LSE: h x n # size and stride at compresstion kernel_size, kernel_stride, # seqlens cu_seqlens_q, cu_seqlens_k, # shape NUM_KV_HEADS, NUM_SHARE_Q_HEADS, HEAD_DIM, # sm_scale sm_scale, # stride stride_qn, stride_qh, stride_qd, stride_kn, stride_kh, stride_kd, stride_vn, stride_vh, stride_vd, stride_on, stride_oh, stride_od, stride_lh, stride_ln, # META parameters BLOCK_SIZE_Q: tl.constexpr, # q block size BLOCK_SIZE_K: tl.constexpr, # k block size BLOCK_SIZE_D: tl.constexpr, ): qk_scale = sm_scale * 1.44269504 # get batch id and head id pid_b = tl.program_id(0) pid_h = tl.program_id(1) pid_q = tl.program_id(2) pid_kh = pid_h // NUM_SHARE_Q_HEADS # get q k start and len after rmpad q_start = tl.load(cu_seqlens_q + pid_b) q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start k_start = tl.load(cu_seqlens_k + pid_b) k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start # skip first kernel_size query block, because they do no attend to any keys q_start_in_seq = pid_q * BLOCK_SIZE_Q + kernel_size - 1 if q_start_in_seq >= q_len: return # init qkv pointer q_ptrs = tl.make_block_ptr( base=q_ptr + q_start * stride_qn + pid_h * stride_qh, shape=(q_len, HEAD_DIM), strides=(stride_qn, stride_qd), offsets=(q_start_in_seq, 0), block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), order=(1, 0), ) k_ptrs = tl.make_block_ptr( base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, shape=(HEAD_DIM, k_len), strides=(stride_kd, stride_kn), offsets=(0, 0), block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), order=(0, 1), ) v_ptrs = tl.make_block_ptr( base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, shape=(k_len, HEAD_DIM), strides=(stride_vn, stride_vd), offsets=(0, 0), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), order=(1, 0), ) # load q q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero") # init statistics off_q = tl.arange(0, BLOCK_SIZE_Q) + q_start_in_seq off_k = tl.arange(0, BLOCK_SIZE_K) * kernel_stride + kernel_size - 1 m_i = tl.full((BLOCK_SIZE_Q,), float("-inf"), dtype=tl.float32) lse_i = tl.full((BLOCK_SIZE_Q,), float("-inf"), dtype=tl.float32) acc_o = tl.full((BLOCK_SIZE_Q, BLOCK_SIZE_D), 0, dtype=tl.float32) # attention lo = 0 hi = min(k_len, (q_start_in_seq + BLOCK_SIZE_Q - kernel_size) // kernel_stride + 1) for i in range(lo, hi, BLOCK_SIZE_K): i = tl.multiple_of(i, BLOCK_SIZE_K) # load k k = tl.load(k_ptrs, boundary_check=(1, 0), padding_option="zero") # compute qk qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) qk += tl.where( off_q[:, None] >= (i * kernel_stride + off_k)[None, :], 0, float("-inf") ) qk += tl.dot(q, k) * qk_scale # compute m_ij and l_ij m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) p = tl.exp2(qk - m_ij[:, None]) l_ij = tl.sum(p, axis=1) # scale acc_o acc_o_scale = tl.exp2(m_i - m_ij) acc_o = acc_o * acc_o_scale[:, None] # load v and update acc_o v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero") p = p.to(v.dtype) acc_o += tl.dot(p, v) # update statistics m_i = m_ij lse_i = m_ij + tl.math.log2(tl.exp2(lse_i - m_ij) + l_ij) # update ptrs k_ptrs = tl.advance(k_ptrs, (0, BLOCK_SIZE_K)) v_ptrs = tl.advance(v_ptrs, (BLOCK_SIZE_K, 0)) # final scale acc_o = acc_o * tl.exp2(m_i - lse_i)[:, None] # save output o_ptrs = tl.make_block_ptr( base=o_ptr + q_start * stride_on + pid_h * stride_oh, shape=(q_len, HEAD_DIM), strides=(stride_on, stride_od), offsets=(q_start_in_seq, 0), block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), order=(1, 0), ) tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1)) # save lse l_ptrs = lse_ptr + q_start * stride_ln + pid_h * stride_lh + off_q * stride_ln tl.store(l_ptrs, lse_i, mask=off_q < q_len) @triton.jit def backward_sum_o_do( o_ptr, # O: n x h x d do_ptr, # dO: n x h x d delta_ptr, # D: h x n o_len, HEAD_DIM, stride_on, stride_oh, stride_od, stride_don, stride_doh, stride_dod, stride_dh, stride_dn, BLOCK_SIZE_O: tl.constexpr, BLOCK_SIZE_D: tl.constexpr, ): pid_n = tl.program_id(0) pid_h = tl.program_id(1) off_n = pid_n * BLOCK_SIZE_O + tl.arange(0, BLOCK_SIZE_O) off_d = tl.arange(0, BLOCK_SIZE_D) o = tl.load( o_ptr + off_n[:, None] * stride_on + pid_h * stride_oh + off_d[None, :] * stride_od, mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM), other=0, ).to(tl.float32) do = tl.load( do_ptr + off_n[:, None] * stride_don + pid_h * stride_doh + off_d[None, :] * stride_dod, mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM), other=0, ).to(tl.float32) delta = tl.sum(o * do, axis=1) tl.store( delta_ptr + pid_h * stride_dh + off_n * stride_dn, delta, mask=off_n < o_len ) @triton.jit def backward_dkdv( q_ptr, # Q: n x qh x d k_ptr, # K: n x kh x d v_ptr, # V: n x kh x d lse_ptr, # LSE: qh x n d_ptr, # Delta: qh x n do_ptr, dk_ptr, # DK: sh x n x kh x d dv_ptr, # DV: sh x n x kh x d kernel_size, kernel_stride, # seqlens cu_seqlens_q, cu_seqlens_k, # shape NUM_KV_HEADS, NUM_SHARE_Q_HEADS, HEAD_DIM, # sm_scale sm_scale, # stride stride_qn, stride_qh, stride_qd, stride_kn, stride_kh, stride_kd, stride_vn, stride_vh, stride_vd, stride_lh, stride_ln, stride_dh, stride_dn, stride_don, stride_doh, stride_dod, stride_dks, stride_dkn, stride_dkh, stride_dkd, stride_dvs, stride_dvn, stride_dvh, stride_dvd, # META parameters BLOCK_SIZE_Q: tl.constexpr, # q block size BLOCK_SIZE_K: tl.constexpr, # k block size BLOCK_SIZE_D: tl.constexpr, ): qk_scale = sm_scale * 1.44269504 # get batch id and head id pid_b = tl.program_id(0) pid_h = tl.program_id(1) pid_kh = pid_h // NUM_SHARE_Q_HEADS pid_sh = pid_h % NUM_SHARE_Q_HEADS pid_k = tl.program_id(2) # get q k start and len after rmpad q_start = tl.load(cu_seqlens_q + pid_b) q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start k_start = tl.load(cu_seqlens_k + pid_b) k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start if BLOCK_SIZE_K * pid_k >= k_len: return # init pointers k_ptrs = tl.make_block_ptr( base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, shape=(k_len, HEAD_DIM), strides=(stride_kn, stride_kd), offsets=(pid_k * BLOCK_SIZE_K, 0), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), order=(1, 0), ) dk_ptrs = tl.make_block_ptr( base=dk_ptr + k_start * stride_dkn + pid_kh * stride_dkh + pid_sh * stride_dks, shape=(k_len, HEAD_DIM), strides=(stride_dkn, stride_dkd), offsets=(pid_k * BLOCK_SIZE_K, 0), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), order=(1, 0), ) v_ptrs = tl.make_block_ptr( base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, shape=(k_len, HEAD_DIM), strides=(stride_vn, stride_vd), offsets=(pid_k * BLOCK_SIZE_K, 0), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), order=(1, 0), ) dv_ptrs = tl.make_block_ptr( base=dv_ptr + k_start * stride_dvn + pid_kh * stride_dvh + pid_sh * stride_dvs, shape=(k_len, HEAD_DIM), strides=(stride_dvn, stride_dvd), offsets=(pid_k * BLOCK_SIZE_K, 0), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), order=(1, 0), ) # offsets off_q = tl.arange(0, BLOCK_SIZE_Q) off_k = ( pid_k * BLOCK_SIZE_K * kernel_stride + tl.arange(0, BLOCK_SIZE_K) * kernel_stride + kernel_size - 1 ) # load k v and keep in SRAM k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero") v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero") # init dk dv dk = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32) dv = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32) q_lo = pid_k * BLOCK_SIZE_K * kernel_stride + kernel_size - 1 q_ptrs = tl.make_block_ptr( base=q_ptr + q_start * stride_qn + pid_h * stride_qh, shape=(HEAD_DIM, q_len), strides=(stride_qd, stride_qn), offsets=(0, q_lo), block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_Q), order=(0, 1), ) do_ptrs = tl.make_block_ptr( base=do_ptr + q_start * stride_don + pid_h * stride_doh, shape=(HEAD_DIM, q_len), strides=(stride_dod, stride_don), offsets=(0, q_lo), block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_Q), order=(0, 1), ) d_ptrs = tl.make_block_ptr( base=d_ptr + q_start * stride_dn + pid_h * stride_dh, shape=(1, q_len), strides=(0, stride_dn), offsets=(0, q_lo), block_shape=(1, BLOCK_SIZE_Q), order=(1, 0), ) lse_ptrs = tl.make_block_ptr( base=lse_ptr + q_start * stride_ln + pid_h * stride_lh, shape=(1, q_len), strides=(0, stride_ln), offsets=(0, q_lo), block_shape=(1, BLOCK_SIZE_Q), order=(0, 1), ) # loop for q blocks for i in range(q_lo, q_len, BLOCK_SIZE_Q): # load q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero") do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero") lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero") d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero") # compute qk # [BLOCK_SIZE_K, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q] qk = tl.where(off_k[:, None] <= (off_q + i)[None, :], float(0.0), float("-inf")) qk += tl.dot(k, q) * qk_scale # compute p, ds # [BLOCK_SIZE_K, BLOCK_SIE_Q] - [1, BLOCK_SIZE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q] p = tl.exp2(qk - lse) # [BLOCK_SIZE_K, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q] dp = tl.dot(v, do) ds = sm_scale * p * (dp - d) # cast dtype p = p.to(do.dtype) ds = ds.to(q.dtype) # update dk and dv # [BLOCK_SIZE_K, BLOCK_SIE_Q] @ [BLOCK_SIE_Q, HEAD_DIM] -> [BLOCK_SIZE_K, HEAD_DIM] dk += tl.dot(ds, tl.trans(q)) dv += tl.dot(p, tl.trans(do)) # increment pointers q_ptrs = tl.advance(q_ptrs, (0, BLOCK_SIZE_Q)) do_ptrs = tl.advance(do_ptrs, (0, BLOCK_SIZE_Q)) lse_ptrs = tl.advance(lse_ptrs, (0, BLOCK_SIZE_Q)) d_ptrs = tl.advance(d_ptrs, (0, BLOCK_SIZE_Q)) # save dk dv tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), boundary_check=(0, 1)) tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), boundary_check=(0, 1)) @triton.jit def backward_dq( q_ptr, # Q: n x qh x d k_ptr, # K: n x kh x d v_ptr, # V: n x kh x d lse_ptr, # LSE: qh x n d_ptr, # Delta: qh x n do_ptr, dq_ptr, kernel_size, kernel_stride, # seqlens cu_seqlens_q, cu_seqlens_k, # shape NUM_KV_HEADS, NUM_SHARE_Q_HEADS, HEAD_DIM, # sm_scale sm_scale, # stride stride_qn, stride_qh, stride_qd, stride_kn, stride_kh, stride_kd, stride_vn, stride_vh, stride_vd, stride_lh, stride_ln, stride_dh, stride_dn, stride_don, stride_doh, stride_dod, stride_dqn, stride_dqh, stride_dqd, # META parameters BLOCK_SIZE_Q: tl.constexpr, # q block size BLOCK_SIZE_K: tl.constexpr, # k block size BLOCK_SIZE_D: tl.constexpr, ): qk_scale = sm_scale * 1.44269504 # get batch id and head id pid_b = tl.program_id(0) pid_h = tl.program_id(1) pid_q = tl.program_id(2) pid_kh = pid_h // NUM_SHARE_Q_HEADS # get q k start and len after rmpad q_start = tl.load(cu_seqlens_q + pid_b) q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start k_start = tl.load(cu_seqlens_k + pid_b) k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start # skip first kernel_size query block, because they do no attend to any keys q_start_in_seq = pid_q * BLOCK_SIZE_Q + kernel_size - 1 if q_start_in_seq >= q_len: return # init pointers q_ptrs = tl.make_block_ptr( base=q_ptr + q_start * stride_qn + pid_h * stride_qh, shape=(q_len, HEAD_DIM), strides=(stride_qn, stride_qd), offsets=(q_start_in_seq, 0), block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), order=(1, 0), ) dq_ptrs = tl.make_block_ptr( base=dq_ptr + q_start * stride_dqn + pid_h * stride_dqh, shape=(q_len, HEAD_DIM), strides=(stride_dqn, stride_dqd), offsets=(q_start_in_seq, 0), block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), order=(1, 0), ) k_ptrs = tl.make_block_ptr( base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, shape=(k_len, HEAD_DIM), strides=(stride_kn, stride_kd), offsets=(0, 0), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), order=(1, 0), ) v_ptrs = tl.make_block_ptr( base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, shape=(HEAD_DIM, k_len), strides=(stride_vd, stride_vn), offsets=(0, 0), block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), order=(0, 1), ) do_ptrs = tl.make_block_ptr( base=do_ptr + q_start * stride_don + pid_h * stride_doh, shape=(q_len, HEAD_DIM), strides=(stride_don, stride_dod), offsets=(q_start_in_seq, 0), block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), order=(1, 0), ) d_ptrs = tl.make_block_ptr( base=d_ptr + q_start * stride_dn + pid_h * stride_dh, shape=(q_len, 1), strides=(stride_dn, stride_dh), offsets=(q_start_in_seq, 0), block_shape=(BLOCK_SIZE_Q, 1), order=(0, 1), ) lse_ptrs = tl.make_block_ptr( base=lse_ptr + q_start * stride_ln + pid_h * stride_lh, shape=(q_len, 1), strides=(stride_ln, stride_lh), offsets=(q_start_in_seq, 0), block_shape=(BLOCK_SIZE_Q, 1), order=(0, 1), ) # offsets off_q = tl.arange(0, BLOCK_SIZE_Q) + q_start_in_seq off_k = tl.arange(0, BLOCK_SIZE_K) * kernel_stride + kernel_size - 1 # load q, do, lse, delta, and keep in SRAM q = tl.load(q_ptrs, boundary_check=(1, 0), padding_option="zero") do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero") lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero") d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero") # init dq dq = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_D), dtype=tl.float32) lo = 0 hi = min(k_len, (q_start_in_seq + BLOCK_SIZE_Q - kernel_size) // kernel_stride + 1) for i in range(lo, hi, BLOCK_SIZE_K): # load k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero") v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero") # compute qk qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) qk += tl.where( off_q[:, None] >= (i * kernel_stride + off_k)[None, :], 0, float("-inf") ) qk += tl.dot(q, tl.trans(k)) * qk_scale # compute p, ds p = tl.exp2(qk - lse) dp = tl.dot(do, v) ds = sm_scale * p * (dp - d) # cast dtype ds = ds.to(q.dtype) # update dq dq += tl.dot(ds, k) # increment pointers k_ptrs = tl.advance(k_ptrs, (BLOCK_SIZE_K, 0)) v_ptrs = tl.advance(v_ptrs, (0, BLOCK_SIZE_K)) # save dq tl.store(dq_ptrs, dq.to(dq_ptr.dtype.element_ty), boundary_check=(0, 1)) def _compressed_attention_fwd( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, kernel_size: int, kernel_stride: int, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, sm_scale: float, ): # dtype check assert k.dtype == q.dtype and v.dtype == q.dtype assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 # shape q_len, num_q_heads, head_dim = q.shape k_len, num_k_heads, head_dim = k.shape v_len, num_v_heads, head_dim = v.shape batch_size = cu_seqlens_q.shape[0] - 1 assert k_len == v_len and q_len > k_len # gqa assert num_k_heads == num_v_heads assert num_q_heads % num_k_heads == 0 num_share_q_heads = num_q_heads // num_k_heads # output tensor o = torch.zeros_like(q) lse = torch.full( (num_q_heads, q_len), fill_value=-torch.inf, dtype=torch.float32, device=q.device, ) # launch kernel grid = lambda META: ( batch_size, num_q_heads, triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]), ) BLOCK_SIZE_Q = 128 BLOCK_SIZE_K = 128 BLOCK_SIZE_D = triton.next_power_of_2(head_dim) num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU) forward_kernel[grid]( q, k, v, o, lse, kernel_size, kernel_stride, cu_seqlens_q, cu_seqlens_k, num_k_heads, num_share_q_heads, head_dim, sm_scale, q.stride(0), q.stride(1), q.stride(2), k.stride(0), k.stride(1), k.stride(2), v.stride(0), v.stride(1), v.stride(2), o.stride(0), o.stride(1), o.stride(2), lse.stride(0), lse.stride(1), BLOCK_SIZE_Q=BLOCK_SIZE_Q, BLOCK_SIZE_K=BLOCK_SIZE_K, BLOCK_SIZE_D=BLOCK_SIZE_D, num_warps=num_warps, num_stages=num_stages, ) return o, lse def _compressed_attention_bwd( o: torch.Tensor, do: torch.Tensor, lse: torch.Tensor, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, kernel_size: int, kernel_stride: int, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, sm_scale: float, ): q_len, num_q_heads, head_dim = q.shape k_len, num_k_heads, head_dim = k.shape v_len, num_v_heads, head_dim = v.shape o_len, num_o_heads, head_dim = o.shape num_share_q_heads = num_q_heads // num_k_heads # compute D delta = torch.zeros([num_o_heads, o_len], device=o.device, dtype=torch.float32) grid = lambda META: (triton.cdiv(o_len, META["BLOCK_SIZE_O"]), num_o_heads) BLOCK_SIZE_O = 256 BLOCK_SIZE_D = triton.next_power_of_2(head_dim) num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_O, IS_HOPPER_GPU) backward_sum_o_do[grid]( o, do, delta, o_len, head_dim, o.stride(0), o.stride(1), o.stride(2), do.stride(0), do.stride(1), do.stride(2), delta.stride(0), delta.stride(1), BLOCK_SIZE_O=BLOCK_SIZE_O, BLOCK_SIZE_D=BLOCK_SIZE_D, num_warps=num_warps, num_stages=num_stages, ) # compute dk dv dk = torch.zeros( num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype ) dv = torch.zeros( num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype ) batch_size = cu_seqlens_q.shape[0] - 1 grid = lambda META: ( batch_size, num_q_heads, triton.cdiv(max_seqlen_k, META["BLOCK_SIZE_K"]), ) BLOCK_SIZE_Q = 64 BLOCK_SIZE_K = 128 BLOCK_SIZE_D = triton.next_power_of_2(head_dim) num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_K, IS_HOPPER_GPU) backward_dkdv[grid]( q, k, v, lse, delta, do, dk, dv, kernel_size, kernel_stride, cu_seqlens_q, cu_seqlens_k, num_k_heads, num_share_q_heads, head_dim, sm_scale, q.stride(0), q.stride(1), q.stride(2), k.stride(0), k.stride(1), k.stride(2), v.stride(0), v.stride(1), v.stride(2), lse.stride(0), lse.stride(1), delta.stride(0), delta.stride(1), do.stride(0), do.stride(1), do.stride(2), dk.stride(0), dk.stride(1), dk.stride(2), dk.stride(3), dv.stride(0), dv.stride(1), dv.stride(2), dv.stride(3), BLOCK_SIZE_Q=BLOCK_SIZE_Q, BLOCK_SIZE_K=BLOCK_SIZE_K, BLOCK_SIZE_D=BLOCK_SIZE_D, num_warps=num_warps, num_stages=num_stages, ) dk = dk.sum(0) dv = dv.sum(0) # compute dq dq = torch.zeros_like(q) grid = lambda META: ( batch_size, num_q_heads, triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]), ) BLOCK_SIZE_Q = 128 BLOCK_SIZE_K = 64 num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU) backward_dq[grid]( q, k, v, lse, delta, do, dq, kernel_size, kernel_stride, cu_seqlens_q, cu_seqlens_k, num_k_heads, num_share_q_heads, head_dim, sm_scale, q.stride(0), q.stride(1), q.stride(2), k.stride(0), k.stride(1), k.stride(2), v.stride(0), v.stride(1), v.stride(2), lse.stride(0), lse.stride(1), delta.stride(0), delta.stride(1), do.stride(0), do.stride(1), do.stride(2), dq.stride(0), dq.stride(1), dq.stride(2), BLOCK_SIZE_Q=BLOCK_SIZE_Q, BLOCK_SIZE_K=BLOCK_SIZE_K, BLOCK_SIZE_D=BLOCK_SIZE_D, num_warps=num_warps, num_stages=num_stages, ) return dq, dk, dv class CompressedAttention(torch.autograd.Function): @staticmethod def forward( ctx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, kernel_size: int, kernel_stride: int, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, sm_scale=None, ): # dtype check assert q.dtype == torch.bfloat16 or q.dtype == torch.float16 assert q.dtype == k.dtype and k.dtype == v.dtype assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 # softmax scale if sm_scale is None: sm_scale = 1 / math.sqrt(q.shape[-1]) o, lse = _compressed_attention_fwd( q, k, v, kernel_size, kernel_stride, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, sm_scale, ) ctx.save_for_backward(q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k) ctx.sm_scale = sm_scale ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_k = max_seqlen_k ctx.kernel_size = kernel_size ctx.kernel_stride = kernel_stride return o, lse @staticmethod def backward(ctx, do: torch.Tensor, *args) -> Any: q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors max_seqlen_q = ctx.max_seqlen_q max_seqlen_k = ctx.max_seqlen_k sm_scale = ctx.sm_scale kernel_size = ctx.kernel_size kernel_stride = ctx.kernel_stride dq, dk, dv = _compressed_attention_bwd( o, do, lse, q, k, v, kernel_size, kernel_stride, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, sm_scale, ) return dq, dk, dv, None, None, None, None, None, None, None @triton.jit def score_kernel( q_ptr, k_ptr, lse_ptr, s_ptr, kernel_size, kernel_stride, # seqlens cu_seqlens_q, cu_seqlens_k, # shape NUM_KV_HEADS, NUM_SHARE_Q_HEADS, HEAD_DIM, # sm_scale sm_scale, # stride stride_qn, stride_qh, stride_qd, stride_kn, stride_kh, stride_kd, stride_lh, stride_ln, stride_sh, stride_sq, stride_sk, # META parameters BLOCK_SIZE_Q: tl.constexpr, # q block size BLOCK_SIZE_K: tl.constexpr, # k block size BLOCK_SIZE_D: tl.constexpr, ): qk_scale = sm_scale * 1.44269504 # get batch id and head id pid_bkh = tl.program_id(0) pid_b = pid_bkh // NUM_KV_HEADS pid_kh = pid_bkh % NUM_KV_HEADS pid_q = tl.program_id(1) pid_k = tl.program_id(2) # get q k start and len after rmpad q_start = tl.load(cu_seqlens_q + pid_b) q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start k_start = tl.load(cu_seqlens_k + pid_b) k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start if pid_q * BLOCK_SIZE_Q >= q_len or pid_k * BLOCK_SIZE_K >= k_len: return # init k pointer and load k k_ptrs = tl.make_block_ptr( base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, shape=(HEAD_DIM, k_len), strides=(stride_kd, stride_kn), offsets=(0, pid_k * BLOCK_SIZE_K), block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), order=(0, 1), ) k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero") # offsets off_q = tl.arange(0, BLOCK_SIZE_Q) + pid_q * BLOCK_SIZE_Q off_k = tl.arange(0, BLOCK_SIZE_K) + pid_k * BLOCK_SIZE_K causal_mask = off_q[:, None] >= (off_k * kernel_stride + kernel_size - 1)[None, :] # init score s = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) # loop over gqa heads for h in range(NUM_SHARE_Q_HEADS): pid_h = pid_kh * NUM_SHARE_Q_HEADS + h q_ptrs = tl.make_block_ptr( base=q_ptr + q_start * stride_qn + pid_h * stride_qh, shape=(q_len, HEAD_DIM), strides=(stride_qn, stride_qd), offsets=(pid_q * BLOCK_SIZE_Q, 0), block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), order=(1, 0), ) lse_ptrs = tl.make_block_ptr( base=lse_ptr + q_start * stride_ln + pid_h * stride_lh, shape=(q_len, 1), strides=(stride_ln, stride_lh), offsets=(pid_q * BLOCK_SIZE_Q, 0), block_shape=(BLOCK_SIZE_Q, 1), order=(0, 1), ) # load q and lse q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero") lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero") # compute qk qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) qk += tl.dot(q, k) * qk_scale # compute score s += tl.where(causal_mask, tl.exp2(qk - lse), 0) # save output s_ptrs = tl.make_block_ptr( base=s_ptr + pid_kh * stride_sh + q_start * stride_sq, shape=(q_len, k_len), strides=(stride_sq, stride_sk), offsets=(pid_q * BLOCK_SIZE_Q, pid_k * BLOCK_SIZE_K), block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_K), order=(1, 0), ) tl.store(s_ptrs, s.to(s_ptr.dtype.element_ty), boundary_check=(0, 1)) def _get_attention_score( q: torch.Tensor, # [total_query_len, num_q_heads, head_dim] k: torch.Tensor, # [total_key_len, num_k_heads, head_dim] lse: torch.Tensor, # [num_q_heads, total_query_len] kernel_size: int, kernel_stride: int, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, sm_scale: float, ) -> torch.Tensor: # dtype check assert q.dtype == torch.bfloat16 or q.dtype == torch.float16 assert q.dtype == k.dtype assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 assert ( lse.dtype == torch.float32 ) # lse here is log2(sum(exp(qk*scale))), not log(sum(exp(qk*scale))) # shape q_len, num_q_heads, head_dim = q.shape k_len, num_k_heads, head_dim = k.shape batch_size = cu_seqlens_q.shape[0] - 1 assert q_len > k_len if sm_scale is None: sm_scale = 1 / math.sqrt(head_dim) # gqa assert num_q_heads % num_k_heads == 0 num_share_q_heads = num_q_heads // num_k_heads # init score score = torch.zeros( num_k_heads, q_len, max_seqlen_k, dtype=torch.float32, device=q.device ) # launch kernel grid = lambda META: ( batch_size * num_k_heads, triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]), triton.cdiv(max_seqlen_k, META["BLOCK_SIZE_K"]), ) BLOCK_SIZE_Q = 128 BLOCK_SIZE_K = 128 BLOCK_SIZE_D = triton.next_power_of_2(head_dim) score_kernel[grid]( q, k, lse, score, kernel_size, kernel_stride, cu_seqlens_q, cu_seqlens_k, num_k_heads, num_share_q_heads, head_dim, sm_scale, q.stride(0), q.stride(1), q.stride(2), k.stride(0), k.stride(1), k.stride(2), lse.stride(0), lse.stride(1), score.stride(0), score.stride(1), score.stride(2), BLOCK_SIZE_Q=BLOCK_SIZE_Q, BLOCK_SIZE_K=BLOCK_SIZE_K, BLOCK_SIZE_D=BLOCK_SIZE_D, num_warps=8, num_stages=3, ) return score @triton.jit def _transform_score_kernel( s_ptr, # score, shape: [num_heads, q_len, k_len] bs_ptr, # block wise score: [num_heads, q_len, num_k_block] offs, cu_seqlens_q, # shape num_heads, num_offs, max_k_len, max_blocks, pad_len, # kernel & block size block_size, block_stride, # block_size // kernel_stride init_blocks, local_blocks, # stride stride_sh, stride_sq, stride_sk, stride_bsh, stride_bsq, stride_bsk, BLOCK_SIZE_Q: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_O: tl.constexpr, ): pid_bh = tl.program_id(0) pid_b = pid_bh // num_heads pid_h = pid_bh % num_heads pid_q = tl.program_id(1) pid_k = tl.program_id(2) q_start = tl.load(cu_seqlens_q + pid_b) q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start k_start = pid_k * BLOCK_SIZE_K if pid_q * BLOCK_SIZE_Q >= q_len: return # load weight off_o = tl.arange(0, BLOCK_SIZE_O) w = tl.load(offs + off_o, mask=off_o < num_offs, other=0) # load score off_q = pid_q * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q) off_k = (k_start + tl.arange(0, BLOCK_SIZE_K)) * block_stride - pad_len off_k = off_k[None, :] + off_o[:, None] s_ptrs = ( s_ptr + q_start * stride_sq + pid_h * stride_sh + off_q[:, None, None] * stride_sq + off_k[None, :, :] * stride_sk ) # weighted sum, [BQ, BO, BK] * [1, BO, 1] -> [BQ, BO, BK] -> [BQ, BK] s = tl.load( s_ptrs, mask=(off_q < q_len)[:, None, None] & (off_k >= 0) & (off_k < max_k_len), other=0, ) s = s * w[None, :, None] s = tl.sum(s, axis=1) # init mask and local mask off_bq = off_q // block_size off_bk = k_start + tl.arange(0, BLOCK_SIZE_K) s = tl.where( ( (off_bq[:, None] >= off_bk[None, :]) # causal mask & (off_bq[:, None] < off_bk[None, :] + local_blocks) # local window ) | (off_bk[None, :] < init_blocks), # init window float("inf"), s, ) # store block wise score bs_ptrs = ( bs_ptr + q_start * stride_bsq + pid_h * stride_bsh + off_q[:, None] * stride_bsq + off_bk[None, :] * stride_bsk ) tl.store( bs_ptrs, s, mask=(off_q < q_len)[:, None] & (off_bk < max_blocks)[None, :], ) def transform_score( score: torch.Tensor, kernel_size: int, kernel_stride: int, block_size: int, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, init_blocks: int = 1, local_blocks: int = 2, ) -> torch.Tensor: num_k_heads, total_query_len, max_key_len = score.shape batch_size = cu_seqlens_q.shape[0] - 1 pad_len = kernel_size // kernel_stride - 1 max_blocks = math.ceil(max_seqlen_q / block_size) block_score = torch.zeros( num_k_heads, total_query_len, max_blocks, dtype=torch.float32, device=score.device, ) offs = ( torch.arange(kernel_size // kernel_stride, device=score.device)[:, None] + torch.arange(block_size // kernel_stride, device=score.device)[None, :] ).view(-1) offs = torch.histc(offs, bins=offs.max() + 1, min=0, max=offs.max()) num_offs = int(offs.shape[0]) BLOCK_SIZE_K = min(128, triton.next_power_of_2(max_blocks)) BLOCK_SIZE_O = triton.next_power_of_2(num_offs) BLOCK_SIZE_Q = 8 grid = ( num_k_heads * batch_size, triton.cdiv(total_query_len, BLOCK_SIZE_Q), triton.cdiv(max_blocks, BLOCK_SIZE_K), ) _transform_score_kernel[grid]( score, block_score, offs, cu_seqlens_q, num_k_heads, offs.shape[0], max_key_len, max_blocks, pad_len, block_size, block_size // kernel_stride, init_blocks, local_blocks, score.stride(0), score.stride(1), score.stride(2), block_score.stride(0), block_score.stride(1), block_score.stride(2), BLOCK_SIZE_Q=BLOCK_SIZE_Q, BLOCK_SIZE_K=BLOCK_SIZE_K, BLOCK_SIZE_O=BLOCK_SIZE_O, num_warps=8, num_stages=3, ) return block_score def compressed_attention( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, kernel_size: int, kernel_stride: int, block_size: int, topk: int, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int = None, max_seqlen_k: int = None, sm_scale: float = None, init_blocks: int = 1, local_blocks: int = 2, parallel_topk_compute: Union[str, bool] = "auto", ) -> Tuple[torch.Tensor, torch.Tensor]: """Attention between query and compressed key and value. Compute attention output and topk block idx used in topk_sparse_attention. Args: q (torch.Tensor): shape [total_q_len, num_q_heads, head_dim] k (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim] v (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim] kernel_size (int): kernel size in compress_key_value kernel_stride (int): stride of compress_key_value block_size (int): key value block size for topk sparse attention. topk (int): number of blocks for each query. cu_seqlens_q (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen. cu_seqlens_k (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_k in flash_attn_func_varlen. max_seqlen_q (int): max q len of the batch. Defaults to None, means (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max().item(). max_seqlen_k (int): max k len of the batch. Defaults to None, means (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item(). sm_scale (float, optional): softmax scale. Defaults to None, means 1/sqrt(head_dim). init_blocks (int, optional): Number of init blocks for each query. Defaults to 1. local_blocks (int, optional): Number of local blocks for each query. Defaults to 2. parallel_topk_compute (str, optional): Only set it to False when the sequence length is too long. This can avoid a current bug. We'll fix this issue later. Defaults to auto, it will be set to False when the sequence length is greater than 32k and True otherwise. Returns: Tuple[torch.Tensor, torch.Tensor]: attention output and topk_idx used in topk_sparse_attention """ if max_seqlen_q is None: max_seqlen_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max().item() if max_seqlen_k is None: max_seqlen_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item() attn_output, lse = CompressedAttention.apply( q, k, v, kernel_size, kernel_stride, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, sm_scale, ) # do not select topk index if topk <= 0: warnings.warn("topk <= 0, returned topk_idx will be None") return attn_output, None assert topk >= init_blocks + local_blocks with torch.no_grad(): num_k_heads, num_q_heads = k.shape[1], q.shape[1] num_shared_q_heads = num_q_heads // num_k_heads batch_size = cu_seqlens_q.shape[0] - 1 q_idx = torch.cat( [ torch.arange(cu_seqlens_q[i + 1] - cu_seqlens_q[i], device=q.device) for i in range(batch_size) ], dim=0, ) q_idx = q_idx // block_size # whether to use parallel version if parallel_topk_compute == "auto": parallel_topk_compute = cu_seqlens_q[-1] <= 32768 # parallel version if parallel_topk_compute: # recompute score score = _get_attention_score( q, k, lse, kernel_size, kernel_stride, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, sm_scale, ) # transform score to block-wise score score = transform_score( score, kernel_size, kernel_stride, block_size, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, init_blocks, local_blocks, ) # get topk topk = min(topk, score.shape[-1]) topk_idx = score.topk(topk, dim=-1).indices.sort(-1).values topk_idx[topk_idx > q_idx[None, :, None]] = -1 topk_idx = topk_idx.to(torch.int32) # non parallel version, avoid some current bugs when sequence length is too long # FIXME: need to fix later else: topk_idx_list = [] for h in range(num_k_heads): # recompute score score = _get_attention_score( q[:, h * num_shared_q_heads : (h + 1) * num_shared_q_heads], k[:, h : h + 1], lse[h * num_shared_q_heads : (h + 1) * num_shared_q_heads], kernel_size, kernel_stride, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, sm_scale, ) # transform score to block-wise score score = transform_score( score, kernel_size, kernel_stride, block_size, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, init_blocks, local_blocks, ) # get topk topk = min(topk, score.shape[-1]) topk_idx = score.topk(topk, dim=-1).indices.sort(-1).values topk_idx[topk_idx > q_idx[None, :, None]] = -1 topk_idx = topk_idx.to(torch.int32) topk_idx_list.append(topk_idx) topk_idx = torch.cat(topk_idx_list, dim=0) return attn_output, topk_idx ================================================ FILE: native_sparse_attention/ops/triton/flash_attention.py ================================================ # Copyright 2025 Xunhao Lai & Jianqiao Lu. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import Any, Optional import torch import triton import triton.language as tl from native_sparse_attention.ops.triton.utils import get_num_warps_stages, is_hopper_gpu IS_HOPPER_GPU = is_hopper_gpu() @triton.jit def forward_kernel( q_ptr, # Q: n x h x d k_ptr, # K: n x h x d v_ptr, # V: n x h x d o_ptr, # O: n x h x d lse_ptr, # LSE: h x n # seqlens cu_seqlens_q, cu_seqlens_k, # shape NUM_KV_HEADS, NUM_SHARE_Q_HEADS, qk_head_dim, v_head_dim, # sm_scale sm_scale, # causal causal, # gqa gqa_interleave, # stride stride_qn, stride_qh, stride_qd, stride_kn, stride_kh, stride_kd, stride_vn, stride_vh, stride_vd, stride_on, stride_oh, stride_od, stride_lh, stride_ln, # META parameters BLOCK_SIZE_Q: tl.constexpr, # q block size BLOCK_SIZE_K: tl.constexpr, # k block size BLOCK_SIZE_KD: tl.constexpr, BLOCK_SIZE_VD: tl.constexpr, ): qk_scale = sm_scale * 1.44269504 # get batch id and head id pid_b = tl.program_id(0) pid_h = tl.program_id(1) pid_q = tl.program_id(2) if gqa_interleave: pid_kh = pid_h % NUM_KV_HEADS else: pid_kh = pid_h // NUM_SHARE_Q_HEADS # get q k start and len after rmpad q_start = tl.load(cu_seqlens_q + pid_b) q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start k_start = tl.load(cu_seqlens_k + pid_b) k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start if BLOCK_SIZE_Q * pid_q >= q_len: return # init qkv pointer q_ptrs = tl.make_block_ptr( base=q_ptr + q_start * stride_qn + pid_h * stride_qh, shape=(q_len, qk_head_dim), strides=(stride_qn, stride_qd), offsets=(pid_q * BLOCK_SIZE_Q, 0), block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_KD), order=(1, 0), ) k_ptrs = tl.make_block_ptr( base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, shape=(qk_head_dim, k_len), strides=(stride_kd, stride_kn), offsets=(0, 0), block_shape=(BLOCK_SIZE_KD, BLOCK_SIZE_K), order=(0, 1), ) v_ptrs = tl.make_block_ptr( base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, shape=(k_len, v_head_dim), strides=(stride_vn, stride_vd), offsets=(0, 0), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_VD), order=(1, 0), ) # load q q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero") # init statistics off_q = tl.arange(0, BLOCK_SIZE_Q) + pid_q * BLOCK_SIZE_Q off_k = tl.arange(0, BLOCK_SIZE_K) m_i = tl.full((BLOCK_SIZE_Q,), float("-inf"), dtype=tl.float32) lse_i = tl.full((BLOCK_SIZE_Q,), float("-inf"), dtype=tl.float32) acc_o = tl.full((BLOCK_SIZE_Q, BLOCK_SIZE_VD), 0, dtype=tl.float32) # full attention or causal attention lo = 0 if causal: hi = min(k_len, (pid_q + 1) * BLOCK_SIZE_Q) else: hi = k_len for i in range(lo, hi, BLOCK_SIZE_K): i = tl.multiple_of(i, BLOCK_SIZE_K) # load k k = tl.load(k_ptrs, boundary_check=(1, 0), padding_option="zero") # compute qk qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) if causal: qk += tl.where(off_q[:, None] >= (i + off_k)[None, :], 0, float("-inf")) else: qk += tl.where((off_k < k_len - i)[None, :], 0, float("-inf")) qk += tl.dot(q, k) * qk_scale # compute m_ij and l_ij m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) p = tl.math.exp2(qk - m_ij[:, None]) l_ij = tl.sum(p, axis=1) # scale acc_o acc_o_scale = tl.math.exp2(m_i - m_ij) acc_o = acc_o * acc_o_scale[:, None] # load v and update acc_o v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero") p = p.to(v.dtype) acc_o += tl.dot(p, v) # update statistics m_i = m_ij lse_i = m_ij + tl.math.log2(tl.math.exp2(lse_i - m_ij) + l_ij) # update ptrs k_ptrs = tl.advance(k_ptrs, (0, BLOCK_SIZE_K)) v_ptrs = tl.advance(v_ptrs, (BLOCK_SIZE_K, 0)) # final scale acc_o = acc_o * tl.math.exp2(m_i - lse_i)[:, None] # save output o_ptrs = tl.make_block_ptr( base=o_ptr + q_start * stride_on + pid_h * stride_oh, shape=(q_len, v_head_dim), strides=(stride_on, stride_od), offsets=(pid_q * BLOCK_SIZE_Q, 0), block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_VD), order=(1, 0), ) tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1)) # save lse l_ptrs = lse_ptr + q_start * stride_ln + pid_h * stride_lh + off_q * stride_ln tl.store(l_ptrs, lse_i, mask=off_q < q_len) @triton.jit def backward_sum_o_do( o_ptr, # O: n x h x d do_ptr, # dO: n x h x d delta_ptr, # D: h x n o_len, HEAD_DIM, stride_on, stride_oh, stride_od, stride_don, stride_doh, stride_dod, stride_dh, stride_dn, BLOCK_SIZE_O: tl.constexpr, BLOCK_SIZE_D: tl.constexpr, ): pid_n = tl.program_id(0) pid_h = tl.program_id(1) off_n = pid_n * BLOCK_SIZE_O + tl.arange(0, BLOCK_SIZE_O) off_d = tl.arange(0, BLOCK_SIZE_D) o = tl.load( o_ptr + off_n[:, None] * stride_on + pid_h * stride_oh + off_d[None, :] * stride_od, mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM), other=0, ).to(tl.float32) do = tl.load( do_ptr + off_n[:, None] * stride_don + pid_h * stride_doh + off_d[None, :] * stride_dod, mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM), other=0, ).to(tl.float32) delta = tl.sum(o * do, axis=1) tl.store( delta_ptr + pid_h * stride_dh + off_n * stride_dn, delta, mask=off_n < o_len ) @triton.jit def backward_dkdv( q_ptr, # Q: n x qh x d k_ptr, # K: n x kh x d v_ptr, # V: n x kh x d lse_ptr, # LSE: qh x n d_ptr, # Delta: qh x n do_ptr, dk_ptr, # DK: sh x n x kh x d dv_ptr, # DV: sh x n x kh x d # seqlens cu_seqlens_q, cu_seqlens_k, # shape NUM_KV_HEADS, NUM_SHARE_Q_HEADS, qk_head_dim, v_head_dim, # sm_scale sm_scale, # causal causal, # gqa gqa_interleave, # stride stride_qn, stride_qh, stride_qd, stride_kn, stride_kh, stride_kd, stride_vn, stride_vh, stride_vd, stride_lh, stride_ln, stride_dh, stride_dn, stride_don, stride_doh, stride_dod, stride_dks, stride_dkn, stride_dkh, stride_dkd, stride_dvs, stride_dvn, stride_dvh, stride_dvd, # META parameters BLOCK_SIZE_Q: tl.constexpr, # q block size BLOCK_SIZE_K: tl.constexpr, # k block size BLOCK_SIZE_KD: tl.constexpr, BLOCK_SIZE_VD: tl.constexpr, ): qk_scale = sm_scale * 1.44269504 # get batch id and head id pid_b = tl.program_id(0) pid_h = tl.program_id(1) if gqa_interleave: pid_kh = pid_h % NUM_SHARE_Q_HEADS pid_sh = pid_h // NUM_SHARE_Q_HEADS else: pid_kh = pid_h // NUM_SHARE_Q_HEADS pid_sh = pid_h % NUM_SHARE_Q_HEADS pid_k = tl.program_id(2) # get q k start and len after rmpad q_start = tl.load(cu_seqlens_q + pid_b) q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start k_start = tl.load(cu_seqlens_k + pid_b) k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start if BLOCK_SIZE_K * pid_k >= k_len: return # init pointers k_ptrs = tl.make_block_ptr( base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, shape=(k_len, qk_head_dim), strides=(stride_kn, stride_kd), offsets=(pid_k * BLOCK_SIZE_K, 0), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_KD), order=(1, 0), ) dk_ptrs = tl.make_block_ptr( base=dk_ptr + k_start * stride_dkn + pid_kh * stride_dkh + pid_sh * stride_dks, shape=(k_len, qk_head_dim), strides=(stride_dkn, stride_dkd), offsets=(pid_k * BLOCK_SIZE_K, 0), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_KD), order=(1, 0), ) v_ptrs = tl.make_block_ptr( base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, shape=(k_len, v_head_dim), strides=(stride_vn, stride_vd), offsets=(pid_k * BLOCK_SIZE_K, 0), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_VD), order=(1, 0), ) dv_ptrs = tl.make_block_ptr( base=dv_ptr + k_start * stride_dvn + pid_kh * stride_dvh + pid_sh * stride_dvs, shape=(k_len, v_head_dim), strides=(stride_dvn, stride_dvd), offsets=(pid_k * BLOCK_SIZE_K, 0), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_VD), order=(1, 0), ) # offsets off_q = tl.arange(0, BLOCK_SIZE_Q) off_k = tl.arange(0, BLOCK_SIZE_K) + pid_k * BLOCK_SIZE_K # load k v and keep in SRAM k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero") v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero") # init dk dv dk = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_KD), dtype=tl.float32) dv = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_VD), dtype=tl.float32) # causal if causal: q_lo = pid_k * BLOCK_SIZE_K else: q_lo = 0 q_ptrs = tl.make_block_ptr( base=q_ptr + q_start * stride_qn + pid_h * stride_qh, shape=(q_len, qk_head_dim), strides=(stride_qn, stride_qd), offsets=(q_lo, 0), block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_KD), order=(1, 0), ) do_ptrs = tl.make_block_ptr( base=do_ptr + q_start * stride_don + pid_h * stride_doh, shape=(q_len, v_head_dim), strides=(stride_don, stride_dod), offsets=(q_lo, 0), block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_VD), order=(1, 0), ) d_ptrs = tl.make_block_ptr( base=d_ptr + q_start * stride_dn + pid_h * stride_dh, shape=(q_len, 1), strides=(stride_dn, stride_dh), offsets=(q_lo, 0), block_shape=(BLOCK_SIZE_Q, 1), order=(0, 1), ) lse_ptrs = tl.make_block_ptr( base=lse_ptr + q_start * stride_ln + pid_h * stride_lh, shape=(q_len, 1), strides=(stride_ln, stride_lh), offsets=(q_lo, 0), block_shape=(BLOCK_SIZE_Q, 1), order=(0, 1), ) # loop for q blocks for i in range(q_lo, q_len, BLOCK_SIZE_Q): # load q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero") do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero") lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero") d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero") # compute qk if causal: qk = tl.where( (off_q + i)[:, None] >= off_k[None, :], float(0.0), float("-inf") ) else: qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) qk += tl.dot(q, k.T) * qk_scale # compute p, ds p = tl.math.exp2(qk - lse) dp = tl.dot(do, v.T) ds = sm_scale * p * (dp - d) # cast dtype p = p.to(do.dtype) ds = ds.to(q.dtype) # update dk and dv dk += tl.dot(ds.T, q) dv += tl.dot(p.T, do) # increment pointers q_ptrs = tl.advance(q_ptrs, (BLOCK_SIZE_Q, 0)) do_ptrs = tl.advance(do_ptrs, (BLOCK_SIZE_Q, 0)) lse_ptrs = tl.advance(lse_ptrs, (BLOCK_SIZE_Q, 0)) d_ptrs = tl.advance(d_ptrs, (BLOCK_SIZE_Q, 0)) # save dk dv tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), boundary_check=(0, 1)) tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), boundary_check=(0, 1)) @triton.jit def backward_dq( q_ptr, # Q: n x qh x d k_ptr, # K: n x kh x d v_ptr, # V: n x kh x d lse_ptr, # LSE: qh x n d_ptr, # Delta: qh x n do_ptr, dq_ptr, # seqlens cu_seqlens_q, cu_seqlens_k, # shape NUM_KV_HEADS, NUM_SHARE_Q_HEADS, qk_head_dim, v_head_dim, # sm_scale sm_scale, # causal causal, # gqa gqa_interleave, # stride stride_qn, stride_qh, stride_qd, stride_kn, stride_kh, stride_kd, stride_vn, stride_vh, stride_vd, stride_lh, stride_ln, stride_dh, stride_dn, stride_don, stride_doh, stride_dod, stride_dqn, stride_dqh, stride_dqd, # META parameters BLOCK_SIZE_Q: tl.constexpr, # q block size BLOCK_SIZE_K: tl.constexpr, # k block size BLOCK_SIZE_KD: tl.constexpr, BLOCK_SIZE_VD: tl.constexpr, ): qk_scale = sm_scale * 1.44269504 # get batch id and head id pid_b = tl.program_id(0) pid_h = tl.program_id(1) pid_q = tl.program_id(2) if gqa_interleave: pid_kh = pid_h % NUM_KV_HEADS else: pid_kh = pid_h // NUM_SHARE_Q_HEADS # get q k start and len after rmpad q_start = tl.load(cu_seqlens_q + pid_b) q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start k_start = tl.load(cu_seqlens_k + pid_b) k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start if BLOCK_SIZE_Q * pid_q >= q_len: return # init pointers q_ptrs = tl.make_block_ptr( base=q_ptr + q_start * stride_qn + pid_h * stride_qh, shape=(q_len, qk_head_dim), strides=(stride_qn, stride_qd), offsets=(pid_q * BLOCK_SIZE_Q, 0), block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_KD), order=(1, 0), ) dq_ptrs = tl.make_block_ptr( base=dq_ptr + q_start * stride_dqn + pid_h * stride_dqh, shape=(q_len, qk_head_dim), strides=(stride_dqn, stride_dqd), offsets=(pid_q * BLOCK_SIZE_Q, 0), block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_KD), order=(1, 0), ) k_ptrs = tl.make_block_ptr( base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, shape=(k_len, qk_head_dim), strides=(stride_kn, stride_kd), offsets=(0, 0), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_KD), order=(1, 0), ) v_ptrs = tl.make_block_ptr( base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, shape=(k_len, qk_head_dim), strides=(stride_vn, stride_vd), offsets=(0, 0), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_VD), order=(1, 0), ) do_ptrs = tl.make_block_ptr( base=do_ptr + q_start * stride_don + pid_h * stride_doh, shape=(q_len, qk_head_dim), strides=(stride_don, stride_dod), offsets=(pid_q * BLOCK_SIZE_Q, 0), block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_VD), order=(1, 0), ) d_ptrs = tl.make_block_ptr( base=d_ptr + q_start * stride_dn + pid_h * stride_dh, shape=(q_len, 1), strides=(stride_dn, stride_dh), offsets=(pid_q * BLOCK_SIZE_Q, 0), block_shape=(BLOCK_SIZE_Q, 1), order=(0, 1), ) lse_ptrs = tl.make_block_ptr( base=lse_ptr + q_start * stride_ln + pid_h * stride_lh, shape=(q_len, 1), strides=(stride_ln, stride_lh), offsets=(pid_q * BLOCK_SIZE_Q, 0), block_shape=(BLOCK_SIZE_Q, 1), order=(0, 1), ) # offsets off_q = tl.arange(0, BLOCK_SIZE_Q) + pid_q * BLOCK_SIZE_Q off_k = tl.arange(0, BLOCK_SIZE_K) # load q, do, lse, delta, and keep in SRAM q = tl.load(q_ptrs, boundary_check=(1, 0), padding_option="zero") do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero") lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero") d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero") # init dq dq = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_KD), dtype=tl.float32) # causal if causal: k_hi = (pid_q + 1) * BLOCK_SIZE_Q else: k_hi = k_len for j in range(0, k_hi, BLOCK_SIZE_K): # load k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero") v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero") # compute qk if causal: qk = tl.where( off_q[:, None] >= (off_k + j)[None, :], float(0.0), float("-inf") ) else: qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) qk += tl.dot(q, k.T) * qk_scale # compute p, ds p = tl.math.exp2(qk - lse) dp = tl.dot(do, v.T) ds = sm_scale * p * (dp - d) # cast dtype ds = ds.to(q.dtype) # update dq dq += tl.dot(ds, k) # increment pointers k_ptrs = tl.advance(k_ptrs, (BLOCK_SIZE_K, 0)) v_ptrs = tl.advance(v_ptrs, (BLOCK_SIZE_K, 0)) # save dq tl.store(dq_ptrs, dq.to(dq_ptr.dtype.element_ty), boundary_check=(0, 1)) def _flash_attention_fwd( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, causal: bool, sm_scale: float, gqa_interleave: bool = False, ): # dtype check assert q.dtype == torch.bfloat16 or q.dtype == torch.float16 assert k.dtype == q.dtype and v.dtype == q.dtype assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 # shape q_len, num_q_heads, qk_head_dim = q.shape k_len, num_k_heads, qk_head_dim = k.shape v_len, num_v_heads, v_head_dim = v.shape batch_size = cu_seqlens_q.shape[0] - 1 assert qk_head_dim <= 256 and v_head_dim <= 256, "head_dim must be less than 256" assert q_len == k_len and k_len == v_len # gqa assert num_k_heads == num_v_heads assert num_q_heads % num_k_heads == 0 num_share_q_heads = num_q_heads // num_k_heads # output tensor o = torch.empty(q.shape[0], q.shape[1], v.shape[-1], dtype=q.dtype, device=q.device) lse = torch.empty(num_q_heads, q_len, dtype=torch.float32, device=q.device) # launch kernel grid = lambda META: ( batch_size, num_q_heads, triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]), ) BLOCK_SIZE_Q = 128 BLOCK_SIZE_K = 64 BLOCK_SIZE_KD = triton.next_power_of_2(qk_head_dim) BLOCK_SIZE_VD = triton.next_power_of_2(v_head_dim) num_warps, num_stages = get_num_warps_stages( max(qk_head_dim, v_head_dim), BLOCK_SIZE_Q, IS_HOPPER_GPU ) forward_kernel[grid]( q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k, num_k_heads, num_share_q_heads, qk_head_dim, v_head_dim, sm_scale, causal, gqa_interleave, q.stride(0), q.stride(1), q.stride(2), k.stride(0), k.stride(1), k.stride(2), v.stride(0), v.stride(1), v.stride(2), o.stride(0), o.stride(1), o.stride(2), lse.stride(0), lse.stride(1), BLOCK_SIZE_Q=BLOCK_SIZE_Q, BLOCK_SIZE_K=BLOCK_SIZE_K, BLOCK_SIZE_KD=BLOCK_SIZE_KD, BLOCK_SIZE_VD=BLOCK_SIZE_VD, num_warps=num_warps, num_stages=num_stages, ) return o, lse def _flash_attention_bwd( o: torch.Tensor, do: torch.Tensor, lse: torch.Tensor, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, causal: bool, sm_scale: float, gqa_interleave: bool = False, ): q_len, num_q_heads, qk_head_dim = q.shape k_len, num_k_heads, qk_head_dim = k.shape v_len, num_v_heads, v_head_dim = v.shape o_len, num_o_heads, v_head_dim = o.shape num_share_q_heads = num_q_heads // num_k_heads # compute D delta = torch.empty([num_o_heads, o_len], device=o.device, dtype=torch.float32) grid = lambda META: (triton.cdiv(o_len, META["BLOCK_SIZE_O"]), num_o_heads) BLOCK_SIZE_O = 256 BLOCK_SIZE_VD = triton.next_power_of_2(v_head_dim) num_warps, num_stages = get_num_warps_stages( max(qk_head_dim, v_head_dim), BLOCK_SIZE_O, IS_HOPPER_GPU ) backward_sum_o_do[grid]( o, do, delta, o_len, v_head_dim, o.stride(0), o.stride(1), o.stride(2), do.stride(0), do.stride(1), do.stride(2), delta.stride(0), delta.stride(1), BLOCK_SIZE_O=BLOCK_SIZE_O, BLOCK_SIZE_D=BLOCK_SIZE_VD, num_warps=num_warps, num_stages=num_stages, ) # compute dk dv dk = torch.empty( num_share_q_heads, k_len, num_k_heads, qk_head_dim, device=k.device, dtype=k.dtype, ) dv = torch.empty( num_share_q_heads, k_len, num_k_heads, v_head_dim, device=k.device, dtype=k.dtype, ) batch_size = cu_seqlens_q.shape[0] - 1 grid = lambda META: ( batch_size, num_q_heads, triton.cdiv(max_seqlen_k, META["BLOCK_SIZE_K"]), ) BLOCK_SIZE_Q = 64 BLOCK_SIZE_K = 64 BLOCK_SIZE_KD = triton.next_power_of_2(qk_head_dim) BLOCK_SIZE_VD = triton.next_power_of_2(v_head_dim) num_warps, num_stages = get_num_warps_stages( max(qk_head_dim, v_head_dim), BLOCK_SIZE_K, IS_HOPPER_GPU ) backward_dkdv[grid]( q, k, v, lse, delta, do, dk, dv, cu_seqlens_q, cu_seqlens_k, num_k_heads, num_share_q_heads, qk_head_dim, v_head_dim, sm_scale, causal, gqa_interleave, q.stride(0), q.stride(1), q.stride(2), k.stride(0), k.stride(1), k.stride(2), v.stride(0), v.stride(1), v.stride(2), lse.stride(0), lse.stride(1), delta.stride(0), delta.stride(1), do.stride(0), do.stride(1), do.stride(2), dk.stride(0), dk.stride(1), dk.stride(2), dk.stride(3), dv.stride(0), dv.stride(1), dv.stride(2), dv.stride(3), BLOCK_SIZE_Q=BLOCK_SIZE_Q, BLOCK_SIZE_K=BLOCK_SIZE_K, BLOCK_SIZE_KD=BLOCK_SIZE_KD, BLOCK_SIZE_VD=BLOCK_SIZE_VD, num_warps=num_warps, num_stages=num_stages, ) dk = dk.sum(0) dv = dv.sum(0) # compute dq dq = torch.empty_like(q) grid = lambda META: ( batch_size, num_q_heads, triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]), ) BLOCK_SIZE_Q = 64 if max(qk_head_dim, v_head_dim) > 128 else 128 BLOCK_SIZE_K = 64 num_warps, num_stages = get_num_warps_stages( max(qk_head_dim, v_head_dim), BLOCK_SIZE_Q, IS_HOPPER_GPU ) backward_dq[grid]( q, k, v, lse, delta, do, dq, cu_seqlens_q, cu_seqlens_k, num_k_heads, num_share_q_heads, qk_head_dim, v_head_dim, sm_scale, causal, gqa_interleave, q.stride(0), q.stride(1), q.stride(2), k.stride(0), k.stride(1), k.stride(2), v.stride(0), v.stride(1), v.stride(2), lse.stride(0), lse.stride(1), delta.stride(0), delta.stride(1), do.stride(0), do.stride(1), do.stride(2), dq.stride(0), dq.stride(1), dq.stride(2), BLOCK_SIZE_Q=BLOCK_SIZE_Q, BLOCK_SIZE_K=BLOCK_SIZE_K, BLOCK_SIZE_KD=BLOCK_SIZE_KD, BLOCK_SIZE_VD=BLOCK_SIZE_VD, num_warps=num_warps, num_stages=num_stages, ) return dq, dk, dv class FlashAttention(torch.autograd.Function): @staticmethod def forward( ctx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, causal=True, sm_scale=None, gqa_interleave=False, ): # softmax scale if sm_scale is None: sm_scale = 1 / math.sqrt(q.shape[-1]) o, lse = _flash_attention_fwd( q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal, sm_scale, gqa_interleave, ) ctx.save_for_backward(q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k) ctx.sm_scale = sm_scale ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_k = max_seqlen_k ctx.causal = causal ctx.gqa_interleave = gqa_interleave return o @staticmethod def backward(ctx, do: torch.Tensor, *args) -> Any: q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors max_seqlen_q = ctx.max_seqlen_q max_seqlen_k = ctx.max_seqlen_k sm_scale = ctx.sm_scale causal = ctx.causal gqa_interleave = ctx.gqa_interleave dq, dk, dv = _flash_attention_bwd( o, do, lse, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal, sm_scale, gqa_interleave, ) return dq, dk, dv, None, None, None, None, None, None, None def flash_attention_varlen( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, causal: bool = False, sm_scale: Optional[float] = None, gqa_interleave: bool = False, ) -> torch.Tensor: """Flash attention with variable length based on triton. Args: q (torch.Tensor): shape [total_q_len, num_q_heads, head_dim] k (torch.Tensor): shape [total_kv_len, num_q_heads, head_dim] v (torch.Tensor): shape [total_kv_len, num_q_heads, head_dim] cu_seqlens_q (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen. cu_seqlens_k (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_k in flash_attn_func_varlen. max_seqlen_q (int): max q len of the batch. max_seqlen_k (int): max k len of the batch. causal (bool, optional): Causal mask. Defaults to False. sm_scale (float, optional): softmax scale. Defaults to None, means 1/sqrt(head_dim). gqa_interleave (bool, optional): GQA pattern. Defaults to False, use Llama style GQA. Returns: torch.Tensor: attention output with shape [total_q_len, num_q_heads, head_dim] """ return FlashAttention.apply( q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal, sm_scale, gqa_interleave, ) ================================================ FILE: native_sparse_attention/ops/triton/flash_attention_decode.py ================================================ # Copyright 2025 Xunhao Lai & Jianqiao Lu. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import torch import triton import triton.language as tl from typing import Optional @triton.jit def decode_kernel( q_ptr, # Q: b x h x d k_ptr, # K: b x n x h x d v_ptr, # V: b x n x h x d o_ptr, # O: b x h x d seqlens, # shape BATCH_SIZE, NUM_SHARE_Q_HEADS, HEAD_DIM, # sm_scale sm_scale, # stride stride_qb, stride_qh, stride_qd, stride_kb, stride_kn, stride_kh, stride_kd, stride_vb, stride_vn, stride_vh, stride_vd, stride_ob, stride_oh, stride_od, # META parameters BLOCK_SIZE_B: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_D: tl.constexpr, ): qk_scale = sm_scale * 1.44269504 # get batch id and head id pid_h = tl.program_id(0) pid_b = tl.program_id(1) pid_kh = pid_h // NUM_SHARE_Q_HEADS # get q k start and len after rmpad off_b = tl.arange(0, BLOCK_SIZE_B) kv_len = tl.load( seqlens + pid_b * BLOCK_SIZE_B + off_b, mask=pid_b * BLOCK_SIZE_B + off_b < BATCH_SIZE, other=0, ) max_kv_len = tl.max(kv_len) # init qkv pointer q_ptrs = tl.make_block_ptr( base=q_ptr + pid_h * stride_qh, shape=(BATCH_SIZE, HEAD_DIM), strides=(stride_qb, stride_qd), offsets=(pid_b * BLOCK_SIZE_B, 0), block_shape=(BLOCK_SIZE_B, BLOCK_SIZE_D), order=(1, 0), ) k_ptrs = tl.make_block_ptr( base=k_ptr + pid_kh * stride_kh, shape=(BATCH_SIZE, max_kv_len, HEAD_DIM), strides=(stride_kb, stride_kn, stride_kd), offsets=(pid_b * BLOCK_SIZE_B, 0, 0), block_shape=(BLOCK_SIZE_B, BLOCK_SIZE_K, BLOCK_SIZE_D), order=(2, 1, 0), ) v_ptrs = tl.make_block_ptr( base=v_ptr + pid_kh * stride_vh, shape=(BATCH_SIZE, max_kv_len, HEAD_DIM), strides=(stride_vb, stride_vn, stride_vd), offsets=(pid_b * BLOCK_SIZE_B, 0, 0), block_shape=(BLOCK_SIZE_B, BLOCK_SIZE_K, BLOCK_SIZE_D), order=(2, 1, 0), ) # load q q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero") # init statistics off_k = tl.arange(0, BLOCK_SIZE_K) m_i = tl.full((BLOCK_SIZE_B,), float("-inf"), dtype=tl.float32) lse_i = tl.full((BLOCK_SIZE_B,), float("-inf"), dtype=tl.float32) acc_o = tl.full((BLOCK_SIZE_B, BLOCK_SIZE_D), 0, dtype=tl.float32) # full attention or causal attention for i in range(0, max_kv_len, BLOCK_SIZE_K): i = tl.multiple_of(i, BLOCK_SIZE_K) # load k k = tl.load(k_ptrs, boundary_check=(0, 1, 2), padding_option="zero") # compute qk qk = tl.zeros((BLOCK_SIZE_B, BLOCK_SIZE_K), dtype=tl.float32) qk += tl.where(off_k[None, :] + i < kv_len[:, None], 0, float("-inf")) # [B, D], [B, K, D] -> [B, K] qk += tl.sum(q[:, None, :] * k, axis=2) * qk_scale # compute m_ij and l_ij m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) p = tl.math.exp2(qk - m_ij[:, None]) l_ij = tl.sum(p, axis=1) # scale acc_o acc_o_scale = tl.math.exp2(m_i - m_ij) acc_o = acc_o * acc_o_scale[:, None] # load v and update acc_o v = tl.load(v_ptrs, boundary_check=(0, 1, 2), padding_option="zero") p = p.to(v.dtype) # [B, K], [B, K, D] -> [B, D] acc_o += tl.sum(p[:, :, None] * v, axis=1) # update statistics m_i = m_ij lse_i = m_ij + tl.math.log2(tl.math.exp2(lse_i - m_ij) + l_ij) # update ptrs k_ptrs = tl.advance(k_ptrs, (0, BLOCK_SIZE_K, 0)) v_ptrs = tl.advance(v_ptrs, (0, BLOCK_SIZE_K, 0)) # final scale acc_o = acc_o * tl.math.exp2(m_i - lse_i)[:, None] # save output o_ptrs = tl.make_block_ptr( base=o_ptr + pid_h * stride_oh, shape=(BATCH_SIZE, HEAD_DIM), strides=(stride_ob, stride_od), offsets=(pid_b * BLOCK_SIZE_B, 0), block_shape=(BLOCK_SIZE_B, BLOCK_SIZE_D), order=(1, 0), ) tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1)) def flash_attention_decode( q: torch.Tensor, # [batch_size, num_heads, head_dim] k: torch.Tensor, # [batch_size, max_len, num_heads, head_dim] v: torch.Tensor, seqlens: torch.Tensor, # [batch_size, ] sm_scale: Optional[float] = None, ) -> torch.Tensor: """flash attention for decode. Args: q (torch.Tensor): query, shape [batch_size, num_q_heads, head_dim] k (torch.Tensor): key, shape [batch_size, kv_len, num_kv_heads, head_dim] v (torch.Tensor): value, shape [batch_size, kv_len, num_kv_heads, head_dim] seqlens (torch.Tensor): kv length for each sequence sm_scale (Optional[float]): softmax scale, default to 1/sqrt(head_dim) Returns: torch.Tensor: attention output """ # dtype check assert q.dtype == torch.bfloat16 or q.dtype == torch.float16 assert k.dtype == q.dtype and v.dtype == q.dtype assert seqlens.dtype == torch.int32 # shape batch_size, num_q_heads, head_dim = q.shape _, k_len, num_k_heads, head_dim = k.shape _, v_len, num_v_heads, head_dim = v.shape assert k_len == v_len and batch_size == seqlens.shape[0] # gqa assert num_k_heads == num_v_heads assert num_q_heads % num_k_heads == 0 num_share_q_heads = num_q_heads // num_k_heads # sm scale if sm_scale is None: sm_scale = 1 / math.sqrt(head_dim) # output tensor o = torch.zeros_like(q) # launch kernel num_warps = 4 if head_dim <= 64 else 8 num_stages = 3 # there is a bug for triton 3.0.0 if BLOCK_SIZE_B > 16 BLOCK_SIZE_B = min(16, triton.next_power_of_2(batch_size)) BLOCK_SIZE_K = 128 BLOCK_SIZE_D = triton.next_power_of_2(head_dim) grid = (num_q_heads, triton.cdiv(batch_size, BLOCK_SIZE_B)) decode_kernel[grid]( q, k, v, o, seqlens, batch_size, num_share_q_heads, head_dim, sm_scale, q.stride(0), q.stride(1), q.stride(2), k.stride(0), k.stride(1), k.stride(2), k.stride(3), v.stride(0), v.stride(1), v.stride(2), v.stride(3), o.stride(0), o.stride(1), o.stride(2), BLOCK_SIZE_B=BLOCK_SIZE_B, BLOCK_SIZE_K=BLOCK_SIZE_K, BLOCK_SIZE_D=BLOCK_SIZE_D, num_warps=num_warps, num_stages=num_stages, ) return o def torch_attention_decode( q: torch.Tensor, # [batch_size, num_heads, head_dim] k: torch.Tensor, # [batch_size, max_len, num_heads, head_dim] v: torch.Tensor, seqlens: torch.Tensor, # [batch_size, ] sm_scale: Optional[float] = None, ): # dtype check assert q.dtype == torch.bfloat16 or q.dtype == torch.float16 assert k.dtype == q.dtype and v.dtype == q.dtype assert seqlens.dtype == torch.int32 # shape batch_size, num_q_heads, head_dim = q.shape _, k_len, num_k_heads, head_dim = k.shape _, v_len, num_v_heads, head_dim = v.shape assert k_len == v_len and batch_size == seqlens.shape[0] # gqa assert num_k_heads == num_v_heads assert num_q_heads % num_k_heads == 0 num_share_q_heads = num_q_heads // num_k_heads # sm scale if sm_scale is None: sm_scale = 1 / math.sqrt(head_dim) # attention attn = ( torch.einsum( "bqhd,bkhd->bhqk", q.unsqueeze(1), k.repeat_interleave(num_share_q_heads, dim=2), ) * sm_scale ) mask = torch.arange(k_len, device=q.device)[None, :] < seqlens[:, None] attn = attn.masked_fill(~mask[:, None, None, :], -torch.inf) attn = torch.softmax(attn, dim=-1, dtype=torch.float32).to(q.dtype) out = torch.einsum( "bhqk,bkhd->bqhd", attn, v.repeat_interleave(num_share_q_heads, dim=2) ).squeeze(1) return out if __name__ == "__main__": torch.manual_seed(42) batch_size = 76 max_length = 8192 seqlens = torch.arange(batch_size, dtype=torch.int32).cuda() * 128 + 1 seqlens[seqlens > max_length] = max_length seqlens = seqlens[torch.randn_like(seqlens, dtype=torch.float32).argsort(-1)] q = ( torch.empty(batch_size, 32, 128, device="cuda") .uniform_(-1, 1) .to(torch.bfloat16) ) k = ( torch.empty(batch_size, max_length, 4, 128, device="cuda") .uniform_(-1, 1) .to(torch.bfloat16) ) v = ( torch.empty(batch_size, max_length, 4, 128, device="cuda") .uniform_(-1, 1) .to(torch.bfloat16) ) o1 = torch_attention_decode(q, k, v, seqlens) o2 = flash_attention_decode(q, k, v, seqlens) print(torch.allclose(o1, o2, atol=1e-2, rtol=1e-2)) ================================================ FILE: native_sparse_attention/ops/triton/linear_compress.py ================================================ # Copyright 2025 Xunhao Lai & Jianqiao Lu. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from typing import Optional, Tuple, Any import triton import torch import triton.language as tl from einops import rearrange, einsum from native_sparse_attention.ops.triton.utils import is_hopper_gpu IS_HOPPER_GPU = is_hopper_gpu() @triton.jit def linear_compress_fwd_kernel( X, # input pointer [total_len, num_heads, head_dim] Y, # output pointer [total_compressed_len, num_heads, head_dim] W, # weight matrix pointer [num_heads, kernel_size, head_dim, head_dim] cu_seqlens_x, # cumulative sequence lengths before compression cu_seqlens_y, # cumulative sequence lengths after compression stride_xn, # stride for X's sequence dimension stride_xh, # stride for X's num head dimension stride_xd, # stride for X's head_dim dimension stride_wh, # stride for W's num head dimension stride_wk, # stride for W's kernel size dimension stride_wd, # stride for W's initial head dim dimension stride_wD, # stride for W's final head dim dimension stride_yn, # stride for Y's sequence dimension stride_yh, # stride for Y's num head dimension stride_yd, # stride for Y's head_dim dimension NUM_HEADS: tl.constexpr, # total num heads KERNEL_SIZE: tl.constexpr, # kernel size when calculate the output KERNEL_STRIDE: tl.constexpr, # kernel stride when calculate the output HEADd_DIM: tl.constexpr, # initial head dimension size HEADD_DIM: tl.constexpr, # final head dimension size BLOCK_OUTPUT_SEQ_SIZE: tl.constexpr, # Loaded output len BLOCK_KERNEL_SIZE: tl.constexpr, # Loaded kernel size when calculate the output BLOCK_HEADd_DIM: tl.constexpr, # Loaded orignal head dimension size BLOCK_HEADD_DIM: tl.constexpr, # loaded final head dimension size ): pid_bh = tl.program_id(0) pid_b = pid_bh // NUM_HEADS pid_h = pid_bh % NUM_HEADS pid_k = tl.program_id(1) pid_D = tl.program_id(2) x_start = tl.load(cu_seqlens_x + pid_b) x_end = tl.load(cu_seqlens_x + pid_b + 1) x_len = x_end - x_start y_start = tl.load(cu_seqlens_y + pid_b) y_end = tl.load(cu_seqlens_y + pid_b + 1) y_len = y_end - y_start if pid_k * BLOCK_OUTPUT_SEQ_SIZE >= y_len: return off_kernel_size = tl.arange(0, BLOCK_KERNEL_SIZE) off_d = tl.arange(0, BLOCK_HEADd_DIM) off_output_seq_size = tl.arange(0, BLOCK_OUTPUT_SEQ_SIZE) x_base_ptrs = ( X + pid_h * stride_xh + x_start * stride_xn + ( ( pid_k * BLOCK_OUTPUT_SEQ_SIZE * KERNEL_STRIDE + off_output_seq_size * KERNEL_STRIDE )[:, None] + off_kernel_size[None, :] )[:, :, None] * stride_xn + off_d[None, None, :] * stride_xd ) x_base_mask = ( ( ( pid_k * BLOCK_OUTPUT_SEQ_SIZE * KERNEL_STRIDE + off_output_seq_size * KERNEL_STRIDE )[:, None] + off_kernel_size[None, :] ) < x_len )[:, :, None] w_ptrs = tl.make_block_ptr( base=W + pid_h * stride_wh, shape=(KERNEL_SIZE, HEADd_DIM, HEADD_DIM), strides=(stride_wk, stride_wd, stride_wD), offsets=(0, 0, pid_D * BLOCK_HEADD_DIM), block_shape=(BLOCK_KERNEL_SIZE, BLOCK_HEADd_DIM, BLOCK_HEADD_DIM), order=(2, 1, 0), ) y_ptrs = tl.make_block_ptr( base=Y + y_start * stride_yn + pid_h * stride_yh, shape=(y_len, HEADD_DIM), strides=(stride_yn, stride_yd), offsets=(pid_k * BLOCK_OUTPUT_SEQ_SIZE, pid_D * BLOCK_HEADD_DIM), block_shape=(BLOCK_OUTPUT_SEQ_SIZE, BLOCK_HEADD_DIM), order=(1, 0), ) y_d = tl.full((BLOCK_OUTPUT_SEQ_SIZE, BLOCK_HEADD_DIM), 0, dtype=tl.float32) for i in range(0, HEADd_DIM, BLOCK_HEADd_DIM): x_ptrs = x_base_ptrs + i * stride_xd x_mask = x_base_mask & ((i + off_d) < HEADd_DIM)[None, None, :] x = tl.load(x_ptrs, mask=x_mask, other=0) x = tl.reshape(x, (BLOCK_OUTPUT_SEQ_SIZE, BLOCK_KERNEL_SIZE * BLOCK_HEADd_DIM)) # x : [n, k * bd] w = tl.load(w_ptrs, boundary_check=(0, 1, 2), padding_option="zero") w = tl.reshape(w, (BLOCK_KERNEL_SIZE * BLOCK_HEADd_DIM, BLOCK_HEADD_DIM)) # w: [k * bd, D] y_d += tl.dot(x, w) # y_d : [n, D] w_ptrs = tl.advance(w_ptrs, (0, BLOCK_HEADd_DIM, 0)) tl.store(y_ptrs, y_d.to(y_ptrs.dtype.element_ty), boundary_check=(0, 1)) @triton.jit def linear_compress_bwd_kernel( DX, # X's gradient pointer [total_len, num_heads, head_dim] DY, # Y's gradient pointer [total_compressed_len, num_heads, head_dim] DW, # weight's gradient pointer [num_heads, kernel_size, head_dim, head_dim] X, # x pointer [total_len, num_heads, head_dim] W, # weight matrix pointer [num_heads, kernel_size, head_dim, head_dim] cu_seqlens_x, # cumulative sequence lengths before compression cu_seqlens_y, # cumulative sequence lengths after compression stride_xn, # stride for X's sequence dimension stride_xh, # stride for X's num head dimension stride_xd, # stride for X's head_dim dimension stride_wh, # stride for W's num head dimension stride_wk, # stride for W's kernel size dimension stride_wd, # stride for W's initial head dim dimension stride_wD, # stride for W's final head dim dimension stride_dxn, # stride for DX's sequence dimension stride_dxh, # stride for DX's num head dimension stride_dxd, # stride for DX's head_dim dimension stride_dwh, # stride for DW's num head dimension stride_dwk, # stride for DW's kernel size dimension stride_dwd, # stride for DW's initial head dim dimension stride_dwD, # stride for DW's final head dim dimension stride_dyn, # stride for DY's sequence dimension stride_dyh, # stride for DY's num head dimension stride_dyd, # stride for DY's head_dim dimension NUM_HEADS: tl.constexpr, # total num heads NUM_PARALLEL_HEADd_NUM: tl.constexpr, # total parallel process among head dim d KERNEL_SIZE: tl.constexpr, # kernel size when calculate the output KERNEL_STRIDE: tl.constexpr, # kernel stride when calculate the output HEADd_DIM: tl.constexpr, # initial head dimension size HEADD_DIM: tl.constexpr, # final head dimension size BLOCK_KERNEL_SIZE: tl.constexpr, # Loaded kernel size when calculate the output BLOCK_HEADd_DIM: tl.constexpr, # Loaded orignal head dimension size BLOCK_HEADD_DIM: tl.constexpr, # loaded final head dimension size BLOCK_OUTPUT_SEQ_SIZE: tl.constexpr, # loaded final output seq parallel ): pid_bh = tl.program_id(0) pid_b = pid_bh // NUM_HEADS pid_h = pid_bh % NUM_HEADS pid_k = tl.program_id(1) pid_Dd = tl.program_id(2) pid_D = pid_Dd // NUM_PARALLEL_HEADd_NUM pid_d = pid_Dd % NUM_PARALLEL_HEADd_NUM x_start = tl.load(cu_seqlens_x + pid_b) x_end = tl.load(cu_seqlens_x + pid_b + 1) x_len = x_end - x_start y_start = tl.load(cu_seqlens_y + pid_b) y_end = tl.load(cu_seqlens_y + pid_b + 1) y_len = y_end - y_start if pid_k * BLOCK_OUTPUT_SEQ_SIZE >= y_len: return # pdb.set_trace() off_kernel_size = tl.arange(0, BLOCK_KERNEL_SIZE) off_d = tl.arange(0, BLOCK_HEADd_DIM) off_D = tl.arange(0, BLOCK_HEADD_DIM) off_output_seq_size = tl.arange(0, BLOCK_OUTPUT_SEQ_SIZE) x_ptrs = ( X + pid_h * stride_xh + x_start * stride_xn + ( ( pid_k * BLOCK_OUTPUT_SEQ_SIZE * KERNEL_STRIDE + off_output_seq_size * KERNEL_STRIDE )[:, None] + off_kernel_size[None, :] )[:, :, None] * stride_xn + (pid_d * BLOCK_HEADd_DIM + off_d)[None, None, :] * stride_xd ) x_mask = ( ( ( pid_k * BLOCK_OUTPUT_SEQ_SIZE * KERNEL_STRIDE + off_output_seq_size * KERNEL_STRIDE )[:, None] + off_kernel_size[None, :] ) < x_len )[:, :, None] & ((pid_d * BLOCK_HEADd_DIM + off_d) < HEADd_DIM)[None, None, :] dx_ptrs = ( DX + pid_h * stride_dxh + x_start * stride_dxn + ( ( pid_k * BLOCK_OUTPUT_SEQ_SIZE * KERNEL_STRIDE + off_output_seq_size * KERNEL_STRIDE )[:, None] + off_kernel_size[None, :] )[:, :, None] * stride_dxn + (pid_d * BLOCK_HEADd_DIM + off_d)[None, None, :] * stride_dxd ) w_ptrs = tl.make_block_ptr( base=W + pid_h * stride_wh, shape=(KERNEL_SIZE, HEADd_DIM, HEADD_DIM), strides=(stride_wk, stride_wd, stride_wD), offsets=(0, pid_d * BLOCK_HEADd_DIM, pid_D * BLOCK_HEADD_DIM), block_shape=(BLOCK_KERNEL_SIZE, BLOCK_HEADd_DIM, BLOCK_HEADD_DIM), order=(2, 1, 0), ) dw_ptrs = ( DW + pid_h * stride_dwh + off_kernel_size[:, None, None] * stride_dwk + (pid_d * BLOCK_HEADd_DIM + off_d)[None, :, None] * stride_dwd + (pid_D * BLOCK_HEADD_DIM + off_D)[None, None, :] * stride_dwD ) dw_mask = ( (off_kernel_size < KERNEL_SIZE)[:, None, None] & ((pid_d * BLOCK_HEADd_DIM + off_d) < HEADd_DIM)[None, :, None] & ((pid_D * BLOCK_HEADD_DIM + off_D) < HEADD_DIM)[None, None, :] ) dy_ptrs = tl.make_block_ptr( base=DY + y_start * stride_dyn + pid_h * stride_dyh, shape=(y_len, HEADD_DIM), strides=(stride_dyn, stride_dyd), offsets=(pid_k * BLOCK_OUTPUT_SEQ_SIZE, pid_D * BLOCK_HEADD_DIM), block_shape=(BLOCK_OUTPUT_SEQ_SIZE, BLOCK_HEADD_DIM), order=(1, 0), ) dy = tl.load(dy_ptrs, boundary_check=(0, 1), padding_option="zero") # dy : [by, D] # cal dx, start w = tl.load(w_ptrs, boundary_check=(0, 1, 2), padding_option="zero") # w: [k, bd, D] w = tl.reshape(w, (BLOCK_KERNEL_SIZE * BLOCK_HEADd_DIM, BLOCK_HEADD_DIM)) # w: [k * bd, D] dx = tl.dot(dy, tl.trans(w)) # dx: [by, k * bd] dx = tl.reshape(dx, (BLOCK_OUTPUT_SEQ_SIZE, BLOCK_KERNEL_SIZE, BLOCK_HEADd_DIM)) # dx: [by, k, bd] tl.atomic_add( dx_ptrs, dx.to(dx_ptrs.dtype.element_ty), mask=x_mask, ) # cal dx, end # cal dw, start x = tl.load(x_ptrs, mask=x_mask, other=0) x = tl.reshape(x, (BLOCK_OUTPUT_SEQ_SIZE, BLOCK_KERNEL_SIZE * BLOCK_HEADd_DIM)) # x : [by, k * bd] dw = tl.dot(tl.trans(x), dy) # dw: [k * bd, D] dw = tl.reshape(dw, (BLOCK_KERNEL_SIZE, BLOCK_HEADd_DIM, BLOCK_HEADD_DIM)) # dw: [k, bd, D] tl.atomic_add(dw_ptrs, dw.to(dw_ptrs.dtype.element_ty), mask=dw_mask) class LinearCompress(torch.autograd.Function): @staticmethod def forward( ctx, x: torch.Tensor, w: torch.Tensor, cu_seqlens: torch.Tensor, kernel_size: int, kernel_stride: int, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Compress key and value tensor with kernel_size and kernel_stride. Similar to conv_compress. Args: x (torch.Tensor): key_states or value_states, shape (total_len, num_heads, head_dim) w (torch.Tensor): weight for each head, shape (num_heads, kernel_size * head_dim, head_dim) cu_seqlens (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen kernel_size (int): kernel_size, each (kernel_size, head_dim) blocks will be compressed to (1, head_dim) kernel_stride (int): stride for each compress kernel Returns: Tuple[torch.Tensor, torch.Tensor]: compressed states and corresponding cu_seqlens. """ # dtype check assert x.dtype == torch.float16 or x.dtype == torch.bfloat16 assert x.dtype == w.dtype assert cu_seqlens.dtype == torch.int32 # shape check total_len, num_heads, head_dim = x.shape batch_size = cu_seqlens.shape[0] - 1 assert w.shape[0] == num_heads assert w.shape[1] == kernel_size * head_dim assert w.shape[2] == head_dim assert kernel_size % kernel_stride == 0 assert kernel_size in {16, 32, 64, 128} assert head_dim % 8 == 0 torch.cuda.set_device(x.device) # compute seqlens after compression seqlens = cu_seqlens[1:] - cu_seqlens[:-1] y_seqlens = ( torch.floor((seqlens - kernel_size) / kernel_stride).to(torch.int32) + 1 ) # corner case: if sequence_length < kernel_size, no compression for this sequence y_seqlens[seqlens < kernel_size] = 0 y_cu_seqlens = torch.cat( [ torch.zeros(1, dtype=torch.int32, device=x.device), torch.cumsum(y_seqlens.to(x.device), dim=0), ], dim=0, ).to(torch.int32) y = torch.zeros( y_cu_seqlens[-1], num_heads, head_dim, dtype=x.dtype, device=x.device ) block_kernel_size = max(16, triton.next_power_of_2(kernel_size)) block_head_dim = 8 if IS_HOPPER_GPU else 4 block_headD_dim = 32 block_output_seq_size = 64 w = w.reshape(num_heads, kernel_size, head_dim, head_dim).contiguous() grid = lambda META: ( batch_size * num_heads, triton.cdiv(y_seqlens.max(0)[0].item(), META["BLOCK_OUTPUT_SEQ_SIZE"]), triton.cdiv(head_dim, META["BLOCK_HEADD_DIM"]), ) linear_compress_fwd_kernel[grid]( x, y, w, cu_seqlens, y_cu_seqlens, x.stride(0), x.stride(1), x.stride(2), w.stride(0), w.stride(1), w.stride(2), w.stride(3), y.stride(0), y.stride(1), y.stride(2), num_heads, kernel_size, kernel_stride, head_dim, head_dim, block_output_seq_size, block_kernel_size, block_head_dim, block_headD_dim, # num_warps=8, # num_stages=3, ) # save for backward ctx.save_for_backward(x, w, cu_seqlens, y_seqlens, y_cu_seqlens) # save value ctx.kernel_size = kernel_size ctx.kernel_stride = kernel_stride ctx.block_kernel_size = block_kernel_size ctx.block_headd_dim = block_head_dim ctx.block_headD_dim = block_headD_dim ctx.block_output_seq_size = block_output_seq_size return y, y_cu_seqlens @staticmethod def backward(ctx, dy: torch.Tensor, *args) -> Any: x, w, cu_seqlens, y_seqlens, y_cu_seqlens = ctx.saved_tensors kernel_size = ctx.kernel_size kernel_stride = ctx.kernel_stride block_kernel_size = ctx.block_kernel_size block_head_dim = ctx.block_headd_dim block_headD_dim = ctx.block_headD_dim block_output_seq_size = ctx.block_output_seq_size total_len, num_heads, head_dim = x.shape batch_size = cu_seqlens.shape[0] - 1 dx = torch.zeros( cu_seqlens[-1], num_heads, head_dim, dtype=torch.float32, device=x.device ) dw = torch.zeros( num_heads, kernel_size, head_dim, head_dim, dtype=torch.float32, device=x.device, ) grid = lambda META: ( batch_size * num_heads, triton.cdiv(y_seqlens.max(0)[0].item(), META["BLOCK_OUTPUT_SEQ_SIZE"]), triton.cdiv(head_dim, META["BLOCK_HEADD_DIM"]) * triton.cdiv(head_dim, META["BLOCK_HEADd_DIM"]), ) linear_compress_bwd_kernel[grid]( dx, dy, dw, x, w, cu_seqlens, y_cu_seqlens, x.stride(0), x.stride(1), x.stride(2), w.stride(0), w.stride(1), w.stride(2), w.stride(3), dx.stride(0), dx.stride(1), dx.stride(2), dw.stride(0), dw.stride(1), dw.stride(2), dw.stride(3), dy.stride(0), dy.stride(1), dy.stride(2), num_heads, head_dim // block_head_dim, kernel_size, kernel_stride, head_dim, head_dim, block_kernel_size, block_head_dim, block_headD_dim, block_output_seq_size, ) return ( dx.to(x.dtype), rearrange(dw.to(x.dtype), "n k d D -> n (k d) D"), None, None, None, ) def linear_compress( x: torch.Tensor, w: torch.Tensor, cu_seqlens: torch.Tensor, kernel_size: int, kernel_stride: int, pe: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Compress key and value tensor with kernel_size and kernel_stride with linear projection. Args: x (torch.Tensor): key_states or value_states, shape (total_len, num_heads, head_dim) w (torch.Tensor): weight for each head, shape (num_heads, kernel_size * head_dim, head_dim) cu_seqlens (_type_): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen. kernel_size (int): kernel_size, each (kernel_size, head_dim) blocks will be compressed to (1, head_dim) kernel_stride (int): stride for each compress kernel pe (Optional[torch.Tensor], optional): intra-block positional embedding with shape (num_heads, kernel_size, head_dim). Defaults to None. Returns: Tuple[torch.Tensor, torch.Tensor]: compressed states and corresponding cu_seqlens. """ y, y_cu_seqlens = LinearCompress.apply(x, w, cu_seqlens, kernel_size, kernel_stride) # position embedding as a bias if pe is not None: assert pe.dtype == x.dtype and pe.device == x.device pe = rearrange(pe, "h k d -> h (k d)") bias = einsum(pe, w, "h D, h D d -> h d") y = y + bias.unsqueeze(0) return y, y_cu_seqlens ================================================ FILE: native_sparse_attention/ops/triton/topk_sparse_attention.py ================================================ # Copyright 2025 Xunhao Lai & Jianqiao Lu. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import Any, Optional import torch import triton import triton.language as tl from native_sparse_attention.ops.triton.utils import get_num_warps_stages, is_hopper_gpu IS_HOPPER_GPU = is_hopper_gpu() @triton.jit def forward_kernel( q_ptr, # Q: n x h x d k_ptr, # K: n x kh x d v_ptr, # V: n x kh x d t_ptr, # topk_idx: kh x n x k o_ptr, # O: n x h x d lse_ptr, # LSE: h x n # seqlens cu_seqlens_q, cu_seqlens_k, # shape NUM_KV_HEADS, NUM_SHARE_Q_HEADS, HEAD_DIM, TOPK, # q loop num num_q_loop, # sm_scale sm_scale, # stride stride_qn, stride_qh, stride_qd, stride_kn, stride_kh, stride_kd, stride_vn, stride_vh, stride_vd, stride_th, stride_tn, stride_tk, stride_on, stride_oh, stride_od, stride_lh, stride_ln, # META parameters BLOCK_SIZE_K: tl.constexpr, # k block size BLOCK_SIZE_D: tl.constexpr, BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_T: tl.constexpr, ): qk_scale = sm_scale * 1.44269504 # get batch id and head id pid_b = tl.program_id(0) pid_kh = tl.program_id(1) pid_h = pid_kh * NUM_SHARE_Q_HEADS pid_q = tl.program_id(2) # get q k start and len after rmpad q_start = tl.load(cu_seqlens_q + pid_b) q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start k_start = tl.load(cu_seqlens_k + pid_b) k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start if pid_q * num_q_loop >= q_len: return real_q_loop = min(num_q_loop, q_len - pid_q * num_q_loop) for j in range(real_q_loop): pid_q_j = pid_q * num_q_loop + j # init topk idx pointer off_t = tl.arange(0, BLOCK_SIZE_T) t_ptr_j = t_ptr + (q_start + pid_q_j) * stride_tn + pid_kh * stride_th topk_idx = tl.load(t_ptr_j + off_t * stride_tk, mask=off_t < TOPK, other=-1) real_topk = tl.sum( tl.where((topk_idx >= 0) & (topk_idx <= pid_q_j // BLOCK_SIZE_K), 1, 0), axis=0, ) # init qkv pointer q_ptrs = tl.make_block_ptr( base=q_ptr + (q_start + pid_q_j) * stride_qn + pid_h * stride_qh, shape=(NUM_SHARE_Q_HEADS, HEAD_DIM), strides=(stride_qh, stride_qd), offsets=(0, 0), block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D), order=(1, 0), ) k_ptrs = tl.make_block_ptr( base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, shape=(HEAD_DIM, k_len), strides=(stride_kd, stride_kn), offsets=(0, 0), block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), order=(0, 1), ) v_ptrs = tl.make_block_ptr( base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, shape=(k_len, HEAD_DIM), strides=(stride_vn, stride_vd), offsets=(0, 0), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), order=(1, 0), ) # load q q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero") # init statistics off_h = tl.arange(0, BLOCK_SIZE_H) off_k = tl.arange(0, BLOCK_SIZE_K) m_i = tl.full((BLOCK_SIZE_H,), float("-inf"), dtype=tl.float32) lse_i = tl.full((BLOCK_SIZE_H,), float("-inf"), dtype=tl.float32) acc_o = tl.full((BLOCK_SIZE_H, BLOCK_SIZE_D), 0, dtype=tl.float32) # sparse attention for i in range(real_topk): # get current block start index c = tl.load(t_ptr_j).to(tl.int32) * BLOCK_SIZE_K t_ptr_j = t_ptr_j + stride_tk # load k k = tl.load( tl.advance(k_ptrs, (0, c)), boundary_check=(1, 0), padding_option="zero" ) # compute qk qk = tl.zeros((BLOCK_SIZE_H, BLOCK_SIZE_K), dtype=tl.float32) qk += tl.where((pid_q_j >= c + off_k)[None, :], 0, float("-inf")) # [BLOCK_SIZE_H, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIZE_K] -> [BLOCK_SIZE_H, BLOCK_SIZE_K] qk += tl.dot(q, k) * qk_scale # compute m_ij and l_ij m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) p = tl.exp2(qk - m_ij[:, None]) l_ij = tl.sum(p, axis=1) # scale acc_o acc_o_scale = tl.exp2(m_i - m_ij) acc_o = acc_o * acc_o_scale[:, None] # load v and update acc_o v = tl.load( tl.advance(v_ptrs, (c, 0)), boundary_check=(0, 1), padding_option="zero" ) p = p.to(v.dtype) acc_o += tl.dot(p, v) # update statistics m_i = m_ij lse_i = m_ij + tl.math.log2(tl.exp2(lse_i - m_ij) + l_ij) # final scale acc_o = acc_o * tl.exp2(m_i - lse_i)[:, None] # save output o_ptrs = tl.make_block_ptr( base=o_ptr + (q_start + pid_q_j) * stride_on + pid_h * stride_oh, shape=(NUM_SHARE_Q_HEADS, HEAD_DIM), strides=(stride_oh, stride_od), offsets=(0, 0), block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D), order=(1, 0), ) tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1)) # save lse lse_ptrs = ( lse_ptr + (q_start + pid_q_j) * stride_ln + (pid_h + off_h) * stride_lh ) tl.store(lse_ptrs, lse_i, mask=off_h < NUM_SHARE_Q_HEADS) @triton.jit def backward_sum_o_do( o_ptr, # O: n x h x d do_ptr, # dO: n x h x d delta_ptr, # D: h x n o_len, HEAD_DIM, stride_on, stride_oh, stride_od, stride_don, stride_doh, stride_dod, stride_dh, stride_dn, BLOCK_SIZE_O: tl.constexpr, BLOCK_SIZE_D: tl.constexpr, ): pid_n = tl.program_id(0) pid_h = tl.program_id(1) off_o = pid_n * BLOCK_SIZE_O + tl.arange(0, BLOCK_SIZE_O) off_d = tl.arange(0, BLOCK_SIZE_D) o = tl.load( o_ptr + off_o[:, None] * stride_on + pid_h * stride_oh + off_d[None, :] * stride_od, mask=(off_o[:, None] < o_len) & (off_d[None, :] < HEAD_DIM), other=0, ).to(tl.float32) do = tl.load( do_ptr + off_o[:, None] * stride_don + pid_h * stride_doh + off_d[None, :] * stride_dod, mask=(off_o[:, None] < o_len) & (off_d[None, :] < HEAD_DIM), other=0, ).to(tl.float32) delta = tl.sum(o * do, axis=1) tl.store( delta_ptr + pid_h * stride_dh + off_o * stride_dn, delta, mask=off_o < o_len ) @triton.jit def count_kernel( x_ptr, # [num_kv_heads, total_len, topk] y_ptr, # [num_kv_heads, total_blocks] cu_seqlens, # [batch_size + 1] cu_seqblocks, # [batch_size + 1] topk, stride_xh, stride_xn, stride_xk, stride_yh, stride_yn, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_R: tl.constexpr, ): pid_h = tl.program_id(0) pid_b = tl.program_id(1) # get start and len after rmpad seq_start = tl.load(cu_seqlens + pid_b) seq_len = tl.load(cu_seqlens + pid_b + 1) - seq_start blocks_start = tl.load(cu_seqblocks + pid_b) num_blocks = tl.load(cu_seqblocks + pid_b + 1) - blocks_start # load x off_k = tl.arange(0, BLOCK_SIZE_K) off_n = tl.arange(0, BLOCK_SIZE_N) x_ptr = x_ptr + pid_h * stride_xh + seq_start * stride_xn x_ptrs = x_ptr + off_n[:, None] * stride_xn + off_k[None, :] * stride_xk # init y y = tl.zeros((BLOCK_SIZE_R,), dtype=tl.int32) # loop for i in range(0, seq_len, BLOCK_SIZE_N): x = tl.load( x_ptrs, mask=(off_n < seq_len - i)[:, None] & (off_k < topk)[None, :], other=-1, ) x = tl.ravel(x) y += tl.histogram(x, BLOCK_SIZE_R) x_ptrs += BLOCK_SIZE_N * stride_xn # store result off_r = tl.arange(0, BLOCK_SIZE_R) y_ptr = y_ptr + pid_h * stride_yh + blocks_start * stride_yn y_ptrs = y_ptr + off_r * stride_yn tl.store(y_ptrs, y.to(y_ptr.dtype.element_ty), mask=off_r < num_blocks) def count_query( topk_idx: torch.Tensor, cu_seqlens: torch.Tensor, cu_seqblocks: torch.Tensor, block_size: int, ): num_kv_heads, total_len, topk = topk_idx.shape seqlens = cu_seqlens[1:] - cu_seqlens[:-1] seqblocks = cu_seqblocks[1:] - cu_seqblocks[:-1] batch_size = seqlens.shape[0] BLOCK_SIZE_K = triton.next_power_of_2(topk) BLOCK_SIZE_N = triton.next_power_of_2(4096 // BLOCK_SIZE_K) BLOCK_SIZE_R = triton.next_power_of_2(seqblocks.max().item() + 2) active_query_count = torch.zeros( num_kv_heads, cu_seqblocks[-1], dtype=torch.int32, device=topk_idx.device ) grid = (num_kv_heads, batch_size) count_kernel[grid]( topk_idx, active_query_count, cu_seqlens, cu_seqblocks, topk, topk_idx.stride(0), topk_idx.stride(1), topk_idx.stride(2), active_query_count.stride(0), active_query_count.stride(1), BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, BLOCK_SIZE_R=BLOCK_SIZE_R, num_warps=4, num_stages=3, ) return active_query_count @triton.jit def pad_topk_idx_kernel( t_ptr, p_ptr, cu_seqlens, topk, stride_th, stride_tn, stride_tk, stride_pb, stride_ph, stride_pn, stride_pk, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_T: tl.constexpr, ): pid_b = tl.program_id(0) pid_h = tl.program_id(1) pid_n = tl.program_id(2) # get q start and len after rmpad q_start = tl.load(cu_seqlens + pid_b) q_len = tl.load(cu_seqlens + pid_b + 1) - q_start if BLOCK_SIZE_N * pid_n >= q_len: return # init prts t_ptrs = tl.make_block_ptr( base=t_ptr + pid_h * stride_th + q_start * stride_tn, shape=(q_len, topk), strides=(stride_tn, stride_tk), offsets=(pid_n * BLOCK_SIZE_N, 0), block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_T), order=(1, 0), ) p_ptrs = tl.make_block_ptr( base=p_ptr + pid_b * stride_pb + pid_h * stride_ph, shape=(q_len, topk), strides=(stride_pn, stride_pk), offsets=(pid_n * BLOCK_SIZE_N, 0), block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_T), order=(1, 0), ) # load and save idxs = tl.load(t_ptrs, boundary_check=(0, 1)) tl.store(p_ptrs, idxs, boundary_check=(0, 1)) @triton.jit def save_topk_idx_kernel( p_ptr, t_ptr, cu_seqblocks, cu_topk_q_count, n_len, stride_pb, stride_ph, stride_pn, stride_th, stride_tn, stride_ch, stride_cn, BLOCK_SIZE_N: tl.constexpr, ): pid_b = tl.program_id(0) pid_h = tl.program_id(1) pid_n = tl.program_id(2) # get q start and len after rmpad q_block_start = tl.load(cu_seqblocks + pid_b) q_block_end = tl.load(cu_seqblocks + pid_b + 1) c_start = tl.load(cu_topk_q_count + pid_h * stride_ch + q_block_start * stride_cn) c_end = tl.load(cu_topk_q_count + pid_h * stride_ch + q_block_end * stride_cn) c_len = c_end - c_start if c_len <= 0: return if pid_n * BLOCK_SIZE_N >= c_len: return # init ptrs p_ptrs = tl.make_block_ptr( base=p_ptr + pid_b * stride_pb + pid_h * stride_ph + (n_len - c_len) * stride_pn, shape=(c_len,), strides=(stride_pn,), offsets=(pid_n * BLOCK_SIZE_N,), block_shape=(BLOCK_SIZE_N,), order=(0,), ) t_ptrs = tl.make_block_ptr( base=t_ptr + pid_h * stride_th + c_start * stride_tn, shape=(c_len,), strides=(stride_tn,), offsets=(pid_n * BLOCK_SIZE_N,), block_shape=(BLOCK_SIZE_N,), order=(0,), ) # load and save idxs = tl.load(p_ptrs, boundary_check=(0,)) tl.store(t_ptrs, idxs, boundary_check=(0,)) def reorder_topk_idx( topk_idx: torch.Tensor, cu_topk_q_count: torch.Tensor, cu_seqlens: torch.Tensor, cu_seqblocks: torch.Tensor, block_size: int, ): num_kv_heads, total_len, topk = topk_idx.shape batch_size = cu_seqlens.shape[0] - 1 seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] max_seqlen = seq_lens.max().item() # pad shape [num_kv_heads, total_seqlen, topk] to [batch_size, num_kv_heads, max_seqlen, topk] pad_topk_idx = torch.full( (batch_size, num_kv_heads, max_seqlen, topk), fill_value=-1, device=topk_idx.device, dtype=torch.int32, ) BLOCK_SIZE_T = triton.next_power_of_2(topk) BLOCK_SIZE_N = min( triton.next_power_of_2(max_seqlen), triton.next_power_of_2(8192 // BLOCK_SIZE_T) ) grid = (batch_size, num_kv_heads, triton.cdiv(max_seqlen, BLOCK_SIZE_N)) pad_topk_idx_kernel[grid]( topk_idx, pad_topk_idx, cu_seqlens, topk, topk_idx.stride(0), topk_idx.stride(1), topk_idx.stride(2), pad_topk_idx.stride(0), pad_topk_idx.stride(1), pad_topk_idx.stride(2), pad_topk_idx.stride(3), BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_T=BLOCK_SIZE_T, ) # argsort pad_topk_q_idx = pad_topk_idx.view(batch_size, num_kv_heads, -1).argsort(-1) // topk pad_topk_q_idx = pad_topk_q_idx.to(torch.int32) # save as remove pad version topk_q_idx = torch.full( (num_kv_heads, cu_topk_q_count[:, -1].max().item()), fill_value=-1, device=topk_idx.device, dtype=torch.int32, ) max_len = ( ( cu_topk_q_count[:, cu_seqblocks][:, 1:] - cu_topk_q_count[:, cu_seqblocks][:, :-1] ) .max() .item() ) BLOCK_SIZE_N = min(triton.next_power_of_2(max_len), 8192) grid = (batch_size, num_kv_heads, triton.cdiv(max_len, BLOCK_SIZE_N)) save_topk_idx_kernel[grid]( pad_topk_q_idx, topk_q_idx, cu_seqblocks, cu_topk_q_count, pad_topk_q_idx.shape[-1], pad_topk_q_idx.stride(0), pad_topk_q_idx.stride(1), pad_topk_q_idx.stride(2), topk_q_idx.stride(0), topk_q_idx.stride(1), cu_topk_q_count.stride(0), cu_topk_q_count.stride(1), BLOCK_SIZE_N=BLOCK_SIZE_N, ) return topk_q_idx @triton.jit def backward_dkdv( q_ptr, # Q: n x qh x d k_ptr, # K: n x kh x d v_ptr, # V: n x kh x d tq_ptr, # topk_q_idx: kh x N lse_ptr, # LSE: qh x n d_ptr, # Delta: qh x n do_ptr, dk_ptr, # DK: sh x n x kh x d dv_ptr, # DK: sh x n x kh x d # seqlens cu_seqlens_q, # [batch_size + 1] cu_seqlens_k, # [batch_size + 1] cu_seqblocks, # [batch_size + 1] cu_topk_q_count, # [kh, total_blocks] # shape NUM_KV_HEADS, NUM_SHARE_Q_HEADS, HEAD_DIM, TOPK, # sm_scale sm_scale, # stride stride_qn, stride_qh, stride_qd, stride_kn, stride_kh, stride_kd, stride_vn, stride_vh, stride_vd, stride_tqh, stride_tqn, stride_ctqh, stride_ctqn, stride_lh, stride_ln, stride_dh, stride_dn, stride_don, stride_doh, stride_dod, stride_dks, stride_dkn, stride_dkh, stride_dkd, stride_dvs, stride_dvn, stride_dvh, stride_dvd, # META parameters BLOCK_SIZE_Q: tl.constexpr, # q block size BLOCK_SIZE_K: tl.constexpr, # k block size BLOCK_SIZE_D: tl.constexpr, ): qk_scale = sm_scale * 1.44269504 # get batch id and head id pid_b = tl.program_id(0) pid_h = tl.program_id(1) pid_kh = pid_h // NUM_SHARE_Q_HEADS pid_sh = pid_h % NUM_SHARE_Q_HEADS pid_k = tl.program_id(2) # get q k start and len after rmpad q_start = tl.load(cu_seqlens_q + pid_b) q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start k_start = tl.load(cu_seqlens_k + pid_b) k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start if BLOCK_SIZE_K * pid_k >= k_len: return # get topk_q_idx b_start = tl.load(cu_seqblocks + pid_b) # how many blocks before current sequence act_q_start = tl.load( cu_topk_q_count + pid_kh * stride_ctqh + (b_start + pid_k) * stride_ctqn ) act_q_end = tl.load( cu_topk_q_count + pid_kh * stride_ctqh + (b_start + pid_k + 1) * stride_ctqn ) act_q_len = act_q_end - act_q_start tq_ptr = tq_ptr + pid_kh * stride_tqh + act_q_start * stride_tqn # init pointers k_ptrs = tl.make_block_ptr( base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, shape=(k_len, HEAD_DIM), strides=(stride_kn, stride_kd), offsets=(pid_k * BLOCK_SIZE_K, 0), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), order=(1, 0), ) dk_ptrs = tl.make_block_ptr( base=dk_ptr + k_start * stride_dkn + pid_kh * stride_dkh + pid_sh * stride_dks, shape=(k_len, HEAD_DIM), strides=(stride_dkn, stride_dkd), offsets=(pid_k * BLOCK_SIZE_K, 0), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), order=(1, 0), ) v_ptrs = tl.make_block_ptr( base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, shape=(k_len, HEAD_DIM), strides=(stride_vn, stride_vd), offsets=(pid_k * BLOCK_SIZE_K, 0), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), order=(1, 0), ) dv_ptrs = tl.make_block_ptr( base=dv_ptr + k_start * stride_dvn + pid_kh * stride_dvh + pid_sh * stride_dvs, shape=(k_len, HEAD_DIM), strides=(stride_dvn, stride_dvd), offsets=(pid_k * BLOCK_SIZE_K, 0), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), order=(1, 0), ) # offsets off_q = tl.arange(0, BLOCK_SIZE_Q) off_k = tl.arange(0, BLOCK_SIZE_K) + pid_k * BLOCK_SIZE_K off_d = tl.arange(0, BLOCK_SIZE_D) # load k v and keep in SRAM k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero") v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero") # init dk dv dk = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32) dv = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32) # init ptrs q_ptrs = ( q_ptr + q_start * stride_qn + pid_h * stride_qh + off_d[None, :] * stride_qd ) do_ptrs = ( do_ptr + q_start * stride_don + pid_h * stride_doh + off_d[None, :] * stride_dod ) d_ptrs = d_ptr + q_start * stride_dn + pid_h * stride_dh lse_ptrs = lse_ptr + q_start * stride_ln + pid_h * stride_lh # loop for q blocks for i in range(0, act_q_len, BLOCK_SIZE_Q): # load idx_q = tl.load(tq_ptr + i + off_q, mask=off_q < act_q_len - i, other=0).to( tl.int32 ) q = tl.load( q_ptrs + idx_q[:, None] * stride_qn, mask=(off_q < act_q_len - i)[:, None] & (off_d < HEAD_DIM)[None, :], other=0, ) do = tl.load( do_ptrs + idx_q[:, None] * stride_don, mask=(off_q < act_q_len - i)[:, None] & (off_d < HEAD_DIM)[None, :], other=0, ) lse = tl.load( lse_ptrs + idx_q[:, None] * stride_ln, mask=(off_q < act_q_len - i)[:, None], other=0, ) d = tl.load( d_ptrs + idx_q[:, None] * stride_dn, mask=(off_q < act_q_len - i)[:, None], other=0, ) # compute qk qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) qk += tl.where(idx_q[:, None] >= off_k[None, :], float(0.0), float("-inf")) qk += tl.dot(q, k.T) * qk_scale # compute p, ds p = tl.exp2(qk - lse) dp = tl.dot(do, v.T) ds = sm_scale * p * (dp - d) # cast dtype p = p.to(do.dtype) ds = ds.to(q.dtype) # update dk and dv dk += tl.dot(ds.T, q) dv += tl.dot(p.T, do) # save dk dv tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), boundary_check=(0, 1)) tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), boundary_check=(0, 1)) @triton.jit def backward_dq( q_ptr, # Q: n x qh x d k_ptr, # K: n x kh x d v_ptr, # V: n x kh x d t_ptr, # topk_idx: kh x n x k lse_ptr, # LSE: qh x n d_ptr, # Delta: qh x n do_ptr, dq_ptr, # seqlens cu_seqlens_q, cu_seqlens_k, # shape NUM_KV_HEADS, NUM_SHARE_Q_HEADS, HEAD_DIM, TOPK, # q loop num num_q_loop, # sm_scale sm_scale, # stride stride_qn, stride_qh, stride_qd, stride_kn, stride_kh, stride_kd, stride_vn, stride_vh, stride_vd, stride_th, stride_tn, stride_tk, stride_lh, stride_ln, stride_dh, stride_dn, stride_don, stride_doh, stride_dod, stride_dqn, stride_dqh, stride_dqd, # META parameters BLOCK_SIZE_K: tl.constexpr, # k block size BLOCK_SIZE_D: tl.constexpr, BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_T: tl.constexpr, ): qk_scale = sm_scale * 1.44269504 # get batch id and head id pid_b = tl.program_id(0) pid_kh = tl.program_id(1) pid_q = tl.program_id(2) pid_h = pid_kh * NUM_SHARE_Q_HEADS # get q k start and len after rmpad q_start = tl.load(cu_seqlens_q + pid_b) q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start k_start = tl.load(cu_seqlens_k + pid_b) k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start if pid_q * num_q_loop >= q_len: return real_q_loop = min(num_q_loop, q_len - pid_q * num_q_loop) for j in range(real_q_loop): pid_q_j = pid_q * num_q_loop + j # init topk idx pointer off_t = tl.arange(0, BLOCK_SIZE_T) t_ptr_j = t_ptr + (q_start + pid_q_j) * stride_tn + pid_kh * stride_th topk_idx = tl.load(t_ptr_j + off_t * stride_tk, mask=off_t < TOPK, other=-1) real_topk = tl.sum( tl.where((topk_idx >= 0) & (topk_idx <= pid_q_j // BLOCK_SIZE_K), 1, 0), axis=0, ) # init pointers q_ptrs = tl.make_block_ptr( base=q_ptr + (q_start + pid_q_j) * stride_qn + pid_h * stride_qh, shape=(NUM_SHARE_Q_HEADS, HEAD_DIM), strides=(stride_qh, stride_qd), offsets=(0, 0), block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D), order=(1, 0), ) dq_ptrs = tl.make_block_ptr( base=dq_ptr + (q_start + pid_q_j) * stride_dqn + pid_h * stride_dqh, shape=(NUM_SHARE_Q_HEADS, HEAD_DIM), strides=(stride_dqh, stride_dqd), offsets=(0, 0), block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D), order=(1, 0), ) k_ptrs = tl.make_block_ptr( base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, shape=(k_len, HEAD_DIM), strides=(stride_kn, stride_kd), offsets=(0, 0), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), order=(1, 0), ) v_ptrs = tl.make_block_ptr( base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, shape=(HEAD_DIM, k_len), strides=(stride_vd, stride_vn), offsets=(0, 0), block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), order=(0, 1), ) do_ptrs = tl.make_block_ptr( base=do_ptr + (q_start + pid_q_j) * stride_don + pid_h * stride_doh, shape=(NUM_SHARE_Q_HEADS, HEAD_DIM), strides=(stride_doh, stride_dod), offsets=(0, 0), block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D), order=(1, 0), ) d_ptrs = tl.make_block_ptr( base=d_ptr + (q_start + pid_q_j) * stride_dn + pid_h * stride_dh, shape=(NUM_SHARE_Q_HEADS, 1), strides=(stride_dh, stride_dn), offsets=(0, 0), block_shape=(BLOCK_SIZE_H, 1), order=(1, 0), ) lse_ptrs = tl.make_block_ptr( base=lse_ptr + (q_start + pid_q_j) * stride_ln + pid_h * stride_lh, shape=(NUM_SHARE_Q_HEADS, 1), strides=(stride_lh, stride_ln), offsets=(0, 0), block_shape=(BLOCK_SIZE_H, 1), order=(1, 0), ) # offsets off_k = tl.arange(0, BLOCK_SIZE_K) # load q, do, lse, delta, and keep in SRAM q = tl.load(q_ptrs, boundary_check=(1, 0), padding_option="zero") do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero") lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero") d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero") # init dq dq = tl.zeros((BLOCK_SIZE_H, BLOCK_SIZE_D), dtype=tl.float32) # sparse for i in range(real_topk): # get current block start index c = tl.load(t_ptr_j).to(tl.int32) * BLOCK_SIZE_K t_ptr_j = t_ptr_j + stride_tk # load k = tl.load( tl.advance(k_ptrs, (c, 0)), boundary_check=(1, 0), padding_option="zero" ) v = tl.load( tl.advance(v_ptrs, (0, c)), boundary_check=(0, 1), padding_option="zero" ) # compute qk qk = tl.zeros((BLOCK_SIZE_H, BLOCK_SIZE_K), dtype=tl.float32) qk += tl.where((pid_q_j >= c + off_k)[None, :], 0, float("-inf")) # [BLOCK_SIZE_H, HEAD_DIM] @ [BLOCK_SIZE_K, HEAD_DIM].T -> [BLOCK_SIZE_H, BLOCK_SIZE_K] qk += tl.dot(q, tl.trans(k)) * qk_scale # compute p, ds p = tl.exp2(qk - lse) dp = tl.dot(do, v) ds = sm_scale * p * (dp - d) # cast dtype ds = ds.to(q.dtype) # update dq dq += tl.dot(ds, k) # save dq tl.store(dq_ptrs, dq.to(dq_ptr.dtype.element_ty), boundary_check=(0, 1)) def _topk_sparse_attention_fwd( q: torch.Tensor, # [total_len, num_q_heads, head_dim] k: torch.Tensor, # [total_len, num_k_heads, head_dim] v: torch.Tensor, # [total_len, num_k_heads, head_dim] topk_idx: torch.Tensor, # [num_kv_heads, total_len, topk] block_size: int, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, sm_scale: float, ): # dtype check assert k.dtype == q.dtype and v.dtype == q.dtype assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 assert block_size in {32, 64, 128, 256} # shape q_len, num_q_heads, head_dim = q.shape k_len, num_k_heads, head_dim = k.shape v_len, num_v_heads, head_dim = v.shape batch_size = cu_seqlens_q.shape[0] - 1 assert q_len == k_len and k_len == v_len topk = topk_idx.shape[-1] assert topk_idx.shape[0] == num_k_heads assert topk_idx.shape[1] == q_len # gqa assert num_k_heads == num_v_heads assert num_q_heads % num_k_heads == 0 num_share_q_heads = num_q_heads // num_k_heads # output tensor o = torch.zeros_like(q) lse = torch.zeros(num_q_heads, q_len, dtype=torch.float32, device=q.device) # launch kernel num_q_loop = ( max_seqlen_q // 32768 + 1 ) # calculate multiple querys in one kernel if seqlence length is too long grid = (batch_size, num_k_heads, triton.cdiv(max_seqlen_q, num_q_loop)) BLOCK_SIZE_K = triton.next_power_of_2(block_size) BLOCK_SIZE_D = triton.next_power_of_2(head_dim) BLOCK_SIZE_H = max(16, triton.next_power_of_2(num_share_q_heads)) BLOCK_SIZE_T = triton.next_power_of_2(topk) num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_K, IS_HOPPER_GPU) forward_kernel[grid]( q, k, v, topk_idx, o, lse, cu_seqlens_q, cu_seqlens_k, num_k_heads, num_share_q_heads, head_dim, topk, num_q_loop, sm_scale, q.stride(0), q.stride(1), q.stride(2), k.stride(0), k.stride(1), k.stride(2), v.stride(0), v.stride(1), v.stride(2), topk_idx.stride(0), topk_idx.stride(1), topk_idx.stride(2), o.stride(0), o.stride(1), o.stride(2), lse.stride(0), lse.stride(1), BLOCK_SIZE_K=BLOCK_SIZE_K, BLOCK_SIZE_D=BLOCK_SIZE_D, BLOCK_SIZE_H=BLOCK_SIZE_H, BLOCK_SIZE_T=BLOCK_SIZE_T, num_warps=num_warps, num_stages=num_stages, ) return o, lse def _topk_sparse_attention_bwd( o: torch.Tensor, do: torch.Tensor, lse: torch.Tensor, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, topk_idx: torch.Tensor, block_size: int, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, sm_scale: float, ): assert block_size in {32, 64, 128, 256} q_len, num_q_heads, head_dim = q.shape k_len, num_k_heads, head_dim = k.shape v_len, num_v_heads, head_dim = v.shape o_len, num_o_heads, head_dim = o.shape num_share_q_heads = num_q_heads // num_k_heads topk = topk_idx.shape[-1] # compute D delta = torch.zeros([num_o_heads, o_len], device=o.device, dtype=torch.float32) BLOCK_SIZE_O = 256 BLOCK_SIZE_D = triton.next_power_of_2(head_dim) num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_O, IS_HOPPER_GPU) grid = (triton.cdiv(o_len, BLOCK_SIZE_O), num_o_heads) backward_sum_o_do[grid]( o, do, delta, o_len, head_dim, o.stride(0), o.stride(1), o.stride(2), do.stride(0), do.stride(1), do.stride(2), delta.stride(0), delta.stride(1), BLOCK_SIZE_O=BLOCK_SIZE_O, BLOCK_SIZE_D=BLOCK_SIZE_D, num_warps=num_warps, num_stages=num_stages, ) # count active querys for each key block, shape: (num_k_heads, total_k_blocks) seqlens = cu_seqlens_q[1:] - cu_seqlens_q[:-1] seqblocks = torch.ceil(seqlens / block_size).to(torch.int32) cu_seqblocks = torch.cat( [ torch.zeros(1, dtype=torch.int32, device=topk_idx.device), torch.cumsum(seqblocks, dim=0), ] ).to(torch.int32) topk_q_count = count_query(topk_idx, cu_seqlens_q, cu_seqblocks, block_size) cu_topk_q_count = torch.cat( [ torch.zeros( topk_q_count.shape[0], 1, dtype=torch.int32, device=topk_idx.device ), torch.cumsum(topk_q_count, dim=-1), ], dim=-1, ).to(torch.int32) # active query idx for each key block # how to get active query idx for sequence b, head h, kv block i? # topk_q_idx[h, cu_topk_q_count[h, cu_seqblocks[b] + i]:cu_topk_q_count[h, cu_seqblocks[b] + i + 1]] topk_q_idx = reorder_topk_idx( topk_idx, cu_topk_q_count, cu_seqlens_q, cu_seqblocks, block_size ) # compute dk dv dk = torch.zeros( num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype ) dv = torch.zeros( num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype ) batch_size = cu_seqlens_q.shape[0] - 1 BLOCK_SIZE_K = triton.next_power_of_2(block_size) BLOCK_SIZE_Q = 64 BLOCK_SIZE_D = triton.next_power_of_2(head_dim) num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU) grid = (batch_size, num_q_heads, triton.cdiv(max_seqlen_k, BLOCK_SIZE_K)) backward_dkdv[grid]( q, k, v, topk_q_idx, lse, delta, do, dk, dv, cu_seqlens_q, cu_seqlens_k, cu_seqblocks, cu_topk_q_count, num_k_heads, num_share_q_heads, head_dim, topk, sm_scale, q.stride(0), q.stride(1), q.stride(2), k.stride(0), k.stride(1), k.stride(2), v.stride(0), v.stride(1), v.stride(2), topk_q_idx.stride(0), topk_q_idx.stride(1), cu_topk_q_count.stride(0), cu_topk_q_count.stride(1), lse.stride(0), lse.stride(1), delta.stride(0), delta.stride(1), do.stride(0), do.stride(1), do.stride(2), dk.stride(0), dk.stride(1), dk.stride(2), dk.stride(3), dv.stride(0), dv.stride(1), dv.stride(2), dv.stride(3), BLOCK_SIZE_Q=BLOCK_SIZE_Q, BLOCK_SIZE_K=BLOCK_SIZE_K, BLOCK_SIZE_D=BLOCK_SIZE_D, num_warps=num_warps, num_stages=num_stages, ) dk = dk.sum(0) dv = dv.sum(0) # compute dq dq = torch.zeros_like(q) num_q_loop = ( max_seqlen_q // 32768 + 1 ) # calculate multiple querys in one kernel if seqlence length is too long grid = (batch_size, num_k_heads, triton.cdiv(max_seqlen_q, num_q_loop)) BLOCK_SIZE_K = block_size BLOCK_SIZE_D = triton.next_power_of_2(head_dim) BLOCK_SIZE_H = max(16, triton.next_power_of_2(num_share_q_heads)) BLOCK_SIZE_T = triton.next_power_of_2(topk) num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_K, IS_HOPPER_GPU) backward_dq[grid]( q, k, v, topk_idx, lse, delta, do, dq, cu_seqlens_q, cu_seqlens_k, num_k_heads, num_share_q_heads, head_dim, topk, num_q_loop, sm_scale, q.stride(0), q.stride(1), q.stride(2), k.stride(0), k.stride(1), k.stride(2), v.stride(0), v.stride(1), v.stride(2), topk_idx.stride(0), topk_idx.stride(1), topk_idx.stride(2), lse.stride(0), lse.stride(1), delta.stride(0), delta.stride(1), do.stride(0), do.stride(1), do.stride(2), dq.stride(0), dq.stride(1), dq.stride(2), BLOCK_SIZE_K=BLOCK_SIZE_K, BLOCK_SIZE_D=BLOCK_SIZE_D, BLOCK_SIZE_H=BLOCK_SIZE_H, BLOCK_SIZE_T=BLOCK_SIZE_T, num_warps=num_warps, num_stages=num_stages, ) return dq, dk, dv class TopkSparseAttention(torch.autograd.Function): @staticmethod def forward( ctx, q: torch.Tensor, # [total_len, num_q_heads, head_dim] k: torch.Tensor, # [total_len, num_k_heads, head_dim] v: torch.Tensor, # [total_len, num_k_heads, head_dim] topk_idx: torch.Tensor, # [num_kv_heads, total_len, topk] block_size: int, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, sm_scale=None, ): # dtype check assert q.dtype == torch.bfloat16 or q.dtype == torch.float16 assert q.dtype == k.dtype and k.dtype == v.dtype assert topk_idx.dtype == torch.int32 assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 # softmax scale if sm_scale is None: sm_scale = 1 / math.sqrt(q.shape[-1]) o, lse = _topk_sparse_attention_fwd( q, k, v, topk_idx, block_size, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, sm_scale, ) ctx.save_for_backward(q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k, topk_idx) ctx.sm_scale = sm_scale ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_k = max_seqlen_k ctx.block_size = block_size # return return o @staticmethod def backward(ctx, do: torch.Tensor, *args) -> Any: q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k, topk_idx = ctx.saved_tensors max_seqlen_q = ctx.max_seqlen_q max_seqlen_k = ctx.max_seqlen_k sm_scale = ctx.sm_scale block_size = ctx.block_size assert block_size in {32, 64, 128, 256} dq, dk, dv = _topk_sparse_attention_bwd( o, do, lse, q, k, v, topk_idx, block_size, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, sm_scale, ) return dq, dk, dv, None, None, None, None, None, None, None, None def topk_sparse_attention( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, topk_idx: torch.Tensor, block_size: int, cu_seqlens: torch.Tensor, softmax_scale: Optional[float] = None, ) -> torch.Tensor: """Topk sparse attention varlen version implemented in triton. Args: q (torch.Tensor): shape [total_len, num_q_heads, head_dim] k (torch.Tensor): shape [total_len, num_kv_heads, head_dim] v (torch.Tensor): shape [total_len, num_kv_heads, head_dim] topk_idx (torch.Tensor): topk block idx for each query, shape [num_kv_heads, total_len, topk]. -1 means padding. block_size (int): key value block size. cu_seqlens (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens in flash_attn_func_varlen. softmax_scale (Optional[float], optional): Defaults to None, means 1/sqrt(head_dim). Returns: torch.Tensor: attention output, shape [total_len, num_q_heads, head_dim] """ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() return TopkSparseAttention.apply( q, k, v, topk_idx, block_size, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, softmax_scale, ) ================================================ FILE: native_sparse_attention/ops/triton/topk_sparse_attention_decode.py ================================================ # Copyright 2025 Xunhao Lai & Jianqiao Lu. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import torch import triton import triton.language as tl from typing import Optional @triton.jit def forward_kernel( q_ptr, # Q: b x h x d k_ptr, # K: b x n x kh x d v_ptr, # V: b x n x kh x d t_ptr, # topk_idx: kh x b x k o_ptr, # O: b x h x d # seqlens seqlens, # shape NUM_SHARE_Q_HEADS, HEAD_DIM, TOPK, # sm_scale sm_scale, # stride stride_qb, stride_qh, stride_qd, stride_kb, stride_kn, stride_kh, stride_kd, stride_vb, stride_vn, stride_vh, stride_vd, stride_th, stride_tb, stride_tk, stride_ob, stride_oh, stride_od, # META parameters BLOCK_SIZE_K: tl.constexpr, # k block size BLOCK_SIZE_D: tl.constexpr, BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_T: tl.constexpr, ): qk_scale = sm_scale * 1.44269504 # get batch id and head id pid_b = tl.program_id(0) pid_kh = tl.program_id(1) # get kv_len kv_len = tl.load(seqlens + pid_b) # init topk idx pointer off_t = tl.arange(0, BLOCK_SIZE_T) t_ptr = t_ptr + pid_b * stride_tb + pid_kh * stride_th topk_idx = tl.load(t_ptr + off_t * stride_tk, mask=off_t < TOPK, other=-1) real_topk = tl.sum( tl.where((topk_idx >= 0) & (topk_idx <= (kv_len - 1) // BLOCK_SIZE_K), 1, 0), axis=0, ) # init qkv pointer q_ptrs = tl.make_block_ptr( base=q_ptr + pid_b * stride_qb + pid_kh * NUM_SHARE_Q_HEADS * stride_qh, shape=(NUM_SHARE_Q_HEADS, HEAD_DIM), strides=(stride_qh, stride_qd), offsets=(0, 0), block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D), order=(1, 0), ) k_ptrs = tl.make_block_ptr( base=k_ptr + pid_b * stride_kb + pid_kh * stride_kh, shape=(HEAD_DIM, kv_len), strides=(stride_kd, stride_kn), offsets=(0, 0), block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), order=(0, 1), ) v_ptrs = tl.make_block_ptr( base=v_ptr + pid_b * stride_vb + pid_kh * stride_vh, shape=(kv_len, HEAD_DIM), strides=(stride_vn, stride_vd), offsets=(0, 0), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), order=(1, 0), ) # load q q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero") # init statistics off_k = tl.arange(0, BLOCK_SIZE_K) m_i = tl.full((BLOCK_SIZE_H,), float("-inf"), dtype=tl.float32) lse_i = tl.full((BLOCK_SIZE_H,), float("-inf"), dtype=tl.float32) acc_o = tl.full((BLOCK_SIZE_H, BLOCK_SIZE_D), 0, dtype=tl.float32) # sparse attention for i in range(real_topk): # get current block start index c = tl.load(t_ptr).to(tl.int32) * BLOCK_SIZE_K t_ptr = t_ptr + stride_tk # load k k = tl.load( tl.advance(k_ptrs, (0, c)), boundary_check=(1, 0), padding_option="zero" ) # compute qk qk = tl.zeros((BLOCK_SIZE_H, BLOCK_SIZE_K), dtype=tl.float32) qk += tl.where((kv_len > c + off_k)[None, :], 0, float("-inf")) # [BLOCK_SIZE_H, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIZE_K] -> [BLOCK_SIZE_H, BLOCK_SIZE_K] qk += tl.dot(q, k) * qk_scale # compute m_ij and l_ij m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) p = tl.exp2(qk - m_ij[:, None]) l_ij = tl.sum(p, axis=1) # scale acc_o acc_o_scale = tl.exp2(m_i - m_ij) acc_o = acc_o * acc_o_scale[:, None] # load v and update acc_o v = tl.load( tl.advance(v_ptrs, (c, 0)), boundary_check=(0, 1), padding_option="zero" ) p = p.to(v.dtype) acc_o += tl.dot(p, v) # update statistics m_i = m_ij lse_i = m_ij + tl.math.log2(tl.exp2(lse_i - m_ij) + l_ij) # final scale acc_o = acc_o * tl.exp2(m_i - lse_i)[:, None] # save output o_ptrs = tl.make_block_ptr( base=o_ptr + pid_b * stride_ob + pid_kh * NUM_SHARE_Q_HEADS * stride_oh, shape=(NUM_SHARE_Q_HEADS, HEAD_DIM), strides=(stride_oh, stride_od), offsets=(0, 0), block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D), order=(1, 0), ) tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1)) def topk_sparse_attention_decode( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, topk_idx: torch.Tensor, block_size: int, seqlens: torch.Tensor, sm_scale: Optional[float] = None, ) -> torch.Tensor: """_summary_ Args: q (torch.Tensor): shape [batch_size, num_q_heads, head_dim] k (torch.Tensor): shape [batch_size, kv_len, num_kv_heads, head_dim] v (torch.Tensor): shape [batch_size, kv_len, num_kv_heads, head_dim] topk_idx (torch.Tensor): topk block idx for each query, shape [num_kv_heads, batch_size, topk]. -1 means padding. block_size (int): key value block size. seqlens (torch.Tensor): max kv length for each sequence softmax_scale (Optional[float], optional): Defaults to None, means 1/sqrt(head_dim). Returns: torch.Tensor: sparse attention output """ # dtype check assert q.dtype == torch.bfloat16 or q.dtype == torch.float16 assert k.dtype == q.dtype and v.dtype == q.dtype assert seqlens.dtype == torch.int32 # shape batch_size, num_q_heads, head_dim = q.shape _, k_len, num_k_heads, head_dim = k.shape _, v_len, num_v_heads, head_dim = v.shape assert k_len == v_len and batch_size == seqlens.shape[0] assert num_k_heads == topk_idx.shape[0] and batch_size == topk_idx.shape[1] topk = topk_idx.shape[-1] # gqa assert num_k_heads == num_v_heads assert num_q_heads % num_k_heads == 0 num_share_q_heads = num_q_heads // num_k_heads # sm scale if sm_scale is None: sm_scale = 1 / math.sqrt(head_dim) # output tensor o = torch.zeros_like(q) # launch kernel grid = (batch_size, num_k_heads) num_warps = 4 if head_dim <= 64 else 8 num_stages = 3 BLOCK_SIZE_K = triton.next_power_of_2(block_size) BLOCK_SIZE_D = triton.next_power_of_2(head_dim) BLOCK_SIZE_H = max(16, triton.next_power_of_2(num_share_q_heads)) BLOCK_SIZE_T = triton.next_power_of_2(topk) forward_kernel[grid]( q, k, v, topk_idx, o, seqlens, num_share_q_heads, head_dim, topk, sm_scale, q.stride(0), q.stride(1), q.stride(2), k.stride(0), k.stride(1), k.stride(2), k.stride(3), v.stride(0), v.stride(1), v.stride(2), v.stride(3), topk_idx.stride(0), topk_idx.stride(1), topk_idx.stride(2), o.stride(0), o.stride(1), o.stride(2), BLOCK_SIZE_K=BLOCK_SIZE_K, BLOCK_SIZE_D=BLOCK_SIZE_D, BLOCK_SIZE_H=BLOCK_SIZE_H, BLOCK_SIZE_T=BLOCK_SIZE_T, num_warps=num_warps, num_stages=num_stages, ) return o def torch_topk_sparse_attention_decode( q: torch.Tensor, # [batch_size, num_q_heads, head_dim] k: torch.Tensor, # [batch_size, kv_len, num_k_heads, head_dim] v: torch.Tensor, # [batch_size, kv_len, num_k_heads, head_dim] topk_idx: torch.Tensor, # [num_k_heads, batch_size, topk] block_size: int, seqlens: torch.Tensor, # [batch_size, ] sm_scale: Optional[float] = None, ): # dtype check assert q.dtype == torch.bfloat16 or q.dtype == torch.float16 assert k.dtype == q.dtype and v.dtype == q.dtype assert seqlens.dtype == torch.int32 # shape batch_size, num_q_heads, head_dim = q.shape _, k_len, num_k_heads, head_dim = k.shape _, v_len, num_v_heads, head_dim = v.shape assert k_len == v_len and batch_size == seqlens.shape[0] assert num_k_heads == topk_idx.shape[0] and batch_size == topk_idx.shape[1] topk = topk_idx.shape[-1] # gqa assert num_k_heads == num_v_heads assert num_q_heads % num_k_heads == 0 num_share_q_heads = num_q_heads // num_k_heads # sm scale if sm_scale is None: sm_scale = 1 / math.sqrt(head_dim) # mask mask = torch.zeros( (batch_size, num_k_heads, k_len), dtype=torch.bool, device=q.device ) for b in range(batch_size): for h in range(num_k_heads): for t in range(topk): if topk_idx[h, b, t] != -1: mask[ b, h, topk_idx[h, b, t] * block_size : (topk_idx[h, b, t] + 1) * block_size, ] = True mask = mask & ( (seqlens - 1)[:, None, None] >= torch.arange(k_len).cuda()[None, None, :] ) mask = mask.repeat_interleave(num_share_q_heads, 1) # attention attn = ( torch.einsum( "bqhd,bkhd->bhqk", q.unsqueeze(1), k.repeat_interleave(num_share_q_heads, 2) ) * sm_scale ) attn = attn.masked_fill(~mask.unsqueeze(2), -torch.inf) attn = torch.softmax(attn, dim=-1, dtype=torch.float32).to(q.dtype) out = torch.einsum( "bhqk,bkhd->bqhd", attn, v.repeat_interleave(num_share_q_heads, 2) ).squeeze(1) return out def generate_topk_idx_example( seqlens: torch.Tensor, block_size: int, topk: int, num_heads: int ) -> torch.Tensor: batch_size = seqlens.shape[0] num_blocks = torch.ceil(seqlens / block_size).to(torch.int32) topk_idx_all_heads = [] for _ in range(num_heads): topk_idx = [ torch.randn(1, num_blocks[i], device="cuda") .topk(min(topk, num_blocks[i]), dim=-1) .indices.to(torch.int32) for i in range(batch_size) ] topk_idx = [ torch.nn.functional.pad( topk_idx[i], (0, topk - topk_idx[i].shape[-1]), value=topk ) for i in range(batch_size) ] topk_idx = torch.cat(topk_idx, dim=0) topk_idx = torch.sort(topk_idx, dim=1).values topk_idx[:, 0] = 0 q_idx = seqlens - 1 topk_idx[topk_idx > (q_idx // block_size)[:, None]] = -1 # -1 means padding topk_idx_all_heads.append(topk_idx) topk_idx = torch.stack(topk_idx_all_heads, dim=0) return topk_idx if __name__ == "__main__": torch.manual_seed(42) topk = 16 block_size = 64 batch_size = 76 max_length = 8192 seqlens = torch.arange(batch_size, dtype=torch.int32).cuda() * 128 + 1 seqlens[seqlens > max_length] = max_length seqlens = seqlens[torch.randn_like(seqlens, dtype=torch.float32).argsort(-1)] q = ( torch.empty(batch_size, 32, 128, device="cuda") .uniform_(-1, 1) .to(torch.float16) ) k = ( torch.empty(batch_size, max_length, 4, 128, device="cuda") .uniform_(-1, 1) .to(torch.float16) ) v = ( torch.empty(batch_size, max_length, 4, 128, device="cuda") .uniform_(-1, 1) .to(torch.float16) ) topk_idx = generate_topk_idx_example(seqlens, block_size, topk, 4) o1 = torch_topk_sparse_attention_decode(q, k, v, topk_idx, block_size, seqlens) o2 = topk_sparse_attention_decode(q, k, v, topk_idx, block_size, seqlens) print(torch.allclose(o1, o2, atol=1e-3, rtol=1e-3)) print((o1 - o2).abs().max()) ================================================ FILE: native_sparse_attention/ops/triton/utils.py ================================================ # Copyright 2025 Xunhao Lai & Jianqiao Lu. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch def is_hopper_gpu(): if torch.cuda.is_available(): device_capability = torch.cuda.get_device_capability() major, minor = device_capability return major == 9 return False def get_compressed_seqlens( cu_seqlens: torch.Tensor, kernel_size: int, kernel_stride: int ): # compute seqlens after compression seqlens = cu_seqlens[1:] - cu_seqlens[:-1] y_seqlens = torch.floor((seqlens - kernel_size) / kernel_stride).to(torch.int32) + 1 # corner case, if sequence_length < kernel_size, no compression for this sequence y_seqlens[seqlens < kernel_size] = 0 y_cu_seqlens = torch.zeros( y_seqlens.shape[0] + 1, dtype=torch.int32, device=cu_seqlens.device ) y_cu_seqlens[1:] = torch.cumsum(y_seqlens, dim=0) return y_seqlens, y_cu_seqlens def get_num_warps_stages(head_dim, block_size, is_hopper_gpu): """ Returns recommended num_warps and num_stages for a Sparse Attention kernel in Triton. Args: head_dim (int): Size of the head dimension. block_size (int): Size of the block in the attention matrix. is_hopper_gpu (bool): True if Hopper GPU, False if Ampere GPU. Returns: tuple: (num_warps, num_stages) recommended values. """ # Determine if head_dim and block_size exceed 64 head_large = head_dim > 64 block_large = block_size > 64 if is_hopper_gpu: # Hopper GPU recommendations if head_large and block_large: num_warps = 8 num_stages = 3 elif head_large or block_large: num_warps = 4 num_stages = 3 else: num_warps = 2 num_stages = 2 else: # Ampere GPU recommendations if head_large and block_large: num_warps = 8 num_stages = 3 elif head_large or block_large: num_warps = 8 num_stages = 3 else: num_warps = 2 num_stages = 2 if head_dim > 128: num_stages = 2 return num_warps, num_stages ================================================ FILE: native_sparse_attention/ops/triton/weighted_pool.py ================================================ # Copyright 2025 Xunhao Lai & Jianqiao Lu. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional from einops import einsum import torch import triton import triton.language as tl from native_sparse_attention.ops.triton.utils import get_compressed_seqlens @triton.jit def sliding_pool_fwd_kernel( x_ptr, y_ptr, w_ptr, cu_seqlens, y_cu_seqlens, head_dim, kernel_size, kernel_stride, stride_xn, stride_xh, stride_xd, stride_yn, stride_yh, stride_yd, stride_wh, stride_wk, BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_D: tl.constexpr, ): pid_b = tl.program_id(0) pid_h = tl.program_id(1) pid_k = tl.program_id(2) # get start and len after rmpad x_start = tl.load(cu_seqlens + pid_b) x_len = tl.load(cu_seqlens + pid_b + 1) - x_start y_start = tl.load(y_cu_seqlens + pid_b) y_len = tl.load(y_cu_seqlens + pid_b + 1) - y_start if pid_k >= y_len: return if w_ptr is not None: # load w w_ptrs = tl.make_block_ptr( base=w_ptr + pid_h * stride_wh, shape=(kernel_size, 1), strides=(stride_wk, 0), offsets=(0, 0), block_shape=(BLOCK_SIZE_K, 1), order=(0, 1), ) w = tl.load(w_ptrs, boundary_check=(0, 1), padding_option="zero") # load x x_ptrs = tl.make_block_ptr( base=x_ptr + x_start * stride_xn + pid_h * stride_xh, shape=(x_len, head_dim), strides=(stride_xn, stride_xd), offsets=(pid_k * kernel_stride, 0), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), order=(1, 0), ) x = tl.load(x_ptrs, boundary_check=(0, 1), padding_option="zero") # compute y if w_ptr is not None: y = tl.sum(x * w, axis=0) else: y = tl.sum(x, axis=0) / kernel_size off_d = tl.arange(0, BLOCK_SIZE_D) tl.store( y_ptr + (y_start + pid_k) * stride_yn + pid_h * stride_yh + off_d * stride_yd, y.to(y_ptr.dtype.element_ty), mask=off_d < head_dim, ) @triton.jit def sliding_pool_dxdw_kernel( x_ptr, dx_ptr, dy_ptr, w_ptr, dw_ptr, cu_seqlens, y_cu_seqlens, head_dim, kernel_size, kernel_stride, stride_xn, stride_xh, stride_xd, stride_dxn, stride_dxh, stride_dxd, stride_dyn, stride_dyh, stride_dyd, stride_wh, stride_wk, stride_dwh, stride_dwn, stride_dwk, BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_D: tl.constexpr, ): pid_b = tl.program_id(0) pid_h = tl.program_id(1) pid_k = tl.program_id(2) # get start and len after rmpad x_start = tl.load(cu_seqlens + pid_b) x_len = tl.load(cu_seqlens + pid_b + 1) - x_start y_start = tl.load(y_cu_seqlens + pid_b) y_len = tl.load(y_cu_seqlens + pid_b + 1) - y_start if pid_k >= y_len: return # offsets off_d = tl.arange(0, BLOCK_SIZE_D) off_k = tl.arange(0, BLOCK_SIZE_K) if w_ptr is not None: # load w w_ptrs = w_ptr + pid_h * stride_wh + off_k * stride_wk w = tl.load(w_ptrs, mask=off_k < kernel_size, other=0) # load x x_ptrs = tl.make_block_ptr( base=x_ptr + x_start * stride_xn + pid_h * stride_xh, shape=(head_dim, x_len), strides=(stride_xd, stride_xn), offsets=(0, pid_k * kernel_stride), block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), order=(0, 1), ) x = tl.load(x_ptrs, boundary_check=(0, 1), padding_option="zero") # load dy dy_ptrs = ( dy_ptr + pid_h * stride_dyh + (y_start + pid_k) * stride_dyn + off_d * stride_dyd ) dy = tl.load(dy_ptrs, mask=off_d < head_dim, other=0) if w_ptr is not None: # compute dx, [1, D] x [K, 1] -> [K, D] dx = dy[None, :] * w[:, None] # compute dw, [D, 1] x [D, K] -> [D, K] -> [K] dw = tl.sum(dy[:, None] * x, axis=0) # store dw dw_ptrs = ( dw_ptr + pid_h * stride_dwh + (y_start + pid_k) * stride_dwn + off_k * stride_dwk ) tl.store(dw_ptrs, dw.to(dw_ptr.dtype.element_ty), mask=off_k < kernel_size) else: dx = dy[None, :] / kernel_size # store dx dx_ptrs = ( dx_ptr + pid_h * stride_dxh + (x_start + pid_k * kernel_stride + off_k[:, None]) * stride_dxn + off_d[None, :] * stride_dxd ) tl.atomic_add( dx_ptrs, dx.to(dx_ptr.dtype.element_ty), mask=(off_k < x_len - pid_k * kernel_stride)[:, None] & (off_d < head_dim)[None, :], ) class SlidingWindowWeightedPool(torch.autograd.Function): @staticmethod def forward( ctx, x: torch.Tensor, # [total_len, num_heads, head_dim] w: torch.Tensor, # [num_heads, kernel_size] cu_seqlens: torch.Tensor, kernel_size: int, kernel_stride: int, ): # dtype check assert x.dtype == torch.float16 or x.dtype == torch.bfloat16 if w is not None: assert x.dtype == w.dtype assert cu_seqlens.dtype == torch.int32 # shape check total_len, num_heads, head_dim = x.shape batch_size = cu_seqlens.shape[0] - 1 if w is not None: assert w.shape[0] == num_heads assert w.shape[1] == kernel_size assert kernel_size % kernel_stride == 0 assert kernel_size in {16, 32, 64, 128} # compute seqlens after compression seqlens = cu_seqlens[1:] - cu_seqlens[:-1] y_seqlens, y_cu_seqlens = get_compressed_seqlens( cu_seqlens, kernel_size, kernel_stride ) # output buffer y = torch.zeros( y_cu_seqlens[-1], num_heads, head_dim, dtype=x.dtype, device=x.device ) # launch kernel BLOCK_SIZE_D = triton.next_power_of_2(head_dim) BLOCK_SIZE_K = triton.next_power_of_2(kernel_size) grid = (batch_size, num_heads, y_seqlens.max().item()) sliding_pool_fwd_kernel[grid]( x, y, w, cu_seqlens, y_cu_seqlens, head_dim, kernel_size, kernel_stride, x.stride(0), x.stride(1), x.stride(2), y.stride(0), y.stride(1), y.stride(2), w.stride(0) if w is not None else None, w.stride(1) if w is not None else None, BLOCK_SIZE_K=BLOCK_SIZE_K, BLOCK_SIZE_D=BLOCK_SIZE_D, ) ctx.save_for_backward(x, w, seqlens, cu_seqlens, y_seqlens, y_cu_seqlens) ctx.kernel_size = kernel_size ctx.kernel_stride = kernel_stride ctx.head_dim = head_dim return y, y_cu_seqlens @staticmethod def backward(ctx, dy, _): x, w, seqlens, cu_seqlens, y_seqlens, y_cu_seqlens = ctx.saved_tensors kernel_size = ctx.kernel_size kernel_stride = ctx.kernel_stride head_dim = ctx.head_dim batch_size = cu_seqlens.shape[0] - 1 num_heads = x.shape[1] # compute dx dx = torch.zeros_like(x, dtype=torch.float32) if w is not None: dw = torch.zeros( num_heads, y_cu_seqlens[-1], kernel_size, dtype=torch.float32, device=w.device, ) BLOCK_SIZE_D = triton.next_power_of_2(head_dim) BLOCK_SIZE_K = triton.next_power_of_2(kernel_size) grid = (batch_size, num_heads, y_seqlens.max().item()) sliding_pool_dxdw_kernel[grid]( x, dx, dy, w, dw if w is not None else None, cu_seqlens, y_cu_seqlens, head_dim, kernel_size, kernel_stride, x.stride(0), x.stride(1), x.stride(2), dx.stride(0), dx.stride(1), dx.stride(2), dy.stride(0), dy.stride(1), dy.stride(2), w.stride(0) if w is not None else None, w.stride(1) if w is not None else None, dw.stride(0) if w is not None else None, dw.stride(1) if w is not None else None, dw.stride(2) if w is not None else None, BLOCK_SIZE_K=BLOCK_SIZE_K, BLOCK_SIZE_D=BLOCK_SIZE_D, ) dx = dx.to(x.dtype) if w is None: dw = None else: dw = dw.sum(1).to(w.dtype) return dx, dw, None, None, None def weightedpool_compress( x: torch.Tensor, # [total_len, num_heads, head_dim] w: torch.Tensor, # [num_heads, kernel_size] cu_seqlens: torch.Tensor, kernel_size: int, kernel_stride: int, pe: Optional[torch.Tensor] = None, ): """Compress key and value tensor with kernel_size and kernel_stride with weighted pooling Args: x (torch.Tensor): key_states or value_states, shape (total_len, num_heads, head_dim) w (torch.Tensor): weight for each head, shape (num_heads, kernel_size) cu_seqlens (_type_): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen. kernel_size (int): kernel_size, each (kernel_size, head_dim) blocks will be compressed to (1, head_dim) kernel_stride (int): stride for each compress kernel pe (Optional[torch.Tensor], optional): intra-block positional embedding with shape (num_heads, kernel_size, head_dim). Defaults to None. Returns: Tuple[torch.Tensor, torch.Tensor]: compressed states and corresponding cu_seqlens. """ y, y_cu_seqlens = SlidingWindowWeightedPool.apply( x, w, cu_seqlens, kernel_size, kernel_stride ) # position embedding as a bias if pe is not None: assert pe.dtype == x.dtype and pe.device == x.device bias = einsum(pe, w, "h k d, h k -> h d") y = y + bias.unsqueeze(0) return y, y_cu_seqlens def avgpool_compress( x: torch.Tensor, # [total_len, num_heads, head_dim] w: torch.Tensor, # don't need weight cu_seqlens: torch.Tensor, kernel_size: int, kernel_stride: int, pe: Optional[torch.Tensor] = None, ): """Compress key and value tensor with kernel_size and kernel_stride with average pooling. Args: x (torch.Tensor): key_states or value_states, shape (total_len, num_heads, head_dim) w (torch.Tensor): weight for each head, shape (num_heads, kernel_size) cu_seqlens (_type_): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen. kernel_size (int): kernel_size, each (kernel_size, head_dim) blocks will be compressed to (1, head_dim) kernel_stride (int): stride for each compress kernel pe (Optional[torch.Tensor], optional): intra-block positional embedding with shape (num_heads, kernel_size, head_dim). Defaults to None. Returns: Tuple[torch.Tensor, torch.Tensor]: compressed states and corresponding cu_seqlens. """ assert w is None, "don't need additional weight for avgpool" y, y_cu_seqlens = SlidingWindowWeightedPool.apply( x, w, cu_seqlens, kernel_size, kernel_stride ) # position embedding as a bias if pe is not None: assert pe.dtype == x.dtype and pe.device == x.device bias = torch.mean(pe, dim=1) y = y + bias.unsqueeze(0) return y, y_cu_seqlens if __name__ == "__main__": from native_sparse_attention.ops.torch.compress_key_value import ( weightedpool_compress_torch, ) torch.manual_seed(42) num_heads = 4 head_dim = 128 kernel_size = 32 kernel_stride = 16 seqlens = torch.LongTensor([12, 1000, 2000, 4096]).int().cuda() cu_seqlens = torch.cat( [ torch.zeros(1, dtype=torch.int32, device="cuda"), torch.cumsum(seqlens, dim=0), ], dim=0, ).to(torch.int32) x = ( torch.zeros(cu_seqlens[-1], num_heads, head_dim) .uniform_(-1, 1) .to(torch.bfloat16) .cuda() .requires_grad_() ) w = ( torch.zeros(num_heads, kernel_size) .uniform_(-1 / 32, 1 / 32) .cuda() .to(torch.bfloat16) .requires_grad_() ) y, y_cu_seqlens = weightedpool_compress_torch( x, w, cu_seqlens, kernel_size, kernel_stride, None ) x1 = x.clone().detach().requires_grad_() w1 = w.clone().detach().requires_grad_() y1, y1_cu_seqlens = weightedpool_compress( x1, w1, cu_seqlens, kernel_size, kernel_stride ) print(torch.allclose(y, y1, rtol=1e-2, atol=1e-2)) print(torch.abs(y - y1).max().item()) randn = torch.randn_like(y) randn1 = randn.clone().detach() loss = (y * randn).sum() loss1 = (y1 * randn1).sum() loss.backward() loss1.backward() print((x.grad - x1.grad).abs().max()) print(((w.grad - w1.grad).abs() / w.grad.abs()).max()) ================================================ FILE: setup.py ================================================ # Copyright 2025 Xunhao Lai & Jianqiao Lu. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from setuptools import setup, find_packages # Read the README.md file for the long description with open("README.md", "r", encoding="utf-8") as fh: long_description = fh.read() # Define the setup configuration setup( name="native-sparse-attention-triton", version="0.1.0", description="An efficient implementation of Native Sparse Attention using Triton", long_description=long_description, long_description_content_type="text/markdown", author="XunhaoLai", author_email="laixunhao@pku.edu.cn", # Replace with your actual email url="https://github.com/XunhaoLai/native-sparse-attention-triton", packages=find_packages(), install_requires=[ "torch>=2.1.0", "triton>=3.0.0", "einops>=0.7.0", "flash-attn>=2.6.3", "transformers>=4.44.0", ], classifiers=[ "Programming Language :: Python :: 3", "License :: OSI Approved :: Apache Software License", "Operating System :: OS Independent", ], python_requires=">=3.9", ) ================================================ FILE: test/test_compress_key_value.py ================================================ # Copyright 2025 Xunhao Lai & Jianqiao Lu. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific import torch import triton from native_sparse_attention.ops import linear_compress if __name__ == "__main__": torch.manual_seed(42) num_heads = 4 head_dim = 192 kernel_size = 32 kernel_stride = 16 seqlens = torch.LongTensor([1000, 2000, 4096]).int().cuda() cu_seqlens = torch.cat( [ torch.zeros(1, dtype=torch.int32, device="cuda"), torch.cumsum(seqlens, dim=0), ], dim=0, ).to(torch.int32) x = ( torch.zeros(cu_seqlens[-1], num_heads, head_dim) .uniform_(-1, 1) .cuda() .bfloat16() .requires_grad_() ) w = ( torch.zeros(num_heads, kernel_size * head_dim, head_dim) .uniform_(-1, 1) .cuda() .bfloat16() .requires_grad_() ) pe = ( torch.zeros(num_heads, kernel_size, head_dim) .uniform_(-1, 1) .cuda() .bfloat16() .requires_grad_() ) y, y_cu_seqlens = linear_compress(x, w, cu_seqlens, kernel_size, kernel_stride, pe) loss = (y * torch.randn_like(y)).mean() loss.backward() print(y.shape, y_cu_seqlens) print(y.norm(), x.grad.norm()) print( w.grad.norm() if w.grad is not None else None, pe.grad.norm() if pe.grad is not None else None, ) # benchmark @triton.testing.perf_report( triton.testing.Benchmark( x_names=["N"], x_vals=[1024 * 2**i for i in range(1, 6)], line_arg="provider", line_vals=["batch1", "batch8", "batch32"], line_names=["batch1", "batch8", "batch32"], styles=[("green", "-"), ("blue", "-"), ("blue", "--")], ylabel="ms", plot_name="** forward **", args={"H": 4, "D": 128}, ) ) def benchmark(N, H, D, provider): K, S = 32, 16 x = torch.zeros(N, H, D, device="cuda", dtype=torch.bfloat16).uniform_(-1, 1) w = torch.zeros(H, K * D, D, device="cuda", dtype=torch.bfloat16).uniform_( -1, 1 ) pe = torch.zeros(H, K, D, device="cuda", dtype=torch.bfloat16).uniform_(-1, 1) cu_seqlens_b1 = torch.LongTensor([0, N]).int().cuda() cu_seqlens_b8 = ( torch.LongTensor([N // 8 if i > 0 else 0 for i in range(9)]).int().cuda() ) cu_seqlens_b32 = ( torch.LongTensor([N // 32 if i > 0 else 0 for i in range(33)]).int().cuda() ) cu_seqlens_b1 = cu_seqlens_b1.cumsum(0).to(torch.int32) cu_seqlens_b8 = cu_seqlens_b8.cumsum(0).to(torch.int32) cu_seqlens_b32 = cu_seqlens_b32.cumsum(0).to(torch.int32) quantiles = [0.5, 0.2, 0.8] if provider == "batch1": ms, min_ms, max_ms = triton.testing.do_bench( lambda: linear_compress(x, w, cu_seqlens_b1, K, S, pe), quantiles=quantiles, ) if provider == "batch8": ms, min_ms, max_ms = triton.testing.do_bench( lambda: linear_compress(x, w, cu_seqlens_b8, K, S, pe), quantiles=quantiles, ) if provider == "batch32": ms, min_ms, max_ms = triton.testing.do_bench( lambda: linear_compress(x, w, cu_seqlens_b32, K, S, pe), quantiles=quantiles, ) return ms, min_ms, max_ms benchmark.run(show_plots=True, print_data=True) ================================================ FILE: test/test_compressed_attention.py ================================================ # Copyright 2025 Xunhao Lai & Jianqiao Lu. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific import torch import triton import math from native_sparse_attention.ops.torch.compressed_attention import ( compressed_attention_torch, ) from native_sparse_attention.ops.triton.compressed_attention import ( compressed_attention, _compressed_attention_bwd, ) from native_sparse_attention.ops import avgpool_compress from native_sparse_attention.ops.triton.flash_attention import ( flash_attention_varlen, _flash_attention_bwd, ) from flash_attn import flash_attn_varlen_func from flash_attn.flash_attn_interface import _flash_attn_varlen_backward if __name__ == "__main__": torch.manual_seed(42) num_heads = 32 head_dim = 96 kernel_size = 32 kernel_stride = 16 block_size = 64 topk = 16 seqlens = torch.LongTensor([1000, 4000, 8192]).int().cuda() cu_seqlens = torch.cat( [ torch.zeros(1, dtype=torch.int32, device="cuda"), torch.cumsum(seqlens, dim=0), ], dim=0, ).to(torch.int32) max_seqlen = seqlens.max().item() q = ( torch.empty(cu_seqlens[-1], num_heads, head_dim, device="cuda") .uniform_(-1, 1) .to(torch.float16) ) k = ( torch.empty(cu_seqlens[-1], num_heads // 4, head_dim, device="cuda") .uniform_(-1, 1) .to(torch.float16) ) v = ( torch.empty(cu_seqlens[-1], num_heads // 4, head_dim, device="cuda") .uniform_(-1, 1) .to(torch.float16) ) q.requires_grad = True k.requires_grad = True v.requires_grad = True ck, ck_cu_seqlens = avgpool_compress( k, None, cu_seqlens, kernel_size, kernel_stride ) ck = torch.empty_like(ck).uniform_(-1, 1) cv = torch.empty_like(ck).uniform_(-1, 1) ck.requires_grad = True cv.requires_grad = True ck_seqlens = ck_cu_seqlens[1:] - ck_cu_seqlens[:-1] ck_max_seqlen = ck_seqlens.max().item() o, topk_idx = compressed_attention_torch( q, ck, cv, kernel_size, kernel_stride, block_size, topk, cu_seqlens, ck_cu_seqlens, max_seqlen, ck_max_seqlen, ) randn = torch.randn_like(o) loss = (o * randn).sum() loss.backward() torch.manual_seed(42) q1 = q.detach().clone().requires_grad_() ck1 = ck.detach().clone().requires_grad_() cv1 = cv.detach().clone().requires_grad_() o1, topk_idx1 = compressed_attention( q1, ck1, cv1, kernel_size, kernel_stride, block_size, topk, cu_seqlens, ck_cu_seqlens, max_seqlen, ck_max_seqlen, ) randn1 = randn.clone().detach() loss1 = (o1 * randn1).sum() loss1.backward() print("Same Output:", torch.allclose(o, o1, atol=0.01, rtol=0.01)) print("Max Error:", (o - o1).abs().max().item()) print() print("Same Query Gradient:", torch.allclose(q.grad, q1.grad, atol=0.01, rtol=0.01)) print("Max Query Gradient Error:", (q.grad - q1.grad).abs().max().item()) print() print("Same Key Gradient:", torch.allclose(ck.grad, ck1.grad, atol=0.01, rtol=0.01)) print("Max Key Gradient Error:", (ck.grad - ck1.grad).abs().max().item()) print() print( "Same Value Gradient:", torch.allclose(cv.grad, cv1.grad, atol=0.01, rtol=0.01) ) print("Max Value Gradient Error:", (cv.grad - cv1.grad).abs().max().item()) print() # There are some discrepancies in the topk indices (about 3%). These might be due to bugs and will be addressed later. all_num = 0 err_num = 0 for h in range(topk_idx.shape[0]): for i in range(topk_idx.shape[1]): s = set(topk_idx[h, i][topk_idx[h, i] >= 0].tolist()) s1 = set(topk_idx1[h, i][topk_idx1[h, i] >= 0].tolist()) all_num += len(s) err_num += len(s) - len(s1 & s) print("Topk Idx Error Rate:", err_num / all_num) # benchmark @triton.testing.perf_report( triton.testing.Benchmark( x_names=["N"], x_vals=[1024 * 2**i for i in range(1, 8)], line_arg="provider", line_vals=[ "flash", "triton-flash", "triton-compressed", "triton-compressed-wo-score", ], line_names=[ "Flash", "Triton-Flash", "Compressed", "Compressed-wo-Score", ], styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")], ylabel="ms", plot_name="** forward speed for compressed attention (kernel 32 stride 16) **", args={"H": 64, "D": 128}, ) ) def benchmark(N, H, D, provider): q = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16) k = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16) v = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16) cu_seqlens = torch.tensor([0, N], device="cuda", dtype=torch.int32) sm_scale = 1 / math.sqrt(D) com_k, com_cu_seqlens = avgpool_compress(k, None, cu_seqlens, 32, 16, None) com_v, com_cu_seqlens = avgpool_compress(v, None, cu_seqlens, 32, 16, None) M = (com_cu_seqlens[1:] - com_cu_seqlens[:-1]).max().item() quantiles = [0.5, 0.2, 0.8] if provider == "flash": ms, min_ms, max_ms = triton.testing.do_bench( lambda: flash_attn_varlen_func( q, k, v, cu_seqlens, cu_seqlens, N, N, dropout_p=0.0, causal=True, softmax_scale=sm_scale, ), quantiles=quantiles, ) if provider == "triton-flash": ms, min_ms, max_ms = triton.testing.do_bench( lambda: flash_attention_varlen( q, k, v, cu_seqlens, cu_seqlens, N, N, True, sm_scale ), quantiles=quantiles, ) if provider == "triton-compressed": ms, min_ms, max_ms = triton.testing.do_bench( lambda: compressed_attention( q, com_k, com_v, 32, 16, 64, 16, cu_seqlens, com_cu_seqlens, N, M, sm_scale, ), quantiles=quantiles, ) if provider == "triton-compressed-wo-score": ms, min_ms, max_ms = triton.testing.do_bench( lambda: compressed_attention( q, com_k, com_v, 32, 16, 64, -1, cu_seqlens, com_cu_seqlens, N, M, sm_scale, ), quantiles=quantiles, ) return ms, min_ms, max_ms benchmark.run(show_plots=True, print_data=True) # benchmark @triton.testing.perf_report( triton.testing.Benchmark( x_names=["N"], x_vals=[1024 * 2**i for i in range(1, 8)], line_arg="provider", line_vals=[ "flash", "triton-flash", "triton-compressed", ], line_names=[ "Flash", "Triton-Flash", "Compressed", ], styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")], ylabel="ms", plot_name="** backward speed for compressed attention (kernel 32 stride 16) **", args={"H": 64, "D": 128}, ) ) def benchmark(N, H, D, provider): q = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16) k = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16) v = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16) o = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16) do = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16) lse = torch.randn((H, N), device="cuda", dtype=torch.float32) sm_scale = 1 / math.sqrt(D) cu_seqlens = torch.tensor([0, N], device="cuda", dtype=torch.int32) dq = torch.zeros_like(q) dk = torch.zeros_like(k) dv = torch.zeros_like(v) com_k, com_cu_seqlens = avgpool_compress(k, None, cu_seqlens, 32, 16, None) com_v, com_cu_seqlens = avgpool_compress(v, None, cu_seqlens, 32, 16, None) M = (com_cu_seqlens[1:] - com_cu_seqlens[:-1]).max().item() quantiles = [0.5, 0.2, 0.8] if provider == "flash": ms, min_ms, max_ms = triton.testing.do_bench( lambda: _flash_attn_varlen_backward( do, q, k, v, o, lse.transpose(0, 1), dq, dk, dv, cu_seqlens, cu_seqlens, N, N, dropout_p=0.0, causal=True, softmax_scale=sm_scale, window_size=(-1, -1), softcap=0.0, alibi_slopes=None, deterministic=False, ), quantiles=quantiles, ) if provider == "triton-flash": ms, min_ms, max_ms = triton.testing.do_bench( lambda: _flash_attention_bwd( o, do, lse, q, k, v, cu_seqlens, cu_seqlens, N, N, True, sm_scale ), quantiles=quantiles, ) if provider == "triton-compressed": ms, min_ms, max_ms = triton.testing.do_bench( lambda: _compressed_attention_bwd( o, do, lse, q, com_k, com_v, 32, 16, cu_seqlens, com_cu_seqlens, N, M, sm_scale, ), quantiles=quantiles, ) return ms, min_ms, max_ms benchmark.run(show_plots=True, print_data=True) ================================================ FILE: test/test_flash_attention.py ================================================ # Copyright 2025 Xunhao Lai & Jianqiao Lu. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific import torch import triton import math from native_sparse_attention.ops.triton.flash_attention import ( flash_attention_varlen, _flash_attention_fwd, _flash_attention_bwd, ) from flash_attn import flash_attn_varlen_func from flash_attn.flash_attn_interface import ( _flash_attn_varlen_forward, _flash_attn_varlen_backward, ) if __name__ == "__main__": for causal in [False, True]: # triton flash attention torch.manual_seed(42) q = torch.randn( 1000, 32, 128, dtype=torch.float16, device="cuda", requires_grad=True ) k = torch.randn( 1000, 16, 128, dtype=torch.float16, device="cuda", requires_grad=True ) v = torch.randn( 1000, 16, 128, dtype=torch.float16, device="cuda", requires_grad=True ) cu_seqlens_q = torch.Tensor([0, 100, 384, 1000]).cuda().to(torch.int32) cu_seqlens_k = torch.Tensor([0, 100, 384, 1000]).cuda().to(torch.int32) max_seqlen_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max() max_seqlen_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max() o = flash_attn_varlen_func( q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal=causal, ) randn = torch.randn_like(o) loss = (o * randn).sum() loss.backward() # flash attention torch.manual_seed(42) q1 = q.clone().detach().requires_grad_() k1 = k.clone().detach().requires_grad_() v1 = v.clone().detach().requires_grad_() cu_seqlens_q1 = cu_seqlens_q.clone().detach() cu_seqlens_k1 = cu_seqlens_k.clone().detach() max_seqlen_q1 = (cu_seqlens_q1[1:] - cu_seqlens_q1[:-1]).max() max_seqlen_k1 = (cu_seqlens_k1[1:] - cu_seqlens_k1[:-1]).max() o1 = flash_attention_varlen( q1, k1, v1, cu_seqlens_q1, cu_seqlens_k1, max_seqlen_q1, max_seqlen_k1, causal=causal, ) randn2 = randn.clone().detach() loss2 = (o1 * randn2).sum() loss2.backward() # diff print( f"=== Flash Attention Backward Test ({'causal' if causal else 'full'}) ===" ) print("Same Output:", torch.allclose(o, o1, atol=0.01, rtol=0.01)) print("Max Error:", (o - o1).abs().max().item()) print() print( "Same Query Gradient:", torch.allclose(q.grad, q1.grad, atol=0.01, rtol=0.01), ) print("Max Query Gradient Error:", (q.grad - q1.grad).abs().max().item()) print() print( "Same Key Gradient:", torch.allclose(k.grad, k1.grad, atol=0.01, rtol=0.01) ) print("Max Key Gradient Error:", (k.grad - k1.grad).abs().max().item()) print() print( "Same Value Gradient:", torch.allclose(v.grad, v1.grad, atol=0.01, rtol=0.01), ) print("Max Value Gradient Error:", (v.grad - v1.grad).abs().max().item()) print() # benchmark @triton.testing.perf_report( triton.testing.Benchmark( x_names=["N"], x_vals=[1024 * 2**i for i in range(1, 6)], line_arg="provider", line_vals=["flash", "triton-flash"], line_names=[ "Flash", "Triton-Flash", ], styles=[("green", "-"), ("green", "--")], ylabel="ms", plot_name="** forward **", args={"H": 64, "D": 128}, ) ) def benchmark(N, H, D, provider): q = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16) k = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16) v = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16) cu_seqlens = torch.tensor([0, N], device="cuda", dtype=torch.int32) sm_scale = 1 / math.sqrt(D) quantiles = [0.5, 0.2, 0.8] if provider == "flash": ms, min_ms, max_ms = triton.testing.do_bench( lambda: _flash_attn_varlen_forward( q, k, v, cu_seqlens, cu_seqlens, N, N, dropout_p=0.0, causal=True, softmax_scale=sm_scale, ), quantiles=quantiles, ) if provider == "triton-flash": ms, min_ms, max_ms = triton.testing.do_bench( lambda: _flash_attention_fwd( q, k, v, cu_seqlens, cu_seqlens, N, N, True, sm_scale ), quantiles=quantiles, ) return ms, min_ms, max_ms benchmark.run(show_plots=True, print_data=True) # benchmark @triton.testing.perf_report( triton.testing.Benchmark( x_names=["N"], x_vals=[1024 * 2**i for i in range(1, 6)], line_arg="provider", line_vals=["flash", "triton-flash"], line_names=[ "Flash", "Triton-Flash", ], styles=[("green", "-"), ("green", "--")], ylabel="ms", plot_name="** backward **", args={"H": 64, "D": 128}, ) ) def benchmark(N, H, D, provider): q = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16) k = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16) v = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16) o = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16) do = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16) lse = torch.randn((H, N), device="cuda", dtype=torch.float32) sm_scale = 1 / math.sqrt(D) cu_seqlens = torch.tensor([0, N], device="cuda", dtype=torch.int32) dq = torch.zeros_like(q) dk = torch.zeros_like(k) dv = torch.zeros_like(v) quantiles = [0.5, 0.2, 0.8] if provider == "flash": ms, min_ms, max_ms = triton.testing.do_bench( lambda: _flash_attn_varlen_backward( do, q, k, v, o, lse.transpose(0, 1), dq, dk, dv, cu_seqlens, cu_seqlens, N, N, dropout_p=0.0, causal=True, softmax_scale=sm_scale, window_size=(-1, -1), softcap=0.0, alibi_slopes=None, deterministic=False, ), quantiles=quantiles, ) if provider == "triton-flash": ms, min_ms, max_ms = triton.testing.do_bench( lambda: _flash_attention_bwd( o, do, lse, q, k, v, cu_seqlens, cu_seqlens, N, N, True, sm_scale ), quantiles=quantiles, ) return ms, min_ms, max_ms benchmark.run(show_plots=True, print_data=True) ================================================ FILE: test/test_kv_cache.py ================================================ # Copyright 2025 Xunhao Lai & Jianqiao Lu. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from native_sparse_attention.module.kv_cache import NSACache if __name__ == "__main__": from native_sparse_attention.ops import avgpool_compress torch.manual_seed(42) num_heads = 4 head_dim = 128 seqlens = torch.tensor([12, 576, 12000]).to(torch.int32).cuda() batch_size = seqlens.shape[0] cu_seqlens = torch.zeros(seqlens.shape[0] + 1, dtype=torch.int32, device="cuda") cu_seqlens[1:] = seqlens.cumsum(0) # init cache cache = NSACache(4, 16384, num_heads, head_dim, 32, 16, 512, torch.bfloat16, "cuda") # test prefill step = 0 k = torch.randn(cu_seqlens[-1], num_heads, head_dim).cuda().bfloat16() v = torch.randn_like(k) ck, _ = avgpool_compress(k, None, cu_seqlens, 32, 16, None) cv, _ = avgpool_compress(v, None, cu_seqlens, 32, 16, None) cache.prepare_compress(cu_seqlens, step, k, v) cache.update_kv(cu_seqlens, step, ck, cv, k, v, k, v) # test decode step = 1 k = torch.randn(batch_size, num_heads, head_dim).cuda().bfloat16() v = torch.randn_like(k) ck = torch.randn_like(k) cv = torch.randn_like(v) cache.prepare_compress(cu_seqlens, step, k, v) cache.update_kv(cu_seqlens, step, ck, cv, k, v, k, v) ================================================ FILE: test/test_linear_compress.py ================================================ # Copyright 2025 Xunhao Lai & Jianqiao Lu. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import triton from native_sparse_attention.ops.torch.compress_key_value import linear_compress_torch from native_sparse_attention.ops.triton.linear_compress import linear_compress def test_linear_compress( batch_size: int = 1, num_heads: int = 1, head_dim: int = 32, max_seqlen: int = 32, kernel_sizes: list = [16, 32], kernel_strides: list = [8, 16], use_pe: bool = True, dtype: torch.dtype = torch.float32, device: str = "cuda", ): """ Test both PyTorch and Triton implementations of linear_compress for equivalence, including forward and backward passes. Args: batch_size: Number of sequences in the batch num_heads: Number of attention heads head_dim: Dimension of each attention head max_seqlen: Maximum sequence length kernel_sizes: List of kernel sizes to test kernel_strides: List of kernel strides to test use_pe: Whether to test with positional encoding dtype: Data type for tensors device: Device to run the test on """ torch.manual_seed(42) # Generate random sequence lengths for each batch seqlens = torch.randint( low=kernel_sizes[0], # minimum length should be at least kernel_size high=max_seqlen + 1, size=(batch_size,), device=device, ) # seqlens[:] = max_seqlen cu_seqlens = torch.cat( [ torch.tensor([0], device=device, dtype=torch.int32), torch.cumsum(seqlens, dim=0).to(torch.int32), ] ) total_len = cu_seqlens[-1].item() for kernel_size, kernel_stride in zip(kernel_sizes, kernel_strides): print(f"\nTesting kernel_size={kernel_size}, kernel_stride={kernel_stride}") # Create input tensors with requires_grad=True x_torch = torch.zeros( (total_len, num_heads, head_dim), dtype=dtype, device=device, ).uniform_(-1, 1) x_torch.requires_grad_(True) x_triton = x_torch.clone().detach().requires_grad_(True) w_torch = ( torch.ones( (num_heads, kernel_size * head_dim, head_dim), dtype=dtype, device=device, ) / kernel_size ) w_torch.requires_grad_(True) w_triton = w_torch.clone().detach().requires_grad_(True) pe_torch = None pe_triton = None if use_pe: pe_torch = torch.randn( (num_heads, kernel_size, head_dim), dtype=dtype, device=device, requires_grad=True, ) pe_triton = pe_torch.clone().detach().requires_grad_(True) # Run forward passes y_torch, y_cu_seqlens_torch = linear_compress_torch( x=x_torch, w=w_torch, cu_seqlens=cu_seqlens, kernel_size=kernel_size, kernel_stride=kernel_stride, pe=pe_torch, ) y_triton, y_cu_seqlens_triton = linear_compress( x=x_triton, w=w_triton, cu_seqlens=cu_seqlens, kernel_size=kernel_size, kernel_stride=kernel_stride, pe=pe_triton, ) # Check forward pass numerical equivalence atol, rtol = 1e-2, 1e-2 values_match = torch.allclose(y_torch, y_triton, atol=atol, rtol=rtol) print( f"Forward pass - Output values match (atol={atol}, rtol={rtol}): {values_match}" ) if not values_match: max_diff = (y_torch - y_triton).abs().max().item() print(f"Forward pass - Maximum difference: {max_diff}") print("\nSample values (first batch, first head):") print("Torch:", y_torch[0, 0, :5]) print("Triton:", y_triton[0, 0, :5]) # Create random output gradients for backward pass grad_output = torch.randn_like(y_torch) # Run backward passes y_torch.backward(grad_output) y_triton.backward(grad_output) # Check gradient equivalence print("\nTesting backward pass:") # Check x gradients x_grads_match = torch.allclose( x_torch.grad, x_triton.grad, atol=atol, rtol=rtol ) print(f"x gradients match (atol={atol}, rtol={rtol}): {x_grads_match}") if not x_grads_match: max_diff = (x_torch.grad - x_triton.grad).abs().max().item() print(f"x gradients - Maximum difference: {max_diff}") print("\nSample x gradients (first batch, first head):") print("Torch:", x_torch.grad[0, 0, :5]) print("Triton:", x_triton.grad[0, 0, :5]) # Check w gradients w_grads_match = torch.allclose( w_torch.grad, w_triton.grad, atol=atol, rtol=rtol ) print(f"w gradients match (atol={atol}, rtol={rtol}): {w_grads_match}") if not w_grads_match: max_diff = (w_torch.grad - w_triton.grad).abs().max().item() print(f"w gradients - Maximum difference: {max_diff}") print("\nSample w gradients (first head):") print("Torch:", w_torch.grad[0, :5, 0]) print("Triton:", w_triton.grad[0, :5, 0]) # Check pe gradients if used if use_pe: pe_grads_match = torch.allclose( pe_torch.grad, pe_triton.grad, atol=atol, rtol=rtol ) print(f"pe gradients match (atol={atol}, rtol={rtol}): {pe_grads_match}") if not pe_grads_match: max_diff = (pe_torch.grad - pe_triton.grad).abs().max().item() print(f"pe gradients - Maximum difference: {max_diff}") print("\nSample pe gradients (first head):") print("Torch:", pe_torch.grad[0, :5, 0]) print("Triton:", pe_triton.grad[0, :5, 0]) # Clean up gradients for next iteration x_torch.grad = None x_triton.grad = None w_torch.grad = None w_triton.grad = None if use_pe: pe_torch.grad = None pe_triton.grad = None if __name__ == "__main__": # Run tests test_linear_compress( batch_size=16, num_heads=8, head_dim=128, max_seqlen=2048, kernel_sizes=[32], kernel_strides=[16], use_pe=False, dtype=torch.float16, device="cuda", ) # benchmark @triton.testing.perf_report( triton.testing.Benchmark( x_names=["N"], x_vals=[1024 * 2**i for i in range(1, 8)], line_arg="provider", line_vals=["torch", "triton"], line_names=["torch", "triton"], styles=[("green", "-"), ("blue", "-")], ylabel="ms", plot_name="** forward + backward **", args={"H": 4, "D": 64}, ) ) def benchmark_fwdbwd(N, H, D, provider): K, S = 32, 16 # Input tensors x = torch.zeros(N, H, D, device="cuda", dtype=torch.bfloat16).uniform_(-1, 1) x.requires_grad = True w = torch.zeros(H, K * D, D, device="cuda", dtype=torch.bfloat16).uniform_( -1, 1 ) w.requires_grad = True pe = torch.zeros(H, K, D, device="cuda", dtype=torch.bfloat16).uniform_(-1, 1) cu_seqlens_b32 = ( torch.LongTensor( [0 if i == 0 else 32 if i > 1 else N - 32 * 31 for i in range(33)] ) .int() .cuda() ) cu_seqlens_b32 = cu_seqlens_b32.cumsum(0).to(torch.int32) quantiles = [0.5, 0.2, 0.8] def fwd_bwd(): if provider == "torch": out, _ = linear_compress_torch(x, w, cu_seqlens_b32, K, S, pe) else: out, _ = linear_compress(x, w, cu_seqlens_b32, K, S, pe) out.backward(out) # Using output as gradient for simplicity return out ms, min_ms, max_ms = triton.testing.do_bench(fwd_bwd, quantiles=quantiles) return ms, min_ms, max_ms benchmark_fwdbwd.run(show_plots=True, print_data=True) ================================================ FILE: test/test_nsa_infer.py ================================================ # Copyright 2025 Xunhao Lai & Jianqiao Lu. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from native_sparse_attention.ops import linear_compress, weightedpool_compress from native_sparse_attention.module import NSACache, RotaryEmbedding, RopeConfig from native_sparse_attention.infer import nsa_infer if __name__ == "__main__": torch.manual_seed(42) num_heads = 4 head_dim = 128 kernel_size = 32 kernel_stride = 16 block_size = 64 window_size = 512 topk = 16 init_blocks = 1 local_blocks = 2 # init seqlens seqlens = torch.tensor([12, 576, 12000]).to(torch.int32).cuda() batch_size = seqlens.shape[0] cu_seqlens = torch.zeros(seqlens.shape[0] + 1, dtype=torch.int32, device="cuda") cu_seqlens[1:] = seqlens.cumsum(0) step = 0 # init cache and weight and rope cache = NSACache(4, 16384, num_heads, head_dim, 32, 16, 512, torch.bfloat16, "cuda") compress_weight = [ torch.ones(num_heads, kernel_size * head_dim, head_dim).cuda().bfloat16() / (kernel_size * head_dim), torch.ones(num_heads, kernel_size).cuda().bfloat16() / kernel_size, ] compress_func = [linear_compress, weightedpool_compress] rope = RotaryEmbedding( RopeConfig( max_position_embeddings=131072, head_dim=128, rope_theta=500000, rope_scaling={ "factor": 8.0, "high_freq_factor": 4.0, "low_freq_factor": 1.0, "original_max_position_embeddings": 8192, "rope_type": "llama3", }, ) ) # test prefill q = torch.randn(cu_seqlens[-1], num_heads * 16, head_dim).cuda().bfloat16() k = torch.randn(cu_seqlens[-1], num_heads, head_dim).cuda().bfloat16() v = torch.randn_like(k) g = torch.rand(cu_seqlens[-1], num_heads * 16, 3).cuda().bfloat16() o = nsa_infer( cu_seqlens, step, q, k, v, g, rope, cache, compress_weight, compress_func, None, kernel_size, kernel_stride, block_size, topk, init_blocks, local_blocks, window_size, ) print(o.shape, o.norm()) # test decode q = torch.randn(cu_seqlens.shape[0] - 1, num_heads * 16, head_dim).cuda().bfloat16() k = torch.randn(cu_seqlens.shape[0] - 1, num_heads, head_dim).cuda().bfloat16() v = torch.randn_like(k) g = torch.rand(cu_seqlens.shape[0] - 1, num_heads * 16, 3).cuda().bfloat16() step = 1 o = nsa_infer( cu_seqlens, step, q, k, v, g, rope, cache, compress_weight, compress_func, None, kernel_size, kernel_stride, block_size, topk, init_blocks, local_blocks, window_size, ) print(o.shape, o.norm()) ================================================ FILE: test/test_nsa_model.py ================================================ # Copyright 2025 Xunhao Lai & Jianqiao Lu. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from native_sparse_attention.model import ( ToyNSALlamaConfig, InferenceConfig, ToyNSALlama, ) if __name__ == "__main__": torch.manual_seed(42) # initialize model config = ToyNSALlamaConfig( hidden_size=4096, intermediate_size=14336, num_hidden_layers=8, num_attention_heads=32, num_key_value_heads=2, head_dim=128, rope_theta=500000.0, rope_scaling={ "factor": 8.0, "high_freq_factor": 4.0, "low_freq_factor": 1.0, "original_max_position_embeddings": 8192, "rope_type": "llama3", }, compress_type="weightedpool", kernel_size=32, kernel_stride=16, block_size=64, topk=8, init_blocks=1, local_blocks=2, window_size=512, ) inference_config = InferenceConfig( max_batch_size=4, max_length=8192, max_new_tokens=128, ) model = ToyNSALlama(config, inference_config).cuda().bfloat16() print(f"\nMODEL CONFIG:\n{config}\n") print(f"\nINFERENCE CONFIG:\n{inference_config}\n") print(f"\nMODEL:\n{model}\n") # example input batch_size = 4 seqlens = torch.randint(0, 4096, (batch_size,), dtype=torch.int32, device="cuda") cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda") cu_seqlens[1:] = torch.cumsum(seqlens, dim=0) input_ids = torch.randint( 0, 128288, (cu_seqlens[-1],), dtype=torch.int64, device="cuda" ) print(f"\nEXAMPLE INPUT:\ncu_seqlens: {cu_seqlens}\ninput_ids: {input_ids.shape}\n") # example output logits = model(input_ids, cu_seqlens) print(f"\nEXAMPLE OUTPUT:\nlogits: {logits.shape}\n") # example generate output_tokens = model.generate(input_ids, cu_seqlens, 64) print(f"\nEXAMPLE GENERATE:\noutput_tokens: {output_tokens}\n") ================================================ FILE: test/test_nsa_module.py ================================================ # Copyright 2025 Xunhao Lai & Jianqiao Lu. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific import torch import triton from native_sparse_attention.module import ( SelfAttention, NativeSparseAttention, RopeConfig, ) if __name__ == "__main__": torch.manual_seed(42) NSA = ( NativeSparseAttention( compress_type="avgpool", hidden_size=8192, num_q_heads=64, num_kv_heads=4, head_dim=128, kernel_size=32, kernel_stride=16, block_size=64, topk=16, init_blocks=1, local_blocks=2, window_size=512, rope_config=RopeConfig( max_position_embeddings=131072, head_dim=128, rope_theta=500000, rope_scaling={ "factor": 8.0, "high_freq_factor": 4.0, "low_freq_factor": 1.0, "original_max_position_embeddings": 8192, "rope_type": "llama3", }, ), ) .cuda() .to(torch.bfloat16) ) print("======= Init Moduel: Native Sparse Attention =======\n") for name, param in NSA.named_parameters(): print(f"NSA Parameters, {name}, shape: {param.shape}\n") # random input seqlens = torch.LongTensor([4000, 8192, 16384]).int().cuda() cu_seqlens = torch.cat( [ torch.zeros(1, dtype=torch.int32, device="cuda"), torch.cumsum(seqlens, dim=0), ], dim=0, ).to(torch.int32) x = torch.zeros(cu_seqlens[-1], 8192, device="cuda", dtype=torch.bfloat16).uniform_( -1, 1 ) # forward test print("======= NSA Forward & Backward Test =======\n") y = NSA(x, cu_seqlens) print(f"Forward, output shape: {y.shape}, output norm: {y.norm()}\n") # backward test loss = (y * torch.randn_like(y)).sum(-1).mean() loss.backward() for name, param in NSA.named_parameters(): print( f"Backward, {name}, grad shape: {param.grad.shape}, grad norm: {param.grad.norm()}\n" ) # speed benchmark SelfAttn = ( SelfAttention( hidden_size=8192, num_q_heads=64, num_kv_heads=4, head_dim=128, rope_config=RopeConfig( max_position_embeddings=131072, head_dim=128, rope_theta=500000, rope_scaling={ "factor": 8.0, "high_freq_factor": 4.0, "low_freq_factor": 1.0, "original_max_position_embeddings": 8192, "rope_type": "llama3", }, ), ) .cuda() .to(torch.bfloat16) ) @triton.testing.perf_report( triton.testing.Benchmark( x_names=["N"], x_vals=[1024 * 2**i for i in range(1, 8)], line_arg="provider", line_vals=["Self-Attention", "Native-Sparse-Attention"], line_names=["Self-Attention", "Native-Sparse-Attention"], styles=[("green", "-"), ("blue", "-")], ylabel="ms", plot_name="** NSA forward speed benchmark **", args={}, ) ) def benchmark(N, provider): x = torch.randn(N, 8192, device="cuda", dtype=torch.bfloat16) cu_seqlens = torch.tensor([0, N], device="cuda", dtype=torch.int32) quantiles = [0.5, 0.2, 0.8] with torch.no_grad(): if provider == "Self-Attention": ms, min_ms, max_ms = triton.testing.do_bench( lambda: SelfAttn(x, cu_seqlens), quantiles=quantiles, ) if provider == "Native-Sparse-Attention": ms, min_ms, max_ms = triton.testing.do_bench( lambda: NSA(x, cu_seqlens), quantiles=quantiles, ) return ms, min_ms, max_ms benchmark.run(show_plots=True, print_data=True) @triton.testing.perf_report( triton.testing.Benchmark( x_names=["N"], x_vals=[1024 * 2**i for i in range(1, 8)], line_arg="provider", line_vals=["Self-Attention", "Native-Sparse-Attention"], line_names=["Self-Attention", "Native-Sparse-Attention"], styles=[("green", "-"), ("blue", "-")], ylabel="ms", plot_name="** NSA backward speed benchmark **", args={}, ) ) def benchmark(N, provider): x = torch.randn(N, 8192, device="cuda", dtype=torch.bfloat16) cu_seqlens = torch.tensor([0, N], device="cuda", dtype=torch.int32) quantiles = [0.5, 0.2, 0.8] if provider == "Self-Attention": loss = SelfAttn(x.clone().detach().requires_grad_(), cu_seqlens).mean() ms, min_ms, max_ms = triton.testing.do_bench( lambda: loss.backward(retain_graph=True), quantiles=quantiles, ) elif provider == "Native-Sparse-Attention": loss = NSA(x.clone().detach().requires_grad_(), cu_seqlens).mean() ms, min_ms, max_ms = triton.testing.do_bench( lambda: loss.backward(retain_graph=True), quantiles=quantiles, ) return ms, min_ms, max_ms benchmark.run(show_plots=True, print_data=True) ================================================ FILE: test/test_rope.py ================================================ # Copyright 2025 Xunhao Lai & Jianqiao Lu. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from native_sparse_attention.module import RopeConfig, RotaryEmbedding if __name__ == "__main__": rope_config = RopeConfig( max_position_embeddings=131072, head_dim=128, rope_theta=500000, rope_scaling={ "factor": 8.0, "high_freq_factor": 4.0, "low_freq_factor": 1.0, "original_max_position_embeddings": 8192, "rope_type": "llama3", }, ) rope = RotaryEmbedding(rope_config, "cuda") # random input torch.manual_seed(42) seqlens = torch.LongTensor([1000, 2000, 4096]).int().cuda() cu_seqlens = torch.cat( [ torch.zeros(1, dtype=torch.int32, device="cuda"), torch.cumsum(seqlens, dim=0), ], dim=0, ).to(torch.int32) x = torch.zeros( cu_seqlens[-1], 32, 128, device="cuda", dtype=torch.bfloat16 ).uniform_(-1, 1) y = rope(x, cu_seqlens) ================================================ FILE: test/test_topk_sparse_attention.py ================================================ # Copyright 2025 Xunhao Lai & Jianqiao Lu. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific import torch import triton import math from native_sparse_attention.ops.torch.topk_sparse_attention import ( topk_sparse_attention_torch, ) from native_sparse_attention.ops.triton.topk_sparse_attention import ( topk_sparse_attention, _topk_sparse_attention_fwd, _topk_sparse_attention_bwd, ) from native_sparse_attention.ops.triton.flash_attention import ( _flash_attention_fwd, _flash_attention_bwd, ) from flash_attn.flash_attn_interface import ( _flash_attn_varlen_forward, _flash_attn_varlen_backward, ) def generate_topk_idx_example( seqlens: torch.Tensor, block_size_k: int, topk: int, num_heads: int, block_size_q: int = 1, ) -> torch.Tensor: """Generate topk idx example for test. Args: seqlens (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens in flash_attn_func_varlen. block_size_q (int): query block size block_size_k (int): key value block size topk (int): selected topk num_heads (int): number of key value heads Returns: torch.Tensor: shape [num_heads, total_seqlen, topk], topk key value block idx for each query. -1 means padding. """ batch_size = seqlens.shape[0] num_blocks = torch.ceil(seqlens / block_size_k).to(torch.int32) topk_idx_all_heads = [] cu_seqlens = torch.nn.functional.pad(seqlens.cumsum(0), pad=(1, 0), value=0) for _ in range(num_heads): topk_idx = [ torch.randn(seqlens[i], num_blocks[i], device="cuda") .topk(min(topk, num_blocks[i]), dim=-1) .indices.to(torch.int32) for i in range(batch_size) ] topk_idx = [ torch.nn.functional.pad( topk_idx[i], (0, topk - topk_idx[i].shape[-1]), value=topk ) for i in range(batch_size) ] topk_idx = torch.cat(topk_idx, dim=0) topk_idx = torch.sort(topk_idx, dim=1).values topk_idx[:, 0] = 0 q_idx = torch.cat( [torch.arange(seqlens[i], device="cuda") for i in range(batch_size)], dim=0 ) topk_idx[topk_idx > (q_idx // block_size_k)[:, None]] = -1 # -1 means padding topk_idx = torch.cat( [ topk_idx[cu_seqlens[i] : cu_seqlens[i + 1]][0::block_size_q] for i in range(batch_size) ], dim=0, ) topk_idx_all_heads.append(topk_idx) topk_idx = torch.stack(topk_idx_all_heads, dim=0) return topk_idx if __name__ == "__main__": torch.manual_seed(42) batch_size = 3 seqlens = torch.LongTensor([1000, 2000, 4096]).int().cuda() cu_seqlens = torch.cat( [ torch.zeros(1, dtype=torch.int32, device="cuda"), torch.cumsum(seqlens, dim=0), ], dim=0, ).to(torch.int32) max_seqlen = seqlens.max().item() q = ( torch.empty(cu_seqlens[-1], 64, 96, device="cuda") .uniform_(-1, 1) .to(torch.float16) ) k = ( torch.empty(cu_seqlens[-1], 8, 96, device="cuda") .uniform_(-1, 1) .to(torch.float16) ) v = ( torch.empty(cu_seqlens[-1], 8, 96, device="cuda") .uniform_(-1, 1) .to(torch.float16) ) q.requires_grad = True k.requires_grad = True v.requires_grad = True block_size = 64 topk = 5 topk_idx = generate_topk_idx_example(seqlens, block_size, topk, 8) o = topk_sparse_attention_torch(q, k, v, topk_idx, block_size, cu_seqlens) randn = torch.randn_like(o) loss = (o * randn).sum() loss.backward() torch.manual_seed(42) q1 = q.clone().detach().requires_grad_() k1 = k.clone().detach().requires_grad_() v1 = v.clone().detach().requires_grad_() topk_idx1 = topk_idx.clone().detach() cu_seqlens1 = cu_seqlens.clone().detach() o1 = topk_sparse_attention(q1, k1, v1, topk_idx, block_size, cu_seqlens) randn2 = randn.clone().detach() loss2 = (o1 * randn2).sum() loss2.backward() print("Same Output:", torch.allclose(o, o1, atol=0.01, rtol=0.01)) print("Max Error:", (o - o1).abs().max().item()) print() print("Same Query Gradient:", torch.allclose(q.grad, q1.grad, atol=0.01, rtol=0.01)) print("Max Query Gradient Error:", (q.grad - q1.grad).abs().max().item()) print() print("Same Key Gradient:", torch.allclose(k.grad, k1.grad, atol=0.01, rtol=0.01)) print("Max Key Gradient Error:", (k.grad - k1.grad).abs().max().item()) print() print("Same Value Gradient:", torch.allclose(v.grad, v1.grad, atol=0.01, rtol=0.01)) print("Max Value Gradient Error:", (v.grad - v1.grad).abs().max().item()) print() # benchmark @triton.testing.perf_report( triton.testing.Benchmark( x_names=["N"], x_vals=[1024 * 2**i for i in range(1, 8)], line_arg="provider", line_vals=["flash", "triton-flash", "triton-top8", "triton-top16"], line_names=[ "Flash", "Triton-Flash", "Triton-Top8", "Triton-Top16", ], styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")], ylabel="ms", plot_name="** forward with block size 64 **", args={"H": 64, "D": 128, "K": 64}, ) ) def benchmark(N, H, D, K, provider): q = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16) k = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16) v = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16) cu_seqlens = torch.tensor([0, N], device="cuda", dtype=torch.int32) sm_scale = 1 / math.sqrt(D) top8_idx = generate_topk_idx_example(cu_seqlens[1:], K, 8, H // 16) top16_idx = generate_topk_idx_example(cu_seqlens[1:], K, 16, H // 16) quantiles = [0.5, 0.2, 0.8] if provider == "flash": ms, min_ms, max_ms = triton.testing.do_bench( lambda: _flash_attn_varlen_forward( q, k, v, cu_seqlens, cu_seqlens, N, N, dropout_p=0.0, causal=True, softmax_scale=sm_scale, ), quantiles=quantiles, ) if provider == "triton-flash": ms, min_ms, max_ms = triton.testing.do_bench( lambda: _flash_attention_fwd( q, k, v, cu_seqlens, cu_seqlens, N, N, True, sm_scale ), quantiles=quantiles, ) if provider == "triton-top8": ms, min_ms, max_ms = triton.testing.do_bench( lambda: _topk_sparse_attention_fwd( q, k, v, top8_idx, K, cu_seqlens, cu_seqlens, N, N, sm_scale ), quantiles=quantiles, ) if provider == "triton-top16": ms, min_ms, max_ms = triton.testing.do_bench( lambda: _topk_sparse_attention_fwd( q, k, v, top16_idx, K, cu_seqlens, cu_seqlens, N, N, sm_scale ), quantiles=quantiles, ) return ms, min_ms, max_ms benchmark.run(show_plots=True, print_data=True) # benchmark @triton.testing.perf_report( triton.testing.Benchmark( x_names=["N"], x_vals=[1024 * 2**i for i in range(1, 8)], line_arg="provider", line_vals=["flash", "triton-flash", "triton-top8", "triton-top16"], line_names=[ "Flash", "Triton-Flash", "Triton-Top8", "Triton-Top16", ], styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")], ylabel="ms", plot_name="** backward with block size 64 **", args={"H": 64, "D": 128, "K": 64}, ) ) def benchmark(N, H, D, K, provider): q = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16) k = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16) v = torch.randn((N, H // 16, D), device="cuda", dtype=torch.bfloat16) o = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16) do = torch.randn((N, H, D), device="cuda", dtype=torch.bfloat16) lse = torch.randn((H, N), device="cuda", dtype=torch.float32) sm_scale = 1 / math.sqrt(D) cu_seqlens = torch.tensor([0, N], device="cuda", dtype=torch.int32) top8_idx = generate_topk_idx_example(cu_seqlens[1:], K, 8, H // 16) top16_idx = generate_topk_idx_example(cu_seqlens[1:], K, 16, H // 16) dq = torch.zeros_like(q) dk = torch.zeros_like(k) dv = torch.zeros_like(v) quantiles = [0.5, 0.2, 0.8] if provider == "flash": ms, min_ms, max_ms = triton.testing.do_bench( lambda: _flash_attn_varlen_backward( do, q, k, v, o, lse.transpose(0, 1), dq, dk, dv, cu_seqlens, cu_seqlens, N, N, dropout_p=0.0, causal=True, softmax_scale=sm_scale, window_size=(-1, -1), softcap=0.0, alibi_slopes=None, deterministic=False, ), quantiles=quantiles, ) if provider == "triton-flash": ms, min_ms, max_ms = triton.testing.do_bench( lambda: _flash_attention_bwd( o, do, lse, q, k, v, cu_seqlens, cu_seqlens, N, N, True, sm_scale ), quantiles=quantiles, ) if provider == "triton-top8": ms, min_ms, max_ms = triton.testing.do_bench( lambda: _topk_sparse_attention_bwd( o, do, lse, q, k, v, top8_idx, K, cu_seqlens, cu_seqlens, N, N, sm_scale, ), quantiles=quantiles, ) if provider == "triton-top16": ms, min_ms, max_ms = triton.testing.do_bench( lambda: _topk_sparse_attention_bwd( o, do, lse, q, k, v, top16_idx, K, cu_seqlens, cu_seqlens, N, N, sm_scale, ), quantiles=quantiles, ) return ms, min_ms, max_ms benchmark.run(show_plots=True, print_data=True)