[
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\ncover/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\n.pybuilder/\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n#   For a library or package, you might want to ignore these files since the code is\n#   intended to run in multiple environments; otherwise, check them in:\n# .python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# UV\n#   Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.\n#   This is especially recommended for binary packages to ensure reproducibility, and is more\n#   commonly ignored for libraries.\n#uv.lock\n\n# poetry\n#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.\n#   This is especially recommended for binary packages to ensure reproducibility, and is more\n#   commonly ignored for libraries.\n#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control\n#poetry.lock\n\n# pdm\n#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.\n#pdm.lock\n#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it\n#   in version control.\n#   https://pdm.fming.dev/latest/usage/project/#working-with-version-control\n.pdm.toml\n.pdm-python\n.pdm-build/\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# pytype static type analyzer\n.pytype/\n\n# Cython debug symbols\ncython_debug/\n\n# PyCharm\n#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can\n#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore\n#  and can be added to the global gitignore or merged into this file.  For a more nuclear\n#  option (not recommended) you can uncomment the following to ignore the entire idea folder.\n#.idea/\n\n# PyPI configuration file\n.pypirc\n"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "<div align=\"center\">\n\n# Native Sparse Attention Triton\n\n</div>\n\nThis repository implements the sparse attention mechanism introduced in the paper [Native Sparse Attention: Hardware-Aligned and Natively Trainable Sparse Attention](https://arxiv.org/abs/2502.11089) and provides an efficient training implementation based on [Triton](https://github.com/triton-lang/triton).\n\n🎉 We now support both training and inference for Native Sparse Attention (variable-length version, including prefilling, decoding, and KV cache management). We have provided a toy model at `model.ToyNSALlama`, which supports `forward` function for training and `generate` function for inference. Welcome to try it out!\n\n## Requirements\nEnsure the following dependencies are installed:\n- PyTorch >= 2.1.0\n- triton >= 3.0.0\n- einops >= 0.7.0\n- flash_attn >= 2.6.3\n\n## Usage\n\n### Notes\n1. PyTorch implementations (`ops.torch`) are intended for debugging only.\n2. For production use, prefer Triton operators (`ops.triton`).\n3. All implementations are based on the varlen approach similiar to flash_attn_func_varlen. Please concatenate the inputs of a batch before use.\n4. Only support attention head dimension less than 128 for now.\n\n### Install\n\nYou can install `native_sparse_attention` using pip:\n\n```shell\npip install git+https://github.com/XunhaoLai/native-sparse-attention-triton.git\n```\n\n### Functions\n\nThe `ops` module has implemented several functions required for native sparse attention. For detailed usage instructions, please see [this link](https://github.com/XunhaoLai/native-sparse-attention-triton/tree/main/native_sparse_attention/ops#readme).\n\nYou can import those functions from the `ops` module:\n\n```python\nimport torch\nfrom native_sparse_attention.ops import linear_compress, compressed_attention, topk_sparse_attention\n\n# input example\nnum_q_heads = 64\nnum_kv_heads = 4\nhead_dim = 128\nkernel_size = 32\nkernel_stride = 16\nblock_size = 64\ntopk = 16\ncu_seqlens = torch.Tensor([0, 1024, 8192, 16384]).to(torch.int32).cuda()\nquery = torch.randn(16384, num_q_heads, head_dim).to(torch.bfloat16).cuda()\nkey = torch.randn(16384, num_kv_heads, head_dim).to(torch.bfloat16).cuda()\nvalue = torch.randn(16384, num_kv_heads, head_dim).to(torch.bfloat16).cuda()\n\n# weight example\nw = (\n    torch.randn(num_kv_heads, kernel_size * head_dim, head_dim)\n    .to(torch.bfloat16)\n    .cuda()\n)\npe = torch.randn(num_kv_heads, kernel_size, head_dim).to(torch.bfloat16).cuda()\n\n# 1. key value compression\ncompressed_key, compressed_cu_seqlens = linear_compress(\n    key, w, cu_seqlens, kernel_size, kernel_stride, pe\n)\ncompressed_value, _ = linear_compress(\n    value, w, cu_seqlens, kernel_size, kernel_stride, None\n)\n\n# 2. attention between query and compressed key value\ncompressed_attn_output, topk_idx = compressed_attention(\n    query,\n    compressed_key,\n    compressed_value,\n    kernel_size,\n    kernel_stride,\n    block_size,\n    topk,\n    cu_seqlens,\n    compressed_cu_seqlens,\n    init_blocks=1,\n    local_blocks=2,\n)\n\n# 3. topk sparse attention\nsparse_attn_output = topk_sparse_attention(\n    query,\n    key,\n    value,\n    topk_idx,\n    block_size,\n    cu_seqlens,\n)\n```\n\n### Module\n\nThe `modules` directory also provides implementations based on `torch.nn.module` for easy integration into models.\n\n```python\nfrom native_sparse_attention.modules import NativeSparseAttention, RopeConfig\n\nNSA_Layer = NativeSparseAttention(\n    compress_type=\"linear\",\n    hidden_size=4096,\n    num_q_heads=64,\n    num_kv_heads=4,\n    head_dim=128,\n    kernel_size=32,\n    kernel_stride=16,\n    block_size=64,\n    topk=8,\n    init_blocks=1,\n    local_blocks=2,\n    window_size=512,\n    rope_config=RopeConfig(\n        max_position_embeddings=32768,\n        head_dim=128,\n        rope_theta=500000,\n        rope_scaling={\n            \"factor\": 4.0,\n            \"high_freq_factor\": 4.0,\n            \"low_freq_factor\": 1.0,\n            \"original_max_position_embeddings\": 8192,\n            \"rope_type\": \"llama3\",\n        },\n    ),\n)\n```\n\n### Model\n\nWe offer two simplified LLaMA models in the `model` directory, featuring self-attention and native sparse attention. For more details on how to use these models, please refer to [this link](https://github.com/XunhaoLai/native-sparse-attention-triton/tree/main/native_sparse_attention/model#readme).\n\n\n```python\nfrom native_sparse_attention.model import ToyNSALlamaConfig, InferenceConfig, ToyNSALlama\n\nconfig = ToyNSALlamaConfig(\n    hidden_size=4096,\n    intermediate_size=14336,\n    num_hidden_layers=8,\n    num_attention_heads=32,\n    num_key_value_heads=2,\n    head_dim=128,\n    rope_theta=500000.0,\n    rope_scaling={\n        \"factor\": 8.0,\n        \"high_freq_factor\": 4.0,\n        \"low_freq_factor\": 1.0,\n        \"original_max_position_embeddings\": 8192,\n        \"rope_type\": \"llama3\",\n    },\n    compress_type=\"weightedpool\",\n    kernel_size=32,\n    kernel_stride=16,\n    block_size=64,\n    topk=8,\n    init_blocks=1,\n    local_blocks=2,\n    window_size=512,\n)\ninference_config = InferenceConfig(\n    max_batch_size=4,\n    max_length=8192,\n    max_new_tokens=128,\n)\nmodel = ToyNSALlama(config, inference_config).cuda().bfloat16()\n```\n\n## Testing\n\nSome test scripts are available in the `test` folder and can be run directly for unit testing. For example:\n\n```bash\npython test/test_topk_sparse_attention.py\npython test/test_nsa_module.py\npython test/test_nsa_model.py\n```\n\n### Benchmarks\n\nHere are the speed benchmarks conducted on a single NVIDIA A100 GPU or H100 GPU for the `topk_sparse_attention` function: \n\nA100 GPU speed benchmarks:\n```sh\n** forward with block size 64 **:\n          N       Flash  Triton-Flash  Triton-Top8  Triton-Top16\n0    2048.0     0.414144      0.635648     0.633440      1.009184\n1    4096.0     1.400304      2.267552     1.179808      1.916736\n2    8192.0     5.223776      8.528160     2.266816      3.723168\n3   16384.0    20.225697     32.745537     4.468128      7.359168\n4   32768.0    79.587715    128.951065     8.517440     14.142848\n5   65536.0   321.240479    511.652100    17.249599     30.991360\n6  131072.0  1349.810425   2063.245605    36.400482     67.884544\n\n** backward with block size 64 **:\n          N        Flash  Triton-Flash  Triton-Top8  Triton-Top16\n0    2048.0     1.315440      2.348560     1.941568      2.691040\n1    4096.0     4.271584      8.553184     3.647744      5.032160\n2    8192.0    15.323984     32.665440     5.650144      9.066112\n3   16384.0    58.753281    127.675964    11.160832     17.113279\n4   32768.0   227.770462    504.572693    21.723392     34.715614\n5   65536.0   899.181274   2059.718506    44.517181     76.309441\n6  131072.0  3587.918701   8530.726562   105.344734    182.970169\n```\n\nH100 GPU benchmarks:\n```sh\n** forward with block size 64 **:\n          N       Flash  Triton-Flash  Triton-Top8  Triton-Top16\n0    2048.0    0.259552      0.293888     0.584544      0.917664\n1    4096.0    0.846848      1.029904     1.094976      1.745136\n2    8192.0    3.043744      3.843392     2.128256      3.396880\n3   16384.0   11.743568     14.791360     4.190528      6.704192\n4   32768.0   45.968513     57.532478     7.614496     12.417440\n5   65536.0  187.234375    228.093948    14.840048     24.511856\n6  131072.0  810.890381    914.693970    29.470400     48.990192\n\n** backward with block size 64 **:\n          N        Flash  Triton-Flash  Triton-Top8  Triton-Top16\n0    2048.0     0.798976      1.096096     1.117312      1.380016\n1    4096.0     2.545680      3.826336     1.669760      2.214880\n2    8192.0     9.029760     14.411633     2.772096      3.947456\n3   16384.0    34.144016     58.945698     5.201344      7.538912\n4   32768.0   135.718369    233.369247     9.968864     15.154192\n5   65536.0   541.053894    929.337646    21.089870     33.818878\n6  131072.0  2139.974854   3785.540527    54.918144     93.750717\n```\n\nHere comes another speed benchmark result for testing `compressed_attention` function on a single NVIDIA A100 GPU or H100 GPU:\n\nA100 GPU speed benchmarks:\n```sh\n** forward with kernel 32 and stride 16 **:\n          N       Flash  Triton-Flash  Compressed  Compressed-wo-Score\n0    2048.0     0.413664      0.635488    0.655024             0.170816\n1    4096.0     1.396416      2.247648    1.132304             0.377152\n2    8192.0     5.234656      8.526400    2.879200             0.977952\n3   16384.0    19.988865     32.755199    9.426448             2.943024\n4   32768.0    79.419907    128.955170   30.284096             9.901120\n5   65536.0   321.590210    511.615509  112.260544            36.001602\n6  131072.0  1346.996338   2069.837891  423.099518           136.820038\n\n** backward with kernel 32 and stride 16 **:\n          N        Flash  Triton-Flash  Compressed\n0    2048.0     1.322560      2.352000    0.486784\n1    4096.0     4.270832      8.552608    0.971392\n2    8192.0    15.515680     32.671329    2.603744\n3   16384.0    59.345055    128.377472    8.499456\n4   32768.0   230.626144    506.581238   30.064833\n5   65536.0   919.260498   2068.642578  113.466560\n6  131072.0  3646.603760   8498.374023  439.623444\n```\n\nH100 GPU speed benchmarks:\n```sh\n** forward with kernel 32 and stride 16 **:\n          N       Flash  Triton-Flash  Compressed  Compressed-wo-Score\n0    2048.0    0.259488      0.297152    0.485920             0.103232\n1    4096.0    0.847376      1.030400    0.710208             0.217760\n2    8192.0    3.044016      3.875840    1.607360             0.516016\n3   16384.0   11.823104     14.829360    4.970272             1.440288\n4   32768.0   46.204750     57.527809   15.004992             4.584736\n5   65536.0  187.324249    227.909958   53.009087            16.134224\n6  131072.0  810.707214    910.106873  191.245728            60.154270\n\n** backward with kernel 32 and stride 16 **:\n          N        Flash  Triton-Flash  Compressed\n0    2048.0     0.797728      1.090640    0.283104\n1    4096.0     2.547088      3.834592    0.550464\n2    8192.0     9.021520     14.421088    1.249184\n3   16384.0    34.159508     58.793377    3.743440\n4   32768.0   136.844070    233.447708   12.640032\n5   65536.0   537.559814    929.360229   46.054817\n6  131072.0  2135.629883   3782.351562  175.587296\n```\n\nAll the speed benchmarks above were tested with 64 query heads, 4 key/value heads, and a head dimension of 128.\n\n## Contributing\nContributions are welcome! Please open an issue to discuss major changes.\n\n## Contact\n\nFor any questions or feedback, please feel free to contact laixunhao@pku.edu.cn.\n\n## Citations\n\n```bibtex\n@inproceedings{Yuan2025NativeSA,\n    title   = {Native Sparse Attention: Hardware-Aligned and Natively Trainable Sparse Attention},\n    author  = {Jingyang Yuan and Huazuo Gao and Damai Dai and Junyu Luo and Liang Zhao and Zhengyan Zhang and Zhenda Xie and Y. X. Wei and Lean Wang and Zhiping Xiao and Yuqing Wang and Chong Ruan and Ming Zhang and Wenfeng Liang and Wangding Zeng},\n    year    = {2025},\n    url     = {https://api.semanticscholar.org/CorpusID:276408911}\n}\n```\n"
  },
  {
    "path": "install_dependency.sh",
    "content": "pip3 install packaging -i https://pypi.org/simple\npip3 install numpy==1.26.4 -i https://pypi.org/simple\npip3 install torch==2.4.0 -i https://pypi.org/simple\npip3 install triton==3.0.0 -i https://pypi.org/simple\npip3 install transformers==4.44.0 -i https://pypi.org/simple\npip3 install flash_attn==2.6.3 -i https://pypi.org/simple\npip3 install matplotlib==3.9.4 -i https://pypi.org/simple\npip3 install pandas==2.2.3 -i https://pypi.org/simple"
  },
  {
    "path": "native_sparse_attention/__init__.py",
    "content": ""
  },
  {
    "path": "native_sparse_attention/infer/__init__.py",
    "content": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom native_sparse_attention.infer.nsa_inference import nsa_infer\n\n__all__ = [\n    \"nsa_infer\",\n]\n"
  },
  {
    "path": "native_sparse_attention/infer/inference_func.py",
    "content": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport torch\nfrom typing import Tuple, Callable, Optional\nfrom flash_attn import flash_attn_varlen_func\nfrom native_sparse_attention.ops import (\n    flash_attention_decode,\n    compressed_attention,\n    compressed_attention_decode,\n    topk_sparse_attention,\n    topk_sparse_attention_decode,\n)\nfrom native_sparse_attention.ops.triton.utils import get_compressed_seqlens\n\n\ndef compress_infer(\n    cu_seqlens: torch.Tensor,\n    step: int,\n    key: torch.Tensor,\n    value: torch.Tensor,\n    cache,\n    weight: Tuple[torch.Tensor, torch.Tensor],\n    compress_func: Tuple[Callable, Callable],\n    intra_block_pe: Optional[torch.Tensor],\n    kernel_size: int,\n    kernel_stride: int,\n):\n    if step == 0:\n        key, compress_cu_seqlens = compress_func[0](\n            key,\n            weight[0],\n            cu_seqlens,\n            kernel_size,\n            kernel_stride,\n            intra_block_pe,\n        )\n        value, _ = compress_func[1](\n            value,\n            weight[1],\n            cu_seqlens,\n            kernel_size,\n            kernel_stride,\n        )\n    else:\n        batch_size = cu_seqlens.shape[0] - 1\n        aux_cu_seqlens = (\n            torch.arange(batch_size + 1, dtype=torch.int32).to(cu_seqlens.device)\n            * kernel_size\n        )\n        key, _ = compress_func[0](\n            cache.before_compress_kv_cache[0, :batch_size].view(\n                batch_size * kernel_size, cache.num_kv_heads, cache.head_dim\n            ),\n            weight[0],\n            aux_cu_seqlens,\n            kernel_size,\n            kernel_stride,\n            intra_block_pe,\n        )\n        value, _ = compress_func[1](\n            cache.before_compress_kv_cache[1, :batch_size].view(\n                batch_size * kernel_size, cache.num_kv_heads, cache.head_dim\n            ),\n            weight[1],\n            aux_cu_seqlens,\n            kernel_size,\n            kernel_stride,\n        )\n        # return actual compress_cu_seqlens before this token\n        compress_cu_seqlens = torch.zeros(\n            batch_size + 1, dtype=torch.int32, device=key.device\n        )\n        compress_cu_seqlens[1:] = torch.cumsum(\n            cache.compress_kv_len[:batch_size], dim=0\n        )\n    return key, value, compress_cu_seqlens\n\n\ndef compressed_attention_infer(\n    cu_seqlens,\n    step,\n    query,\n    key,\n    value,\n    cache,\n    kernel_size,\n    kernel_stride,\n    topk,\n    block_size,\n    init_blocks,\n    local_blocks,\n):\n    if step == 0:\n        seqlens = cu_seqlens[1:] - cu_seqlens[:-1]\n        compress_seqlens, compress_cu_seqlens = get_compressed_seqlens(\n            cu_seqlens, kernel_size, kernel_stride\n        )\n        attn_output, topk_idx = compressed_attention(\n            query,\n            key,\n            value,\n            kernel_size,\n            kernel_stride,\n            block_size,\n            topk,\n            cu_seqlens,\n            compress_cu_seqlens,\n            seqlens.max().item(),\n            compress_seqlens.max().item(),\n            None,\n            init_blocks,\n            local_blocks,\n        )\n    else:\n        batch_size = cu_seqlens.shape[0] - 1\n        seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + step\n        attn_output, topk_idx = compressed_attention_decode(\n            query,\n            cache.compress_kv_cache[\n                0, :batch_size, : cache.compress_kv_len[:batch_size].max()\n            ],\n            cache.compress_kv_cache[\n                1, :batch_size, : cache.compress_kv_len[:batch_size].max()\n            ],\n            seqlens,\n            cache.compress_kv_len[:batch_size],\n            kernel_size,\n            kernel_stride,\n            block_size,\n            topk,\n            init_blocks,\n            local_blocks,\n        )\n    return attn_output, topk_idx\n\n\ndef topk_sparse_attention_infer(\n    cu_seqlens,\n    step,\n    query,\n    key,\n    value,\n    cache,\n    topk_idx,\n    block_size,\n):\n    if step == 0:\n        attn_output = topk_sparse_attention(\n            query, key, value, topk_idx, block_size, cu_seqlens\n        )\n    else:\n        batch_size = cu_seqlens.shape[0] - 1\n        attn_output = topk_sparse_attention_decode(\n            query,\n            cache.sparse_kv_cache[0, :batch_size],\n            cache.sparse_kv_cache[1, :batch_size],\n            topk_idx,\n            block_size,\n            cache.sparse_kv_len[:batch_size],\n        )\n    return attn_output\n\n\ndef sliding_window_attention_infer(\n    cu_seqlens, step, query, key, value, cache, window_size\n):\n    if step == 0:\n        seqlens = cu_seqlens[1:] - cu_seqlens[:-1]\n        attn_output = flash_attn_varlen_func(\n            query,\n            key,\n            value,\n            cu_seqlens,\n            cu_seqlens,\n            seqlens.max().item(),\n            seqlens.max().item(),\n            causal=True,\n            window_size=(window_size, -1),\n        )\n    else:\n        batch_size = cu_seqlens.shape[0] - 1\n        attn_output = flash_attention_decode(\n            query,\n            cache.sliding_kv_cache[0, :batch_size],\n            cache.sliding_kv_cache[1, :batch_size],\n            torch.minimum(\n                cache.sliding_kv_len,\n                torch.zeros_like(cache.sliding_kv_len) + window_size,\n            )[:batch_size],\n        )\n    return attn_output\n"
  },
  {
    "path": "native_sparse_attention/infer/nsa_inference.py",
    "content": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport torch\nfrom typing import Tuple, Callable, Optional\nfrom native_sparse_attention.infer.inference_func import (\n    compress_infer,\n    compressed_attention_infer,\n    topk_sparse_attention_infer,\n    sliding_window_attention_infer,\n)\n\n\ndef nsa_infer(\n    cu_seqlens: torch.Tensor,\n    step: int,\n    # qkv for three parts\n    query: torch.Tensor,\n    key: torch.Tensor,  # prefill: [total_len, num_heads, head_dim], decode: [batch_size, num_heads, head_dim]\n    value: torch.Tensor,\n    gate_value: torch.Tensor,  # prefill: [total_len, num_heads, 3], decode: [batch_size, num_heads, 3]\n    # rope and kv cache\n    rope,\n    cache,\n    # weight for nsa compress\n    compress_weight: Tuple[\n        torch.Tensor, torch.Tensor\n    ],  # compress weight for key and value\n    compress_func: Tuple[Callable, Callable],  # compress function for key and value\n    intra_block_pe: Optional[torch.Tensor],\n    # nsa parameters\n    kernel_size: int,\n    kernel_stride: int,\n    block_size: int,\n    topk: int,\n    init_blocks: int,\n    local_blocks: int,\n    window_size: int,\n) -> torch.Tensor:\n    \"\"\"Inference function for native sparse attention. Support prefill and decode with kv cache.\n\n    Args:\n        cu_seqlens (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.\n        step (int): current inference step, step == 0 means prefill, step > 0 means decode step.\n        query (torch.Tensor): for prefill, shape [total_len, num_q_heads, head_dim]; for decode, shape [batch_size, num_q_heads, head_dim]\n        key (torch.Tensor): for prefill, shape [total_len, num_kv_heads, head_dim]; for decode, shape [batch_size, num_kv_heads, head_dim]\n        value (torch.Tensor): for prefill, shape [total_len, num_kv_heads, head_dim]; for decode, shape [batch_size, num_kv_heads, head_dim]\n        gate_value (torch.Tensor): for prefill, shape [total_len, num_heads, 3]; for decode, shape [batch_size, num_heads, 3]\n        rope (RotaryEmbedding): rope module, see native_sparse_attention.module.rope.RotaryEmbedding for details\n        cache (NSACache): kv cache, seed native_sparse_attention.module.kv_cache.NSACache for details\n        compress_weight (Tuple[torch.Tensor, torch.Tensor]): compress weight of key and value respectively\n        compress_func (Tuple[Callable, Callable]): compress functions for key and value respectively\n        intra_block_pe (Optional[torch.Tensor]): intra-block positonal embedding for compression, set to None if don't use it\n        kernel_size (int): kernel size of compression\n        kernel_stride (int): kernel stride ofr compression\n        block_size (int): block size of sparse attention\n        topk (int): topk of sparse attention\n        init_blocks (int): number of blocks at the begining of the sequence, these blocks are force to be computed in sparse attention\n        local_blocks (int): number of blocks at the local window of each query, these blocks are force to be computed in sparse attention\n        window_size (int): window size for sliding window attention\n\n    Returns:\n        torch.Tensor: native sparse attention output, same shape as input query\n    \"\"\"\n    # reset kv cache at the begining of prefilling\n    if step == 0:\n        cache.reset()\n    # prepare for compress\n    cache.prepare_compress(cu_seqlens, step, key, value)\n    # compressed key and value before rope\n    compress_key, compress_value, compress_cu_seqlens = compress_infer(\n        cu_seqlens,\n        step,\n        key,\n        value,\n        cache,\n        compress_weight,\n        compress_func,\n        intra_block_pe,\n        kernel_size,\n        kernel_stride,\n    )\n    # do rope\n    query = rope(query, cu_seqlens, step)\n    if step == 0:\n        compress_key = rope(\n            compress_key, compress_cu_seqlens, step, stride=cache.kernel_stride\n        )\n    else:\n        compress_key = rope(\n            compress_key, compress_cu_seqlens, 1, stride=cache.kernel_stride\n        )\n    key = rope(key, cu_seqlens, step)\n    # update kv cache\n    cache.update_kv(\n        cu_seqlens,\n        step,\n        compress_key,\n        compress_value,\n        key,\n        value,\n        key,\n        value,\n    )\n    # compressed attention\n    compress_attn_output, topk_idx = compressed_attention_infer(\n        cu_seqlens,\n        step,\n        query,\n        compress_key,\n        compress_value,\n        cache,\n        kernel_size,\n        kernel_stride,\n        topk,\n        block_size,\n        init_blocks,\n        local_blocks,\n    )\n    # topk sparse attention\n    sparse_attn_output = topk_sparse_attention_infer(\n        cu_seqlens,\n        step,\n        query,\n        key,\n        value,\n        cache,\n        topk_idx,\n        block_size,\n    )\n    # sliding window attention\n    sliding_attn_output = sliding_window_attention_infer(\n        cu_seqlens, step, query, key, value, cache, window_size\n    )\n    # combine 3 attn output\n    attn_output = (\n        gate_value[..., 0, None] * compress_attn_output\n        + gate_value[..., 1, None] * sparse_attn_output\n        + gate_value[..., 2, None] * sliding_attn_output\n    )\n    return attn_output\n"
  },
  {
    "path": "native_sparse_attention/model/README.md",
    "content": "# Guide for the ToyNSALlama Model\n\nThe `ToyNSALlama` model is a custom implementation of a Llama-like transformer architecture featuring a Native Sparse Attention (NSA) module. This guide explains how to integrate the NSA module into your own model.\n\n## Overview\n\nThe `ToyNSALlama` model consists of:\n- **Configuration**: Defined by `ToyNSALlamaConfig` (model structure parameters) and `InferenceConfig` (inference-specific parameters).\n- **Components**: An embedding layer, multiple NativeSparseAttention modules, Feed-Forward Network (FFN) modules, normalization layers, and a language model head.\n\n## Step-by-Step Instructions\n\n### 1. Import Necessary Modules\n```python\nimport torch\nimport torch.nn as nn\nfrom native_sparse_attention.model import ToyNSALlama, ToyNSALlamaConfig, InferenceConfig\n```\n\n### 2. Define Configurations\nCreate instances of `ToyNSALlamaConfig` and `InferenceConfig` to set model and inference parameters.\n\n#### Model Configuration\nThe model configuration aligns with the Transformers Llama model configuration. Adjust the following parameters to control the NSA module’s sparsity:\n- `compress_type`: Compression method for keys/values. Supported options: `avgpool`, `weightedpool`, `linear`.\n- `kernel_size` & `kernel_stride`: `kernel_size` determines how many tokens are compressed into one; `kernel_stride` sets the sliding window stride (must be divisible by `kernel_size`).\n- `block_size`: Block size for sparse attention (recommended: 64 or 128).\n- `topk`, `init_blocks`, `local_blocks`: `topk` specifies the number of blocks selected in sparse attention; `init_blocks` and `local_blocks` define the number of initial and local blocks that must be selected.\n- `window_size`: Size of the sliding window for attention.\n\nExample:\n```python\nconfig = ToyNSALlamaConfig(\n    hidden_size=4096,\n    intermediate_size=14336,\n    num_hidden_layers=8,\n    num_attention_heads=32,\n    num_key_value_heads=2,\n    head_dim=128,\n    vocab_size=128288,\n    max_position_embeddings=131072,\n    rope_theta=500000.0,\n    rope_scaling={\n        \"factor\": 8.0,\n        \"high_freq_factor\": 4.0,\n        \"low_freq_factor\": 1.0,\n        \"original_max_position_embeddings\": 8192,\n        \"rope_type\": \"llama3\",\n    },\n    compress_type=\"weightedpool\",\n    kernel_size=32,\n    kernel_stride=16,\n    block_size=64,\n    topk=8,\n    init_blocks=1,\n    local_blocks=2,\n    window_size=512,\n)\n```\n\n#### Inference Configuration\nThis configuration applies during inference, initializing the Key-Value (KV) Cache based on these settings. The full KV cache size is calculated as `max_batch_size × max_length × num_kv_heads × num_layers × 2 × 2` bytes. Currently, only greedy decoding is supported as an example.\n\nExample:\n```python\ninference_config = InferenceConfig(\n    max_batch_size=4,\n    max_length=8192,\n    max_new_tokens=128,\n)\n```\n\n### 3. Initialize the Model\nInstantiate the model and move it to the GPU with the appropriate data type (currently, only `bfloat16` is supported).\n\n```python\nmodel = ToyNSALlama(config, inference_config).cuda().to(torch.bfloat16)\n```\n\n### 4. Forward & Generate\nThe model supports two methods:\n- **`forward`**: Accepts `input_ids` and `cu_seqlens`, returning final logits after the language model head. Use this for training or evaluation.\n- **`generate`**: Accepts `input_ids` and `cu_seqlens`, generating output tokens via greedy sampling. This demonstrates KV cache usage for token generation (pre-filling and decoding).\n\nExample:\n```python\n# Example input\nbatch_size = 4\nseqlens = torch.randint(0, 4096, (batch_size,), dtype=torch.int32, device=\"cuda\")\ncu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device=\"cuda\")\ncu_seqlens[1:] = torch.cumsum(seqlens, dim=0)\ninput_ids = torch.randint(0, 128288, (cu_seqlens[-1],), dtype=torch.int64, device=\"cuda\")\nprint(f\"\\nEXAMPLE INPUT:\\ncu_seqlens: {cu_seqlens}\\ninput_ids: {input_ids.shape}\\n\")\n\n# Example forward\nlogits = model(input_ids, cu_seqlens)\nprint(f\"\\nEXAMPLE OUTPUT:\\nlogits: {logits.shape}\\n\")\n\n# Example generate\noutput_tokens = model.generate(input_ids, cu_seqlens)\nprint(f\"\\nEXAMPLE GENERATE:\\noutput_tokens: {output_tokens}\\n\")\n```\n\n## Toy Llama Model with Self-Attention\nA simpler toy model with the Llama structure is available in `native_sparse_attention/model/toy_llama.py`. Compare `ToyLlama` and `ToyNSALlama` to see how to adapt a self-attention model into an NSA-based model.\n\nThe primary difference lies in replacing the `SelfAttention` module with the `NativeSparseAttention` module, along with updates to the KV cache and inference function. These changes are straightforward and easy to implement.\n"
  },
  {
    "path": "native_sparse_attention/model/__init__.py",
    "content": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom native_sparse_attention.model.toy_llama import (\n    ToyLlamaConfig,\n    InferenceConfig,\n    ToyLlama,\n)\nfrom native_sparse_attention.model.toy_nsa_llama import (\n    ToyNSALlamaConfig,\n    InferenceConfig,\n    ToyNSALlama,\n)\n\n__all__ = [\n    \"ToyLlamaConfig\",\n    \"ToyNSALlamaConfig\",\n    \"InferenceConfig\",\n    \"ToyLlama\",\n    \"ToyNSALlama\",\n]\n"
  },
  {
    "path": "native_sparse_attention/model/toy_llama.py",
    "content": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom typing import Optional\nimport torch\nimport torch.nn as nn\nfrom dataclasses import dataclass, field\nfrom native_sparse_attention.module import SelfAttention, RopeConfig, KVCache\n\n\n@dataclass\nclass ToyLlamaConfig:\n    # embedding config\n    vocab_size: int = 128288\n    max_position_embeddings: int = 131072\n    # model config\n    hidden_size: int = 4096\n    intermediate_size: int = 14336\n    num_hidden_layers: int = 32\n    num_attention_heads: int = 32\n    num_key_value_heads: int = 2\n    head_dim: int = 128\n    # rope config\n    rope_theta: float = 500000.0\n    rope_scaling: dict = field(\n        default_factory=lambda: {\n            \"factor\": 8.0,\n            \"high_freq_factor\": 4.0,\n            \"low_freq_factor\": 1.0,\n            \"original_max_position_embeddings\": 8192,\n            \"rope_type\": \"llama3\",\n        }\n    )\n\n\n@dataclass\nclass InferenceConfig:\n    max_batch_size: int = 32\n    max_length: int = 8192\n    max_new_tokens: int = 128\n\n\nclass RMSNorm(nn.Module):\n    def __init__(self, hidden_size: int, eps: float = 1e-6):\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(hidden_size))\n        self.variance_epsilon = eps\n\n    def forward(self, hidden_states: torch.Tensor):\n        input_dtype = hidden_states.dtype\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        return self.weight * hidden_states.to(input_dtype)\n\n\nclass FFN(nn.Module):\n    def __init__(self, hidden_size: int, intermediate_size: int):\n        super().__init__()\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)\n        self.act_fn = nn.SiLU()\n\n    def forward(self, x):\n        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))\n        return down_proj\n\n\nclass ToyLlamaLayer(nn.Module):\n    def __init__(\n        self,\n        hidden_size: int,\n        intermediate_size: int,\n        num_q_heads: int,\n        num_kv_heads: int,\n        head_dim: int,\n        rope_config: RopeConfig,\n    ):\n        super().__init__()\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_q_heads = num_q_heads\n        self.num_kv_heads = num_kv_heads\n        self.head_dim = head_dim\n        self.rope_config = rope_config\n        self.attn_norm = RMSNorm(self.hidden_size)\n        self.self_attn = SelfAttention(\n            hidden_size=self.hidden_size,\n            num_q_heads=self.num_q_heads,\n            num_kv_heads=self.num_kv_heads,\n            head_dim=self.head_dim,\n            rope_config=rope_config,\n        )\n        self.ffn_norm = RMSNorm(self.hidden_size)\n        self.ffn = FFN(\n            hidden_size=self.hidden_size, intermediate_size=self.intermediate_size\n        )\n\n    def forward(self, x, cu_seqlens):\n        x = x + self.self_attn(self.attn_norm(x), cu_seqlens)\n        x = x + self.ffn(self.ffn_norm(x))\n        return x\n\n    @torch.no_grad()\n    def inference(self, x, cu_seqlens, step, kv_cache):\n        x = x + self.self_attn.inference(self.attn_norm(x), cu_seqlens, step, kv_cache)\n        x = x + self.ffn(self.ffn_norm(x))\n        return x\n\n\nclass ToyLlama(nn.Module):\n    def __init__(\n        self, config: ToyLlamaConfig, inference_config: Optional[InferenceConfig] = None\n    ):\n        super().__init__()\n        self.config = config\n        self.embedding = nn.Embedding(self.config.vocab_size, self.config.hidden_size)\n        self.rope_config = RopeConfig(\n            head_dim=self.config.head_dim,\n            rope_theta=self.config.rope_theta,\n            rope_scaling=self.config.rope_scaling,\n        )\n        self.layers = nn.ModuleList(\n            [\n                ToyLlamaLayer(\n                    hidden_size=self.config.hidden_size,\n                    intermediate_size=self.config.intermediate_size,\n                    num_q_heads=self.config.num_attention_heads,\n                    num_kv_heads=self.config.num_key_value_heads,\n                    head_dim=self.config.head_dim,\n                    rope_config=RopeConfig(\n                        self.config.max_position_embeddings,\n                        self.config.head_dim,\n                        self.config.rope_theta,\n                        self.config.rope_scaling,\n                    ),\n                )\n                for _ in range(self.config.num_hidden_layers)\n            ]\n        )\n        self.norm = RMSNorm(self.config.hidden_size)\n        self.lm_head = nn.Linear(\n            self.config.hidden_size, self.config.vocab_size, bias=False\n        )\n\n        # inference config and kv cache\n        self.inference_config = inference_config\n        self.kv_cache = None\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor,  # shape: [total_length, ]\n        cu_seqlens: torch.LongTensor,  # shape: [batch_size + 1, ]\n    ):\n        # embedding\n        x = self.embedding(input_ids).to(torch.bfloat16)\n        # layers\n        for layer in self.layers:\n            x = layer(x, cu_seqlens)\n        # final norm\n        x = self.norm(x)\n        # lanugauge head\n        x = self.lm_head(x).to(torch.float32)  # [total_len, vocab_size]\n        return x\n\n    @torch.no_grad()\n    def inference(\n        self,\n        input_ids: torch.LongTensor,  # prefill shape: [total_length, ]; decode shape: [batch_size, ]\n        cu_seqlens: torch.LongTensor,  # shape: [batch_size + 1, ]\n        step: int,\n    ):\n        # set kv cache if self.kv_cache is None\n        if self.kv_cache is None:\n            self.kv_cache = [\n                KVCache(\n                    max_batch_size=self.inference_config.max_batch_size,\n                    max_length=self.inference_config.max_length,\n                    num_kv_heads=self.config.num_key_value_heads,\n                    head_dim=self.config.head_dim,\n                    dtype=torch.bfloat16,\n                    device=\"cuda\",\n                )\n                for _ in range(self.config.num_hidden_layers)\n            ]\n        # embedding\n        x = self.embedding(input_ids).to(torch.bfloat16)\n        # layers\n        for i, layer in enumerate(self.layers):\n            x = layer.inference(x, cu_seqlens, step, self.kv_cache[i])\n        # final norm\n        x = self.norm(x)\n        # lanugauge head\n        if step == 0:\n            x = x[cu_seqlens[1:] - 1, :]\n        x = self.lm_head(x).to(torch.float32)  # [total_len, vocab_size]\n        return x\n\n    def generate(\n        self,\n        input_ids: torch.LongTensor,\n        cu_seqlens: torch.LongTensor,\n        max_new_tokens: int = -1,\n    ):\n        output_tokens = []\n        if max_new_tokens <= 0:\n            max_new_tokens = self.inference_config.max_new_tokens\n        for step in range(max_new_tokens):\n            logits = self.inference(\n                input_ids, cu_seqlens, step\n            )  # shape: [batch_size, vocab_size]\n            next_token = torch.argmax(logits, dim=-1)  # shape: [batch_size, ]\n            input_ids = next_token\n            output_tokens.append(next_token)\n        output_tokens = torch.stack(\n            output_tokens, dim=1\n        )  # shape: [batch_size, max_new_tokens]\n        return output_tokens\n\n\nif __name__ == \"__main__\":\n    torch.manual_seed(42)\n    # initialize model\n    config = ToyLlamaConfig(\n        hidden_size=4096,\n        intermediate_size=14336,\n        num_hidden_layers=8,\n        num_attention_heads=32,\n        num_key_value_heads=2,\n        head_dim=128,\n        rope_theta=500000.0,\n        rope_scaling={\n            \"factor\": 8.0,\n            \"high_freq_factor\": 4.0,\n            \"low_freq_factor\": 1.0,\n            \"original_max_position_embeddings\": 8192,\n            \"rope_type\": \"llama3\",\n        },\n    )\n    inference_config = InferenceConfig(\n        max_batch_size=4,\n        max_length=8192,\n        max_new_tokens=128,\n    )\n    model = ToyLlama(config, inference_config).cuda().bfloat16()\n    print(f\"\\nMODEL CONFIG:\\n{config}\\n\")\n    print(f\"\\nINFERENCE CONFIG:\\n{inference_config}\\n\")\n    print(f\"\\nMODEL:\\n{model}\\n\")\n\n    # example input\n    batch_size = 4\n    seqlens = torch.randint(0, 4096, (batch_size,), dtype=torch.int32, device=\"cuda\")\n    cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device=\"cuda\")\n    cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)\n    input_ids = torch.randint(\n        0, 128288, (cu_seqlens[-1],), dtype=torch.int64, device=\"cuda\"\n    )\n    print(f\"\\nEXAMPLE INPUT:\\ncu_seqlens: {cu_seqlens}\\ninput_ids: {input_ids.shape}\\n\")\n\n    # example output\n    logits = model(input_ids, cu_seqlens)\n    print(f\"\\nEXAMPLE OUTPUT:\\nlogits: {logits.shape}\\n\")\n\n    # example generate\n    output_tokens = model.generate(input_ids, cu_seqlens, 64)\n    print(f\"\\nEXAMPLE GENERATE:\\noutput_tokens: {output_tokens}\\n\")\n"
  },
  {
    "path": "native_sparse_attention/model/toy_nsa_llama.py",
    "content": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom typing import Optional\nimport torch\nimport torch.nn as nn\nfrom dataclasses import dataclass, field\nfrom native_sparse_attention.module import NativeSparseAttention, RopeConfig, NSACache\n\n\n@dataclass\nclass ToyNSALlamaConfig:\n    # embedding config\n    vocab_size: int = 128288\n    max_position_embeddings: int = 131072\n    # model config\n    hidden_size: int = 4096\n    intermediate_size: int = 14336\n    num_hidden_layers: int = 32\n    num_attention_heads: int = 32\n    num_key_value_heads: int = 2\n    head_dim: int = 128\n    # rope config\n    rope_theta: float = 500000.0\n    rope_scaling: dict = field(\n        default_factory=lambda: {\n            \"factor\": 8.0,\n            \"high_freq_factor\": 4.0,\n            \"low_freq_factor\": 1.0,\n            \"original_max_position_embeddings\": 8192,\n            \"rope_type\": \"llama3\",\n        }\n    )\n    # nsa config\n    compress_type: str = \"weightedpool\"\n    kernel_size: int = 32\n    kernel_stride: int = 16\n    block_size: int = 64\n    topk: int = 16\n    init_blocks: int = 1\n    local_blocks: int = 2\n    window_size: int = 512\n\n\n@dataclass\nclass InferenceConfig:\n    max_batch_size: int = 32\n    max_length: int = 8192\n    max_new_tokens: int = 128\n\n\nclass RMSNorm(nn.Module):\n    def __init__(self, hidden_size: int, eps: float = 1e-6):\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(hidden_size))\n        self.variance_epsilon = eps\n\n    def forward(self, hidden_states: torch.Tensor):\n        input_dtype = hidden_states.dtype\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        return self.weight * hidden_states.to(input_dtype)\n\n\nclass FFN(nn.Module):\n    def __init__(self, hidden_size: int, intermediate_size: int):\n        super().__init__()\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)\n        self.act_fn = nn.SiLU()\n\n    def forward(self, x):\n        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))\n        return down_proj\n\n\nclass ToyNSALlamaLayer(nn.Module):\n    def __init__(\n        self,\n        hidden_size: int,\n        intermediate_size: int,\n        num_q_heads: int,\n        num_kv_heads: int,\n        head_dim: int,\n        compress_type: str,\n        kernel_size: int,\n        kernel_stride: int,\n        block_size: int,\n        topk: int,\n        init_blocks: int,\n        local_blocks: int,\n        window_size: int,\n        rope_config: RopeConfig,\n    ):\n        super().__init__()\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_q_heads = num_q_heads\n        self.num_kv_heads = num_kv_heads\n        self.head_dim = head_dim\n        self.compress_type = compress_type\n        self.kernel_size = kernel_size\n        self.kernel_stride = kernel_stride\n        self.block_size = block_size\n        self.topk = topk\n        self.init_blocks = init_blocks\n        self.local_blocks = local_blocks\n        self.window_size = window_size\n        self.rope_config = rope_config\n        self.attn_norm = RMSNorm(self.hidden_size)\n        self.nsa = NativeSparseAttention(\n            hidden_size=self.hidden_size,\n            num_q_heads=self.num_q_heads,\n            num_kv_heads=self.num_kv_heads,\n            head_dim=self.head_dim,\n            compress_type=self.compress_type,\n            kernel_size=self.kernel_size,\n            kernel_stride=self.kernel_stride,\n            block_size=self.block_size,\n            topk=self.topk,\n            init_blocks=self.init_blocks,\n            local_blocks=self.local_blocks,\n            window_size=self.window_size,\n            rope_config=rope_config,\n        )\n        self.ffn_norm = RMSNorm(self.hidden_size)\n        self.ffn = FFN(\n            hidden_size=self.hidden_size, intermediate_size=self.intermediate_size\n        )\n\n    def forward(self, x, cu_seqlens):\n        x = x + self.nsa(self.attn_norm(x), cu_seqlens)\n        x = x + self.ffn(self.ffn_norm(x))\n        return x\n\n    @torch.no_grad()\n    def inference(self, x, cu_seqlens, step, kv_cache):\n        x = x + self.nsa.inference(self.attn_norm(x), cu_seqlens, step, kv_cache)\n        x = x + self.ffn(self.ffn_norm(x))\n        return x\n\n\nclass ToyNSALlama(nn.Module):\n    def __init__(\n        self,\n        config: ToyNSALlamaConfig,\n        inference_config: Optional[InferenceConfig] = None,\n    ):\n        super().__init__()\n        self.config = config\n        self.embedding = nn.Embedding(self.config.vocab_size, self.config.hidden_size)\n        self.rope_config = RopeConfig(\n            head_dim=self.config.head_dim,\n            rope_theta=self.config.rope_theta,\n            rope_scaling=self.config.rope_scaling,\n        )\n        self.layers = nn.ModuleList(\n            [\n                ToyNSALlamaLayer(\n                    hidden_size=self.config.hidden_size,\n                    intermediate_size=self.config.intermediate_size,\n                    num_q_heads=self.config.num_attention_heads,\n                    num_kv_heads=self.config.num_key_value_heads,\n                    head_dim=self.config.head_dim,\n                    compress_type=self.config.compress_type,\n                    kernel_size=self.config.kernel_size,\n                    kernel_stride=self.config.kernel_stride,\n                    block_size=self.config.block_size,\n                    topk=self.config.topk,\n                    init_blocks=self.config.init_blocks,\n                    local_blocks=self.config.local_blocks,\n                    window_size=self.config.window_size,\n                    rope_config=RopeConfig(\n                        self.config.max_position_embeddings,\n                        self.config.head_dim,\n                        self.config.rope_theta,\n                        self.config.rope_scaling,\n                    ),\n                )\n                for _ in range(self.config.num_hidden_layers)\n            ]\n        )\n        self.norm = RMSNorm(self.config.hidden_size)\n        self.lm_head = nn.Linear(\n            self.config.hidden_size, self.config.vocab_size, bias=False\n        )\n\n        # inference config and kv cache\n        self.inference_config = inference_config\n        self.kv_cache = None\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor,  # shape: [batch_size, max_length]\n        cu_seqlens: torch.LongTensor,  # shape: [batch_size + 1, ]\n    ):\n        # embedding\n        x = self.embedding(input_ids).to(torch.bfloat16)\n        # layers\n        for layer in self.layers:\n            x = layer(x, cu_seqlens)\n        # final norm\n        x = self.norm(x)\n        # lanugauge head\n        x = self.lm_head(x).to(torch.float32)  # [total_len, vocab_size]\n        return x\n\n    @torch.no_grad()\n    def inference(\n        self,\n        input_ids: torch.LongTensor,  # prefill shape: [total_length, ]; decode shape: [batch_size, ]\n        cu_seqlens: torch.LongTensor,  # shape: [batch_size + 1, ]\n        step: int,\n    ):\n        # set kv cache if self.kv_cache is None\n        if self.kv_cache is None:\n            self.kv_cache = [\n                NSACache(\n                    max_batch_size=self.inference_config.max_batch_size,\n                    max_length=self.inference_config.max_length,\n                    num_kv_heads=self.config.num_key_value_heads,\n                    head_dim=self.config.head_dim,\n                    kernel_size=self.config.kernel_size,\n                    kernel_stride=self.config.kernel_stride,\n                    window_size=self.config.window_size,\n                    dtype=torch.bfloat16,\n                    device=\"cuda\",\n                )\n                for _ in range(self.config.num_hidden_layers)\n            ]\n        # embedding\n        x = self.embedding(input_ids).to(torch.bfloat16)\n        # layers\n        for i, layer in enumerate(self.layers):\n            x = layer.inference(x, cu_seqlens, step, self.kv_cache[i])\n        # final norm\n        x = self.norm(x)\n        # lanugauge head\n        if step == 0:\n            x = x[cu_seqlens[1:] - 1, :]\n        x = self.lm_head(x).to(torch.float32)  # [total_len, vocab_size]\n        return x\n\n    def generate(\n        self,\n        input_ids: torch.LongTensor,\n        cu_seqlens: torch.LongTensor,\n        max_new_tokens: int = -1,\n    ):\n        output_tokens = []\n        if max_new_tokens <= 0:\n            max_new_tokens = self.inference_config.max_new_tokens\n        for step in range(max_new_tokens):\n            logits = self.inference(\n                input_ids, cu_seqlens, step\n            )  # shape: [batch_size, vocab_size]\n            next_token = torch.argmax(logits, dim=-1)  # shape: [batch_size, ]\n            input_ids = next_token\n            output_tokens.append(next_token)\n        output_tokens = torch.stack(\n            output_tokens, dim=1\n        )  # shape: [batch_size, max_new_tokens]\n        return output_tokens\n\n\nif __name__ == \"__main__\":\n    torch.manual_seed(42)\n    # initialize model\n    config = ToyNSALlamaConfig(\n        hidden_size=4096,\n        intermediate_size=14336,\n        num_hidden_layers=8,\n        num_attention_heads=32,\n        num_key_value_heads=2,\n        head_dim=128,\n        rope_theta=500000.0,\n        rope_scaling={\n            \"factor\": 8.0,\n            \"high_freq_factor\": 4.0,\n            \"low_freq_factor\": 1.0,\n            \"original_max_position_embeddings\": 8192,\n            \"rope_type\": \"llama3\",\n        },\n        compress_type=\"weightedpool\",\n        kernel_size=32,\n        kernel_stride=16,\n        block_size=64,\n        topk=8,\n        init_blocks=1,\n        local_blocks=2,\n        window_size=512,\n    )\n    inference_config = InferenceConfig(\n        max_batch_size=4,\n        max_length=8192,\n        max_new_tokens=128,\n    )\n    model = ToyNSALlama(config, inference_config).cuda().bfloat16()\n    print(f\"\\nMODEL CONFIG:\\n{config}\\n\")\n    print(f\"\\nINFERENCE CONFIG:\\n{inference_config}\\n\")\n    print(f\"\\nMODEL:\\n{model}\\n\")\n\n    # example input\n    batch_size = 4\n    seqlens = torch.randint(0, 4096, (batch_size,), dtype=torch.int32, device=\"cuda\")\n    cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device=\"cuda\")\n    cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)\n    input_ids = torch.randint(\n        0, 128288, (cu_seqlens[-1],), dtype=torch.int64, device=\"cuda\"\n    )\n    print(f\"\\nEXAMPLE INPUT:\\ncu_seqlens: {cu_seqlens}\\ninput_ids: {input_ids.shape}\\n\")\n\n    # example output\n    logits = model(input_ids, cu_seqlens)\n    print(f\"\\nEXAMPLE OUTPUT:\\nlogits: {logits.shape}\\n\")\n\n    # example generate\n    output_tokens = model.generate(input_ids, cu_seqlens, 64)\n    print(f\"\\nEXAMPLE GENERATE:\\noutput_tokens: {output_tokens}\\n\")\n"
  },
  {
    "path": "native_sparse_attention/module/__init__.py",
    "content": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom native_sparse_attention.module.native_sparse_attention import NativeSparseAttention\nfrom native_sparse_attention.module.self_attention import SelfAttention\nfrom native_sparse_attention.module.rope import RotaryEmbedding, RopeConfig\nfrom native_sparse_attention.module.kv_cache import NSACache, KVCache\n\n__all__ = [\n    \"SelfAttention\",\n    \"NativeSparseAttention\",\n    \"RotaryEmbedding\",\n    \"RopeConfig\",\n    \"NSACache\",\n]\n"
  },
  {
    "path": "native_sparse_attention/module/kv_cache.py",
    "content": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport torch\nimport triton\nimport triton.language as tl\nfrom typing import Union\nfrom native_sparse_attention.ops.triton.utils import get_compressed_seqlens\n\n\nclass KVCache:\n    def __init__(\n        self,\n        max_batch_size: int,\n        max_length: int,\n        num_kv_heads: int,\n        head_dim: int,\n        dtype: torch.dtype,\n        device: Union[str, torch.device],\n    ):\n        self.max_batch_size = max_batch_size\n        self.max_length = max_length\n        self.num_kv_heads = num_kv_heads\n        self.head_dim = head_dim\n        self.dtype = dtype\n        self.device = device\n\n        # alloc kv cache tensor for topk sparse attention\n        self.kv_cache = torch.zeros(\n            2,\n            self.max_batch_size,\n            self.max_length,\n            self.num_kv_heads,\n            self.head_dim,\n            dtype=self.dtype,\n            device=self.device,\n        )\n        self.kv_len = torch.zeros(\n            self.max_batch_size, dtype=torch.int32, device=self.device\n        )\n\n    def reset(self):\n        self.kv_cache.zero_()\n        self.kv_len.zero_()\n\n    def update_kv(\n        self,\n        cu_seqlens: torch.Tensor,\n        step: int,\n        key: torch.Tensor,\n        value: torch.Tensor,\n    ):\n        if step == 0:\n            self._update_kv_prefill(\n                cu_seqlens,\n                step,\n                key,\n                value,\n            )\n        else:\n            self._update_kv_decode(\n                cu_seqlens,\n                step,\n                key,\n                value,\n            )\n\n    def _update_kv_prefill(\n        self,\n        cu_seqlens: torch.Tensor,\n        step: int,\n        key: torch.Tensor,\n        value: torch.Tensor,\n    ):\n        assert step == 0\n        seqlens = cu_seqlens[1:] - cu_seqlens[:-1]\n        batch_size = seqlens.shape[0]\n        # sparse part kv, shape check\n        total_len, num_heads, head_dim = key.shape\n        assert key.shape == value.shape\n        assert num_heads == self.num_kv_heads and head_dim == self.head_dim\n        assert total_len == cu_seqlens[-1].item()\n        # fill sparse part kv cache\n        seq_start, seq_end = cu_seqlens[:-1], cu_seqlens[1:]\n        _fill_kv_cache(\n            self.kv_cache,\n            key,\n            value,\n            seq_start,\n            seq_end,\n        )\n        self.kv_len[:batch_size] = seqlens\n\n    def _update_kv_decode(\n        self,\n        cu_seqlens: torch.Tensor,\n        step: int,\n        key: torch.Tensor,\n        value: torch.Tensor,\n    ):\n        assert step > 0\n        seqlens = cu_seqlens[1:] - cu_seqlens[:-1]\n        # sparse part kv, shape check\n        batch_size, num_heads, head_dim = key.shape\n        assert batch_size == seqlens.shape[0]\n        assert key.shape == value.shape\n        assert num_heads == self.num_kv_heads and head_dim == self.head_dim\n        # fill sparse part kv cache\n        brange = torch.arange(batch_size, dtype=torch.int32, device=key.device)\n        self.kv_cache[0, :batch_size][brange, self.kv_len[:batch_size]] = key\n        self.kv_cache[1, :batch_size][brange, self.kv_len[:batch_size]] = value\n        self.kv_len[:batch_size] += 1\n\n\nclass NSACache:\n    \"\"\"KV cache manager for native sparse attention.\n    Args:\n        max_batch_size (int): max batch size\n        max_length (int): max length, including prompt len and reponse len\n        num_kv_heads (int): number of key/value heads\n        head_dim (int): head dim\n        kernel_size (int): kernel size of compression\n        kernel_stride (int): kernel stride ofr compression\n        window_size (int): window size for sliding window attention\n        dtype (torch.dtype): data type for kv cache, should be same as model weight dtype\n        device (Union[str, torch.device]): default to 'cuda'\n\n    Methods:\n        reset: reset kv cache, should be called before prefilling\n        prepare_compress: store keys/values for compression, should be called before key/value compression at both prefilling and decoding\n        update_kv: update key/value cache, should be called after rope\n    \"\"\"\n\n    def __init__(\n        self,\n        max_batch_size: int,\n        max_length: int,\n        num_kv_heads: int,\n        head_dim: int,\n        kernel_size: int,\n        kernel_stride: int,\n        window_size: int,\n        dtype: torch.dtype,\n        device: Union[str, torch.device] = \"cuda\",\n    ):\n        self.max_batch_size = max_batch_size\n        self.max_length = max_length\n        self.num_kv_heads = num_kv_heads\n        self.head_dim = head_dim\n        self.kernel_size = kernel_size\n        self.kernel_stride = kernel_stride\n        self.window_size = window_size\n        self.dtype = dtype\n        self.device = device\n\n        # alloc kv cache tensor for topk sparse attention\n        self.sparse_kv_cache = torch.zeros(\n            2,\n            self.max_batch_size,\n            self.max_length,\n            self.num_kv_heads,\n            self.head_dim,\n            dtype=self.dtype,\n            device=self.device,\n        )\n        self.sparse_kv_len = torch.zeros(\n            self.max_batch_size, dtype=torch.int32, device=self.device\n        )\n\n        # alloc kv cache tensor for compressed attention\n        self.max_comp_length = (\n            self.max_length - self.kernel_size\n        ) // self.kernel_stride + 1\n        self.compress_kv_cache = torch.zeros(\n            2,\n            self.max_batch_size,\n            self.max_comp_length,\n            self.num_kv_heads,\n            self.head_dim,\n            dtype=self.dtype,\n            device=self.device,\n        )\n        self.compress_kv_len = torch.zeros(\n            self.max_batch_size, dtype=torch.int32, device=self.device\n        )\n        self.before_compress_kv_cache = torch.zeros(\n            2,\n            self.max_batch_size,\n            self.kernel_size,\n            self.num_kv_heads,\n            self.head_dim,\n            dtype=self.dtype,\n            device=self.device,\n        )\n        self.before_compress_kv_len = torch.zeros(\n            self.max_batch_size, dtype=torch.int32, device=self.device\n        )\n\n        # alloc kv cache for sliding window attention\n        self.sliding_kv_cache = torch.zeros(\n            2,\n            self.max_batch_size,\n            self.window_size,\n            self.num_kv_heads,\n            self.head_dim,\n            dtype=self.dtype,\n            device=self.device,\n        )\n        self.sliding_kv_len = torch.zeros(\n            self.max_batch_size, dtype=torch.int32, device=self.device\n        )\n\n    def reset(self):\n        self.compress_kv_cache.zero_()\n        self.compress_kv_len.zero_()\n        self.before_compress_kv_cache.zero_()\n        self.before_compress_kv_len.zero_()\n        self.sparse_kv_cache.zero_()\n        self.sparse_kv_len.zero_()\n        self.sliding_kv_cache.zero_()\n        self.sliding_kv_len.zero_()\n\n    def prepare_compress(\n        self,\n        cu_seqlens: torch.Tensor,\n        step: int,\n        key: torch.Tensor,\n        value: torch.Tensor,\n    ):\n        if step == 0:\n            self._prepare_compress_prefill(cu_seqlens, step, key, value)\n        else:\n            self._prepare_compress_decode(cu_seqlens, step, key, value)\n\n    def _prepare_compress_prefill(\n        self,\n        cu_seqlens: torch.Tensor,\n        step: int,\n        key: torch.Tensor,\n        value: torch.Tensor,\n    ):\n        assert step == 0\n        # compress part kv, shape check\n        batch_size = cu_seqlens.shape[0] - 1\n        total_len, num_heads, head_dim = key.shape\n        assert key.shape == value.shape\n        assert num_heads == self.num_kv_heads and head_dim == self.head_dim\n        comp_seqlens, comp_cu_seqlens = get_compressed_seqlens(\n            cu_seqlens, self.kernel_size, self.kernel_stride\n        )\n        assert total_len == cu_seqlens[-1].item()\n        # fill tmp cache\n        seq_start = cu_seqlens[:-1] + comp_seqlens * self.kernel_stride\n        seq_end = cu_seqlens[1:]\n        _fill_kv_cache(\n            self.before_compress_kv_cache,\n            key,\n            value,\n            seq_start,\n            seq_end,\n        )\n        self.before_compress_kv_len[:batch_size] = seq_end - seq_start\n\n    def _prepare_compress_decode(\n        self,\n        cu_seqlens: torch.Tensor,\n        step: int,\n        key: torch.Tensor,\n        value: torch.Tensor,\n    ):\n        assert step > 0\n        # compress part kv, shape check\n        batch_size, num_heads, head_dim = key.shape\n        assert key.shape == value.shape\n        assert num_heads == self.num_kv_heads and head_dim == self.head_dim\n        assert batch_size == cu_seqlens.shape[0] - 1\n        # sequence with full before_compress_kv\n        idx = torch.where(self.before_compress_kv_len == self.kernel_size)[0].squeeze()\n        self.before_compress_kv_cache[\n            :, idx, : self.kernel_size - self.kernel_stride, :, :\n        ] = self.before_compress_kv_cache[:, idx, self.kernel_stride :, :, :]\n        self.before_compress_kv_len[idx] -= self.kernel_stride\n        # fill new kv\n        brange = torch.arange(batch_size, dtype=torch.int32, device=key.device)\n        self.before_compress_kv_cache[0, :batch_size][\n            brange, self.before_compress_kv_len[:batch_size]\n        ] = key\n        self.before_compress_kv_cache[1, :batch_size][\n            brange, self.before_compress_kv_len[:batch_size]\n        ] = value\n        # update kv len\n        self.before_compress_kv_len[:batch_size] += 1\n\n    def update_kv(\n        self,\n        cu_seqlens: torch.Tensor,\n        step: int,\n        compress_key: torch.Tensor,\n        compress_value: torch.Tensor,\n        sparse_key: torch.Tensor,\n        sparse_value: torch.Tensor,\n        sliding_key: torch.Tensor,\n        sliding_value: torch.Tensor,\n    ):\n        if step == 0:\n            self._update_kv_prefill(\n                cu_seqlens,\n                step,\n                compress_key,\n                compress_value,\n                sparse_key,\n                sparse_value,\n                sliding_key,\n                sliding_value,\n            )\n        else:\n            self._update_kv_decode(\n                cu_seqlens,\n                step,\n                compress_key,\n                compress_value,\n                sparse_key,\n                sparse_value,\n                sliding_key,\n                sliding_value,\n            )\n\n    def _update_kv_prefill(\n        self,\n        cu_seqlens: torch.Tensor,\n        step: int,\n        compress_key: torch.Tensor,\n        compress_value: torch.Tensor,\n        sparse_key: torch.Tensor,\n        sparse_value: torch.Tensor,\n        sliding_key: torch.Tensor,\n        sliding_value: torch.Tensor,\n    ):\n        assert step == 0\n        seqlens = cu_seqlens[1:] - cu_seqlens[:-1]\n        batch_size = seqlens.shape[0]\n        # sparse part kv, shape check\n        total_len, num_heads, head_dim = sparse_key.shape\n        assert sparse_key.shape == sparse_value.shape\n        assert num_heads == self.num_kv_heads and head_dim == self.head_dim\n        assert total_len == cu_seqlens[-1].item()\n        # compress part kv, shape check\n        total_comp_len, num_heads, head_dim = compress_key.shape\n        assert compress_key.shape == compress_value.shape\n        assert num_heads == self.num_kv_heads and head_dim == self.head_dim\n        comp_seqlens, comp_cu_seqlens = get_compressed_seqlens(\n            cu_seqlens, self.kernel_size, self.kernel_stride\n        )\n        assert total_comp_len == comp_cu_seqlens[-1].item()\n        # sliding window part kv, shape check\n        total_len, num_heads, head_dim = sliding_key.shape\n        assert sliding_key.shape == sliding_value.shape\n\n        # fill compress part kv cache\n        seq_start, seq_end = comp_cu_seqlens[:-1], comp_cu_seqlens[1:]\n        _fill_kv_cache(\n            self.compress_kv_cache,\n            compress_key,\n            compress_value,\n            seq_start,\n            seq_end,\n        )\n        self.compress_kv_len[:batch_size] = comp_seqlens\n        # fill sparse part kv cache\n        seq_start, seq_end = cu_seqlens[:-1], cu_seqlens[1:]\n        _fill_kv_cache(\n            self.sparse_kv_cache,\n            sparse_key,\n            sparse_value,\n            seq_start,\n            seq_end,\n        )\n        self.sparse_kv_len[:batch_size] = seqlens\n        # fill sliding part kv cache\n        seq_start = torch.maximum(cu_seqlens[1:] - self.window_size, cu_seqlens[:-1])\n        seq_end = cu_seqlens[1:]\n        _fill_kv_cache(\n            self.sliding_kv_cache,\n            sliding_key,\n            sliding_value,\n            seq_start,\n            seq_end,\n        )\n        self.sliding_kv_len[:batch_size] = seq_end - seq_start\n\n    def _update_kv_decode(\n        self,\n        cu_seqlens: torch.Tensor,\n        step: int,\n        compress_key: torch.Tensor,\n        compress_value: torch.Tensor,\n        sparse_key: torch.Tensor,\n        sparse_value: torch.Tensor,\n        sliding_key: torch.Tensor,\n        sliding_value: torch.Tensor,\n    ):\n        assert step > 0\n        seqlens = cu_seqlens[1:] - cu_seqlens[:-1]\n        # sparse part kv, shape check\n        batch_size, num_heads, head_dim = sparse_key.shape\n        assert batch_size == seqlens.shape[0]\n        assert sparse_key.shape == sparse_value.shape\n        assert num_heads == self.num_kv_heads and head_dim == self.head_dim\n        # compress part kv, shape check\n        batch_size, num_heads, head_dim = compress_key.shape\n        assert batch_size == seqlens.shape[0]\n        assert compress_key.shape == compress_value.shape\n        assert num_heads == self.num_kv_heads and head_dim == self.head_dim\n        # sliding window part kv, shape check\n        total_len, num_heads, head_dim = sliding_key.shape\n        assert sliding_key.shape == sliding_value.shape\n\n        # fill compress part kv cache, only full block need to be compress\n        idx = torch.where(self.before_compress_kv_len == self.kernel_size)[0].squeeze()\n        self.compress_kv_cache[0][idx, self.compress_kv_len[idx]] = compress_key[idx]\n        self.compress_kv_cache[1][idx, self.compress_kv_len[idx]] = compress_value[idx]\n        self.compress_kv_len[idx] += 1\n        # fill sparse part kv cache\n        brange = torch.arange(batch_size, dtype=torch.int32, device=sparse_key.device)\n        self.sparse_kv_cache[0, :batch_size][\n            brange, self.sparse_kv_len[:batch_size]\n        ] = sparse_key\n        self.sparse_kv_cache[1, :batch_size][\n            brange, self.sparse_kv_len[:batch_size]\n        ] = sparse_value\n        self.sparse_kv_len[:batch_size] += 1\n        # fill sliding window kv cache\n        self.sliding_kv_cache[0, :batch_size][\n            brange, self.sliding_kv_len[:batch_size] % self.window_size\n        ] = sliding_key\n        self.sliding_kv_cache[1, :batch_size][\n            brange, self.sliding_kv_len[:batch_size] % self.window_size\n        ] = sliding_value\n        self.sliding_kv_len[:batch_size] += 1\n\n\n@triton.jit\ndef _fill_kv_cache_kernel(\n    cache_ptr,\n    k_ptr,\n    v_ptr,\n    seq_start,\n    seq_end,\n    head_dim,\n    stride_c2,\n    stride_cb,\n    stride_cn,\n    stride_ch,\n    stride_cd,\n    stride_kn,\n    stride_kh,\n    stride_kd,\n    stride_vn,\n    stride_vh,\n    stride_vd,\n    BLOCK_SIZE_N: tl.constexpr,\n    BLOCK_SIZE_D: tl.constexpr,\n):\n    # get batch id and head id\n    pid_2b = tl.program_id(0)\n    pid_2 = pid_2b % 2\n    pid_b = pid_2b // 2\n    pid_h = tl.program_id(1)\n    pid_n = tl.program_id(2)\n    # get kv start and len after rmpad\n    kv_start = tl.load(seq_start + pid_b)\n    kv_end = tl.load(seq_end + pid_b)\n    kv_len = kv_end - kv_start\n    if pid_n * BLOCK_SIZE_N >= kv_len:\n        return\n    # load key or value\n    if pid_2 == 0:\n        kv_ptrs = tl.make_block_ptr(\n            base=k_ptr + kv_start * stride_kn + pid_h * stride_kh,\n            shape=(kv_len, head_dim),\n            strides=(stride_kn, stride_kd),\n            offsets=(pid_n * BLOCK_SIZE_N, 0),\n            block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_D),\n            order=(1, 0),\n        )\n        kv = tl.load(kv_ptrs, boundary_check=(0, 1))\n    else:\n        kv_ptrs = tl.make_block_ptr(\n            base=v_ptr + kv_start * stride_vn + pid_h * stride_vh,\n            shape=(kv_len, head_dim),\n            strides=(stride_vn, stride_vd),\n            offsets=(pid_n * BLOCK_SIZE_N, 0),\n            block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_D),\n            order=(1, 0),\n        )\n        kv = tl.load(kv_ptrs, boundary_check=(0, 1))\n    # store to cache\n    cache_ptrs = tl.make_block_ptr(\n        base=cache_ptr + pid_2 * stride_c2 + pid_b * stride_cb + pid_h * stride_ch,\n        shape=(kv_len, head_dim),\n        strides=(stride_cn, stride_cd),\n        offsets=(pid_n * BLOCK_SIZE_N, 0),\n        block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_D),\n        order=(1, 0),\n    )\n    tl.store(cache_ptrs, kv.to(cache_ptr.dtype.element_ty), boundary_check=(0, 1))\n\n\ndef _fill_kv_cache(\n    kv_cache: torch.Tensor,  # shape: [2, b, n, h, d]\n    key: torch.Tensor,  # shape: [l, h, d]\n    value: torch.Tensor,  # shape: [l, h, d]\n    seq_start: torch.Tensor,  # shape: [b]\n    seq_end: torch.Tensor,  # shape: [b]\n):\n    total_len, num_heads, head_dim = key.shape\n    batch_size = seq_start.shape[0]\n    max_kv_len = (seq_end - seq_start).max().item()\n    # no kv cache to fill\n    if max_kv_len == 0:\n        return\n    BLOCK_SIZE_N = min(1024, triton.next_power_of_2(max_kv_len))\n    BLOCK_SIZE_D = triton.next_power_of_2(head_dim)\n    grid = (2 * batch_size, num_heads, triton.cdiv(max_kv_len, BLOCK_SIZE_N))\n    _fill_kv_cache_kernel[grid](\n        kv_cache,\n        key,\n        value,\n        seq_start,\n        seq_end,\n        head_dim,\n        kv_cache.stride(0),\n        kv_cache.stride(1),\n        kv_cache.stride(2),\n        kv_cache.stride(3),\n        kv_cache.stride(4),\n        key.stride(0),\n        key.stride(1),\n        key.stride(2),\n        value.stride(0),\n        value.stride(1),\n        value.stride(2),\n        BLOCK_SIZE_N=BLOCK_SIZE_N,\n        BLOCK_SIZE_D=BLOCK_SIZE_D,\n    )\n"
  },
  {
    "path": "native_sparse_attention/module/native_sparse_attention.py",
    "content": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport torch\nfrom flash_attn import flash_attn_varlen_func\nfrom native_sparse_attention.ops import (\n    compressed_attention,\n    topk_sparse_attention,\n    avgpool_compress,\n    weightedpool_compress,\n    linear_compress,\n)\nfrom einops import rearrange\nfrom native_sparse_attention.module.rope import RopeConfig, RotaryEmbedding\nfrom native_sparse_attention.infer import nsa_infer\nfrom native_sparse_attention.module.kv_cache import NSACache\n\nCOMPRESS_TYPE_TO_FUNC = {\n    \"avgpool\": avgpool_compress,\n    \"weightedpool\": weightedpool_compress,\n    \"linear\": linear_compress,\n}\n\nCOMPRESS_TYPE_TO_WEIGHT = {\n    \"avgpool\": lambda num_heads, head_dim, kernel_size: None,\n    \"weightedpool\": lambda num_heads, head_dim, kernel_size: torch.nn.Parameter(\n        torch.zeros(num_heads, kernel_size)\n    ),\n    \"linear\": lambda num_heads, head_dim, kernel_size: torch.nn.Parameter(\n        torch.zeros(num_heads, head_dim * kernel_size, head_dim)\n    ),\n}\n\n\nclass NativeSparseAttention(torch.nn.Module):\n    \"\"\"Native sparse attention module, support training and inference\n\n    Args:\n        compress_type (str): key value compression type, currently support ['linear', 'avgpool', 'weightedpool']\n        hidden_size (int): hidden dimension\n        num_q_heads (int): number of query heads\n        num_kv_heads (int): number of key/value heads, must be divisible by num_q_heads\n        head_dim (int): head dim\n        kernel_size (int): kernel size of compression\n        kernel_stride (int): kernel stride ofr compression\n        block_size (int): block size of sparse attention\n        topk (int): topk of sparse attention\n        init_blocks (int): number of blocks at the begining of the sequence, these blocks are force to be computed in sparse attention\n        local_blocks (int): number of blocks at the local window of each query, these blocks are force to be computed in sparse attention\n        window_size (int): window size for sliding window attention\n        rope_config (RopeConfig): config for rotary embedding, see native_sparse_attention.module.rope.RopeConfig for details\n        rope_device (str): device used to store rope freqs\n    \"\"\"\n\n    def __init__(\n        self,\n        compress_type: str,\n        hidden_size: int,\n        num_q_heads: int,\n        num_kv_heads: int,\n        head_dim: int,\n        kernel_size: int,\n        kernel_stride: int,\n        block_size: int,\n        topk: int,\n        init_blocks: int,\n        local_blocks: int,\n        window_size: int,\n        rope_config: RopeConfig,\n        rope_device: str = \"cuda\",\n    ):\n        super().__init__()\n        # configs\n        self.compress_type = compress_type\n        self.hidden_size = hidden_size\n        self.num_q_heads = num_q_heads\n        self.num_kv_heads = num_kv_heads\n        self.head_dim = head_dim\n        self.kernel_size = kernel_size\n        self.kernel_stride = kernel_stride\n        self.block_size = block_size\n        self.topk = topk\n        self.init_blocks = init_blocks\n        self.local_blocks = local_blocks\n        self.window_size = window_size\n        self.rope_config = rope_config\n        assert self.head_dim == self.rope_config.head_dim\n\n        # qkv proj and o proj\n        self.proj_q = torch.nn.Linear(\n            self.hidden_size, self.num_q_heads * self.head_dim, bias=False\n        )\n        self.proj_k = torch.nn.Linear(\n            self.hidden_size, self.num_kv_heads * self.head_dim, bias=False\n        )\n        self.proj_v = torch.nn.Linear(\n            self.hidden_size, self.num_kv_heads * self.head_dim, bias=False\n        )\n        self.proj_o = torch.nn.Linear(\n            self.num_q_heads * self.head_dim, self.hidden_size, bias=False\n        )\n\n        # nsa compress func\n        self.compress_func = COMPRESS_TYPE_TO_FUNC[self.compress_type]\n\n        # nsa parameteres\n        self.compress_key = COMPRESS_TYPE_TO_WEIGHT[self.compress_type](\n            num_kv_heads, head_dim, kernel_size\n        )\n\n        self.compress_value = COMPRESS_TYPE_TO_WEIGHT[self.compress_type](\n            num_kv_heads, head_dim, kernel_size\n        )\n        self.intra_block_pe = torch.nn.Parameter(\n            torch.zeros(self.num_kv_heads, self.kernel_size, self.head_dim)\n        )\n\n        # gate function\n        self.gate = torch.nn.Sequential(\n            torch.nn.Linear(self.hidden_size, self.num_q_heads * 3, bias=False),\n            torch.nn.Sigmoid(),\n        )\n\n        # rope\n        self.rope = RotaryEmbedding(self.rope_config, device=rope_device)\n\n        # init parameters\n        self.init_params()\n\n    def init_params(self):\n        for p in self.parameters():\n            if len(p.shape) > 1:\n                torch.nn.init.xavier_uniform_(p)\n\n    def forward(\n        self,\n        x: torch.Tensor,  # shape: [total_len, hidden_size]\n        cu_seqlens: torch.Tensor,  # shape: [batch_size + 1]\n    ):\n        # dtype and shape check\n        assert x.dtype == torch.bfloat16 or x.dtype == torch.float16\n        assert x.shape[-1] == self.hidden_size\n        cu_seqlens = cu_seqlens.to(torch.int32)\n        seqlens = cu_seqlens[1:] - cu_seqlens[:-1]\n\n        # qkv proj\n        q = self.proj_q(x).view(-1, self.num_q_heads, self.head_dim)\n        k = self.proj_k(x).view(-1, self.num_kv_heads, self.head_dim)\n        v = self.proj_v(x).view(-1, self.num_kv_heads, self.head_dim)\n\n        # compressed key and value before rope\n        compressed_k, compressed_cu_seqlens = self.compress_func(\n            k,\n            self.compress_key,\n            cu_seqlens,\n            self.kernel_size,\n            self.kernel_stride,\n            self.intra_block_pe,\n        )\n        compressed_v, _ = self.compress_func(\n            v,\n            self.compress_value,\n            cu_seqlens,\n            self.kernel_size,\n            self.kernel_stride,\n            None,\n        )\n\n        # do rope for query and compressed key\n        q = self.rope(q, cu_seqlens)\n        compressed_k = self.rope(\n            compressed_k, compressed_cu_seqlens, stride=self.kernel_stride\n        )\n\n        # attention between query and compressed key value\n        compressed_seqlens = compressed_cu_seqlens[1:] - compressed_cu_seqlens[:-1]\n        compressed_attn_output, topk_idx = compressed_attention(\n            q,\n            compressed_k,\n            compressed_v,\n            self.kernel_size,\n            self.kernel_stride,\n            self.block_size,\n            self.topk,\n            cu_seqlens,\n            compressed_cu_seqlens,\n            seqlens.max().item(),\n            compressed_seqlens.max().item(),\n            None,\n            self.init_blocks,\n            self.local_blocks,\n        )\n\n        # do rope for original key\n        k = self.rope(k, cu_seqlens)\n\n        # topk sparse attention\n        sparse_attn_output = topk_sparse_attention(\n            q, k, v, topk_idx, self.block_size, cu_seqlens, None\n        )\n\n        # sliding window attention\n        sliding_attn_output = flash_attn_varlen_func(\n            q,\n            k,\n            v,\n            cu_seqlens,\n            cu_seqlens,\n            seqlens.max().item(),\n            seqlens.max().item(),\n            causal=True,\n            window_size=(self.window_size, -1),\n        )\n\n        # gate average\n        gate = self.gate(x)\n        gate = rearrange(gate, \"n (h g) -> n h g\", g=3)\n        attn_output = (\n            gate[..., 0:1] * compressed_attn_output\n            + gate[..., 1:2] * sparse_attn_output\n            + gate[..., 2:3] * sliding_attn_output\n        )\n\n        # rearrange and output proj\n        attn_output = rearrange(attn_output, \"n h d -> n (h d)\")\n        attn_output = self.proj_o(attn_output)\n\n        return attn_output\n\n    @torch.no_grad()\n    def inference(\n        self,\n        x: torch.Tensor,  # shape: [total_len, hidden_size]\n        cu_seqlens: torch.Tensor,  # shape: [batch_size + 1]\n        step: int,\n        cache: NSACache,\n    ):\n        # dtype and shape check\n        assert x.dtype == torch.bfloat16 or x.dtype == torch.float16\n        assert x.shape[-1] == self.hidden_size\n        cu_seqlens = cu_seqlens.to(torch.int32)\n        assert step >= 0\n        if step == 0:\n            assert x.shape[0] == cu_seqlens[-1]\n        else:\n            assert x.shape[0] == cu_seqlens.shape[0] - 1\n        # qkv proj\n        q = self.proj_q(x).view(-1, self.num_q_heads, self.head_dim)\n        k = self.proj_k(x).view(-1, self.num_kv_heads, self.head_dim)\n        v = self.proj_v(x).view(-1, self.num_kv_heads, self.head_dim)\n        # gate proj\n        gate = self.gate(x)\n        gate = rearrange(gate, \"n (h g) -> n h g\", g=3)\n        # nsa infer\n        output = nsa_infer(\n            cu_seqlens,\n            step,\n            q,\n            k,\n            v,\n            gate,\n            self.rope,\n            cache,\n            [self.compress_key, self.compress_value],\n            [self.compress_func, self.compress_func],\n            self.intra_block_pe,\n            self.kernel_size,\n            self.kernel_stride,\n            self.block_size,\n            self.topk,\n            self.init_blocks,\n            self.local_blocks,\n            self.window_size,\n        )\n        # output proj\n        output = rearrange(output, \"n h d -> n (h d)\")\n        output = self.proj_o(output)\n        return output\n"
  },
  {
    "path": "native_sparse_attention/module/rope.py",
    "content": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport torch\nfrom dataclasses import dataclass, field\nfrom torch import nn\nfrom transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS\n\n\n# default to llama3.1 rope config\n@dataclass\nclass RopeConfig:\n    \"\"\"Config for RotaryEmbedding, similar to transformers llama.\"\"\"\n\n    max_position_embeddings: int = 131072\n    head_dim: int = 128\n    rope_theta: float = 500000\n    rope_scaling: dict = field(\n        default_factory=lambda: {\n            \"factor\": 8.0,\n            \"high_freq_factor\": 4.0,\n            \"low_freq_factor\": 1.0,\n            \"original_max_position_embeddings\": 8192,\n            \"rope_type\": \"llama3\",\n        }\n    )\n    # useless, just for compatibility, please use head_dim instead\n    hidden_size: int = 1\n    num_attention_heads: int = 1\n\n    def __post_init__(self):\n        self.num_attention_heads = 1\n        self.hidden_size = self.head_dim\n\n\n# Copied from transformers.models.llama.modeling_llama.rotate_half\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\n# copy and modify from modify from hugigngface transformers\n# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py\nclass RotaryEmbedding(nn.Module):\n    \"\"\"Rotary embedding\n\n    Args:\n        config (RopeConfig): config for rotary embedding, see native_sparse_attention.module.rope.RopeConfig for details\n        device (str): default to 'cuda'\n    \"\"\"\n\n    cos = None\n    sin = None\n\n    def __init__(\n        self, config: RopeConfig, device=torch.device(torch.cuda.current_device())\n    ):\n        super().__init__()\n        # BC: \"rope_type\" was originally \"type\"\n        if hasattr(config, \"rope_scaling\") and config.rope_scaling is not None:\n            self.rope_type = config.rope_scaling.get(\n                \"rope_type\", config.rope_scaling.get(\"type\")\n            )\n        else:\n            self.rope_type = \"default\"\n        self.max_seq_len_cached = config.max_position_embeddings\n        self.original_max_seq_len = config.max_position_embeddings\n\n        self.config = config\n        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]\n\n        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n        self.original_inv_freq = self.inv_freq\n\n    def _dynamic_frequency_update(self, position_ids, device):\n        \"\"\"\n        dynamic RoPE layers should recompute `inv_freq` in the following situations:\n        1 - growing beyond the cached sequence length (allow scaling)\n        2 - the current sequence length is in the original scale (avoid losing precision with small sequences)\n        \"\"\"\n        seq_len = torch.max(position_ids) + 1\n        if seq_len > self.max_seq_len_cached:  # growth\n            inv_freq, self.attention_scaling = self.rope_init_fn(\n                self.config, device, seq_len=seq_len\n            )\n            self.register_buffer(\n                \"inv_freq\", inv_freq, persistent=False\n            )  # TODO joao: may break with compilation\n            self.max_seq_len_cached = seq_len\n\n        if (\n            seq_len < self.original_max_seq_len\n            and self.max_seq_len_cached > self.original_max_seq_len\n        ):  # reset\n            # This .to() is needed if the model has been moved to a device after being initialized (because\n            # the buffer is automatically moved, but not the original copy)\n            self.original_inv_freq = self.original_inv_freq.to(device)\n            self.register_buffer(\"inv_freq\", self.original_inv_freq, persistent=False)\n            self.max_seq_len_cached = self.original_max_seq_len\n\n    @torch.no_grad()\n    def generate_cos_sin(self, x: torch.Tensor, position_ids):\n        if \"dynamic\" in self.rope_type:\n            self._dynamic_frequency_update(position_ids, device=x.device)\n\n        # Core RoPE block\n        inv_freq_expanded = (\n            self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)\n        )\n        position_ids_expanded = position_ids[:, None, :].float()\n        # Force float32 (see https://github.com/huggingface/transformers/pull/29285)\n        device_type = x.device.type\n        device_type = (\n            device_type\n            if isinstance(device_type, str) and device_type != \"mps\"\n            else \"cpu\"\n        )\n        with torch.autocast(device_type=device_type, enabled=False):\n            freqs = (\n                inv_freq_expanded.float() @ position_ids_expanded.float()\n            ).transpose(1, 2)\n            # # donot use this if use flash_attn\n            # emb = torch.cat((freqs, freqs), dim=-1)\n            emb = freqs\n            cos = emb.cos()\n            sin = emb.sin()\n\n        # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention\n        cos = (cos * self.attention_scaling).to(dtype=x.dtype).squeeze(0)\n        sin = (sin * self.attention_scaling).to(dtype=x.dtype).squeeze(0)\n\n        # save cos sin\n        RotaryEmbedding.cos = torch.cat([cos, cos], dim=-1)\n        RotaryEmbedding.sin = torch.cat([sin, sin], dim=-1)\n\n        return RotaryEmbedding.cos, RotaryEmbedding.sin\n\n    @torch.no_grad()\n    def generate_pos_embs(\n        self,\n        x: torch.Tensor,\n        cu_seqlens: torch.Tensor,\n        seqlens: torch.Tensor,\n        step: int = 0,\n        stride: int = 1,\n    ):\n        if (\n            RotaryEmbedding.cos is None\n            or seqlens.max() + step > RotaryEmbedding.cos.shape[0]\n        ):\n            self.generate_cos_sin(\n                x, torch.arange(seqlens.max() + step).to(x.device)[None, :]\n            )\n\n        cos_embs = []\n        sin_embs = []\n        bsz = len(cu_seqlens) - 1\n\n        for i in range(bsz):\n            if step == 0:  # prefilling\n                r = cu_seqlens[i + 1] - cu_seqlens[i]\n                cos_emb, sin_emb = (\n                    RotaryEmbedding.cos[: r * stride : stride],\n                    RotaryEmbedding.sin[: r * stride : stride],\n                )\n            elif step > 0:  # decoding\n                r = cu_seqlens[i + 1] - cu_seqlens[i] + step - 1\n                cos_emb, sin_emb = (\n                    RotaryEmbedding.cos[r * stride : r * stride + 1],\n                    RotaryEmbedding.sin[r * stride : r * stride + 1],\n                )\n            cos_embs.append(cos_emb)\n            sin_embs.append(sin_emb)\n\n        cos_embs = torch.cat(cos_embs, dim=0)\n        sin_embs = torch.cat(sin_embs, dim=0)\n        return cos_embs, sin_embs\n\n    def forward(self, x, cu_seqlens, step=0, stride=1):\n        seqlens = cu_seqlens[1:] - cu_seqlens[:-1]\n        cos_embs, sin_embs = self.generate_pos_embs(\n            x,\n            cu_seqlens,\n            seqlens,\n            step=step,\n            stride=stride,\n        )\n        N, H, D = x.shape[0], x.shape[-2], x.shape[-1]  # H: number of heads\n        x = x * cos_embs.view(N, 1, D) + rotate_half(x) * sin_embs.view(N, 1, D)\n        return x\n"
  },
  {
    "path": "native_sparse_attention/module/self_attention.py",
    "content": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport torch\nfrom flash_attn import flash_attn_varlen_func\nfrom einops import rearrange\nfrom native_sparse_attention.module.rope import RopeConfig, RotaryEmbedding\nfrom native_sparse_attention.module.kv_cache import KVCache\nfrom native_sparse_attention.ops import flash_attention_decode\n\n\nclass SelfAttention(torch.nn.Module):\n    \"\"\"self attention module\n\n    Args:\n        hidden_size (int): hidden dimension\n        num_q_heads (int): number of query heads\n        num_kv_heads (int): number of key/value heads, must be divisible by num_q_heads\n        head_dim (int): head dim\n        rope_config (RopeConfig): config for rotary embedding, see native_sparse_attention.module.rope.RopeConfig for details\n    \"\"\"\n\n    def __init__(\n        self,\n        hidden_size: int,\n        num_q_heads: int,\n        num_kv_heads: int,\n        head_dim: int,\n        rope_config: RopeConfig,\n        rope_device: str = \"cuda\",\n    ):\n        super().__init__()\n        # configs\n        self.hidden_size = hidden_size\n        self.num_q_heads = num_q_heads\n        self.num_kv_heads = num_kv_heads\n        self.head_dim = head_dim\n        self.rope_config = rope_config\n        assert self.head_dim == self.rope_config.head_dim\n\n        # qkv proj and o proj\n        self.proj_q = torch.nn.Linear(\n            self.hidden_size, self.num_q_heads * self.head_dim, bias=False\n        )\n        self.proj_k = torch.nn.Linear(\n            self.hidden_size, self.num_kv_heads * self.head_dim, bias=False\n        )\n        self.proj_v = torch.nn.Linear(\n            self.hidden_size, self.num_kv_heads * self.head_dim, bias=False\n        )\n        self.proj_o = torch.nn.Linear(\n            self.num_q_heads * self.head_dim, self.hidden_size, bias=False\n        )\n        # rope\n        self.rope = RotaryEmbedding(self.rope_config, device=rope_device)\n\n        # init parameters\n        self.init_params()\n\n    def init_params(self):\n        for p in self.parameters():\n            torch.nn.init.xavier_uniform_(p)\n\n    def forward(\n        self,\n        x: torch.Tensor,  # shape: [total_len, hidden_size]\n        cu_seqlens: torch.Tensor,  # shape: [batch_size + 1]\n    ):\n        # dtype and shape check\n        assert x.dtype == torch.bfloat16 or x.dtype == torch.float16\n        assert x.shape[-1] == self.hidden_size\n        cu_seqlens = cu_seqlens.to(torch.int32)\n        seqlens = cu_seqlens[1:] - cu_seqlens[:-1]\n\n        # qkv proj\n        q = self.proj_q(x).view(-1, self.num_q_heads, self.head_dim)\n        k = self.proj_k(x).view(-1, self.num_kv_heads, self.head_dim)\n        v = self.proj_v(x).view(-1, self.num_kv_heads, self.head_dim)\n\n        # do rope for query and compressed key\n        q = self.rope(q, cu_seqlens)\n        k = self.rope(k, cu_seqlens)\n\n        # self attention\n        attn_output = flash_attn_varlen_func(\n            q,\n            k,\n            v,\n            cu_seqlens,\n            cu_seqlens,\n            seqlens.max().item(),\n            seqlens.max().item(),\n            causal=True,\n        )\n\n        # rearrange and output proj\n        attn_output = rearrange(attn_output, \"n h d -> n (h d)\")\n        attn_output = self.proj_o(attn_output)\n\n        return attn_output\n\n    @torch.no_grad()\n    def inference(\n        self,\n        x: torch.Tensor,  # shape: [total_len, hidden_size]\n        cu_seqlens: torch.Tensor,  # shape: [batch_size + 1]\n        step: int,\n        cache: KVCache,\n    ):\n        # dtype and shape check\n        assert x.dtype == torch.bfloat16 or x.dtype == torch.float16\n        assert x.shape[-1] == self.hidden_size\n        cu_seqlens = cu_seqlens.to(torch.int32)\n        seqlens = cu_seqlens[1:] - cu_seqlens[:-1]\n        assert step >= 0\n        if step == 0:\n            assert x.shape[0] == cu_seqlens[-1]\n        else:\n            assert x.shape[0] == cu_seqlens.shape[0] - 1\n        batch_size = cu_seqlens.shape[0] - 1\n        # qkv proj\n        q = self.proj_q(x).view(-1, self.num_q_heads, self.head_dim)\n        k = self.proj_k(x).view(-1, self.num_kv_heads, self.head_dim)\n        v = self.proj_v(x).view(-1, self.num_kv_heads, self.head_dim)\n        # do rope for query and compressed key\n        q = self.rope(q, cu_seqlens, step)\n        k = self.rope(k, cu_seqlens, step)\n        # reset and update kv cache\n        if step == 0:\n            cache.reset()\n        cache.update_kv(cu_seqlens, step, k, v)\n        # self attention\n        if step == 0:\n            cu_seqlens_q = cu_seqlens_k = cu_seqlens\n            max_seqlen_in_batch_q = max_seqlen_in_batch_k = seqlens.max().item()\n            output = flash_attn_varlen_func(\n                q,\n                k,\n                v,\n                cu_seqlens_q=cu_seqlens_q,\n                cu_seqlens_k=cu_seqlens_k,\n                max_seqlen_q=max_seqlen_in_batch_q,\n                max_seqlen_k=max_seqlen_in_batch_k,\n                causal=True,\n            )\n        else:\n            output = flash_attention_decode(\n                q,\n                cache.kv_cache[0, :batch_size],\n                cache.kv_cache[1, :batch_size],\n                cache.kv_len[:batch_size],\n            )\n        # rearrange and output proj\n        output = rearrange(output, \"n h d -> n (h d)\")\n        output = self.proj_o(output)\n        return output\n"
  },
  {
    "path": "native_sparse_attention/ops/README.md",
    "content": "# Triton Functions for Native Sparse Attention\n\nThis folder provides efficient Triton-based implementations of components for Native Sparse Attention. This README introduces the available functions, explains how to set them up, and offers guidance on their usage.\n\n---\n\n## Overview of Functions\n\nThe functions are organized into two main categories:\n\n1. **Compression Methods**: Techniques for compressing key and value tensors.\n2. **Attention Mechanisms**: Methods for computing attention between queries and compressed key/value tensors, including top-k sparse attention.\n\n---\n\n## Function Descriptions\n\n### Compression Methods\n\nThese functions compress key and value tensors using a sliding window approach. Within each window, `kernel_size` tokens are compressed into a single token, with a stride of `kernel_stride`. All compression functions share similar input parameters and return formats.\n\n**Parameters:**\n- `x`: Input tensor (`total_len, num_heads, head_dim`)\n- `w`: Weight tensor (shape varies by compression method)\n- `cu_seqlens`: Cumulative sequence lengths (`batch_size + 1`)\n- `kernel_size`: Size of the compression window\n- `kernel_stride`: Stride of the compression window\n- `pe`: Optional positional embedding (`num_heads, kernel_size, head_dim`)\n\n**Returns:**\n- Compressed tensor (`total_compress_len, num_heads, head_dim`)\n- Cumulative sequence lengths (`com_cu_seqlens`) for the compressed tensor\n\n#### `weightedpool_compress`\nCompresses the input tensor using weighted pooling, applying a weighted sum over each block:  \n$\\hat{k} = w_1 k_1 + \\dots + w_m k_m$  \n- **Weight shape**: `(num_heads, kernel_size)`\n\n#### `avgpool_compress`\nCompresses the input tensor using average pooling:  \n$\\hat{k} = (k_1 + \\dots + k_m) / m$  \n- **Weight**: Must be `None`\n\n#### `linear_compress`\nCompresses the input tensor via linear projection, mapping each block to a single vector using learned weights:  \n$\\hat{k} = \\text{cat}(k_1, \\dots, k_m) W$  \n- **Weight shape**: `(num_heads, kernel_size * head_dim, head_dim)`\n\n---\n\n### Attention Mechanisms\n\nThese functions compute attention using either full or sparse mechanisms.\n\n#### `flash_attention_varlen`\nA variable-length implementation of flash attention, similar to `flash_attn_varlen_func` from the `flash_attn` package.\n\n**Parameters:**\n- `q`, `k`, `v`: Query, key, and value tensors (`total_len, num_heads, head_dim`)\n- `cu_seqlens_q`, `cu_seqlens_k`: Cumulative sequence lengths for queries and keys\n- `max_seqlen_q`, `max_seqlen_k`: Maximum sequence lengths in the batch\n- `causal`: Apply causal masking (default: `False`)\n- `sm_scale`: Softmax scale (default: `1 / sqrt(head_dim)`)\n\n**Returns:**\n- Attention output tensor (`total_q_len, num_q_heads, head_dim`)\n\n#### `compressed_attention`\nComputes attention between a query and compressed key/value tensors, identifying the top-k blocks for sparse attention.\n\n**Parameters:**\n- `q`: Query tensor (`total_len, num_heads, head_dim`)\n- `k`, `v`: Compressed key and value tensors (`total_compress_len, num_heads, head_dim`)\n- `kernel_size`, `kernel_stride`: Compression parameters\n- `block_size`: Size of blocks for sparse attention\n- `topk`: Number of top blocks to select\n- `cu_seqlens_q`, `cu_seqlens_k`: Cumulative sequence lengths for query and compressed key/value\n- `max_seqlen_q`, `max_seqlen_k`: Maximum sequence lengths for query and compressed key/value\n- `sm_scale`: Softmax scale (default: `1 / sqrt(head_dim)`)\n- `init_blocks`: Number of initial blocks forced to be selected (default: `1`)\n- `local_blocks`: Number of local blocks forced to be selected (default: `2`)\n\n**Returns:**\n- Tuple containing:\n  - Attention output tensor\n  - Top-k block indices\n\n#### `topk_sparse_attention`\nPerforms sparse attention using precomputed top-k block indices. If a query attends to fewer than `topk` key/value blocks, the `topk_idx` should be padded with `-1` on the right.\n\n**Parameters:**\n- `q`, `k`, `v`: Query, key, and value tensors (`total_len, num_heads, head_dim`)\n- `topk_idx`: Precomputed top-k indices (`num_kv_heads, total_len, topk`)\n- `block_size`: Block size for sparse attention (recommended: `64` or `128`)\n- `cu_seqlens`: Cumulative sequence lengths\n- `softmax_scale`: Softmax scale (default: `1 / sqrt(head_dim)`)\n\n**Returns:**\n- Attention output tensor (`total_len, num_q_heads, head_dim`)\n\n---\n\n## Usage Example\n\nBelow is a typical workflow demonstrating how to combine these sparse attention functions:\n\n```python\nimport torch\nfrom native_sparse_attention.ops import linear_compress, compressed_attention, topk_sparse_attention\n\n# Example input setup\nnum_q_heads = 64\nnum_kv_heads = 4\nhead_dim = 128\ncu_seqlens = torch.tensor([0, 1024, 8192, 16384], dtype=torch.int32).cuda()\n\n# Query, key, and value tensors\nquery = torch.randn(16384, num_q_heads, head_dim, dtype=torch.bfloat16).cuda()\nkey = torch.randn(16384, num_kv_heads, head_dim, dtype=torch.bfloat16).cuda()\nvalue = torch.randn(16384, num_kv_heads, head_dim, dtype=torch.bfloat16).cuda()\n\n# Compression weights and positional embeddings\nkernel_size = 32\nkernel_stride = 16\nwk = torch.randn(num_kv_heads, kernel_size * head_dim, head_dim, dtype=torch.bfloat16).cuda()\nwv = torch.randn_like(wk)\npe = torch.randn(num_kv_heads, kernel_size, head_dim, dtype=torch.bfloat16).cuda()\n\n# Parameters for top-k sparse attention\nblock_size = 64\ntopk = 16\n\n# 1. Compress key and value tensors\ncompressed_key, compressed_cu_seqlens = linear_compress(\n    key, wk, cu_seqlens, kernel_size, kernel_stride, pe\n)\ncompressed_value, _ = linear_compress(\n    value, wv, cu_seqlens, kernel_size, kernel_stride, None\n)\n\n# 2. Compute attention with compressed key/value and get top-k indices\ncompressed_attn_output, topk_idx = compressed_attention(\n    query,\n    compressed_key,\n    compressed_value,\n    kernel_size,\n    kernel_stride,\n    block_size,\n    topk,\n    cu_seqlens,\n    compressed_cu_seqlens,\n    init_blocks=1,\n    local_blocks=2,\n)\n\n# 3. Perform top-k sparse attention\nsparse_attn_output = topk_sparse_attention(\n    query,\n    key,\n    value,\n    topk_idx,\n    block_size,\n    cu_seqlens,\n)\n\n# 4. Combine attention outputs (e.g., average)\nattn_output = (compressed_attn_output + sparse_attn_output) / 2\n```\n\nFor a complete implementation of the Native Sparse Attention module, see `native_sparse_attention/module/native_sparse_attention.py`.\n"
  },
  {
    "path": "native_sparse_attention/ops/__init__.py",
    "content": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# compress method\nfrom native_sparse_attention.ops.triton.weighted_pool import (\n    weightedpool_compress,\n    avgpool_compress,\n)\nfrom native_sparse_attention.ops.triton.linear_compress import linear_compress\n\n# prefill attention\nfrom native_sparse_attention.ops.triton.flash_attention import flash_attention_varlen\nfrom native_sparse_attention.ops.triton.compressed_attention import compressed_attention\nfrom native_sparse_attention.ops.triton.topk_sparse_attention import (\n    topk_sparse_attention,\n)\n\n# decode attention\nfrom native_sparse_attention.ops.triton.flash_attention_decode import (\n    flash_attention_decode,\n)\nfrom native_sparse_attention.ops.torch.compressed_attention_decode import (\n    compressed_attention_decode,\n)\nfrom native_sparse_attention.ops.triton.topk_sparse_attention_decode import (\n    topk_sparse_attention_decode,\n)\n\n__all__ = [\n    # compress method\n    \"avgpool_compress\",\n    \"weightedpool_compress\",\n    \"linear_compress\",\n    # prefill attention, trainable\n    \"flash_attention_varlen\",\n    \"compressed_attention\",\n    \"topk_sparse_attention\",\n    # decode attention, no grad\n    \"flash_attention_decode\",\n    \"compressed_attention_decode\",\n    \"topk_sparse_attention_decode\",\n]\n"
  },
  {
    "path": "native_sparse_attention/ops/torch/__init__.py",
    "content": ""
  },
  {
    "path": "native_sparse_attention/ops/torch/compress_key_value.py",
    "content": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport torch\nfrom typing import Optional\nfrom einops import rearrange, einsum\n\n\ndef avgpool_compress_torch(\n    x: torch.Tensor,\n    w: torch.Tensor,\n    cu_seqlens,\n    kernel_size: int,\n    kernel_stride: int,\n    pe: Optional[torch.Tensor] = None,\n):\n    \"\"\"Compress key and value tensor with kernel_size and kernel_stride.\n\n    Args:\n        x (torch.Tensor): key_states or value_states, shape (total_len, num_heads, head_dim)\n        w (torch.Tensor): no weight for avgpool, must be None.\n        cu_seqlens (_type_): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.\n        kernel_size (int): kernel_size, each (kernel_size, head_dim) blocks will be compressed to (1, head_dim)\n        kernel_stride (int): stride for each compress kernel\n        pe (Optional[torch.Tensor], optional): intra-block positional embedding with shape (num_heads, kernel_size, head_dim). Defaults to None.\n\n    Returns:\n        Tuple[torch.Tensor, torch.Tensor]: compressed states and corresponding cu_seqlens.\n    \"\"\"\n    # dtype check\n    assert x.dtype == torch.float16 or x.dtype == torch.bfloat16\n    assert cu_seqlens.dtype == torch.int32\n    assert x.dtype == pe.dtype if pe is not None else True\n\n    # shape check\n    total_len, num_heads, head_dim = x.shape\n    batch_size = cu_seqlens.shape[0] - 1\n    assert w is None, \"don't need additional weight for avgpool\"\n    assert kernel_size % kernel_stride == 0\n    assert kernel_size in {16, 32, 64, 128}\n\n    # compute seqlens after compression\n    seqlens = cu_seqlens[1:] - cu_seqlens[:-1]\n    y_seqlens = torch.floor((seqlens - kernel_size) / kernel_stride).to(torch.int32) + 1\n    # corner case, if sequence_length < kernel_size, no compression for this sequence\n    y_seqlens[seqlens < kernel_size] = 0\n    y_cu_seqlens = torch.cat(\n        [\n            torch.zeros(1, dtype=torch.int32, device=\"cuda\"),\n            torch.cumsum(y_seqlens, dim=0),\n        ],\n        dim=0,\n    ).to(torch.int32)\n\n    # pad and rearrange x\n    x = rearrange(x, \"n h d -> n (h d)\")\n    splited_x = torch.split(x, seqlens.tolist(), 0)\n    x = torch.nn.utils.rnn.pad_sequence(splited_x, batch_first=True)\n    x = rearrange(x, \"b n d -> b d n\")\n    # avgpool\n    y = torch.nn.functional.avg_pool1d(x, kernel_size=kernel_size, stride=kernel_stride)\n    y = rearrange(y, \"b (h d) n -> b n h d\", h=num_heads)\n    # only keep useful part\n    y = torch.cat([y[i, : y_seqlens[i]] for i in range(batch_size)], dim=0)\n\n    # position embedding as a bias\n    if pe is not None:\n        bias = torch.mean(pe, dim=1)\n        y = y + bias.unsqueeze(0)\n\n    return y, y_cu_seqlens\n\n\ndef weightedpool_compress_torch(\n    x: torch.Tensor,\n    w: torch.Tensor,  # [num_heads, kernel_size]\n    cu_seqlens,\n    kernel_size: int,\n    kernel_stride: int,\n    pe: Optional[torch.Tensor] = None,\n):\n    \"\"\"Compress key and value tensor with kernel_size and kernel_stride.\n\n    Args:\n        x (torch.Tensor): key_states or value_states, shape (total_len, num_heads, head_dim)\n        w (torch.Tensor): weight for each head, shape (num_heads, kernel_size)\n        cu_seqlens (_type_): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.\n        kernel_size (int): kernel_size, each (kernel_size, head_dim) blocks will be compressed to (1, head_dim)\n        kernel_stride (int): stride for each compress kernel\n        pe (Optional[torch.Tensor], optional): intra-block positional embedding with shape (num_heads, kernel_size, head_dim). Defaults to None.\n\n    Returns:\n        Tuple[torch.Tensor, torch.Tensor]: compressed states and corresponding cu_seqlens.\n    \"\"\"\n    # dtype check\n    assert x.dtype == torch.float16 or x.dtype == torch.bfloat16\n    assert x.dtype == w.dtype\n    assert x.dtype == pe.dtype if pe is not None else True\n    assert cu_seqlens.dtype == torch.int32\n    # shape check\n    total_len, num_heads, head_dim = x.shape\n    batch_size = cu_seqlens.shape[0] - 1\n    assert w.shape[0] == num_heads\n    assert w.shape[1] == kernel_size\n    assert kernel_size % kernel_stride == 0\n    assert kernel_size in {16, 32, 64, 128}\n    # compute seqlens after compression\n    seqlens = cu_seqlens[1:] - cu_seqlens[:-1]\n    y_seqlens = torch.floor((seqlens - kernel_size) / kernel_stride).to(torch.int32) + 1\n    # corner case, if sequence_length < kernel_size, no compression for this sequence\n    y_seqlens[seqlens < kernel_size] = 0\n    y_cu_seqlens = torch.cat(\n        [\n            torch.zeros(1, dtype=torch.int32, device=\"cuda\"),\n            torch.cumsum(y_seqlens, dim=0),\n        ],\n        dim=0,\n    ).to(torch.int32)\n    # pad and rearrange x\n    x = rearrange(x, \"n h d -> n (h d)\")\n    splited_x = torch.split(x, seqlens.tolist(), 0)\n    x = torch.nn.utils.rnn.pad_sequence(splited_x, batch_first=True)\n    x = rearrange(x, \"b n (h d) -> b h n d\", h=num_heads)\n    x = x.as_strided(\n        size=(batch_size, num_heads, y_seqlens.max().item(), kernel_size, head_dim),\n        stride=(\n            x.stride(0),\n            x.stride(1),\n            kernel_stride * x.stride(2),\n            x.stride(2),\n            x.stride(3),\n        ),\n    )\n    y = einsum(x, w, \"b h n k d, h k -> b n h d\")\n    # only keep useful part\n    y = torch.cat([y[i, : y_seqlens[i]] for i in range(batch_size)], dim=0)\n\n    # position embedding as a bias\n    if pe is not None:\n        bias = einsum(pe, w, \"h k d, h k -> h d\")\n        y = y + bias.unsqueeze(0)\n\n    return y, y_cu_seqlens\n\n\ndef linear_compress_torch(\n    x: torch.Tensor,\n    w: torch.Tensor,  # [num_heads, kernel_size * head_dim, head_dim]\n    cu_seqlens,\n    kernel_size: int,\n    kernel_stride: int,\n    pe: Optional[torch.Tensor] = None,\n):\n    \"\"\"Compress key and value tensor with kernel_size and kernel_stride. Similar to conv_compress.\n\n    Args:\n        x (torch.Tensor): key_states or value_states, shape (total_len, num_heads, head_dim)\n        w (torch.Tensor): weight for each head, shape (num_heads, kernel_size * head_dim, head_dim)\n        cu_seqlens (_type_): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.\n        kernel_size (int): kernel_size, each (kernel_size, head_dim) blocks will be compressed to (1, head_dim)\n        kernel_stride (int): stride for each compress kernel\n        pe (Optional[torch.Tensor], optional): intra-block positional embedding with shape (num_heads, kernel_size, head_dim). Defaults to None.\n\n    Returns:\n        Tuple[torch.Tensor, torch.Tensor]: compressed states and corresponding cu_seqlens.\n    \"\"\"\n    # dtype check\n    assert x.dtype == torch.float16 or x.dtype == torch.bfloat16\n    assert x.dtype == w.dtype\n    assert x.dtype == pe.dtype if pe is not None else True\n    assert cu_seqlens.dtype == torch.int32\n    # shape check\n    total_len, num_heads, head_dim = x.shape\n    batch_size = cu_seqlens.shape[0] - 1\n    assert w.shape[0] == num_heads\n    assert w.shape[1] == kernel_size * head_dim\n    assert w.shape[2] == head_dim\n    assert kernel_size % kernel_stride == 0\n    assert kernel_size in {16, 32, 64, 128}\n    # compute seqlens after compression\n    seqlens = cu_seqlens[1:] - cu_seqlens[:-1]\n    y_seqlens = torch.floor((seqlens - kernel_size) / kernel_stride).to(torch.int32) + 1\n    # corner case, if sequence_length < kernel_size, no compression for this sequence\n    y_seqlens[seqlens < kernel_size] = 0\n    y_cu_seqlens = torch.cat(\n        [\n            torch.zeros(1, dtype=torch.int32, device=\"cuda\"),\n            torch.cumsum(y_seqlens, dim=0),\n        ],\n        dim=0,\n    ).to(torch.int32)\n    # pad and rearrange x\n    x = rearrange(x, \"n h d -> n (h d)\")\n    splited_x = torch.split(x, seqlens.tolist(), 0)\n    x = torch.nn.utils.rnn.pad_sequence(splited_x, batch_first=True)\n    x = rearrange(x, \"b n (h d) -> b h n d\", h=num_heads)\n    x = x.as_strided(\n        size=(batch_size, num_heads, y_seqlens.max().item(), kernel_size, head_dim),\n        stride=(\n            x.stride(0),\n            x.stride(1),\n            kernel_stride * x.stride(2),\n            x.stride(2),\n            x.stride(3),\n        ),\n    )\n    y = einsum(\n        x,\n        rearrange(w, \"h (k d) D -> h k d D\", k=kernel_size),\n        \"b h n k d, h k d D -> b n h D\",\n    )\n    # only keep useful part\n    y = torch.cat([y[i, : y_seqlens[i]] for i in range(batch_size)], dim=0)\n\n    # position embedding as a bias\n    if pe is not None:\n        pe = rearrange(pe, \"h k d -> h (k d)\")\n        bias = einsum(pe, w, \"h D, h D d -> h d\")\n        y = y + bias.unsqueeze(0)\n\n    return y, y_cu_seqlens\n"
  },
  {
    "path": "native_sparse_attention/ops/torch/compressed_attention.py",
    "content": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport torch\nimport math\nfrom typing import Tuple\nfrom collections import Counter\nfrom einops import rearrange\n\n\ndef transform_score(\n    score: torch.Tensor,\n    kernel_size: int,\n    kernel_stride: int,\n    block_size: int,\n    cu_seqlens_q: torch.Tensor,\n    cu_seqlens_k: torch.Tensor,\n    max_seqlen_q: int,\n    max_seqlen_k: int,\n    init_blocks: int = 1,\n    local_blocks: int = 2,\n) -> torch.Tensor:\n    num_k_heads, total_query_len, _ = score.shape\n    pad_len = kernel_size // kernel_stride - 1\n    score = torch.nn.functional.pad(score, (pad_len, pad_len), value=0)\n    max_blocks = math.ceil(max_seqlen_q / block_size)\n    full_blocks = max_seqlen_q // block_size\n    block_score = torch.zeros(\n        num_k_heads,\n        total_query_len,\n        max_blocks,\n        dtype=torch.float32,\n        device=score.device,\n    )\n    offs = (\n        torch.arange(kernel_size // kernel_stride)[:, None]\n        + torch.arange(block_size // kernel_stride)[None, :]\n    ).view(-1)\n    offs = dict(Counter(offs.tolist()))\n    for k, v in offs.items():\n        block_score[..., :full_blocks] += (\n            v * score[..., k :: block_size // kernel_stride][..., :full_blocks]\n        )\n    # set init block and local block score\n    batch_size = cu_seqlens_q.shape[0] - 1\n    q_idx = torch.cat(\n        [\n            torch.arange(cu_seqlens_q[i + 1] - cu_seqlens_q[i], device=score.device)\n            for i in range(batch_size)\n        ],\n        dim=0,\n    )\n    q_idx = q_idx // block_size\n    b_idx = torch.arange(max_blocks, device=score.device)\n    block_score[..., :init_blocks] = torch.inf\n    local_mask = (q_idx[:, None] >= b_idx[None, :]) & (\n        q_idx[:, None] < b_idx[None, :] + local_blocks\n    )\n    local_mask = local_mask.unsqueeze(0).expand(num_k_heads, -1, -1)\n    block_score[local_mask] = torch.inf\n    return block_score\n\n\ndef compressed_attention_torch(\n    q: torch.Tensor,  # [total_query_len, num_q_heads, head_dim]\n    k: torch.Tensor,  # [total_key_len, num_k_heads, head_dim]\n    v: torch.Tensor,  # [total_key_len, num_k_heads, head_dim]\n    kernel_size: int,\n    kernel_stride: int,\n    block_size: int,\n    topk: int,\n    cu_seqlens_q: torch.Tensor,\n    cu_seqlens_k: torch.Tensor,\n    max_seqlen_q: int,\n    max_seqlen_k: int,\n    sm_scale: float = None,\n    init_blocks: int = 1,\n    local_blocks: int = 2,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"Attention between query and compressed key and value. Implemented with torch, only for debug.\n\n    Args:\n        q (torch.Tensor): shape [total_q_len, num_q_heads, head_dim]\n        k (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim]\n        v (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim]\n        kernel_size (int): kernel size in compress_key_value\n        kernel_stride (int): stride of compress_key_value\n        block_size (int): key value block size for topk sparse attention.\n        topk (int): number of blocks for each query.\n        cu_seqlens_q (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.\n        cu_seqlens_k (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_k in flash_attn_func_varlen.\n        max_seqlen_q (int): max q len of the batch.\n        max_seqlen_k (int): max k len of the batch.\n        sm_scale (float, optional): softmax scale. Defaults to None, means 1/sqrt(head_dim).\n        init_blocks (int, optional): Number of init blocks for each query. Defaults to 1.\n        local_blocks (int, optional): Number of local blocks for each query. Defaults to 2.\n\n    Returns:\n        Tuple[torch.Tensor, torch.Tensor]: attention output and topk_idx used in topk_sparse_attention\n    \"\"\"\n    assert block_size % kernel_size == 0 and kernel_size % kernel_stride == 0\n    total_query_len, num_q_heads, head_dim = q.shape\n    total_key_len, num_k_heads, _ = k.shape\n    num_share_q_heads = num_q_heads // num_k_heads\n    batch_size = cu_seqlens_q.shape[0] - 1\n    if sm_scale is None:\n        sm_scale = 1.0 / math.sqrt(head_dim)\n    # get mask\n    mask = torch.zeros(\n        (total_query_len, total_key_len), dtype=torch.bool, device=q.device\n    )\n    for b in range(batch_size):\n        q_len, k_len = (\n            cu_seqlens_q[b + 1] - cu_seqlens_q[b],\n            cu_seqlens_k[b + 1] - cu_seqlens_k[b],\n        )\n        k_max_ids = (\n            torch.arange(k_len, device=q.device) * kernel_stride + kernel_size - 1\n        )\n        q_ids = torch.arange(q_len, device=q.device)\n        mask[\n            cu_seqlens_q[b] : cu_seqlens_q[b + 1], cu_seqlens_k[b] : cu_seqlens_k[b + 1]\n        ] = (q_ids[:, None] >= k_max_ids[None, :])\n    # attention\n    qk = (\n        torch.einsum(\"qhd,khd->hqk\", q, k.repeat_interleave(num_share_q_heads, 1))\n        * sm_scale\n    )\n    qk = qk.masked_fill_(~mask[None, ...], -torch.inf)\n    # query from beginning of the sequence can't attend to any compressed key\n    qk = qk.softmax(dim=-1, dtype=torch.float32)\n    qk = qk.nan_to_num(0)\n    attn_output = torch.einsum(\n        \"hqk,khd->qhd\", qk.to(v.dtype), v.repeat_interleave(num_share_q_heads, 1)\n    )\n    with torch.no_grad():\n        # get avg score over gqa heads\n        # qk shape [num_k_heads, total_q_len, total_k_len]\n        score = torch.zeros(\n            num_k_heads,\n            cu_seqlens_q[-1],\n            max_seqlen_k,\n            dtype=torch.float32,\n            device=q.device,\n        )\n        qk = rearrange(qk, \"(h g) q k -> h g q k\", h=num_k_heads).sum(1)\n        for b in range(batch_size):\n            score[\n                :,\n                cu_seqlens_q[b] : cu_seqlens_q[b + 1],\n                : cu_seqlens_k[b + 1] - cu_seqlens_k[b],\n            ] = qk[\n                :,\n                cu_seqlens_q[b] : cu_seqlens_q[b + 1],\n                cu_seqlens_k[b] : cu_seqlens_k[b + 1],\n            ]\n        # transform score to block-wise score\n        score = transform_score(\n            score,\n            kernel_size,\n            kernel_stride,\n            block_size,\n            cu_seqlens_q,\n            cu_seqlens_k,\n            max_seqlen_q,\n            max_seqlen_k,\n            init_blocks,\n            local_blocks,\n        )\n        # get topk\n        batch_size = cu_seqlens_q.shape[0] - 1\n        q_idx = torch.cat(\n            [\n                torch.arange(cu_seqlens_q[i + 1] - cu_seqlens_q[i], device=q.device)\n                for i in range(batch_size)\n            ],\n            dim=0,\n        )\n        q_idx = q_idx // block_size\n        topk = min(topk, score.shape[-1])\n        topk_idx = score.topk(topk, dim=-1).indices.sort(-1).values\n        topk_idx[topk_idx > q_idx[None, :, None]] = -1\n        topk_idx = topk_idx.to(torch.int32)\n    return attn_output, topk_idx\n"
  },
  {
    "path": "native_sparse_attention/ops/torch/compressed_attention_decode.py",
    "content": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport torch\nimport math\nfrom typing import Tuple, Optional\nfrom collections import Counter\nfrom einops import rearrange\n\n\ndef transform_score(\n    score: torch.Tensor,\n    seqlens: torch.Tensor,\n    kernel_size: int,\n    kernel_stride: int,\n    block_size: int,\n    init_blocks: int = 1,\n    local_blocks: int = 2,\n) -> torch.Tensor:\n    num_k_heads, batch_size, kv_len = score.shape\n    pad_len = kernel_size // kernel_stride - 1\n    score = torch.nn.functional.pad(score, (pad_len, pad_len), value=0)\n    max_seqlen = seqlens.max().item()\n    max_blocks = math.ceil(max_seqlen / block_size)\n    full_blocks = max_seqlen // block_size\n    block_score = torch.zeros(\n        num_k_heads,\n        batch_size,\n        max_blocks,\n        dtype=torch.float32,\n        device=score.device,\n    )\n    offs = (\n        torch.arange(kernel_size // kernel_stride)[:, None]\n        + torch.arange(block_size // kernel_stride)[None, :]\n    ).view(-1)\n    offs = dict(Counter(offs.tolist()))\n    for k, v in offs.items():\n        block_score[..., :full_blocks] += (\n            v * score[..., k :: block_size // kernel_stride][..., :full_blocks]\n        )\n    # set init block and local block score\n    q_idx = (seqlens - 1) // block_size\n    b_idx = torch.arange(max_blocks, device=score.device)\n    block_score[..., :init_blocks] = torch.inf\n    local_mask = (q_idx[:, None] >= b_idx[None, :]) & (\n        q_idx[:, None] < b_idx[None, :] + local_blocks\n    )\n    local_mask = local_mask.unsqueeze(0).expand(num_k_heads, -1, -1)\n    block_score[local_mask] = torch.inf\n    block_score = block_score.nan_to_num(0, torch.inf, -torch.inf)\n    return block_score\n\n\ndef compressed_attention_decode(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    seqlens: torch.Tensor,\n    compress_seqlens: torch.Tensor,\n    kernel_size: int,\n    kernel_stride: int,\n    block_size: int,\n    topk: int,\n    init_blocks: int = 1,\n    local_blocks: int = 2,\n    sm_scale: Optional[float] = None,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"_summary_\n\n    Args:\n        q (torch.Tensor): shape [batch_size, num_q_heads, head_dim]\n        k (torch.Tensor): shape [batch_size, kv_len, num_kv_heads, head_dim]\n        v (torch.Tensor): shape [batch_size, kv_len, num_kv_heads, head_dim]\n        seqlens (torch.Tensor): original kv length for each sequence\n        compress_seqlens (torch.Tensor): kv length for each sequence after compression\n        kernel_size (int): kernel size in compress_key_value\n        kernel_stride (int): stride of compress_key_value\n        block_size (int): key value block size for topk sparse attention.\n        topk (int): number of blocks for each query.\n        init_blocks (int, optional): Number of init blocks for each query. Defaults to 1.\n        local_blocks (int, optional): Number of local blocks for each query. Defaults to 2.\n        sm_scale (float, optional): softmax scale. Defaults to None, means 1/sqrt(head_dim).\n\n    Returns:\n        Tuple[torch.Tensor, torch.Tensor]: attention output and topk_idx used in topk_sparse_attention_decode\n    \"\"\"\n    assert block_size % kernel_size == 0 and kernel_size % kernel_stride == 0\n    batch_size, num_q_heads, head_dim = q.shape\n    batch_size, kv_len, num_k_heads, _ = k.shape\n    num_share_q_heads = num_q_heads // num_k_heads\n    if sm_scale is None:\n        sm_scale = 1.0 / math.sqrt(head_dim)\n    # input is too short to have a valid block\n    if kv_len == 0:\n        return torch.zeros_like(q), torch.zeros(\n            num_k_heads, batch_size, 1, device=q.device, dtype=torch.int32\n        )\n    # get mask\n    mask = (\n        compress_seqlens[:, None]\n        > torch.arange(\n            kv_len, device=compress_seqlens.device, dtype=compress_seqlens.dtype\n        )[None, :]\n    )\n    # attention\n    qk = (\n        torch.einsum(\n            \"bihgd, bjhgd -> bhgij\",\n            rearrange(q, \"b (h g) d -> b 1 h g d\", g=num_share_q_heads),\n            rearrange(k, \"b j h d -> b j h 1 d\"),\n        )\n        * sm_scale\n    )\n    qk = qk.masked_fill_(~mask[:, None, None, None, :], -torch.inf)\n    qk = qk.softmax(dim=-1, dtype=torch.float32)\n    qk = qk.nan_to_num_(0)  # qk is nan when seqlen == 0\n    attn_output = torch.einsum(\n        \"bhgij, bjhgd -> bihgd\",\n        qk.to(v.dtype),\n        rearrange(v, \"b k h d -> b k h 1 d\"),\n    )\n    attn_output = rearrange(attn_output, \"b 1 h g d -> b (h g) d\")\n\n    # get score\n    score = rearrange(qk.sum(2).squeeze(2), \"b h j -> h b j\")\n    # transform score to block-wise score\n    score = transform_score(\n        score,\n        seqlens,\n        kernel_size,\n        kernel_stride,\n        block_size,\n        init_blocks,\n        local_blocks,\n    )\n    # get topk\n    q_idx = (seqlens - 1) // block_size\n    topk = min(topk, score.shape[-1])\n    topk_idx = score.topk(topk, dim=-1).indices.sort(-1).values\n    topk_idx[topk_idx > q_idx[None, :, None]] = -1\n    topk_idx = topk_idx.to(torch.int32)\n    return attn_output, topk_idx\n"
  },
  {
    "path": "native_sparse_attention/ops/torch/topk_sparse_attention.py",
    "content": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport torch\nimport math\nfrom typing import Optional\n\n\ndef topk_sparse_attention_torch(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    topk_idx: torch.Tensor,\n    block_size_k: int,\n    cu_seqlens: torch.Tensor,\n    softmax_scale: Optional[float] = None,\n    block_size_q: int = 1,\n) -> torch.Tensor:\n    \"\"\"Simple topk sparse attention varlen version implemented in torch. Extremly slow, only for debugging.\n\n    Args:\n        q (torch.Tensor): shape [total_len, num_q_heads, head_dim]\n        k (torch.Tensor): shape [total_len, num_kv_heads, head_dim]\n        v (torch.Tensor): shape [total_len, num_kv_heads, head_dim]\n        topk_idx (torch.Tensor): topk block idx for each query, shape [num_kv_heads, total_len, topk]. -1 means padding.\n        block_size_q (int): query block size.\n        block_size_k (int): key value block size.\n        cu_seqlens (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens in flash_attn_func_varlen.\n        softmax_scale (Optional[float], optional): Defaults to None, means 1/sqrt(head_dim).\n\n    Returns:\n        torch.Tensor: attention output, shape [total_len, num_q_heads, head_dim]\n    \"\"\"\n    total_seqlen, num_q_heads, head_dim = q.shape\n    total_seqlen, num_kv_heads, head_dim = k.shape\n    num_share_q_heads = num_q_heads // num_kv_heads\n    batch_size = cu_seqlens.shape[0] - 1\n    topk = topk_idx.shape[-1]\n    seqlens = cu_seqlens[1:] - cu_seqlens[:-1]\n    seqblocks_q = torch.ceil(seqlens / block_size_q).to(torch.int32)\n    cu_seqblocks_q = torch.nn.functional.pad(seqblocks_q.cumsum(0), (1, 0), value=0)\n    if softmax_scale is None:\n        softmax_scale = 1.0 / math.sqrt(head_dim)\n    # get mask\n    mask = torch.zeros(\n        (num_kv_heads, total_seqlen, total_seqlen), dtype=torch.bool, device=q.device\n    )\n    for i in range(batch_size):\n        num_q_blocks = math.ceil(seqlens[i] / block_size_q)\n        num_kv_blocks = math.ceil(seqlens[i] / block_size_k)\n        for h in range(num_kv_heads):\n            temp_mask = torch.zeros(\n                num_q_blocks, num_kv_blocks, dtype=torch.bool, device=q.device\n            )\n            temp_idx = topk_idx[h, cu_seqblocks_q[i] : cu_seqblocks_q[i + 1]].clone()\n            temp_idx[temp_idx < 0] = 0\n            temp_mask[torch.arange(num_q_blocks).to(q.device)[:, None], temp_idx] = True\n            temp_mask = torch.repeat_interleave(temp_mask, block_size_q, dim=0)\n            temp_mask = torch.repeat_interleave(temp_mask, block_size_k, dim=1)\n            temp_mask = temp_mask[: seqlens[i], : seqlens[i]]\n            mask[\n                h, cu_seqlens[i] : cu_seqlens[i + 1], cu_seqlens[i] : cu_seqlens[i + 1]\n            ] = temp_mask\n    mask = torch.tril(mask).repeat_interleave(num_share_q_heads, 0)\n    # qk attn\n    qk = (\n        torch.einsum(\"qhd,khd->hqk\", q, k.repeat_interleave(num_share_q_heads, 1))\n        * softmax_scale\n    )\n    qk = torch.masked_fill(qk, ~mask, -torch.inf)\n    qk = torch.softmax(qk, dim=-1, dtype=torch.float32).to(q.dtype)\n    o = torch.einsum(\"hqk,khd->qhd\", qk, v.repeat_interleave(num_share_q_heads, 1))\n    return o\n"
  },
  {
    "path": "native_sparse_attention/ops/triton/__init__.py",
    "content": ""
  },
  {
    "path": "native_sparse_attention/ops/triton/compressed_attention.py",
    "content": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport math\nfrom typing import Any, Tuple, Union\nfrom collections import Counter\nimport torch\nimport triton\nimport triton.language as tl\nimport warnings\nfrom native_sparse_attention.ops.triton.utils import get_num_warps_stages, is_hopper_gpu\n\n\nIS_HOPPER_GPU = is_hopper_gpu()\n\n\n@triton.jit\ndef forward_kernel(\n    q_ptr,  # Q: n x h x d\n    k_ptr,  # K: n x h x d\n    v_ptr,  # V: n x h x d\n    o_ptr,  # O: n x h x d\n    lse_ptr,  # LSE: h x n\n    # size and stride at compresstion\n    kernel_size,\n    kernel_stride,\n    # seqlens\n    cu_seqlens_q,\n    cu_seqlens_k,\n    # shape\n    NUM_KV_HEADS,\n    NUM_SHARE_Q_HEADS,\n    HEAD_DIM,\n    # sm_scale\n    sm_scale,\n    # stride\n    stride_qn,\n    stride_qh,\n    stride_qd,\n    stride_kn,\n    stride_kh,\n    stride_kd,\n    stride_vn,\n    stride_vh,\n    stride_vd,\n    stride_on,\n    stride_oh,\n    stride_od,\n    stride_lh,\n    stride_ln,\n    # META parameters\n    BLOCK_SIZE_Q: tl.constexpr,  # q block size\n    BLOCK_SIZE_K: tl.constexpr,  # k block size\n    BLOCK_SIZE_D: tl.constexpr,\n):\n    qk_scale = sm_scale * 1.44269504\n    # get batch id and head id\n    pid_b = tl.program_id(0)\n    pid_h = tl.program_id(1)\n    pid_q = tl.program_id(2)\n    pid_kh = pid_h // NUM_SHARE_Q_HEADS\n    # get q k start and len after rmpad\n    q_start = tl.load(cu_seqlens_q + pid_b)\n    q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start\n    k_start = tl.load(cu_seqlens_k + pid_b)\n    k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start\n    # skip first kernel_size query block, because they do no attend to any keys\n    q_start_in_seq = pid_q * BLOCK_SIZE_Q + kernel_size - 1\n    if q_start_in_seq >= q_len:\n        return\n    # init qkv pointer\n    q_ptrs = tl.make_block_ptr(\n        base=q_ptr + q_start * stride_qn + pid_h * stride_qh,\n        shape=(q_len, HEAD_DIM),\n        strides=(stride_qn, stride_qd),\n        offsets=(q_start_in_seq, 0),\n        block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),\n        order=(1, 0),\n    )\n    k_ptrs = tl.make_block_ptr(\n        base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,\n        shape=(HEAD_DIM, k_len),\n        strides=(stride_kd, stride_kn),\n        offsets=(0, 0),\n        block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K),\n        order=(0, 1),\n    )\n    v_ptrs = tl.make_block_ptr(\n        base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,\n        shape=(k_len, HEAD_DIM),\n        strides=(stride_vn, stride_vd),\n        offsets=(0, 0),\n        block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),\n        order=(1, 0),\n    )\n    # load q\n    q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option=\"zero\")\n    # init statistics\n    off_q = tl.arange(0, BLOCK_SIZE_Q) + q_start_in_seq\n    off_k = tl.arange(0, BLOCK_SIZE_K) * kernel_stride + kernel_size - 1\n    m_i = tl.full((BLOCK_SIZE_Q,), float(\"-inf\"), dtype=tl.float32)\n    lse_i = tl.full((BLOCK_SIZE_Q,), float(\"-inf\"), dtype=tl.float32)\n    acc_o = tl.full((BLOCK_SIZE_Q, BLOCK_SIZE_D), 0, dtype=tl.float32)\n    # attention\n    lo = 0\n    hi = min(k_len, (q_start_in_seq + BLOCK_SIZE_Q - kernel_size) // kernel_stride + 1)\n    for i in range(lo, hi, BLOCK_SIZE_K):\n        i = tl.multiple_of(i, BLOCK_SIZE_K)\n        # load k\n        k = tl.load(k_ptrs, boundary_check=(1, 0), padding_option=\"zero\")\n        # compute qk\n        qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)\n        qk += tl.where(\n            off_q[:, None] >= (i * kernel_stride + off_k)[None, :], 0, float(\"-inf\")\n        )\n        qk += tl.dot(q, k) * qk_scale\n        # compute m_ij and l_ij\n        m_ij = tl.maximum(m_i, tl.max(qk, axis=1))\n        p = tl.exp2(qk - m_ij[:, None])\n        l_ij = tl.sum(p, axis=1)\n        # scale acc_o\n        acc_o_scale = tl.exp2(m_i - m_ij)\n        acc_o = acc_o * acc_o_scale[:, None]\n        # load v and update acc_o\n        v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option=\"zero\")\n        p = p.to(v.dtype)\n        acc_o += tl.dot(p, v)\n        # update statistics\n        m_i = m_ij\n        lse_i = m_ij + tl.math.log2(tl.exp2(lse_i - m_ij) + l_ij)\n        # update ptrs\n        k_ptrs = tl.advance(k_ptrs, (0, BLOCK_SIZE_K))\n        v_ptrs = tl.advance(v_ptrs, (BLOCK_SIZE_K, 0))\n    # final scale\n    acc_o = acc_o * tl.exp2(m_i - lse_i)[:, None]\n    # save output\n    o_ptrs = tl.make_block_ptr(\n        base=o_ptr + q_start * stride_on + pid_h * stride_oh,\n        shape=(q_len, HEAD_DIM),\n        strides=(stride_on, stride_od),\n        offsets=(q_start_in_seq, 0),\n        block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),\n        order=(1, 0),\n    )\n    tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1))\n    # save lse\n    l_ptrs = lse_ptr + q_start * stride_ln + pid_h * stride_lh + off_q * stride_ln\n    tl.store(l_ptrs, lse_i, mask=off_q < q_len)\n\n\n@triton.jit\ndef backward_sum_o_do(\n    o_ptr,  # O: n x h x d\n    do_ptr,  # dO: n x h x d\n    delta_ptr,  # D: h x n\n    o_len,\n    HEAD_DIM,\n    stride_on,\n    stride_oh,\n    stride_od,\n    stride_don,\n    stride_doh,\n    stride_dod,\n    stride_dh,\n    stride_dn,\n    BLOCK_SIZE_O: tl.constexpr,\n    BLOCK_SIZE_D: tl.constexpr,\n):\n    pid_n = tl.program_id(0)\n    pid_h = tl.program_id(1)\n    off_n = pid_n * BLOCK_SIZE_O + tl.arange(0, BLOCK_SIZE_O)\n    off_d = tl.arange(0, BLOCK_SIZE_D)\n    o = tl.load(\n        o_ptr\n        + off_n[:, None] * stride_on\n        + pid_h * stride_oh\n        + off_d[None, :] * stride_od,\n        mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM),\n        other=0,\n    ).to(tl.float32)\n    do = tl.load(\n        do_ptr\n        + off_n[:, None] * stride_don\n        + pid_h * stride_doh\n        + off_d[None, :] * stride_dod,\n        mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM),\n        other=0,\n    ).to(tl.float32)\n    delta = tl.sum(o * do, axis=1)\n    tl.store(\n        delta_ptr + pid_h * stride_dh + off_n * stride_dn, delta, mask=off_n < o_len\n    )\n\n\n@triton.jit\ndef backward_dkdv(\n    q_ptr,  # Q: n x qh x d\n    k_ptr,  # K: n x kh x d\n    v_ptr,  # V: n x kh x d\n    lse_ptr,  # LSE: qh x n\n    d_ptr,  # Delta: qh x n\n    do_ptr,\n    dk_ptr,  # DK: sh x n x kh x d\n    dv_ptr,  # DV: sh x n x kh x d\n    kernel_size,\n    kernel_stride,\n    # seqlens\n    cu_seqlens_q,\n    cu_seqlens_k,\n    # shape\n    NUM_KV_HEADS,\n    NUM_SHARE_Q_HEADS,\n    HEAD_DIM,\n    # sm_scale\n    sm_scale,\n    # stride\n    stride_qn,\n    stride_qh,\n    stride_qd,\n    stride_kn,\n    stride_kh,\n    stride_kd,\n    stride_vn,\n    stride_vh,\n    stride_vd,\n    stride_lh,\n    stride_ln,\n    stride_dh,\n    stride_dn,\n    stride_don,\n    stride_doh,\n    stride_dod,\n    stride_dks,\n    stride_dkn,\n    stride_dkh,\n    stride_dkd,\n    stride_dvs,\n    stride_dvn,\n    stride_dvh,\n    stride_dvd,\n    # META parameters\n    BLOCK_SIZE_Q: tl.constexpr,  # q block size\n    BLOCK_SIZE_K: tl.constexpr,  # k block size\n    BLOCK_SIZE_D: tl.constexpr,\n):\n    qk_scale = sm_scale * 1.44269504\n    # get batch id and head id\n    pid_b = tl.program_id(0)\n    pid_h = tl.program_id(1)\n    pid_kh = pid_h // NUM_SHARE_Q_HEADS\n    pid_sh = pid_h % NUM_SHARE_Q_HEADS\n    pid_k = tl.program_id(2)\n    # get q k start and len after rmpad\n    q_start = tl.load(cu_seqlens_q + pid_b)\n    q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start\n    k_start = tl.load(cu_seqlens_k + pid_b)\n    k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start\n    if BLOCK_SIZE_K * pid_k >= k_len:\n        return\n    # init pointers\n    k_ptrs = tl.make_block_ptr(\n        base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,\n        shape=(k_len, HEAD_DIM),\n        strides=(stride_kn, stride_kd),\n        offsets=(pid_k * BLOCK_SIZE_K, 0),\n        block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),\n        order=(1, 0),\n    )\n    dk_ptrs = tl.make_block_ptr(\n        base=dk_ptr + k_start * stride_dkn + pid_kh * stride_dkh + pid_sh * stride_dks,\n        shape=(k_len, HEAD_DIM),\n        strides=(stride_dkn, stride_dkd),\n        offsets=(pid_k * BLOCK_SIZE_K, 0),\n        block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),\n        order=(1, 0),\n    )\n    v_ptrs = tl.make_block_ptr(\n        base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,\n        shape=(k_len, HEAD_DIM),\n        strides=(stride_vn, stride_vd),\n        offsets=(pid_k * BLOCK_SIZE_K, 0),\n        block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),\n        order=(1, 0),\n    )\n    dv_ptrs = tl.make_block_ptr(\n        base=dv_ptr + k_start * stride_dvn + pid_kh * stride_dvh + pid_sh * stride_dvs,\n        shape=(k_len, HEAD_DIM),\n        strides=(stride_dvn, stride_dvd),\n        offsets=(pid_k * BLOCK_SIZE_K, 0),\n        block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),\n        order=(1, 0),\n    )\n    # offsets\n    off_q = tl.arange(0, BLOCK_SIZE_Q)\n    off_k = (\n        pid_k * BLOCK_SIZE_K * kernel_stride\n        + tl.arange(0, BLOCK_SIZE_K) * kernel_stride\n        + kernel_size\n        - 1\n    )\n    # load k v and keep in SRAM\n    k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option=\"zero\")\n    v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option=\"zero\")\n    # init dk dv\n    dk = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32)\n    dv = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32)\n    q_lo = pid_k * BLOCK_SIZE_K * kernel_stride + kernel_size - 1\n    q_ptrs = tl.make_block_ptr(\n        base=q_ptr + q_start * stride_qn + pid_h * stride_qh,\n        shape=(HEAD_DIM, q_len),\n        strides=(stride_qd, stride_qn),\n        offsets=(0, q_lo),\n        block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_Q),\n        order=(0, 1),\n    )\n    do_ptrs = tl.make_block_ptr(\n        base=do_ptr + q_start * stride_don + pid_h * stride_doh,\n        shape=(HEAD_DIM, q_len),\n        strides=(stride_dod, stride_don),\n        offsets=(0, q_lo),\n        block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_Q),\n        order=(0, 1),\n    )\n    d_ptrs = tl.make_block_ptr(\n        base=d_ptr + q_start * stride_dn + pid_h * stride_dh,\n        shape=(1, q_len),\n        strides=(0, stride_dn),\n        offsets=(0, q_lo),\n        block_shape=(1, BLOCK_SIZE_Q),\n        order=(1, 0),\n    )\n    lse_ptrs = tl.make_block_ptr(\n        base=lse_ptr + q_start * stride_ln + pid_h * stride_lh,\n        shape=(1, q_len),\n        strides=(0, stride_ln),\n        offsets=(0, q_lo),\n        block_shape=(1, BLOCK_SIZE_Q),\n        order=(0, 1),\n    )\n    # loop for q blocks\n    for i in range(q_lo, q_len, BLOCK_SIZE_Q):\n        # load\n        q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option=\"zero\")\n        do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option=\"zero\")\n        lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option=\"zero\")\n        d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option=\"zero\")\n        # compute qk\n        # [BLOCK_SIZE_K, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q]\n        qk = tl.where(off_k[:, None] <= (off_q + i)[None, :], float(0.0), float(\"-inf\"))\n        qk += tl.dot(k, q) * qk_scale\n        # compute p, ds\n        # [BLOCK_SIZE_K, BLOCK_SIE_Q] - [1, BLOCK_SIZE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q]\n        p = tl.exp2(qk - lse)\n        # [BLOCK_SIZE_K, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q]\n        dp = tl.dot(v, do)\n        ds = sm_scale * p * (dp - d)\n        # cast dtype\n        p = p.to(do.dtype)\n        ds = ds.to(q.dtype)\n        # update dk and dv\n        # [BLOCK_SIZE_K, BLOCK_SIE_Q] @ [BLOCK_SIE_Q, HEAD_DIM] -> [BLOCK_SIZE_K, HEAD_DIM]\n        dk += tl.dot(ds, tl.trans(q))\n        dv += tl.dot(p, tl.trans(do))\n        # increment pointers\n        q_ptrs = tl.advance(q_ptrs, (0, BLOCK_SIZE_Q))\n        do_ptrs = tl.advance(do_ptrs, (0, BLOCK_SIZE_Q))\n        lse_ptrs = tl.advance(lse_ptrs, (0, BLOCK_SIZE_Q))\n        d_ptrs = tl.advance(d_ptrs, (0, BLOCK_SIZE_Q))\n    # save dk dv\n    tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), boundary_check=(0, 1))\n    tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef backward_dq(\n    q_ptr,  # Q: n x qh x d\n    k_ptr,  # K: n x kh x d\n    v_ptr,  # V: n x kh x d\n    lse_ptr,  # LSE: qh x n\n    d_ptr,  # Delta: qh x n\n    do_ptr,\n    dq_ptr,\n    kernel_size,\n    kernel_stride,\n    # seqlens\n    cu_seqlens_q,\n    cu_seqlens_k,\n    # shape\n    NUM_KV_HEADS,\n    NUM_SHARE_Q_HEADS,\n    HEAD_DIM,\n    # sm_scale\n    sm_scale,\n    # stride\n    stride_qn,\n    stride_qh,\n    stride_qd,\n    stride_kn,\n    stride_kh,\n    stride_kd,\n    stride_vn,\n    stride_vh,\n    stride_vd,\n    stride_lh,\n    stride_ln,\n    stride_dh,\n    stride_dn,\n    stride_don,\n    stride_doh,\n    stride_dod,\n    stride_dqn,\n    stride_dqh,\n    stride_dqd,\n    # META parameters\n    BLOCK_SIZE_Q: tl.constexpr,  # q block size\n    BLOCK_SIZE_K: tl.constexpr,  # k block size\n    BLOCK_SIZE_D: tl.constexpr,\n):\n    qk_scale = sm_scale * 1.44269504\n    # get batch id and head id\n    pid_b = tl.program_id(0)\n    pid_h = tl.program_id(1)\n    pid_q = tl.program_id(2)\n    pid_kh = pid_h // NUM_SHARE_Q_HEADS\n    # get q k start and len after rmpad\n    q_start = tl.load(cu_seqlens_q + pid_b)\n    q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start\n    k_start = tl.load(cu_seqlens_k + pid_b)\n    k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start\n    # skip first kernel_size query block, because they do no attend to any keys\n    q_start_in_seq = pid_q * BLOCK_SIZE_Q + kernel_size - 1\n    if q_start_in_seq >= q_len:\n        return\n    # init pointers\n    q_ptrs = tl.make_block_ptr(\n        base=q_ptr + q_start * stride_qn + pid_h * stride_qh,\n        shape=(q_len, HEAD_DIM),\n        strides=(stride_qn, stride_qd),\n        offsets=(q_start_in_seq, 0),\n        block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),\n        order=(1, 0),\n    )\n    dq_ptrs = tl.make_block_ptr(\n        base=dq_ptr + q_start * stride_dqn + pid_h * stride_dqh,\n        shape=(q_len, HEAD_DIM),\n        strides=(stride_dqn, stride_dqd),\n        offsets=(q_start_in_seq, 0),\n        block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),\n        order=(1, 0),\n    )\n    k_ptrs = tl.make_block_ptr(\n        base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,\n        shape=(k_len, HEAD_DIM),\n        strides=(stride_kn, stride_kd),\n        offsets=(0, 0),\n        block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),\n        order=(1, 0),\n    )\n    v_ptrs = tl.make_block_ptr(\n        base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,\n        shape=(HEAD_DIM, k_len),\n        strides=(stride_vd, stride_vn),\n        offsets=(0, 0),\n        block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K),\n        order=(0, 1),\n    )\n    do_ptrs = tl.make_block_ptr(\n        base=do_ptr + q_start * stride_don + pid_h * stride_doh,\n        shape=(q_len, HEAD_DIM),\n        strides=(stride_don, stride_dod),\n        offsets=(q_start_in_seq, 0),\n        block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),\n        order=(1, 0),\n    )\n    d_ptrs = tl.make_block_ptr(\n        base=d_ptr + q_start * stride_dn + pid_h * stride_dh,\n        shape=(q_len, 1),\n        strides=(stride_dn, stride_dh),\n        offsets=(q_start_in_seq, 0),\n        block_shape=(BLOCK_SIZE_Q, 1),\n        order=(0, 1),\n    )\n    lse_ptrs = tl.make_block_ptr(\n        base=lse_ptr + q_start * stride_ln + pid_h * stride_lh,\n        shape=(q_len, 1),\n        strides=(stride_ln, stride_lh),\n        offsets=(q_start_in_seq, 0),\n        block_shape=(BLOCK_SIZE_Q, 1),\n        order=(0, 1),\n    )\n    # offsets\n    off_q = tl.arange(0, BLOCK_SIZE_Q) + q_start_in_seq\n    off_k = tl.arange(0, BLOCK_SIZE_K) * kernel_stride + kernel_size - 1\n    # load q, do, lse, delta, and keep in SRAM\n    q = tl.load(q_ptrs, boundary_check=(1, 0), padding_option=\"zero\")\n    do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option=\"zero\")\n    lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option=\"zero\")\n    d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option=\"zero\")\n    # init dq\n    dq = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_D), dtype=tl.float32)\n    lo = 0\n    hi = min(k_len, (q_start_in_seq + BLOCK_SIZE_Q - kernel_size) // kernel_stride + 1)\n    for i in range(lo, hi, BLOCK_SIZE_K):\n        # load\n        k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option=\"zero\")\n        v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option=\"zero\")\n        # compute qk\n        qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)\n        qk += tl.where(\n            off_q[:, None] >= (i * kernel_stride + off_k)[None, :], 0, float(\"-inf\")\n        )\n        qk += tl.dot(q, tl.trans(k)) * qk_scale\n        # compute p, ds\n        p = tl.exp2(qk - lse)\n        dp = tl.dot(do, v)\n        ds = sm_scale * p * (dp - d)\n        # cast dtype\n        ds = ds.to(q.dtype)\n        # update dq\n        dq += tl.dot(ds, k)\n        # increment pointers\n        k_ptrs = tl.advance(k_ptrs, (BLOCK_SIZE_K, 0))\n        v_ptrs = tl.advance(v_ptrs, (0, BLOCK_SIZE_K))\n    # save dq\n    tl.store(dq_ptrs, dq.to(dq_ptr.dtype.element_ty), boundary_check=(0, 1))\n\n\ndef _compressed_attention_fwd(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    kernel_size: int,\n    kernel_stride: int,\n    cu_seqlens_q: torch.Tensor,\n    cu_seqlens_k: torch.Tensor,\n    max_seqlen_q: int,\n    max_seqlen_k: int,\n    sm_scale: float,\n):\n    # dtype check\n    assert k.dtype == q.dtype and v.dtype == q.dtype\n    assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32\n    # shape\n    q_len, num_q_heads, head_dim = q.shape\n    k_len, num_k_heads, head_dim = k.shape\n    v_len, num_v_heads, head_dim = v.shape\n    batch_size = cu_seqlens_q.shape[0] - 1\n    assert k_len == v_len and q_len > k_len\n    # gqa\n    assert num_k_heads == num_v_heads\n    assert num_q_heads % num_k_heads == 0\n    num_share_q_heads = num_q_heads // num_k_heads\n    # output tensor\n    o = torch.zeros_like(q)\n    lse = torch.full(\n        (num_q_heads, q_len),\n        fill_value=-torch.inf,\n        dtype=torch.float32,\n        device=q.device,\n    )\n    # launch kernel\n    grid = lambda META: (\n        batch_size,\n        num_q_heads,\n        triton.cdiv(max_seqlen_q, META[\"BLOCK_SIZE_Q\"]),\n    )\n    BLOCK_SIZE_Q = 128\n    BLOCK_SIZE_K = 128\n    BLOCK_SIZE_D = triton.next_power_of_2(head_dim)\n    num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU)\n    forward_kernel[grid](\n        q,\n        k,\n        v,\n        o,\n        lse,\n        kernel_size,\n        kernel_stride,\n        cu_seqlens_q,\n        cu_seqlens_k,\n        num_k_heads,\n        num_share_q_heads,\n        head_dim,\n        sm_scale,\n        q.stride(0),\n        q.stride(1),\n        q.stride(2),\n        k.stride(0),\n        k.stride(1),\n        k.stride(2),\n        v.stride(0),\n        v.stride(1),\n        v.stride(2),\n        o.stride(0),\n        o.stride(1),\n        o.stride(2),\n        lse.stride(0),\n        lse.stride(1),\n        BLOCK_SIZE_Q=BLOCK_SIZE_Q,\n        BLOCK_SIZE_K=BLOCK_SIZE_K,\n        BLOCK_SIZE_D=BLOCK_SIZE_D,\n        num_warps=num_warps,\n        num_stages=num_stages,\n    )\n    return o, lse\n\n\ndef _compressed_attention_bwd(\n    o: torch.Tensor,\n    do: torch.Tensor,\n    lse: torch.Tensor,\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    kernel_size: int,\n    kernel_stride: int,\n    cu_seqlens_q: torch.Tensor,\n    cu_seqlens_k: torch.Tensor,\n    max_seqlen_q: int,\n    max_seqlen_k: int,\n    sm_scale: float,\n):\n    q_len, num_q_heads, head_dim = q.shape\n    k_len, num_k_heads, head_dim = k.shape\n    v_len, num_v_heads, head_dim = v.shape\n    o_len, num_o_heads, head_dim = o.shape\n    num_share_q_heads = num_q_heads // num_k_heads\n    # compute D\n    delta = torch.zeros([num_o_heads, o_len], device=o.device, dtype=torch.float32)\n    grid = lambda META: (triton.cdiv(o_len, META[\"BLOCK_SIZE_O\"]), num_o_heads)\n    BLOCK_SIZE_O = 256\n    BLOCK_SIZE_D = triton.next_power_of_2(head_dim)\n    num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_O, IS_HOPPER_GPU)\n    backward_sum_o_do[grid](\n        o,\n        do,\n        delta,\n        o_len,\n        head_dim,\n        o.stride(0),\n        o.stride(1),\n        o.stride(2),\n        do.stride(0),\n        do.stride(1),\n        do.stride(2),\n        delta.stride(0),\n        delta.stride(1),\n        BLOCK_SIZE_O=BLOCK_SIZE_O,\n        BLOCK_SIZE_D=BLOCK_SIZE_D,\n        num_warps=num_warps,\n        num_stages=num_stages,\n    )\n    # compute dk dv\n    dk = torch.zeros(\n        num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype\n    )\n    dv = torch.zeros(\n        num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype\n    )\n    batch_size = cu_seqlens_q.shape[0] - 1\n    grid = lambda META: (\n        batch_size,\n        num_q_heads,\n        triton.cdiv(max_seqlen_k, META[\"BLOCK_SIZE_K\"]),\n    )\n    BLOCK_SIZE_Q = 64\n    BLOCK_SIZE_K = 128\n    BLOCK_SIZE_D = triton.next_power_of_2(head_dim)\n    num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_K, IS_HOPPER_GPU)\n    backward_dkdv[grid](\n        q,\n        k,\n        v,\n        lse,\n        delta,\n        do,\n        dk,\n        dv,\n        kernel_size,\n        kernel_stride,\n        cu_seqlens_q,\n        cu_seqlens_k,\n        num_k_heads,\n        num_share_q_heads,\n        head_dim,\n        sm_scale,\n        q.stride(0),\n        q.stride(1),\n        q.stride(2),\n        k.stride(0),\n        k.stride(1),\n        k.stride(2),\n        v.stride(0),\n        v.stride(1),\n        v.stride(2),\n        lse.stride(0),\n        lse.stride(1),\n        delta.stride(0),\n        delta.stride(1),\n        do.stride(0),\n        do.stride(1),\n        do.stride(2),\n        dk.stride(0),\n        dk.stride(1),\n        dk.stride(2),\n        dk.stride(3),\n        dv.stride(0),\n        dv.stride(1),\n        dv.stride(2),\n        dv.stride(3),\n        BLOCK_SIZE_Q=BLOCK_SIZE_Q,\n        BLOCK_SIZE_K=BLOCK_SIZE_K,\n        BLOCK_SIZE_D=BLOCK_SIZE_D,\n        num_warps=num_warps,\n        num_stages=num_stages,\n    )\n    dk = dk.sum(0)\n    dv = dv.sum(0)\n    # compute dq\n    dq = torch.zeros_like(q)\n    grid = lambda META: (\n        batch_size,\n        num_q_heads,\n        triton.cdiv(max_seqlen_q, META[\"BLOCK_SIZE_Q\"]),\n    )\n    BLOCK_SIZE_Q = 128\n    BLOCK_SIZE_K = 64\n    num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU)\n    backward_dq[grid](\n        q,\n        k,\n        v,\n        lse,\n        delta,\n        do,\n        dq,\n        kernel_size,\n        kernel_stride,\n        cu_seqlens_q,\n        cu_seqlens_k,\n        num_k_heads,\n        num_share_q_heads,\n        head_dim,\n        sm_scale,\n        q.stride(0),\n        q.stride(1),\n        q.stride(2),\n        k.stride(0),\n        k.stride(1),\n        k.stride(2),\n        v.stride(0),\n        v.stride(1),\n        v.stride(2),\n        lse.stride(0),\n        lse.stride(1),\n        delta.stride(0),\n        delta.stride(1),\n        do.stride(0),\n        do.stride(1),\n        do.stride(2),\n        dq.stride(0),\n        dq.stride(1),\n        dq.stride(2),\n        BLOCK_SIZE_Q=BLOCK_SIZE_Q,\n        BLOCK_SIZE_K=BLOCK_SIZE_K,\n        BLOCK_SIZE_D=BLOCK_SIZE_D,\n        num_warps=num_warps,\n        num_stages=num_stages,\n    )\n    return dq, dk, dv\n\n\nclass CompressedAttention(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        kernel_size: int,\n        kernel_stride: int,\n        cu_seqlens_q: torch.Tensor,\n        cu_seqlens_k: torch.Tensor,\n        max_seqlen_q: int,\n        max_seqlen_k: int,\n        sm_scale=None,\n    ):\n        # dtype check\n        assert q.dtype == torch.bfloat16 or q.dtype == torch.float16\n        assert q.dtype == k.dtype and k.dtype == v.dtype\n        assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32\n        # softmax scale\n        if sm_scale is None:\n            sm_scale = 1 / math.sqrt(q.shape[-1])\n        o, lse = _compressed_attention_fwd(\n            q,\n            k,\n            v,\n            kernel_size,\n            kernel_stride,\n            cu_seqlens_q,\n            cu_seqlens_k,\n            max_seqlen_q,\n            max_seqlen_k,\n            sm_scale,\n        )\n        ctx.save_for_backward(q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k)\n        ctx.sm_scale = sm_scale\n        ctx.max_seqlen_q = max_seqlen_q\n        ctx.max_seqlen_k = max_seqlen_k\n        ctx.kernel_size = kernel_size\n        ctx.kernel_stride = kernel_stride\n        return o, lse\n\n    @staticmethod\n    def backward(ctx, do: torch.Tensor, *args) -> Any:\n        q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors\n        max_seqlen_q = ctx.max_seqlen_q\n        max_seqlen_k = ctx.max_seqlen_k\n        sm_scale = ctx.sm_scale\n        kernel_size = ctx.kernel_size\n        kernel_stride = ctx.kernel_stride\n        dq, dk, dv = _compressed_attention_bwd(\n            o,\n            do,\n            lse,\n            q,\n            k,\n            v,\n            kernel_size,\n            kernel_stride,\n            cu_seqlens_q,\n            cu_seqlens_k,\n            max_seqlen_q,\n            max_seqlen_k,\n            sm_scale,\n        )\n        return dq, dk, dv, None, None, None, None, None, None, None\n\n\n@triton.jit\ndef score_kernel(\n    q_ptr,\n    k_ptr,\n    lse_ptr,\n    s_ptr,\n    kernel_size,\n    kernel_stride,\n    # seqlens\n    cu_seqlens_q,\n    cu_seqlens_k,\n    # shape\n    NUM_KV_HEADS,\n    NUM_SHARE_Q_HEADS,\n    HEAD_DIM,\n    # sm_scale\n    sm_scale,\n    # stride\n    stride_qn,\n    stride_qh,\n    stride_qd,\n    stride_kn,\n    stride_kh,\n    stride_kd,\n    stride_lh,\n    stride_ln,\n    stride_sh,\n    stride_sq,\n    stride_sk,\n    # META parameters\n    BLOCK_SIZE_Q: tl.constexpr,  # q block size\n    BLOCK_SIZE_K: tl.constexpr,  # k block size\n    BLOCK_SIZE_D: tl.constexpr,\n):\n    qk_scale = sm_scale * 1.44269504\n    # get batch id and head id\n    pid_bkh = tl.program_id(0)\n    pid_b = pid_bkh // NUM_KV_HEADS\n    pid_kh = pid_bkh % NUM_KV_HEADS\n    pid_q = tl.program_id(1)\n    pid_k = tl.program_id(2)\n    # get q k start and len after rmpad\n    q_start = tl.load(cu_seqlens_q + pid_b)\n    q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start\n    k_start = tl.load(cu_seqlens_k + pid_b)\n    k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start\n    if pid_q * BLOCK_SIZE_Q >= q_len or pid_k * BLOCK_SIZE_K >= k_len:\n        return\n    # init k pointer and load k\n    k_ptrs = tl.make_block_ptr(\n        base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,\n        shape=(HEAD_DIM, k_len),\n        strides=(stride_kd, stride_kn),\n        offsets=(0, pid_k * BLOCK_SIZE_K),\n        block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K),\n        order=(0, 1),\n    )\n    k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option=\"zero\")\n    # offsets\n    off_q = tl.arange(0, BLOCK_SIZE_Q) + pid_q * BLOCK_SIZE_Q\n    off_k = tl.arange(0, BLOCK_SIZE_K) + pid_k * BLOCK_SIZE_K\n    causal_mask = off_q[:, None] >= (off_k * kernel_stride + kernel_size - 1)[None, :]\n    # init score\n    s = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)\n    # loop over gqa heads\n    for h in range(NUM_SHARE_Q_HEADS):\n        pid_h = pid_kh * NUM_SHARE_Q_HEADS + h\n        q_ptrs = tl.make_block_ptr(\n            base=q_ptr + q_start * stride_qn + pid_h * stride_qh,\n            shape=(q_len, HEAD_DIM),\n            strides=(stride_qn, stride_qd),\n            offsets=(pid_q * BLOCK_SIZE_Q, 0),\n            block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D),\n            order=(1, 0),\n        )\n        lse_ptrs = tl.make_block_ptr(\n            base=lse_ptr + q_start * stride_ln + pid_h * stride_lh,\n            shape=(q_len, 1),\n            strides=(stride_ln, stride_lh),\n            offsets=(pid_q * BLOCK_SIZE_Q, 0),\n            block_shape=(BLOCK_SIZE_Q, 1),\n            order=(0, 1),\n        )\n        # load q and lse\n        q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option=\"zero\")\n        lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option=\"zero\")\n        # compute qk\n        qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)\n        qk += tl.dot(q, k) * qk_scale\n        # compute score\n        s += tl.where(causal_mask, tl.exp2(qk - lse), 0)\n    # save output\n    s_ptrs = tl.make_block_ptr(\n        base=s_ptr + pid_kh * stride_sh + q_start * stride_sq,\n        shape=(q_len, k_len),\n        strides=(stride_sq, stride_sk),\n        offsets=(pid_q * BLOCK_SIZE_Q, pid_k * BLOCK_SIZE_K),\n        block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_K),\n        order=(1, 0),\n    )\n    tl.store(s_ptrs, s.to(s_ptr.dtype.element_ty), boundary_check=(0, 1))\n\n\ndef _get_attention_score(\n    q: torch.Tensor,  # [total_query_len, num_q_heads, head_dim]\n    k: torch.Tensor,  # [total_key_len, num_k_heads, head_dim]\n    lse: torch.Tensor,  # [num_q_heads, total_query_len]\n    kernel_size: int,\n    kernel_stride: int,\n    cu_seqlens_q: torch.Tensor,\n    cu_seqlens_k: torch.Tensor,\n    max_seqlen_q: int,\n    max_seqlen_k: int,\n    sm_scale: float,\n) -> torch.Tensor:\n    # dtype check\n    assert q.dtype == torch.bfloat16 or q.dtype == torch.float16\n    assert q.dtype == k.dtype\n    assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32\n    assert (\n        lse.dtype == torch.float32\n    )  # lse here is log2(sum(exp(qk*scale))), not log(sum(exp(qk*scale)))\n    # shape\n    q_len, num_q_heads, head_dim = q.shape\n    k_len, num_k_heads, head_dim = k.shape\n    batch_size = cu_seqlens_q.shape[0] - 1\n    assert q_len > k_len\n    if sm_scale is None:\n        sm_scale = 1 / math.sqrt(head_dim)\n    # gqa\n    assert num_q_heads % num_k_heads == 0\n    num_share_q_heads = num_q_heads // num_k_heads\n    # init score\n    score = torch.zeros(\n        num_k_heads, q_len, max_seqlen_k, dtype=torch.float32, device=q.device\n    )\n    # launch kernel\n    grid = lambda META: (\n        batch_size * num_k_heads,\n        triton.cdiv(max_seqlen_q, META[\"BLOCK_SIZE_Q\"]),\n        triton.cdiv(max_seqlen_k, META[\"BLOCK_SIZE_K\"]),\n    )\n    BLOCK_SIZE_Q = 128\n    BLOCK_SIZE_K = 128\n    BLOCK_SIZE_D = triton.next_power_of_2(head_dim)\n    score_kernel[grid](\n        q,\n        k,\n        lse,\n        score,\n        kernel_size,\n        kernel_stride,\n        cu_seqlens_q,\n        cu_seqlens_k,\n        num_k_heads,\n        num_share_q_heads,\n        head_dim,\n        sm_scale,\n        q.stride(0),\n        q.stride(1),\n        q.stride(2),\n        k.stride(0),\n        k.stride(1),\n        k.stride(2),\n        lse.stride(0),\n        lse.stride(1),\n        score.stride(0),\n        score.stride(1),\n        score.stride(2),\n        BLOCK_SIZE_Q=BLOCK_SIZE_Q,\n        BLOCK_SIZE_K=BLOCK_SIZE_K,\n        BLOCK_SIZE_D=BLOCK_SIZE_D,\n        num_warps=8,\n        num_stages=3,\n    )\n    return score\n\n\n@triton.jit\ndef _transform_score_kernel(\n    s_ptr,  # score, shape: [num_heads, q_len, k_len]\n    bs_ptr,  # block wise score: [num_heads, q_len, num_k_block]\n    offs,\n    cu_seqlens_q,\n    # shape\n    num_heads,\n    num_offs,\n    max_k_len,\n    max_blocks,\n    pad_len,\n    # kernel & block size\n    block_size,\n    block_stride,  # block_size // kernel_stride\n    init_blocks,\n    local_blocks,\n    # stride\n    stride_sh,\n    stride_sq,\n    stride_sk,\n    stride_bsh,\n    stride_bsq,\n    stride_bsk,\n    BLOCK_SIZE_Q: tl.constexpr,\n    BLOCK_SIZE_K: tl.constexpr,\n    BLOCK_SIZE_O: tl.constexpr,\n):\n    pid_bh = tl.program_id(0)\n    pid_b = pid_bh // num_heads\n    pid_h = pid_bh % num_heads\n    pid_q = tl.program_id(1)\n    pid_k = tl.program_id(2)\n    q_start = tl.load(cu_seqlens_q + pid_b)\n    q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start\n    k_start = pid_k * BLOCK_SIZE_K\n    if pid_q * BLOCK_SIZE_Q >= q_len:\n        return\n    # load weight\n    off_o = tl.arange(0, BLOCK_SIZE_O)\n    w = tl.load(offs + off_o, mask=off_o < num_offs, other=0)\n    # load score\n    off_q = pid_q * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q)\n    off_k = (k_start + tl.arange(0, BLOCK_SIZE_K)) * block_stride - pad_len\n    off_k = off_k[None, :] + off_o[:, None]\n    s_ptrs = (\n        s_ptr\n        + q_start * stride_sq\n        + pid_h * stride_sh\n        + off_q[:, None, None] * stride_sq\n        + off_k[None, :, :] * stride_sk\n    )\n    # weighted sum, [BQ, BO, BK] * [1, BO, 1] -> [BQ, BO, BK] -> [BQ, BK]\n    s = tl.load(\n        s_ptrs,\n        mask=(off_q < q_len)[:, None, None] & (off_k >= 0) & (off_k < max_k_len),\n        other=0,\n    )\n    s = s * w[None, :, None]\n    s = tl.sum(s, axis=1)\n    # init mask and local mask\n    off_bq = off_q // block_size\n    off_bk = k_start + tl.arange(0, BLOCK_SIZE_K)\n    s = tl.where(\n        (\n            (off_bq[:, None] >= off_bk[None, :])  # causal mask\n            & (off_bq[:, None] < off_bk[None, :] + local_blocks)  # local window\n        )\n        | (off_bk[None, :] < init_blocks),  # init window\n        float(\"inf\"),\n        s,\n    )\n    # store block wise score\n    bs_ptrs = (\n        bs_ptr\n        + q_start * stride_bsq\n        + pid_h * stride_bsh\n        + off_q[:, None] * stride_bsq\n        + off_bk[None, :] * stride_bsk\n    )\n    tl.store(\n        bs_ptrs,\n        s,\n        mask=(off_q < q_len)[:, None] & (off_bk < max_blocks)[None, :],\n    )\n\n\ndef transform_score(\n    score: torch.Tensor,\n    kernel_size: int,\n    kernel_stride: int,\n    block_size: int,\n    cu_seqlens_q: torch.Tensor,\n    cu_seqlens_k: torch.Tensor,\n    max_seqlen_q: int,\n    max_seqlen_k: int,\n    init_blocks: int = 1,\n    local_blocks: int = 2,\n) -> torch.Tensor:\n    num_k_heads, total_query_len, max_key_len = score.shape\n    batch_size = cu_seqlens_q.shape[0] - 1\n    pad_len = kernel_size // kernel_stride - 1\n    max_blocks = math.ceil(max_seqlen_q / block_size)\n    block_score = torch.zeros(\n        num_k_heads,\n        total_query_len,\n        max_blocks,\n        dtype=torch.float32,\n        device=score.device,\n    )\n    offs = (\n        torch.arange(kernel_size // kernel_stride, device=score.device)[:, None]\n        + torch.arange(block_size // kernel_stride, device=score.device)[None, :]\n    ).view(-1)\n    offs = torch.histc(offs, bins=offs.max() + 1, min=0, max=offs.max())\n    num_offs = int(offs.shape[0])\n    BLOCK_SIZE_K = min(128, triton.next_power_of_2(max_blocks))\n    BLOCK_SIZE_O = triton.next_power_of_2(num_offs)\n    BLOCK_SIZE_Q = 8\n    grid = (\n        num_k_heads * batch_size,\n        triton.cdiv(total_query_len, BLOCK_SIZE_Q),\n        triton.cdiv(max_blocks, BLOCK_SIZE_K),\n    )\n    _transform_score_kernel[grid](\n        score,\n        block_score,\n        offs,\n        cu_seqlens_q,\n        num_k_heads,\n        offs.shape[0],\n        max_key_len,\n        max_blocks,\n        pad_len,\n        block_size,\n        block_size // kernel_stride,\n        init_blocks,\n        local_blocks,\n        score.stride(0),\n        score.stride(1),\n        score.stride(2),\n        block_score.stride(0),\n        block_score.stride(1),\n        block_score.stride(2),\n        BLOCK_SIZE_Q=BLOCK_SIZE_Q,\n        BLOCK_SIZE_K=BLOCK_SIZE_K,\n        BLOCK_SIZE_O=BLOCK_SIZE_O,\n        num_warps=8,\n        num_stages=3,\n    )\n    return block_score\n\n\ndef compressed_attention(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    kernel_size: int,\n    kernel_stride: int,\n    block_size: int,\n    topk: int,\n    cu_seqlens_q: torch.Tensor,\n    cu_seqlens_k: torch.Tensor,\n    max_seqlen_q: int = None,\n    max_seqlen_k: int = None,\n    sm_scale: float = None,\n    init_blocks: int = 1,\n    local_blocks: int = 2,\n    parallel_topk_compute: Union[str, bool] = \"auto\",\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"Attention between query and compressed key and value. Compute attention output and topk block idx used in topk_sparse_attention.\n\n    Args:\n        q (torch.Tensor): shape [total_q_len, num_q_heads, head_dim]\n        k (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim]\n        v (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim]\n        kernel_size (int): kernel size in compress_key_value\n        kernel_stride (int): stride of compress_key_value\n        block_size (int): key value block size for topk sparse attention.\n        topk (int): number of blocks for each query.\n        cu_seqlens_q (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.\n        cu_seqlens_k (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_k in flash_attn_func_varlen.\n        max_seqlen_q (int): max q len of the batch. Defaults to None, means (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max().item().\n        max_seqlen_k (int): max k len of the batch. Defaults to None, means (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item().\n        sm_scale (float, optional): softmax scale. Defaults to None, means 1/sqrt(head_dim).\n        init_blocks (int, optional): Number of init blocks for each query. Defaults to 1.\n        local_blocks (int, optional): Number of local blocks for each query. Defaults to 2.\n        parallel_topk_compute (str, optional): Only set it to False when the sequence length is too long. This can avoid a current bug.\n            We'll fix this issue later. Defaults to auto, it will be set to False when the sequence length is greater than 32k and True otherwise.\n\n    Returns:\n        Tuple[torch.Tensor, torch.Tensor]: attention output and topk_idx used in topk_sparse_attention\n    \"\"\"\n    if max_seqlen_q is None:\n        max_seqlen_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max().item()\n    if max_seqlen_k is None:\n        max_seqlen_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item()\n    attn_output, lse = CompressedAttention.apply(\n        q,\n        k,\n        v,\n        kernel_size,\n        kernel_stride,\n        cu_seqlens_q,\n        cu_seqlens_k,\n        max_seqlen_q,\n        max_seqlen_k,\n        sm_scale,\n    )\n\n    # do not select topk index\n    if topk <= 0:\n        warnings.warn(\"topk <= 0, returned topk_idx will be None\")\n        return attn_output, None\n\n    assert topk >= init_blocks + local_blocks\n    with torch.no_grad():\n        num_k_heads, num_q_heads = k.shape[1], q.shape[1]\n        num_shared_q_heads = num_q_heads // num_k_heads\n        batch_size = cu_seqlens_q.shape[0] - 1\n        q_idx = torch.cat(\n            [\n                torch.arange(cu_seqlens_q[i + 1] - cu_seqlens_q[i], device=q.device)\n                for i in range(batch_size)\n            ],\n            dim=0,\n        )\n        q_idx = q_idx // block_size\n        # whether to use parallel version\n        if parallel_topk_compute == \"auto\":\n            parallel_topk_compute = cu_seqlens_q[-1] <= 32768\n        # parallel version\n        if parallel_topk_compute:\n            # recompute score\n            score = _get_attention_score(\n                q,\n                k,\n                lse,\n                kernel_size,\n                kernel_stride,\n                cu_seqlens_q,\n                cu_seqlens_k,\n                max_seqlen_q,\n                max_seqlen_k,\n                sm_scale,\n            )\n            # transform score to block-wise score\n            score = transform_score(\n                score,\n                kernel_size,\n                kernel_stride,\n                block_size,\n                cu_seqlens_q,\n                cu_seqlens_k,\n                max_seqlen_q,\n                max_seqlen_k,\n                init_blocks,\n                local_blocks,\n            )\n            # get topk\n            topk = min(topk, score.shape[-1])\n            topk_idx = score.topk(topk, dim=-1).indices.sort(-1).values\n            topk_idx[topk_idx > q_idx[None, :, None]] = -1\n            topk_idx = topk_idx.to(torch.int32)\n        # non parallel version, avoid some current bugs when sequence length is too long\n        # FIXME: need to fix later\n        else:\n            topk_idx_list = []\n            for h in range(num_k_heads):\n                # recompute score\n                score = _get_attention_score(\n                    q[:, h * num_shared_q_heads : (h + 1) * num_shared_q_heads],\n                    k[:, h : h + 1],\n                    lse[h * num_shared_q_heads : (h + 1) * num_shared_q_heads],\n                    kernel_size,\n                    kernel_stride,\n                    cu_seqlens_q,\n                    cu_seqlens_k,\n                    max_seqlen_q,\n                    max_seqlen_k,\n                    sm_scale,\n                )\n                # transform score to block-wise score\n                score = transform_score(\n                    score,\n                    kernel_size,\n                    kernel_stride,\n                    block_size,\n                    cu_seqlens_q,\n                    cu_seqlens_k,\n                    max_seqlen_q,\n                    max_seqlen_k,\n                    init_blocks,\n                    local_blocks,\n                )\n                # get topk\n                topk = min(topk, score.shape[-1])\n                topk_idx = score.topk(topk, dim=-1).indices.sort(-1).values\n                topk_idx[topk_idx > q_idx[None, :, None]] = -1\n                topk_idx = topk_idx.to(torch.int32)\n                topk_idx_list.append(topk_idx)\n            topk_idx = torch.cat(topk_idx_list, dim=0)\n    return attn_output, topk_idx\n"
  },
  {
    "path": "native_sparse_attention/ops/triton/flash_attention.py",
    "content": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport math\nfrom typing import Any, Optional\n\nimport torch\nimport triton\nimport triton.language as tl\nfrom native_sparse_attention.ops.triton.utils import get_num_warps_stages, is_hopper_gpu\n\n\nIS_HOPPER_GPU = is_hopper_gpu()\n\n\n@triton.jit\ndef forward_kernel(\n    q_ptr,  # Q: n x h x d\n    k_ptr,  # K: n x h x d\n    v_ptr,  # V: n x h x d\n    o_ptr,  # O: n x h x d\n    lse_ptr,  # LSE: h x n\n    # seqlens\n    cu_seqlens_q,\n    cu_seqlens_k,\n    # shape\n    NUM_KV_HEADS,\n    NUM_SHARE_Q_HEADS,\n    qk_head_dim,\n    v_head_dim,\n    # sm_scale\n    sm_scale,\n    # causal\n    causal,\n    # gqa\n    gqa_interleave,\n    # stride\n    stride_qn,\n    stride_qh,\n    stride_qd,\n    stride_kn,\n    stride_kh,\n    stride_kd,\n    stride_vn,\n    stride_vh,\n    stride_vd,\n    stride_on,\n    stride_oh,\n    stride_od,\n    stride_lh,\n    stride_ln,\n    # META parameters\n    BLOCK_SIZE_Q: tl.constexpr,  # q block size\n    BLOCK_SIZE_K: tl.constexpr,  # k block size\n    BLOCK_SIZE_KD: tl.constexpr,\n    BLOCK_SIZE_VD: tl.constexpr,\n):\n    qk_scale = sm_scale * 1.44269504\n    # get batch id and head id\n    pid_b = tl.program_id(0)\n    pid_h = tl.program_id(1)\n    pid_q = tl.program_id(2)\n    if gqa_interleave:\n        pid_kh = pid_h % NUM_KV_HEADS\n    else:\n        pid_kh = pid_h // NUM_SHARE_Q_HEADS\n    # get q k start and len after rmpad\n    q_start = tl.load(cu_seqlens_q + pid_b)\n    q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start\n    k_start = tl.load(cu_seqlens_k + pid_b)\n    k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start\n    if BLOCK_SIZE_Q * pid_q >= q_len:\n        return\n    # init qkv pointer\n    q_ptrs = tl.make_block_ptr(\n        base=q_ptr + q_start * stride_qn + pid_h * stride_qh,\n        shape=(q_len, qk_head_dim),\n        strides=(stride_qn, stride_qd),\n        offsets=(pid_q * BLOCK_SIZE_Q, 0),\n        block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_KD),\n        order=(1, 0),\n    )\n    k_ptrs = tl.make_block_ptr(\n        base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,\n        shape=(qk_head_dim, k_len),\n        strides=(stride_kd, stride_kn),\n        offsets=(0, 0),\n        block_shape=(BLOCK_SIZE_KD, BLOCK_SIZE_K),\n        order=(0, 1),\n    )\n    v_ptrs = tl.make_block_ptr(\n        base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,\n        shape=(k_len, v_head_dim),\n        strides=(stride_vn, stride_vd),\n        offsets=(0, 0),\n        block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_VD),\n        order=(1, 0),\n    )\n    # load q\n    q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option=\"zero\")\n    # init statistics\n    off_q = tl.arange(0, BLOCK_SIZE_Q) + pid_q * BLOCK_SIZE_Q\n    off_k = tl.arange(0, BLOCK_SIZE_K)\n    m_i = tl.full((BLOCK_SIZE_Q,), float(\"-inf\"), dtype=tl.float32)\n    lse_i = tl.full((BLOCK_SIZE_Q,), float(\"-inf\"), dtype=tl.float32)\n    acc_o = tl.full((BLOCK_SIZE_Q, BLOCK_SIZE_VD), 0, dtype=tl.float32)\n    # full attention or causal attention\n    lo = 0\n    if causal:\n        hi = min(k_len, (pid_q + 1) * BLOCK_SIZE_Q)\n    else:\n        hi = k_len\n    for i in range(lo, hi, BLOCK_SIZE_K):\n        i = tl.multiple_of(i, BLOCK_SIZE_K)\n        # load k\n        k = tl.load(k_ptrs, boundary_check=(1, 0), padding_option=\"zero\")\n        # compute qk\n        qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)\n        if causal:\n            qk += tl.where(off_q[:, None] >= (i + off_k)[None, :], 0, float(\"-inf\"))\n        else:\n            qk += tl.where((off_k < k_len - i)[None, :], 0, float(\"-inf\"))\n        qk += tl.dot(q, k) * qk_scale\n        # compute m_ij and l_ij\n        m_ij = tl.maximum(m_i, tl.max(qk, axis=1))\n        p = tl.math.exp2(qk - m_ij[:, None])\n        l_ij = tl.sum(p, axis=1)\n        # scale acc_o\n        acc_o_scale = tl.math.exp2(m_i - m_ij)\n        acc_o = acc_o * acc_o_scale[:, None]\n        # load v and update acc_o\n        v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option=\"zero\")\n        p = p.to(v.dtype)\n        acc_o += tl.dot(p, v)\n        # update statistics\n        m_i = m_ij\n        lse_i = m_ij + tl.math.log2(tl.math.exp2(lse_i - m_ij) + l_ij)\n        # update ptrs\n        k_ptrs = tl.advance(k_ptrs, (0, BLOCK_SIZE_K))\n        v_ptrs = tl.advance(v_ptrs, (BLOCK_SIZE_K, 0))\n    # final scale\n    acc_o = acc_o * tl.math.exp2(m_i - lse_i)[:, None]\n    # save output\n    o_ptrs = tl.make_block_ptr(\n        base=o_ptr + q_start * stride_on + pid_h * stride_oh,\n        shape=(q_len, v_head_dim),\n        strides=(stride_on, stride_od),\n        offsets=(pid_q * BLOCK_SIZE_Q, 0),\n        block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_VD),\n        order=(1, 0),\n    )\n    tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1))\n    # save lse\n    l_ptrs = lse_ptr + q_start * stride_ln + pid_h * stride_lh + off_q * stride_ln\n    tl.store(l_ptrs, lse_i, mask=off_q < q_len)\n\n\n@triton.jit\ndef backward_sum_o_do(\n    o_ptr,  # O: n x h x d\n    do_ptr,  # dO: n x h x d\n    delta_ptr,  # D: h x n\n    o_len,\n    HEAD_DIM,\n    stride_on,\n    stride_oh,\n    stride_od,\n    stride_don,\n    stride_doh,\n    stride_dod,\n    stride_dh,\n    stride_dn,\n    BLOCK_SIZE_O: tl.constexpr,\n    BLOCK_SIZE_D: tl.constexpr,\n):\n    pid_n = tl.program_id(0)\n    pid_h = tl.program_id(1)\n    off_n = pid_n * BLOCK_SIZE_O + tl.arange(0, BLOCK_SIZE_O)\n    off_d = tl.arange(0, BLOCK_SIZE_D)\n    o = tl.load(\n        o_ptr\n        + off_n[:, None] * stride_on\n        + pid_h * stride_oh\n        + off_d[None, :] * stride_od,\n        mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM),\n        other=0,\n    ).to(tl.float32)\n    do = tl.load(\n        do_ptr\n        + off_n[:, None] * stride_don\n        + pid_h * stride_doh\n        + off_d[None, :] * stride_dod,\n        mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM),\n        other=0,\n    ).to(tl.float32)\n    delta = tl.sum(o * do, axis=1)\n    tl.store(\n        delta_ptr + pid_h * stride_dh + off_n * stride_dn, delta, mask=off_n < o_len\n    )\n\n\n@triton.jit\ndef backward_dkdv(\n    q_ptr,  # Q: n x qh x d\n    k_ptr,  # K: n x kh x d\n    v_ptr,  # V: n x kh x d\n    lse_ptr,  # LSE: qh x n\n    d_ptr,  # Delta: qh x n\n    do_ptr,\n    dk_ptr,  # DK: sh x n x kh x d\n    dv_ptr,  # DV: sh x n x kh x d\n    # seqlens\n    cu_seqlens_q,\n    cu_seqlens_k,\n    # shape\n    NUM_KV_HEADS,\n    NUM_SHARE_Q_HEADS,\n    qk_head_dim,\n    v_head_dim,\n    # sm_scale\n    sm_scale,\n    # causal\n    causal,\n    # gqa\n    gqa_interleave,\n    # stride\n    stride_qn,\n    stride_qh,\n    stride_qd,\n    stride_kn,\n    stride_kh,\n    stride_kd,\n    stride_vn,\n    stride_vh,\n    stride_vd,\n    stride_lh,\n    stride_ln,\n    stride_dh,\n    stride_dn,\n    stride_don,\n    stride_doh,\n    stride_dod,\n    stride_dks,\n    stride_dkn,\n    stride_dkh,\n    stride_dkd,\n    stride_dvs,\n    stride_dvn,\n    stride_dvh,\n    stride_dvd,\n    # META parameters\n    BLOCK_SIZE_Q: tl.constexpr,  # q block size\n    BLOCK_SIZE_K: tl.constexpr,  # k block size\n    BLOCK_SIZE_KD: tl.constexpr,\n    BLOCK_SIZE_VD: tl.constexpr,\n):\n    qk_scale = sm_scale * 1.44269504\n    # get batch id and head id\n    pid_b = tl.program_id(0)\n    pid_h = tl.program_id(1)\n    if gqa_interleave:\n        pid_kh = pid_h % NUM_SHARE_Q_HEADS\n        pid_sh = pid_h // NUM_SHARE_Q_HEADS\n    else:\n        pid_kh = pid_h // NUM_SHARE_Q_HEADS\n        pid_sh = pid_h % NUM_SHARE_Q_HEADS\n    pid_k = tl.program_id(2)\n    # get q k start and len after rmpad\n    q_start = tl.load(cu_seqlens_q + pid_b)\n    q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start\n    k_start = tl.load(cu_seqlens_k + pid_b)\n    k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start\n    if BLOCK_SIZE_K * pid_k >= k_len:\n        return\n    # init pointers\n    k_ptrs = tl.make_block_ptr(\n        base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,\n        shape=(k_len, qk_head_dim),\n        strides=(stride_kn, stride_kd),\n        offsets=(pid_k * BLOCK_SIZE_K, 0),\n        block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_KD),\n        order=(1, 0),\n    )\n    dk_ptrs = tl.make_block_ptr(\n        base=dk_ptr + k_start * stride_dkn + pid_kh * stride_dkh + pid_sh * stride_dks,\n        shape=(k_len, qk_head_dim),\n        strides=(stride_dkn, stride_dkd),\n        offsets=(pid_k * BLOCK_SIZE_K, 0),\n        block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_KD),\n        order=(1, 0),\n    )\n    v_ptrs = tl.make_block_ptr(\n        base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,\n        shape=(k_len, v_head_dim),\n        strides=(stride_vn, stride_vd),\n        offsets=(pid_k * BLOCK_SIZE_K, 0),\n        block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_VD),\n        order=(1, 0),\n    )\n    dv_ptrs = tl.make_block_ptr(\n        base=dv_ptr + k_start * stride_dvn + pid_kh * stride_dvh + pid_sh * stride_dvs,\n        shape=(k_len, v_head_dim),\n        strides=(stride_dvn, stride_dvd),\n        offsets=(pid_k * BLOCK_SIZE_K, 0),\n        block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_VD),\n        order=(1, 0),\n    )\n    # offsets\n    off_q = tl.arange(0, BLOCK_SIZE_Q)\n    off_k = tl.arange(0, BLOCK_SIZE_K) + pid_k * BLOCK_SIZE_K\n    # load k v and keep in SRAM\n    k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option=\"zero\")\n    v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option=\"zero\")\n    # init dk dv\n    dk = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_KD), dtype=tl.float32)\n    dv = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_VD), dtype=tl.float32)\n    # causal\n    if causal:\n        q_lo = pid_k * BLOCK_SIZE_K\n    else:\n        q_lo = 0\n    q_ptrs = tl.make_block_ptr(\n        base=q_ptr + q_start * stride_qn + pid_h * stride_qh,\n        shape=(q_len, qk_head_dim),\n        strides=(stride_qn, stride_qd),\n        offsets=(q_lo, 0),\n        block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_KD),\n        order=(1, 0),\n    )\n    do_ptrs = tl.make_block_ptr(\n        base=do_ptr + q_start * stride_don + pid_h * stride_doh,\n        shape=(q_len, v_head_dim),\n        strides=(stride_don, stride_dod),\n        offsets=(q_lo, 0),\n        block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_VD),\n        order=(1, 0),\n    )\n    d_ptrs = tl.make_block_ptr(\n        base=d_ptr + q_start * stride_dn + pid_h * stride_dh,\n        shape=(q_len, 1),\n        strides=(stride_dn, stride_dh),\n        offsets=(q_lo, 0),\n        block_shape=(BLOCK_SIZE_Q, 1),\n        order=(0, 1),\n    )\n    lse_ptrs = tl.make_block_ptr(\n        base=lse_ptr + q_start * stride_ln + pid_h * stride_lh,\n        shape=(q_len, 1),\n        strides=(stride_ln, stride_lh),\n        offsets=(q_lo, 0),\n        block_shape=(BLOCK_SIZE_Q, 1),\n        order=(0, 1),\n    )\n    # loop for q blocks\n    for i in range(q_lo, q_len, BLOCK_SIZE_Q):\n        # load\n        q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option=\"zero\")\n        do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option=\"zero\")\n        lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option=\"zero\")\n        d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option=\"zero\")\n        # compute qk\n        if causal:\n            qk = tl.where(\n                (off_q + i)[:, None] >= off_k[None, :], float(0.0), float(\"-inf\")\n            )\n        else:\n            qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)\n        qk += tl.dot(q, k.T) * qk_scale\n        # compute p, ds\n        p = tl.math.exp2(qk - lse)\n        dp = tl.dot(do, v.T)\n        ds = sm_scale * p * (dp - d)\n        # cast dtype\n        p = p.to(do.dtype)\n        ds = ds.to(q.dtype)\n        # update dk and dv\n        dk += tl.dot(ds.T, q)\n        dv += tl.dot(p.T, do)\n        # increment pointers\n        q_ptrs = tl.advance(q_ptrs, (BLOCK_SIZE_Q, 0))\n        do_ptrs = tl.advance(do_ptrs, (BLOCK_SIZE_Q, 0))\n        lse_ptrs = tl.advance(lse_ptrs, (BLOCK_SIZE_Q, 0))\n        d_ptrs = tl.advance(d_ptrs, (BLOCK_SIZE_Q, 0))\n    # save dk dv\n    tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), boundary_check=(0, 1))\n    tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef backward_dq(\n    q_ptr,  # Q: n x qh x d\n    k_ptr,  # K: n x kh x d\n    v_ptr,  # V: n x kh x d\n    lse_ptr,  # LSE: qh x n\n    d_ptr,  # Delta: qh x n\n    do_ptr,\n    dq_ptr,\n    # seqlens\n    cu_seqlens_q,\n    cu_seqlens_k,\n    # shape\n    NUM_KV_HEADS,\n    NUM_SHARE_Q_HEADS,\n    qk_head_dim,\n    v_head_dim,\n    # sm_scale\n    sm_scale,\n    # causal\n    causal,\n    # gqa\n    gqa_interleave,\n    # stride\n    stride_qn,\n    stride_qh,\n    stride_qd,\n    stride_kn,\n    stride_kh,\n    stride_kd,\n    stride_vn,\n    stride_vh,\n    stride_vd,\n    stride_lh,\n    stride_ln,\n    stride_dh,\n    stride_dn,\n    stride_don,\n    stride_doh,\n    stride_dod,\n    stride_dqn,\n    stride_dqh,\n    stride_dqd,\n    # META parameters\n    BLOCK_SIZE_Q: tl.constexpr,  # q block size\n    BLOCK_SIZE_K: tl.constexpr,  # k block size\n    BLOCK_SIZE_KD: tl.constexpr,\n    BLOCK_SIZE_VD: tl.constexpr,\n):\n    qk_scale = sm_scale * 1.44269504\n    # get batch id and head id\n    pid_b = tl.program_id(0)\n    pid_h = tl.program_id(1)\n    pid_q = tl.program_id(2)\n    if gqa_interleave:\n        pid_kh = pid_h % NUM_KV_HEADS\n    else:\n        pid_kh = pid_h // NUM_SHARE_Q_HEADS\n    # get q k start and len after rmpad\n    q_start = tl.load(cu_seqlens_q + pid_b)\n    q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start\n    k_start = tl.load(cu_seqlens_k + pid_b)\n    k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start\n    if BLOCK_SIZE_Q * pid_q >= q_len:\n        return\n    # init pointers\n    q_ptrs = tl.make_block_ptr(\n        base=q_ptr + q_start * stride_qn + pid_h * stride_qh,\n        shape=(q_len, qk_head_dim),\n        strides=(stride_qn, stride_qd),\n        offsets=(pid_q * BLOCK_SIZE_Q, 0),\n        block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_KD),\n        order=(1, 0),\n    )\n    dq_ptrs = tl.make_block_ptr(\n        base=dq_ptr + q_start * stride_dqn + pid_h * stride_dqh,\n        shape=(q_len, qk_head_dim),\n        strides=(stride_dqn, stride_dqd),\n        offsets=(pid_q * BLOCK_SIZE_Q, 0),\n        block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_KD),\n        order=(1, 0),\n    )\n    k_ptrs = tl.make_block_ptr(\n        base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,\n        shape=(k_len, qk_head_dim),\n        strides=(stride_kn, stride_kd),\n        offsets=(0, 0),\n        block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_KD),\n        order=(1, 0),\n    )\n    v_ptrs = tl.make_block_ptr(\n        base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,\n        shape=(k_len, qk_head_dim),\n        strides=(stride_vn, stride_vd),\n        offsets=(0, 0),\n        block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_VD),\n        order=(1, 0),\n    )\n    do_ptrs = tl.make_block_ptr(\n        base=do_ptr + q_start * stride_don + pid_h * stride_doh,\n        shape=(q_len, qk_head_dim),\n        strides=(stride_don, stride_dod),\n        offsets=(pid_q * BLOCK_SIZE_Q, 0),\n        block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_VD),\n        order=(1, 0),\n    )\n    d_ptrs = tl.make_block_ptr(\n        base=d_ptr + q_start * stride_dn + pid_h * stride_dh,\n        shape=(q_len, 1),\n        strides=(stride_dn, stride_dh),\n        offsets=(pid_q * BLOCK_SIZE_Q, 0),\n        block_shape=(BLOCK_SIZE_Q, 1),\n        order=(0, 1),\n    )\n    lse_ptrs = tl.make_block_ptr(\n        base=lse_ptr + q_start * stride_ln + pid_h * stride_lh,\n        shape=(q_len, 1),\n        strides=(stride_ln, stride_lh),\n        offsets=(pid_q * BLOCK_SIZE_Q, 0),\n        block_shape=(BLOCK_SIZE_Q, 1),\n        order=(0, 1),\n    )\n    # offsets\n    off_q = tl.arange(0, BLOCK_SIZE_Q) + pid_q * BLOCK_SIZE_Q\n    off_k = tl.arange(0, BLOCK_SIZE_K)\n    # load q, do, lse, delta, and keep in SRAM\n    q = tl.load(q_ptrs, boundary_check=(1, 0), padding_option=\"zero\")\n    do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option=\"zero\")\n    lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option=\"zero\")\n    d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option=\"zero\")\n    # init dq\n    dq = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_KD), dtype=tl.float32)\n    # causal\n    if causal:\n        k_hi = (pid_q + 1) * BLOCK_SIZE_Q\n    else:\n        k_hi = k_len\n    for j in range(0, k_hi, BLOCK_SIZE_K):\n        # load\n        k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option=\"zero\")\n        v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option=\"zero\")\n        # compute qk\n        if causal:\n            qk = tl.where(\n                off_q[:, None] >= (off_k + j)[None, :], float(0.0), float(\"-inf\")\n            )\n        else:\n            qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)\n        qk += tl.dot(q, k.T) * qk_scale\n        # compute p, ds\n        p = tl.math.exp2(qk - lse)\n        dp = tl.dot(do, v.T)\n        ds = sm_scale * p * (dp - d)\n        # cast dtype\n        ds = ds.to(q.dtype)\n        # update dq\n        dq += tl.dot(ds, k)\n        # increment pointers\n        k_ptrs = tl.advance(k_ptrs, (BLOCK_SIZE_K, 0))\n        v_ptrs = tl.advance(v_ptrs, (BLOCK_SIZE_K, 0))\n    # save dq\n    tl.store(dq_ptrs, dq.to(dq_ptr.dtype.element_ty), boundary_check=(0, 1))\n\n\ndef _flash_attention_fwd(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    cu_seqlens_q: torch.Tensor,\n    cu_seqlens_k: torch.Tensor,\n    max_seqlen_q: int,\n    max_seqlen_k: int,\n    causal: bool,\n    sm_scale: float,\n    gqa_interleave: bool = False,\n):\n    # dtype check\n    assert q.dtype == torch.bfloat16 or q.dtype == torch.float16\n    assert k.dtype == q.dtype and v.dtype == q.dtype\n    assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32\n    # shape\n    q_len, num_q_heads, qk_head_dim = q.shape\n    k_len, num_k_heads, qk_head_dim = k.shape\n    v_len, num_v_heads, v_head_dim = v.shape\n    batch_size = cu_seqlens_q.shape[0] - 1\n    assert qk_head_dim <= 256 and v_head_dim <= 256, \"head_dim must be less than 256\"\n    assert q_len == k_len and k_len == v_len\n    # gqa\n    assert num_k_heads == num_v_heads\n    assert num_q_heads % num_k_heads == 0\n    num_share_q_heads = num_q_heads // num_k_heads\n    # output tensor\n    o = torch.empty(q.shape[0], q.shape[1], v.shape[-1], dtype=q.dtype, device=q.device)\n    lse = torch.empty(num_q_heads, q_len, dtype=torch.float32, device=q.device)\n    # launch kernel\n    grid = lambda META: (\n        batch_size,\n        num_q_heads,\n        triton.cdiv(max_seqlen_q, META[\"BLOCK_SIZE_Q\"]),\n    )\n    BLOCK_SIZE_Q = 128\n    BLOCK_SIZE_K = 64\n    BLOCK_SIZE_KD = triton.next_power_of_2(qk_head_dim)\n    BLOCK_SIZE_VD = triton.next_power_of_2(v_head_dim)\n    num_warps, num_stages = get_num_warps_stages(\n        max(qk_head_dim, v_head_dim), BLOCK_SIZE_Q, IS_HOPPER_GPU\n    )\n    forward_kernel[grid](\n        q,\n        k,\n        v,\n        o,\n        lse,\n        cu_seqlens_q,\n        cu_seqlens_k,\n        num_k_heads,\n        num_share_q_heads,\n        qk_head_dim,\n        v_head_dim,\n        sm_scale,\n        causal,\n        gqa_interleave,\n        q.stride(0),\n        q.stride(1),\n        q.stride(2),\n        k.stride(0),\n        k.stride(1),\n        k.stride(2),\n        v.stride(0),\n        v.stride(1),\n        v.stride(2),\n        o.stride(0),\n        o.stride(1),\n        o.stride(2),\n        lse.stride(0),\n        lse.stride(1),\n        BLOCK_SIZE_Q=BLOCK_SIZE_Q,\n        BLOCK_SIZE_K=BLOCK_SIZE_K,\n        BLOCK_SIZE_KD=BLOCK_SIZE_KD,\n        BLOCK_SIZE_VD=BLOCK_SIZE_VD,\n        num_warps=num_warps,\n        num_stages=num_stages,\n    )\n    return o, lse\n\n\ndef _flash_attention_bwd(\n    o: torch.Tensor,\n    do: torch.Tensor,\n    lse: torch.Tensor,\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    cu_seqlens_q: torch.Tensor,\n    cu_seqlens_k: torch.Tensor,\n    max_seqlen_q: int,\n    max_seqlen_k: int,\n    causal: bool,\n    sm_scale: float,\n    gqa_interleave: bool = False,\n):\n    q_len, num_q_heads, qk_head_dim = q.shape\n    k_len, num_k_heads, qk_head_dim = k.shape\n    v_len, num_v_heads, v_head_dim = v.shape\n    o_len, num_o_heads, v_head_dim = o.shape\n    num_share_q_heads = num_q_heads // num_k_heads\n    # compute D\n    delta = torch.empty([num_o_heads, o_len], device=o.device, dtype=torch.float32)\n    grid = lambda META: (triton.cdiv(o_len, META[\"BLOCK_SIZE_O\"]), num_o_heads)\n    BLOCK_SIZE_O = 256\n    BLOCK_SIZE_VD = triton.next_power_of_2(v_head_dim)\n    num_warps, num_stages = get_num_warps_stages(\n        max(qk_head_dim, v_head_dim), BLOCK_SIZE_O, IS_HOPPER_GPU\n    )\n    backward_sum_o_do[grid](\n        o,\n        do,\n        delta,\n        o_len,\n        v_head_dim,\n        o.stride(0),\n        o.stride(1),\n        o.stride(2),\n        do.stride(0),\n        do.stride(1),\n        do.stride(2),\n        delta.stride(0),\n        delta.stride(1),\n        BLOCK_SIZE_O=BLOCK_SIZE_O,\n        BLOCK_SIZE_D=BLOCK_SIZE_VD,\n        num_warps=num_warps,\n        num_stages=num_stages,\n    )\n    # compute dk dv\n    dk = torch.empty(\n        num_share_q_heads,\n        k_len,\n        num_k_heads,\n        qk_head_dim,\n        device=k.device,\n        dtype=k.dtype,\n    )\n    dv = torch.empty(\n        num_share_q_heads,\n        k_len,\n        num_k_heads,\n        v_head_dim,\n        device=k.device,\n        dtype=k.dtype,\n    )\n    batch_size = cu_seqlens_q.shape[0] - 1\n    grid = lambda META: (\n        batch_size,\n        num_q_heads,\n        triton.cdiv(max_seqlen_k, META[\"BLOCK_SIZE_K\"]),\n    )\n    BLOCK_SIZE_Q = 64\n    BLOCK_SIZE_K = 64\n    BLOCK_SIZE_KD = triton.next_power_of_2(qk_head_dim)\n    BLOCK_SIZE_VD = triton.next_power_of_2(v_head_dim)\n    num_warps, num_stages = get_num_warps_stages(\n        max(qk_head_dim, v_head_dim), BLOCK_SIZE_K, IS_HOPPER_GPU\n    )\n    backward_dkdv[grid](\n        q,\n        k,\n        v,\n        lse,\n        delta,\n        do,\n        dk,\n        dv,\n        cu_seqlens_q,\n        cu_seqlens_k,\n        num_k_heads,\n        num_share_q_heads,\n        qk_head_dim,\n        v_head_dim,\n        sm_scale,\n        causal,\n        gqa_interleave,\n        q.stride(0),\n        q.stride(1),\n        q.stride(2),\n        k.stride(0),\n        k.stride(1),\n        k.stride(2),\n        v.stride(0),\n        v.stride(1),\n        v.stride(2),\n        lse.stride(0),\n        lse.stride(1),\n        delta.stride(0),\n        delta.stride(1),\n        do.stride(0),\n        do.stride(1),\n        do.stride(2),\n        dk.stride(0),\n        dk.stride(1),\n        dk.stride(2),\n        dk.stride(3),\n        dv.stride(0),\n        dv.stride(1),\n        dv.stride(2),\n        dv.stride(3),\n        BLOCK_SIZE_Q=BLOCK_SIZE_Q,\n        BLOCK_SIZE_K=BLOCK_SIZE_K,\n        BLOCK_SIZE_KD=BLOCK_SIZE_KD,\n        BLOCK_SIZE_VD=BLOCK_SIZE_VD,\n        num_warps=num_warps,\n        num_stages=num_stages,\n    )\n    dk = dk.sum(0)\n    dv = dv.sum(0)\n    # compute dq\n    dq = torch.empty_like(q)\n    grid = lambda META: (\n        batch_size,\n        num_q_heads,\n        triton.cdiv(max_seqlen_q, META[\"BLOCK_SIZE_Q\"]),\n    )\n    BLOCK_SIZE_Q = 64 if max(qk_head_dim, v_head_dim) > 128 else 128\n    BLOCK_SIZE_K = 64\n    num_warps, num_stages = get_num_warps_stages(\n        max(qk_head_dim, v_head_dim), BLOCK_SIZE_Q, IS_HOPPER_GPU\n    )\n    backward_dq[grid](\n        q,\n        k,\n        v,\n        lse,\n        delta,\n        do,\n        dq,\n        cu_seqlens_q,\n        cu_seqlens_k,\n        num_k_heads,\n        num_share_q_heads,\n        qk_head_dim,\n        v_head_dim,\n        sm_scale,\n        causal,\n        gqa_interleave,\n        q.stride(0),\n        q.stride(1),\n        q.stride(2),\n        k.stride(0),\n        k.stride(1),\n        k.stride(2),\n        v.stride(0),\n        v.stride(1),\n        v.stride(2),\n        lse.stride(0),\n        lse.stride(1),\n        delta.stride(0),\n        delta.stride(1),\n        do.stride(0),\n        do.stride(1),\n        do.stride(2),\n        dq.stride(0),\n        dq.stride(1),\n        dq.stride(2),\n        BLOCK_SIZE_Q=BLOCK_SIZE_Q,\n        BLOCK_SIZE_K=BLOCK_SIZE_K,\n        BLOCK_SIZE_KD=BLOCK_SIZE_KD,\n        BLOCK_SIZE_VD=BLOCK_SIZE_VD,\n        num_warps=num_warps,\n        num_stages=num_stages,\n    )\n    return dq, dk, dv\n\n\nclass FlashAttention(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        cu_seqlens_q: torch.Tensor,\n        cu_seqlens_k: torch.Tensor,\n        max_seqlen_q: int,\n        max_seqlen_k: int,\n        causal=True,\n        sm_scale=None,\n        gqa_interleave=False,\n    ):\n        # softmax scale\n        if sm_scale is None:\n            sm_scale = 1 / math.sqrt(q.shape[-1])\n        o, lse = _flash_attention_fwd(\n            q,\n            k,\n            v,\n            cu_seqlens_q,\n            cu_seqlens_k,\n            max_seqlen_q,\n            max_seqlen_k,\n            causal,\n            sm_scale,\n            gqa_interleave,\n        )\n        ctx.save_for_backward(q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k)\n        ctx.sm_scale = sm_scale\n        ctx.max_seqlen_q = max_seqlen_q\n        ctx.max_seqlen_k = max_seqlen_k\n        ctx.causal = causal\n        ctx.gqa_interleave = gqa_interleave\n        return o\n\n    @staticmethod\n    def backward(ctx, do: torch.Tensor, *args) -> Any:\n        q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors\n        max_seqlen_q = ctx.max_seqlen_q\n        max_seqlen_k = ctx.max_seqlen_k\n        sm_scale = ctx.sm_scale\n        causal = ctx.causal\n        gqa_interleave = ctx.gqa_interleave\n        dq, dk, dv = _flash_attention_bwd(\n            o,\n            do,\n            lse,\n            q,\n            k,\n            v,\n            cu_seqlens_q,\n            cu_seqlens_k,\n            max_seqlen_q,\n            max_seqlen_k,\n            causal,\n            sm_scale,\n            gqa_interleave,\n        )\n        return dq, dk, dv, None, None, None, None, None, None, None\n\n\ndef flash_attention_varlen(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    cu_seqlens_q: torch.Tensor,\n    cu_seqlens_k: torch.Tensor,\n    max_seqlen_q: int,\n    max_seqlen_k: int,\n    causal: bool = False,\n    sm_scale: Optional[float] = None,\n    gqa_interleave: bool = False,\n) -> torch.Tensor:\n    \"\"\"Flash attention with variable length based on triton.\n\n    Args:\n        q (torch.Tensor): shape [total_q_len, num_q_heads, head_dim]\n        k (torch.Tensor): shape [total_kv_len, num_q_heads, head_dim]\n        v (torch.Tensor): shape [total_kv_len, num_q_heads, head_dim]\n        cu_seqlens_q (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.\n        cu_seqlens_k (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_k in flash_attn_func_varlen.\n        max_seqlen_q (int): max q len of the batch.\n        max_seqlen_k (int): max k len of the batch.\n        causal (bool, optional): Causal mask. Defaults to False.\n        sm_scale (float, optional): softmax scale. Defaults to None, means 1/sqrt(head_dim).\n        gqa_interleave (bool, optional): GQA pattern. Defaults to False, use Llama style GQA.\n\n    Returns:\n        torch.Tensor: attention output with shape [total_q_len, num_q_heads, head_dim]\n    \"\"\"\n    return FlashAttention.apply(\n        q,\n        k,\n        v,\n        cu_seqlens_q,\n        cu_seqlens_k,\n        max_seqlen_q,\n        max_seqlen_k,\n        causal,\n        sm_scale,\n        gqa_interleave,\n    )\n"
  },
  {
    "path": "native_sparse_attention/ops/triton/flash_attention_decode.py",
    "content": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport math\nimport torch\nimport triton\nimport triton.language as tl\nfrom typing import Optional\n\n\n@triton.jit\ndef decode_kernel(\n    q_ptr,  # Q: b x h x d\n    k_ptr,  # K: b x n x h x d\n    v_ptr,  # V: b x n x h x d\n    o_ptr,  # O: b x h x d\n    seqlens,\n    # shape\n    BATCH_SIZE,\n    NUM_SHARE_Q_HEADS,\n    HEAD_DIM,\n    # sm_scale\n    sm_scale,\n    # stride\n    stride_qb,\n    stride_qh,\n    stride_qd,\n    stride_kb,\n    stride_kn,\n    stride_kh,\n    stride_kd,\n    stride_vb,\n    stride_vn,\n    stride_vh,\n    stride_vd,\n    stride_ob,\n    stride_oh,\n    stride_od,\n    # META parameters\n    BLOCK_SIZE_B: tl.constexpr,\n    BLOCK_SIZE_K: tl.constexpr,\n    BLOCK_SIZE_D: tl.constexpr,\n):\n    qk_scale = sm_scale * 1.44269504\n    # get batch id and head id\n    pid_h = tl.program_id(0)\n    pid_b = tl.program_id(1)\n    pid_kh = pid_h // NUM_SHARE_Q_HEADS\n    # get q k start and len after rmpad\n    off_b = tl.arange(0, BLOCK_SIZE_B)\n    kv_len = tl.load(\n        seqlens + pid_b * BLOCK_SIZE_B + off_b,\n        mask=pid_b * BLOCK_SIZE_B + off_b < BATCH_SIZE,\n        other=0,\n    )\n    max_kv_len = tl.max(kv_len)\n    # init qkv pointer\n    q_ptrs = tl.make_block_ptr(\n        base=q_ptr + pid_h * stride_qh,\n        shape=(BATCH_SIZE, HEAD_DIM),\n        strides=(stride_qb, stride_qd),\n        offsets=(pid_b * BLOCK_SIZE_B, 0),\n        block_shape=(BLOCK_SIZE_B, BLOCK_SIZE_D),\n        order=(1, 0),\n    )\n    k_ptrs = tl.make_block_ptr(\n        base=k_ptr + pid_kh * stride_kh,\n        shape=(BATCH_SIZE, max_kv_len, HEAD_DIM),\n        strides=(stride_kb, stride_kn, stride_kd),\n        offsets=(pid_b * BLOCK_SIZE_B, 0, 0),\n        block_shape=(BLOCK_SIZE_B, BLOCK_SIZE_K, BLOCK_SIZE_D),\n        order=(2, 1, 0),\n    )\n    v_ptrs = tl.make_block_ptr(\n        base=v_ptr + pid_kh * stride_vh,\n        shape=(BATCH_SIZE, max_kv_len, HEAD_DIM),\n        strides=(stride_vb, stride_vn, stride_vd),\n        offsets=(pid_b * BLOCK_SIZE_B, 0, 0),\n        block_shape=(BLOCK_SIZE_B, BLOCK_SIZE_K, BLOCK_SIZE_D),\n        order=(2, 1, 0),\n    )\n    # load q\n    q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option=\"zero\")\n    # init statistics\n    off_k = tl.arange(0, BLOCK_SIZE_K)\n    m_i = tl.full((BLOCK_SIZE_B,), float(\"-inf\"), dtype=tl.float32)\n    lse_i = tl.full((BLOCK_SIZE_B,), float(\"-inf\"), dtype=tl.float32)\n    acc_o = tl.full((BLOCK_SIZE_B, BLOCK_SIZE_D), 0, dtype=tl.float32)\n    # full attention or causal attention\n    for i in range(0, max_kv_len, BLOCK_SIZE_K):\n        i = tl.multiple_of(i, BLOCK_SIZE_K)\n        # load k\n        k = tl.load(k_ptrs, boundary_check=(0, 1, 2), padding_option=\"zero\")\n        # compute qk\n        qk = tl.zeros((BLOCK_SIZE_B, BLOCK_SIZE_K), dtype=tl.float32)\n        qk += tl.where(off_k[None, :] + i < kv_len[:, None], 0, float(\"-inf\"))\n        # [B, D], [B, K, D] -> [B, K]\n        qk += tl.sum(q[:, None, :] * k, axis=2) * qk_scale\n        # compute m_ij and l_ij\n        m_ij = tl.maximum(m_i, tl.max(qk, axis=1))\n        p = tl.math.exp2(qk - m_ij[:, None])\n        l_ij = tl.sum(p, axis=1)\n        # scale acc_o\n        acc_o_scale = tl.math.exp2(m_i - m_ij)\n        acc_o = acc_o * acc_o_scale[:, None]\n        # load v and update acc_o\n        v = tl.load(v_ptrs, boundary_check=(0, 1, 2), padding_option=\"zero\")\n        p = p.to(v.dtype)\n        # [B, K], [B, K, D] -> [B, D]\n        acc_o += tl.sum(p[:, :, None] * v, axis=1)\n        # update statistics\n        m_i = m_ij\n        lse_i = m_ij + tl.math.log2(tl.math.exp2(lse_i - m_ij) + l_ij)\n        # update ptrs\n        k_ptrs = tl.advance(k_ptrs, (0, BLOCK_SIZE_K, 0))\n        v_ptrs = tl.advance(v_ptrs, (0, BLOCK_SIZE_K, 0))\n    # final scale\n    acc_o = acc_o * tl.math.exp2(m_i - lse_i)[:, None]\n    # save output\n    o_ptrs = tl.make_block_ptr(\n        base=o_ptr + pid_h * stride_oh,\n        shape=(BATCH_SIZE, HEAD_DIM),\n        strides=(stride_ob, stride_od),\n        offsets=(pid_b * BLOCK_SIZE_B, 0),\n        block_shape=(BLOCK_SIZE_B, BLOCK_SIZE_D),\n        order=(1, 0),\n    )\n    tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1))\n\n\ndef flash_attention_decode(\n    q: torch.Tensor,  # [batch_size, num_heads, head_dim]\n    k: torch.Tensor,  # [batch_size, max_len, num_heads, head_dim]\n    v: torch.Tensor,\n    seqlens: torch.Tensor,  # [batch_size, ]\n    sm_scale: Optional[float] = None,\n) -> torch.Tensor:\n    \"\"\"flash attention for decode.\n\n    Args:\n        q (torch.Tensor): query, shape [batch_size, num_q_heads, head_dim]\n        k (torch.Tensor): key, shape [batch_size, kv_len, num_kv_heads, head_dim]\n        v (torch.Tensor): value, shape [batch_size, kv_len, num_kv_heads, head_dim]\n        seqlens (torch.Tensor): kv length for each sequence\n        sm_scale (Optional[float]): softmax scale, default to 1/sqrt(head_dim)\n\n    Returns:\n        torch.Tensor: attention output\n    \"\"\"\n    # dtype check\n    assert q.dtype == torch.bfloat16 or q.dtype == torch.float16\n    assert k.dtype == q.dtype and v.dtype == q.dtype\n    assert seqlens.dtype == torch.int32\n    # shape\n    batch_size, num_q_heads, head_dim = q.shape\n    _, k_len, num_k_heads, head_dim = k.shape\n    _, v_len, num_v_heads, head_dim = v.shape\n    assert k_len == v_len and batch_size == seqlens.shape[0]\n    # gqa\n    assert num_k_heads == num_v_heads\n    assert num_q_heads % num_k_heads == 0\n    num_share_q_heads = num_q_heads // num_k_heads\n    # sm scale\n    if sm_scale is None:\n        sm_scale = 1 / math.sqrt(head_dim)\n    # output tensor\n    o = torch.zeros_like(q)\n    # launch kernel\n    num_warps = 4 if head_dim <= 64 else 8\n    num_stages = 3\n    # there is a bug for triton 3.0.0 if BLOCK_SIZE_B > 16\n    BLOCK_SIZE_B = min(16, triton.next_power_of_2(batch_size))\n    BLOCK_SIZE_K = 128\n    BLOCK_SIZE_D = triton.next_power_of_2(head_dim)\n    grid = (num_q_heads, triton.cdiv(batch_size, BLOCK_SIZE_B))\n    decode_kernel[grid](\n        q,\n        k,\n        v,\n        o,\n        seqlens,\n        batch_size,\n        num_share_q_heads,\n        head_dim,\n        sm_scale,\n        q.stride(0),\n        q.stride(1),\n        q.stride(2),\n        k.stride(0),\n        k.stride(1),\n        k.stride(2),\n        k.stride(3),\n        v.stride(0),\n        v.stride(1),\n        v.stride(2),\n        v.stride(3),\n        o.stride(0),\n        o.stride(1),\n        o.stride(2),\n        BLOCK_SIZE_B=BLOCK_SIZE_B,\n        BLOCK_SIZE_K=BLOCK_SIZE_K,\n        BLOCK_SIZE_D=BLOCK_SIZE_D,\n        num_warps=num_warps,\n        num_stages=num_stages,\n    )\n    return o\n\n\ndef torch_attention_decode(\n    q: torch.Tensor,  # [batch_size, num_heads, head_dim]\n    k: torch.Tensor,  # [batch_size, max_len, num_heads, head_dim]\n    v: torch.Tensor,\n    seqlens: torch.Tensor,  # [batch_size, ]\n    sm_scale: Optional[float] = None,\n):\n    # dtype check\n    assert q.dtype == torch.bfloat16 or q.dtype == torch.float16\n    assert k.dtype == q.dtype and v.dtype == q.dtype\n    assert seqlens.dtype == torch.int32\n    # shape\n    batch_size, num_q_heads, head_dim = q.shape\n    _, k_len, num_k_heads, head_dim = k.shape\n    _, v_len, num_v_heads, head_dim = v.shape\n    assert k_len == v_len and batch_size == seqlens.shape[0]\n    # gqa\n    assert num_k_heads == num_v_heads\n    assert num_q_heads % num_k_heads == 0\n    num_share_q_heads = num_q_heads // num_k_heads\n    # sm scale\n    if sm_scale is None:\n        sm_scale = 1 / math.sqrt(head_dim)\n    # attention\n    attn = (\n        torch.einsum(\n            \"bqhd,bkhd->bhqk\",\n            q.unsqueeze(1),\n            k.repeat_interleave(num_share_q_heads, dim=2),\n        )\n        * sm_scale\n    )\n    mask = torch.arange(k_len, device=q.device)[None, :] < seqlens[:, None]\n    attn = attn.masked_fill(~mask[:, None, None, :], -torch.inf)\n    attn = torch.softmax(attn, dim=-1, dtype=torch.float32).to(q.dtype)\n    out = torch.einsum(\n        \"bhqk,bkhd->bqhd\", attn, v.repeat_interleave(num_share_q_heads, dim=2)\n    ).squeeze(1)\n    return out\n\n\nif __name__ == \"__main__\":\n    torch.manual_seed(42)\n    batch_size = 76\n    max_length = 8192\n    seqlens = torch.arange(batch_size, dtype=torch.int32).cuda() * 128 + 1\n    seqlens[seqlens > max_length] = max_length\n    seqlens = seqlens[torch.randn_like(seqlens, dtype=torch.float32).argsort(-1)]\n    q = (\n        torch.empty(batch_size, 32, 128, device=\"cuda\")\n        .uniform_(-1, 1)\n        .to(torch.bfloat16)\n    )\n    k = (\n        torch.empty(batch_size, max_length, 4, 128, device=\"cuda\")\n        .uniform_(-1, 1)\n        .to(torch.bfloat16)\n    )\n    v = (\n        torch.empty(batch_size, max_length, 4, 128, device=\"cuda\")\n        .uniform_(-1, 1)\n        .to(torch.bfloat16)\n    )\n\n    o1 = torch_attention_decode(q, k, v, seqlens)\n    o2 = flash_attention_decode(q, k, v, seqlens)\n\n    print(torch.allclose(o1, o2, atol=1e-2, rtol=1e-2))\n"
  },
  {
    "path": "native_sparse_attention/ops/triton/linear_compress.py",
    "content": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nfrom typing import Optional, Tuple, Any\nimport triton\nimport torch\nimport triton.language as tl\nfrom einops import rearrange, einsum\nfrom native_sparse_attention.ops.triton.utils import is_hopper_gpu\n\nIS_HOPPER_GPU = is_hopper_gpu()\n\n\n@triton.jit\ndef linear_compress_fwd_kernel(\n    X,  # input pointer [total_len, num_heads, head_dim]\n    Y,  # output pointer [total_compressed_len, num_heads, head_dim]\n    W,  # weight matrix pointer [num_heads, kernel_size, head_dim, head_dim]\n    cu_seqlens_x,  # cumulative sequence lengths before compression\n    cu_seqlens_y,  # cumulative sequence lengths after compression\n    stride_xn,  # stride for X's sequence dimension\n    stride_xh,  # stride for X's num head dimension\n    stride_xd,  # stride for X's head_dim dimension\n    stride_wh,  # stride for W's num head dimension\n    stride_wk,  # stride for W's kernel size  dimension\n    stride_wd,  # stride for W's initial head dim dimension\n    stride_wD,  # stride for W's final head dim dimension\n    stride_yn,  # stride for Y's sequence dimension\n    stride_yh,  # stride for Y's num head dimension\n    stride_yd,  # stride for Y's head_dim dimension\n    NUM_HEADS: tl.constexpr,  # total num heads\n    KERNEL_SIZE: tl.constexpr,  # kernel size when calculate the output\n    KERNEL_STRIDE: tl.constexpr,  # kernel stride when calculate the output\n    HEADd_DIM: tl.constexpr,  # initial head dimension size\n    HEADD_DIM: tl.constexpr,  # final head dimension size\n    BLOCK_OUTPUT_SEQ_SIZE: tl.constexpr,  # Loaded output len\n    BLOCK_KERNEL_SIZE: tl.constexpr,  # Loaded kernel size when calculate the output\n    BLOCK_HEADd_DIM: tl.constexpr,  # Loaded  orignal head dimension size\n    BLOCK_HEADD_DIM: tl.constexpr,  # loaded final head dimension size\n):\n    pid_bh = tl.program_id(0)\n    pid_b = pid_bh // NUM_HEADS\n    pid_h = pid_bh % NUM_HEADS\n    pid_k = tl.program_id(1)\n    pid_D = tl.program_id(2)\n\n    x_start = tl.load(cu_seqlens_x + pid_b)\n    x_end = tl.load(cu_seqlens_x + pid_b + 1)\n    x_len = x_end - x_start\n\n    y_start = tl.load(cu_seqlens_y + pid_b)\n    y_end = tl.load(cu_seqlens_y + pid_b + 1)\n    y_len = y_end - y_start\n    if pid_k * BLOCK_OUTPUT_SEQ_SIZE >= y_len:\n        return\n\n    off_kernel_size = tl.arange(0, BLOCK_KERNEL_SIZE)\n    off_d = tl.arange(0, BLOCK_HEADd_DIM)\n    off_output_seq_size = tl.arange(0, BLOCK_OUTPUT_SEQ_SIZE)\n\n    x_base_ptrs = (\n        X\n        + pid_h * stride_xh\n        + x_start * stride_xn\n        + (\n            (\n                pid_k * BLOCK_OUTPUT_SEQ_SIZE * KERNEL_STRIDE\n                + off_output_seq_size * KERNEL_STRIDE\n            )[:, None]\n            + off_kernel_size[None, :]\n        )[:, :, None]\n        * stride_xn\n        + off_d[None, None, :] * stride_xd\n    )\n    x_base_mask = (\n        (\n            (\n                pid_k * BLOCK_OUTPUT_SEQ_SIZE * KERNEL_STRIDE\n                + off_output_seq_size * KERNEL_STRIDE\n            )[:, None]\n            + off_kernel_size[None, :]\n        )\n        < x_len\n    )[:, :, None]\n\n    w_ptrs = tl.make_block_ptr(\n        base=W + pid_h * stride_wh,\n        shape=(KERNEL_SIZE, HEADd_DIM, HEADD_DIM),\n        strides=(stride_wk, stride_wd, stride_wD),\n        offsets=(0, 0, pid_D * BLOCK_HEADD_DIM),\n        block_shape=(BLOCK_KERNEL_SIZE, BLOCK_HEADd_DIM, BLOCK_HEADD_DIM),\n        order=(2, 1, 0),\n    )\n\n    y_ptrs = tl.make_block_ptr(\n        base=Y + y_start * stride_yn + pid_h * stride_yh,\n        shape=(y_len, HEADD_DIM),\n        strides=(stride_yn, stride_yd),\n        offsets=(pid_k * BLOCK_OUTPUT_SEQ_SIZE, pid_D * BLOCK_HEADD_DIM),\n        block_shape=(BLOCK_OUTPUT_SEQ_SIZE, BLOCK_HEADD_DIM),\n        order=(1, 0),\n    )\n\n    y_d = tl.full((BLOCK_OUTPUT_SEQ_SIZE, BLOCK_HEADD_DIM), 0, dtype=tl.float32)\n\n    for i in range(0, HEADd_DIM, BLOCK_HEADd_DIM):\n\n        x_ptrs = x_base_ptrs + i * stride_xd\n        x_mask = x_base_mask & ((i + off_d) < HEADd_DIM)[None, None, :]\n\n        x = tl.load(x_ptrs, mask=x_mask, other=0)\n        x = tl.reshape(x, (BLOCK_OUTPUT_SEQ_SIZE, BLOCK_KERNEL_SIZE * BLOCK_HEADd_DIM))\n        # x : [n, k * bd]\n\n        w = tl.load(w_ptrs, boundary_check=(0, 1, 2), padding_option=\"zero\")\n        w = tl.reshape(w, (BLOCK_KERNEL_SIZE * BLOCK_HEADd_DIM, BLOCK_HEADD_DIM))\n        # w: [k * bd, D]\n\n        y_d += tl.dot(x, w)\n        # y_d : [n, D]\n\n        w_ptrs = tl.advance(w_ptrs, (0, BLOCK_HEADd_DIM, 0))\n\n    tl.store(y_ptrs, y_d.to(y_ptrs.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef linear_compress_bwd_kernel(\n    DX,  # X's gradient pointer [total_len, num_heads, head_dim]\n    DY,  # Y's gradient pointer [total_compressed_len, num_heads, head_dim]\n    DW,  # weight's gradient pointer [num_heads, kernel_size, head_dim, head_dim]\n    X,  # x pointer [total_len, num_heads, head_dim]\n    W,  # weight matrix pointer [num_heads, kernel_size, head_dim, head_dim]\n    cu_seqlens_x,  # cumulative sequence lengths before compression\n    cu_seqlens_y,  # cumulative sequence lengths after compression\n    stride_xn,  # stride for X's sequence dimension\n    stride_xh,  # stride for X's num head dimension\n    stride_xd,  # stride for X's head_dim dimension\n    stride_wh,  # stride for W's num head dimension\n    stride_wk,  # stride for W's kernel size  dimension\n    stride_wd,  # stride for W's initial head dim dimension\n    stride_wD,  # stride for W's final head dim dimension\n    stride_dxn,  # stride for DX's sequence dimension\n    stride_dxh,  # stride for DX's num head dimension\n    stride_dxd,  # stride for DX's head_dim dimension\n    stride_dwh,  # stride for DW's num head dimension\n    stride_dwk,  # stride for DW's kernel size  dimension\n    stride_dwd,  # stride for DW's initial head dim dimension\n    stride_dwD,  # stride for DW's final head dim dimension\n    stride_dyn,  # stride for DY's sequence dimension\n    stride_dyh,  # stride for DY's num head dimension\n    stride_dyd,  # stride for DY's head_dim dimension\n    NUM_HEADS: tl.constexpr,  # total num heads\n    NUM_PARALLEL_HEADd_NUM: tl.constexpr,  # total parallel process among head dim d\n    KERNEL_SIZE: tl.constexpr,  # kernel size when calculate the output\n    KERNEL_STRIDE: tl.constexpr,  # kernel stride when calculate the output\n    HEADd_DIM: tl.constexpr,  # initial head dimension size\n    HEADD_DIM: tl.constexpr,  # final head dimension size\n    BLOCK_KERNEL_SIZE: tl.constexpr,  # Loaded kernel size when calculate the output\n    BLOCK_HEADd_DIM: tl.constexpr,  # Loaded  orignal head dimension size\n    BLOCK_HEADD_DIM: tl.constexpr,  # loaded final head dimension size\n    BLOCK_OUTPUT_SEQ_SIZE: tl.constexpr,  # loaded final output seq parallel\n):\n    pid_bh = tl.program_id(0)\n    pid_b = pid_bh // NUM_HEADS\n    pid_h = pid_bh % NUM_HEADS\n    pid_k = tl.program_id(1)\n    pid_Dd = tl.program_id(2)\n    pid_D = pid_Dd // NUM_PARALLEL_HEADd_NUM\n    pid_d = pid_Dd % NUM_PARALLEL_HEADd_NUM\n\n    x_start = tl.load(cu_seqlens_x + pid_b)\n    x_end = tl.load(cu_seqlens_x + pid_b + 1)\n    x_len = x_end - x_start\n\n    y_start = tl.load(cu_seqlens_y + pid_b)\n    y_end = tl.load(cu_seqlens_y + pid_b + 1)\n    y_len = y_end - y_start\n    if pid_k * BLOCK_OUTPUT_SEQ_SIZE >= y_len:\n        return\n\n    # pdb.set_trace()\n\n    off_kernel_size = tl.arange(0, BLOCK_KERNEL_SIZE)\n    off_d = tl.arange(0, BLOCK_HEADd_DIM)\n    off_D = tl.arange(0, BLOCK_HEADD_DIM)\n    off_output_seq_size = tl.arange(0, BLOCK_OUTPUT_SEQ_SIZE)\n\n    x_ptrs = (\n        X\n        + pid_h * stride_xh\n        + x_start * stride_xn\n        + (\n            (\n                pid_k * BLOCK_OUTPUT_SEQ_SIZE * KERNEL_STRIDE\n                + off_output_seq_size * KERNEL_STRIDE\n            )[:, None]\n            + off_kernel_size[None, :]\n        )[:, :, None]\n        * stride_xn\n        + (pid_d * BLOCK_HEADd_DIM + off_d)[None, None, :] * stride_xd\n    )\n\n    x_mask = (\n        (\n            (\n                pid_k * BLOCK_OUTPUT_SEQ_SIZE * KERNEL_STRIDE\n                + off_output_seq_size * KERNEL_STRIDE\n            )[:, None]\n            + off_kernel_size[None, :]\n        )\n        < x_len\n    )[:, :, None] & ((pid_d * BLOCK_HEADd_DIM + off_d) < HEADd_DIM)[None, None, :]\n\n    dx_ptrs = (\n        DX\n        + pid_h * stride_dxh\n        + x_start * stride_dxn\n        + (\n            (\n                pid_k * BLOCK_OUTPUT_SEQ_SIZE * KERNEL_STRIDE\n                + off_output_seq_size * KERNEL_STRIDE\n            )[:, None]\n            + off_kernel_size[None, :]\n        )[:, :, None]\n        * stride_dxn\n        + (pid_d * BLOCK_HEADd_DIM + off_d)[None, None, :] * stride_dxd\n    )\n\n    w_ptrs = tl.make_block_ptr(\n        base=W + pid_h * stride_wh,\n        shape=(KERNEL_SIZE, HEADd_DIM, HEADD_DIM),\n        strides=(stride_wk, stride_wd, stride_wD),\n        offsets=(0, pid_d * BLOCK_HEADd_DIM, pid_D * BLOCK_HEADD_DIM),\n        block_shape=(BLOCK_KERNEL_SIZE, BLOCK_HEADd_DIM, BLOCK_HEADD_DIM),\n        order=(2, 1, 0),\n    )\n\n    dw_ptrs = (\n        DW\n        + pid_h * stride_dwh\n        + off_kernel_size[:, None, None] * stride_dwk\n        + (pid_d * BLOCK_HEADd_DIM + off_d)[None, :, None] * stride_dwd\n        + (pid_D * BLOCK_HEADD_DIM + off_D)[None, None, :] * stride_dwD\n    )\n\n    dw_mask = (\n        (off_kernel_size < KERNEL_SIZE)[:, None, None]\n        & ((pid_d * BLOCK_HEADd_DIM + off_d) < HEADd_DIM)[None, :, None]\n        & ((pid_D * BLOCK_HEADD_DIM + off_D) < HEADD_DIM)[None, None, :]\n    )\n\n    dy_ptrs = tl.make_block_ptr(\n        base=DY + y_start * stride_dyn + pid_h * stride_dyh,\n        shape=(y_len, HEADD_DIM),\n        strides=(stride_dyn, stride_dyd),\n        offsets=(pid_k * BLOCK_OUTPUT_SEQ_SIZE, pid_D * BLOCK_HEADD_DIM),\n        block_shape=(BLOCK_OUTPUT_SEQ_SIZE, BLOCK_HEADD_DIM),\n        order=(1, 0),\n    )\n\n    dy = tl.load(dy_ptrs, boundary_check=(0, 1), padding_option=\"zero\")\n    # dy : [by, D]\n\n    # cal dx, start\n    w = tl.load(w_ptrs, boundary_check=(0, 1, 2), padding_option=\"zero\")\n    # w: [k, bd, D]\n    w = tl.reshape(w, (BLOCK_KERNEL_SIZE * BLOCK_HEADd_DIM, BLOCK_HEADD_DIM))\n    # w: [k * bd, D]\n\n    dx = tl.dot(dy, tl.trans(w))\n    # dx: [by, k * bd]\n\n    dx = tl.reshape(dx, (BLOCK_OUTPUT_SEQ_SIZE, BLOCK_KERNEL_SIZE, BLOCK_HEADd_DIM))\n    # dx: [by, k, bd]\n\n    tl.atomic_add(\n        dx_ptrs,\n        dx.to(dx_ptrs.dtype.element_ty),\n        mask=x_mask,\n    )\n    # cal dx, end\n\n    # cal dw, start\n    x = tl.load(x_ptrs, mask=x_mask, other=0)\n    x = tl.reshape(x, (BLOCK_OUTPUT_SEQ_SIZE, BLOCK_KERNEL_SIZE * BLOCK_HEADd_DIM))\n    # x : [by, k * bd]\n\n    dw = tl.dot(tl.trans(x), dy)\n    # dw: [k * bd, D]\n    dw = tl.reshape(dw, (BLOCK_KERNEL_SIZE, BLOCK_HEADd_DIM, BLOCK_HEADD_DIM))\n    # dw: [k, bd, D]\n\n    tl.atomic_add(dw_ptrs, dw.to(dw_ptrs.dtype.element_ty), mask=dw_mask)\n\n\nclass LinearCompress(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        x: torch.Tensor,\n        w: torch.Tensor,\n        cu_seqlens: torch.Tensor,\n        kernel_size: int,\n        kernel_stride: int,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Compress key and value tensor with kernel_size and kernel_stride. Similar to conv_compress.\n\n        Args:\n            x (torch.Tensor): key_states or value_states, shape (total_len, num_heads, head_dim)\n            w (torch.Tensor): weight for each head, shape (num_heads, kernel_size * head_dim, head_dim)\n            cu_seqlens (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen\n            kernel_size (int): kernel_size, each (kernel_size, head_dim) blocks will be compressed to (1, head_dim)\n            kernel_stride (int): stride for each compress kernel\n\n        Returns:\n            Tuple[torch.Tensor, torch.Tensor]: compressed states and corresponding cu_seqlens.\n        \"\"\"\n        # dtype check\n        assert x.dtype == torch.float16 or x.dtype == torch.bfloat16\n        assert x.dtype == w.dtype\n        assert cu_seqlens.dtype == torch.int32\n\n        # shape check\n        total_len, num_heads, head_dim = x.shape\n        batch_size = cu_seqlens.shape[0] - 1\n        assert w.shape[0] == num_heads\n        assert w.shape[1] == kernel_size * head_dim\n        assert w.shape[2] == head_dim\n        assert kernel_size % kernel_stride == 0\n        assert kernel_size in {16, 32, 64, 128}\n        assert head_dim % 8 == 0\n\n        torch.cuda.set_device(x.device)\n\n        # compute seqlens after compression\n        seqlens = cu_seqlens[1:] - cu_seqlens[:-1]\n        y_seqlens = (\n            torch.floor((seqlens - kernel_size) / kernel_stride).to(torch.int32) + 1\n        )\n        # corner case: if sequence_length < kernel_size, no compression for this sequence\n        y_seqlens[seqlens < kernel_size] = 0\n        y_cu_seqlens = torch.cat(\n            [\n                torch.zeros(1, dtype=torch.int32, device=x.device),\n                torch.cumsum(y_seqlens.to(x.device), dim=0),\n            ],\n            dim=0,\n        ).to(torch.int32)\n\n        y = torch.zeros(\n            y_cu_seqlens[-1], num_heads, head_dim, dtype=x.dtype, device=x.device\n        )\n\n        block_kernel_size = max(16, triton.next_power_of_2(kernel_size))\n        block_head_dim = 8 if IS_HOPPER_GPU else 4\n        block_headD_dim = 32\n        block_output_seq_size = 64\n        w = w.reshape(num_heads, kernel_size, head_dim, head_dim).contiguous()\n\n        grid = lambda META: (\n            batch_size * num_heads,\n            triton.cdiv(y_seqlens.max(0)[0].item(), META[\"BLOCK_OUTPUT_SEQ_SIZE\"]),\n            triton.cdiv(head_dim, META[\"BLOCK_HEADD_DIM\"]),\n        )\n\n        linear_compress_fwd_kernel[grid](\n            x,\n            y,\n            w,\n            cu_seqlens,\n            y_cu_seqlens,\n            x.stride(0),\n            x.stride(1),\n            x.stride(2),\n            w.stride(0),\n            w.stride(1),\n            w.stride(2),\n            w.stride(3),\n            y.stride(0),\n            y.stride(1),\n            y.stride(2),\n            num_heads,\n            kernel_size,\n            kernel_stride,\n            head_dim,\n            head_dim,\n            block_output_seq_size,\n            block_kernel_size,\n            block_head_dim,\n            block_headD_dim,\n            # num_warps=8,\n            # num_stages=3,\n        )\n        # save for backward\n        ctx.save_for_backward(x, w, cu_seqlens, y_seqlens, y_cu_seqlens)\n        # save value\n        ctx.kernel_size = kernel_size\n        ctx.kernel_stride = kernel_stride\n        ctx.block_kernel_size = block_kernel_size\n        ctx.block_headd_dim = block_head_dim\n        ctx.block_headD_dim = block_headD_dim\n        ctx.block_output_seq_size = block_output_seq_size\n        return y, y_cu_seqlens\n\n    @staticmethod\n    def backward(ctx, dy: torch.Tensor, *args) -> Any:\n        x, w, cu_seqlens, y_seqlens, y_cu_seqlens = ctx.saved_tensors\n        kernel_size = ctx.kernel_size\n        kernel_stride = ctx.kernel_stride\n        block_kernel_size = ctx.block_kernel_size\n        block_head_dim = ctx.block_headd_dim\n        block_headD_dim = ctx.block_headD_dim\n        block_output_seq_size = ctx.block_output_seq_size\n\n        total_len, num_heads, head_dim = x.shape\n        batch_size = cu_seqlens.shape[0] - 1\n\n        dx = torch.zeros(\n            cu_seqlens[-1], num_heads, head_dim, dtype=torch.float32, device=x.device\n        )\n\n        dw = torch.zeros(\n            num_heads,\n            kernel_size,\n            head_dim,\n            head_dim,\n            dtype=torch.float32,\n            device=x.device,\n        )\n\n        grid = lambda META: (\n            batch_size * num_heads,\n            triton.cdiv(y_seqlens.max(0)[0].item(), META[\"BLOCK_OUTPUT_SEQ_SIZE\"]),\n            triton.cdiv(head_dim, META[\"BLOCK_HEADD_DIM\"])\n            * triton.cdiv(head_dim, META[\"BLOCK_HEADd_DIM\"]),\n        )\n\n        linear_compress_bwd_kernel[grid](\n            dx,\n            dy,\n            dw,\n            x,\n            w,\n            cu_seqlens,\n            y_cu_seqlens,\n            x.stride(0),\n            x.stride(1),\n            x.stride(2),\n            w.stride(0),\n            w.stride(1),\n            w.stride(2),\n            w.stride(3),\n            dx.stride(0),\n            dx.stride(1),\n            dx.stride(2),\n            dw.stride(0),\n            dw.stride(1),\n            dw.stride(2),\n            dw.stride(3),\n            dy.stride(0),\n            dy.stride(1),\n            dy.stride(2),\n            num_heads,\n            head_dim // block_head_dim,\n            kernel_size,\n            kernel_stride,\n            head_dim,\n            head_dim,\n            block_kernel_size,\n            block_head_dim,\n            block_headD_dim,\n            block_output_seq_size,\n        )\n        return (\n            dx.to(x.dtype),\n            rearrange(dw.to(x.dtype), \"n k d D -> n (k d) D\"),\n            None,\n            None,\n            None,\n        )\n\n\ndef linear_compress(\n    x: torch.Tensor,\n    w: torch.Tensor,\n    cu_seqlens: torch.Tensor,\n    kernel_size: int,\n    kernel_stride: int,\n    pe: Optional[torch.Tensor] = None,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"Compress key and value tensor with kernel_size and kernel_stride with linear projection.\n\n    Args:\n        x (torch.Tensor): key_states or value_states, shape (total_len, num_heads, head_dim)\n        w (torch.Tensor): weight for each head, shape (num_heads, kernel_size * head_dim, head_dim)\n        cu_seqlens (_type_): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.\n        kernel_size (int): kernel_size, each (kernel_size, head_dim) blocks will be compressed to (1, head_dim)\n        kernel_stride (int): stride for each compress kernel\n        pe (Optional[torch.Tensor], optional): intra-block positional embedding with shape (num_heads, kernel_size, head_dim). Defaults to None.\n\n    Returns:\n        Tuple[torch.Tensor, torch.Tensor]: compressed states and corresponding cu_seqlens.\n    \"\"\"\n    y, y_cu_seqlens = LinearCompress.apply(x, w, cu_seqlens, kernel_size, kernel_stride)\n    # position embedding as a bias\n    if pe is not None:\n        assert pe.dtype == x.dtype and pe.device == x.device\n        pe = rearrange(pe, \"h k d -> h (k d)\")\n        bias = einsum(pe, w, \"h D, h D d -> h d\")\n        y = y + bias.unsqueeze(0)\n    return y, y_cu_seqlens\n"
  },
  {
    "path": "native_sparse_attention/ops/triton/topk_sparse_attention.py",
    "content": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport math\nfrom typing import Any, Optional\n\nimport torch\nimport triton\nimport triton.language as tl\nfrom native_sparse_attention.ops.triton.utils import get_num_warps_stages, is_hopper_gpu\n\n\nIS_HOPPER_GPU = is_hopper_gpu()\n\n\n@triton.jit\ndef forward_kernel(\n    q_ptr,  # Q: n x h x d\n    k_ptr,  # K: n x kh x d\n    v_ptr,  # V: n x kh x d\n    t_ptr,  # topk_idx: kh x n x k\n    o_ptr,  # O: n x h x d\n    lse_ptr,  # LSE: h x n\n    # seqlens\n    cu_seqlens_q,\n    cu_seqlens_k,\n    # shape\n    NUM_KV_HEADS,\n    NUM_SHARE_Q_HEADS,\n    HEAD_DIM,\n    TOPK,\n    # q loop num\n    num_q_loop,\n    # sm_scale\n    sm_scale,\n    # stride\n    stride_qn,\n    stride_qh,\n    stride_qd,\n    stride_kn,\n    stride_kh,\n    stride_kd,\n    stride_vn,\n    stride_vh,\n    stride_vd,\n    stride_th,\n    stride_tn,\n    stride_tk,\n    stride_on,\n    stride_oh,\n    stride_od,\n    stride_lh,\n    stride_ln,\n    # META parameters\n    BLOCK_SIZE_K: tl.constexpr,  # k block size\n    BLOCK_SIZE_D: tl.constexpr,\n    BLOCK_SIZE_H: tl.constexpr,\n    BLOCK_SIZE_T: tl.constexpr,\n):\n    qk_scale = sm_scale * 1.44269504\n    # get batch id and head id\n    pid_b = tl.program_id(0)\n    pid_kh = tl.program_id(1)\n    pid_h = pid_kh * NUM_SHARE_Q_HEADS\n    pid_q = tl.program_id(2)\n    # get q k start and len after rmpad\n    q_start = tl.load(cu_seqlens_q + pid_b)\n    q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start\n    k_start = tl.load(cu_seqlens_k + pid_b)\n    k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start\n    if pid_q * num_q_loop >= q_len:\n        return\n    real_q_loop = min(num_q_loop, q_len - pid_q * num_q_loop)\n    for j in range(real_q_loop):\n        pid_q_j = pid_q * num_q_loop + j\n        # init topk idx pointer\n        off_t = tl.arange(0, BLOCK_SIZE_T)\n        t_ptr_j = t_ptr + (q_start + pid_q_j) * stride_tn + pid_kh * stride_th\n        topk_idx = tl.load(t_ptr_j + off_t * stride_tk, mask=off_t < TOPK, other=-1)\n        real_topk = tl.sum(\n            tl.where((topk_idx >= 0) & (topk_idx <= pid_q_j // BLOCK_SIZE_K), 1, 0),\n            axis=0,\n        )\n        # init qkv pointer\n        q_ptrs = tl.make_block_ptr(\n            base=q_ptr + (q_start + pid_q_j) * stride_qn + pid_h * stride_qh,\n            shape=(NUM_SHARE_Q_HEADS, HEAD_DIM),\n            strides=(stride_qh, stride_qd),\n            offsets=(0, 0),\n            block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D),\n            order=(1, 0),\n        )\n        k_ptrs = tl.make_block_ptr(\n            base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,\n            shape=(HEAD_DIM, k_len),\n            strides=(stride_kd, stride_kn),\n            offsets=(0, 0),\n            block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K),\n            order=(0, 1),\n        )\n        v_ptrs = tl.make_block_ptr(\n            base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,\n            shape=(k_len, HEAD_DIM),\n            strides=(stride_vn, stride_vd),\n            offsets=(0, 0),\n            block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),\n            order=(1, 0),\n        )\n        # load q\n        q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option=\"zero\")\n        # init statistics\n        off_h = tl.arange(0, BLOCK_SIZE_H)\n        off_k = tl.arange(0, BLOCK_SIZE_K)\n        m_i = tl.full((BLOCK_SIZE_H,), float(\"-inf\"), dtype=tl.float32)\n        lse_i = tl.full((BLOCK_SIZE_H,), float(\"-inf\"), dtype=tl.float32)\n        acc_o = tl.full((BLOCK_SIZE_H, BLOCK_SIZE_D), 0, dtype=tl.float32)\n        # sparse attention\n        for i in range(real_topk):\n            # get current block start index\n            c = tl.load(t_ptr_j).to(tl.int32) * BLOCK_SIZE_K\n            t_ptr_j = t_ptr_j + stride_tk\n            # load k\n            k = tl.load(\n                tl.advance(k_ptrs, (0, c)), boundary_check=(1, 0), padding_option=\"zero\"\n            )\n            # compute qk\n            qk = tl.zeros((BLOCK_SIZE_H, BLOCK_SIZE_K), dtype=tl.float32)\n            qk += tl.where((pid_q_j >= c + off_k)[None, :], 0, float(\"-inf\"))\n            # [BLOCK_SIZE_H, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIZE_K] -> [BLOCK_SIZE_H, BLOCK_SIZE_K]\n            qk += tl.dot(q, k) * qk_scale\n            # compute m_ij and l_ij\n            m_ij = tl.maximum(m_i, tl.max(qk, axis=1))\n            p = tl.exp2(qk - m_ij[:, None])\n            l_ij = tl.sum(p, axis=1)\n            # scale acc_o\n            acc_o_scale = tl.exp2(m_i - m_ij)\n            acc_o = acc_o * acc_o_scale[:, None]\n            # load v and update acc_o\n            v = tl.load(\n                tl.advance(v_ptrs, (c, 0)), boundary_check=(0, 1), padding_option=\"zero\"\n            )\n            p = p.to(v.dtype)\n            acc_o += tl.dot(p, v)\n            # update statistics\n            m_i = m_ij\n            lse_i = m_ij + tl.math.log2(tl.exp2(lse_i - m_ij) + l_ij)\n        # final scale\n        acc_o = acc_o * tl.exp2(m_i - lse_i)[:, None]\n        # save output\n        o_ptrs = tl.make_block_ptr(\n            base=o_ptr + (q_start + pid_q_j) * stride_on + pid_h * stride_oh,\n            shape=(NUM_SHARE_Q_HEADS, HEAD_DIM),\n            strides=(stride_oh, stride_od),\n            offsets=(0, 0),\n            block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D),\n            order=(1, 0),\n        )\n        tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1))\n        # save lse\n        lse_ptrs = (\n            lse_ptr + (q_start + pid_q_j) * stride_ln + (pid_h + off_h) * stride_lh\n        )\n        tl.store(lse_ptrs, lse_i, mask=off_h < NUM_SHARE_Q_HEADS)\n\n\n@triton.jit\ndef backward_sum_o_do(\n    o_ptr,  # O: n x h x d\n    do_ptr,  # dO: n x h x d\n    delta_ptr,  # D: h x n\n    o_len,\n    HEAD_DIM,\n    stride_on,\n    stride_oh,\n    stride_od,\n    stride_don,\n    stride_doh,\n    stride_dod,\n    stride_dh,\n    stride_dn,\n    BLOCK_SIZE_O: tl.constexpr,\n    BLOCK_SIZE_D: tl.constexpr,\n):\n    pid_n = tl.program_id(0)\n    pid_h = tl.program_id(1)\n    off_o = pid_n * BLOCK_SIZE_O + tl.arange(0, BLOCK_SIZE_O)\n    off_d = tl.arange(0, BLOCK_SIZE_D)\n    o = tl.load(\n        o_ptr\n        + off_o[:, None] * stride_on\n        + pid_h * stride_oh\n        + off_d[None, :] * stride_od,\n        mask=(off_o[:, None] < o_len) & (off_d[None, :] < HEAD_DIM),\n        other=0,\n    ).to(tl.float32)\n    do = tl.load(\n        do_ptr\n        + off_o[:, None] * stride_don\n        + pid_h * stride_doh\n        + off_d[None, :] * stride_dod,\n        mask=(off_o[:, None] < o_len) & (off_d[None, :] < HEAD_DIM),\n        other=0,\n    ).to(tl.float32)\n    delta = tl.sum(o * do, axis=1)\n    tl.store(\n        delta_ptr + pid_h * stride_dh + off_o * stride_dn, delta, mask=off_o < o_len\n    )\n\n\n@triton.jit\ndef count_kernel(\n    x_ptr,  # [num_kv_heads, total_len, topk]\n    y_ptr,  # [num_kv_heads, total_blocks]\n    cu_seqlens,  # [batch_size + 1]\n    cu_seqblocks,  # [batch_size + 1]\n    topk,\n    stride_xh,\n    stride_xn,\n    stride_xk,\n    stride_yh,\n    stride_yn,\n    BLOCK_SIZE_N: tl.constexpr,\n    BLOCK_SIZE_K: tl.constexpr,\n    BLOCK_SIZE_R: tl.constexpr,\n):\n    pid_h = tl.program_id(0)\n    pid_b = tl.program_id(1)\n    # get start and len after rmpad\n    seq_start = tl.load(cu_seqlens + pid_b)\n    seq_len = tl.load(cu_seqlens + pid_b + 1) - seq_start\n    blocks_start = tl.load(cu_seqblocks + pid_b)\n    num_blocks = tl.load(cu_seqblocks + pid_b + 1) - blocks_start\n    # load x\n    off_k = tl.arange(0, BLOCK_SIZE_K)\n    off_n = tl.arange(0, BLOCK_SIZE_N)\n    x_ptr = x_ptr + pid_h * stride_xh + seq_start * stride_xn\n    x_ptrs = x_ptr + off_n[:, None] * stride_xn + off_k[None, :] * stride_xk\n    # init y\n    y = tl.zeros((BLOCK_SIZE_R,), dtype=tl.int32)\n    # loop\n    for i in range(0, seq_len, BLOCK_SIZE_N):\n        x = tl.load(\n            x_ptrs,\n            mask=(off_n < seq_len - i)[:, None] & (off_k < topk)[None, :],\n            other=-1,\n        )\n        x = tl.ravel(x)\n        y += tl.histogram(x, BLOCK_SIZE_R)\n        x_ptrs += BLOCK_SIZE_N * stride_xn\n    # store result\n    off_r = tl.arange(0, BLOCK_SIZE_R)\n    y_ptr = y_ptr + pid_h * stride_yh + blocks_start * stride_yn\n    y_ptrs = y_ptr + off_r * stride_yn\n    tl.store(y_ptrs, y.to(y_ptr.dtype.element_ty), mask=off_r < num_blocks)\n\n\ndef count_query(\n    topk_idx: torch.Tensor,\n    cu_seqlens: torch.Tensor,\n    cu_seqblocks: torch.Tensor,\n    block_size: int,\n):\n    num_kv_heads, total_len, topk = topk_idx.shape\n    seqlens = cu_seqlens[1:] - cu_seqlens[:-1]\n    seqblocks = cu_seqblocks[1:] - cu_seqblocks[:-1]\n    batch_size = seqlens.shape[0]\n    BLOCK_SIZE_K = triton.next_power_of_2(topk)\n    BLOCK_SIZE_N = triton.next_power_of_2(4096 // BLOCK_SIZE_K)\n    BLOCK_SIZE_R = triton.next_power_of_2(seqblocks.max().item() + 2)\n    active_query_count = torch.zeros(\n        num_kv_heads, cu_seqblocks[-1], dtype=torch.int32, device=topk_idx.device\n    )\n    grid = (num_kv_heads, batch_size)\n    count_kernel[grid](\n        topk_idx,\n        active_query_count,\n        cu_seqlens,\n        cu_seqblocks,\n        topk,\n        topk_idx.stride(0),\n        topk_idx.stride(1),\n        topk_idx.stride(2),\n        active_query_count.stride(0),\n        active_query_count.stride(1),\n        BLOCK_SIZE_N=BLOCK_SIZE_N,\n        BLOCK_SIZE_K=BLOCK_SIZE_K,\n        BLOCK_SIZE_R=BLOCK_SIZE_R,\n        num_warps=4,\n        num_stages=3,\n    )\n    return active_query_count\n\n\n@triton.jit\ndef pad_topk_idx_kernel(\n    t_ptr,\n    p_ptr,\n    cu_seqlens,\n    topk,\n    stride_th,\n    stride_tn,\n    stride_tk,\n    stride_pb,\n    stride_ph,\n    stride_pn,\n    stride_pk,\n    BLOCK_SIZE_N: tl.constexpr,\n    BLOCK_SIZE_T: tl.constexpr,\n):\n    pid_b = tl.program_id(0)\n    pid_h = tl.program_id(1)\n    pid_n = tl.program_id(2)\n    # get q start and len after rmpad\n    q_start = tl.load(cu_seqlens + pid_b)\n    q_len = tl.load(cu_seqlens + pid_b + 1) - q_start\n    if BLOCK_SIZE_N * pid_n >= q_len:\n        return\n    # init prts\n    t_ptrs = tl.make_block_ptr(\n        base=t_ptr + pid_h * stride_th + q_start * stride_tn,\n        shape=(q_len, topk),\n        strides=(stride_tn, stride_tk),\n        offsets=(pid_n * BLOCK_SIZE_N, 0),\n        block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_T),\n        order=(1, 0),\n    )\n    p_ptrs = tl.make_block_ptr(\n        base=p_ptr + pid_b * stride_pb + pid_h * stride_ph,\n        shape=(q_len, topk),\n        strides=(stride_pn, stride_pk),\n        offsets=(pid_n * BLOCK_SIZE_N, 0),\n        block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_T),\n        order=(1, 0),\n    )\n    # load and save\n    idxs = tl.load(t_ptrs, boundary_check=(0, 1))\n    tl.store(p_ptrs, idxs, boundary_check=(0, 1))\n\n\n@triton.jit\ndef save_topk_idx_kernel(\n    p_ptr,\n    t_ptr,\n    cu_seqblocks,\n    cu_topk_q_count,\n    n_len,\n    stride_pb,\n    stride_ph,\n    stride_pn,\n    stride_th,\n    stride_tn,\n    stride_ch,\n    stride_cn,\n    BLOCK_SIZE_N: tl.constexpr,\n):\n    pid_b = tl.program_id(0)\n    pid_h = tl.program_id(1)\n    pid_n = tl.program_id(2)\n    # get q start and len after rmpad\n    q_block_start = tl.load(cu_seqblocks + pid_b)\n    q_block_end = tl.load(cu_seqblocks + pid_b + 1)\n    c_start = tl.load(cu_topk_q_count + pid_h * stride_ch + q_block_start * stride_cn)\n    c_end = tl.load(cu_topk_q_count + pid_h * stride_ch + q_block_end * stride_cn)\n    c_len = c_end - c_start\n    if c_len <= 0:\n        return\n    if pid_n * BLOCK_SIZE_N >= c_len:\n        return\n    # init ptrs\n    p_ptrs = tl.make_block_ptr(\n        base=p_ptr\n        + pid_b * stride_pb\n        + pid_h * stride_ph\n        + (n_len - c_len) * stride_pn,\n        shape=(c_len,),\n        strides=(stride_pn,),\n        offsets=(pid_n * BLOCK_SIZE_N,),\n        block_shape=(BLOCK_SIZE_N,),\n        order=(0,),\n    )\n    t_ptrs = tl.make_block_ptr(\n        base=t_ptr + pid_h * stride_th + c_start * stride_tn,\n        shape=(c_len,),\n        strides=(stride_tn,),\n        offsets=(pid_n * BLOCK_SIZE_N,),\n        block_shape=(BLOCK_SIZE_N,),\n        order=(0,),\n    )\n    # load and save\n    idxs = tl.load(p_ptrs, boundary_check=(0,))\n    tl.store(t_ptrs, idxs, boundary_check=(0,))\n\n\ndef reorder_topk_idx(\n    topk_idx: torch.Tensor,\n    cu_topk_q_count: torch.Tensor,\n    cu_seqlens: torch.Tensor,\n    cu_seqblocks: torch.Tensor,\n    block_size: int,\n):\n    num_kv_heads, total_len, topk = topk_idx.shape\n    batch_size = cu_seqlens.shape[0] - 1\n    seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]\n    max_seqlen = seq_lens.max().item()\n    # pad shape [num_kv_heads, total_seqlen, topk] to [batch_size, num_kv_heads, max_seqlen, topk]\n    pad_topk_idx = torch.full(\n        (batch_size, num_kv_heads, max_seqlen, topk),\n        fill_value=-1,\n        device=topk_idx.device,\n        dtype=torch.int32,\n    )\n    BLOCK_SIZE_T = triton.next_power_of_2(topk)\n    BLOCK_SIZE_N = min(\n        triton.next_power_of_2(max_seqlen), triton.next_power_of_2(8192 // BLOCK_SIZE_T)\n    )\n    grid = (batch_size, num_kv_heads, triton.cdiv(max_seqlen, BLOCK_SIZE_N))\n    pad_topk_idx_kernel[grid](\n        topk_idx,\n        pad_topk_idx,\n        cu_seqlens,\n        topk,\n        topk_idx.stride(0),\n        topk_idx.stride(1),\n        topk_idx.stride(2),\n        pad_topk_idx.stride(0),\n        pad_topk_idx.stride(1),\n        pad_topk_idx.stride(2),\n        pad_topk_idx.stride(3),\n        BLOCK_SIZE_N=BLOCK_SIZE_N,\n        BLOCK_SIZE_T=BLOCK_SIZE_T,\n    )\n    # argsort\n    pad_topk_q_idx = pad_topk_idx.view(batch_size, num_kv_heads, -1).argsort(-1) // topk\n    pad_topk_q_idx = pad_topk_q_idx.to(torch.int32)\n    # save as remove pad version\n    topk_q_idx = torch.full(\n        (num_kv_heads, cu_topk_q_count[:, -1].max().item()),\n        fill_value=-1,\n        device=topk_idx.device,\n        dtype=torch.int32,\n    )\n    max_len = (\n        (\n            cu_topk_q_count[:, cu_seqblocks][:, 1:]\n            - cu_topk_q_count[:, cu_seqblocks][:, :-1]\n        )\n        .max()\n        .item()\n    )\n    BLOCK_SIZE_N = min(triton.next_power_of_2(max_len), 8192)\n    grid = (batch_size, num_kv_heads, triton.cdiv(max_len, BLOCK_SIZE_N))\n    save_topk_idx_kernel[grid](\n        pad_topk_q_idx,\n        topk_q_idx,\n        cu_seqblocks,\n        cu_topk_q_count,\n        pad_topk_q_idx.shape[-1],\n        pad_topk_q_idx.stride(0),\n        pad_topk_q_idx.stride(1),\n        pad_topk_q_idx.stride(2),\n        topk_q_idx.stride(0),\n        topk_q_idx.stride(1),\n        cu_topk_q_count.stride(0),\n        cu_topk_q_count.stride(1),\n        BLOCK_SIZE_N=BLOCK_SIZE_N,\n    )\n    return topk_q_idx\n\n\n@triton.jit\ndef backward_dkdv(\n    q_ptr,  # Q: n x qh x d\n    k_ptr,  # K: n x kh x d\n    v_ptr,  # V: n x kh x d\n    tq_ptr,  # topk_q_idx: kh x N\n    lse_ptr,  # LSE: qh x n\n    d_ptr,  # Delta: qh x n\n    do_ptr,\n    dk_ptr,  # DK: sh x n x kh x d\n    dv_ptr,  # DK: sh x n x kh x d\n    # seqlens\n    cu_seqlens_q,  # [batch_size + 1]\n    cu_seqlens_k,  # [batch_size + 1]\n    cu_seqblocks,  # [batch_size + 1]\n    cu_topk_q_count,  # [kh, total_blocks]\n    # shape\n    NUM_KV_HEADS,\n    NUM_SHARE_Q_HEADS,\n    HEAD_DIM,\n    TOPK,\n    # sm_scale\n    sm_scale,\n    # stride\n    stride_qn,\n    stride_qh,\n    stride_qd,\n    stride_kn,\n    stride_kh,\n    stride_kd,\n    stride_vn,\n    stride_vh,\n    stride_vd,\n    stride_tqh,\n    stride_tqn,\n    stride_ctqh,\n    stride_ctqn,\n    stride_lh,\n    stride_ln,\n    stride_dh,\n    stride_dn,\n    stride_don,\n    stride_doh,\n    stride_dod,\n    stride_dks,\n    stride_dkn,\n    stride_dkh,\n    stride_dkd,\n    stride_dvs,\n    stride_dvn,\n    stride_dvh,\n    stride_dvd,\n    # META parameters\n    BLOCK_SIZE_Q: tl.constexpr,  # q block size\n    BLOCK_SIZE_K: tl.constexpr,  # k block size\n    BLOCK_SIZE_D: tl.constexpr,\n):\n    qk_scale = sm_scale * 1.44269504\n    # get batch id and head id\n    pid_b = tl.program_id(0)\n    pid_h = tl.program_id(1)\n    pid_kh = pid_h // NUM_SHARE_Q_HEADS\n    pid_sh = pid_h % NUM_SHARE_Q_HEADS\n    pid_k = tl.program_id(2)\n    # get q k start and len after rmpad\n    q_start = tl.load(cu_seqlens_q + pid_b)\n    q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start\n    k_start = tl.load(cu_seqlens_k + pid_b)\n    k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start\n    if BLOCK_SIZE_K * pid_k >= k_len:\n        return\n    # get topk_q_idx\n    b_start = tl.load(cu_seqblocks + pid_b)  # how many blocks before current sequence\n    act_q_start = tl.load(\n        cu_topk_q_count + pid_kh * stride_ctqh + (b_start + pid_k) * stride_ctqn\n    )\n    act_q_end = tl.load(\n        cu_topk_q_count + pid_kh * stride_ctqh + (b_start + pid_k + 1) * stride_ctqn\n    )\n    act_q_len = act_q_end - act_q_start\n    tq_ptr = tq_ptr + pid_kh * stride_tqh + act_q_start * stride_tqn\n    # init pointers\n    k_ptrs = tl.make_block_ptr(\n        base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,\n        shape=(k_len, HEAD_DIM),\n        strides=(stride_kn, stride_kd),\n        offsets=(pid_k * BLOCK_SIZE_K, 0),\n        block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),\n        order=(1, 0),\n    )\n    dk_ptrs = tl.make_block_ptr(\n        base=dk_ptr + k_start * stride_dkn + pid_kh * stride_dkh + pid_sh * stride_dks,\n        shape=(k_len, HEAD_DIM),\n        strides=(stride_dkn, stride_dkd),\n        offsets=(pid_k * BLOCK_SIZE_K, 0),\n        block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),\n        order=(1, 0),\n    )\n    v_ptrs = tl.make_block_ptr(\n        base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,\n        shape=(k_len, HEAD_DIM),\n        strides=(stride_vn, stride_vd),\n        offsets=(pid_k * BLOCK_SIZE_K, 0),\n        block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),\n        order=(1, 0),\n    )\n    dv_ptrs = tl.make_block_ptr(\n        base=dv_ptr + k_start * stride_dvn + pid_kh * stride_dvh + pid_sh * stride_dvs,\n        shape=(k_len, HEAD_DIM),\n        strides=(stride_dvn, stride_dvd),\n        offsets=(pid_k * BLOCK_SIZE_K, 0),\n        block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),\n        order=(1, 0),\n    )\n    # offsets\n    off_q = tl.arange(0, BLOCK_SIZE_Q)\n    off_k = tl.arange(0, BLOCK_SIZE_K) + pid_k * BLOCK_SIZE_K\n    off_d = tl.arange(0, BLOCK_SIZE_D)\n    # load k v and keep in SRAM\n    k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option=\"zero\")\n    v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option=\"zero\")\n    # init dk dv\n    dk = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32)\n    dv = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32)\n    # init ptrs\n    q_ptrs = (\n        q_ptr + q_start * stride_qn + pid_h * stride_qh + off_d[None, :] * stride_qd\n    )\n    do_ptrs = (\n        do_ptr + q_start * stride_don + pid_h * stride_doh + off_d[None, :] * stride_dod\n    )\n    d_ptrs = d_ptr + q_start * stride_dn + pid_h * stride_dh\n    lse_ptrs = lse_ptr + q_start * stride_ln + pid_h * stride_lh\n    # loop for q blocks\n    for i in range(0, act_q_len, BLOCK_SIZE_Q):\n        # load\n        idx_q = tl.load(tq_ptr + i + off_q, mask=off_q < act_q_len - i, other=0).to(\n            tl.int32\n        )\n        q = tl.load(\n            q_ptrs + idx_q[:, None] * stride_qn,\n            mask=(off_q < act_q_len - i)[:, None] & (off_d < HEAD_DIM)[None, :],\n            other=0,\n        )\n        do = tl.load(\n            do_ptrs + idx_q[:, None] * stride_don,\n            mask=(off_q < act_q_len - i)[:, None] & (off_d < HEAD_DIM)[None, :],\n            other=0,\n        )\n        lse = tl.load(\n            lse_ptrs + idx_q[:, None] * stride_ln,\n            mask=(off_q < act_q_len - i)[:, None],\n            other=0,\n        )\n        d = tl.load(\n            d_ptrs + idx_q[:, None] * stride_dn,\n            mask=(off_q < act_q_len - i)[:, None],\n            other=0,\n        )\n        # compute qk\n        qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32)\n        qk += tl.where(idx_q[:, None] >= off_k[None, :], float(0.0), float(\"-inf\"))\n        qk += tl.dot(q, k.T) * qk_scale\n        # compute p, ds\n        p = tl.exp2(qk - lse)\n        dp = tl.dot(do, v.T)\n        ds = sm_scale * p * (dp - d)\n        # cast dtype\n        p = p.to(do.dtype)\n        ds = ds.to(q.dtype)\n        # update dk and dv\n        dk += tl.dot(ds.T, q)\n        dv += tl.dot(p.T, do)\n    # save dk dv\n    tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), boundary_check=(0, 1))\n    tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef backward_dq(\n    q_ptr,  # Q: n x qh x d\n    k_ptr,  # K: n x kh x d\n    v_ptr,  # V: n x kh x d\n    t_ptr,  # topk_idx: kh x n x k\n    lse_ptr,  # LSE: qh x n\n    d_ptr,  # Delta: qh x n\n    do_ptr,\n    dq_ptr,\n    # seqlens\n    cu_seqlens_q,\n    cu_seqlens_k,\n    # shape\n    NUM_KV_HEADS,\n    NUM_SHARE_Q_HEADS,\n    HEAD_DIM,\n    TOPK,\n    # q loop num\n    num_q_loop,\n    # sm_scale\n    sm_scale,\n    # stride\n    stride_qn,\n    stride_qh,\n    stride_qd,\n    stride_kn,\n    stride_kh,\n    stride_kd,\n    stride_vn,\n    stride_vh,\n    stride_vd,\n    stride_th,\n    stride_tn,\n    stride_tk,\n    stride_lh,\n    stride_ln,\n    stride_dh,\n    stride_dn,\n    stride_don,\n    stride_doh,\n    stride_dod,\n    stride_dqn,\n    stride_dqh,\n    stride_dqd,\n    # META parameters\n    BLOCK_SIZE_K: tl.constexpr,  # k block size\n    BLOCK_SIZE_D: tl.constexpr,\n    BLOCK_SIZE_H: tl.constexpr,\n    BLOCK_SIZE_T: tl.constexpr,\n):\n    qk_scale = sm_scale * 1.44269504\n    # get batch id and head id\n    pid_b = tl.program_id(0)\n    pid_kh = tl.program_id(1)\n    pid_q = tl.program_id(2)\n    pid_h = pid_kh * NUM_SHARE_Q_HEADS\n    # get q k start and len after rmpad\n    q_start = tl.load(cu_seqlens_q + pid_b)\n    q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start\n    k_start = tl.load(cu_seqlens_k + pid_b)\n    k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start\n    if pid_q * num_q_loop >= q_len:\n        return\n    real_q_loop = min(num_q_loop, q_len - pid_q * num_q_loop)\n    for j in range(real_q_loop):\n        pid_q_j = pid_q * num_q_loop + j\n        # init topk idx pointer\n        off_t = tl.arange(0, BLOCK_SIZE_T)\n        t_ptr_j = t_ptr + (q_start + pid_q_j) * stride_tn + pid_kh * stride_th\n        topk_idx = tl.load(t_ptr_j + off_t * stride_tk, mask=off_t < TOPK, other=-1)\n        real_topk = tl.sum(\n            tl.where((topk_idx >= 0) & (topk_idx <= pid_q_j // BLOCK_SIZE_K), 1, 0),\n            axis=0,\n        )\n        # init pointers\n        q_ptrs = tl.make_block_ptr(\n            base=q_ptr + (q_start + pid_q_j) * stride_qn + pid_h * stride_qh,\n            shape=(NUM_SHARE_Q_HEADS, HEAD_DIM),\n            strides=(stride_qh, stride_qd),\n            offsets=(0, 0),\n            block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D),\n            order=(1, 0),\n        )\n        dq_ptrs = tl.make_block_ptr(\n            base=dq_ptr + (q_start + pid_q_j) * stride_dqn + pid_h * stride_dqh,\n            shape=(NUM_SHARE_Q_HEADS, HEAD_DIM),\n            strides=(stride_dqh, stride_dqd),\n            offsets=(0, 0),\n            block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D),\n            order=(1, 0),\n        )\n        k_ptrs = tl.make_block_ptr(\n            base=k_ptr + k_start * stride_kn + pid_kh * stride_kh,\n            shape=(k_len, HEAD_DIM),\n            strides=(stride_kn, stride_kd),\n            offsets=(0, 0),\n            block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),\n            order=(1, 0),\n        )\n        v_ptrs = tl.make_block_ptr(\n            base=v_ptr + k_start * stride_vn + pid_kh * stride_vh,\n            shape=(HEAD_DIM, k_len),\n            strides=(stride_vd, stride_vn),\n            offsets=(0, 0),\n            block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K),\n            order=(0, 1),\n        )\n        do_ptrs = tl.make_block_ptr(\n            base=do_ptr + (q_start + pid_q_j) * stride_don + pid_h * stride_doh,\n            shape=(NUM_SHARE_Q_HEADS, HEAD_DIM),\n            strides=(stride_doh, stride_dod),\n            offsets=(0, 0),\n            block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D),\n            order=(1, 0),\n        )\n        d_ptrs = tl.make_block_ptr(\n            base=d_ptr + (q_start + pid_q_j) * stride_dn + pid_h * stride_dh,\n            shape=(NUM_SHARE_Q_HEADS, 1),\n            strides=(stride_dh, stride_dn),\n            offsets=(0, 0),\n            block_shape=(BLOCK_SIZE_H, 1),\n            order=(1, 0),\n        )\n        lse_ptrs = tl.make_block_ptr(\n            base=lse_ptr + (q_start + pid_q_j) * stride_ln + pid_h * stride_lh,\n            shape=(NUM_SHARE_Q_HEADS, 1),\n            strides=(stride_lh, stride_ln),\n            offsets=(0, 0),\n            block_shape=(BLOCK_SIZE_H, 1),\n            order=(1, 0),\n        )\n        # offsets\n        off_k = tl.arange(0, BLOCK_SIZE_K)\n        # load q, do, lse, delta, and keep in SRAM\n        q = tl.load(q_ptrs, boundary_check=(1, 0), padding_option=\"zero\")\n        do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option=\"zero\")\n        lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option=\"zero\")\n        d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option=\"zero\")\n        # init dq\n        dq = tl.zeros((BLOCK_SIZE_H, BLOCK_SIZE_D), dtype=tl.float32)\n        # sparse\n        for i in range(real_topk):\n            # get current block start index\n            c = tl.load(t_ptr_j).to(tl.int32) * BLOCK_SIZE_K\n            t_ptr_j = t_ptr_j + stride_tk\n            # load\n            k = tl.load(\n                tl.advance(k_ptrs, (c, 0)), boundary_check=(1, 0), padding_option=\"zero\"\n            )\n            v = tl.load(\n                tl.advance(v_ptrs, (0, c)), boundary_check=(0, 1), padding_option=\"zero\"\n            )\n            # compute qk\n            qk = tl.zeros((BLOCK_SIZE_H, BLOCK_SIZE_K), dtype=tl.float32)\n            qk += tl.where((pid_q_j >= c + off_k)[None, :], 0, float(\"-inf\"))\n            # [BLOCK_SIZE_H, HEAD_DIM] @ [BLOCK_SIZE_K, HEAD_DIM].T -> [BLOCK_SIZE_H, BLOCK_SIZE_K]\n            qk += tl.dot(q, tl.trans(k)) * qk_scale\n            # compute p, ds\n            p = tl.exp2(qk - lse)\n            dp = tl.dot(do, v)\n            ds = sm_scale * p * (dp - d)\n            # cast dtype\n            ds = ds.to(q.dtype)\n            # update dq\n            dq += tl.dot(ds, k)\n        # save dq\n        tl.store(dq_ptrs, dq.to(dq_ptr.dtype.element_ty), boundary_check=(0, 1))\n\n\ndef _topk_sparse_attention_fwd(\n    q: torch.Tensor,  # [total_len, num_q_heads, head_dim]\n    k: torch.Tensor,  # [total_len, num_k_heads, head_dim]\n    v: torch.Tensor,  # [total_len, num_k_heads, head_dim]\n    topk_idx: torch.Tensor,  # [num_kv_heads, total_len, topk]\n    block_size: int,\n    cu_seqlens_q: torch.Tensor,\n    cu_seqlens_k: torch.Tensor,\n    max_seqlen_q: int,\n    max_seqlen_k: int,\n    sm_scale: float,\n):\n    # dtype check\n    assert k.dtype == q.dtype and v.dtype == q.dtype\n    assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32\n    assert block_size in {32, 64, 128, 256}\n    # shape\n    q_len, num_q_heads, head_dim = q.shape\n    k_len, num_k_heads, head_dim = k.shape\n    v_len, num_v_heads, head_dim = v.shape\n    batch_size = cu_seqlens_q.shape[0] - 1\n    assert q_len == k_len and k_len == v_len\n    topk = topk_idx.shape[-1]\n    assert topk_idx.shape[0] == num_k_heads\n    assert topk_idx.shape[1] == q_len\n    # gqa\n    assert num_k_heads == num_v_heads\n    assert num_q_heads % num_k_heads == 0\n    num_share_q_heads = num_q_heads // num_k_heads\n    # output tensor\n    o = torch.zeros_like(q)\n    lse = torch.zeros(num_q_heads, q_len, dtype=torch.float32, device=q.device)\n    # launch kernel\n    num_q_loop = (\n        max_seqlen_q // 32768 + 1\n    )  # calculate multiple querys in one kernel if seqlence length is too long\n    grid = (batch_size, num_k_heads, triton.cdiv(max_seqlen_q, num_q_loop))\n    BLOCK_SIZE_K = triton.next_power_of_2(block_size)\n    BLOCK_SIZE_D = triton.next_power_of_2(head_dim)\n    BLOCK_SIZE_H = max(16, triton.next_power_of_2(num_share_q_heads))\n    BLOCK_SIZE_T = triton.next_power_of_2(topk)\n    num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_K, IS_HOPPER_GPU)\n    forward_kernel[grid](\n        q,\n        k,\n        v,\n        topk_idx,\n        o,\n        lse,\n        cu_seqlens_q,\n        cu_seqlens_k,\n        num_k_heads,\n        num_share_q_heads,\n        head_dim,\n        topk,\n        num_q_loop,\n        sm_scale,\n        q.stride(0),\n        q.stride(1),\n        q.stride(2),\n        k.stride(0),\n        k.stride(1),\n        k.stride(2),\n        v.stride(0),\n        v.stride(1),\n        v.stride(2),\n        topk_idx.stride(0),\n        topk_idx.stride(1),\n        topk_idx.stride(2),\n        o.stride(0),\n        o.stride(1),\n        o.stride(2),\n        lse.stride(0),\n        lse.stride(1),\n        BLOCK_SIZE_K=BLOCK_SIZE_K,\n        BLOCK_SIZE_D=BLOCK_SIZE_D,\n        BLOCK_SIZE_H=BLOCK_SIZE_H,\n        BLOCK_SIZE_T=BLOCK_SIZE_T,\n        num_warps=num_warps,\n        num_stages=num_stages,\n    )\n    return o, lse\n\n\ndef _topk_sparse_attention_bwd(\n    o: torch.Tensor,\n    do: torch.Tensor,\n    lse: torch.Tensor,\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    topk_idx: torch.Tensor,\n    block_size: int,\n    cu_seqlens_q: torch.Tensor,\n    cu_seqlens_k: torch.Tensor,\n    max_seqlen_q: int,\n    max_seqlen_k: int,\n    sm_scale: float,\n):\n    assert block_size in {32, 64, 128, 256}\n    q_len, num_q_heads, head_dim = q.shape\n    k_len, num_k_heads, head_dim = k.shape\n    v_len, num_v_heads, head_dim = v.shape\n    o_len, num_o_heads, head_dim = o.shape\n    num_share_q_heads = num_q_heads // num_k_heads\n    topk = topk_idx.shape[-1]\n    # compute D\n    delta = torch.zeros([num_o_heads, o_len], device=o.device, dtype=torch.float32)\n    BLOCK_SIZE_O = 256\n    BLOCK_SIZE_D = triton.next_power_of_2(head_dim)\n    num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_O, IS_HOPPER_GPU)\n    grid = (triton.cdiv(o_len, BLOCK_SIZE_O), num_o_heads)\n    backward_sum_o_do[grid](\n        o,\n        do,\n        delta,\n        o_len,\n        head_dim,\n        o.stride(0),\n        o.stride(1),\n        o.stride(2),\n        do.stride(0),\n        do.stride(1),\n        do.stride(2),\n        delta.stride(0),\n        delta.stride(1),\n        BLOCK_SIZE_O=BLOCK_SIZE_O,\n        BLOCK_SIZE_D=BLOCK_SIZE_D,\n        num_warps=num_warps,\n        num_stages=num_stages,\n    )\n    # count active querys for each key block, shape: (num_k_heads, total_k_blocks)\n    seqlens = cu_seqlens_q[1:] - cu_seqlens_q[:-1]\n    seqblocks = torch.ceil(seqlens / block_size).to(torch.int32)\n    cu_seqblocks = torch.cat(\n        [\n            torch.zeros(1, dtype=torch.int32, device=topk_idx.device),\n            torch.cumsum(seqblocks, dim=0),\n        ]\n    ).to(torch.int32)\n    topk_q_count = count_query(topk_idx, cu_seqlens_q, cu_seqblocks, block_size)\n    cu_topk_q_count = torch.cat(\n        [\n            torch.zeros(\n                topk_q_count.shape[0], 1, dtype=torch.int32, device=topk_idx.device\n            ),\n            torch.cumsum(topk_q_count, dim=-1),\n        ],\n        dim=-1,\n    ).to(torch.int32)\n    # active query idx for each key block\n    # how to get active query idx for sequence b, head h, kv block i?\n    # topk_q_idx[h, cu_topk_q_count[h, cu_seqblocks[b] + i]:cu_topk_q_count[h, cu_seqblocks[b] + i + 1]]\n    topk_q_idx = reorder_topk_idx(\n        topk_idx, cu_topk_q_count, cu_seqlens_q, cu_seqblocks, block_size\n    )\n    # compute dk dv\n    dk = torch.zeros(\n        num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype\n    )\n    dv = torch.zeros(\n        num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype\n    )\n    batch_size = cu_seqlens_q.shape[0] - 1\n    BLOCK_SIZE_K = triton.next_power_of_2(block_size)\n    BLOCK_SIZE_Q = 64\n    BLOCK_SIZE_D = triton.next_power_of_2(head_dim)\n    num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU)\n    grid = (batch_size, num_q_heads, triton.cdiv(max_seqlen_k, BLOCK_SIZE_K))\n    backward_dkdv[grid](\n        q,\n        k,\n        v,\n        topk_q_idx,\n        lse,\n        delta,\n        do,\n        dk,\n        dv,\n        cu_seqlens_q,\n        cu_seqlens_k,\n        cu_seqblocks,\n        cu_topk_q_count,\n        num_k_heads,\n        num_share_q_heads,\n        head_dim,\n        topk,\n        sm_scale,\n        q.stride(0),\n        q.stride(1),\n        q.stride(2),\n        k.stride(0),\n        k.stride(1),\n        k.stride(2),\n        v.stride(0),\n        v.stride(1),\n        v.stride(2),\n        topk_q_idx.stride(0),\n        topk_q_idx.stride(1),\n        cu_topk_q_count.stride(0),\n        cu_topk_q_count.stride(1),\n        lse.stride(0),\n        lse.stride(1),\n        delta.stride(0),\n        delta.stride(1),\n        do.stride(0),\n        do.stride(1),\n        do.stride(2),\n        dk.stride(0),\n        dk.stride(1),\n        dk.stride(2),\n        dk.stride(3),\n        dv.stride(0),\n        dv.stride(1),\n        dv.stride(2),\n        dv.stride(3),\n        BLOCK_SIZE_Q=BLOCK_SIZE_Q,\n        BLOCK_SIZE_K=BLOCK_SIZE_K,\n        BLOCK_SIZE_D=BLOCK_SIZE_D,\n        num_warps=num_warps,\n        num_stages=num_stages,\n    )\n    dk = dk.sum(0)\n    dv = dv.sum(0)\n    # compute dq\n    dq = torch.zeros_like(q)\n    num_q_loop = (\n        max_seqlen_q // 32768 + 1\n    )  # calculate multiple querys in one kernel if seqlence length is too long\n    grid = (batch_size, num_k_heads, triton.cdiv(max_seqlen_q, num_q_loop))\n    BLOCK_SIZE_K = block_size\n    BLOCK_SIZE_D = triton.next_power_of_2(head_dim)\n    BLOCK_SIZE_H = max(16, triton.next_power_of_2(num_share_q_heads))\n    BLOCK_SIZE_T = triton.next_power_of_2(topk)\n    num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_K, IS_HOPPER_GPU)\n    backward_dq[grid](\n        q,\n        k,\n        v,\n        topk_idx,\n        lse,\n        delta,\n        do,\n        dq,\n        cu_seqlens_q,\n        cu_seqlens_k,\n        num_k_heads,\n        num_share_q_heads,\n        head_dim,\n        topk,\n        num_q_loop,\n        sm_scale,\n        q.stride(0),\n        q.stride(1),\n        q.stride(2),\n        k.stride(0),\n        k.stride(1),\n        k.stride(2),\n        v.stride(0),\n        v.stride(1),\n        v.stride(2),\n        topk_idx.stride(0),\n        topk_idx.stride(1),\n        topk_idx.stride(2),\n        lse.stride(0),\n        lse.stride(1),\n        delta.stride(0),\n        delta.stride(1),\n        do.stride(0),\n        do.stride(1),\n        do.stride(2),\n        dq.stride(0),\n        dq.stride(1),\n        dq.stride(2),\n        BLOCK_SIZE_K=BLOCK_SIZE_K,\n        BLOCK_SIZE_D=BLOCK_SIZE_D,\n        BLOCK_SIZE_H=BLOCK_SIZE_H,\n        BLOCK_SIZE_T=BLOCK_SIZE_T,\n        num_warps=num_warps,\n        num_stages=num_stages,\n    )\n    return dq, dk, dv\n\n\nclass TopkSparseAttention(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        q: torch.Tensor,  # [total_len, num_q_heads, head_dim]\n        k: torch.Tensor,  # [total_len, num_k_heads, head_dim]\n        v: torch.Tensor,  # [total_len, num_k_heads, head_dim]\n        topk_idx: torch.Tensor,  # [num_kv_heads, total_len, topk]\n        block_size: int,\n        cu_seqlens_q: torch.Tensor,\n        cu_seqlens_k: torch.Tensor,\n        max_seqlen_q: int,\n        max_seqlen_k: int,\n        sm_scale=None,\n    ):\n        # dtype check\n        assert q.dtype == torch.bfloat16 or q.dtype == torch.float16\n        assert q.dtype == k.dtype and k.dtype == v.dtype\n        assert topk_idx.dtype == torch.int32\n        assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32\n        # softmax scale\n        if sm_scale is None:\n            sm_scale = 1 / math.sqrt(q.shape[-1])\n        o, lse = _topk_sparse_attention_fwd(\n            q,\n            k,\n            v,\n            topk_idx,\n            block_size,\n            cu_seqlens_q,\n            cu_seqlens_k,\n            max_seqlen_q,\n            max_seqlen_k,\n            sm_scale,\n        )\n        ctx.save_for_backward(q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k, topk_idx)\n        ctx.sm_scale = sm_scale\n        ctx.max_seqlen_q = max_seqlen_q\n        ctx.max_seqlen_k = max_seqlen_k\n        ctx.block_size = block_size\n        # return\n        return o\n\n    @staticmethod\n    def backward(ctx, do: torch.Tensor, *args) -> Any:\n        q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k, topk_idx = ctx.saved_tensors\n        max_seqlen_q = ctx.max_seqlen_q\n        max_seqlen_k = ctx.max_seqlen_k\n        sm_scale = ctx.sm_scale\n        block_size = ctx.block_size\n        assert block_size in {32, 64, 128, 256}\n\n        dq, dk, dv = _topk_sparse_attention_bwd(\n            o,\n            do,\n            lse,\n            q,\n            k,\n            v,\n            topk_idx,\n            block_size,\n            cu_seqlens_q,\n            cu_seqlens_k,\n            max_seqlen_q,\n            max_seqlen_k,\n            sm_scale,\n        )\n        return dq, dk, dv, None, None, None, None, None, None, None, None\n\n\ndef topk_sparse_attention(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    topk_idx: torch.Tensor,\n    block_size: int,\n    cu_seqlens: torch.Tensor,\n    softmax_scale: Optional[float] = None,\n) -> torch.Tensor:\n    \"\"\"Topk sparse attention varlen version implemented in triton.\n\n    Args:\n        q (torch.Tensor): shape [total_len, num_q_heads, head_dim]\n        k (torch.Tensor): shape [total_len, num_kv_heads, head_dim]\n        v (torch.Tensor): shape [total_len, num_kv_heads, head_dim]\n        topk_idx (torch.Tensor): topk block idx for each query, shape [num_kv_heads, total_len, topk]. -1 means padding.\n        block_size (int): key value block size.\n        cu_seqlens (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens in flash_attn_func_varlen.\n        softmax_scale (Optional[float], optional): Defaults to None, means 1/sqrt(head_dim).\n\n    Returns:\n        torch.Tensor: attention output, shape [total_len, num_q_heads, head_dim]\n    \"\"\"\n    max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()\n    return TopkSparseAttention.apply(\n        q,\n        k,\n        v,\n        topk_idx,\n        block_size,\n        cu_seqlens,\n        cu_seqlens,\n        max_seqlen,\n        max_seqlen,\n        softmax_scale,\n    )\n"
  },
  {
    "path": "native_sparse_attention/ops/triton/topk_sparse_attention_decode.py",
    "content": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport math\nimport torch\nimport triton\nimport triton.language as tl\nfrom typing import Optional\n\n\n@triton.jit\ndef forward_kernel(\n    q_ptr,  # Q: b x h x d\n    k_ptr,  # K: b x n x kh x d\n    v_ptr,  # V: b x n x kh x d\n    t_ptr,  # topk_idx: kh x b x k\n    o_ptr,  # O: b x h x d\n    # seqlens\n    seqlens,\n    # shape\n    NUM_SHARE_Q_HEADS,\n    HEAD_DIM,\n    TOPK,\n    # sm_scale\n    sm_scale,\n    # stride\n    stride_qb,\n    stride_qh,\n    stride_qd,\n    stride_kb,\n    stride_kn,\n    stride_kh,\n    stride_kd,\n    stride_vb,\n    stride_vn,\n    stride_vh,\n    stride_vd,\n    stride_th,\n    stride_tb,\n    stride_tk,\n    stride_ob,\n    stride_oh,\n    stride_od,\n    # META parameters\n    BLOCK_SIZE_K: tl.constexpr,  # k block size\n    BLOCK_SIZE_D: tl.constexpr,\n    BLOCK_SIZE_H: tl.constexpr,\n    BLOCK_SIZE_T: tl.constexpr,\n):\n    qk_scale = sm_scale * 1.44269504\n    # get batch id and head id\n    pid_b = tl.program_id(0)\n    pid_kh = tl.program_id(1)\n    # get kv_len\n    kv_len = tl.load(seqlens + pid_b)\n    # init topk idx pointer\n    off_t = tl.arange(0, BLOCK_SIZE_T)\n    t_ptr = t_ptr + pid_b * stride_tb + pid_kh * stride_th\n    topk_idx = tl.load(t_ptr + off_t * stride_tk, mask=off_t < TOPK, other=-1)\n    real_topk = tl.sum(\n        tl.where((topk_idx >= 0) & (topk_idx <= (kv_len - 1) // BLOCK_SIZE_K), 1, 0),\n        axis=0,\n    )\n    # init qkv pointer\n    q_ptrs = tl.make_block_ptr(\n        base=q_ptr + pid_b * stride_qb + pid_kh * NUM_SHARE_Q_HEADS * stride_qh,\n        shape=(NUM_SHARE_Q_HEADS, HEAD_DIM),\n        strides=(stride_qh, stride_qd),\n        offsets=(0, 0),\n        block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D),\n        order=(1, 0),\n    )\n    k_ptrs = tl.make_block_ptr(\n        base=k_ptr + pid_b * stride_kb + pid_kh * stride_kh,\n        shape=(HEAD_DIM, kv_len),\n        strides=(stride_kd, stride_kn),\n        offsets=(0, 0),\n        block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K),\n        order=(0, 1),\n    )\n    v_ptrs = tl.make_block_ptr(\n        base=v_ptr + pid_b * stride_vb + pid_kh * stride_vh,\n        shape=(kv_len, HEAD_DIM),\n        strides=(stride_vn, stride_vd),\n        offsets=(0, 0),\n        block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),\n        order=(1, 0),\n    )\n    # load q\n    q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option=\"zero\")\n    # init statistics\n    off_k = tl.arange(0, BLOCK_SIZE_K)\n    m_i = tl.full((BLOCK_SIZE_H,), float(\"-inf\"), dtype=tl.float32)\n    lse_i = tl.full((BLOCK_SIZE_H,), float(\"-inf\"), dtype=tl.float32)\n    acc_o = tl.full((BLOCK_SIZE_H, BLOCK_SIZE_D), 0, dtype=tl.float32)\n    # sparse attention\n    for i in range(real_topk):\n        # get current block start index\n        c = tl.load(t_ptr).to(tl.int32) * BLOCK_SIZE_K\n        t_ptr = t_ptr + stride_tk\n        # load k\n        k = tl.load(\n            tl.advance(k_ptrs, (0, c)), boundary_check=(1, 0), padding_option=\"zero\"\n        )\n        # compute qk\n        qk = tl.zeros((BLOCK_SIZE_H, BLOCK_SIZE_K), dtype=tl.float32)\n        qk += tl.where((kv_len > c + off_k)[None, :], 0, float(\"-inf\"))\n        # [BLOCK_SIZE_H, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIZE_K] -> [BLOCK_SIZE_H, BLOCK_SIZE_K]\n        qk += tl.dot(q, k) * qk_scale\n        # compute m_ij and l_ij\n        m_ij = tl.maximum(m_i, tl.max(qk, axis=1))\n        p = tl.exp2(qk - m_ij[:, None])\n        l_ij = tl.sum(p, axis=1)\n        # scale acc_o\n        acc_o_scale = tl.exp2(m_i - m_ij)\n        acc_o = acc_o * acc_o_scale[:, None]\n        # load v and update acc_o\n        v = tl.load(\n            tl.advance(v_ptrs, (c, 0)), boundary_check=(0, 1), padding_option=\"zero\"\n        )\n        p = p.to(v.dtype)\n        acc_o += tl.dot(p, v)\n        # update statistics\n        m_i = m_ij\n        lse_i = m_ij + tl.math.log2(tl.exp2(lse_i - m_ij) + l_ij)\n    # final scale\n    acc_o = acc_o * tl.exp2(m_i - lse_i)[:, None]\n    # save output\n    o_ptrs = tl.make_block_ptr(\n        base=o_ptr + pid_b * stride_ob + pid_kh * NUM_SHARE_Q_HEADS * stride_oh,\n        shape=(NUM_SHARE_Q_HEADS, HEAD_DIM),\n        strides=(stride_oh, stride_od),\n        offsets=(0, 0),\n        block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D),\n        order=(1, 0),\n    )\n    tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1))\n\n\ndef topk_sparse_attention_decode(\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    topk_idx: torch.Tensor,\n    block_size: int,\n    seqlens: torch.Tensor,\n    sm_scale: Optional[float] = None,\n) -> torch.Tensor:\n    \"\"\"_summary_\n\n    Args:\n        q (torch.Tensor): shape [batch_size, num_q_heads, head_dim]\n        k (torch.Tensor): shape [batch_size, kv_len, num_kv_heads, head_dim]\n        v (torch.Tensor): shape [batch_size, kv_len, num_kv_heads, head_dim]\n        topk_idx (torch.Tensor): topk block idx for each query, shape [num_kv_heads, batch_size, topk]. -1 means padding.\n        block_size (int): key value block size.\n        seqlens (torch.Tensor): max kv length for each sequence\n        softmax_scale (Optional[float], optional): Defaults to None, means 1/sqrt(head_dim).\n\n    Returns:\n        torch.Tensor: sparse attention output\n    \"\"\"\n    # dtype check\n    assert q.dtype == torch.bfloat16 or q.dtype == torch.float16\n    assert k.dtype == q.dtype and v.dtype == q.dtype\n    assert seqlens.dtype == torch.int32\n    # shape\n    batch_size, num_q_heads, head_dim = q.shape\n    _, k_len, num_k_heads, head_dim = k.shape\n    _, v_len, num_v_heads, head_dim = v.shape\n    assert k_len == v_len and batch_size == seqlens.shape[0]\n    assert num_k_heads == topk_idx.shape[0] and batch_size == topk_idx.shape[1]\n    topk = topk_idx.shape[-1]\n    # gqa\n    assert num_k_heads == num_v_heads\n    assert num_q_heads % num_k_heads == 0\n    num_share_q_heads = num_q_heads // num_k_heads\n    # sm scale\n    if sm_scale is None:\n        sm_scale = 1 / math.sqrt(head_dim)\n    # output tensor\n    o = torch.zeros_like(q)\n    # launch kernel\n    grid = (batch_size, num_k_heads)\n    num_warps = 4 if head_dim <= 64 else 8\n    num_stages = 3\n    BLOCK_SIZE_K = triton.next_power_of_2(block_size)\n    BLOCK_SIZE_D = triton.next_power_of_2(head_dim)\n    BLOCK_SIZE_H = max(16, triton.next_power_of_2(num_share_q_heads))\n    BLOCK_SIZE_T = triton.next_power_of_2(topk)\n    forward_kernel[grid](\n        q,\n        k,\n        v,\n        topk_idx,\n        o,\n        seqlens,\n        num_share_q_heads,\n        head_dim,\n        topk,\n        sm_scale,\n        q.stride(0),\n        q.stride(1),\n        q.stride(2),\n        k.stride(0),\n        k.stride(1),\n        k.stride(2),\n        k.stride(3),\n        v.stride(0),\n        v.stride(1),\n        v.stride(2),\n        v.stride(3),\n        topk_idx.stride(0),\n        topk_idx.stride(1),\n        topk_idx.stride(2),\n        o.stride(0),\n        o.stride(1),\n        o.stride(2),\n        BLOCK_SIZE_K=BLOCK_SIZE_K,\n        BLOCK_SIZE_D=BLOCK_SIZE_D,\n        BLOCK_SIZE_H=BLOCK_SIZE_H,\n        BLOCK_SIZE_T=BLOCK_SIZE_T,\n        num_warps=num_warps,\n        num_stages=num_stages,\n    )\n    return o\n\n\ndef torch_topk_sparse_attention_decode(\n    q: torch.Tensor,  # [batch_size, num_q_heads, head_dim]\n    k: torch.Tensor,  # [batch_size, kv_len, num_k_heads, head_dim]\n    v: torch.Tensor,  # [batch_size, kv_len, num_k_heads, head_dim]\n    topk_idx: torch.Tensor,  # [num_k_heads, batch_size, topk]\n    block_size: int,\n    seqlens: torch.Tensor,  # [batch_size, ]\n    sm_scale: Optional[float] = None,\n):\n    # dtype check\n    assert q.dtype == torch.bfloat16 or q.dtype == torch.float16\n    assert k.dtype == q.dtype and v.dtype == q.dtype\n    assert seqlens.dtype == torch.int32\n    # shape\n    batch_size, num_q_heads, head_dim = q.shape\n    _, k_len, num_k_heads, head_dim = k.shape\n    _, v_len, num_v_heads, head_dim = v.shape\n    assert k_len == v_len and batch_size == seqlens.shape[0]\n    assert num_k_heads == topk_idx.shape[0] and batch_size == topk_idx.shape[1]\n    topk = topk_idx.shape[-1]\n    # gqa\n    assert num_k_heads == num_v_heads\n    assert num_q_heads % num_k_heads == 0\n    num_share_q_heads = num_q_heads // num_k_heads\n    # sm scale\n    if sm_scale is None:\n        sm_scale = 1 / math.sqrt(head_dim)\n    # mask\n    mask = torch.zeros(\n        (batch_size, num_k_heads, k_len), dtype=torch.bool, device=q.device\n    )\n    for b in range(batch_size):\n        for h in range(num_k_heads):\n            for t in range(topk):\n                if topk_idx[h, b, t] != -1:\n                    mask[\n                        b,\n                        h,\n                        topk_idx[h, b, t]\n                        * block_size : (topk_idx[h, b, t] + 1)\n                        * block_size,\n                    ] = True\n    mask = mask & (\n        (seqlens - 1)[:, None, None] >= torch.arange(k_len).cuda()[None, None, :]\n    )\n    mask = mask.repeat_interleave(num_share_q_heads, 1)\n    # attention\n    attn = (\n        torch.einsum(\n            \"bqhd,bkhd->bhqk\", q.unsqueeze(1), k.repeat_interleave(num_share_q_heads, 2)\n        )\n        * sm_scale\n    )\n    attn = attn.masked_fill(~mask.unsqueeze(2), -torch.inf)\n    attn = torch.softmax(attn, dim=-1, dtype=torch.float32).to(q.dtype)\n    out = torch.einsum(\n        \"bhqk,bkhd->bqhd\", attn, v.repeat_interleave(num_share_q_heads, 2)\n    ).squeeze(1)\n    return out\n\n\ndef generate_topk_idx_example(\n    seqlens: torch.Tensor, block_size: int, topk: int, num_heads: int\n) -> torch.Tensor:\n    batch_size = seqlens.shape[0]\n    num_blocks = torch.ceil(seqlens / block_size).to(torch.int32)\n    topk_idx_all_heads = []\n    for _ in range(num_heads):\n        topk_idx = [\n            torch.randn(1, num_blocks[i], device=\"cuda\")\n            .topk(min(topk, num_blocks[i]), dim=-1)\n            .indices.to(torch.int32)\n            for i in range(batch_size)\n        ]\n        topk_idx = [\n            torch.nn.functional.pad(\n                topk_idx[i], (0, topk - topk_idx[i].shape[-1]), value=topk\n            )\n            for i in range(batch_size)\n        ]\n        topk_idx = torch.cat(topk_idx, dim=0)\n        topk_idx = torch.sort(topk_idx, dim=1).values\n        topk_idx[:, 0] = 0\n        q_idx = seqlens - 1\n        topk_idx[topk_idx > (q_idx // block_size)[:, None]] = -1  # -1 means padding\n        topk_idx_all_heads.append(topk_idx)\n    topk_idx = torch.stack(topk_idx_all_heads, dim=0)\n    return topk_idx\n\n\nif __name__ == \"__main__\":\n    torch.manual_seed(42)\n    topk = 16\n    block_size = 64\n    batch_size = 76\n    max_length = 8192\n    seqlens = torch.arange(batch_size, dtype=torch.int32).cuda() * 128 + 1\n    seqlens[seqlens > max_length] = max_length\n    seqlens = seqlens[torch.randn_like(seqlens, dtype=torch.float32).argsort(-1)]\n    q = (\n        torch.empty(batch_size, 32, 128, device=\"cuda\")\n        .uniform_(-1, 1)\n        .to(torch.float16)\n    )\n    k = (\n        torch.empty(batch_size, max_length, 4, 128, device=\"cuda\")\n        .uniform_(-1, 1)\n        .to(torch.float16)\n    )\n    v = (\n        torch.empty(batch_size, max_length, 4, 128, device=\"cuda\")\n        .uniform_(-1, 1)\n        .to(torch.float16)\n    )\n    topk_idx = generate_topk_idx_example(seqlens, block_size, topk, 4)\n\n    o1 = torch_topk_sparse_attention_decode(q, k, v, topk_idx, block_size, seqlens)\n    o2 = topk_sparse_attention_decode(q, k, v, topk_idx, block_size, seqlens)\n\n    print(torch.allclose(o1, o2, atol=1e-3, rtol=1e-3))\n    print((o1 - o2).abs().max())\n"
  },
  {
    "path": "native_sparse_attention/ops/triton/utils.py",
    "content": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport torch\n\n\ndef is_hopper_gpu():\n    if torch.cuda.is_available():\n        device_capability = torch.cuda.get_device_capability()\n        major, minor = device_capability\n        return major == 9\n    return False\n\n\ndef get_compressed_seqlens(\n    cu_seqlens: torch.Tensor, kernel_size: int, kernel_stride: int\n):\n    # compute seqlens after compression\n    seqlens = cu_seqlens[1:] - cu_seqlens[:-1]\n    y_seqlens = torch.floor((seqlens - kernel_size) / kernel_stride).to(torch.int32) + 1\n    # corner case, if sequence_length < kernel_size, no compression for this sequence\n    y_seqlens[seqlens < kernel_size] = 0\n    y_cu_seqlens = torch.zeros(\n        y_seqlens.shape[0] + 1, dtype=torch.int32, device=cu_seqlens.device\n    )\n    y_cu_seqlens[1:] = torch.cumsum(y_seqlens, dim=0)\n    return y_seqlens, y_cu_seqlens\n\n\ndef get_num_warps_stages(head_dim, block_size, is_hopper_gpu):\n    \"\"\"\n    Returns recommended num_warps and num_stages for a Sparse Attention kernel in Triton.\n\n    Args:\n        head_dim (int): Size of the head dimension.\n        block_size (int): Size of the block in the attention matrix.\n        is_hopper_gpu (bool): True if Hopper GPU, False if Ampere GPU.\n\n    Returns:\n        tuple: (num_warps, num_stages) recommended values.\n    \"\"\"\n    # Determine if head_dim and block_size exceed 64\n    head_large = head_dim > 64\n    block_large = block_size > 64\n\n    if is_hopper_gpu:\n        # Hopper GPU recommendations\n        if head_large and block_large:\n            num_warps = 8\n            num_stages = 3\n        elif head_large or block_large:\n            num_warps = 4\n            num_stages = 3\n        else:\n            num_warps = 2\n            num_stages = 2\n    else:\n        # Ampere GPU recommendations\n        if head_large and block_large:\n            num_warps = 8\n            num_stages = 3\n        elif head_large or block_large:\n            num_warps = 8\n            num_stages = 3\n        else:\n            num_warps = 2\n            num_stages = 2\n    if head_dim > 128:\n        num_stages = 2\n    return num_warps, num_stages\n"
  },
  {
    "path": "native_sparse_attention/ops/triton/weighted_pool.py",
    "content": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Optional\nfrom einops import einsum\nimport torch\nimport triton\nimport triton.language as tl\nfrom native_sparse_attention.ops.triton.utils import get_compressed_seqlens\n\n\n@triton.jit\ndef sliding_pool_fwd_kernel(\n    x_ptr,\n    y_ptr,\n    w_ptr,\n    cu_seqlens,\n    y_cu_seqlens,\n    head_dim,\n    kernel_size,\n    kernel_stride,\n    stride_xn,\n    stride_xh,\n    stride_xd,\n    stride_yn,\n    stride_yh,\n    stride_yd,\n    stride_wh,\n    stride_wk,\n    BLOCK_SIZE_K: tl.constexpr,\n    BLOCK_SIZE_D: tl.constexpr,\n):\n    pid_b = tl.program_id(0)\n    pid_h = tl.program_id(1)\n    pid_k = tl.program_id(2)\n    # get start and len after rmpad\n    x_start = tl.load(cu_seqlens + pid_b)\n    x_len = tl.load(cu_seqlens + pid_b + 1) - x_start\n    y_start = tl.load(y_cu_seqlens + pid_b)\n    y_len = tl.load(y_cu_seqlens + pid_b + 1) - y_start\n    if pid_k >= y_len:\n        return\n    if w_ptr is not None:\n        # load w\n        w_ptrs = tl.make_block_ptr(\n            base=w_ptr + pid_h * stride_wh,\n            shape=(kernel_size, 1),\n            strides=(stride_wk, 0),\n            offsets=(0, 0),\n            block_shape=(BLOCK_SIZE_K, 1),\n            order=(0, 1),\n        )\n        w = tl.load(w_ptrs, boundary_check=(0, 1), padding_option=\"zero\")\n    # load x\n    x_ptrs = tl.make_block_ptr(\n        base=x_ptr + x_start * stride_xn + pid_h * stride_xh,\n        shape=(x_len, head_dim),\n        strides=(stride_xn, stride_xd),\n        offsets=(pid_k * kernel_stride, 0),\n        block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D),\n        order=(1, 0),\n    )\n    x = tl.load(x_ptrs, boundary_check=(0, 1), padding_option=\"zero\")\n    # compute y\n    if w_ptr is not None:\n        y = tl.sum(x * w, axis=0)\n    else:\n        y = tl.sum(x, axis=0) / kernel_size\n    off_d = tl.arange(0, BLOCK_SIZE_D)\n    tl.store(\n        y_ptr + (y_start + pid_k) * stride_yn + pid_h * stride_yh + off_d * stride_yd,\n        y.to(y_ptr.dtype.element_ty),\n        mask=off_d < head_dim,\n    )\n\n\n@triton.jit\ndef sliding_pool_dxdw_kernel(\n    x_ptr,\n    dx_ptr,\n    dy_ptr,\n    w_ptr,\n    dw_ptr,\n    cu_seqlens,\n    y_cu_seqlens,\n    head_dim,\n    kernel_size,\n    kernel_stride,\n    stride_xn,\n    stride_xh,\n    stride_xd,\n    stride_dxn,\n    stride_dxh,\n    stride_dxd,\n    stride_dyn,\n    stride_dyh,\n    stride_dyd,\n    stride_wh,\n    stride_wk,\n    stride_dwh,\n    stride_dwn,\n    stride_dwk,\n    BLOCK_SIZE_K: tl.constexpr,\n    BLOCK_SIZE_D: tl.constexpr,\n):\n    pid_b = tl.program_id(0)\n    pid_h = tl.program_id(1)\n    pid_k = tl.program_id(2)\n    # get start and len after rmpad\n    x_start = tl.load(cu_seqlens + pid_b)\n    x_len = tl.load(cu_seqlens + pid_b + 1) - x_start\n    y_start = tl.load(y_cu_seqlens + pid_b)\n    y_len = tl.load(y_cu_seqlens + pid_b + 1) - y_start\n    if pid_k >= y_len:\n        return\n    # offsets\n    off_d = tl.arange(0, BLOCK_SIZE_D)\n    off_k = tl.arange(0, BLOCK_SIZE_K)\n    if w_ptr is not None:\n        # load w\n        w_ptrs = w_ptr + pid_h * stride_wh + off_k * stride_wk\n        w = tl.load(w_ptrs, mask=off_k < kernel_size, other=0)\n    # load x\n    x_ptrs = tl.make_block_ptr(\n        base=x_ptr + x_start * stride_xn + pid_h * stride_xh,\n        shape=(head_dim, x_len),\n        strides=(stride_xd, stride_xn),\n        offsets=(0, pid_k * kernel_stride),\n        block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K),\n        order=(0, 1),\n    )\n    x = tl.load(x_ptrs, boundary_check=(0, 1), padding_option=\"zero\")\n    # load dy\n    dy_ptrs = (\n        dy_ptr\n        + pid_h * stride_dyh\n        + (y_start + pid_k) * stride_dyn\n        + off_d * stride_dyd\n    )\n    dy = tl.load(dy_ptrs, mask=off_d < head_dim, other=0)\n    if w_ptr is not None:\n        # compute dx, [1, D] x [K, 1] -> [K, D]\n        dx = dy[None, :] * w[:, None]\n        # compute dw, [D, 1] x [D, K] -> [D, K] -> [K]\n        dw = tl.sum(dy[:, None] * x, axis=0)\n        # store dw\n        dw_ptrs = (\n            dw_ptr\n            + pid_h * stride_dwh\n            + (y_start + pid_k) * stride_dwn\n            + off_k * stride_dwk\n        )\n        tl.store(dw_ptrs, dw.to(dw_ptr.dtype.element_ty), mask=off_k < kernel_size)\n    else:\n        dx = dy[None, :] / kernel_size\n    # store dx\n    dx_ptrs = (\n        dx_ptr\n        + pid_h * stride_dxh\n        + (x_start + pid_k * kernel_stride + off_k[:, None]) * stride_dxn\n        + off_d[None, :] * stride_dxd\n    )\n    tl.atomic_add(\n        dx_ptrs,\n        dx.to(dx_ptr.dtype.element_ty),\n        mask=(off_k < x_len - pid_k * kernel_stride)[:, None]\n        & (off_d < head_dim)[None, :],\n    )\n\n\nclass SlidingWindowWeightedPool(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        x: torch.Tensor,  # [total_len, num_heads, head_dim]\n        w: torch.Tensor,  # [num_heads, kernel_size]\n        cu_seqlens: torch.Tensor,\n        kernel_size: int,\n        kernel_stride: int,\n    ):\n        # dtype check\n        assert x.dtype == torch.float16 or x.dtype == torch.bfloat16\n        if w is not None:\n            assert x.dtype == w.dtype\n        assert cu_seqlens.dtype == torch.int32\n        # shape check\n        total_len, num_heads, head_dim = x.shape\n        batch_size = cu_seqlens.shape[0] - 1\n        if w is not None:\n            assert w.shape[0] == num_heads\n            assert w.shape[1] == kernel_size\n        assert kernel_size % kernel_stride == 0\n        assert kernel_size in {16, 32, 64, 128}\n        # compute seqlens after compression\n        seqlens = cu_seqlens[1:] - cu_seqlens[:-1]\n        y_seqlens, y_cu_seqlens = get_compressed_seqlens(\n            cu_seqlens, kernel_size, kernel_stride\n        )\n        # output buffer\n        y = torch.zeros(\n            y_cu_seqlens[-1], num_heads, head_dim, dtype=x.dtype, device=x.device\n        )\n        # launch kernel\n        BLOCK_SIZE_D = triton.next_power_of_2(head_dim)\n        BLOCK_SIZE_K = triton.next_power_of_2(kernel_size)\n        grid = (batch_size, num_heads, y_seqlens.max().item())\n        sliding_pool_fwd_kernel[grid](\n            x,\n            y,\n            w,\n            cu_seqlens,\n            y_cu_seqlens,\n            head_dim,\n            kernel_size,\n            kernel_stride,\n            x.stride(0),\n            x.stride(1),\n            x.stride(2),\n            y.stride(0),\n            y.stride(1),\n            y.stride(2),\n            w.stride(0) if w is not None else None,\n            w.stride(1) if w is not None else None,\n            BLOCK_SIZE_K=BLOCK_SIZE_K,\n            BLOCK_SIZE_D=BLOCK_SIZE_D,\n        )\n        ctx.save_for_backward(x, w, seqlens, cu_seqlens, y_seqlens, y_cu_seqlens)\n        ctx.kernel_size = kernel_size\n        ctx.kernel_stride = kernel_stride\n        ctx.head_dim = head_dim\n        return y, y_cu_seqlens\n\n    @staticmethod\n    def backward(ctx, dy, _):\n        x, w, seqlens, cu_seqlens, y_seqlens, y_cu_seqlens = ctx.saved_tensors\n        kernel_size = ctx.kernel_size\n        kernel_stride = ctx.kernel_stride\n        head_dim = ctx.head_dim\n        batch_size = cu_seqlens.shape[0] - 1\n        num_heads = x.shape[1]\n        # compute dx\n        dx = torch.zeros_like(x, dtype=torch.float32)\n        if w is not None:\n            dw = torch.zeros(\n                num_heads,\n                y_cu_seqlens[-1],\n                kernel_size,\n                dtype=torch.float32,\n                device=w.device,\n            )\n        BLOCK_SIZE_D = triton.next_power_of_2(head_dim)\n        BLOCK_SIZE_K = triton.next_power_of_2(kernel_size)\n        grid = (batch_size, num_heads, y_seqlens.max().item())\n        sliding_pool_dxdw_kernel[grid](\n            x,\n            dx,\n            dy,\n            w,\n            dw if w is not None else None,\n            cu_seqlens,\n            y_cu_seqlens,\n            head_dim,\n            kernel_size,\n            kernel_stride,\n            x.stride(0),\n            x.stride(1),\n            x.stride(2),\n            dx.stride(0),\n            dx.stride(1),\n            dx.stride(2),\n            dy.stride(0),\n            dy.stride(1),\n            dy.stride(2),\n            w.stride(0) if w is not None else None,\n            w.stride(1) if w is not None else None,\n            dw.stride(0) if w is not None else None,\n            dw.stride(1) if w is not None else None,\n            dw.stride(2) if w is not None else None,\n            BLOCK_SIZE_K=BLOCK_SIZE_K,\n            BLOCK_SIZE_D=BLOCK_SIZE_D,\n        )\n        dx = dx.to(x.dtype)\n        if w is None:\n            dw = None\n        else:\n            dw = dw.sum(1).to(w.dtype)\n        return dx, dw, None, None, None\n\n\ndef weightedpool_compress(\n    x: torch.Tensor,  # [total_len, num_heads, head_dim]\n    w: torch.Tensor,  # [num_heads, kernel_size]\n    cu_seqlens: torch.Tensor,\n    kernel_size: int,\n    kernel_stride: int,\n    pe: Optional[torch.Tensor] = None,\n):\n    \"\"\"Compress key and value tensor with kernel_size and kernel_stride with weighted pooling\n\n    Args:\n        x (torch.Tensor): key_states or value_states, shape (total_len, num_heads, head_dim)\n        w (torch.Tensor): weight for each head, shape (num_heads, kernel_size)\n        cu_seqlens (_type_): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.\n        kernel_size (int): kernel_size, each (kernel_size, head_dim) blocks will be compressed to (1, head_dim)\n        kernel_stride (int): stride for each compress kernel\n        pe (Optional[torch.Tensor], optional): intra-block positional embedding with shape (num_heads, kernel_size, head_dim). Defaults to None.\n\n    Returns:\n        Tuple[torch.Tensor, torch.Tensor]: compressed states and corresponding cu_seqlens.\n    \"\"\"\n    y, y_cu_seqlens = SlidingWindowWeightedPool.apply(\n        x, w, cu_seqlens, kernel_size, kernel_stride\n    )\n    # position embedding as a bias\n    if pe is not None:\n        assert pe.dtype == x.dtype and pe.device == x.device\n        bias = einsum(pe, w, \"h k d, h k -> h d\")\n        y = y + bias.unsqueeze(0)\n    return y, y_cu_seqlens\n\n\ndef avgpool_compress(\n    x: torch.Tensor,  # [total_len, num_heads, head_dim]\n    w: torch.Tensor,  # don't need weight\n    cu_seqlens: torch.Tensor,\n    kernel_size: int,\n    kernel_stride: int,\n    pe: Optional[torch.Tensor] = None,\n):\n    \"\"\"Compress key and value tensor with kernel_size and kernel_stride with average pooling.\n\n    Args:\n        x (torch.Tensor): key_states or value_states, shape (total_len, num_heads, head_dim)\n        w (torch.Tensor): weight for each head, shape (num_heads, kernel_size)\n        cu_seqlens (_type_): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen.\n        kernel_size (int): kernel_size, each (kernel_size, head_dim) blocks will be compressed to (1, head_dim)\n        kernel_stride (int): stride for each compress kernel\n        pe (Optional[torch.Tensor], optional): intra-block positional embedding with shape (num_heads, kernel_size, head_dim). Defaults to None.\n\n    Returns:\n        Tuple[torch.Tensor, torch.Tensor]: compressed states and corresponding cu_seqlens.\n    \"\"\"\n    assert w is None, \"don't need additional weight for avgpool\"\n    y, y_cu_seqlens = SlidingWindowWeightedPool.apply(\n        x, w, cu_seqlens, kernel_size, kernel_stride\n    )\n    # position embedding as a bias\n    if pe is not None:\n        assert pe.dtype == x.dtype and pe.device == x.device\n        bias = torch.mean(pe, dim=1)\n        y = y + bias.unsqueeze(0)\n    return y, y_cu_seqlens\n\n\nif __name__ == \"__main__\":\n    from native_sparse_attention.ops.torch.compress_key_value import (\n        weightedpool_compress_torch,\n    )\n\n    torch.manual_seed(42)\n    num_heads = 4\n    head_dim = 128\n    kernel_size = 32\n    kernel_stride = 16\n    seqlens = torch.LongTensor([12, 1000, 2000, 4096]).int().cuda()\n    cu_seqlens = torch.cat(\n        [\n            torch.zeros(1, dtype=torch.int32, device=\"cuda\"),\n            torch.cumsum(seqlens, dim=0),\n        ],\n        dim=0,\n    ).to(torch.int32)\n\n    x = (\n        torch.zeros(cu_seqlens[-1], num_heads, head_dim)\n        .uniform_(-1, 1)\n        .to(torch.bfloat16)\n        .cuda()\n        .requires_grad_()\n    )\n    w = (\n        torch.zeros(num_heads, kernel_size)\n        .uniform_(-1 / 32, 1 / 32)\n        .cuda()\n        .to(torch.bfloat16)\n        .requires_grad_()\n    )\n\n    y, y_cu_seqlens = weightedpool_compress_torch(\n        x, w, cu_seqlens, kernel_size, kernel_stride, None\n    )\n\n    x1 = x.clone().detach().requires_grad_()\n    w1 = w.clone().detach().requires_grad_()\n\n    y1, y1_cu_seqlens = weightedpool_compress(\n        x1, w1, cu_seqlens, kernel_size, kernel_stride\n    )\n\n    print(torch.allclose(y, y1, rtol=1e-2, atol=1e-2))\n    print(torch.abs(y - y1).max().item())\n\n    randn = torch.randn_like(y)\n    randn1 = randn.clone().detach()\n\n    loss = (y * randn).sum()\n    loss1 = (y1 * randn1).sum()\n    loss.backward()\n    loss1.backward()\n\n    print((x.grad - x1.grad).abs().max())\n    print(((w.grad - w1.grad).abs() / w.grad.abs()).max())\n"
  },
  {
    "path": "setup.py",
    "content": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom setuptools import setup, find_packages\n\n# Read the README.md file for the long description\nwith open(\"README.md\", \"r\", encoding=\"utf-8\") as fh:\n    long_description = fh.read()\n\n# Define the setup configuration\nsetup(\n    name=\"native-sparse-attention-triton\",\n    version=\"0.1.0\",\n    description=\"An efficient implementation of Native Sparse Attention using Triton\",\n    long_description=long_description,\n    long_description_content_type=\"text/markdown\",\n    author=\"XunhaoLai\",\n    author_email=\"laixunhao@pku.edu.cn\",  # Replace with your actual email\n    url=\"https://github.com/XunhaoLai/native-sparse-attention-triton\",\n    packages=find_packages(),\n    install_requires=[\n        \"torch>=2.1.0\",\n        \"triton>=3.0.0\",\n        \"einops>=0.7.0\",\n        \"flash-attn>=2.6.3\",\n        \"transformers>=4.44.0\",\n    ],\n    classifiers=[\n        \"Programming Language :: Python :: 3\",\n        \"License :: OSI Approved :: Apache Software License\",\n        \"Operating System :: OS Independent\",\n    ],\n    python_requires=\">=3.9\",\n)\n"
  },
  {
    "path": "test/test_compress_key_value.py",
    "content": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific\nimport torch\nimport triton\nfrom native_sparse_attention.ops import linear_compress\n\n\nif __name__ == \"__main__\":\n    torch.manual_seed(42)\n    num_heads = 4\n    head_dim = 192\n    kernel_size = 32\n    kernel_stride = 16\n    seqlens = torch.LongTensor([1000, 2000, 4096]).int().cuda()\n    cu_seqlens = torch.cat(\n        [\n            torch.zeros(1, dtype=torch.int32, device=\"cuda\"),\n            torch.cumsum(seqlens, dim=0),\n        ],\n        dim=0,\n    ).to(torch.int32)\n\n    x = (\n        torch.zeros(cu_seqlens[-1], num_heads, head_dim)\n        .uniform_(-1, 1)\n        .cuda()\n        .bfloat16()\n        .requires_grad_()\n    )\n    w = (\n        torch.zeros(num_heads, kernel_size * head_dim, head_dim)\n        .uniform_(-1, 1)\n        .cuda()\n        .bfloat16()\n        .requires_grad_()\n    )\n    pe = (\n        torch.zeros(num_heads, kernel_size, head_dim)\n        .uniform_(-1, 1)\n        .cuda()\n        .bfloat16()\n        .requires_grad_()\n    )\n\n    y, y_cu_seqlens = linear_compress(x, w, cu_seqlens, kernel_size, kernel_stride, pe)\n\n    loss = (y * torch.randn_like(y)).mean()\n    loss.backward()\n\n    print(y.shape, y_cu_seqlens)\n    print(y.norm(), x.grad.norm())\n    print(\n        w.grad.norm() if w.grad is not None else None,\n        pe.grad.norm() if pe.grad is not None else None,\n    )\n\n    # benchmark\n    @triton.testing.perf_report(\n        triton.testing.Benchmark(\n            x_names=[\"N\"],\n            x_vals=[1024 * 2**i for i in range(1, 6)],\n            line_arg=\"provider\",\n            line_vals=[\"batch1\", \"batch8\", \"batch32\"],\n            line_names=[\"batch1\", \"batch8\", \"batch32\"],\n            styles=[(\"green\", \"-\"), (\"blue\", \"-\"), (\"blue\", \"--\")],\n            ylabel=\"ms\",\n            plot_name=\"** forward **\",\n            args={\"H\": 4, \"D\": 128},\n        )\n    )\n    def benchmark(N, H, D, provider):\n        K, S = 32, 16\n        x = torch.zeros(N, H, D, device=\"cuda\", dtype=torch.bfloat16).uniform_(-1, 1)\n        w = torch.zeros(H, K * D, D, device=\"cuda\", dtype=torch.bfloat16).uniform_(\n            -1, 1\n        )\n        pe = torch.zeros(H, K, D, device=\"cuda\", dtype=torch.bfloat16).uniform_(-1, 1)\n        cu_seqlens_b1 = torch.LongTensor([0, N]).int().cuda()\n        cu_seqlens_b8 = (\n            torch.LongTensor([N // 8 if i > 0 else 0 for i in range(9)]).int().cuda()\n        )\n        cu_seqlens_b32 = (\n            torch.LongTensor([N // 32 if i > 0 else 0 for i in range(33)]).int().cuda()\n        )\n        cu_seqlens_b1 = cu_seqlens_b1.cumsum(0).to(torch.int32)\n        cu_seqlens_b8 = cu_seqlens_b8.cumsum(0).to(torch.int32)\n        cu_seqlens_b32 = cu_seqlens_b32.cumsum(0).to(torch.int32)\n\n        quantiles = [0.5, 0.2, 0.8]\n        if provider == \"batch1\":\n            ms, min_ms, max_ms = triton.testing.do_bench(\n                lambda: linear_compress(x, w, cu_seqlens_b1, K, S, pe),\n                quantiles=quantiles,\n            )\n        if provider == \"batch8\":\n            ms, min_ms, max_ms = triton.testing.do_bench(\n                lambda: linear_compress(x, w, cu_seqlens_b8, K, S, pe),\n                quantiles=quantiles,\n            )\n        if provider == \"batch32\":\n            ms, min_ms, max_ms = triton.testing.do_bench(\n                lambda: linear_compress(x, w, cu_seqlens_b32, K, S, pe),\n                quantiles=quantiles,\n            )\n        return ms, min_ms, max_ms\n\n    benchmark.run(show_plots=True, print_data=True)\n"
  },
  {
    "path": "test/test_compressed_attention.py",
    "content": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific\nimport torch\nimport triton\nimport math\nfrom native_sparse_attention.ops.torch.compressed_attention import (\n    compressed_attention_torch,\n)\nfrom native_sparse_attention.ops.triton.compressed_attention import (\n    compressed_attention,\n    _compressed_attention_bwd,\n)\nfrom native_sparse_attention.ops import avgpool_compress\nfrom native_sparse_attention.ops.triton.flash_attention import (\n    flash_attention_varlen,\n    _flash_attention_bwd,\n)\nfrom flash_attn import flash_attn_varlen_func\nfrom flash_attn.flash_attn_interface import _flash_attn_varlen_backward\n\n\nif __name__ == \"__main__\":\n    torch.manual_seed(42)\n    num_heads = 32\n    head_dim = 96\n    kernel_size = 32\n    kernel_stride = 16\n    block_size = 64\n    topk = 16\n    seqlens = torch.LongTensor([1000, 4000, 8192]).int().cuda()\n    cu_seqlens = torch.cat(\n        [\n            torch.zeros(1, dtype=torch.int32, device=\"cuda\"),\n            torch.cumsum(seqlens, dim=0),\n        ],\n        dim=0,\n    ).to(torch.int32)\n    max_seqlen = seqlens.max().item()\n    q = (\n        torch.empty(cu_seqlens[-1], num_heads, head_dim, device=\"cuda\")\n        .uniform_(-1, 1)\n        .to(torch.float16)\n    )\n    k = (\n        torch.empty(cu_seqlens[-1], num_heads // 4, head_dim, device=\"cuda\")\n        .uniform_(-1, 1)\n        .to(torch.float16)\n    )\n    v = (\n        torch.empty(cu_seqlens[-1], num_heads // 4, head_dim, device=\"cuda\")\n        .uniform_(-1, 1)\n        .to(torch.float16)\n    )\n    q.requires_grad = True\n    k.requires_grad = True\n    v.requires_grad = True\n\n    ck, ck_cu_seqlens = avgpool_compress(\n        k, None, cu_seqlens, kernel_size, kernel_stride\n    )\n\n    ck = torch.empty_like(ck).uniform_(-1, 1)\n    cv = torch.empty_like(ck).uniform_(-1, 1)\n    ck.requires_grad = True\n    cv.requires_grad = True\n\n    ck_seqlens = ck_cu_seqlens[1:] - ck_cu_seqlens[:-1]\n    ck_max_seqlen = ck_seqlens.max().item()\n\n    o, topk_idx = compressed_attention_torch(\n        q,\n        ck,\n        cv,\n        kernel_size,\n        kernel_stride,\n        block_size,\n        topk,\n        cu_seqlens,\n        ck_cu_seqlens,\n        max_seqlen,\n        ck_max_seqlen,\n    )\n\n    randn = torch.randn_like(o)\n    loss = (o * randn).sum()\n    loss.backward()\n\n    torch.manual_seed(42)\n\n    q1 = q.detach().clone().requires_grad_()\n    ck1 = ck.detach().clone().requires_grad_()\n    cv1 = cv.detach().clone().requires_grad_()\n\n    o1, topk_idx1 = compressed_attention(\n        q1,\n        ck1,\n        cv1,\n        kernel_size,\n        kernel_stride,\n        block_size,\n        topk,\n        cu_seqlens,\n        ck_cu_seqlens,\n        max_seqlen,\n        ck_max_seqlen,\n    )\n    randn1 = randn.clone().detach()\n    loss1 = (o1 * randn1).sum()\n    loss1.backward()\n\n    print(\"Same Output:\", torch.allclose(o, o1, atol=0.01, rtol=0.01))\n    print(\"Max Error:\", (o - o1).abs().max().item())\n    print()\n    print(\"Same Query Gradient:\", torch.allclose(q.grad, q1.grad, atol=0.01, rtol=0.01))\n    print(\"Max Query Gradient Error:\", (q.grad - q1.grad).abs().max().item())\n    print()\n    print(\"Same Key Gradient:\", torch.allclose(ck.grad, ck1.grad, atol=0.01, rtol=0.01))\n    print(\"Max Key Gradient Error:\", (ck.grad - ck1.grad).abs().max().item())\n    print()\n    print(\n        \"Same Value Gradient:\", torch.allclose(cv.grad, cv1.grad, atol=0.01, rtol=0.01)\n    )\n    print(\"Max Value Gradient Error:\", (cv.grad - cv1.grad).abs().max().item())\n    print()\n\n    # There are some discrepancies in the topk indices (about 3%). These might be due to bugs and will be addressed later.\n    all_num = 0\n    err_num = 0\n    for h in range(topk_idx.shape[0]):\n        for i in range(topk_idx.shape[1]):\n            s = set(topk_idx[h, i][topk_idx[h, i] >= 0].tolist())\n            s1 = set(topk_idx1[h, i][topk_idx1[h, i] >= 0].tolist())\n            all_num += len(s)\n            err_num += len(s) - len(s1 & s)\n    print(\"Topk Idx Error Rate:\", err_num / all_num)\n\n    # benchmark\n    @triton.testing.perf_report(\n        triton.testing.Benchmark(\n            x_names=[\"N\"],\n            x_vals=[1024 * 2**i for i in range(1, 8)],\n            line_arg=\"provider\",\n            line_vals=[\n                \"flash\",\n                \"triton-flash\",\n                \"triton-compressed\",\n                \"triton-compressed-wo-score\",\n            ],\n            line_names=[\n                \"Flash\",\n                \"Triton-Flash\",\n                \"Compressed\",\n                \"Compressed-wo-Score\",\n            ],\n            styles=[(\"green\", \"-\"), (\"green\", \"--\"), (\"blue\", \"-\"), (\"blue\", \"--\")],\n            ylabel=\"ms\",\n            plot_name=\"** forward speed for compressed attention (kernel 32 stride 16) **\",\n            args={\"H\": 64, \"D\": 128},\n        )\n    )\n    def benchmark(N, H, D, provider):\n        q = torch.randn((N, H, D), device=\"cuda\", dtype=torch.bfloat16)\n        k = torch.randn((N, H // 16, D), device=\"cuda\", dtype=torch.bfloat16)\n        v = torch.randn((N, H // 16, D), device=\"cuda\", dtype=torch.bfloat16)\n        cu_seqlens = torch.tensor([0, N], device=\"cuda\", dtype=torch.int32)\n        sm_scale = 1 / math.sqrt(D)\n        com_k, com_cu_seqlens = avgpool_compress(k, None, cu_seqlens, 32, 16, None)\n        com_v, com_cu_seqlens = avgpool_compress(v, None, cu_seqlens, 32, 16, None)\n        M = (com_cu_seqlens[1:] - com_cu_seqlens[:-1]).max().item()\n\n        quantiles = [0.5, 0.2, 0.8]\n        if provider == \"flash\":\n            ms, min_ms, max_ms = triton.testing.do_bench(\n                lambda: flash_attn_varlen_func(\n                    q,\n                    k,\n                    v,\n                    cu_seqlens,\n                    cu_seqlens,\n                    N,\n                    N,\n                    dropout_p=0.0,\n                    causal=True,\n                    softmax_scale=sm_scale,\n                ),\n                quantiles=quantiles,\n            )\n        if provider == \"triton-flash\":\n            ms, min_ms, max_ms = triton.testing.do_bench(\n                lambda: flash_attention_varlen(\n                    q, k, v, cu_seqlens, cu_seqlens, N, N, True, sm_scale\n                ),\n                quantiles=quantiles,\n            )\n        if provider == \"triton-compressed\":\n            ms, min_ms, max_ms = triton.testing.do_bench(\n                lambda: compressed_attention(\n                    q,\n                    com_k,\n                    com_v,\n                    32,\n                    16,\n                    64,\n                    16,\n                    cu_seqlens,\n                    com_cu_seqlens,\n                    N,\n                    M,\n                    sm_scale,\n                ),\n                quantiles=quantiles,\n            )\n        if provider == \"triton-compressed-wo-score\":\n            ms, min_ms, max_ms = triton.testing.do_bench(\n                lambda: compressed_attention(\n                    q,\n                    com_k,\n                    com_v,\n                    32,\n                    16,\n                    64,\n                    -1,\n                    cu_seqlens,\n                    com_cu_seqlens,\n                    N,\n                    M,\n                    sm_scale,\n                ),\n                quantiles=quantiles,\n            )\n        return ms, min_ms, max_ms\n\n    benchmark.run(show_plots=True, print_data=True)\n\n    # benchmark\n    @triton.testing.perf_report(\n        triton.testing.Benchmark(\n            x_names=[\"N\"],\n            x_vals=[1024 * 2**i for i in range(1, 8)],\n            line_arg=\"provider\",\n            line_vals=[\n                \"flash\",\n                \"triton-flash\",\n                \"triton-compressed\",\n            ],\n            line_names=[\n                \"Flash\",\n                \"Triton-Flash\",\n                \"Compressed\",\n            ],\n            styles=[(\"green\", \"-\"), (\"green\", \"--\"), (\"blue\", \"-\"), (\"blue\", \"--\")],\n            ylabel=\"ms\",\n            plot_name=\"** backward speed for compressed attention (kernel 32 stride 16) **\",\n            args={\"H\": 64, \"D\": 128},\n        )\n    )\n    def benchmark(N, H, D, provider):\n        q = torch.randn((N, H, D), device=\"cuda\", dtype=torch.bfloat16)\n        k = torch.randn((N, H // 16, D), device=\"cuda\", dtype=torch.bfloat16)\n        v = torch.randn((N, H // 16, D), device=\"cuda\", dtype=torch.bfloat16)\n        o = torch.randn((N, H, D), device=\"cuda\", dtype=torch.bfloat16)\n        do = torch.randn((N, H, D), device=\"cuda\", dtype=torch.bfloat16)\n        lse = torch.randn((H, N), device=\"cuda\", dtype=torch.float32)\n        sm_scale = 1 / math.sqrt(D)\n        cu_seqlens = torch.tensor([0, N], device=\"cuda\", dtype=torch.int32)\n        dq = torch.zeros_like(q)\n        dk = torch.zeros_like(k)\n        dv = torch.zeros_like(v)\n\n        com_k, com_cu_seqlens = avgpool_compress(k, None, cu_seqlens, 32, 16, None)\n        com_v, com_cu_seqlens = avgpool_compress(v, None, cu_seqlens, 32, 16, None)\n        M = (com_cu_seqlens[1:] - com_cu_seqlens[:-1]).max().item()\n\n        quantiles = [0.5, 0.2, 0.8]\n        if provider == \"flash\":\n            ms, min_ms, max_ms = triton.testing.do_bench(\n                lambda: _flash_attn_varlen_backward(\n                    do,\n                    q,\n                    k,\n                    v,\n                    o,\n                    lse.transpose(0, 1),\n                    dq,\n                    dk,\n                    dv,\n                    cu_seqlens,\n                    cu_seqlens,\n                    N,\n                    N,\n                    dropout_p=0.0,\n                    causal=True,\n                    softmax_scale=sm_scale,\n                    window_size=(-1, -1),\n                    softcap=0.0,\n                    alibi_slopes=None,\n                    deterministic=False,\n                ),\n                quantiles=quantiles,\n            )\n        if provider == \"triton-flash\":\n            ms, min_ms, max_ms = triton.testing.do_bench(\n                lambda: _flash_attention_bwd(\n                    o, do, lse, q, k, v, cu_seqlens, cu_seqlens, N, N, True, sm_scale\n                ),\n                quantiles=quantiles,\n            )\n        if provider == \"triton-compressed\":\n            ms, min_ms, max_ms = triton.testing.do_bench(\n                lambda: _compressed_attention_bwd(\n                    o,\n                    do,\n                    lse,\n                    q,\n                    com_k,\n                    com_v,\n                    32,\n                    16,\n                    cu_seqlens,\n                    com_cu_seqlens,\n                    N,\n                    M,\n                    sm_scale,\n                ),\n                quantiles=quantiles,\n            )\n        return ms, min_ms, max_ms\n\n    benchmark.run(show_plots=True, print_data=True)\n"
  },
  {
    "path": "test/test_flash_attention.py",
    "content": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific\nimport torch\nimport triton\nimport math\nfrom native_sparse_attention.ops.triton.flash_attention import (\n    flash_attention_varlen,\n    _flash_attention_fwd,\n    _flash_attention_bwd,\n)\nfrom flash_attn import flash_attn_varlen_func\nfrom flash_attn.flash_attn_interface import (\n    _flash_attn_varlen_forward,\n    _flash_attn_varlen_backward,\n)\n\n\nif __name__ == \"__main__\":\n    for causal in [False, True]:\n        # triton flash attention\n        torch.manual_seed(42)\n        q = torch.randn(\n            1000, 32, 128, dtype=torch.float16, device=\"cuda\", requires_grad=True\n        )\n        k = torch.randn(\n            1000, 16, 128, dtype=torch.float16, device=\"cuda\", requires_grad=True\n        )\n        v = torch.randn(\n            1000, 16, 128, dtype=torch.float16, device=\"cuda\", requires_grad=True\n        )\n        cu_seqlens_q = torch.Tensor([0, 100, 384, 1000]).cuda().to(torch.int32)\n        cu_seqlens_k = torch.Tensor([0, 100, 384, 1000]).cuda().to(torch.int32)\n        max_seqlen_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max()\n        max_seqlen_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max()\n        o = flash_attn_varlen_func(\n            q,\n            k,\n            v,\n            cu_seqlens_q,\n            cu_seqlens_k,\n            max_seqlen_q,\n            max_seqlen_k,\n            causal=causal,\n        )\n        randn = torch.randn_like(o)\n        loss = (o * randn).sum()\n        loss.backward()\n\n        # flash attention\n        torch.manual_seed(42)\n        q1 = q.clone().detach().requires_grad_()\n        k1 = k.clone().detach().requires_grad_()\n        v1 = v.clone().detach().requires_grad_()\n        cu_seqlens_q1 = cu_seqlens_q.clone().detach()\n        cu_seqlens_k1 = cu_seqlens_k.clone().detach()\n        max_seqlen_q1 = (cu_seqlens_q1[1:] - cu_seqlens_q1[:-1]).max()\n        max_seqlen_k1 = (cu_seqlens_k1[1:] - cu_seqlens_k1[:-1]).max()\n        o1 = flash_attention_varlen(\n            q1,\n            k1,\n            v1,\n            cu_seqlens_q1,\n            cu_seqlens_k1,\n            max_seqlen_q1,\n            max_seqlen_k1,\n            causal=causal,\n        )\n        randn2 = randn.clone().detach()\n        loss2 = (o1 * randn2).sum()\n        loss2.backward()\n\n        # diff\n        print(\n            f\"=== Flash Attention Backward Test ({'causal' if causal else 'full'}) ===\"\n        )\n        print(\"Same Output:\", torch.allclose(o, o1, atol=0.01, rtol=0.01))\n        print(\"Max Error:\", (o - o1).abs().max().item())\n        print()\n        print(\n            \"Same Query Gradient:\",\n            torch.allclose(q.grad, q1.grad, atol=0.01, rtol=0.01),\n        )\n        print(\"Max Query Gradient Error:\", (q.grad - q1.grad).abs().max().item())\n        print()\n        print(\n            \"Same Key Gradient:\", torch.allclose(k.grad, k1.grad, atol=0.01, rtol=0.01)\n        )\n        print(\"Max Key Gradient Error:\", (k.grad - k1.grad).abs().max().item())\n        print()\n        print(\n            \"Same Value Gradient:\",\n            torch.allclose(v.grad, v1.grad, atol=0.01, rtol=0.01),\n        )\n        print(\"Max Value Gradient Error:\", (v.grad - v1.grad).abs().max().item())\n        print()\n\n    # benchmark\n    @triton.testing.perf_report(\n        triton.testing.Benchmark(\n            x_names=[\"N\"],\n            x_vals=[1024 * 2**i for i in range(1, 6)],\n            line_arg=\"provider\",\n            line_vals=[\"flash\", \"triton-flash\"],\n            line_names=[\n                \"Flash\",\n                \"Triton-Flash\",\n            ],\n            styles=[(\"green\", \"-\"), (\"green\", \"--\")],\n            ylabel=\"ms\",\n            plot_name=\"** forward **\",\n            args={\"H\": 64, \"D\": 128},\n        )\n    )\n    def benchmark(N, H, D, provider):\n        q = torch.randn((N, H, D), device=\"cuda\", dtype=torch.bfloat16)\n        k = torch.randn((N, H // 16, D), device=\"cuda\", dtype=torch.bfloat16)\n        v = torch.randn((N, H // 16, D), device=\"cuda\", dtype=torch.bfloat16)\n        cu_seqlens = torch.tensor([0, N], device=\"cuda\", dtype=torch.int32)\n        sm_scale = 1 / math.sqrt(D)\n\n        quantiles = [0.5, 0.2, 0.8]\n        if provider == \"flash\":\n            ms, min_ms, max_ms = triton.testing.do_bench(\n                lambda: _flash_attn_varlen_forward(\n                    q,\n                    k,\n                    v,\n                    cu_seqlens,\n                    cu_seqlens,\n                    N,\n                    N,\n                    dropout_p=0.0,\n                    causal=True,\n                    softmax_scale=sm_scale,\n                ),\n                quantiles=quantiles,\n            )\n        if provider == \"triton-flash\":\n            ms, min_ms, max_ms = triton.testing.do_bench(\n                lambda: _flash_attention_fwd(\n                    q, k, v, cu_seqlens, cu_seqlens, N, N, True, sm_scale\n                ),\n                quantiles=quantiles,\n            )\n        return ms, min_ms, max_ms\n\n    benchmark.run(show_plots=True, print_data=True)\n\n    # benchmark\n    @triton.testing.perf_report(\n        triton.testing.Benchmark(\n            x_names=[\"N\"],\n            x_vals=[1024 * 2**i for i in range(1, 6)],\n            line_arg=\"provider\",\n            line_vals=[\"flash\", \"triton-flash\"],\n            line_names=[\n                \"Flash\",\n                \"Triton-Flash\",\n            ],\n            styles=[(\"green\", \"-\"), (\"green\", \"--\")],\n            ylabel=\"ms\",\n            plot_name=\"** backward **\",\n            args={\"H\": 64, \"D\": 128},\n        )\n    )\n    def benchmark(N, H, D, provider):\n        q = torch.randn((N, H, D), device=\"cuda\", dtype=torch.bfloat16)\n        k = torch.randn((N, H // 16, D), device=\"cuda\", dtype=torch.bfloat16)\n        v = torch.randn((N, H // 16, D), device=\"cuda\", dtype=torch.bfloat16)\n        o = torch.randn((N, H, D), device=\"cuda\", dtype=torch.bfloat16)\n        do = torch.randn((N, H, D), device=\"cuda\", dtype=torch.bfloat16)\n        lse = torch.randn((H, N), device=\"cuda\", dtype=torch.float32)\n        sm_scale = 1 / math.sqrt(D)\n        cu_seqlens = torch.tensor([0, N], device=\"cuda\", dtype=torch.int32)\n        dq = torch.zeros_like(q)\n        dk = torch.zeros_like(k)\n        dv = torch.zeros_like(v)\n\n        quantiles = [0.5, 0.2, 0.8]\n        if provider == \"flash\":\n            ms, min_ms, max_ms = triton.testing.do_bench(\n                lambda: _flash_attn_varlen_backward(\n                    do,\n                    q,\n                    k,\n                    v,\n                    o,\n                    lse.transpose(0, 1),\n                    dq,\n                    dk,\n                    dv,\n                    cu_seqlens,\n                    cu_seqlens,\n                    N,\n                    N,\n                    dropout_p=0.0,\n                    causal=True,\n                    softmax_scale=sm_scale,\n                    window_size=(-1, -1),\n                    softcap=0.0,\n                    alibi_slopes=None,\n                    deterministic=False,\n                ),\n                quantiles=quantiles,\n            )\n        if provider == \"triton-flash\":\n            ms, min_ms, max_ms = triton.testing.do_bench(\n                lambda: _flash_attention_bwd(\n                    o, do, lse, q, k, v, cu_seqlens, cu_seqlens, N, N, True, sm_scale\n                ),\n                quantiles=quantiles,\n            )\n        return ms, min_ms, max_ms\n\n    benchmark.run(show_plots=True, print_data=True)\n"
  },
  {
    "path": "test/test_kv_cache.py",
    "content": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport torch\nfrom native_sparse_attention.module.kv_cache import NSACache\n\n\nif __name__ == \"__main__\":\n    from native_sparse_attention.ops import avgpool_compress\n\n    torch.manual_seed(42)\n\n    num_heads = 4\n    head_dim = 128\n    seqlens = torch.tensor([12, 576, 12000]).to(torch.int32).cuda()\n    batch_size = seqlens.shape[0]\n    cu_seqlens = torch.zeros(seqlens.shape[0] + 1, dtype=torch.int32, device=\"cuda\")\n    cu_seqlens[1:] = seqlens.cumsum(0)\n\n    # init cache\n    cache = NSACache(4, 16384, num_heads, head_dim, 32, 16, 512, torch.bfloat16, \"cuda\")\n\n    # test prefill\n    step = 0\n    k = torch.randn(cu_seqlens[-1], num_heads, head_dim).cuda().bfloat16()\n    v = torch.randn_like(k)\n    ck, _ = avgpool_compress(k, None, cu_seqlens, 32, 16, None)\n    cv, _ = avgpool_compress(v, None, cu_seqlens, 32, 16, None)\n    cache.prepare_compress(cu_seqlens, step, k, v)\n    cache.update_kv(cu_seqlens, step, ck, cv, k, v, k, v)\n\n    # test decode\n    step = 1\n    k = torch.randn(batch_size, num_heads, head_dim).cuda().bfloat16()\n    v = torch.randn_like(k)\n    ck = torch.randn_like(k)\n    cv = torch.randn_like(v)\n    cache.prepare_compress(cu_seqlens, step, k, v)\n    cache.update_kv(cu_seqlens, step, ck, cv, k, v, k, v)\n"
  },
  {
    "path": "test/test_linear_compress.py",
    "content": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nimport triton\nfrom native_sparse_attention.ops.torch.compress_key_value import linear_compress_torch\nfrom native_sparse_attention.ops.triton.linear_compress import linear_compress\n\n\ndef test_linear_compress(\n    batch_size: int = 1,\n    num_heads: int = 1,\n    head_dim: int = 32,\n    max_seqlen: int = 32,\n    kernel_sizes: list = [16, 32],\n    kernel_strides: list = [8, 16],\n    use_pe: bool = True,\n    dtype: torch.dtype = torch.float32,\n    device: str = \"cuda\",\n):\n    \"\"\"\n    Test both PyTorch and Triton implementations of linear_compress for equivalence,\n    including forward and backward passes.\n\n    Args:\n        batch_size: Number of sequences in the batch\n        num_heads: Number of attention heads\n        head_dim: Dimension of each attention head\n        max_seqlen: Maximum sequence length\n        kernel_sizes: List of kernel sizes to test\n        kernel_strides: List of kernel strides to test\n        use_pe: Whether to test with positional encoding\n        dtype: Data type for tensors\n        device: Device to run the test on\n    \"\"\"\n    torch.manual_seed(42)\n\n    # Generate random sequence lengths for each batch\n\n    seqlens = torch.randint(\n        low=kernel_sizes[0],  # minimum length should be at least kernel_size\n        high=max_seqlen + 1,\n        size=(batch_size,),\n        device=device,\n    )\n    # seqlens[:] = max_seqlen\n    cu_seqlens = torch.cat(\n        [\n            torch.tensor([0], device=device, dtype=torch.int32),\n            torch.cumsum(seqlens, dim=0).to(torch.int32),\n        ]\n    )\n\n    total_len = cu_seqlens[-1].item()\n\n    for kernel_size, kernel_stride in zip(kernel_sizes, kernel_strides):\n        print(f\"\\nTesting kernel_size={kernel_size}, kernel_stride={kernel_stride}\")\n\n        # Create input tensors with requires_grad=True\n        x_torch = torch.zeros(\n            (total_len, num_heads, head_dim),\n            dtype=dtype,\n            device=device,\n        ).uniform_(-1, 1)\n        x_torch.requires_grad_(True)\n\n        x_triton = x_torch.clone().detach().requires_grad_(True)\n\n        w_torch = (\n            torch.ones(\n                (num_heads, kernel_size * head_dim, head_dim),\n                dtype=dtype,\n                device=device,\n            )\n            / kernel_size\n        )\n        w_torch.requires_grad_(True)\n\n        w_triton = w_torch.clone().detach().requires_grad_(True)\n\n        pe_torch = None\n        pe_triton = None\n        if use_pe:\n            pe_torch = torch.randn(\n                (num_heads, kernel_size, head_dim),\n                dtype=dtype,\n                device=device,\n                requires_grad=True,\n            )\n            pe_triton = pe_torch.clone().detach().requires_grad_(True)\n\n        # Run forward passes\n        y_torch, y_cu_seqlens_torch = linear_compress_torch(\n            x=x_torch,\n            w=w_torch,\n            cu_seqlens=cu_seqlens,\n            kernel_size=kernel_size,\n            kernel_stride=kernel_stride,\n            pe=pe_torch,\n        )\n\n        y_triton, y_cu_seqlens_triton = linear_compress(\n            x=x_triton,\n            w=w_triton,\n            cu_seqlens=cu_seqlens,\n            kernel_size=kernel_size,\n            kernel_stride=kernel_stride,\n            pe=pe_triton,\n        )\n\n        # Check forward pass numerical equivalence\n        atol, rtol = 1e-2, 1e-2\n        values_match = torch.allclose(y_torch, y_triton, atol=atol, rtol=rtol)\n        print(\n            f\"Forward pass - Output values match (atol={atol}, rtol={rtol}): {values_match}\"\n        )\n        if not values_match:\n            max_diff = (y_torch - y_triton).abs().max().item()\n            print(f\"Forward pass - Maximum difference: {max_diff}\")\n            print(\"\\nSample values (first batch, first head):\")\n            print(\"Torch:\", y_torch[0, 0, :5])\n            print(\"Triton:\", y_triton[0, 0, :5])\n\n        # Create random output gradients for backward pass\n        grad_output = torch.randn_like(y_torch)\n\n        # Run backward passes\n        y_torch.backward(grad_output)\n        y_triton.backward(grad_output)\n\n        # Check gradient equivalence\n        print(\"\\nTesting backward pass:\")\n\n        # Check x gradients\n        x_grads_match = torch.allclose(\n            x_torch.grad, x_triton.grad, atol=atol, rtol=rtol\n        )\n        print(f\"x gradients match (atol={atol}, rtol={rtol}): {x_grads_match}\")\n        if not x_grads_match:\n            max_diff = (x_torch.grad - x_triton.grad).abs().max().item()\n            print(f\"x gradients - Maximum difference: {max_diff}\")\n            print(\"\\nSample x gradients (first batch, first head):\")\n            print(\"Torch:\", x_torch.grad[0, 0, :5])\n            print(\"Triton:\", x_triton.grad[0, 0, :5])\n\n        # Check w gradients\n        w_grads_match = torch.allclose(\n            w_torch.grad, w_triton.grad, atol=atol, rtol=rtol\n        )\n        print(f\"w gradients match (atol={atol}, rtol={rtol}): {w_grads_match}\")\n        if not w_grads_match:\n            max_diff = (w_torch.grad - w_triton.grad).abs().max().item()\n            print(f\"w gradients - Maximum difference: {max_diff}\")\n            print(\"\\nSample w gradients (first head):\")\n            print(\"Torch:\", w_torch.grad[0, :5, 0])\n            print(\"Triton:\", w_triton.grad[0, :5, 0])\n\n        # Check pe gradients if used\n        if use_pe:\n            pe_grads_match = torch.allclose(\n                pe_torch.grad, pe_triton.grad, atol=atol, rtol=rtol\n            )\n            print(f\"pe gradients match (atol={atol}, rtol={rtol}): {pe_grads_match}\")\n            if not pe_grads_match:\n                max_diff = (pe_torch.grad - pe_triton.grad).abs().max().item()\n                print(f\"pe gradients - Maximum difference: {max_diff}\")\n                print(\"\\nSample pe gradients (first head):\")\n                print(\"Torch:\", pe_torch.grad[0, :5, 0])\n                print(\"Triton:\", pe_triton.grad[0, :5, 0])\n\n        # Clean up gradients for next iteration\n        x_torch.grad = None\n        x_triton.grad = None\n        w_torch.grad = None\n        w_triton.grad = None\n        if use_pe:\n            pe_torch.grad = None\n            pe_triton.grad = None\n\n\nif __name__ == \"__main__\":\n    # Run tests\n    test_linear_compress(\n        batch_size=16,\n        num_heads=8,\n        head_dim=128,\n        max_seqlen=2048,\n        kernel_sizes=[32],\n        kernel_strides=[16],\n        use_pe=False,\n        dtype=torch.float16,\n        device=\"cuda\",\n    )\n\n    # benchmark\n    @triton.testing.perf_report(\n        triton.testing.Benchmark(\n            x_names=[\"N\"],\n            x_vals=[1024 * 2**i for i in range(1, 8)],\n            line_arg=\"provider\",\n            line_vals=[\"torch\", \"triton\"],\n            line_names=[\"torch\", \"triton\"],\n            styles=[(\"green\", \"-\"), (\"blue\", \"-\")],\n            ylabel=\"ms\",\n            plot_name=\"** forward + backward **\",\n            args={\"H\": 4, \"D\": 64},\n        )\n    )\n    def benchmark_fwdbwd(N, H, D, provider):\n        K, S = 32, 16\n        # Input tensors\n        x = torch.zeros(N, H, D, device=\"cuda\", dtype=torch.bfloat16).uniform_(-1, 1)\n        x.requires_grad = True\n        w = torch.zeros(H, K * D, D, device=\"cuda\", dtype=torch.bfloat16).uniform_(\n            -1, 1\n        )\n        w.requires_grad = True\n        pe = torch.zeros(H, K, D, device=\"cuda\", dtype=torch.bfloat16).uniform_(-1, 1)\n        cu_seqlens_b32 = (\n            torch.LongTensor(\n                [0 if i == 0 else 32 if i > 1 else N - 32 * 31 for i in range(33)]\n            )\n            .int()\n            .cuda()\n        )\n        cu_seqlens_b32 = cu_seqlens_b32.cumsum(0).to(torch.int32)\n\n        quantiles = [0.5, 0.2, 0.8]\n\n        def fwd_bwd():\n            if provider == \"torch\":\n                out, _ = linear_compress_torch(x, w, cu_seqlens_b32, K, S, pe)\n            else:\n                out, _ = linear_compress(x, w, cu_seqlens_b32, K, S, pe)\n            out.backward(out)  # Using output as gradient for simplicity\n            return out\n\n        ms, min_ms, max_ms = triton.testing.do_bench(fwd_bwd, quantiles=quantiles)\n        return ms, min_ms, max_ms\n\n    benchmark_fwdbwd.run(show_plots=True, print_data=True)\n"
  },
  {
    "path": "test/test_nsa_infer.py",
    "content": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport torch\nfrom native_sparse_attention.ops import linear_compress, weightedpool_compress\nfrom native_sparse_attention.module import NSACache, RotaryEmbedding, RopeConfig\nfrom native_sparse_attention.infer import nsa_infer\n\n\nif __name__ == \"__main__\":\n    torch.manual_seed(42)\n\n    num_heads = 4\n    head_dim = 128\n    kernel_size = 32\n    kernel_stride = 16\n    block_size = 64\n    window_size = 512\n    topk = 16\n    init_blocks = 1\n    local_blocks = 2\n\n    # init seqlens\n    seqlens = torch.tensor([12, 576, 12000]).to(torch.int32).cuda()\n    batch_size = seqlens.shape[0]\n    cu_seqlens = torch.zeros(seqlens.shape[0] + 1, dtype=torch.int32, device=\"cuda\")\n    cu_seqlens[1:] = seqlens.cumsum(0)\n    step = 0\n\n    # init cache and weight and rope\n    cache = NSACache(4, 16384, num_heads, head_dim, 32, 16, 512, torch.bfloat16, \"cuda\")\n    compress_weight = [\n        torch.ones(num_heads, kernel_size * head_dim, head_dim).cuda().bfloat16()\n        / (kernel_size * head_dim),\n        torch.ones(num_heads, kernel_size).cuda().bfloat16() / kernel_size,\n    ]\n    compress_func = [linear_compress, weightedpool_compress]\n    rope = RotaryEmbedding(\n        RopeConfig(\n            max_position_embeddings=131072,\n            head_dim=128,\n            rope_theta=500000,\n            rope_scaling={\n                \"factor\": 8.0,\n                \"high_freq_factor\": 4.0,\n                \"low_freq_factor\": 1.0,\n                \"original_max_position_embeddings\": 8192,\n                \"rope_type\": \"llama3\",\n            },\n        )\n    )\n\n    # test prefill\n    q = torch.randn(cu_seqlens[-1], num_heads * 16, head_dim).cuda().bfloat16()\n    k = torch.randn(cu_seqlens[-1], num_heads, head_dim).cuda().bfloat16()\n    v = torch.randn_like(k)\n    g = torch.rand(cu_seqlens[-1], num_heads * 16, 3).cuda().bfloat16()\n    o = nsa_infer(\n        cu_seqlens,\n        step,\n        q,\n        k,\n        v,\n        g,\n        rope,\n        cache,\n        compress_weight,\n        compress_func,\n        None,\n        kernel_size,\n        kernel_stride,\n        block_size,\n        topk,\n        init_blocks,\n        local_blocks,\n        window_size,\n    )\n    print(o.shape, o.norm())\n\n    # test decode\n    q = torch.randn(cu_seqlens.shape[0] - 1, num_heads * 16, head_dim).cuda().bfloat16()\n    k = torch.randn(cu_seqlens.shape[0] - 1, num_heads, head_dim).cuda().bfloat16()\n    v = torch.randn_like(k)\n    g = torch.rand(cu_seqlens.shape[0] - 1, num_heads * 16, 3).cuda().bfloat16()\n    step = 1\n    o = nsa_infer(\n        cu_seqlens,\n        step,\n        q,\n        k,\n        v,\n        g,\n        rope,\n        cache,\n        compress_weight,\n        compress_func,\n        None,\n        kernel_size,\n        kernel_stride,\n        block_size,\n        topk,\n        init_blocks,\n        local_blocks,\n        window_size,\n    )\n    print(o.shape, o.norm())\n"
  },
  {
    "path": "test/test_nsa_model.py",
    "content": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport torch\nfrom native_sparse_attention.model import (\n    ToyNSALlamaConfig,\n    InferenceConfig,\n    ToyNSALlama,\n)\n\n\nif __name__ == \"__main__\":\n    torch.manual_seed(42)\n    # initialize model\n    config = ToyNSALlamaConfig(\n        hidden_size=4096,\n        intermediate_size=14336,\n        num_hidden_layers=8,\n        num_attention_heads=32,\n        num_key_value_heads=2,\n        head_dim=128,\n        rope_theta=500000.0,\n        rope_scaling={\n            \"factor\": 8.0,\n            \"high_freq_factor\": 4.0,\n            \"low_freq_factor\": 1.0,\n            \"original_max_position_embeddings\": 8192,\n            \"rope_type\": \"llama3\",\n        },\n        compress_type=\"weightedpool\",\n        kernel_size=32,\n        kernel_stride=16,\n        block_size=64,\n        topk=8,\n        init_blocks=1,\n        local_blocks=2,\n        window_size=512,\n    )\n    inference_config = InferenceConfig(\n        max_batch_size=4,\n        max_length=8192,\n        max_new_tokens=128,\n    )\n    model = ToyNSALlama(config, inference_config).cuda().bfloat16()\n    print(f\"\\nMODEL CONFIG:\\n{config}\\n\")\n    print(f\"\\nINFERENCE CONFIG:\\n{inference_config}\\n\")\n    print(f\"\\nMODEL:\\n{model}\\n\")\n\n    # example input\n    batch_size = 4\n    seqlens = torch.randint(0, 4096, (batch_size,), dtype=torch.int32, device=\"cuda\")\n    cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device=\"cuda\")\n    cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)\n    input_ids = torch.randint(\n        0, 128288, (cu_seqlens[-1],), dtype=torch.int64, device=\"cuda\"\n    )\n    print(f\"\\nEXAMPLE INPUT:\\ncu_seqlens: {cu_seqlens}\\ninput_ids: {input_ids.shape}\\n\")\n\n    # example output\n    logits = model(input_ids, cu_seqlens)\n    print(f\"\\nEXAMPLE OUTPUT:\\nlogits: {logits.shape}\\n\")\n\n    # example generate\n    output_tokens = model.generate(input_ids, cu_seqlens, 64)\n    print(f\"\\nEXAMPLE GENERATE:\\noutput_tokens: {output_tokens}\\n\")\n"
  },
  {
    "path": "test/test_nsa_module.py",
    "content": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific\nimport torch\nimport triton\nfrom native_sparse_attention.module import (\n    SelfAttention,\n    NativeSparseAttention,\n    RopeConfig,\n)\n\n\nif __name__ == \"__main__\":\n    torch.manual_seed(42)\n    NSA = (\n        NativeSparseAttention(\n            compress_type=\"avgpool\",\n            hidden_size=8192,\n            num_q_heads=64,\n            num_kv_heads=4,\n            head_dim=128,\n            kernel_size=32,\n            kernel_stride=16,\n            block_size=64,\n            topk=16,\n            init_blocks=1,\n            local_blocks=2,\n            window_size=512,\n            rope_config=RopeConfig(\n                max_position_embeddings=131072,\n                head_dim=128,\n                rope_theta=500000,\n                rope_scaling={\n                    \"factor\": 8.0,\n                    \"high_freq_factor\": 4.0,\n                    \"low_freq_factor\": 1.0,\n                    \"original_max_position_embeddings\": 8192,\n                    \"rope_type\": \"llama3\",\n                },\n            ),\n        )\n        .cuda()\n        .to(torch.bfloat16)\n    )\n    print(\"======= Init Moduel: Native Sparse Attention =======\\n\")\n    for name, param in NSA.named_parameters():\n        print(f\"NSA Parameters, {name}, shape: {param.shape}\\n\")\n\n    # random input\n    seqlens = torch.LongTensor([4000, 8192, 16384]).int().cuda()\n    cu_seqlens = torch.cat(\n        [\n            torch.zeros(1, dtype=torch.int32, device=\"cuda\"),\n            torch.cumsum(seqlens, dim=0),\n        ],\n        dim=0,\n    ).to(torch.int32)\n    x = torch.zeros(cu_seqlens[-1], 8192, device=\"cuda\", dtype=torch.bfloat16).uniform_(\n        -1, 1\n    )\n\n    # forward test\n    print(\"======= NSA Forward & Backward Test =======\\n\")\n    y = NSA(x, cu_seqlens)\n    print(f\"Forward, output shape: {y.shape}, output norm: {y.norm()}\\n\")\n\n    # backward test\n    loss = (y * torch.randn_like(y)).sum(-1).mean()\n    loss.backward()\n    for name, param in NSA.named_parameters():\n        print(\n            f\"Backward, {name}, grad shape: {param.grad.shape}, grad norm: {param.grad.norm()}\\n\"\n        )\n\n    # speed benchmark\n    SelfAttn = (\n        SelfAttention(\n            hidden_size=8192,\n            num_q_heads=64,\n            num_kv_heads=4,\n            head_dim=128,\n            rope_config=RopeConfig(\n                max_position_embeddings=131072,\n                head_dim=128,\n                rope_theta=500000,\n                rope_scaling={\n                    \"factor\": 8.0,\n                    \"high_freq_factor\": 4.0,\n                    \"low_freq_factor\": 1.0,\n                    \"original_max_position_embeddings\": 8192,\n                    \"rope_type\": \"llama3\",\n                },\n            ),\n        )\n        .cuda()\n        .to(torch.bfloat16)\n    )\n\n    @triton.testing.perf_report(\n        triton.testing.Benchmark(\n            x_names=[\"N\"],\n            x_vals=[1024 * 2**i for i in range(1, 8)],\n            line_arg=\"provider\",\n            line_vals=[\"Self-Attention\", \"Native-Sparse-Attention\"],\n            line_names=[\"Self-Attention\", \"Native-Sparse-Attention\"],\n            styles=[(\"green\", \"-\"), (\"blue\", \"-\")],\n            ylabel=\"ms\",\n            plot_name=\"** NSA forward speed benchmark **\",\n            args={},\n        )\n    )\n    def benchmark(N, provider):\n        x = torch.randn(N, 8192, device=\"cuda\", dtype=torch.bfloat16)\n        cu_seqlens = torch.tensor([0, N], device=\"cuda\", dtype=torch.int32)\n        quantiles = [0.5, 0.2, 0.8]\n        with torch.no_grad():\n            if provider == \"Self-Attention\":\n                ms, min_ms, max_ms = triton.testing.do_bench(\n                    lambda: SelfAttn(x, cu_seqlens),\n                    quantiles=quantiles,\n                )\n            if provider == \"Native-Sparse-Attention\":\n                ms, min_ms, max_ms = triton.testing.do_bench(\n                    lambda: NSA(x, cu_seqlens),\n                    quantiles=quantiles,\n                )\n        return ms, min_ms, max_ms\n\n    benchmark.run(show_plots=True, print_data=True)\n\n    @triton.testing.perf_report(\n        triton.testing.Benchmark(\n            x_names=[\"N\"],\n            x_vals=[1024 * 2**i for i in range(1, 8)],\n            line_arg=\"provider\",\n            line_vals=[\"Self-Attention\", \"Native-Sparse-Attention\"],\n            line_names=[\"Self-Attention\", \"Native-Sparse-Attention\"],\n            styles=[(\"green\", \"-\"), (\"blue\", \"-\")],\n            ylabel=\"ms\",\n            plot_name=\"** NSA backward speed benchmark **\",\n            args={},\n        )\n    )\n    def benchmark(N, provider):\n        x = torch.randn(N, 8192, device=\"cuda\", dtype=torch.bfloat16)\n        cu_seqlens = torch.tensor([0, N], device=\"cuda\", dtype=torch.int32)\n        quantiles = [0.5, 0.2, 0.8]\n        if provider == \"Self-Attention\":\n            loss = SelfAttn(x.clone().detach().requires_grad_(), cu_seqlens).mean()\n            ms, min_ms, max_ms = triton.testing.do_bench(\n                lambda: loss.backward(retain_graph=True),\n                quantiles=quantiles,\n            )\n        elif provider == \"Native-Sparse-Attention\":\n            loss = NSA(x.clone().detach().requires_grad_(), cu_seqlens).mean()\n            ms, min_ms, max_ms = triton.testing.do_bench(\n                lambda: loss.backward(retain_graph=True),\n                quantiles=quantiles,\n            )\n        return ms, min_ms, max_ms\n\n    benchmark.run(show_plots=True, print_data=True)\n"
  },
  {
    "path": "test/test_rope.py",
    "content": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport torch\nfrom native_sparse_attention.module import RopeConfig, RotaryEmbedding\n\nif __name__ == \"__main__\":\n    rope_config = RopeConfig(\n        max_position_embeddings=131072,\n        head_dim=128,\n        rope_theta=500000,\n        rope_scaling={\n            \"factor\": 8.0,\n            \"high_freq_factor\": 4.0,\n            \"low_freq_factor\": 1.0,\n            \"original_max_position_embeddings\": 8192,\n            \"rope_type\": \"llama3\",\n        },\n    )\n    rope = RotaryEmbedding(rope_config, \"cuda\")\n\n    # random input\n    torch.manual_seed(42)\n    seqlens = torch.LongTensor([1000, 2000, 4096]).int().cuda()\n    cu_seqlens = torch.cat(\n        [\n            torch.zeros(1, dtype=torch.int32, device=\"cuda\"),\n            torch.cumsum(seqlens, dim=0),\n        ],\n        dim=0,\n    ).to(torch.int32)\n    x = torch.zeros(\n        cu_seqlens[-1], 32, 128, device=\"cuda\", dtype=torch.bfloat16\n    ).uniform_(-1, 1)\n    y = rope(x, cu_seqlens)\n"
  },
  {
    "path": "test/test_topk_sparse_attention.py",
    "content": "# Copyright 2025 Xunhao Lai & Jianqiao Lu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific\nimport torch\nimport triton\nimport math\nfrom native_sparse_attention.ops.torch.topk_sparse_attention import (\n    topk_sparse_attention_torch,\n)\nfrom native_sparse_attention.ops.triton.topk_sparse_attention import (\n    topk_sparse_attention,\n    _topk_sparse_attention_fwd,\n    _topk_sparse_attention_bwd,\n)\nfrom native_sparse_attention.ops.triton.flash_attention import (\n    _flash_attention_fwd,\n    _flash_attention_bwd,\n)\nfrom flash_attn.flash_attn_interface import (\n    _flash_attn_varlen_forward,\n    _flash_attn_varlen_backward,\n)\n\n\ndef generate_topk_idx_example(\n    seqlens: torch.Tensor,\n    block_size_k: int,\n    topk: int,\n    num_heads: int,\n    block_size_q: int = 1,\n) -> torch.Tensor:\n    \"\"\"Generate topk idx example for test.\n\n    Args:\n        seqlens (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens in flash_attn_func_varlen.\n        block_size_q (int): query block size\n        block_size_k (int): key value block size\n        topk (int): selected topk\n        num_heads (int): number of key value heads\n\n    Returns:\n        torch.Tensor: shape [num_heads, total_seqlen, topk], topk key value block idx for each query. -1 means padding.\n    \"\"\"\n    batch_size = seqlens.shape[0]\n    num_blocks = torch.ceil(seqlens / block_size_k).to(torch.int32)\n    topk_idx_all_heads = []\n    cu_seqlens = torch.nn.functional.pad(seqlens.cumsum(0), pad=(1, 0), value=0)\n    for _ in range(num_heads):\n        topk_idx = [\n            torch.randn(seqlens[i], num_blocks[i], device=\"cuda\")\n            .topk(min(topk, num_blocks[i]), dim=-1)\n            .indices.to(torch.int32)\n            for i in range(batch_size)\n        ]\n        topk_idx = [\n            torch.nn.functional.pad(\n                topk_idx[i], (0, topk - topk_idx[i].shape[-1]), value=topk\n            )\n            for i in range(batch_size)\n        ]\n        topk_idx = torch.cat(topk_idx, dim=0)\n        topk_idx = torch.sort(topk_idx, dim=1).values\n        topk_idx[:, 0] = 0\n        q_idx = torch.cat(\n            [torch.arange(seqlens[i], device=\"cuda\") for i in range(batch_size)], dim=0\n        )\n        topk_idx[topk_idx > (q_idx // block_size_k)[:, None]] = -1  # -1 means padding\n        topk_idx = torch.cat(\n            [\n                topk_idx[cu_seqlens[i] : cu_seqlens[i + 1]][0::block_size_q]\n                for i in range(batch_size)\n            ],\n            dim=0,\n        )\n        topk_idx_all_heads.append(topk_idx)\n    topk_idx = torch.stack(topk_idx_all_heads, dim=0)\n    return topk_idx\n\n\nif __name__ == \"__main__\":\n    torch.manual_seed(42)\n    batch_size = 3\n    seqlens = torch.LongTensor([1000, 2000, 4096]).int().cuda()\n    cu_seqlens = torch.cat(\n        [\n            torch.zeros(1, dtype=torch.int32, device=\"cuda\"),\n            torch.cumsum(seqlens, dim=0),\n        ],\n        dim=0,\n    ).to(torch.int32)\n    max_seqlen = seqlens.max().item()\n    q = (\n        torch.empty(cu_seqlens[-1], 64, 96, device=\"cuda\")\n        .uniform_(-1, 1)\n        .to(torch.float16)\n    )\n    k = (\n        torch.empty(cu_seqlens[-1], 8, 96, device=\"cuda\")\n        .uniform_(-1, 1)\n        .to(torch.float16)\n    )\n    v = (\n        torch.empty(cu_seqlens[-1], 8, 96, device=\"cuda\")\n        .uniform_(-1, 1)\n        .to(torch.float16)\n    )\n    q.requires_grad = True\n    k.requires_grad = True\n    v.requires_grad = True\n    block_size = 64\n    topk = 5\n    topk_idx = generate_topk_idx_example(seqlens, block_size, topk, 8)\n\n    o = topk_sparse_attention_torch(q, k, v, topk_idx, block_size, cu_seqlens)\n\n    randn = torch.randn_like(o)\n    loss = (o * randn).sum()\n    loss.backward()\n\n    torch.manual_seed(42)\n    q1 = q.clone().detach().requires_grad_()\n    k1 = k.clone().detach().requires_grad_()\n    v1 = v.clone().detach().requires_grad_()\n    topk_idx1 = topk_idx.clone().detach()\n    cu_seqlens1 = cu_seqlens.clone().detach()\n\n    o1 = topk_sparse_attention(q1, k1, v1, topk_idx, block_size, cu_seqlens)\n\n    randn2 = randn.clone().detach()\n    loss2 = (o1 * randn2).sum()\n    loss2.backward()\n\n    print(\"Same Output:\", torch.allclose(o, o1, atol=0.01, rtol=0.01))\n    print(\"Max Error:\", (o - o1).abs().max().item())\n    print()\n    print(\"Same Query Gradient:\", torch.allclose(q.grad, q1.grad, atol=0.01, rtol=0.01))\n    print(\"Max Query Gradient Error:\", (q.grad - q1.grad).abs().max().item())\n    print()\n    print(\"Same Key Gradient:\", torch.allclose(k.grad, k1.grad, atol=0.01, rtol=0.01))\n    print(\"Max Key Gradient Error:\", (k.grad - k1.grad).abs().max().item())\n    print()\n    print(\"Same Value Gradient:\", torch.allclose(v.grad, v1.grad, atol=0.01, rtol=0.01))\n    print(\"Max Value Gradient Error:\", (v.grad - v1.grad).abs().max().item())\n    print()\n\n    # benchmark\n    @triton.testing.perf_report(\n        triton.testing.Benchmark(\n            x_names=[\"N\"],\n            x_vals=[1024 * 2**i for i in range(1, 8)],\n            line_arg=\"provider\",\n            line_vals=[\"flash\", \"triton-flash\", \"triton-top8\", \"triton-top16\"],\n            line_names=[\n                \"Flash\",\n                \"Triton-Flash\",\n                \"Triton-Top8\",\n                \"Triton-Top16\",\n            ],\n            styles=[(\"green\", \"-\"), (\"green\", \"--\"), (\"blue\", \"-\"), (\"blue\", \"--\")],\n            ylabel=\"ms\",\n            plot_name=\"** forward with block size 64 **\",\n            args={\"H\": 64, \"D\": 128, \"K\": 64},\n        )\n    )\n    def benchmark(N, H, D, K, provider):\n        q = torch.randn((N, H, D), device=\"cuda\", dtype=torch.bfloat16)\n        k = torch.randn((N, H // 16, D), device=\"cuda\", dtype=torch.bfloat16)\n        v = torch.randn((N, H // 16, D), device=\"cuda\", dtype=torch.bfloat16)\n        cu_seqlens = torch.tensor([0, N], device=\"cuda\", dtype=torch.int32)\n        sm_scale = 1 / math.sqrt(D)\n\n        top8_idx = generate_topk_idx_example(cu_seqlens[1:], K, 8, H // 16)\n        top16_idx = generate_topk_idx_example(cu_seqlens[1:], K, 16, H // 16)\n\n        quantiles = [0.5, 0.2, 0.8]\n        if provider == \"flash\":\n            ms, min_ms, max_ms = triton.testing.do_bench(\n                lambda: _flash_attn_varlen_forward(\n                    q,\n                    k,\n                    v,\n                    cu_seqlens,\n                    cu_seqlens,\n                    N,\n                    N,\n                    dropout_p=0.0,\n                    causal=True,\n                    softmax_scale=sm_scale,\n                ),\n                quantiles=quantiles,\n            )\n        if provider == \"triton-flash\":\n            ms, min_ms, max_ms = triton.testing.do_bench(\n                lambda: _flash_attention_fwd(\n                    q, k, v, cu_seqlens, cu_seqlens, N, N, True, sm_scale\n                ),\n                quantiles=quantiles,\n            )\n        if provider == \"triton-top8\":\n            ms, min_ms, max_ms = triton.testing.do_bench(\n                lambda: _topk_sparse_attention_fwd(\n                    q, k, v, top8_idx, K, cu_seqlens, cu_seqlens, N, N, sm_scale\n                ),\n                quantiles=quantiles,\n            )\n        if provider == \"triton-top16\":\n            ms, min_ms, max_ms = triton.testing.do_bench(\n                lambda: _topk_sparse_attention_fwd(\n                    q, k, v, top16_idx, K, cu_seqlens, cu_seqlens, N, N, sm_scale\n                ),\n                quantiles=quantiles,\n            )\n        return ms, min_ms, max_ms\n\n    benchmark.run(show_plots=True, print_data=True)\n\n    # benchmark\n    @triton.testing.perf_report(\n        triton.testing.Benchmark(\n            x_names=[\"N\"],\n            x_vals=[1024 * 2**i for i in range(1, 8)],\n            line_arg=\"provider\",\n            line_vals=[\"flash\", \"triton-flash\", \"triton-top8\", \"triton-top16\"],\n            line_names=[\n                \"Flash\",\n                \"Triton-Flash\",\n                \"Triton-Top8\",\n                \"Triton-Top16\",\n            ],\n            styles=[(\"green\", \"-\"), (\"green\", \"--\"), (\"blue\", \"-\"), (\"blue\", \"--\")],\n            ylabel=\"ms\",\n            plot_name=\"** backward with block size 64 **\",\n            args={\"H\": 64, \"D\": 128, \"K\": 64},\n        )\n    )\n    def benchmark(N, H, D, K, provider):\n        q = torch.randn((N, H, D), device=\"cuda\", dtype=torch.bfloat16)\n        k = torch.randn((N, H // 16, D), device=\"cuda\", dtype=torch.bfloat16)\n        v = torch.randn((N, H // 16, D), device=\"cuda\", dtype=torch.bfloat16)\n        o = torch.randn((N, H, D), device=\"cuda\", dtype=torch.bfloat16)\n        do = torch.randn((N, H, D), device=\"cuda\", dtype=torch.bfloat16)\n        lse = torch.randn((H, N), device=\"cuda\", dtype=torch.float32)\n        sm_scale = 1 / math.sqrt(D)\n        cu_seqlens = torch.tensor([0, N], device=\"cuda\", dtype=torch.int32)\n        top8_idx = generate_topk_idx_example(cu_seqlens[1:], K, 8, H // 16)\n        top16_idx = generate_topk_idx_example(cu_seqlens[1:], K, 16, H // 16)\n        dq = torch.zeros_like(q)\n        dk = torch.zeros_like(k)\n        dv = torch.zeros_like(v)\n\n        quantiles = [0.5, 0.2, 0.8]\n        if provider == \"flash\":\n            ms, min_ms, max_ms = triton.testing.do_bench(\n                lambda: _flash_attn_varlen_backward(\n                    do,\n                    q,\n                    k,\n                    v,\n                    o,\n                    lse.transpose(0, 1),\n                    dq,\n                    dk,\n                    dv,\n                    cu_seqlens,\n                    cu_seqlens,\n                    N,\n                    N,\n                    dropout_p=0.0,\n                    causal=True,\n                    softmax_scale=sm_scale,\n                    window_size=(-1, -1),\n                    softcap=0.0,\n                    alibi_slopes=None,\n                    deterministic=False,\n                ),\n                quantiles=quantiles,\n            )\n        if provider == \"triton-flash\":\n            ms, min_ms, max_ms = triton.testing.do_bench(\n                lambda: _flash_attention_bwd(\n                    o, do, lse, q, k, v, cu_seqlens, cu_seqlens, N, N, True, sm_scale\n                ),\n                quantiles=quantiles,\n            )\n        if provider == \"triton-top8\":\n            ms, min_ms, max_ms = triton.testing.do_bench(\n                lambda: _topk_sparse_attention_bwd(\n                    o,\n                    do,\n                    lse,\n                    q,\n                    k,\n                    v,\n                    top8_idx,\n                    K,\n                    cu_seqlens,\n                    cu_seqlens,\n                    N,\n                    N,\n                    sm_scale,\n                ),\n                quantiles=quantiles,\n            )\n        if provider == \"triton-top16\":\n            ms, min_ms, max_ms = triton.testing.do_bench(\n                lambda: _topk_sparse_attention_bwd(\n                    o,\n                    do,\n                    lse,\n                    q,\n                    k,\n                    v,\n                    top16_idx,\n                    K,\n                    cu_seqlens,\n                    cu_seqlens,\n                    N,\n                    N,\n                    sm_scale,\n                ),\n                quantiles=quantiles,\n            )\n        return ms, min_ms, max_ms\n\n    benchmark.run(show_plots=True, print_data=True)\n"
  }
]